x-transformers 2.3.26__py3-none-any.whl → 2.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +46 -4
- {x_transformers-2.3.26.dist-info → x_transformers-2.4.0.dist-info}/METADATA +1 -1
- {x_transformers-2.3.26.dist-info → x_transformers-2.4.0.dist-info}/RECORD +5 -5
- {x_transformers-2.3.26.dist-info → x_transformers-2.4.0.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.26.dist-info → x_transformers-2.4.0.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -47,6 +47,8 @@ class LayerIntermediates:
|
|
47
47
|
layer_hiddens: list[Tensor] | None = None
|
48
48
|
attn_z_loss: Tensor | None = None
|
49
49
|
mems: Tensor | None = None
|
50
|
+
last_layer_hiddens: Tensor | None = None
|
51
|
+
attn_pooled_tokens: Tensor | None = None
|
50
52
|
memory_tokens: Tensor | None = None
|
51
53
|
logit_entropies: Tensor | None = None
|
52
54
|
cache_length: int = 0
|
@@ -1926,7 +1928,7 @@ class Attention(Module):
|
|
1926
1928
|
|
1927
1929
|
out = maybe(self.sublayer_dropout)(out)
|
1928
1930
|
|
1929
|
-
if exists(mask):
|
1931
|
+
if exists(mask) and not exists(cache):
|
1930
1932
|
out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
|
1931
1933
|
|
1932
1934
|
if not return_intermediates:
|
@@ -2484,7 +2486,7 @@ class AttentionLayers(Module):
|
|
2484
2486
|
attn_cache = []
|
2485
2487
|
|
2486
2488
|
if exists(cache):
|
2487
|
-
assert self.causal and not
|
2489
|
+
assert self.causal and not exists(attn_mask)
|
2488
2490
|
|
2489
2491
|
prev_cache_length = cache.cache_length
|
2490
2492
|
|
@@ -2848,6 +2850,9 @@ class TransformerWrapper(Module):
|
|
2848
2850
|
average_pool_embed = False,
|
2849
2851
|
use_cls_token = False,
|
2850
2852
|
num_cls_tokens = 1,
|
2853
|
+
attn_pool = False,
|
2854
|
+
num_attn_pool_queries = 1,
|
2855
|
+
dim_attn_pool_query = None,
|
2851
2856
|
squeeze_out_last_dim = False,
|
2852
2857
|
token_emb: TokenEmbedding | None = None,
|
2853
2858
|
mixture_of_softmax = False,
|
@@ -2927,6 +2932,10 @@ class TransformerWrapper(Module):
|
|
2927
2932
|
|
2928
2933
|
self.train_max_recycle_steps = train_max_recycle_steps
|
2929
2934
|
|
2935
|
+
# either cls token or attn pool, but not both
|
2936
|
+
|
2937
|
+
assert not (use_cls_token and attn_pool)
|
2938
|
+
|
2930
2939
|
# classic cls token from the bert days
|
2931
2940
|
|
2932
2941
|
self.cls_token = None
|
@@ -2935,6 +2944,16 @@ class TransformerWrapper(Module):
|
|
2935
2944
|
self.cls_token = nn.Parameter(torch.zeros(num_cls_tokens, dim))
|
2936
2945
|
nn.init.normal_(self.cls_token, std = 0.02)
|
2937
2946
|
|
2947
|
+
# attn pool
|
2948
|
+
|
2949
|
+
self.attn_pool = None
|
2950
|
+
|
2951
|
+
if attn_pool:
|
2952
|
+
self.attn_pool = Attention(dim = default(dim_attn_pool_query, dim), dim_context = dim)
|
2953
|
+
|
2954
|
+
self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
|
2955
|
+
nn.init.normal_(self.attn_pool_queries, std = 0.02)
|
2956
|
+
|
2938
2957
|
# whether to average pool the embed (`global average pool`)
|
2939
2958
|
|
2940
2959
|
self.average_pool_embed = average_pool_embed
|
@@ -3222,14 +3241,37 @@ class TransformerWrapper(Module):
|
|
3222
3241
|
|
3223
3242
|
x = x[:, :mem_seq]
|
3224
3243
|
|
3244
|
+
# store last layer hiddens, for access in case of cls token or attention pooling
|
3245
|
+
|
3246
|
+
intermediates.last_layer_hiddens = x
|
3247
|
+
|
3225
3248
|
# global average pool
|
3226
3249
|
|
3227
3250
|
if self.average_pool_embed:
|
3228
3251
|
x = masked_mean(x, mask = orig_mask, dim = 1)
|
3229
3252
|
|
3253
|
+
|
3254
|
+
# cls token(s)
|
3255
|
+
|
3230
3256
|
if exists(self.cls_token):
|
3231
|
-
x,
|
3232
|
-
|
3257
|
+
x, last_layer_hiddens = unpack(x, cls_packed_shape, 'b * d')
|
3258
|
+
|
3259
|
+
intermediates.last_layer_hiddens = last_layer_hiddens
|
3260
|
+
|
3261
|
+
if x.shape[1] == 1:
|
3262
|
+
x = rearrange(x, 'b 1 d -> b d') # Remove sequence dimension if num_cls_tokens=1 to keep previous behavior
|
3263
|
+
|
3264
|
+
# attention pool
|
3265
|
+
|
3266
|
+
if exists(self.attn_pool):
|
3267
|
+
queries = repeat(self.attn_pool_queries, 'n d -> b n d', b = x.shape[0])
|
3268
|
+
|
3269
|
+
attn_pooled_tokens = self.attn_pool(queries, context = x, context_mask = mask)
|
3270
|
+
|
3271
|
+
if attn_pooled_tokens.shape[1] == 1:
|
3272
|
+
attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
|
3273
|
+
|
3274
|
+
intermediates.attn_pooled_tokens = attn_pooled_tokens
|
3233
3275
|
|
3234
3276
|
# handle expansion to mixture if needed (for mixture of softmax)
|
3235
3277
|
|
@@ -8,10 +8,10 @@ x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaY
|
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
|
-
x_transformers/x_transformers.py,sha256=
|
11
|
+
x_transformers/x_transformers.py,sha256=IelVhLUuDmRnv6zXlQNvwUluW2RqVQQE2vYKCqctJyY,117583
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.
|
15
|
-
x_transformers-2.
|
16
|
-
x_transformers-2.
|
17
|
-
x_transformers-2.
|
14
|
+
x_transformers-2.4.0.dist-info/METADATA,sha256=RyKkjmTnjbGUHA4EL-znJCPR17VF6i7ebvvgMKpTXVY,89896
|
15
|
+
x_transformers-2.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.4.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.4.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|