x-transformers 2.9.1__tar.gz → 2.10.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (66) hide show
  1. {x_transformers-2.9.1 → x_transformers-2.10.0}/PKG-INFO +12 -1
  2. {x_transformers-2.9.1 → x_transformers-2.10.0}/README.md +11 -0
  3. {x_transformers-2.9.1 → x_transformers-2.10.0}/pyproject.toml +1 -1
  4. {x_transformers-2.9.1 → x_transformers-2.10.0}/tests/test_x_transformers.py +23 -1
  5. {x_transformers-2.9.1 → x_transformers-2.10.0}/train_with_muon.py +2 -4
  6. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/attend.py +31 -0
  7. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/x_transformers.py +2 -0
  8. {x_transformers-2.9.1 → x_transformers-2.10.0}/.github/FUNDING.yml +0 -0
  9. {x_transformers-2.9.1 → x_transformers-2.10.0}/.github/workflows/python-publish.yml +0 -0
  10. {x_transformers-2.9.1 → x_transformers-2.10.0}/.github/workflows/python-test.yaml +0 -0
  11. {x_transformers-2.9.1 → x_transformers-2.10.0}/.gitignore +0 -0
  12. {x_transformers-2.9.1 → x_transformers-2.10.0}/LICENSE +0 -0
  13. {x_transformers-2.9.1 → x_transformers-2.10.0}/data/README.md +0 -0
  14. {x_transformers-2.9.1 → x_transformers-2.10.0}/data/enwik8.gz +0 -0
  15. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/all-attention.png +0 -0
  16. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/attention-on-attention.png +0 -0
  17. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/cosine-sim-attention.png +0 -0
  18. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/deepnorm.png +0 -0
  19. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/dynamic-pos-bias-linear.png +0 -0
  20. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/dynamic-pos-bias-log.png +0 -0
  21. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  22. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/dynamic-pos-bias.png +0 -0
  23. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/enhanced-recurrence.png +0 -0
  24. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/fcm.png +0 -0
  25. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/ffglu.png +0 -0
  26. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/flash-attention.png +0 -0
  27. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/gate_values.png +0 -0
  28. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/gating.png +0 -0
  29. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/length-extrapolation-scale.png +0 -0
  30. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/macaron-1.png +0 -0
  31. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/macaron-2.png +0 -0
  32. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/memory-transformer.png +0 -0
  33. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/normformer.png +0 -0
  34. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/pia.png +0 -0
  35. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/qknorm-analysis.png +0 -0
  36. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/resi_dual.png +0 -0
  37. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/residual_attn.png +0 -0
  38. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/rezero.png +0 -0
  39. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/rotary.png +0 -0
  40. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/sandwich-2.png +0 -0
  41. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/sandwich.png +0 -0
  42. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/sandwich_norm.png +0 -0
  43. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/scalenorm.png +0 -0
  44. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/talking-heads.png +0 -0
  45. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/topk-attention.png +0 -0
  46. {x_transformers-2.9.1 → x_transformers-2.10.0}/images/xval.png +0 -0
  47. {x_transformers-2.9.1 → x_transformers-2.10.0}/train_belief_state.py +0 -0
  48. {x_transformers-2.9.1 → x_transformers-2.10.0}/train_copy.py +0 -0
  49. {x_transformers-2.9.1 → x_transformers-2.10.0}/train_entropy_tokenizer.py +0 -0
  50. {x_transformers-2.9.1 → x_transformers-2.10.0}/train_enwik8.py +0 -0
  51. {x_transformers-2.9.1 → x_transformers-2.10.0}/train_gpt_vae.py +0 -0
  52. {x_transformers-2.9.1 → x_transformers-2.10.0}/train_length_extrapolate.py +0 -0
  53. {x_transformers-2.9.1 → x_transformers-2.10.0}/train_parity.py +0 -0
  54. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/autoregressive_wrapper.py +0 -0
  56. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/belief_state_wrapper.py +0 -0
  57. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/continuous.py +0 -0
  58. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/dpo.py +0 -0
  59. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  60. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/gpt_vae.py +0 -0
  61. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/multi_input.py +0 -0
  62. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/neo_mlp.py +0 -0
  63. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  64. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/up_wrapper.py +0 -0
  65. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  66. {x_transformers-2.9.1 → x_transformers-2.10.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.9.1
3
+ Version: 2.10.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2586,4 +2586,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2586
2586
  }
