x-transformers 2.0.1__tar.gz → 2.0.3__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (55) hide show
  1. {x_transformers-2.0.1 → x_transformers-2.0.3}/PKG-INFO +13 -2
  2. {x_transformers-2.0.1 → x_transformers-2.0.3}/README.md +11 -1
  3. {x_transformers-2.0.1 → x_transformers-2.0.3}/pyproject.toml +6 -2
  4. {x_transformers-2.0.1 → x_transformers-2.0.3}/train_parity.py +16 -8
  5. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/x_transformers.py +10 -3
  6. {x_transformers-2.0.1 → x_transformers-2.0.3}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.0.1 → x_transformers-2.0.3}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.0.1 → x_transformers-2.0.3}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.0.1 → x_transformers-2.0.3}/.gitignore +0 -0
  10. {x_transformers-2.0.1 → x_transformers-2.0.3}/LICENSE +0 -0
  11. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/all-attention.png +0 -0
  12. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/attention-on-attention.png +0 -0
  13. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/cosine-sim-attention.png +0 -0
  14. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/deepnorm.png +0 -0
  15. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/dynamic-pos-bias-linear.png +0 -0
  16. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/dynamic-pos-bias-log.png +0 -0
  17. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  18. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/dynamic-pos-bias.png +0 -0
  19. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/enhanced-recurrence.png +0 -0
  20. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/fcm.png +0 -0
  21. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/ffglu.png +0 -0
  22. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/flash-attention.png +0 -0
  23. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/gate_values.png +0 -0
  24. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/gating.png +0 -0
  25. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/length-extrapolation-scale.png +0 -0
  26. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/macaron-1.png +0 -0
  27. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/macaron-2.png +0 -0
  28. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/memory-transformer.png +0 -0
  29. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/normformer.png +0 -0
  30. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/pia.png +0 -0
  31. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/qknorm-analysis.png +0 -0
  32. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/resi_dual.png +0 -0
  33. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/residual_attn.png +0 -0
  34. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/rezero.png +0 -0
  35. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/rotary.png +0 -0
  36. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/sandwich-2.png +0 -0
  37. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/sandwich.png +0 -0
  38. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/sandwich_norm.png +0 -0
  39. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/scalenorm.png +0 -0
  40. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/talking-heads.png +0 -0
  41. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/topk-attention.png +0 -0
  42. {x_transformers-2.0.1 → x_transformers-2.0.3}/images/xval.png +0 -0
  43. {x_transformers-2.0.1 → x_transformers-2.0.3}/tests/test_x_transformers.py +0 -0
  44. {x_transformers-2.0.1 → x_transformers-2.0.3}/train_copy.py +0 -0
  45. {x_transformers-2.0.1 → x_transformers-2.0.3}/train_enwik8.py +0 -0
  46. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/__init__.py +0 -0
  47. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/attend.py +0 -0
  48. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/autoregressive_wrapper.py +0 -0
  49. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/continuous.py +0 -0
  50. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/dpo.py +0 -0
  51. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/multi_input.py +0 -0
  52. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/neo_mlp.py +0 -0
  53. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/nonautoregressive_wrapper.py +0 -0
  54. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  55. {x_transformers-2.0.1 → x_transformers-2.0.3}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.0.1
3
+ Version: 2.0.3
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -40,6 +40,7 @@ Requires-Dist: loguru
40
40
  Requires-Dist: packaging>=21.0
41
41
  Requires-Dist: torch>=2.0
42
42
  Provides-Extra: examples
43
+ Requires-Dist: lion-pytorch; extra == 'examples'
43
44
  Requires-Dist: torchvision; extra == 'examples'
44
45
  Requires-Dist: tqdm; extra == 'examples'
45
46
  Provides-Extra: test
@@ -949,7 +950,8 @@ model_xl = TransformerWrapper(
949
950
  dim = 512,
950
951
  depth = 6,
951
952
  heads = 8,
952
- rotary_pos_emb = True
953
+ rotary_pos_emb = True,
954
+ rotate_num_heads = 4 # only rotate 4 out of the 8 attention heads
953
955
  )
954
956
  )
