x-transformers 1.16.12__py3-none-any.whl → 1.16.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -138,7 +138,7 @@ class Attend(nn.Module):
138
138
 
139
139
  if causal:
140
140
  causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
141
- mask = mask & ~causal_mask
141
+ mask = mask | causal_mask
142
142
  causal = False
143
143
 
144
144
  # handle alibi positional bias
@@ -153,7 +153,7 @@ class Attend(nn.Module):
153
153
  mask_value = -torch.finfo(q.dtype).max
154
154
 
155
155
  if exists(mask):
156
- attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
156
+ attn_bias = attn_bias.masked_fill(mask, mask_value // 2)
157
157
  elif causal:
158
158
  causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
159
159
  attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
@@ -163,6 +163,8 @@ class Attend(nn.Module):
163
163
  # make it an additive bias here
164
164
 
165
165
  mask = attn_bias
166
+ else:
167
+ mask = ~mask
166
168
 
167
169
  # Check if there is a compatible device for flash attention
168
170
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.16.12
3
+ Version: 1.16.14
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,12 +1,12 @@
1
1
  x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
2
- x_transformers/attend.py,sha256=D-fIw9jTyda2h0GR0SY2cOaDnveMdw9p9gRULKe9Rhc,10381
2
+ x_transformers/attend.py,sha256=Mm3KSsGKJ1PTqoCl0Za5K7S06Bs6cwaCMkC9-fqw1QY,10418
3
3
  x_transformers/autoregressive_wrapper.py,sha256=u2celA8KeHm_Gd83Q7qaiLbJnwaDGdsbUck-JiokpKg,4446
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
6
  x_transformers/x_transformers.py,sha256=n8NeIWo7WijFK4t9L5-ixlgvFnuxP-imEtW7T8e-pWg,54202
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
8
- x_transformers-1.16.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.16.12.dist-info/METADATA,sha256=V4nm8BFPP5KNmCRnwyWuKucXuzK676h_bZ500x6M00o,666
10
- x_transformers-1.16.12.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
11
- x_transformers-1.16.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.16.12.dist-info/RECORD,,
8
+ x_transformers-1.16.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.16.14.dist-info/METADATA,sha256=lubdgZENatrU56P9BEsAvluo_yewrPydqf57coFFmZ0,666
10
+ x_transformers-1.16.14.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
11
+ x_transformers-1.16.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.16.14.dist-info/RECORD,,