x-transformers 1.16.12__py3-none-any.whl → 1.16.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +4 -2
- {x_transformers-1.16.12.dist-info → x_transformers-1.16.14.dist-info}/METADATA +1 -1
- {x_transformers-1.16.12.dist-info → x_transformers-1.16.14.dist-info}/RECORD +6 -6
- {x_transformers-1.16.12.dist-info → x_transformers-1.16.14.dist-info}/LICENSE +0 -0
- {x_transformers-1.16.12.dist-info → x_transformers-1.16.14.dist-info}/WHEEL +0 -0
- {x_transformers-1.16.12.dist-info → x_transformers-1.16.14.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -138,7 +138,7 @@ class Attend(nn.Module):
|
|
138
138
|
|
139
139
|
if causal:
|
140
140
|
causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
|
141
|
-
mask = mask
|
141
|
+
mask = mask | causal_mask
|
142
142
|
causal = False
|
143
143
|
|
144
144
|
# handle alibi positional bias
|
@@ -153,7 +153,7 @@ class Attend(nn.Module):
|
|
153
153
|
mask_value = -torch.finfo(q.dtype).max
|
154
154
|
|
155
155
|
if exists(mask):
|
156
|
-
attn_bias = attn_bias.masked_fill(
|
156
|
+
attn_bias = attn_bias.masked_fill(mask, mask_value // 2)
|
157
157
|
elif causal:
|
158
158
|
causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
|
159
159
|
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
|
@@ -163,6 +163,8 @@ class Attend(nn.Module):
|
|
163
163
|
# make it an additive bias here
|
164
164
|
|
165
165
|
mask = attn_bias
|
166
|
+
else:
|
167
|
+
mask = ~mask
|
166
168
|
|
167
169
|
# Check if there is a compatible device for flash attention
|
168
170
|
|
@@ -1,12 +1,12 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=Mm3KSsGKJ1PTqoCl0Za5K7S06Bs6cwaCMkC9-fqw1QY,10418
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=u2celA8KeHm_Gd83Q7qaiLbJnwaDGdsbUck-JiokpKg,4446
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
6
|
x_transformers/x_transformers.py,sha256=n8NeIWo7WijFK4t9L5-ixlgvFnuxP-imEtW7T8e-pWg,54202
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
|
8
|
-
x_transformers-1.16.
|
9
|
-
x_transformers-1.16.
|
10
|
-
x_transformers-1.16.
|
11
|
-
x_transformers-1.16.
|
12
|
-
x_transformers-1.16.
|
8
|
+
x_transformers-1.16.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.16.14.dist-info/METADATA,sha256=lubdgZENatrU56P9BEsAvluo_yewrPydqf57coFFmZ0,666
|
10
|
+
x_transformers-1.16.14.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
11
|
+
x_transformers-1.16.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.16.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|