x-transformers 1.37.9__tar.gz → 1.38.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.37.9/x_transformers.egg-info → x_transformers-1.38.0}/PKG-INFO +1 -1
  2. {x_transformers-1.37.9 → x_transformers-1.38.0}/README.md +9 -0
  3. {x_transformers-1.37.9 → x_transformers-1.38.0}/setup.py +1 -1
  4. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/attend.py +48 -7
  5. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/x_transformers.py +4 -0
  6. {x_transformers-1.37.9 → x_transformers-1.38.0/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.37.9 → x_transformers-1.38.0}/LICENSE +0 -0
  8. {x_transformers-1.37.9 → x_transformers-1.38.0}/setup.cfg +0 -0
  9. {x_transformers-1.37.9 → x_transformers-1.38.0}/tests/test_x_transformers.py +0 -0
  10. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/__init__.py +0 -0
  11. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  16. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.9
3
+ Version: 1.38.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2261,4 +2261,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2261
2261
  }
2262
2262
  ```
2263
2263
 
2264
+ ```bibtex
2265
+ @inproceedings{Leviathan2024SelectiveAI,
2266
+ title = {Selective Attention Improves Transformer},
2267
+ author = {Yaniv Leviathan and Matan Kalman and Yossi Matias},
2268
+ year = {2024},
2269
+ url = {https://api.semanticscholar.org/CorpusID:273098114}
2270
+ }
2271
+ ```
2272
+
2264
2273
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.37.9',
6
+ version = '1.38.0',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from functools import partial
4
- from typing import Tuple
4
+ from typing import Tuple, Callable
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
@@ -65,6 +65,35 @@ def once(fn):
65
65
 
66
66
  print_once = once(print)
67
67
 
68
+ # selective attention
69
+ # https://arxiv.org/abs/2410.02703 - section 3.3
70
+ # it is a parameter-less technique to allow each token to prevent itself from being attended to by future tokens
71
+ # if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
72
+
73
+ def selective_attn(
74
+ sim,
75
+ sim_head_gate = None,
76
+ no_mask_sos = True
77
+ ):
78
+ i, j, device = *sim.shape[-2:], sim.device
79
+ sim_head_gate = default(sim_head_gate, sim[:, 0])
80
+
81
+ gate = F.relu(sim_head_gate) # only positive
82
+
83
+ if no_mask_sos:
84
+ gate[..., -i] = 0.
85
+
86
+ eye = torch.eye(j, device = device)
87
+
88
+ if j > i:
89
+ eye = F.pad(eye, (j - i, 0), value = 1.)
90
+
91
+ gate = (1. - eye) * gate
92
+ gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future
93
+ gate = gate.cumsum(dim = -2)
94
+
95
+ return sim - rearrange(gate, 'b i j -> b 1 i j')
96
+
68
97
  # alternative distance functions
69
98
 
70
99
  def qk_l2_dist_squared(q, k):
@@ -104,10 +133,12 @@ class Attend(Module):
104
133
  qk_norm = False,
105
134
  l2_distance = False,
106
135
  sigmoid = False,
136
+ custom_attn_fn: Callable | None = None,
107
137
  flash = False,
108
138
  softclamp_logits = False,
109
139
  logit_softclamp_value = 50.,
110
140
  add_zero_kv = False,
141
+ selective = False,
111
142
  sigsoftmax = False,
112
143
  cope = None,
113
144
  onnxable = False,
@@ -132,7 +163,9 @@ class Attend(Module):
132
163
 
133
164
  self.sigmoid = sigmoid
134
165
 
135
- if not sigmoid:
166
+ if exists(custom_attn_fn):
167
+ self.attn_fn = custom_attn_fn
168
+ elif not sigmoid:
136
169
  softmax_fn = partial(F.softmax, dim = -1)
137
170
  self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
138
171
  else:
@@ -152,6 +185,11 @@ class Attend(Module):
152
185
  self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
153
186
  self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
154
187
 
188
+ # selective attention
189
+
190
+ assert not (selective and not causal), 'selective attention is designed for autoregressive'
191
+ self.selective = selective
192
+
155
193
  # sparse topk
156
194
 
157
195
  assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
@@ -404,11 +442,6 @@ class Attend(Module):
404
442
 
405
443
  mask_value = -torch.finfo(sim.dtype).max
406
444
 
407
- if exists(self.sparse_topk) and self.sparse_topk < j:
408
- top_values, _ = sim.topk(self.sparse_topk, dim = -1)
409
- sparse_topk_mask = sim < top_values[..., -1:]
410
- mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
411
-
412
445
  if exists(mask):
413
446
  sim = sim.masked_fill(~mask, mask_value)
414
447
 
@@ -416,6 +449,11 @@ class Attend(Module):
416
449
  causal_mask = self.create_causal_mask(i, j, device = device)
417
450
  sim = sim.masked_fill(causal_mask, mask_value)
418
451
 
452
+ if exists(self.sparse_topk):
453
+ top_values, _ = sim.topk(self.sparse_topk, dim = -1)
454
+ sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
455
+ sim = sim.masked_fill(~sparse_topk_mask, mask_value)
456
+
419
457
  row_is_entirely_masked = None
420
458
 
421
459
  if exists(mask):
@@ -424,6 +462,9 @@ class Attend(Module):
424
462
  if exists(self.cope):
425
463
  sim = sim + self.cope(q, sim)
426
464
 
465
+ if self.selective:
466
+ sim = selective_attn(sim)
467
+
427
468
  pre_softmax_attn = sim
428
469
 
429
470
  if self.sigsoftmax:
@@ -925,6 +925,8 @@ class Attention(Module):
925
925
  qk_norm_dim_scale = False,
926
926
  l2_distance = False,
927
927
  sigmoid = False,
928
+ selective = False,
929
+ custom_attn_fn: Callable | None = None,
928
930
  one_kv_head = False,
929
931
  kv_heads = None,
930
932
  shared_kv = False,
@@ -1041,6 +1043,8 @@ class Attention(Module):
1041
1043
  scale = qk_norm_scale if qk_norm else self.scale,
1042
1044
  l2_distance = l2_distance,
1043
1045
  sigmoid = sigmoid,
1046
+ selective = selective,
1047
+ custom_attn_fn = custom_attn_fn,
1044
1048
  add_zero_kv = add_zero_kv,
1045
1049
  flash = flash,
1046
1050
  softclamp_logits = softclamp_logits,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.9
3
+ Version: 1.38.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes