x-transformers 2.11.20__tar.gz → 2.11.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.11.20 → x_transformers-2.11.23}/PKG-INFO +1 -1
  2. {x_transformers-2.11.20 → x_transformers-2.11.23}/pyproject.toml +1 -1
  3. {x_transformers-2.11.20 → x_transformers-2.11.23}/tests/test_x_transformers.py +9 -2
  4. {x_transformers-2.11.20 → x_transformers-2.11.23}/train_enwik8.py +3 -1
  5. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/x_transformers.py +22 -4
  6. {x_transformers-2.11.20 → x_transformers-2.11.23}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.11.20 → x_transformers-2.11.23}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.11.20 → x_transformers-2.11.23}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.11.20 → x_transformers-2.11.23}/.gitignore +0 -0
  10. {x_transformers-2.11.20 → x_transformers-2.11.23}/LICENSE +0 -0
  11. {x_transformers-2.11.20 → x_transformers-2.11.23}/README.md +0 -0
  12. {x_transformers-2.11.20 → x_transformers-2.11.23}/data/README.md +0 -0
  13. {x_transformers-2.11.20 → x_transformers-2.11.23}/data/enwik8.gz +0 -0
  14. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/all-attention.png +0 -0
  15. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/deepnorm.png +0 -0
  18. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/fcm.png +0 -0
  24. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/ffglu.png +0 -0
  25. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/flash-attention.png +0 -0
  26. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/gate_values.png +0 -0
  27. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/gating.png +0 -0
  28. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/macaron-1.png +0 -0
  30. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/macaron-2.png +0 -0
  31. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/normformer.png +0 -0
  33. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/pia.png +0 -0
  34. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/resi_dual.png +0 -0
  36. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/residual_attn.png +0 -0
  37. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/rezero.png +0 -0
  38. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/rotary.png +0 -0
  39. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/sandwich.png +0 -0
  41. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/scalenorm.png +0 -0
  43. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/talking-heads.png +0 -0
  44. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/topk-attention.png +0 -0
  45. {x_transformers-2.11.20 → x_transformers-2.11.23}/images/xval.png +0 -0
  46. {x_transformers-2.11.20 → x_transformers-2.11.23}/train_belief_state.py +0 -0
  47. {x_transformers-2.11.20 → x_transformers-2.11.23}/train_copy.py +0 -0
  48. {x_transformers-2.11.20 → x_transformers-2.11.23}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.11.20 → x_transformers-2.11.23}/train_free.py +0 -0
  50. {x_transformers-2.11.20 → x_transformers-2.11.23}/train_gpt_vae.py +0 -0
  51. {x_transformers-2.11.20 → x_transformers-2.11.23}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.11.20 → x_transformers-2.11.23}/train_parity.py +0 -0
  53. {x_transformers-2.11.20 → x_transformers-2.11.23}/train_with_muon.py +0 -0
  54. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/__init__.py +0 -0
  55. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/attend.py +0 -0
  56. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/autoregressive_wrapper.py +0 -0
  57. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/free_transformer.py +0 -0
  62. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/gpt_vae.py +0 -0
  63. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/multi_input.py +0 -0
  64. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/neo_mlp.py +0 -0
  65. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/nonautoregressive_wrapper.py +0 -0
  66. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/up_wrapper.py +0 -0
  67. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.11.20 → x_transformers-2.11.23}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.20
3
+ Version: 2.11.23
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.11.20"
3
+ version = "2.11.23"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1463,7 +1463,12 @@ def test_kv_input_residual():
1463
1463
 
1464
1464
  assert tokens.shape == out.shape
1465
1465
 
1466
- def test_belief_attn():
1466
+ @param('orthog_project', (False, True))
1467
+ @param('orthog_project_per_head', (False, True))
1468
+ def test_belief_attn(
1469
+ orthog_project,
1470
+ orthog_project_per_head
1471
+ ):
1467
1472
  from x_transformers import TransformerWrapper, Decoder
