x-transformers 2.3.27__py3-none-any.whl → 2.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +44 -2
- {x_transformers-2.3.27.dist-info → x_transformers-2.4.1.dist-info}/METADATA +1 -1
- {x_transformers-2.3.27.dist-info → x_transformers-2.4.1.dist-info}/RECORD +5 -5
- {x_transformers-2.3.27.dist-info → x_transformers-2.4.1.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.27.dist-info → x_transformers-2.4.1.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -47,6 +47,8 @@ class LayerIntermediates:
|
|
47
47
|
layer_hiddens: list[Tensor] | None = None
|
48
48
|
attn_z_loss: Tensor | None = None
|
49
49
|
mems: Tensor | None = None
|
50
|
+
last_layer_hiddens: Tensor | None = None
|
51
|
+
attn_pooled_tokens: Tensor | None = None
|
50
52
|
memory_tokens: Tensor | None = None
|
51
53
|
logit_entropies: Tensor | None = None
|
52
54
|
cache_length: int = 0
|
@@ -2848,6 +2850,9 @@ class TransformerWrapper(Module):
|
|
2848
2850
|
average_pool_embed = False,
|
2849
2851
|
use_cls_token = False,
|
2850
2852
|
num_cls_tokens = 1,
|
2853
|
+
attn_pool = False,
|
2854
|
+
num_attn_pool_queries = 1,
|
2855
|
+
dim_attn_pool_query = None,
|
2851
2856
|
squeeze_out_last_dim = False,
|
2852
2857
|
token_emb: TokenEmbedding | None = None,
|
2853
2858
|
mixture_of_softmax = False,
|
@@ -2927,6 +2932,10 @@ class TransformerWrapper(Module):
|
|
2927
2932
|
|
2928
2933
|
self.train_max_recycle_steps = train_max_recycle_steps
|
2929
2934
|
|
2935
|
+
# either cls token or attn pool, but not both
|
2936
|
+
|
2937
|
+
assert not (use_cls_token and attn_pool)
|
2938
|
+
|
2930
2939
|
# classic cls token from the bert days
|
2931
2940
|
|
2932
2941
|
self.cls_token = None
|
@@ -2935,6 +2944,16 @@ class TransformerWrapper(Module):
|
|
2935
2944
|
self.cls_token = nn.Parameter(torch.zeros(num_cls_tokens, dim))
|
2936
2945
|
nn.init.normal_(self.cls_token, std = 0.02)
|
2937
2946
|
|
2947
|
+
# attn pool
|
2948
|
+
|
2949
|
+
self.attn_pool = None
|
2950
|
+
|
2951
|
+
if attn_pool:
|
2952
|
+
self.attn_pool = Attention(dim = default(dim_attn_pool_query, dim), dim_context = dim)
|
2953
|
+
|
2954
|
+
self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
|
2955
|
+
nn.init.normal_(self.attn_pool_queries, std = 0.02)
|
2956
|
+
|
2938
2957
|
# whether to average pool the embed (`global average pool`)
|
2939
2958
|
|
2940
2959
|
self.average_pool_embed = average_pool_embed
|
@@ -3222,14 +3241,37 @@ class TransformerWrapper(Module):
|
|
3222
3241
|
|
3223
3242
|
x = x[:, :mem_seq]
|
3224
3243
|
|
3244
|
+
# store last layer hiddens, for access in case of cls token or attention pooling
|
3245
|
+
|
3246
|
+
intermediates.last_layer_hiddens = x
|
3247
|
+
|
3225
3248
|
# global average pool
|
3226
3249
|
|
3227
3250
|
if self.average_pool_embed:
|
3228
3251
|
x = masked_mean(x, mask = orig_mask, dim = 1)
|
3229
3252
|
|
3253
|
+
|
3254
|
+
# cls token(s)
|
3255
|
+
|
3230
3256
|
if exists(self.cls_token):
|
3231
|
-
x,
|
3232
|
-
|
3257
|
+
x, last_layer_hiddens = unpack(x, cls_packed_shape, 'b * d')
|
3258
|
+
|
3259
|
+
intermediates.last_layer_hiddens = last_layer_hiddens
|
3260
|
+
|
3261
|
+
if x.shape[1] == 1:
|
3262
|
+
x = rearrange(x, 'b 1 d -> b d') # Remove sequence dimension if num_cls_tokens=1 to keep previous behavior
|
3263
|
+
|
3264
|
+
# attention pool
|
3265
|
+
|
3266
|
+
if exists(self.attn_pool) and return_intermediates:
|
3267
|
+
queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
|
3268
|
+
|
3269
|
+
attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
|
3270
|
+
|
3271
|
+
if attn_pooled_tokens.shape[1] == 1:
|
3272
|
+
attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
|
3273
|
+
|
3274
|
+
intermediates.attn_pooled_tokens = attn_pooled_tokens
|
3233
3275
|
|
3234
3276
|
# handle expansion to mixture if needed (for mixture of softmax)
|
3235
3277
|
|
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.
|
15
|
-
x_transformers-2.
|
16
|
-
x_transformers-2.
|
17
|
-
x_transformers-2.
|
14
|
+
x_transformers-2.4.1.dist-info/METADATA,sha256=riX6ywwN305W2El7pkwY3GSCXrn5CNHFROTlcXN3yvo,89896
|
15
|
+
x_transformers-2.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.4.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.4.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|