x-transformers 1.42.28__tar.gz → 1.43.1__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.28/x_transformers.egg-info → x_transformers-1.43.1}/PKG-INFO +1 -1
  2. {x_transformers-1.42.28 → x_transformers-1.43.1}/README.md +12 -1
  3. {x_transformers-1.42.28 → x_transformers-1.43.1}/setup.py +1 -1
  4. {x_transformers-1.42.28 → x_transformers-1.43.1}/tests/test_x_transformers.py +21 -0
  5. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/x_transformers.py +110 -8
  6. {x_transformers-1.42.28 → x_transformers-1.43.1/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.42.28 → x_transformers-1.43.1}/LICENSE +0 -0
  8. {x_transformers-1.42.28 → x_transformers-1.43.1}/setup.cfg +0 -0
  9. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.28
3
+ Version: 1.43.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2240,7 +2240,7 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2240
2240
  }
2241
2241
  ```
2242
2242
 
2243
- ```
2243
+ ```bibtex
2244
2244
  @article{Yang2017BreakingTS,
2245
2245
  title = {Breaking the Softmax Bottleneck: A High-Rank RNN Language Model},
2246
2246
  author = {Zhilin Yang and Zihang Dai and Ruslan Salakhutdinov and William W. Cohen},
@@ -2363,4 +2363,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2363
2363
  }
2364
2364
  ```
2365
2365
 
2366
+ ```bibtex
2367
+ @article{Zhu2024HyperConnections,
2368
+ title = {Hyper-Connections},
2369
+ author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
2370
+ journal = {ArXiv},
2371
+ year = {2024},
2372
+ volume = {abs/2409.19606},
2373
+ url = {https://api.semanticscholar.org/CorpusID:272987528}
2374
+ }
2375
+ ```
2376
+
2366
2377
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.28',
6
+ version = '1.43.1',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -590,3 +590,24 @@ def test_cross_attn_rotary(
590
590
  context_pos = context_pos,
591
591
  context_mask = context_mask
592
592
  )
593
+
594
+ @pytest.mark.parametrize('tanh', (True, False))
595
+ def test_hyper_connections(tanh):
596
+
597
+ model = TransformerWrapper(
598
+ num_tokens = 20000,
599
+ max_seq_len = 1024,
600
+ attn_layers = Decoder(
601
+ dim = 128,
602
+ depth = 6,
603
+ heads = 8,
604
+ num_residual_streams = 8, # 8 dynamic hyper connection residual streams
605
+ residual_fn_kwargs = dict(
606
+ tanh = tanh
607
+ )
608
+ )
609
+ )
610
+
611
+ x = torch.randint(0, 20000, (2, 1024))
612
+
613
+ model(x)
@@ -824,12 +824,15 @@ class SimpleRMSNorm(Module):
824
824
  # residual and residual gates
825
825
 
826
826
  class Residual(Module):
827
- def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
827
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
828
828
  super().__init__()
829
829
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
830
830
  self.scale_residual_constant = scale_residual_constant
831
831
 
832
- def forward(self, x, residual):
832
+ def prepare(self, residual):
833
+ return residual, residual, dict()
834
+
835
+ def forward(self, x, residual, **kwargs):
833
836
  if exists(self.residual_scale):
834
837
  residual = residual * self.residual_scale
835
838
 
@@ -844,7 +847,10 @@ class GRUGating(Module):
844
847
  self.gru = nn.GRUCell(dim, dim)
845
848
  self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
846
849
 
847
- def forward(self, x, residual):
850
+ def prepare(self, residual):
851
+ return residual, residual, dict()
852
+
853
+ def forward(self, x, residual, **kwargs):
848
854
  if exists(self.residual_scale):
849
855
  residual = residual * self.residual_scale
850
856
 
@@ -855,6 +861,69 @@ class GRUGating(Module):
855
861
 
856
862
  return gated_output.reshape_as(x)
857
863
 
