x-transformers 1.43.0__tar.gz → 1.43.2__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.43.0/x_transformers.egg-info → x_transformers-1.43.2}/PKG-INFO +1 -1
  2. {x_transformers-1.43.0 → x_transformers-1.43.2}/README.md +1 -1
  3. {x_transformers-1.43.0 → x_transformers-1.43.2}/setup.py +1 -1
  4. {x_transformers-1.43.0 → x_transformers-1.43.2}/tests/test_x_transformers.py +11 -4
  5. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/x_transformers.py +9 -5
  6. {x_transformers-1.43.0 → x_transformers-1.43.2/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.43.0 → x_transformers-1.43.2}/LICENSE +0 -0
  8. {x_transformers-1.43.0 → x_transformers-1.43.2}/setup.cfg +0 -0
  9. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.43.0 → x_transformers-1.43.2}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.0
3
+ Version: 1.43.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2240,7 +2240,7 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2240
2240
  }
2241
2241
  ```
2242
2242
 
2243
- ```
2243
+ ```bibtex
2244
2244
  @article{Yang2017BreakingTS,
2245
2245
  title = {Breaking the Softmax Bottleneck: A High-Rank RNN Language Model},
2246
2246
  author = {Zhilin Yang and Zihang Dai and Ruslan Salakhutdinov and William W. Cohen},
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.43.0',
6
+ version = '1.43.2',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -409,7 +409,8 @@ def test_custom_alibi(flash: bool):
409
409
 
410
410
  logits = model(x, pos = pos)
411
411
 
412
- def test_custom_rotary_pos_emb():
412
+ @pytest.mark.parametrize('rotary_xpos', (True, False))
413
+ def test_custom_rotary_pos_emb(rotary_xpos):
413
414
  from einops import repeat
414
415
 
415
416
  model = TransformerWrapper(
@@ -419,7 +420,8 @@ def test_custom_rotary_pos_emb():
419
420
  dim = 512,
420
421
  depth = 2,
421
422
  heads = 8,
422
- rotary_pos_emb = True
423
+ rotary_pos_emb = True,
424
+ rotary_xpos = rotary_xpos
423
425
  )
424
426
  )
425
427
 
@@ -591,7 +593,9 @@ def test_cross_attn_rotary(
591
593
  context_mask = context_mask
592
594
  )
593
595
 
594
- def test_hyper_connections():
596
+ @pytest.mark.parametrize('tanh', (True, False))
597
+ def test_hyper_connections(tanh):
598
+
595
599
  model = TransformerWrapper(
596
600
  num_tokens = 20000,
597
601
  max_seq_len = 1024,
@@ -599,7 +603,10 @@ def test_hyper_connections():
599
603
  dim = 128,
600
604
  depth = 6,
601
605
  heads = 8,
602
- num_residual_streams = 8 # 8 dynamic hyper connection residual streams
606
+ num_residual_streams = 8, # 8 dynamic hyper connection residual streams
607
+ residual_fn_kwargs = dict(
608
+ tanh = tanh
609
+ )
603
610
  )
604
611
  )
605
612
 
@@ -666,7 +666,7 @@ class RotaryEmbedding(Module):
666
666
  return freqs, 1.
667
667
 
668
668
  power = (t - (max_pos // 2)) / self.scale_base
669
- scale = self.scale ** rearrange(power, 'n -> n 1')
669
+ scale = self.scale ** rearrange(power, '... n -> ... n 1')
670
670
  scale = torch.stack((scale, scale), dim = -1)
671
671
  scale = rearrange(scale, '... d r -> ... (d r)')
672
672
 
@@ -870,6 +870,7 @@ class HyperConnection(Module):
870
870
  *,
871
871
  layer_index,
872
872
  num_residual_streams,
873
+ tanh = True,
873
874
  **kwargs
874
875
  ):
875
876
  """
@@ -878,6 +879,8 @@ class HyperConnection(Module):
878
879
  """
879
880
  super().__init__()
880
881
 
882
+ self.act = nn.Tanh() if tanh else nn.Identity()
883
+
881
884
  self.norm = nn.LayerNorm(dim, bias = False)
882
885
 
883
886
  self.num_residual_streams = num_residual_streams
@@ -901,11 +904,11 @@ class HyperConnection(Module):
901
904
 
902
905
  normed = self.norm(residuals)
903
906
 
904
- wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
907
+ wc_weight = self.act(normed @ self.dynamic_alpha_fn)
905
908
  dynamic_alpha = wc_weight * self.dynamic_alpha_scale
906
909
  alpha = dynamic_alpha + self.static_alpha
907
910
 
908
- dc_weight = (normed @ self.dynamic_beta_fn).tanh()
911
+ dc_weight = self.act(normed @ self.dynamic_beta_fn)
909
912
  dynamic_beta = dc_weight * self.dynamic_beta_scale
910
913
  beta = dynamic_beta + self.static_beta
911
914
 
@@ -1650,9 +1653,10 @@ class AttentionLayers(Module):
1650
1653
  unet_skips = False,
1651
1654
  num_residual_streams = 1,
1652
1655
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1653
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1656
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
1654
1657
  learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1655
1658
  rel_pos_kwargs: dict = dict(),
1659
+ residual_fn_kwargs: dict = dict(),
1656
1660
  **kwargs
1657
1661
  ):
1658
1662
  super().__init__()
@@ -1957,7 +1961,7 @@ class AttentionLayers(Module):
1957
1961
  else:
1958
1962
  residual_fn = Residual
1959
1963
 
1960
- residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1964
+ residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant, **residual_fn_kwargs)
1961
1965
 
1962
1966
  # handle unet skip connection
1963
1967
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.0
3
+ Version: 1.43.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes