x-transformers 2.3.24__py3-none-any.whl → 2.3.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -309,7 +309,13 @@ class AutoregressiveWrapper(Module):
309
309
 
310
310
  return out
311
311
 
312
- def forward(self, x, return_outputs = False, **kwargs):
312
+ def forward(
313
+ self,
314
+ x,
315
+ return_outputs = False,
316
+ prepend_embeds = None,
317
+ **kwargs
318
+ ):
313
319
  seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
314
320
 
315
321
  inp, target = x, x[:, 1:]
@@ -328,6 +334,7 @@ class AutoregressiveWrapper(Module):
328
334
  return_intermediates = True,
329
335
  return_attn_z_loss = add_attn_z_loss,
330
336
  return_next_embed_pred = add_next_embed_loss,
337
+ prepend_embeds = prepend_embeds,
331
338
  **kwargs
332
339
  )
333
340
 
@@ -338,6 +345,14 @@ class AutoregressiveWrapper(Module):
338
345
  else:
339
346
  logits = out
340
347
 
348
+ # if there are prepended embeds, excise it out
349
+
350
+ if exists(prepend_embeds):
351
+ prepend_len = prepend_embeds.shape[1]
352
+ logits = logits[:, prepend_len:]
353
+
354
+ # take all tokens but the last
355
+
341
356
  logits = logits[:, :-1]
342
357
 
343
358
  # loss function
@@ -356,7 +371,7 @@ class AutoregressiveWrapper(Module):
356
371
  loss = loss + cache.attn_z_loss
357
372
 
358
373
  if add_next_embed_loss:
359
- mask = inp[:, :-1] != ignore_index
374
+ mask = target != ignore_index
360
375
  embed_pred = next_embed_pred[:, :-1]
361
376
  cont_targets = init_embeds[:, 1:].detach()
362
377
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.24
3
+ Version: 2.3.26
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
2
  x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
3
- x_transformers/autoregressive_wrapper.py,sha256=tMVbIC8iXTpfGDxRhPqqHTulkxB8aZqNML77WbGhfac,11755
3
+ x_transformers/autoregressive_wrapper.py,sha256=3tDUiY5kNcxNUjRERoeuFV0mXztOvgGrckoACIfHvqI,12091
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=7phSZvP1_SDRIkVMwVR4cz1dFU2UlR2Wf1HJHEQlcQg,116222
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.24.dist-info/METADATA,sha256=vqW6_PFF3JiQirofvdzEMXwAt_x9luG_TOQimtIWwg8,89897
15
- x_transformers-2.3.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.24.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.24.dist-info/RECORD,,
14
+ x_transformers-2.3.26.dist-info/METADATA,sha256=Qc0zIph59FLOC0GPGIe41M6P1SD_lljzKg5ytoMyPAI,89897
15
+ x_transformers-2.3.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.26.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.26.dist-info/RECORD,,