x-transformers 1.42.28__py3-none-any.whl → 1.43.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -824,12 +824,15 @@ class SimpleRMSNorm(Module):
824
824
  # residual and residual gates
825
825
 
826
826
  class Residual(Module):
827
- def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
827
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
828
828
  super().__init__()
829
829
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
830
830
  self.scale_residual_constant = scale_residual_constant
831
831
 
832
- def forward(self, x, residual):
832
+ def prepare(self, residual):
833
+ return residual, residual, dict()
834
+
835
+ def forward(self, x, residual, **kwargs):
833
836
  if exists(self.residual_scale):
834
837
  residual = residual * self.residual_scale
835
838
 
@@ -844,7 +847,10 @@ class GRUGating(Module):
844
847
  self.gru = nn.GRUCell(dim, dim)
845
848
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
846
849
 
847
- def forward(self, x, residual):
850
+ def prepare(self, residual):
851
+ return residual, residual, dict()
852
+
853
+ def forward(self, x, residual, **kwargs):
848
854
  if exists(self.residual_scale):
849
855
  residual = residual * self.residual_scale
850
856
 
@@ -855,6 +861,69 @@ class GRUGating(Module):
855
861
 
856
862
  return gated_output.reshape_as(x)
857
863
 
864
+ # hyper connections
865
+
866
+ class HyperConnection(Module):
867
+ def __init__(
868
+ self,
869
+ dim,
870
+ *,
871
+ layer_index,
872
+ num_residual_streams,
873
+ tanh = True,
874
+ **kwargs
875
+ ):
876
+ """
877
+ https://arxiv.org/abs/2409.19606
878
+ Appendix J - Algorithm 2, Dynamic only
879
+ """
880
+ super().__init__()
881
+
882
+ self.act = nn.Tanh() if tanh else nn.Identity()
883
+
884
+ self.norm = nn.LayerNorm(dim, bias = False)
885
+
886
+ self.num_residual_streams = num_residual_streams
887
+ self.layer_index = layer_index
888
+
889
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
890
+
891
+ init_alpha0 = torch.zeros((num_residual_streams, 1))
892
+ init_alpha0[layer_index % num_residual_streams, 0] = 1.
893
+
894
+ self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
895
+
896
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
897
+ self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
898
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
899
+ self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
900
+
901
+ def prepare(self, residuals):
902
+
903
+ residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
904
+
905
+ normed = self.norm(residuals)
906
+
907
+ wc_weight = self.act(normed @ self.dynamic_alpha_fn)
908
+ dynamic_alpha = wc_weight * self.dynamic_alpha_scale
909
+ alpha = dynamic_alpha + self.static_alpha
910
+
911
+ dc_weight = self.act(normed @ self.dynamic_beta_fn)
912
+ dynamic_beta = dc_weight * self.dynamic_beta_scale
913
+ beta = dynamic_beta + self.static_beta
914
+
915
+ # width connection
916
+
917
+ mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
918
+
919
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
920
+
921
+ return branch_input, residuals, dict(beta = beta)
922
+
923
+ def forward(self, x, residuals, *, beta):
924
+ residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
925
+ return rearrange(residuals, 'b n s d -> (b s) n d')
926
+
858
927
  # token shifting
859
928
 
860
929
  def shift(t, amount, mask = None):
@@ -1582,10 +1651,12 @@ class AttentionLayers(Module):
1582
1651
  use_layerscale = False,
1583
1652
  layerscale_init_value = 0.,
1584
1653
  unet_skips = False,
1654
+ num_residual_streams = 1,
1585
1655
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1586
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1656
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
1587
1657
  learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1588
1658
  rel_pos_kwargs: dict = dict(),
1659
+ residual_fn_kwargs: dict = dict(),
1589
1660
  **kwargs
1590
1661
  ):
1591
1662
  super().__init__()
@@ -1607,6 +1678,17 @@ class AttentionLayers(Module):
1607
1678
  self.causal = causal
