mps-flash-attn 0.1.8__tar.gz → 0.1.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/PKG-INFO +1 -1
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/__init__.py +30 -4
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/csrc/mps_flash_attn.mm +69 -25
- mps_flash_attn-0.1.14/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/pyproject.toml +1 -1
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/setup.py +21 -0
- mps_flash_attn-0.1.8/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/LICENSE +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/README.md +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/setup.cfg +0 -0
- {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/tests/test_attention.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.14"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -30,6 +30,30 @@ def is_available() -> bool:
|
|
|
30
30
|
return _HAS_MFA and torch.backends.mps.is_available()
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
34
|
+
"""
|
|
35
|
+
Convert attention mask to MFA's boolean format.
|
|
36
|
+
|
|
37
|
+
MFA uses boolean masks where True = masked (don't attend).
|
|
38
|
+
PyTorch SDPA uses additive float masks where -inf/large negative = masked.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
attn_mask: Optional mask, either:
|
|
42
|
+
- None: no mask
|
|
43
|
+
- bool tensor: already in MFA format (True = masked)
|
|
44
|
+
- float tensor: additive mask (large negative = masked)
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Boolean mask suitable for flash_attention(), or None
|
|
48
|
+
"""
|
|
49
|
+
if attn_mask is None:
|
|
50
|
+
return None
|
|
51
|
+
if attn_mask.dtype == torch.bool:
|
|
52
|
+
return attn_mask
|
|
53
|
+
# Float mask: large negative values indicate masked positions
|
|
54
|
+
return attn_mask <= -1e3
|
|
55
|
+
|
|
56
|
+
|
|
33
57
|
class FlashAttentionFunction(torch.autograd.Function):
|
|
34
58
|
"""Autograd function for Flash Attention with backward pass support."""
|
|
35
59
|
|
|
@@ -176,12 +200,13 @@ def replace_sdpa():
|
|
|
176
200
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
177
201
|
is_causal=False, scale=None):
|
|
178
202
|
# Use MFA for MPS tensors without dropout
|
|
179
|
-
# Only use MFA for seq_len >=
|
|
203
|
+
# Only use MFA for seq_len >= 1536 where it outperforms PyTorch's math backend
|
|
180
204
|
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
205
|
+
# Benchmark (BF16, heads=30, head_dim=128): crossover is ~1200-1500
|
|
181
206
|
if (query.device.type == 'mps' and
|
|
182
207
|
dropout_p == 0.0 and
|
|
183
208
|
_HAS_MFA and
|
|
184
|
-
query.shape[2] >=
|
|
209
|
+
query.shape[2] >= 1536):
|
|
185
210
|
try:
|
|
186
211
|
# Convert float mask to bool mask if needed
|
|
187
212
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
@@ -194,7 +219,8 @@ def replace_sdpa():
|
|
|
194
219
|
else:
|
|
195
220
|
# Float mask: typically -inf for masked positions, 0 for unmasked
|
|
196
221
|
# Convert: positions with large negative values -> True (masked)
|
|
197
|
-
|
|
222
|
+
# Use -1e3 threshold to catch -1000, -10000, -inf, etc.
|
|
223
|
+
mfa_mask = attn_mask <= -1e3
|
|
198
224
|
return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
199
225
|
except Exception:
|
|
200
226
|
# Fall back to original on any error
|
|
@@ -21,6 +21,7 @@
|
|
|
21
21
|
|
|
22
22
|
typedef bool (*mfa_init_fn)();
|
|
23
23
|
typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool);
|
|
24
|
+
typedef void* (*mfa_create_kernel_v2_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool);
|
|
24
25
|
// New zero-sync encode functions that take PyTorch's command encoder
|
|
25
26
|
// Added mask_ptr and mask_offset parameters
|
|
26
27
|
typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
|
|
@@ -41,6 +42,7 @@ typedef void (*mfa_release_kernel_fn)(void*);
|
|
|
41
42
|
// Global function pointers
|
|
42
43
|
static mfa_init_fn g_mfa_init = nullptr;
|
|
43
44
|
static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
|
|
45
|
+
static mfa_create_kernel_v2_fn g_mfa_create_kernel_v2 = nullptr;
|
|
44
46
|
static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
|
|
45
47
|
static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
|
|
46
48
|
static mfa_forward_fn g_mfa_forward = nullptr;
|
|
@@ -99,6 +101,7 @@ static bool load_mfa_bridge() {
|
|
|
99
101
|
// Load function pointers
|
|
100
102
|
g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
|
|
101
103
|
g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
|
|
104
|
+
g_mfa_create_kernel_v2 = (mfa_create_kernel_v2_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v2");
|
|
102
105
|
g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
|
|
103
106
|
g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
|
|
104
107
|
g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
|
|
@@ -148,6 +151,7 @@ struct KernelCacheKey {
|
|
|
148
151
|
bool low_precision_outputs;
|
|
149
152
|
bool causal;
|
|
150
153
|
bool has_mask;
|
|
154
|
+
bool use_bf16;
|
|
151
155
|
|
|
152
156
|
bool operator==(const KernelCacheKey& other) const {
|
|
153
157
|
return seq_len_q == other.seq_len_q &&
|
|
@@ -156,7 +160,8 @@ struct KernelCacheKey {
|
|
|
156
160
|
low_precision == other.low_precision &&
|
|
157
161
|
low_precision_outputs == other.low_precision_outputs &&
|
|
158
162
|
causal == other.causal &&
|
|
159
|
-
has_mask == other.has_mask
|
|
163
|
+
has_mask == other.has_mask &&
|
|
164
|
+
use_bf16 == other.use_bf16;
|
|
160
165
|
}
|
|
161
166
|
};
|
|
162
167
|
|
|
@@ -168,29 +173,46 @@ struct KernelCacheKeyHash {
|
|
|
168
173
|
(std::hash<bool>()(k.low_precision) << 3) ^
|
|
169
174
|
(std::hash<bool>()(k.low_precision_outputs) << 4) ^
|
|
170
175
|
(std::hash<bool>()(k.causal) << 5) ^
|
|
171
|
-
(std::hash<bool>()(k.has_mask) << 6)
|
|
176
|
+
(std::hash<bool>()(k.has_mask) << 6) ^
|
|
177
|
+
(std::hash<bool>()(k.use_bf16) << 7);
|
|
172
178
|
}
|
|
173
179
|
};
|
|
174
180
|
|
|
175
181
|
static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
|
|
176
182
|
|
|
177
|
-
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask) {
|
|
178
|
-
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask};
|
|
183
|
+
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false) {
|
|
184
|
+
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16};
|
|
179
185
|
|
|
180
186
|
auto it = g_kernel_cache.find(key);
|
|
181
187
|
if (it != g_kernel_cache.end()) {
|
|
182
188
|
return it->second;
|
|
183
189
|
}
|
|
184
190
|
|
|
185
|
-
void* kernel =
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
191
|
+
void* kernel = nullptr;
|
|
192
|
+
if (use_bf16 && g_mfa_create_kernel_v2) {
|
|
193
|
+
// Use v2 API with BF16 support
|
|
194
|
+
kernel = g_mfa_create_kernel_v2(
|
|
195
|
+
static_cast<int32_t>(seq_q),
|
|
196
|
+
static_cast<int32_t>(seq_kv),
|
|
197
|
+
static_cast<int32_t>(head_dim),
|
|
198
|
+
low_prec,
|
|
199
|
+
low_prec_outputs,
|
|
200
|
+
causal,
|
|
201
|
+
has_mask,
|
|
202
|
+
use_bf16
|
|
203
|
+
);
|
|
204
|
+
} else {
|
|
205
|
+
// Legacy API
|
|
206
|
+
kernel = g_mfa_create_kernel(
|
|
207
|
+
static_cast<int32_t>(seq_q),
|
|
208
|
+
static_cast<int32_t>(seq_kv),
|
|
209
|
+
static_cast<int32_t>(head_dim),
|
|
210
|
+
low_prec,
|
|
211
|
+
low_prec_outputs,
|
|
212
|
+
causal,
|
|
213
|
+
has_mask
|
|
214
|
+
);
|
|
215
|
+
}
|
|
194
216
|
|
|
195
217
|
if (!kernel) {
|
|
196
218
|
throw std::runtime_error("Failed to create MFA kernel");
|
|
@@ -261,18 +283,27 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
261
283
|
v_expanded = value;
|
|
262
284
|
}
|
|
263
285
|
|
|
264
|
-
// Determine precision
|
|
265
|
-
bool
|
|
266
|
-
|
|
286
|
+
// Determine precision - MFA kernel supports FP16, BF16, and FP32
|
|
287
|
+
bool is_bfloat16 = (query.scalar_type() == at::kBFloat16);
|
|
288
|
+
bool is_fp16 = (query.scalar_type() == at::kHalf);
|
|
267
289
|
|
|
268
|
-
//
|
|
269
|
-
bool
|
|
290
|
+
// Use native BF16 kernel if available, otherwise fall back to FP32
|
|
291
|
+
bool use_bf16_kernel = is_bfloat16 && g_mfa_create_kernel_v2;
|
|
292
|
+
bool low_precision = is_fp16; // FP16 path
|
|
293
|
+
bool low_precision_outputs = is_fp16 || use_bf16_kernel;
|
|
270
294
|
|
|
271
295
|
// Make inputs contiguous
|
|
272
296
|
auto q = query.contiguous();
|
|
273
297
|
auto k = k_expanded.contiguous();
|
|
274
298
|
auto v = v_expanded.contiguous();
|
|
275
299
|
|
|
300
|
+
// For BF16 without native kernel support, convert to FP32 (avoids FP16 overflow)
|
|
301
|
+
if (is_bfloat16 && !use_bf16_kernel) {
|
|
302
|
+
q = q.to(at::kFloat);
|
|
303
|
+
k = k.to(at::kFloat);
|
|
304
|
+
v = v.to(at::kFloat);
|
|
305
|
+
}
|
|
306
|
+
|
|
276
307
|
// Handle attention mask
|
|
277
308
|
bool has_mask = attn_mask.has_value();
|
|
278
309
|
at::Tensor mask;
|
|
@@ -287,20 +318,28 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
287
318
|
}
|
|
288
319
|
TORCH_CHECK(mask.scalar_type() == at::kByte,
|
|
289
320
|
"Attention mask must be bool or uint8");
|
|
290
|
-
// Expand mask
|
|
291
|
-
|
|
321
|
+
// Expand mask dimensions if needed -> (B, H, N_q, N_kv)
|
|
322
|
+
// Handle (B, 1, N_q, N_kv) -> expand heads
|
|
323
|
+
// Handle (B, H, 1, N_kv) -> expand query dim (1D key mask)
|
|
324
|
+
// Handle (B, 1, 1, N_kv) -> expand both
|
|
325
|
+
if (mask.size(1) == 1 || mask.size(2) == 1) {
|
|
292
326
|
mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
|
|
293
327
|
}
|
|
294
328
|
mask = mask.contiguous();
|
|
295
329
|
}
|
|
296
330
|
|
|
297
331
|
// Allocate output in the appropriate precision
|
|
298
|
-
// With lowPrecisionOutputs=true, MFA writes FP16 directly
|
|
299
332
|
at::Tensor output;
|
|
300
|
-
if (
|
|
333
|
+
if (use_bf16_kernel) {
|
|
334
|
+
// Native BF16 kernel outputs BF16
|
|
335
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
336
|
+
query.options().dtype(at::kBFloat16));
|
|
337
|
+
} else if (low_precision_outputs) {
|
|
338
|
+
// FP16 kernel outputs FP16
|
|
301
339
|
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
302
340
|
query.options().dtype(at::kHalf));
|
|
303
341
|
} else {
|
|
342
|
+
// FP32 kernel outputs FP32
|
|
304
343
|
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
305
344
|
query.options().dtype(at::kFloat));
|
|
306
345
|
}
|
|
@@ -309,8 +348,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
309
348
|
auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
|
|
310
349
|
query.options().dtype(at::kFloat));
|
|
311
350
|
|
|
312
|
-
// Get or create kernel with matching
|
|
313
|
-
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask);
|
|
351
|
+
// Get or create kernel with matching precision and causal mode
|
|
352
|
+
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask, use_bf16_kernel);
|
|
314
353
|
|
|
315
354
|
// Get Metal buffers with byte offsets
|
|
316
355
|
auto q_info = getBufferInfo(q);
|
|
@@ -361,7 +400,12 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
361
400
|
// The encoder stays open for coalescing more kernels
|
|
362
401
|
}
|
|
363
402
|
|
|
364
|
-
//
|
|
403
|
+
// Convert output back to BF16 if input was BF16 and we used FP32 fallback
|
|
404
|
+
// (native BF16 kernel already outputs BF16, no conversion needed)
|
|
405
|
+
if (is_bfloat16 && !use_bf16_kernel) {
|
|
406
|
+
output = output.to(at::kBFloat16);
|
|
407
|
+
}
|
|
408
|
+
|
|
365
409
|
// Return both output and logsumexp (needed for backward pass)
|
|
366
410
|
return std::make_tuple(output, logsumexp);
|
|
367
411
|
}
|
|
Binary file
|
|
@@ -4,6 +4,7 @@ Setup script for MPS Flash Attention
|
|
|
4
4
|
|
|
5
5
|
import os
|
|
6
6
|
import sys
|
|
7
|
+
import shutil
|
|
7
8
|
from setuptools import setup, find_packages, Extension
|
|
8
9
|
from setuptools.command.build_ext import build_ext
|
|
9
10
|
|
|
@@ -36,6 +37,26 @@ class ObjCppBuildExt(build_ext):
|
|
|
36
37
|
|
|
37
38
|
super().build_extensions()
|
|
38
39
|
|
|
40
|
+
# Copy libMFABridge.dylib to lib/ after building
|
|
41
|
+
self._copy_swift_bridge()
|
|
42
|
+
|
|
43
|
+
def _copy_swift_bridge(self):
|
|
44
|
+
"""Copy Swift bridge dylib to package lib/ directory."""
|
|
45
|
+
src_path = os.path.join(
|
|
46
|
+
os.path.dirname(__file__),
|
|
47
|
+
"swift-bridge", ".build", "release", "libMFABridge.dylib"
|
|
48
|
+
)
|
|
49
|
+
dst_dir = os.path.join(os.path.dirname(__file__), "mps_flash_attn", "lib")
|
|
50
|
+
dst_path = os.path.join(dst_dir, "libMFABridge.dylib")
|
|
51
|
+
|
|
52
|
+
if os.path.exists(src_path):
|
|
53
|
+
os.makedirs(dst_dir, exist_ok=True)
|
|
54
|
+
shutil.copy2(src_path, dst_path)
|
|
55
|
+
print(f"Copied libMFABridge.dylib to {dst_path}")
|
|
56
|
+
else:
|
|
57
|
+
print(f"Warning: {src_path} not found. Build swift-bridge first with:")
|
|
58
|
+
print(" cd swift-bridge && swift build -c release")
|
|
59
|
+
|
|
39
60
|
|
|
40
61
|
def get_extensions():
|
|
41
62
|
if sys.platform != "darwin":
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|