x-transformers 2.4.14__tar.gz → 2.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.14 → x_transformers-2.5.0}/PKG-INFO +1 -1
  2. {x_transformers-2.4.14 → x_transformers-2.5.0}/pyproject.toml +1 -1
  3. {x_transformers-2.4.14 → x_transformers-2.5.0}/tests/test_x_transformers.py +30 -2
  4. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/__init__.py +2 -1
  5. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/x_transformers.py +54 -12
  6. {x_transformers-2.4.14 → x_transformers-2.5.0}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.4.14 → x_transformers-2.5.0}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.4.14 → x_transformers-2.5.0}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.4.14 → x_transformers-2.5.0}/.gitignore +0 -0
  10. {x_transformers-2.4.14 → x_transformers-2.5.0}/LICENSE +0 -0
  11. {x_transformers-2.4.14 → x_transformers-2.5.0}/README.md +0 -0
  12. {x_transformers-2.4.14 → x_transformers-2.5.0}/data/README.md +0 -0
  13. {x_transformers-2.4.14 → x_transformers-2.5.0}/data/enwik8.gz +0 -0
  14. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/all-attention.png +0 -0
  15. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/deepnorm.png +0 -0
  18. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/fcm.png +0 -0
  24. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/ffglu.png +0 -0
  25. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/flash-attention.png +0 -0
  26. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/gate_values.png +0 -0
  27. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/gating.png +0 -0
  28. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/macaron-1.png +0 -0
  30. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/macaron-2.png +0 -0
  31. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/normformer.png +0 -0
  33. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/pia.png +0 -0
  34. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/resi_dual.png +0 -0
  36. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/residual_attn.png +0 -0
  37. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/rezero.png +0 -0
  38. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/rotary.png +0 -0
  39. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/sandwich.png +0 -0
  41. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/scalenorm.png +0 -0
  43. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/talking-heads.png +0 -0
  44. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/topk-attention.png +0 -0
  45. {x_transformers-2.4.14 → x_transformers-2.5.0}/images/xval.png +0 -0
  46. {x_transformers-2.4.14 → x_transformers-2.5.0}/train_belief_state.py +0 -0
  47. {x_transformers-2.4.14 → x_transformers-2.5.0}/train_copy.py +0 -0
  48. {x_transformers-2.4.14 → x_transformers-2.5.0}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.4.14 → x_transformers-2.5.0}/train_enwik8.py +0 -0
  50. {x_transformers-2.4.14 → x_transformers-2.5.0}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.4.14 → x_transformers-2.5.0}/train_parity.py +0 -0
  52. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.14
3
+ Version: 2.5.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.14"
3
+ version = "2.5.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1086,7 +1086,7 @@ def add_attn_pool():
1086
1086
  num_tokens = 256,
1087
1087
  max_seq_len = 1024,
1088
1088
  attn_pool = True,
