x-transformers 2.3.25__tar.gz → 2.3.27__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {x_transformers-2.3.25 → x_transformers-2.3.27}/PKG-INFO +1 -1
  2. {x_transformers-2.3.25 → x_transformers-2.3.27}/pyproject.toml +1 -1
  3. {x_transformers-2.3.25 → x_transformers-2.3.27}/tests/test_x_transformers.py +43 -0
  4. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/autoregressive_wrapper.py +16 -1
  5. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/x_transformers.py +2 -2
  6. {x_transformers-2.3.25 → x_transformers-2.3.27}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.3.25 → x_transformers-2.3.27}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.3.25 → x_transformers-2.3.27}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.3.25 → x_transformers-2.3.27}/.gitignore +0 -0
  10. {x_transformers-2.3.25 → x_transformers-2.3.27}/LICENSE +0 -0
  11. {x_transformers-2.3.25 → x_transformers-2.3.27}/README.md +0 -0
  12. {x_transformers-2.3.25 → x_transformers-2.3.27}/data/README.md +0 -0
  13. {x_transformers-2.3.25 → x_transformers-2.3.27}/data/enwik8.gz +0 -0
  14. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/all-attention.png +0 -0
  15. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/deepnorm.png +0 -0
  18. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/fcm.png +0 -0
  24. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/ffglu.png +0 -0
  25. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/flash-attention.png +0 -0
  26. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/gate_values.png +0 -0
  27. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/gating.png +0 -0
  28. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/macaron-1.png +0 -0
  30. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/macaron-2.png +0 -0
  31. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/normformer.png +0 -0
  33. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/pia.png +0 -0
  34. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/resi_dual.png +0 -0
  36. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/residual_attn.png +0 -0
  37. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/rezero.png +0 -0
  38. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/rotary.png +0 -0
  39. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/sandwich.png +0 -0
  41. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/scalenorm.png +0 -0
  43. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/talking-heads.png +0 -0
  44. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/topk-attention.png +0 -0
  45. {x_transformers-2.3.25 → x_transformers-2.3.27}/images/xval.png +0 -0
  46. {x_transformers-2.3.25 → x_transformers-2.3.27}/train_belief_state.py +0 -0
  47. {x_transformers-2.3.25 → x_transformers-2.3.27}/train_copy.py +0 -0
  48. {x_transformers-2.3.25 → x_transformers-2.3.27}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.3.25 → x_transformers-2.3.27}/train_enwik8.py +0 -0
  50. {x_transformers-2.3.25 → x_transformers-2.3.27}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.3.25 → x_transformers-2.3.27}/train_parity.py +0 -0
  52. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/__init__.py +0 -0
  53. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/attend.py +0 -0
  54. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  62. {x_transformers-2.3.25 → x_transformers-2.3.27}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.25
3
+ Version: 2.3.27
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.3.25"
3
+ version = "2.3.27"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1036,3 +1036,46 @@ def test_autoregressive_wrapper(
1036
1036
  loss = wrapper(x)
1037
1037
 
1038
1038
  loss.backward()
1039
+
1040
+ def test_prepend_embed():
1041
+
1042
+ from x_transformers import AutoregressiveWrapper
1043
+
1044
+ model = TransformerWrapper(
1045
+ num_tokens = 256,
1046
+ max_seq_len = 1024,
1047
+ attn_layers = Decoder(
1048
+ dim = 512,
1049
+ depth = 12,
1050
+ heads = 8
1051
+ )
1052
+ )
1053
+
1054
+ model = AutoregressiveWrapper(model)
1055
+
1056
+ x = torch.randint(0, 256, (2, 10))
1057
+ prepend_embeds = torch.randn(2, 3, 512)
1058
+ prepend_mask = torch.randint(0, 2, (2, 3)).bool()
1059
+
1060
+ loss = model(x, prepend_mask = prepend_mask, prepend_embeds = prepend_embeds)
1061
+ loss.backward()
1062
+
1063
+ sample = model.generate(
1064
+ prompts = x[:, :1],
1065
+ seq_len = 100,
1066
+ temperature = 0.,
1067
+ prepend_embeds = prepend_embeds,
1068
+ prepend_mask = prepend_mask,
1069
+ cache_kv = True,
1070
+ )
1071
+
1072
+ sample_no_cache = model.generate(
1073
+ prompts = x[:, :1],
1074
+ seq_len = 100,
1075
+ temperature = 0.,
1076
+ prepend_embeds = prepend_embeds,
1077
+ prepend_mask = prepend_mask,
1078
+ cache_kv = False,
1079
+ )
1080
+
1081
+ assert torch.allclose(sample, sample_no_cache)
@@ -309,7 +309,13 @@ class AutoregressiveWrapper(Module):
309
309
 
310
310
  return out
311
311
 
312
- def forward(self, x, return_outputs = False, **kwargs):
312
+ def forward(
313
+ self,
314
+ x,
315
+ return_outputs = False,
316
+ prepend_embeds = None,
317
+ **kwargs
318
+ ):
313
319
  seq, ignore_index, add_attn_z_loss, add_next_embed_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss, self.add_continuous_pred_head
314
320
 
315
321
  inp, target = x, x[:, 1:]
@@ -328,6 +334,7 @@ class AutoregressiveWrapper(Module):
328
334
  return_intermediates = True,
329
335
  return_attn_z_loss = add_attn_z_loss,
330
336
  return_next_embed_pred = add_next_embed_loss,
337
+ prepend_embeds = prepend_embeds,
331
338
  **kwargs
332
339
  )
333
340
 
@@ -338,6 +345,14 @@ class AutoregressiveWrapper(Module):
338
345
  else:
339
346
  logits = out
340
347
 
348
+ # if there are prepended embeds, excise it out
349
+
350
+ if exists(prepend_embeds):
351
+ prepend_len = prepend_embeds.shape[1]
352
+ logits = logits[:, prepend_len:]
353
+
354
+ # take all tokens but the last
355
+
341
356
  logits = logits[:, :-1]
342
357
 
343
358
  # loss function
@@ -1926,7 +1926,7 @@ class Attention(Module):
1926
1926
 
1927
1927
  out = maybe(self.sublayer_dropout)(out)
1928
1928
 
1929
- if exists(mask):
1929
+ if exists(mask) and not exists(cache):
1930
1930
  out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
1931
1931
 
1932
1932
  if not return_intermediates:
@@ -2484,7 +2484,7 @@ class AttentionLayers(Module):
2484
2484
  attn_cache = []
2485
2485
 
2486
2486
  if exists(cache):
2487
- assert self.causal and not any([*map(exists, (mask, attn_mask))])
2487
+ assert self.causal and not exists(attn_mask)
2488
2488
 
2489
2489
  prev_cache_length = cache.cache_length
2490
2490
 
File without changes