x-transformers 1.42.27__tar.gz → 1.43.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.27/x_transformers.egg-info → x_transformers-1.43.0}/PKG-INFO +1 -1
  2. {x_transformers-1.42.27 → x_transformers-1.43.0}/README.md +11 -0
  3. {x_transformers-1.42.27 → x_transformers-1.43.0}/setup.py +1 -1
  4. {x_transformers-1.42.27 → x_transformers-1.43.0}/tests/test_x_transformers.py +16 -0
  5. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/x_transformers.py +107 -9
  6. {x_transformers-1.42.27 → x_transformers-1.43.0/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.42.27 → x_transformers-1.43.0}/LICENSE +0 -0
  8. {x_transformers-1.42.27 → x_transformers-1.43.0}/setup.cfg +0 -0
  9. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.42.27 → x_transformers-1.43.0}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.27
3
+ Version: 1.43.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2363,4 +2363,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2363
2363
  }
2364
2364
  ```
2365
2365
 
2366
+ ```bibtex
2367
+ @article{Zhu2024HyperConnections,
2368
+ title = {Hyper-Connections},
2369
+ author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
2370
+ journal = {ArXiv},
2371
+ year = {2024},
2372
+ volume = {abs/2409.19606},
2373
+ url = {https://api.semanticscholar.org/CorpusID:272987528}
2374
+ }
2375
+ ```
2376
+
2366
2377
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.27',
6
+ version = '1.43.0',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -590,3 +590,19 @@ def test_cross_attn_rotary(
590
590
  context_pos = context_pos,
591
591
  context_mask = context_mask
592
592
  )
593
+
594
+ def test_hyper_connections():
595
+ model = TransformerWrapper(
596
+ num_tokens = 20000,
597
+ max_seq_len = 1024,
598
+ attn_layers = Decoder(
599
+ dim = 128,
600
+ depth = 6,
601
+ heads = 8,
602
+ num_residual_streams = 8 # 8 dynamic hyper connection residual streams
603
+ )
604
+ )
605
+
606
+ x = torch.randint(0, 20000, (2, 1024))
607
+
608
+ model(x)
@@ -824,12 +824,15 @@ class SimpleRMSNorm(Module):
824
824
  # residual and residual gates
825
825
 
826
826
  class Residual(Module):
827
- def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
827
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
828
828
  super().__init__()
829
829
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
830
830
  self.scale_residual_constant = scale_residual_constant
831
831
 
832
- def forward(self, x, residual):
832
+ def prepare(self, residual):
833
+ return residual, residual, dict()
834
+
835
+ def forward(self, x, residual, **kwargs):
833
836
  if exists(self.residual_scale):
834
837
  residual = residual * self.residual_scale
835
838
 
@@ -844,7 +847,10 @@ class GRUGating(Module):
844
847
  self.gru = nn.GRUCell(dim, dim)
845
848
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
846
849
 
847
- def forward(self, x, residual):
850
+ def prepare(self, residual):
851
+ return residual, residual, dict()
852
+
853
+ def forward(self, x, residual, **kwargs):
848
854
  if exists(self.residual_scale):
849
855
  residual = residual * self.residual_scale
850
856
 
@@ -855,6 +861,66 @@ class GRUGating(Module):
855
861
 
856
862
  return gated_output.reshape_as(x)
857
863
 
864
+ # hyper connections
865
+
866
+ class HyperConnection(Module):
867
+ def __init__(
868
+ self,
869
+ dim,
870
+ *,
871
+ layer_index,
872
+ num_residual_streams,
873
+ **kwargs
874
+ ):
875
+ """
876
+ https://arxiv.org/abs/2409.19606
877
+ Appendix J - Algorithm 2, Dynamic only
878
+ """
879
+ super().__init__()
880
+
881
+ self.norm = nn.LayerNorm(dim, bias = False)
882
+
883
+ self.num_residual_streams = num_residual_streams
884
+ self.layer_index = layer_index
885
+
886
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
887
+
888
+ init_alpha0 = torch.zeros((num_residual_streams, 1))
889
+ init_alpha0[layer_index % num_residual_streams, 0] = 1.
890
+
891
+ self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
892
+
893
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
894
+ self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
895
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
896
+ self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
897
+
898
+ def prepare(self, residuals):
899
+
900
+ residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
901
+
902
+ normed = self.norm(residuals)
903
+
904
+ wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
905
+ dynamic_alpha = wc_weight * self.dynamic_alpha_scale
906
+ alpha = dynamic_alpha + self.static_alpha
907
+
908
+ dc_weight = (normed @ self.dynamic_beta_fn).tanh()
909
+ dynamic_beta = dc_weight * self.dynamic_beta_scale
910
+ beta = dynamic_beta + self.static_beta
911
+
912
+ # width connection
913
+
914
+ mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
915
+
916
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
917
+
918
+ return branch_input, residuals, dict(beta = beta)
919
+
920
+ def forward(self, x, residuals, *, beta):
921
+ residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
922
+ return rearrange(residuals, 'b n s d -> (b s) n d')
923
+
858
924
  # token shifting
859
925
 
860
926
  def shift(t, amount, mask = None):
@@ -1077,7 +1143,7 @@ class Attention(Module):
1077
1143
  logit_softclamp_value = 50.,
1078
1144
  neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1079
1145
  neutreno_alpha = 0.4,
1080
- learned_value_residual_mix = True,
1146
+ learned_value_residual_mix = False,
1081
1147
  laser = False, # https://arxiv.org/abs/2411.03493v1
1082
1148
  laser_softclamp_value = 15.,
1083
1149
  onnxable = False,
@@ -1582,9 +1648,10 @@ class AttentionLayers(Module):
1582
1648
  use_layerscale = False,
1583
1649
  layerscale_init_value = 0.,
1584
1650
  unet_skips = False,
1651
+ num_residual_streams = 1,
1585
1652
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1586
1653
  add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1587
- learned_value_residual_mix = False, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1654
+ learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1588
1655
  rel_pos_kwargs: dict = dict(),
1589
1656
  **kwargs
1590
1657
  ):
@@ -1607,6 +1674,17 @@ class AttentionLayers(Module):
1607
1674
  self.causal = causal
1608
1675
  self.layers = ModuleList([])
1609
1676
 
1677
+ # greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1678
+
1679
+ assert num_residual_streams > 0
1680
+
1681
+ self.num_residual_streams = num_residual_streams
1682
+ self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
1683
+
1684
+ assert not (num_residual_streams > 1 and gate_residual)
1685
+
1686
+ # positions related
1687
+
1610
1688
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
1611
1689
 
1612
1690
  rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
@@ -1872,9 +1950,14 @@ class AttentionLayers(Module):
1872
1950
  if exists(post_branch_fn):
1873
1951
  layer = post_branch_fn(layer)
1874
1952
 
1875
- residual_fn = GRUGating if gate_residual else Residual
1953
+ if num_residual_streams > 1:
1954
+ residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
1955
+ elif gate_residual:
1956
+ residual_fn = GRUGating
1957
+ else:
1958
+ residual_fn = Residual
1876
1959
 
1877
- residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1960
+ residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1878
1961
 
1879
1962
  # handle unet skip connection
1880
1963
 
@@ -2024,6 +2107,16 @@ class AttentionLayers(Module):
2024
2107
 
2025
2108
  iter_attn_cache = iter(attn_cache)
2026
2109
 
2110
+ # setup multistreams if needed
2111
+
2112
+ streams = self.num_residual_streams
2113
+ is_multistream = streams > 1
2114
+
2115
+ if is_multistream:
2116
+ x = repeat(x, 'b n d -> b n s d', s = streams)
2117
+ x = x + self.stream_emb
2118
+ x = rearrange(x, 'b n s d -> (b s) n d')
2119
+
2027
2120
  # outer residual - for resiDual paper
2028
2121
 
2029
2122
  outer_residual = x * self.resi_dual_scale
@@ -2090,7 +2183,7 @@ class AttentionLayers(Module):
2090
2183
  if self.training and self.cross_attn_tokens_dropout > 0.:
2091
2184
  context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
2092
2185
 
2093
- inner_residual = x
2186
+ x, inner_residual, residual_kwargs = residual_fn.prepare(x)
2094
2187
 
2095
2188
  if return_hiddens:
2096
2189
  layer_hiddens.append(x)
@@ -2148,7 +2241,7 @@ class AttentionLayers(Module):
2148
2241
  if exists(post_branch_norm):
2149
2242
  out = post_branch_norm(out)
2150
2243
 
2151
- x = residual_fn(out, inner_residual)
2244
+ x = residual_fn(out, inner_residual, **residual_kwargs)
2152
2245
 
2153
2246
  if layer_type in ('a', 'c') and return_hiddens:
2154
2247
  inter.layer_type = layer_type
@@ -2178,6 +2271,11 @@ class AttentionLayers(Module):
2178
2271
  else:
2179
2272
  x = final_norm(x)
2180
2273
 
2274
+ # take care of multistreams if needed, use sum for now
2275
+
2276
+ if is_multistream:
2277
+ x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
2278
+
2181
2279
  if not return_hiddens:
2182
2280
  return x
2183
2281
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.27
3
+ Version: 1.43.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang