x-transformers 2.3.24__tar.gz → 2.3.26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {x_transformers-2.3.24 → x_transformers-2.3.26}/PKG-INFO +1 -1
  2. {x_transformers-2.3.24 → x_transformers-2.3.26}/pyproject.toml +1 -1
  3. {x_transformers-2.3.24 → x_transformers-2.3.26}/tests/test_x_transformers.py +40 -0
  4. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/autoregressive_wrapper.py +17 -2
  5. {x_transformers-2.3.24 → x_transformers-2.3.26}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.3.24 → x_transformers-2.3.26}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.3.24 → x_transformers-2.3.26}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.3.24 → x_transformers-2.3.26}/.gitignore +0 -0
  9. {x_transformers-2.3.24 → x_transformers-2.3.26}/LICENSE +0 -0
  10. {x_transformers-2.3.24 → x_transformers-2.3.26}/README.md +0 -0
  11. {x_transformers-2.3.24 → x_transformers-2.3.26}/data/README.md +0 -0
  12. {x_transformers-2.3.24 → x_transformers-2.3.26}/data/enwik8.gz +0 -0
  13. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/all-attention.png +0 -0
  14. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/deepnorm.png +0 -0
  17. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/fcm.png +0 -0
  23. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/ffglu.png +0 -0
  24. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/flash-attention.png +0 -0
  25. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/gate_values.png +0 -0
  26. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/gating.png +0 -0
  27. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/macaron-1.png +0 -0
  29. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/macaron-2.png +0 -0
  30. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/normformer.png +0 -0
  32. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/pia.png +0 -0
  33. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/resi_dual.png +0 -0
  35. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/residual_attn.png +0 -0
  36. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/rezero.png +0 -0
  37. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/rotary.png +0 -0
  38. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/sandwich.png +0 -0
  40. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/scalenorm.png +0 -0
  42. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/talking-heads.png +0 -0
  43. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/topk-attention.png +0 -0
  44. {x_transformers-2.3.24 → x_transformers-2.3.26}/images/xval.png +0 -0
  45. {x_transformers-2.3.24 → x_transformers-2.3.26}/train_belief_state.py +0 -0
  46. {x_transformers-2.3.24 → x_transformers-2.3.26}/train_copy.py +0 -0
  47. {x_transformers-2.3.24 → x_transformers-2.3.26}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.3.24 → x_transformers-2.3.26}/train_enwik8.py +0 -0
  49. {x_transformers-2.3.24 → x_transformers-2.3.26}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.3.24 → x_transformers-2.3.26}/train_parity.py +0 -0
  51. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/belief_state_wrapper.py +0 -0
  54. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/continuous.py +0 -0
  55. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/entropy_based_tokenizer.py +0 -0
  57. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/multi_input.py +0 -0
  58. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/neo_mlp.py +0 -0
  59. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/nonautoregressive_wrapper.py +0 -0
  60. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/x_transformers.py +0 -0
  61. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  62. {x_transformers-2.3.24 → x_transformers-2.3.26}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.24
3
+ Version: 2.3.26
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.3.24"
3
+ version = "2.3.26"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1036,3 +1036,43 @@ def test_autoregressive_wrapper(
1036
1036
  loss = wrapper(x)
1037
1037
 
1038
1038
  loss.backward()
1039
+
1040
+ def test_prepend_embed():
1041
+
1042
+ from x_transformers import AutoregressiveWrapper
1043
+
1044
+ model = TransformerWrapper(
1045
+ num_tokens = 256,
1046
+ max_seq_len = 1024,
1047
+ attn_layers = Decoder(
1048
+ dim = 512,
1049
+ depth = 12,
1050
+ heads = 8
1051
+ )
1052
+ )
1053
+
1054
+ model = AutoregressiveWrapper(model)
1055
+
1056
+ x = torch.randint(0, 256, (2, 10))
1057
+ prepend_embeds = torch.randn(2, 3, 512)
1058
+
1059
+ loss = model(x, prepend_embeds = prepend_embeds)
1060
+ loss.backward()
1061
+
1062
+ sample = model.generate(
1063
+ prompts = x[:, :1],
1064
+ seq_len = 100,
1065
+ temperature = 0.,
1066
+ prepend_embeds = prepend_embeds,
1067
+ cache_kv = True,
1068
+ )
1069
+
1070
+ sample_no_cache = model.generate(
1071
+ prompts = x[:, :1],
1072
+ seq_len = 100,
1073
+ temperature = 0.,
1074
+ prepend_embeds = prepend_embeds,
1075
+ cache_kv = False,
1076
+ )
1077
+
1078
+ assert torch.allclose(sample, sample_no_cache)
@@ -309,7 +309,13 @@ class AutoregressiveWrapper(Module):
309
309
 
310
310
  return out
311
311
 
312
- def forward(self, x, return_outputs = False, **kwargs):
312
+ def forward(
313
+ self,
314
+ x,
315
+ return_outputs = False,
316
+ prepend_embeds = None,
317
+ **kwargs
318
+ ):
313
319
  seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
314
320
 
315
321
  inp, target = x, x[:, 1:]
@@ -328,6 +334,7 @@ class AutoregressiveWrapper(Module):
328
334
  return_intermediates = True,
329
335
  return_attn_z_loss = add_attn_z_loss,
330
336
  return_next_embed_pred = add_next_embed_loss,
337
+ prepend_embeds = prepend_embeds,
331
338
  **kwargs
332
339
  )
333
340
 
@@ -338,6 +345,14 @@ class AutoregressiveWrapper(Module):
338
345
  else:
339
346
  logits = out
340
347
 
348
+ # if there are prepended embeds, excise it out
349
+
350
+ if exists(prepend_embeds):
351
+ prepend_len = prepend_embeds.shape[1]
352
+ logits = logits[:, prepend_len:]
353
+
354
+ # take all tokens but the last
355
+
341
356
  logits = logits[:, :-1]
342
357
 
343
358
  # loss function
@@ -356,7 +371,7 @@ class AutoregressiveWrapper(Module):
356
371
  loss = loss + cache.attn_z_loss
357
372
 
358
373
  if add_next_embed_loss:
359
- mask = inp[:, :-1] != ignore_index
374
+ mask = target != ignore_index
360
375
  embed_pred = next_embed_pred[:, :-1]
361
376
  cont_targets = init_embeds[:, 1:].detach()
362
377
 
File without changes