x-transformers 2.4.14__tar.gz → 2.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.14 → x_transformers-2.5.1}/PKG-INFO +1 -1
  2. {x_transformers-2.4.14 → x_transformers-2.5.1}/pyproject.toml +1 -1
  3. {x_transformers-2.4.14 → x_transformers-2.5.1}/tests/test_x_transformers.py +30 -2
  4. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/__init__.py +2 -1
  5. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/x_transformers.py +61 -12
  6. {x_transformers-2.4.14 → x_transformers-2.5.1}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.4.14 → x_transformers-2.5.1}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.4.14 → x_transformers-2.5.1}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.4.14 → x_transformers-2.5.1}/.gitignore +0 -0
  10. {x_transformers-2.4.14 → x_transformers-2.5.1}/LICENSE +0 -0
  11. {x_transformers-2.4.14 → x_transformers-2.5.1}/README.md +0 -0
  12. {x_transformers-2.4.14 → x_transformers-2.5.1}/data/README.md +0 -0
  13. {x_transformers-2.4.14 → x_transformers-2.5.1}/data/enwik8.gz +0 -0
  14. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/all-attention.png +0 -0
  15. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/deepnorm.png +0 -0
  18. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/fcm.png +0 -0
  24. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/ffglu.png +0 -0
  25. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/flash-attention.png +0 -0
  26. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/gate_values.png +0 -0
  27. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/gating.png +0 -0
  28. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/macaron-1.png +0 -0
  30. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/macaron-2.png +0 -0
  31. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/normformer.png +0 -0
  33. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/pia.png +0 -0
  34. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/resi_dual.png +0 -0
  36. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/residual_attn.png +0 -0
  37. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/rezero.png +0 -0
  38. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/rotary.png +0 -0
  39. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/sandwich.png +0 -0
  41. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/scalenorm.png +0 -0
  43. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/talking-heads.png +0 -0
  44. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/topk-attention.png +0 -0
  45. {x_transformers-2.4.14 → x_transformers-2.5.1}/images/xval.png +0 -0
  46. {x_transformers-2.4.14 → x_transformers-2.5.1}/train_belief_state.py +0 -0
  47. {x_transformers-2.4.14 → x_transformers-2.5.1}/train_copy.py +0 -0
  48. {x_transformers-2.4.14 → x_transformers-2.5.1}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.4.14 → x_transformers-2.5.1}/train_enwik8.py +0 -0
  50. {x_transformers-2.4.14 → x_transformers-2.5.1}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.4.14 → x_transformers-2.5.1}/train_parity.py +0 -0
  52. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.14
3
+ Version: 2.5.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.14"
3
+ version = "2.5.1"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1086,7 +1086,7 @@ def add_attn_pool():
1086
1086
  num_tokens = 256,
1087
1087
  max_seq_len = 1024,
1088
1088
  attn_pool = True,
