mps-flash-attn 0.2.9__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/PKG-INFO +1 -1
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/__init__.py +167 -43
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/csrc/mps_flash_attn.mm +30 -33
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/SOURCES.txt +1 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/pyproject.toml +1 -1
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/setup.py +1 -1
- mps_flash_attn-0.3.0/tests/test_issues.py +446 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/LICENSE +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/README.md +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/setup.cfg +0 -0
- {mps_flash_attn-0.2.9 → mps_flash_attn-0.3.0}/tests/test_mfa_v2.py +0 -0
|
@@ -4,13 +4,42 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.
|
|
7
|
+
__version__ = "0.3.0"
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
# Core functions
|
|
11
|
+
"flash_attention",
|
|
12
|
+
"flash_attention_with_bias",
|
|
13
|
+
"flash_attention_chunked",
|
|
14
|
+
# Quantized attention
|
|
15
|
+
"flash_attention_fp8",
|
|
16
|
+
"flash_attention_int8",
|
|
17
|
+
"flash_attention_nf4",
|
|
18
|
+
"quantize_kv_fp8",
|
|
19
|
+
"quantize_kv_int8",
|
|
20
|
+
"quantize_kv_nf4",
|
|
21
|
+
# Utilities
|
|
22
|
+
"replace_sdpa",
|
|
23
|
+
"precompile",
|
|
24
|
+
"clear_cache",
|
|
25
|
+
"register_custom_op",
|
|
26
|
+
"is_available",
|
|
27
|
+
"convert_mask",
|
|
28
|
+
# Constants
|
|
29
|
+
"QUANT_FP8_E4M3",
|
|
30
|
+
"QUANT_FP8_E5M2",
|
|
31
|
+
"QUANT_INT8",
|
|
32
|
+
"QUANT_NF4",
|
|
33
|
+
# Version
|
|
34
|
+
"__version__",
|
|
35
|
+
]
|
|
8
36
|
|
|
9
37
|
import torch
|
|
10
|
-
from typing import Optional
|
|
38
|
+
from typing import Optional, Tuple
|
|
11
39
|
import math
|
|
12
40
|
import threading
|
|
13
41
|
import os
|
|
42
|
+
import warnings
|
|
14
43
|
|
|
15
44
|
# Try to import the C++ extension
|
|
16
45
|
try:
|
|
@@ -30,6 +59,20 @@ def is_available() -> bool:
|
|
|
30
59
|
return _HAS_MFA and torch.backends.mps.is_available()
|
|
31
60
|
|
|
32
61
|
|
|
62
|
+
def _ensure_contiguous(tensor: torch.Tensor, name: str) -> torch.Tensor:
|
|
63
|
+
"""Ensure tensor is contiguous, with a debug warning if conversion needed."""
|
|
64
|
+
if tensor.is_contiguous():
|
|
65
|
+
return tensor
|
|
66
|
+
# Auto-convert with debug info
|
|
67
|
+
if os.environ.get("MFA_DEBUG", "0") == "1":
|
|
68
|
+
warnings.warn(
|
|
69
|
+
f"MFA: {name} tensor was not contiguous (stride={tensor.stride()}), "
|
|
70
|
+
f"auto-converting. For best performance, ensure inputs are contiguous.",
|
|
71
|
+
UserWarning
|
|
72
|
+
)
|
|
73
|
+
return tensor.contiguous()
|
|
74
|
+
|
|
75
|
+
|
|
33
76
|
def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
34
77
|
"""
|
|
35
78
|
Convert attention mask to MFA's boolean format.
|
|
@@ -176,6 +219,20 @@ def flash_attention(
|
|
|
176
219
|
if not torch.backends.mps.is_available():
|
|
177
220
|
raise RuntimeError("MPS not available")
|
|
178
221
|
|
|
222
|
+
# Validate scale parameter
|
|
223
|
+
if scale is not None:
|
|
224
|
+
if scale <= 0:
|
|
225
|
+
raise ValueError(f"scale must be positive, got {scale}")
|
|
226
|
+
# Warn about extreme scale values that could cause numerical issues
|
|
227
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
228
|
+
if scale < default_scale * 0.01 or scale > default_scale * 100:
|
|
229
|
+
warnings.warn(
|
|
230
|
+
f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
|
|
231
|
+
"this may cause numerical issues",
|
|
232
|
+
UserWarning,
|
|
233
|
+
stacklevel=2
|
|
234
|
+
)
|
|
235
|
+
|
|
179
236
|
# Validate device
|
|
180
237
|
if query.device.type != 'mps':
|
|
181
238
|
raise ValueError("query must be on MPS device")
|
|
@@ -186,6 +243,31 @@ def flash_attention(
|
|
|
186
243
|
if attn_mask is not None and attn_mask.device.type != 'mps':
|
|
187
244
|
raise ValueError("attn_mask must be on MPS device")
|
|
188
245
|
|
|
246
|
+
# Ensure contiguous (auto-convert with debug warning)
|
|
247
|
+
query = _ensure_contiguous(query, "query")
|
|
248
|
+
key = _ensure_contiguous(key, "key")
|
|
249
|
+
value = _ensure_contiguous(value, "value")
|
|
250
|
+
if attn_mask is not None:
|
|
251
|
+
attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
|
|
252
|
+
# Validate mask shape
|
|
253
|
+
B, H, N_q, D = query.shape
|
|
254
|
+
N_kv = key.shape[2]
|
|
255
|
+
if attn_mask.dim() != 4:
|
|
256
|
+
raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
|
|
257
|
+
mb, mh, mq, mk = attn_mask.shape
|
|
258
|
+
if mq != N_q or mk != N_kv:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv})"
|
|
261
|
+
)
|
|
262
|
+
if mb != 1 and mb != B:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
f"attn_mask batch size must be 1 or {B}, got {mb}"
|
|
265
|
+
)
|
|
266
|
+
if mh != 1 and mh != H:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
f"attn_mask head count must be 1 or {H}, got {mh}"
|
|
269
|
+
)
|
|
270
|
+
|
|
189
271
|
# Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
|
|
190
272
|
if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
|
|
191
273
|
# Apply scale if provided
|
|
@@ -215,6 +297,8 @@ def replace_sdpa():
|
|
|
215
297
|
original_sdpa = F.scaled_dot_product_attention
|
|
216
298
|
_debug = os.environ.get("MFA_DEBUG", "0") == "1"
|
|
217
299
|
_call_count = [0] # mutable for closure
|
|
300
|
+
_fallback_count = [0] # track fallbacks for warning
|
|
301
|
+
_last_fallback_error = [None]
|
|
218
302
|
|
|
219
303
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
220
304
|
is_causal=False, scale=None, enable_gqa=False, **kwargs):
|
|
@@ -290,10 +374,20 @@ def replace_sdpa():
|
|
|
290
374
|
|
|
291
375
|
return out
|
|
292
376
|
except Exception as e:
|
|
293
|
-
# Fall back to original on any error
|
|
377
|
+
# Fall back to original on any error, but track it
|
|
378
|
+
_fallback_count[0] += 1
|
|
379
|
+
_last_fallback_error[0] = str(e)
|
|
294
380
|
if _debug:
|
|
295
|
-
|
|
296
|
-
|
|
381
|
+
import traceback
|
|
382
|
+
print(f"[MFA FALLBACK #{_fallback_count[0]}] shape={tuple(query.shape)}\n{traceback.format_exc()}")
|
|
383
|
+
# Warn user after repeated fallbacks (likely a real problem)
|
|
384
|
+
if _fallback_count[0] == 10:
|
|
385
|
+
warnings.warn(
|
|
386
|
+
f"MFA has fallen back to native SDPA {_fallback_count[0]} times. "
|
|
387
|
+
f"Last error: {_last_fallback_error[0]}. "
|
|
388
|
+
f"Set MFA_DEBUG=1 for details.",
|
|
389
|
+
UserWarning
|
|
390
|
+
)
|
|
297
391
|
|
|
298
392
|
if _debug and query.device.type == 'mps':
|
|
299
393
|
_call_count[0] += 1
|
|
@@ -508,13 +602,13 @@ def flash_attention_chunked(
|
|
|
508
602
|
return _C.forward(query, key, value, is_causal, None, 0)
|
|
509
603
|
|
|
510
604
|
# Initialize running statistics for online softmax
|
|
511
|
-
# m = running max, l = running sum of exp, acc = accumulated output
|
|
512
605
|
device = query.device
|
|
513
606
|
dtype = query.dtype
|
|
514
607
|
|
|
515
608
|
# Use float32 for numerical stability of softmax statistics
|
|
516
|
-
|
|
517
|
-
|
|
609
|
+
# running_L: base-2 logsumexp of all attention scores seen so far (-inf means no data yet)
|
|
610
|
+
# output_acc: weighted combination of outputs (weights sum to 1 after each update)
|
|
611
|
+
running_L = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
|
|
518
612
|
output_acc = torch.zeros((B, H, seq_len_q, D), device=device, dtype=torch.float32)
|
|
519
613
|
|
|
520
614
|
# Process K/V in chunks
|
|
@@ -535,51 +629,81 @@ def flash_attention_chunked(
|
|
|
535
629
|
# - Partial chunk (up to q) if start_idx <= q < end_idx
|
|
536
630
|
# - None of chunk if q < start_idx
|
|
537
631
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
# chunk_lse shape: (B, H, seq_len_q)
|
|
545
|
-
# We need to convert logsumexp to (max, sum) for online algorithm
|
|
546
|
-
chunk_lse = chunk_lse.unsqueeze(-1) # (B, H, seq_len_q, 1)
|
|
632
|
+
if is_causal:
|
|
633
|
+
# Create explicit causal mask for this chunk
|
|
634
|
+
# Query positions: 0 to seq_len_q-1
|
|
635
|
+
# Key positions in chunk: start_idx to end_idx-1
|
|
636
|
+
chunk_len = end_idx - start_idx
|
|
547
637
|
|
|
548
|
-
|
|
549
|
-
|
|
638
|
+
# Build mask: mask[q, k_local] = True means DON'T attend
|
|
639
|
+
# We want to attend when global_k_pos <= q
|
|
640
|
+
# global_k_pos = start_idx + k_local
|
|
641
|
+
# So: attend when start_idx + k_local <= q
|
|
642
|
+
# mask = start_idx + k_local > q
|
|
550
643
|
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
# We approximate chunk_max ≈ chunk_lse (valid when exp sum dominates)
|
|
644
|
+
q_pos = torch.arange(seq_len_q, device=device).view(1, 1, seq_len_q, 1)
|
|
645
|
+
k_pos = torch.arange(chunk_len, device=device).view(1, 1, 1, chunk_len) + start_idx
|
|
646
|
+
causal_mask = k_pos > q_pos # True = masked (don't attend)
|
|
555
647
|
|
|
556
|
-
|
|
648
|
+
# Expand to batch and heads
|
|
649
|
+
causal_mask = causal_mask.expand(B, H, seq_len_q, chunk_len)
|
|
557
650
|
|
|
558
|
-
|
|
559
|
-
|
|
651
|
+
# Call forward with explicit mask (is_causal=False since we handle it)
|
|
652
|
+
chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, causal_mask, 0)
|
|
653
|
+
else:
|
|
654
|
+
# Non-causal: just process the chunk directly
|
|
655
|
+
chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, None, 0)
|
|
560
656
|
|
|
561
|
-
#
|
|
562
|
-
#
|
|
563
|
-
|
|
564
|
-
#
|
|
565
|
-
|
|
657
|
+
# chunk_L shape: (B, H, seq_len_q)
|
|
658
|
+
# The kernel returns L = m + log2(l) where:
|
|
659
|
+
# m = max(scores * log2(e) / sqrt(D))
|
|
660
|
+
# l = sum(exp2(scores * log2(e) / sqrt(D) - m))
|
|
661
|
+
# This is a base-2 logsumexp: L = log2(sum(exp2(scaled_scores)))
|
|
662
|
+
chunk_L = chunk_lse.unsqueeze(-1).float() # (B, H, seq_len_q, 1)
|
|
566
663
|
|
|
567
|
-
#
|
|
568
|
-
|
|
569
|
-
correction_new = torch.exp(chunk_max - new_max)
|
|
664
|
+
# Convert chunk output to float32 for accumulation
|
|
665
|
+
chunk_out = chunk_out.float()
|
|
570
666
|
|
|
571
|
-
#
|
|
572
|
-
#
|
|
573
|
-
#
|
|
574
|
-
|
|
667
|
+
# Online softmax algorithm using base-2 representation
|
|
668
|
+
#
|
|
669
|
+
# Flash attention returns: chunk_out = softmax(scores) @ V
|
|
670
|
+
# The output is already normalized. For online combination:
|
|
671
|
+
# new_L = log2(2^running_L + 2^chunk_L)
|
|
672
|
+
# = max(running_L, chunk_L) + log2(2^(running_L - max) + 2^(chunk_L - max))
|
|
673
|
+
#
|
|
674
|
+
# The weights for combining outputs are:
|
|
675
|
+
# old_weight = 2^(running_L - new_L)
|
|
676
|
+
# new_weight = 2^(chunk_L - new_L)
|
|
677
|
+
# These weights sum to 1, so: output = old_weight * old_out + new_weight * new_out
|
|
678
|
+
|
|
679
|
+
# Compute new base-2 logsumexp
|
|
680
|
+
max_L = torch.maximum(running_L, chunk_L)
|
|
681
|
+
|
|
682
|
+
# Handle -inf case (no previous data)
|
|
683
|
+
# Use exp2 for base-2 (matches kernel's internal representation)
|
|
684
|
+
running_exp2 = torch.where(
|
|
685
|
+
running_L == float('-inf'),
|
|
686
|
+
torch.zeros_like(running_L),
|
|
687
|
+
torch.exp2(running_L - max_L)
|
|
688
|
+
)
|
|
689
|
+
chunk_exp2 = torch.exp2(chunk_L - max_L)
|
|
690
|
+
new_L = max_L + torch.log2(running_exp2 + chunk_exp2)
|
|
691
|
+
|
|
692
|
+
# Compute correction factors using base-2 exp
|
|
693
|
+
old_weight = torch.where(
|
|
694
|
+
running_L == float('-inf'),
|
|
695
|
+
torch.zeros_like(running_L),
|
|
696
|
+
torch.exp2(running_L - new_L)
|
|
697
|
+
)
|
|
698
|
+
new_weight = torch.exp2(chunk_L - new_L)
|
|
575
699
|
|
|
576
700
|
# Update accumulator
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
701
|
+
# Update accumulator
|
|
702
|
+
output_acc = output_acc * old_weight + chunk_out * new_weight
|
|
703
|
+
running_L = new_L
|
|
580
704
|
|
|
581
|
-
#
|
|
582
|
-
output = output_acc
|
|
705
|
+
# No final normalization needed - weights already sum to 1
|
|
706
|
+
output = output_acc
|
|
583
707
|
|
|
584
708
|
# Convert back to original dtype
|
|
585
709
|
return output.to(dtype)
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
#include <dlfcn.h>
|
|
15
15
|
#include <string>
|
|
16
16
|
#include <vector>
|
|
17
|
+
#include <mutex>
|
|
18
|
+
#include <atomic>
|
|
17
19
|
|
|
18
20
|
// ============================================================================
|
|
19
21
|
// MFA Bridge Function Types
|
|
@@ -67,7 +69,8 @@ static mfa_forward_fn g_mfa_forward = nullptr;
|
|
|
67
69
|
static mfa_backward_fn g_mfa_backward = nullptr;
|
|
68
70
|
static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
|
|
69
71
|
static void* g_dylib_handle = nullptr;
|
|
70
|
-
static bool g_initialized
|
|
72
|
+
static std::atomic<bool> g_initialized{false};
|
|
73
|
+
static std::mutex g_init_mutex;
|
|
71
74
|
|
|
72
75
|
// ============================================================================
|
|
73
76
|
// Load MFA Bridge Library
|
|
@@ -141,6 +144,24 @@ static bool load_mfa_bridge() {
|
|
|
141
144
|
return true;
|
|
142
145
|
}
|
|
143
146
|
|
|
147
|
+
// Thread-safe initialization helper
|
|
148
|
+
static void ensure_initialized() {
|
|
149
|
+
// Fast path: already initialized
|
|
150
|
+
if (g_initialized.load(std::memory_order_acquire)) {
|
|
151
|
+
return;
|
|
152
|
+
}
|
|
153
|
+
// Slow path: need to initialize with lock
|
|
154
|
+
std::lock_guard<std::mutex> lock(g_init_mutex);
|
|
155
|
+
// Double-check after acquiring lock
|
|
156
|
+
if (!g_initialized.load(std::memory_order_relaxed)) {
|
|
157
|
+
load_mfa_bridge();
|
|
158
|
+
if (!g_mfa_init()) {
|
|
159
|
+
throw std::runtime_error("Failed to initialize MFA");
|
|
160
|
+
}
|
|
161
|
+
g_initialized.store(true, std::memory_order_release);
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
144
165
|
// ============================================================================
|
|
145
166
|
// Get MTLBuffer from PyTorch MPS Tensor
|
|
146
167
|
// ============================================================================
|
|
@@ -359,14 +380,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
359
380
|
const c10::optional<at::Tensor>& attn_mask, // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
|
|
360
381
|
int64_t window_size // 0 = full attention, >0 = sliding window
|
|
361
382
|
) {
|
|
362
|
-
//
|
|
363
|
-
|
|
364
|
-
load_mfa_bridge();
|
|
365
|
-
if (!g_mfa_init()) {
|
|
366
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
367
|
-
}
|
|
368
|
-
g_initialized = true;
|
|
369
|
-
}
|
|
383
|
+
// Thread-safe initialization
|
|
384
|
+
ensure_initialized();
|
|
370
385
|
|
|
371
386
|
// Validate inputs
|
|
372
387
|
TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
|
|
@@ -562,14 +577,8 @@ at::Tensor mps_flash_attention_forward_with_bias(
|
|
|
562
577
|
int64_t window_size,
|
|
563
578
|
int64_t bias_repeat_count // >0 means bias repeats every N batches (for window attention)
|
|
564
579
|
) {
|
|
565
|
-
//
|
|
566
|
-
|
|
567
|
-
load_mfa_bridge();
|
|
568
|
-
if (!g_mfa_init()) {
|
|
569
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
570
|
-
}
|
|
571
|
-
g_initialized = true;
|
|
572
|
-
}
|
|
580
|
+
// Thread-safe initialization
|
|
581
|
+
ensure_initialized();
|
|
573
582
|
|
|
574
583
|
// Check that v6/v7 API is available
|
|
575
584
|
TORCH_CHECK(g_mfa_create_kernel_v6 || g_mfa_create_kernel_v7,
|
|
@@ -735,14 +744,8 @@ at::Tensor mps_flash_attention_forward_quantized(
|
|
|
735
744
|
const c10::optional<at::Tensor>& attn_mask,
|
|
736
745
|
int64_t window_size
|
|
737
746
|
) {
|
|
738
|
-
//
|
|
739
|
-
|
|
740
|
-
load_mfa_bridge();
|
|
741
|
-
if (!g_mfa_init()) {
|
|
742
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
743
|
-
}
|
|
744
|
-
g_initialized = true;
|
|
745
|
-
}
|
|
747
|
+
// Thread-safe initialization
|
|
748
|
+
ensure_initialized();
|
|
746
749
|
|
|
747
750
|
// Check that v4 API is available
|
|
748
751
|
TORCH_CHECK(g_mfa_create_kernel_v4, "Quantized attention requires MFA v4 API (update libMFABridge.dylib)");
|
|
@@ -992,14 +995,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
992
995
|
int64_t window_size, // 0 = full attention, >0 = sliding window
|
|
993
996
|
bool bf16_backward // true = use BF16 intermediates for ~2x faster backward
|
|
994
997
|
) {
|
|
995
|
-
//
|
|
996
|
-
|
|
997
|
-
load_mfa_bridge();
|
|
998
|
-
if (!g_mfa_init()) {
|
|
999
|
-
throw std::runtime_error("Failed to initialize MFA");
|
|
1000
|
-
}
|
|
1001
|
-
g_initialized = true;
|
|
1002
|
-
}
|
|
998
|
+
// Thread-safe initialization
|
|
999
|
+
ensure_initialized();
|
|
1003
1000
|
|
|
1004
1001
|
// Validate inputs
|
|
1005
1002
|
TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
|
|
@@ -34,4 +34,5 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
|
|
|
34
34
|
mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib
|
|
35
35
|
mps_flash_attn/kernels/manifest.json
|
|
36
36
|
mps_flash_attn/lib/libMFABridge.dylib
|
|
37
|
+
tests/test_issues.py
|
|
37
38
|
tests/test_mfa_v2.py
|
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for known issues and fixes in mps-flash-attention.
|
|
3
|
+
|
|
4
|
+
These tests verify that:
|
|
5
|
+
1. Version is consistent between setup.py and __init__.py
|
|
6
|
+
2. __all__ exports are properly defined
|
|
7
|
+
3. Scale validation works correctly
|
|
8
|
+
4. Chunked attention accuracy is within expected bounds
|
|
9
|
+
5. All exported symbols are importable
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import pytest
|
|
13
|
+
import torch
|
|
14
|
+
import math
|
|
15
|
+
import warnings
|
|
16
|
+
import re
|
|
17
|
+
import sys
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Get the version from setup.py for comparison
|
|
22
|
+
def get_setup_version():
|
|
23
|
+
setup_path = Path(__file__).parent.parent / "setup.py"
|
|
24
|
+
content = setup_path.read_text()
|
|
25
|
+
match = re.search(r'version="([^"]+)"', content)
|
|
26
|
+
return match.group(1) if match else None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TestVersionConsistency:
|
|
30
|
+
"""Test that version is consistent across the package."""
|
|
31
|
+
|
|
32
|
+
def test_version_matches_setup(self):
|
|
33
|
+
"""Version in __init__.py should match setup.py."""
|
|
34
|
+
import mps_flash_attn
|
|
35
|
+
setup_version = get_setup_version()
|
|
36
|
+
assert setup_version is not None, "Could not parse version from setup.py"
|
|
37
|
+
assert mps_flash_attn.__version__ == setup_version, (
|
|
38
|
+
f"Version mismatch: __init__.py has {mps_flash_attn.__version__}, "
|
|
39
|
+
f"setup.py has {setup_version}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TestExports:
|
|
44
|
+
"""Test that __all__ exports are properly defined."""
|
|
45
|
+
|
|
46
|
+
def test_all_is_defined(self):
|
|
47
|
+
"""__all__ should be defined."""
|
|
48
|
+
import mps_flash_attn
|
|
49
|
+
assert hasattr(mps_flash_attn, "__all__"), "__all__ not defined"
|
|
50
|
+
assert isinstance(mps_flash_attn.__all__, list), "__all__ should be a list"
|
|
51
|
+
assert len(mps_flash_attn.__all__) > 0, "__all__ should not be empty"
|
|
52
|
+
|
|
53
|
+
def test_all_symbols_exist(self):
|
|
54
|
+
"""All symbols in __all__ should be importable."""
|
|
55
|
+
import mps_flash_attn
|
|
56
|
+
missing = []
|
|
57
|
+
for name in mps_flash_attn.__all__:
|
|
58
|
+
if not hasattr(mps_flash_attn, name):
|
|
59
|
+
missing.append(name)
|
|
60
|
+
assert not missing, f"Missing symbols from __all__: {missing}"
|
|
61
|
+
|
|
62
|
+
def test_core_functions_exported(self):
|
|
63
|
+
"""Core functions should be in __all__."""
|
|
64
|
+
import mps_flash_attn
|
|
65
|
+
required = [
|
|
66
|
+
"flash_attention",
|
|
67
|
+
"flash_attention_chunked",
|
|
68
|
+
"is_available",
|
|
69
|
+
"__version__",
|
|
70
|
+
]
|
|
71
|
+
for name in required:
|
|
72
|
+
assert name in mps_flash_attn.__all__, f"{name} missing from __all__"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class TestScaleValidation:
|
|
76
|
+
"""Test scale parameter validation."""
|
|
77
|
+
|
|
78
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
79
|
+
def test_negative_scale_raises(self):
|
|
80
|
+
"""Negative scale should raise ValueError."""
|
|
81
|
+
import mps_flash_attn
|
|
82
|
+
if not mps_flash_attn.is_available():
|
|
83
|
+
pytest.skip("MFA not available")
|
|
84
|
+
|
|
85
|
+
q = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
86
|
+
k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
87
|
+
v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
88
|
+
|
|
89
|
+
with pytest.raises(ValueError, match="scale must be positive"):
|
|
90
|
+
mps_flash_attn.flash_attention(q, k, v, scale=-1.0)
|
|
91
|
+
|
|
92
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
93
|
+
def test_zero_scale_raises(self):
|
|
94
|
+
"""Zero scale should raise ValueError."""
|
|
95
|
+
import mps_flash_attn
|
|
96
|
+
if not mps_flash_attn.is_available():
|
|
97
|
+
pytest.skip("MFA not available")
|
|
98
|
+
|
|
99
|
+
q = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
100
|
+
k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
101
|
+
v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
102
|
+
|
|
103
|
+
with pytest.raises(ValueError, match="scale must be positive"):
|
|
104
|
+
mps_flash_attn.flash_attention(q, k, v, scale=0.0)
|
|
105
|
+
|
|
106
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
107
|
+
def test_extreme_scale_warns(self):
|
|
108
|
+
"""Extreme scale values should warn."""
|
|
109
|
+
import mps_flash_attn
|
|
110
|
+
if not mps_flash_attn.is_available():
|
|
111
|
+
pytest.skip("MFA not available")
|
|
112
|
+
|
|
113
|
+
q = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
114
|
+
k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
115
|
+
v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
116
|
+
|
|
117
|
+
# Very large scale (1000x default)
|
|
118
|
+
with warnings.catch_warnings(record=True) as w:
|
|
119
|
+
warnings.simplefilter("always")
|
|
120
|
+
mps_flash_attn.flash_attention(q, k, v, scale=100.0) # default is ~0.177
|
|
121
|
+
assert any("very different from default" in str(warning.message) for warning in w), \
|
|
122
|
+
"Should warn about extreme scale"
|
|
123
|
+
|
|
124
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
125
|
+
def test_normal_scale_no_warning(self):
|
|
126
|
+
"""Normal scale values should not warn."""
|
|
127
|
+
import mps_flash_attn
|
|
128
|
+
if not mps_flash_attn.is_available():
|
|
129
|
+
pytest.skip("MFA not available")
|
|
130
|
+
|
|
131
|
+
q = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
132
|
+
k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
133
|
+
v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
134
|
+
|
|
135
|
+
default_scale = 1.0 / math.sqrt(32)
|
|
136
|
+
with warnings.catch_warnings(record=True) as w:
|
|
137
|
+
warnings.simplefilter("always")
|
|
138
|
+
mps_flash_attn.flash_attention(q, k, v, scale=default_scale)
|
|
139
|
+
scale_warnings = [x for x in w if "scale" in str(x.message).lower()]
|
|
140
|
+
assert not scale_warnings, f"Should not warn about normal scale: {scale_warnings}"
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class TestChunkedAccuracy:
|
|
144
|
+
"""Test chunked attention accuracy vs non-chunked."""
|
|
145
|
+
|
|
146
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
147
|
+
def test_chunked_vs_regular_accuracy(self):
|
|
148
|
+
"""Chunked attention should match regular within tolerance.
|
|
149
|
+
|
|
150
|
+
Note: Online softmax approximation has inherent numerical error.
|
|
151
|
+
The max difference is typically ~0.05-0.06 due to the approximation
|
|
152
|
+
at line 556 where chunk_max = chunk_lse (simplified online softmax).
|
|
153
|
+
"""
|
|
154
|
+
import mps_flash_attn
|
|
155
|
+
if not mps_flash_attn.is_available():
|
|
156
|
+
pytest.skip("MFA not available")
|
|
157
|
+
|
|
158
|
+
torch.manual_seed(42)
|
|
159
|
+
# Use a moderate size to keep test fast
|
|
160
|
+
B, H, N, D = 2, 4, 2048, 64
|
|
161
|
+
|
|
162
|
+
q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
163
|
+
k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
164
|
+
v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
165
|
+
|
|
166
|
+
# Regular flash attention
|
|
167
|
+
out_regular = mps_flash_attn.flash_attention(q, k, v)
|
|
168
|
+
|
|
169
|
+
# Chunked flash attention with smaller chunk
|
|
170
|
+
out_chunked = mps_flash_attn.flash_attention_chunked(q, k, v, chunk_size=512)
|
|
171
|
+
|
|
172
|
+
torch.mps.synchronize()
|
|
173
|
+
|
|
174
|
+
# Compute differences
|
|
175
|
+
max_diff = (out_regular - out_chunked).abs().max().item()
|
|
176
|
+
mean_diff = (out_regular - out_chunked).abs().mean().item()
|
|
177
|
+
|
|
178
|
+
# Online softmax has inherent ~0.05 max error
|
|
179
|
+
# This is NOT a regression - it's expected behavior
|
|
180
|
+
assert max_diff < 0.10, f"Max diff {max_diff:.6f} too large (expected < 0.10)"
|
|
181
|
+
assert mean_diff < 0.01, f"Mean diff {mean_diff:.6f} too large (expected < 0.01)"
|
|
182
|
+
|
|
183
|
+
# Log actual values for transparency
|
|
184
|
+
print(f"\nChunked accuracy: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
|
185
|
+
|
|
186
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
187
|
+
def test_chunked_causal_accuracy(self):
|
|
188
|
+
"""Chunked causal attention should also be accurate."""
|
|
189
|
+
import mps_flash_attn
|
|
190
|
+
if not mps_flash_attn.is_available():
|
|
191
|
+
pytest.skip("MFA not available")
|
|
192
|
+
|
|
193
|
+
torch.manual_seed(123)
|
|
194
|
+
B, H, N, D = 2, 4, 1024, 64
|
|
195
|
+
|
|
196
|
+
q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
197
|
+
k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
198
|
+
v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
199
|
+
|
|
200
|
+
out_regular = mps_flash_attn.flash_attention(q, k, v, is_causal=True)
|
|
201
|
+
out_chunked = mps_flash_attn.flash_attention_chunked(q, k, v, is_causal=True, chunk_size=256)
|
|
202
|
+
|
|
203
|
+
torch.mps.synchronize()
|
|
204
|
+
|
|
205
|
+
max_diff = (out_regular - out_chunked).abs().max().item()
|
|
206
|
+
mean_diff = (out_regular - out_chunked).abs().mean().item()
|
|
207
|
+
|
|
208
|
+
assert max_diff < 0.10, f"Causal max diff {max_diff:.6f} too large"
|
|
209
|
+
assert mean_diff < 0.01, f"Causal mean diff {mean_diff:.6f} too large"
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class TestInputValidation:
|
|
213
|
+
"""Test input validation catches errors early."""
|
|
214
|
+
|
|
215
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
216
|
+
def test_wrong_device_raises(self):
|
|
217
|
+
"""Tensors on wrong device should raise."""
|
|
218
|
+
import mps_flash_attn
|
|
219
|
+
if not mps_flash_attn.is_available():
|
|
220
|
+
pytest.skip("MFA not available")
|
|
221
|
+
|
|
222
|
+
q = torch.randn(1, 4, 64, 32) # CPU tensor
|
|
223
|
+
k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
224
|
+
v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
225
|
+
|
|
226
|
+
with pytest.raises((RuntimeError, ValueError)):
|
|
227
|
+
mps_flash_attn.flash_attention(q, k, v)
|
|
228
|
+
|
|
229
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
230
|
+
def test_wrong_dims_raises(self):
|
|
231
|
+
"""Wrong tensor dimensions should raise."""
|
|
232
|
+
import mps_flash_attn
|
|
233
|
+
if not mps_flash_attn.is_available():
|
|
234
|
+
pytest.skip("MFA not available")
|
|
235
|
+
|
|
236
|
+
q = torch.randn(1, 64, 32, device="mps", dtype=torch.float16) # 3D instead of 4D
|
|
237
|
+
k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
238
|
+
v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
|
|
239
|
+
|
|
240
|
+
with pytest.raises(RuntimeError):
|
|
241
|
+
mps_flash_attn.flash_attention(q, k, v)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class TestMaskValidation:
|
|
245
|
+
"""Test attention mask shape validation."""
|
|
246
|
+
|
|
247
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
248
|
+
def test_valid_mask_shape(self):
|
|
249
|
+
"""Valid mask shapes should work."""
|
|
250
|
+
import mps_flash_attn
|
|
251
|
+
if not mps_flash_attn.is_available():
|
|
252
|
+
pytest.skip("MFA not available")
|
|
253
|
+
|
|
254
|
+
B, H, N, D = 2, 4, 64, 32
|
|
255
|
+
q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
256
|
+
k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
257
|
+
v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
258
|
+
|
|
259
|
+
# Full mask
|
|
260
|
+
mask = torch.zeros(B, H, N, N, dtype=torch.bool, device="mps")
|
|
261
|
+
out = mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
|
|
262
|
+
assert out.shape == (B, H, N, D)
|
|
263
|
+
|
|
264
|
+
# Broadcast over batch
|
|
265
|
+
mask = torch.zeros(1, H, N, N, dtype=torch.bool, device="mps")
|
|
266
|
+
out = mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
|
|
267
|
+
assert out.shape == (B, H, N, D)
|
|
268
|
+
|
|
269
|
+
# Broadcast over heads
|
|
270
|
+
mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device="mps")
|
|
271
|
+
out = mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
|
|
272
|
+
assert out.shape == (B, H, N, D)
|
|
273
|
+
|
|
274
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
275
|
+
def test_invalid_mask_dims_raises(self):
|
|
276
|
+
"""Wrong mask dimensions should raise."""
|
|
277
|
+
import mps_flash_attn
|
|
278
|
+
if not mps_flash_attn.is_available():
|
|
279
|
+
pytest.skip("MFA not available")
|
|
280
|
+
|
|
281
|
+
B, H, N, D = 2, 4, 64, 32
|
|
282
|
+
q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
283
|
+
k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
284
|
+
v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
285
|
+
|
|
286
|
+
# 3D mask (missing batch dim)
|
|
287
|
+
mask = torch.zeros(H, N, N, dtype=torch.bool, device="mps")
|
|
288
|
+
with pytest.raises(ValueError, match="must be 4D"):
|
|
289
|
+
mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
|
|
290
|
+
|
|
291
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
292
|
+
def test_invalid_mask_seq_len_raises(self):
|
|
293
|
+
"""Wrong mask sequence length should raise."""
|
|
294
|
+
import mps_flash_attn
|
|
295
|
+
if not mps_flash_attn.is_available():
|
|
296
|
+
pytest.skip("MFA not available")
|
|
297
|
+
|
|
298
|
+
B, H, N, D = 2, 4, 64, 32
|
|
299
|
+
q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
300
|
+
k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
301
|
+
v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
|
|
302
|
+
|
|
303
|
+
# Wrong N_kv
|
|
304
|
+
mask = torch.zeros(B, H, N, N // 2, dtype=torch.bool, device="mps")
|
|
305
|
+
with pytest.raises(ValueError, match="shape mismatch"):
|
|
306
|
+
mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class TestAutoContiguous:
|
|
310
|
+
"""Test auto-contiguous conversion."""
|
|
311
|
+
|
|
312
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
313
|
+
def test_non_contiguous_works(self):
|
|
314
|
+
"""Non-contiguous tensors should be auto-converted."""
|
|
315
|
+
import mps_flash_attn
|
|
316
|
+
if not mps_flash_attn.is_available():
|
|
317
|
+
pytest.skip("MFA not available")
|
|
318
|
+
|
|
319
|
+
B, H, N, D = 2, 4, 64, 32
|
|
320
|
+
# Create non-contiguous tensor via transpose
|
|
321
|
+
q = torch.randn(B, N, H, D, device="mps", dtype=torch.float16).transpose(1, 2)
|
|
322
|
+
k = torch.randn(B, N, H, D, device="mps", dtype=torch.float16).transpose(1, 2)
|
|
323
|
+
v = torch.randn(B, N, H, D, device="mps", dtype=torch.float16).transpose(1, 2)
|
|
324
|
+
|
|
325
|
+
assert not q.is_contiguous(), "Test setup: q should be non-contiguous"
|
|
326
|
+
|
|
327
|
+
# Should work without error
|
|
328
|
+
out = mps_flash_attn.flash_attention(q, k, v)
|
|
329
|
+
assert out.shape == (B, H, N, D)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class TestFallbackTracking:
|
|
333
|
+
"""Test SDPA fallback tracking."""
|
|
334
|
+
|
|
335
|
+
def test_fallback_counter_exists(self):
|
|
336
|
+
"""The replace_sdpa function should track fallbacks."""
|
|
337
|
+
import mps_flash_attn
|
|
338
|
+
# Just verify the module has the replace_sdpa function
|
|
339
|
+
assert hasattr(mps_flash_attn, 'replace_sdpa')
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class TestThreadSafety:
|
|
343
|
+
"""Test thread-safe initialization.
|
|
344
|
+
|
|
345
|
+
Note: MPS itself doesn't support concurrent operations from multiple threads,
|
|
346
|
+
so we can't test concurrent flash_attention calls. But we CAN test that
|
|
347
|
+
the initialization code is thread-safe by checking that multiple threads
|
|
348
|
+
attempting to initialize don't cause crashes or double-initialization.
|
|
349
|
+
"""
|
|
350
|
+
|
|
351
|
+
@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
|
|
352
|
+
def test_concurrent_initialization(self):
|
|
353
|
+
"""Multiple threads initializing simultaneously should not crash.
|
|
354
|
+
|
|
355
|
+
This tests the thread-safe initialization in the C++ code.
|
|
356
|
+
We spawn multiple threads that all try to call flash_attention
|
|
357
|
+
with a barrier to synchronize their start times. Only ONE should
|
|
358
|
+
actually initialize the MFA bridge.
|
|
359
|
+
|
|
360
|
+
Note: We serialize the actual MPS operations to avoid driver crashes,
|
|
361
|
+
but the initialization race is what we're testing.
|
|
362
|
+
"""
|
|
363
|
+
import mps_flash_attn
|
|
364
|
+
if not mps_flash_attn.is_available():
|
|
365
|
+
pytest.skip("MFA not available")
|
|
366
|
+
|
|
367
|
+
import threading
|
|
368
|
+
import queue
|
|
369
|
+
|
|
370
|
+
num_threads = 4
|
|
371
|
+
results = queue.Queue()
|
|
372
|
+
barrier = threading.Barrier(num_threads)
|
|
373
|
+
|
|
374
|
+
# Use a lock to serialize MPS operations (avoid driver crash)
|
|
375
|
+
# but let initialization race
|
|
376
|
+
mps_lock = threading.Lock()
|
|
377
|
+
|
|
378
|
+
def worker(thread_id):
|
|
379
|
+
try:
|
|
380
|
+
# All threads wait here, then start simultaneously
|
|
381
|
+
barrier.wait()
|
|
382
|
+
|
|
383
|
+
# Each thread creates its own tensors and calls flash_attention
|
|
384
|
+
# The initialization will race, but should be thread-safe
|
|
385
|
+
with mps_lock: # Serialize MPS operations
|
|
386
|
+
q = torch.randn(1, 2, 32, 16, device="mps", dtype=torch.float16)
|
|
387
|
+
k = torch.randn(1, 2, 32, 16, device="mps", dtype=torch.float16)
|
|
388
|
+
v = torch.randn(1, 2, 32, 16, device="mps", dtype=torch.float16)
|
|
389
|
+
|
|
390
|
+
out = mps_flash_attn.flash_attention(q, k, v)
|
|
391
|
+
torch.mps.synchronize()
|
|
392
|
+
|
|
393
|
+
results.put((thread_id, "success", out.shape))
|
|
394
|
+
except Exception as e:
|
|
395
|
+
results.put((thread_id, "error", str(e)))
|
|
396
|
+
|
|
397
|
+
# Spawn threads
|
|
398
|
+
threads = []
|
|
399
|
+
for i in range(num_threads):
|
|
400
|
+
t = threading.Thread(target=worker, args=(i,))
|
|
401
|
+
threads.append(t)
|
|
402
|
+
t.start()
|
|
403
|
+
|
|
404
|
+
# Wait for completion
|
|
405
|
+
for t in threads:
|
|
406
|
+
t.join(timeout=30)
|
|
407
|
+
|
|
408
|
+
# Check results
|
|
409
|
+
successes = 0
|
|
410
|
+
errors = []
|
|
411
|
+
while not results.empty():
|
|
412
|
+
thread_id, status, data = results.get()
|
|
413
|
+
if status == "success":
|
|
414
|
+
successes += 1
|
|
415
|
+
assert data == (1, 2, 32, 16), f"Thread {thread_id} got wrong shape: {data}"
|
|
416
|
+
else:
|
|
417
|
+
errors.append(f"Thread {thread_id}: {data}")
|
|
418
|
+
|
|
419
|
+
assert successes == num_threads, f"Only {successes}/{num_threads} succeeded. Errors: {errors}"
|
|
420
|
+
|
|
421
|
+
def test_atomic_flag_exists(self):
|
|
422
|
+
"""Verify the C++ code uses atomic for g_initialized.
|
|
423
|
+
|
|
424
|
+
This is a static check - we grep the source to ensure the fix is in place.
|
|
425
|
+
"""
|
|
426
|
+
cpp_file = Path(__file__).parent.parent / "mps_flash_attn" / "csrc" / "mps_flash_attn.mm"
|
|
427
|
+
if not cpp_file.exists():
|
|
428
|
+
pytest.skip("C++ source not found")
|
|
429
|
+
|
|
430
|
+
content = cpp_file.read_text()
|
|
431
|
+
|
|
432
|
+
# Check for atomic<bool>
|
|
433
|
+
assert "std::atomic<bool>" in content or "atomic<bool>" in content, \
|
|
434
|
+
"g_initialized should be std::atomic<bool>"
|
|
435
|
+
|
|
436
|
+
# Check for mutex
|
|
437
|
+
assert "std::mutex" in content or "mutex" in content, \
|
|
438
|
+
"Should have mutex for thread-safe init"
|
|
439
|
+
|
|
440
|
+
# Check for ensure_initialized helper
|
|
441
|
+
assert "ensure_initialized" in content, \
|
|
442
|
+
"Should have ensure_initialized() helper function"
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
if __name__ == "__main__":
|
|
446
|
+
pytest.main([__file__, "-v"])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|