2587
2587
  ```
2588
2588
 
2589
+ ```bibtex
2590
+ @inproceedings{anonymous2025more,
2591
+ title = {More Expressive Attention with Negative Weights},
2592
+ author = {Anonymous},
2593
+ booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
2594
+ year = {2025},
2595
+ url = {https://openreview.net/forum?id=ezRrwwbxd0},
2596
+ note = {under review}
2597
+ }
2598
+ ```
2599
+
2589
2600
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2537,4 +2537,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2537
2537
  }
2538
2538
  ```
2539
2539
 
2540
+ ```bibtex
2541
+ @inproceedings{anonymous2025more,
2542
+ title = {More Expressive Attention with Negative Weights},
2543
+ author = {Anonymous},
2544
+ booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
2545
+ year = {2025},
2546
+ url = {https://openreview.net/forum?id=ezRrwwbxd0},
2547
+ note = {under review}
2548
+ }
2549
+ ```
2550
+
2540
2551
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.9.1"
3
+ version = "2.10.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1378,6 +1378,28 @@ def test_stochastic_attn():
1378
1378
  from x_transformers import Attention
1379
1379
 
1380
1380
  attn = Attention(dim = 512, gumbel_softmax = True)
1381
- out = attn(torch.randn(1, 1024, 512))
1381
+ out, intermediate = attn(torch.randn(1, 1024, 512), return_intermediates = True)
1382
1382
 
1383
1383
  assert out.shape == (1, 1024, 512)
1384
+
1385
+ from x_transformers.attend import log_prob_from_hard_attend
1386
+ log_probs = log_prob_from_hard_attend(intermediate)
1387
+ assert log_probs.shape == (1, 8, 1024)
1388
+
1389
+ def test_attn_negative_weights():
1390
+ from x_transformers import TransformerWrapper, Decoder
1391
+
1392
+ model = TransformerWrapper(
1393
+ num_tokens = 256,
1394
+ max_seq_len = 1024,
1395
+ attn_layers = Decoder(
1396
+ dim = 512,
1397
+ depth = 12,
1398
+ heads = 8,
1399
+ attn_cog_signed = True
1400
+ ),
1401
+ )
1402
+
1403
+ x = torch.randint(0, 256, (1, 10))
1404
+
1405
+ logits = model(x)
@@ -1,7 +1,7 @@
1
1
  # /// script
2
2
  # dependencies = [
3
3
  # "x-transformers",
4
- # "adam-atan2-pytorch>=0.2.2",
4
+ # "adam-atan2-pytorch>=0.2.4",
5
5
  # ]
6
6
  # ///
7
7
 
@@ -25,7 +25,6 @@ NUM_BATCHES = int(1e5)
25
25
  BATCH_SIZE = 4
26
26
  GRADIENT_ACCUMULATE_EVERY = 4
27
27
  LEARNING_RATE = 1e-4
28
- MUON_LEARNING_RATE = 1e-3
29
28
  VALIDATE_EVERY = 100
30
29
  GENERATE_EVERY = 500
31
30
  GENERATE_LENGTH = 1024
@@ -92,8 +91,7 @@ optim = MuonAdamAtan2(
92
91
  muon_params = model.muon_parameters(),
93
92
  params = model.parameters(),
94
93
  remove_muon_params_from_params = True,
95
- lr = LEARNING_RATE,
96
- muon_lr = MUON_LEARNING_RATE,
94
+ lr = LEARNING_RATE
97
95
  )
98
96
 
99
97
  # training
@@ -67,6 +67,15 @@ def once(fn):
67
67
 
68
68
  print_once = once(print)
69
69
 
70
+ # gumbel softmax attention related
71
+
72
+ def log_prob_from_hard_attend(intermeds: Intermediates):
73
+ log_probs = intermeds.pre_softmax_attn.log_softmax(dim = -1)
74
+
75
+ one_hot = intermeds.post_softmax_attn.argmax(dim = -1, keepdim = True)
76
+ log_prob = log_probs.gather(-1, one_hot)
77
+ return rearrange(log_prob, 'b h i 1 -> b h i')
78
+
70
79
  # selective attention
71
80
  # https://arxiv.org/abs/2410.02703 - section 3.3
72
81
  # it is a technique to allow each token to prevent itself from being attended to by future tokens
@@ -174,6 +183,7 @@ class Attend(Module):
174
183
  gumbel_softmax = False,
175
184
  gumbel_softmax_temp = 1.,
176
185
  gumbel_softmax_hard = True,
186
+ cog_signed = False,
177
187
  custom_attn_fn: Callable | None = None,
178
188
  flash = False,
179
189
  softclamp_logits = False,
@@ -251,6 +261,12 @@ class Attend(Module):
251
261
  assert not (selective and not causal), 'selective attention is designed for autoregressive'
252
262
  self.selective = selective
253
263
 
264
+ # cog attention - negative weights for expressiveness
265
+ # https://openreview.net/forum?id=ezRrwwbxd0
266
+
267
+ assert not (flash and cog_signed), 'cog attention not available for flash'
268
+ self.cog_signed = cog_signed
269
+
254
270
  # l2 distance attention
255
271
 
256
272
  self.l2_distance = l2_distance
@@ -500,6 +516,13 @@ class Attend(Module):
500
516
  if self.softclamp_logits:
501
517
  sim = softclamp(sim, self.logit_softclamp_value)
502
518
 
519
+ # pre-masking - handle cog by storing sign
520
+
521
+ if self.cog_signed:
522
+ sim_sign = sim.sign()
523
+
524
+ # masking
525
+
503
526
  i, j, dtype = *sim.shape[-2:], sim.dtype
504
527
 
505
528
  mask_value = -torch.finfo(sim.dtype).max
@@ -529,10 +552,18 @@ class Attend(Module):
529
552
 
530
553
  pre_softmax_attn = sim
531
554
 
555
+ if self.cog_signed:
556
+ sim = sim.abs()
557
+
532
558
  attn = self.attn_fn(sim)
533
559
 
534
560
  attn = attn.type(dtype)
535
561
 
562
+ # add back the sign
563
+
564
+ if self.cog_signed:
565
+ attn = attn * sim_sign
566
+
536
567
  post_softmax_attn = attn
537
568
 
538
569
  if self.head_learned_sink:
@@ -1340,6 +1340,7 @@ class Attention(Module):
1340
1340
  gumbel_softmax_temp = 1.,
1341
1341
  gumbel_softmax_hard = True,
1342
1342
  selective = False,
1343
+ cog_signed = False,
1343
1344
  custom_attn_fn: Callable | None = None,
1344
1345
  hybrid_module: Module | None = None,
1345
1346
  hybrid_mask_kwarg: str | None = None,
@@ -1548,6 +1549,7 @@ class Attention(Module):
1548
1549
  gumbel_softmax_temp = gumbel_softmax_temp,
1549
1550
  gumbel_softmax_hard = gumbel_softmax_hard,
1550
1551
  selective = selective,
1552
+ cog_signed = cog_signed,
1551
1553
  custom_attn_fn = custom_attn_fn,
1552
1554
  add_zero_kv = add_zero_kv,
1553
1555
  head_learned_sink = head_learned_sink,
File without changes