mps-flash-attn 0.1.0__cp314-cp314-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (33) hide show
  1. mps_flash_attn/_C.cpython-314-darwin.so +0 -0
  2. mps_flash_attn/__init__.py +246 -0
  3. mps_flash_attn/csrc/mps_flash_attn.cpp +441 -0
  4. mps_flash_attn/csrc/mps_flash_attn.mm +441 -0
  5. mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  6. mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  7. mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  8. mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  9. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  10. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  11. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  12. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  13. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  14. mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  15. mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  16. mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  17. mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  18. mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  19. mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  20. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  21. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  22. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  23. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  24. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  25. mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  26. mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  27. mps_flash_attn/kernels/manifest.json +27 -0
  28. mps_flash_attn/lib/libMFABridge.dylib +0 -0
  29. mps_flash_attn-0.1.0.dist-info/METADATA +264 -0
  30. mps_flash_attn-0.1.0.dist-info/RECORD +33 -0
  31. mps_flash_attn-0.1.0.dist-info/WHEEL +5 -0
  32. mps_flash_attn-0.1.0.dist-info/licenses/LICENSE +27 -0
  33. mps_flash_attn-0.1.0.dist-info/top_level.txt +1 -0
Binary file
@@ -0,0 +1,246 @@
1
+ """
2
+ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
3
+
4
+ This package provides memory-efficient attention using Metal Flash Attention kernels.
5
+ """
6
+
7
+ import torch
8
+ from typing import Optional
9
+ import math
10
+ import threading
11
+ import os
12
+
13
+ # Try to import the C++ extension
14
+ try:
15
+ from . import _C
16
+ _HAS_MFA = True
17
+ except ImportError as e:
18
+ _HAS_MFA = False
19
+ _IMPORT_ERROR = str(e)
20
+
21
+ # Set up shipped kernels directory for zero-compilation loading
22
+ def _init_shipped_kernels():
23
+ """Point the Swift bridge to pre-shipped kernel binaries."""
24
+ try:
25
+ import ctypes
26
+ bridge_path = os.environ.get("MFA_BRIDGE_PATH")
27
+ if not bridge_path:
28
+ module_dir = os.path.dirname(__file__)
29
+ candidates = [
30
+ os.path.join(module_dir, "lib", "libMFABridge.dylib"), # Bundled in wheel
31
+ os.path.join(module_dir, "..", "swift-bridge", ".build", "release", "libMFABridge.dylib"),
32
+ os.path.join(module_dir, "libMFABridge.dylib"),
33
+ ]
34
+ for path in candidates:
35
+ if os.path.exists(path):
36
+ bridge_path = path
37
+ break
38
+
39
+ if bridge_path and os.path.exists(bridge_path):
40
+ lib = ctypes.CDLL(bridge_path)
41
+
42
+ # Set shipped kernels directory (pre-compiled metallibs + pipeline binaries)
43
+ kernels_dir = os.path.join(os.path.dirname(__file__), "kernels")
44
+ if os.path.exists(kernels_dir):
45
+ lib.mfa_set_kernels_dir(kernels_dir.encode('utf-8'))
46
+
47
+ lib.mfa_init()
48
+ except Exception:
49
+ pass # Init is optional, will fall back to runtime compilation
50
+
51
+ # Initialize shipped kernels on import
52
+ if _HAS_MFA:
53
+ _init_shipped_kernels()
54
+
55
+
56
+ def is_available() -> bool:
57
+ """Check if MPS Flash Attention is available."""
58
+ return _HAS_MFA and torch.backends.mps.is_available()
59
+
60
+
61
+ class FlashAttentionFunction(torch.autograd.Function):
62
+ """Autograd function for Flash Attention with backward pass support."""
63
+
64
+ @staticmethod
65
+ def forward(ctx, query, key, value, is_causal, scale):
66
+ # Apply scale if provided (MFA uses 1/sqrt(D) internally)
67
+ scale_factor = 1.0
68
+ if scale is not None:
69
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
70
+ if abs(scale - default_scale) > 1e-6:
71
+ scale_factor = scale / default_scale
72
+ query = query * scale_factor
73
+
74
+ # Forward with logsumexp for backward
75
+ output, logsumexp = _C.forward_with_lse(query, key, value, is_causal)
76
+
77
+ # Save for backward
78
+ ctx.save_for_backward(query, key, value, output, logsumexp)
79
+ ctx.is_causal = is_causal
80
+ ctx.scale_factor = scale_factor
81
+
82
+ return output
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output):
86
+ query, key, value, output, logsumexp = ctx.saved_tensors
87
+
88
+ # Compute gradients
89
+ dQ, dK, dV = _C.backward(
90
+ grad_output, query, key, value, output, logsumexp, ctx.is_causal
91
+ )
92
+
93
+ # If we scaled the query in forward, scale the gradient back
94
+ if ctx.scale_factor != 1.0:
95
+ dQ = dQ * ctx.scale_factor
96
+
97
+ # Return gradients (None for is_causal and scale since they're not tensors)
98
+ return dQ, dK, dV, None, None
99
+
100
+
101
+ def flash_attention(
102
+ query: torch.Tensor,
103
+ key: torch.Tensor,
104
+ value: torch.Tensor,
105
+ is_causal: bool = False,
106
+ scale: Optional[float] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Compute scaled dot-product attention using Flash Attention on MPS.
110
+
111
+ This function provides O(N) memory complexity instead of O(N²) by using
112
+ tiled computation, allowing much longer sequences on limited GPU memory.
113
+
114
+ Supports both forward and backward passes for training.
115
+
116
+ Args:
117
+ query: Query tensor of shape (B, num_heads, seq_len, head_dim)
118
+ key: Key tensor of shape (B, num_heads, seq_len, head_dim)
119
+ value: Value tensor of shape (B, num_heads, seq_len, head_dim)
120
+ is_causal: If True, applies causal masking (for autoregressive models)
121
+ scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
122
+
123
+ Returns:
124
+ Output tensor of shape (B, num_heads, seq_len, head_dim)
125
+
126
+ Example:
127
+ >>> import torch
128
+ >>> from mps_flash_attn import flash_attention
129
+ >>> q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
130
+ >>> k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
131
+ >>> v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
132
+ >>> out = flash_attention(q, k, v)
133
+
134
+ # With gradients:
135
+ >>> q.requires_grad = True
136
+ >>> out = flash_attention(q, k, v)
137
+ >>> out.sum().backward() # Computes dQ
138
+ """
139
+ if not _HAS_MFA:
140
+ raise RuntimeError(
141
+ f"MPS Flash Attention C++ extension not available: {_IMPORT_ERROR}\n"
142
+ "Please rebuild with: pip install -e ."
143
+ )
144
+
145
+ if not torch.backends.mps.is_available():
146
+ raise RuntimeError("MPS not available")
147
+
148
+ # Validate device
149
+ if query.device.type != 'mps':
150
+ raise ValueError("query must be on MPS device")
151
+ if key.device.type != 'mps':
152
+ raise ValueError("key must be on MPS device")
153
+ if value.device.type != 'mps':
154
+ raise ValueError("value must be on MPS device")
155
+
156
+ # Use autograd function for gradient support
157
+ return FlashAttentionFunction.apply(query, key, value, is_causal, scale)
158
+
159
+
160
+ def replace_sdpa():
161
+ """
162
+ Monkey-patch torch.nn.functional.scaled_dot_product_attention to use
163
+ Flash Attention on MPS devices.
164
+
165
+ Call this at the start of your script to automatically use Flash Attention
166
+ for all attention operations.
167
+ """
168
+ import torch.nn.functional as F
169
+
170
+ original_sdpa = F.scaled_dot_product_attention
171
+
172
+ def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
173
+ is_causal=False, scale=None):
174
+ # Use MFA for MPS tensors without mask/dropout
175
+ if (query.device.type == 'mps' and
176
+ attn_mask is None and
177
+ dropout_p == 0.0 and
178
+ _HAS_MFA):
179
+ try:
180
+ return flash_attention(query, key, value, is_causal=is_causal, scale=scale)
181
+ except Exception:
182
+ # Fall back to original on any error
183
+ pass
184
+
185
+ return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
186
+
187
+ F.scaled_dot_product_attention = patched_sdpa
188
+ print("MPS Flash Attention: Patched F.scaled_dot_product_attention")
189
+
190
+
191
+ def precompile():
192
+ """
193
+ Pre-compile Metal kernels for common configurations.
194
+
195
+ Call this once after installation to eliminate runtime compilation overhead.
196
+ Pre-compiled kernels are cached to disk and loaded instantly on subsequent runs.
197
+
198
+ This compiles kernels for:
199
+ - Sequence lengths: 64, 128, 256, 512, 1024, 2048, 4096, 8192
200
+ - Head dimensions: 32, 48, 64, 80, 96, 128
201
+ - Both fp32 and fp16 precision
202
+
203
+ Total: 96 kernel configurations
204
+ """
205
+ if not _HAS_MFA:
206
+ raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
207
+
208
+ import ctypes
209
+ import os
210
+
211
+ # Load the Swift bridge directly
212
+ bridge_path = os.environ.get("MFA_BRIDGE_PATH")
213
+ if not bridge_path:
214
+ # Try common locations
215
+ module_dir = os.path.dirname(__file__)
216
+ candidates = [
217
+ os.path.join(module_dir, "lib", "libMFABridge.dylib"), # Bundled in wheel
218
+ os.path.join(module_dir, "..", "swift-bridge", ".build", "release", "libMFABridge.dylib"),
219
+ os.path.join(module_dir, "libMFABridge.dylib"),
220
+ ]
221
+ for path in candidates:
222
+ if os.path.exists(path):
223
+ bridge_path = path
224
+ break
225
+
226
+ if not bridge_path or not os.path.exists(bridge_path):
227
+ raise RuntimeError("Cannot find libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.")
228
+
229
+ lib = ctypes.CDLL(bridge_path)
230
+ lib.mfa_precompile()
231
+ print("\nPre-compilation complete! Kernels cached to disk.")
232
+
233
+
234
+ def clear_cache():
235
+ """Clear the pre-compiled kernel cache."""
236
+ if not _HAS_MFA:
237
+ raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
238
+
239
+ import ctypes
240
+ import os
241
+
242
+ bridge_path = os.environ.get("MFA_BRIDGE_PATH")
243
+ if bridge_path and os.path.exists(bridge_path):
244
+ lib = ctypes.CDLL(bridge_path)
245
+ lib.mfa_clear_cache()
246
+ print("Cache cleared.")
@@ -0,0 +1,441 @@
1
+ /**
2
+ * MPS Flash Attention - PyTorch C++ Extension
3
+ *
4
+ * Bridges PyTorch MPS tensors to the MFA Swift library.
5
+ */
6
+
7
+ #include <torch/extension.h>
8
+ #include <ATen/mps/MPSStream.h>
9
+ #include <ATen/native/mps/OperationUtils.h>
10
+
11
+ #import <Metal/Metal.h>
12
+ #import <Foundation/Foundation.h>
13
+
14
+ #include <dlfcn.h>
15
+
16
+ // ============================================================================
17
+ // MFA Bridge Function Types
18
+ // ============================================================================
19
+
20
+ typedef bool (*mfa_init_fn)();
21
+ typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool); // Added causal param
22
+ typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
23
+ typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
24
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
25
+ int32_t, int32_t);
26
+ typedef void (*mfa_release_kernel_fn)(void*);
27
+
28
+ // Global function pointers
29
+ static mfa_init_fn g_mfa_init = nullptr;
30
+ static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
31
+ static mfa_forward_fn g_mfa_forward = nullptr;
32
+ static mfa_backward_fn g_mfa_backward = nullptr;
33
+ static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
34
+ static void* g_dylib_handle = nullptr;
35
+ static bool g_initialized = false;
36
+
37
+ // ============================================================================
38
+ // Load MFA Bridge Library
39
+ // ============================================================================
40
+
41
+ static bool load_mfa_bridge() {
42
+ if (g_dylib_handle) return true;
43
+
44
+ // Try to find the dylib relative to this extension
45
+ // First try the standard location
46
+ const char* paths[] = {
47
+ "libMFABridge.dylib",
48
+ "./libMFABridge.dylib",
49
+ "../swift-bridge/.build/release/libMFABridge.dylib",
50
+ nullptr
51
+ };
52
+
53
+ for (int i = 0; paths[i] != nullptr; i++) {
54
+ g_dylib_handle = dlopen(paths[i], RTLD_NOW);
55
+ if (g_dylib_handle) break;
56
+ }
57
+
58
+ if (!g_dylib_handle) {
59
+ // Try with absolute path from environment
60
+ const char* mfa_path = getenv("MFA_BRIDGE_PATH");
61
+ if (mfa_path) {
62
+ g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
63
+ }
64
+ }
65
+
66
+ if (!g_dylib_handle) {
67
+ throw std::runtime_error(
68
+ "Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
69
+ }
70
+
71
+ // Load function pointers
72
+ g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
73
+ g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
74
+ g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
75
+ g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
76
+ g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
77
+
78
+ if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward || !g_mfa_backward || !g_mfa_release_kernel) {
79
+ throw std::runtime_error("Failed to load MFA bridge functions");
80
+ }
81
+
82
+ return true;
83
+ }
84
+
85
+ // ============================================================================
86
+ // Get MTLBuffer from PyTorch MPS Tensor
87
+ // ============================================================================
88
+
89
+ struct BufferInfo {
90
+ id<MTLBuffer> buffer;
91
+ int64_t byte_offset;
92
+ };
93
+
94
+ static BufferInfo getBufferInfo(const at::Tensor& tensor) {
95
+ TORCH_CHECK(tensor.device().is_mps(), "Tensor must be on MPS device");
96
+ TORCH_CHECK(tensor.is_contiguous(), "Tensor must be contiguous");
97
+
98
+ // Get the underlying Metal buffer (covers entire storage)
99
+ id<MTLBuffer> buffer = at::native::mps::getMTLBufferStorage(tensor);
100
+
101
+ // Calculate byte offset: storage_offset() is in elements, multiply by element size
102
+ int64_t element_size = tensor.element_size();
103
+ int64_t byte_offset = tensor.storage_offset() * element_size;
104
+
105
+ return {buffer, byte_offset};
106
+ }
107
+
108
+ // ============================================================================
109
+ // Kernel Cache
110
+ // ============================================================================
111
+
112
+ struct KernelCacheKey {
113
+ int64_t seq_len_q;
114
+ int64_t seq_len_kv;
115
+ int64_t head_dim;
116
+ bool low_precision;
117
+ bool low_precision_outputs;
118
+ bool causal;
119
+
120
+ bool operator==(const KernelCacheKey& other) const {
121
+ return seq_len_q == other.seq_len_q &&
122
+ seq_len_kv == other.seq_len_kv &&
123
+ head_dim == other.head_dim &&
124
+ low_precision == other.low_precision &&
125
+ low_precision_outputs == other.low_precision_outputs &&
126
+ causal == other.causal;
127
+ }
128
+ };
129
+
130
+ struct KernelCacheKeyHash {
131
+ size_t operator()(const KernelCacheKey& k) const {
132
+ return std::hash<int64_t>()(k.seq_len_q) ^
133
+ (std::hash<int64_t>()(k.seq_len_kv) << 1) ^
134
+ (std::hash<int64_t>()(k.head_dim) << 2) ^
135
+ (std::hash<bool>()(k.low_precision) << 3) ^
136
+ (std::hash<bool>()(k.low_precision_outputs) << 4) ^
137
+ (std::hash<bool>()(k.causal) << 5);
138
+ }
139
+ };
140
+
141
+ static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
142
+
143
+ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal) {
144
+ KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal};
145
+
146
+ auto it = g_kernel_cache.find(key);
147
+ if (it != g_kernel_cache.end()) {
148
+ return it->second;
149
+ }
150
+
151
+ void* kernel = g_mfa_create_kernel(
152
+ static_cast<int32_t>(seq_q),
153
+ static_cast<int32_t>(seq_kv),
154
+ static_cast<int32_t>(head_dim),
155
+ low_prec,
156
+ low_prec_outputs,
157
+ causal
158
+ );
159
+
160
+ if (!kernel) {
161
+ throw std::runtime_error("Failed to create MFA kernel");
162
+ }
163
+
164
+ g_kernel_cache[key] = kernel;
165
+ return kernel;
166
+ }
167
+
168
+ // ============================================================================
169
+ // Flash Attention Forward
170
+ // ============================================================================
171
+
172
+ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
173
+ const at::Tensor& query, // (B, H, N, D)
174
+ const at::Tensor& key, // (B, H, N, D)
175
+ const at::Tensor& value, // (B, H, N, D)
176
+ bool is_causal
177
+ ) {
178
+ // Initialize MFA on first call
179
+ if (!g_initialized) {
180
+ load_mfa_bridge();
181
+ if (!g_mfa_init()) {
182
+ throw std::runtime_error("Failed to initialize MFA");
183
+ }
184
+ g_initialized = true;
185
+ }
186
+
187
+ // Validate inputs
188
+ TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
189
+ TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
190
+ TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
191
+ TORCH_CHECK(query.device().is_mps(), "Query must be on MPS device");
192
+ TORCH_CHECK(key.device().is_mps(), "Key must be on MPS device");
193
+ TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
194
+
195
+ const int64_t batch_size = query.size(0);
196
+ const int64_t num_heads = query.size(1);
197
+ const int64_t seq_len_q = query.size(2);
198
+ const int64_t head_dim = query.size(3);
199
+ const int64_t seq_len_kv = key.size(2);
200
+
201
+ TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
202
+ "Batch size mismatch");
203
+ TORCH_CHECK(key.size(1) == num_heads && value.size(1) == num_heads,
204
+ "Number of heads mismatch");
205
+ TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
206
+ "Head dimension mismatch");
207
+
208
+ // Determine precision
209
+ bool low_precision = (query.scalar_type() == at::kHalf ||
210
+ query.scalar_type() == at::kBFloat16);
211
+
212
+ // For fp16 inputs, we can now output directly to fp16 (no extra conversion needed!)
213
+ bool low_precision_outputs = low_precision;
214
+
215
+ // Make inputs contiguous
216
+ auto q = query.contiguous();
217
+ auto k = key.contiguous();
218
+ auto v = value.contiguous();
219
+
220
+ // Allocate output in the appropriate precision
221
+ // With lowPrecisionOutputs=true, MFA writes FP16 directly
222
+ at::Tensor output;
223
+ if (low_precision_outputs) {
224
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
225
+ query.options().dtype(at::kHalf));
226
+ } else {
227
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
228
+ query.options().dtype(at::kFloat));
229
+ }
230
+
231
+ // Allocate logsumexp (for backward pass, always fp32)
232
+ auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
233
+ query.options().dtype(at::kFloat));
234
+
235
+ // Get or create kernel with matching output precision and causal mode
236
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
237
+
238
+ // Get Metal buffers with byte offsets
239
+ auto q_info = getBufferInfo(q);
240
+ auto k_info = getBufferInfo(k);
241
+ auto v_info = getBufferInfo(v);
242
+ auto o_info = getBufferInfo(output);
243
+ auto l_info = getBufferInfo(logsumexp);
244
+
245
+ // Synchronize with PyTorch's MPS stream
246
+ @autoreleasepool {
247
+ // Wait for PyTorch operations to complete
248
+ auto stream = at::mps::getCurrentMPSStream();
249
+ stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
250
+
251
+ // Execute MFA with storage byte offsets
252
+ bool success = g_mfa_forward(
253
+ kernel,
254
+ (__bridge void*)q_info.buffer,
255
+ (__bridge void*)k_info.buffer,
256
+ (__bridge void*)v_info.buffer,
257
+ (__bridge void*)o_info.buffer,
258
+ (__bridge void*)l_info.buffer,
259
+ q_info.byte_offset,
260
+ k_info.byte_offset,
261
+ v_info.byte_offset,
262
+ o_info.byte_offset,
263
+ l_info.byte_offset,
264
+ static_cast<int32_t>(batch_size),
265
+ static_cast<int32_t>(num_heads)
266
+ );
267
+
268
+ if (!success) {
269
+ throw std::runtime_error("MFA forward pass failed");
270
+ }
271
+ }
272
+
273
+ // Output is already in the correct dtype (fp16 or fp32)
274
+ // Return both output and logsumexp (needed for backward pass)
275
+ return std::make_tuple(output, logsumexp);
276
+ }
277
+
278
+ // Simple forward that only returns output (for inference)
279
+ at::Tensor mps_flash_attention_forward(
280
+ const at::Tensor& query,
281
+ const at::Tensor& key,
282
+ const at::Tensor& value,
283
+ bool is_causal
284
+ ) {
285
+ auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal);
286
+ return output;
287
+ }
288
+
289
+ // ============================================================================
290
+ // Flash Attention Backward
291
+ // ============================================================================
292
+
293
+ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
294
+ const at::Tensor& grad_output, // (B, H, N, D)
295
+ const at::Tensor& query, // (B, H, N, D)
296
+ const at::Tensor& key, // (B, H, N, D)
297
+ const at::Tensor& value, // (B, H, N, D)
298
+ const at::Tensor& output, // (B, H, N, D)
299
+ const at::Tensor& logsumexp, // (B, H, N)
300
+ bool is_causal
301
+ ) {
302
+ // Initialize MFA on first call
303
+ if (!g_initialized) {
304
+ load_mfa_bridge();
305
+ if (!g_mfa_init()) {
306
+ throw std::runtime_error("Failed to initialize MFA");
307
+ }
308
+ g_initialized = true;
309
+ }
310
+
311
+ // Validate inputs
312
+ TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
313
+ TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
314
+ TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
315
+ TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
316
+ TORCH_CHECK(output.dim() == 4, "Output must be 4D (B, H, N, D)");
317
+ TORCH_CHECK(logsumexp.dim() == 3, "Logsumexp must be 3D (B, H, N)");
318
+
319
+ const int64_t batch_size = query.size(0);
320
+ const int64_t num_heads = query.size(1);
321
+ const int64_t seq_len_q = query.size(2);
322
+ const int64_t head_dim = query.size(3);
323
+ const int64_t seq_len_kv = key.size(2);
324
+
325
+ // Determine precision
326
+ bool low_precision = (query.scalar_type() == at::kHalf ||
327
+ query.scalar_type() == at::kBFloat16);
328
+ bool low_precision_outputs = low_precision;
329
+
330
+ // Make inputs contiguous
331
+ auto q = query.contiguous();
332
+ auto k = key.contiguous();
333
+ auto v = value.contiguous();
334
+ auto o = output.contiguous();
335
+ auto dO = grad_output.contiguous();
336
+ auto lse = logsumexp.contiguous();
337
+
338
+ // Get or create kernel (with causal mode)
339
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
340
+
341
+ // Allocate D buffer (dO * O reduction, always fp32)
342
+ auto D = at::empty({batch_size, num_heads, seq_len_q},
343
+ query.options().dtype(at::kFloat));
344
+
345
+ // Allocate gradients (always fp32 for numerical stability)
346
+ auto dQ = at::zeros({batch_size, num_heads, seq_len_q, head_dim},
347
+ query.options().dtype(at::kFloat));
348
+ auto dK = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
349
+ query.options().dtype(at::kFloat));
350
+ auto dV = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
351
+ query.options().dtype(at::kFloat));
352
+
353
+ // Get Metal buffers with byte offsets
354
+ auto q_info = getBufferInfo(q);
355
+ auto k_info = getBufferInfo(k);
356
+ auto v_info = getBufferInfo(v);
357
+ auto o_info = getBufferInfo(o);
358
+ auto do_info = getBufferInfo(dO);
359
+ auto l_info = getBufferInfo(lse);
360
+ auto d_info = getBufferInfo(D);
361
+ auto dq_info = getBufferInfo(dQ);
362
+ auto dk_info = getBufferInfo(dK);
363
+ auto dv_info = getBufferInfo(dV);
364
+
365
+ // Execute backward pass
366
+ @autoreleasepool {
367
+ auto stream = at::mps::getCurrentMPSStream();
368
+ stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
369
+
370
+ bool success = g_mfa_backward(
371
+ kernel,
372
+ (__bridge void*)q_info.buffer,
373
+ (__bridge void*)k_info.buffer,
374
+ (__bridge void*)v_info.buffer,
375
+ (__bridge void*)o_info.buffer,
376
+ (__bridge void*)do_info.buffer,
377
+ (__bridge void*)l_info.buffer,
378
+ (__bridge void*)d_info.buffer,
379
+ (__bridge void*)dq_info.buffer,
380
+ (__bridge void*)dk_info.buffer,
381
+ (__bridge void*)dv_info.buffer,
382
+ q_info.byte_offset,
383
+ k_info.byte_offset,
384
+ v_info.byte_offset,
385
+ o_info.byte_offset,
386
+ do_info.byte_offset,
387
+ l_info.byte_offset,
388
+ d_info.byte_offset,
389
+ dq_info.byte_offset,
390
+ dk_info.byte_offset,
391
+ dv_info.byte_offset,
392
+ static_cast<int32_t>(batch_size),
393
+ static_cast<int32_t>(num_heads)
394
+ );
395
+
396
+ if (!success) {
397
+ throw std::runtime_error("MFA backward pass failed");
398
+ }
399
+ }
400
+
401
+ // Convert gradients back to input dtype if needed
402
+ if (low_precision) {
403
+ dQ = dQ.to(query.scalar_type());
404
+ dK = dK.to(query.scalar_type());
405
+ dV = dV.to(query.scalar_type());
406
+ }
407
+
408
+ return std::make_tuple(dQ, dK, dV);
409
+ }
410
+
411
+ // ============================================================================
412
+ // Python Bindings
413
+ // ============================================================================
414
+
415
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
416
+ m.doc() = "MPS Flash Attention - Metal accelerated attention for Apple Silicon";
417
+
418
+ m.def("forward", &mps_flash_attention_forward,
419
+ "Flash Attention forward pass (returns output only)",
420
+ py::arg("query"),
421
+ py::arg("key"),
422
+ py::arg("value"),
423
+ py::arg("is_causal") = false);
424
+
425
+ m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
426
+ "Flash Attention forward pass (returns output and logsumexp for backward)",
427
+ py::arg("query"),
428
+ py::arg("key"),
429
+ py::arg("value"),
430
+ py::arg("is_causal") = false);
431
+
432
+ m.def("backward", &mps_flash_attention_backward,
433
+ "Flash Attention backward pass",
434
+ py::arg("grad_output"),
435
+ py::arg("query"),
436
+ py::arg("key"),
437
+ py::arg("value"),
438
+ py::arg("output"),
439
+ py::arg("logsumexp"),
440
+ py::arg("is_causal") = false);
441
+ }