mps-flash-attn 0.1.4__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/PKG-INFO +1 -1
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/__init__.py +5 -2
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/csrc/mps_flash_attn.mm +57 -15
- mps_flash_attn-0.1.6/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/pyproject.toml +1 -1
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/setup.py +2 -2
- mps_flash_attn-0.1.4/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/LICENSE +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/README.md +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/setup.cfg +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/tests/test_attention.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.6"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -174,10 +174,13 @@ def replace_sdpa():
|
|
|
174
174
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
175
175
|
is_causal=False, scale=None):
|
|
176
176
|
# Use MFA for MPS tensors without mask/dropout
|
|
177
|
+
# Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
|
|
178
|
+
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
177
179
|
if (query.device.type == 'mps' and
|
|
178
180
|
attn_mask is None and
|
|
179
181
|
dropout_p == 0.0 and
|
|
180
|
-
_HAS_MFA
|
|
182
|
+
_HAS_MFA and
|
|
183
|
+
query.shape[2] >= 1024):
|
|
181
184
|
try:
|
|
182
185
|
return flash_attention(query, key, value, is_causal=is_causal, scale=scale)
|
|
183
186
|
except Exception:
|
|
@@ -18,7 +18,13 @@
|
|
|
18
18
|
// ============================================================================
|
|
19
19
|
|
|
20
20
|
typedef bool (*mfa_init_fn)();
|
|
21
|
-
typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool);
|
|
21
|
+
typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool);
|
|
22
|
+
// New zero-sync encode functions that take PyTorch's command encoder
|
|
23
|
+
typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
|
|
24
|
+
typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
25
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
26
|
+
int32_t, int32_t);
|
|
27
|
+
// Legacy sync functions (fallback)
|
|
22
28
|
typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
|
|
23
29
|
typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
24
30
|
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
@@ -28,6 +34,8 @@ typedef void (*mfa_release_kernel_fn)(void*);
|
|
|
28
34
|
// Global function pointers
|
|
29
35
|
static mfa_init_fn g_mfa_init = nullptr;
|
|
30
36
|
static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
|
|
37
|
+
static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
|
|
38
|
+
static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
|
|
31
39
|
static mfa_forward_fn g_mfa_forward = nullptr;
|
|
32
40
|
static mfa_backward_fn g_mfa_backward = nullptr;
|
|
33
41
|
static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
|
|
@@ -71,11 +79,14 @@ static bool load_mfa_bridge() {
|
|
|
71
79
|
// Load function pointers
|
|
72
80
|
g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
|
|
73
81
|
g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
|
|
82
|
+
g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
|
|
83
|
+
g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
|
|
74
84
|
g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
|
|
75
85
|
g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
|
|
76
86
|
g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
|
|
77
87
|
|
|
78
|
-
|
|
88
|
+
// Require at least init, create_kernel, forward_encode (for zero-sync path)
|
|
89
|
+
if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward_encode) {
|
|
79
90
|
throw std::runtime_error("Failed to load MFA bridge functions");
|
|
80
91
|
}
|
|
81
92
|
|
|
@@ -193,17 +204,37 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
193
204
|
TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
|
|
194
205
|
|
|
195
206
|
const int64_t batch_size = query.size(0);
|
|
196
|
-
const int64_t
|
|
207
|
+
const int64_t num_heads_q = query.size(1);
|
|
208
|
+
const int64_t num_heads_kv = key.size(1);
|
|
197
209
|
const int64_t seq_len_q = query.size(2);
|
|
198
210
|
const int64_t head_dim = query.size(3);
|
|
199
211
|
const int64_t seq_len_kv = key.size(2);
|
|
200
212
|
|
|
201
213
|
TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
|
|
202
214
|
"Batch size mismatch");
|
|
203
|
-
TORCH_CHECK(key.size(1) == num_heads && value.size(1) == num_heads,
|
|
204
|
-
"Number of heads mismatch");
|
|
205
215
|
TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
|
|
206
216
|
"Head dimension mismatch");
|
|
217
|
+
TORCH_CHECK(key.size(1) == value.size(1),
|
|
218
|
+
"K and V must have same number of heads");
|
|
219
|
+
|
|
220
|
+
// Handle GQA (Grouped Query Attention): expand K/V if fewer heads than Q
|
|
221
|
+
const int64_t num_heads = num_heads_q;
|
|
222
|
+
at::Tensor k_expanded, v_expanded;
|
|
223
|
+
|
|
224
|
+
if (num_heads_kv != num_heads_q) {
|
|
225
|
+
// GQA: num_heads_q must be divisible by num_heads_kv
|
|
226
|
+
TORCH_CHECK(num_heads_q % num_heads_kv == 0,
|
|
227
|
+
"num_heads_q (", num_heads_q, ") must be divisible by num_heads_kv (", num_heads_kv, ")");
|
|
228
|
+
int64_t repeat_factor = num_heads_q / num_heads_kv;
|
|
229
|
+
|
|
230
|
+
// Expand K and V to match Q's head count: (B, H_kv, S, D) -> (B, H_q, S, D)
|
|
231
|
+
// Use repeat_interleave for proper GQA expansion
|
|
232
|
+
k_expanded = key.repeat_interleave(repeat_factor, /*dim=*/1);
|
|
233
|
+
v_expanded = value.repeat_interleave(repeat_factor, /*dim=*/1);
|
|
234
|
+
} else {
|
|
235
|
+
k_expanded = key;
|
|
236
|
+
v_expanded = value;
|
|
237
|
+
}
|
|
207
238
|
|
|
208
239
|
// Determine precision
|
|
209
240
|
bool low_precision = (query.scalar_type() == at::kHalf ||
|
|
@@ -214,8 +245,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
214
245
|
|
|
215
246
|
// Make inputs contiguous
|
|
216
247
|
auto q = query.contiguous();
|
|
217
|
-
auto k =
|
|
218
|
-
auto v =
|
|
248
|
+
auto k = k_expanded.contiguous();
|
|
249
|
+
auto v = v_expanded.contiguous();
|
|
219
250
|
|
|
220
251
|
// Allocate output in the appropriate precision
|
|
221
252
|
// With lowPrecisionOutputs=true, MFA writes FP16 directly
|
|
@@ -242,15 +273,18 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
242
273
|
auto o_info = getBufferInfo(output);
|
|
243
274
|
auto l_info = getBufferInfo(logsumexp);
|
|
244
275
|
|
|
245
|
-
//
|
|
276
|
+
// Use PyTorch's MPS stream command encoder for zero-sync integration
|
|
246
277
|
@autoreleasepool {
|
|
247
|
-
// Wait for PyTorch operations to complete
|
|
248
278
|
auto stream = at::mps::getCurrentMPSStream();
|
|
249
|
-
stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
|
|
250
279
|
|
|
251
|
-
//
|
|
252
|
-
|
|
280
|
+
// Get PyTorch's shared command encoder - this is the key for zero-sync!
|
|
281
|
+
// All our dispatches go onto the same encoder that PyTorch uses.
|
|
282
|
+
id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
|
|
283
|
+
|
|
284
|
+
// Execute MFA using the shared encoder (no sync needed!)
|
|
285
|
+
bool success = g_mfa_forward_encode(
|
|
253
286
|
kernel,
|
|
287
|
+
(__bridge void*)encoder, // PyTorch's shared command encoder
|
|
254
288
|
(__bridge void*)q_info.buffer,
|
|
255
289
|
(__bridge void*)k_info.buffer,
|
|
256
290
|
(__bridge void*)v_info.buffer,
|
|
@@ -268,6 +302,9 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
268
302
|
if (!success) {
|
|
269
303
|
throw std::runtime_error("MFA forward pass failed");
|
|
270
304
|
}
|
|
305
|
+
|
|
306
|
+
// No commit needed - PyTorch will commit when it needs the results
|
|
307
|
+
// The encoder stays open for coalescing more kernels
|
|
271
308
|
}
|
|
272
309
|
|
|
273
310
|
// Output is already in the correct dtype (fp16 or fp32)
|
|
@@ -362,13 +399,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
362
399
|
auto dk_info = getBufferInfo(dK);
|
|
363
400
|
auto dv_info = getBufferInfo(dV);
|
|
364
401
|
|
|
365
|
-
//
|
|
402
|
+
// Use PyTorch's MPS stream command encoder for zero-sync integration
|
|
366
403
|
@autoreleasepool {
|
|
367
404
|
auto stream = at::mps::getCurrentMPSStream();
|
|
368
|
-
stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
|
|
369
405
|
|
|
370
|
-
|
|
406
|
+
// Get PyTorch's shared command encoder
|
|
407
|
+
id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
|
|
408
|
+
|
|
409
|
+
bool success = g_mfa_backward_encode(
|
|
371
410
|
kernel,
|
|
411
|
+
(__bridge void*)encoder, // PyTorch's shared command encoder
|
|
372
412
|
(__bridge void*)q_info.buffer,
|
|
373
413
|
(__bridge void*)k_info.buffer,
|
|
374
414
|
(__bridge void*)v_info.buffer,
|
|
@@ -396,6 +436,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
396
436
|
if (!success) {
|
|
397
437
|
throw std::runtime_error("MFA backward pass failed");
|
|
398
438
|
}
|
|
439
|
+
|
|
440
|
+
// No commit needed - PyTorch will commit when it needs the results
|
|
399
441
|
}
|
|
400
442
|
|
|
401
443
|
// Convert gradients back to input dtype if needed
|
|
Binary file
|
|
@@ -44,14 +44,14 @@ def get_extensions():
|
|
|
44
44
|
return [Extension(
|
|
45
45
|
name="mps_flash_attn._C",
|
|
46
46
|
sources=["mps_flash_attn/csrc/mps_flash_attn.mm"],
|
|
47
|
-
extra_compile_args=["-std=c++17", "-O3"],
|
|
47
|
+
extra_compile_args=["-std=c++17", "-O3", "-DTORCH_EXTENSION_NAME=_C"],
|
|
48
48
|
extra_link_args=["-framework", "Metal", "-framework", "Foundation"],
|
|
49
49
|
)]
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
setup(
|
|
53
53
|
name="mps-flash-attn",
|
|
54
|
-
version="0.1.
|
|
54
|
+
version="0.1.5",
|
|
55
55
|
packages=find_packages(),
|
|
56
56
|
package_data={
|
|
57
57
|
"mps_flash_attn": [
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|