x-transformers 2.10.0__tar.gz → 2.10.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (66) hide show
  1. {x_transformers-2.10.0 → x_transformers-2.10.2}/PKG-INFO +8 -7
  2. {x_transformers-2.10.0 → x_transformers-2.10.2}/README.md +7 -6
  3. {x_transformers-2.10.0 → x_transformers-2.10.2}/pyproject.toml +1 -1
  4. {x_transformers-2.10.0 → x_transformers-2.10.2}/tests/test_x_transformers.py +37 -32
  5. {x_transformers-2.10.0 → x_transformers-2.10.2}/train_copy.py +10 -8
  6. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/attend.py +6 -3
  7. {x_transformers-2.10.0 → x_transformers-2.10.2}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.10.0 → x_transformers-2.10.2}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.10.0 → x_transformers-2.10.2}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.10.0 → x_transformers-2.10.2}/.gitignore +0 -0
  11. {x_transformers-2.10.0 → x_transformers-2.10.2}/LICENSE +0 -0
  12. {x_transformers-2.10.0 → x_transformers-2.10.2}/data/README.md +0 -0
  13. {x_transformers-2.10.0 → x_transformers-2.10.2}/data/enwik8.gz +0 -0
  14. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/all-attention.png +0 -0
  15. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/deepnorm.png +0 -0
  18. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/fcm.png +0 -0
  24. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/ffglu.png +0 -0
  25. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/flash-attention.png +0 -0
  26. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/gate_values.png +0 -0
  27. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/gating.png +0 -0
  28. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/macaron-1.png +0 -0
  30. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/macaron-2.png +0 -0
  31. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/normformer.png +0 -0
  33. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/pia.png +0 -0
  34. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/resi_dual.png +0 -0
  36. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/residual_attn.png +0 -0
  37. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/rezero.png +0 -0
  38. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/rotary.png +0 -0
  39. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/sandwich.png +0 -0
  41. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/scalenorm.png +0 -0
  43. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/talking-heads.png +0 -0
  44. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/topk-attention.png +0 -0
  45. {x_transformers-2.10.0 → x_transformers-2.10.2}/images/xval.png +0 -0
  46. {x_transformers-2.10.0 → x_transformers-2.10.2}/train_belief_state.py +0 -0
  47. {x_transformers-2.10.0 → x_transformers-2.10.2}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.10.0 → x_transformers-2.10.2}/train_enwik8.py +0 -0
  49. {x_transformers-2.10.0 → x_transformers-2.10.2}/train_gpt_vae.py +0 -0
  50. {x_transformers-2.10.0 → x_transformers-2.10.2}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.10.0 → x_transformers-2.10.2}/train_parity.py +0 -0
  52. {x_transformers-2.10.0 → x_transformers-2.10.2}/train_with_muon.py +0 -0
  53. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/__init__.py +0 -0
  54. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/autoregressive_wrapper.py +0 -0
  55. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/belief_state_wrapper.py +0 -0
  56. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/continuous.py +0 -0
  57. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/dpo.py +0 -0
  58. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/entropy_based_tokenizer.py +0 -0
  59. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/gpt_vae.py +0 -0
  60. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/multi_input.py +0 -0
  61. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/neo_mlp.py +0 -0
  62. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  63. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/up_wrapper.py +0 -0
  64. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/x_transformers.py +0 -0
  65. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  66. {x_transformers-2.10.0 → x_transformers-2.10.2}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.10.0
3
+ Version: 2.10.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2587,13 +2587,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2587
2587
  ```
2588
2588
 
2589
2589
  ```bibtex
