x-transformers 2.11.19__tar.gz → 2.11.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.19 → x_transformers-2.11.22}/PKG-INFO +8 -7
  2. {x_transformers-2.11.19 → x_transformers-2.11.22}/README.md +7 -6
  3. {x_transformers-2.11.19 → x_transformers-2.11.22}/pyproject.toml +1 -1
  4. {x_transformers-2.11.19 → x_transformers-2.11.22}/tests/test_x_transformers.py +22 -8
  5. {x_transformers-2.11.19 → x_transformers-2.11.22}/train_enwik8.py +22 -1
  6. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/x_transformers.py +45 -1
  7. {x_transformers-2.11.19 → x_transformers-2.11.22}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.11.19 → x_transformers-2.11.22}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.11.19 → x_transformers-2.11.22}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.11.19 → x_transformers-2.11.22}/.gitignore +0 -0
  11. {x_transformers-2.11.19 → x_transformers-2.11.22}/LICENSE +0 -0
  12. {x_transformers-2.11.19 → x_transformers-2.11.22}/data/README.md +0 -0
  13. {x_transformers-2.11.19 → x_transformers-2.11.22}/data/enwik8.gz +0 -0
  14. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/all-attention.png +0 -0
  15. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/deepnorm.png +0 -0
  18. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/fcm.png +0 -0
  24. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/ffglu.png +0 -0
  25. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/flash-attention.png +0 -0
  26. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/gate_values.png +0 -0
  27. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/gating.png +0 -0
  28. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/macaron-1.png +0 -0
  30. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/macaron-2.png +0 -0
  31. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/normformer.png +0 -0
  33. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/pia.png +0 -0
  34. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/resi_dual.png +0 -0
  36. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/residual_attn.png +0 -0
  37. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/rezero.png +0 -0
  38. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/rotary.png +0 -0
  39. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/sandwich.png +0 -0
  41. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/scalenorm.png +0 -0
  43. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/talking-heads.png +0 -0
  44. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/topk-attention.png +0 -0
  45. {x_transformers-2.11.19 → x_transformers-2.11.22}/images/xval.png +0 -0
  46. {x_transformers-2.11.19 → x_transformers-2.11.22}/train_belief_state.py +0 -0
  47. {x_transformers-2.11.19 → x_transformers-2.11.22}/train_copy.py +0 -0
  48. {x_transformers-2.11.19 → x_transformers-2.11.22}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.11.19 → x_transformers-2.11.22}/train_free.py +0 -0
  50. {x_transformers-2.11.19 → x_transformers-2.11.22}/train_gpt_vae.py +0 -0
  51. {x_transformers-2.11.19 → x_transformers-2.11.22}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.11.19 → x_transformers-2.11.22}/train_parity.py +0 -0
  53. {x_transformers-2.11.19 → x_transformers-2.11.22}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/free_transformer.py +0 -0
  62. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/gpt_vae.py +0 -0
  63. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/multi_input.py +0 -0
  64. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/neo_mlp.py +0 -0
  65. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/nonautoregressive_wrapper.py +0 -0
  66. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/up_wrapper.py +0 -0
  67. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.19 → x_transformers-2.11.22}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.19
3
+ Version: 2.11.22
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2608,12 +2608,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2608
2608
  ```
2609
2609
 
2610
2610
  ```bibtex
