x-transformers 2.3.22__py3-none-any.whl → 2.3.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +14 -13
- x_transformers/autoregressive_wrapper.py +1 -1
- {x_transformers-2.3.22.dist-info → x_transformers-2.3.24.dist-info}/METADATA +1 -1
- {x_transformers-2.3.22.dist-info → x_transformers-2.3.24.dist-info}/RECORD +6 -6
- {x_transformers-2.3.22.dist-info → x_transformers-2.3.24.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.22.dist-info → x_transformers-2.3.24.dist-info}/licenses/LICENSE +0 -0
x_transformers/attend.py
CHANGED
@@ -276,21 +276,22 @@ class Attend(Module):
|
|
276
276
|
|
277
277
|
# torch 2.3 uses new backend and context manager
|
278
278
|
|
279
|
-
if
|
280
|
-
|
281
|
-
|
282
|
-
str_to_backend = dict(
|
283
|
-
enable_flash = SDPBackend.FLASH_ATTENTION,
|
284
|
-
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
285
|
-
enable_math = SDPBackend.MATH,
|
286
|
-
enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
287
|
-
)
|
279
|
+
if self.flash:
|
280
|
+
if torch_version >= version.parse('2.3'):
|
281
|
+
from torch.nn.attention import SDPBackend
|
288
282
|
|
289
|
-
|
283
|
+
str_to_backend = dict(
|
284
|
+
enable_flash = SDPBackend.FLASH_ATTENTION,
|
285
|
+
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
286
|
+
enable_math = SDPBackend.MATH,
|
287
|
+
enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
288
|
+
)
|
290
289
|
|
291
|
-
|
292
|
-
|
293
|
-
|
290
|
+
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
|
291
|
+
|
292
|
+
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
|
293
|
+
else:
|
294
|
+
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
|
294
295
|
|
295
296
|
def flash_attn(
|
296
297
|
self,
|
@@ -324,7 +324,7 @@ class AutoregressiveWrapper(Module):
|
|
324
324
|
kwargs.update(self_attn_kv_mask = mask)
|
325
325
|
|
326
326
|
out, cache = self.net(
|
327
|
-
|
327
|
+
inp,
|
328
328
|
return_intermediates = True,
|
329
329
|
return_attn_z_loss = add_attn_z_loss,
|
330
330
|
return_next_embed_pred = add_next_embed_loss,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
|
-
x_transformers/attend.py,sha256=
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=tMVbIC8iXTpfGDxRhPqqHTulkxB8aZqNML77WbGhfac,11755
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=7phSZvP1_SDRIkVMwVR4cz1dFU2UlR2Wf1HJHEQlcQg,116222
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.24.dist-info/METADATA,sha256=vqW6_PFF3JiQirofvdzEMXwAt_x9luG_TOQimtIWwg8,89897
|
15
|
+
x_transformers-2.3.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.24.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.24.dist-info/RECORD,,
|
File without changes
|
File without changes
|