mps-flash-attn 0.1.0__cp314-cp314-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- mps_flash_attn/_C.cpython-314-darwin.so +0 -0
- mps_flash_attn/__init__.py +246 -0
- mps_flash_attn/csrc/mps_flash_attn.cpp +441 -0
- mps_flash_attn/csrc/mps_flash_attn.mm +441 -0
- mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- mps_flash_attn/kernels/manifest.json +27 -0
- mps_flash_attn/lib/libMFABridge.dylib +0 -0
- mps_flash_attn-0.1.0.dist-info/METADATA +264 -0
- mps_flash_attn-0.1.0.dist-info/RECORD +33 -0
- mps_flash_attn-0.1.0.dist-info/WHEEL +5 -0
- mps_flash_attn-0.1.0.dist-info/licenses/LICENSE +27 -0
- mps_flash_attn-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,441 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* MPS Flash Attention - PyTorch C++ Extension
|
|
3
|
+
*
|
|
4
|
+
* Bridges PyTorch MPS tensors to the MFA Swift library.
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
#include <torch/extension.h>
|
|
8
|
+
#include <ATen/mps/MPSStream.h>
|
|
9
|
+
#include <ATen/native/mps/OperationUtils.h>
|
|
10
|
+
|
|
11
|
+
#import <Metal/Metal.h>
|
|
12
|
+
#import <Foundation/Foundation.h>
|
|
13
|
+
|
|
14
|
+
#include <dlfcn.h>
|
|
15
|
+
|
|
16
|
+
// ============================================================================
|
|
17
|
+
// MFA Bridge Function Types
|
|
18
|
+
// ============================================================================
|
|
19
|
+
|
|
20
|
+
typedef bool (*mfa_init_fn)();
|
|
21
|
+
typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool); // Added causal param
|
|
22
|
+
typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
|
|
23
|
+
typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
24
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
25
|
+
int32_t, int32_t);
|
|
26
|
+
typedef void (*mfa_release_kernel_fn)(void*);
|
|
27
|
+
|
|
28
|
+
// Global function pointers
|
|
29
|
+
static mfa_init_fn g_mfa_init = nullptr;
|
|
30
|
+
static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
|
|
31
|
+
static mfa_forward_fn g_mfa_forward = nullptr;
|
|
32
|
+
static mfa_backward_fn g_mfa_backward = nullptr;
|
|
33
|
+
static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
|
|
34
|
+
static void* g_dylib_handle = nullptr;
|
|
35
|
+
static bool g_initialized = false;
|
|
36
|
+
|
|
37
|
+
// ============================================================================
|
|
38
|
+
// Load MFA Bridge Library
|
|
39
|
+
// ============================================================================
|
|
40
|
+
|
|
41
|
+
static bool load_mfa_bridge() {
|
|
42
|
+
if (g_dylib_handle) return true;
|
|
43
|
+
|
|
44
|
+
// Try to find the dylib relative to this extension
|
|
45
|
+
// First try the standard location
|
|
46
|
+
const char* paths[] = {
|
|
47
|
+
"libMFABridge.dylib",
|
|
48
|
+
"./libMFABridge.dylib",
|
|
49
|
+
"../swift-bridge/.build/release/libMFABridge.dylib",
|
|
50
|
+
nullptr
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
for (int i = 0; paths[i] != nullptr; i++) {
|
|
54
|
+
g_dylib_handle = dlopen(paths[i], RTLD_NOW);
|
|
55
|
+
if (g_dylib_handle) break;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if (!g_dylib_handle) {
|
|
59
|
+
// Try with absolute path from environment
|
|
60
|
+
const char* mfa_path = getenv("MFA_BRIDGE_PATH");
|
|
61
|
+
if (mfa_path) {
|
|
62
|
+
g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
if (!g_dylib_handle) {
|
|
67
|
+
throw std::runtime_error(
|
|
68
|
+
"Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
// Load function pointers
|
|
72
|
+
g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
|
|
73
|
+
g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
|
|
74
|
+
g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
|
|
75
|
+
g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
|
|
76
|
+
g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
|
|
77
|
+
|
|
78
|
+
if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward || !g_mfa_backward || !g_mfa_release_kernel) {
|
|
79
|
+
throw std::runtime_error("Failed to load MFA bridge functions");
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
return true;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// ============================================================================
|
|
86
|
+
// Get MTLBuffer from PyTorch MPS Tensor
|
|
87
|
+
// ============================================================================
|
|
88
|
+
|
|
89
|
+
struct BufferInfo {
|
|
90
|
+
id<MTLBuffer> buffer;
|
|
91
|
+
int64_t byte_offset;
|
|
92
|
+
};
|
|
93
|
+
|
|
94
|
+
static BufferInfo getBufferInfo(const at::Tensor& tensor) {
|
|
95
|
+
TORCH_CHECK(tensor.device().is_mps(), "Tensor must be on MPS device");
|
|
96
|
+
TORCH_CHECK(tensor.is_contiguous(), "Tensor must be contiguous");
|
|
97
|
+
|
|
98
|
+
// Get the underlying Metal buffer (covers entire storage)
|
|
99
|
+
id<MTLBuffer> buffer = at::native::mps::getMTLBufferStorage(tensor);
|
|
100
|
+
|
|
101
|
+
// Calculate byte offset: storage_offset() is in elements, multiply by element size
|
|
102
|
+
int64_t element_size = tensor.element_size();
|
|
103
|
+
int64_t byte_offset = tensor.storage_offset() * element_size;
|
|
104
|
+
|
|
105
|
+
return {buffer, byte_offset};
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// ============================================================================
|
|
109
|
+
// Kernel Cache
|
|
110
|
+
// ============================================================================
|
|
111
|
+
|
|
112
|
+
struct KernelCacheKey {
|
|
113
|
+
int64_t seq_len_q;
|
|
114
|
+
int64_t seq_len_kv;
|
|
115
|
+
int64_t head_dim;
|
|
116
|
+
bool low_precision;
|
|
117
|
+
bool low_precision_outputs;
|
|
118
|
+
bool causal;
|
|
119
|
+
|
|
120
|
+
bool operator==(const KernelCacheKey& other) const {
|
|
121
|
+
return seq_len_q == other.seq_len_q &&
|
|
122
|
+
seq_len_kv == other.seq_len_kv &&
|
|
123
|
+
head_dim == other.head_dim &&
|
|
124
|
+
low_precision == other.low_precision &&
|
|
125
|
+
low_precision_outputs == other.low_precision_outputs &&
|
|
126
|
+
causal == other.causal;
|
|
127
|
+
}
|
|
128
|
+
};
|
|
129
|
+
|
|
130
|
+
struct KernelCacheKeyHash {
|
|
131
|
+
size_t operator()(const KernelCacheKey& k) const {
|
|
132
|
+
return std::hash<int64_t>()(k.seq_len_q) ^
|
|
133
|
+
(std::hash<int64_t>()(k.seq_len_kv) << 1) ^
|
|
134
|
+
(std::hash<int64_t>()(k.head_dim) << 2) ^
|
|
135
|
+
(std::hash<bool>()(k.low_precision) << 3) ^
|
|
136
|
+
(std::hash<bool>()(k.low_precision_outputs) << 4) ^
|
|
137
|
+
(std::hash<bool>()(k.causal) << 5);
|
|
138
|
+
}
|
|
139
|
+
};
|
|
140
|
+
|
|
141
|
+
static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
|
|
142
|
+
|
|
143
|
+
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal) {
|
|
144
|
+
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal};
|
|
145
|
+
|
|
146
|
+
auto it = g_kernel_cache.find(key);
|
|
147
|
+
if (it != g_kernel_cache.end()) {
|
|
148
|
+
return it->second;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
void* kernel = g_mfa_create_kernel(
|
|
152
|
+
static_cast<int32_t>(seq_q),
|
|
153
|
+
static_cast<int32_t>(seq_kv),
|
|
154
|
+
static_cast<int32_t>(head_dim),
|
|
155
|
+
low_prec,
|
|
156
|
+
low_prec_outputs,
|
|
157
|
+
causal
|
|
158
|
+
);
|
|
159
|
+
|
|
160
|
+
if (!kernel) {
|
|
161
|
+
throw std::runtime_error("Failed to create MFA kernel");
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
g_kernel_cache[key] = kernel;
|
|
165
|
+
return kernel;
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
// ============================================================================
|
|
169
|
+
// Flash Attention Forward
|
|
170
|
+
// ============================================================================
|
|
171
|
+
|
|
172
|
+
std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
173
|
+
const at::Tensor& query, // (B, H, N, D)
|
|
174
|
+
const at::Tensor& key, // (B, H, N, D)
|
|
175
|
+
const at::Tensor& value, // (B, H, N, D)
|
|
176
|
+
bool is_causal
|
|
177
|
+
) {
|
|
178
|
+
// Initialize MFA on first call
|
|
179
|
+
if (!g_initialized) {
|
|
180
|
+
load_mfa_bridge();
|
|
181
|
+
if (!g_mfa_init()) {
|
|
182
|
+
throw std::runtime_error("Failed to initialize MFA");
|
|
183
|
+
}
|
|
184
|
+
g_initialized = true;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
// Validate inputs
|
|
188
|
+
TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
|
|
189
|
+
TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
|
|
190
|
+
TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
|
|
191
|
+
TORCH_CHECK(query.device().is_mps(), "Query must be on MPS device");
|
|
192
|
+
TORCH_CHECK(key.device().is_mps(), "Key must be on MPS device");
|
|
193
|
+
TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
|
|
194
|
+
|
|
195
|
+
const int64_t batch_size = query.size(0);
|
|
196
|
+
const int64_t num_heads = query.size(1);
|
|
197
|
+
const int64_t seq_len_q = query.size(2);
|
|
198
|
+
const int64_t head_dim = query.size(3);
|
|
199
|
+
const int64_t seq_len_kv = key.size(2);
|
|
200
|
+
|
|
201
|
+
TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
|
|
202
|
+
"Batch size mismatch");
|
|
203
|
+
TORCH_CHECK(key.size(1) == num_heads && value.size(1) == num_heads,
|
|
204
|
+
"Number of heads mismatch");
|
|
205
|
+
TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
|
|
206
|
+
"Head dimension mismatch");
|
|
207
|
+
|
|
208
|
+
// Determine precision
|
|
209
|
+
bool low_precision = (query.scalar_type() == at::kHalf ||
|
|
210
|
+
query.scalar_type() == at::kBFloat16);
|
|
211
|
+
|
|
212
|
+
// For fp16 inputs, we can now output directly to fp16 (no extra conversion needed!)
|
|
213
|
+
bool low_precision_outputs = low_precision;
|
|
214
|
+
|
|
215
|
+
// Make inputs contiguous
|
|
216
|
+
auto q = query.contiguous();
|
|
217
|
+
auto k = key.contiguous();
|
|
218
|
+
auto v = value.contiguous();
|
|
219
|
+
|
|
220
|
+
// Allocate output in the appropriate precision
|
|
221
|
+
// With lowPrecisionOutputs=true, MFA writes FP16 directly
|
|
222
|
+
at::Tensor output;
|
|
223
|
+
if (low_precision_outputs) {
|
|
224
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
225
|
+
query.options().dtype(at::kHalf));
|
|
226
|
+
} else {
|
|
227
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
228
|
+
query.options().dtype(at::kFloat));
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
// Allocate logsumexp (for backward pass, always fp32)
|
|
232
|
+
auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
|
|
233
|
+
query.options().dtype(at::kFloat));
|
|
234
|
+
|
|
235
|
+
// Get or create kernel with matching output precision and causal mode
|
|
236
|
+
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
|
|
237
|
+
|
|
238
|
+
// Get Metal buffers with byte offsets
|
|
239
|
+
auto q_info = getBufferInfo(q);
|
|
240
|
+
auto k_info = getBufferInfo(k);
|
|
241
|
+
auto v_info = getBufferInfo(v);
|
|
242
|
+
auto o_info = getBufferInfo(output);
|
|
243
|
+
auto l_info = getBufferInfo(logsumexp);
|
|
244
|
+
|
|
245
|
+
// Synchronize with PyTorch's MPS stream
|
|
246
|
+
@autoreleasepool {
|
|
247
|
+
// Wait for PyTorch operations to complete
|
|
248
|
+
auto stream = at::mps::getCurrentMPSStream();
|
|
249
|
+
stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
|
|
250
|
+
|
|
251
|
+
// Execute MFA with storage byte offsets
|
|
252
|
+
bool success = g_mfa_forward(
|
|
253
|
+
kernel,
|
|
254
|
+
(__bridge void*)q_info.buffer,
|
|
255
|
+
(__bridge void*)k_info.buffer,
|
|
256
|
+
(__bridge void*)v_info.buffer,
|
|
257
|
+
(__bridge void*)o_info.buffer,
|
|
258
|
+
(__bridge void*)l_info.buffer,
|
|
259
|
+
q_info.byte_offset,
|
|
260
|
+
k_info.byte_offset,
|
|
261
|
+
v_info.byte_offset,
|
|
262
|
+
o_info.byte_offset,
|
|
263
|
+
l_info.byte_offset,
|
|
264
|
+
static_cast<int32_t>(batch_size),
|
|
265
|
+
static_cast<int32_t>(num_heads)
|
|
266
|
+
);
|
|
267
|
+
|
|
268
|
+
if (!success) {
|
|
269
|
+
throw std::runtime_error("MFA forward pass failed");
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
// Output is already in the correct dtype (fp16 or fp32)
|
|
274
|
+
// Return both output and logsumexp (needed for backward pass)
|
|
275
|
+
return std::make_tuple(output, logsumexp);
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
// Simple forward that only returns output (for inference)
|
|
279
|
+
at::Tensor mps_flash_attention_forward(
|
|
280
|
+
const at::Tensor& query,
|
|
281
|
+
const at::Tensor& key,
|
|
282
|
+
const at::Tensor& value,
|
|
283
|
+
bool is_causal
|
|
284
|
+
) {
|
|
285
|
+
auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal);
|
|
286
|
+
return output;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
// ============================================================================
|
|
290
|
+
// Flash Attention Backward
|
|
291
|
+
// ============================================================================
|
|
292
|
+
|
|
293
|
+
std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
294
|
+
const at::Tensor& grad_output, // (B, H, N, D)
|
|
295
|
+
const at::Tensor& query, // (B, H, N, D)
|
|
296
|
+
const at::Tensor& key, // (B, H, N, D)
|
|
297
|
+
const at::Tensor& value, // (B, H, N, D)
|
|
298
|
+
const at::Tensor& output, // (B, H, N, D)
|
|
299
|
+
const at::Tensor& logsumexp, // (B, H, N)
|
|
300
|
+
bool is_causal
|
|
301
|
+
) {
|
|
302
|
+
// Initialize MFA on first call
|
|
303
|
+
if (!g_initialized) {
|
|
304
|
+
load_mfa_bridge();
|
|
305
|
+
if (!g_mfa_init()) {
|
|
306
|
+
throw std::runtime_error("Failed to initialize MFA");
|
|
307
|
+
}
|
|
308
|
+
g_initialized = true;
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
// Validate inputs
|
|
312
|
+
TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
|
|
313
|
+
TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
|
|
314
|
+
TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
|
|
315
|
+
TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
|
|
316
|
+
TORCH_CHECK(output.dim() == 4, "Output must be 4D (B, H, N, D)");
|
|
317
|
+
TORCH_CHECK(logsumexp.dim() == 3, "Logsumexp must be 3D (B, H, N)");
|
|
318
|
+
|
|
319
|
+
const int64_t batch_size = query.size(0);
|
|
320
|
+
const int64_t num_heads = query.size(1);
|
|
321
|
+
const int64_t seq_len_q = query.size(2);
|
|
322
|
+
const int64_t head_dim = query.size(3);
|
|
323
|
+
const int64_t seq_len_kv = key.size(2);
|
|
324
|
+
|
|
325
|
+
// Determine precision
|
|
326
|
+
bool low_precision = (query.scalar_type() == at::kHalf ||
|
|
327
|
+
query.scalar_type() == at::kBFloat16);
|
|
328
|
+
bool low_precision_outputs = low_precision;
|
|
329
|
+
|
|
330
|
+
// Make inputs contiguous
|
|
331
|
+
auto q = query.contiguous();
|
|
332
|
+
auto k = key.contiguous();
|
|
333
|
+
auto v = value.contiguous();
|
|
334
|
+
auto o = output.contiguous();
|
|
335
|
+
auto dO = grad_output.contiguous();
|
|
336
|
+
auto lse = logsumexp.contiguous();
|
|
337
|
+
|
|
338
|
+
// Get or create kernel (with causal mode)
|
|
339
|
+
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
|
|
340
|
+
|
|
341
|
+
// Allocate D buffer (dO * O reduction, always fp32)
|
|
342
|
+
auto D = at::empty({batch_size, num_heads, seq_len_q},
|
|
343
|
+
query.options().dtype(at::kFloat));
|
|
344
|
+
|
|
345
|
+
// Allocate gradients (always fp32 for numerical stability)
|
|
346
|
+
auto dQ = at::zeros({batch_size, num_heads, seq_len_q, head_dim},
|
|
347
|
+
query.options().dtype(at::kFloat));
|
|
348
|
+
auto dK = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
|
|
349
|
+
query.options().dtype(at::kFloat));
|
|
350
|
+
auto dV = at::zeros({batch_size, num_heads, seq_len_kv, head_dim},
|
|
351
|
+
query.options().dtype(at::kFloat));
|
|
352
|
+
|
|
353
|
+
// Get Metal buffers with byte offsets
|
|
354
|
+
auto q_info = getBufferInfo(q);
|
|
355
|
+
auto k_info = getBufferInfo(k);
|
|
356
|
+
auto v_info = getBufferInfo(v);
|
|
357
|
+
auto o_info = getBufferInfo(o);
|
|
358
|
+
auto do_info = getBufferInfo(dO);
|
|
359
|
+
auto l_info = getBufferInfo(lse);
|
|
360
|
+
auto d_info = getBufferInfo(D);
|
|
361
|
+
auto dq_info = getBufferInfo(dQ);
|
|
362
|
+
auto dk_info = getBufferInfo(dK);
|
|
363
|
+
auto dv_info = getBufferInfo(dV);
|
|
364
|
+
|
|
365
|
+
// Execute backward pass
|
|
366
|
+
@autoreleasepool {
|
|
367
|
+
auto stream = at::mps::getCurrentMPSStream();
|
|
368
|
+
stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
|
|
369
|
+
|
|
370
|
+
bool success = g_mfa_backward(
|
|
371
|
+
kernel,
|
|
372
|
+
(__bridge void*)q_info.buffer,
|
|
373
|
+
(__bridge void*)k_info.buffer,
|
|
374
|
+
(__bridge void*)v_info.buffer,
|
|
375
|
+
(__bridge void*)o_info.buffer,
|
|
376
|
+
(__bridge void*)do_info.buffer,
|
|
377
|
+
(__bridge void*)l_info.buffer,
|
|
378
|
+
(__bridge void*)d_info.buffer,
|
|
379
|
+
(__bridge void*)dq_info.buffer,
|
|
380
|
+
(__bridge void*)dk_info.buffer,
|
|
381
|
+
(__bridge void*)dv_info.buffer,
|
|
382
|
+
q_info.byte_offset,
|
|
383
|
+
k_info.byte_offset,
|
|
384
|
+
v_info.byte_offset,
|
|
385
|
+
o_info.byte_offset,
|
|
386
|
+
do_info.byte_offset,
|
|
387
|
+
l_info.byte_offset,
|
|
388
|
+
d_info.byte_offset,
|
|
389
|
+
dq_info.byte_offset,
|
|
390
|
+
dk_info.byte_offset,
|
|
391
|
+
dv_info.byte_offset,
|
|
392
|
+
static_cast<int32_t>(batch_size),
|
|
393
|
+
static_cast<int32_t>(num_heads)
|
|
394
|
+
);
|
|
395
|
+
|
|
396
|
+
if (!success) {
|
|
397
|
+
throw std::runtime_error("MFA backward pass failed");
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
// Convert gradients back to input dtype if needed
|
|
402
|
+
if (low_precision) {
|
|
403
|
+
dQ = dQ.to(query.scalar_type());
|
|
404
|
+
dK = dK.to(query.scalar_type());
|
|
405
|
+
dV = dV.to(query.scalar_type());
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
return std::make_tuple(dQ, dK, dV);
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
// ============================================================================
|
|
412
|
+
// Python Bindings
|
|
413
|
+
// ============================================================================
|
|
414
|
+
|
|
415
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
416
|
+
m.doc() = "MPS Flash Attention - Metal accelerated attention for Apple Silicon";
|
|
417
|
+
|
|
418
|
+
m.def("forward", &mps_flash_attention_forward,
|
|
419
|
+
"Flash Attention forward pass (returns output only)",
|
|
420
|
+
py::arg("query"),
|
|
421
|
+
py::arg("key"),
|
|
422
|
+
py::arg("value"),
|
|
423
|
+
py::arg("is_causal") = false);
|
|
424
|
+
|
|
425
|
+
m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
|
|
426
|
+
"Flash Attention forward pass (returns output and logsumexp for backward)",
|
|
427
|
+
py::arg("query"),
|
|
428
|
+
py::arg("key"),
|
|
429
|
+
py::arg("value"),
|
|
430
|
+
py::arg("is_causal") = false);
|
|
431
|
+
|
|
432
|
+
m.def("backward", &mps_flash_attention_backward,
|
|
433
|
+
"Flash Attention backward pass",
|
|
434
|
+
py::arg("grad_output"),
|
|
435
|
+
py::arg("query"),
|
|
436
|
+
py::arg("key"),
|
|
437
|
+
py::arg("value"),
|
|
438
|
+
py::arg("output"),
|
|
439
|
+
py::arg("logsumexp"),
|
|
440
|
+
py::arg("is_causal") = false);
|
|
441
|
+
}
|
mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib
ADDED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin
ADDED
|
Binary file
|
|
Binary file
|
mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib
ADDED
|
Binary file
|
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib
ADDED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin
ADDED
|
Binary file
|
|
Binary file
|
mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib
ADDED
|
Binary file
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
{
|
|
2
|
+
"version": "1.0",
|
|
3
|
+
"files": [
|
|
4
|
+
"06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib",
|
|
5
|
+
"adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib",
|
|
6
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin",
|
|
7
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin",
|
|
8
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin",
|
|
9
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin",
|
|
10
|
+
"ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib",
|
|
11
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin",
|
|
12
|
+
"a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib",
|
|
13
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib",
|
|
14
|
+
"975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib",
|
|
15
|
+
"2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib",
|
|
16
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin",
|
|
17
|
+
"09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib",
|
|
18
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin",
|
|
19
|
+
"0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib",
|
|
20
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin",
|
|
21
|
+
"73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin",
|
|
22
|
+
"771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib",
|
|
23
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin",
|
|
24
|
+
"eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib",
|
|
25
|
+
"f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib"
|
|
26
|
+
]
|
|
27
|
+
}
|
|
Binary file
|