x-transformers 2.8.2__tar.gz → 2.8.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {x_transformers-2.8.2 → x_transformers-2.8.4}/PKG-INFO +22 -1
  2. {x_transformers-2.8.2 → x_transformers-2.8.4}/README.md +21 -0
  3. {x_transformers-2.8.2 → x_transformers-2.8.4}/pyproject.toml +1 -1
  4. {x_transformers-2.8.2 → x_transformers-2.8.4}/tests/test_x_transformers.py +13 -0
  5. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/x_transformers.py +33 -5
  6. {x_transformers-2.8.2 → x_transformers-2.8.4}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.8.2 → x_transformers-2.8.4}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.8.2 → x_transformers-2.8.4}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.8.2 → x_transformers-2.8.4}/.gitignore +0 -0
  10. {x_transformers-2.8.2 → x_transformers-2.8.4}/LICENSE +0 -0
  11. {x_transformers-2.8.2 → x_transformers-2.8.4}/data/README.md +0 -0
  12. {x_transformers-2.8.2 → x_transformers-2.8.4}/data/enwik8.gz +0 -0
  13. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/all-attention.png +0 -0
  14. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/deepnorm.png +0 -0
  17. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/fcm.png +0 -0
  23. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/ffglu.png +0 -0
  24. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/flash-attention.png +0 -0
  25. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/gate_values.png +0 -0
  26. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/gating.png +0 -0
  27. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/macaron-1.png +0 -0
  29. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/macaron-2.png +0 -0
  30. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/normformer.png +0 -0
  32. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/pia.png +0 -0
  33. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/resi_dual.png +0 -0
  35. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/residual_attn.png +0 -0
  36. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/rezero.png +0 -0
  37. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/rotary.png +0 -0
  38. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/sandwich.png +0 -0
  40. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/scalenorm.png +0 -0
  42. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/talking-heads.png +0 -0
  43. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/topk-attention.png +0 -0
  44. {x_transformers-2.8.2 → x_transformers-2.8.4}/images/xval.png +0 -0
  45. {x_transformers-2.8.2 → x_transformers-2.8.4}/train_belief_state.py +0 -0
  46. {x_transformers-2.8.2 → x_transformers-2.8.4}/train_copy.py +0 -0
  47. {x_transformers-2.8.2 → x_transformers-2.8.4}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.8.2 → x_transformers-2.8.4}/train_enwik8.py +0 -0
  49. {x_transformers-2.8.2 → x_transformers-2.8.4}/train_gpt_vae.py +0 -0
  50. {x_transformers-2.8.2 → x_transformers-2.8.4}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.8.2 → x_transformers-2.8.4}/train_parity.py +0 -0
  52. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/__init__.py +0 -0
  53. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/attend.py +0 -0
  54. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/autoregressive_wrapper.py +0 -0
  55. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/belief_state_wrapper.py +0 -0
  56. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/continuous.py +0 -0
  57. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/dpo.py +0 -0
  58. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/entropy_based_tokenizer.py +0 -0
  59. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/gpt_vae.py +0 -0
  60. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/multi_input.py +0 -0
  61. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/neo_mlp.py +0 -0
  62. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/nonautoregressive_wrapper.py +0 -0
  63. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/up_wrapper.py +0 -0
  64. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  65. {x_transformers-2.8.2 → x_transformers-2.8.4}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.8.2
3
+ Version: 2.8.4
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2552,4 +2552,25 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2552
2552
  }
2553
2553
  ```
2554
2554
 
2555
+ ```bibtex
2556
+ @misc{jordan2024muon,
2557
+ author = {Keller Jordan and Yuchen Jin and Vlado Boza and Jiacheng You and Franz Cesista and Laker Newhouse and Jeremy Bernstein},
2558
+ title = {Muon: An optimizer for hidden layers in neural networks},
2559
+ year = {2024},
2560
+ url = {https://kellerjordan.github.io/posts/muon/}
2561
+ }
2562
+ ```
2563
+
2564
+ ```bibtex
2565
+ @misc{wang2025muonoutperformsadamtailend,
2566
+ title = {Muon Outperforms Adam in Tail-End Associative Memory Learning},
2567
+ author = {Shuche Wang and Fengzhuo Zhang and Jiaxiang Li and Cunxiao Du and Chao Du and Tianyu Pang and Zhuoran Yang and Mingyi Hong and Vincent Y. F. Tan},
2568
+ year = {2025},
2569
+ eprint = {2509.26030},
2570
+ archivePrefix = {arXiv},
2571
+ primaryClass = {cs.LG},
2572
+ url = {https://arxiv.org/abs/2509.26030},
2573
+ }
2574
+ ```
2575
+
2555
2576
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2504,4 +2504,25 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2504
2504
  }
2505
2505
  ```
2506
2506
 
2507
+ ```bibtex
2508
+ @misc{jordan2024muon,
2509
+ author = {Keller Jordan and Yuchen Jin and Vlado Boza and Jiacheng You and Franz Cesista and Laker Newhouse and Jeremy Bernstein},
2510
+ title = {Muon: An optimizer for hidden layers in neural networks},
2511
+ year = {2024},
2512
+ url = {https://kellerjordan.github.io/posts/muon/}
2513
+ }
2514
+ ```
2515
+
2516
+ ```bibtex
2517
+ @misc{wang2025muonoutperformsadamtailend,
2518
+ title = {Muon Outperforms Adam in Tail-End Associative Memory Learning},
2519
+ author = {Shuche Wang and Fengzhuo Zhang and Jiaxiang Li and Cunxiao Du and Chao Du and Tianyu Pang and Zhuoran Yang and Mingyi Hong and Vincent Y. F. Tan},
2520
+ year = {2025},
2521
+ eprint = {2509.26030},
2522
+ archivePrefix = {arXiv},
2523
+ primaryClass = {cs.LG},
2524
+ url = {https://arxiv.org/abs/2509.26030},
2525
+ }
2526
+ ```
2527
+
2507
2528
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.8.2"
3
+ version = "2.8.4"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1360,3 +1360,16 @@ def test_vae():
1360
1360
  style = torch.randint(0, 256, (1, 1024))
1361
1361
 
1362
1362
  out = model.generate(seq[:, :512], 512, seq_for_latents = style)
1363
+
1364
+ def test_muon_params():
1365
+ from x_transformers import Attention, FeedForward, Encoder
1366
+
1367
+ attn = Attention(dim = 512, dim_out = 384)
1368
+ assert len(list(attn.muon_parameters())) == 2
1369
+
1370
+ ff = FeedForward(dim = 512)
1371
+
1372
+ assert len(list(ff.muon_parameters())) == 2
1373
+
1374
+ enc = Encoder(dim = 512, depth = 2)
1375
+ assert len(enc.muon_parameters()) > 0
@@ -4,6 +4,11 @@ from typing import Callable
4
4
  import math
5
5
  from copy import deepcopy
6
6
  from random import random, randrange
7
+ from functools import partial, wraps
8
+ from itertools import chain
9
+ from collections import namedtuple
10
+ from contextlib import nullcontext
11
+ from dataclasses import dataclass
7
12
  from packaging import version
8
13
 
9
14
  import torch
@@ -13,11 +18,6 @@ from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
13
18
  from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
14
19
  from torch.nn import Module, ModuleList, ModuleDict
15
20
 
16
- from functools import partial, wraps
17
- from collections import namedtuple
18
- from contextlib import nullcontext
19
- from dataclasses import dataclass
20
-
21
21
  from loguru import logger
22
22
 
23
23
  from x_transformers.attend import Attend, Intermediates
@@ -1279,6 +1279,17 @@ class FeedForward(Module):
1279
1279
  if zero_init_output:
1280
1280
  init_zero_(proj_out)
1281
1281
 
1282
+ def muon_parameters(self):
1283
+ weights = []
1284
+
1285
+ for m in self.modules():
1286
+ if not isinstance(m, nn.Linear):
1287
+ continue
1288
+
1289
+ weights.append(m.weight)
1290
+
1291
+ return weights
1292
+
1282
1293
  def forward(
1283
1294
  self,
1284
1295
  x,
@@ -1644,6 +1655,9 @@ class Attention(Module):
1644
1655
  q_weight.mul_(qk_weight_scale)
1645
1656
  k_weight.mul_(qk_weight_scale)
1646
1657
 
1658
+ def muon_parameters(self):
1659
+ return chain(self.to_v.parameters(), self.to_out.parameters())
1660
+
1647
1661
  def forward(
1648
1662
  self,
1649
1663
  x,
@@ -2479,6 +2493,17 @@ class AttentionLayers(Module):
2479
2493
  for attn_layer, attn_inter in zip(attn_layers, attn_intermeds):
2480
2494
  attn_layer.qk_clip_(attn_inter, tau = tau)
2481
2495
 
2496
+ def muon_parameters(self):
2497
+ params = []
2498
+
2499
+ for m in self.modules():
2500
+ if not isinstance(m, (Attention, FeedForward)):
2501
+ continue
2502
+
2503
+ params.extend(list(m.muon_parameters()))
2504
+
2505
+ return params
2506
+
2482
2507
  def forward(
2483
2508
  self,
2484
2509
  x,
@@ -3216,6 +3241,9 @@ class TransformerWrapper(Module):
3216
3241
  ):
3217
3242
  self.attn_layers.attn_qk_clip_(intermediates, tau = tau)
3218
3243
 
3244
+ def muon_parameters(self):
3245
+ return self.attn_layers.muon_parameters()
3246
+
3219
3247
  def forward(
3220
3248
  self,
3221
3249
  x,
File without changes