x-transformers 2.6.3__tar.gz → 2.6.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.6.3 → x_transformers-2.6.5}/PKG-INFO +11 -1
  2. {x_transformers-2.6.3 → x_transformers-2.6.5}/README.md +10 -0
  3. {x_transformers-2.6.3 → x_transformers-2.6.5}/pyproject.toml +1 -1
  4. {x_transformers-2.6.3 → x_transformers-2.6.5}/tests/test_x_transformers.py +17 -0
  5. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/attend.py +22 -4
  6. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/x_transformers.py +2 -0
  7. {x_transformers-2.6.3 → x_transformers-2.6.5}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.6.3 → x_transformers-2.6.5}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.6.3 → x_transformers-2.6.5}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.6.3 → x_transformers-2.6.5}/.gitignore +0 -0
  11. {x_transformers-2.6.3 → x_transformers-2.6.5}/LICENSE +0 -0
  12. {x_transformers-2.6.3 → x_transformers-2.6.5}/data/README.md +0 -0
  13. {x_transformers-2.6.3 → x_transformers-2.6.5}/data/enwik8.gz +0 -0
  14. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/all-attention.png +0 -0
  15. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/deepnorm.png +0 -0
  18. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/fcm.png +0 -0
  24. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/ffglu.png +0 -0
  25. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/flash-attention.png +0 -0
  26. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/gate_values.png +0 -0
  27. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/gating.png +0 -0
  28. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/macaron-1.png +0 -0
  30. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/macaron-2.png +0 -0
  31. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/normformer.png +0 -0
  33. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/pia.png +0 -0
  34. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/resi_dual.png +0 -0
  36. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/residual_attn.png +0 -0
  37. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/rezero.png +0 -0
  38. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/rotary.png +0 -0
  39. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/sandwich.png +0 -0
  41. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/scalenorm.png +0 -0
  43. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/talking-heads.png +0 -0
  44. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/topk-attention.png +0 -0
  45. {x_transformers-2.6.3 → x_transformers-2.6.5}/images/xval.png +0 -0
  46. {x_transformers-2.6.3 → x_transformers-2.6.5}/train_belief_state.py +0 -0
  47. {x_transformers-2.6.3 → x_transformers-2.6.5}/train_copy.py +0 -0
  48. {x_transformers-2.6.3 → x_transformers-2.6.5}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.6.3 → x_transformers-2.6.5}/train_enwik8.py +0 -0
  50. {x_transformers-2.6.3 → x_transformers-2.6.5}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.6.3 → x_transformers-2.6.5}/train_parity.py +0 -0
  52. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/__init__.py +0 -0
  53. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.6.3 → x_transformers-2.6.5}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.3
3
+ Version: 2.6.5
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2507,4 +2507,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2507
2507
  }
