x-transformers 2.8.1__py3-none-any.whl → 2.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/gpt_vae.py +18 -14
- x_transformers/x_transformers.py +19 -5
- {x_transformers-2.8.1.dist-info → x_transformers-2.8.3.dist-info}/METADATA +22 -1
- {x_transformers-2.8.1.dist-info → x_transformers-2.8.3.dist-info}/RECORD +6 -6
- {x_transformers-2.8.1.dist-info → x_transformers-2.8.3.dist-info}/WHEEL +0 -0
- {x_transformers-2.8.1.dist-info → x_transformers-2.8.3.dist-info}/licenses/LICENSE +0 -0
x_transformers/gpt_vae.py
CHANGED
@@ -46,25 +46,29 @@ class GPTVAE(Module):
|
|
46
46
|
vae_kl_loss_weight = 1.,
|
47
47
|
latents_dropout_prob = 0.5, # what percentage of the time to dropout the latents completely
|
48
48
|
pad_id = -1,
|
49
|
+
encoder: Module | None = None,
|
49
50
|
**kwargs
|
50
51
|
):
|
51
52
|
super().__init__()
|
52
53
|
dim_latent = default(dim_latent, dim)
|
53
54
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
55
|
+
if not exists(encoder):
|
56
|
+
encoder = TransformerWrapper(
|
57
|
+
num_tokens = num_tokens,
|
58
|
+
max_seq_len = max_seq_len + 1,
|
59
|
+
return_only_embed = True,
|
60
|
+
average_pool_embed = True,
|
61
|
+
attn_layers = Encoder(
|
62
|
+
dim = dim,
|
63
|
+
depth = enc_depth,
|
64
|
+
attn_dim_head = attn_dim_head,
|
65
|
+
heads = heads,
|
66
|
+
**kwargs,
|
67
|
+
**enc_kwargs
|
68
|
+
),
|
69
|
+
)
|
70
|
+
|
71
|
+
self.encoder = encoder
|
68
72
|
|
69
73
|
self.to_latent_mean_log_variance = nn.Sequential(
|
70
74
|
nn.Linear(dim, dim_latent * 2),
|
x_transformers/x_transformers.py
CHANGED
@@ -4,6 +4,11 @@ from typing import Callable
|
|
4
4
|
import math
|
5
5
|
from copy import deepcopy
|
6
6
|
from random import random, randrange
|
7
|
+
from functools import partial, wraps
|
8
|
+
from itertools import chain
|
9
|
+
from collections import namedtuple
|
10
|
+
from contextlib import nullcontext
|
11
|
+
from dataclasses import dataclass
|
7
12
|
from packaging import version
|
8
13
|
|
9
14
|
import torch
|
@@ -13,11 +18,6 @@ from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
|
|
13
18
|
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
|
14
19
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
20
|
|
16
|
-
from functools import partial, wraps
|
17
|
-
from collections import namedtuple
|
18
|
-
from contextlib import nullcontext
|
19
|
-
from dataclasses import dataclass
|
20
|
-
|
21
21
|
from loguru import logger
|
22
22
|
|
23
23
|
from x_transformers.attend import Attend, Intermediates
|
@@ -1279,6 +1279,17 @@ class FeedForward(Module):
|
|
1279
1279
|
if zero_init_output:
|
1280
1280
|
init_zero_(proj_out)
|
1281
1281
|
|
1282
|
+
def muon_parameters(self):
|
1283
|
+
weights = []
|
1284
|
+
|
1285
|
+
for m in self.modules():
|
1286
|
+
if not isinstance(m, nn.Linear):
|
1287
|
+
continue
|
1288
|
+
|
1289
|
+
weights.append(m.weight)
|
1290
|
+
|
1291
|
+
return weights
|
1292
|
+
|
1282
1293
|
def forward(
|
1283
1294
|
self,
|
1284
1295
|
x,
|
@@ -1644,6 +1655,9 @@ class Attention(Module):
|
|
1644
1655
|
q_weight.mul_(qk_weight_scale)
|
1645
1656
|
k_weight.mul_(qk_weight_scale)
|
1646
1657
|
|
1658
|
+
def muon_parameters(self):
|
1659
|
+
return chain(self.to_v.parameters(), self.to_out.parameters())
|
1660
|
+
|
1647
1661
|
def forward(
|
1648
1662
|
self,
|
1649
1663
|
x,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.8.
|
3
|
+
Version: 2.8.3
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2552,4 +2552,25 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2552
2552
|
}
|
2553
2553
|
```
|
2554
2554
|
|
2555
|
+
```bibtex
|
2556
|
+
@misc{jordan2024muon,
|
2557
|
+
author = {Keller Jordan and Yuchen Jin and Vlado Boza and Jiacheng You and Franz Cesista and Laker Newhouse and Jeremy Bernstein},
|
2558
|
+
title = {Muon: An optimizer for hidden layers in neural networks},
|
2559
|
+
year = {2024},
|
2560
|
+
url = {https://kellerjordan.github.io/posts/muon/}
|
2561
|
+
}
|
2562
|
+
```
|
2563
|
+
|
2564
|
+
```bibtex
|
2565
|
+
@misc{wang2025muonoutperformsadamtailend,
|
2566
|
+
title = {Muon Outperforms Adam in Tail-End Associative Memory Learning},
|
2567
|
+
author = {Shuche Wang and Fengzhuo Zhang and Jiaxiang Li and Cunxiao Du and Chao Du and Tianyu Pang and Zhuoran Yang and Mingyi Hong and Vincent Y. F. Tan},
|
2568
|
+
year = {2025},
|
2569
|
+
eprint = {2509.26030},
|
2570
|
+
archivePrefix = {arXiv},
|
2571
|
+
primaryClass = {cs.LG},
|
2572
|
+
url = {https://arxiv.org/abs/2509.26030},
|
2573
|
+
}
|
2574
|
+
```
|
2575
|
+
|
2555
2576
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -5,15 +5,15 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
|
|
5
5
|
x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
|
-
x_transformers/gpt_vae.py,sha256=
|
8
|
+
x_transformers/gpt_vae.py,sha256=myYSgcx66V0M4zeEGKyhY1P2HlPDHcezhaZEoo_uMdo,5715
|
9
9
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
10
10
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
11
11
|
x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
|
12
12
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
13
|
-
x_transformers/x_transformers.py,sha256=
|
13
|
+
x_transformers/x_transformers.py,sha256=gnmhtxPdmVQTd59MFXcGSm9HCKH9jv1fTBBYWAu5qaI,125113
|
14
14
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
15
15
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
16
|
-
x_transformers-2.8.
|
17
|
-
x_transformers-2.8.
|
18
|
-
x_transformers-2.8.
|
19
|
-
x_transformers-2.8.
|
16
|
+
x_transformers-2.8.3.dist-info/METADATA,sha256=vB7jRRZOX58zB9QhBagiQ3u61t6Xd6XMzWwnDngroVw,94924
|
17
|
+
x_transformers-2.8.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
18
|
+
x_transformers-2.8.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
19
|
+
x_transformers-2.8.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|