x-transformers 1.42.27__py3-none-any.whl → 1.43.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -824,12 +824,15 @@ class SimpleRMSNorm(Module):
824
824
  # residual and residual gates
825
825
 
826
826
  class Residual(Module):
827
- def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
827
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
828
828
  super().__init__()
829
829
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
830
830
  self.scale_residual_constant = scale_residual_constant
831
831
 
832
- def forward(self, x, residual):
832
+ def prepare(self, residual):
833
+ return residual, residual, dict()
834
+
835
+ def forward(self, x, residual, **kwargs):
833
836
  if exists(self.residual_scale):
834
837
  residual = residual * self.residual_scale
835
838
 
@@ -844,7 +847,10 @@ class GRUGating(Module):
844
847
  self.gru = nn.GRUCell(dim, dim)
845
848
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
846
849
 
847
- def forward(self, x, residual):
850
+ def prepare(self, residual):
851
+ return residual, residual, dict()
852
+
853
+ def forward(self, x, residual, **kwargs):
848
854
  if exists(self.residual_scale):
849
855
  residual = residual * self.residual_scale
850
856
 
@@ -855,6 +861,66 @@ class GRUGating(Module):
855
861
 
856
862
  return gated_output.reshape_as(x)
857
863
 
864
+ # hyper connections
865
+
866
+ class HyperConnection(Module):
867
+ def __init__(
868
+ self,
869
+ dim,
870
+ *,
871
+ layer_index,
872
+ num_residual_streams,
873
+ **kwargs
874
+ ):
875
+ """
876
+ https://arxiv.org/abs/2409.19606
877
+ Appendix J - Algorithm 2, Dynamic only
878
+ """
879
+ super().__init__()
880
+
881
+ self.norm = nn.LayerNorm(dim, bias = False)
882
+
883
+ self.num_residual_streams = num_residual_streams
884
+ self.layer_index = layer_index
885
+
886
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
887
+
888
+ init_alpha0 = torch.zeros((num_residual_streams, 1))
889
+ init_alpha0[layer_index % num_residual_streams, 0] = 1.
890
+
891
+ self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
892
+
893
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
894
+ self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
895
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
896
+ self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
897
+
898
+ def prepare(self, residuals):
899
+
900
+ residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
901
+
902
+ normed = self.norm(residuals)
903
+
904
+ wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
905
+ dynamic_alpha = wc_weight * self.dynamic_alpha_scale
906
+ alpha = dynamic_alpha + self.static_alpha
907
+
908
+ dc_weight = (normed @ self.dynamic_beta_fn).tanh()
909
+ dynamic_beta = dc_weight * self.dynamic_beta_scale
910
+ beta = dynamic_beta + self.static_beta
911
+
912
+ # width connection
913
+
914
+ mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
915
+
916
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
917
+
918
+ return branch_input, residuals, dict(beta = beta)
919
+
920
+ def forward(self, x, residuals, *, beta):
921
+ residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
922
+ return rearrange(residuals, 'b n s d -> (b s) n d')
923
+
858
924
  # token shifting
859
925
 
860
926
  def shift(t, amount, mask = None):
@@ -1077,7 +1143,7 @@ class Attention(Module):
1077
1143
  logit_softclamp_value = 50.,
1078
1144
  neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1079
1145
  neutreno_alpha = 0.4,
1080
- learned_value_residual_mix = True,
1146
+ learned_value_residual_mix = False,
1081
1147
  laser = False, # https://arxiv.org/abs/2411.03493v1
1082
1148
  laser_softclamp_value = 15.,
1083
1149
  onnxable = False,
@@ -1582,9 +1648,10 @@ class AttentionLayers(Module):
1582
1648
  use_layerscale = False,
1583
1649
  layerscale_init_value = 0.,
1584
1650
  unet_skips = False,
1651
+ num_residual_streams = 1,
1585
1652
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1586
1653
  add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1587
- learned_value_residual_mix = False, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1654
+ learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1588
1655
  rel_pos_kwargs: dict = dict(),
1589
1656
  **kwargs
1590
1657
  ):
