x-transformers 1.31.14__py3-none-any.whl → 1.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +20 -5
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.0.dist-info}/METADATA +1 -1
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.0.dist-info}/RECORD +6 -6
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.0.dist-info}/WHEEL +1 -1
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.0.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -194,18 +194,21 @@ class Attend(Module):
|
|
194
194
|
|
195
195
|
# manually handle causal mask, if another mask was given
|
196
196
|
|
197
|
-
row_is_entirely_masked = None
|
198
|
-
|
199
197
|
if exists(mask) and causal:
|
200
198
|
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
201
199
|
mask = mask & ~causal_mask
|
200
|
+
causal = False
|
201
|
+
|
202
|
+
# protect against an entire row being masked out
|
202
203
|
|
203
|
-
|
204
|
+
row_is_entirely_masked = None
|
204
205
|
|
206
|
+
if exists(mask):
|
205
207
|
row_is_entirely_masked = ~mask.any(dim = -1)
|
206
|
-
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
|
207
208
|
|
208
|
-
|
209
|
+
if row_is_entirely_masked.any():
|
210
|
+
mask = mask.clone()
|
211
|
+
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
|
209
212
|
|
210
213
|
# handle alibi positional bias
|
211
214
|
# convert from bool to float
|
@@ -334,6 +337,15 @@ class Attend(Module):
|
|
334
337
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
335
338
|
sim = sim.masked_fill(causal_mask, mask_value)
|
336
339
|
|
340
|
+
row_is_entirely_masked = None
|
341
|
+
|
342
|
+
if exists(mask):
|
343
|
+
row_is_entirely_masked = ~mask.any(dim = -1)
|
344
|
+
|
345
|
+
if row_is_entirely_masked.any():
|
346
|
+
mask = mask.clone()
|
347
|
+
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
|
348
|
+
|
337
349
|
if exists(self.cope):
|
338
350
|
sim = sim + self.cope(q, sim)
|
339
351
|
|
@@ -357,4 +369,7 @@ class Attend(Module):
|
|
357
369
|
post_softmax_attn = post_softmax_attn
|
358
370
|
)
|
359
371
|
|
372
|
+
if exists(row_is_entirely_masked):
|
373
|
+
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
374
|
+
|
360
375
|
return out, intermediates
|
@@ -1,5 +1,5 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=5ms39Df8osTUHQ-XTCgP4vSUA4UiNpim9VXJtrLrIvQ,724
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=2nN708coYLzvTy937KCKR1iI_uhgfmTnY9GWGsSXjHw,11587
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
@@ -7,8 +7,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T
|
|
7
7
|
x_transformers/x_transformers.py,sha256=1QG7zUe89h1R5VDMoKEAkvdRRDkzQ7h6npkqblxxR6g,76312
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
10
|
+
x_transformers-1.32.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.32.0.dist-info/METADATA,sha256=3GkF9cqLxmReELQRflSZpqSXP9tt10A23eiR6wRGzIs,661
|
12
|
+
x_transformers-1.32.0.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
13
|
+
x_transformers-1.32.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.32.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|