1089
- num_attn_pool_queries = 3,
1089
+ num_pooled_tokens = 3,
1090
1090
  attn_layers = Decoder(
1091
1091
  dim = 512,
1092
1092
  depth = 12,
@@ -1110,7 +1110,6 @@ def test_up(
1110
1110
  num_tokens = 256,
1111
1111
  max_seq_len = 1024,
1112
1112
  attn_pool = True,
1113
- num_attn_pool_queries = 3,
1114
1113
  attn_layers = Decoder(
1115
1114
  dim = 512,
1116
1115
  depth = 12,
@@ -1153,3 +1152,32 @@ def test_beam_search(stochastic):
1153
1152
 
1154
1153
  assert beams.shape == (4, 2, 10)
1155
1154
  assert scores.shape == (4, 2)
1155
+
1156
+
1157
+ @pytest.mark.parametrize('num_pooled_tokens', (1, 3))
1158
+ @pytest.mark.parametrize('attn_pool_depth', (1, 3))
1159
+ def test_attn_pooler(
1160
+ num_pooled_tokens,
1161
+ attn_pool_depth
1162
+ ):
1163
+
1164
+ model = TransformerWrapper(
1165
+ num_tokens = 256,
1166
+ max_seq_len = 1024,
1167
+ attn_pool = True,
1168
+ num_pooled_tokens = num_pooled_tokens,
1169
+ attn_pool_depth = attn_pool_depth,
1170
+ dim_pooled_tokens = 77,
1171
+ attn_layers = Encoder(
1172
+ dim = 512,
1173
+ depth = 12,
1174
+ heads = 8,
1175
+ attn_value_rmsnorm = True
1176
+ ),
1177
+ )
1178
+
1179
+ x = torch.randint(0, 256, (2, 10))
1180
+
1181
+ out = model(x)
1182
+
1183
+ assert out.shape == (2, num_pooled_tokens, 77)
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
4
4
  Decoder,
5
5
  PrefixDecoder,
6
6
  CrossAttender,
7
+ AttentionPool,
7
8
  Attention,
8
9
  FeedForward,
9
10
  RMSNorm,
10
11
  AdaptiveRMSNorm,
11
12
  TransformerWrapper,
12
- ViTransformerWrapper
13
+ ViTransformerWrapper,
13
14
  )
14
15
 
15
16
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
@@ -51,6 +51,7 @@ class LayerIntermediates:
51
51
  attn_pooled_tokens: Tensor | None = None
52
52
  memory_tokens: Tensor | None = None
53
53
  logit_entropies: Tensor | None = None
54
+ logits: Tensor | None = None
54
55
  cache_length: int = 0
55
56
 
56
57
  LinearNoBias = partial(nn.Linear, bias = False)
@@ -2037,6 +2038,9 @@ class AttentionLayers(Module):
2037
2038
  self.causal = causal
2038
2039
  self.layers = ModuleList([])
2039
2040
 
2041
+ self.attn_heads = heads
2042
+ self.attn_dim_head = dim_head
2043
+
2040
2044
  # routing related
2041
2045
  # 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
2042
2046
  # 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
@@ -2749,6 +2753,49 @@ class CrossAttender(AttentionLayers):
2749
2753
  def __init__(self, **kwargs):
2750
2754
  super().__init__(cross_attend = True, only_cross = True, **kwargs)
2751
2755
 
2756
+ class AttentionPool(Module):
2757
+ def __init__(
2758
+ self,
2759
+ dim,
2760
+ num_pooled_tokens = 1,
2761
+ dim_context = None,
2762
+ add_residual = False,
2763
+ depth = 1,
2764
+ heads = 8,
2765
+ dim_head = 64,
2766
+ squeeze_output = None,
2767
+ attn_kwargs: dict = dict()
2768
+ ):
2769
+ super().__init__()
2770
+ dim_context = default(dim_context, dim)
2771
+
2772
+ squeeze_output = default(squeeze_output, False)
2773
+ assert not (squeeze_output and num_pooled_tokens > 1)
2774
+
2775
+ self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
2776
+
2777
+ if depth > 1:
2778
+ assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
2779
+ attn_kwargs = {f'attn_{k}': v for k, v in attn_kwargs.items()}
2780
+
2781
+ self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, heads = heads, attn_dim_head = dim_head, )
2782
+ else:
2783
+ self.pooler = Attention(dim = dim, dim_context = dim_context, heads = heads, dim_head = dim_head, **attn_kwargs)
2784
+
2785
+ self.add_residual = add_residual
2786
+
2787
+ def forward(self, context, mask = None):
2788
+ batch = context.shape[0]
2789
+
2790
+ queries = repeat(self.queries, 'n d -> b n d', b = batch)
2791
+
2792
+ pooled = self.pooler(queries, context, context_mask = mask)
2793
+
2794
+ if self.add_residual:
2795
+ pooled = pooled + queries
2796
+
2797
+ return pooled
2798
+
2752
2799
  class ViTransformerWrapper(Module):
2753
2800
  def __init__(
2754
2801
  self,
@@ -2860,8 +2907,9 @@ class TransformerWrapper(Module):
2860
2907
  use_cls_token = False,
2861
2908
  num_cls_tokens = 1,
2862
2909
  attn_pool = False,
2863
- num_attn_pool_queries = 1,
2864
- dim_attn_pool_query = None,
2910
+ num_pooled_tokens = 1,
2911
+ attn_pool_depth = 1,
2912
+ dim_pooled_tokens = None,
2865
2913
  squeeze_out_last_dim = False,
2866
2914
  token_emb: TokenEmbedding | None = None,
2867
2915
  mixture_of_softmax = False,
@@ -2958,10 +3006,7 @@ class TransformerWrapper(Module):
2958
3006
  self.attn_pool = None
2959
3007
 
2960
3008
  if attn_pool:
2961
- self.attn_pool = Attention(dim = default(dim_attn_pool_query, dim), dim_context = dim)
2962
-
2963
- self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
2964
- nn.init.normal_(self.attn_pool_queries, std = 0.02)
3009
+ self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth, heads = self.attn_layers.attn_heads, dim_head = self.attn_layers.attn_dim_head)
2965
3010
 
2966
3011
  # whether to average pool the embed (`global average pool`)
2967
3012
 
@@ -3259,7 +3304,6 @@ class TransformerWrapper(Module):
3259
3304
  if self.average_pool_embed:
3260
3305
  x = masked_mean(x, mask = orig_mask, dim = 1)
3261
3306
 
3262
-
3263
3307
  # cls token(s)
3264
3308
 
3265
3309
  if exists(self.cls_token):
@@ -3272,13 +3316,15 @@ class TransformerWrapper(Module):
3272
3316
 
3273
3317
  # attention pool
3274
3318
 
3275
- if exists(self.attn_pool) and return_intermediates:
3276
- queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
3319
+ is_encoder = not self.attn_layers.causal
3320
+ return_pooled_tokens = exists(self.attn_pool) and is_encoder
3277
3321
 
3278
- attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
3322
+ if (
3323
+ exists(self.attn_pool) and
3324
+ (return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
3325
+ ):
3279
3326
 
3280
- if attn_pooled_tokens.shape[1] == 1:
3281
- attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
3327
+ attn_pooled_tokens = self.attn_pool(x, mask = mask)
3282
3328
 
3283
3329
  intermediates.attn_pooled_tokens = attn_pooled_tokens
3284
3330
 
@@ -3327,6 +3373,9 @@ class TransformerWrapper(Module):
3327
3373
  out = (x, intermediates)
3328
3374
  elif return_embeddings:
3329
3375
  out = x
3376
+ elif return_pooled_tokens:
3377
+ intermediates.logits = logits
3378
+ out = attn_pooled_tokens
3330
3379
  else:
3331
3380
  out = logits
3332
3381
 
File without changes