x-transformers 2.4.14__py3-none-any.whl → 2.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +2 -1
- x_transformers/x_transformers.py +61 -12
- {x_transformers-2.4.14.dist-info → x_transformers-2.5.1.dist-info}/METADATA +1 -1
- {x_transformers-2.4.14.dist-info → x_transformers-2.5.1.dist-info}/RECORD +6 -6
- {x_transformers-2.4.14.dist-info → x_transformers-2.5.1.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.14.dist-info → x_transformers-2.5.1.dist-info}/licenses/LICENSE +0 -0
x_transformers/__init__.py
CHANGED
|
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
|
|
|
4
4
|
Decoder,
|
|
5
5
|
PrefixDecoder,
|
|
6
6
|
CrossAttender,
|
|
7
|
+
AttentionPool,
|
|
7
8
|
Attention,
|
|
8
9
|
FeedForward,
|
|
9
10
|
RMSNorm,
|
|
10
11
|
AdaptiveRMSNorm,
|
|
11
12
|
TransformerWrapper,
|
|
12
|
-
ViTransformerWrapper
|
|
13
|
+
ViTransformerWrapper,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
x_transformers/x_transformers.py
CHANGED
|
@@ -51,6 +51,7 @@ class LayerIntermediates:
|
|
|
51
51
|
attn_pooled_tokens: Tensor | None = None
|
|
52
52
|
memory_tokens: Tensor | None = None
|
|
53
53
|
logit_entropies: Tensor | None = None
|
|
54
|
+
logits: Tensor | None = None
|
|
54
55
|
cache_length: int = 0
|
|
55
56
|
|
|
56
57
|
LinearNoBias = partial(nn.Linear, bias = False)
|
|
@@ -2037,6 +2038,9 @@ class AttentionLayers(Module):
|
|
|
2037
2038
|
self.causal = causal
|
|
2038
2039
|
self.layers = ModuleList([])
|
|
2039
2040
|
|
|
2041
|
+
self.attn_heads = heads
|
|
2042
|
+
self.attn_dim_head = dim_head
|
|
2043
|
+
|
|
2040
2044
|
# routing related
|
|
2041
2045
|
# 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
|
|
2042
2046
|
# 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
|
|
@@ -2749,6 +2753,49 @@ class CrossAttender(AttentionLayers):
|
|
|
2749
2753
|
def __init__(self, **kwargs):
|
|
2750
2754
|
super().__init__(cross_attend = True, only_cross = True, **kwargs)
|
|
2751
2755
|
|
|
2756
|
+
class AttentionPool(Module):
|
|
2757
|
+
def __init__(
|
|
2758
|
+
self,
|
|
2759
|
+
dim,
|
|
2760
|
+
num_pooled_tokens = 1,
|
|
2761
|
+
dim_context = None,
|
|
2762
|
+
add_residual = False,
|
|
2763
|
+
depth = 1,
|
|
2764
|
+
heads = 8,
|
|
2765
|
+
dim_head = 64,
|
|
2766
|
+
squeeze_output = None,
|
|
2767
|
+
attn_kwargs: dict = dict()
|
|
2768
|
+
):
|
|
2769
|
+
super().__init__()
|
|
2770
|
+
dim_context = default(dim_context, dim)
|
|
2771
|
+
|
|
2772
|
+
squeeze_output = default(squeeze_output, False)
|
|
2773
|
+
assert not (squeeze_output and num_pooled_tokens > 1)
|
|
2774
|
+
|
|
2775
|
+
self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
|
|
2776
|
+
|
|
2777
|
+
if depth > 1:
|
|
2778
|
+
assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
|
|
2779
|
+
attn_kwargs = {f'attn_{k}': v for k, v in attn_kwargs.items()}
|
|
2780
|
+
|
|
2781
|
+
self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, heads = heads, attn_dim_head = dim_head, )
|
|
2782
|
+
else:
|
|
2783
|
+
self.pooler = Attention(dim = dim, dim_context = dim_context, heads = heads, dim_head = dim_head, **attn_kwargs)
|
|
2784
|
+
|
|
2785
|
+
self.add_residual = add_residual
|
|
2786
|
+
|
|
2787
|
+
def forward(self, context, mask = None):
|
|
2788
|
+
batch = context.shape[0]
|
|
2789
|
+
|
|
2790
|
+
queries = repeat(self.queries, 'n d -> b n d', b = batch)
|
|
2791
|
+
|
|
2792
|
+
pooled = self.pooler(queries, context, context_mask = mask)
|
|
2793
|
+
|
|
2794
|
+
if self.add_residual:
|
|
2795
|
+
pooled = pooled + queries
|
|
2796
|
+
|
|
2797
|
+
return pooled
|
|
2798
|
+
|
|
2752
2799
|
class ViTransformerWrapper(Module):
|
|
2753
2800
|
def __init__(
|
|
2754
2801
|
self,
|
|
@@ -2860,8 +2907,9 @@ class TransformerWrapper(Module):
|
|
|
2860
2907
|
use_cls_token = False,
|
|
2861
2908
|
num_cls_tokens = 1,
|
|
2862
2909
|
attn_pool = False,
|
|
2863
|
-
|
|
2864
|
-
|
|
2910
|
+
num_pooled_tokens = 1,
|
|
2911
|
+
attn_pool_depth = 1,
|
|
2912
|
+
dim_pooled_tokens = None,
|
|
2865
2913
|
squeeze_out_last_dim = False,
|
|
2866
2914
|
token_emb: TokenEmbedding | None = None,
|
|
2867
2915
|
mixture_of_softmax = False,
|
|
@@ -2958,10 +3006,7 @@ class TransformerWrapper(Module):
|
|
|
2958
3006
|
self.attn_pool = None
|
|
2959
3007
|
|
|
2960
3008
|
if attn_pool:
|
|
2961
|
-
self.attn_pool =
|
|
2962
|
-
|
|
2963
|
-
self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
|
|
2964
|
-
nn.init.normal_(self.attn_pool_queries, std = 0.02)
|
|
3009
|
+
self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth, heads = self.attn_layers.attn_heads, dim_head = self.attn_layers.attn_dim_head)
|
|
2965
3010
|
|
|
2966
3011
|
# whether to average pool the embed (`global average pool`)
|
|
2967
3012
|
|
|
@@ -3259,7 +3304,6 @@ class TransformerWrapper(Module):
|
|
|
3259
3304
|
if self.average_pool_embed:
|
|
3260
3305
|
x = masked_mean(x, mask = orig_mask, dim = 1)
|
|
3261
3306
|
|
|
3262
|
-
|
|
3263
3307
|
# cls token(s)
|
|
3264
3308
|
|
|
3265
3309
|
if exists(self.cls_token):
|
|
@@ -3272,13 +3316,15 @@ class TransformerWrapper(Module):
|
|
|
3272
3316
|
|
|
3273
3317
|
# attention pool
|
|
3274
3318
|
|
|
3275
|
-
|
|
3276
|
-
|
|
3319
|
+
is_encoder = not self.attn_layers.causal
|
|
3320
|
+
return_pooled_tokens = exists(self.attn_pool) and is_encoder
|
|
3277
3321
|
|
|
3278
|
-
|
|
3322
|
+
if (
|
|
3323
|
+
exists(self.attn_pool) and
|
|
3324
|
+
(return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
|
|
3325
|
+
):
|
|
3279
3326
|
|
|
3280
|
-
|
|
3281
|
-
attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
|
|
3327
|
+
attn_pooled_tokens = self.attn_pool(x, mask = mask)
|
|
3282
3328
|
|
|
3283
3329
|
intermediates.attn_pooled_tokens = attn_pooled_tokens
|
|
3284
3330
|
|
|
@@ -3327,6 +3373,9 @@ class TransformerWrapper(Module):
|
|
|
3327
3373
|
out = (x, intermediates)
|
|
3328
3374
|
elif return_embeddings:
|
|
3329
3375
|
out = x
|
|
3376
|
+
elif return_pooled_tokens:
|
|
3377
|
+
intermediates.logits = logits
|
|
3378
|
+
out = attn_pooled_tokens
|
|
3330
3379
|
else:
|
|
3331
3380
|
out = logits
|
|
3332
3381
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
x_transformers/__init__.py,sha256=
|
|
1
|
+
x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
|
|
2
2
|
x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=y798kS9_VvPOY_5Ilits_64aXNqYvGuilsky1y07ryE,17834
|
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
|
12
|
+
x_transformers/x_transformers.py,sha256=l6RSwfvWxeSQWXiZVQzomyn8CyMvfmXr8oBVfb9zbBM,119679
|
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
15
|
-
x_transformers-2.
|
|
16
|
-
x_transformers-2.
|
|
17
|
-
x_transformers-2.
|
|
18
|
-
x_transformers-2.
|
|
15
|
+
x_transformers-2.5.1.dist-info/METADATA,sha256=H2y1hNleReXPTU2G1VX1vowYLQFFmaotou1QmXqJM-Q,90223
|
|
16
|
+
x_transformers-2.5.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
17
|
+
x_transformers-2.5.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
18
|
+
x_transformers-2.5.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|