x-transformers 2.6.2__tar.gz → 2.6.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.6.2 → x_transformers-2.6.3}/PKG-INFO +1 -1
  2. {x_transformers-2.6.2 → x_transformers-2.6.3}/pyproject.toml +1 -1
  3. {x_transformers-2.6.2 → x_transformers-2.6.3}/tests/test_x_transformers.py +2 -2
  4. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/x_transformers.py +7 -1
  5. {x_transformers-2.6.2 → x_transformers-2.6.3}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.6.2 → x_transformers-2.6.3}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.6.2 → x_transformers-2.6.3}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.6.2 → x_transformers-2.6.3}/.gitignore +0 -0
  9. {x_transformers-2.6.2 → x_transformers-2.6.3}/LICENSE +0 -0
  10. {x_transformers-2.6.2 → x_transformers-2.6.3}/README.md +0 -0
  11. {x_transformers-2.6.2 → x_transformers-2.6.3}/data/README.md +0 -0
  12. {x_transformers-2.6.2 → x_transformers-2.6.3}/data/enwik8.gz +0 -0
  13. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/all-attention.png +0 -0
  14. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/deepnorm.png +0 -0
  17. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/fcm.png +0 -0
  23. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/ffglu.png +0 -0
  24. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/flash-attention.png +0 -0
  25. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/gate_values.png +0 -0
  26. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/gating.png +0 -0
  27. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/macaron-1.png +0 -0
  29. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/macaron-2.png +0 -0
  30. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/normformer.png +0 -0
  32. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/pia.png +0 -0
  33. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/resi_dual.png +0 -0
  35. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/residual_attn.png +0 -0
  36. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/rezero.png +0 -0
  37. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/rotary.png +0 -0
  38. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/sandwich.png +0 -0
  40. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/scalenorm.png +0 -0
  42. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/talking-heads.png +0 -0
  43. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/topk-attention.png +0 -0
  44. {x_transformers-2.6.2 → x_transformers-2.6.3}/images/xval.png +0 -0
  45. {x_transformers-2.6.2 → x_transformers-2.6.3}/train_belief_state.py +0 -0
  46. {x_transformers-2.6.2 → x_transformers-2.6.3}/train_copy.py +0 -0
  47. {x_transformers-2.6.2 → x_transformers-2.6.3}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.6.2 → x_transformers-2.6.3}/train_enwik8.py +0 -0
  49. {x_transformers-2.6.2 → x_transformers-2.6.3}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.6.2 → x_transformers-2.6.3}/train_parity.py +0 -0
  51. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.6.2 → x_transformers-2.6.3}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.2
3
+ Version: 2.6.3
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.6.2"
3
+ version = "2.6.3"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1228,8 +1228,8 @@ def test_external_key_values():
1228
1228
  seq = torch.randint(0, 20000, (3, 1024))
1229
1229
 
1230
1230
  key_values = [
1231
- (torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
1232
- (torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
1231
+ (torch.randn(3, 2, 32, 16), torch.randn(3, 2, 32, 16)),
1232
+ (torch.randn(3, 2, 32, 16), torch.randn(3, 2, 32, 16)),
1233
1233
  ]
1234
1234
 
1235
1235
  additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()
@@ -1795,6 +1795,13 @@ class Attention(Module):
1795
1795
  seq_len = k.shape[-2]
1796
1796
 
1797
1797
  added_k, added_v = additional_key_values
1798
+ added_kv_heads, added_kv_len = added_k.shape[1], added_k.shape[-2]
1799
+
1800
+ # take care of expanding to query heads if mismatch between key / value heads with the ones coming from vlm
1801
+
1802
+ if added_kv_heads != kv_h:
1803
+ assert divisible_by(h, added_kv_heads)
1804
+ k, v, added_k, added_v = tuple(repeat(t, 'b h ... -> b (r h) ...', r = h // t.shape[1]) for t in (k, v, added_k, added_v))
1798
1805
 
1799
1806
  k = cat((added_k, k), dim = -2)
1800
1807
  v = cat((added_v, v), dim = -2)
@@ -1802,7 +1809,6 @@ class Attention(Module):
1802
1809
  if (exists(input_mask) or exists(additional_key_value_mask)):
1803
1810
 
1804
1811
  if not exists(additional_key_value_mask):
1805
- added_kv_len = added_k.shape[-2]
1806
1812
  input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1807
1813
  elif not exists(input_mask):
1808
1814
  input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
File without changes
File without changes