x-transformers 2.5.0__tar.gz → 2.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.5.0 → x_transformers-2.5.1}/PKG-INFO +1 -1
  2. {x_transformers-2.5.0 → x_transformers-2.5.1}/pyproject.toml +1 -1
  3. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/x_transformers.py +10 -3
  4. {x_transformers-2.5.0 → x_transformers-2.5.1}/.github/FUNDING.yml +0 -0
  5. {x_transformers-2.5.0 → x_transformers-2.5.1}/.github/workflows/python-publish.yml +0 -0
  6. {x_transformers-2.5.0 → x_transformers-2.5.1}/.github/workflows/python-test.yaml +0 -0
  7. {x_transformers-2.5.0 → x_transformers-2.5.1}/.gitignore +0 -0
  8. {x_transformers-2.5.0 → x_transformers-2.5.1}/LICENSE +0 -0
  9. {x_transformers-2.5.0 → x_transformers-2.5.1}/README.md +0 -0
  10. {x_transformers-2.5.0 → x_transformers-2.5.1}/data/README.md +0 -0
  11. {x_transformers-2.5.0 → x_transformers-2.5.1}/data/enwik8.gz +0 -0
  12. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/all-attention.png +0 -0
  13. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/attention-on-attention.png +0 -0
  14. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/cosine-sim-attention.png +0 -0
  15. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/deepnorm.png +0 -0
  16. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/dynamic-pos-bias-linear.png +0 -0
  17. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/dynamic-pos-bias-log.png +0 -0
  18. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  19. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/dynamic-pos-bias.png +0 -0
  20. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/enhanced-recurrence.png +0 -0
  21. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/fcm.png +0 -0
  22. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/ffglu.png +0 -0
  23. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/flash-attention.png +0 -0
  24. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/gate_values.png +0 -0
  25. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/gating.png +0 -0
  26. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/length-extrapolation-scale.png +0 -0
  27. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/macaron-1.png +0 -0
  28. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/macaron-2.png +0 -0
  29. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/memory-transformer.png +0 -0
  30. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/normformer.png +0 -0
  31. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/pia.png +0 -0
  32. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/qknorm-analysis.png +0 -0
  33. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/resi_dual.png +0 -0
  34. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/residual_attn.png +0 -0
  35. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/rezero.png +0 -0
  36. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/rotary.png +0 -0
  37. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/sandwich-2.png +0 -0
  38. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/sandwich.png +0 -0
  39. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/sandwich_norm.png +0 -0
  40. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/scalenorm.png +0 -0
  41. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/talking-heads.png +0 -0
  42. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/topk-attention.png +0 -0
  43. {x_transformers-2.5.0 → x_transformers-2.5.1}/images/xval.png +0 -0
  44. {x_transformers-2.5.0 → x_transformers-2.5.1}/tests/test_x_transformers.py +0 -0
  45. {x_transformers-2.5.0 → x_transformers-2.5.1}/train_belief_state.py +0 -0
  46. {x_transformers-2.5.0 → x_transformers-2.5.1}/train_copy.py +0 -0
  47. {x_transformers-2.5.0 → x_transformers-2.5.1}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.5.0 → x_transformers-2.5.1}/train_enwik8.py +0 -0
  49. {x_transformers-2.5.0 → x_transformers-2.5.1}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.5.0 → x_transformers-2.5.1}/train_parity.py +0 -0
  51. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.5.0 → x_transformers-2.5.1}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.5.0
3
+ Version: 2.5.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.5.0"
3
+ version = "2.5.1"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -2038,6 +2038,9 @@ class AttentionLayers(Module):
2038
2038
  self.causal = causal
2039
2039
  self.layers = ModuleList([])
2040
2040
 
2041
+ self.attn_heads = heads
2042
+ self.attn_dim_head = dim_head
2043
+
2041
2044
  # routing related
2042
2045
  # 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
2043
2046
  # 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
@@ -2758,6 +2761,8 @@ class AttentionPool(Module):
2758
2761
  dim_context = None,
2759
2762
  add_residual = False,
2760
2763
  depth = 1,
2764
+ heads = 8,
2765
+ dim_head = 64,
2761
2766
  squeeze_output = None,
2762
2767
  attn_kwargs: dict = dict()
2763
2768
  ):
@@ -2771,9 +2776,11 @@ class AttentionPool(Module):
2771
2776
 
2772
2777
  if depth > 1:
2773
2778
  assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
2774
- self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, **attn_kwargs)
2779
+ attn_kwargs = {f'attn_{k}': v for k, v in attn_kwargs.items()}
2780
+
2781
+ self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, heads = heads, attn_dim_head = dim_head, )
2775
2782
  else:
2776
- self.pooler = Attention(dim = dim, dim_context = dim_context, **attn_kwargs)
2783
+ self.pooler = Attention(dim = dim, dim_context = dim_context, heads = heads, dim_head = dim_head, **attn_kwargs)
2777
2784
 
2778
2785
  self.add_residual = add_residual
2779
2786
 
@@ -2999,7 +3006,7 @@ class TransformerWrapper(Module):
2999
3006
  self.attn_pool = None
3000
3007
 
3001
3008
  if attn_pool:
3002
- self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth)
3009
+ self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth, heads = self.attn_layers.attn_heads, dim_head = self.attn_layers.attn_dim_head)
3003
3010
 
3004
3011
  # whether to average pool the embed (`global average pool`)
3005
3012
 
File without changes
File without changes