x-transformers 1.31.14__py3-none-any.whl → 1.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +13 -6
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.1.dist-info}/METADATA +1 -1
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.1.dist-info}/RECORD +6 -6
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.1.dist-info}/WHEEL +1 -1
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.31.14.dist-info → x_transformers-1.32.1.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -194,18 +194,17 @@ class Attend(Module):
|
|
194
194
|
|
195
195
|
# manually handle causal mask, if another mask was given
|
196
196
|
|
197
|
-
row_is_entirely_masked = None
|
198
|
-
|
199
197
|
if exists(mask) and causal:
|
200
198
|
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
|
201
199
|
mask = mask & ~causal_mask
|
200
|
+
causal = False
|
202
201
|
|
203
|
-
|
202
|
+
# protect against an entire row being masked out
|
204
203
|
|
205
|
-
|
206
|
-
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
|
204
|
+
row_is_entirely_masked = None
|
207
205
|
|
208
|
-
|
206
|
+
if exists(mask):
|
207
|
+
row_is_entirely_masked = ~mask.any(dim = -1)
|
209
208
|
|
210
209
|
# handle alibi positional bias
|
211
210
|
# convert from bool to float
|
@@ -334,6 +333,11 @@ class Attend(Module):
|
|
334
333
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
335
334
|
sim = sim.masked_fill(causal_mask, mask_value)
|
336
335
|
|
336
|
+
row_is_entirely_masked = None
|
337
|
+
|
338
|
+
if exists(mask):
|
339
|
+
row_is_entirely_masked = ~mask.any(dim = -1)
|
340
|
+
|
337
341
|
if exists(self.cope):
|
338
342
|
sim = sim + self.cope(q, sim)
|
339
343
|
|
@@ -357,4 +361,7 @@ class Attend(Module):
|
|
357
361
|
post_softmax_attn = post_softmax_attn
|
358
362
|
)
|
359
363
|
|
364
|
+
if exists(row_is_entirely_masked):
|
365
|
+
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
|
366
|
+
|
360
367
|
return out, intermediates
|
@@ -1,5 +1,5 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=5ms39Df8osTUHQ-XTCgP4vSUA4UiNpim9VXJtrLrIvQ,724
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=PzJ0_MnysNDV-jcDuCbAFPEUZxgljVYOShvtMGfO7wU,11283
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
@@ -7,8 +7,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T
|
|
7
7
|
x_transformers/x_transformers.py,sha256=1QG7zUe89h1R5VDMoKEAkvdRRDkzQ7h6npkqblxxR6g,76312
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
10
|
+
x_transformers-1.32.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.32.1.dist-info/METADATA,sha256=NAyDRgkGJJ5ovqj1--2CgHKalfczAxWcyKiEr7Spxc4,661
|
12
|
+
x_transformers-1.32.1.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
13
|
+
x_transformers-1.32.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.32.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|