x-transformers 2.1.16__py3-none-any.whl → 2.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,7 +23,7 @@ from x_transformers.x_transformers import (
23
23
  )
24
24
 
25
25
  import einx
26
- from einops import rearrange, repeat
26
+ from einops import rearrange, repeat, pack, unpack
27
27
  from einops.layers.torch import Rearrange
28
28
 
29
29
  # helper functions
@@ -69,6 +69,8 @@ class BeliefStateWrapper(Module):
69
69
  dim = forward_decoder.emb_dim
70
70
  num_tokens = forward_decoder.num_tokens
71
71
 
72
+ self.to_forward_logits = nn.Linear(dim, num_tokens, bias = False)
73
+
72
74
  # the suffix token
73
75
 
74
76
  self.suffix_token = nn.Parameter(torch.zeros(dim))
@@ -126,7 +128,7 @@ class BeliefStateWrapper(Module):
126
128
  prompts,
127
129
  seq_len,
128
130
  temperature = 1.25,
129
- cache_kv = True,
131
+ cache_kv = False,
130
132
  suffix: Tensor | None = None, # the goal conditioning
131
133
  filter_logits_fn = min_p,
132
134
  filter_kwargs = dict(
@@ -136,6 +138,8 @@ class BeliefStateWrapper(Module):
136
138
  ):
137
139
  max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
138
140
 
141
+ prompts, batch_ps = pack([prompts], '* d')
142
+
139
143
  batch, orig_seq_len = prompts.shape
140
144
 
141
145
  out = prompts
@@ -197,14 +201,19 @@ class BeliefStateWrapper(Module):
197
201
 
198
202
  # concat sample
199
203
 
200
- out = torch.cat((out, sample), dim=-1)
204
+ out = torch.cat((out, sample), dim = -1)
205
+
206
+ out = out[:, orig_seq_len:]
201
207
 
202
- return out[:, orig_seq_len:]
208
+ out, = unpack(out, batch_ps, '* n')
209
+
210
+ return out
203
211
 
204
212
  def forward(
205
213
  self,
206
214
  seq,
207
- backward = True
215
+ return_loss_only = False,
216
+ loss_scale = 1.
208
217
  ):
209
218
  batch, seq_len, device = *seq.shape, seq.device
210
219
 
@@ -244,6 +253,7 @@ class BeliefStateWrapper(Module):
244
253
  # f - forward, b - backward, i - indices
245
254
 
246
255
  fi, bi = fb_pairs.unbind(dim = -1)
256
+
247
257
  valid_mask = (bi - fi) >= 2
248
258
 
249
259
  fb_pairs = fb_pairs[valid_mask]
@@ -265,8 +275,9 @@ class BeliefStateWrapper(Module):
265
275
 
266
276
  labels_fi, labels_bi = (fi + 1), bi
267
277
 
268
- forward_labels, backward_labels = seq[:, fi], seq[:, bi]
269
- labels = stack((forward_labels, backward_labels), dim = -1)
278
+ forward_labels, backward_labels = seq[:, labels_fi], seq[:, labels_bi]
279
+
280
+ labels = cat((forward_labels, backward_labels), dim = -1)
270
281
 
271
282
  # get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
272
283
 
@@ -281,7 +292,7 @@ class BeliefStateWrapper(Module):
281
292
 
282
293
  loss = F.cross_entropy(
283
294
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
284
- rearrange(labels, 'b n fb -> b (fb n)'),
295
+ labels,
285
296
  reduction = 'none' if self.needs_loss_weight else 'mean'
286
297
  )
287
298
 
@@ -290,12 +301,12 @@ class BeliefStateWrapper(Module):
290
301
  if exists(self.to_terminal_logit):
291
302
  terminal_logits = self.to_terminal_logit(fb_embeds)
292
303
 
293
- labels = ((bi - fi) == 2).float() # distance is exactly 2
294
- labels = repeat(labels, 'n -> b n', b = batch)
304
+ terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
305
+ terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
295
306
 
296
307
  is_end_loss = F.binary_cross_entropy_with_logits(
297
308
  terminal_logits,
298
- labels
309
+ terminal_labels
299
310
  )
300
311
 
301
312
  loss = (
@@ -303,6 +314,11 @@ class BeliefStateWrapper(Module):
303
314
  is_end_loss * self.pred_terminal_loss_weight
304
315
  )
305
316
 
317
+ # maybe early return loss
318
+
319
+ if return_loss_only:
320
+ return loss
321
+
306
322
  # maybe loss weighting
307
323
 
308
324
  if self.needs_loss_weight:
@@ -312,18 +328,9 @@ class BeliefStateWrapper(Module):
312
328
 
313
329
  # backwards
314
330
 
315
- orig_backward = getattr(loss, 'backward')
316
-
317
- def patched_backward_fn(*args, **kwargs):
318
- orig_backward(*args, **kwargs)
319
- orig_forward_embeds.backward(forward_embeds.grad)
320
- orig_backward_embeds.backward(backward_embeds.grad)
321
-
322
- # can allow the researcher to call .backward from the outside
331
+ (loss * loss_scale).backward()
323
332
 
324
- if backward:
325
- patched_backward_fn()
326
- else:
327
- setattr(loss, 'backward', patched_backward_fn)
333
+ orig_forward_embeds.backward(forward_embeds.grad)
334
+ orig_backward_embeds.backward(backward_embeds.grad)
328
335
 
329
336
  return loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.16
3
+ Version: 2.1.17
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=VjUB73yFBWevN6xMc6_1s-Yc58pJv8SDAUUEXwpR-W0,9842
4
+ x_transformers/belief_state_wrapper.py,sha256=7zi_inz75jZVQE4zFVA2H45yR2V0RLTgX_aGO86rN9s,9878
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.16.dist-info/METADATA,sha256=MXaag1fuq1BsAmyh9k8sSRiHpy-jAUQW2Hn1GC53MnQ,87571
14
- x_transformers-2.1.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.16.dist-info/RECORD,,
13
+ x_transformers-2.1.17.dist-info/METADATA,sha256=PzCTyZz0yqv6vXVpzZwtWyu6sjR-kKwcgVPAsrn5TnI,87571
14
+ x_transformers-2.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.17.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.17.dist-info/RECORD,,