x-transformers 2.9.2__py3-none-any.whl → 2.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

x_transformers/attend.py CHANGED
@@ -183,6 +183,7 @@ class Attend(Module):
183
183
  gumbel_softmax = False,
184
184
  gumbel_softmax_temp = 1.,
185
185
  gumbel_softmax_hard = True,
186
+ cog_signed = False,
186
187
  custom_attn_fn: Callable | None = None,
187
188
  flash = False,
188
189
  softclamp_logits = False,
@@ -260,6 +261,12 @@ class Attend(Module):
260
261
  assert not (selective and not causal), 'selective attention is designed for autoregressive'
261
262
  self.selective = selective
262
263
 
264
+ # cog attention - negative weights for expressiveness
265
+ # https://openreview.net/forum?id=ezRrwwbxd0
266
+
267
+ assert not (flash and cog_signed), 'cog attention not available for flash'
268
+ self.cog_signed = cog_signed
269
+
263
270
  # l2 distance attention
264
271
 
265
272
  self.l2_distance = l2_distance
@@ -509,6 +516,14 @@ class Attend(Module):
509
516
  if self.softclamp_logits:
510
517
  sim = softclamp(sim, self.logit_softclamp_value)
511
518
 
519
+ # pre-masking - handle cog by storing sign
520
+
521
+ if self.cog_signed:
522
+ sim_sign = sim.sign()
523
+ sim = sim.abs()
524
+
525
+ # masking
526
+
512
527
  i, j, dtype = *sim.shape[-2:], sim.dtype
513
528
 
514
529
  mask_value = -torch.finfo(sim.dtype).max
@@ -542,6 +557,11 @@ class Attend(Module):
542
557
 
543
558
  attn = attn.type(dtype)
544
559
 
560
+ # add back the sign
561
+
562
+ if self.cog_signed:
563
+ attn = attn * sim_sign
564
+
545
565
  post_softmax_attn = attn
546
566
 
547
567
  if self.head_learned_sink:
@@ -1340,6 +1340,7 @@ class Attention(Module):
1340
1340
  gumbel_softmax_temp = 1.,
1341
1341
  gumbel_softmax_hard = True,
1342
1342
  selective = False,
1343
+ cog_signed = False,
1343
1344
  custom_attn_fn: Callable | None = None,
1344
1345
  hybrid_module: Module | None = None,
1345
1346
  hybrid_mask_kwarg: str | None = None,
@@ -1548,6 +1549,7 @@ class Attention(Module):
1548
1549
  gumbel_softmax_temp = gumbel_softmax_temp,
1549
1550
  gumbel_softmax_hard = gumbel_softmax_hard,
1550
1551
  selective = selective,
1552
+ cog_signed = cog_signed,
1551
1553
  custom_attn_fn = custom_attn_fn,
1552
1554
  add_zero_kv = add_zero_kv,
1553
1555
  head_learned_sink = head_learned_sink,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.9.2
3
+ Version: 2.10.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2586,4 +2586,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2586
2586
  }
2587
2587
  ```
2588
2588
 
2589
+ ```bibtex
2590
+ @misc{lv2025expressiveattentionnegativeweights,
2591
+ title = {More Expressive Attention with Negative Weights},
2592
+ author = {Ang Lv and Ruobing Xie and Shuaipeng Li and Jiayi Liao and Xingwu Sun and Zhanhui Kang and Di Wang and Rui Yan},
2593
+ year = {2025},
2594
+ eprint = {2411.07176},
2595
+ archivePrefix = {arXiv},
2596
+ primaryClass = {cs.CL},
2597
+ url = {https://arxiv.org/abs/2411.07176},
2598
+ }
2599
+ ```
2600
+
2589
2601
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
2
- x_transformers/attend.py,sha256=RZJT9pPlpqSG3nOUqQHNRR6jOeJ2r-Fvvar2wdu9HLw,18687
2
+ x_transformers/attend.py,sha256=uu4lIEfiwzZLeuBY2dJLG9709DZbWK8-on4ds8SCCJ0,19207
3
3
  x_transformers/autoregressive_wrapper.py,sha256=BsGO9xfVYkvynqbU1__tu_S_cxl7gss0YwnkhIa2baY,18401
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
@@ -10,10 +10,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
10
10
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
11
11
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
12
12
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
13
- x_transformers/x_transformers.py,sha256=o6B10urcC7MRUrmoHOgYJgkrVDzHhX-jt6zZY3pZEgA,125700
13
+ x_transformers/x_transformers.py,sha256=ADr83Fz2cehj_F7N1bMwxhAg-r48fGhlaZqw3hxoxMQ,125765
14
14
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
15
15
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
16
- x_transformers-2.9.2.dist-info/METADATA,sha256=3JsbSIp9fsGpuXopeIaIq4ffjYTJIHyqdRLxM21cfUM,95381
17
- x_transformers-2.9.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
- x_transformers-2.9.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
- x_transformers-2.9.2.dist-info/RECORD,,
16
+ x_transformers-2.10.1.dist-info/METADATA,sha256=sEfcxJr3l0W4Yga0NLHq1sMk90Zr5-Lpr-9fIlmG9H4,95799
17
+ x_transformers-2.10.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ x_transformers-2.10.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
+ x_transformers-2.10.1.dist-info/RECORD,,