x-transformers 1.32.15__py3-none-any.whl → 1.34.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +21 -3
- x_transformers/x_transformers.py +3 -3
- {x_transformers-1.32.15.dist-info → x_transformers-1.34.1.dist-info}/METADATA +1 -1
- {x_transformers-1.32.15.dist-info → x_transformers-1.34.1.dist-info}/RECORD +7 -7
- {x_transformers-1.32.15.dist-info → x_transformers-1.34.1.dist-info}/WHEEL +1 -1
- {x_transformers-1.32.15.dist-info → x_transformers-1.34.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.15.dist-info → x_transformers-1.34.1.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -138,9 +138,27 @@ class Attend(Module):
|
|
138
138
|
# flash attention
|
139
139
|
|
140
140
|
self.flash = flash
|
141
|
-
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
142
141
|
|
143
|
-
|
142
|
+
torch_version = version.parse(torch.__version__)
|
143
|
+
assert not (flash and torch_version < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
144
|
+
|
145
|
+
# torch 2.3 uses new backend and context manager
|
146
|
+
|
147
|
+
if torch_version >= version.parse('2.3'):
|
148
|
+
from torch.nn.attention import SDPBackend
|
149
|
+
|
150
|
+
str_to_backend = dict(
|
151
|
+
enable_flash = SDPBackend.FLASH_ATTENTION,
|
152
|
+
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
|
153
|
+
enable_math = SDPBackend.MATH,
|
154
|
+
enable_cudnn = SDPBackend.CUDNN_ATTENTION
|
155
|
+
)
|
156
|
+
|
157
|
+
sdpa_backends = [str_to_backend[enable_str] for enable_str, enable in sdp_kwargs.items() if enable]
|
158
|
+
|
159
|
+
self.sdp_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
|
160
|
+
else:
|
161
|
+
self.sdp_context_manager = partial(torch.backends.cuda.sdp_kernel, **sdp_kwargs)
|
144
162
|
|
145
163
|
def flash_attn(
|
146
164
|
self,
|
@@ -231,7 +249,7 @@ class Attend(Module):
|
|
231
249
|
|
232
250
|
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
233
251
|
|
234
|
-
with
|
252
|
+
with self.sdp_context_manager():
|
235
253
|
out = F.scaled_dot_product_attention(
|
236
254
|
q, k, v,
|
237
255
|
attn_mask = mask,
|
x_transformers/x_transformers.py
CHANGED
@@ -8,7 +8,7 @@ import torch
|
|
8
8
|
import torch.nn.functional as F
|
9
9
|
from torch import nn, einsum, Tensor
|
10
10
|
from torch.nn import Module, ModuleList, ModuleDict
|
11
|
-
from torch.
|
11
|
+
from torch.amp import autocast
|
12
12
|
|
13
13
|
from functools import partial, wraps
|
14
14
|
from collections import namedtuple
|
@@ -521,7 +521,7 @@ class RotaryEmbedding(Module):
|
|
521
521
|
t = torch.arange(seq_len, device = device)
|
522
522
|
return self.forward(t)
|
523
523
|
|
524
|
-
@autocast(enabled = False)
|
524
|
+
@autocast('cuda', enabled = False)
|
525
525
|
def forward(self, t):
|
526
526
|
max_pos = t.max() + 1
|
527
527
|
|
@@ -545,7 +545,7 @@ def rotate_half(x):
|
|
545
545
|
x = torch.stack((-x2, x1), dim = -1)
|
546
546
|
return rearrange(x, '... d r -> ... (d r)')
|
547
547
|
|
548
|
-
@autocast(enabled = False)
|
548
|
+
@autocast('cuda', enabled = False)
|
549
549
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
550
550
|
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
551
551
|
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=7q996VGYHGIsc0FQnN8WNiwHn3xny3i1biRwx7yW5vg,12090
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=ka_iiej5lEBOcbutWQgGrFVMDilz2PFWzLhBh5_tmmg,10366
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=hs9j-lHukVGYLlpbBhn4CZhSzI7s0x6bYtEhCc33ftE,78680
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.34.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.34.1.dist-info/METADATA,sha256=jSsnjS0ptrIpH-nc9h7fNMjzAvpmQGOkXYqTSWyUvGQ,661
|
13
|
+
x_transformers-1.34.1.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
14
|
+
x_transformers-1.34.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.34.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|