x-transformers 1.31.14__py3-none-any.whl → 1.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -194,18 +194,21 @@ class Attend(Module):
194
194
 
195
195
  # manually handle causal mask, if another mask was given
196
196
 
197
- row_is_entirely_masked = None
198
-
199
197
  if exists(mask) and causal:
200
198
  causal_mask = self.create_causal_mask(q_len, k_len, device = device)
201
199
  mask = mask & ~causal_mask
200
+ causal = False
201
+
202
+ # protect against an entire row being masked out
202
203
 
203
- # protect against an entire row being masked out
204
+ row_is_entirely_masked = None
204
205
 
206
+ if exists(mask):
205
207
  row_is_entirely_masked = ~mask.any(dim = -1)
206
- mask[..., 0] = mask[..., 0] | row_is_entirely_masked
207
208
 
208
- causal = False
209
+ if row_is_entirely_masked.any():
210
+ mask = mask.clone()
211
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
209
212
 
210
213
  # handle alibi positional bias
211
214
  # convert from bool to float
@@ -334,6 +337,15 @@ class Attend(Module):
334
337
  causal_mask = self.create_causal_mask(i, j, device = device)
335
338
  sim = sim.masked_fill(causal_mask, mask_value)
336
339
 
340
+ row_is_entirely_masked = None
341
+
342
+ if exists(mask):
343
+ row_is_entirely_masked = ~mask.any(dim = -1)
344
+
345
+ if row_is_entirely_masked.any():
346
+ mask = mask.clone()
347
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
348
+
337
349
  if exists(self.cope):
338
350
  sim = sim + self.cope(q, sim)
339
351
 
@@ -357,4 +369,7 @@ class Attend(Module):
357
369
  post_softmax_attn = post_softmax_attn
358
370
  )
359
371
 
372
+ if exists(row_is_entirely_masked):
373
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
374
+
360
375
  return out, intermediates
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.14
3
+ Version: 1.32.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=5ms39Df8osTUHQ-XTCgP4vSUA4UiNpim9VXJtrLrIvQ,724
2
- x_transformers/attend.py,sha256=oAS0vSy5qH7iTCXzHKfM4k7m_fvuZIR49PStZO8OFJo,11089
2
+ x_transformers/attend.py,sha256=2nN708coYLzvTy937KCKR1iI_uhgfmTnY9GWGsSXjHw,11587
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
@@ -7,8 +7,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T
7
7
  x_transformers/x_transformers.py,sha256=1QG7zUe89h1R5VDMoKEAkvdRRDkzQ7h6npkqblxxR6g,76312
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.31.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.31.14.dist-info/METADATA,sha256=Qj5yRxhBmF87HtYeWuFTiiYVZf-eDdXabhW_P5McQ7w,662
12
- x_transformers-1.31.14.dist-info/WHEEL,sha256=-oYQCr74JF3a37z2nRlQays_SX2MqOANoqVjBBAP2yE,91
13
- x_transformers-1.31.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.31.14.dist-info/RECORD,,
10
+ x_transformers-1.32.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.32.0.dist-info/METADATA,sha256=3GkF9cqLxmReELQRflSZpqSXP9tt10A23eiR6wRGzIs,661
12
+ x_transformers-1.32.0.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
13
+ x_transformers-1.32.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.32.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (71.0.3)
2
+ Generator: setuptools (72.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5