mps-flash-attn 0.2.8__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (40) hide show
  1. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/PKG-INFO +1 -1
  2. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/__init__.py +223 -52
  3. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/csrc/mps_flash_attn.mm +30 -33
  4. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/PKG-INFO +1 -1
  5. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/SOURCES.txt +1 -0
  6. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/pyproject.toml +1 -1
  7. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/setup.py +1 -1
  8. mps_flash_attn-0.3.0/tests/test_issues.py +446 -0
  9. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/LICENSE +0 -0
  10. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/README.md +0 -0
  11. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/benchmark.py +0 -0
  12. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  13. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  14. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  15. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  16. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  17. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  18. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  19. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  20. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  21. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  22. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  23. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  24. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  25. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  26. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  27. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  28. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  29. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  30. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  31. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  32. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  33. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  34. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/kernels/manifest.json +0 -0
  35. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  36. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  37. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/requires.txt +0 -0
  38. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/mps_flash_attn.egg-info/top_level.txt +0 -0
  39. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/setup.cfg +0 -0
  40. {mps_flash_attn-0.2.8 → mps_flash_attn-0.3.0}/tests/test_mfa_v2.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.8
3
+ Version: 0.3.0
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,13 +4,42 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.2.8"
7
+ __version__ = "0.3.0"
8
+
9
+ __all__ = [
10
+ # Core functions
11
+ "flash_attention",
12
+ "flash_attention_with_bias",
13
+ "flash_attention_chunked",
14
+ # Quantized attention
15
+ "flash_attention_fp8",
16
+ "flash_attention_int8",
17
+ "flash_attention_nf4",
18
+ "quantize_kv_fp8",
19
+ "quantize_kv_int8",
20
+ "quantize_kv_nf4",
21
+ # Utilities
22
+ "replace_sdpa",
23
+ "precompile",
24
+ "clear_cache",
25
+ "register_custom_op",
26
+ "is_available",
27
+ "convert_mask",
28
+ # Constants
29
+ "QUANT_FP8_E4M3",
30
+ "QUANT_FP8_E5M2",
31
+ "QUANT_INT8",
32
+ "QUANT_NF4",
33
+ # Version
34
+ "__version__",
35
+ ]
8
36
 
9
37
  import torch
10
- from typing import Optional
38
+ from typing import Optional, Tuple
11
39
  import math
12
40
  import threading
13
41
  import os
42
+ import warnings
14
43
 
15
44
  # Try to import the C++ extension
16
45
  try:
@@ -30,6 +59,20 @@ def is_available() -> bool:
30
59
  return _HAS_MFA and torch.backends.mps.is_available()
31
60
 
32
61
 
62
+ def _ensure_contiguous(tensor: torch.Tensor, name: str) -> torch.Tensor:
63
+ """Ensure tensor is contiguous, with a debug warning if conversion needed."""
64
+ if tensor.is_contiguous():
65
+ return tensor
66
+ # Auto-convert with debug info
67
+ if os.environ.get("MFA_DEBUG", "0") == "1":
68
+ warnings.warn(
69
+ f"MFA: {name} tensor was not contiguous (stride={tensor.stride()}), "
70
+ f"auto-converting. For best performance, ensure inputs are contiguous.",
71
+ UserWarning
72
+ )
73
+ return tensor.contiguous()
74
+
75
+
33
76
  def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
