x-transformers 2.4.14__py3-none-any.whl → 2.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
4
4
  Decoder,
5
5
  PrefixDecoder,
6
6
  CrossAttender,
7
+ AttentionPool,
7
8
  Attention,
8
9
  FeedForward,
9
10
  RMSNorm,
10
11
  AdaptiveRMSNorm,
11
12
  TransformerWrapper,
12
- ViTransformerWrapper
13
+ ViTransformerWrapper,
13
14
  )
14
15
 
15
16
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
@@ -51,6 +51,7 @@ class LayerIntermediates:
51
51
  attn_pooled_tokens: Tensor | None = None
52
52
  memory_tokens: Tensor | None = None
53
53
  logit_entropies: Tensor | None = None
54
+ logits: Tensor | None = None
54
55
  cache_length: int = 0
55
56
 
56
57
  LinearNoBias = partial(nn.Linear, bias = False)
@@ -2749,6 +2750,45 @@ class CrossAttender(AttentionLayers):
2749
2750
  def __init__(self, **kwargs):
2750
2751
  super().__init__(cross_attend = True, only_cross = True, **kwargs)
2751
2752
 
2753
+ class AttentionPool(Module):
2754
+ def __init__(
2755
+ self,
2756
+ dim,
2757
+ num_pooled_tokens = 1,
2758
+ dim_context = None,
2759
+ add_residual = False,
2760
+ depth = 1,
2761
+ squeeze_output = None,
2762
+ attn_kwargs: dict = dict()
2763
+ ):
2764
+ super().__init__()
2765
+ dim_context = default(dim_context, dim)
2766
+
2767
+ squeeze_output = default(squeeze_output, False)
2768
+ assert not (squeeze_output and num_pooled_tokens > 1)
2769
+
2770
+ self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
2771
+
2772
+ if depth > 1:
2773
+ assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
2774
+ self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, **attn_kwargs)
2775
+ else:
2776
+ self.pooler = Attention(dim = dim, dim_context = dim_context, **attn_kwargs)
2777
+
2778
+ self.add_residual = add_residual
2779
+
2780
+ def forward(self, context, mask = None):
2781
+ batch = context.shape[0]
2782
+
2783
+ queries = repeat(self.queries, 'n d -> b n d', b = batch)
2784
+
2785
+ pooled = self.pooler(queries, context, context_mask = mask)
2786
+
2787
+ if self.add_residual:
2788
+ pooled = pooled + queries
2789
+
2790
+ return pooled
2791
+
2752
2792
  class ViTransformerWrapper(Module):
2753
2793
  def __init__(
2754
2794
  self,
@@ -2860,8 +2900,9 @@ class TransformerWrapper(Module):
2860
2900
  use_cls_token = False,
2861
2901
  num_cls_tokens = 1,
2862
2902
  attn_pool = False,
2863
- num_attn_pool_queries = 1,
2864
- dim_attn_pool_query = None,
2903
+ num_pooled_tokens = 1,
2904
+ attn_pool_depth = 1,
2905
+ dim_pooled_tokens = None,
2865
2906
  squeeze_out_last_dim = False,
2866
2907
  token_emb: TokenEmbedding | None = None,
2867
2908
  mixture_of_softmax = False,
@@ -2958,10 +2999,7 @@ class TransformerWrapper(Module):
2958
2999
  self.attn_pool = None
2959
3000
 
2960
3001
  if attn_pool:
2961
- self.attn_pool = Attention(dim = default(dim_attn_pool_query, dim), dim_context = dim)
2962
-
2963
- self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
2964
- nn.init.normal_(self.attn_pool_queries, std = 0.02)
3002
+ self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth)
2965
3003
 
2966
3004
  # whether to average pool the embed (`global average pool`)
2967
3005
 
@@ -3259,7 +3297,6 @@ class TransformerWrapper(Module):
3259
3297
  if self.average_pool_embed:
3260
3298
  x = masked_mean(x, mask = orig_mask, dim = 1)
3261
3299
 
3262
-
3263
3300
  # cls token(s)
3264
3301
 
3265
3302
  if exists(self.cls_token):
@@ -3272,13 +3309,15 @@ class TransformerWrapper(Module):
3272
3309
 
3273
3310
  # attention pool
3274
3311
 
3275
- if exists(self.attn_pool) and return_intermediates:
3276
- queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
3312
+ is_encoder = not self.attn_layers.causal
3313
+ return_pooled_tokens = exists(self.attn_pool) and is_encoder
3277
3314
 
3278
- attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
3315
+ if (
3316
+ exists(self.attn_pool) and
3317
+ (return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
3318
+ ):
3279
3319
 
3280
- if attn_pooled_tokens.shape[1] == 1:
3281
- attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
3320
+ attn_pooled_tokens = self.attn_pool(x, mask = mask)
3282
3321
 
3283
3322
  intermediates.attn_pooled_tokens = attn_pooled_tokens
3284
3323
 
@@ -3327,6 +3366,9 @@ class TransformerWrapper(Module):
3327
3366
  out = (x, intermediates)
3328
3367
  elif return_embeddings:
3329
3368
  out = x
3369
+ elif return_pooled_tokens:
3370
+ intermediates.logits = logits
3371
+ out = attn_pooled_tokens
3330
3372
  else:
3331
3373
  out = logits
3332
3374
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.14
3
+ Version: 2.5.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,4 +1,4 @@
1
- x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
1
+ x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
2
2
  x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
3
3
  x_transformers/autoregressive_wrapper.py,sha256=y798kS9_VvPOY_5Ilits_64aXNqYvGuilsky1y07ryE,17834
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=bNp6hWuuqn7x5yKFfYocvu3X1YCjpfwrWMh-kAanS48,117906
12
+ x_transformers/x_transformers.py,sha256=NOTTbqDk5qZEY2MPpdIwJv4BvGGhXt_nIffrgQDXTf4,119346
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.4.14.dist-info/METADATA,sha256=KScRZIcmRXCv8NnhzQ3Uzo9uHE2oI51chzj78Wh_OVo,90224
16
- x_transformers-2.4.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.4.14.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.4.14.dist-info/RECORD,,
15
+ x_transformers-2.5.0.dist-info/METADATA,sha256=ZcxK61msWcsm8vEPG-FirnPTnwG6HNKtvy2ZoLluJHM,90223
16
+ x_transformers-2.5.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.5.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.5.0.dist-info/RECORD,,