x-transformers 2.1.29__py3-none-any.whl → 2.1.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
12
+ from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -266,7 +266,6 @@ class TokenEmbedding(Module):
266
266
  return
267
267
  nn.init.kaiming_normal_(self.emb.weight)
268
268
 
269
-
270
269
  # positional embeddings
271
270
 
272
271
  class AbsolutePositionalEmbedding(Module):
@@ -849,6 +848,32 @@ class MultiheadRMSNorm(Module):
849
848
  def forward(self, x):
850
849
  return self.rmsnorm(x) * (self.gamma + 1.)
851
850
 
851
+ class DynamicTanh(Module):
852
+ """ https://arxiv.org/abs/2503.10622 """
853
+ def __init__(
854
+ self,
855
+ dim,
856
+ init_alpha = 1.,
857
+ gamma = 1.,
858
+ beta = 0.,
859
+ unit_offset = False
860
+ ):
861
+ super().__init__()
862
+ self.pre_tanh_scale = nn.Parameter(tensor(init_alpha))
863
+
864
+ self.gamma = nn.Parameter(torch.ones(dim))
865
+ self.beta = nn.Parameter(torch.zeros(dim))
866
+
867
+ self.unit_offset = int(unit_offset)
868
+
869
+ nn.init.constant_(self.pre_tanh_scale, 1. - float(unit_offset))
870
+ nn.init.constant_(self.gamma, 1. - float(unit_offset))
871
+
872
+ def forward(self, x):
873
+ pre_tanh_scale = self.pre_tanh_scale + self.unit_offset
874
+ gamma = self.gamma + self.unit_offset
875
+ return (x * pre_tanh_scale).tanh() * gamma + self.beta
876
+
852
877
  # residual and residual gates
853
878
 
854
879
  class Residual(Module):
@@ -1863,6 +1888,8 @@ class AttentionLayers(Module):
1863
1888
  only_cross = False,
1864
1889
  use_scalenorm = False,
1865
1890
  use_rmsnorm = False,
1891
+ use_dynamic_tanh = False,
1892
+ dynamic_tanh_init_alpha = 1.,
1866
1893
  use_simple_rmsnorm = False,
1867
1894
  use_adaptive_layernorm = False,
1868
1895
  use_adaptive_rmsnorm = False,
@@ -2012,7 +2039,7 @@ class AttentionLayers(Module):
2012
2039
 
2013
2040
  # determine norm
2014
2041
 
2015
- assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2042
+ assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2016
2043
 
2017
2044
  norm_need_condition = False
2018
2045
  dim_condition = default(dim_condition, dim)
@@ -2027,6 +2054,8 @@ class AttentionLayers(Module):
2027
2054
  norm_class = RMSNorm
2028
2055
  elif use_simple_rmsnorm:
2029
2056
  norm_class = SimpleRMSNorm
2057
+ elif use_dynamic_tanh:
2058
+ norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2030
2059
  elif use_adaptive_layernorm:
2031
2060
  norm_need_condition = True
2032
2061
  norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.29
3
+ Version: 2.1.31
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2455,4 +2455,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2455
2455
  }
2456
2456
  ```
2457
2457
 
2458
+ ```bibtex
2459
+ @inproceedings{Zhu2025TransformersWN,
2460
+ title = {Transformers without Normalization},
2461
+ author = {Jiachen Zhu and Xinlei Chen and Kaiming He and Yann LeCun and Zhuang Liu},
2462
+ year = {2025},
2463
+ url = {https://api.semanticscholar.org/CorpusID:276961218}
2464
+ }
2465
+ ```
2466
+
2458
2467
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -7,10 +7,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
8
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
9
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
- x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
10
+ x_transformers/x_transformers.py,sha256=X6xE_y_rCP6m4Ov2GahbPY5QABsc94y38PrnWSS9F9U,111428
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.29.dist-info/METADATA,sha256=CI6GLna-OqlmDEjv8sP0CcfI7SNCAbL-nQCm2sQqdbc,87875
14
- x_transformers-2.1.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.29.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.29.dist-info/RECORD,,
13
+ x_transformers-2.1.31.dist-info/METADATA,sha256=tdus1LlIJQXbeq4G_Qt1QaEGkMkOoB7flxSoEe0CYNw,88161
14
+ x_transformers-2.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.31.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.31.dist-info/RECORD,,