x-transformers 2.4.14__tar.gz → 2.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.4.14 → x_transformers-2.5.1}/PKG-INFO +1 -1
- {x_transformers-2.4.14 → x_transformers-2.5.1}/pyproject.toml +1 -1
- {x_transformers-2.4.14 → x_transformers-2.5.1}/tests/test_x_transformers.py +30 -2
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/__init__.py +2 -1
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/x_transformers.py +61 -12
- {x_transformers-2.4.14 → x_transformers-2.5.1}/.github/FUNDING.yml +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/.gitignore +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/LICENSE +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/README.md +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/data/README.md +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/data/enwik8.gz +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/all-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/attention-on-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/deepnorm.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/fcm.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/ffglu.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/flash-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/gate_values.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/gating.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/macaron-1.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/macaron-2.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/memory-transformer.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/normformer.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/pia.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/resi_dual.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/residual_attn.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/rezero.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/rotary.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/sandwich-2.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/sandwich.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/sandwich_norm.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/scalenorm.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/talking-heads.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/topk-attention.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/images/xval.png +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/train_belief_state.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/train_copy.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/train_enwik8.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/train_length_extrapolate.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/train_parity.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/attend.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/continuous.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/dpo.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.14 → x_transformers-2.5.1}/x_transformers/xval.py +0 -0
|
@@ -1086,7 +1086,7 @@ def add_attn_pool():
|
|
|
1086
1086
|
num_tokens = 256,
|
|
1087
1087
|
max_seq_len = 1024,
|
|
1088
1088
|
attn_pool = True,
|
|
1089
|
-
|
|
1089
|
+
num_pooled_tokens = 3,
|
|
1090
1090
|
attn_layers = Decoder(
|
|
1091
1091
|
dim = 512,
|
|
1092
1092
|
depth = 12,
|
|
@@ -1110,7 +1110,6 @@ def test_up(
|
|
|
1110
1110
|
num_tokens = 256,
|
|
1111
1111
|
max_seq_len = 1024,
|
|
1112
1112
|
attn_pool = True,
|
|
1113
|
-
num_attn_pool_queries = 3,
|
|
1114
1113
|
attn_layers = Decoder(
|
|
1115
1114
|
dim = 512,
|
|
1116
1115
|
depth = 12,
|
|
@@ -1153,3 +1152,32 @@ def test_beam_search(stochastic):
|
|
|
1153
1152
|
|
|
1154
1153
|
assert beams.shape == (4, 2, 10)
|
|
1155
1154
|
assert scores.shape == (4, 2)
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
@pytest.mark.parametrize('num_pooled_tokens', (1, 3))
|
|
1158
|
+
@pytest.mark.parametrize('attn_pool_depth', (1, 3))
|
|
1159
|
+
def test_attn_pooler(
|
|
1160
|
+
num_pooled_tokens,
|
|
1161
|
+
attn_pool_depth
|
|
1162
|
+
):
|
|
1163
|
+
|
|
1164
|
+
model = TransformerWrapper(
|
|
1165
|
+
num_tokens = 256,
|
|
1166
|
+
max_seq_len = 1024,
|
|
1167
|
+
attn_pool = True,
|
|
1168
|
+
num_pooled_tokens = num_pooled_tokens,
|
|
1169
|
+
attn_pool_depth = attn_pool_depth,
|
|
1170
|
+
dim_pooled_tokens = 77,
|
|
1171
|
+
attn_layers = Encoder(
|
|
1172
|
+
dim = 512,
|
|
1173
|
+
depth = 12,
|
|
1174
|
+
heads = 8,
|
|
1175
|
+
attn_value_rmsnorm = True
|
|
1176
|
+
),
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
x = torch.randint(0, 256, (2, 10))
|
|
1180
|
+
|
|
1181
|
+
out = model(x)
|
|
1182
|
+
|
|
1183
|
+
assert out.shape == (2, num_pooled_tokens, 77)
|
|
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
|
|
|
4
4
|
Decoder,
|
|
5
5
|
PrefixDecoder,
|
|
6
6
|
CrossAttender,
|
|
7
|
+
AttentionPool,
|
|
7
8
|
Attention,
|
|
8
9
|
FeedForward,
|
|
9
10
|
RMSNorm,
|
|
10
11
|
AdaptiveRMSNorm,
|
|
11
12
|
TransformerWrapper,
|
|
12
|
-
ViTransformerWrapper
|
|
13
|
+
ViTransformerWrapper,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
|
@@ -51,6 +51,7 @@ class LayerIntermediates:
|
|
|
51
51
|
attn_pooled_tokens: Tensor | None = None
|
|
52
52
|
memory_tokens: Tensor | None = None
|
|
53
53
|
logit_entropies: Tensor | None = None
|
|
54
|
+
logits: Tensor | None = None
|
|
54
55
|
cache_length: int = 0
|
|
55
56
|
|
|
56
57
|
LinearNoBias = partial(nn.Linear, bias = False)
|
|
@@ -2037,6 +2038,9 @@ class AttentionLayers(Module):
|
|
|
2037
2038
|
self.causal = causal
|
|
2038
2039
|
self.layers = ModuleList([])
|
|
2039
2040
|
|
|
2041
|
+
self.attn_heads = heads
|
|
2042
|
+
self.attn_dim_head = dim_head
|
|
2043
|
+
|
|
2040
2044
|
# routing related
|
|
2041
2045
|
# 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
|
|
2042
2046
|
# 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
|
|
@@ -2749,6 +2753,49 @@ class CrossAttender(AttentionLayers):
|
|
|
2749
2753
|
def __init__(self, **kwargs):
|
|
2750
2754
|
super().__init__(cross_attend = True, only_cross = True, **kwargs)
|
|
2751
2755
|
|
|
2756
|
+
class AttentionPool(Module):
|
|
2757
|
+
def __init__(
|
|
2758
|
+
self,
|
|
2759
|
+
dim,
|
|
2760
|
+
num_pooled_tokens = 1,
|
|
2761
|
+
dim_context = None,
|
|
2762
|
+
add_residual = False,
|
|
2763
|
+
depth = 1,
|
|
2764
|
+
heads = 8,
|
|
2765
|
+
dim_head = 64,
|
|
2766
|
+
squeeze_output = None,
|
|
2767
|
+
attn_kwargs: dict = dict()
|
|
2768
|
+
):
|
|
2769
|
+
super().__init__()
|
|
2770
|
+
dim_context = default(dim_context, dim)
|
|
2771
|
+
|
|
2772
|
+
squeeze_output = default(squeeze_output, False)
|
|
2773
|
+
assert not (squeeze_output and num_pooled_tokens > 1)
|
|
2774
|
+
|
|
2775
|
+
self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
|
|
2776
|
+
|
|
2777
|
+
if depth > 1:
|
|
2778
|
+
assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
|
|
2779
|
+
attn_kwargs = {f'attn_{k}': v for k, v in attn_kwargs.items()}
|
|
2780
|
+
|
|
2781
|
+
self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, heads = heads, attn_dim_head = dim_head, )
|
|
2782
|
+
else:
|
|
2783
|
+
self.pooler = Attention(dim = dim, dim_context = dim_context, heads = heads, dim_head = dim_head, **attn_kwargs)
|
|
2784
|
+
|
|
2785
|
+
self.add_residual = add_residual
|
|
2786
|
+
|
|
2787
|
+
def forward(self, context, mask = None):
|
|
2788
|
+
batch = context.shape[0]
|
|
2789
|
+
|
|
2790
|
+
queries = repeat(self.queries, 'n d -> b n d', b = batch)
|
|
2791
|
+
|
|
2792
|
+
pooled = self.pooler(queries, context, context_mask = mask)
|
|
2793
|
+
|
|
2794
|
+
if self.add_residual:
|
|
2795
|
+
pooled = pooled + queries
|
|
2796
|
+
|
|
2797
|
+
return pooled
|
|
2798
|
+
|
|
2752
2799
|
class ViTransformerWrapper(Module):
|
|
2753
2800
|
def __init__(
|
|
2754
2801
|
self,
|
|
@@ -2860,8 +2907,9 @@ class TransformerWrapper(Module):
|
|
|
2860
2907
|
use_cls_token = False,
|
|
2861
2908
|
num_cls_tokens = 1,
|
|
2862
2909
|
attn_pool = False,
|
|
2863
|
-
|
|
2864
|
-
|
|
2910
|
+
num_pooled_tokens = 1,
|
|
2911
|
+
attn_pool_depth = 1,
|
|
2912
|
+
dim_pooled_tokens = None,
|
|
2865
2913
|
squeeze_out_last_dim = False,
|
|
2866
2914
|
token_emb: TokenEmbedding | None = None,
|
|
2867
2915
|
mixture_of_softmax = False,
|
|
@@ -2958,10 +3006,7 @@ class TransformerWrapper(Module):
|
|
|
2958
3006
|
self.attn_pool = None
|
|
2959
3007
|
|
|
2960
3008
|
if attn_pool:
|
|
2961
|
-
self.attn_pool =
|
|
2962
|
-
|
|
2963
|
-
self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
|
|
2964
|
-
nn.init.normal_(self.attn_pool_queries, std = 0.02)
|
|
3009
|
+
self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth, heads = self.attn_layers.attn_heads, dim_head = self.attn_layers.attn_dim_head)
|
|
2965
3010
|
|
|
2966
3011
|
# whether to average pool the embed (`global average pool`)
|
|
2967
3012
|
|
|
@@ -3259,7 +3304,6 @@ class TransformerWrapper(Module):
|
|
|
3259
3304
|
if self.average_pool_embed:
|
|
3260
3305
|
x = masked_mean(x, mask = orig_mask, dim = 1)
|
|
3261
3306
|
|
|
3262
|
-
|
|
3263
3307
|
# cls token(s)
|
|
3264
3308
|
|
|
3265
3309
|
if exists(self.cls_token):
|
|
@@ -3272,13 +3316,15 @@ class TransformerWrapper(Module):
|
|
|
3272
3316
|
|
|
3273
3317
|
# attention pool
|
|
3274
3318
|
|
|
3275
|
-
|
|
3276
|
-
|
|
3319
|
+
is_encoder = not self.attn_layers.causal
|
|
3320
|
+
return_pooled_tokens = exists(self.attn_pool) and is_encoder
|
|
3277
3321
|
|
|
3278
|
-
|
|
3322
|
+
if (
|
|
3323
|
+
exists(self.attn_pool) and
|
|
3324
|
+
(return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
|
|
3325
|
+
):
|
|
3279
3326
|
|
|
3280
|
-
|
|
3281
|
-
attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
|
|
3327
|
+
attn_pooled_tokens = self.attn_pool(x, mask = mask)
|
|
3282
3328
|
|
|
3283
3329
|
intermediates.attn_pooled_tokens = attn_pooled_tokens
|
|
3284
3330
|
|
|
@@ -3327,6 +3373,9 @@ class TransformerWrapper(Module):
|
|
|
3327
3373
|
out = (x, intermediates)
|
|
3328
3374
|
elif return_embeddings:
|
|
3329
3375
|
out = x
|
|
3376
|
+
elif return_pooled_tokens:
|
|
3377
|
+
intermediates.logits = logits
|
|
3378
|
+
out = attn_pooled_tokens
|
|
3330
3379
|
else:
|
|
3331
3380
|
out = logits
|
|
3332
3381
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|