x-transformers 2.3.26__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,6 +47,8 @@ class LayerIntermediates:
47
47
  layer_hiddens: list[Tensor] | None = None
48
48
  attn_z_loss: Tensor | None = None
49
49
  mems: Tensor | None = None
50
+ last_layer_hiddens: Tensor | None = None
51
+ attn_pooled_tokens: Tensor | None = None
50
52
  memory_tokens: Tensor | None = None
51
53
  logit_entropies: Tensor | None = None
52
54
  cache_length: int = 0
@@ -1926,7 +1928,7 @@ class Attention(Module):
1926
1928
 
1927
1929
  out = maybe(self.sublayer_dropout)(out)
1928
1930
 
1929
- if exists(mask):
1931
+ if exists(mask) and not exists(cache):
1930
1932
  out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
1931
1933
 
1932
1934
  if not return_intermediates:
@@ -2484,7 +2486,7 @@ class AttentionLayers(Module):
2484
2486
  attn_cache = []
2485
2487
 
2486
2488
  if exists(cache):
2487
- assert self.causal and not any([*map(exists, (mask, attn_mask))])
2489
+ assert self.causal and not exists(attn_mask)
2488
2490
 
2489
2491
  prev_cache_length = cache.cache_length
2490
2492
 
@@ -2848,6 +2850,9 @@ class TransformerWrapper(Module):
2848
2850
  average_pool_embed = False,
2849
2851
  use_cls_token = False,
2850
2852
  num_cls_tokens = 1,
2853
+ attn_pool = False,
2854
+ num_attn_pool_queries = 1,
2855
+ dim_attn_pool_query = None,
2851
2856
  squeeze_out_last_dim = False,
2852
2857
  token_emb: TokenEmbedding | None = None,
2853
2858
  mixture_of_softmax = False,
@@ -2927,6 +2932,10 @@ class TransformerWrapper(Module):
2927
2932
 
2928
2933
  self.train_max_recycle_steps = train_max_recycle_steps
2929
2934
 
2935
+ # either cls token or attn pool, but not both
2936
+
2937
+ assert not (use_cls_token and attn_pool)
2938
+
2930
2939
  # classic cls token from the bert days
2931
2940
 
2932
2941
  self.cls_token = None
@@ -2935,6 +2944,16 @@ class TransformerWrapper(Module):
2935
2944
  self.cls_token = nn.Parameter(torch.zeros(num_cls_tokens, dim))
2936
2945
  nn.init.normal_(self.cls_token, std = 0.02)
2937
2946
 
2947
+ # attn pool
2948
+
2949
+ self.attn_pool = None
2950
+
2951
+ if attn_pool:
2952
+ self.attn_pool = Attention(dim = default(dim_attn_pool_query, dim), dim_context = dim)
2953
+
2954
+ self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
2955
+ nn.init.normal_(self.attn_pool_queries, std = 0.02)
2956
+
2938
2957
  # whether to average pool the embed (`global average pool`)
2939
2958
 
2940
2959
  self.average_pool_embed = average_pool_embed
@@ -3222,14 +3241,37 @@ class TransformerWrapper(Module):
3222
3241
 
3223
3242
  x = x[:, :mem_seq]
3224
3243
 
3244
+ # store last layer hiddens, for access in case of cls token or attention pooling
3245
+
3246
+ intermediates.last_layer_hiddens = x
3247
+
3225
3248
  # global average pool
3226
3249
 
3227
3250
  if self.average_pool_embed:
3228
3251
  x = masked_mean(x, mask = orig_mask, dim = 1)
3229
3252
 
3253
+
3254
+ # cls token(s)
3255
+
3230
3256
  if exists(self.cls_token):
3231
- x, _ = unpack(x, cls_packed_shape, 'b * d')
3232
- x = x.squeeze(1) # Remove sequence dimension if num_cls_tokens=1 to keep previous behavior
3257
+ x, last_layer_hiddens = unpack(x, cls_packed_shape, 'b * d')
3258
+
3259
+ intermediates.last_layer_hiddens = last_layer_hiddens
3260
+
3261
+ if x.shape[1] == 1:
3262
+ x = rearrange(x, 'b 1 d -> b d') # Remove sequence dimension if num_cls_tokens=1 to keep previous behavior
3263
+
3264
+ # attention pool
3265
+
3266
+ if exists(self.attn_pool):
3267
+ queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
3268
+
3269
+ attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
3270
+
3271
+ if attn_pooled_tokens.shape[1] == 1:
3272
+ attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
3273
+
3274
+ intermediates.attn_pooled_tokens = attn_pooled_tokens
3233
3275
 
3234
3276
  # handle expansion to mixture if needed (for mixture of softmax)
3235
3277
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.26
3
+ Version: 2.4.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
- x_transformers/x_transformers.py,sha256=7phSZvP1_SDRIkVMwVR4cz1dFU2UlR2Wf1HJHEQlcQg,116222
11
+ x_transformers/x_transformers.py,sha256=IelVhLUuDmRnv6zXlQNvwUluW2RqVQQE2vYKCqctJyY,117583
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.26.dist-info/METADATA,sha256=Qc0zIph59FLOC0GPGIe41M6P1SD_lljzKg5ytoMyPAI,89897
15
- x_transformers-2.3.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.26.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.26.dist-info/RECORD,,
14
+ x_transformers-2.4.0.dist-info/METADATA,sha256=RyKkjmTnjbGUHA4EL-znJCPR17VF6i7ebvvgMKpTXVY,89896
15
+ x_transformers-2.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.4.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.4.0.dist-info/RECORD,,