955
957
 
@@ -1838,6 +1840,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
1838
1840
  }
1839
1841
  ```
1840
1842
 
1843
+ ```bibtex
1844
+ @inproceedings{Yang2025RopeTN,
1845
+ title = {Rope to Nope and Back Again: A New Hybrid Attention Strategy},
1846
+ author = {Bowen Yang and Bharat Venkitesh and Dwarak Talupuru and Hangyu Lin and David Cairuz and Phil Blunsom and Acyr F. Locatelli},
1847
+ year = {2025},
1848
+ url = {https://api.semanticscholar.org/CorpusID:276079501}
1849
+ }
1850
+ ```
1851
+
1841
1852
  ```bibtex
1842
1853
  @inproceedings{Chen2023ExtendingCW,
1843
1854
  title = {Extending Context Window of Large Language Models via Positional Interpolation},
@@ -901,7 +901,8 @@ model_xl = TransformerWrapper(
901
901
  dim = 512,
902
902
  depth = 6,
903
903
  heads = 8,
904
- rotary_pos_emb = True
904
+ rotary_pos_emb = True,
905
+ rotate_num_heads = 4 # only rotate 4 out of the 8 attention heads
905
906
  )
906
907
  )
907
908
 
@@ -1790,6 +1791,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
1790
1791
  }
