x-transformers 2.10.1__tar.gz → 2.11.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.10.1 → x_transformers-2.11.0}/PKG-INFO +10 -1
  2. {x_transformers-2.10.1 → x_transformers-2.11.0}/README.md +9 -0
  3. {x_transformers-2.10.1 → x_transformers-2.11.0}/pyproject.toml +1 -1
  4. {x_transformers-2.10.1 → x_transformers-2.11.0}/tests/test_x_transformers.py +37 -32
  5. {x_transformers-2.10.1 → x_transformers-2.11.0}/train_copy.py +11 -8
  6. x_transformers-2.11.0/train_free.py +134 -0
  7. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/attend.py +5 -0
  8. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/autoregressive_wrapper.py +4 -0
  9. x_transformers-2.11.0/x_transformers/free_transformer.py +330 -0
  10. {x_transformers-2.10.1 → x_transformers-2.11.0}/.github/FUNDING.yml +0 -0
  11. {x_transformers-2.10.1 → x_transformers-2.11.0}/.github/workflows/python-publish.yml +0 -0
  12. {x_transformers-2.10.1 → x_transformers-2.11.0}/.github/workflows/python-test.yaml +0 -0
  13. {x_transformers-2.10.1 → x_transformers-2.11.0}/.gitignore +0 -0
  14. {x_transformers-2.10.1 → x_transformers-2.11.0}/LICENSE +0 -0
  15. {x_transformers-2.10.1 → x_transformers-2.11.0}/data/README.md +0 -0
  16. {x_transformers-2.10.1 → x_transformers-2.11.0}/data/enwik8.gz +0 -0
  17. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/all-attention.png +0 -0
  18. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/attention-on-attention.png +0 -0
  19. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/cosine-sim-attention.png +0 -0
  20. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/deepnorm.png +0 -0
  21. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/dynamic-pos-bias-linear.png +0 -0
  22. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/dynamic-pos-bias-log.png +0 -0
  23. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  24. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/dynamic-pos-bias.png +0 -0
  25. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/enhanced-recurrence.png +0 -0
  26. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/fcm.png +0 -0
  27. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/ffglu.png +0 -0
  28. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/flash-attention.png +0 -0
  29. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/gate_values.png +0 -0
  30. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/gating.png +0 -0
  31. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/length-extrapolation-scale.png +0 -0
  32. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/macaron-1.png +0 -0
  33. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/macaron-2.png +0 -0
  34. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/memory-transformer.png +0 -0
  35. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/normformer.png +0 -0
  36. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/pia.png +0 -0
  37. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/qknorm-analysis.png +0 -0
  38. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/resi_dual.png +0 -0
  39. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/residual_attn.png +0 -0
  40. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/rezero.png +0 -0
  41. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/rotary.png +0 -0
  42. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/sandwich-2.png +0 -0
  43. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/sandwich.png +0 -0
  44. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/sandwich_norm.png +0 -0
  45. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/scalenorm.png +0 -0
  46. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/talking-heads.png +0 -0
  47. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/topk-attention.png +0 -0
  48. {x_transformers-2.10.1 → x_transformers-2.11.0}/images/xval.png +0 -0
  49. {x_transformers-2.10.1 → x_transformers-2.11.0}/train_belief_state.py +0 -0
  50. {x_transformers-2.10.1 → x_transformers-2.11.0}/train_entropy_tokenizer.py +0 -0
  51. {x_transformers-2.10.1 → x_transformers-2.11.0}/train_enwik8.py +0 -0
  52. {x_transformers-2.10.1 → x_transformers-2.11.0}/train_gpt_vae.py +0 -0
  53. {x_transformers-2.10.1 → x_transformers-2.11.0}/train_length_extrapolate.py +0 -0
  54. {x_transformers-2.10.1 → x_transformers-2.11.0}/train_parity.py +0 -0
  55. {x_transformers-2.10.1 → x_transformers-2.11.0}/train_with_muon.py +0 -0
  56. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/__init__.py +0 -0
  57. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/gpt_vae.py +0 -0
  62. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/multi_input.py +0 -0
  63. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/neo_mlp.py +0 -0
  64. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  65. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/up_wrapper.py +0 -0
  66. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/x_transformers.py +0 -0
  67. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.10.1 → x_transformers-2.11.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.10.1