1468
1473
 
1469
1474
  model = TransformerWrapper(
@@ -1473,8 +1478,10 @@ def test_belief_attn():
1473
1478
  dim = 512,
1474
1479
  depth = 6,
1475
1480
  heads = 8,
1481
+ attn_kv_heads = 4,
1476
1482
  rotary_pos_emb = True,
1477
- attn_orthog_projected_values = True
1483
+ attn_orthog_projected_values = orthog_project,
1484
+ attn_orthog_projected_values_per_head = orthog_project_per_head
1478
1485
  )
1479
1486
  )
1480
1487
 
@@ -52,7 +52,9 @@ model = TransformerWrapper(
52
52
  dim = 512,
53
53
  depth = 6,
54
54
  heads = 8,
55
- rotary_pos_emb = True
55
+ rotary_pos_emb = True,
56
+ attn_orthog_projected_values = True,
57
+ attn_orthog_projected_values_per_head = True
56
58
  )
57
59
  )
58
60
 
@@ -1397,6 +1397,7 @@ class Attention(Module):
1397
1397
  logit_softclamp_value = 50.,
1398
1398
  learned_value_residual_mix = False,
1399
1399
  orthog_projected_values = False, # https://openreview.net/forum?id=Ard2QzPAUK
1400
+ orthog_projected_values_per_head = False,
1400
1401
  laser = False, # https://arxiv.org/abs/2411.03493v1
1401
1402
  laser_softclamp_value = 15.,
1402
1403
  qkv_receive_diff_residuals = False,
@@ -1430,6 +1431,7 @@ class Attention(Module):
1430
1431
  assert divisible_by(heads, kv_heads)
1431
1432
 
1432
1433
  self.kv_heads = kv_heads
1434
+ self.groups = heads // kv_heads
1433
1435
 
1434
1436
  q_dim = dim_head * heads
1435
1437
  k_dim = dim_head * kv_heads
@@ -1627,6 +1629,9 @@ class Attention(Module):
1627
1629
  # "belief attention" - iclr 2026
1628
1630
 
1629
1631
  self.orthog_projected_values = orthog_projected_values
1632
+ self.orthog_projected_values_per_head = orthog_projected_values_per_head
1633
+
1634
+ out_dim *= max(1, int(orthog_projected_values) + int(orthog_projected_values_per_head))
1630
1635
 
1631
1636
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1632
1637
 
@@ -2069,11 +2074,24 @@ class Attention(Module):
2069
2074
  gates = self.to_v_gate(x)
2070
2075
  out = out * self.to_v_gate_activation(gates)
2071
2076
 
2072
- # maybe return orthogonal projected - "belief" attention
2077
+ # maybe orthogonal projected weighted values - "belief" attention
2078
+
2079
+ if self.orthog_projected_values or self.orthog_projected_values_per_head:
2080
+ orthog_projected = []
2081
+ v_for_proj = repeat(orig_values, 'b h n d -> b n (g h d)', g = self.groups)
2082
+
2083
+ if self.orthog_projected_values:
2084
+ projected = orthog_project(out, v_for_proj)
2085
+ orthog_projected.append(projected)
2086
+
2087
+ if self.orthog_projected_values_per_head:
2088
+ v_for_proj = rearrange(v_for_proj, 'b n (h d) -> b n h d', h = h)
2089
+ out = rearrange(out, 'b n (h d) -> b n h d', h = h)
2090
+ projected = orthog_project(out, v_for_proj)
2091
+ projected = rearrange(projected, 'b n h d -> b n (h d)')
2092
+ orthog_projected.append(projected)
2073
2093
 
2074
- if self.orthog_projected_values:
2075
- merged_v = self.merge_heads(orig_values)
2076
- out = orthog_project(out, merged_v)
2094
+ out = cat(orthog_projected, dim = -1)
2077
2095
 
2078
2096
  # combine the heads
2079
2097