x-transformers 2.4.12__tar.gz → 2.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.4.12 → x_transformers-2.5.0}/PKG-INFO +1 -1
  2. {x_transformers-2.4.12 → x_transformers-2.5.0}/pyproject.toml +1 -1
  3. {x_transformers-2.4.12 → x_transformers-2.5.0}/tests/test_x_transformers.py +30 -2
  4. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/__init__.py +2 -1
  5. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/x_transformers.py +63 -12
  6. {x_transformers-2.4.12 → x_transformers-2.5.0}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.4.12 → x_transformers-2.5.0}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.4.12 → x_transformers-2.5.0}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.4.12 → x_transformers-2.5.0}/.gitignore +0 -0
  10. {x_transformers-2.4.12 → x_transformers-2.5.0}/LICENSE +0 -0
  11. {x_transformers-2.4.12 → x_transformers-2.5.0}/README.md +0 -0
  12. {x_transformers-2.4.12 → x_transformers-2.5.0}/data/README.md +0 -0
  13. {x_transformers-2.4.12 → x_transformers-2.5.0}/data/enwik8.gz +0 -0
  14. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/all-attention.png +0 -0
  15. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/deepnorm.png +0 -0
  18. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/fcm.png +0 -0
  24. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/ffglu.png +0 -0
  25. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/flash-attention.png +0 -0
  26. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/gate_values.png +0 -0
  27. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/gating.png +0 -0
  28. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/macaron-1.png +0 -0
  30. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/macaron-2.png +0 -0
  31. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/normformer.png +0 -0
  33. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/pia.png +0 -0
  34. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/resi_dual.png +0 -0
  36. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/residual_attn.png +0 -0
  37. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/rezero.png +0 -0
  38. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/rotary.png +0 -0
  39. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/sandwich.png +0 -0
  41. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/scalenorm.png +0 -0
  43. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/talking-heads.png +0 -0
  44. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/topk-attention.png +0 -0
  45. {x_transformers-2.4.12 → x_transformers-2.5.0}/images/xval.png +0 -0
  46. {x_transformers-2.4.12 → x_transformers-2.5.0}/train_belief_state.py +0 -0
  47. {x_transformers-2.4.12 → x_transformers-2.5.0}/train_copy.py +0 -0
  48. {x_transformers-2.4.12 → x_transformers-2.5.0}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.4.12 → x_transformers-2.5.0}/train_enwik8.py +0 -0
  50. {x_transformers-2.4.12 → x_transformers-2.5.0}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.4.12 → x_transformers-2.5.0}/train_parity.py +0 -0
  52. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.12
3
+ Version: 2.5.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.4.12"
3
+ version = "2.5.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1086,7 +1086,7 @@ def add_attn_pool():
1086
1086
  num_tokens = 256,
1087
1087
  max_seq_len = 1024,
1088
1088
  attn_pool = True,
