x-transformers 1.34.0__py3-none-any.whl → 1.34.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +21 -3
- {x_transformers-1.34.0.dist-info → x_transformers-1.34.1.dist-info}/METADATA +1 -1
- {x_transformers-1.34.0.dist-info → x_transformers-1.34.1.dist-info}/RECORD +6 -6
- {x_transformers-1.34.0.dist-info → x_transformers-1.34.1.dist-info}/WHEEL +1 -1
- {x_transformers-1.34.0.dist-info → x_transformers-1.34.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.34.0.dist-info → x_transformers-1.34.1.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -138,9 +138,27 @@ class Attend(Module):
|
|
138
138
|
# flash attention
|
139
139
|
|
140
140
|
self.flash = flash
|
141
|
-
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
142
141
|
|
143
|
-
|
142
|
+
torch_version = version.parse(torch.__version__)
|
143
|
+
assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
144
|
+
|
145
|
+
# torch 2.3 uses new backend and context manager
|
146
|
+
|
147
|
+
if torch_version >= version.parse('2.3'):
|
148
|
+
from torch.nn.attention import SDPBackend
|
149
|
+
|
150
|
+
str_to_backend = dict(
|
151
|
+
enable_flash = SDPBackend.FLASH_ATTENTION,
|
152
|
+
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
153
|
+
enable_math = SDPBackend.MATH,
|
154
|
+
enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
155
|
+
)
|
156
|
+
|
157
|
+
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
|
158
|
+
|
159
|
+
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
|
160
|
+
else:
|
161
|
+
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
|
144
162
|
|
145
163
|
def flash_attn(
|
146
164
|
self,
|
@@ -231,7 +249,7 @@ class Attend(Module):
|
|
231
249
|
|
232
250
|
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
233
251
|
|
234
|
-
with
|
252
|
+
with self.sdp_context_manager():
|
235
253
|
out = F.scaled_dot_product_attention(
|
236
254
|
q, k, v,
|
237
255
|
attn_mask = mask,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=ka_iiej5lEBOcbutWQgGrFVMDilz2PFWzLhBh5_tmmg,10366
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T
|
|
8
8
|
x_transformers/x_transformers.py,sha256=hs9j-lHukVGYLlpbBhn4CZhSzI7s0x6bYtEhCc33ftE,78680
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.34.
|
12
|
-
x_transformers-1.34.
|
13
|
-
x_transformers-1.34.
|
14
|
-
x_transformers-1.34.
|
15
|
-
x_transformers-1.34.
|
11
|
+
x_transformers-1.34.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.34.1.dist-info/METADATA,sha256=jSsnjS0ptrIpH-nc9h7fNMjzAvpmQGOkXYqTSWyUvGQ,661
|
13
|
+
x_transformers-1.34.1.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
14
|
+
x_transformers-1.34.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.34.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|