mps-flash-attn 0.1.7__cp314-cp314-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- mps_flash_attn/_C.cpython-314-darwin.so +0 -0
- mps_flash_attn/__init__.py +264 -0
- mps_flash_attn/csrc/mps_flash_attn.mm +569 -0
- mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- mps_flash_attn/kernels/manifest.json +27 -0
- mps_flash_attn-0.1.7.dist-info/METADATA +270 -0
- mps_flash_attn-0.1.7.dist-info/RECORD +31 -0
- mps_flash_attn-0.1.7.dist-info/WHEEL +5 -0
- mps_flash_attn-0.1.7.dist-info/licenses/LICENSE +27 -0
- mps_flash_attn-0.1.7.dist-info/top_level.txt +1 -0
|
Binary file
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
3
|
+
|
|
4
|
+
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.7"
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from typing import Optional
|
|
11
|
+
import math
|
|
12
|
+
import threading
|
|
13
|
+
import os
|
|
14
|
+
|
|
15
|
+
# Try to import the C++ extension
|
|
16
|
+
try:
|
|
17
|
+
from . import _C
|
|
18
|
+
_HAS_MFA = True
|
|
19
|
+
except ImportError as e:
|
|
20
|
+
_HAS_MFA = False
|
|
21
|
+
_IMPORT_ERROR = str(e)
|
|
22
|
+
|
|
23
|
+
# Note: The C++ extension handles loading libMFABridge.dylib via dlopen.
|
|
24
|
+
# Set MFA_BRIDGE_PATH environment variable to specify the library location.
|
|
25
|
+
# Do NOT load the library here via ctypes - that causes duplicate class warnings.
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def is_available() -> bool:
|
|
29
|
+
"""Check if MPS Flash Attention is available."""
|
|
30
|
+
return _HAS_MFA and torch.backends.mps.is_available()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class FlashAttentionFunction(torch.autograd.Function):
|
|
34
|
+
"""Autograd function for Flash Attention with backward pass support."""
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def forward(ctx, query, key, value, is_causal, scale, attn_mask):
|
|
38
|
+
# Apply scale if provided (MFA uses 1/sqrt(D) internally)
|
|
39
|
+
scale_factor = 1.0
|
|
40
|
+
if scale is not None:
|
|
41
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
42
|
+
if abs(scale - default_scale) > 1e-6:
|
|
43
|
+
scale_factor = scale / default_scale
|
|
44
|
+
query = query * scale_factor
|
|
45
|
+
|
|
46
|
+
# Forward with logsumexp for backward
|
|
47
|
+
output, logsumexp = _C.forward_with_lse(query, key, value, is_causal, attn_mask)
|
|
48
|
+
|
|
49
|
+
# Save for backward
|
|
50
|
+
if attn_mask is not None:
|
|
51
|
+
ctx.save_for_backward(query, key, value, output, logsumexp, attn_mask)
|
|
52
|
+
ctx.has_mask = True
|
|
53
|
+
else:
|
|
54
|
+
ctx.save_for_backward(query, key, value, output, logsumexp)
|
|
55
|
+
ctx.has_mask = False
|
|
56
|
+
ctx.is_causal = is_causal
|
|
57
|
+
ctx.scale_factor = scale_factor
|
|
58
|
+
|
|
59
|
+
return output
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def backward(ctx, grad_output):
|
|
63
|
+
if ctx.has_mask:
|
|
64
|
+
query, key, value, output, logsumexp, attn_mask = ctx.saved_tensors
|
|
65
|
+
else:
|
|
66
|
+
query, key, value, output, logsumexp = ctx.saved_tensors
|
|
67
|
+
attn_mask = None
|
|
68
|
+
|
|
69
|
+
# Compute gradients
|
|
70
|
+
dQ, dK, dV = _C.backward(
|
|
71
|
+
grad_output, query, key, value, output, logsumexp, ctx.is_causal, attn_mask
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# If we scaled the query in forward, scale the gradient back
|
|
75
|
+
if ctx.scale_factor != 1.0:
|
|
76
|
+
dQ = dQ * ctx.scale_factor
|
|
77
|
+
|
|
78
|
+
# Return gradients (None for is_causal, scale, and attn_mask since they're not tensors or don't need grad)
|
|
79
|
+
return dQ, dK, dV, None, None, None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def flash_attention(
|
|
83
|
+
query: torch.Tensor,
|
|
84
|
+
key: torch.Tensor,
|
|
85
|
+
value: torch.Tensor,
|
|
86
|
+
is_causal: bool = False,
|
|
87
|
+
scale: Optional[float] = None,
|
|
88
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
89
|
+
) -> torch.Tensor:
|
|
90
|
+
"""
|
|
91
|
+
Compute scaled dot-product attention using Flash Attention on MPS.
|
|
92
|
+
|
|
93
|
+
This function provides O(N) memory complexity instead of O(N²) by using
|
|
94
|
+
tiled computation, allowing much longer sequences on limited GPU memory.
|
|
95
|
+
|
|
96
|
+
Supports both forward and backward passes for training.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
query: Query tensor of shape (B, num_heads, seq_len, head_dim)
|
|
100
|
+
key: Key tensor of shape (B, num_heads, seq_len, head_dim)
|
|
101
|
+
value: Value tensor of shape (B, num_heads, seq_len, head_dim)
|
|
102
|
+
is_causal: If True, applies causal masking (for autoregressive models)
|
|
103
|
+
scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
|
|
104
|
+
attn_mask: Optional boolean attention mask of shape (B, 1, seq_len_q, seq_len_kv)
|
|
105
|
+
or (B, num_heads, seq_len_q, seq_len_kv). True values indicate
|
|
106
|
+
positions to be masked (not attended to).
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Output tensor of shape (B, num_heads, seq_len, head_dim)
|
|
110
|
+
|
|
111
|
+
Example:
|
|
112
|
+
>>> import torch
|
|
113
|
+
>>> from mps_flash_attn import flash_attention
|
|
114
|
+
>>> q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
115
|
+
>>> k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
116
|
+
>>> v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
117
|
+
>>> out = flash_attention(q, k, v)
|
|
118
|
+
|
|
119
|
+
# With gradients:
|
|
120
|
+
>>> q.requires_grad = True
|
|
121
|
+
>>> out = flash_attention(q, k, v)
|
|
122
|
+
>>> out.sum().backward() # Computes dQ
|
|
123
|
+
|
|
124
|
+
# With attention mask:
|
|
125
|
+
>>> mask = torch.zeros(2, 1, 4096, 4096, dtype=torch.bool, device='mps')
|
|
126
|
+
>>> mask[:, :, :, 2048:] = True # mask out second half of keys
|
|
127
|
+
>>> out = flash_attention(q, k, v, attn_mask=mask)
|
|
128
|
+
"""
|
|
129
|
+
if not _HAS_MFA:
|
|
130
|
+
raise RuntimeError(
|
|
131
|
+
f"MPS Flash Attention C++ extension not available: {_IMPORT_ERROR}\n"
|
|
132
|
+
"Please rebuild with: pip install -e ."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if not torch.backends.mps.is_available():
|
|
136
|
+
raise RuntimeError("MPS not available")
|
|
137
|
+
|
|
138
|
+
# Validate device
|
|
139
|
+
if query.device.type != 'mps':
|
|
140
|
+
raise ValueError("query must be on MPS device")
|
|
141
|
+
if key.device.type != 'mps':
|
|
142
|
+
raise ValueError("key must be on MPS device")
|
|
143
|
+
if value.device.type != 'mps':
|
|
144
|
+
raise ValueError("value must be on MPS device")
|
|
145
|
+
if attn_mask is not None and attn_mask.device.type != 'mps':
|
|
146
|
+
raise ValueError("attn_mask must be on MPS device")
|
|
147
|
+
|
|
148
|
+
# Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
|
|
149
|
+
if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
|
|
150
|
+
# Apply scale if provided
|
|
151
|
+
if scale is not None:
|
|
152
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
153
|
+
if abs(scale - default_scale) > 1e-6:
|
|
154
|
+
scale_factor = scale / default_scale
|
|
155
|
+
query = query * scale_factor
|
|
156
|
+
|
|
157
|
+
# Forward only - no logsumexp needed, no tensors saved
|
|
158
|
+
return _C.forward(query, key, value, is_causal, attn_mask)
|
|
159
|
+
|
|
160
|
+
# Use autograd function for gradient support
|
|
161
|
+
return FlashAttentionFunction.apply(query, key, value, is_causal, scale, attn_mask)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def replace_sdpa():
|
|
165
|
+
"""
|
|
166
|
+
Monkey-patch torch.nn.functional.scaled_dot_product_attention to use
|
|
167
|
+
Flash Attention on MPS devices.
|
|
168
|
+
|
|
169
|
+
Call this at the start of your script to automatically use Flash Attention
|
|
170
|
+
for all attention operations.
|
|
171
|
+
"""
|
|
172
|
+
import torch.nn.functional as F
|
|
173
|
+
|
|
174
|
+
original_sdpa = F.scaled_dot_product_attention
|
|
175
|
+
|
|
176
|
+
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
177
|
+
is_causal=False, scale=None):
|
|
178
|
+
# Use MFA for MPS tensors without dropout
|
|
179
|
+
# Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
|
|
180
|
+
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
181
|
+
if (query.device.type == 'mps' and
|
|
182
|
+
dropout_p == 0.0 and
|
|
183
|
+
_HAS_MFA and
|
|
184
|
+
query.shape[2] >= 1024):
|
|
185
|
+
try:
|
|
186
|
+
# Convert float mask to bool mask if needed
|
|
187
|
+
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
188
|
+
# MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
|
|
189
|
+
mfa_mask = None
|
|
190
|
+
if attn_mask is not None:
|
|
191
|
+
if attn_mask.dtype == torch.bool:
|
|
192
|
+
# Boolean mask: True means masked (don't attend)
|
|
193
|
+
mfa_mask = attn_mask
|
|
194
|
+
else:
|
|
195
|
+
# Float mask: typically -inf for masked positions, 0 for unmasked
|
|
196
|
+
# Convert: positions with large negative values -> True (masked)
|
|
197
|
+
mfa_mask = attn_mask < -1e4
|
|
198
|
+
return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
199
|
+
except Exception:
|
|
200
|
+
# Fall back to original on any error
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale)
|
|
204
|
+
|
|
205
|
+
F.scaled_dot_product_attention = patched_sdpa
|
|
206
|
+
print("MPS Flash Attention: Patched F.scaled_dot_product_attention")
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def precompile():
|
|
210
|
+
"""
|
|
211
|
+
Pre-compile Metal kernels for common configurations.
|
|
212
|
+
|
|
213
|
+
Call this once after installation to eliminate runtime compilation overhead.
|
|
214
|
+
Pre-compiled kernels are cached to disk and loaded instantly on subsequent runs.
|
|
215
|
+
|
|
216
|
+
This compiles kernels for:
|
|
217
|
+
- Sequence lengths: 64, 128, 256, 512, 1024, 2048, 4096, 8192
|
|
218
|
+
- Head dimensions: 32, 48, 64, 80, 96, 128
|
|
219
|
+
- Both fp32 and fp16 precision
|
|
220
|
+
|
|
221
|
+
Total: 96 kernel configurations
|
|
222
|
+
"""
|
|
223
|
+
if not _HAS_MFA:
|
|
224
|
+
raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
|
|
225
|
+
|
|
226
|
+
import ctypes
|
|
227
|
+
import os
|
|
228
|
+
|
|
229
|
+
# Load the Swift bridge directly
|
|
230
|
+
bridge_path = os.environ.get("MFA_BRIDGE_PATH")
|
|
231
|
+
if not bridge_path:
|
|
232
|
+
# Try common locations
|
|
233
|
+
module_dir = os.path.dirname(__file__)
|
|
234
|
+
candidates = [
|
|
235
|
+
os.path.join(module_dir, "lib", "libMFABridge.dylib"), # Bundled in wheel
|
|
236
|
+
os.path.join(module_dir, "..", "swift-bridge", ".build", "release", "libMFABridge.dylib"),
|
|
237
|
+
os.path.join(module_dir, "libMFABridge.dylib"),
|
|
238
|
+
]
|
|
239
|
+
for path in candidates:
|
|
240
|
+
if os.path.exists(path):
|
|
241
|
+
bridge_path = path
|
|
242
|
+
break
|
|
243
|
+
|
|
244
|
+
if not bridge_path or not os.path.exists(bridge_path):
|
|
245
|
+
raise RuntimeError("Cannot find libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.")
|
|
246
|
+
|
|
247
|
+
lib = ctypes.CDLL(bridge_path)
|
|
248
|
+
lib.mfa_precompile()
|
|
249
|
+
print("\nPre-compilation complete! Kernels cached to disk.")
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def clear_cache():
|
|
253
|
+
"""Clear the pre-compiled kernel cache."""
|
|
254
|
+
if not _HAS_MFA:
|
|
255
|
+
raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
|
|
256
|
+
|
|
257
|
+
import ctypes
|
|
258
|
+
import os
|
|
259
|
+
|
|
260
|
+
bridge_path = os.environ.get("MFA_BRIDGE_PATH")
|
|
261
|
+
if bridge_path and os.path.exists(bridge_path):
|
|
262
|
+
lib = ctypes.CDLL(bridge_path)
|
|
263
|
+
lib.mfa_clear_cache()
|
|
264
|
+
print("Cache cleared.")
|
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* MPS Flash Attention - PyTorch C++ Extension
|
|
3
|
+
*
|
|
4
|
+
* Bridges PyTorch MPS tensors to the MFA Swift library.
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
#include <torch/extension.h>
|
|
8
|
+
#include <ATen/mps/MPSStream.h>
|
|
9
|
+
#include <ATen/native/mps/OperationUtils.h>
|
|
10
|
+
|
|
11
|
+
#import <Metal/Metal.h>
|
|
12
|
+
#import <Foundation/Foundation.h>
|
|
13
|
+
|
|
14
|
+
#include <dlfcn.h>
|
|
15
|
+
#include <string>
|
|
16
|
+
#include <vector>
|
|
17
|
+
|
|
18
|
+
// ============================================================================
|
|
19
|
+
// MFA Bridge Function Types
|
|
20
|
+
// ============================================================================
|
|
21
|
+
|
|
22
|
+
typedef bool (*mfa_init_fn)();
|
|
23
|
+
typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool);
|
|
24
|
+
// New zero-sync encode functions that take PyTorch's command encoder
|
|
25
|
+
// Added mask_ptr and mask_offset parameters
|
|
26
|
+
typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
|
|
27
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
28
|
+
int32_t, int32_t);
|
|
29
|
+
typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
30
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
31
|
+
int32_t, int32_t);
|
|
32
|
+
// Legacy sync functions (fallback)
|
|
33
|
+
typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, void*,
|
|
34
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
35
|
+
int32_t, int32_t);
|
|
36
|
+
typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
37
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
38
|
+
int32_t, int32_t);
|
|
39
|
+
typedef void (*mfa_release_kernel_fn)(void*);
|
|
40
|
+
|
|
41
|
+
// Global function pointers
|
|
42
|
+
static mfa_init_fn g_mfa_init = nullptr;
|
|
43
|
+
static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
|
|
44
|
+
static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
|
|
45
|
+
static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
|
|
46
|
+
static mfa_forward_fn g_mfa_forward = nullptr;
|
|
47
|
+
static mfa_backward_fn g_mfa_backward = nullptr;
|
|
48
|
+
static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
|
|
49
|
+
static void* g_dylib_handle = nullptr;
|
|
50
|
+
static bool g_initialized = false;
|
|
51
|
+
|
|
52
|
+
// ============================================================================
|
|
53
|
+
// Load MFA Bridge Library
|
|
54
|
+
// ============================================================================
|
|
55
|
+
|
|
56
|
+
// Get the directory containing this shared library
|
|
57
|
+
static std::string get_module_dir() {
|
|
58
|
+
Dl_info info;
|
|
59
|
+
if (dladdr((void*)get_module_dir, &info) && info.dli_fname) {
|
|
60
|
+
std::string path(info.dli_fname);
|
|
61
|
+
size_t last_slash = path.rfind('/');
|
|
62
|
+
if (last_slash != std::string::npos) {
|
|
63
|
+
return path.substr(0, last_slash);
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
return ".";
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
static bool load_mfa_bridge() {
|
|
70
|
+
if (g_dylib_handle) return true;
|
|
71
|
+
|
|
72
|
+
// First check environment variable (highest priority)
|
|
73
|
+
const char* mfa_path = getenv("MFA_BRIDGE_PATH");
|
|
74
|
+
if (mfa_path) {
|
|
75
|
+
g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
|
|
76
|
+
if (g_dylib_handle) return true;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
// Get the directory containing this extension module
|
|
80
|
+
std::string module_dir = get_module_dir();
|
|
81
|
+
|
|
82
|
+
// Try paths relative to the module directory
|
|
83
|
+
std::vector<std::string> paths = {
|
|
84
|
+
module_dir + "/lib/libMFABridge.dylib", // Bundled in wheel
|
|
85
|
+
module_dir + "/../swift-bridge/.build/release/libMFABridge.dylib", // Dev build
|
|
86
|
+
"libMFABridge.dylib", // Current directory fallback
|
|
87
|
+
};
|
|
88
|
+
|
|
89
|
+
for (const auto& path : paths) {
|
|
90
|
+
g_dylib_handle = dlopen(path.c_str(), RTLD_NOW);
|
|
91
|
+
if (g_dylib_handle) break;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
if (!g_dylib_handle) {
|
|
95
|
+
throw std::runtime_error(
|
|
96
|
+
"Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// Load function pointers
|
|
100
|
+
g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
|
|
101
|
+
g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
|
|
102
|
+
g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
|
|
103
|
+
g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
|
|
104
|
+
g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
|
|
105
|
+
g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
|
|
106
|
+
g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
|
|
107
|
+
|
|
108
|
+
// Require at least init, create_kernel, forward_encode (for zero-sync path)
|
|
109
|
+
if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward_encode) {
|
|
110
|
+
throw std::runtime_error("Failed to load MFA bridge functions");
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
return true;
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
// ============================================================================
|
|
117
|
+
// Get MTLBuffer from PyTorch MPS Tensor
|
|
118
|
+
// ============================================================================
|
|
119
|
+
|
|
120
|
+
struct BufferInfo {
|
|
121
|
+
id<MTLBuffer> buffer;
|
|
122
|
+
int64_t byte_offset;
|
|
123
|
+
};
|
|
124
|
+
|
|
125
|
+
static BufferInfo getBufferInfo(const at::Tensor& tensor) {
|
|
126
|
+
TORCH_CHECK(tensor.device().is_mps(), "Tensor must be on MPS device");
|
|
127
|
+
TORCH_CHECK(tensor.is_contiguous(), "Tensor must be contiguous");
|
|
128
|
+
|
|
129
|
+
// Get the underlying Metal buffer (covers entire storage)
|
|
130
|
+
id<MTLBuffer> buffer = at::native::mps::getMTLBufferStorage(tensor);
|
|
131
|
+
|
|
132
|
+
// Calculate byte offset: storage_offset() is in elements, multiply by element size
|
|
133
|
+
int64_t element_size = tensor.element_size();
|
|
134
|
+
int64_t byte_offset = tensor.storage_offset() * element_size;
|
|
135
|
+
|
|
136
|
+
return {buffer, byte_offset};
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
// ============================================================================
|
|
140
|
+
// Kernel Cache
|
|
141
|
+
// ============================================================================
|
|
142
|
+
|
|
143
|
+
struct KernelCacheKey {
|
|
144
|
+
int64_t seq_len_q;
|
|
145
|
+
int64_t seq_len_kv;
|
|
146
|
+
int64_t head_dim;
|
|
147
|
+
bool low_precision;
|
|
148
|
+
bool low_precision_outputs;
|
|
149
|
+
bool causal;
|
|
150
|
+
bool has_mask;
|
|
151
|
+
|
|
152
|
+
bool operator==(const KernelCacheKey& other) const {
|
|
153
|
+
return seq_len_q == other.seq_len_q &&
|
|
154
|
+
seq_len_kv == other.seq_len_kv &&
|
|
155
|
+
head_dim == other.head_dim &&
|
|
156
|
+
low_precision == other.low_precision &&
|
|
157
|
+
low_precision_outputs == other.low_precision_outputs &&
|
|
158
|
+
causal == other.causal &&
|
|
159
|
+
has_mask == other.has_mask;
|
|
160
|
+
}
|
|
161
|
+
};
|
|
162
|
+
|
|
163
|
+
struct KernelCacheKeyHash {
|
|
164
|
+
size_t operator()(const KernelCacheKey& k) const {
|
|
165
|
+
return std::hash<int64_t>()(k.seq_len_q) ^
|
|
166
|
+
(std::hash<int64_t>()(k.seq_len_kv) << 1) ^
|
|
167
|
+
(std::hash<int64_t>()(k.head_dim) << 2) ^
|
|
168
|
+
(std::hash<bool>()(k.low_precision) << 3) ^
|
|
169
|
+
(std::hash<bool>()(k.low_precision_outputs) << 4) ^
|
|
170
|
+
(std::hash<bool>()(k.causal) << 5) ^
|
|
171
|
+
(std::hash<bool>()(k.has_mask) << 6);
|
|
172
|
+
}
|
|
173
|
+
};
|
|
174
|
+
|
|
175
|
+
static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
|
|
176
|
+
|
|
177
|
+
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask) {
|
|
178
|
+
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask};
|
|
179
|
+
|
|
180
|
+
auto it = g_kernel_cache.find(key);
|
|
181
|
+
if (it != g_kernel_cache.end()) {
|
|
182
|
+
return it->second;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
void* kernel = g_mfa_create_kernel(
|
|
186
|
+
static_cast<int32_t>(seq_q),
|
|
187
|
+
static_cast<int32_t>(seq_kv),
|
|
188
|
+
static_cast<int32_t>(head_dim),
|
|
189
|
+
low_prec,
|
|
190
|
+
low_prec_outputs,
|
|
191
|
+
causal,
|
|
192
|
+
has_mask
|
|
193
|
+
);
|
|
194
|
+
|
|
195
|
+
if (!kernel) {
|
|
196
|
+
throw std::runtime_error("Failed to create MFA kernel");
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
g_kernel_cache[key] = kernel;
|
|
200
|
+
return kernel;
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
// ============================================================================
|
|
204
|
+
// Flash Attention Forward
|
|
205
|
+
// ============================================================================
|
|
206
|
+
|
|
207
|
+
std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
208
|
+
const at::Tensor& query, // (B, H, N, D)
|
|
209
|
+
const at::Tensor& key, // (B, H, N, D)
|
|
210
|
+
const at::Tensor& value, // (B, H, N, D)
|
|
211
|
+
bool is_causal,
|
|
212
|
+
const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
|
|
213
|
+
) {
|
|
214
|
+
// Initialize MFA on first call
|
|
215
|
+
if (!g_initialized) {
|
|
216
|
+
load_mfa_bridge();
|
|
217
|
+
if (!g_mfa_init()) {
|
|
218
|
+
throw std::runtime_error("Failed to initialize MFA");
|
|
219
|
+
}
|
|
220
|
+
g_initialized = true;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
// Validate inputs
|
|
224
|
+
TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
|
|
225
|
+
TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
|
|
226
|
+
TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
|
|
227
|
+
TORCH_CHECK(query.device().is_mps(), "Query must be on MPS device");
|
|
228
|
+
TORCH_CHECK(key.device().is_mps(), "Key must be on MPS device");
|
|
229
|
+
TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
|
|
230
|
+
|
|
231
|
+
const int64_t batch_size = query.size(0);
|
|
232
|
+
const int64_t num_heads_q = query.size(1);
|
|
233
|
+
const int64_t num_heads_kv = key.size(1);
|
|
234
|
+
const int64_t seq_len_q = query.size(2);
|
|
235
|
+
const int64_t head_dim = query.size(3);
|
|
236
|
+
const int64_t seq_len_kv = key.size(2);
|
|
237
|
+
|
|
238
|
+
TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
|
|
239
|
+
"Batch size mismatch");
|
|
240
|
+
TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
|
|
241
|
+
"Head dimension mismatch");
|
|
242
|
+
TORCH_CHECK(key.size(1) == value.size(1),
|
|
243
|
+
"K and V must have same number of heads");
|
|
244
|
+
|
|
245
|
+
// Handle GQA (Grouped Query Attention): expand K/V if fewer heads than Q
|
|
246
|
+
const int64_t num_heads = num_heads_q;
|
|
247
|
+
at::Tensor k_expanded, v_expanded;
|
|
248
|
+
|
|
249
|
+
if (num_heads_kv != num_heads_q) {
|
|
250
|
+
// GQA: num_heads_q must be divisible by num_heads_kv
|
|
251
|
+
TORCH_CHECK(num_heads_q % num_heads_kv == 0,
|
|
252
|
+
"num_heads_q (", num_heads_q, ") must be divisible by num_heads_kv (", num_heads_kv, ")");
|
|
253
|
+
int64_t repeat_factor = num_heads_q / num_heads_kv;
|
|
254
|
+
|
|
255
|
+
// Expand K and V to match Q's head count: (B, H_kv, S, D) -> (B, H_q, S, D)
|
|
256
|
+
// Use repeat_interleave for proper GQA expansion
|
|
257
|
+
k_expanded = key.repeat_interleave(repeat_factor, /*dim=*/1);
|
|
258
|
+
v_expanded = value.repeat_interleave(repeat_factor, /*dim=*/1);
|
|
259
|
+
} else {
|
|
260
|
+
k_expanded = key;
|
|
261
|
+
v_expanded = value;
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// Determine precision
|
|
265
|
+
bool low_precision = (query.scalar_type() == at::kHalf ||
|
|
266
|
+
query.scalar_type() == at::kBFloat16);
|
|
267
|
+
|
|
268
|
+
// For fp16 inputs, we can now output directly to fp16 (no extra conversion needed!)
|
|
269
|
+
bool low_precision_outputs = low_precision;
|
|
270
|
+
|
|
271
|
+
// Make inputs contiguous
|
|
272
|
+
auto q = query.contiguous();
|
|
273
|
+
auto k = k_expanded.contiguous();
|
|
274
|
+
auto v = v_expanded.contiguous();
|
|
275
|
+
|
|
276
|
+
// Handle attention mask
|
|
277
|
+
bool has_mask = attn_mask.has_value();
|
|
278
|
+
at::Tensor mask;
|
|
279
|
+
if (has_mask) {
|
|
280
|
+
mask = attn_mask.value();
|
|
281
|
+
TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
|
|
282
|
+
TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
|
|
283
|
+
// Convert to bool/uint8 if needed - kernel expects uchar (0 = attend, non-0 = mask out)
|
|
284
|
+
if (mask.scalar_type() == at::kBool) {
|
|
285
|
+
// Convert bool to uint8 for Metal compatibility
|
|
286
|
+
mask = mask.to(at::kByte);
|
|
287
|
+
}
|
|
288
|
+
TORCH_CHECK(mask.scalar_type() == at::kByte,
|
|
289
|
+
"Attention mask must be bool or uint8");
|
|
290
|
+
// Expand mask heads if needed (B, 1, N_q, N_kv) -> (B, H, N_q, N_kv)
|
|
291
|
+
if (mask.size(1) == 1 && num_heads > 1) {
|
|
292
|
+
mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
|
|
293
|
+
}
|
|
294
|
+
mask = mask.contiguous();
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
// Allocate output in the appropriate precision
|
|
298
|
+
// With lowPrecisionOutputs=true, MFA writes FP16 directly
|
|
299
|
+
at::Tensor output;
|
|
300
|
+
if (low_precision_outputs) {
|
|
301
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
302
|
+
query.options().dtype(at::kHalf));
|
|
303
|
+
} else {
|
|
304
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
305
|
+
query.options().dtype(at::kFloat));
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
// Allocate logsumexp (for backward pass, always fp32)
|
|
309
|
+
auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
|
|
310
|
+
query.options().dtype(at::kFloat));
|
|
311
|
+
|
|
312
|
+
// Get or create kernel with matching output precision and causal mode
|
|
313
|
+
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask);
|
|
314
|
+
|
|
315
|
+
// Get Metal buffers with byte offsets
|
|
316
|
+
auto q_info = getBufferInfo(q);
|
|
317
|
+
auto k_info = getBufferInfo(k);
|
|
318
|
+
auto v_info = getBufferInfo(v);
|
|
319
|
+
auto o_info = getBufferInfo(output);
|
|
320
|
+
auto l_info = getBufferInfo(logsumexp);
|
|
321
|
+
|
|
322
|
+
// Mask buffer info (may be nullptr if no mask)
|
|
323
|
+
BufferInfo mask_info = {nil, 0};
|
|
324
|
+
if (has_mask) {
|
|
325
|
+
mask_info = getBufferInfo(mask);
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
// Use PyTorch's MPS stream command encoder for zero-sync integration
|
|
329
|
+
@autoreleasepool {
|
|
330
|
+
auto stream = at::mps::getCurrentMPSStream();
|
|
331
|
+
|
|
332
|
+
// Get PyTorch's shared command encoder - this is the key for zero-sync!
|
|
333
|
+
// All our dispatches go onto the same encoder that PyTorch uses.
|
|
334
|
+
id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
|
|
335
|
+
|
|
336
|
+
// Execute MFA using the shared encoder (no sync needed!)
|
|
337
|
+
bool success = g_mfa_forward_encode(
|
|
338
|
+
kernel,
|
|
339
|
+
(__bridge void*)encoder, // PyTorch's shared command encoder
|
|
340
|
+
(__bridge void*)q_info.buffer,
|
|
341
|
+
(__bridge void*)k_info.buffer,
|
|
342
|
+
(__bridge void*)v_info.buffer,
|
|
343
|
+
(__bridge void*)o_info.buffer,
|
|
344
|
+
(__bridge void*)l_info.buffer,
|
|
345
|
+
has_mask ? (__bridge void*)mask_info.buffer : nullptr,
|
|
346
|
+
q_info.byte_offset,
|
|
347
|
+
k_info.byte_offset,
|
|
348
|
+
v_info.byte_offset,
|
|
349
|
+
o_info.byte_offset,
|
|
350
|
+
l_info.byte_offset,
|
|
351
|
+
mask_info.byte_offset,
|
|
352
|
+
static_cast<int32_t>(batch_size),
|
|
353
|
+
static_cast<int32_t>(num_heads)
|
|
354
|
+
);
|
|
355
|
+
|
|
356
|
+
if (!success) {
|
|
357
|
+
throw std::runtime_error("MFA forward pass failed");
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
// No commit needed - PyTorch will commit when it needs the results
|
|
361
|
+
// The encoder stays open for coalescing more kernels
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// Output is already in the correct dtype (fp16 or fp32)
|
|
365
|
+
// Return both output and logsumexp (needed for backward pass)
|
|
366
|
+
return std::make_tuple(output, logsumexp);
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
// Simple forward that only returns output (for inference)
|
|
370
|
+
at::Tensor mps_flash_attention_forward(
|
|
371
|
+
const at::Tensor& query,
|
|
372
|
+
const at::Tensor& key,
|
|
373
|
+
const at::Tensor& value,
|
|
374
|
+
bool is_causal,
|
|
375
|
+
const c10::optional<at::Tensor>& attn_mask
|
|
376
|
+
) {
|
|
377
|
+
auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal, attn_mask);
|
|
378
|
+
return output;
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
// ============================================================================
|
|
382
|
+
// Flash Attention Backward
|
|
383
|
+
// ============================================================================
|
|
384
|
+
|
|
385
|
+
std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
386
|
+
const at::Tensor& grad_output, // (B, H, N, D)
|
|
387
|
+
const at::Tensor& query, // (B, H, N, D)
|
|
388
|
+
const at::Tensor& key, // (B, H, N, D)
|
|
389
|
+
const at::Tensor& value, // (B, H, N, D)
|
|
390
|
+
const at::Tensor& output, // (B, H, N, D)
|
|
391
|
+
const at::Tensor& logsumexp, // (B, H, N)
|
|
392
|
+
bool is_causal,
|
|
393
|
+
const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
|
|
394
|
+
) {
|
|
395
|
+
// Initialize MFA on first call
|
|
396
|
+
if (!g_initialized) {
|
|
397
|
+
load_mfa_bridge();
|
|
398
|
+
if (!g_mfa_init()) {
|
|
399
|
+
throw std::runtime_error("Failed to initialize MFA");
|
|
400
|
+
}
|
|
401
|
+
g_initialized = true;
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
// Validate inputs
|
|
405
|
+
TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
|
|
406
|
+
TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
|
|
407
|
+
TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
|
|
408
|
+
TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
|
|
409
|
+
TORCH_CHECK(output.dim() == 4, "Output must be 4D (B, H, N, D)");
|
|
410
|
+
TORCH_CHECK(logsumexp.dim() == 3, "Logsumexp must be 3D (B, H, N)");
|
|
411
|
+
|
|
412
|
+
const int64_t batch_size = query.size(0);
|
|
413
|
+
const int64_t num_heads = query.size(1);
|
|
414
|
+
const int64_t seq_len_q = query.size(2);
|
|
415
|
+
const int64_t head_dim = query.size(3);
|
|
416
|
+
const int64_t seq_len_kv = key.size(2);
|
|
417
|
+
|
|
418
|
+
// Determine precision
|
|
419
|
+
bool low_precision = (query.scalar_type() == at::kHalf ||
|
|
420
|
+
query.scalar_type() == at::kBFloat16);
|
|
421
|
+
bool low_precision_outputs = low_precision;
|
|
422
|
+
|
|
423
|
+
// Handle attention mask
|
|
424
|
+
bool has_mask = attn_mask.has_value();
|
|
425
|
+
at::Tensor mask;
|
|
426
|
+
if (has_mask) {
|
|
427
|
+
mask = attn_mask.value();
|
|
428
|
+
TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
|
|
429
|
+
TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
|
|
430
|
+
if (mask.scalar_type() == at::kBool) {
|
|
431
|
+
mask = mask.to(at::kByte);
|
|
432
|
+
}
|
|
433
|
+
TORCH_CHECK(mask.scalar_type() == at::kByte,
|
|
434
|
+
"Attention mask must be bool or uint8");
|
|
435
|
+
if (mask.size(1) == 1 && num_heads > 1) {
|
|
436
|
+
mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
|
|
437
|
+
}
|
|
438
|
+
mask = mask.contiguous();
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
// Make inputs contiguous and upcast to fp32 for numerical stability
|
|
442
|
+
// The backward pass accumulates many small values, so fp32 precision is critical
|
|
443
|
+
auto q = query.contiguous().to(at::kFloat);
|
|
444
|
+
auto k = key.contiguous().to(at::kFloat);
|
|
445
|
+
auto v = value.contiguous().to(at::kFloat);
|
|
446
|
+
auto o = output.contiguous().to(at::kFloat);
|
|
447
|
+
auto dO = grad_output.contiguous().to(at::kFloat);
|
|
448
|
+
auto lse = logsumexp.contiguous();
|
|
449
|
+
|
|
450
|
+
// Get or create kernel - always use fp32 for backward pass
|
|
451
|
+
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, false, false, is_causal, has_mask);
|
|
452
|
+
|
|
453
|
+
// Allocate D buffer (dO * O reduction, always fp32)
|
|
454
|
+
auto D = at::empty({batch_size, num_heads, seq_len_q},
|
|
455
|
+
query.options().dtype(at::kFloat));
|
|
456
|
+
|
|
457
|
+
// Allocate gradients (always fp32 for numerical stability)
|
|
458
|
+
auto dQ = at::zeros({batch_size, num_heads, seq_len_q, head_dim},
|
|
459
|
+
query.options().dtype(at::kFloat));
|
|
460
|
+
auto dK = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
|
|
461
|
+
query.options().dtype(at::kFloat));
|
|
462
|
+
auto dV = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
|
|
463
|
+
query.options().dtype(at::kFloat));
|
|
464
|
+
|
|
465
|
+
// Get Metal buffers with byte offsets
|
|
466
|
+
auto q_info = getBufferInfo(q);
|
|
467
|
+
auto k_info = getBufferInfo(k);
|
|
468
|
+
auto v_info = getBufferInfo(v);
|
|
469
|
+
auto o_info = getBufferInfo(o);
|
|
470
|
+
auto do_info = getBufferInfo(dO);
|
|
471
|
+
auto l_info = getBufferInfo(lse);
|
|
472
|
+
auto d_info = getBufferInfo(D);
|
|
473
|
+
auto dq_info = getBufferInfo(dQ);
|
|
474
|
+
auto dk_info = getBufferInfo(dK);
|
|
475
|
+
auto dv_info = getBufferInfo(dV);
|
|
476
|
+
|
|
477
|
+
// Mask buffer info (may be nullptr if no mask)
|
|
478
|
+
BufferInfo mask_info = {nil, 0};
|
|
479
|
+
if (has_mask) {
|
|
480
|
+
mask_info = getBufferInfo(mask);
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
// Use PyTorch's MPS stream command encoder for zero-sync integration
|
|
484
|
+
@autoreleasepool {
|
|
485
|
+
auto stream = at::mps::getCurrentMPSStream();
|
|
486
|
+
|
|
487
|
+
// Get PyTorch's shared command encoder
|
|
488
|
+
id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
|
|
489
|
+
|
|
490
|
+
bool success = g_mfa_backward_encode(
|
|
491
|
+
kernel,
|
|
492
|
+
(__bridge void*)encoder, // PyTorch's shared command encoder
|
|
493
|
+
(__bridge void*)q_info.buffer,
|
|
494
|
+
(__bridge void*)k_info.buffer,
|
|
495
|
+
(__bridge void*)v_info.buffer,
|
|
496
|
+
(__bridge void*)o_info.buffer,
|
|
497
|
+
(__bridge void*)do_info.buffer,
|
|
498
|
+
(__bridge void*)l_info.buffer,
|
|
499
|
+
(__bridge void*)d_info.buffer,
|
|
500
|
+
(__bridge void*)dq_info.buffer,
|
|
501
|
+
(__bridge void*)dk_info.buffer,
|
|
502
|
+
(__bridge void*)dv_info.buffer,
|
|
503
|
+
has_mask ? (__bridge void*)mask_info.buffer : nullptr,
|
|
504
|
+
q_info.byte_offset,
|
|
505
|
+
k_info.byte_offset,
|
|
506
|
+
v_info.byte_offset,
|
|
507
|
+
o_info.byte_offset,
|
|
508
|
+
do_info.byte_offset,
|
|
509
|
+
l_info.byte_offset,
|
|
510
|
+
d_info.byte_offset,
|
|
511
|
+
dq_info.byte_offset,
|
|
512
|
+
dk_info.byte_offset,
|
|
513
|
+
dv_info.byte_offset,
|
|
514
|
+
mask_info.byte_offset,
|
|
515
|
+
static_cast<int32_t>(batch_size),
|
|
516
|
+
static_cast<int32_t>(num_heads)
|
|
517
|
+
);
|
|
518
|
+
|
|
519
|
+
if (!success) {
|
|
520
|
+
throw std::runtime_error("MFA backward pass failed");
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
// No commit needed - PyTorch will commit when it needs the results
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
// Convert gradients back to input dtype if needed
|
|
527
|
+
if (low_precision) {
|
|
528
|
+
dQ = dQ.to(query.scalar_type());
|
|
529
|
+
dK = dK.to(query.scalar_type());
|
|
530
|
+
dV = dV.to(query.scalar_type());
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
return std::make_tuple(dQ, dK, dV);
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
// ============================================================================
|
|
537
|
+
// Python Bindings
|
|
538
|
+
// ============================================================================
|
|
539
|
+
|
|
540
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
541
|
+
m.doc() = "MPS Flash Attention - Metal accelerated attention for Apple Silicon";
|
|
542
|
+
|
|
543
|
+
m.def("forward", &mps_flash_attention_forward,
|
|
544
|
+
"Flash Attention forward pass (returns output only)",
|
|
545
|
+
py::arg("query"),
|
|
546
|
+
py::arg("key"),
|
|
547
|
+
py::arg("value"),
|
|
548
|
+
py::arg("is_causal") = false,
|
|
549
|
+
py::arg("attn_mask") = py::none());
|
|
550
|
+
|
|
551
|
+
m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
|
|
552
|
+
"Flash Attention forward pass (returns output and logsumexp for backward)",
|
|
553
|
+
py::arg("query"),
|
|
554
|
+
py::arg("key"),
|
|
555
|
+
py::arg("value"),
|
|
556
|
+
py::arg("is_causal") = false,
|
|
557
|
+
py::arg("attn_mask") = py::none());
|
|
558
|
+
|
|
559
|
+
m.def("backward", &mps_flash_attention_backward,
|
|
560
|
+
"Flash Attention backward pass",
|
|
561
|
+
py::arg("grad_output"),
|
|
562
|
+
py::arg("query"),
|
|
563
|
+
py::arg("key"),
|
|
564
|
+
py::arg("value"),
|
|
565
|
+
py::arg("output"),
|
|
566
|
+
py::arg("logsumexp"),
|
|
567
|
+
py::arg("is_causal") = false,
|
|
568
|
+
py::arg("attn_mask") = py::none());
|
|
569
|
+
}
|
mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib
ADDED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin
ADDED
|
Binary file
|
|
Binary file
|
mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib
ADDED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin
ADDED
|
Binary file
|
|
Binary file
|
mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib
ADDED
|
Binary file
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
{
|
|
2
|
+
"version": "1.0",
|
|
3
|
+
"files": [
|
|
4
|
+
"06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib",
|
|
5
|
+
"adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib",
|
|
6
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin",
|
|
7
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin",
|
|
8
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin",
|
|
9
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin",
|
|
10
|
+
"ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib",
|
|
11
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin",
|
|
12
|
+
"a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib",
|
|
13
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib",
|
|
14
|
+
"975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib",
|
|
15
|
+
"2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib",
|
|
16
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin",
|
|
17
|
+
"09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib",
|
|
18
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin",
|
|
19
|
+
"0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib",
|
|
20
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin",
|
|
21
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin",
|
|
22
|
+
"771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib",
|
|
23
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin",
|
|
24
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib",
|
|
25
|
+
"f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib"
|
|
26
|
+
]
|
|
27
|
+
}
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mps-flash-attn
|
|
3
|
+
Version: 0.1.7
|
|
4
|
+
Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
|
|
5
|
+
Author: imperatormk
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/mpsops/mps-flash-attention
|
|
8
|
+
Project-URL: Repository, https://github.com/mpsops/mps-flash-attention
|
|
9
|
+
Project-URL: Issues, https://github.com/mpsops/mps-flash-attention/issues
|
|
10
|
+
Keywords: flash-attention,apple-silicon,pytorch,mps,metal,transformer,attention
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Operating System :: MacOS :: MacOS X
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: torch>=2.0.0
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
|
|
26
|
+
# MPS Flash Attention
|
|
27
|
+
|
|
28
|
+
Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
|
|
29
|
+
|
|
30
|
+
**O(N) memory** instead of O(N²), enabling 8K+ sequence lengths on unified memory.
|
|
31
|
+
|
|
32
|
+
## Features
|
|
33
|
+
|
|
34
|
+
- **Forward pass**: 2-5x faster than PyTorch SDPA
|
|
35
|
+
- **Backward pass**: Full gradient support for training (fp32 precision)
|
|
36
|
+
- **Causal masking**: Native kernel support (only 5% overhead)
|
|
37
|
+
- **Attention masks**: Full boolean mask support for arbitrary masking patterns
|
|
38
|
+
- **FP16/FP32**: Native fp16 output (no conversion overhead)
|
|
39
|
+
- **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
|
|
40
|
+
|
|
41
|
+
## Performance
|
|
42
|
+
|
|
43
|
+
Tested on M1 Max, N=2048, B=4, H=8, D=64:
|
|
44
|
+
|
|
45
|
+
| Operation | MPS Flash Attn | PyTorch SDPA | Speedup |
|
|
46
|
+
|-----------|----------------|--------------|---------|
|
|
47
|
+
| Forward | 5.3ms | 15ms | 2.8x |
|
|
48
|
+
| Forward+Backward | 55ms | 108ms | 2.0x |
|
|
49
|
+
| Memory | 80MB | 592MB | 7.4x less |
|
|
50
|
+
|
|
51
|
+
## Installation
|
|
52
|
+
|
|
53
|
+
### Prerequisites
|
|
54
|
+
|
|
55
|
+
- macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
|
|
56
|
+
- Xcode Command Line Tools (`xcode-select --install`)
|
|
57
|
+
- Python 3.10+ with PyTorch 2.0+
|
|
58
|
+
|
|
59
|
+
### Build from source
|
|
60
|
+
|
|
61
|
+
```bash
|
|
62
|
+
# Clone with submodules
|
|
63
|
+
git clone --recursive https://github.com/mpsops/mps-flash-attention.git
|
|
64
|
+
cd mps-flash-attention
|
|
65
|
+
|
|
66
|
+
# Build Swift bridge
|
|
67
|
+
cd swift-bridge
|
|
68
|
+
swift build -c release
|
|
69
|
+
cd ..
|
|
70
|
+
|
|
71
|
+
# Install Python package
|
|
72
|
+
pip install -e .
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
### Set environment variable
|
|
76
|
+
|
|
77
|
+
```bash
|
|
78
|
+
export MFA_BRIDGE_PATH=/path/to/mps-flash-attention/swift-bridge/.build/release/libMFABridge.dylib
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
## Usage
|
|
82
|
+
|
|
83
|
+
### Basic usage
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
from mps_flash_attn import flash_attention
|
|
87
|
+
|
|
88
|
+
# Standard attention (B, H, N, D)
|
|
89
|
+
q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
90
|
+
k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
91
|
+
v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
92
|
+
|
|
93
|
+
out = flash_attention(q, k, v)
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
### Causal masking (for autoregressive models)
|
|
97
|
+
|
|
98
|
+
```python
|
|
99
|
+
out = flash_attention(q, k, v, is_causal=True)
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
### Attention masks (for custom masking patterns)
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
# Boolean mask: True = masked (don't attend), False = attend
|
|
106
|
+
mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
|
|
107
|
+
mask[:, :, :, 512:] = True # Mask out positions after 512
|
|
108
|
+
|
|
109
|
+
out = flash_attention(q, k, v, attn_mask=mask)
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
### Training with gradients
|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
q.requires_grad = True
|
|
116
|
+
k.requires_grad = True
|
|
117
|
+
v.requires_grad = True
|
|
118
|
+
|
|
119
|
+
out = flash_attention(q, k, v, is_causal=True)
|
|
120
|
+
loss = out.sum()
|
|
121
|
+
loss.backward() # Computes dQ, dK, dV
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
### Drop-in replacement for SDPA
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
from mps_flash_attn import replace_sdpa
|
|
128
|
+
|
|
129
|
+
# Monkey-patch F.scaled_dot_product_attention
|
|
130
|
+
replace_sdpa()
|
|
131
|
+
|
|
132
|
+
# Now all attention ops use Flash Attention on MPS
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
## Architecture
|
|
136
|
+
|
|
137
|
+
```
|
|
138
|
+
+----------------------------------------------------------+
|
|
139
|
+
| Python API |
|
|
140
|
+
| mps_flash_attn/__init__.py |
|
|
141
|
+
| (flash_attention, autograd Function) |
|
|
142
|
+
+----------------------------+-----------------------------+
|
|
143
|
+
|
|
|
144
|
+
+----------------------------v-----------------------------+
|
|
145
|
+
| C++ Extension |
|
|
146
|
+
| mps_flash_attn/csrc/mps_flash_attn.mm |
|
|
147
|
+
| (PyTorch bindings, MTLBuffer handling, offsets) |
|
|
148
|
+
+----------------------------+-----------------------------+
|
|
149
|
+
| dlopen + dlsym
|
|
150
|
+
+----------------------------v-----------------------------+
|
|
151
|
+
| Swift Bridge |
|
|
152
|
+
| swift-bridge/Sources/MFABridge/ |
|
|
153
|
+
| (MFABridge.swift, MetallibCache.swift) |
|
|
154
|
+
| @_cdecl exports: mfa_init, mfa_create_kernel, |
|
|
155
|
+
| mfa_forward, mfa_backward |
|
|
156
|
+
+----------------------------+-----------------------------+
|
|
157
|
+
|
|
|
158
|
+
+----------------------------v-----------------------------+
|
|
159
|
+
| Metal Flash Attention |
|
|
160
|
+
| metal-flash-attention/Sources/FlashAttention/ |
|
|
161
|
+
| (AttentionDescriptor, AttentionKernel, etc.) |
|
|
162
|
+
| |
|
|
163
|
+
| Generates Metal shader source at runtime, |
|
|
164
|
+
| compiles to .metallib, caches pipelines |
|
|
165
|
+
+----------------------------------------------------------+
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
## Project Structure
|
|
169
|
+
|
|
170
|
+
```
|
|
171
|
+
mps-flash-attention/
|
|
172
|
+
├── mps_flash_attn/ # Python package
|
|
173
|
+
│ ├── __init__.py # Public API (flash_attention, replace_sdpa)
|
|
174
|
+
│ ├── csrc/
|
|
175
|
+
│ │ └── mps_flash_attn.mm # PyTorch C++ extension
|
|
176
|
+
│ └── kernels/ # Pre-compiled metallibs (optional)
|
|
177
|
+
│
|
|
178
|
+
├── swift-bridge/ # Swift -> C bridge
|
|
179
|
+
│ ├── Package.swift
|
|
180
|
+
│ └── Sources/MFABridge/
|
|
181
|
+
│ ├── MFABridge.swift # C-callable API (@_cdecl)
|
|
182
|
+
│ └── MetallibCache.swift # Disk caching for metallibs
|
|
183
|
+
│
|
|
184
|
+
├── metal-flash-attention/ # Upstream (git submodule)
|
|
185
|
+
│ └── Sources/FlashAttention/
|
|
186
|
+
│ └── Attention/
|
|
187
|
+
│ ├── AttentionDescriptor/ # Problem configuration
|
|
188
|
+
│ ├── AttentionKernel/ # Metal shader generation
|
|
189
|
+
│ └── ...
|
|
190
|
+
│
|
|
191
|
+
├── scripts/
|
|
192
|
+
│ └── build_metallibs.py # Pre-compile kernels for distribution
|
|
193
|
+
│
|
|
194
|
+
└── setup.py # Python package setup
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
## Changes from upstream metal-flash-attention
|
|
198
|
+
|
|
199
|
+
We made the following modifications to `metal-flash-attention`:
|
|
200
|
+
|
|
201
|
+
### 1. macOS 15+ compatibility (MTLLibraryCompiler.swift)
|
|
202
|
+
|
|
203
|
+
Apple restricted `__asm` in runtime-compiled Metal shaders on macOS 15. We added a fallback that uses `xcrun metal` CLI compilation when runtime compilation fails.
|
|
204
|
+
|
|
205
|
+
### 2. Causal masking support
|
|
206
|
+
|
|
207
|
+
Added `causal` flag to AttentionDescriptor and kernel generation:
|
|
208
|
+
|
|
209
|
+
- `AttentionDescriptor.swift`: Added `causal: Bool` property
|
|
210
|
+
- `AttentionKernelDescriptor.swift`: Added `causal: Bool` property
|
|
211
|
+
- `AttentionKernel.swift`: Added `causal` field
|
|
212
|
+
- `AttentionKernel+Softmax.swift`: Added `maskCausal()` function
|
|
213
|
+
- `AttentionKernel+Source.swift`: Added causal masking to forward/backward loops
|
|
214
|
+
|
|
215
|
+
## Next Steps
|
|
216
|
+
|
|
217
|
+
### 1. PR to upstream metal-flash-attention
|
|
218
|
+
|
|
219
|
+
The macOS 15 fix and causal masking should be contributed back:
|
|
220
|
+
|
|
221
|
+
```bash
|
|
222
|
+
cd metal-flash-attention
|
|
223
|
+
git checkout -b macos15-causal-support
|
|
224
|
+
# Commit changes to:
|
|
225
|
+
# - Sources/FlashAttention/Utilities/MTLLibraryCompiler.swift (new file)
|
|
226
|
+
# - Sources/FlashAttention/Attention/AttentionDescriptor/*.swift
|
|
227
|
+
# - Sources/FlashAttention/Attention/AttentionKernel/*.swift
|
|
228
|
+
git push origin macos15-causal-support
|
|
229
|
+
# Open PR at https://github.com/philipturner/metal-flash-attention
|
|
230
|
+
```
|
|
231
|
+
|
|
232
|
+
### 2. Publish mps-flash-attention to PyPI
|
|
233
|
+
|
|
234
|
+
```bash
|
|
235
|
+
# Add pyproject.toml with proper metadata
|
|
236
|
+
# Build wheel with pre-compiled Swift bridge
|
|
237
|
+
python -m build
|
|
238
|
+
twine upload dist/*
|
|
239
|
+
```
|
|
240
|
+
|
|
241
|
+
### 3. Pre-compile kernels for zero cold start
|
|
242
|
+
|
|
243
|
+
```bash
|
|
244
|
+
python scripts/build_metallibs.py
|
|
245
|
+
# Copies metallibs to mps_flash_attn/kernels/
|
|
246
|
+
# These get shipped with the wheel
|
|
247
|
+
```
|
|
248
|
+
|
|
249
|
+
## Current Status (Jan 2025)
|
|
250
|
+
|
|
251
|
+
**Working:**
|
|
252
|
+
- Forward pass (fp16/fp32)
|
|
253
|
+
- Backward pass (dQ, dK, dV gradients)
|
|
254
|
+
- Causal masking
|
|
255
|
+
- Metallib disk caching
|
|
256
|
+
- Pipeline binary caching (MTLBinaryArchive)
|
|
257
|
+
|
|
258
|
+
**Known limitations:**
|
|
259
|
+
- Sequence length must be divisible by block size (typically 64)
|
|
260
|
+
- Head dimension: Best with 32, 64, 96, 128
|
|
261
|
+
- No dropout
|
|
262
|
+
|
|
263
|
+
## Credits
|
|
264
|
+
|
|
265
|
+
- [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
|
|
266
|
+
- [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
|
|
267
|
+
|
|
268
|
+
## License
|
|
269
|
+
|
|
270
|
+
MIT
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
mps_flash_attn/_C.cpython-314-darwin.so,sha256=Wy33hq4OLXRvTI_HrBdnml7_MVqhIAC6LfHl2501WcY,266984
|
|
2
|
+
mps_flash_attn/__init__.py,sha256=HCz_Qs65RcxGD3FZ6tirBfxvFo0kk4BiOhFO-3QV_Pg,9980
|
|
3
|
+
mps_flash_attn/csrc/mps_flash_attn.mm,sha256=CSqtWw15FbQ3O50SgoOmL2ZsG2rKnAu8s8rE0MzxWas,22567
|
|
4
|
+
mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib,sha256=_oig6f2I6ZxBCKWbJF3ofmZMySm8gB399_M-lD2NOfM,13747
|
|
5
|
+
mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib,sha256=1fsmVvB5EubhN-y6s5CB-eVk_wuO2tfrabiQTwXvJJc,13171
|
|
6
|
+
mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib,sha256=5WKo_yAU-PgmulBUQhnzvt0DZRteVmo4-nc4U-T6G2g,17507
|
|
7
|
+
mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib,sha256=OehFzjntFdIgEIvE1EW2sGxDEzUgreATvI-fm5-HjaI,12723
|
|
8
|
+
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib,sha256=FUcHAIglSvbw6aavubMKzBH6PO5Z1TkUhKO4IqIZiLQ,14083
|
|
9
|
+
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin,sha256=4qmQhKBCDqZEydpRaPhuF0WBCiDXfALSW8SbAE9HUps,36496
|
|
10
|
+
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin,sha256=NB1ex3aP-SARN_u4I24laoe_H_rbfRBLyDgDnZH4IYY,36496
|
|
11
|
+
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin,sha256=uZpt_PpesdfmS7ywNrGqFZecNy4q0yXKQIz0rKu-9yg,36496
|
|
12
|
+
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin,sha256=dj6zpVaJtduxjWXg3J8bTuZo_LfAEutNGc0quZXQqSA,36496
|
|
13
|
+
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin,sha256=gwGlY6Od0Ub1u3JfRLDkksYFQR-aT3A-igFDYaBDEYA,36496
|
|
14
|
+
mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib,sha256=bfG4lBjhtTXAJEInsCBHfE9CzwUkLAZWduLy-jr6alA,12819
|
|
15
|
+
mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib,sha256=rPg7TwGLETfD5-GJHAHy084Wza0N5BcsifgxKEQ-HkU,12707
|
|
16
|
+
mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib,sha256=2q-7oJOeS8GR66iJJisbfNbIVuY44E41wwh_UgfBlug,17203
|
|
17
|
+
mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib,sha256=QVQMAdhrinNhTJ6vYHgeWWz9-tv4Kp3Z9O4tJvLfjLM,13747
|
|
18
|
+
mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib,sha256=v20uzcsN3HJAxopiFfu90cKJ_zNeUb13OZzm1ubGXaA,13747
|
|
19
|
+
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib,sha256=YhZNagEBf68RxgO8ZqTujfu0YjcUD85dq0Iv5fC8QLE,13171
|
|
20
|
+
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin,sha256=13GAdfAiz-3wWcKhvvMbydtoo9pCMXqUsYujQSv5A_8,34112
|
|
21
|
+
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin,sha256=DOicBen4ow2FSk2Jv2KPiDZ5-6H93-LvwafvOPqfoF0,34112
|
|
22
|
+
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin,sha256=CyRMosbr9mYqxARS3ta_v2cBoYuqnHmEwwQ6aaXsDUY,34112
|
|
23
|
+
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin,sha256=IHdTXUtoGnKKYBXeO-Yj5quXy91lgmfxhrD4Bn_yGzA,34112
|
|
24
|
+
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin,sha256=4BjRprnMycnhZql9829R6FS3HW30jejuDJM9p9vzVPs,34112
|
|
25
|
+
mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib,sha256=qyOaQtRVwL_Wc6GGdu6z-ftf0iX84XexuY09-lNLl5o,13747
|
|
26
|
+
mps_flash_attn/kernels/manifest.json,sha256=d5MkE_BjqDQuMNm1jZiwWkQKfB-yfFml3lLSeR-wCLo,1867
|
|
27
|
+
mps_flash_attn-0.1.7.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
|
|
28
|
+
mps_flash_attn-0.1.7.dist-info/METADATA,sha256=eIKQbJo__DR8opFhCTok4KsjNZyJynRXaL_r_riA7M4,8634
|
|
29
|
+
mps_flash_attn-0.1.7.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
|
|
30
|
+
mps_flash_attn-0.1.7.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
|
|
31
|
+
mps_flash_attn-0.1.7.dist-info/RECORD,,
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 imperatormk
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
This project includes code from metal-flash-attention by Philip Turner,
|
|
26
|
+
also licensed under the MIT License:
|
|
27
|
+
https://github.com/philipturner/metal-flash-attention
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
mps_flash_attn
|