x-transformers 2.1.30__py3-none-any.whl → 2.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -852,6 +852,7 @@ class DynamicTanh(Module):
852
852
  """ https://arxiv.org/abs/2503.10622 """
853
853
  def __init__(
854
854
  self,
855
+ dim,
855
856
  init_alpha = 1.,
856
857
  gamma = 1.,
857
858
  beta = 0.,
@@ -860,8 +861,8 @@ class DynamicTanh(Module):
860
861
  super().__init__()
861
862
  self.pre_tanh_scale = nn.Parameter(tensor(init_alpha))
862
863
 
863
- self.gamma = nn.Parameter(tensor(init_alpha))
864
- self.beta = nn.Parameter(tensor(init_alpha))
864
+ self.gamma = nn.Parameter(torch.ones(dim))
865
+ self.beta = nn.Parameter(torch.zeros(dim))
865
866
 
866
867
  self.unit_offset = int(unit_offset)
867
868
 
@@ -2040,7 +2041,6 @@ class AttentionLayers(Module):
2040
2041
 
2041
2042
  assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2042
2043
 
2043
- norm_fn = None
2044
2044
  norm_need_condition = False
2045
2045
  dim_condition = default(dim_condition, dim)
2046
2046
  dim_condition_mult = 1
@@ -2055,7 +2055,8 @@ class AttentionLayers(Module):
2055
2055
  elif use_simple_rmsnorm:
2056
2056
  norm_class = SimpleRMSNorm
2057
2057
  elif use_dynamic_tanh:
2058
- norm_fn = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2058
+ assert pre_norm, 'only tested for pre-norm'
2059
+ norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2059
2060
  elif use_adaptive_layernorm:
2060
2061
  norm_need_condition = True
2061
2062
  norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
@@ -2065,8 +2066,7 @@ class AttentionLayers(Module):
2065
2066
  else:
2066
2067
  norm_class = LayerNorm
2067
2068
 
2068
- if not exists(norm_fn):
2069
- norm_fn = partial(norm_class, dim)
2069
+ norm_fn = partial(norm_class, dim)
2070
2070
 
2071
2071
  if not norm_need_condition and norm_add_unit_offset:
2072
2072
  # researcher Ohad Rubin shares in a blog post by adding an offset to gammas, they can be subjected to weight decay safely
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.30
3
+ Version: 2.1.32
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -7,10 +7,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
8
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
9
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
- x_transformers/x_transformers.py,sha256=a7k6tR9H1kCRX44PP0N9nMMb3V1_cIgFweTBK84VtEk,111476
10
+ x_transformers/x_transformers.py,sha256=YQODc4PDB_ddgm7vi0uktV5GGetgEuwADzt3CaIdAXs,111484
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.30.dist-info/METADATA,sha256=IdKgXNQf9aTZ_JiOYhc3q1J44ITmwBO-VRPRDaZtnEU,88161
14
- x_transformers-2.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.30.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.30.dist-info/RECORD,,
13
+ x_transformers-2.1.32.dist-info/METADATA,sha256=Jmgj9CByp1_kveLWVm5-QM_n55A9s3MFzUiV1ciD034,88161
14
+ x_transformers-2.1.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.32.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.32.dist-info/RECORD,,