x-transformers 2.11.23__tar.gz → 2.12.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.23 → x_transformers-2.12.0}/PKG-INFO +25 -1
  2. {x_transformers-2.11.23 → x_transformers-2.12.0}/README.md +24 -0
  3. {x_transformers-2.11.23 → x_transformers-2.12.0}/pyproject.toml +1 -1
  4. {x_transformers-2.11.23 → x_transformers-2.12.0}/tests/test_x_transformers.py +38 -0
  5. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/x_transformers.py +96 -2
  6. {x_transformers-2.11.23 → x_transformers-2.12.0}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.11.23 → x_transformers-2.12.0}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.11.23 → x_transformers-2.12.0}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.11.23 → x_transformers-2.12.0}/.gitignore +0 -0
  10. {x_transformers-2.11.23 → x_transformers-2.12.0}/LICENSE +0 -0
  11. {x_transformers-2.11.23 → x_transformers-2.12.0}/data/README.md +0 -0
  12. {x_transformers-2.11.23 → x_transformers-2.12.0}/data/enwik8.gz +0 -0
  13. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/all-attention.png +0 -0
  14. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/deepnorm.png +0 -0
  17. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/fcm.png +0 -0
  23. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/ffglu.png +0 -0
  24. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/flash-attention.png +0 -0
  25. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/gate_values.png +0 -0
  26. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/gating.png +0 -0
  27. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/macaron-1.png +0 -0
  29. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/macaron-2.png +0 -0
  30. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/normformer.png +0 -0
  32. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/pia.png +0 -0
  33. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/resi_dual.png +0 -0
  35. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/residual_attn.png +0 -0
  36. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/rezero.png +0 -0
  37. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/rotary.png +0 -0
  38. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/sandwich.png +0 -0
  40. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/scalenorm.png +0 -0
  42. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/talking-heads.png +0 -0
  43. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/topk-attention.png +0 -0
  44. {x_transformers-2.11.23 → x_transformers-2.12.0}/images/xval.png +0 -0
  45. {x_transformers-2.11.23 → x_transformers-2.12.0}/train_belief_state.py +0 -0
  46. {x_transformers-2.11.23 → x_transformers-2.12.0}/train_copy.py +0 -0
  47. {x_transformers-2.11.23 → x_transformers-2.12.0}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.11.23 → x_transformers-2.12.0}/train_enwik8.py +0 -0
  49. {x_transformers-2.11.23 → x_transformers-2.12.0}/train_free.py +0 -0
  50. {x_transformers-2.11.23 → x_transformers-2.12.0}/train_gpt_vae.py +0 -0
  51. {x_transformers-2.11.23 → x_transformers-2.12.0}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.11.23 → x_transformers-2.12.0}/train_parity.py +0 -0
  53. {x_transformers-2.11.23 → x_transformers-2.12.0}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/free_transformer.py +0 -0
  62. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/gpt_vae.py +0 -0
  63. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/multi_input.py +0 -0
  64. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/neo_mlp.py +0 -0
  65. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  66. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/up_wrapper.py +0 -0
  67. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.23 → x_transformers-2.12.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.23
3
+ Version: 2.12.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2618,4 +2618,28 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2618
2618
  }
