x-transformers 1.37.9__tar.gz → 1.38.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-1.37.9/x_transformers.egg-info → x_transformers-1.38.0}/PKG-INFO +1 -1
- {x_transformers-1.37.9 → x_transformers-1.38.0}/README.md +9 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/setup.py +1 -1
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/attend.py +48 -7
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/x_transformers.py +4 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.37.9 → x_transformers-1.38.0}/LICENSE +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/setup.cfg +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/tests/test_x_transformers.py +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/__init__.py +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/continuous.py +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/dpo.py +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers/xval.py +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers.egg-info/top_level.txt +0 -0
@@ -2261,4 +2261,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2261
2261
|
}
|
2262
2262
|
```
|
2263
2263
|
|
2264
|
+
```bibtex
|
2265
|
+
@inproceedings{Leviathan2024SelectiveAI,
|
2266
|
+
title = {Selective Attention Improves Transformer},
|
2267
|
+
author = {Yaniv Leviathan and Matan Kalman and Yossi Matias},
|
2268
|
+
year = {2024},
|
2269
|
+
url = {https://api.semanticscholar.org/CorpusID:273098114}
|
2270
|
+
}
|
2271
|
+
```
|
2272
|
+
|
2264
2273
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from functools import partial
|
4
|
-
from typing import Tuple
|
4
|
+
from typing import Tuple, Callable
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn import Module
|
@@ -65,6 +65,35 @@ def once(fn):
|
|
65
65
|
|
66
66
|
print_once = once(print)
|
67
67
|
|
68
|
+
# selective attention
|
69
|
+
# https://arxiv.org/abs/2410.02703 - section 3.3
|
70
|
+
# it is a parameter-less technique to allow each token to prevent itself from being attended to by future tokens
|
71
|
+
# if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
|
72
|
+
|
73
|
+
def selective_attn(
|
74
|
+
sim,
|
75
|
+
sim_head_gate = None,
|
76
|
+
no_mask_sos = True
|
77
|
+
):
|
78
|
+
i, j, device = *sim.shape[-2:], sim.device
|
79
|
+
sim_head_gate = default(sim_head_gate, sim[:, 0])
|
80
|
+
|
81
|
+
gate = F.relu(sim_head_gate) # only positive
|
82
|
+
|
83
|
+
if no_mask_sos:
|
84
|
+
gate[..., -i] = 0.
|
85
|
+
|
86
|
+
eye = torch.eye(j, device = device)
|
87
|
+
|
88
|
+
if j > i:
|
89
|
+
eye = F.pad(eye, (j - i, 0), value = 1.)
|
90
|
+
|
91
|
+
gate = (1. - eye) * gate
|
92
|
+
gate = F.pad(gate, (0, 0, 1, -1), value = 0.) # only allow for masking the future
|
93
|
+
gate = gate.cumsum(dim = -2)
|
94
|
+
|
95
|
+
return sim - rearrange(gate, 'b i j -> b 1 i j')
|
96
|
+
|
68
97
|
# alternative distance functions
|
69
98
|
|
70
99
|
def qk_l2_dist_squared(q, k):
|
@@ -104,10 +133,12 @@ class Attend(Module):
|
|
104
133
|
qk_norm = False,
|
105
134
|
l2_distance = False,
|
106
135
|
sigmoid = False,
|
136
|
+
custom_attn_fn: Callable | None = None,
|
107
137
|
flash = False,
|
108
138
|
softclamp_logits = False,
|
109
139
|
logit_softclamp_value = 50.,
|
110
140
|
add_zero_kv = False,
|
141
|
+
selective = False,
|
111
142
|
sigsoftmax = False,
|
112
143
|
cope = None,
|
113
144
|
onnxable = False,
|
@@ -132,7 +163,9 @@ class Attend(Module):
|
|
132
163
|
|
133
164
|
self.sigmoid = sigmoid
|
134
165
|
|
135
|
-
if
|
166
|
+
if exists(custom_attn_fn):
|
167
|
+
self.attn_fn = custom_attn_fn
|
168
|
+
elif not sigmoid:
|
136
169
|
softmax_fn = partial(F.softmax, dim = -1)
|
137
170
|
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
138
171
|
else:
|
@@ -152,6 +185,11 @@ class Attend(Module):
|
|
152
185
|
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
153
186
|
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
154
187
|
|
188
|
+
# selective attention
|
189
|
+
|
190
|
+
assert not (selective and not causal), 'selective attention is designed for autoregressive'
|
191
|
+
self.selective = selective
|
192
|
+
|
155
193
|
# sparse topk
|
156
194
|
|
157
195
|
assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
|
@@ -404,11 +442,6 @@ class Attend(Module):
|
|
404
442
|
|
405
443
|
mask_value = -torch.finfo(sim.dtype).max
|
406
444
|
|
407
|
-
if exists(self.sparse_topk) and self.sparse_topk < j:
|
408
|
-
top_values, _ = sim.topk(self.sparse_topk, dim = -1)
|
409
|
-
sparse_topk_mask = sim < top_values[..., -1:]
|
410
|
-
mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
|
411
|
-
|
412
445
|
if exists(mask):
|
413
446
|
sim = sim.masked_fill(~mask, mask_value)
|
414
447
|
|
@@ -416,6 +449,11 @@ class Attend(Module):
|
|
416
449
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
417
450
|
sim = sim.masked_fill(causal_mask, mask_value)
|
418
451
|
|
452
|
+
if exists(self.sparse_topk):
|
453
|
+
top_values, _ = sim.topk(self.sparse_topk, dim = -1)
|
454
|
+
sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
|
455
|
+
sim = sim.masked_fill(~sparse_topk_mask, mask_value)
|
456
|
+
|
419
457
|
row_is_entirely_masked = None
|
420
458
|
|
421
459
|
if exists(mask):
|
@@ -424,6 +462,9 @@ class Attend(Module):
|
|
424
462
|
if exists(self.cope):
|
425
463
|
sim = sim + self.cope(q, sim)
|
426
464
|
|
465
|
+
if self.selective:
|
466
|
+
sim = selective_attn(sim)
|
467
|
+
|
427
468
|
pre_softmax_attn = sim
|
428
469
|
|
429
470
|
if self.sigsoftmax:
|
@@ -925,6 +925,8 @@ class Attention(Module):
|
|
925
925
|
qk_norm_dim_scale = False,
|
926
926
|
l2_distance = False,
|
927
927
|
sigmoid = False,
|
928
|
+
selective = False,
|
929
|
+
custom_attn_fn: Callable | None = None,
|
928
930
|
one_kv_head = False,
|
929
931
|
kv_heads = None,
|
930
932
|
shared_kv = False,
|
@@ -1041,6 +1043,8 @@ class Attention(Module):
|
|
1041
1043
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1042
1044
|
l2_distance = l2_distance,
|
1043
1045
|
sigmoid = sigmoid,
|
1046
|
+
selective = selective,
|
1047
|
+
custom_attn_fn = custom_attn_fn,
|
1044
1048
|
add_zero_kv = add_zero_kv,
|
1045
1049
|
flash = flash,
|
1046
1050
|
softclamp_logits = softclamp_logits,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.37.9 → x_transformers-1.38.0}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|