1791
1792
  ```
1792
1793
 
1794
+ ```bibtex
1795
+ @inproceedings{Yang2025RopeTN,
1796
+ title = {Rope to Nope and Back Again: A New Hybrid Attention Strategy},
1797
+ author = {Bowen Yang and Bharat Venkitesh and Dwarak Talupuru and Hangyu Lin and David Cairuz and Phil Blunsom and Acyr F. Locatelli},
1798
+ year = {2025},
1799
+ url = {https://api.semanticscholar.org/CorpusID:276079501}
1800
+ }
1801
+ ```
1802
+
1793
1803
  ```bibtex
1794
1804
  @inproceedings{Chen2023ExtendingCW,
1795
1805
  title = {Extending Context Window of Large Language Models via Positional Interpolation},
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.0.1"
3
+ version = "2.0.3"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -34,7 +34,11 @@ Homepage = "https://pypi.org/project/x-transformers/"
34
34
  Repository = "https://github.com/lucidrains/x-transformers"
35
35
 
36
36
  [project.optional-dependencies]
37
- examples = ["tqdm", "torchvision"]
37
+ examples = [
38
+ "lion-pytorch",
39
+ "tqdm",
40
+ "torchvision"
41
+ ]
38
42
 
39
43
  test = [
40
44
  "pytest",
@@ -7,12 +7,16 @@ from x_transformers import TransformerWrapper, Decoder
7
7
 
8
8
  # constants
9
9
 
10
- NUM_BATCHES = 100000
11
10
  BATCH_SIZE = 256
12
11
  LEARNING_RATE = 3e-4
13
12
  EVAL_EVERY = 500
14
- TRAIN_MAX_LENGTH = 64
13
+
15
14
  EVAL_LENGTHS = (16, 32, 64, 128, 256, 512)
15
+ TRAIN_MAX_LENGTH = EVAL_LENGTHS[-2]
16
+
17
+ LOSS_THRES_INCREASE_LEN = 1e-3
18
+ MEET_CRITERIA_THRES_INCREASE_LEN = 10
19
+
16
20
  HYBRIDIZE_WITH_RNN = True
17
21
 
18
22
  # rnn for fully resolving state tracking by hybridization
@@ -28,6 +32,7 @@ if HYBRIDIZE_WITH_RNN:
28
32
 
29
33
  decoder_kwargs = dict(
30
34
  attn_hybrid_fold_axial_dim = 4, # even if recurrence is every 4 tokens, can generalize for parity
35
+ attn_hybrid_learned_mix = True,
31
36
  attn_hybrid_module = GRU(dim, dim_head * heads, batch_first = True)
32
37
  )
33
38
 
@@ -48,7 +53,9 @@ model = TransformerWrapper(
48
53
 
49
54
  # optimizer
50
55
 
51
- adam = optim.Adam(model.parameters(), lr = LEARNING_RATE)
56
+ from lion_pytorch.cautious_lion import Lion
57
+
58
+ optimizer = Lion(model.parameters(), lr = LEARNING_RATE, cautious_factor = 0.1)
52
59
 
53
60
  # data generator
54
61
 
@@ -73,7 +80,8 @@ meet_criteria = 0
73
80
  train_seq_len = 1
74
81
  stop_length = EVAL_LENGTHS[-2]
75
82
 
76
- with tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training') as pbar:
83
+ with tqdm.tqdm(mininterval = 10., desc = 'training') as pbar:
84
+
77
85
  while train_seq_len < stop_length:
78
86
  model.train()
79
87
 
@@ -90,12 +98,12 @@ with tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training') as pbar
90
98
  last_loss = loss[:, -1].mean()
91
99
  loss.mean().backward()
92
100
 
93
- if last_loss.item() < 0.001:
101
+ if last_loss.item() < LOSS_THRES_INCREASE_LEN:
94
102
  meet_criteria += 1
95
103
  else:
96
104
  meet_criteria = 0
97
105
 
98
- if meet_criteria >= 10:
106
+ if meet_criteria >= MEET_CRITERIA_THRES_INCREASE_LEN:
99
107
  meet_criteria = 0
100
108
  train_seq_len += 1
101
109
  print(f'criteria met, incrementing to {train_seq_len}')
@@ -103,8 +111,8 @@ with tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training') as pbar
103
111
  print(f'({train_seq_len})| {i}: {last_loss.item()}')
104
112
  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
105
113
 
106
- adam.step()
107
- adam.zero_grad()
114
+ optimizer.step()
115
+ optimizer.zero_grad()
108
116
 
109
117
  last_step = train_seq_len == stop_length
110
118
 
@@ -1204,6 +1204,7 @@ class Attention(Module):
1204
1204
  hybrid_module: Module | None = None,
1205
1205
  hybrid_mask_kwarg: str | None = None,
1206
1206
  hybrid_fold_axial_dim: int | None = None,
1207
+ hybrid_learned_mix = False,
1207
1208
  one_kv_head = False,
1208
1209
  kv_heads = None,
1209
1210
  value_dim_head = None,
@@ -1446,7 +1447,7 @@ class Attention(Module):
1446
1447
 
1447
1448
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1448
1449
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1449
- hybrid_mix = LinearNoBias(dim, heads)
1450
+ hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
1450
1451
 
1451
1452
  hybrid_norms = ModuleList([
1452
1453
  MultiheadRMSNorm(dim_head, heads = heads),
@@ -1779,7 +1780,12 @@ class Attention(Module):
1779
1780
  out = out_norm(out)
1780
1781
  hybrid_out = hybrid_out_norm(hybrid_out)
1781
1782
 
1782
- out = 0.5 * (out + hybrid_out)
1783
+ if exists(self.hybrid_mix):
1784
+ mix = self.hybrid_mix(x)
1785
+ mix = rearrange(mix, 'b n h -> b h n 1')
1786
+ out = out.lerp(hybrid_out, mix.sigmoid())
1787
+ else:
1788
+ out = 0.5 * (out + hybrid_out)
1783
1789
 
1784
1790
  # merge heads
1785
1791
 
@@ -1839,6 +1845,7 @@ class AttentionLayers(Module):
1839
1845
  rotary_interpolation_factor = 1.,
1840
1846
  rotary_xpos_scale_base = 512,
1841
1847
  rotary_base_rescale_factor = 1.,
1848
+ rotate_num_heads = None,
1842
1849
  weight_tie_layers = False,
1843
1850
  custom_layers: tuple[str, ...] | None = None,
1844
1851
  layers_execute_order: tuple[int, ...] | None = None,
@@ -2141,7 +2148,7 @@ class AttentionLayers(Module):
2141
2148
 
2142
2149
  if layer_type == 'a':
2143
2150
  self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
2144
- layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
2151
+ layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
2145
2152
  is_first_self_attn = False
2146
2153
  elif layer_type == 'c':
2147
2154
  layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
File without changes