x-transformers 2.1.29__py3-none-any.whl → 2.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
12
+ from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -266,7 +266,6 @@ class TokenEmbedding(Module):
266
266
  return
267
267
  nn.init.kaiming_normal_(self.emb.weight)
268
268
 
269
-
270
269
  # positional embeddings
271
270
 
272
271
  class AbsolutePositionalEmbedding(Module):
@@ -849,6 +848,31 @@ class MultiheadRMSNorm(Module):
849
848
  def forward(self, x):
850
849
  return self.rmsnorm(x) * (self.gamma + 1.)
851
850
 
851
+ class DynamicTanh(Module):
852
+ """ https://arxiv.org/abs/2503.10622 """
853
+ def __init__(
854
+ self,
855
+ init_alpha = 1.,
856
+ gamma = 1.,
857
+ beta = 0.,
858
+ unit_offset = False
859
+ ):
860
+ super().__init__()
861
+ self.pre_tanh_scale = nn.Parameter(tensor(init_alpha))
862
+
863
+ self.gamma = nn.Parameter(tensor(init_alpha))
864
+ self.beta = nn.Parameter(tensor(init_alpha))
865
+
866
+ self.unit_offset = int(unit_offset)
867
+
868
+ nn.init.constant_(self.pre_tanh_scale, 1. - float(unit_offset))
869
+ nn.init.constant_(self.gamma, 1. - float(unit_offset))
870
+
871
+ def forward(self, x):
872
+ pre_tanh_scale = self.pre_tanh_scale + self.unit_offset
873
+ gamma = self.gamma + self.unit_offset
874
+ return (x * pre_tanh_scale).tanh() * gamma + self.beta
875
+
852
876
  # residual and residual gates
853
877
 
854
878
  class Residual(Module):
@@ -1863,6 +1887,8 @@ class AttentionLayers(Module):
1863
1887
  only_cross = False,
1864
1888
  use_scalenorm = False,
1865
1889
  use_rmsnorm = False,
1890
+ use_dynamic_tanh = False,
1891
+ dynamic_tanh_init_alpha = 1.,
1866
1892
  use_simple_rmsnorm = False,
1867
1893
  use_adaptive_layernorm = False,
1868
1894
  use_adaptive_rmsnorm = False,
@@ -2012,8 +2038,9 @@ class AttentionLayers(Module):
2012
2038
 
2013
2039
  # determine norm
2014
2040
 
2015
- assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2041
+ assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2016
2042
 
2043
+ norm_fn = None
2017
2044
  norm_need_condition = False
2018
2045
  dim_condition = default(dim_condition, dim)
2019
2046
  dim_condition_mult = 1
@@ -2027,6 +2054,8 @@ class AttentionLayers(Module):
2027
2054
  norm_class = RMSNorm
2028
2055
  elif use_simple_rmsnorm:
2029
2056
  norm_class = SimpleRMSNorm
2057
+ elif use_dynamic_tanh:
2058
+ norm_fn = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2030
2059
  elif use_adaptive_layernorm:
2031
2060
  norm_need_condition = True
2032
2061
  norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
@@ -2036,7 +2065,8 @@ class AttentionLayers(Module):
2036
2065
  else:
2037
2066
  norm_class = LayerNorm
2038
2067
 
2039
- norm_fn = partial(norm_class, dim)
2068
+ if not exists(norm_fn):
2069
+ norm_fn = partial(norm_class, dim)
2040
2070
 
2041
2071
  if not norm_need_condition and norm_add_unit_offset:
2042
2072
  # researcher Ohad Rubin shares in a blog post by adding an offset to gammas, they can be subjected to weight decay safely
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.29
3
+ Version: 2.1.30
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2455,4 +2455,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2455
2455
  }
2456
2456
  ```
2457
2457
 
2458
+ ```bibtex
2459
+ @inproceedings{Zhu2025TransformersWN,
2460
+ title = {Transformers without Normalization},
2461
+ author = {Jiachen Zhu and Xinlei Chen and Kaiming He and Yann LeCun and Zhuang Liu},
2462
+ year = {2025},
2463
+ url = {https://api.semanticscholar.org/CorpusID:276961218}
2464
+ }
2465
+ ```
2466
+
2458
2467
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -7,10 +7,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
8
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
9
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
- x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
10
+ x_transformers/x_transformers.py,sha256=a7k6tR9H1kCRX44PP0N9nMMb3V1_cIgFweTBK84VtEk,111476
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.29.dist-info/METADATA,sha256=CI6GLna-OqlmDEjv8sP0CcfI7SNCAbL-nQCm2sQqdbc,87875
14
- x_transformers-2.1.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.29.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.29.dist-info/RECORD,,
13
+ x_transformers-2.1.30.dist-info/METADATA,sha256=IdKgXNQf9aTZ_JiOYhc3q1J44ITmwBO-VRPRDaZtnEU,88161
14
+ x_transformers-2.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.30.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.30.dist-info/RECORD,,