mps-flash-attn 0.2.8__cp314-cp314-macosx_15_0_arm64.whl → 0.3.2__cp314-cp314-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Binary file
@@ -4,13 +4,42 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.2.8"
7
+ __version__ = "0.3.2"
8
+
9
+ __all__ = [
10
+ # Core functions
11
+ "flash_attention",
12
+ "flash_attention_with_bias",
13
+ "flash_attention_chunked",
14
+ # Quantized attention
15
+ "flash_attention_fp8",
16
+ "flash_attention_int8",
17
+ "flash_attention_nf4",
18
+ "quantize_kv_fp8",
19
+ "quantize_kv_int8",
20
+ "quantize_kv_nf4",
21
+ # Utilities
22
+ "replace_sdpa",
23
+ "precompile",
24
+ "clear_cache",
25
+ "register_custom_op",
26
+ "is_available",
27
+ "convert_mask",
28
+ # Constants
29
+ "QUANT_FP8_E4M3",
30
+ "QUANT_FP8_E5M2",
31
+ "QUANT_INT8",
32
+ "QUANT_NF4",
33
+ # Version
34
+ "__version__",
35
+ ]
8
36
 
9
37
  import torch
10
- from typing import Optional
38
+ from typing import Optional, Tuple
11
39
  import math
12
40
  import threading
13
41
  import os
42
+ import warnings
14
43
 
15
44
  # Try to import the C++ extension
16
45
  try:
@@ -30,6 +59,20 @@ def is_available() -> bool:
30
59
  return _HAS_MFA and torch.backends.mps.is_available()
31
60
 
32
61
 
62
+ def _ensure_contiguous(tensor: torch.Tensor, name: str) -> torch.Tensor:
63
+ """Ensure tensor is contiguous, with a debug warning if conversion needed."""
64
+ if tensor.is_contiguous():
65
+ return tensor
66
+ # Auto-convert with debug info
67
+ if os.environ.get("MFA_DEBUG", "0") == "1":
68
+ warnings.warn(
69
+ f"MFA: {name} tensor was not contiguous (stride={tensor.stride()}), "
70
+ f"auto-converting. For best performance, ensure inputs are contiguous.",
71
+ UserWarning
72
+ )
73
+ return tensor.contiguous()
74
+
75
+
33
76
  def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
