x-transformers 2.9.1__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

x_transformers/attend.py CHANGED
@@ -67,6 +67,15 @@ def once(fn):
67
67
 
68
68
  print_once = once(print)
69
69
 
70
+ # gumbel softmax attention related
71
+
72
+ def log_prob_from_hard_attend(intermeds: Intermediates):
73
+ log_probs = intermeds.pre_softmax_attn.log_softmax(dim = -1)
74
+
75
+ one_hot = intermeds.post_softmax_attn.argmax(dim = -1, keepdim = True)
76
+ log_prob = log_probs.gather(-1, one_hot)
77
+ return rearrange(log_prob, 'b h i 1 -> b h i')
78
+
70
79
  # selective attention
71
80
  # https://arxiv.org/abs/2410.02703 - section 3.3
72
81
  # it is a technique to allow each token to prevent itself from being attended to by future tokens
@@ -174,6 +183,7 @@ class Attend(Module):
174
183
  gumbel_softmax = False,
175
184
  gumbel_softmax_temp = 1.,
176
185
  gumbel_softmax_hard = True,
186
+ cog_signed = False,
177
187
  custom_attn_fn: Callable | None = None,
178
188
  flash = False,
179
189
  softclamp_logits = False,
@@ -251,6 +261,12 @@ class Attend(Module):
251
261
  assert not (selective and not causal), 'selective attention is designed for autoregressive'
252
262
  self.selective = selective
253
263
 
264
+ # cog attention - negative weights for expressiveness
265
+ # https://openreview.net/forum?id=ezRrwwbxd0
266
+
267
+ assert not (flash and cog_signed), 'cog attention not available for flash'
268
+ self.cog_signed = cog_signed
269
+
254
270
  # l2 distance attention
255
271
 
256
272
  self.l2_distance = l2_distance
@@ -500,6 +516,13 @@ class Attend(Module):
500
516
  if self.softclamp_logits:
501
517
  sim = softclamp(sim, self.logit_softclamp_value)
502
518
 
519
+ # pre-masking - handle cog by storing sign
520
+
521
+ if self.cog_signed:
522
+ sim_sign = sim.sign()
523
+
524
+ # masking
525
+
503
526
  i, j, dtype = *sim.shape[-2:], sim.dtype
504
527
 
505
528
  mask_value = -torch.finfo(sim.dtype).max
@@ -529,10 +552,18 @@ class Attend(Module):
529
552
 
530
553
  pre_softmax_attn = sim
531
554
 
555
+ if self.cog_signed:
556
+ sim = sim.abs()
557
+
532
558
  attn = self.attn_fn(sim)
533
559
 
534
560
  attn = attn.type(dtype)
535
561
 
562
+ # add back the sign
563
+
564
+ if self.cog_signed:
565
+ attn = attn * sim_sign
566
+
536
567
  post_softmax_attn = attn
537
568
 
538
569
  if self.head_learned_sink:
@@ -1340,6 +1340,7 @@ class Attention(Module):
1340
1340
  gumbel_softmax_temp = 1.,
1341
1341
  gumbel_softmax_hard = True,
1342
1342
  selective = False,
1343
+ cog_signed = False,
1343
1344
  custom_attn_fn: Callable | None = None,
1344
1345
  hybrid_module: Module | None = None,
1345
1346
  hybrid_mask_kwarg: str | None = None,
@@ -1548,6 +1549,7 @@ class Attention(Module):
1548
1549
  gumbel_softmax_temp = gumbel_softmax_temp,
1549
1550
  gumbel_softmax_hard = gumbel_softmax_hard,
1550
1551
  selective = selective,
1552
+ cog_signed = cog_signed,
1551
1553
  custom_attn_fn = custom_attn_fn,
1552
1554
  add_zero_kv = add_zero_kv,
1553
1555
  head_learned_sink = head_learned_sink,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.9.1
3
+ Version: 2.10.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2586,4 +2586,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2586
2586
  }
2587
2587
  ```
2588
2588
 
2589
+ ```bibtex
2590
+ @inproceedings{anonymous2025more,
2591
+ title = {More Expressive Attention with Negative Weights},
2592
+ author = {Anonymous},
2593
+ booktitle = {Submitted to The Fourteenth International Conference on Learning Representations},
2594
+ year = {2025},
2595
+ url = {https://openreview.net/forum?id=ezRrwwbxd0},
2596
+ note = {under review}
2597
+ }
2598
+ ```
2599
+
2589
2600
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
2
- x_transformers/attend.py,sha256=mQj3p4RMfifL_P-S-hCtEolIU_GDDLwHZJ2cT8wBf7Q,18356
2
+ x_transformers/attend.py,sha256=l968RkOaypWMb_Ba-n82zKms4b62Ng337wtigvPAums,19236
3
3
  x_transformers/autoregressive_wrapper.py,sha256=BsGO9xfVYkvynqbU1__tu_S_cxl7gss0YwnkhIa2baY,18401
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
@@ -10,10 +10,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
10
10
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
11
11
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
12
12
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
13
- x_transformers/x_transformers.py,sha256=o6B10urcC7MRUrmoHOgYJgkrVDzHhX-jt6zZY3pZEgA,125700
13
+ x_transformers/x_transformers.py,sha256=ADr83Fz2cehj_F7N1bMwxhAg-r48fGhlaZqw3hxoxMQ,125765
14
14
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
15
15
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
16
- x_transformers-2.9.1.dist-info/METADATA,sha256=XMP7KEX4fg8VvvcpULmJT_1KBKfwIG8yyiHvCnscOyg,95381
17
- x_transformers-2.9.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
- x_transformers-2.9.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
- x_transformers-2.9.1.dist-info/RECORD,,
16
+ x_transformers-2.10.0.dist-info/METADATA,sha256=1tiahG4NWO99cWEZ_qRgdgKHSWRIUKdf0xl2j0BfIXQ,95736
17
+ x_transformers-2.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ x_transformers-2.10.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
+ x_transformers-2.10.0.dist-info/RECORD,,