mps-flash-attn 0.2.8__cp314-cp314-macosx_15_0_arm64.whl → 0.3.1__cp314-cp314-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- mps_flash_attn/_C.cpython-314-darwin.so +0 -0
- mps_flash_attn/__init__.py +229 -52
- mps_flash_attn/csrc/mps_flash_attn.mm +30 -33
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.1.dist-info}/METADATA +1 -1
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.1.dist-info}/RECORD +8 -8
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.1.dist-info}/WHEEL +0 -0
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.1.dist-info}/top_level.txt +0 -0
|
Binary file
|
mps_flash_attn/__init__.py
CHANGED
|
@@ -4,13 +4,42 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.
|
|
7
|
+
__version__ = "0.3.1"
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
# Core functions
|
|
11
|
+
"flash_attention",
|
|
12
|
+
"flash_attention_with_bias",
|
|
13
|
+
"flash_attention_chunked",
|
|
14
|
+
# Quantized attention
|
|
15
|
+
"flash_attention_fp8",
|
|
16
|
+
"flash_attention_int8",
|
|
17
|
+
"flash_attention_nf4",
|
|
18
|
+
"quantize_kv_fp8",
|
|
19
|
+
"quantize_kv_int8",
|
|
20
|
+
"quantize_kv_nf4",
|
|
21
|
+
# Utilities
|
|
22
|
+
"replace_sdpa",
|
|
23
|
+
"precompile",
|
|
24
|
+
"clear_cache",
|
|
25
|
+
"register_custom_op",
|
|
26
|
+
"is_available",
|
|
27
|
+
"convert_mask",
|
|
28
|
+
# Constants
|
|
29
|
+
"QUANT_FP8_E4M3",
|
|
30
|
+
"QUANT_FP8_E5M2",
|
|
31
|
+
"QUANT_INT8",
|
|
32
|
+
"QUANT_NF4",
|
|
33
|
+
# Version
|
|
34
|
+
"__version__",
|
|
35
|
+
]
|
|
8
36
|
|
|
9
37
|
import torch
|
|
10
|
-
from typing import Optional
|
|
38
|
+
from typing import Optional, Tuple
|
|
11
39
|
import math
|
|
12
40
|
import threading
|
|
13
41
|
import os
|
|
42
|
+
import warnings
|
|
14
43
|
|
|
15
44
|
# Try to import the C++ extension
|
|
16
45
|
try:
|
|
@@ -30,6 +59,20 @@ def is_available() -> bool:
|
|
|
30
59
|
return _HAS_MFA and torch.backends.mps.is_available()
|
|
31
60
|
|
|
32
61
|
|
|
62
|
+
def _ensure_contiguous(tensor: torch.Tensor, name: str) -> torch.Tensor:
|
|
63
|
+
"""Ensure tensor is contiguous, with a debug warning if conversion needed."""
|
|
64
|
+
if tensor.is_contiguous():
|
|
65
|
+
return tensor
|
|
66
|
+
# Auto-convert with debug info
|
|
67
|
+
if os.environ.get("MFA_DEBUG", "0") == "1":
|
|
68
|
+
warnings.warn(
|
|
69
|
+
f"MFA: {name} tensor was not contiguous (stride={tensor.stride()}), "
|
|
70
|
+
f"auto-converting. For best performance, ensure inputs are contiguous.",
|
|
71
|
+
UserWarning
|
|
72
|
+
)
|
|
73
|
+
return tensor.contiguous()
|
|
74
|
+
|
|
75
|
+
|
|
33
76
|
def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
34
77
|
"""
|
|
35
78
|
Convert attention mask to MFA's boolean format.
|
|
@@ -176,6 +219,20 @@ def flash_attention(
|
|
|
176
219
|
if not torch.backends.mps.is_available():
|
|
177
220
|
raise RuntimeError("MPS not available")
|
|
178
221
|
|
|
222
|
+
# Validate scale parameter
|
|
223
|
+
if scale is not None:
|
|
224
|
+
if scale <= 0:
|
|
225
|
+
raise ValueError(f"scale must be positive, got {scale}")
|
|
226
|
+
# Warn about extreme scale values that could cause numerical issues
|
|
227
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
228
|
+
if scale < default_scale * 0.01 or scale > default_scale * 100:
|
|
229
|
+
warnings.warn(
|
|
230
|
+
f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
|
|
231
|
+
"this may cause numerical issues",
|
|
232
|
+
UserWarning,
|
|
233
|
+
stacklevel=2
|
|
234
|
+
)
|
|
235
|
+
|
|
179
236
|
# Validate device
|
|
180
237
|
if query.device.type != 'mps':
|
|
181
238
|
raise ValueError("query must be on MPS device")
|
|
@@ -186,6 +243,37 @@ def flash_attention(
|
|
|
186
243
|
if attn_mask is not None and attn_mask.device.type != 'mps':
|
|
187
244
|
raise ValueError("attn_mask must be on MPS device")
|
|
188
245
|
|
|
246
|
+
# Ensure contiguous (auto-convert with debug warning)
|
|
247
|
+
query = _ensure_contiguous(query, "query")
|
|
248
|
+
key = _ensure_contiguous(key, "key")
|
|
249
|
+
value = _ensure_contiguous(value, "value")
|
|
250
|
+
if attn_mask is not None:
|
|
251
|
+
attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
|
|
252
|
+
# Validate mask shape
|
|
253
|
+
B, H, N_q, D = query.shape
|
|
254
|
+
N_kv = key.shape[2]
|
|
255
|
+
if attn_mask.dim() != 4:
|
|
256
|
+
raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
|
|
257
|
+
mb, mh, mq, mk = attn_mask.shape
|
|
258
|
+
# Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
|
|
259
|
+
if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
|
|
262
|
+
)
|
|
263
|
+
# Expand broadcast mask to full shape for Metal kernel
|
|
264
|
+
if mq == 1 and N_q > 1:
|
|
265
|
+
attn_mask = attn_mask.expand(mb, mh, N_q, mk)
|
|
266
|
+
if mk == 1 and N_kv > 1:
|
|
267
|
+
attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
|
|
268
|
+
if mb != 1 and mb != B:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
f"attn_mask batch size must be 1 or {B}, got {mb}"
|
|
271
|
+
)
|
|
272
|
+
if mh != 1 and mh != H:
|
|
273
|
+
raise ValueError(
|
|
274
|
+
f"attn_mask head count must be 1 or {H}, got {mh}"
|
|
275
|
+
)
|
|
276
|
+
|
|
189
277
|
# Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
|
|
190
278
|
if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
|
|
191
279
|
# Apply scale if provided
|
|
@@ -213,6 +301,10 @@ def replace_sdpa():
|
|
|
213
301
|
import torch.nn.functional as F
|
|
214
302
|
|
|
215
303
|
original_sdpa = F.scaled_dot_product_attention
|
|
304
|
+
_debug = os.environ.get("MFA_DEBUG", "0") == "1"
|
|
305
|
+
_call_count = [0] # mutable for closure
|
|
306
|
+
_fallback_count = [0] # track fallbacks for warning
|
|
307
|
+
_last_fallback_error = [None]
|
|
216
308
|
|
|
217
309
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
218
310
|
is_causal=False, scale=None, enable_gqa=False, **kwargs):
|
|
@@ -224,37 +316,92 @@ def replace_sdpa():
|
|
|
224
316
|
# seq=1024: 2.3-3.7x (MFA much faster)
|
|
225
317
|
# seq=2048: 2.2-3.9x (MFA much faster)
|
|
226
318
|
# seq=4096: 2.1-3.7x (MFA much faster)
|
|
319
|
+
# Determine seq_len based on tensor dimensionality
|
|
320
|
+
# 4D: (B, H, S, D) -> seq_len = shape[2]
|
|
321
|
+
# 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
|
|
322
|
+
is_3d = query.ndim == 3
|
|
323
|
+
seq_len = query.shape[1] if is_3d else query.shape[2]
|
|
324
|
+
|
|
227
325
|
if (query.device.type == 'mps' and
|
|
228
326
|
dropout_p == 0.0 and
|
|
229
327
|
_HAS_MFA and
|
|
230
|
-
query.
|
|
328
|
+
query.ndim >= 3 and
|
|
329
|
+
seq_len >= 512):
|
|
231
330
|
try:
|
|
331
|
+
q, k, v = query, key, value
|
|
332
|
+
|
|
333
|
+
# Handle 3D tensors (B, S, D) - treat as single-head attention
|
|
334
|
+
# Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
|
|
335
|
+
if is_3d:
|
|
336
|
+
q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
|
|
337
|
+
k = k.unsqueeze(1)
|
|
338
|
+
v = v.unsqueeze(1)
|
|
339
|
+
|
|
232
340
|
# Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
|
|
233
341
|
# Common in Llama 2/3, Mistral, Qwen, etc.
|
|
234
|
-
|
|
235
|
-
|
|
342
|
+
# NOTE: Always expand when heads mismatch, not just when enable_gqa=True
|
|
343
|
+
# Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
|
|
344
|
+
# MPS SDPA doesn't support native GQA - we handle it here
|
|
345
|
+
if q.shape[1] != k.shape[1]:
|
|
236
346
|
# Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
|
|
237
|
-
n_rep =
|
|
238
|
-
k =
|
|
239
|
-
v =
|
|
347
|
+
n_rep = q.shape[1] // k.shape[1]
|
|
348
|
+
k = k.repeat_interleave(n_rep, dim=1)
|
|
349
|
+
v = v.repeat_interleave(n_rep, dim=1)
|
|
240
350
|
|
|
241
351
|
# Convert float mask to bool mask if needed
|
|
242
352
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
243
353
|
# MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
|
|
244
354
|
mfa_mask = None
|
|
245
355
|
if attn_mask is not None:
|
|
356
|
+
if _debug:
|
|
357
|
+
print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
|
|
246
358
|
if attn_mask.dtype == torch.bool:
|
|
247
|
-
#
|
|
248
|
-
|
|
359
|
+
# PyTorch SDPA bool mask: True = ATTEND, False = MASKED
|
|
360
|
+
# MFA bool mask: True = MASKED, False = ATTEND
|
|
361
|
+
# They're opposite! Invert it.
|
|
362
|
+
mfa_mask = ~attn_mask
|
|
249
363
|
else:
|
|
250
364
|
# Float mask: typically -inf for masked positions, 0 for unmasked
|
|
251
365
|
# Convert: positions with large negative values -> True (masked)
|
|
252
366
|
# Use -1e3 threshold to catch -1000, -10000, -inf, etc.
|
|
253
367
|
mfa_mask = attn_mask <= -1e3
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
368
|
+
if _debug:
|
|
369
|
+
print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
|
|
370
|
+
|
|
371
|
+
out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
372
|
+
|
|
373
|
+
# Squeeze back for 3D input
|
|
374
|
+
if is_3d:
|
|
375
|
+
out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
|
|
376
|
+
|
|
377
|
+
if _debug:
|
|
378
|
+
_call_count[0] += 1
|
|
379
|
+
print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
|
|
380
|
+
|
|
381
|
+
return out
|
|
382
|
+
except Exception as e:
|
|
383
|
+
# Fall back to original on any error, but track it
|
|
384
|
+
_fallback_count[0] += 1
|
|
385
|
+
_last_fallback_error[0] = str(e)
|
|
386
|
+
if _debug:
|
|
387
|
+
import traceback
|
|
388
|
+
print(f"[MFA FALLBACK #{_fallback_count[0]}] shape={tuple(query.shape)}\n{traceback.format_exc()}")
|
|
389
|
+
# Warn user after repeated fallbacks (likely a real problem)
|
|
390
|
+
if _fallback_count[0] == 10:
|
|
391
|
+
warnings.warn(
|
|
392
|
+
f"MFA has fallen back to native SDPA {_fallback_count[0]} times. "
|
|
393
|
+
f"Last error: {_last_fallback_error[0]}. "
|
|
394
|
+
f"Set MFA_DEBUG=1 for details.",
|
|
395
|
+
UserWarning
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
if _debug and query.device.type == 'mps':
|
|
399
|
+
_call_count[0] += 1
|
|
400
|
+
reason = []
|
|
401
|
+
if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
|
|
402
|
+
if query.ndim < 3: reason.append(f"ndim={query.ndim}")
|
|
403
|
+
if seq_len < 512: reason.append(f"seq={seq_len}<512")
|
|
404
|
+
print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
|
|
258
405
|
|
|
259
406
|
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
|
|
260
407
|
|
|
@@ -461,13 +608,13 @@ def flash_attention_chunked(
|
|
|
461
608
|
return _C.forward(query, key, value, is_causal, None, 0)
|
|
462
609
|
|
|
463
610
|
# Initialize running statistics for online softmax
|
|
464
|
-
# m = running max, l = running sum of exp, acc = accumulated output
|
|
465
611
|
device = query.device
|
|
466
612
|
dtype = query.dtype
|
|
467
613
|
|
|
468
614
|
# Use float32 for numerical stability of softmax statistics
|
|
469
|
-
|
|
470
|
-
|
|
615
|
+
# running_L: base-2 logsumexp of all attention scores seen so far (-inf means no data yet)
|
|
616
|
+
# output_acc: weighted combination of outputs (weights sum to 1 after each update)
|
|
617
|
+
running_L = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
|
|
471
618
|
output_acc = torch.zeros((B, H, seq_len_q, D), device=device, dtype=torch.float32)
|
|
472
619
|
|
|
473
620
|
# Process K/V in chunks
|
|
@@ -488,51 +635,81 @@ def flash_attention_chunked(
|
|
|
488
635
|
# - Partial chunk (up to q) if start_idx <= q < end_idx
|
|
489
636
|
# - None of chunk if q < start_idx
|
|
490
637
|
|
|
491
|
-
|
|
638
|
+
if is_causal:
|
|
639
|
+
# Create explicit causal mask for this chunk
|
|
640
|
+
# Query positions: 0 to seq_len_q-1
|
|
641
|
+
# Key positions in chunk: start_idx to end_idx-1
|
|
642
|
+
chunk_len = end_idx - start_idx
|
|
492
643
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
644
|
+
# Build mask: mask[q, k_local] = True means DON'T attend
|
|
645
|
+
# We want to attend when global_k_pos <= q
|
|
646
|
+
# global_k_pos = start_idx + k_local
|
|
647
|
+
# So: attend when start_idx + k_local <= q
|
|
648
|
+
# mask = start_idx + k_local > q
|
|
496
649
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
650
|
+
q_pos = torch.arange(seq_len_q, device=device).view(1, 1, seq_len_q, 1)
|
|
651
|
+
k_pos = torch.arange(chunk_len, device=device).view(1, 1, 1, chunk_len) + start_idx
|
|
652
|
+
causal_mask = k_pos > q_pos # True = masked (don't attend)
|
|
500
653
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
# Online softmax update:
|
|
505
|
-
# new_max = max(running_max, chunk_max)
|
|
506
|
-
# For flash attention, chunk_lse ≈ chunk_max + log(chunk_sum)
|
|
507
|
-
# We approximate chunk_max ≈ chunk_lse (valid when exp sum dominates)
|
|
508
|
-
|
|
509
|
-
chunk_max = chunk_lse # Approximation: logsumexp ≈ max when sum is dominated by max
|
|
654
|
+
# Expand to batch and heads
|
|
655
|
+
causal_mask = causal_mask.expand(B, H, seq_len_q, chunk_len)
|
|
510
656
|
|
|
511
|
-
|
|
512
|
-
|
|
657
|
+
# Call forward with explicit mask (is_causal=False since we handle it)
|
|
658
|
+
chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, causal_mask, 0)
|
|
659
|
+
else:
|
|
660
|
+
# Non-causal: just process the chunk directly
|
|
661
|
+
chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, None, 0)
|
|
513
662
|
|
|
514
|
-
#
|
|
515
|
-
#
|
|
516
|
-
|
|
517
|
-
#
|
|
518
|
-
|
|
663
|
+
# chunk_L shape: (B, H, seq_len_q)
|
|
664
|
+
# The kernel returns L = m + log2(l) where:
|
|
665
|
+
# m = max(scores * log2(e) / sqrt(D))
|
|
666
|
+
# l = sum(exp2(scores * log2(e) / sqrt(D) - m))
|
|
667
|
+
# This is a base-2 logsumexp: L = log2(sum(exp2(scaled_scores)))
|
|
668
|
+
chunk_L = chunk_lse.unsqueeze(-1).float() # (B, H, seq_len_q, 1)
|
|
519
669
|
|
|
520
|
-
#
|
|
521
|
-
|
|
522
|
-
correction_new = torch.exp(chunk_max - new_max)
|
|
670
|
+
# Convert chunk output to float32 for accumulation
|
|
671
|
+
chunk_out = chunk_out.float()
|
|
523
672
|
|
|
524
|
-
#
|
|
525
|
-
#
|
|
526
|
-
#
|
|
527
|
-
|
|
673
|
+
# Online softmax algorithm using base-2 representation
|
|
674
|
+
#
|
|
675
|
+
# Flash attention returns: chunk_out = softmax(scores) @ V
|
|
676
|
+
# The output is already normalized. For online combination:
|
|
677
|
+
# new_L = log2(2^running_L + 2^chunk_L)
|
|
678
|
+
# = max(running_L, chunk_L) + log2(2^(running_L - max) + 2^(chunk_L - max))
|
|
679
|
+
#
|
|
680
|
+
# The weights for combining outputs are:
|
|
681
|
+
# old_weight = 2^(running_L - new_L)
|
|
682
|
+
# new_weight = 2^(chunk_L - new_L)
|
|
683
|
+
# These weights sum to 1, so: output = old_weight * old_out + new_weight * new_out
|
|
684
|
+
|
|
685
|
+
# Compute new base-2 logsumexp
|
|
686
|
+
max_L = torch.maximum(running_L, chunk_L)
|
|
687
|
+
|
|
688
|
+
# Handle -inf case (no previous data)
|
|
689
|
+
# Use exp2 for base-2 (matches kernel's internal representation)
|
|
690
|
+
running_exp2 = torch.where(
|
|
691
|
+
running_L == float('-inf'),
|
|
692
|
+
torch.zeros_like(running_L),
|
|
693
|
+
torch.exp2(running_L - max_L)
|
|
694
|
+
)
|
|
695
|
+
chunk_exp2 = torch.exp2(chunk_L - max_L)
|
|
696
|
+
new_L = max_L + torch.log2(running_exp2 + chunk_exp2)
|
|
697
|
+
|
|
698
|
+
# Compute correction factors using base-2 exp
|
|
699
|
+
old_weight = torch.where(
|
|
700
|
+
running_L == float('-inf'),
|
|
701
|
+
torch.zeros_like(running_L),
|
|
702
|
+
torch.exp2(running_L - new_L)
|
|
703
|
+
)
|
|
704
|
+
new_weight = torch.exp2(chunk_L - new_L)
|
|
528
705
|
|
|
529
706
|
# Update accumulator
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
707
|
+
# Update accumulator
|
|
708
|
+
output_acc = output_acc * old_weight + chunk_out * new_weight
|
|
709
|
+
running_L = new_L
|
|
533
710
|
|
|
534
|
-
#
|
|
535
|
-
output = output_acc
|
|
711
|
+
# No final normalization needed - weights already sum to 1
|
|
712
|
+
output = output_acc
|
|
536
713
|
|
|
537
714
|
# Convert back to original dtype
|
|
538
715
|
return output.to(dtype)
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
#include <dlfcn.h>
|
|
15
15
|
#include <string>
|
|
16
16
|
#include <vector>
|
|
17
|
+
#include <mutex>
|
|
18
|
+
#include <atomic>
|
|
17
19
|
|
|
18
20
|
// ============================================================================
|
|
19
21
|
// MFA Bridge Function Types
|
|
@@ -67,7 +69,8 @@ static mfa_forward_fn g_mfa_forward = nullptr;
|
|
|
67
69
|
static mfa_backward_fn g_mfa_backward = nullptr;
|
|
68
70
|
static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
|
|
69
71
|
static void* g_dylib_handle = nullptr;
|
|
70
|
-
static bool g_initialized
|
|
72
|
+
static std::atomic<bool> g_initialized{false};
|
|
73
|
+
static std::mutex g_init_mutex;
|
|
71
74
|
|
|
72
75
|
// ============================================================================
|
|
73
76
|
// Load MFA Bridge Library
|
|
@@ -141,6 +144,24 @@ static bool load_mfa_bridge() {
|
|
|
141
144
|
return true;
|
|
142
145
|
}
|
|
143
146
|
|
|
147
|
+
// Thread-safe initialization helper
|
|
148
|
+
static void ensure_initialized() {
|
|
149
|
+
// Fast path: already initialized
|
|
150
|
+
if (g_initialized.load(std::memory_order_acquire)) {
|
|
151
|
+
return;
|
|
152
|
+
}
|
|
153
|
+
// Slow path: need to initialize with lock
|
|
154
|
+
std::lock_guard<std::mutex> lock(g_init_mutex);
|
|
155
|
+
// Double-check after acquiring lock
|
|
156
|
+
if (!g_initialized.load(std::memory_order_relaxed)) {
|
|
157
|
+
load_mfa_bridge();
|
|
158
|
+
if (!g_mfa_init()) {
|
|
159
|
+
throw std::runtime_error("Failed to initialize MFA");
|
|
160
|
+
}
|
|
161
|
+
g_initialized.store(true, std::memory_order_release);
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
144
165
|
// ============================================================================
|
|
145
166
|
// Get MTLBuffer from PyTorch MPS Tensor
|
|
146
167
|
// ============================================================================
|
|
@@ -359,14 +380,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
359
380
|
const c10::optional<at::Tensor>& attn_mask, // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
|
|
360
381
|
int64_t window_size // 0 = full attention, >0 = sliding window
|
|
361
382
|
) {
|
|
362
|
-
//
|
|
363
|
-
|
|
364
|
-
load_mfa_bridge();
|
|
365
|
-
if (!g_mfa_init()) {
|
|
366
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
367
|
-
}
|
|
368
|
-
g_initialized = true;
|
|
369
|
-
}
|
|
383
|
+
// Thread-safe initialization
|
|
384
|
+
ensure_initialized();
|
|
370
385
|
|
|
371
386
|
// Validate inputs
|
|
372
387
|
TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
|
|
@@ -562,14 +577,8 @@ at::Tensor mps_flash_attention_forward_with_bias(
|
|
|
562
577
|
int64_t window_size,
|
|
563
578
|
int64_t bias_repeat_count // >0 means bias repeats every N batches (for window attention)
|
|
564
579
|
) {
|
|
565
|
-
//
|
|
566
|
-
|
|
567
|
-
load_mfa_bridge();
|
|
568
|
-
if (!g_mfa_init()) {
|
|
569
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
570
|
-
}
|
|
571
|
-
g_initialized = true;
|
|
572
|
-
}
|
|
580
|
+
// Thread-safe initialization
|
|
581
|
+
ensure_initialized();
|
|
573
582
|
|
|
574
583
|
// Check that v6/v7 API is available
|
|
575
584
|
TORCH_CHECK(g_mfa_create_kernel_v6 || g_mfa_create_kernel_v7,
|
|
@@ -735,14 +744,8 @@ at::Tensor mps_flash_attention_forward_quantized(
|
|
|
735
744
|
const c10::optional<at::Tensor>& attn_mask,
|
|
736
745
|
int64_t window_size
|
|
737
746
|
) {
|
|
738
|
-
//
|
|
739
|
-
|
|
740
|
-
load_mfa_bridge();
|
|
741
|
-
if (!g_mfa_init()) {
|
|
742
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
743
|
-
}
|
|
744
|
-
g_initialized = true;
|
|
745
|
-
}
|
|
747
|
+
// Thread-safe initialization
|
|
748
|
+
ensure_initialized();
|
|
746
749
|
|
|
747
750
|
// Check that v4 API is available
|
|
748
751
|
TORCH_CHECK(g_mfa_create_kernel_v4, "Quantized attention requires MFA v4 API (update libMFABridge.dylib)");
|
|
@@ -992,14 +995,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
992
995
|
int64_t window_size, // 0 = full attention, >0 = sliding window
|
|
993
996
|
bool bf16_backward // true = use BF16 intermediates for ~2x faster backward
|
|
994
997
|
) {
|
|
995
|
-
//
|
|
996
|
-
|
|
997
|
-
load_mfa_bridge();
|
|
998
|
-
if (!g_mfa_init()) {
|
|
999
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
1000
|
-
}
|
|
1001
|
-
g_initialized = true;
|
|
1002
|
-
}
|
|
998
|
+
// Thread-safe initialization
|
|
999
|
+
ensure_initialized();
|
|
1003
1000
|
|
|
1004
1001
|
// Validate inputs
|
|
1005
1002
|
TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
mps_flash_attn/_C.cpython-314-darwin.so,sha256=
|
|
2
|
-
mps_flash_attn/__init__.py,sha256=
|
|
1
|
+
mps_flash_attn/_C.cpython-314-darwin.so,sha256=V9bjj53KRFmbMSslzTf7YV8N2l9NPa9_Ia2dORgRjqA,313448
|
|
2
|
+
mps_flash_attn/__init__.py,sha256=Esm5wd3As4es3ne1GjUtlQGfBtj0LB05UuND-SaIRXo,47730
|
|
3
3
|
mps_flash_attn/benchmark.py,sha256=qHhvb8Dmh07OEa_iXuPuJSEnRJlrjVF5nKzVwbWypWE,24141
|
|
4
|
-
mps_flash_attn/csrc/mps_flash_attn.mm,sha256=
|
|
4
|
+
mps_flash_attn/csrc/mps_flash_attn.mm,sha256=mR4S8SHLtRiksrmoFH6s2118q662SMNlFU8HmxAE3YY,51204
|
|
5
5
|
mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib,sha256=_oig6f2I6ZxBCKWbJF3ofmZMySm8gB399_M-lD2NOfM,13747
|
|
6
6
|
mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib,sha256=1fsmVvB5EubhN-y6s5CB-eVk_wuO2tfrabiQTwXvJJc,13171
|
|
7
7
|
mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib,sha256=5WKo_yAU-PgmulBUQhnzvt0DZRteVmo4-nc4U-T6G2g,17507
|
|
@@ -26,8 +26,8 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
|
|
|
26
26
|
mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib,sha256=qyOaQtRVwL_Wc6GGdu6z-ftf0iX84XexuY09-lNLl5o,13747
|
|
27
27
|
mps_flash_attn/kernels/manifest.json,sha256=d5MkE_BjqDQuMNm1jZiwWkQKfB-yfFml3lLSeR-wCLo,1867
|
|
28
28
|
mps_flash_attn/lib/libMFABridge.dylib,sha256=iKgfYISSKMSNt_iXnljjUr_hZZHyCAg2tdS3_ZjmLkc,605696
|
|
29
|
-
mps_flash_attn-0.
|
|
30
|
-
mps_flash_attn-0.
|
|
31
|
-
mps_flash_attn-0.
|
|
32
|
-
mps_flash_attn-0.
|
|
33
|
-
mps_flash_attn-0.
|
|
29
|
+
mps_flash_attn-0.3.1.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
|
|
30
|
+
mps_flash_attn-0.3.1.dist-info/METADATA,sha256=hp_w8UG_IpMF6BfS7STV69sM0Ss01-n6nWz9s1S2JzM,5834
|
|
31
|
+
mps_flash_attn-0.3.1.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
|
|
32
|
+
mps_flash_attn-0.3.1.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
|
|
33
|
+
mps_flash_attn-0.3.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|