34
77
  """
35
78
  Convert attention mask to MFA's boolean format.
@@ -54,6 +97,56 @@ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
54
97
  return attn_mask <= -1e3
55
98
 
56
99
 
100
+ def _validate_and_expand_mask(
101
+ attn_mask: Optional[torch.Tensor],
102
+ B: int,
103
+ H: int,
104
+ N_q: int,
105
+ N_kv: int,
106
+ ) -> Optional[torch.Tensor]:
107
+ """
108
+ Validate attention mask shape and expand broadcast dimensions.
109
+
110
+ Args:
111
+ attn_mask: Optional mask of shape (B, H, N_q, N_kv) or broadcastable
112
+ B: Batch size
113
+ H: Number of heads
114
+ N_q: Query sequence length
115
+ N_kv: Key/Value sequence length
116
+
117
+ Returns:
118
+ Expanded mask of shape (mb, mh, N_q, N_kv) or None
119
+ """
120
+ if attn_mask is None:
121
+ return None
122
+
123
+ attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
124
+
125
+ if attn_mask.dim() != 4:
126
+ raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
127
+
128
+ mb, mh, mq, mk = attn_mask.shape
129
+
130
+ # Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
131
+ if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
132
+ raise ValueError(
133
+ f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
134
+ )
135
+
136
+ # Expand broadcast mask to full shape for Metal kernel
137
+ if mq == 1 and N_q > 1:
138
+ attn_mask = attn_mask.expand(mb, mh, N_q, mk)
139
+ if mk == 1 and N_kv > 1:
140
+ attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
141
+
142
+ if mb != 1 and mb != B:
143
+ raise ValueError(f"attn_mask batch size must be 1 or {B}, got {mb}")
144
+ if mh != 1 and mh != H:
145
+ raise ValueError(f"attn_mask head count must be 1 or {H}, got {mh}")
146
+
147
+ return attn_mask
148
+
149
+
57
150
  class FlashAttentionFunction(torch.autograd.Function):
58
151
  """Autograd function for Flash Attention with backward pass support."""
59
152
 
@@ -176,6 +269,20 @@ def flash_attention(
176
269
  if not torch.backends.mps.is_available():
177
270
  raise RuntimeError("MPS not available")
178
271
 
272
+ # Validate scale parameter
273
+ if scale is not None:
274
+ if scale <= 0:
275
+ raise ValueError(f"scale must be positive, got {scale}")
276
+ # Warn about extreme scale values that could cause numerical issues
277
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
278
+ if scale < default_scale * 0.01 or scale > default_scale * 100:
279
+ warnings.warn(
280
+ f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
281
+ "this may cause numerical issues",
282
+ UserWarning,
283
+ stacklevel=2
284
+ )
285
+
179
286
  # Validate device
180
287
  if query.device.type != 'mps':
181
288
  raise ValueError("query must be on MPS device")
@@ -186,6 +293,24 @@ def flash_attention(
186
293
  if attn_mask is not None and attn_mask.device.type != 'mps':
187
294
  raise ValueError("attn_mask must be on MPS device")
188
295
 
296
+ # Ensure contiguous (auto-convert with debug warning)
297
+ query = _ensure_contiguous(query, "query")
298
+ key = _ensure_contiguous(key, "key")
299
+ value = _ensure_contiguous(value, "value")
300
+
301
+ # Validate tensor dimensions
302
+ if query.dim() != 4:
303
+ raise RuntimeError(f"query must be 4D (B, H, N, D), got {query.dim()}D")
304
+ if key.dim() != 4:
305
+ raise RuntimeError(f"key must be 4D (B, H, N, D), got {key.dim()}D")
306
+ if value.dim() != 4:
307
+ raise RuntimeError(f"value must be 4D (B, H, N, D), got {value.dim()}D")
308
+
309
+ # Validate and expand broadcast mask
310
+ B, H, N_q, D = query.shape
311
+ N_kv = key.shape[2]
312
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
313
+
189
314
  # Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
190
315
  if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
191
316
  # Apply scale if provided
@@ -213,6 +338,10 @@ def replace_sdpa():
213
338
  import torch.nn.functional as F
214
339
 
215
340
  original_sdpa = F.scaled_dot_product_attention
341
+ _debug = os.environ.get("MFA_DEBUG", "0") == "1"
342
+ _call_count = [0] # mutable for closure
343
+ _fallback_count = [0] # track fallbacks for warning
344
+ _last_fallback_error = [None]
216
345
 
217
346
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
218
347
  is_causal=False, scale=None, enable_gqa=False, **kwargs):
@@ -224,37 +353,92 @@ def replace_sdpa():
224
353
  # seq=1024: 2.3-3.7x (MFA much faster)
225
354
  # seq=2048: 2.2-3.9x (MFA much faster)
226
355
  # seq=4096: 2.1-3.7x (MFA much faster)
356
+ # Determine seq_len based on tensor dimensionality
357
+ # 4D: (B, H, S, D) -> seq_len = shape[2]
358
+ # 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
359
+ is_3d = query.ndim == 3
360
+ seq_len = query.shape[1] if is_3d else query.shape[2]
361
+
227
362
  if (query.device.type == 'mps' and
228
363
  dropout_p == 0.0 and
229
364
  _HAS_MFA and
230
- query.shape[2] >= 512):
365
+ query.ndim >= 3 and
366
+ seq_len >= 512):
231
367
  try:
368
+ q, k, v = query, key, value
369
+
370
+ # Handle 3D tensors (B, S, D) - treat as single-head attention
371
+ # Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
372
+ if is_3d:
373
+ q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
374
+ k = k.unsqueeze(1)
375
+ v = v.unsqueeze(1)
376
+
232
377
  # Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
233
378
  # Common in Llama 2/3, Mistral, Qwen, etc.
234
- k, v = key, value
235
- if enable_gqa and query.shape[1] != key.shape[1]:
379
+ # NOTE: Always expand when heads mismatch, not just when enable_gqa=True
380
+ # Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
381
+ # MPS SDPA doesn't support native GQA - we handle it here
382
+ if q.shape[1] != k.shape[1]:
236
383
  # Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
237
- n_rep = query.shape[1] // key.shape[1]
238
- k = key.repeat_interleave(n_rep, dim=1)
239
- v = value.repeat_interleave(n_rep, dim=1)
384
+ n_rep = q.shape[1] // k.shape[1]
385
+ k = k.repeat_interleave(n_rep, dim=1)
386
+ v = v.repeat_interleave(n_rep, dim=1)
240
387
 
241
388
  # Convert float mask to bool mask if needed
242
389
  # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
243
390
  # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
244
391
  mfa_mask = None
245
392
  if attn_mask is not None:
393
+ if _debug:
394
+ print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
246
395
  if attn_mask.dtype == torch.bool:
247
- # Boolean mask: True means masked (don't attend)
248
- mfa_mask = attn_mask
396
+ # PyTorch SDPA bool mask: True = ATTEND, False = MASKED
397
+ # MFA bool mask: True = MASKED, False = ATTEND
398
+ # They're opposite! Invert it.
399
+ mfa_mask = ~attn_mask
249
400
  else:
250
401
  # Float mask: typically -inf for masked positions, 0 for unmasked
251
402
  # Convert: positions with large negative values -> True (masked)
252
403
  # Use -1e3 threshold to catch -1000, -10000, -inf, etc.
253
404
  mfa_mask = attn_mask <= -1e3
254
- return flash_attention(query, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
255
- except Exception:
256
- # Fall back to original on any error
257
- pass
405
+ if _debug:
406
+ print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
407
+
408
+ out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
409
+
410
+ # Squeeze back for 3D input
411
+ if is_3d:
412
+ out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
413
+
414
+ if _debug:
415
+ _call_count[0] += 1
416
+ print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
417
+
418
+ return out
419
+ except Exception as e:
420
+ # Fall back to original on any error, but track it
421
+ _fallback_count[0] += 1
422
+ _last_fallback_error[0] = str(e)
423
+ if _debug:
424
+ import traceback
425
+ print(f"[MFA FALLBACK #{_fallback_count[0]}] shape={tuple(query.shape)}\n{traceback.format_exc()}")
426
+ # Warn user after repeated fallbacks (likely a real problem)
427
+ if _fallback_count[0] == 10:
428
+ warnings.warn(
429
+ f"MFA has fallen back to native SDPA {_fallback_count[0]} times. "
430
+ f"Last error: {_last_fallback_error[0]}. "
431
+ f"Set MFA_DEBUG=1 for details.",
432
+ UserWarning
433
+ )
434
+
435
+ if _debug and query.device.type == 'mps':
436
+ _call_count[0] += 1
437
+ reason = []
438
+ if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
439
+ if query.ndim < 3: reason.append(f"ndim={query.ndim}")
440
+ if seq_len < 512: reason.append(f"seq={seq_len}<512")
441
+ print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
258
442
 
259
443
  return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
260
444
 
@@ -461,13 +645,13 @@ def flash_attention_chunked(
461
645
  return _C.forward(query, key, value, is_causal, None, 0)
462
646
 
463
647
  # Initialize running statistics for online softmax
464
- # m = running max, l = running sum of exp, acc = accumulated output
465
648
  device = query.device
466
649
  dtype = query.dtype
467
650
 
468
651
  # Use float32 for numerical stability of softmax statistics
469
- running_max = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
470
- running_sum = torch.zeros((B, H, seq_len_q, 1), device=device, dtype=torch.float32)
652
+ # running_L: base-2 logsumexp of all attention scores seen so far (-inf means no data yet)
653
+ # output_acc: weighted combination of outputs (weights sum to 1 after each update)
654
+ running_L = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
471
655
  output_acc = torch.zeros((B, H, seq_len_q, D), device=device, dtype=torch.float32)
472
656
 
473
657
  # Process K/V in chunks
@@ -488,51 +672,81 @@ def flash_attention_chunked(
488
672
  # - Partial chunk (up to q) if start_idx <= q < end_idx
489
673
  # - None of chunk if q < start_idx
490
674
 
491
- chunk_is_causal = is_causal and (end_idx <= seq_len_q)
492
-
493
- # Compute attention for this chunk
494
- # forward_with_lse returns (output, logsumexp) where logsumexp = m + log(l)
495
- chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, chunk_is_causal, None, 0)
496
-
497
- # chunk_lse shape: (B, H, seq_len_q)
498
- # We need to convert logsumexp to (max, sum) for online algorithm
499
- chunk_lse = chunk_lse.unsqueeze(-1) # (B, H, seq_len_q, 1)
675
+ if is_causal:
676
+ # Create explicit causal mask for this chunk
677
+ # Query positions: 0 to seq_len_q-1
678
+ # Key positions in chunk: start_idx to end_idx-1
679
+ chunk_len = end_idx - start_idx
500
680
 
501
- # Convert chunk output to float32 for accumulation
502
- chunk_out = chunk_out.float()
681
+ # Build mask: mask[q, k_local] = True means DON'T attend
682
+ # We want to attend when global_k_pos <= q
683
+ # global_k_pos = start_idx + k_local
684
+ # So: attend when start_idx + k_local <= q
685
+ # mask = start_idx + k_local > q
503
686
 
504
- # Online softmax update:
505
- # new_max = max(running_max, chunk_max)
506
- # For flash attention, chunk_lse chunk_max + log(chunk_sum)
507
- # We approximate chunk_max ≈ chunk_lse (valid when exp sum dominates)
687
+ q_pos = torch.arange(seq_len_q, device=device).view(1, 1, seq_len_q, 1)
688
+ k_pos = torch.arange(chunk_len, device=device).view(1, 1, 1, chunk_len) + start_idx
689
+ causal_mask = k_pos > q_pos # True = masked (don't attend)
508
690
 
509
- chunk_max = chunk_lse # Approximation: logsumexp max when sum is dominated by max
691
+ # Expand to batch and heads
692
+ causal_mask = causal_mask.expand(B, H, seq_len_q, chunk_len)
510
693
 
511
- # Compute new max
512
- new_max = torch.maximum(running_max, chunk_max)
694
+ # Call forward with explicit mask (is_causal=False since we handle it)
695
+ chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, causal_mask, 0)
696
+ else:
697
+ # Non-causal: just process the chunk directly
698
+ chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, None, 0)
513
699
 
514
- # Rescale previous accumulator
515
- # correction_old = exp(running_max - new_max)
516
- correction_old = torch.exp(running_max - new_max)
517
- # Clip to avoid inf * 0 issues when running_max was -inf
518
- correction_old = torch.where(running_max == float('-inf'), torch.zeros_like(correction_old), correction_old)
700
+ # chunk_L shape: (B, H, seq_len_q)
701
+ # The kernel returns L = m + log2(l) where:
702
+ # m = max(scores * log2(e) / sqrt(D))
703
+ # l = sum(exp2(scores * log2(e) / sqrt(D) - m))
704
+ # This is a base-2 logsumexp: L = log2(sum(exp2(scaled_scores)))
705
+ chunk_L = chunk_lse.unsqueeze(-1).float() # (B, H, seq_len_q, 1)
519
706
 
520
- # Rescale chunk output
521
- # correction_new = exp(chunk_max - new_max)
522
- correction_new = torch.exp(chunk_max - new_max)
707
+ # Convert chunk output to float32 for accumulation
708
+ chunk_out = chunk_out.float()
523
709
 
524
- # For the sum, we need exp(chunk_lse - new_max) = exp(chunk_max + log(chunk_sum) - new_max)
525
- # = exp(chunk_max - new_max) * chunk_sum
526
- # But we only have logsumexp, so: exp(chunk_lse - new_max)
527
- chunk_sum_scaled = torch.exp(chunk_lse - new_max)
710
+ # Online softmax algorithm using base-2 representation
711
+ #
712
+ # Flash attention returns: chunk_out = softmax(scores) @ V
713
+ # The output is already normalized. For online combination:
714
+ # new_L = log2(2^running_L + 2^chunk_L)
715
+ # = max(running_L, chunk_L) + log2(2^(running_L - max) + 2^(chunk_L - max))
716
+ #
717
+ # The weights for combining outputs are:
718
+ # old_weight = 2^(running_L - new_L)
719
+ # new_weight = 2^(chunk_L - new_L)
720
+ # These weights sum to 1, so: output = old_weight * old_out + new_weight * new_out
721
+
722
+ # Compute new base-2 logsumexp
723
+ max_L = torch.maximum(running_L, chunk_L)
724
+
725
+ # Handle -inf case (no previous data)
726
+ # Use exp2 for base-2 (matches kernel's internal representation)
727
+ running_exp2 = torch.where(
728
+ running_L == float('-inf'),
729
+ torch.zeros_like(running_L),
730
+ torch.exp2(running_L - max_L)
731
+ )
732
+ chunk_exp2 = torch.exp2(chunk_L - max_L)
733
+ new_L = max_L + torch.log2(running_exp2 + chunk_exp2)
734
+
735
+ # Compute correction factors using base-2 exp
736
+ old_weight = torch.where(
737
+ running_L == float('-inf'),
738
+ torch.zeros_like(running_L),
739
+ torch.exp2(running_L - new_L)
740
+ )
741
+ new_weight = torch.exp2(chunk_L - new_L)
528
742
 
529
743
  # Update accumulator
530
- output_acc = output_acc * correction_old + chunk_out * correction_new
531
- running_sum = running_sum * correction_old + chunk_sum_scaled
532
- running_max = new_max
744
+ # Update accumulator
745
+ output_acc = output_acc * old_weight + chunk_out * new_weight
746
+ running_L = new_L
533
747
 
534
- # Final normalization
535
- output = output_acc / running_sum
748
+ # No final normalization needed - weights already sum to 1
749
+ output = output_acc
536
750
 
537
751
  # Convert back to original dtype
538
752
  return output.to(dtype)
@@ -754,6 +968,11 @@ def flash_attention_fp8(
754
968
  scale_factor = scale / default_scale
755
969
  query = query * scale_factor
756
970
 
971
+ # Validate and expand broadcast mask
972
+ B, H, N_q, D = query.shape
973
+ N_kv = key.shape[2]
974
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
975
+
757
976
  quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
758
977
  return _C.forward_quantized(
759
978
  query, key, value, k_scale, v_scale,
@@ -807,6 +1026,11 @@ def flash_attention_int8(
807
1026
  scale_factor = scale / default_scale
808
1027
  query = query * scale_factor
809
1028
 
1029
+ # Validate and expand broadcast mask
1030
+ B, H, N_q, D = query.shape
1031
+ N_kv = key.shape[2]
1032
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1033
+
810
1034
  return _C.forward_quantized(
811
1035
  query, key, value, k_scale, v_scale,
812
1036
  QUANT_INT8, is_causal, attn_mask, window_size
@@ -863,6 +1087,11 @@ def flash_attention_nf4(
863
1087
  scale_factor = scale / default_scale
864
1088
  query = query * scale_factor
865
1089
 
1090
+ # Validate and expand broadcast mask
1091
+ B, H, N_q, D = query.shape
1092
+ N_kv = key.shape[2]
1093
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1094
+
866
1095
  return _C.forward_quantized(
867
1096
  query, key, value, k_scale, v_scale,
868
1097
  QUANT_NF4, is_causal, attn_mask, window_size
@@ -917,6 +1146,11 @@ def flash_attention_quantized(
917
1146
  scale_factor = scale / default_scale
918
1147
  query = query * scale_factor
919
1148
 
1149
+ # Validate and expand broadcast mask
1150
+ B, H, N_q, D = query.shape
1151
+ N_kv = key.shape[2]
1152
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1153
+
920
1154
  return _C.forward_quantized(
921
1155
  query, key, value, k_scale, v_scale,
922
1156
  quant_type, is_causal, attn_mask, window_size
@@ -14,6 +14,8 @@
14
14
  #include <dlfcn.h>
15
15
  #include <string>
16
16
  #include <vector>
17
+ #include <mutex>
18
+ #include <atomic>
17
19
 
18
20
  // ============================================================================
19
21
  // MFA Bridge Function Types
@@ -67,7 +69,8 @@ static mfa_forward_fn g_mfa_forward = nullptr;
67
69
  static mfa_backward_fn g_mfa_backward = nullptr;
68
70
  static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
69
71
  static void* g_dylib_handle = nullptr;
70
- static bool g_initialized = false;
72
+ static std::atomic<bool> g_initialized{false};
73
+ static std::mutex g_init_mutex;
71
74
 
72
75
  // ============================================================================
73
76
  // Load MFA Bridge Library
@@ -141,6 +144,24 @@ static bool load_mfa_bridge() {
141
144
  return true;
142
145
  }
143
146
 
147
+ // Thread-safe initialization helper
148
+ static void ensure_initialized() {
149
+ // Fast path: already initialized
150
+ if (g_initialized.load(std::memory_order_acquire)) {
151
+ return;
152
+ }
153
+ // Slow path: need to initialize with lock
154
+ std::lock_guard<std::mutex> lock(g_init_mutex);
155
+ // Double-check after acquiring lock
156
+ if (!g_initialized.load(std::memory_order_relaxed)) {
157
+ load_mfa_bridge();
158
+ if (!g_mfa_init()) {
159
+ throw std::runtime_error("Failed to initialize MFA");
160
+ }
161
+ g_initialized.store(true, std::memory_order_release);
162
+ }
163
+ }
164
+
144
165
  // ============================================================================
145
166
  // Get MTLBuffer from PyTorch MPS Tensor
146
167
  // ============================================================================
@@ -359,14 +380,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
359
380
  const c10::optional<at::Tensor>& attn_mask, // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
360
381
  int64_t window_size // 0 = full attention, >0 = sliding window
361
382
  ) {
362
- // Initialize MFA on first call
363
- if (!g_initialized) {
364
- load_mfa_bridge();
365
- if (!g_mfa_init()) {
366
- throw std::runtime_error("Failed to initialize MFA");
367
- }
368
- g_initialized = true;
369
- }
383
+ // Thread-safe initialization
384
+ ensure_initialized();
370
385
 
371
386
  // Validate inputs
372
387
  TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
@@ -562,14 +577,8 @@ at::Tensor mps_flash_attention_forward_with_bias(
562
577
  int64_t window_size,
563
578
  int64_t bias_repeat_count // >0 means bias repeats every N batches (for window attention)
564
579
  ) {
565
- // Initialize MFA on first call
566
- if (!g_initialized) {
567
- load_mfa_bridge();
568
- if (!g_mfa_init()) {
569
- throw std::runtime_error("Failed to initialize MFA");
570
- }
571
- g_initialized = true;
572
- }
580
+ // Thread-safe initialization
581
+ ensure_initialized();
573
582
 
574
583
  // Check that v6/v7 API is available
575
584
  TORCH_CHECK(g_mfa_create_kernel_v6 || g_mfa_create_kernel_v7,
@@ -735,14 +744,8 @@ at::Tensor mps_flash_attention_forward_quantized(
735
744
  const c10::optional<at::Tensor>& attn_mask,
736
745
  int64_t window_size
737
746
  ) {
738
- // Initialize MFA on first call
739
- if (!g_initialized) {
740
- load_mfa_bridge();
741
- if (!g_mfa_init()) {
742
- throw std::runtime_error("Failed to initialize MFA");
743
- }
744
- g_initialized = true;
745
- }
747
+ // Thread-safe initialization
748
+ ensure_initialized();
746
749
 
747
750
  // Check that v4 API is available
748
751
  TORCH_CHECK(g_mfa_create_kernel_v4, "Quantized attention requires MFA v4 API (update libMFABridge.dylib)");
@@ -992,14 +995,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
992
995
  int64_t window_size, // 0 = full attention, >0 = sliding window
993
996
  bool bf16_backward // true = use BF16 intermediates for ~2x faster backward
994
997
  ) {
995
- // Initialize MFA on first call
996
- if (!g_initialized) {
997
- load_mfa_bridge();
998
- if (!g_mfa_init()) {
999
- throw std::runtime_error("Failed to initialize MFA");
1000
- }
1001
- g_initialized = true;
1002
- }
998
+ // Thread-safe initialization
999
+ ensure_initialized();
1003
1000
 
1004
1001
  // Validate inputs
1005
1002
  TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.8
3
+ Version: 0.3.2
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -1,7 +1,7 @@
1
- mps_flash_attn/_C.cpython-314-darwin.so,sha256=Npw6N708E8Fzy14nnd99XgwPSxkXIRE4iAhwVTwaZHc,313160
2
- mps_flash_attn/__init__.py,sha256=Fdz8SYpudWlogBU6QiJxy6ybyFR4Eqq7GMPdBusMl6U,40221
1
+ mps_flash_attn/_C.cpython-314-darwin.so,sha256=GtWa4KIcynqjbCQYw-uTBpkX5NTcxyKuK0APoGnRlQM,313448
2
+ mps_flash_attn/__init__.py,sha256=u6B_WenZOTk1WMe9u4-PbyPy7pt8NValR5i8Oz6bI-U,49252
3
3
  mps_flash_attn/benchmark.py,sha256=qHhvb8Dmh07OEa_iXuPuJSEnRJlrjVF5nKzVwbWypWE,24141
4
- mps_flash_attn/csrc/mps_flash_attn.mm,sha256=d7Bjcm2VOTNANmdqUevN-mqa5aOEVMMxAuTYINAeSr0,51215
4
+ mps_flash_attn/csrc/mps_flash_attn.mm,sha256=mR4S8SHLtRiksrmoFH6s2118q662SMNlFU8HmxAE3YY,51204
5
5
  mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib,sha256=_oig6f2I6ZxBCKWbJF3ofmZMySm8gB399_M-lD2NOfM,13747
6
6
  mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib,sha256=1fsmVvB5EubhN-y6s5CB-eVk_wuO2tfrabiQTwXvJJc,13171
7
7
  mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib,sha256=5WKo_yAU-PgmulBUQhnzvt0DZRteVmo4-nc4U-T6G2g,17507
@@ -26,8 +26,8 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
26
26
  mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib,sha256=qyOaQtRVwL_Wc6GGdu6z-ftf0iX84XexuY09-lNLl5o,13747
27
27
  mps_flash_attn/kernels/manifest.json,sha256=d5MkE_BjqDQuMNm1jZiwWkQKfB-yfFml3lLSeR-wCLo,1867
28
28
  mps_flash_attn/lib/libMFABridge.dylib,sha256=iKgfYISSKMSNt_iXnljjUr_hZZHyCAg2tdS3_ZjmLkc,605696
29
- mps_flash_attn-0.2.8.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
30
- mps_flash_attn-0.2.8.dist-info/METADATA,sha256=-Nx-lkEs-hfr1QevTiUt9yLU81x2JfKmsf7HqoSEKeg,5834
31
- mps_flash_attn-0.2.8.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
32
- mps_flash_attn-0.2.8.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
33
- mps_flash_attn-0.2.8.dist-info/RECORD,,
29
+ mps_flash_attn-0.3.2.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
30
+ mps_flash_attn-0.3.2.dist-info/METADATA,sha256=vhcu8d8NdzmuQbOqVUpzacXJF__Eu-BW1C7Em_CNoyg,5834
31
+ mps_flash_attn-0.3.2.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
32
+ mps_flash_attn-0.3.2.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
33
+ mps_flash_attn-0.3.2.dist-info/RECORD,,