x-transformers 1.16.14__py3-none-any.whl → 1.16.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -138,7 +138,7 @@ class Attend(nn.Module):
138
138
 
139
139
  if causal:
140
140
  causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
141
- mask = mask | causal_mask
141
+ mask = mask & ~causal_mask
142
142
  causal = False
143
143
 
144
144
  # handle alibi positional bias
@@ -153,7 +153,7 @@ class Attend(nn.Module):
153
153
  mask_value = -torch.finfo(q.dtype).max
154
154
 
155
155
  if exists(mask):
156
- attn_bias = attn_bias.masked_fill(mask, mask_value // 2)
156
+ attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
157
157
  elif causal:
158
158
  causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
159
159
  attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
@@ -163,8 +163,6 @@ class Attend(nn.Module):
163
163
  # make it an additive bias here
164
164
 
165
165
  mask = attn_bias
166
- else:
167
- mask = ~mask
168
166
 
169
167
  # Check if there is a compatible device for flash attention
170
168
 
@@ -226,7 +224,7 @@ class Attend(nn.Module):
226
224
  mask_value = -torch.finfo(dots.dtype).max
227
225
 
228
226
  if exists(mask):
229
- dots = dots.masked_fill(mask, mask_value)
227
+ dots = dots.masked_fill(~mask, mask_value)
230
228
 
231
229
  if self.causal:
232
230
  i, j = dots.shape[-2:]
@@ -815,7 +815,7 @@ class Attention(nn.Module):
815
815
  masks.append(sparse_topk_mask)
816
816
 
817
817
  if len(masks) > 0:
818
- final_attn_mask = or_reduce(masks)
818
+ final_attn_mask = ~or_reduce(masks)
819
819
 
820
820
  # prepare relative positional bias, if needed
821
821
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.16.14
3
+ Version: 1.16.15
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,12 +1,12 @@
1
1
  x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
2
- x_transformers/attend.py,sha256=Mm3KSsGKJ1PTqoCl0Za5K7S06Bs6cwaCMkC9-fqw1QY,10418
2
+ x_transformers/attend.py,sha256=tDgL6ryV3QV2bU855iiyF19dbDWWdZ52dCPKk99sfOs,10382
3
3
  x_transformers/autoregressive_wrapper.py,sha256=u2celA8KeHm_Gd83Q7qaiLbJnwaDGdsbUck-JiokpKg,4446
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=n8NeIWo7WijFK4t9L5-ixlgvFnuxP-imEtW7T8e-pWg,54202
6
+ x_transformers/x_transformers.py,sha256=rroq3Qq4XQ190XeVXeaiIsJdcYfidIOEt69N5DV_Rdo,54203
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
8
- x_transformers-1.16.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.16.14.dist-info/METADATA,sha256=lubdgZENatrU56P9BEsAvluo_yewrPydqf57coFFmZ0,666
10
- x_transformers-1.16.14.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
11
- x_transformers-1.16.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.16.14.dist-info/RECORD,,
8
+ x_transformers-1.16.15.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.16.15.dist-info/METADATA,sha256=-vQVqSIeWbDuvNGVV0JXOAyukJ4Ge6I0BHHcWhdTIdM,666
10
+ x_transformers-1.16.15.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
11
+ x_transformers-1.16.15.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.16.15.dist-info/RECORD,,