x-transformers 1.16.14__py3-none-any.whl → 1.16.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +3 -5
- x_transformers/x_transformers.py +1 -1
- {x_transformers-1.16.14.dist-info → x_transformers-1.16.15.dist-info}/METADATA +1 -1
- {x_transformers-1.16.14.dist-info → x_transformers-1.16.15.dist-info}/RECORD +7 -7
- {x_transformers-1.16.14.dist-info → x_transformers-1.16.15.dist-info}/LICENSE +0 -0
- {x_transformers-1.16.14.dist-info → x_transformers-1.16.15.dist-info}/WHEEL +0 -0
- {x_transformers-1.16.14.dist-info → x_transformers-1.16.15.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -138,7 +138,7 @@ class Attend(nn.Module):
|
|
138
138
|
|
139
139
|
if causal:
|
140
140
|
causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
|
141
|
-
mask = mask
|
141
|
+
mask = mask & ~causal_mask
|
142
142
|
causal = False
|
143
143
|
|
144
144
|
# handle alibi positional bias
|
@@ -153,7 +153,7 @@ class Attend(nn.Module):
|
|
153
153
|
mask_value = -torch.finfo(q.dtype).max
|
154
154
|
|
155
155
|
if exists(mask):
|
156
|
-
attn_bias = attn_bias.masked_fill(mask, mask_value // 2)
|
156
|
+
attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
|
157
157
|
elif causal:
|
158
158
|
causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
|
159
159
|
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
|
@@ -163,8 +163,6 @@ class Attend(nn.Module):
|
|
163
163
|
# make it an additive bias here
|
164
164
|
|
165
165
|
mask = attn_bias
|
166
|
-
else:
|
167
|
-
mask = ~mask
|
168
166
|
|
169
167
|
# Check if there is a compatible device for flash attention
|
170
168
|
|
@@ -226,7 +224,7 @@ class Attend(nn.Module):
|
|
226
224
|
mask_value = -torch.finfo(dots.dtype).max
|
227
225
|
|
228
226
|
if exists(mask):
|
229
|
-
dots = dots.masked_fill(mask, mask_value)
|
227
|
+
dots = dots.masked_fill(~mask, mask_value)
|
230
228
|
|
231
229
|
if self.causal:
|
232
230
|
i, j = dots.shape[-2:]
|
x_transformers/x_transformers.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=tDgL6ryV3QV2bU855iiyF19dbDWWdZ52dCPKk99sfOs,10382
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=u2celA8KeHm_Gd83Q7qaiLbJnwaDGdsbUck-JiokpKg,4446
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=rroq3Qq4XQ190XeVXeaiIsJdcYfidIOEt69N5DV_Rdo,54203
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
|
8
|
-
x_transformers-1.16.
|
9
|
-
x_transformers-1.16.
|
10
|
-
x_transformers-1.16.
|
11
|
-
x_transformers-1.16.
|
12
|
-
x_transformers-1.16.
|
8
|
+
x_transformers-1.16.15.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.16.15.dist-info/METADATA,sha256=-vQVqSIeWbDuvNGVV0JXOAyukJ4Ge6I0BHHcWhdTIdM,666
|
10
|
+
x_transformers-1.16.15.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
11
|
+
x_transformers-1.16.15.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.16.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|