2619
2619
  ```
2620
2620
 
2621
+ ```bibtex
2622
+ @misc{chen2025strongernormalizationfreetransformers,
2623
+ title = {Stronger Normalization-Free Transformers},
2624
+ author = {Mingzhi Chen and Taiming Lu and Jiachen Zhu and Mingjie Sun and Zhuang Liu},
2625
+ year = {2025},
2626
+ eprint = {2512.10938},
2627
+ archivePrefix = {arXiv},
2628
+ primaryClass = {cs.LG},
2629
+ url = {https://arxiv.org/abs/2512.10938},
2630
+ }
2631
+ ```
2632
+
2633
+ ```bibtex
2634
+ @misc{gopalakrishnan2025decouplingwhatwherepolar,
2635
+ title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
2636
+ author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
2637
+ year = {2025},
2638
+ eprint = {2509.10534},
2639
+ archivePrefix = {arXiv},
2640
+ primaryClass = {cs.LG},
2641
+ url = {https://arxiv.org/abs/2509.10534},
2642
+ }
2643
+ ```
2644
+
2621
2645
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2569,4 +2569,28 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2569
2569
  }
2570
2570
  ```
2571
2571
 
2572
+ ```bibtex
2573
+ @misc{chen2025strongernormalizationfreetransformers,
2574
+ title = {Stronger Normalization-Free Transformers},
2575
+ author = {Mingzhi Chen and Taiming Lu and Jiachen Zhu and Mingjie Sun and Zhuang Liu},
2576
+ year = {2025},
2577
+ eprint = {2512.10938},
2578
+ archivePrefix = {arXiv},
2579
+ primaryClass = {cs.LG},
2580
+ url = {https://arxiv.org/abs/2512.10938},
2581
+ }
2582
+ ```
2583
+
2584
+ ```bibtex
2585
+ @misc{gopalakrishnan2025decouplingwhatwherepolar,
2586
+ title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
2587
+ author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
2588
+ year = {2025},
2589
+ eprint = {2509.10534},
2590
+ archivePrefix = {arXiv},
2591
+ primaryClass = {cs.LG},
2592
+ url = {https://arxiv.org/abs/2509.10534},
2593
+ }
2594
+ ```
2595
+
2572
2596
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.23"
3
+ version = "2.12.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1488,3 +1488,41 @@ def test_belief_attn(
1488
1488
  x = torch.randint(0, 256, (1, 10))
1489
1489
 
1490
1490
  logits = model(x)
1491
+
1492
+ def test_derf():
1493
+ from x_transformers import TransformerWrapper, Decoder
1494
+
1495
+ model = TransformerWrapper(
1496
+ num_tokens = 256,
1497
+ max_seq_len = 1024,
1498
+ attn_layers = Decoder(
1499
+ dim = 512,
1500
+ depth = 6,
1501
+ heads = 8,
1502
+ attn_kv_heads = 4,
1503
+ rotary_pos_emb = True,
1504
+ use_derf = True
1505
+ )
1506
+ )
1507
+
1508
+ x = torch.randint(0, 256, (1, 10))
1509
+
1510
+ logits = model(x)
1511
+
1512
+ def test_pope():
1513
+ from x_transformers import TransformerWrapper, Decoder
1514
+
1515
+ model = TransformerWrapper(
1516
+ num_tokens = 256,
1517
+ max_seq_len = 1024,
1518
+ attn_layers = Decoder(
1519
+ dim = 512,
1520
+ depth = 6,
1521
+ heads = 8,
1522
+ polar_pos_emb = True,
1523
+ )
1524
+ )
1525
+
1526
+ x = torch.randint(0, 256, (1, 10))
1527
+
1528
+ logits = model(x)
@@ -779,6 +779,49 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
779
779
 
780
780
  return out.type(orig_dtype)
781
781
 
782
+ class PolarEmbedding(Module):
783
+ """ https://arxiv.org/abs/2509.10534 """
784
+
785
+ def __init__(
786
+ self,
787
+ dim,
788
+ bias_uniform_init = False,
789
+ base = 10000,
790
+ ):
791
+ super().__init__()
792
+ inv_freq = 1. / (base ** (arange(0, dim).float() / dim))
793
+ self.register_buffer('inv_freq', inv_freq)
794
+
795
+ self.learned_bias = nn.Parameter(torch.zeros(dim))
796
+
797
+ if bias_uniform_init:
798
+ self.learned_bias.uniform_(-2. * math.pi, 0.)
799
+
800
+ @autocast('cuda', enabled = False)
801
+ def forward(self, t, offset = 0):
802
+ max_pos = t.max() + 1
803
+
804
+ if t.ndim == 1:
805
+ t = rearrange(t, 'n -> 1 n')
806
+
807
+ freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq)
808
+
809
+ bias = self.learned_bias.clamp(-2. * math.pi, 0.)
810
+
811
+ return freqs, bias
812
+
813
+ @autocast('cuda', enabled = False)
814
+ def apply_polar_pos_emb(t, freqs):
815
+ rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
816
+ freqs = freqs[:, -seq_len:]
817
+
818
+ t = t.float()
819
+
820
+ t = F.softplus(t)
821
+ out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
822
+
823
+ return out.type(orig_dtype)
824
+
782
825
  # norms
783
826
 
784
827
  class Scale(Module):
@@ -941,6 +984,31 @@ class DynamicTanh(Module):
941
984
  gamma = self.gamma + self.gamma_offset
942
985
  return (x * pre_tanh_scale).tanh() * gamma + self.beta
943
986
 
987
+ class Derf(Module):
988
+ """ https://arxiv.org/abs/2512.10938 """
989
+ def __init__(
990
+ self,
991
+ dim,
992
+ init_alpha = 0.5,
993
+ init_bias = 0.,
994
+ unit_offset = False
995
+ ):
996
+ super().__init__()
997
+ scale_offset = 1. if unit_offset else 0.
998
+
999
+ self.alpha = nn.Parameter(tensor(init_alpha) - scale_offset)
1000
+ self.s = nn.Parameter(tensor(init_bias))
1001
+
1002
+ self.gamma = nn.Parameter(torch.ones(dim) - scale_offset)
1003
+ self.beta = nn.Parameter(torch.zeros(dim))
1004
+
1005
+ self.scale_offset = scale_offset
1006
+
1007
+ def forward(self, x):
1008
+ x = x * (self.alpha + self.scale_offset) + self.s
1009
+ activated = torch.erf(x)
1010
+ return activated * (self.gamma + self.scale_offset) + self.beta
1011
+
944
1012
  # residual and residual gates
945
1013
 
946
1014
  class Residual(Module):
@@ -1720,6 +1788,7 @@ class Attention(Module):
1720
1788
  attn_bias = None,
1721
1789
  rotary_pos_emb = None,
1722
1790
  context_rotary_pos_emb = None,
1791
+ polar_pos_emb = None,
1723
1792
  pos = None, # for custom alibi positions
1724
1793
  prev_attn = None,
1725
1794
  mem = None,
@@ -1871,6 +1940,11 @@ class Attention(Module):
1871
1940
  q = cat((q_rest, q), dim = 1)
1872
1941
  k = cat((k_rest, k), dim = 1)
1873
1942
 
1943
+ if exists(polar_pos_emb):
1944
+ freqs, bias = polar_pos_emb
1945
+ q = apply_polar_pos_emb(q, freqs)
1946
+ k = apply_polar_pos_emb(k, freqs + bias)
1947
+
1874
1948
  input_mask = context_mask
1875
1949
 
1876
1950
  if not exists(input_mask) and not has_context:
@@ -2123,6 +2197,7 @@ class AttentionLayers(Module):
2123
2197
  use_scalenorm = False,
2124
2198
  use_rmsnorm = False,
2125
2199
  use_dynamic_tanh = False,
2200
+ use_derf = False,
2126
2201
  dynamic_tanh_init_alpha = 1.,
2127
2202
  use_simple_rmsnorm = False,
2128
2203
  use_adaptive_layernorm = False,
@@ -2148,6 +2223,8 @@ class AttentionLayers(Module):
2148
2223
  rotary_xpos_scale_base = 512,
2149
2224
  rotary_base_rescale_factor = 1.,
2150
2225
  rotate_num_heads = None,
2226
+ polar_pos_emb = False,
2227
+ polar_bias_uniform_init = False,
2151
2228
  weight_tie_layers = False,
2152
2229
  custom_layers: tuple[str, ...] | None = None,
2153
2230
  layers_execute_order: tuple[int, ...] | None = None,
@@ -2240,9 +2317,14 @@ class AttentionLayers(Module):
2240
2317
  if verbose and rotary_emb_dim < 32:
2241
2318
  logger.warning('when training language model, rotary embedding dimension should be at least 32')
2242
2319
 
2320
+ assert at_most_one_of(rotary_pos_emb, polar_pos_emb), f'either rotary positional embedding or polar positional embedding can be turned on'
2243
2321
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
2244
2322
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
2245
2323
 
2324
+ # polar positional embedding (PoPE) - https://arxiv.org/abs/2509.10534
2325
+
2326
+ self.polar_pos_emb = PolarEmbedding(dim_head, polar_bias_uniform_init) if polar_pos_emb else None
2327
+
2246
2328
  assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), dynamic tanh, or T5 relative positional bias'
