x-transformers 2.1.28__tar.gz → 2.1.30__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.28 → x_transformers-2.1.30}/PKG-INFO +10 -1
  2. {x_transformers-2.1.28 → x_transformers-2.1.30}/README.md +9 -0
  3. {x_transformers-2.1.28 → x_transformers-2.1.30}/pyproject.toml +1 -1
  4. {x_transformers-2.1.28 → x_transformers-2.1.30}/tests/test_x_transformers.py +26 -2
  5. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/belief_state_wrapper.py +2 -0
  6. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/x_transformers.py +34 -4
  7. {x_transformers-2.1.28 → x_transformers-2.1.30}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.1.28 → x_transformers-2.1.30}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.1.28 → x_transformers-2.1.30}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.1.28 → x_transformers-2.1.30}/.gitignore +0 -0
  11. {x_transformers-2.1.28 → x_transformers-2.1.30}/LICENSE +0 -0
  12. {x_transformers-2.1.28 → x_transformers-2.1.30}/data/README.md +0 -0
  13. {x_transformers-2.1.28 → x_transformers-2.1.30}/data/enwik8.gz +0 -0
  14. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/all-attention.png +0 -0
  15. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/deepnorm.png +0 -0
  18. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/fcm.png +0 -0
  24. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/ffglu.png +0 -0
  25. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/flash-attention.png +0 -0
  26. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/gate_values.png +0 -0
  27. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/gating.png +0 -0
  28. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/macaron-1.png +0 -0
  30. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/macaron-2.png +0 -0
  31. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/normformer.png +0 -0
  33. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/pia.png +0 -0
  34. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/resi_dual.png +0 -0
  36. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/residual_attn.png +0 -0
  37. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/rezero.png +0 -0
  38. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/rotary.png +0 -0
  39. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/sandwich.png +0 -0
  41. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/scalenorm.png +0 -0
  43. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/talking-heads.png +0 -0
  44. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/topk-attention.png +0 -0
  45. {x_transformers-2.1.28 → x_transformers-2.1.30}/images/xval.png +0 -0
  46. {x_transformers-2.1.28 → x_transformers-2.1.30}/train_belief_state.py +0 -0
  47. {x_transformers-2.1.28 → x_transformers-2.1.30}/train_copy.py +0 -0
  48. {x_transformers-2.1.28 → x_transformers-2.1.30}/train_enwik8.py +0 -0
  49. {x_transformers-2.1.28 → x_transformers-2.1.30}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.1.28 → x_transformers-2.1.30}/train_parity.py +0 -0
  51. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/continuous.py +0 -0
  55. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/multi_input.py +0 -0
  57. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/neo_mlp.py +0 -0
  58. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/nonautoregressive_wrapper.py +0 -0
  59. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.28 → x_transformers-2.1.30}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.28
3
+ Version: 2.1.30
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2455,4 +2455,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2455
2455
  }
2456
2456
  ```
2457
2457
 
2458
+ ```bibtex
2459
+ @inproceedings{Zhu2025TransformersWN,
2460
+ title = {Transformers without Normalization},
2461
+ author = {Jiachen Zhu and Xinlei Chen and Kaiming He and Yann LeCun and Zhuang Liu},
2462
+ year = {2025},
2463
+ url = {https://api.semanticscholar.org/CorpusID:276961218}
2464
+ }
2465
+ ```
2466
+
2458
2467
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2407,4 +2407,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2407
2407
  }
2408
2408
  ```
2409
2409
 
2410
+ ```bibtex
2411
+ @inproceedings{Zhu2025TransformersWN,
2412
+ title = {Transformers without Normalization},
2413
+ author = {Jiachen Zhu and Xinlei Chen and Kaiming He and Yann LeCun and Zhuang Liu},
2414
+ year = {2025},
2415
+ url = {https://api.semanticscholar.org/CorpusID:276961218}
2416
+ }
2417
+ ```
2418
+
2410
2419
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.28"
3
+ version = "2.1.30"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -697,10 +697,12 @@ def test_lime(
697
697
  @pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5))
698
698
  @pytest.mark.parametrize('goal_suffix', (False, True))
699
699
  @pytest.mark.parametrize('pred_distance', (False, True))
700
+ @pytest.mark.parametrize('variable_len', (False, True))
700
701
  def test_belief_state_wrapper(
701
702
  backward_ar_loss_weight,
702
703
  goal_suffix,
703
- pred_distance
704
+ pred_distance,
705
+ variable_len
704
706
  ):
705
707
  from x_transformers.belief_state_wrapper import BeliefStateWrapper
706
708
 
