x-transformers 2.4.12__tar.gz → 2.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.4.12 → x_transformers-2.5.0}/PKG-INFO +1 -1
- {x_transformers-2.4.12 → x_transformers-2.5.0}/pyproject.toml +1 -1
- {x_transformers-2.4.12 → x_transformers-2.5.0}/tests/test_x_transformers.py +30 -2
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/__init__.py +2 -1
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/x_transformers.py +63 -12
- {x_transformers-2.4.12 → x_transformers-2.5.0}/.github/FUNDING.yml +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/.gitignore +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/LICENSE +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/README.md +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/data/README.md +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/data/enwik8.gz +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/all-attention.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/attention-on-attention.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/deepnorm.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/fcm.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/ffglu.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/flash-attention.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/gate_values.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/gating.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/macaron-1.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/macaron-2.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/memory-transformer.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/normformer.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/pia.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/resi_dual.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/residual_attn.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/rezero.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/rotary.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/sandwich-2.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/sandwich.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/sandwich_norm.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/scalenorm.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/talking-heads.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/topk-attention.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/images/xval.png +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/train_belief_state.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/train_copy.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/train_enwik8.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/train_length_extrapolate.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/train_parity.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/attend.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/continuous.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/dpo.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.12 → x_transformers-2.5.0}/x_transformers/xval.py +0 -0
|
@@ -1086,7 +1086,7 @@ def add_attn_pool():
|
|
|
1086
1086
|
num_tokens = 256,
|
|
1087
1087
|
max_seq_len = 1024,
|
|
1088
1088
|
attn_pool = True,
|
|
1089
|
-
|
|
1089
|
+
num_pooled_tokens = 3,
|
|
1090
1090
|
attn_layers = Decoder(
|
|
1091
1091
|
dim = 512,
|
|
1092
1092
|
depth = 12,
|
|
@@ -1110,7 +1110,6 @@ def test_up(
|
|
|
1110
1110
|
num_tokens = 256,
|
|
1111
1111
|
max_seq_len = 1024,
|
|
1112
1112
|
attn_pool = True,
|
|
1113
|
-
num_attn_pool_queries = 3,
|
|
1114
1113
|
attn_layers = Decoder(
|
|
1115
1114
|
dim = 512,
|
|
1116
1115
|
depth = 12,
|
|
@@ -1153,3 +1152,32 @@ def test_beam_search(stochastic):
|
|
|
1153
1152
|
|
|
1154
1153
|
assert beams.shape == (4, 2, 10)
|
|
1155
1154
|
assert scores.shape == (4, 2)
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
@pytest.mark.parametrize('num_pooled_tokens', (1, 3))
|
|
1158
|
+
@pytest.mark.parametrize('attn_pool_depth', (1, 3))
|
|
1159
|
+
def test_attn_pooler(
|
|
1160
|
+
num_pooled_tokens,
|
|
1161
|
+
attn_pool_depth
|
|
1162
|
+
):
|
|
1163
|
+
|
|
1164
|
+
model = TransformerWrapper(
|
|
1165
|
+
num_tokens = 256,
|
|
1166
|
+
max_seq_len = 1024,
|
|
1167
|
+
attn_pool = True,
|
|
1168
|
+
num_pooled_tokens = num_pooled_tokens,
|
|
1169
|
+
attn_pool_depth = attn_pool_depth,
|
|
1170
|
+
dim_pooled_tokens = 77,
|
|
1171
|
+
attn_layers = Encoder(
|
|
1172
|
+
dim = 512,
|
|
1173
|
+
depth = 12,
|
|
1174
|
+
heads = 8,
|
|
1175
|
+
attn_value_rmsnorm = True
|
|
1176
|
+
),
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
x = torch.randint(0, 256, (2, 10))
|
|
1180
|
+
|
|
1181
|
+
out = model(x)
|
|
1182
|
+
|
|
1183
|
+
assert out.shape == (2, num_pooled_tokens, 77)
|
|
@@ -4,12 +4,13 @@ from x_transformers.x_transformers import (
|
|
|
4
4
|
Decoder,
|
|
5
5
|
PrefixDecoder,
|
|
6
6
|
CrossAttender,
|
|
7
|
+
AttentionPool,
|
|
7
8
|
Attention,
|
|
8
9
|
FeedForward,
|
|
9
10
|
RMSNorm,
|
|
10
11
|
AdaptiveRMSNorm,
|
|
11
12
|
TransformerWrapper,
|
|
12
|
-
ViTransformerWrapper
|
|
13
|
+
ViTransformerWrapper,
|
|
13
14
|
)
|
|
14
15
|
|
|
15
16
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
|
@@ -51,6 +51,7 @@ class LayerIntermediates:
|
|
|
51
51
|
attn_pooled_tokens: Tensor | None = None
|
|
52
52
|
memory_tokens: Tensor | None = None
|
|
53
53
|
logit_entropies: Tensor | None = None
|
|
54
|
+
logits: Tensor | None = None
|
|
54
55
|
cache_length: int = 0
|
|
55
56
|
|
|
56
57
|
LinearNoBias = partial(nn.Linear, bias = False)
|
|
@@ -1304,6 +1305,7 @@ class Attention(Module):
|
|
|
1304
1305
|
qk_norm_groups = 1,
|
|
1305
1306
|
qk_norm_scale = 10,
|
|
1306
1307
|
qk_norm_dim_scale = False,
|
|
1308
|
+
value_rmsnorm = False, # used in alphagenome and bytedance's GR3 for further stability
|
|
1307
1309
|
l2_distance = False,
|
|
1308
1310
|
sigmoid = False,
|
|
1309
1311
|
selective = False,
|
|
@@ -1458,6 +1460,10 @@ class Attention(Module):
|
|
|
1458
1460
|
assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
|
|
1459
1461
|
assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
|
|
1460
1462
|
|
|
1463
|
+
# value rms norm
|
|
1464
|
+
|
|
1465
|
+
self.value_rmsnorm = MultiheadRMSNorm(dim_head, heads = heads) if value_rmsnorm else None
|
|
1466
|
+
|
|
1461
1467
|
# contextual positional encoding
|
|
1462
1468
|
# https://arxiv.org/html/2405.18719v2
|
|
1463
1469
|
|
|
@@ -1697,6 +1703,10 @@ class Attention(Module):
|
|
|
1697
1703
|
q = q * self.qk_norm_q_scale
|
|
1698
1704
|
k = k * self.qk_norm_k_scale
|
|
1699
1705
|
|
|
1706
|
+
# maybe value rmsnorm
|
|
1707
|
+
|
|
1708
|
+
v = maybe(self.value_rmsnorm)(v)
|
|
1709
|
+
|
|
1700
1710
|
# take care of caching
|
|
1701
1711
|
|
|
1702
1712
|
if not is_multi_latent_attn:
|
|
@@ -2740,6 +2750,45 @@ class CrossAttender(AttentionLayers):
|
|
|
2740
2750
|
def __init__(self, **kwargs):
|
|
2741
2751
|
super().__init__(cross_attend = True, only_cross = True, **kwargs)
|
|
2742
2752
|
|
|
2753
|
+
class AttentionPool(Module):
|
|
2754
|
+
def __init__(
|
|
2755
|
+
self,
|
|
2756
|
+
dim,
|
|
2757
|
+
num_pooled_tokens = 1,
|
|
2758
|
+
dim_context = None,
|
|
2759
|
+
add_residual = False,
|
|
2760
|
+
depth = 1,
|
|
2761
|
+
squeeze_output = None,
|
|
2762
|
+
attn_kwargs: dict = dict()
|
|
2763
|
+
):
|
|
2764
|
+
super().__init__()
|
|
2765
|
+
dim_context = default(dim_context, dim)
|
|
2766
|
+
|
|
2767
|
+
squeeze_output = default(squeeze_output, False)
|
|
2768
|
+
assert not (squeeze_output and num_pooled_tokens > 1)
|
|
2769
|
+
|
|
2770
|
+
self.queries = nn.Parameter(torch.randn(num_pooled_tokens, dim) * 1e-2)
|
|
2771
|
+
|
|
2772
|
+
if depth > 1:
|
|
2773
|
+
assert not add_residual, 'residual already in effect when doing a full cross attention based transformer for pooling'
|
|
2774
|
+
self.pooler = CrossAttender(dim = dim, cross_attn_dim_context = dim_context, depth = depth, **attn_kwargs)
|
|
2775
|
+
else:
|
|
2776
|
+
self.pooler = Attention(dim = dim, dim_context = dim_context, **attn_kwargs)
|
|
2777
|
+
|
|
2778
|
+
self.add_residual = add_residual
|
|
2779
|
+
|
|
2780
|
+
def forward(self, context, mask = None):
|
|
2781
|
+
batch = context.shape[0]
|
|
2782
|
+
|
|
2783
|
+
queries = repeat(self.queries, 'n d -> b n d', b = batch)
|
|
2784
|
+
|
|
2785
|
+
pooled = self.pooler(queries, context, context_mask = mask)
|
|
2786
|
+
|
|
2787
|
+
if self.add_residual:
|
|
2788
|
+
pooled = pooled + queries
|
|
2789
|
+
|
|
2790
|
+
return pooled
|
|
2791
|
+
|
|
2743
2792
|
class ViTransformerWrapper(Module):
|
|
2744
2793
|
def __init__(
|
|
2745
2794
|
self,
|
|
@@ -2851,8 +2900,9 @@ class TransformerWrapper(Module):
|
|
|
2851
2900
|
use_cls_token = False,
|
|
2852
2901
|
num_cls_tokens = 1,
|
|
2853
2902
|
attn_pool = False,
|
|
2854
|
-
|
|
2855
|
-
|
|
2903
|
+
num_pooled_tokens = 1,
|
|
2904
|
+
attn_pool_depth = 1,
|
|
2905
|
+
dim_pooled_tokens = None,
|
|
2856
2906
|
squeeze_out_last_dim = False,
|
|
2857
2907
|
token_emb: TokenEmbedding | None = None,
|
|
2858
2908
|
mixture_of_softmax = False,
|
|
@@ -2949,10 +2999,7 @@ class TransformerWrapper(Module):
|
|
|
2949
2999
|
self.attn_pool = None
|
|
2950
3000
|
|
|
2951
3001
|
if attn_pool:
|
|
2952
|
-
self.attn_pool =
|
|
2953
|
-
|
|
2954
|
-
self.attn_pool_queries = nn.Parameter(torch.zeros(num_attn_pool_queries, dim))
|
|
2955
|
-
nn.init.normal_(self.attn_pool_queries, std = 0.02)
|
|
3002
|
+
self.attn_pool = AttentionPool(dim = default(dim_pooled_tokens, dim), dim_context = dim, num_pooled_tokens = num_pooled_tokens, depth = attn_pool_depth)
|
|
2956
3003
|
|
|
2957
3004
|
# whether to average pool the embed (`global average pool`)
|
|
2958
3005
|
|
|
@@ -3250,7 +3297,6 @@ class TransformerWrapper(Module):
|
|
|
3250
3297
|
if self.average_pool_embed:
|
|
3251
3298
|
x = masked_mean(x, mask = orig_mask, dim = 1)
|
|
3252
3299
|
|
|
3253
|
-
|
|
3254
3300
|
# cls token(s)
|
|
3255
3301
|
|
|
3256
3302
|
if exists(self.cls_token):
|
|
@@ -3263,13 +3309,15 @@ class TransformerWrapper(Module):
|
|
|
3263
3309
|
|
|
3264
3310
|
# attention pool
|
|
3265
3311
|
|
|
3266
|
-
|
|
3267
|
-
|
|
3312
|
+
is_encoder = not self.attn_layers.causal
|
|
3313
|
+
return_pooled_tokens = exists(self.attn_pool) and is_encoder
|
|
3268
3314
|
|
|
3269
|
-
|
|
3315
|
+
if (
|
|
3316
|
+
exists(self.attn_pool) and
|
|
3317
|
+
(return_intermediates or is_encoder) # in a new paper, they use attention pooling on decoder - so we'll default to returning pooled tokens if encoder, but for decoder, they must set `return_intermediates`
|
|
3318
|
+
):
|
|
3270
3319
|
|
|
3271
|
-
|
|
3272
|
-
attn_pooled_tokens = rearrange(attn_pooled_tokens, 'b 1 d -> b d')
|
|
3320
|
+
attn_pooled_tokens = self.attn_pool(x, mask = mask)
|
|
3273
3321
|
|
|
3274
3322
|
intermediates.attn_pooled_tokens = attn_pooled_tokens
|
|
3275
3323
|
|
|
@@ -3318,6 +3366,9 @@ class TransformerWrapper(Module):
|
|
|
3318
3366
|
out = (x, intermediates)
|
|
3319
3367
|
elif return_embeddings:
|
|
3320
3368
|
out = x
|
|
3369
|
+
elif return_pooled_tokens:
|
|
3370
|
+
intermediates.logits = logits
|
|
3371
|
+
out = attn_pooled_tokens
|
|
3321
3372
|
else:
|
|
3322
3373
|
out = logits
|
|
3323
3374
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|