x-transformers 2.4.14__py3-none-any.whl → 2.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +2 -1
- x_transformers/x_transformers.py +54 -12
- {x_transformers-2.4.14.dist-info → x_transformers-2.5.0.dist-info}/METADATA +1 -1
- {x_transformers-2.4.14.dist-info → x_transformers-2.5.0.dist-info}/RECORD +6 -6
- {x_transformers-2.4.14.dist-info → x_transformers-2.5.0.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.14.dist-info → x_transformers-2.5.0.dist-info}/licenses/LICENSE +0 -0
x_transformers/__init__.py
CHANGED
|
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
|
|
|
4
4
|
Decoder,
|
|
5
5
|
PrefixDecoder,
|
|
6
6
|
CrossAttender,
|
|
7
|
+
AttentionPool,
|
|
7
8
|
Attention,
|
|
8
9
|
FeedForward,
|
|
9
10
|
RMSNorm,
|
|
10
11
|
AdaptiveRMSNorm,
|
|
11
12
|
TransformerWrapper,
|
|
12
|
-
ViTransformerWrapper
|
|
13
|
+
ViTransformerWrapper,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
x_transformers/x_transformers.py
CHANGED
|
@@ -51,6 +51,7 @@ class LayerIntermediates:
|
|
|
51
51
|
attn_pooled_tokens: Tensor | None = None
|
|
52
52
|
memory_tokens: Tensor | None = None
|
|
53
53
|
logit_entropies: Tensor | None = None
|
|
54
|
+
logits: Tensor | None = None
|
|
54
55
|
cache_length: int = 0
|
|
55
56
|
|
|
56
57
|
LinearNoBias = partial(nn.Linear, bias = False)
|
|
@@ -2749,6 +2750,45 @@ class CrossAttender(AttentionLayers):
|
|
|
2749
2750
|
def __init__(self, **kwargs):
|
|
2750
2751
|
super().__init__(cross_attend = True, only_cross = True, **kwargs)
|
|
2751
2752
|
|
|
2753
|
+
class AttentionPool(Module):
|
|
2754
|
+
def __init__(
|
|
2755
|
+
self,
|
|
2756
|
+
dim,
|
|
2757
|
+
num_pooled_tokens = 1,
|
|
2758
|
+
dim_context = None,
|
|
2759
|
+
add_residual = False,
|
|
2760
|
+
depth = 1,
|
|
2761
|
+
squeeze_output = None,
|
|
2762
|
+
attn_kwargs: dict = dict()
|
|
2763
|
+
):
|
|
2764
|
+
super().__init__()
|
|
2765
|
+
dim_context = default(dim_context, dim)
|
|
2766
|
+
|
|
2767
|
+
squeeze_output = default(squeeze_output, False)
|
|
2768
|
+
assert not (squeeze_output and num_pooled_tokens > 1)
|
|
2769
|
+
|
|
2770
|
+
self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
|
|
2771
|
+
|
|
2772
|
+
if depth > 1:
|
|
2773
|
+
assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
|
|
2774
|
+
self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, **attn_kwargs)
|
|
2775
|
+
else:
|
|
2776
|
+
self.pooler = Attention(dim = dim, dim_context = dim_context, **attn_kwargs)
|
|
2777
|
+
|
|
2778
|
+
self.add_residual = add_residual
|
|
2779
|
+
|
|
2780
|
+
def forward(self, context, mask = None):
|
|
2781
|
+
batch = context.shape[0]
|
|
2782
|
+
|
|
2783
|
+
queries = repeat(self.queries, 'n d -> b n d', b = batch)
|
|
2784
|
+
|
|
2785
|
+
pooled = self.pooler(queries, context, context_mask = mask)
|
|
2786
|
+
|
|
2787
|
+
if self.add_residual:
|
|
2788
|
+
pooled = pooled + queries
|
|
2789
|
+
|
|
2790
|
+
return pooled
|
|
2791
|
+
|
|
2752
2792
|
class ViTransformerWrapper(Module):
|
|
2753
2793
|
def __init__(
|
|
2754
2794
|
self,
|
|
@@ -2860,8 +2900,9 @@ class TransformerWrapper(Module):
|
|
|
2860
2900
|
use_cls_token = False,
|
|
2861
2901
|
num_cls_tokens = 1,
|
|
2862
2902
|
attn_pool = False,
|
|
2863
|
-
|
|
2864
|
-
|
|
2903
|
+
num_pooled_tokens = 1,
|
|
2904
|
+
attn_pool_depth = 1,
|
|
2905
|
+
dim_pooled_tokens = None,
|
|
2865
2906
|
squeeze_out_last_dim = False,
|
|
2866
2907
|
token_emb: TokenEmbedding | None = None,
|
|
2867
2908
|
mixture_of_softmax = False,
|
|
@@ -2958,10 +2999,7 @@ class TransformerWrapper(Module):
|
|
|
2958
2999
|
self.attn_pool = None
|
|
2959
3000
|
|
|
2960
3001
|
if attn_pool:
|
|
2961
|
-
self.attn_pool =
|
|
2962
|
-
|
|
2963
|
-
self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
|
|
2964
|
-
nn.init.normal_(self.attn_pool_queries, std = 0.02)
|
|
3002
|
+
self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth)
|
|
2965
3003
|
|
|
2966
3004
|
# whether to average pool the embed (`global average pool`)
|
|
2967
3005
|
|
|
@@ -3259,7 +3297,6 @@ class TransformerWrapper(Module):
|
|
|
3259
3297
|
if self.average_pool_embed:
|
|
3260
3298
|
x = masked_mean(x, mask = orig_mask, dim = 1)
|
|
3261
3299
|
|
|
3262
|
-
|
|
3263
3300
|
# cls token(s)
|
|
3264
3301
|
|
|
3265
3302
|
if exists(self.cls_token):
|
|
@@ -3272,13 +3309,15 @@ class TransformerWrapper(Module):
|
|
|
3272
3309
|
|
|
3273
3310
|
# attention pool
|
|
3274
3311
|
|
|
3275
|
-
|
|
3276
|
-
|
|
3312
|
+
is_encoder = not self.attn_layers.causal
|
|
3313
|
+
return_pooled_tokens = exists(self.attn_pool) and is_encoder
|
|
3277
3314
|
|
|
3278
|
-
|
|
3315
|
+
if (
|
|
3316
|
+
exists(self.attn_pool) and
|
|
3317
|
+
(return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
|
|
3318
|
+
):
|
|
3279
3319
|
|
|
3280
|
-
|
|
3281
|
-
attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
|
|
3320
|
+
attn_pooled_tokens = self.attn_pool(x, mask = mask)
|
|
3282
3321
|
|
|
3283
3322
|
intermediates.attn_pooled_tokens = attn_pooled_tokens
|
|
3284
3323
|
|
|
@@ -3327,6 +3366,9 @@ class TransformerWrapper(Module):
|
|
|
3327
3366
|
out = (x, intermediates)
|
|
3328
3367
|
elif return_embeddings:
|
|
3329
3368
|
out = x
|
|
3369
|
+
elif return_pooled_tokens:
|
|
3370
|
+
intermediates.logits = logits
|
|
3371
|
+
out = attn_pooled_tokens
|
|
3330
3372
|
else:
|
|
3331
3373
|
out = logits
|
|
3332
3374
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
x_transformers/__init__.py,sha256=
|
|
1
|
+
x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
|
|
2
2
|
x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=y798kS9_VvPOY_5Ilits_64aXNqYvGuilsky1y07ryE,17834
|
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
|
12
|
+
x_transformers/x_transformers.py,sha256=NOTTbqDk5qZEY2MPpdIwJv4BvGGhXt_nIffrgQDXTf4,119346
|
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
15
|
-
x_transformers-2.
|
|
16
|
-
x_transformers-2.
|
|
17
|
-
x_transformers-2.
|
|
18
|
-
x_transformers-2.
|
|
15
|
+
x_transformers-2.5.0.dist-info/METADATA,sha256=ZcxK61msWcsm8vEPG-FirnPTnwG6HNKtvy2ZoLluJHM,90223
|
|
16
|
+
x_transformers-2.5.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
17
|
+
x_transformers-2.5.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
18
|
+
x_transformers-2.5.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|