@@ -735,7 +737,12 @@ def test_belief_state_wrapper(
735
737
 
736
738
  seq = torch.randint(0, 20000, (2, 16))
737
739
 
738
- loss = model(seq) # backwards happen automatically
740
+ lens = None
741
+
742
+ if variable_len:
743
+ lens = torch.randint(4, 16, (2,))
744
+
745
+ loss = model(seq, lens = lens) # backwards happen automatically
739
746
 
740
747
  suffix = None
741
748
  if goal_suffix:
@@ -743,3 +750,20 @@ def test_belief_state_wrapper(
743
750
 
744
751
  sampled = model.generate_with_suffix_cond(seq[:, :1], 16, suffix = suffix)
745
752
  assert sampled.shape == (2, 16)
753
+
754
+ def test_dynamic_tanh():
755
+ model = TransformerWrapper(
756
+ num_tokens = 20000,
757
+ max_seq_len = 1024,
758
+ attn_layers = Decoder(
759
+ dim = 128,
760
+ depth = 6,
761
+ heads = 8,
762
+ use_dynamic_tanh = True,
763
+ dynamic_tanh_init_alpha = 1.5
764
+ )
765
+ )
766
+
767
+ x = torch.randint(0, 20000, (2, 1024))
768
+
769
+ model(x)
@@ -258,6 +258,8 @@ class BeliefStateWrapper(Module):
258
258
 
259
259
  # handle variable length sequences
260
260
 
261
+ seq_for_labels = seq
262
+
261
263
  if exists(lens):
262
264
  mask = einx.less('j, i -> i j', arange(seq_len, device = device), lens)
263
265
  seq_for_labels = torch.where(mask, seq, -1)
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
12
+ from torch import nn, einsum, tensor, Tensor, cat, stack, arange, is_tensor
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -266,7 +266,6 @@ class TokenEmbedding(Module):
266
266
  return
267
267
  nn.init.kaiming_normal_(self.emb.weight)
268
268
 
269
-
270
269
  # positional embeddings
271
270
 
272
271
  class AbsolutePositionalEmbedding(Module):
@@ -849,6 +848,31 @@ class MultiheadRMSNorm(Module):
849
848
  def forward(self, x):
850
849
  return self.rmsnorm(x) * (self.gamma + 1.)
851
850
 
851
+ class DynamicTanh(Module):
852
+ """ https://arxiv.org/abs/2503.10622 """
853
+ def __init__(
854
+ self,
855
+ init_alpha = 1.,
856
+ gamma = 1.,
857
+ beta = 0.,
858
+ unit_offset = False
859
+ ):
860
+ super().__init__()
861
+ self.pre_tanh_scale = nn.Parameter(tensor(init_alpha))
862
+
863
+ self.gamma = nn.Parameter(tensor(init_alpha))
864
+ self.beta = nn.Parameter(tensor(init_alpha))
865
+
866
+ self.unit_offset = int(unit_offset)
867
+
868
+ nn.init.constant_(self.pre_tanh_scale, 1. - float(unit_offset))
869
+ nn.init.constant_(self.gamma, 1. - float(unit_offset))
870
+
871
+ def forward(self, x):
872
+ pre_tanh_scale = self.pre_tanh_scale + self.unit_offset
873
+ gamma = self.gamma + self.unit_offset
874
+ return (x * pre_tanh_scale).tanh() * gamma + self.beta
875
+
852
876
  # residual and residual gates
853
877
 
854
878
  class Residual(Module):
@@ -1863,6 +1887,8 @@ class AttentionLayers(Module):
1863
1887
  only_cross = False,
1864
1888
  use_scalenorm = False,
1865
1889
  use_rmsnorm = False,
1890
+ use_dynamic_tanh = False,
1891
+ dynamic_tanh_init_alpha = 1.,
1866
1892
  use_simple_rmsnorm = False,
1867
1893
  use_adaptive_layernorm = False,
1868
1894
  use_adaptive_rmsnorm = False,
@@ -2012,8 +2038,9 @@ class AttentionLayers(Module):
2012
2038
 
2013
2039
  # determine norm
2014
2040
 
2015
- assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2041
+ assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2016
2042
 
2043
+ norm_fn = None
2017
2044
  norm_need_condition = False
2018
2045
  dim_condition = default(dim_condition, dim)
2019
2046
  dim_condition_mult = 1
@@ -2027,6 +2054,8 @@ class AttentionLayers(Module):
2027
2054
  norm_class = RMSNorm
2028
2055
  elif use_simple_rmsnorm:
2029
2056
  norm_class = SimpleRMSNorm
2057
+ elif use_dynamic_tanh:
2058
+ norm_fn = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2030
2059
  elif use_adaptive_layernorm:
2031
2060
  norm_need_condition = True
2032
2061
  norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
@@ -2036,7 +2065,8 @@ class AttentionLayers(Module):
2036
2065
  else:
2037
2066
  norm_class = LayerNorm
2038
2067
 
2039
- norm_fn = partial(norm_class, dim)
2068
+ if not exists(norm_fn):
2069
+ norm_fn = partial(norm_class, dim)
2040
2070
 
2041
2071
  if not norm_need_condition and norm_add_unit_offset:
2042
2072
  # researcher Ohad Rubin shares in a blog post by adding an offset to gammas, they can be subjected to weight decay safely
File without changes