1089
- num_attn_pool_queries = 3,
1089
+ num_pooled_tokens = 3,
1090
1090
  attn_layers = Decoder(
1091
1091
  dim = 512,
1092
1092
  depth = 12,
@@ -1110,7 +1110,6 @@ def test_up(
1110
1110
  num_tokens = 256,
1111
1111
  max_seq_len = 1024,
1112
1112
  attn_pool = True,
1113
- num_attn_pool_queries = 3,
1114
1113
  attn_layers = Decoder(
1115
1114
  dim = 512,
1116
1115
  depth = 12,
@@ -1153,3 +1152,32 @@ def test_beam_search(stochastic):
1153
1152
 
1154
1153
  assert beams.shape == (4, 2, 10)
1155
1154
  assert scores.shape == (4, 2)
1155
+
1156
+
1157
+ @pytest.mark.parametrize('num_pooled_tokens', (1, 3))
1158
+ @pytest.mark.parametrize('attn_pool_depth', (1, 3))
1159
+ def test_attn_pooler(
1160
+ num_pooled_tokens,
1161
+ attn_pool_depth
1162
+ ):
1163
+
1164
+ model = TransformerWrapper(
1165
+ num_tokens = 256,
1166
+ max_seq_len = 1024,
1167
+ attn_pool = True,
1168
+ num_pooled_tokens = num_pooled_tokens,
1169
+ attn_pool_depth = attn_pool_depth,
1170
+ dim_pooled_tokens = 77,
1171
+ attn_layers = Encoder(
1172
+ dim = 512,
1173
+ depth = 12,
1174
+ heads = 8,
1175
+ attn_value_rmsnorm = True
1176
+ ),
1177
+ )
1178
+
1179
+ x = torch.randint(0, 256, (2, 10))
1180
+
1181
+ out = model(x)
1182
+
1183
+ assert out.shape == (2, num_pooled_tokens, 77)
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
4
4
  Decoder,
5
5
  PrefixDecoder,
6
6
  CrossAttender,
7
+ AttentionPool,
7
8
  Attention,
8
9
  FeedForward,
9
10
  RMSNorm,
10
11
  AdaptiveRMSNorm,
11
12
  TransformerWrapper,
12
- ViTransformerWrapper
13
+ ViTransformerWrapper,
13
14
  )
14
15
 
15
16
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
@@ -51,6 +51,7 @@ class LayerIntermediates:
51
51
  attn_pooled_tokens: Tensor | None = None
52
52
  memory_tokens: Tensor | None = None
53
53
  logit_entropies: Tensor | None = None
54
+ logits: Tensor | None = None
54
55
  cache_length: int = 0
55
56
 
56
57
  LinearNoBias = partial(nn.Linear, bias = False)
@@ -2749,6 +2750,45 @@ class CrossAttender(AttentionLayers):
2749
2750
  def __init__(self, **kwargs):
2750
2751
  super().__init__(cross_attend = True, only_cross = True, **kwargs)
2751
2752
 
2753
+ class AttentionPool(Module):
2754
+ def __init__(
2755
+ self,
2756
+ dim,
2757
+ num_pooled_tokens = 1,
2758
+ dim_context = None,
2759
+ add_residual = False,
2760
+ depth = 1,
2761
+ squeeze_output = None,
2762
+ attn_kwargs: dict = dict()
2763
+ ):
2764
+ super().__init__()
2765
+ dim_context = default(dim_context, dim)
2766
+
2767
+ squeeze_output = default(squeeze_output, False)
2768
+ assert not (squeeze_output and num_pooled_tokens > 1)
2769
+
2770
+ self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
2771
+
2772
+ if depth > 1:
2773
+ assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
2774
+ self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, **attn_kwargs)
2775
+ else:
2776
+ self.pooler = Attention(dim = dim, dim_context = dim_context, **attn_kwargs)
2777
+
2778
+ self.add_residual = add_residual
2779
+
2780
+ def forward(self, context, mask = None):
2781
+ batch = context.shape[0]
2782
+
2783
+ queries = repeat(self.queries, 'n d -> b n d', b = batch)
2784
+
2785
+ pooled = self.pooler(queries, context, context_mask = mask)
2786
+
2787
+ if self.add_residual:
2788
+ pooled = pooled + queries
2789
+
2790
+ return pooled
2791
+
2752
2792
  class ViTransformerWrapper(Module):
2753
2793
  def __init__(
2754
2794
  self,
@@ -2860,8 +2900,9 @@ class TransformerWrapper(Module):
2860
2900
  use_cls_token = False,
2861
2901
  num_cls_tokens = 1,
2862
2902
  attn_pool = False,
2863
- num_attn_pool_queries = 1,
2864
- dim_attn_pool_query = None,
2903
+ num_pooled_tokens = 1,
2904
+ attn_pool_depth = 1,
2905
+ dim_pooled_tokens = None,
2865
2906
  squeeze_out_last_dim = False,
2866
2907
  token_emb: TokenEmbedding | None = None,
2867
2908
  mixture_of_softmax = False,
@@ -2958,10 +2999,7 @@ class TransformerWrapper(Module):
2958
2999
  self.attn_pool = None
2959
3000
 
2960
3001
  if attn_pool:
2961
- self.attn_pool = Attention(dim = default(dim_attn_pool_query, dim), dim_context = dim)
2962
-
2963
- self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
2964
- nn.init.normal_(self.attn_pool_queries, std = 0.02)
3002
+ self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth)
2965
3003
 
2966
3004
  # whether to average pool the embed (`global average pool`)
2967
3005
 
@@ -3259,7 +3297,6 @@ class TransformerWrapper(Module):
3259
3297
  if self.average_pool_embed:
3260
3298
  x = masked_mean(x, mask = orig_mask, dim = 1)
3261
3299
 
3262
-
3263
3300
  # cls token(s)
3264
3301
 
3265
3302
  if exists(self.cls_token):
@@ -3272,13 +3309,15 @@ class TransformerWrapper(Module):
3272
3309
 
3273
3310
  # attention pool
3274
3311
 
3275
- if exists(self.attn_pool) and return_intermediates:
3276
- queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
3312
+ is_encoder = not self.attn_layers.causal
3313
+ return_pooled_tokens = exists(self.attn_pool) and is_encoder
3277
3314
 
3278
- attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
3315
+ if (
3316
+ exists(self.attn_pool) and
3317
+ (return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
3318
+ ):
3279
3319
 
3280
- if attn_pooled_tokens.shape[1] == 1:
3281
- attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
3320
+ attn_pooled_tokens = self.attn_pool(x, mask = mask)
3282
3321
 
3283
3322
  intermediates.attn_pooled_tokens = attn_pooled_tokens
3284
3323
 
@@ -3327,6 +3366,9 @@ class TransformerWrapper(Module):
3327
3366
  out = (x, intermediates)
3328
3367
  elif return_embeddings:
3329
3368
  out = x
3369
+ elif return_pooled_tokens:
3370
+ intermediates.logits = logits
3371
+ out = attn_pooled_tokens
3330
3372
  else:
3331
3373
  out = logits
3332
3374
 
File without changes