1608
1679
  self.layers = ModuleList([])
1609
1680
 
1681
+ # greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1682
+
1683
+ assert num_residual_streams > 0
1684
+
1685
+ self.num_residual_streams = num_residual_streams
1686
+ self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
1687
+
1688
+ assert not (num_residual_streams > 1 and gate_residual)
1689
+
1690
+ # positions related
1691
+
1610
1692
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
1611
1693
 
1612
1694
  rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
@@ -1872,9 +1954,14 @@ class AttentionLayers(Module):
1872
1954
  if exists(post_branch_fn):
1873
1955
  layer = post_branch_fn(layer)
1874
1956
 
1875
- residual_fn = GRUGating if gate_residual else Residual
1957
+ if num_residual_streams > 1:
1958
+ residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
1959
+ elif gate_residual:
1960
+ residual_fn = GRUGating
1961
+ else:
1962
+ residual_fn = Residual
1876
1963
 
1877
- residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1964
+ residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant, **residual_fn_kwargs)
1878
1965
 
1879
1966
  # handle unet skip connection
1880
1967
 
@@ -2024,6 +2111,16 @@ class AttentionLayers(Module):
2024
2111
 
2025
2112
  iter_attn_cache = iter(attn_cache)
2026
2113
 
2114
+ # setup multistreams if needed
2115
+
2116
+ streams = self.num_residual_streams
2117
+ is_multistream = streams > 1
2118
+
2119
+ if is_multistream:
2120
+ x = repeat(x, 'b n d -> b n s d', s = streams)
2121
+ x = x + self.stream_emb
2122
+ x = rearrange(x, 'b n s d -> (b s) n d')
2123
+
2027
2124
  # outer residual - for resiDual paper
2028
2125
 
2029
2126
  outer_residual = x * self.resi_dual_scale
@@ -2090,7 +2187,7 @@ class AttentionLayers(Module):
2090
2187
  if self.training and self.cross_attn_tokens_dropout > 0.:
2091
2188
  context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
2092
2189
 
2093
- inner_residual = x
2190
+ x, inner_residual, residual_kwargs = residual_fn.prepare(x)
2094
2191
 
2095
2192
  if return_hiddens:
2096
2193
  layer_hiddens.append(x)
@@ -2148,7 +2245,7 @@ class AttentionLayers(Module):
2148
2245
  if exists(post_branch_norm):
2149
2246
  out = post_branch_norm(out)
2150
2247
 
2151
- x = residual_fn(out, inner_residual)
2248
+ x = residual_fn(out, inner_residual, **residual_kwargs)
2152
2249
 
2153
2250
  if layer_type in ('a', 'c') and return_hiddens:
2154
2251
  inter.layer_type = layer_type
@@ -2178,6 +2275,11 @@ class AttentionLayers(Module):
2178
2275
  else:
2179
2276
  x = final_norm(x)
2180
2277
 
2278
+ # take care of multistreams if needed, use sum for now
2279
+
2280
+ if is_multistream:
2281
+ x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
2282
+
2181
2283
  if not return_hiddens:
2182
2284
  return x
2183
2285
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.28
3
+ Version: 1.43.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=X4HegsAtCnaL3MAxu07RkZ5WBMgtdbi0W-2c9bXQxew,96696
9
+ x_transformers/x_transformers.py,sha256=JG38kcXdhRBKT5_FHMhV5dQabSGrAHsuIQkHjPalDiI,100384
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.28.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.28.dist-info/METADATA,sha256=txhDZvzsfiBEPBUg3Ipszv2cWu9sXyd7hhDz4BGsbfc,739
14
- x_transformers-1.42.28.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.42.28.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.28.dist-info/RECORD,,
12
+ x_transformers-1.43.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.43.1.dist-info/METADATA,sha256=V57c6Bps0GjG0GLEBpxkHdbvxIWzXss2Xu5_KQJJXPc,738
14
+ x_transformers-1.43.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.43.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.43.1.dist-info/RECORD,,