x-transformers 2.6.4__tar.gz → 2.6.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {x_transformers-2.6.4 → x_transformers-2.6.5}/PKG-INFO +1 -1
  2. {x_transformers-2.6.4 → x_transformers-2.6.5}/pyproject.toml +1 -1
  3. {x_transformers-2.6.4 → x_transformers-2.6.5}/tests/test_x_transformers.py +1 -1
  4. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/attend.py +9 -8
  5. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/x_transformers.py +2 -2
  6. {x_transformers-2.6.4 → x_transformers-2.6.5}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.6.4 → x_transformers-2.6.5}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.6.4 → x_transformers-2.6.5}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.6.4 → x_transformers-2.6.5}/.gitignore +0 -0
  10. {x_transformers-2.6.4 → x_transformers-2.6.5}/LICENSE +0 -0
  11. {x_transformers-2.6.4 → x_transformers-2.6.5}/README.md +0 -0
  12. {x_transformers-2.6.4 → x_transformers-2.6.5}/data/README.md +0 -0
  13. {x_transformers-2.6.4 → x_transformers-2.6.5}/data/enwik8.gz +0 -0
  14. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/all-attention.png +0 -0
  15. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/deepnorm.png +0 -0
  18. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/fcm.png +0 -0
  24. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/ffglu.png +0 -0
  25. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/flash-attention.png +0 -0
  26. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/gate_values.png +0 -0
  27. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/gating.png +0 -0
  28. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/macaron-1.png +0 -0
  30. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/macaron-2.png +0 -0
  31. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/normformer.png +0 -0
  33. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/pia.png +0 -0
  34. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/resi_dual.png +0 -0
  36. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/residual_attn.png +0 -0
  37. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/rezero.png +0 -0
  38. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/rotary.png +0 -0
  39. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/sandwich.png +0 -0
  41. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/scalenorm.png +0 -0
  43. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/talking-heads.png +0 -0
  44. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/topk-attention.png +0 -0
  45. {x_transformers-2.6.4 → x_transformers-2.6.5}/images/xval.png +0 -0
  46. {x_transformers-2.6.4 → x_transformers-2.6.5}/train_belief_state.py +0 -0
  47. {x_transformers-2.6.4 → x_transformers-2.6.5}/train_copy.py +0 -0
  48. {x_transformers-2.6.4 → x_transformers-2.6.5}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.6.4 → x_transformers-2.6.5}/train_enwik8.py +0 -0
  50. {x_transformers-2.6.4 → x_transformers-2.6.5}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.6.4 → x_transformers-2.6.5}/train_parity.py +0 -0
  52. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/__init__.py +0 -0
  53. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/continuous.py +0 -0
  56. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/up_wrapper.py +0 -0
  62. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  63. {x_transformers-2.6.4 → x_transformers-2.6.5}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.4
3
+ Version: 2.6.5
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.6.4"
3
+ version = "2.6.5"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1245,7 +1245,7 @@ def test_learned_head_attn_sink():
1245
1245
  dim = 512,
1246
1246
  depth = 12,
1247
1247
  heads = 8,
1248
- attn_head_learned_sink = True
1248
+ attn_head_learned_sinks = 4
1249
1249
  )
1250
1250
  )
1251
1251
 
@@ -176,7 +176,7 @@ class Attend(Module):
176
176
  softclamp_logits = False,
177
177
  logit_softclamp_value = 50.,
178
178
  add_zero_kv = False,
179
- head_learned_sink = False,
179
+ head_learned_sinks = 0,
180
180
  selective = False,
181
181
  hard = False,
182
182
  cope = None,
@@ -257,10 +257,10 @@ class Attend(Module):
257
257
 
258
258
  # learned sink concatted pre-softmax, working solution from gpt-oss
259
259
 
260
- assert not (head_learned_sink and flash), f'not supported for flash attention yet'
260
+ self.has_head_learned_sinks = head_learned_sinks > 0
261
+ assert not (self.has_head_learned_sinks and flash), f'not supported for flash attention yet'
261
262
 
262
- self.head_learned_sink = head_learned_sink
263
- self.head_attn_sink = Parameter(torch.zeros(heads)) if head_learned_sink else None
263
+ self.head_attn_sinks = Parameter(torch.zeros(heads, head_learned_sinks)) if self.has_head_learned_sinks else None
264
264
 
265
265
  # soft clamp attention logit value
266
266
 
@@ -517,9 +517,10 @@ class Attend(Module):
517
517
  if self.selective:
518
518
  sim = selective_attn(sim)
519
519
 
520
- if self.head_learned_sink:
520
+ if self.has_head_learned_sinks:
521
521
  # add learned attention sink
522
- attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
522
+ num_sinks = self.head_attn_sinks.shape[-1]
523
+ attn_sink = repeat(self.head_attn_sinks, 'h sinks -> b h i sinks', b = sim.shape[0], i = sim.shape[2])
523
524
  sim = cat((attn_sink, sim), dim = -1)
524
525
 
525
526
  pre_softmax_attn = sim
@@ -530,9 +531,9 @@ class Attend(Module):
530
531
 
531
532
  post_softmax_attn = attn
532
533
 
533
- if self.head_learned_sink:
534
+ if self.has_head_learned_sinks:
534
535
  # remove attention sink
535
- attn = attn[..., 1:]
536
+ attn = attn[..., num_sinks:]
536
537
 
537
538
  attn = self.attn_dropout(attn)
538
539
 
@@ -1319,7 +1319,7 @@ class Attention(Module):
1319
1319
  value_dim_head = None,
1320
1320
  dim_out = None,
1321
1321
  add_zero_kv = False, # same as add_zero_attn in pytorch
1322
- head_learned_sink = False,
1322
+ head_learned_sinks = 0,
1323
1323
  rotate_num_heads = None,
1324
1324
  data_dependent_alibi = False,
1325
1325
  data_dependent_alibi_per_row = False,
@@ -1516,7 +1516,7 @@ class Attention(Module):
1516
1516
  selective = selective,
1517
1517
  custom_attn_fn = custom_attn_fn,
1518
1518
  add_zero_kv = add_zero_kv,
1519
- head_learned_sink = head_learned_sink,
1519
+ head_learned_sinks = head_learned_sinks,
1520
1520
  flash = flash,
1521
1521
  softclamp_logits = softclamp_logits,
1522
1522
  logit_softclamp_value = logit_softclamp_value,
File without changes
File without changes