34
77
  """
35
78
  Convert attention mask to MFA's boolean format.
@@ -176,6 +219,20 @@ def flash_attention(
176
219
  if not torch.backends.mps.is_available():
177
220
  raise RuntimeError("MPS not available")
178
221
 
222
+ # Validate scale parameter
223
+ if scale is not None:
224
+ if scale <= 0:
225
+ raise ValueError(f"scale must be positive, got {scale}")
226
+ # Warn about extreme scale values that could cause numerical issues
227
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
228
+ if scale < default_scale * 0.01 or scale > default_scale * 100:
229
+ warnings.warn(
230
+ f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
231
+ "this may cause numerical issues",
232
+ UserWarning,
233
+ stacklevel=2
234
+ )
235
+
179
236
  # Validate device
180
237
  if query.device.type != 'mps':
181
238
  raise ValueError("query must be on MPS device")
@@ -186,6 +243,31 @@ def flash_attention(
186
243
  if attn_mask is not None and attn_mask.device.type != 'mps':
187
244
  raise ValueError("attn_mask must be on MPS device")
188
245
 
246
+ # Ensure contiguous (auto-convert with debug warning)
247
+ query = _ensure_contiguous(query, "query")
248
+ key = _ensure_contiguous(key, "key")
249
+ value = _ensure_contiguous(value, "value")
250
+ if attn_mask is not None:
251
+ attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
252
+ # Validate mask shape
253
+ B, H, N_q, D = query.shape
254
+ N_kv = key.shape[2]
255
+ if attn_mask.dim() != 4:
256
+ raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
257
+ mb, mh, mq, mk = attn_mask.shape
258
+ if mq != N_q or mk != N_kv:
259
+ raise ValueError(
260
+ f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv})"
261
+ )
262
+ if mb != 1 and mb != B:
263
+ raise ValueError(
264
+ f"attn_mask batch size must be 1 or {B}, got {mb}"
265
+ )
266
+ if mh != 1 and mh != H:
267
+ raise ValueError(
268
+ f"attn_mask head count must be 1 or {H}, got {mh}"
269
+ )
270
+
189
271
  # Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
190
272
  if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
191
273
  # Apply scale if provided
@@ -213,6 +295,10 @@ def replace_sdpa():
213
295
  import torch.nn.functional as F
214
296
 
215
297
  original_sdpa = F.scaled_dot_product_attention
298
+ _debug = os.environ.get("MFA_DEBUG", "0") == "1"
299
+ _call_count = [0] # mutable for closure
300
+ _fallback_count = [0] # track fallbacks for warning
301
+ _last_fallback_error = [None]
216
302
 
217
303
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
218
304
  is_causal=False, scale=None, enable_gqa=False, **kwargs):
@@ -224,37 +310,92 @@ def replace_sdpa():
224
310
  # seq=1024: 2.3-3.7x (MFA much faster)
225
311
  # seq=2048: 2.2-3.9x (MFA much faster)
226
312
  # seq=4096: 2.1-3.7x (MFA much faster)
313
+ # Determine seq_len based on tensor dimensionality
314
+ # 4D: (B, H, S, D) -> seq_len = shape[2]
315
+ # 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
316
+ is_3d = query.ndim == 3
317
+ seq_len = query.shape[1] if is_3d else query.shape[2]
318
+
227
319
  if (query.device.type == 'mps' and
228
320
  dropout_p == 0.0 and
229
321
  _HAS_MFA and
230
- query.shape[2] >= 512):
322
+ query.ndim >= 3 and
323
+ seq_len >= 512):
231
324
  try:
325
+ q, k, v = query, key, value
326
+
327
+ # Handle 3D tensors (B, S, D) - treat as single-head attention
328
+ # Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
329
+ if is_3d:
330
+ q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
331
+ k = k.unsqueeze(1)
332
+ v = v.unsqueeze(1)
333
+
232
334
  # Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
233
335
  # Common in Llama 2/3, Mistral, Qwen, etc.
234
- k, v = key, value
235
- if enable_gqa and query.shape[1] != key.shape[1]:
336
+ # NOTE: Always expand when heads mismatch, not just when enable_gqa=True
337
+ # Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
338
+ # MPS SDPA doesn't support native GQA - we handle it here
339
+ if q.shape[1] != k.shape[1]:
236
340
  # Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
237
- n_rep = query.shape[1] // key.shape[1]
238
- k = key.repeat_interleave(n_rep, dim=1)
239
- v = value.repeat_interleave(n_rep, dim=1)
341
+ n_rep = q.shape[1] // k.shape[1]
342
+ k = k.repeat_interleave(n_rep, dim=1)
343
+ v = v.repeat_interleave(n_rep, dim=1)
240
344
 
241
345
  # Convert float mask to bool mask if needed
242
346
  # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
243
347
  # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
244
348
  mfa_mask = None
245
349
  if attn_mask is not None:
350
+ if _debug:
351
+ print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
246
352
  if attn_mask.dtype == torch.bool:
247
- # Boolean mask: True means masked (don't attend)
248
- mfa_mask = attn_mask
353
+ # PyTorch SDPA bool mask: True = ATTEND, False = MASKED
354
+ # MFA bool mask: True = MASKED, False = ATTEND
355
+ # They're opposite! Invert it.
356
+ mfa_mask = ~attn_mask
249
357
  else:
250
358
  # Float mask: typically -inf for masked positions, 0 for unmasked
251
359
  # Convert: positions with large negative values -> True (masked)
252
360
  # Use -1e3 threshold to catch -1000, -10000, -inf, etc.
253
361
  mfa_mask = attn_mask <= -1e3
254
- return flash_attention(query, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
255
- except Exception:
256
- # Fall back to original on any error
257
- pass
362
+ if _debug:
363
+ print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
364
+
365
+ out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
366
+
367
+ # Squeeze back for 3D input
368
+ if is_3d:
369
+ out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
370
+
371
+ if _debug:
372
+ _call_count[0] += 1
373
+ print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
374
+
375
+ return out
376
+ except Exception as e:
377
+ # Fall back to original on any error, but track it
378
+ _fallback_count[0] += 1
379
+ _last_fallback_error[0] = str(e)
380
+ if _debug:
381
+ import traceback
382
+ print(f"[MFA FALLBACK #{_fallback_count[0]}] shape={tuple(query.shape)}\n{traceback.format_exc()}")
383
+ # Warn user after repeated fallbacks (likely a real problem)
384
+ if _fallback_count[0] == 10:
385
+ warnings.warn(
386
+ f"MFA has fallen back to native SDPA {_fallback_count[0]} times. "
387
+ f"Last error: {_last_fallback_error[0]}. "
388
+ f"Set MFA_DEBUG=1 for details.",
389
+ UserWarning
390
+ )
391
+
392
+ if _debug and query.device.type == 'mps':
393
+ _call_count[0] += 1
394
+ reason = []
395
+ if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
396
+ if query.ndim < 3: reason.append(f"ndim={query.ndim}")
397
+ if seq_len < 512: reason.append(f"seq={seq_len}<512")
398
+ print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
258
399
 
259
400
  return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
260
401
 
@@ -461,13 +602,13 @@ def flash_attention_chunked(
461
602
  return _C.forward(query, key, value, is_causal, None, 0)
462
603
 
463
604
  # Initialize running statistics for online softmax
464
- # m = running max, l = running sum of exp, acc = accumulated output
465
605
  device = query.device
466
606
  dtype = query.dtype
467
607
 
468
608
  # Use float32 for numerical stability of softmax statistics
469
- running_max = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
470
- running_sum = torch.zeros((B, H, seq_len_q, 1), device=device, dtype=torch.float32)
609
+ # running_L: base-2 logsumexp of all attention scores seen so far (-inf means no data yet)
610
+ # output_acc: weighted combination of outputs (weights sum to 1 after each update)
611
+ running_L = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
471
612
  output_acc = torch.zeros((B, H, seq_len_q, D), device=device, dtype=torch.float32)
472
613
 
473
614
  # Process K/V in chunks
@@ -488,51 +629,81 @@ def flash_attention_chunked(
488
629
  # - Partial chunk (up to q) if start_idx <= q < end_idx
489
630
  # - None of chunk if q < start_idx
490
631
 
491
- chunk_is_causal = is_causal and (end_idx <= seq_len_q)
632
+ if is_causal:
633
+ # Create explicit causal mask for this chunk
634
+ # Query positions: 0 to seq_len_q-1
635
+ # Key positions in chunk: start_idx to end_idx-1
636
+ chunk_len = end_idx - start_idx
492
637
 
493
- # Compute attention for this chunk
494
- # forward_with_lse returns (output, logsumexp) where logsumexp = m + log(l)
495
- chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, chunk_is_causal, None, 0)
638
+ # Build mask: mask[q, k_local] = True means DON'T attend
639
+ # We want to attend when global_k_pos <= q
640
+ # global_k_pos = start_idx + k_local
641
+ # So: attend when start_idx + k_local <= q
642
+ # mask = start_idx + k_local > q
496
643
 
497
- # chunk_lse shape: (B, H, seq_len_q)
498
- # We need to convert logsumexp to (max, sum) for online algorithm
499
- chunk_lse = chunk_lse.unsqueeze(-1) # (B, H, seq_len_q, 1)
644
+ q_pos = torch.arange(seq_len_q, device=device).view(1, 1, seq_len_q, 1)
645
+ k_pos = torch.arange(chunk_len, device=device).view(1, 1, 1, chunk_len) + start_idx
646
+ causal_mask = k_pos > q_pos # True = masked (don't attend)
500
647
 
501
- # Convert chunk output to float32 for accumulation
502
- chunk_out = chunk_out.float()
503
-
504
- # Online softmax update:
505
- # new_max = max(running_max, chunk_max)
506
- # For flash attention, chunk_lse ≈ chunk_max + log(chunk_sum)
507
- # We approximate chunk_max ≈ chunk_lse (valid when exp sum dominates)
508
-
509
- chunk_max = chunk_lse # Approximation: logsumexp ≈ max when sum is dominated by max
648
+ # Expand to batch and heads
649
+ causal_mask = causal_mask.expand(B, H, seq_len_q, chunk_len)
510
650
 
511
- # Compute new max
512
- new_max = torch.maximum(running_max, chunk_max)
651
+ # Call forward with explicit mask (is_causal=False since we handle it)
652
+ chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, causal_mask, 0)
653
+ else:
654
+ # Non-causal: just process the chunk directly
655
+ chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, None, 0)
513
656
 
514
- # Rescale previous accumulator
515
- # correction_old = exp(running_max - new_max)
516
- correction_old = torch.exp(running_max - new_max)
517
- # Clip to avoid inf * 0 issues when running_max was -inf
518
- correction_old = torch.where(running_max == float('-inf'), torch.zeros_like(correction_old), correction_old)
657
+ # chunk_L shape: (B, H, seq_len_q)
658
+ # The kernel returns L = m + log2(l) where:
659
+ # m = max(scores * log2(e) / sqrt(D))
660
+ # l = sum(exp2(scores * log2(e) / sqrt(D) - m))
661
+ # This is a base-2 logsumexp: L = log2(sum(exp2(scaled_scores)))
662
+ chunk_L = chunk_lse.unsqueeze(-1).float() # (B, H, seq_len_q, 1)
519
663
 
520
- # Rescale chunk output
521
- # correction_new = exp(chunk_max - new_max)
522
- correction_new = torch.exp(chunk_max - new_max)
664
+ # Convert chunk output to float32 for accumulation
665
+ chunk_out = chunk_out.float()
523
666
 
524
- # For the sum, we need exp(chunk_lse - new_max) = exp(chunk_max + log(chunk_sum) - new_max)
525
- # = exp(chunk_max - new_max) * chunk_sum
526
- # But we only have logsumexp, so: exp(chunk_lse - new_max)
527
- chunk_sum_scaled = torch.exp(chunk_lse - new_max)
667
+ # Online softmax algorithm using base-2 representation
668
+ #
669
+ # Flash attention returns: chunk_out = softmax(scores) @ V
670
+ # The output is already normalized. For online combination:
671
+ # new_L = log2(2^running_L + 2^chunk_L)
672
+ # = max(running_L, chunk_L) + log2(2^(running_L - max) + 2^(chunk_L - max))
673
+ #
674
+ # The weights for combining outputs are:
675
+ # old_weight = 2^(running_L - new_L)
676
+ # new_weight = 2^(chunk_L - new_L)
677
+ # These weights sum to 1, so: output = old_weight * old_out + new_weight * new_out
678
+
679
+ # Compute new base-2 logsumexp
680
+ max_L = torch.maximum(running_L, chunk_L)
681
+
682
+ # Handle -inf case (no previous data)
683
+ # Use exp2 for base-2 (matches kernel's internal representation)
684
+ running_exp2 = torch.where(
685
+ running_L == float('-inf'),
686
+ torch.zeros_like(running_L),
687
+ torch.exp2(running_L - max_L)
688
+ )
689
+ chunk_exp2 = torch.exp2(chunk_L - max_L)
690
+ new_L = max_L + torch.log2(running_exp2 + chunk_exp2)
691
+
692
+ # Compute correction factors using base-2 exp
693
+ old_weight = torch.where(
694
+ running_L == float('-inf'),
695
+ torch.zeros_like(running_L),
696
+ torch.exp2(running_L - new_L)
697
+ )
698
+ new_weight = torch.exp2(chunk_L - new_L)
528
699
 
529
700
  # Update accumulator
530
- output_acc = output_acc * correction_old + chunk_out * correction_new
531
- running_sum = running_sum * correction_old + chunk_sum_scaled
532
- running_max = new_max
701
+ # Update accumulator
702
+ output_acc = output_acc * old_weight + chunk_out * new_weight
703
+ running_L = new_L
533
704
 
534
- # Final normalization
535
- output = output_acc / running_sum
705
+ # No final normalization needed - weights already sum to 1
706
+ output = output_acc
536
707
 
537
708
  # Convert back to original dtype
538
709
  return output.to(dtype)
@@ -14,6 +14,8 @@
14
14
  #include <dlfcn.h>
15
15
  #include <string>
16
16
  #include <vector>
17
+ #include <mutex>
18
+ #include <atomic>
17
19
 
18
20
  // ============================================================================
19
21
  // MFA Bridge Function Types
@@ -67,7 +69,8 @@ static mfa_forward_fn g_mfa_forward = nullptr;
67
69
  static mfa_backward_fn g_mfa_backward = nullptr;
68
70
  static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
69
71
  static void* g_dylib_handle = nullptr;
70
- static bool g_initialized = false;
72
+ static std::atomic<bool> g_initialized{false};
73
+ static std::mutex g_init_mutex;
71
74
 
72
75
  // ============================================================================
73
76
  // Load MFA Bridge Library
@@ -141,6 +144,24 @@ static bool load_mfa_bridge() {
141
144
  return true;
142
145
  }
143
146
 
147
+ // Thread-safe initialization helper
148
+ static void ensure_initialized() {
149
+ // Fast path: already initialized
150
+ if (g_initialized.load(std::memory_order_acquire)) {
151
+ return;
152
+ }
153
+ // Slow path: need to initialize with lock
154
+ std::lock_guard<std::mutex> lock(g_init_mutex);
155
+ // Double-check after acquiring lock
156
+ if (!g_initialized.load(std::memory_order_relaxed)) {
157
+ load_mfa_bridge();
158
+ if (!g_mfa_init()) {
159
+ throw std::runtime_error("Failed to initialize MFA");
160
+ }
161
+ g_initialized.store(true, std::memory_order_release);
162
+ }
163
+ }
164
+
144
165
  // ============================================================================
145
166
  // Get MTLBuffer from PyTorch MPS Tensor
146
167
  // ============================================================================
@@ -359,14 +380,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
359
380
  const c10::optional<at::Tensor>& attn_mask, // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
360
381
  int64_t window_size // 0 = full attention, >0 = sliding window
361
382
  ) {
362
- // Initialize MFA on first call
363
- if (!g_initialized) {
364
- load_mfa_bridge();
365
- if (!g_mfa_init()) {
366
- throw std::runtime_error("Failed to initialize MFA");
367
- }
368
- g_initialized = true;
369
- }
383
+ // Thread-safe initialization
384
+ ensure_initialized();
370
385
 
371
386
  // Validate inputs
372
387
  TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
@@ -562,14 +577,8 @@ at::Tensor mps_flash_attention_forward_with_bias(
562
577
  int64_t window_size,
563
578
  int64_t bias_repeat_count // >0 means bias repeats every N batches (for window attention)
564
579
  ) {
565
- // Initialize MFA on first call
566
- if (!g_initialized) {
567
- load_mfa_bridge();
568
- if (!g_mfa_init()) {
569
- throw std::runtime_error("Failed to initialize MFA");
570
- }
571
- g_initialized = true;
572
- }
580
+ // Thread-safe initialization
581
+ ensure_initialized();
573
582
 
574
583
  // Check that v6/v7 API is available
575
584
  TORCH_CHECK(g_mfa_create_kernel_v6 || g_mfa_create_kernel_v7,
@@ -735,14 +744,8 @@ at::Tensor mps_flash_attention_forward_quantized(
735
744
  const c10::optional<at::Tensor>& attn_mask,
736
745
  int64_t window_size
737
746
  ) {
738
- // Initialize MFA on first call
739
- if (!g_initialized) {
740
- load_mfa_bridge();
741
- if (!g_mfa_init()) {
742
- throw std::runtime_error("Failed to initialize MFA");
743
- }
744
- g_initialized = true;
745
- }
747
+ // Thread-safe initialization
748
+ ensure_initialized();
746
749
 
747
750
  // Check that v4 API is available
748
751
  TORCH_CHECK(g_mfa_create_kernel_v4, "Quantized attention requires MFA v4 API (update libMFABridge.dylib)");
@@ -992,14 +995,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
992
995
  int64_t window_size, // 0 = full attention, >0 = sliding window
993
996
  bool bf16_backward // true = use BF16 intermediates for ~2x faster backward
994
997
  ) {
995
- // Initialize MFA on first call
996
- if (!g_initialized) {
997
- load_mfa_bridge();
998
- if (!g_mfa_init()) {
999
- throw std::runtime_error("Failed to initialize MFA");
1000
- }
1001
- g_initialized = true;
1002
- }
998
+ // Thread-safe initialization
999
+ ensure_initialized();
1003
1000
 
1004
1001
  // Validate inputs
1005
1002
  TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.8
3
+ Version: 0.3.0
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -34,4 +34,5 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
34
34
  mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib
35
35
  mps_flash_attn/kernels/manifest.json
36
36
  mps_flash_attn/lib/libMFABridge.dylib
37
+ tests/test_issues.py
37
38
  tests/test_mfa_v2.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.2.8"
7
+ version = "0.3.0"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -72,7 +72,7 @@ def get_extensions():
72
72
 
73
73
  setup(
74
74
  name="mps-flash-attn",
75
- version="0.1.5",
75
+ version="0.3.0",
76
76
  packages=find_packages(),
77
77
  package_data={
78
78
  "mps_flash_attn": [
@@ -0,0 +1,446 @@
1
+ """
2
+ Tests for known issues and fixes in mps-flash-attention.
3
+
4
+ These tests verify that:
5
+ 1. Version is consistent between setup.py and __init__.py
6
+ 2. __all__ exports are properly defined
7
+ 3. Scale validation works correctly
8
+ 4. Chunked attention accuracy is within expected bounds
9
+ 5. All exported symbols are importable
10
+ """
11
+
12
+ import pytest
13
+ import torch
14
+ import math
15
+ import warnings
16
+ import re
17
+ import sys
18
+ from pathlib import Path
19
+
20
+
21
+ # Get the version from setup.py for comparison
22
+ def get_setup_version():
23
+ setup_path = Path(__file__).parent.parent / "setup.py"
24
+ content = setup_path.read_text()
25
+ match = re.search(r'version="([^"]+)"', content)
26
+ return match.group(1) if match else None
27
+
28
+
29
+ class TestVersionConsistency:
30
+ """Test that version is consistent across the package."""
31
+
32
+ def test_version_matches_setup(self):
33
+ """Version in __init__.py should match setup.py."""
34
+ import mps_flash_attn
35
+ setup_version = get_setup_version()
36
+ assert setup_version is not None, "Could not parse version from setup.py"
37
+ assert mps_flash_attn.__version__ == setup_version, (
38
+ f"Version mismatch: __init__.py has {mps_flash_attn.__version__}, "
39
+ f"setup.py has {setup_version}"
40
+ )
41
+
42
+
43
+ class TestExports:
44
+ """Test that __all__ exports are properly defined."""
45
+
46
+ def test_all_is_defined(self):
47
+ """__all__ should be defined."""
48
+ import mps_flash_attn
49
+ assert hasattr(mps_flash_attn, "__all__"), "__all__ not defined"
50
+ assert isinstance(mps_flash_attn.__all__, list), "__all__ should be a list"
51
+ assert len(mps_flash_attn.__all__) > 0, "__all__ should not be empty"
52
+
53
+ def test_all_symbols_exist(self):
54
+ """All symbols in __all__ should be importable."""
55
+ import mps_flash_attn
56
+ missing = []
57
+ for name in mps_flash_attn.__all__:
58
+ if not hasattr(mps_flash_attn, name):
59
+ missing.append(name)
60
+ assert not missing, f"Missing symbols from __all__: {missing}"
61
+
62
+ def test_core_functions_exported(self):
63
+ """Core functions should be in __all__."""
64
+ import mps_flash_attn
65
+ required = [
66
+ "flash_attention",
67
+ "flash_attention_chunked",
68
+ "is_available",
69
+ "__version__",
70
+ ]
71
+ for name in required:
72
+ assert name in mps_flash_attn.__all__, f"{name} missing from __all__"
73
+
74
+
75
+ class TestScaleValidation:
76
+ """Test scale parameter validation."""
77
+
78
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
79
+ def test_negative_scale_raises(self):
80
+ """Negative scale should raise ValueError."""
81
+ import mps_flash_attn
82
+ if not mps_flash_attn.is_available():
83
+ pytest.skip("MFA not available")
84
+
85
+ q = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
86
+ k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
87
+ v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
88
+
89
+ with pytest.raises(ValueError, match="scale must be positive"):
90
+ mps_flash_attn.flash_attention(q, k, v, scale=-1.0)
91
+
92
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
93
+ def test_zero_scale_raises(self):
94
+ """Zero scale should raise ValueError."""
95
+ import mps_flash_attn
96
+ if not mps_flash_attn.is_available():
97
+ pytest.skip("MFA not available")
98
+
99
+ q = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
100
+ k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
101
+ v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
102
+
103
+ with pytest.raises(ValueError, match="scale must be positive"):
104
+ mps_flash_attn.flash_attention(q, k, v, scale=0.0)
105
+
106
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
107
+ def test_extreme_scale_warns(self):
108
+ """Extreme scale values should warn."""
109
+ import mps_flash_attn
110
+ if not mps_flash_attn.is_available():
111
+ pytest.skip("MFA not available")
112
+
113
+ q = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
114
+ k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
115
+ v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
116
+
117
+ # Very large scale (1000x default)
118
+ with warnings.catch_warnings(record=True) as w:
119
+ warnings.simplefilter("always")
120
+ mps_flash_attn.flash_attention(q, k, v, scale=100.0) # default is ~0.177
121
+ assert any("very different from default" in str(warning.message) for warning in w), \
122
+ "Should warn about extreme scale"
123
+
124
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
125
+ def test_normal_scale_no_warning(self):
126
+ """Normal scale values should not warn."""
127
+ import mps_flash_attn
128
+ if not mps_flash_attn.is_available():
129
+ pytest.skip("MFA not available")
130
+
131
+ q = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
132
+ k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
133
+ v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
134
+
135
+ default_scale = 1.0 / math.sqrt(32)
136
+ with warnings.catch_warnings(record=True) as w:
137
+ warnings.simplefilter("always")
138
+ mps_flash_attn.flash_attention(q, k, v, scale=default_scale)
139
+ scale_warnings = [x for x in w if "scale" in str(x.message).lower()]
140
+ assert not scale_warnings, f"Should not warn about normal scale: {scale_warnings}"
141
+
142
+
143
+ class TestChunkedAccuracy:
144
+ """Test chunked attention accuracy vs non-chunked."""
145
+
146
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
147
+ def test_chunked_vs_regular_accuracy(self):
148
+ """Chunked attention should match regular within tolerance.
149
+
150
+ Note: Online softmax approximation has inherent numerical error.
151
+ The max difference is typically ~0.05-0.06 due to the approximation
152
+ at line 556 where chunk_max = chunk_lse (simplified online softmax).
153
+ """
154
+ import mps_flash_attn
155
+ if not mps_flash_attn.is_available():
156
+ pytest.skip("MFA not available")
157
+
158
+ torch.manual_seed(42)
159
+ # Use a moderate size to keep test fast
160
+ B, H, N, D = 2, 4, 2048, 64
161
+
162
+ q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
163
+ k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
164
+ v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
165
+
166
+ # Regular flash attention
167
+ out_regular = mps_flash_attn.flash_attention(q, k, v)
168
+
169
+ # Chunked flash attention with smaller chunk
170
+ out_chunked = mps_flash_attn.flash_attention_chunked(q, k, v, chunk_size=512)
171
+
172
+ torch.mps.synchronize()
173
+
174
+ # Compute differences
175
+ max_diff = (out_regular - out_chunked).abs().max().item()
176
+ mean_diff = (out_regular - out_chunked).abs().mean().item()
177
+
178
+ # Online softmax has inherent ~0.05 max error
179
+ # This is NOT a regression - it's expected behavior
180
+ assert max_diff < 0.10, f"Max diff {max_diff:.6f} too large (expected < 0.10)"
181
+ assert mean_diff < 0.01, f"Mean diff {mean_diff:.6f} too large (expected < 0.01)"
182
+
183
+ # Log actual values for transparency
184
+ print(f"\nChunked accuracy: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
185
+
186
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
187
+ def test_chunked_causal_accuracy(self):
188
+ """Chunked causal attention should also be accurate."""
189
+ import mps_flash_attn
190
+ if not mps_flash_attn.is_available():
191
+ pytest.skip("MFA not available")
192
+
193
+ torch.manual_seed(123)
194
+ B, H, N, D = 2, 4, 1024, 64
195
+
196
+ q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
197
+ k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
198
+ v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
199
+
200
+ out_regular = mps_flash_attn.flash_attention(q, k, v, is_causal=True)
201
+ out_chunked = mps_flash_attn.flash_attention_chunked(q, k, v, is_causal=True, chunk_size=256)
202
+
203
+ torch.mps.synchronize()
204
+
205
+ max_diff = (out_regular - out_chunked).abs().max().item()
206
+ mean_diff = (out_regular - out_chunked).abs().mean().item()
207
+
208
+ assert max_diff < 0.10, f"Causal max diff {max_diff:.6f} too large"
209
+ assert mean_diff < 0.01, f"Causal mean diff {mean_diff:.6f} too large"
210
+
211
+
212
+ class TestInputValidation:
213
+ """Test input validation catches errors early."""
214
+
215
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
216
+ def test_wrong_device_raises(self):
217
+ """Tensors on wrong device should raise."""
218
+ import mps_flash_attn
219
+ if not mps_flash_attn.is_available():
220
+ pytest.skip("MFA not available")
221
+
222
+ q = torch.randn(1, 4, 64, 32) # CPU tensor
223
+ k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
224
+ v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
225
+
226
+ with pytest.raises((RuntimeError, ValueError)):
227
+ mps_flash_attn.flash_attention(q, k, v)
228
+
229
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
230
+ def test_wrong_dims_raises(self):
231
+ """Wrong tensor dimensions should raise."""
232
+ import mps_flash_attn
233
+ if not mps_flash_attn.is_available():
234
+ pytest.skip("MFA not available")
235
+
236
+ q = torch.randn(1, 64, 32, device="mps", dtype=torch.float16) # 3D instead of 4D
237
+ k = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
238
+ v = torch.randn(1, 4, 64, 32, device="mps", dtype=torch.float16)
239
+
240
+ with pytest.raises(RuntimeError):
241
+ mps_flash_attn.flash_attention(q, k, v)
242
+
243
+
244
+ class TestMaskValidation:
245
+ """Test attention mask shape validation."""
246
+
247
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
248
+ def test_valid_mask_shape(self):
249
+ """Valid mask shapes should work."""
250
+ import mps_flash_attn
251
+ if not mps_flash_attn.is_available():
252
+ pytest.skip("MFA not available")
253
+
254
+ B, H, N, D = 2, 4, 64, 32
255
+ q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
256
+ k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
257
+ v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
258
+
259
+ # Full mask
260
+ mask = torch.zeros(B, H, N, N, dtype=torch.bool, device="mps")
261
+ out = mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
262
+ assert out.shape == (B, H, N, D)
263
+
264
+ # Broadcast over batch
265
+ mask = torch.zeros(1, H, N, N, dtype=torch.bool, device="mps")
266
+ out = mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
267
+ assert out.shape == (B, H, N, D)
268
+
269
+ # Broadcast over heads
270
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device="mps")
271
+ out = mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
272
+ assert out.shape == (B, H, N, D)
273
+
274
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
275
+ def test_invalid_mask_dims_raises(self):
276
+ """Wrong mask dimensions should raise."""
277
+ import mps_flash_attn
278
+ if not mps_flash_attn.is_available():
279
+ pytest.skip("MFA not available")
280
+
281
+ B, H, N, D = 2, 4, 64, 32
282
+ q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
283
+ k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
284
+ v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
285
+
286
+ # 3D mask (missing batch dim)
287
+ mask = torch.zeros(H, N, N, dtype=torch.bool, device="mps")
288
+ with pytest.raises(ValueError, match="must be 4D"):
289
+ mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
290
+
291
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
292
+ def test_invalid_mask_seq_len_raises(self):
293
+ """Wrong mask sequence length should raise."""
294
+ import mps_flash_attn
295
+ if not mps_flash_attn.is_available():
296
+ pytest.skip("MFA not available")
297
+
298
+ B, H, N, D = 2, 4, 64, 32
299
+ q = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
300
+ k = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
301
+ v = torch.randn(B, H, N, D, device="mps", dtype=torch.float16)
302
+
303
+ # Wrong N_kv
304
+ mask = torch.zeros(B, H, N, N // 2, dtype=torch.bool, device="mps")
305
+ with pytest.raises(ValueError, match="shape mismatch"):
306
+ mps_flash_attn.flash_attention(q, k, v, attn_mask=mask)
307
+
308
+
309
+ class TestAutoContiguous:
310
+ """Test auto-contiguous conversion."""
311
+
312
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
313
+ def test_non_contiguous_works(self):
314
+ """Non-contiguous tensors should be auto-converted."""
315
+ import mps_flash_attn
316
+ if not mps_flash_attn.is_available():
317
+ pytest.skip("MFA not available")
318
+
319
+ B, H, N, D = 2, 4, 64, 32
320
+ # Create non-contiguous tensor via transpose
321
+ q = torch.randn(B, N, H, D, device="mps", dtype=torch.float16).transpose(1, 2)
322
+ k = torch.randn(B, N, H, D, device="mps", dtype=torch.float16).transpose(1, 2)
323
+ v = torch.randn(B, N, H, D, device="mps", dtype=torch.float16).transpose(1, 2)
324
+
325
+ assert not q.is_contiguous(), "Test setup: q should be non-contiguous"
326
+
327
+ # Should work without error
328
+ out = mps_flash_attn.flash_attention(q, k, v)
329
+ assert out.shape == (B, H, N, D)
330
+
331
+
332
+ class TestFallbackTracking:
333
+ """Test SDPA fallback tracking."""
334
+
335
+ def test_fallback_counter_exists(self):
336
+ """The replace_sdpa function should track fallbacks."""
337
+ import mps_flash_attn
338
+ # Just verify the module has the replace_sdpa function
339
+ assert hasattr(mps_flash_attn, 'replace_sdpa')
340
+
341
+
342
+ class TestThreadSafety:
343
+ """Test thread-safe initialization.
344
+
345
+ Note: MPS itself doesn't support concurrent operations from multiple threads,
346
+ so we can't test concurrent flash_attention calls. But we CAN test that
347
+ the initialization code is thread-safe by checking that multiple threads
348
+ attempting to initialize don't cause crashes or double-initialization.
349
+ """
350
+
351
+ @pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
352
+ def test_concurrent_initialization(self):
353
+ """Multiple threads initializing simultaneously should not crash.
354
+
355
+ This tests the thread-safe initialization in the C++ code.
356
+ We spawn multiple threads that all try to call flash_attention
357
+ with a barrier to synchronize their start times. Only ONE should
358
+ actually initialize the MFA bridge.
359
+
360
+ Note: We serialize the actual MPS operations to avoid driver crashes,
361
+ but the initialization race is what we're testing.
362
+ """
363
+ import mps_flash_attn
364
+ if not mps_flash_attn.is_available():
365
+ pytest.skip("MFA not available")
366
+
367
+ import threading
368
+ import queue
369
+
370
+ num_threads = 4
371
+ results = queue.Queue()
372
+ barrier = threading.Barrier(num_threads)
373
+
374
+ # Use a lock to serialize MPS operations (avoid driver crash)
375
+ # but let initialization race
376
+ mps_lock = threading.Lock()
377
+
378
+ def worker(thread_id):
379
+ try:
380
+ # All threads wait here, then start simultaneously
381
+ barrier.wait()
382
+
383
+ # Each thread creates its own tensors and calls flash_attention
384
+ # The initialization will race, but should be thread-safe
385
+ with mps_lock: # Serialize MPS operations
386
+ q = torch.randn(1, 2, 32, 16, device="mps", dtype=torch.float16)
387
+ k = torch.randn(1, 2, 32, 16, device="mps", dtype=torch.float16)
388
+ v = torch.randn(1, 2, 32, 16, device="mps", dtype=torch.float16)
389
+
390
+ out = mps_flash_attn.flash_attention(q, k, v)
391
+ torch.mps.synchronize()
392
+
393
+ results.put((thread_id, "success", out.shape))
394
+ except Exception as e:
395
+ results.put((thread_id, "error", str(e)))
396
+
397
+ # Spawn threads
398
+ threads = []
399
+ for i in range(num_threads):
400
+ t = threading.Thread(target=worker, args=(i,))
401
+ threads.append(t)
402
+ t.start()
403
+
404
+ # Wait for completion
405
+ for t in threads:
406
+ t.join(timeout=30)
407
+
408
+ # Check results
409
+ successes = 0
410
+ errors = []
411
+ while not results.empty():
412
+ thread_id, status, data = results.get()
413
+ if status == "success":
414
+ successes += 1
415
+ assert data == (1, 2, 32, 16), f"Thread {thread_id} got wrong shape: {data}"
416
+ else:
417
+ errors.append(f"Thread {thread_id}: {data}")
418
+
419
+ assert successes == num_threads, f"Only {successes}/{num_threads} succeeded. Errors: {errors}"
420
+
421
+ def test_atomic_flag_exists(self):
422
+ """Verify the C++ code uses atomic for g_initialized.
423
+
424
+ This is a static check - we grep the source to ensure the fix is in place.
425
+ """
426
+ cpp_file = Path(__file__).parent.parent / "mps_flash_attn" / "csrc" / "mps_flash_attn.mm"
427
+ if not cpp_file.exists():
428
+ pytest.skip("C++ source not found")
429
+
430
+ content = cpp_file.read_text()
431
+
432
+ # Check for atomic<bool>
433
+ assert "std::atomic<bool>" in content or "atomic<bool>" in content, \
434
+ "g_initialized should be std::atomic<bool>"
435
+
436
+ # Check for mutex
437
+ assert "std::mutex" in content or "mutex" in content, \
438
+ "Should have mutex for thread-safe init"
439
+
440
+ # Check for ensure_initialized helper
441
+ assert "ensure_initialized" in content, \
442
+ "Should have ensure_initialized() helper function"
443
+
444
+
445
+ if __name__ == "__main__":
446
+ pytest.main([__file__, "-v"])
File without changes
File without changes
File without changes