x-transformers 2.1.29__tar.gz → 2.1.31__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.29 → x_transformers-2.1.31}/PKG-INFO +10 -1
  2. {x_transformers-2.1.29 → x_transformers-2.1.31}/README.md +9 -0
  3. {x_transformers-2.1.29 → x_transformers-2.1.31}/pyproject.toml +1 -1
  4. {x_transformers-2.1.29 → x_transformers-2.1.31}/tests/test_x_transformers.py +26 -2
  5. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/x_transformers.py +32 -3
  6. {x_transformers-2.1.29 → x_transformers-2.1.31}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.29 → x_transformers-2.1.31}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.29 → x_transformers-2.1.31}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.29 → x_transformers-2.1.31}/.gitignore +0 -0
  10. {x_transformers-2.1.29 → x_transformers-2.1.31}/LICENSE +0 -0
  11. {x_transformers-2.1.29 → x_transformers-2.1.31}/data/README.md +0 -0
  12. {x_transformers-2.1.29 → x_transformers-2.1.31}/data/enwik8.gz +0 -0
  13. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/all-attention.png +0 -0
  14. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/deepnorm.png +0 -0
  17. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/fcm.png +0 -0
  23. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/ffglu.png +0 -0
  24. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/flash-attention.png +0 -0
  25. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/gate_values.png +0 -0
  26. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/gating.png +0 -0
  27. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/macaron-1.png +0 -0
  29. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/macaron-2.png +0 -0
  30. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/normformer.png +0 -0
  32. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/pia.png +0 -0
  33. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/resi_dual.png +0 -0
  35. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/residual_attn.png +0 -0
  36. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/rezero.png +0 -0
  37. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/rotary.png +0 -0
  38. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/sandwich.png +0 -0
  40. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/scalenorm.png +0 -0
  42. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/talking-heads.png +0 -0
  43. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/topk-attention.png +0 -0
  44. {x_transformers-2.1.29 → x_transformers-2.1.31}/images/xval.png +0 -0
  45. {x_transformers-2.1.29 → x_transformers-2.1.31}/train_belief_state.py +0 -0
  46. {x_transformers-2.1.29 → x_transformers-2.1.31}/train_copy.py +0 -0
  47. {x_transformers-2.1.29 → x_transformers-2.1.31}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.29 → x_transformers-2.1.31}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.29 → x_transformers-2.1.31}/train_parity.py +0 -0
  50. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/belief_state_wrapper.py +0 -0
  54. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/continuous.py +0 -0
  55. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/multi_input.py +0 -0
  57. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/neo_mlp.py +0 -0
  58. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/nonautoregressive_wrapper.py +0 -0
  59. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.29 → x_transformers-2.1.31}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.29
3
+ Version: 2.1.31
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2455,4 +2455,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2455
2455
  }
2456
2456
  ```
2457
2457
 
2458
+ ```bibtex
2459
+ @inproceedings{Zhu2025TransformersWN,
2460
+ title = {Transformers without Normalization},
2461
+ author = {Jiachen Zhu and Xinlei Chen and Kaiming He and Yann LeCun and Zhuang Liu},
2462
+ year = {2025},
2463
+ url = {https://api.semanticscholar.org/CorpusID:276961218}
2464
+ }
2465
+ ```
2466
+
2458
2467
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2407,4 +2407,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2407
2407
  }
2408
2408
  ```
2409
2409
 
2410
+ ```bibtex
2411
+ @inproceedings{Zhu2025TransformersWN,
2412
+ title = {Transformers without Normalization},
2413
+ author = {Jiachen Zhu and Xinlei Chen and Kaiming He and Yann LeCun and Zhuang Liu},
2414
+ year = {2025},
2415
+ url = {https://api.semanticscholar.org/CorpusID:276961218}
2416
+ }
2417
+ ```
2418
+
2410
2419
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.29"
3
+ version = "2.1.31"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -697,10 +697,12 @@ def test_lime(
697
697
  @pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5))
698
698
  @pytest.mark.parametrize('goal_suffix', (False, True))
699
699
  @pytest.mark.parametrize('pred_distance', (False, True))
700
+ @pytest.mark.parametrize('variable_len', (False, True))
700
701
  def test_belief_state_wrapper(
701
702
  backward_ar_loss_weight,
702
703
  goal_suffix,
703
- pred_distance
704
+ pred_distance,
705
+ variable_len
704
706
  ):