2611
- @article{elhage2022solu,
2612
- title = {Softmax Linear Units},
2613
- author = {Elhage, Nelson and Hume, Tristan and Olsson, Catherine and Nanda, Neel and Henighan, Tom and Johnston, Scott and ElShowk, Sheer and Joseph, Nicholas and DasSarma, Nova and Mann, Ben and Hernandez, Danny and Askell, Amanda and Ndousse, Kamal and Jones, Andy and Drain, Dawn and Chen, Anna and Bai, Yuntao and Ganguli, Deep and Lovitt, Liane and Hatfield-Dodds, Zac and Kernion, Jackson and Conerly, Tom and Kravec, Shauna and Fort, Stanislav and Kadavath, Saurav and Jacobson, Josh and Tran-Johnson, Eli and Kaplan, Jared and Clark, Jack and Brown, Tom and McCandlish, Sam and Amodei, Dario and Olah, Christopher},
2614
- year = {2022},
2615
- journal = {Transformer Circuits Thread},
2616
- note = {https://transformer-circuits.pub/2022/solu/index.html}
2611
+ @inproceedings{anonymous2025beliefformer,
2612
+ title = {BeliefFormer: Belief Attention in Transformer},
2613
+ author = {Anonymous},
2614
+ booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
2615
+ year = {2025},
2616
+ url = {https://openreview.net/forum?id=Ard2QzPAUK},
2617
+ note = {under review}
2617
2618
  }
2618
2619
  ```
2619
2620
 
@@ -2559,12 +2559,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2559
2559
  ```
2560
2560
 
2561
2561
  ```bibtex
2562
- @article{elhage2022solu,
2563
- title = {Softmax Linear Units},
2564
- author = {Elhage, Nelson and Hume, Tristan and Olsson, Catherine and Nanda, Neel and Henighan, Tom and Johnston, Scott and ElShowk, Sheer and Joseph, Nicholas and DasSarma, Nova and Mann, Ben and Hernandez, Danny and Askell, Amanda and Ndousse, Kamal and Jones, Andy and Drain, Dawn and Chen, Anna and Bai, Yuntao and Ganguli, Deep and Lovitt, Liane and Hatfield-Dodds, Zac and Kernion, Jackson and Conerly, Tom and Kravec, Shauna and Fort, Stanislav and Kadavath, Saurav and Jacobson, Josh and Tran-Johnson, Eli and Kaplan, Jared and Clark, Jack and Brown, Tom and McCandlish, Sam and Amodei, Dario and Olah, Christopher},
2565
- year = {2022},
2566
- journal = {Transformer Circuits Thread},
2567
- note = {https://transformer-circuits.pub/2022/solu/index.html}
2562
+ @inproceedings{anonymous2025beliefformer,
2563
+ title = {BeliefFormer: Belief Attention in Transformer},
2564
+ author = {Anonymous},
2565
+ booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
2566
+ year = {2025},
2567
+ url = {https://openreview.net/forum?id=Ard2QzPAUK},
2568
+ note = {under review}
2568
2569
  }
2569
2570
  ```
2570
2571
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.19"
3
+ version = "2.11.22"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1463,13 +1463,27 @@ def test_kv_input_residual():
1463
1463
 
1464
1464
  assert tokens.shape == out.shape
1465
1465
 
1466
- def test_solu():
1467
- attn = Decoder(
1468
- dim = 256,
1469
- depth = 2,
1470
- heads = 4,
1471
- ff_solu = True
1466
+ @param('orthog_project', (False, True))
1467
+ @param('orthog_project_per_head', (False, True))
1468
+ def test_belief_attn(
1469
+ orthog_project,
1470
+ orthog_project_per_head
1471
+ ):
1472
+ from x_transformers import TransformerWrapper, Decoder
1473
+
1474
+ model = TransformerWrapper(
1475
+ num_tokens = 256,
1476
+ max_seq_len = 1024,
1477
+ attn_layers = Decoder(
1478
+ dim = 512,
1479
+ depth = 6,
1480
+ heads = 8,
1481
+ rotary_pos_emb = True,
1482
+ attn_orthog_projected_values = orthog_project,
1483
+ attn_orthog_projected_values_per_head = orthog_project_per_head
1484
+ )
1472
1485
  )
1473
1486
 
1474
- tokens = torch.randn(3, 32, 256)
1475
- attn(tokens)
1487
+ x = torch.randint(0, 256, (1, 10))
1488
+
1489
+ logits = model(x)
@@ -1,3 +1,11 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "tqdm",
4
+ # "x-transformers",
5
+ # "wandb"
6
+ # ]
7
+ # ///
8
+
1
9
  from x_transformers import TransformerWrapper, Decoder
2
10
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
3
11
 
@@ -20,6 +28,7 @@ VALIDATE_EVERY = 100
20
28
  GENERATE_EVERY = 500
21
29
  GENERATE_LENGTH = 1024
22
30
  SEQ_LEN = 1024
31
+ TRACK_EXPERIMENT_ONLINE = False
23
32
 
24
33
  # helpers
25
34
 
@@ -43,7 +52,9 @@ model = TransformerWrapper(
43
52
  dim = 512,
44
53
  depth = 6,
45
54
  heads = 8,
46
- rotary_pos_emb = True
55
+ rotary_pos_emb = True,
56
+ attn_orthog_projected_values = True,
57
+ attn_orthog_projected_values_per_head = True
47
58
  )
48
59
  )
49
60
 
@@ -80,6 +91,12 @@ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last
80
91
 
81
92
  optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
82
93
 
94
+ # experiment
95
+
96
+ import wandb
97
+ wandb.init(project = 'enwik8', mode = 'online' if TRACK_EXPERIMENT_ONLINE else 'disabled')
98
+ wandb.run.name = 'baseline'
99
+
83
100
  # training
84
101
 
85
102
  for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
@@ -90,6 +107,8 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
90
107
  (loss / GRADIENT_ACCUMULATE_EVERY).backward()
91
108
 
92
109
  print(f'training loss: {loss.item()}')
110
+ wandb.log(dict(loss = loss.item()))
111
+
93
112
  torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
94
113
  optim.step()
95
114
  optim.zero_grad()
@@ -98,7 +117,9 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
98
117
  model.eval()
99
118
  with torch.no_grad():
100
119
  loss = model(next(val_loader))
120
+
101
121
  print(f'validation loss: {loss.item()}')
122
+ wandb.log(dict(valid_loss = loss.item()))
102
123
 
103
124
  if i % GENERATE_EVERY == 0:
104
125
  model.eval()
@@ -161,6 +161,21 @@ def or_reduce(masks):
161
161
  head = head | rest
162
162
  return head
163
163
 
164
+ def orthog_project(x, y):
165
+ x, packed_shape = pack([x], 'b *')
166
+ y, _ = pack([y], 'b *')
167
+
168
+ dtype = x.dtype
169
+ x, y = x.double(), y.double()
170
+ unit = F.normalize(y, dim = -1)
171
+
172
+ parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
173
+ orthog = x - parallel
174
+
175
+ orthog, = unpack(orthog, packed_shape, 'b *')
176
+
177
+ return orthog.to(dtype)
178
+
164
179
  # cache helpers
165
180
 
166
181
  def get_cached_kvs(
@@ -1381,7 +1396,9 @@ class Attention(Module):
1381
1396
  softclamp_logits = False,
1382
1397
  logit_softclamp_value = 50.,
1383
1398
  learned_value_residual_mix = False,
1384
- laser = False, # https://arxiv.org/abs/2411.03493v1
1399
+ orthog_projected_values = False, # https://openreview.net/forum?id=Ard2QzPAUK
1400
+ orthog_projected_values_per_head = False,
1401
+ laser = False, # https://arxiv.org/abs/2411.03493v1
1385
1402
  laser_softclamp_value = 15.,
1386
1403
  qkv_receive_diff_residuals = False,
1387
1404
  use_latent_q = False,
@@ -1607,6 +1624,14 @@ class Attention(Module):
1607
1624
 
1608
1625
  self.attn_on_attn = on_attn
1609
1626
 
1627
+ # return orthogonal projected weighted values on original values
1628
+ # "belief attention" - iclr 2026
1629
+
1630
+ self.orthog_projected_values = orthog_projected_values
1631
+ self.orthog_projected_values_per_head = orthog_projected_values_per_head
1632
+
1633
+ out_dim *= max(1, int(orthog_projected_values) + int(orthog_projected_values_per_head))
1634
+
1610
1635
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1611
1636
 
1612
1637
  hybrid_mix = None
@@ -2048,6 +2073,25 @@ class Attention(Module):
2048
2073
  gates = self.to_v_gate(x)
2049
2074
  out = out * self.to_v_gate_activation(gates)
2050
2075
 
2076
+ # maybe orthogonal projected weighted values - "belief" attention
2077
+
2078
+ if self.orthog_projected_values or self.orthog_projected_values_per_head:
2079
+ orthog_projected = []
2080
+ v_for_proj = self.merge_heads(orig_values)
2081
+
2082
+ if self.orthog_projected_values:
2083
+ projected = orthog_project(out, v_for_proj)
2084
+ orthog_projected.append(projected)
2085
+
2086
+ if self.orthog_projected_values_per_head:
2087
+ v_for_proj = rearrange(v_for_proj, 'b n (h d) -> b n h d', h = h)
2088
+ out = rearrange(out, 'b n (h d) -> b n h d', h = h)
2089
+ projected = orthog_project(out, v_for_proj)
2090
+ projected = rearrange(projected, 'b n h d -> b n (h d)')
2091
+ orthog_projected.append(projected)
2092
+
2093
+ out = cat(orthog_projected, dim = -1)
2094
+
2051
2095
  # combine the heads
2052
2096
 
2053
2097
  out = self.to_out(out)