2590
- @inproceedings{anonymous2025more,
2591
- title = {More Expressive Attention with Negative Weights},
2592
- author = {Anonymous},
2593
- booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
2590
+ @misc{lv2025expressiveattentionnegativeweights,
2591
+ title = {More Expressive Attention with Negative Weights},
2592
+ author = {Ang Lv and Ruobing Xie and Shuaipeng Li and Jiayi Liao and Xingwu Sun and Zhanhui Kang and Di Wang and Rui Yan},
2594
2593
  year = {2025},
2595
- url = {https://openreview.net/forum?id=ezRrwwbxd0},
2596
- note = {under review}
2594
+ eprint = {2411.07176},
2595
+ archivePrefix = {arXiv},
2596
+ primaryClass = {cs.CL},
2597
+ url = {https://arxiv.org/abs/2411.07176},
2597
2598
  }
2598
2599
  ```
2599
2600
 
@@ -2538,13 +2538,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2538
2538
  ```
2539
2539
 
2540
2540
  ```bibtex
2541
- @inproceedings{anonymous2025more,
2542
- title = {More Expressive Attention with Negative Weights},
2543
- author = {Anonymous},
2544
- booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
2541
+ @misc{lv2025expressiveattentionnegativeweights,
2542
+ title = {More Expressive Attention with Negative Weights},
2543
+ author = {Ang Lv and Ruobing Xie and Shuaipeng Li and Jiayi Liao and Xingwu Sun and Zhanhui Kang and Di Wang and Rui Yan},
2545
2544
  year = {2025},
2546
- url = {https://openreview.net/forum?id=ezRrwwbxd0},
2547
- note = {under review}
2545
+ eprint = {2411.07176},
2546
+ archivePrefix = {arXiv},
2547
+ primaryClass = {cs.CL},
2548
+ url = {https://arxiv.org/abs/2411.07176},
2548
2549
  }
2549
2550
  ```
2550
2551
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.10.0"
3
+ version = "2.10.2"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,4 +1,5 @@
1
1
  import pytest
2
+ param = pytest.mark.parametrize
2
3
 
3
4
  import torch
4
5
  from torch import nn
@@ -186,7 +187,7 @@ def test_average_pool_embed():
186
187
 
187
188
  assert logits.shape == (2, 20000)
188
189
 
189
- @pytest.mark.parametrize('num_cls_tokens', (1, 2))
190
+ @param('num_cls_tokens', (1, 2))
190
191
  def test_cls_token(num_cls_tokens):
191
192
  model = TransformerWrapper(
192
193
  num_tokens = 20000,
@@ -234,7 +235,7 @@ def test_squeeze_logit_dim_one():
234
235
 
235
236
  assert logits.shape == (2,)
236
237
 
237
- @pytest.mark.parametrize('depth', (4, 5))
238
+ @param('depth', (4, 5))
238
239
  def test_unet_skip(depth):
239
240
 
240
241
  model = TransformerWrapper(
@@ -294,7 +295,7 @@ def test_mos():
294
295
 
295
296
  eval_logits = model(x)
296
297
 
297
- @pytest.mark.parametrize('attn_one_kv_head', (True, False))
298
+ @param('attn_one_kv_head', (True, False))
298
299
  def test_l2_distance(attn_one_kv_head):
299
300
 
300
301
  model = TransformerWrapper(
@@ -331,7 +332,7 @@ def test_reinject_input():
331
332
 
332
333
  model(x) # (1, 1024, 20000)
333
334
 
334
- @pytest.mark.parametrize('learned_value_residual_mix', (False, True))
335
+ @param('learned_value_residual_mix', (False, True))
335
336
  def test_value_residual(
336
337
  learned_value_residual_mix: bool
337
338
  ):
@@ -352,7 +353,7 @@ def test_value_residual(
352
353
 
353
354
  model(x)
354
355
 
355
- @pytest.mark.parametrize('has_num_mem_kv', (False, True))
356
+ @param('has_num_mem_kv', (False, True))
356
357
  def test_forgetting_transformer(
357
358
  has_num_mem_kv: bool
358
359
  ):
@@ -388,7 +389,7 @@ def test_neo_mlp():
388
389
  out = mlp(x)
389
390
  assert out.shape == (3, 7)
390
391
 
391
- @pytest.mark.parametrize('flash', (True, False))
392
+ @param('flash', (True, False))
392
393
  def test_custom_alibi(flash: bool):
393
394
 
394
395
  model = TransformerWrapper(
@@ -409,7 +410,7 @@ def test_custom_alibi(flash: bool):
409
410
 
410
411
  logits = model(x, pos = pos)
411
412
 
412
- @pytest.mark.parametrize('rotary_xpos', (True, False))
413
+ @param('rotary_xpos', (True, False))
413
414
  def test_custom_rotary_pos_emb(rotary_xpos):
414
415
  from einops import repeat
415
416
 
@@ -433,7 +434,7 @@ def test_custom_rotary_pos_emb(rotary_xpos):
433
434
  logits2 = model(x)
434
435
  assert torch.allclose(logits1, logits2)
435
436
 
436
- @pytest.mark.parametrize('flash', (True, False))
437
+ @param('flash', (True, False))
437
438
  def test_custom_alibi_across_heads(flash: bool):
438
439
  model = Decoder(
439
440
  dim = 512,
@@ -455,7 +456,7 @@ def test_custom_alibi_across_heads(flash: bool):
455
456
 
456
457
  embed = model(x, pos = pos)
457
458
 
458
- @pytest.mark.parametrize('embedder_type', ('embedding', 'none', 'custom'))
459
+ @param('embedder_type', ('embedding', 'none', 'custom'))
459
460
  def test_embedder(embedder_type):
460
461
  num_tokens = 20000
461
462
  dim = 128
@@ -502,7 +503,7 @@ def test_embedder(embedder_type):
502
503
  assert output.shape == (2, 1024, 20000)
503
504
 
504
505
 
505
- @pytest.mark.parametrize("to_logits", ('linear', 'none', 'pointer'))
506
+ @param("to_logits", ('linear', 'none', 'pointer'))
506
507
  def test_to_logits(to_logits):
507
508
  num_tokens = 20000
508
509
  dim = 128
@@ -560,8 +561,8 @@ def test_laser():
560
561
 
561
562
  model(x)
562
563
 
563
- @pytest.mark.parametrize('self_attn_custom_pos', (True, False))
564
- @pytest.mark.parametrize('cross_attn_rotary', (True, False))
564
+ @param('self_attn_custom_pos', (True, False))
565
+ @param('cross_attn_rotary', (True, False))
565
566
  def test_cross_attn_rotary(
566
567
  self_attn_custom_pos: bool,
567
568
  cross_attn_rotary: bool
@@ -593,7 +594,7 @@ def test_cross_attn_rotary(
593
594
  context_mask = context_mask
594
595
  )
595
596
 
596
- @pytest.mark.parametrize('tanh', (True, False))
597
+ @param('tanh', (True, False))
597
598
  def test_hyper_connections(tanh):
598
599
 
599
600
  model = TransformerWrapper(
@@ -614,7 +615,7 @@ def test_hyper_connections(tanh):
614
615
 
615
616
  model(x)
616
617
 
617
- @pytest.mark.parametrize('hybrid_axial_dim', (1, 4))
618
+ @param('hybrid_axial_dim', (1, 4))
618
619
  def test_hybrid(hybrid_axial_dim):
619
620
  from torch.nn import GRU
620
621
 
@@ -770,8 +771,8 @@ def test_multi_latent_attention():
770
771
 
771
772
  model(x)
772
773
 
773
- @pytest.mark.parametrize('num_residual_streams', (1, 4))
774
- @pytest.mark.parametrize('integrate_layers', (False, True))
774
+ @param('num_residual_streams', (1, 4))
775
+ @param('integrate_layers', (False, True))
775
776
  def test_lime(
776
777
  num_residual_streams,
777
778
  integrate_layers
@@ -792,10 +793,10 @@ def test_lime(
792
793
 
793
794
  model(x)
794
795
 
795
- @pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5))
796
- @pytest.mark.parametrize('goal_suffix', (False, True))
797
- @pytest.mark.parametrize('pred_distance', (False, True))
798
- @pytest.mark.parametrize('variable_len', (False, True))
796
+ @param('backward_ar_loss_weight', (1., 0.5))
797
+ @param('goal_suffix', (False, True))
798
+ @param('pred_distance', (False, True))
799
+ @param('variable_len', (False, True))
799
800
  def test_belief_state_wrapper(
800
801
  backward_ar_loss_weight,
801
802
  goal_suffix,
@@ -867,7 +868,7 @@ def test_dynamic_tanh():
867
868
 
868
869
  model(x)
869
870
 
870
- @pytest.mark.parametrize('var_length', (False, True))
871
+ @param('var_length', (False, True))
871
872
  def test_entropy_based_tokenizer(
872
873
  var_length
873
874
  ):
@@ -966,9 +967,9 @@ def test_ff_deep_embed():
966
967
 
967
968
  assert logits.shape == (2, 1024, 20000)
968
969
 
969
- @pytest.mark.parametrize('probabilistic', (False, True))
970
- @pytest.mark.parametrize('cache_kv', (False, True))
971
- @pytest.mark.parametrize('rollout_steps', (1, 4))
970
+ @param('probabilistic', (False, True))
971
+ @param('cache_kv', (False, True))
972
+ @param('rollout_steps', (1, 4))
972
973
  def test_continuous(
973
974
  probabilistic,
974
975
  cache_kv,
@@ -1012,7 +1013,7 @@ def test_continuous(
1012
1013
  generated = model.generate(start_emb, 17, cache_kv = cache_kv) # (17, 777)
1013
1014
  assert generated.shape == (17, 777)
1014
1015
 
1015
- @pytest.mark.parametrize('add_continuous_pred_head', (False, True))
1016
+ @param('add_continuous_pred_head', (False, True))
1016
1017
  def test_autoregressive_wrapper(
1017
1018
  add_continuous_pred_head
1018
1019
  ):
@@ -1100,7 +1101,7 @@ def add_attn_pool():
1100
1101
 
1101
1102
  assert intermediates.attn_pooled_tokens.shape[1] == 3
1102
1103
 
1103
- @pytest.mark.parametrize('keep_buffer_on_cpu', (False, True))
1104
+ @param('keep_buffer_on_cpu', (False, True))
1104
1105
  def test_up(
1105
1106
  keep_buffer_on_cpu
1106
1107
  ):
@@ -1126,7 +1127,7 @@ def test_up(
1126
1127
  loss = up_wrapper()
1127
1128
  loss.backward()
1128
1129
 
1129
- @pytest.mark.parametrize('stochastic', (False, True))
1130
+ @param('stochastic', (False, True))
1130
1131
  def test_beam_search(stochastic):
1131
1132
  from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper
1132
1133
 
@@ -1154,8 +1155,8 @@ def test_beam_search(stochastic):
1154
1155
  assert scores.shape == (4, 2)
1155
1156
 
1156
1157
 
1157
- @pytest.mark.parametrize('num_pooled_tokens', (1, 3))
1158
- @pytest.mark.parametrize('attn_pool_depth', (1, 3))
1158
+ @param('num_pooled_tokens', (1, 3))
1159
+ @param('attn_pool_depth', (1, 3))
1159
1160
  def test_attn_pooler(
1160
1161
  num_pooled_tokens,
1161
1162
  attn_pool_depth
@@ -1288,7 +1289,7 @@ def test_accept_layer_intermediates():
1288
1289
 
1289
1290
  assert embeds.shape == (3, 32, 512)
1290
1291
 
1291
- @pytest.mark.parametrize('use_loss_weight', (False, True))
1292
+ @param('use_loss_weight', (False, True))
1292
1293
  def test_simple_mdlm(
1293
1294
  use_loss_weight
1294
1295
  ):
@@ -1386,7 +1387,10 @@ def test_stochastic_attn():
1386
1387
  log_probs = log_prob_from_hard_attend(intermediate)
1387
1388
  assert log_probs.shape == (1, 8, 1024)
1388
1389
 
1389
- def test_attn_negative_weights():
1390
+ @param('head_learned_sink', (True, False))
1391
+ def test_attn_negative_weights(
1392
+ head_learned_sink
1393
+ ):
1390
1394
  from x_transformers import TransformerWrapper, Decoder
1391
1395
 
1392
1396
  model = TransformerWrapper(
@@ -1396,7 +1400,8 @@ def test_attn_negative_weights():
1396
1400
  dim = 512,
1397
1401
  depth = 12,
1398
1402
  heads = 8,
1399
- attn_cog_signed = True
1403
+ attn_cog_signed = True,
1404
+ attn_head_learned_sink = True
1400
1405
  ),
1401
1406
  )
1402
1407
 
@@ -17,27 +17,29 @@ DEC_SEQ_LEN = 64 + 1
17
17
 
18
18
  def cycle():
19
19
  while True:
20
- prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
21
- src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
20
+ prefix = torch.ones((BATCH_SIZE, 1)).long()
21
+ src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long()
22
22
  tgt = torch.cat((prefix, src, src), 1)
23
- src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().cuda()
23
+ src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool()
24
24
  yield (src, tgt, src_mask)
25
25
 
26
26
  # instantiate model
27
27
 
28
28
  model = XTransformer(
29
- dim = 512,
29
+ dim = 128,
30
30
  tie_token_emb = True,
31
31
  return_tgt_loss = True,
32
32
  enc_num_tokens=NUM_TOKENS,
33
33
  enc_depth = 3,
34
34
  enc_heads = 8,
35
35
  enc_max_seq_len = ENC_SEQ_LEN,
36
+ enc_attn_cog_signed = True,
36
37
  dec_num_tokens = NUM_TOKENS,
37
38
  dec_depth = 3,
38
39
  dec_heads = 8,
39
- dec_max_seq_len = DEC_SEQ_LEN
40
- ).cuda()
40
+ dec_max_seq_len = DEC_SEQ_LEN,
41
+ dec_attn_cog_signed = True
42
+ )
41
43
 
42
44
  # optimizer
43
45
 
@@ -61,10 +63,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
61
63
  model.eval()
62
64
  src, _, src_mask = next(cycle())
63
65
  src, src_mask = src[:1], src_mask[:1]
64
- start_tokens = (torch.ones((1, 1)) * 1).long().cuda()
66
+ start_tokens = (torch.ones((1, 1)) * 1).long()
65
67
 
66
68
  sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
67
- incorrects = (src != sample).abs().sum()
69
+ incorrects = (src != sample).long().abs().sum()
68
70
 
69
71
  print(f"input: ", src)
70
72
  print(f"predicted output: ", sample)
@@ -520,6 +520,7 @@ class Attend(Module):
520
520
 
521
521
  if self.cog_signed:
522
522
  sim_sign = sim.sign()
523
+ sim = sim.abs()
523
524
 
524
525
  # masking
525
526
 
@@ -548,13 +549,15 @@ class Attend(Module):
548
549
  if self.head_learned_sink:
549
550
  # add learned attention sink
550
551
  attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
552
+
553
+ if self.cog_signed:
554
+ attn_sink, attn_sink_sign = attn_sink.abs(), attn_sink.sign()
555
+ sim_sign = cat((attn_sink_sign, sim_sign), dim = -1)
556
+
551
557
  sim = cat((attn_sink, sim), dim = -1)
552
558
 
553
559
  pre_softmax_attn = sim
554
560
 
555
- if self.cog_signed:
556
- sim = sim.abs()
557
-
558
561
  attn = self.attn_fn(sim)
559
562
 
560
563
  attn = attn.type(dtype)
File without changes