2247
2329
  assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
2248
2330
 
@@ -2277,7 +2359,7 @@ class AttentionLayers(Module):
2277
2359
 
2278
2360
  # determine norm
2279
2361
 
2280
- assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2362
+ assert at_most_one_of(use_scalenorm, use_rmsnorm, use_dynamic_tanh, use_derf, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
2281
2363
 
2282
2364
  norm_need_condition = False
2283
2365
  dim_condition = default(dim_condition, dim)
@@ -2295,6 +2377,8 @@ class AttentionLayers(Module):
2295
2377
  elif use_dynamic_tanh:
2296
2378
  assert pre_norm, 'dynamic tanh norm only tested for pre-norm'
2297
2379
  norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2380
+ elif use_derf:
2381
+ norm_class = Derf
2298
2382
  elif use_adaptive_layernorm:
2299
2383
  norm_need_condition = True
2300
2384
  norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
@@ -2598,6 +2682,7 @@ class AttentionLayers(Module):
2598
2682
  cache_age = 1,
2599
2683
  return_hiddens = False,
2600
2684
  rotary_pos_emb = None,
2685
+ polar_pos_emb = None,
2601
2686
  pos = None,
2602
2687
  context_pos = None,
2603
2688
  attn_bias = None,
@@ -2693,6 +2778,15 @@ class AttentionLayers(Module):
2693
2778
  context_rotary_pos_emb = context_rotary_pos_emb
2694
2779
  )
2695
2780
 
2781
+ # polar positions
2782
+
2783
+ if exists(self.polar_pos_emb):
2784
+ if not exists(polar_pos_emb):
2785
+ if not exists(pos):
2786
+ pos = arange(x.shape[1] + seq_pos_offset, device = x.device)
2787
+
2788
+ polar_pos_emb = self.polar_pos_emb(pos)
2789
+
2696
2790
  # assume cached key / values
2697
2791
 
2698
2792
  prev_cache_length = 0
@@ -2882,7 +2976,7 @@ class AttentionLayers(Module):
2882
2976
  # forward depending on layer type
2883
2977
 
2884
2978
  if layer_type == 'a':
2885
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2979
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, polar_pos_emb = polar_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2886
2980
  elif layer_type == 'c':
2887
2981
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), kv_input_residual = next(cross_attn_kv_residuals_iter, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
2888
2982
  elif layer_type == 'f':