705
707
  from x_transformers.belief_state_wrapper import BeliefStateWrapper
706
708
 
@@ -735,7 +737,12 @@ def test_belief_state_wrapper(
735
737
 
736
738
  seq = torch.randint(0, 20000, (2, 16))
737
739
 
738
- loss = model(seq) # backwards happen automatically
740
+ lens = None
741
+
742
+ if variable_len:
743
+ lens = torch.randint(4, 16, (2,))
744
+
745
+ loss = model(seq, lens = lens) # backwards happen automatically
739
746
 
740
747
  suffix = None
741
748
  if goal_suffix:
@@ -743,3 +750,20 @@ def test_belief_state_wrapper(
743
750
 
744
751
  sampled = model.generate_with_suffix_cond(seq[:, :1], 16, suffix = suffix)
745
752
  assert sampled.shape == (2, 16)
753
+
754
+ def test_dynamic_tanh():
755
+ model = TransformerWrapper(
756
+ num_tokens = 20000,
757
+ max_seq_len = 1024,
758
+ attn_layers = Decoder(
759
+ dim = 128,
760
+ depth = 6,
761
+ heads = 8,
762
+ use_dynamic_tanh = True,
763
+ dynamic_tanh_init_alpha = 1.5
764
+ )
765
+ )
766
+
767
+ x = torch.randint(0, 20000, (2, 1024))
768
+
769
+ model(x)
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
12
+ from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -266,7 +266,6 @@ class TokenEmbedding(Module):
266
266
  return
267
267
  nn.init.kaiming_normal_(self.emb.weight)
268
268
 
269
-
270
269
  # positional embeddings
271
270
 
272
271
  class AbsolutePositionalEmbedding(Module):
@@ -849,6 +848,32 @@ class MultiheadRMSNorm(Module):
849
848
  def forward(self, x):
850
849
  return self.rmsnorm(x) * (self.gamma + 1.)
851
850
 
851
+ class DynamicTanh(Module):
852
+ """ https://arxiv.org/abs/2503.10622 """
853
+ def __init__(
854
+ self,
855
+ dim,
856
+ init_alpha = 1.,
857
+ gamma = 1.,
858
+ beta = 0.,
859
+ unit_offset = False
860
+ ):
861
+ super().__init__()
862
+ self.pre_tanh_scale = nn.Parameter(tensor(init_alpha))
863
+
864
+ self.gamma = nn.Parameter(torch.ones(dim))
865
+ self.beta = nn.Parameter(torch.zeros(dim))
866
+
867
+ self.unit_offset = int(unit_offset)
868
+
869
+ nn.init.constant_(self.pre_tanh_scale, 1. - float(unit_offset))
870
+ nn.init.constant_(self.gamma, 1. - float(unit_offset))
871
+
872
+ def forward(self, x):
873
+ pre_tanh_scale = self.pre_tanh_scale + self.unit_offset
874
+ gamma = self.gamma + self.unit_offset
875
+ return (x * pre_tanh_scale).tanh() * gamma + self.beta
876
+
852
877
  # residual and residual gates
853
878
 
854
879
  class Residual(Module):
@@ -1863,6 +1888,8 @@ class AttentionLayers(Module):
1863
1888
  only_cross = False,
1864
1889
  use_scalenorm = False,
1865
1890
  use_rmsnorm = False,
1891
+ use_dynamic_tanh = False,
1892
+ dynamic_tanh_init_alpha = 1.,
1866
1893
  use_simple_rmsnorm = False,
1867
1894
  use_adaptive_layernorm = False,
1868
1895
  use_adaptive_rmsnorm = False,
@@ -2012,7 +2039,7 @@ class AttentionLayers(Module):
2012
2039
 
2013
2040
  # determine norm
2014
2041
 
2015
- assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2042
+ assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2016
2043
 
2017
2044
  norm_need_condition = False
2018
2045
  dim_condition = default(dim_condition, dim)
@@ -2027,6 +2054,8 @@ class AttentionLayers(Module):
2027
2054
  norm_class = RMSNorm
2028
2055
  elif use_simple_rmsnorm:
2029
2056
  norm_class = SimpleRMSNorm
2057
+ elif use_dynamic_tanh:
2058
+ norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2030
2059
  elif use_adaptive_layernorm:
2031
2060
  norm_need_condition = True
2032
2061
  norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
File without changes