x-transformers 2.4.14__tar.gz → 2.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.4.14 → x_transformers-2.5.0}/PKG-INFO +1 -1
- {x_transformers-2.4.14 → x_transformers-2.5.0}/pyproject.toml +1 -1
- {x_transformers-2.4.14 → x_transformers-2.5.0}/tests/test_x_transformers.py +30 -2
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/__init__.py +2 -1
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/x_transformers.py +54 -12
- {x_transformers-2.4.14 → x_transformers-2.5.0}/.github/FUNDING.yml +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/.gitignore +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/LICENSE +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/README.md +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/data/README.md +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/data/enwik8.gz +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/all-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/attention-on-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/deepnorm.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/fcm.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/ffglu.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/flash-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/gate_values.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/gating.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/macaron-1.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/macaron-2.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/memory-transformer.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/normformer.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/pia.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/resi_dual.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/residual_attn.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/rezero.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/rotary.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/sandwich-2.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/sandwich.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/sandwich_norm.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/scalenorm.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/talking-heads.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/topk-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/images/xval.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/train_belief_state.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/train_copy.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/train_enwik8.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/train_length_extrapolate.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/train_parity.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/attend.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/continuous.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/dpo.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.0}/x_transformers/xval.py +0 -0
|
@@ -1086,7 +1086,7 @@ def add_attn_pool():
|
|
|
1086
1086
|
num_tokens = 256,
|
|
1087
1087
|
max_seq_len = 1024,
|
|
1088
1088
|
attn_pool = True,
|
|
1089
|
-
|
|
1089
|
+
num_pooled_tokens = 3,
|
|
1090
1090
|
attn_layers = Decoder(
|
|
1091
1091
|
dim = 512,
|
|
1092
1092
|
depth = 12,
|
|
@@ -1110,7 +1110,6 @@ def test_up(
|
|
|
1110
1110
|
num_tokens = 256,
|
|
1111
1111
|
max_seq_len = 1024,
|
|
1112
1112
|
attn_pool = True,
|
|
1113
|
-
num_attn_pool_queries = 3,
|
|
1114
1113
|
attn_layers = Decoder(
|
|
1115
1114
|
dim = 512,
|
|
1116
1115
|
depth = 12,
|
|
@@ -1153,3 +1152,32 @@ def test_beam_search(stochastic):
|
|
|
1153
1152
|
|
|
1154
1153
|
assert beams.shape == (4, 2, 10)
|
|
1155
1154
|
assert scores.shape == (4, 2)
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
@pytest.mark.parametrize('num_pooled_tokens', (1, 3))
|
|
1158
|
+
@pytest.mark.parametrize('attn_pool_depth', (1, 3))
|
|
1159
|
+
def test_attn_pooler(
|
|
1160
|
+
num_pooled_tokens,
|
|
1161
|
+
attn_pool_depth
|
|
1162
|
+
):
|
|
1163
|
+
|
|
1164
|
+
model = TransformerWrapper(
|
|
1165
|
+
num_tokens = 256,
|
|
1166
|
+
max_seq_len = 1024,
|
|
1167
|
+
attn_pool = True,
|
|
1168
|
+
num_pooled_tokens = num_pooled_tokens,
|
|
1169
|
+
attn_pool_depth = attn_pool_depth,
|
|
1170
|
+
dim_pooled_tokens = 77,
|
|
1171
|
+
attn_layers = Encoder(
|
|
1172
|
+
dim = 512,
|
|
1173
|
+
depth = 12,
|
|
1174
|
+
heads = 8,
|
|
1175
|
+
attn_value_rmsnorm = True
|
|
1176
|
+
),
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
x = torch.randint(0, 256, (2, 10))
|
|
1180
|
+
|
|
1181
|
+
out = model(x)
|
|
1182
|
+
|
|
1183
|
+
assert out.shape == (2, num_pooled_tokens, 77)
|
|
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
|
|
|
4
4
|
Decoder,
|
|
5
5
|
PrefixDecoder,
|
|
6
6
|
CrossAttender,
|
|
7
|
+
AttentionPool,
|
|
7
8
|
Attention,
|
|
8
9
|
FeedForward,
|
|
9
10
|
RMSNorm,
|
|
10
11
|
AdaptiveRMSNorm,
|
|
11
12
|
TransformerWrapper,
|
|
12
|
-
ViTransformerWrapper
|
|
13
|
+
ViTransformerWrapper,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
|
@@ -51,6 +51,7 @@ class LayerIntermediates:
|
|
|
51
51
|
attn_pooled_tokens: Tensor | None = None
|
|
52
52
|
memory_tokens: Tensor | None = None
|
|
53
53
|
logit_entropies: Tensor | None = None
|
|
54
|
+
logits: Tensor | None = None
|
|
54
55
|
cache_length: int = 0
|
|
55
56
|
|
|
56
57
|
LinearNoBias = partial(nn.Linear, bias = False)
|
|
@@ -2749,6 +2750,45 @@ class CrossAttender(AttentionLayers):
|
|
|
2749
2750
|
def __init__(self, **kwargs):
|
|
2750
2751
|
super().__init__(cross_attend = True, only_cross = True, **kwargs)
|
|
2751
2752
|
|
|
2753
|
+
class AttentionPool(Module):
|
|
2754
|
+
def __init__(
|
|
2755
|
+
self,
|
|
2756
|
+
dim,
|
|
2757
|
+
num_pooled_tokens = 1,
|
|
2758
|
+
dim_context = None,
|
|
2759
|
+
add_residual = False,
|
|
2760
|
+
depth = 1,
|
|
2761
|
+
squeeze_output = None,
|
|
2762
|
+
attn_kwargs: dict = dict()
|
|
2763
|
+
):
|
|
2764
|
+
super().__init__()
|
|
2765
|
+
dim_context = default(dim_context, dim)
|
|
2766
|
+
|
|
2767
|
+
squeeze_output = default(squeeze_output, False)
|
|
2768
|
+
assert not (squeeze_output and num_pooled_tokens > 1)
|
|
2769
|
+
|
|
2770
|
+
self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
|
|
2771
|
+
|
|
2772
|
+
if depth > 1:
|
|
2773
|
+
assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
|
|
2774
|
+
self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, **attn_kwargs)
|
|
2775
|
+
else:
|
|
2776
|
+
self.pooler = Attention(dim = dim, dim_context = dim_context, **attn_kwargs)
|
|
2777
|
+
|
|
2778
|
+
self.add_residual = add_residual
|
|
2779
|
+
|
|
2780
|
+
def forward(self, context, mask = None):
|
|
2781
|
+
batch = context.shape[0]
|
|
2782
|
+
|
|
2783
|
+
queries = repeat(self.queries, 'n d -> b n d', b = batch)
|
|
2784
|
+
|
|
2785
|
+
pooled = self.pooler(queries, context, context_mask = mask)
|
|
2786
|
+
|
|
2787
|
+
if self.add_residual:
|
|
2788
|
+
pooled = pooled + queries
|
|
2789
|
+
|
|
2790
|
+
return pooled
|
|
2791
|
+
|
|
2752
2792
|
class ViTransformerWrapper(Module):
|
|
2753
2793
|
def __init__(
|
|
2754
2794
|
self,
|
|
@@ -2860,8 +2900,9 @@ class TransformerWrapper(Module):
|
|
|
2860
2900
|
use_cls_token = False,
|
|
2861
2901
|
num_cls_tokens = 1,
|
|
2862
2902
|
attn_pool = False,
|
|
2863
|
-
|
|
2864
|
-
|
|
2903
|
+
num_pooled_tokens = 1,
|
|
2904
|
+
attn_pool_depth = 1,
|
|
2905
|
+
dim_pooled_tokens = None,
|
|
2865
2906
|
squeeze_out_last_dim = False,
|
|
2866
2907
|
token_emb: TokenEmbedding | None = None,
|
|
2867
2908
|
mixture_of_softmax = False,
|
|
@@ -2958,10 +2999,7 @@ class TransformerWrapper(Module):
|
|
|
2958
2999
|
self.attn_pool = None
|
|
2959
3000
|
|
|
2960
3001
|
if attn_pool:
|
|
2961
|
-
self.attn_pool =
|
|
2962
|
-
|
|
2963
|
-
self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
|
|
2964
|
-
nn.init.normal_(self.attn_pool_queries, std = 0.02)
|
|
3002
|
+
self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth)
|
|
2965
3003
|
|
|
2966
3004
|
# whether to average pool the embed (`global average pool`)
|
|
2967
3005
|
|
|
@@ -3259,7 +3297,6 @@ class TransformerWrapper(Module):
|
|
|
3259
3297
|
if self.average_pool_embed:
|
|
3260
3298
|
x = masked_mean(x, mask = orig_mask, dim = 1)
|
|
3261
3299
|
|
|
3262
|
-
|
|
3263
3300
|
# cls token(s)
|
|
3264
3301
|
|
|
3265
3302
|
if exists(self.cls_token):
|
|
@@ -3272,13 +3309,15 @@ class TransformerWrapper(Module):
|
|
|
3272
3309
|
|
|
3273
3310
|
# attention pool
|
|
3274
3311
|
|
|
3275
|
-
|
|
3276
|
-
|
|
3312
|
+
is_encoder = not self.attn_layers.causal
|
|
3313
|
+
return_pooled_tokens = exists(self.attn_pool) and is_encoder
|
|
3277
3314
|
|
|
3278
|
-
|
|
3315
|
+
if (
|
|
3316
|
+
exists(self.attn_pool) and
|
|
3317
|
+
(return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
|
|
3318
|
+
):
|
|
3279
3319
|
|
|
3280
|
-
|
|
3281
|
-
attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
|
|
3320
|
+
attn_pooled_tokens = self.attn_pool(x, mask = mask)
|
|
3282
3321
|
|
|
3283
3322
|
intermediates.attn_pooled_tokens = attn_pooled_tokens
|
|
3284
3323
|
|
|
@@ -3327,6 +3366,9 @@ class TransformerWrapper(Module):
|
|
|
3327
3366
|
out = (x, intermediates)
|
|
3328
3367
|
elif return_embeddings:
|
|
3329
3368
|
out = x
|
|
3369
|
+
elif return_pooled_tokens:
|
|
3370
|
+
intermediates.logits = logits
|
|
3371
|
+
out = attn_pooled_tokens
|
|
3330
3372
|
else:
|
|
3331
3373
|
out = logits
|
|
3332
3374
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|