864
+ # hyper connections
865
+
866
+ class HyperConnection(Module):
867
+ def __init__(
868
+ self,
869
+ dim,
870
+ *,
871
+ layer_index,
872
+ num_residual_streams,
873
+ tanh = True,
874
+ **kwargs
875
+ ):
876
+ """
877
+ https://arxiv.org/abs/2409.19606
878
+ Appendix J - Algorithm 2, Dynamic only
879
+ """
880
+ super().__init__()
881
+
882
+ self.act = nn.Tanh() if tanh else nn.Identity()
883
+
884
+ self.norm = nn.LayerNorm(dim, bias = False)
885
+
886
+ self.num_residual_streams = num_residual_streams
887
+ self.layer_index = layer_index
888
+
889
+ self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
890
+
891
+ init_alpha0 = torch.zeros((num_residual_streams, 1))
892
+ init_alpha0[layer_index % num_residual_streams, 0] = 1.
893
+
894
+ self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
895
+
896
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
897
+ self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
898
+ self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
899
+ self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
900
+
901
+ def prepare(self, residuals):
902
+
903
+ residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
904
+
905
+ normed = self.norm(residuals)
906
+
907
+ wc_weight = self.act(normed @ self.dynamic_alpha_fn)
908
+ dynamic_alpha = wc_weight * self.dynamic_alpha_scale
909
+ alpha = dynamic_alpha + self.static_alpha
910
+
911
+ dc_weight = self.act(normed @ self.dynamic_beta_fn)
912
+ dynamic_beta = dc_weight * self.dynamic_beta_scale
913
+ beta = dynamic_beta + self.static_beta
914
+
915
+ # width connection
916
+
917
+ mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
918
+
919
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
920
+
921
+ return branch_input, residuals, dict(beta = beta)
922
+
923
+ def forward(self, x, residuals, *, beta):
924
+ residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
925
+ return rearrange(residuals, 'b n s d -> (b s) n d')
926
+
858
927
  # token shifting
859
928
 
860
929
  def shift(t, amount, mask = None):
@@ -1582,10 +1651,12 @@ class AttentionLayers(Module):
1582
1651
  use_layerscale = False,
1583
1652
  layerscale_init_value = 0.,
1584
1653
  unet_skips = False,
1654
+ num_residual_streams = 1,
1585
1655
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1586
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1656
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
1587
1657
  learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1588
1658
  rel_pos_kwargs: dict = dict(),
1659
+ residual_fn_kwargs: dict = dict(),
1589
1660
  **kwargs
1590
1661
  ):
1591
1662
  super().__init__()
@@ -1607,6 +1678,17 @@ class AttentionLayers(Module):
1607
1678
  self.causal = causal
1608
1679
  self.layers = ModuleList([])
1609
1680
 
1681
+ # greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1682
+
1683
+ assert num_residual_streams > 0
1684
+
1685
+ self.num_residual_streams = num_residual_streams
1686
+ self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
1687
+
1688
+ assert not (num_residual_streams > 1 and gate_residual)
1689
+
1690
+ # positions related
1691
+
1610
1692
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
1611
1693
 
1612
1694
  rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
@@ -1872,9 +1954,14 @@ class AttentionLayers(Module):
1872
1954
  if exists(post_branch_fn):
1873
1955
  layer = post_branch_fn(layer)
1874
1956
 
1875
- residual_fn = GRUGating if gate_residual else Residual
1957
+ if num_residual_streams > 1:
1958
+ residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
1959
+ elif gate_residual:
1960
+ residual_fn = GRUGating
1961
+ else:
1962
+ residual_fn = Residual
1876
1963
 
1877
- residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1964
+ residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant, **residual_fn_kwargs)
1878
1965
 
1879
1966
  # handle unet skip connection
1880
1967
 
@@ -2024,6 +2111,16 @@ class AttentionLayers(Module):
2024
2111
 
2025
2112
  iter_attn_cache = iter(attn_cache)
2026
2113
 
2114
+ # setup multistreams if needed
2115
+
2116
+ streams = self.num_residual_streams
2117
+ is_multistream = streams > 1
2118
+
2119
+ if is_multistream:
2120
+ x = repeat(x, 'b n d -> b n s d', s = streams)
2121
+ x = x + self.stream_emb
2122
+ x = rearrange(x, 'b n s d -> (b s) n d')
2123
+
2027
2124
  # outer residual - for resiDual paper
2028
2125
 
2029
2126
  outer_residual = x * self.resi_dual_scale
@@ -2090,7 +2187,7 @@ class AttentionLayers(Module):
2090
2187
  if self.training and self.cross_attn_tokens_dropout > 0.:
2091
2188
  context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
2092
2189
 
2093
- inner_residual = x
2190
+ x, inner_residual, residual_kwargs = residual_fn.prepare(x)
2094
2191
 
2095
2192
  if return_hiddens:
2096
2193
  layer_hiddens.append(x)
@@ -2148,7 +2245,7 @@ class AttentionLayers(Module):
2148
2245
  if exists(post_branch_norm):
2149
2246
  out = post_branch_norm(out)
2150
2247
 
2151
- x = residual_fn(out, inner_residual)
2248
+ x = residual_fn(out, inner_residual, **residual_kwargs)
2152
2249
 
2153
2250
  if layer_type in ('a', 'c') and return_hiddens:
2154
2251
  inter.layer_type = layer_type
@@ -2178,6 +2275,11 @@ class AttentionLayers(Module):
2178
2275
  else:
2179
2276
  x = final_norm(x)
2180
2277
 
2278
+ # take care of multistreams if needed, use sum for now
2279
+
2280
+ if is_multistream:
2281
+ x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
2282
+
2181
2283
  if not return_hiddens:
2182
2284
  return x
2183
2285
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.28
3
+ Version: 1.43.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang