x-transformers 2.8.1__py3-none-any.whl → 2.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/gpt_vae.py CHANGED
@@ -46,25 +46,29 @@ class GPTVAE(Module):
46
46
  vae_kl_loss_weight = 1.,
47
47
  latents_dropout_prob = 0.5, # what percentage of the time to dropout the latents completely
48
48
  pad_id = -1,
49
+ encoder: Module | None = None,
49
50
  **kwargs
50
51
  ):
51
52
  super().__init__()
52
53
  dim_latent = default(dim_latent, dim)
53
54
 
54
- self.encoder = TransformerWrapper(
55
- num_tokens = num_tokens,
56
- max_seq_len = max_seq_len + 1,
57
- return_only_embed = True,
58
- average_pool_embed = True,
59
- attn_layers = Encoder(
60
- dim = dim,
61
- depth = enc_depth,
62
- attn_dim_head = attn_dim_head,
63
- heads = heads,
64
- **kwargs,
65
- **enc_kwargs
66
- ),
67
- )
55
+ if not exists(encoder):
56
+ encoder = TransformerWrapper(
57
+ num_tokens = num_tokens,
58
+ max_seq_len = max_seq_len + 1,
59
+ return_only_embed = True,
60
+ average_pool_embed = True,
61
+ attn_layers = Encoder(
62
+ dim = dim,
63
+ depth = enc_depth,
64
+ attn_dim_head = attn_dim_head,
65
+ heads = heads,
66
+ **kwargs,
67
+ **enc_kwargs
68
+ ),
69
+ )
70
+
71
+ self.encoder = encoder
68
72
 
69
73
  self.to_latent_mean_log_variance = nn.Sequential(
70
74
  nn.Linear(dim, dim_latent * 2),
@@ -4,6 +4,11 @@ from typing import Callable
4
4
  import math
5
5
  from copy import deepcopy
6
6
  from random import random, randrange
7
+ from functools import partial, wraps
8
+ from itertools import chain
9
+ from collections import namedtuple
10
+ from contextlib import nullcontext
11
+ from dataclasses import dataclass
7
12
  from packaging import version
8
13
 
9
14
  import torch
@@ -13,11 +18,6 @@ from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
13
18
  from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
14
19
  from torch.nn import Module, ModuleList, ModuleDict
15
20
 
16
- from functools import partial, wraps
17
- from collections import namedtuple
18
- from contextlib import nullcontext
19
- from dataclasses import dataclass
20
-
21
21
  from loguru import logger
22
22
 
23
23
  from x_transformers.attend import Attend, Intermediates
@@ -1279,6 +1279,17 @@ class FeedForward(Module):
1279
1279
  if zero_init_output:
1280
1280
  init_zero_(proj_out)
1281
1281
 
1282
+ def muon_parameters(self):
1283
+ weights = []
1284
+
1285
+ for m in self.modules():
1286
+ if not isinstance(m, nn.Linear):
1287
+ continue
1288
+
1289
+ weights.append(m.weight)
1290
+
1291
+ return weights
1292
+
1282
1293
  def forward(
1283
1294
  self,
1284
1295
  x,
@@ -1644,6 +1655,9 @@ class Attention(Module):
1644
1655
  q_weight.mul_(qk_weight_scale)
1645
1656
  k_weight.mul_(qk_weight_scale)
1646
1657
 
1658
+ def muon_parameters(self):
1659
+ return chain(self.to_v.parameters(), self.to_out.parameters())
1660
+
1647
1661
  def forward(
1648
1662
  self,
1649
1663
  x,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.8.1
3
+ Version: 2.8.3
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2552,4 +2552,25 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2552
2552
  }
2553
2553
  ```
2554
2554
 
2555
+ ```bibtex
2556
+ @misc{jordan2024muon,
2557
+ author = {Keller Jordan and Yuchen Jin and Vlado Boza and Jiacheng You and Franz Cesista and Laker Newhouse and Jeremy Bernstein},
2558
+ title = {Muon: An optimizer for hidden layers in neural networks},
2559
+ year = {2024},
2560
+ url = {https://kellerjordan.github.io/posts/muon/}
2561
+ }
2562
+ ```
2563
+
2564
+ ```bibtex
2565
+ @misc{wang2025muonoutperformsadamtailend,
2566
+ title = {Muon Outperforms Adam in Tail-End Associative Memory Learning},
2567
+ author = {Shuche Wang and Fengzhuo Zhang and Jiaxiang Li and Cunxiao Du and Chao Du and Tianyu Pang and Zhuoran Yang and Mingyi Hong and Vincent Y. F. Tan},
2568
+ year = {2025},
2569
+ eprint = {2509.26030},
2570
+ archivePrefix = {arXiv},
2571
+ primaryClass = {cs.LG},
2572
+ url = {https://arxiv.org/abs/2509.26030},
2573
+ }
2574
+ ```
2575
+
2555
2576
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -5,15 +5,15 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
- x_transformers/gpt_vae.py,sha256=Q2pzQ6iXRnP2Bfa6g-fs4US-JTouXB5-MfKw3sTwWmU,5561
8
+ x_transformers/gpt_vae.py,sha256=myYSgcx66V0M4zeEGKyhY1P2HlPDHcezhaZEoo_uMdo,5715
9
9
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
10
10
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
11
11
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
12
12
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
13
- x_transformers/x_transformers.py,sha256=odnCZAKZKrQLXmpaWhiPVB5elGjt8kerDbO3-yeC-60,124764
13
+ x_transformers/x_transformers.py,sha256=gnmhtxPdmVQTd59MFXcGSm9HCKH9jv1fTBBYWAu5qaI,125113
14
14
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
15
15
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
16
- x_transformers-2.8.1.dist-info/METADATA,sha256=_PnvoOSFJAgrpEfpNNljxdeYQ3BhDYJvVOp7yjaF-iM,94136
17
- x_transformers-2.8.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
- x_transformers-2.8.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
- x_transformers-2.8.1.dist-info/RECORD,,
16
+ x_transformers-2.8.3.dist-info/METADATA,sha256=vB7jRRZOX58zB9QhBagiQ3u61t6Xd6XMzWwnDngroVw,94924
17
+ x_transformers-2.8.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ x_transformers-2.8.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
+ x_transformers-2.8.3.dist-info/RECORD,,