1089
- num_attn_pool_queries = 3,
1089
+ num_pooled_tokens = 3,
1090
1090
  attn_layers = Decoder(
1091
1091
  dim = 512,
1092
1092
  depth = 12,
@@ -1110,7 +1110,6 @@ def test_up(
1110
1110
  num_tokens = 256,
1111
1111
  max_seq_len = 1024,
1112
1112
  attn_pool = True,
1113
- num_attn_pool_queries = 3,
1114
1113
  attn_layers = Decoder(
1115
1114
  dim = 512,
1116
1115
  depth = 12,
@@ -1153,3 +1152,32 @@ def test_beam_search(stochastic):
1153
1152
 
1154
1153
  assert beams.shape == (4, 2, 10)
1155
1154
  assert scores.shape == (4, 2)
1155
+
1156
+
1157
+ @pytest.mark.parametrize('num_pooled_tokens', (1, 3))
1158
+ @pytest.mark.parametrize('attn_pool_depth', (1, 3))
1159
+ def test_attn_pooler(
1160
+ num_pooled_tokens,
1161
+ attn_pool_depth
1162
+ ):
1163
+
1164
+ model = TransformerWrapper(
1165
+ num_tokens = 256,
1166
+ max_seq_len = 1024,
1167
+ attn_pool = True,
1168
+ num_pooled_tokens = num_pooled_tokens,
1169
+ attn_pool_depth = attn_pool_depth,
1170
+ dim_pooled_tokens = 77,
1171
+ attn_layers = Encoder(
1172
+ dim = 512,
1173
+ depth = 12,
1174
+ heads = 8,
1175
+ attn_value_rmsnorm = True
1176
+ ),
1177
+ )
1178
+
1179
+ x = torch.randint(0, 256, (2, 10))
1180
+
1181
+ out = model(x)
1182
+
1183
+ assert out.shape == (2, num_pooled_tokens, 77)
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
4
4
  Decoder,
5
5
  PrefixDecoder,
6
6
  CrossAttender,
7
+ AttentionPool,
7
8
  Attention,
8
9
  FeedForward,
9
10
  RMSNorm,
10
11
  AdaptiveRMSNorm,
11
12
  TransformerWrapper,
12
- ViTransformerWrapper
13
+ ViTransformerWrapper,
13
14
  )
14
15
 
15
16
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
@@ -51,6 +51,7 @@ class LayerIntermediates:
51
51
  attn_pooled_tokens: Tensor | None = None
52
52
  memory_tokens: Tensor | None = None
53
53
  logit_entropies: Tensor | None = None
54
+ logits: Tensor | None = None
54
55
  cache_length: int = 0
55
56
 
56
57
  LinearNoBias = partial(nn.Linear, bias = False)
@@ -1304,6 +1305,7 @@ class Attention(Module):
1304
1305
  qk_norm_groups = 1,
1305
1306
  qk_norm_scale = 10,
1306
1307
  qk_norm_dim_scale = False,
1308
+ value_rmsnorm = False, # used in alphagenome and bytedance's GR3 for further stability
1307
1309
  l2_distance = False,
1308
1310
  sigmoid = False,
1309
1311
  selective = False,
@@ -1458,6 +1460,10 @@ class Attention(Module):
1458
1460
  assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1459
1461
  assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
1460
1462
 
1463
+ # value rms norm
1464
+
1465
+ self.value_rmsnorm = MultiheadRMSNorm(dim_head, heads = heads) if value_rmsnorm else None
1466
+
1461
1467
  # contextual positional encoding
1462
1468
  # https://arxiv.org/html/2405.18719v2
1463
1469
 
@@ -1697,6 +1703,10 @@ class Attention(Module):
1697
1703
  q = q * self.qk_norm_q_scale
1698
1704
  k = k * self.qk_norm_k_scale
1699
1705
 
1706
+ # maybe value rmsnorm
1707
+
1708
+ v = maybe(self.value_rmsnorm)(v)
1709
+
1700
1710
  # take care of caching
1701
1711
 
1702
1712
  if not is_multi_latent_attn:
@@ -2740,6 +2750,45 @@ class CrossAttender(AttentionLayers):
2740
2750
  def __init__(self, **kwargs):
2741
2751
  super().__init__(cross_attend = True, only_cross = True, **kwargs)
2742
2752
 
2753
+ class AttentionPool(Module):
2754
+ def __init__(
2755
+ self,
2756
+ dim,
2757
+ num_pooled_tokens = 1,
2758
+ dim_context = None,
2759
+ add_residual = False,
2760
+ depth = 1,
2761
+ squeeze_output = None,
2762
+ attn_kwargs: dict = dict()
2763
+ ):
2764
+ super().__init__()
2765
+ dim_context = default(dim_context, dim)
2766
+
2767
+ squeeze_output = default(squeeze_output, False)
2768
+ assert not (squeeze_output and num_pooled_tokens > 1)
2769
+
2770
+ self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
2771
+
2772
+ if depth > 1:
2773
+ assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
2774
+ self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, **attn_kwargs)
2775
+ else:
2776
+ self.pooler = Attention(dim = dim, dim_context = dim_context, **attn_kwargs)
2777
+
2778
+ self.add_residual = add_residual
2779
+
2780
+ def forward(self, context, mask = None):
2781
+ batch = context.shape[0]
2782
+
2783
+ queries = repeat(self.queries, 'n d -> b n d', b = batch)
2784
+
2785
+ pooled = self.pooler(queries, context, context_mask = mask)
2786
+
2787
+ if self.add_residual:
2788
+ pooled = pooled + queries
2789
+
2790
+ return pooled
2791
+
2743
2792
  class ViTransformerWrapper(Module):
2744
2793
  def __init__(
2745
2794
  self,
@@ -2851,8 +2900,9 @@ class TransformerWrapper(Module):
2851
2900
  use_cls_token = False,
2852
2901
  num_cls_tokens = 1,
2853
2902
  attn_pool = False,
2854
- num_attn_pool_queries = 1,
2855
- dim_attn_pool_query = None,
2903
+ num_pooled_tokens = 1,
2904
+ attn_pool_depth = 1,
2905
+ dim_pooled_tokens = None,
2856
2906
  squeeze_out_last_dim = False,
2857
2907
  token_emb: TokenEmbedding | None = None,
2858
2908
  mixture_of_softmax = False,
@@ -2949,10 +2999,7 @@ class TransformerWrapper(Module):
2949
2999
  self.attn_pool = None
2950
3000
 
2951
3001
  if attn_pool:
2952
- self.attn_pool = Attention(dim = default(dim_attn_pool_query, dim), dim_context = dim)
2953
-
2954
- self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
2955
- nn.init.normal_(self.attn_pool_queries, std = 0.02)
3002
+ self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth)
2956
3003
 
2957
3004
  # whether to average pool the embed (`global average pool`)
2958
3005
 
@@ -3250,7 +3297,6 @@ class TransformerWrapper(Module):
3250
3297
  if self.average_pool_embed:
3251
3298
  x = masked_mean(x, mask = orig_mask, dim = 1)
3252
3299
 
3253
-
3254
3300
  # cls token(s)
3255
3301
 
3256
3302
  if exists(self.cls_token):
@@ -3263,13 +3309,15 @@ class TransformerWrapper(Module):
3263
3309
 
3264
3310
  # attention pool
3265
3311
 
3266
- if exists(self.attn_pool) and return_intermediates:
3267
- queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
3312
+ is_encoder = not self.attn_layers.causal
3313
+ return_pooled_tokens = exists(self.attn_pool) and is_encoder
3268
3314
 
3269
- attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
3315
+ if (
3316
+ exists(self.attn_pool) and
3317
+ (return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
3318
+ ):
3270
3319
 
3271
- if attn_pooled_tokens.shape[1] == 1:
3272
- attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
3320
+ attn_pooled_tokens = self.attn_pool(x, mask = mask)
3273
3321
 
3274
3322
  intermediates.attn_pooled_tokens = attn_pooled_tokens
3275
3323
 
@@ -3318,6 +3366,9 @@ class TransformerWrapper(Module):
3318
3366
  out = (x, intermediates)
3319
3367
  elif return_embeddings:
3320
3368
  out = x
3369
+ elif return_pooled_tokens:
3370
+ intermediates.logits = logits
3371
+ out = attn_pooled_tokens
3321
3372
  else:
3322
3373
  out = logits
3323
3374
 
File without changes