x-transformers 2.1.29__py3-none-any.whl → 2.1.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +32 -3
- {x_transformers-2.1.29.dist-info → x_transformers-2.1.31.dist-info}/METADATA +10 -1
- {x_transformers-2.1.29.dist-info → x_transformers-2.1.31.dist-info}/RECORD +5 -5
- {x_transformers-2.1.29.dist-info → x_transformers-2.1.31.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.29.dist-info → x_transformers-2.1.31.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -9,7 +9,7 @@ from packaging import version
|
|
9
9
|
import torch
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
|
-
from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
|
12
|
+
from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
|
13
13
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
@@ -266,7 +266,6 @@ class TokenEmbedding(Module):
|
|
266
266
|
return
|
267
267
|
nn.init.kaiming_normal_(self.emb.weight)
|
268
268
|
|
269
|
-
|
270
269
|
# positional embeddings
|
271
270
|
|
272
271
|
class AbsolutePositionalEmbedding(Module):
|
@@ -849,6 +848,32 @@ class MultiheadRMSNorm(Module):
|
|
849
848
|
def forward(self, x):
|
850
849
|
return self.rmsnorm(x) * (self.gamma + 1.)
|
851
850
|
|
851
|
+
class DynamicTanh(Module):
|
852
|
+
""" https://arxiv.org/abs/2503.10622 """
|
853
|
+
def __init__(
|
854
|
+
self,
|
855
|
+
dim,
|
856
|
+
init_alpha = 1.,
|
857
|
+
gamma = 1.,
|
858
|
+
beta = 0.,
|
859
|
+
unit_offset = False
|
860
|
+
):
|
861
|
+
super().__init__()
|
862
|
+
self.pre_tanh_scale = nn.Parameter(tensor(init_alpha))
|
863
|
+
|
864
|
+
self.gamma = nn.Parameter(torch.ones(dim))
|
865
|
+
self.beta = nn.Parameter(torch.zeros(dim))
|
866
|
+
|
867
|
+
self.unit_offset = int(unit_offset)
|
868
|
+
|
869
|
+
nn.init.constant_(self.pre_tanh_scale, 1. - float(unit_offset))
|
870
|
+
nn.init.constant_(self.gamma, 1. - float(unit_offset))
|
871
|
+
|
872
|
+
def forward(self, x):
|
873
|
+
pre_tanh_scale = self.pre_tanh_scale + self.unit_offset
|
874
|
+
gamma = self.gamma + self.unit_offset
|
875
|
+
return (x * pre_tanh_scale).tanh() * gamma + self.beta
|
876
|
+
|
852
877
|
# residual and residual gates
|
853
878
|
|
854
879
|
class Residual(Module):
|
@@ -1863,6 +1888,8 @@ class AttentionLayers(Module):
|
|
1863
1888
|
only_cross = False,
|
1864
1889
|
use_scalenorm = False,
|
1865
1890
|
use_rmsnorm = False,
|
1891
|
+
use_dynamic_tanh = False,
|
1892
|
+
dynamic_tanh_init_alpha = 1.,
|
1866
1893
|
use_simple_rmsnorm = False,
|
1867
1894
|
use_adaptive_layernorm = False,
|
1868
1895
|
use_adaptive_rmsnorm = False,
|
@@ -2012,7 +2039,7 @@ class AttentionLayers(Module):
|
|
2012
2039
|
|
2013
2040
|
# determine norm
|
2014
2041
|
|
2015
|
-
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
2042
|
+
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
2016
2043
|
|
2017
2044
|
norm_need_condition = False
|
2018
2045
|
dim_condition = default(dim_condition, dim)
|
@@ -2027,6 +2054,8 @@ class AttentionLayers(Module):
|
|
2027
2054
|
norm_class = RMSNorm
|
2028
2055
|
elif use_simple_rmsnorm:
|
2029
2056
|
norm_class = SimpleRMSNorm
|
2057
|
+
elif use_dynamic_tanh:
|
2058
|
+
norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
|
2030
2059
|
elif use_adaptive_layernorm:
|
2031
2060
|
norm_need_condition = True
|
2032
2061
|
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.31
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2455,4 +2455,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2455
2455
|
}
|
2456
2456
|
```
|
2457
2457
|
|
2458
|
+
```bibtex
|
2459
|
+
@inproceedings{Zhu2025TransformersWN,
|
2460
|
+
title = {Transformers without Normalization},
|
2461
|
+
author = {Jiachen Zhu and Xinlei Chen and Kaiming He and Yann LeCun and Zhuang Liu},
|
2462
|
+
year = {2025},
|
2463
|
+
url = {https://api.semanticscholar.org/CorpusID:276961218}
|
2464
|
+
}
|
2465
|
+
```
|
2466
|
+
|
2458
2467
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -7,10 +7,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
8
8
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
9
9
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
10
|
-
x_transformers/x_transformers.py,sha256=
|
10
|
+
x_transformers/x_transformers.py,sha256=X6xE_y_rCP6m4Ov2GahbPY5QABsc94y38PrnWSS9F9U,111428
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.31.dist-info/METADATA,sha256=tdus1LlIJQXbeq4G_Qt1QaEGkMkOoB7flxSoEe0CYNw,88161
|
14
|
+
x_transformers-2.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.31.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|