x-transformers 2.4.14__py3-none-any.whl → 2.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
4
4
  Decoder,
5
5
  PrefixDecoder,
6
6
  CrossAttender,
7
+ AttentionPool,
7
8
  Attention,
8
9
  FeedForward,
9
10
  RMSNorm,
10
11
  AdaptiveRMSNorm,
11
12
  TransformerWrapper,
12
- ViTransformerWrapper
13
+ ViTransformerWrapper,
13
14
  )
14
15
 
15
16
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
@@ -51,6 +51,7 @@ class LayerIntermediates:
51
51
  attn_pooled_tokens: Tensor | None = None
52
52
  memory_tokens: Tensor | None = None
53
53
  logit_entropies: Tensor | None = None
54
+ logits: Tensor | None = None
54
55
  cache_length: int = 0
55
56
 
56
57
  LinearNoBias = partial(nn.Linear, bias = False)
@@ -2037,6 +2038,9 @@ class AttentionLayers(Module):
2037
2038
  self.causal = causal
2038
2039
  self.layers = ModuleList([])
2039
2040
 
2041
+ self.attn_heads = heads
2042
+ self.attn_dim_head = dim_head
2043
+
2040
2044
  # routing related
2041
2045
  # 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
2042
2046
  # 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
@@ -2749,6 +2753,49 @@ class CrossAttender(AttentionLayers):
2749
2753
  def __init__(self, **kwargs):
2750
2754
  super().__init__(cross_attend = True, only_cross = True, **kwargs)
2751
2755
 
2756
+ class AttentionPool(Module):
2757
+ def __init__(
2758
+ self,
2759
+ dim,
2760
+ num_pooled_tokens = 1,
2761
+ dim_context = None,
2762
+ add_residual = False,
2763
+ depth = 1,
2764
+ heads = 8,
2765
+ dim_head = 64,
2766
+ squeeze_output = None,
2767
+ attn_kwargs: dict = dict()
2768
+ ):
2769
+ super().__init__()
2770
+ dim_context = default(dim_context, dim)
2771
+
2772
+ squeeze_output = default(squeeze_output, False)
2773
+ assert not (squeeze_output and num_pooled_tokens > 1)
2774
+
2775
+ self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
2776
+
2777
+ if depth > 1:
2778
+ assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
2779
+ attn_kwargs = {f'attn_{k}': v for k, v in attn_kwargs.items()}
2780
+
2781
+ self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, heads = heads, attn_dim_head = dim_head, )
2782
+ else:
2783
+ self.pooler = Attention(dim = dim, dim_context = dim_context, heads = heads, dim_head = dim_head, **attn_kwargs)
2784
+
2785
+ self.add_residual = add_residual
2786
+
2787
+ def forward(self, context, mask = None):
2788
+ batch = context.shape[0]
2789
+
2790
+ queries = repeat(self.queries, 'n d -> b n d', b = batch)
2791
+
2792
+ pooled = self.pooler(queries, context, context_mask = mask)
2793
+
2794
+ if self.add_residual:
2795
+ pooled = pooled + queries
2796
+
2797
+ return pooled
2798
+
2752
2799
  class ViTransformerWrapper(Module):
2753
2800
  def __init__(
2754
2801
  self,
@@ -2860,8 +2907,9 @@ class TransformerWrapper(Module):
2860
2907
  use_cls_token = False,
2861
2908
  num_cls_tokens = 1,
2862
2909
  attn_pool = False,
2863
- num_attn_pool_queries = 1,
2864
- dim_attn_pool_query = None,
2910
+ num_pooled_tokens = 1,
2911
+ attn_pool_depth = 1,
2912
+ dim_pooled_tokens = None,
2865
2913
  squeeze_out_last_dim = False,
2866
2914
  token_emb: TokenEmbedding | None = None,
2867
2915
  mixture_of_softmax = False,
@@ -2958,10 +3006,7 @@ class TransformerWrapper(Module):
2958
3006
  self.attn_pool = None
2959
3007
 
2960
3008
  if attn_pool:
2961
- self.attn_pool = Attention(dim = default(dim_attn_pool_query, dim), dim_context = dim)
2962
-
2963
- self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
2964
- nn.init.normal_(self.attn_pool_queries, std = 0.02)
3009
+ self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth, heads = self.attn_layers.attn_heads, dim_head = self.attn_layers.attn_dim_head)
2965
3010
 
2966
3011
  # whether to average pool the embed (`global average pool`)
2967
3012
 
@@ -3259,7 +3304,6 @@ class TransformerWrapper(Module):
3259
3304
  if self.average_pool_embed:
3260
3305
  x = masked_mean(x, mask = orig_mask, dim = 1)
3261
3306
 
3262
-
3263
3307
  # cls token(s)
3264
3308
 
3265
3309
  if exists(self.cls_token):
@@ -3272,13 +3316,15 @@ class TransformerWrapper(Module):
3272
3316
 
3273
3317
  # attention pool
3274
3318
 
3275
- if exists(self.attn_pool) and return_intermediates:
3276
- queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
3319
+ is_encoder = not self.attn_layers.causal
3320
+ return_pooled_tokens = exists(self.attn_pool) and is_encoder
3277
3321
 
3278
- attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
3322
+ if (
3323
+ exists(self.attn_pool) and
3324
+ (return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
3325
+ ):
3279
3326
 
3280
- if attn_pooled_tokens.shape[1] == 1:
3281
- attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
3327
+ attn_pooled_tokens = self.attn_pool(x, mask = mask)
3282
3328
 
3283
3329
  intermediates.attn_pooled_tokens = attn_pooled_tokens
3284
3330
 
@@ -3327,6 +3373,9 @@ class TransformerWrapper(Module):
3327
3373
  out = (x, intermediates)
3328
3374
  elif return_embeddings:
3329
3375
  out = x
3376
+ elif return_pooled_tokens:
3377
+ intermediates.logits = logits
3378
+ out = attn_pooled_tokens
3330
3379
  else:
3331
3380
  out = logits
3332
3381
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.14
3
+ Version: 2.5.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,4 +1,4 @@
1
- x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
1
+ x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
2
2
  x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
3
3
  x_transformers/autoregressive_wrapper.py,sha256=y798kS9_VvPOY_5Ilits_64aXNqYvGuilsky1y07ryE,17834
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=bNp6hWuuqn7x5yKFfYocvu3X1YCjpfwrWMh-kAanS48,117906
12
+ x_transformers/x_transformers.py,sha256=l6RSwfvWxeSQWXiZVQzomyn8CyMvfmXr8oBVfb9zbBM,119679
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.4.14.dist-info/METADATA,sha256=KScRZIcmRXCv8NnhzQ3Uzo9uHE2oI51chzj78Wh_OVo,90224
16
- x_transformers-2.4.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.4.14.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.4.14.dist-info/RECORD,,
15
+ x_transformers-2.5.1.dist-info/METADATA,sha256=H2y1hNleReXPTU2G1VX1vowYLQFFmaotou1QmXqJM-Q,90223
16
+ x_transformers-2.5.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.5.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.5.1.dist-info/RECORD,,