x-transformers 1.39.3__py3-none-any.whl → 1.40.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +33 -27
- x_transformers/x_transformers.py +2 -2
- {x_transformers-1.39.3.dist-info → x_transformers-1.40.0.dist-info}/METADATA +1 -1
- {x_transformers-1.39.3.dist-info → x_transformers-1.40.0.dist-info}/RECORD +7 -7
- {x_transformers-1.39.3.dist-info → x_transformers-1.40.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.39.3.dist-info → x_transformers-1.40.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.39.3.dist-info → x_transformers-1.40.0.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -107,14 +107,32 @@ def qk_l2_dist_squared(q, k):
|
|
107
107
|
l2_dist_squared = torch.cdist(q, k) ** 2
|
108
108
|
return unpack_one(l2_dist_squared, packed_shape, '* i j')
|
109
109
|
|
110
|
-
#
|
110
|
+
# one-hot straight through softmax
|
111
111
|
|
112
|
-
def
|
113
|
-
one_hot_indices =
|
114
|
-
one_hot = torch.zeros_like(
|
112
|
+
def one_hot_straight_through(logits, temperature = 1.):
|
113
|
+
one_hot_indices = logits.argmax(dim = -1, keepdim = True)
|
114
|
+
one_hot = torch.zeros_like(logits).scatter(-1, one_hot_indices, 1.)
|
115
115
|
|
116
|
-
|
117
|
-
return one_hot +
|
116
|
+
soft_attn = (logits / temperature).softmax(dim = -1)
|
117
|
+
return one_hot + soft_attn - soft_attn.detach()
|
118
|
+
|
119
|
+
# sparse topk attention - only keep topk attn logits for softmax
|
120
|
+
# optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`
|
121
|
+
|
122
|
+
def sparse_topk_attn(logits, sparse_topk, temperature = 1., straight_through = False):
|
123
|
+
orig_logits = logits
|
124
|
+
|
125
|
+
mask_value = -torch.finfo(logits.dtype).max
|
126
|
+
top_values, _ = logits.topk(sparse_topk, dim = -1)
|
127
|
+
sparse_topk_mask = (logits >= top_values[..., -1:]) & (logits > mask_value)
|
128
|
+
logits = logits.masked_fill(~sparse_topk_mask, mask_value)
|
129
|
+
topk_attn = logits.softmax(dim = -1)
|
130
|
+
|
131
|
+
if not straight_through:
|
132
|
+
return topk_attn
|
133
|
+
|
134
|
+
soft_attn = (orig_logits / temperature).softmax(dim = -1)
|
135
|
+
return topk_attn + soft_attn - soft_attn.detach()
|
118
136
|
|
119
137
|
# functions for creating causal mask
|
120
138
|
# need a special one for onnx cpu (no support for .triu)
|
@@ -141,6 +159,7 @@ class Attend(Module):
|
|
141
159
|
post_talking_heads = False,
|
142
160
|
pre_scale_post_talking_heads = False,
|
143
161
|
sparse_topk = None,
|
162
|
+
sparse_topk_straight_through = False,
|
144
163
|
scale = None,
|
145
164
|
qk_norm = False,
|
146
165
|
l2_distance = False,
|
@@ -152,7 +171,6 @@ class Attend(Module):
|
|
152
171
|
add_zero_kv = False,
|
153
172
|
selective = False,
|
154
173
|
hard = False,
|
155
|
-
sigsoftmax = False,
|
156
174
|
cope = None,
|
157
175
|
onnxable = False,
|
158
176
|
sdp_kwargs: dict = dict(
|
@@ -171,16 +189,22 @@ class Attend(Module):
|
|
171
189
|
|
172
190
|
# attention type
|
173
191
|
|
192
|
+
is_sparse_topk_attn = exists(sparse_topk)
|
193
|
+
|
174
194
|
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
|
175
195
|
assert not (flash and hard), 'hard attention not available for flash'
|
176
|
-
assert
|
196
|
+
assert not (flash and is_sparse_topk_attn), 'topk attention not available for flash'
|
197
|
+
|
198
|
+
assert at_most_one_of(sigmoid, hard, l2_distance, is_sparse_topk_attn)
|
177
199
|
|
178
200
|
if exists(custom_attn_fn):
|
179
201
|
self.attn_fn = custom_attn_fn
|
180
202
|
elif sigmoid:
|
181
203
|
self.attn_fn = F.sigmoid
|
182
204
|
elif hard:
|
183
|
-
self.attn_fn =
|
205
|
+
self.attn_fn = one_hot_straight_through
|
206
|
+
elif is_sparse_topk_attn:
|
207
|
+
self.attn_fn = partial(sparse_topk_attn, sparse_topk = sparse_topk, straight_through = sparse_topk_straight_through)
|
184
208
|
else:
|
185
209
|
softmax_fn = partial(F.softmax, dim = -1)
|
186
210
|
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
@@ -214,16 +238,6 @@ class Attend(Module):
|
|
214
238
|
assert not (selective and not causal), 'selective attention is designed for autoregressive'
|
215
239
|
self.selective = selective
|
216
240
|
|
217
|
-
# sparse topk
|
218
|
-
|
219
|
-
assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
|
220
|
-
self.sparse_topk = sparse_topk
|
221
|
-
|
222
|
-
# sig softmax
|
223
|
-
|
224
|
-
assert not (flash and sigsoftmax), 'sigsoftmax not available for flash attention'
|
225
|
-
self.sigsoftmax = sigsoftmax
|
226
|
-
|
227
241
|
# l2 distance attention
|
228
242
|
|
229
243
|
self.l2_distance = l2_distance
|
@@ -476,11 +490,6 @@ class Attend(Module):
|
|
476
490
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
477
491
|
sim = sim.masked_fill(causal_mask, mask_value)
|
478
492
|
|
479
|
-
if exists(self.sparse_topk):
|
480
|
-
top_values, _ = sim.topk(self.sparse_topk, dim = -1)
|
481
|
-
sparse_topk_mask = (sim >= top_values[..., -1:]) & (sim > mask_value)
|
482
|
-
sim = sim.masked_fill(~sparse_topk_mask, mask_value)
|
483
|
-
|
484
493
|
row_is_entirely_masked = None
|
485
494
|
|
486
495
|
if exists(mask):
|
@@ -494,9 +503,6 @@ class Attend(Module):
|
|
494
503
|
|
495
504
|
pre_softmax_attn = sim
|
496
505
|
|
497
|
-
if self.sigsoftmax:
|
498
|
-
sim = sim + sim.sigmoid().log()
|
499
|
-
|
500
506
|
attn = self.attn_fn(sim)
|
501
507
|
|
502
508
|
attn = attn.type(dtype)
|
x_transformers/x_transformers.py
CHANGED
@@ -912,6 +912,7 @@ class Attention(Module):
|
|
912
912
|
pre_scale_post_talking_heads = False,
|
913
913
|
head_scale = False,
|
914
914
|
sparse_topk = None,
|
915
|
+
sparse_topk_straight_through = False,
|
915
916
|
num_mem_kv = 0,
|
916
917
|
dropout = 0.,
|
917
918
|
on_attn = False,
|
@@ -920,7 +921,6 @@ class Attention(Module):
|
|
920
921
|
gate_values = False,
|
921
922
|
zero_init_output = False,
|
922
923
|
hard = False,
|
923
|
-
sigsoftmax = False,
|
924
924
|
max_attend_past = None,
|
925
925
|
qk_norm = False,
|
926
926
|
qk_norm_groups = 1,
|
@@ -1044,6 +1044,7 @@ class Attention(Module):
|
|
1044
1044
|
pre_scale_post_talking_heads = pre_scale_post_talking_heads,
|
1045
1045
|
dropout = dropout,
|
1046
1046
|
sparse_topk = sparse_topk,
|
1047
|
+
sparse_topk_straight_through = sparse_topk_straight_through,
|
1047
1048
|
hard = hard,
|
1048
1049
|
qk_norm = qk_norm,
|
1049
1050
|
scale = qk_norm_scale if qk_norm else self.scale,
|
@@ -1054,7 +1055,6 @@ class Attention(Module):
|
|
1054
1055
|
add_zero_kv = add_zero_kv,
|
1055
1056
|
flash = flash,
|
1056
1057
|
softclamp_logits = softclamp_logits,
|
1057
|
-
sigsoftmax = sigsoftmax,
|
1058
1058
|
logit_softclamp_value = logit_softclamp_value,
|
1059
1059
|
cope = cope,
|
1060
1060
|
onnxable = onnxable
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=eoBEK0HdDCWaJgxwGZPeO36ydBt1NbB-gpij_Jkj4Mw,17212
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=TGXJZXCWR5BiMkS5Kx-JhFQ85AxkiJabLiHnrCTC874,84562
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.40.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.0.dist-info/METADATA,sha256=WxBpjG7F8utkdJN1AF9vZMAfbltT0lmpdgtUErBbYQY,661
|
13
|
+
x_transformers-1.40.0.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|