x-transformers 2.3.22__py3-none-any.whl → 2.3.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -276,21 +276,22 @@ class Attend(Module):
276
276
 
277
277
  # torch 2.3 uses new backend and context manager
278
278
 
279
- if torch_version >= version.parse('2.3'):
280
- from torch.nn.attention import SDPBackend
281
-
282
- str_to_backend = dict(
283
- enable_flash = SDPBackend.FLASH_ATTENTION,
284
- enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
285
- enable_math = SDPBackend.MATH,
286
- enable_cudnn = SDPBackend.CUDNN_ATTENTION
287
- )
279
+ if self.flash:
280
+ if torch_version >= version.parse('2.3'):
281
+ from torch.nn.attention import SDPBackend
288
282
 
289
- sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
283
+ str_to_backend = dict(
284
+ enable_flash = SDPBackend.FLASH_ATTENTION,
285
+ enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
286
+ enable_math = SDPBackend.MATH,
287
+ enable_cudnn = SDPBackend.CUDNN_ATTENTION
288
+ )
290
289
 
291
- self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
292
- else:
293
- self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
290
+ sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
291
+
292
+ self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
293
+ else:
294
+ self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
294
295
 
295
296
  def flash_attn(
296
297
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.22
3
+ Version: 2.3.23
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
- x_transformers/attend.py,sha256=fXMuwHuBAFB4f4_U6j5_uVeK7N4cV0PDd6UTqtkjKKM,17333
2
+ x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
3
3
  x_transformers/autoregressive_wrapper.py,sha256=BWFaO-3YWzCcEfp-EC1ZkdckqDpPIOQG6_uyyP6AmhM,11753
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=7phSZvP1_SDRIkVMwVR4cz1dFU2UlR2Wf1HJHEQlcQg,116222
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.22.dist-info/METADATA,sha256=_8m8ftpHRKjbEUDuoeYPcVh4yan1FxNRj3seJwiZzl8,89897
15
- x_transformers-2.3.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.22.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.22.dist-info/RECORD,,
14
+ x_transformers-2.3.23.dist-info/METADATA,sha256=xRMZP1TSYdcbc0F5GX-WcaHhAbQPdGeFIbjHBZYG9_0,89897
15
+ x_transformers-2.3.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.23.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.23.dist-info/RECORD,,