@@ -1607,6 +1674,17 @@ class AttentionLayers(Module):
1607
1674
  self.causal = causal
1608
1675
  self.layers = ModuleList([])
1609
1676
 
1677
+ # greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1678
+
1679
+ assert num_residual_streams > 0
1680
+
1681
+ self.num_residual_streams = num_residual_streams
1682
+ self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
1683
+
1684
+ assert not (num_residual_streams > 1 and gate_residual)
1685
+
1686
+ # positions related
1687
+
1610
1688
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
1611
1689
 
1612
1690
  rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
@@ -1872,9 +1950,14 @@ class AttentionLayers(Module):
1872
1950
  if exists(post_branch_fn):
1873
1951
  layer = post_branch_fn(layer)
1874
1952
 
1875
- residual_fn = GRUGating if gate_residual else Residual
1953
+ if num_residual_streams > 1:
1954
+ residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
1955
+ elif gate_residual:
1956
+ residual_fn = GRUGating
1957
+ else:
1958
+ residual_fn = Residual
1876
1959
 
1877
- residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1960
+ residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1878
1961
 
1879
1962
  # handle unet skip connection
1880
1963
 
@@ -2024,6 +2107,16 @@ class AttentionLayers(Module):
2024
2107
 
2025
2108
  iter_attn_cache = iter(attn_cache)
2026
2109
 
2110
+ # setup multistreams if needed
2111
+
2112
+ streams = self.num_residual_streams
2113
+ is_multistream = streams > 1
2114
+
2115
+ if is_multistream:
2116
+ x = repeat(x, 'b n d -> b n s d', s = streams)
2117
+ x = x + self.stream_emb
2118
+ x = rearrange(x, 'b n s d -> (b s) n d')
2119
+
2027
2120
  # outer residual - for resiDual paper
2028
2121
 
2029
2122
  outer_residual = x * self.resi_dual_scale
@@ -2090,7 +2183,7 @@ class AttentionLayers(Module):
2090
2183
  if self.training and self.cross_attn_tokens_dropout > 0.:
2091
2184
  context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
2092
2185
 
2093
- inner_residual = x
2186
+ x, inner_residual, residual_kwargs = residual_fn.prepare(x)
2094
2187
 
2095
2188
  if return_hiddens:
2096
2189
  layer_hiddens.append(x)
@@ -2148,7 +2241,7 @@ class AttentionLayers(Module):
2148
2241
  if exists(post_branch_norm):
2149
2242
  out = post_branch_norm(out)
2150
2243
 
2151
- x = residual_fn(out, inner_residual)
2244
+ x = residual_fn(out, inner_residual, **residual_kwargs)
2152
2245
 
2153
2246
  if layer_type in ('a', 'c') and return_hiddens:
2154
2247
  inter.layer_type = layer_type
@@ -2178,6 +2271,11 @@ class AttentionLayers(Module):
2178
2271
  else:
2179
2272
  x = final_norm(x)
2180
2273
 
2274
+ # take care of multistreams if needed, use sum for now
2275
+
2276
+ if is_multistream:
2277
+ x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
2278
+
2181
2279
  if not return_hiddens:
2182
2280
  return x
2183
2281
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.27
3
+ Version: 1.43.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=mLAqXQuZynqueJDkTEBs-kE9Uk8mSq_DF8UG9oY65Ns,96695
9
+ x_transformers/x_transformers.py,sha256=wAY0lqZvFlXk-fmpr4Ot6yZ6ivzEjetFXTin7z7eA88,100075
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.27.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.27.dist-info/METADATA,sha256=g6KI8a3WyHUyq9w5Tq3aQatgr89gpc5IMZ5c1zAGlHU,739
14
- x_transformers-1.42.27.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.42.27.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.27.dist-info/RECORD,,
12
+ x_transformers-1.43.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.43.0.dist-info/METADATA,sha256=C6eRstMfzmbxQUxNeKnt1Mf-e9pJ45GKNJ8hsc_3uwo,738
14
+ x_transformers-1.43.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.43.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.43.0.dist-info/RECORD,,