mps-flash-attn 0.1.7__cp314-cp314-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (31) hide show
  1. mps_flash_attn/_C.cpython-314-darwin.so +0 -0
  2. mps_flash_attn/__init__.py +264 -0
  3. mps_flash_attn/csrc/mps_flash_attn.mm +569 -0
  4. mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  5. mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  6. mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  7. mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  8. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  9. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  10. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  11. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  12. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  13. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  14. mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  15. mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  16. mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  17. mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  18. mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  19. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  20. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  21. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  22. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  23. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  24. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  25. mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  26. mps_flash_attn/kernels/manifest.json +27 -0
  27. mps_flash_attn-0.1.7.dist-info/METADATA +270 -0
  28. mps_flash_attn-0.1.7.dist-info/RECORD +31 -0
  29. mps_flash_attn-0.1.7.dist-info/WHEEL +5 -0
  30. mps_flash_attn-0.1.7.dist-info/licenses/LICENSE +27 -0
  31. mps_flash_attn-0.1.7.dist-info/top_level.txt +1 -0
Binary file
@@ -0,0 +1,264 @@
1
+ """
2
+ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
3
+
4
+ This package provides memory-efficient attention using Metal Flash Attention kernels.
5
+ """
6
+
7
+ __version__ = "0.1.7"
8
+
9
+ import torch
10
+ from typing import Optional
11
+ import math
12
+ import threading
13
+ import os
14
+
15
+ # Try to import the C++ extension
16
+ try:
17
+ from . import _C
18
+ _HAS_MFA = True
19
+ except ImportError as e:
20
+ _HAS_MFA = False
21
+ _IMPORT_ERROR = str(e)
22
+
23
+ # Note: The C++ extension handles loading libMFABridge.dylib via dlopen.
24
+ # Set MFA_BRIDGE_PATH environment variable to specify the library location.
25
+ # Do NOT load the library here via ctypes - that causes duplicate class warnings.
26
+
27
+
28
+ def is_available() -> bool:
29
+ """Check if MPS Flash Attention is available."""
30
+ return _HAS_MFA and torch.backends.mps.is_available()
31
+
32
+
33
+ class FlashAttentionFunction(torch.autograd.Function):
34
+ """Autograd function for Flash Attention with backward pass support."""
35
+
36
+ @staticmethod
37
+ def forward(ctx, query, key, value, is_causal, scale, attn_mask):
38
+ # Apply scale if provided (MFA uses 1/sqrt(D) internally)
39
+ scale_factor = 1.0
40
+ if scale is not None:
41
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
42
+ if abs(scale - default_scale) > 1e-6:
43
+ scale_factor = scale / default_scale
44
+ query = query * scale_factor
45
+
46
+ # Forward with logsumexp for backward
47
+ output, logsumexp = _C.forward_with_lse(query, key, value, is_causal, attn_mask)
48
+
49
+ # Save for backward
50
+ if attn_mask is not None:
51
+ ctx.save_for_backward(query, key, value, output, logsumexp, attn_mask)
52
+ ctx.has_mask = True
53
+ else:
54
+ ctx.save_for_backward(query, key, value, output, logsumexp)
55
+ ctx.has_mask = False
56
+ ctx.is_causal = is_causal
57
+ ctx.scale_factor = scale_factor
58
+
59
+ return output
60
+
61
+ @staticmethod
62
+ def backward(ctx, grad_output):
63
+ if ctx.has_mask:
64
+ query, key, value, output, logsumexp, attn_mask = ctx.saved_tensors
65
+ else:
66
+ query, key, value, output, logsumexp = ctx.saved_tensors
67
+ attn_mask = None
68
+
69
+ # Compute gradients
70
+ dQ, dK, dV = _C.backward(
71
+ grad_output, query, key, value, output, logsumexp, ctx.is_causal, attn_mask
72
+ )
73
+
74
+ # If we scaled the query in forward, scale the gradient back
75
+ if ctx.scale_factor != 1.0:
76
+ dQ = dQ * ctx.scale_factor
77
+
78
+ # Return gradients (None for is_causal, scale, and attn_mask since they're not tensors or don't need grad)
79
+ return dQ, dK, dV, None, None, None
80
+
81
+
82
+ def flash_attention(
83
+ query: torch.Tensor,
84
+ key: torch.Tensor,
85
+ value: torch.Tensor,
86
+ is_causal: bool = False,
87
+ scale: Optional[float] = None,
88
+ attn_mask: Optional[torch.Tensor] = None,
89
+ ) -> torch.Tensor:
90
+ """
91
+ Compute scaled dot-product attention using Flash Attention on MPS.
92
+
93
+ This function provides O(N) memory complexity instead of O(N²) by using
94
+ tiled computation, allowing much longer sequences on limited GPU memory.
95
+
96
+ Supports both forward and backward passes for training.
97
+
98
+ Args:
99
+ query: Query tensor of shape (B, num_heads, seq_len, head_dim)
100
+ key: Key tensor of shape (B, num_heads, seq_len, head_dim)
101
+ value: Value tensor of shape (B, num_heads, seq_len, head_dim)
102
+ is_causal: If True, applies causal masking (for autoregressive models)
103
+ scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
104
+ attn_mask: Optional boolean attention mask of shape (B, 1, seq_len_q, seq_len_kv)
105
+ or (B, num_heads, seq_len_q, seq_len_kv). True values indicate
106
+ positions to be masked (not attended to).
107
+
108
+ Returns:
109
+ Output tensor of shape (B, num_heads, seq_len, head_dim)
110
+
111
+ Example:
112
+ >>> import torch
113
+ >>> from mps_flash_attn import flash_attention
114
+ >>> q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
115
+ >>> k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
116
+ >>> v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
117
+ >>> out = flash_attention(q, k, v)
118
+
119
+ # With gradients:
120
+ >>> q.requires_grad = True
121
+ >>> out = flash_attention(q, k, v)
122
+ >>> out.sum().backward() # Computes dQ
123
+
124
+ # With attention mask:
125
+ >>> mask = torch.zeros(2, 1, 4096, 4096, dtype=torch.bool, device='mps')
126
+ >>> mask[:, :, :, 2048:] = True # mask out second half of keys
127
+ >>> out = flash_attention(q, k, v, attn_mask=mask)
128
+ """
129
+ if not _HAS_MFA:
130
+ raise RuntimeError(
131
+ f"MPS Flash Attention C++ extension not available: {_IMPORT_ERROR}\n"
132
+ "Please rebuild with: pip install -e ."
133
+ )
134
+
135
+ if not torch.backends.mps.is_available():
136
+ raise RuntimeError("MPS not available")
137
+
138
+ # Validate device
139
+ if query.device.type != 'mps':
140
+ raise ValueError("query must be on MPS device")
141
+ if key.device.type != 'mps':
142
+ raise ValueError("key must be on MPS device")
143
+ if value.device.type != 'mps':
144
+ raise ValueError("value must be on MPS device")
145
+ if attn_mask is not None and attn_mask.device.type != 'mps':
146
+ raise ValueError("attn_mask must be on MPS device")
147
+
148
+ # Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
149
+ if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
150
+ # Apply scale if provided
151
+ if scale is not None:
152
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
153
+ if abs(scale - default_scale) > 1e-6:
154
+ scale_factor = scale / default_scale
155
+ query = query * scale_factor
156
+
157
+ # Forward only - no logsumexp needed, no tensors saved
158
+ return _C.forward(query, key, value, is_causal, attn_mask)
159
+
160
+ # Use autograd function for gradient support
161
+ return FlashAttentionFunction.apply(query, key, value, is_causal, scale, attn_mask)
162
+
163
+
164
+ def replace_sdpa():
165
+ """
166
+ Monkey-patch torch.nn.functional.scaled_dot_product_attention to use
167
+ Flash Attention on MPS devices.
168
+
169
+ Call this at the start of your script to automatically use Flash Attention
170
+ for all attention operations.
171
+ """
172
+ import torch.nn.functional as F
173
+
174
+ original_sdpa = F.scaled_dot_product_attention
175
+
176
+ def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
177
+ is_causal=False, scale=None):
178
+ # Use MFA for MPS tensors without dropout
179
+ # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
180
+ # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
181
+ if (query.device.type == 'mps' and
182
+ dropout_p == 0.0 and
183
+ _HAS_MFA and
184
+ query.shape[2] >= 1024):
185
+ try:
186
+ # Convert float mask to bool mask if needed
187
+ # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
188
+ # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
189
+ mfa_mask = None
190
+ if attn_mask is not None:
191
+ if attn_mask.dtype == torch.bool:
192
+ # Boolean mask: True means masked (don't attend)
193
+ mfa_mask = attn_mask
194
+ else:
195
+ # Float mask: typically -inf for masked positions, 0 for unmasked
196
+ # Convert: positions with large negative values -> True (masked)
197
+ mfa_mask = attn_mask < -1e4
198
+ return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
199
+ except Exception:
200
+ # Fall back to original on any error
201
+ pass
202
+
203
+ return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale)
204
+
205
+ F.scaled_dot_product_attention = patched_sdpa
206
+ print("MPS Flash Attention: Patched F.scaled_dot_product_attention")
207
+
208
+
209
+ def precompile():
210
+ """
211
+ Pre-compile Metal kernels for common configurations.
212
+
213
+ Call this once after installation to eliminate runtime compilation overhead.
214
+ Pre-compiled kernels are cached to disk and loaded instantly on subsequent runs.
215
+
216
+ This compiles kernels for:
217
+ - Sequence lengths: 64, 128, 256, 512, 1024, 2048, 4096, 8192
218
+ - Head dimensions: 32, 48, 64, 80, 96, 128
219
+ - Both fp32 and fp16 precision
220
+
221
+ Total: 96 kernel configurations
222
+ """
223
+ if not _HAS_MFA:
224
+ raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
225
+
226
+ import ctypes
227
+ import os
228
+
229
+ # Load the Swift bridge directly
230
+ bridge_path = os.environ.get("MFA_BRIDGE_PATH")
231
+ if not bridge_path:
232
+ # Try common locations
233
+ module_dir = os.path.dirname(__file__)
234
+ candidates = [
235
+ os.path.join(module_dir, "lib", "libMFABridge.dylib"), # Bundled in wheel
236
+ os.path.join(module_dir, "..", "swift-bridge", ".build", "release", "libMFABridge.dylib"),
237
+ os.path.join(module_dir, "libMFABridge.dylib"),
238
+ ]
239
+ for path in candidates:
240
+ if os.path.exists(path):
241
+ bridge_path = path
242
+ break
243
+
244
+ if not bridge_path or not os.path.exists(bridge_path):
245
+ raise RuntimeError("Cannot find libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.")
246
+
247
+ lib = ctypes.CDLL(bridge_path)
248
+ lib.mfa_precompile()
249
+ print("\nPre-compilation complete! Kernels cached to disk.")
250
+
251
+
252
+ def clear_cache():
253
+ """Clear the pre-compiled kernel cache."""
254
+ if not _HAS_MFA:
255
+ raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
256
+
257
+ import ctypes
258
+ import os
259
+
260
+ bridge_path = os.environ.get("MFA_BRIDGE_PATH")
261
+ if bridge_path and os.path.exists(bridge_path):
262
+ lib = ctypes.CDLL(bridge_path)
263
+ lib.mfa_clear_cache()
264
+ print("Cache cleared.")
@@ -0,0 +1,569 @@
1
+ /**
2
+ * MPS Flash Attention - PyTorch C++ Extension
3
+ *
4
+ * Bridges PyTorch MPS tensors to the MFA Swift library.
5
+ */
6
+
7
+ #include <torch/extension.h>
8
+ #include <ATen/mps/MPSStream.h>
9
+ #include <ATen/native/mps/OperationUtils.h>
10
+
11
+ #import <Metal/Metal.h>
12
+ #import <Foundation/Foundation.h>
13
+
14
+ #include <dlfcn.h>
15
+ #include <string>
16
+ #include <vector>
17
+
18
+ // ============================================================================
19
+ // MFA Bridge Function Types
20
+ // ============================================================================
21
+
22
+ typedef bool (*mfa_init_fn)();
23
+ typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool);
24
+ // New zero-sync encode functions that take PyTorch's command encoder
25
+ // Added mask_ptr and mask_offset parameters
26
+ typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
27
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
28
+ int32_t, int32_t);
29
+ typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
30
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
31
+ int32_t, int32_t);
32
+ // Legacy sync functions (fallback)
33
+ typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, void*,
34
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
35
+ int32_t, int32_t);
36
+ typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
37
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
38
+ int32_t, int32_t);
39
+ typedef void (*mfa_release_kernel_fn)(void*);
40
+
41
+ // Global function pointers
42
+ static mfa_init_fn g_mfa_init = nullptr;
43
+ static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
44
+ static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
45
+ static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
46
+ static mfa_forward_fn g_mfa_forward = nullptr;
47
+ static mfa_backward_fn g_mfa_backward = nullptr;
48
+ static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
49
+ static void* g_dylib_handle = nullptr;
50
+ static bool g_initialized = false;
51
+
52
+ // ============================================================================
53
+ // Load MFA Bridge Library
54
+ // ============================================================================
55
+
56
+ // Get the directory containing this shared library
57
+ static std::string get_module_dir() {
58
+ Dl_info info;
59
+ if (dladdr((void*)get_module_dir, &info) && info.dli_fname) {
60
+ std::string path(info.dli_fname);
61
+ size_t last_slash = path.rfind('/');
62
+ if (last_slash != std::string::npos) {
63
+ return path.substr(0, last_slash);
64
+ }
65
+ }
66
+ return ".";
67
+ }
68
+
69
+ static bool load_mfa_bridge() {
70
+ if (g_dylib_handle) return true;
71
+
72
+ // First check environment variable (highest priority)
73
+ const char* mfa_path = getenv("MFA_BRIDGE_PATH");
74
+ if (mfa_path) {
75
+ g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
76
+ if (g_dylib_handle) return true;
77
+ }
78
+
79
+ // Get the directory containing this extension module
80
+ std::string module_dir = get_module_dir();
81
+
82
+ // Try paths relative to the module directory
83
+ std::vector<std::string> paths = {
84
+ module_dir + "/lib/libMFABridge.dylib", // Bundled in wheel
85
+ module_dir + "/../swift-bridge/.build/release/libMFABridge.dylib", // Dev build
86
+ "libMFABridge.dylib", // Current directory fallback
87
+ };
88
+
89
+ for (const auto& path : paths) {
90
+ g_dylib_handle = dlopen(path.c_str(), RTLD_NOW);
91
+ if (g_dylib_handle) break;
92
+ }
93
+
94
+ if (!g_dylib_handle) {
95
+ throw std::runtime_error(
96
+ "Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
97
+ }
98
+
99
+ // Load function pointers
100
+ g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
101
+ g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
102
+ g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
103
+ g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
104
+ g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
105
+ g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
106
+ g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
107
+
108
+ // Require at least init, create_kernel, forward_encode (for zero-sync path)
109
+ if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward_encode) {
110
+ throw std::runtime_error("Failed to load MFA bridge functions");
111
+ }
112
+
113
+ return true;
114
+ }
115
+
116
+ // ============================================================================
117
+ // Get MTLBuffer from PyTorch MPS Tensor
118
+ // ============================================================================
119
+
120
+ struct BufferInfo {
121
+ id<MTLBuffer> buffer;
122
+ int64_t byte_offset;
123
+ };
124
+
125
+ static BufferInfo getBufferInfo(const at::Tensor& tensor) {
126
+ TORCH_CHECK(tensor.device().is_mps(), "Tensor must be on MPS device");
127
+ TORCH_CHECK(tensor.is_contiguous(), "Tensor must be contiguous");
128
+
129
+ // Get the underlying Metal buffer (covers entire storage)
130
+ id<MTLBuffer> buffer = at::native::mps::getMTLBufferStorage(tensor);
131
+
132
+ // Calculate byte offset: storage_offset() is in elements, multiply by element size
133
+ int64_t element_size = tensor.element_size();
134
+ int64_t byte_offset = tensor.storage_offset() * element_size;
135
+
136
+ return {buffer, byte_offset};
137
+ }
138
+
139
+ // ============================================================================
140
+ // Kernel Cache
141
+ // ============================================================================
142
+
143
+ struct KernelCacheKey {
144
+ int64_t seq_len_q;
145
+ int64_t seq_len_kv;
146
+ int64_t head_dim;
147
+ bool low_precision;
148
+ bool low_precision_outputs;
149
+ bool causal;
150
+ bool has_mask;
151
+
152
+ bool operator==(const KernelCacheKey& other) const {
153
+ return seq_len_q == other.seq_len_q &&
154
+ seq_len_kv == other.seq_len_kv &&
155
+ head_dim == other.head_dim &&
156
+ low_precision == other.low_precision &&
157
+ low_precision_outputs == other.low_precision_outputs &&
158
+ causal == other.causal &&
159
+ has_mask == other.has_mask;
160
+ }
161
+ };
162
+
163
+ struct KernelCacheKeyHash {
164
+ size_t operator()(const KernelCacheKey& k) const {
165
+ return std::hash<int64_t>()(k.seq_len_q) ^
166
+ (std::hash<int64_t>()(k.seq_len_kv) << 1) ^
167
+ (std::hash<int64_t>()(k.head_dim) << 2) ^
168
+ (std::hash<bool>()(k.low_precision) << 3) ^
169
+ (std::hash<bool>()(k.low_precision_outputs) << 4) ^
170
+ (std::hash<bool>()(k.causal) << 5) ^
171
+ (std::hash<bool>()(k.has_mask) << 6);
172
+ }
173
+ };
174
+
175
+ static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
176
+
177
+ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask) {
178
+ KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask};
179
+
180
+ auto it = g_kernel_cache.find(key);
181
+ if (it != g_kernel_cache.end()) {
182
+ return it->second;
183
+ }
184
+
185
+ void* kernel = g_mfa_create_kernel(
186
+ static_cast<int32_t>(seq_q),
187
+ static_cast<int32_t>(seq_kv),
188
+ static_cast<int32_t>(head_dim),
189
+ low_prec,
190
+ low_prec_outputs,
191
+ causal,
192
+ has_mask
193
+ );
194
+
195
+ if (!kernel) {
196
+ throw std::runtime_error("Failed to create MFA kernel");
197
+ }
198
+
199
+ g_kernel_cache[key] = kernel;
200
+ return kernel;
201
+ }
202
+
203
+ // ============================================================================
204
+ // Flash Attention Forward
205
+ // ============================================================================
206
+
207
+ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
208
+ const at::Tensor& query, // (B, H, N, D)
209
+ const at::Tensor& key, // (B, H, N, D)
210
+ const at::Tensor& value, // (B, H, N, D)
211
+ bool is_causal,
212
+ const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
213
+ ) {
214
+ // Initialize MFA on first call
215
+ if (!g_initialized) {
216
+ load_mfa_bridge();
217
+ if (!g_mfa_init()) {
218
+ throw std::runtime_error("Failed to initialize MFA");
219
+ }
220
+ g_initialized = true;
221
+ }
222
+
223
+ // Validate inputs
224
+ TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
225
+ TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
226
+ TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
227
+ TORCH_CHECK(query.device().is_mps(), "Query must be on MPS device");
228
+ TORCH_CHECK(key.device().is_mps(), "Key must be on MPS device");
229
+ TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
230
+
231
+ const int64_t batch_size = query.size(0);
232
+ const int64_t num_heads_q = query.size(1);
233
+ const int64_t num_heads_kv = key.size(1);
234
+ const int64_t seq_len_q = query.size(2);
235
+ const int64_t head_dim = query.size(3);
236
+ const int64_t seq_len_kv = key.size(2);
237
+
238
+ TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
239
+ "Batch size mismatch");
240
+ TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
241
+ "Head dimension mismatch");
242
+ TORCH_CHECK(key.size(1) == value.size(1),
243
+ "K and V must have same number of heads");
244
+
245
+ // Handle GQA (Grouped Query Attention): expand K/V if fewer heads than Q
246
+ const int64_t num_heads = num_heads_q;
247
+ at::Tensor k_expanded, v_expanded;
248
+
249
+ if (num_heads_kv != num_heads_q) {
250
+ // GQA: num_heads_q must be divisible by num_heads_kv
251
+ TORCH_CHECK(num_heads_q % num_heads_kv == 0,
252
+ "num_heads_q (", num_heads_q, ") must be divisible by num_heads_kv (", num_heads_kv, ")");
253
+ int64_t repeat_factor = num_heads_q / num_heads_kv;
254
+
255
+ // Expand K and V to match Q's head count: (B, H_kv, S, D) -> (B, H_q, S, D)
256
+ // Use repeat_interleave for proper GQA expansion
257
+ k_expanded = key.repeat_interleave(repeat_factor, /*dim=*/1);
258
+ v_expanded = value.repeat_interleave(repeat_factor, /*dim=*/1);
259
+ } else {
260
+ k_expanded = key;
261
+ v_expanded = value;
262
+ }
263
+
264
+ // Determine precision
265
+ bool low_precision = (query.scalar_type() == at::kHalf ||
266
+ query.scalar_type() == at::kBFloat16);
267
+
268
+ // For fp16 inputs, we can now output directly to fp16 (no extra conversion needed!)
269
+ bool low_precision_outputs = low_precision;
270
+
271
+ // Make inputs contiguous
272
+ auto q = query.contiguous();
273
+ auto k = k_expanded.contiguous();
274
+ auto v = v_expanded.contiguous();
275
+
276
+ // Handle attention mask
277
+ bool has_mask = attn_mask.has_value();
278
+ at::Tensor mask;
279
+ if (has_mask) {
280
+ mask = attn_mask.value();
281
+ TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
282
+ TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
283
+ // Convert to bool/uint8 if needed - kernel expects uchar (0 = attend, non-0 = mask out)
284
+ if (mask.scalar_type() == at::kBool) {
285
+ // Convert bool to uint8 for Metal compatibility
286
+ mask = mask.to(at::kByte);
287
+ }
288
+ TORCH_CHECK(mask.scalar_type() == at::kByte,
289
+ "Attention mask must be bool or uint8");
290
+ // Expand mask heads if needed (B, 1, N_q, N_kv) -> (B, H, N_q, N_kv)
291
+ if (mask.size(1) == 1 && num_heads > 1) {
292
+ mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
293
+ }
294
+ mask = mask.contiguous();
295
+ }
296
+
297
+ // Allocate output in the appropriate precision
298
+ // With lowPrecisionOutputs=true, MFA writes FP16 directly
299
+ at::Tensor output;
300
+ if (low_precision_outputs) {
301
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
302
+ query.options().dtype(at::kHalf));
303
+ } else {
304
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
305
+ query.options().dtype(at::kFloat));
306
+ }
307
+
308
+ // Allocate logsumexp (for backward pass, always fp32)
309
+ auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
310
+ query.options().dtype(at::kFloat));
311
+
312
+ // Get or create kernel with matching output precision and causal mode
313
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask);
314
+
315
+ // Get Metal buffers with byte offsets
316
+ auto q_info = getBufferInfo(q);
317
+ auto k_info = getBufferInfo(k);
318
+ auto v_info = getBufferInfo(v);
319
+ auto o_info = getBufferInfo(output);
320
+ auto l_info = getBufferInfo(logsumexp);
321
+
322
+ // Mask buffer info (may be nullptr if no mask)
323
+ BufferInfo mask_info = {nil, 0};
324
+ if (has_mask) {
325
+ mask_info = getBufferInfo(mask);
326
+ }
327
+
328
+ // Use PyTorch's MPS stream command encoder for zero-sync integration
329
+ @autoreleasepool {
330
+ auto stream = at::mps::getCurrentMPSStream();
331
+
332
+ // Get PyTorch's shared command encoder - this is the key for zero-sync!
333
+ // All our dispatches go onto the same encoder that PyTorch uses.
334
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
335
+
336
+ // Execute MFA using the shared encoder (no sync needed!)
337
+ bool success = g_mfa_forward_encode(
338
+ kernel,
339
+ (__bridge void*)encoder, // PyTorch's shared command encoder
340
+ (__bridge void*)q_info.buffer,
341
+ (__bridge void*)k_info.buffer,
342
+ (__bridge void*)v_info.buffer,
343
+ (__bridge void*)o_info.buffer,
344
+ (__bridge void*)l_info.buffer,
345
+ has_mask ? (__bridge void*)mask_info.buffer : nullptr,
346
+ q_info.byte_offset,
347
+ k_info.byte_offset,
348
+ v_info.byte_offset,
349
+ o_info.byte_offset,
350
+ l_info.byte_offset,
351
+ mask_info.byte_offset,
352
+ static_cast<int32_t>(batch_size),
353
+ static_cast<int32_t>(num_heads)
354
+ );
355
+
356
+ if (!success) {
357
+ throw std::runtime_error("MFA forward pass failed");
358
+ }
359
+
360
+ // No commit needed - PyTorch will commit when it needs the results
361
+ // The encoder stays open for coalescing more kernels
362
+ }
363
+
364
+ // Output is already in the correct dtype (fp16 or fp32)
365
+ // Return both output and logsumexp (needed for backward pass)
366
+ return std::make_tuple(output, logsumexp);
367
+ }
368
+
369
+ // Simple forward that only returns output (for inference)
370
+ at::Tensor mps_flash_attention_forward(
371
+ const at::Tensor& query,
372
+ const at::Tensor& key,
373
+ const at::Tensor& value,
374
+ bool is_causal,
375
+ const c10::optional<at::Tensor>& attn_mask
376
+ ) {
377
+ auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal, attn_mask);
378
+ return output;
379
+ }
380
+
381
+ // ============================================================================
382
+ // Flash Attention Backward
383
+ // ============================================================================
384
+
385
+ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
386
+ const at::Tensor& grad_output, // (B, H, N, D)
387
+ const at::Tensor& query, // (B, H, N, D)
388
+ const at::Tensor& key, // (B, H, N, D)
389
+ const at::Tensor& value, // (B, H, N, D)
390
+ const at::Tensor& output, // (B, H, N, D)
391
+ const at::Tensor& logsumexp, // (B, H, N)
392
+ bool is_causal,
393
+ const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
394
+ ) {
395
+ // Initialize MFA on first call
396
+ if (!g_initialized) {
397
+ load_mfa_bridge();
398
+ if (!g_mfa_init()) {
399
+ throw std::runtime_error("Failed to initialize MFA");
400
+ }
401
+ g_initialized = true;
402
+ }
403
+
404
+ // Validate inputs
405
+ TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
406
+ TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
407
+ TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
408
+ TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
409
+ TORCH_CHECK(output.dim() == 4, "Output must be 4D (B, H, N, D)");
410
+ TORCH_CHECK(logsumexp.dim() == 3, "Logsumexp must be 3D (B, H, N)");
411
+
412
+ const int64_t batch_size = query.size(0);
413
+ const int64_t num_heads = query.size(1);
414
+ const int64_t seq_len_q = query.size(2);
415
+ const int64_t head_dim = query.size(3);
416
+ const int64_t seq_len_kv = key.size(2);
417
+
418
+ // Determine precision
419
+ bool low_precision = (query.scalar_type() == at::kHalf ||
420
+ query.scalar_type() == at::kBFloat16);
421
+ bool low_precision_outputs = low_precision;
422
+
423
+ // Handle attention mask
424
+ bool has_mask = attn_mask.has_value();
425
+ at::Tensor mask;
426
+ if (has_mask) {
427
+ mask = attn_mask.value();
428
+ TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
429
+ TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
430
+ if (mask.scalar_type() == at::kBool) {
431
+ mask = mask.to(at::kByte);
432
+ }
433
+ TORCH_CHECK(mask.scalar_type() == at::kByte,
434
+ "Attention mask must be bool or uint8");
435
+ if (mask.size(1) == 1 && num_heads > 1) {
436
+ mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
437
+ }
438
+ mask = mask.contiguous();
439
+ }
440
+
441
+ // Make inputs contiguous and upcast to fp32 for numerical stability
442
+ // The backward pass accumulates many small values, so fp32 precision is critical
443
+ auto q = query.contiguous().to(at::kFloat);
444
+ auto k = key.contiguous().to(at::kFloat);
445
+ auto v = value.contiguous().to(at::kFloat);
446
+ auto o = output.contiguous().to(at::kFloat);
447
+ auto dO = grad_output.contiguous().to(at::kFloat);
448
+ auto lse = logsumexp.contiguous();
449
+
450
+ // Get or create kernel - always use fp32 for backward pass
451
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, false, false, is_causal, has_mask);
452
+
453
+ // Allocate D buffer (dO * O reduction, always fp32)
454
+ auto D = at::empty({batch_size, num_heads, seq_len_q},
455
+ query.options().dtype(at::kFloat));
456
+
457
+ // Allocate gradients (always fp32 for numerical stability)
458
+ auto dQ = at::zeros({batch_size, num_heads, seq_len_q, head_dim},
459
+ query.options().dtype(at::kFloat));
460
+ auto dK = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
461
+ query.options().dtype(at::kFloat));
462
+ auto dV = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
463
+ query.options().dtype(at::kFloat));
464
+
465
+ // Get Metal buffers with byte offsets
466
+ auto q_info = getBufferInfo(q);
467
+ auto k_info = getBufferInfo(k);
468
+ auto v_info = getBufferInfo(v);
469
+ auto o_info = getBufferInfo(o);
470
+ auto do_info = getBufferInfo(dO);
471
+ auto l_info = getBufferInfo(lse);
472
+ auto d_info = getBufferInfo(D);
473
+ auto dq_info = getBufferInfo(dQ);
474
+ auto dk_info = getBufferInfo(dK);
475
+ auto dv_info = getBufferInfo(dV);
476
+
477
+ // Mask buffer info (may be nullptr if no mask)
478
+ BufferInfo mask_info = {nil, 0};
479
+ if (has_mask) {
480
+ mask_info = getBufferInfo(mask);
481
+ }
482
+
483
+ // Use PyTorch's MPS stream command encoder for zero-sync integration
484
+ @autoreleasepool {
485
+ auto stream = at::mps::getCurrentMPSStream();
486
+
487
+ // Get PyTorch's shared command encoder
488
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
489
+
490
+ bool success = g_mfa_backward_encode(
491
+ kernel,
492
+ (__bridge void*)encoder, // PyTorch's shared command encoder
493
+ (__bridge void*)q_info.buffer,
494
+ (__bridge void*)k_info.buffer,
495
+ (__bridge void*)v_info.buffer,
496
+ (__bridge void*)o_info.buffer,
497
+ (__bridge void*)do_info.buffer,
498
+ (__bridge void*)l_info.buffer,
499
+ (__bridge void*)d_info.buffer,
500
+ (__bridge void*)dq_info.buffer,
501
+ (__bridge void*)dk_info.buffer,
502
+ (__bridge void*)dv_info.buffer,
503
+ has_mask ? (__bridge void*)mask_info.buffer : nullptr,
504
+ q_info.byte_offset,
505
+ k_info.byte_offset,
506
+ v_info.byte_offset,
507
+ o_info.byte_offset,
508
+ do_info.byte_offset,
509
+ l_info.byte_offset,
510
+ d_info.byte_offset,
511
+ dq_info.byte_offset,
512
+ dk_info.byte_offset,
513
+ dv_info.byte_offset,
514
+ mask_info.byte_offset,
515
+ static_cast<int32_t>(batch_size),
516
+ static_cast<int32_t>(num_heads)
517
+ );
518
+
519
+ if (!success) {
520
+ throw std::runtime_error("MFA backward pass failed");
521
+ }
522
+
523
+ // No commit needed - PyTorch will commit when it needs the results
524
+ }
525
+
526
+ // Convert gradients back to input dtype if needed
527
+ if (low_precision) {
528
+ dQ = dQ.to(query.scalar_type());
529
+ dK = dK.to(query.scalar_type());
530
+ dV = dV.to(query.scalar_type());
531
+ }
532
+
533
+ return std::make_tuple(dQ, dK, dV);
534
+ }
535
+
536
+ // ============================================================================
537
+ // Python Bindings
538
+ // ============================================================================
539
+
540
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
541
+ m.doc() = "MPS Flash Attention - Metal accelerated attention for Apple Silicon";
542
+
543
+ m.def("forward", &mps_flash_attention_forward,
544
+ "Flash Attention forward pass (returns output only)",
545
+ py::arg("query"),
546
+ py::arg("key"),
547
+ py::arg("value"),
548
+ py::arg("is_causal") = false,
549
+ py::arg("attn_mask") = py::none());
550
+
551
+ m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
552
+ "Flash Attention forward pass (returns output and logsumexp for backward)",
553
+ py::arg("query"),
554
+ py::arg("key"),
555
+ py::arg("value"),
556
+ py::arg("is_causal") = false,
557
+ py::arg("attn_mask") = py::none());
558
+
559
+ m.def("backward", &mps_flash_attention_backward,
560
+ "Flash Attention backward pass",
561
+ py::arg("grad_output"),
562
+ py::arg("query"),
563
+ py::arg("key"),
564
+ py::arg("value"),
565
+ py::arg("output"),
566
+ py::arg("logsumexp"),
567
+ py::arg("is_causal") = false,
568
+ py::arg("attn_mask") = py::none());
569
+ }
@@ -0,0 +1,27 @@
1
+ {
2
+ "version": "1.0",
3
+ "files": [
4
+ "06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib",
5
+ "adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib",
6
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin",
7
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin",
8
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin",
9
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin",
10
+ "ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib",
11
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin",
12
+ "a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib",
13
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib",
14
+ "975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib",
15
+ "2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib",
16
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin",
17
+ "09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib",
18
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin",
19
+ "0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib",
20
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin",
21
+ "73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin",
22
+ "771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib",
23
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin",
24
+ "eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib",
25
+ "f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib"
26
+ ]
27
+ }
@@ -0,0 +1,270 @@
1
+ Metadata-Version: 2.4
2
+ Name: mps-flash-attn
3
+ Version: 0.1.7
4
+ Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
+ Author: imperatormk
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/mpsops/mps-flash-attention
8
+ Project-URL: Repository, https://github.com/mpsops/mps-flash-attention
9
+ Project-URL: Issues, https://github.com/mpsops/mps-flash-attention/issues
10
+ Keywords: flash-attention,apple-silicon,pytorch,mps,metal,transformer,attention
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Operating System :: MacOS :: MacOS X
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.10
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: torch>=2.0.0
24
+ Dynamic: license-file
25
+
26
+ # MPS Flash Attention
27
+
28
+ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
29
+
30
+ **O(N) memory** instead of O(N²), enabling 8K+ sequence lengths on unified memory.
31
+
32
+ ## Features
33
+
34
+ - **Forward pass**: 2-5x faster than PyTorch SDPA
35
+ - **Backward pass**: Full gradient support for training (fp32 precision)
36
+ - **Causal masking**: Native kernel support (only 5% overhead)
37
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
38
+ - **FP16/FP32**: Native fp16 output (no conversion overhead)
39
+ - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
40
+
41
+ ## Performance
42
+
43
+ Tested on M1 Max, N=2048, B=4, H=8, D=64:
44
+
45
+ | Operation | MPS Flash Attn | PyTorch SDPA | Speedup |
46
+ |-----------|----------------|--------------|---------|
47
+ | Forward | 5.3ms | 15ms | 2.8x |
48
+ | Forward+Backward | 55ms | 108ms | 2.0x |
49
+ | Memory | 80MB | 592MB | 7.4x less |
50
+
51
+ ## Installation
52
+
53
+ ### Prerequisites
54
+
55
+ - macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
56
+ - Xcode Command Line Tools (`xcode-select --install`)
57
+ - Python 3.10+ with PyTorch 2.0+
58
+
59
+ ### Build from source
60
+
61
+ ```bash
62
+ # Clone with submodules
63
+ git clone --recursive https://github.com/mpsops/mps-flash-attention.git
64
+ cd mps-flash-attention
65
+
66
+ # Build Swift bridge
67
+ cd swift-bridge
68
+ swift build -c release
69
+ cd ..
70
+
71
+ # Install Python package
72
+ pip install -e .
73
+ ```
74
+
75
+ ### Set environment variable
76
+
77
+ ```bash
78
+ export MFA_BRIDGE_PATH=/path/to/mps-flash-attention/swift-bridge/.build/release/libMFABridge.dylib
79
+ ```
80
+
81
+ ## Usage
82
+
83
+ ### Basic usage
84
+
85
+ ```python
86
+ from mps_flash_attn import flash_attention
87
+
88
+ # Standard attention (B, H, N, D)
89
+ q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
90
+ k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
91
+ v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
92
+
93
+ out = flash_attention(q, k, v)
94
+ ```
95
+
96
+ ### Causal masking (for autoregressive models)
97
+
98
+ ```python
99
+ out = flash_attention(q, k, v, is_causal=True)
100
+ ```
101
+
102
+ ### Attention masks (for custom masking patterns)
103
+
104
+ ```python
105
+ # Boolean mask: True = masked (don't attend), False = attend
106
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
107
+ mask[:, :, :, 512:] = True # Mask out positions after 512
108
+
109
+ out = flash_attention(q, k, v, attn_mask=mask)
110
+ ```
111
+
112
+ ### Training with gradients
113
+
114
+ ```python
115
+ q.requires_grad = True
116
+ k.requires_grad = True
117
+ v.requires_grad = True
118
+
119
+ out = flash_attention(q, k, v, is_causal=True)
120
+ loss = out.sum()
121
+ loss.backward() # Computes dQ, dK, dV
122
+ ```
123
+
124
+ ### Drop-in replacement for SDPA
125
+
126
+ ```python
127
+ from mps_flash_attn import replace_sdpa
128
+
129
+ # Monkey-patch F.scaled_dot_product_attention
130
+ replace_sdpa()
131
+
132
+ # Now all attention ops use Flash Attention on MPS
133
+ ```
134
+
135
+ ## Architecture
136
+
137
+ ```
138
+ +----------------------------------------------------------+
139
+ | Python API |
140
+ | mps_flash_attn/__init__.py |
141
+ | (flash_attention, autograd Function) |
142
+ +----------------------------+-----------------------------+
143
+ |
144
+ +----------------------------v-----------------------------+
145
+ | C++ Extension |
146
+ | mps_flash_attn/csrc/mps_flash_attn.mm |
147
+ | (PyTorch bindings, MTLBuffer handling, offsets) |
148
+ +----------------------------+-----------------------------+
149
+ | dlopen + dlsym
150
+ +----------------------------v-----------------------------+
151
+ | Swift Bridge |
152
+ | swift-bridge/Sources/MFABridge/ |
153
+ | (MFABridge.swift, MetallibCache.swift) |
154
+ | @_cdecl exports: mfa_init, mfa_create_kernel, |
155
+ | mfa_forward, mfa_backward |
156
+ +----------------------------+-----------------------------+
157
+ |
158
+ +----------------------------v-----------------------------+
159
+ | Metal Flash Attention |
160
+ | metal-flash-attention/Sources/FlashAttention/ |
161
+ | (AttentionDescriptor, AttentionKernel, etc.) |
162
+ | |
163
+ | Generates Metal shader source at runtime, |
164
+ | compiles to .metallib, caches pipelines |
165
+ +----------------------------------------------------------+
166
+ ```
167
+
168
+ ## Project Structure
169
+
170
+ ```
171
+ mps-flash-attention/
172
+ ├── mps_flash_attn/ # Python package
173
+ │ ├── __init__.py # Public API (flash_attention, replace_sdpa)
174
+ │ ├── csrc/
175
+ │ │ └── mps_flash_attn.mm # PyTorch C++ extension
176
+ │ └── kernels/ # Pre-compiled metallibs (optional)
177
+
178
+ ├── swift-bridge/ # Swift -> C bridge
179
+ │ ├── Package.swift
180
+ │ └── Sources/MFABridge/
181
+ │ ├── MFABridge.swift # C-callable API (@_cdecl)
182
+ │ └── MetallibCache.swift # Disk caching for metallibs
183
+
184
+ ├── metal-flash-attention/ # Upstream (git submodule)
185
+ │ └── Sources/FlashAttention/
186
+ │ └── Attention/
187
+ │ ├── AttentionDescriptor/ # Problem configuration
188
+ │ ├── AttentionKernel/ # Metal shader generation
189
+ │ └── ...
190
+
191
+ ├── scripts/
192
+ │ └── build_metallibs.py # Pre-compile kernels for distribution
193
+
194
+ └── setup.py # Python package setup
195
+ ```
196
+
197
+ ## Changes from upstream metal-flash-attention
198
+
199
+ We made the following modifications to `metal-flash-attention`:
200
+
201
+ ### 1. macOS 15+ compatibility (MTLLibraryCompiler.swift)
202
+
203
+ Apple restricted `__asm` in runtime-compiled Metal shaders on macOS 15. We added a fallback that uses `xcrun metal` CLI compilation when runtime compilation fails.
204
+
205
+ ### 2. Causal masking support
206
+
207
+ Added `causal` flag to AttentionDescriptor and kernel generation:
208
+
209
+ - `AttentionDescriptor.swift`: Added `causal: Bool` property
210
+ - `AttentionKernelDescriptor.swift`: Added `causal: Bool` property
211
+ - `AttentionKernel.swift`: Added `causal` field
212
+ - `AttentionKernel+Softmax.swift`: Added `maskCausal()` function
213
+ - `AttentionKernel+Source.swift`: Added causal masking to forward/backward loops
214
+
215
+ ## Next Steps
216
+
217
+ ### 1. PR to upstream metal-flash-attention
218
+
219
+ The macOS 15 fix and causal masking should be contributed back:
220
+
221
+ ```bash
222
+ cd metal-flash-attention
223
+ git checkout -b macos15-causal-support
224
+ # Commit changes to:
225
+ # - Sources/FlashAttention/Utilities/MTLLibraryCompiler.swift (new file)
226
+ # - Sources/FlashAttention/Attention/AttentionDescriptor/*.swift
227
+ # - Sources/FlashAttention/Attention/AttentionKernel/*.swift
228
+ git push origin macos15-causal-support
229
+ # Open PR at https://github.com/philipturner/metal-flash-attention
230
+ ```
231
+
232
+ ### 2. Publish mps-flash-attention to PyPI
233
+
234
+ ```bash
235
+ # Add pyproject.toml with proper metadata
236
+ # Build wheel with pre-compiled Swift bridge
237
+ python -m build
238
+ twine upload dist/*
239
+ ```
240
+
241
+ ### 3. Pre-compile kernels for zero cold start
242
+
243
+ ```bash
244
+ python scripts/build_metallibs.py
245
+ # Copies metallibs to mps_flash_attn/kernels/
246
+ # These get shipped with the wheel
247
+ ```
248
+
249
+ ## Current Status (Jan 2025)
250
+
251
+ **Working:**
252
+ - Forward pass (fp16/fp32)
253
+ - Backward pass (dQ, dK, dV gradients)
254
+ - Causal masking
255
+ - Metallib disk caching
256
+ - Pipeline binary caching (MTLBinaryArchive)
257
+
258
+ **Known limitations:**
259
+ - Sequence length must be divisible by block size (typically 64)
260
+ - Head dimension: Best with 32, 64, 96, 128
261
+ - No dropout
262
+
263
+ ## Credits
264
+
265
+ - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
266
+ - [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
267
+
268
+ ## License
269
+
270
+ MIT
@@ -0,0 +1,31 @@
1
+ mps_flash_attn/_C.cpython-314-darwin.so,sha256=Wy33hq4OLXRvTI_HrBdnml7_MVqhIAC6LfHl2501WcY,266984
2
+ mps_flash_attn/__init__.py,sha256=HCz_Qs65RcxGD3FZ6tirBfxvFo0kk4BiOhFO-3QV_Pg,9980
3
+ mps_flash_attn/csrc/mps_flash_attn.mm,sha256=CSqtWw15FbQ3O50SgoOmL2ZsG2rKnAu8s8rE0MzxWas,22567
4
+ mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib,sha256=_oig6f2I6ZxBCKWbJF3ofmZMySm8gB399_M-lD2NOfM,13747
5
+ mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib,sha256=1fsmVvB5EubhN-y6s5CB-eVk_wuO2tfrabiQTwXvJJc,13171
6
+ mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib,sha256=5WKo_yAU-PgmulBUQhnzvt0DZRteVmo4-nc4U-T6G2g,17507
7
+ mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib,sha256=OehFzjntFdIgEIvE1EW2sGxDEzUgreATvI-fm5-HjaI,12723
8
+ mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib,sha256=FUcHAIglSvbw6aavubMKzBH6PO5Z1TkUhKO4IqIZiLQ,14083
9
+ mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin,sha256=4qmQhKBCDqZEydpRaPhuF0WBCiDXfALSW8SbAE9HUps,36496
10
+ mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin,sha256=NB1ex3aP-SARN_u4I24laoe_H_rbfRBLyDgDnZH4IYY,36496
11
+ mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin,sha256=uZpt_PpesdfmS7ywNrGqFZecNy4q0yXKQIz0rKu-9yg,36496
12
+ mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin,sha256=dj6zpVaJtduxjWXg3J8bTuZo_LfAEutNGc0quZXQqSA,36496
13
+ mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin,sha256=gwGlY6Od0Ub1u3JfRLDkksYFQR-aT3A-igFDYaBDEYA,36496
14
+ mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib,sha256=bfG4lBjhtTXAJEInsCBHfE9CzwUkLAZWduLy-jr6alA,12819
15
+ mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib,sha256=rPg7TwGLETfD5-GJHAHy084Wza0N5BcsifgxKEQ-HkU,12707
16
+ mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib,sha256=2q-7oJOeS8GR66iJJisbfNbIVuY44E41wwh_UgfBlug,17203
17
+ mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib,sha256=QVQMAdhrinNhTJ6vYHgeWWz9-tv4Kp3Z9O4tJvLfjLM,13747
18
+ mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib,sha256=v20uzcsN3HJAxopiFfu90cKJ_zNeUb13OZzm1ubGXaA,13747
19
+ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib,sha256=YhZNagEBf68RxgO8ZqTujfu0YjcUD85dq0Iv5fC8QLE,13171
20
+ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin,sha256=13GAdfAiz-3wWcKhvvMbydtoo9pCMXqUsYujQSv5A_8,34112
21
+ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin,sha256=DOicBen4ow2FSk2Jv2KPiDZ5-6H93-LvwafvOPqfoF0,34112
22
+ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin,sha256=CyRMosbr9mYqxARS3ta_v2cBoYuqnHmEwwQ6aaXsDUY,34112
23
+ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin,sha256=IHdTXUtoGnKKYBXeO-Yj5quXy91lgmfxhrD4Bn_yGzA,34112
24
+ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin,sha256=4BjRprnMycnhZql9829R6FS3HW30jejuDJM9p9vzVPs,34112
25
+ mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib,sha256=qyOaQtRVwL_Wc6GGdu6z-ftf0iX84XexuY09-lNLl5o,13747
26
+ mps_flash_attn/kernels/manifest.json,sha256=d5MkE_BjqDQuMNm1jZiwWkQKfB-yfFml3lLSeR-wCLo,1867
27
+ mps_flash_attn-0.1.7.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
28
+ mps_flash_attn-0.1.7.dist-info/METADATA,sha256=eIKQbJo__DR8opFhCTok4KsjNZyJynRXaL_r_riA7M4,8634
29
+ mps_flash_attn-0.1.7.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
30
+ mps_flash_attn-0.1.7.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
31
+ mps_flash_attn-0.1.7.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: false
4
+ Tag: cp314-cp314-macosx_15_0_arm64
5
+
@@ -0,0 +1,27 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 imperatormk
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ ---
24
+
25
+ This project includes code from metal-flash-attention by Philip Turner,
26
+ also licensed under the MIT License:
27
+ https://github.com/philipturner/metal-flash-attention
@@ -0,0 +1 @@
1
+ mps_flash_attn