x-transformers 1.43.0__py3-none-any.whl → 1.43.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -870,6 +870,7 @@ class HyperConnection(Module):
870
870
  *,
871
871
  layer_index,
872
872
  num_residual_streams,
873
+ tanh = True,
873
874
  **kwargs
874
875
  ):
875
876
  """
@@ -878,6 +879,8 @@ class HyperConnection(Module):
878
879
  """
879
880
  super().__init__()
880
881
 
882
+ self.act = nn.Tanh() if tanh else nn.Identity()
883
+
881
884
  self.norm = nn.LayerNorm(dim, bias = False)
882
885
 
883
886
  self.num_residual_streams = num_residual_streams
@@ -901,11 +904,11 @@ class HyperConnection(Module):
901
904
 
902
905
  normed = self.norm(residuals)
903
906
 
904
- wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
907
+ wc_weight = self.act(normed @ self.dynamic_alpha_fn)
905
908
  dynamic_alpha = wc_weight * self.dynamic_alpha_scale
906
909
  alpha = dynamic_alpha + self.static_alpha
907
910
 
908
- dc_weight = (normed @ self.dynamic_beta_fn).tanh()
911
+ dc_weight = self.act(normed @ self.dynamic_beta_fn)
909
912
  dynamic_beta = dc_weight * self.dynamic_beta_scale
910
913
  beta = dynamic_beta + self.static_beta
911
914
 
@@ -1650,9 +1653,10 @@ class AttentionLayers(Module):
1650
1653
  unet_skips = False,
1651
1654
  num_residual_streams = 1,
1652
1655
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1653
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1656
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
1654
1657
  learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1655
1658
  rel_pos_kwargs: dict = dict(),
1659
+ residual_fn_kwargs: dict = dict(),
1656
1660
  **kwargs
1657
1661
  ):
1658
1662
  super().__init__()
@@ -1957,7 +1961,7 @@ class AttentionLayers(Module):
1957
1961
  else:
1958
1962
  residual_fn = Residual
1959
1963
 
1960
- residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1964
+ residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant, **residual_fn_kwargs)
1961
1965
 
1962
1966
  # handle unet skip connection
1963
1967
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.0
3
+ Version: 1.43.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=wAY0lqZvFlXk-fmpr4Ot6yZ6ivzEjetFXTin7z7eA88,100075
9
+ x_transformers/x_transformers.py,sha256=JG38kcXdhRBKT5_FHMhV5dQabSGrAHsuIQkHjPalDiI,100384
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.43.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.43.0.dist-info/METADATA,sha256=C6eRstMfzmbxQUxNeKnt1Mf-e9pJ45GKNJ8hsc_3uwo,738
14
- x_transformers-1.43.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.43.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.43.0.dist-info/RECORD,,
12
+ x_transformers-1.43.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.43.1.dist-info/METADATA,sha256=V57c6Bps0GjG0GLEBpxkHdbvxIWzXss2Xu5_KQJJXPc,738
14
+ x_transformers-1.43.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.43.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.43.1.dist-info/RECORD,,