x-transformers 1.34.0__py3-none-any.whl → 1.34.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -138,9 +138,27 @@ class Attend(Module):
138
138
  # flash attention
139
139
 
140
140
  self.flash = flash
141
- assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
142
141
 
143
- self.sdp_kwargs = sdp_kwargs
142
+ torch_version = version.parse(torch.__version__)
143
+ assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
144
+
145
+ # torch 2.3 uses new backend and context manager
146
+
147
+ if torch_version >= version.parse('2.3'):
148
+ from torch.nn.attention import SDPBackend
149
+
150
+ str_to_backend = dict(
151
+ enable_flash = SDPBackend.FLASH_ATTENTION,
152
+ enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
153
+ enable_math = SDPBackend.MATH,
154
+ enable_cudnn = SDPBackend.CUDNN_ATTENTION
155
+ )
156
+
157
+ sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
158
+
159
+ self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
160
+ else:
161
+ self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
144
162
 
145
163
  def flash_attn(
146
164
  self,
@@ -231,7 +249,7 @@ class Attend(Module):
231
249
 
232
250
  # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
233
251
 
234
- with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
252
+ with self.sdp_context_manager():
235
253
  out = F.scaled_dot_product_attention(
236
254
  q, k, v,
237
255
  attn_mask = mask,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.34.0
3
+ Version: 1.34.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
2
+ x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
3
3
  x_transformers/autoregressive_wrapper.py,sha256=ka_iiej5lEBOcbutWQgGrFVMDilz2PFWzLhBh5_tmmg,10366
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T
8
8
  x_transformers/x_transformers.py,sha256=hs9j-lHukVGYLlpbBhn4CZhSzI7s0x6bYtEhCc33ftE,78680
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.34.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.34.0.dist-info/METADATA,sha256=aTRBJepYjojT5TFi8W2oK4j7daQGRQaWwj2HHnnwDCQ,661
13
- x_transformers-1.34.0.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
14
- x_transformers-1.34.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.34.0.dist-info/RECORD,,
11
+ x_transformers-1.34.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.34.1.dist-info/METADATA,sha256=jSsnjS0ptrIpH-nc9h7fNMjzAvpmQGOkXYqTSWyUvGQ,661
13
+ x_transformers-1.34.1.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
14
+ x_transformers-1.34.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.34.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.0.0)
2
+ Generator: setuptools (74.1.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5