x-transformers 2.10.1__tar.gz → 2.10.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.10.1 → x_transformers-2.10.2}/PKG-INFO +1 -1
- {x_transformers-2.10.1 → x_transformers-2.10.2}/pyproject.toml +1 -1
- {x_transformers-2.10.1 → x_transformers-2.10.2}/tests/test_x_transformers.py +37 -32
- {x_transformers-2.10.1 → x_transformers-2.10.2}/train_copy.py +10 -8
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/attend.py +5 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/.github/FUNDING.yml +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/.gitignore +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/LICENSE +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/README.md +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/data/README.md +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/data/enwik8.gz +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/all-attention.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/attention-on-attention.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/deepnorm.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/fcm.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/ffglu.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/flash-attention.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/gate_values.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/gating.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/macaron-1.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/macaron-2.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/memory-transformer.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/normformer.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/pia.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/resi_dual.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/residual_attn.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/rezero.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/rotary.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/sandwich-2.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/sandwich.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/sandwich_norm.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/scalenorm.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/talking-heads.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/topk-attention.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/images/xval.png +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/train_belief_state.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/train_enwik8.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/train_gpt_vae.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/train_length_extrapolate.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/train_parity.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/train_with_muon.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/__init__.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/continuous.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/dpo.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.10.1 → x_transformers-2.10.2}/x_transformers/xval.py +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import nn
|
|
@@ -186,7 +187,7 @@ def test_average_pool_embed():
|
|
|
186
187
|
|
|
187
188
|
assert logits.shape == (2, 20000)
|
|
188
189
|
|
|
189
|
-
@
|
|
190
|
+
@param('num_cls_tokens', (1, 2))
|
|
190
191
|
def test_cls_token(num_cls_tokens):
|
|
191
192
|
model = TransformerWrapper(
|
|
192
193
|
num_tokens = 20000,
|
|
@@ -234,7 +235,7 @@ def test_squeeze_logit_dim_one():
|
|
|
234
235
|
|
|
235
236
|
assert logits.shape == (2,)
|
|
236
237
|
|
|
237
|
-
@
|
|
238
|
+
@param('depth', (4, 5))
|
|
238
239
|
def test_unet_skip(depth):
|
|
239
240
|
|
|
240
241
|
model = TransformerWrapper(
|
|
@@ -294,7 +295,7 @@ def test_mos():
|
|
|
294
295
|
|
|
295
296
|
eval_logits = model(x)
|
|
296
297
|
|
|
297
|
-
@
|
|
298
|
+
@param('attn_one_kv_head', (True, False))
|
|
298
299
|
def test_l2_distance(attn_one_kv_head):
|
|
299
300
|
|
|
300
301
|
model = TransformerWrapper(
|
|
@@ -331,7 +332,7 @@ def test_reinject_input():
|
|
|
331
332
|
|
|
332
333
|
model(x) # (1, 1024, 20000)
|
|
333
334
|
|
|
334
|
-
@
|
|
335
|
+
@param('learned_value_residual_mix', (False, True))
|
|
335
336
|
def test_value_residual(
|
|
336
337
|
learned_value_residual_mix: bool
|
|
337
338
|
):
|
|
@@ -352,7 +353,7 @@ def test_value_residual(
|
|
|
352
353
|
|
|
353
354
|
model(x)
|
|
354
355
|
|
|
355
|
-
@
|
|
356
|
+
@param('has_num_mem_kv', (False, True))
|
|
356
357
|
def test_forgetting_transformer(
|
|
357
358
|
has_num_mem_kv: bool
|
|
358
359
|
):
|
|
@@ -388,7 +389,7 @@ def test_neo_mlp():
|
|
|
388
389
|
out = mlp(x)
|
|
389
390
|
assert out.shape == (3, 7)
|
|
390
391
|
|
|
391
|
-
@
|
|
392
|
+
@param('flash', (True, False))
|
|
392
393
|
def test_custom_alibi(flash: bool):
|
|
393
394
|
|
|
394
395
|
model = TransformerWrapper(
|
|
@@ -409,7 +410,7 @@ def test_custom_alibi(flash: bool):
|
|
|
409
410
|
|
|
410
411
|
logits = model(x, pos = pos)
|
|
411
412
|
|
|
412
|
-
@
|
|
413
|
+
@param('rotary_xpos', (True, False))
|
|
413
414
|
def test_custom_rotary_pos_emb(rotary_xpos):
|
|
414
415
|
from einops import repeat
|
|
415
416
|
|
|
@@ -433,7 +434,7 @@ def test_custom_rotary_pos_emb(rotary_xpos):
|
|
|
433
434
|
logits2 = model(x)
|
|
434
435
|
assert torch.allclose(logits1, logits2)
|
|
435
436
|
|
|
436
|
-
@
|
|
437
|
+
@param('flash', (True, False))
|
|
437
438
|
def test_custom_alibi_across_heads(flash: bool):
|
|
438
439
|
model = Decoder(
|
|
439
440
|
dim = 512,
|
|
@@ -455,7 +456,7 @@ def test_custom_alibi_across_heads(flash: bool):
|
|
|
455
456
|
|
|
456
457
|
embed = model(x, pos = pos)
|
|
457
458
|
|
|
458
|
-
@
|
|
459
|
+
@param('embedder_type', ('embedding', 'none', 'custom'))
|
|
459
460
|
def test_embedder(embedder_type):
|
|
460
461
|
num_tokens = 20000
|
|
461
462
|
dim = 128
|
|
@@ -502,7 +503,7 @@ def test_embedder(embedder_type):
|
|
|
502
503
|
assert output.shape == (2, 1024, 20000)
|
|
503
504
|
|
|
504
505
|
|
|
505
|
-
@
|
|
506
|
+
@param("to_logits", ('linear', 'none', 'pointer'))
|
|
506
507
|
def test_to_logits(to_logits):
|
|
507
508
|
num_tokens = 20000
|
|
508
509
|
dim = 128
|
|
@@ -560,8 +561,8 @@ def test_laser():
|
|
|
560
561
|
|
|
561
562
|
model(x)
|
|
562
563
|
|
|
563
|
-
@
|
|
564
|
-
@
|
|
564
|
+
@param('self_attn_custom_pos', (True, False))
|
|
565
|
+
@param('cross_attn_rotary', (True, False))
|
|
565
566
|
def test_cross_attn_rotary(
|
|
566
567
|
self_attn_custom_pos: bool,
|
|
567
568
|
cross_attn_rotary: bool
|
|
@@ -593,7 +594,7 @@ def test_cross_attn_rotary(
|
|
|
593
594
|
context_mask = context_mask
|
|
594
595
|
)
|
|
595
596
|
|
|
596
|
-
@
|
|
597
|
+
@param('tanh', (True, False))
|
|
597
598
|
def test_hyper_connections(tanh):
|
|
598
599
|
|
|
599
600
|
model = TransformerWrapper(
|
|
@@ -614,7 +615,7 @@ def test_hyper_connections(tanh):
|
|
|
614
615
|
|
|
615
616
|
model(x)
|
|
616
617
|
|
|
617
|
-
@
|
|
618
|
+
@param('hybrid_axial_dim', (1, 4))
|
|
618
619
|
def test_hybrid(hybrid_axial_dim):
|
|
619
620
|
from torch.nn import GRU
|
|
620
621
|
|
|
@@ -770,8 +771,8 @@ def test_multi_latent_attention():
|
|
|
770
771
|
|
|
771
772
|
model(x)
|
|
772
773
|
|
|
773
|
-
@
|
|
774
|
-
@
|
|
774
|
+
@param('num_residual_streams', (1, 4))
|
|
775
|
+
@param('integrate_layers', (False, True))
|
|
775
776
|
def test_lime(
|
|
776
777
|
num_residual_streams,
|
|
777
778
|
integrate_layers
|
|
@@ -792,10 +793,10 @@ def test_lime(
|
|
|
792
793
|
|
|
793
794
|
model(x)
|
|
794
795
|
|
|
795
|
-
@
|
|
796
|
-
@
|
|
797
|
-
@
|
|
798
|
-
@
|
|
796
|
+
@param('backward_ar_loss_weight', (1., 0.5))
|
|
797
|
+
@param('goal_suffix', (False, True))
|
|
798
|
+
@param('pred_distance', (False, True))
|
|
799
|
+
@param('variable_len', (False, True))
|
|
799
800
|
def test_belief_state_wrapper(
|
|
800
801
|
backward_ar_loss_weight,
|
|
801
802
|
goal_suffix,
|
|
@@ -867,7 +868,7 @@ def test_dynamic_tanh():
|
|
|
867
868
|
|
|
868
869
|
model(x)
|
|
869
870
|
|
|
870
|
-
@
|
|
871
|
+
@param('var_length', (False, True))
|
|
871
872
|
def test_entropy_based_tokenizer(
|
|
872
873
|
var_length
|
|
873
874
|
):
|
|
@@ -966,9 +967,9 @@ def test_ff_deep_embed():
|
|
|
966
967
|
|
|
967
968
|
assert logits.shape == (2, 1024, 20000)
|
|
968
969
|
|
|
969
|
-
@
|
|
970
|
-
@
|
|
971
|
-
@
|
|
970
|
+
@param('probabilistic', (False, True))
|
|
971
|
+
@param('cache_kv', (False, True))
|
|
972
|
+
@param('rollout_steps', (1, 4))
|
|
972
973
|
def test_continuous(
|
|
973
974
|
probabilistic,
|
|
974
975
|
cache_kv,
|
|
@@ -1012,7 +1013,7 @@ def test_continuous(
|
|
|
1012
1013
|
generated = model.generate(start_emb, 17, cache_kv = cache_kv) # (17, 777)
|
|
1013
1014
|
assert generated.shape == (17, 777)
|
|
1014
1015
|
|
|
1015
|
-
@
|
|
1016
|
+
@param('add_continuous_pred_head', (False, True))
|
|
1016
1017
|
def test_autoregressive_wrapper(
|
|
1017
1018
|
add_continuous_pred_head
|
|
1018
1019
|
):
|
|
@@ -1100,7 +1101,7 @@ def add_attn_pool():
|
|
|
1100
1101
|
|
|
1101
1102
|
assert intermediates.attn_pooled_tokens.shape[1] == 3
|
|
1102
1103
|
|
|
1103
|
-
@
|
|
1104
|
+
@param('keep_buffer_on_cpu', (False, True))
|
|
1104
1105
|
def test_up(
|
|
1105
1106
|
keep_buffer_on_cpu
|
|
1106
1107
|
):
|
|
@@ -1126,7 +1127,7 @@ def test_up(
|
|
|
1126
1127
|
loss = up_wrapper()
|
|
1127
1128
|
loss.backward()
|
|
1128
1129
|
|
|
1129
|
-
@
|
|
1130
|
+
@param('stochastic', (False, True))
|
|
1130
1131
|
def test_beam_search(stochastic):
|
|
1131
1132
|
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper
|
|
1132
1133
|
|
|
@@ -1154,8 +1155,8 @@ def test_beam_search(stochastic):
|
|
|
1154
1155
|
assert scores.shape == (4, 2)
|
|
1155
1156
|
|
|
1156
1157
|
|
|
1157
|
-
@
|
|
1158
|
-
@
|
|
1158
|
+
@param('num_pooled_tokens', (1, 3))
|
|
1159
|
+
@param('attn_pool_depth', (1, 3))
|
|
1159
1160
|
def test_attn_pooler(
|
|
1160
1161
|
num_pooled_tokens,
|
|
1161
1162
|
attn_pool_depth
|
|
@@ -1288,7 +1289,7 @@ def test_accept_layer_intermediates():
|
|
|
1288
1289
|
|
|
1289
1290
|
assert embeds.shape == (3, 32, 512)
|
|
1290
1291
|
|
|
1291
|
-
@
|
|
1292
|
+
@param('use_loss_weight', (False, True))
|
|
1292
1293
|
def test_simple_mdlm(
|
|
1293
1294
|
use_loss_weight
|
|
1294
1295
|
):
|
|
@@ -1386,7 +1387,10 @@ def test_stochastic_attn():
|
|
|
1386
1387
|
log_probs = log_prob_from_hard_attend(intermediate)
|
|
1387
1388
|
assert log_probs.shape == (1, 8, 1024)
|
|
1388
1389
|
|
|
1389
|
-
|
|
1390
|
+
@param('head_learned_sink', (True, False))
|
|
1391
|
+
def test_attn_negative_weights(
|
|
1392
|
+
head_learned_sink
|
|
1393
|
+
):
|
|
1390
1394
|
from x_transformers import TransformerWrapper, Decoder
|
|
1391
1395
|
|
|
1392
1396
|
model = TransformerWrapper(
|
|
@@ -1396,7 +1400,8 @@ def test_attn_negative_weights():
|
|
|
1396
1400
|
dim = 512,
|
|
1397
1401
|
depth = 12,
|
|
1398
1402
|
heads = 8,
|
|
1399
|
-
attn_cog_signed = True
|
|
1403
|
+
attn_cog_signed = True,
|
|
1404
|
+
attn_head_learned_sink = True
|
|
1400
1405
|
),
|
|
1401
1406
|
)
|
|
1402
1407
|
|
|
@@ -17,27 +17,29 @@ DEC_SEQ_LEN = 64 + 1
|
|
|
17
17
|
|
|
18
18
|
def cycle():
|
|
19
19
|
while True:
|
|
20
|
-
prefix = torch.ones((BATCH_SIZE, 1)).long()
|
|
21
|
-
src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long()
|
|
20
|
+
prefix = torch.ones((BATCH_SIZE, 1)).long()
|
|
21
|
+
src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long()
|
|
22
22
|
tgt = torch.cat((prefix, src, src), 1)
|
|
23
|
-
src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool()
|
|
23
|
+
src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool()
|
|
24
24
|
yield (src, tgt, src_mask)
|
|
25
25
|
|
|
26
26
|
# instantiate model
|
|
27
27
|
|
|
28
28
|
model = XTransformer(
|
|
29
|
-
dim =
|
|
29
|
+
dim = 128,
|
|
30
30
|
tie_token_emb = True,
|
|
31
31
|
return_tgt_loss = True,
|
|
32
32
|
enc_num_tokens=NUM_TOKENS,
|
|
33
33
|
enc_depth = 3,
|
|
34
34
|
enc_heads = 8,
|
|
35
35
|
enc_max_seq_len = ENC_SEQ_LEN,
|
|
36
|
+
enc_attn_cog_signed = True,
|
|
36
37
|
dec_num_tokens = NUM_TOKENS,
|
|
37
38
|
dec_depth = 3,
|
|
38
39
|
dec_heads = 8,
|
|
39
|
-
dec_max_seq_len = DEC_SEQ_LEN
|
|
40
|
-
|
|
40
|
+
dec_max_seq_len = DEC_SEQ_LEN,
|
|
41
|
+
dec_attn_cog_signed = True
|
|
42
|
+
)
|
|
41
43
|
|
|
42
44
|
# optimizer
|
|
43
45
|
|
|
@@ -61,10 +63,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
61
63
|
model.eval()
|
|
62
64
|
src, _, src_mask = next(cycle())
|
|
63
65
|
src, src_mask = src[:1], src_mask[:1]
|
|
64
|
-
start_tokens = (torch.ones((1, 1)) * 1).long()
|
|
66
|
+
start_tokens = (torch.ones((1, 1)) * 1).long()
|
|
65
67
|
|
|
66
68
|
sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
|
|
67
|
-
incorrects = (src != sample).abs().sum()
|
|
69
|
+
incorrects = (src != sample).long().abs().sum()
|
|
68
70
|
|
|
69
71
|
print(f"input: ", src)
|
|
70
72
|
print(f"predicted output: ", sample)
|
|
@@ -549,6 +549,11 @@ class Attend(Module):
|
|
|
549
549
|
if self.head_learned_sink:
|
|
550
550
|
# add learned attention sink
|
|
551
551
|
attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
|
|
552
|
+
|
|
553
|
+
if self.cog_signed:
|
|
554
|
+
attn_sink, attn_sink_sign = attn_sink.abs(), attn_sink.sign()
|
|
555
|
+
sim_sign = cat((attn_sink_sign, sim_sign), dim = -1)
|
|
556
|
+
|
|
552
557
|
sim = cat((attn_sink, sim), dim = -1)
|
|
553
558
|
|
|
554
559
|
pre_softmax_attn = sim
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|