x-transformers 1.31.14__py3-none-any.whl → 1.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -194,18 +194,17 @@ class Attend(Module):
194
194
 
195
195
  # manually handle causal mask, if another mask was given
196
196
 
197
- row_is_entirely_masked = None
198
-
199
197
  if exists(mask) and causal:
200
198
  causal_mask = self.create_causal_mask(q_len, k_len, device = device)
201
199
  mask = mask & ~causal_mask
200
+ causal = False
202
201
 
203
- # protect against an entire row being masked out
202
+ # protect against an entire row being masked out
204
203
 
205
- row_is_entirely_masked = ~mask.any(dim = -1)
206
- mask[..., 0] = mask[..., 0] | row_is_entirely_masked
204
+ row_is_entirely_masked = None
207
205
 
208
- causal = False
206
+ if exists(mask):
207
+ row_is_entirely_masked = ~mask.any(dim = -1)
209
208
 
210
209
  # handle alibi positional bias
211
210
  # convert from bool to float
@@ -334,6 +333,11 @@ class Attend(Module):
334
333
  causal_mask = self.create_causal_mask(i, j, device = device)
335
334
  sim = sim.masked_fill(causal_mask, mask_value)
336
335
 
336
+ row_is_entirely_masked = None
337
+
338
+ if exists(mask):
339
+ row_is_entirely_masked = ~mask.any(dim = -1)
340
+
337
341
  if exists(self.cope):
338
342
  sim = sim + self.cope(q, sim)
339
343
 
@@ -357,4 +361,7 @@ class Attend(Module):
357
361
  post_softmax_attn = post_softmax_attn
358
362
  )
359
363
 
364
+ if exists(row_is_entirely_masked):
365
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
366
+
360
367
  return out, intermediates
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.14
3
+ Version: 1.32.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=5ms39Df8osTUHQ-XTCgP4vSUA4UiNpim9VXJtrLrIvQ,724
2
- x_transformers/attend.py,sha256=oAS0vSy5qH7iTCXzHKfM4k7m_fvuZIR49PStZO8OFJo,11089
2
+ x_transformers/attend.py,sha256=PzJ0_MnysNDV-jcDuCbAFPEUZxgljVYOShvtMGfO7wU,11283
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
@@ -7,8 +7,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T
7
7
  x_transformers/x_transformers.py,sha256=1QG7zUe89h1R5VDMoKEAkvdRRDkzQ7h6npkqblxxR6g,76312
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.31.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.31.14.dist-info/METADATA,sha256=Qj5yRxhBmF87HtYeWuFTiiYVZf-eDdXabhW_P5McQ7w,662
12
- x_transformers-1.31.14.dist-info/WHEEL,sha256=-oYQCr74JF3a37z2nRlQays_SX2MqOANoqVjBBAP2yE,91
13
- x_transformers-1.31.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.31.14.dist-info/RECORD,,
10
+ x_transformers-1.32.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.32.1.dist-info/METADATA,sha256=NAyDRgkGJJ5ovqj1--2CgHKalfczAxWcyKiEr7Spxc4,661
12
+ x_transformers-1.32.1.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
13
+ x_transformers-1.32.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.32.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (71.0.3)
2
+ Generator: setuptools (72.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5