2508
2508
  ```
2509
2509
 
2510
+ ```bibtex
2511
+ @misc{openai_gpt_oss,
2512
+ author = {OpenAI},
2513
+ title = {Introducing gpt-oss},
2514
+ howpublished = {https://openai.com/index/introducing-gpt-oss},
2515
+ month = {August},
2516
+ year = {2025}
2517
+ }
2518
+ ```
2519
+
2510
2520
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2459,4 +2459,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2459
2459
  }
2460
2460
  ```
2461
2461
 
2462
+ ```bibtex
2463
+ @misc{openai_gpt_oss,
2464
+ author = {OpenAI},
2465
+ title = {Introducing gpt-oss},
2466
+ howpublished = {https://openai.com/index/introducing-gpt-oss},
2467
+ month = {August},
2468
+ year = {2025}
2469
+ }
2470
+ ```
2471
+
2462
2472
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.6.3"
3
+ version = "2.6.5"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1235,3 +1235,20 @@ def test_external_key_values():
1235
1235
  additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()
1236
1236
 
1237
1237
  logits = model(seq, self_attn_additional_kv = key_values, additional_kv_mask = additional_kv_mask)
1238
+
1239
+ def test_learned_head_attn_sink():
1240
+
1241
+ model = TransformerWrapper(
1242
+ num_tokens = 20000,
1243
+ max_seq_len = 1024,
1244
+ attn_layers = Decoder(
1245
+ dim = 512,
1246
+ depth = 12,
1247
+ heads = 8,
1248
+ attn_head_learned_sinks = 4
1249
+ )
1250
+ )
1251
+
1252
+ seq = torch.randint(0, 20000, (3, 1024))
1253
+
1254
+ logits = model(seq)
@@ -4,8 +4,8 @@ from functools import partial
4
4
  from typing import Tuple, Callable
5
5
 
6
6
  import torch
7
- from torch.nn import Module
8
- from torch import nn, einsum, Tensor
7
+ from torch.nn import Module, Parameter
8
+ from torch import cat, nn, einsum, Tensor
9
9
  import torch.nn.functional as F
10
10
 
11
11
  from collections import namedtuple
@@ -176,6 +176,7 @@ class Attend(Module):
176
176
  softclamp_logits = False,
177
177
  logit_softclamp_value = 50.,
178
178
  add_zero_kv = False,
179
+ head_learned_sinks = 0,
179
180
  selective = False,
180
181
  hard = False,
181
182
  cope = None,
@@ -254,6 +255,13 @@ class Attend(Module):
254
255
 
255
256
  self.add_zero_kv = add_zero_kv
256
257
 
258
+ # learned sink concatted pre-softmax, working solution from gpt-oss
259
+
260
+ self.has_head_learned_sinks = head_learned_sinks > 0
261
+ assert not (self.has_head_learned_sinks and flash), f'not supported for flash attention yet'
262
+
263
+ self.head_attn_sinks = Parameter(torch.zeros(heads, head_learned_sinks)) if self.has_head_learned_sinks else None
264
+
257
265
  # soft clamp attention logit value
258
266
 
259
267
  if softclamp_logits:
@@ -315,10 +323,10 @@ class Attend(Module):
315
323
  if self.l2_distance:
316
324
  k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
317
325
  k = F.pad(k, (0, 1), value = -1.)
318
- k = torch.cat((k, k_norm_sq), dim = -1)
326
+ k = cat((k, k_norm_sq), dim = -1)
319
327
 
320
328
  q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
321
- q = torch.cat((2 * q, q_norm_sq), dim = -1)
329
+ q = cat((2 * q, q_norm_sq), dim = -1)
322
330
  q = F.pad(q, (0, 1), value = -1.)
323
331
 
324
332
  # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
@@ -509,6 +517,12 @@ class Attend(Module):
509
517
  if self.selective:
510
518
  sim = selective_attn(sim)
511
519
 
520
+ if self.has_head_learned_sinks:
521
+ # add learned attention sink
522
+ num_sinks = self.head_attn_sinks.shape[-1]
523
+ attn_sink = repeat(self.head_attn_sinks, 'h sinks -> b h i sinks', b = sim.shape[0], i = sim.shape[2])
524
+ sim = cat((attn_sink, sim), dim = -1)
525
+
512
526
  pre_softmax_attn = sim
513
527
 
514
528
  attn = self.attn_fn(sim)
@@ -517,6 +531,10 @@ class Attend(Module):
517
531
 
518
532
  post_softmax_attn = attn
519
533
 
534
+ if self.has_head_learned_sinks:
535
+ # remove attention sink
536
+ attn = attn[..., num_sinks:]
537
+
520
538
  attn = self.attn_dropout(attn)
521
539
 
522
540
  if exists(self.post_softmax_talking_heads):
@@ -1319,6 +1319,7 @@ class Attention(Module):
1319
1319
  value_dim_head = None,
1320
1320
  dim_out = None,
1321
1321
  add_zero_kv = False, # same as add_zero_attn in pytorch
1322
+ head_learned_sinks = 0,
1322
1323
  rotate_num_heads = None,
1323
1324
  data_dependent_alibi = False,
1324
1325
  data_dependent_alibi_per_row = False,
@@ -1515,6 +1516,7 @@ class Attention(Module):
1515
1516
  selective = selective,
1516
1517
  custom_attn_fn = custom_attn_fn,
1517
1518
  add_zero_kv = add_zero_kv,
1519
+ head_learned_sinks = head_learned_sinks,
1518
1520
  flash = flash,
1519
1521
  softclamp_logits = softclamp_logits,
1520
1522
  logit_softclamp_value = logit_softclamp_value,
File without changes