mps-flash-attn 0.2.8__cp314-cp314-macosx_15_0_arm64.whl → 0.3.2__cp314-cp314-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- mps_flash_attn/_C.cpython-314-darwin.so +0 -0
- mps_flash_attn/__init__.py +286 -52
- mps_flash_attn/csrc/mps_flash_attn.mm +30 -33
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.2.dist-info}/METADATA +1 -1
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.2.dist-info}/RECORD +8 -8
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.2.dist-info}/WHEEL +0 -0
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {mps_flash_attn-0.2.8.dist-info → mps_flash_attn-0.3.2.dist-info}/top_level.txt +0 -0
|
Binary file
|
mps_flash_attn/__init__.py
CHANGED
|
@@ -4,13 +4,42 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.2
|
|
7
|
+
__version__ = "0.3.2"
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
# Core functions
|
|
11
|
+
"flash_attention",
|
|
12
|
+
"flash_attention_with_bias",
|
|
13
|
+
"flash_attention_chunked",
|
|
14
|
+
# Quantized attention
|
|
15
|
+
"flash_attention_fp8",
|
|
16
|
+
"flash_attention_int8",
|
|
17
|
+
"flash_attention_nf4",
|
|
18
|
+
"quantize_kv_fp8",
|
|
19
|
+
"quantize_kv_int8",
|
|
20
|
+
"quantize_kv_nf4",
|
|
21
|
+
# Utilities
|
|
22
|
+
"replace_sdpa",
|
|
23
|
+
"precompile",
|
|
24
|
+
"clear_cache",
|
|
25
|
+
"register_custom_op",
|
|
26
|
+
"is_available",
|
|
27
|
+
"convert_mask",
|
|
28
|
+
# Constants
|
|
29
|
+
"QUANT_FP8_E4M3",
|
|
30
|
+
"QUANT_FP8_E5M2",
|
|
31
|
+
"QUANT_INT8",
|
|
32
|
+
"QUANT_NF4",
|
|
33
|
+
# Version
|
|
34
|
+
"__version__",
|
|
35
|
+
]
|
|
8
36
|
|
|
9
37
|
import torch
|
|
10
|
-
from typing import Optional
|
|
38
|
+
from typing import Optional, Tuple
|
|
11
39
|
import math
|
|
12
40
|
import threading
|
|
13
41
|
import os
|
|
42
|
+
import warnings
|
|
14
43
|
|
|
15
44
|
# Try to import the C++ extension
|
|
16
45
|
try:
|
|
@@ -30,6 +59,20 @@ def is_available() -> bool:
|
|
|
30
59
|
return _HAS_MFA and torch.backends.mps.is_available()
|
|
31
60
|
|
|
32
61
|
|
|
62
|
+
def _ensure_contiguous(tensor: torch.Tensor, name: str) -> torch.Tensor:
|
|
63
|
+
"""Ensure tensor is contiguous, with a debug warning if conversion needed."""
|
|
64
|
+
if tensor.is_contiguous():
|
|
65
|
+
return tensor
|
|
66
|
+
# Auto-convert with debug info
|
|
67
|
+
if os.environ.get("MFA_DEBUG", "0") == "1":
|
|
68
|
+
warnings.warn(
|
|
69
|
+
f"MFA: {name} tensor was not contiguous (stride={tensor.stride()}), "
|
|
70
|
+
f"auto-converting. For best performance, ensure inputs are contiguous.",
|
|
71
|
+
UserWarning
|
|
72
|
+
)
|
|
73
|
+
return tensor.contiguous()
|
|
74
|
+
|
|
75
|
+
|
|
33
76
|
def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
34
77
|
"""
|
|
35
78
|
Convert attention mask to MFA's boolean format.
|
|
@@ -54,6 +97,56 @@ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
|
54
97
|
return attn_mask <= -1e3
|
|
55
98
|
|
|
56
99
|
|
|
100
|
+
def _validate_and_expand_mask(
|
|
101
|
+
attn_mask: Optional[torch.Tensor],
|
|
102
|
+
B: int,
|
|
103
|
+
H: int,
|
|
104
|
+
N_q: int,
|
|
105
|
+
N_kv: int,
|
|
106
|
+
) -> Optional[torch.Tensor]:
|
|
107
|
+
"""
|
|
108
|
+
Validate attention mask shape and expand broadcast dimensions.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
attn_mask: Optional mask of shape (B, H, N_q, N_kv) or broadcastable
|
|
112
|
+
B: Batch size
|
|
113
|
+
H: Number of heads
|
|
114
|
+
N_q: Query sequence length
|
|
115
|
+
N_kv: Key/Value sequence length
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Expanded mask of shape (mb, mh, N_q, N_kv) or None
|
|
119
|
+
"""
|
|
120
|
+
if attn_mask is None:
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
|
|
124
|
+
|
|
125
|
+
if attn_mask.dim() != 4:
|
|
126
|
+
raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
|
|
127
|
+
|
|
128
|
+
mb, mh, mq, mk = attn_mask.shape
|
|
129
|
+
|
|
130
|
+
# Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
|
|
131
|
+
if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Expand broadcast mask to full shape for Metal kernel
|
|
137
|
+
if mq == 1 and N_q > 1:
|
|
138
|
+
attn_mask = attn_mask.expand(mb, mh, N_q, mk)
|
|
139
|
+
if mk == 1 and N_kv > 1:
|
|
140
|
+
attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
|
|
141
|
+
|
|
142
|
+
if mb != 1 and mb != B:
|
|
143
|
+
raise ValueError(f"attn_mask batch size must be 1 or {B}, got {mb}")
|
|
144
|
+
if mh != 1 and mh != H:
|
|
145
|
+
raise ValueError(f"attn_mask head count must be 1 or {H}, got {mh}")
|
|
146
|
+
|
|
147
|
+
return attn_mask
|
|
148
|
+
|
|
149
|
+
|
|
57
150
|
class FlashAttentionFunction(torch.autograd.Function):
|
|
58
151
|
"""Autograd function for Flash Attention with backward pass support."""
|
|
59
152
|
|
|
@@ -176,6 +269,20 @@ def flash_attention(
|
|
|
176
269
|
if not torch.backends.mps.is_available():
|
|
177
270
|
raise RuntimeError("MPS not available")
|
|
178
271
|
|
|
272
|
+
# Validate scale parameter
|
|
273
|
+
if scale is not None:
|
|
274
|
+
if scale <= 0:
|
|
275
|
+
raise ValueError(f"scale must be positive, got {scale}")
|
|
276
|
+
# Warn about extreme scale values that could cause numerical issues
|
|
277
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
278
|
+
if scale < default_scale * 0.01 or scale > default_scale * 100:
|
|
279
|
+
warnings.warn(
|
|
280
|
+
f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
|
|
281
|
+
"this may cause numerical issues",
|
|
282
|
+
UserWarning,
|
|
283
|
+
stacklevel=2
|
|
284
|
+
)
|
|
285
|
+
|
|
179
286
|
# Validate device
|
|
180
287
|
if query.device.type != 'mps':
|
|
181
288
|
raise ValueError("query must be on MPS device")
|
|
@@ -186,6 +293,24 @@ def flash_attention(
|
|
|
186
293
|
if attn_mask is not None and attn_mask.device.type != 'mps':
|
|
187
294
|
raise ValueError("attn_mask must be on MPS device")
|
|
188
295
|
|
|
296
|
+
# Ensure contiguous (auto-convert with debug warning)
|
|
297
|
+
query = _ensure_contiguous(query, "query")
|
|
298
|
+
key = _ensure_contiguous(key, "key")
|
|
299
|
+
value = _ensure_contiguous(value, "value")
|
|
300
|
+
|
|
301
|
+
# Validate tensor dimensions
|
|
302
|
+
if query.dim() != 4:
|
|
303
|
+
raise RuntimeError(f"query must be 4D (B, H, N, D), got {query.dim()}D")
|
|
304
|
+
if key.dim() != 4:
|
|
305
|
+
raise RuntimeError(f"key must be 4D (B, H, N, D), got {key.dim()}D")
|
|
306
|
+
if value.dim() != 4:
|
|
307
|
+
raise RuntimeError(f"value must be 4D (B, H, N, D), got {value.dim()}D")
|
|
308
|
+
|
|
309
|
+
# Validate and expand broadcast mask
|
|
310
|
+
B, H, N_q, D = query.shape
|
|
311
|
+
N_kv = key.shape[2]
|
|
312
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
313
|
+
|
|
189
314
|
# Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
|
|
190
315
|
if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
|
|
191
316
|
# Apply scale if provided
|
|
@@ -213,6 +338,10 @@ def replace_sdpa():
|
|
|
213
338
|
import torch.nn.functional as F
|
|
214
339
|
|
|
215
340
|
original_sdpa = F.scaled_dot_product_attention
|
|
341
|
+
_debug = os.environ.get("MFA_DEBUG", "0") == "1"
|
|
342
|
+
_call_count = [0] # mutable for closure
|
|
343
|
+
_fallback_count = [0] # track fallbacks for warning
|
|
344
|
+
_last_fallback_error = [None]
|
|
216
345
|
|
|
217
346
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
218
347
|
is_causal=False, scale=None, enable_gqa=False, **kwargs):
|
|
@@ -224,37 +353,92 @@ def replace_sdpa():
|
|
|
224
353
|
# seq=1024: 2.3-3.7x (MFA much faster)
|
|
225
354
|
# seq=2048: 2.2-3.9x (MFA much faster)
|
|
226
355
|
# seq=4096: 2.1-3.7x (MFA much faster)
|
|
356
|
+
# Determine seq_len based on tensor dimensionality
|
|
357
|
+
# 4D: (B, H, S, D) -> seq_len = shape[2]
|
|
358
|
+
# 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
|
|
359
|
+
is_3d = query.ndim == 3
|
|
360
|
+
seq_len = query.shape[1] if is_3d else query.shape[2]
|
|
361
|
+
|
|
227
362
|
if (query.device.type == 'mps' and
|
|
228
363
|
dropout_p == 0.0 and
|
|
229
364
|
_HAS_MFA and
|
|
230
|
-
query.
|
|
365
|
+
query.ndim >= 3 and
|
|
366
|
+
seq_len >= 512):
|
|
231
367
|
try:
|
|
368
|
+
q, k, v = query, key, value
|
|
369
|
+
|
|
370
|
+
# Handle 3D tensors (B, S, D) - treat as single-head attention
|
|
371
|
+
# Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
|
|
372
|
+
if is_3d:
|
|
373
|
+
q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
|
|
374
|
+
k = k.unsqueeze(1)
|
|
375
|
+
v = v.unsqueeze(1)
|
|
376
|
+
|
|
232
377
|
# Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
|
|
233
378
|
# Common in Llama 2/3, Mistral, Qwen, etc.
|
|
234
|
-
|
|
235
|
-
|
|
379
|
+
# NOTE: Always expand when heads mismatch, not just when enable_gqa=True
|
|
380
|
+
# Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
|
|
381
|
+
# MPS SDPA doesn't support native GQA - we handle it here
|
|
382
|
+
if q.shape[1] != k.shape[1]:
|
|
236
383
|
# Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
|
|
237
|
-
n_rep =
|
|
238
|
-
k =
|
|
239
|
-
v =
|
|
384
|
+
n_rep = q.shape[1] // k.shape[1]
|
|
385
|
+
k = k.repeat_interleave(n_rep, dim=1)
|
|
386
|
+
v = v.repeat_interleave(n_rep, dim=1)
|
|
240
387
|
|
|
241
388
|
# Convert float mask to bool mask if needed
|
|
242
389
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
243
390
|
# MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
|
|
244
391
|
mfa_mask = None
|
|
245
392
|
if attn_mask is not None:
|
|
393
|
+
if _debug:
|
|
394
|
+
print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
|
|
246
395
|
if attn_mask.dtype == torch.bool:
|
|
247
|
-
#
|
|
248
|
-
|
|
396
|
+
# PyTorch SDPA bool mask: True = ATTEND, False = MASKED
|
|
397
|
+
# MFA bool mask: True = MASKED, False = ATTEND
|
|
398
|
+
# They're opposite! Invert it.
|
|
399
|
+
mfa_mask = ~attn_mask
|
|
249
400
|
else:
|
|
250
401
|
# Float mask: typically -inf for masked positions, 0 for unmasked
|
|
251
402
|
# Convert: positions with large negative values -> True (masked)
|
|
252
403
|
# Use -1e3 threshold to catch -1000, -10000, -inf, etc.
|
|
253
404
|
mfa_mask = attn_mask <= -1e3
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
405
|
+
if _debug:
|
|
406
|
+
print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
|
|
407
|
+
|
|
408
|
+
out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
409
|
+
|
|
410
|
+
# Squeeze back for 3D input
|
|
411
|
+
if is_3d:
|
|
412
|
+
out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
|
|
413
|
+
|
|
414
|
+
if _debug:
|
|
415
|
+
_call_count[0] += 1
|
|
416
|
+
print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
|
|
417
|
+
|
|
418
|
+
return out
|
|
419
|
+
except Exception as e:
|
|
420
|
+
# Fall back to original on any error, but track it
|
|
421
|
+
_fallback_count[0] += 1
|
|
422
|
+
_last_fallback_error[0] = str(e)
|
|
423
|
+
if _debug:
|
|
424
|
+
import traceback
|
|
425
|
+
print(f"[MFA FALLBACK #{_fallback_count[0]}] shape={tuple(query.shape)}\n{traceback.format_exc()}")
|
|
426
|
+
# Warn user after repeated fallbacks (likely a real problem)
|
|
427
|
+
if _fallback_count[0] == 10:
|
|
428
|
+
warnings.warn(
|
|
429
|
+
f"MFA has fallen back to native SDPA {_fallback_count[0]} times. "
|
|
430
|
+
f"Last error: {_last_fallback_error[0]}. "
|
|
431
|
+
f"Set MFA_DEBUG=1 for details.",
|
|
432
|
+
UserWarning
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
if _debug and query.device.type == 'mps':
|
|
436
|
+
_call_count[0] += 1
|
|
437
|
+
reason = []
|
|
438
|
+
if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
|
|
439
|
+
if query.ndim < 3: reason.append(f"ndim={query.ndim}")
|
|
440
|
+
if seq_len < 512: reason.append(f"seq={seq_len}<512")
|
|
441
|
+
print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
|
|
258
442
|
|
|
259
443
|
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
|
|
260
444
|
|
|
@@ -461,13 +645,13 @@ def flash_attention_chunked(
|
|
|
461
645
|
return _C.forward(query, key, value, is_causal, None, 0)
|
|
462
646
|
|
|
463
647
|
# Initialize running statistics for online softmax
|
|
464
|
-
# m = running max, l = running sum of exp, acc = accumulated output
|
|
465
648
|
device = query.device
|
|
466
649
|
dtype = query.dtype
|
|
467
650
|
|
|
468
651
|
# Use float32 for numerical stability of softmax statistics
|
|
469
|
-
|
|
470
|
-
|
|
652
|
+
# running_L: base-2 logsumexp of all attention scores seen so far (-inf means no data yet)
|
|
653
|
+
# output_acc: weighted combination of outputs (weights sum to 1 after each update)
|
|
654
|
+
running_L = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
|
|
471
655
|
output_acc = torch.zeros((B, H, seq_len_q, D), device=device, dtype=torch.float32)
|
|
472
656
|
|
|
473
657
|
# Process K/V in chunks
|
|
@@ -488,51 +672,81 @@ def flash_attention_chunked(
|
|
|
488
672
|
# - Partial chunk (up to q) if start_idx <= q < end_idx
|
|
489
673
|
# - None of chunk if q < start_idx
|
|
490
674
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
# chunk_lse shape: (B, H, seq_len_q)
|
|
498
|
-
# We need to convert logsumexp to (max, sum) for online algorithm
|
|
499
|
-
chunk_lse = chunk_lse.unsqueeze(-1) # (B, H, seq_len_q, 1)
|
|
675
|
+
if is_causal:
|
|
676
|
+
# Create explicit causal mask for this chunk
|
|
677
|
+
# Query positions: 0 to seq_len_q-1
|
|
678
|
+
# Key positions in chunk: start_idx to end_idx-1
|
|
679
|
+
chunk_len = end_idx - start_idx
|
|
500
680
|
|
|
501
|
-
|
|
502
|
-
|
|
681
|
+
# Build mask: mask[q, k_local] = True means DON'T attend
|
|
682
|
+
# We want to attend when global_k_pos <= q
|
|
683
|
+
# global_k_pos = start_idx + k_local
|
|
684
|
+
# So: attend when start_idx + k_local <= q
|
|
685
|
+
# mask = start_idx + k_local > q
|
|
503
686
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
# We approximate chunk_max ≈ chunk_lse (valid when exp sum dominates)
|
|
687
|
+
q_pos = torch.arange(seq_len_q, device=device).view(1, 1, seq_len_q, 1)
|
|
688
|
+
k_pos = torch.arange(chunk_len, device=device).view(1, 1, 1, chunk_len) + start_idx
|
|
689
|
+
causal_mask = k_pos > q_pos # True = masked (don't attend)
|
|
508
690
|
|
|
509
|
-
|
|
691
|
+
# Expand to batch and heads
|
|
692
|
+
causal_mask = causal_mask.expand(B, H, seq_len_q, chunk_len)
|
|
510
693
|
|
|
511
|
-
|
|
512
|
-
|
|
694
|
+
# Call forward with explicit mask (is_causal=False since we handle it)
|
|
695
|
+
chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, causal_mask, 0)
|
|
696
|
+
else:
|
|
697
|
+
# Non-causal: just process the chunk directly
|
|
698
|
+
chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, None, 0)
|
|
513
699
|
|
|
514
|
-
#
|
|
515
|
-
#
|
|
516
|
-
|
|
517
|
-
#
|
|
518
|
-
|
|
700
|
+
# chunk_L shape: (B, H, seq_len_q)
|
|
701
|
+
# The kernel returns L = m + log2(l) where:
|
|
702
|
+
# m = max(scores * log2(e) / sqrt(D))
|
|
703
|
+
# l = sum(exp2(scores * log2(e) / sqrt(D) - m))
|
|
704
|
+
# This is a base-2 logsumexp: L = log2(sum(exp2(scaled_scores)))
|
|
705
|
+
chunk_L = chunk_lse.unsqueeze(-1).float() # (B, H, seq_len_q, 1)
|
|
519
706
|
|
|
520
|
-
#
|
|
521
|
-
|
|
522
|
-
correction_new = torch.exp(chunk_max - new_max)
|
|
707
|
+
# Convert chunk output to float32 for accumulation
|
|
708
|
+
chunk_out = chunk_out.float()
|
|
523
709
|
|
|
524
|
-
#
|
|
525
|
-
#
|
|
526
|
-
#
|
|
527
|
-
|
|
710
|
+
# Online softmax algorithm using base-2 representation
|
|
711
|
+
#
|
|
712
|
+
# Flash attention returns: chunk_out = softmax(scores) @ V
|
|
713
|
+
# The output is already normalized. For online combination:
|
|
714
|
+
# new_L = log2(2^running_L + 2^chunk_L)
|
|
715
|
+
# = max(running_L, chunk_L) + log2(2^(running_L - max) + 2^(chunk_L - max))
|
|
716
|
+
#
|
|
717
|
+
# The weights for combining outputs are:
|
|
718
|
+
# old_weight = 2^(running_L - new_L)
|
|
719
|
+
# new_weight = 2^(chunk_L - new_L)
|
|
720
|
+
# These weights sum to 1, so: output = old_weight * old_out + new_weight * new_out
|
|
721
|
+
|
|
722
|
+
# Compute new base-2 logsumexp
|
|
723
|
+
max_L = torch.maximum(running_L, chunk_L)
|
|
724
|
+
|
|
725
|
+
# Handle -inf case (no previous data)
|
|
726
|
+
# Use exp2 for base-2 (matches kernel's internal representation)
|
|
727
|
+
running_exp2 = torch.where(
|
|
728
|
+
running_L == float('-inf'),
|
|
729
|
+
torch.zeros_like(running_L),
|
|
730
|
+
torch.exp2(running_L - max_L)
|
|
731
|
+
)
|
|
732
|
+
chunk_exp2 = torch.exp2(chunk_L - max_L)
|
|
733
|
+
new_L = max_L + torch.log2(running_exp2 + chunk_exp2)
|
|
734
|
+
|
|
735
|
+
# Compute correction factors using base-2 exp
|
|
736
|
+
old_weight = torch.where(
|
|
737
|
+
running_L == float('-inf'),
|
|
738
|
+
torch.zeros_like(running_L),
|
|
739
|
+
torch.exp2(running_L - new_L)
|
|
740
|
+
)
|
|
741
|
+
new_weight = torch.exp2(chunk_L - new_L)
|
|
528
742
|
|
|
529
743
|
# Update accumulator
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
744
|
+
# Update accumulator
|
|
745
|
+
output_acc = output_acc * old_weight + chunk_out * new_weight
|
|
746
|
+
running_L = new_L
|
|
533
747
|
|
|
534
|
-
#
|
|
535
|
-
output = output_acc
|
|
748
|
+
# No final normalization needed - weights already sum to 1
|
|
749
|
+
output = output_acc
|
|
536
750
|
|
|
537
751
|
# Convert back to original dtype
|
|
538
752
|
return output.to(dtype)
|
|
@@ -754,6 +968,11 @@ def flash_attention_fp8(
|
|
|
754
968
|
scale_factor = scale / default_scale
|
|
755
969
|
query = query * scale_factor
|
|
756
970
|
|
|
971
|
+
# Validate and expand broadcast mask
|
|
972
|
+
B, H, N_q, D = query.shape
|
|
973
|
+
N_kv = key.shape[2]
|
|
974
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
975
|
+
|
|
757
976
|
quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
|
|
758
977
|
return _C.forward_quantized(
|
|
759
978
|
query, key, value, k_scale, v_scale,
|
|
@@ -807,6 +1026,11 @@ def flash_attention_int8(
|
|
|
807
1026
|
scale_factor = scale / default_scale
|
|
808
1027
|
query = query * scale_factor
|
|
809
1028
|
|
|
1029
|
+
# Validate and expand broadcast mask
|
|
1030
|
+
B, H, N_q, D = query.shape
|
|
1031
|
+
N_kv = key.shape[2]
|
|
1032
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1033
|
+
|
|
810
1034
|
return _C.forward_quantized(
|
|
811
1035
|
query, key, value, k_scale, v_scale,
|
|
812
1036
|
QUANT_INT8, is_causal, attn_mask, window_size
|
|
@@ -863,6 +1087,11 @@ def flash_attention_nf4(
|
|
|
863
1087
|
scale_factor = scale / default_scale
|
|
864
1088
|
query = query * scale_factor
|
|
865
1089
|
|
|
1090
|
+
# Validate and expand broadcast mask
|
|
1091
|
+
B, H, N_q, D = query.shape
|
|
1092
|
+
N_kv = key.shape[2]
|
|
1093
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1094
|
+
|
|
866
1095
|
return _C.forward_quantized(
|
|
867
1096
|
query, key, value, k_scale, v_scale,
|
|
868
1097
|
QUANT_NF4, is_causal, attn_mask, window_size
|
|
@@ -917,6 +1146,11 @@ def flash_attention_quantized(
|
|
|
917
1146
|
scale_factor = scale / default_scale
|
|
918
1147
|
query = query * scale_factor
|
|
919
1148
|
|
|
1149
|
+
# Validate and expand broadcast mask
|
|
1150
|
+
B, H, N_q, D = query.shape
|
|
1151
|
+
N_kv = key.shape[2]
|
|
1152
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1153
|
+
|
|
920
1154
|
return _C.forward_quantized(
|
|
921
1155
|
query, key, value, k_scale, v_scale,
|
|
922
1156
|
quant_type, is_causal, attn_mask, window_size
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
#include <dlfcn.h>
|
|
15
15
|
#include <string>
|
|
16
16
|
#include <vector>
|
|
17
|
+
#include <mutex>
|
|
18
|
+
#include <atomic>
|
|
17
19
|
|
|
18
20
|
// ============================================================================
|
|
19
21
|
// MFA Bridge Function Types
|
|
@@ -67,7 +69,8 @@ static mfa_forward_fn g_mfa_forward = nullptr;
|
|
|
67
69
|
static mfa_backward_fn g_mfa_backward = nullptr;
|
|
68
70
|
static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
|
|
69
71
|
static void* g_dylib_handle = nullptr;
|
|
70
|
-
static bool g_initialized
|
|
72
|
+
static std::atomic<bool> g_initialized{false};
|
|
73
|
+
static std::mutex g_init_mutex;
|
|
71
74
|
|
|
72
75
|
// ============================================================================
|
|
73
76
|
// Load MFA Bridge Library
|
|
@@ -141,6 +144,24 @@ static bool load_mfa_bridge() {
|
|
|
141
144
|
return true;
|
|
142
145
|
}
|
|
143
146
|
|
|
147
|
+
// Thread-safe initialization helper
|
|
148
|
+
static void ensure_initialized() {
|
|
149
|
+
// Fast path: already initialized
|
|
150
|
+
if (g_initialized.load(std::memory_order_acquire)) {
|
|
151
|
+
return;
|
|
152
|
+
}
|
|
153
|
+
// Slow path: need to initialize with lock
|
|
154
|
+
std::lock_guard<std::mutex> lock(g_init_mutex);
|
|
155
|
+
// Double-check after acquiring lock
|
|
156
|
+
if (!g_initialized.load(std::memory_order_relaxed)) {
|
|
157
|
+
load_mfa_bridge();
|
|
158
|
+
if (!g_mfa_init()) {
|
|
159
|
+
throw std::runtime_error("Failed to initialize MFA");
|
|
160
|
+
}
|
|
161
|
+
g_initialized.store(true, std::memory_order_release);
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
144
165
|
// ============================================================================
|
|
145
166
|
// Get MTLBuffer from PyTorch MPS Tensor
|
|
146
167
|
// ============================================================================
|
|
@@ -359,14 +380,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
359
380
|
const c10::optional<at::Tensor>& attn_mask, // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
|
|
360
381
|
int64_t window_size // 0 = full attention, >0 = sliding window
|
|
361
382
|
) {
|
|
362
|
-
//
|
|
363
|
-
|
|
364
|
-
load_mfa_bridge();
|
|
365
|
-
if (!g_mfa_init()) {
|
|
366
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
367
|
-
}
|
|
368
|
-
g_initialized = true;
|
|
369
|
-
}
|
|
383
|
+
// Thread-safe initialization
|
|
384
|
+
ensure_initialized();
|
|
370
385
|
|
|
371
386
|
// Validate inputs
|
|
372
387
|
TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
|
|
@@ -562,14 +577,8 @@ at::Tensor mps_flash_attention_forward_with_bias(
|
|
|
562
577
|
int64_t window_size,
|
|
563
578
|
int64_t bias_repeat_count // >0 means bias repeats every N batches (for window attention)
|
|
564
579
|
) {
|
|
565
|
-
//
|
|
566
|
-
|
|
567
|
-
load_mfa_bridge();
|
|
568
|
-
if (!g_mfa_init()) {
|
|
569
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
570
|
-
}
|
|
571
|
-
g_initialized = true;
|
|
572
|
-
}
|
|
580
|
+
// Thread-safe initialization
|
|
581
|
+
ensure_initialized();
|
|
573
582
|
|
|
574
583
|
// Check that v6/v7 API is available
|
|
575
584
|
TORCH_CHECK(g_mfa_create_kernel_v6 || g_mfa_create_kernel_v7,
|
|
@@ -735,14 +744,8 @@ at::Tensor mps_flash_attention_forward_quantized(
|
|
|
735
744
|
const c10::optional<at::Tensor>& attn_mask,
|
|
736
745
|
int64_t window_size
|
|
737
746
|
) {
|
|
738
|
-
//
|
|
739
|
-
|
|
740
|
-
load_mfa_bridge();
|
|
741
|
-
if (!g_mfa_init()) {
|
|
742
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
743
|
-
}
|
|
744
|
-
g_initialized = true;
|
|
745
|
-
}
|
|
747
|
+
// Thread-safe initialization
|
|
748
|
+
ensure_initialized();
|
|
746
749
|
|
|
747
750
|
// Check that v4 API is available
|
|
748
751
|
TORCH_CHECK(g_mfa_create_kernel_v4, "Quantized attention requires MFA v4 API (update libMFABridge.dylib)");
|
|
@@ -992,14 +995,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
992
995
|
int64_t window_size, // 0 = full attention, >0 = sliding window
|
|
993
996
|
bool bf16_backward // true = use BF16 intermediates for ~2x faster backward
|
|
994
997
|
) {
|
|
995
|
-
//
|
|
996
|
-
|
|
997
|
-
load_mfa_bridge();
|
|
998
|
-
if (!g_mfa_init()) {
|
|
999
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
1000
|
-
}
|
|
1001
|
-
g_initialized = true;
|
|
1002
|
-
}
|
|
998
|
+
// Thread-safe initialization
|
|
999
|
+
ensure_initialized();
|
|
1003
1000
|
|
|
1004
1001
|
// Validate inputs
|
|
1005
1002
|
TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
mps_flash_attn/_C.cpython-314-darwin.so,sha256=
|
|
2
|
-
mps_flash_attn/__init__.py,sha256=
|
|
1
|
+
mps_flash_attn/_C.cpython-314-darwin.so,sha256=GtWa4KIcynqjbCQYw-uTBpkX5NTcxyKuK0APoGnRlQM,313448
|
|
2
|
+
mps_flash_attn/__init__.py,sha256=u6B_WenZOTk1WMe9u4-PbyPy7pt8NValR5i8Oz6bI-U,49252
|
|
3
3
|
mps_flash_attn/benchmark.py,sha256=qHhvb8Dmh07OEa_iXuPuJSEnRJlrjVF5nKzVwbWypWE,24141
|
|
4
|
-
mps_flash_attn/csrc/mps_flash_attn.mm,sha256=
|
|
4
|
+
mps_flash_attn/csrc/mps_flash_attn.mm,sha256=mR4S8SHLtRiksrmoFH6s2118q662SMNlFU8HmxAE3YY,51204
|
|
5
5
|
mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib,sha256=_oig6f2I6ZxBCKWbJF3ofmZMySm8gB399_M-lD2NOfM,13747
|
|
6
6
|
mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib,sha256=1fsmVvB5EubhN-y6s5CB-eVk_wuO2tfrabiQTwXvJJc,13171
|
|
7
7
|
mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib,sha256=5WKo_yAU-PgmulBUQhnzvt0DZRteVmo4-nc4U-T6G2g,17507
|
|
@@ -26,8 +26,8 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
|
|
|
26
26
|
mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib,sha256=qyOaQtRVwL_Wc6GGdu6z-ftf0iX84XexuY09-lNLl5o,13747
|
|
27
27
|
mps_flash_attn/kernels/manifest.json,sha256=d5MkE_BjqDQuMNm1jZiwWkQKfB-yfFml3lLSeR-wCLo,1867
|
|
28
28
|
mps_flash_attn/lib/libMFABridge.dylib,sha256=iKgfYISSKMSNt_iXnljjUr_hZZHyCAg2tdS3_ZjmLkc,605696
|
|
29
|
-
mps_flash_attn-0.2.
|
|
30
|
-
mps_flash_attn-0.2.
|
|
31
|
-
mps_flash_attn-0.2.
|
|
32
|
-
mps_flash_attn-0.2.
|
|
33
|
-
mps_flash_attn-0.2.
|
|
29
|
+
mps_flash_attn-0.3.2.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
|
|
30
|
+
mps_flash_attn-0.3.2.dist-info/METADATA,sha256=vhcu8d8NdzmuQbOqVUpzacXJF__Eu-BW1C7Em_CNoyg,5834
|
|
31
|
+
mps_flash_attn-0.3.2.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
|
|
32
|
+
mps_flash_attn-0.3.2.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
|
|
33
|
+
mps_flash_attn-0.3.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|