3
+ Version: 2.11.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2598,4 +2598,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2598
2598
  }
2599
2599
  ```
2600
2600
 
2601
+ ```bibtex
2602
+ @inproceedings{Fleuret2025TheFT,
2603
+ title = {The Free Transformer},
2604
+ author = {Franccois Fleuret},
2605
+ year = {2025},
2606
+ url = {https://api.semanticscholar.org/CorpusID:282210283}
2607
+ }
2608
+ ```
2609
+
2601
2610
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2549,4 +2549,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2549
2549
  }
2550
2550
  ```
2551
2551
 
2552
+ ```bibtex
2553
+ @inproceedings{Fleuret2025TheFT,
2554
+ title = {The Free Transformer},
2555
+ author = {Franccois Fleuret},
2556
+ year = {2025},
2557
+ url = {https://api.semanticscholar.org/CorpusID:282210283}
2558
+ }
2559
+ ```
2560
+
2552
2561
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.10.1"
3
+ version = "2.11.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,4 +1,5 @@
1
1
  import pytest
2
+ param = pytest.mark.parametrize
2
3
 
3
4
  import torch
4
5
  from torch import nn
@@ -186,7 +187,7 @@ def test_average_pool_embed():
186
187
 
187
188
  assert logits.shape == (2, 20000)
188
189
 
189
- @pytest.mark.parametrize('num_cls_tokens', (1, 2))
190
+ @param('num_cls_tokens', (1, 2))
190
191
  def test_cls_token(num_cls_tokens):
191
192
  model = TransformerWrapper(
192
193
  num_tokens = 20000,
@@ -234,7 +235,7 @@ def test_squeeze_logit_dim_one():
234
235
 
235
236
  assert logits.shape == (2,)
236
237
 
237
- @pytest.mark.parametrize('depth', (4, 5))
238
+ @param('depth', (4, 5))
238
239
  def test_unet_skip(depth):
239
240
 
240
241
  model = TransformerWrapper(
@@ -294,7 +295,7 @@ def test_mos():
294
295
 
295
296
  eval_logits = model(x)
296
297
 
297
- @pytest.mark.parametrize('attn_one_kv_head', (True, False))
298
+ @param('attn_one_kv_head', (True, False))
298
299
  def test_l2_distance(attn_one_kv_head):
299
300
 
300
301
  model = TransformerWrapper(
@@ -331,7 +332,7 @@ def test_reinject_input():
331
332
 
332
333
  model(x) # (1, 1024, 20000)
333
334
 
334
- @pytest.mark.parametrize('learned_value_residual_mix', (False, True))
335
+ @param('learned_value_residual_mix', (False, True))
335
336
  def test_value_residual(
336
337
  learned_value_residual_mix: bool
337
338
  ):
@@ -352,7 +353,7 @@ def test_value_residual(
352
353
 
353
354
  model(x)
354
355
 
355
- @pytest.mark.parametrize('has_num_mem_kv', (False, True))
356
+ @param('has_num_mem_kv', (False, True))
356
357
  def test_forgetting_transformer(
357
358
  has_num_mem_kv: bool
358
359
  ):
@@ -388,7 +389,7 @@ def test_neo_mlp():
388
389
  out = mlp(x)
389
390
  assert out.shape == (3, 7)
390
391
 
391
- @pytest.mark.parametrize('flash', (True, False))
392
+ @param('flash', (True, False))
392
393
  def test_custom_alibi(flash: bool):
393
394
 
394
395
  model = TransformerWrapper(
@@ -409,7 +410,7 @@ def test_custom_alibi(flash: bool):
409
410
 
410
411
  logits = model(x, pos = pos)
411
412
 
412
- @pytest.mark.parametrize('rotary_xpos', (True, False))
413
+ @param('rotary_xpos', (True, False))
413
414
  def test_custom_rotary_pos_emb(rotary_xpos):
414
415
  from einops import repeat
415
416
 
@@ -433,7 +434,7 @@ def test_custom_rotary_pos_emb(rotary_xpos):
433
434
  logits2 = model(x)
434
435
  assert torch.allclose(logits1, logits2)
435
436
 
436
- @pytest.mark.parametrize('flash', (True, False))
437
+ @param('flash', (True, False))
437
438
  def test_custom_alibi_across_heads(flash: bool):
438
439
  model = Decoder(
439
440
  dim = 512,
@@ -455,7 +456,7 @@ def test_custom_alibi_across_heads(flash: bool):
455
456
 
456
457
  embed = model(x, pos = pos)
457
458
 
458
- @pytest.mark.parametrize('embedder_type', ('embedding', 'none', 'custom'))
459
+ @param('embedder_type', ('embedding', 'none', 'custom'))
459
460
  def test_embedder(embedder_type):
460
461
  num_tokens = 20000
461
462
  dim = 128
@@ -502,7 +503,7 @@ def test_embedder(embedder_type):
502
503
  assert output.shape == (2, 1024, 20000)
503
504
 
504
505
 
505
- @pytest.mark.parametrize("to_logits", ('linear', 'none', 'pointer'))
506
+ @param("to_logits", ('linear', 'none', 'pointer'))
506
507
  def test_to_logits(to_logits):
507
508
  num_tokens = 20000
508
509
  dim = 128
@@ -560,8 +561,8 @@ def test_laser():
560
561
 
561
562
  model(x)
562
563
 
563
- @pytest.mark.parametrize('self_attn_custom_pos', (True, False))
564
- @pytest.mark.parametrize('cross_attn_rotary', (True, False))
564
+ @param('self_attn_custom_pos', (True, False))
565
+ @param('cross_attn_rotary', (True, False))
565
566
  def test_cross_attn_rotary(
566
567
  self_attn_custom_pos: bool,
567
568
  cross_attn_rotary: bool
@@ -593,7 +594,7 @@ def test_cross_attn_rotary(
593
594
  context_mask = context_mask
594
595
  )
595
596
 
596
- @pytest.mark.parametrize('tanh', (True, False))
597
+ @param('tanh', (True, False))
597
598
  def test_hyper_connections(tanh):
598
599
 
599
600
  model = TransformerWrapper(
@@ -614,7 +615,7 @@ def test_hyper_connections(tanh):
614
615
 
615
616
  model(x)
616
617
 
617
- @pytest.mark.parametrize('hybrid_axial_dim', (1, 4))
618
+ @param('hybrid_axial_dim', (1, 4))
618
619
  def test_hybrid(hybrid_axial_dim):
619
620
  from torch.nn import GRU
620
621
 
@@ -770,8 +771,8 @@ def test_multi_latent_attention():
770
771
 
771
772
  model(x)
772
773
 
773
- @pytest.mark.parametrize('num_residual_streams', (1, 4))
774
- @pytest.mark.parametrize('integrate_layers', (False, True))
774
+ @param('num_residual_streams', (1, 4))
775
+ @param('integrate_layers', (False, True))
775
776
  def test_lime(
776
777
  num_residual_streams,
777
778
  integrate_layers
@@ -792,10 +793,10 @@ def test_lime(
792
793
 
793
794
  model(x)
794
795
 
795
- @pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5))
796
- @pytest.mark.parametrize('goal_suffix', (False, True))
797
- @pytest.mark.parametrize('pred_distance', (False, True))
798
- @pytest.mark.parametrize('variable_len', (False, True))
796
+ @param('backward_ar_loss_weight', (1., 0.5))
797
+ @param('goal_suffix', (False, True))
798
+ @param('pred_distance', (False, True))
799
+ @param('variable_len', (False, True))
799
800
  def test_belief_state_wrapper(
800
801
  backward_ar_loss_weight,
801
802
  goal_suffix,
@@ -867,7 +868,7 @@ def test_dynamic_tanh():
867
868
 
868
869
  model(x)
869
870
 
870
- @pytest.mark.parametrize('var_length', (False, True))
871
+ @param('var_length', (False, True))
871
872
  def test_entropy_based_tokenizer(
872
873
  var_length
873
874
  ):
@@ -966,9 +967,9 @@ def test_ff_deep_embed():
966
967
 
967
968
  assert logits.shape == (2, 1024, 20000)
968
969
 
969
- @pytest.mark.parametrize('probabilistic', (False, True))
970
- @pytest.mark.parametrize('cache_kv', (False, True))
971
- @pytest.mark.parametrize('rollout_steps', (1, 4))
970
+ @param('probabilistic', (False, True))
971
+ @param('cache_kv', (False, True))
972
+ @param('rollout_steps', (1, 4))
972
973
  def test_continuous(
973
974
  probabilistic,
974
975
  cache_kv,
@@ -1012,7 +1013,7 @@ def test_continuous(
1012
1013
  generated = model.generate(start_emb, 17, cache_kv = cache_kv) # (17, 777)
1013
1014
  assert generated.shape == (17, 777)
1014
1015
 
1015
- @pytest.mark.parametrize('add_continuous_pred_head', (False, True))
1016
+ @param('add_continuous_pred_head', (False, True))
1016
1017
  def test_autoregressive_wrapper(
1017
1018
  add_continuous_pred_head
1018
1019
  ):
@@ -1100,7 +1101,7 @@ def add_attn_pool():
1100
1101
 
1101
1102
  assert intermediates.attn_pooled_tokens.shape[1] == 3
1102
1103
 
1103
- @pytest.mark.parametrize('keep_buffer_on_cpu', (False, True))
1104
+ @param('keep_buffer_on_cpu', (False, True))
1104
1105
  def test_up(
1105
1106
  keep_buffer_on_cpu
1106
1107
  ):
@@ -1126,7 +1127,7 @@ def test_up(
1126
1127
  loss = up_wrapper()
1127
1128
  loss.backward()
1128
1129
 
1129
- @pytest.mark.parametrize('stochastic', (False, True))
1130
+ @param('stochastic', (False, True))
1130
1131
  def test_beam_search(stochastic):
1131
1132
  from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper
1132
1133
 
@@ -1154,8 +1155,8 @@ def test_beam_search(stochastic):
1154
1155
  assert scores.shape == (4, 2)
1155
1156
 
1156
1157
 
1157
- @pytest.mark.parametrize('num_pooled_tokens', (1, 3))
1158
- @pytest.mark.parametrize('attn_pool_depth', (1, 3))
1158
+ @param('num_pooled_tokens', (1, 3))
1159
+ @param('attn_pool_depth', (1, 3))
1159
1160
  def test_attn_pooler(
1160
1161
  num_pooled_tokens,
1161
1162
  attn_pool_depth
@@ -1288,7 +1289,7 @@ def test_accept_layer_intermediates():
1288
1289
 
1289
1290
  assert embeds.shape == (3, 32, 512)
1290
1291
 
1291
- @pytest.mark.parametrize('use_loss_weight', (False, True))
1292
+ @param('use_loss_weight', (False, True))
1292
1293
  def test_simple_mdlm(
1293
1294
  use_loss_weight
1294
1295
  ):
@@ -1386,7 +1387,10 @@ def test_stochastic_attn():
1386
1387
  log_probs = log_prob_from_hard_attend(intermediate)
1387
1388
  assert log_probs.shape == (1, 8, 1024)
1388
1389
 
1389
- def test_attn_negative_weights():
1390
+ @param('head_learned_sink', (True, False))
1391
+ def test_attn_negative_weights(
1392
+ head_learned_sink
1393
+ ):
1390
1394
  from x_transformers import TransformerWrapper, Decoder
1391
1395
 
1392
1396
  model = TransformerWrapper(
@@ -1396,7 +1400,8 @@ def test_attn_negative_weights():
1396
1400
  dim = 512,
1397
1401
  depth = 12,
1398
1402
  heads = 8,
1399
- attn_cog_signed = True
1403
+ attn_cog_signed = True,
1404
+ attn_head_learned_sink = True
1400
1405
  ),
1401
1406
  )
1402
1407
 
@@ -12,32 +12,35 @@ GENERATE_EVERY = 100
12
12
  NUM_TOKENS = 16 + 2
13
13
  ENC_SEQ_LEN = 32
14
14
  DEC_SEQ_LEN = 64 + 1
15
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
15
16
 
16
17
  # helpers
17
18
 
18
19
  def cycle():
19
20
  while True:
20
- prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
21
- src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
21
+ prefix = torch.ones((BATCH_SIZE, 1)).long().to(DEVICE)
22
+ src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().to(DEVICE)
22
23
  tgt = torch.cat((prefix, src, src), 1)
23
- src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
24
+ src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().to(DEVICE)
24
25
  yield (src, tgt, src_mask)
25
26
 
26
27
  # instantiate model
27
28
 
28
29
  model = XTransformer(
29
- dim = 512,
30
+ dim = 128,
30
31
  tie_token_emb = True,
31
32
  return_tgt_loss = True,
32
33
  enc_num_tokens=NUM_TOKENS,
33
34
  enc_depth = 3,
34
35
  enc_heads = 8,
35
36
  enc_max_seq_len = ENC_SEQ_LEN,
37
+ enc_attn_cog_signed = True,
36
38
  dec_num_tokens = NUM_TOKENS,
37
39
  dec_depth = 3,
38
40
  dec_heads = 8,
39
- dec_max_seq_len = DEC_SEQ_LEN
40
- ).cuda()
41
+ dec_max_seq_len = DEC_SEQ_LEN,
42
+ dec_attn_cog_signed = True
43
+ ).to(DEVICE)
41
44
 
42
45
  # optimizer
43
46
 
@@ -61,10 +64,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
61
64
  model.eval()
62
65
  src, _, src_mask = next(cycle())
63
66
  src, src_mask = src[:1], src_mask[:1]
64
- start_tokens = (torch.ones((1, 1)) * 1).long().cuda()
67
+ start_tokens = (torch.ones((1, 1)) * 1).long().to(DEVICE)
65
68
 
66
69
  sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
67
- incorrects = (src != sample).abs().sum()
70
+ incorrects = (src != sample).long().abs().sum()
68
71
 
69
72
  print(f"input: ", src)
70
73
  print(f"predicted output: ", sample)
@@ -0,0 +1,134 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "tqdm",
4
+ # "x-transformers>=2.11.0",
5
+ # ]
6
+ # ///
7
+
8
+ from x_transformers.free_transformer import FreeTransformer
9
+
10
+ from math import log
11
+ import random
12
+ import tqdm
13
+ import gzip
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.optim as optim
18
+ from torch import tensor
19
+ from torch.nn import functional as F
20
+ from torch.utils.data import DataLoader, Dataset
21
+
22
+ # constants
23
+
24
+ NUM_BATCHES = int(1e5)
25
+ BATCH_SIZE = 4
26
+ GRADIENT_ACCUMULATE_EVERY = 4
27
+ LEARNING_RATE = 1e-4
28
+ VALIDATE_EVERY = 100
29
+ GENERATE_EVERY = 250
30
+ GENERATE_LENGTH = 512
31
+ PRIME_LENGTH = 32
32
+ SEQ_LEN = 512
33
+
34
+ LATENT_BITS = 8
35
+ NAT = log(2)
36
+
37
+ # helpers
38
+
39
+ def cycle(loader):
40
+ while True:
41
+ for data in loader:
42
+ yield data
43
+
44
+ def decode_token(token):
45
+ return str(chr(max(32, token)))
46
+
47
+ def decode_tokens(tokens):
48
+ return ''.join(list(map(decode_token, tokens)))
49
+
50
+ # instantiate GPT-like decoder model
51
+
52
+ model = FreeTransformer(
53
+ num_tokens = 256,
54
+ max_seq_len = SEQ_LEN,
55
+ dim = 512,
56
+ heads = 8,
57
+ rotary_pos_emb = True,
58
+ dec_head_depth = 4,
59
+ dec_tail_depth = 4,
60
+ enc_depth = 3,
61
+ kl_loss_weight = 1.,
62
+ kl_loss_threshold = NAT,
63
+ latent_bits = LATENT_BITS
64
+ ).cuda()
65
+
66
+ rand_index = torch.randint(0, 2 ** LATENT_BITS, ())
67
+ latents = F.one_hot(rand_index, 2 ** LATENT_BITS).float().cuda()
68
+
69
+ # prepare enwik8 data
70
+
71
+ with gzip.open('./data/enwik8.gz') as file:
72
+ data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
73
+ train_x, valid_x = np.split(data, [int(90e6)])
74
+ data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
75
+
76
+ class TextSamplerDataset(Dataset):
77
+ def __init__(self, data, seq_len):
78
+ super().__init__()
79
+ self.data = data
80
+ self.seq_len = seq_len
81
+
82
+ def __getitem__(self, index):
83
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
84
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
85
+ return full_seq.cuda()
86
+
87
+ def __len__(self):
88
+ return self.data.size(0) // self.seq_len
89
+
90
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
91
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
92
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
93
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
94
+
95
+ # optimizer
96
+
97
+ optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
98
+
99
+ # training
100
+
101
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
102
+ model.train()
103
+
104
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
105
+ loss, (ar_loss, vae_kl_loss) = model(next(train_loader), return_all_losses = True)
106
+ (loss / GRADIENT_ACCUMULATE_EVERY).backward()
107
+
108
+ print(f'training loss: {ar_loss.item():.4f}\t| kl loss: {vae_kl_loss.item():.4f}')
109
+
110
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
111
+ optim.step()
112
+ optim.zero_grad()
113
+
114
+ if i % VALIDATE_EVERY == 0:
115
+ model.eval()
116
+ with torch.no_grad():
117
+ loss, (ar_loss, _) = model(next(val_loader), return_all_losses = True)
118
+ print(f'validation loss: {ar_loss.item():.4f}')
119
+
120
+ if i % GENERATE_EVERY == 0:
121
+ model.eval()
122
+ inp = random.choice(val_dataset)[:PRIME_LENGTH]
123
+ prime = decode_tokens(inp)
124
+ print(f'%s \n\n %s', (prime, '*' * 100))
125
+
126
+ sample = model.generate(
127
+ prompts = inp,
128
+ seq_len = GENERATE_LENGTH,
129
+ latents = latents
130
+ )
131
+
132
+ output_str = decode_tokens(sample)
133
+
134
+ print(f'\n\nlatent {rand_index.tolist()} - ', output_str)
@@ -549,6 +549,11 @@ class Attend(Module):
549
549
  if self.head_learned_sink:
550
550
  # add learned attention sink
551
551
  attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
552
+
553
+ if self.cog_signed:
554
+ attn_sink, attn_sink_sign = attn_sink.abs(), attn_sink.sign()
555
+ sim_sign = cat((attn_sink_sign, sim_sign), dim = -1)
556
+
552
557
  sim = cat((attn_sink, sim), dim = -1)
553
558
 
554
559
  pre_softmax_attn = sim
@@ -43,6 +43,10 @@ def log(t, eps = 1e-20):
43
43
  def gumbel_noise(t):
44
44
  return -log(-log(torch.rand_like(t)))
45
45
 
46
+ def gumbel_sample(logits, temperature = 1., eps = 1e-6):
47
+ noise = gumbel_noise(logits)
48
+ return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
49
+
46
50
  # function for modifying all the cached key / values
47
51
 
48
52
  def modify_cached_kv(cache, fn):
@@ -0,0 +1,330 @@
1
+ from __future__ import annotations
2
+
3
+ # https://arxiv.org/abs/2510.17558
4
+ # François Fleuret
5
+ # https://www.youtube.com/watch?v=Nao16-6l6dQ
6
+
7
+ import math
8
+
9
+ import torch
10
+ from torch import nn, Tensor, is_tensor, tensor, arange
11
+ import torch.nn.functional as F
12
+ from torch.nn import Module, ModuleList
13
+
14
+ from x_transformers.x_transformers import (
15
+ Encoder,
16
+ Decoder,
17
+ TransformerWrapper
18
+ )
19
+
20
+ from x_transformers.autoregressive_wrapper import (
21
+ gumbel_sample,
22
+ top_p,
23
+ top_k
24
+ )
25
+
26
+ from einops.layers.torch import Rearrange, Reduce
27
+ from einops import rearrange, reduce, repeat, einsum, pack, unpack
28
+
29
+ # helper functions
30
+
31
+ def exists(v):
32
+ return v is not None
33
+
34
+ def default(v, d):
35
+ return v if exists(v) else d
36
+
37
+ def log(t, eps = 1e-20):
38
+ return t.clamp_min(eps).log()
39
+
40
+ def pack_with_inverse(t, pattern):
41
+ packed, ps = pack([t], pattern)
42
+
43
+ def inverse(out, inv_pattern = None):
44
+ inv_pattern = default(inv_pattern, pattern)
45
+ unpacked, = unpack(out, ps, inv_pattern)
46
+ return unpacked
47
+
48
+ return packed, inverse
49
+
50
+ # binary mapper
51
+
52
+ NAT = math.log(2)
53
+
54
+ def binary_entropy(logits):
55
+ prob = logits.sigmoid()
56
+ not_prob = 1. - prob
57
+ return -(prob * F.logsigmoid(logits) + not_prob * F.logsigmoid(-logits)).sum(dim = -1)
58
+
59
+ class BinaryMapper(Module):
60
+ def __init__(
61
+ self,
62
+ bits = 1,
63
+ kl_loss_threshold = NAT # 1 bit
64
+ ):
65
+ super().__init__()
66
+
67
+ self.bits = bits
68
+ self.num_codes = 2 ** bits
69
+ self.kl_loss_threshold = kl_loss_threshold
70
+
71
+ power_two = 2 ** arange(bits)
72
+ codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
73
+
74
+ self.register_buffer('power_two', power_two, persistent = False)
75
+ self.register_buffer('codes', codes, persistent = False)
76
+
77
+ def forward(
78
+ self,
79
+ logits,
80
+ temperature = 1.,
81
+ straight_through = None
82
+ ):
83
+ straight_through = default(straight_through, self.training)
84
+
85
+ assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
86
+
87
+ # temperature and prob for sampling
88
+
89
+ prob_for_sample = (logits / temperature).sigmoid()
90
+
91
+ # sampling
92
+
93
+ sampled_bits = (torch.rand_like(logits) <= prob_for_sample).long()
94
+ indices = (self.power_two * sampled_bits).sum(dim = -1)
95
+
96
+ one_hot = F.one_hot(indices, self.num_codes).float()
97
+
98
+ # return hard one hot if not training or overridden
99
+
100
+ if not straight_through:
101
+ return one_hot
102
+
103
+ # calculate negative entropy
104
+
105
+ kl_div = self.bits * NAT - binary_entropy(logits)
106
+ aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
107
+
108
+ # get the soft G for the gradients and do a straight through
109
+
110
+ soft_G = (
111
+ einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
112
+ einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
113
+ ).exp()
114
+
115
+ # straight through
116
+
117
+ one_hot = one_hot + soft_G - soft_G.detach()
118
+
119
+ return one_hot, aux_kl_loss
120
+
121
+ # classes
122
+
123
+ class FreeTransformer(Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ num_tokens,
128
+ dim,
129
+ dec_head_depth,
130
+ dec_tail_depth,
131
+ enc_depth,
132
+ max_seq_len,
133
+ dim_latent = None,
134
+ attn_dim_head = 64,
135
+ heads = 8,
136
+ latent_bits = 16,
137
+ kl_loss_threshold = NAT,
138
+ binary_mapper_kwargs: dict = dict(),
139
+ enc_kwargs: dict = dict(),
140
+ dec_kwargs: dict = dict(),
141
+ kl_loss_weight = 1.,
142
+ pad_id = -1,
143
+ encoder: Module | None = None,
144
+ **kwargs
145
+ ):
146
+ super().__init__()
147
+ dim_latent = default(dim_latent, dim)
148
+
149
+ self.token_emb = nn.Embedding(num_tokens, dim)
150
+
151
+ self.token_unembed = nn.Linear(dim, num_tokens, bias = False)
152
+
153
+ if not exists(encoder):
154
+ encoder = Encoder(
155
+ dim = dim,
156
+ depth = enc_depth,
157
+ attn_dim_head = attn_dim_head,
158
+ heads = heads,
159
+ **kwargs,
160
+ **enc_kwargs
161
+ )
162
+
163
+ self.encoder = encoder
164
+
165
+ self.to_latent_bit_logits = nn.Sequential(
166
+ Reduce('b n d -> b d', 'mean'),
167
+ nn.Linear(dim, latent_bits, bias = False),
168
+ )
169
+
170
+ self.binary_mapper = BinaryMapper(
171
+ latent_bits,
172
+ kl_loss_threshold,
173
+ **binary_mapper_kwargs
174
+ )
175
+
176
+ self.from_latent_to_condition = nn.Sequential(
177
+ nn.Linear(2 ** latent_bits, dim, bias = False),
178
+ Rearrange('b d -> b 1 d')
179
+ )
180
+
181
+ self.decoder_head = Decoder(
182
+ dim = dim,
183
+ depth = dec_head_depth,
184
+ attn_dim_head = attn_dim_head,
185
+ heads = heads,
186
+ pre_norm_has_final_norm = False,
187
+ **kwargs,
188
+ **dec_kwargs
189
+ )
190
+
191
+ self.decoder_tail = Decoder(
192
+ dim = dim,
193
+ depth = dec_tail_depth,
194
+ attn_dim_head = attn_dim_head,
195
+ heads = heads,
196
+ pre_norm_has_final_norm = True,
197
+ **kwargs,
198
+ **dec_kwargs
199
+ )
200
+
201
+ self.pad_id = pad_id
202
+
203
+ self.kl_loss_weight = kl_loss_weight
204
+
205
+ @property
206
+ def device(self):
207
+ return next(self.parameters()).device
208
+
209
+ def encode_to_latents(
210
+ self,
211
+ seq,
212
+ mask = None,
213
+ return_kl_loss = False
214
+ ):
215
+ pooled = self.encoder(seq, mask = mask)
216
+
217
+ bit_logits = self.to_latent_bit_logits(pooled)
218
+
219
+ one_hot_latents, kl_loss = self.binary_mapper(bit_logits, straight_through = True)
220
+
221
+ if not return_kl_loss:
222
+ return one_hot_latents
223
+
224
+ return one_hot_latents, kl_loss
225
+
226
+ @torch.no_grad()
227
+ def generate(
228
+ self,
229
+ prompts,
230
+ seq_len,
231
+ latents = None,
232
+ filter_logits_fn = top_p,
233
+ logit_filter_kwargs: dict = dict(thres = 0.9)
234
+ ):
235
+ prompts, inverse_pack = pack_with_inverse(prompts, '* n')
236
+
237
+ batch = prompts.shape[0]
238
+
239
+ # prepend embeds
240
+
241
+ condition = None
242
+ if exists(latents):
243
+ if not is_tensor(latents):
244
+ latents = tensor(latents, device = self.device)
245
+
246
+ if latents.ndim == 1: # repeat latents
247
+ latents = repeat(latents, 'd -> b d', b = batch)
248
+
249
+ condition = self.from_latent_to_condition(latents)
250
+
251
+ # generated
252
+
253
+ prompt_len = prompts.shape[-1]
254
+
255
+ generated = prompts
256
+
257
+ tokens = self.token_emb(generated)
258
+
259
+ for _ in range(max(0, seq_len - prompt_len)):
260
+
261
+ head_embed = self.decoder_head(tokens)
262
+
263
+ if exists(condition):
264
+ head_embed = head_embed + condition
265
+
266
+ tail_embed = self.decoder_tail(head_embed)
267
+
268
+ tail_embed = tail_embed[:, -1]
269
+
270
+ logits = self.token_unembed(tail_embed)
271
+
272
+ logits = filter_logits_fn(logits, **logit_filter_kwargs)
273
+
274
+ sampled = gumbel_sample(logits)
275
+
276
+ generated, _ = pack((generated, sampled), 'b *')
277
+ tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
278
+
279
+ return inverse_pack(generated)
280
+
281
+ def forward(
282
+ self,
283
+ seq,
284
+ return_all_losses = False
285
+ ):
286
+ batch, device = seq.shape[0], seq.device
287
+
288
+ seq, labels = seq[:, :-1], seq[:, 1:]
289
+
290
+ encoder_mask = seq != self.pad_id
291
+
292
+ tokens = self.token_emb(seq)
293
+
294
+ # decoder head
295
+
296
+ tokens = self.decoder_head(tokens)
297
+
298
+ # get latent Z
299
+
300
+ latents, kl_loss = self.encode_to_latents(tokens, mask = encoder_mask, return_kl_loss = True)
301
+
302
+ condition = self.from_latent_to_condition(latents)
303
+
304
+ # decoder tail
305
+
306
+ tokens = self.decoder_tail(tokens)
307
+
308
+ # cross entropy loss
309
+
310
+ logits = self.token_unembed(tokens)
311
+
312
+ ar_loss = F.cross_entropy(
313
+ rearrange(logits, 'b n l -> b l n'),
314
+ labels,
315
+ ignore_index = self.pad_id
316
+ )
317
+
318
+ # return losses
319
+
320
+ total_loss = (
321
+ ar_loss +
322
+ kl_loss * self.kl_loss_weight
323
+ )
324
+
325
+ if not return_all_losses:
326
+ return total_loss
327
+
328
+ losses = (ar_loss, kl_loss)
329
+
330
+ return total_loss, losses
File without changes