mps-flash-attn 0.2.6__cp314-cp314-macosx_15_0_arm64.whl → 0.3.7__cp314-cp314-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
@@ -4,13 +4,43 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.2.5"
7
+ __version__ = "0.3.7"
8
+
9
+ __all__ = [
10
+ # Core functions
11
+ "flash_attention",
12
+ "flash_attention_with_bias",
13
+ "flash_attention_chunked",
14
+ # Quantized attention
15
+ "flash_attention_fp8",
16
+ "flash_attention_int8",
17
+ "flash_attention_nf4",
18
+ "quantize_kv_fp8",
19
+ "quantize_kv_int8",
20
+ "quantize_kv_nf4",
21
+ # Utilities
22
+ "replace_sdpa",
23
+ "precompile",
24
+ "clear_cache",
25
+ "register_custom_op",
26
+ "is_available",
27
+ "convert_mask",
28
+ # Constants
29
+ "QUANT_FP8_E4M3",
30
+ "QUANT_FP8_E5M2",
31
+ "QUANT_INT8",
32
+ "QUANT_NF4",
33
+ # Version
34
+ "__version__",
35
+ ]
8
36
 
9
37
  import torch
10
- from typing import Optional
38
+ import torch.nn.functional as F
39
+ from typing import Optional, Tuple
11
40
  import math
12
41
  import threading
13
42
  import os
43
+ import warnings
14
44
 
15
45
  # Try to import the C++ extension
16
46
  try:
@@ -30,6 +60,20 @@ def is_available() -> bool:
30
60
  return _HAS_MFA and torch.backends.mps.is_available()
31
61
 
32
62
 
63
+ def _ensure_contiguous(tensor: torch.Tensor, name: str) -> torch.Tensor:
64
+ """Ensure tensor is contiguous, with a debug warning if conversion needed."""
65
+ if tensor.is_contiguous():
66
+ return tensor
67
+ # Auto-convert with debug info
68
+ if os.environ.get("MFA_DEBUG", "0") == "1":
69
+ warnings.warn(
70
+ f"MFA: {name} tensor was not contiguous (stride={tensor.stride()}), "
71
+ f"auto-converting. For best performance, ensure inputs are contiguous.",
72
+ UserWarning
73
+ )
74
+ return tensor.contiguous()
75
+
76
+
33
77
  def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
34
78
  """
35
79
  Convert attention mask to MFA's boolean format.
@@ -54,6 +98,98 @@ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
54
98
  return attn_mask <= -1e3
55
99
 
56
100
 
101
+ def _validate_and_expand_mask(
102
+ attn_mask: Optional[torch.Tensor],
103
+ B: int,
104
+ H: int,
105
+ N_q: int,
106
+ N_kv: int,
107
+ ) -> Optional[torch.Tensor]:
108
+ """
109
+ Validate attention mask shape and expand broadcast dimensions.
110
+
111
+ Args:
112
+ attn_mask: Optional mask of shape (B, H, N_q, N_kv) or broadcastable
113
+ B: Batch size
114
+ H: Number of heads
115
+ N_q: Query sequence length
116
+ N_kv: Key/Value sequence length
117
+
118
+ Returns:
119
+ Expanded mask of shape (mb, mh, N_q, N_kv) or None
120
+ """
121
+ if attn_mask is None:
122
+ return None
123
+
124
+ attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
125
+
126
+ if attn_mask.dim() != 4:
127
+ raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
128
+
129
+ mb, mh, mq, mk = attn_mask.shape
130
+
131
+ # Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
132
+ if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
133
+ raise ValueError(
134
+ f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
135
+ )
136
+
137
+ # Expand broadcast mask to full shape for Metal kernel
138
+ if mq == 1 and N_q > 1:
139
+ attn_mask = attn_mask.expand(mb, mh, N_q, mk)
140
+ if mk == 1 and N_kv > 1:
141
+ attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
142
+
143
+ if mb != 1 and mb != B:
144
+ raise ValueError(f"attn_mask batch size must be 1 or {B}, got {mb}")
145
+ if mh != 1 and mh != H:
146
+ raise ValueError(f"attn_mask head count must be 1 or {H}, got {mh}")
147
+
148
+ return attn_mask
149
+
150
+
151
+ class FlashAttentionWithBiasFunction(torch.autograd.Function):
152
+ """Autograd function for Flash Attention with bias - native C++ backward."""
153
+
154
+ @staticmethod
155
+ def forward(ctx, query, key, value, attn_bias, is_causal, scale, window_size, bias_repeat_count):
156
+ # Apply scale if provided
157
+ scale_factor = 1.0
158
+ if scale is not None:
159
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
160
+ if abs(scale - default_scale) > 1e-6:
161
+ scale_factor = scale / default_scale
162
+ query = query * scale_factor
163
+
164
+ # Call C++ forward with bias (returns output and logsumexp)
165
+ output, logsumexp = _C.forward_with_bias_lse(query, key, value, attn_bias, is_causal, window_size, bias_repeat_count)
166
+
167
+ # Save for backward
168
+ ctx.save_for_backward(query, key, value, output, logsumexp, attn_bias)
169
+ ctx.is_causal = is_causal
170
+ ctx.scale_factor = scale_factor
171
+ ctx.window_size = window_size
172
+ ctx.bias_repeat_count = bias_repeat_count
173
+
174
+ return output
175
+
176
+ @staticmethod
177
+ def backward(ctx, grad_output):
178
+ query, key, value, output, logsumexp, attn_bias = ctx.saved_tensors
179
+
180
+ # Call native C++ backward with bias
181
+ dQ, dK, dV = _C.backward_with_bias(
182
+ grad_output, query, key, value, output, logsumexp, attn_bias,
183
+ ctx.is_causal, ctx.window_size, ctx.bias_repeat_count
184
+ )
185
+
186
+ # Scale dQ back if we scaled query
187
+ if ctx.scale_factor != 1.0:
188
+ dQ = dQ * ctx.scale_factor
189
+
190
+ return dQ, dK, dV, None, None, None, None, None
191
+
192
+
57
193
  class FlashAttentionFunction(torch.autograd.Function):
58
194
  """Autograd function for Flash Attention with backward pass support."""
59
195
 
@@ -176,6 +312,20 @@ def flash_attention(
176
312
  if not torch.backends.mps.is_available():
177
313
  raise RuntimeError("MPS not available")
178
314
 
315
+ # Validate scale parameter
316
+ if scale is not None:
317
+ if scale <= 0:
318
+ raise ValueError(f"scale must be positive, got {scale}")
319
+ # Warn about extreme scale values that could cause numerical issues
320
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
321
+ if scale < default_scale * 0.01 or scale > default_scale * 100:
322
+ warnings.warn(
323
+ f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
324
+ "this may cause numerical issues",
325
+ UserWarning,
326
+ stacklevel=2
327
+ )
328
+
179
329
  # Validate device
180
330
  if query.device.type != 'mps':
181
331
  raise ValueError("query must be on MPS device")
@@ -186,6 +336,24 @@ def flash_attention(
186
336
  if attn_mask is not None and attn_mask.device.type != 'mps':
187
337
  raise ValueError("attn_mask must be on MPS device")
188
338
 
339
+ # Ensure contiguous (auto-convert with debug warning)
340
+ query = _ensure_contiguous(query, "query")
341
+ key = _ensure_contiguous(key, "key")
342
+ value = _ensure_contiguous(value, "value")
343
+
344
+ # Validate tensor dimensions
345
+ if query.dim() != 4:
346
+ raise RuntimeError(f"query must be 4D (B, H, N, D), got {query.dim()}D")
347
+ if key.dim() != 4:
348
+ raise RuntimeError(f"key must be 4D (B, H, N, D), got {key.dim()}D")
349
+ if value.dim() != 4:
350
+ raise RuntimeError(f"value must be 4D (B, H, N, D), got {value.dim()}D")
351
+
352
+ # Validate and expand broadcast mask
353
+ B, H, N_q, D = query.shape
354
+ N_kv = key.shape[2]
355
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
356
+
189
357
  # Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
190
358
  if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
191
359
  # Apply scale if provided
@@ -213,39 +381,107 @@ def replace_sdpa():
213
381
  import torch.nn.functional as F
214
382
 
215
383
  original_sdpa = F.scaled_dot_product_attention
384
+ _debug = os.environ.get("MFA_DEBUG", "0") == "1"
385
+ _call_count = [0] # mutable for closure
386
+ _fallback_count = [0] # track fallbacks for warning
387
+ _last_fallback_error = [None]
216
388
 
217
389
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
218
390
  is_causal=False, scale=None, enable_gqa=False, **kwargs):
219
391
  # Use MFA for MPS tensors without dropout
220
- # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
392
+ # Only use MFA for seq_len >= 512 where it outperforms PyTorch's math backend
221
393
  # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
222
394
  # Benchmark results (B=1-4, H=8, D=64-128, fp16/bf16):
223
- # seq=512: 0.3-0.5x (MFA slower)
224
- # seq=1024: 1.1-2.0x (MFA faster)
225
- # seq=2048: 1.7-3.7x (MFA much faster)
226
- # seq=4096: 2.0-3.9x (MFA much faster)
395
+ # seq=512: 1.2-1.6x (MFA faster)
396
+ # seq=1024: 2.3-3.7x (MFA much faster)
397
+ # seq=2048: 2.2-3.9x (MFA much faster)
398
+ # seq=4096: 2.1-3.7x (MFA much faster)
399
+ # Determine seq_len based on tensor dimensionality
400
+ # 4D: (B, H, S, D) -> seq_len = shape[2]
401
+ # 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
402
+ is_3d = query.ndim == 3
403
+ seq_len = query.shape[1] if is_3d else query.shape[2]
404
+
227
405
  if (query.device.type == 'mps' and
228
406
  dropout_p == 0.0 and
229
407
  _HAS_MFA and
230
- query.shape[2] >= 1024):
408
+ query.ndim >= 3 and
409
+ seq_len >= 512):
231
410
  try:
411
+ q, k, v = query, key, value
412
+
413
+ # Handle 3D tensors (B, S, D) - treat as single-head attention
414
+ # Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
415
+ if is_3d:
416
+ q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
417
+ k = k.unsqueeze(1)
418
+ v = v.unsqueeze(1)
419
+
420
+ # Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
421
+ # Common in Llama 2/3, Mistral, Qwen, etc.
422
+ # NOTE: Always expand when heads mismatch, not just when enable_gqa=True
423
+ # Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
424
+ # MPS SDPA doesn't support native GQA - we handle it here
425
+ if q.shape[1] != k.shape[1]:
426
+ # Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
427
+ n_rep = q.shape[1] // k.shape[1]
428
+ k = k.repeat_interleave(n_rep, dim=1)
429
+ v = v.repeat_interleave(n_rep, dim=1)
430
+
232
431
  # Convert float mask to bool mask if needed
233
432
  # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
234
433
  # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
235
434
  mfa_mask = None
236
435
  if attn_mask is not None:
436
+ if _debug:
437
+ print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
237
438
  if attn_mask.dtype == torch.bool:
238
- # Boolean mask: True means masked (don't attend)
239
- mfa_mask = attn_mask
439
+ # PyTorch SDPA bool mask: True = ATTEND, False = MASKED
440
+ # MFA bool mask: True = MASKED, False = ATTEND
441
+ # They're opposite! Invert it.
442
+ mfa_mask = ~attn_mask
240
443
  else:
241
444
  # Float mask: typically -inf for masked positions, 0 for unmasked
242
445
  # Convert: positions with large negative values -> True (masked)
243
446
  # Use -1e3 threshold to catch -1000, -10000, -inf, etc.
244
447
  mfa_mask = attn_mask <= -1e3
245
- return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
246
- except Exception:
247
- # Fall back to original on any error
248
- pass
448
+ if _debug:
449
+ print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
450
+
451
+ out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
452
+
453
+ # Squeeze back for 3D input
454
+ if is_3d:
455
+ out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
456
+
457
+ if _debug:
458
+ _call_count[0] += 1
459
+ print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
460
+
461
+ return out
462
+ except Exception as e:
463
+ # Fall back to original on any error, but track it
464
+ _fallback_count[0] += 1
465
+ _last_fallback_error[0] = str(e)
466
+ if _debug:
467
+ import traceback
468
+ print(f"[MFA FALLBACK #{_fallback_count[0]}] shape={tuple(query.shape)}\n{traceback.format_exc()}")
469
+ # Warn user after repeated fallbacks (likely a real problem)
470
+ if _fallback_count[0] == 10:
471
+ warnings.warn(
472
+ f"MFA has fallen back to native SDPA {_fallback_count[0]} times. "
473
+ f"Last error: {_last_fallback_error[0]}. "
474
+ f"Set MFA_DEBUG=1 for details.",
475
+ UserWarning
476
+ )
477
+
478
+ if _debug and query.device.type == 'mps':
479
+ _call_count[0] += 1
480
+ reason = []
481
+ if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
482
+ if query.ndim < 3: reason.append(f"ndim={query.ndim}")
483
+ if seq_len < 512: reason.append(f"seq={seq_len}<512")
484
+ print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
249
485
 
250
486
  return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
251
487
 
@@ -302,6 +538,7 @@ def flash_attention_with_bias(
302
538
  value: torch.Tensor,
303
539
  attn_bias: torch.Tensor,
304
540
  is_causal: bool = False,
541
+ scale: Optional[float] = None,
305
542
  window_size: int = 0,
306
543
  bias_repeat_count: int = 0,
307
544
  ) -> torch.Tensor:
@@ -309,13 +546,15 @@ def flash_attention_with_bias(
309
546
  Compute scaled dot-product attention with additive attention bias.
310
547
 
311
548
  This function supports additive attention bias (like relative position encodings
312
- or ALiBi) which is added to the attention scores before softmax:
549
+ or ALiBi) which is added to the attention scores:
550
+
551
+ Attention(Q, K, V) = softmax((Q @ K.T + bias) * scale) @ V
313
552
 
314
- Attention(Q, K, V) = softmax((Q @ K.T) / sqrt(d) + bias) @ V
553
+ IMPORTANT: The bias is added to UNSCALED scores, then the sum is scaled.
554
+ This differs from PyTorch SDPA which does: softmax((Q @ K.T) * scale + bias).
315
555
 
316
- IMPORTANT: MFA adds bias to UNSCALED scores internally and scales during softmax.
317
- If your bias was computed for scaled scores (like PyTorch SDPA), you need to
318
- pre-scale it by multiplying by sqrt(head_dim).
556
+ To convert from SDPA-style bias to MFA-style:
557
+ bias_mfa = bias_sdpa * sqrt(head_dim) # when using default scale
319
558
 
320
559
  Args:
321
560
  query: Query tensor of shape (B, H, N_q, D)
@@ -326,6 +565,7 @@ def flash_attention_with_bias(
326
565
  - (1, H, N_q, N_kv): Broadcast across batch
327
566
  - (H, N_q, N_kv): Broadcast across batch (3D)
328
567
  is_causal: If True, applies causal masking
568
+ scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
329
569
  window_size: Sliding window attention size (0 = full attention)
330
570
  bias_repeat_count: If > 0, the bias tensor repeats every N batches.
331
571
  Useful for window attention where multiple windows share the same
@@ -343,10 +583,13 @@ def flash_attention_with_bias(
343
583
  >>> v = torch.randn(4, 8, 64, 64, device='mps', dtype=torch.float16)
344
584
  >>> # Position bias: (1, num_heads, seq_len, seq_len)
345
585
  >>> bias = torch.randn(1, 8, 64, 64, device='mps', dtype=torch.float16)
346
- >>> # Pre-scale bias since MFA uses unscaled scores
586
+ >>> # Pre-scale bias since default scale is 1/sqrt(head_dim)
347
587
  >>> scaled_bias = bias * math.sqrt(64) # sqrt(head_dim)
348
588
  >>> out = flash_attention_with_bias(q, k, v, scaled_bias)
349
589
 
590
+ >>> # With custom scale
591
+ >>> out = flash_attention_with_bias(q, k, v, bias, scale=0.1)
592
+
350
593
  >>> # Window attention with repeating bias pattern
351
594
  >>> n_windows = 16
352
595
  >>> q = torch.randn(n_windows * 4, 8, 49, 64, device='mps', dtype=torch.float16)
@@ -363,6 +606,20 @@ def flash_attention_with_bias(
363
606
  if not torch.backends.mps.is_available():
364
607
  raise RuntimeError("MPS not available")
365
608
 
609
+ # Validate scale parameter
610
+ if scale is not None:
611
+ if scale <= 0:
612
+ raise ValueError(f"scale must be positive, got {scale}")
613
+ # Warn about extreme scale values
614
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
615
+ if scale < default_scale * 0.01 or scale > default_scale * 100:
616
+ warnings.warn(
617
+ f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
618
+ "this may cause numerical issues",
619
+ UserWarning,
620
+ stacklevel=2
621
+ )
622
+
366
623
  # Validate device
367
624
  if query.device.type != 'mps':
368
625
  raise ValueError("query must be on MPS device")
@@ -373,7 +630,10 @@ def flash_attention_with_bias(
373
630
  if attn_bias.device.type != 'mps':
374
631
  raise ValueError("attn_bias must be on MPS device")
375
632
 
376
- return _C.forward_with_bias(query, key, value, attn_bias, is_causal, window_size, bias_repeat_count)
633
+ # Use autograd Function for backward support
634
+ return FlashAttentionWithBiasFunction.apply(
635
+ query, key, value, attn_bias, is_causal, scale, window_size, bias_repeat_count
636
+ )
377
637
 
378
638
 
379
639
  def flash_attention_chunked(
@@ -452,13 +712,13 @@ def flash_attention_chunked(
452
712
  return _C.forward(query, key, value, is_causal, None, 0)
453
713
 
454
714
  # Initialize running statistics for online softmax
455
- # m = running max, l = running sum of exp, acc = accumulated output
456
715
  device = query.device
457
716
  dtype = query.dtype
458
717
 
459
718
  # Use float32 for numerical stability of softmax statistics
460
- running_max = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
461
- running_sum = torch.zeros((B, H, seq_len_q, 1), device=device, dtype=torch.float32)
719
+ # running_L: base-2 logsumexp of all attention scores seen so far (-inf means no data yet)
720
+ # output_acc: weighted combination of outputs (weights sum to 1 after each update)
721
+ running_L = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
462
722
  output_acc = torch.zeros((B, H, seq_len_q, D), device=device, dtype=torch.float32)
463
723
 
464
724
  # Process K/V in chunks
@@ -479,51 +739,81 @@ def flash_attention_chunked(
479
739
  # - Partial chunk (up to q) if start_idx <= q < end_idx
480
740
  # - None of chunk if q < start_idx
481
741
 
482
- chunk_is_causal = is_causal and (end_idx <= seq_len_q)
742
+ if is_causal:
743
+ # Create explicit causal mask for this chunk
744
+ # Query positions: 0 to seq_len_q-1
745
+ # Key positions in chunk: start_idx to end_idx-1
746
+ chunk_len = end_idx - start_idx
483
747
 
484
- # Compute attention for this chunk
485
- # forward_with_lse returns (output, logsumexp) where logsumexp = m + log(l)
486
- chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, chunk_is_causal, None, 0)
748
+ # Build mask: mask[q, k_local] = True means DON'T attend
749
+ # We want to attend when global_k_pos <= q
750
+ # global_k_pos = start_idx + k_local
751
+ # So: attend when start_idx + k_local <= q
752
+ # mask = start_idx + k_local > q
487
753
 
488
- # chunk_lse shape: (B, H, seq_len_q)
489
- # We need to convert logsumexp to (max, sum) for online algorithm
490
- chunk_lse = chunk_lse.unsqueeze(-1) # (B, H, seq_len_q, 1)
754
+ q_pos = torch.arange(seq_len_q, device=device).view(1, 1, seq_len_q, 1)
755
+ k_pos = torch.arange(chunk_len, device=device).view(1, 1, 1, chunk_len) + start_idx
756
+ causal_mask = k_pos > q_pos # True = masked (don't attend)
491
757
 
492
- # Convert chunk output to float32 for accumulation
493
- chunk_out = chunk_out.float()
494
-
495
- # Online softmax update:
496
- # new_max = max(running_max, chunk_max)
497
- # For flash attention, chunk_lse ≈ chunk_max + log(chunk_sum)
498
- # We approximate chunk_max ≈ chunk_lse (valid when exp sum dominates)
758
+ # Expand to batch and heads
759
+ causal_mask = causal_mask.expand(B, H, seq_len_q, chunk_len)
499
760
 
500
- chunk_max = chunk_lse # Approximation: logsumexp max when sum is dominated by max
501
-
502
- # Compute new max
503
- new_max = torch.maximum(running_max, chunk_max)
761
+ # Call forward with explicit mask (is_causal=False since we handle it)
762
+ chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, causal_mask, 0)
763
+ else:
764
+ # Non-causal: just process the chunk directly
765
+ chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, None, 0)
504
766
 
505
- # Rescale previous accumulator
506
- # correction_old = exp(running_max - new_max)
507
- correction_old = torch.exp(running_max - new_max)
508
- # Clip to avoid inf * 0 issues when running_max was -inf
509
- correction_old = torch.where(running_max == float('-inf'), torch.zeros_like(correction_old), correction_old)
767
+ # chunk_L shape: (B, H, seq_len_q)
768
+ # The kernel returns L = m + log2(l) where:
769
+ # m = max(scores * log2(e) / sqrt(D))
770
+ # l = sum(exp2(scores * log2(e) / sqrt(D) - m))
771
+ # This is a base-2 logsumexp: L = log2(sum(exp2(scaled_scores)))
772
+ chunk_L = chunk_lse.unsqueeze(-1).float() # (B, H, seq_len_q, 1)
510
773
 
511
- # Rescale chunk output
512
- # correction_new = exp(chunk_max - new_max)
513
- correction_new = torch.exp(chunk_max - new_max)
774
+ # Convert chunk output to float32 for accumulation
775
+ chunk_out = chunk_out.float()
514
776
 
515
- # For the sum, we need exp(chunk_lse - new_max) = exp(chunk_max + log(chunk_sum) - new_max)
516
- # = exp(chunk_max - new_max) * chunk_sum
517
- # But we only have logsumexp, so: exp(chunk_lse - new_max)
518
- chunk_sum_scaled = torch.exp(chunk_lse - new_max)
777
+ # Online softmax algorithm using base-2 representation
778
+ #
779
+ # Flash attention returns: chunk_out = softmax(scores) @ V
780
+ # The output is already normalized. For online combination:
781
+ # new_L = log2(2^running_L + 2^chunk_L)
782
+ # = max(running_L, chunk_L) + log2(2^(running_L - max) + 2^(chunk_L - max))
783
+ #
784
+ # The weights for combining outputs are:
785
+ # old_weight = 2^(running_L - new_L)
786
+ # new_weight = 2^(chunk_L - new_L)
787
+ # These weights sum to 1, so: output = old_weight * old_out + new_weight * new_out
788
+
789
+ # Compute new base-2 logsumexp
790
+ max_L = torch.maximum(running_L, chunk_L)
791
+
792
+ # Handle -inf case (no previous data)
793
+ # Use exp2 for base-2 (matches kernel's internal representation)
794
+ running_exp2 = torch.where(
795
+ running_L == float('-inf'),
796
+ torch.zeros_like(running_L),
797
+ torch.exp2(running_L - max_L)
798
+ )
799
+ chunk_exp2 = torch.exp2(chunk_L - max_L)
800
+ new_L = max_L + torch.log2(running_exp2 + chunk_exp2)
801
+
802
+ # Compute correction factors using base-2 exp
803
+ old_weight = torch.where(
804
+ running_L == float('-inf'),
805
+ torch.zeros_like(running_L),
806
+ torch.exp2(running_L - new_L)
807
+ )
808
+ new_weight = torch.exp2(chunk_L - new_L)
519
809
 
520
810
  # Update accumulator
521
- output_acc = output_acc * correction_old + chunk_out * correction_new
522
- running_sum = running_sum * correction_old + chunk_sum_scaled
523
- running_max = new_max
811
+ # Update accumulator
812
+ output_acc = output_acc * old_weight + chunk_out * new_weight
813
+ running_L = new_L
524
814
 
525
- # Final normalization
526
- output = output_acc / running_sum
815
+ # No final normalization needed - weights already sum to 1
816
+ output = output_acc
527
817
 
528
818
  # Convert back to original dtype
529
819
  return output.to(dtype)
@@ -745,6 +1035,11 @@ def flash_attention_fp8(
745
1035
  scale_factor = scale / default_scale
746
1036
  query = query * scale_factor
747
1037
 
1038
+ # Validate and expand broadcast mask
1039
+ B, H, N_q, D = query.shape
1040
+ N_kv = key.shape[2]
1041
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1042
+
748
1043
  quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
749
1044
  return _C.forward_quantized(
750
1045
  query, key, value, k_scale, v_scale,
@@ -798,6 +1093,11 @@ def flash_attention_int8(
798
1093
  scale_factor = scale / default_scale
799
1094
  query = query * scale_factor
800
1095
 
1096
+ # Validate and expand broadcast mask
1097
+ B, H, N_q, D = query.shape
1098
+ N_kv = key.shape[2]
1099
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1100
+
801
1101
  return _C.forward_quantized(
802
1102
  query, key, value, k_scale, v_scale,
803
1103
  QUANT_INT8, is_causal, attn_mask, window_size
@@ -854,6 +1154,11 @@ def flash_attention_nf4(
854
1154
  scale_factor = scale / default_scale
855
1155
  query = query * scale_factor
856
1156
 
1157
+ # Validate and expand broadcast mask
1158
+ B, H, N_q, D = query.shape
1159
+ N_kv = key.shape[2]
1160
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1161
+
857
1162
  return _C.forward_quantized(
858
1163
  query, key, value, k_scale, v_scale,
859
1164
  QUANT_NF4, is_causal, attn_mask, window_size
@@ -908,6 +1213,11 @@ def flash_attention_quantized(
908
1213
  scale_factor = scale / default_scale
909
1214
  query = query * scale_factor
910
1215
 
1216
+ # Validate and expand broadcast mask
1217
+ B, H, N_q, D = query.shape
1218
+ N_kv = key.shape[2]
1219
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1220
+
911
1221
  return _C.forward_quantized(
912
1222
  query, key, value, k_scale, v_scale,
913
1223
  quant_type, is_causal, attn_mask, window_size
@@ -14,6 +14,8 @@
14
14
  #include <dlfcn.h>
15
15
  #include <string>
16
16
  #include <vector>
17
+ #include <mutex>
18
+ #include <atomic>
17
19
 
18
20
  // ============================================================================
19
21
  // MFA Bridge Function Types
@@ -41,6 +43,9 @@ typedef bool (*mfa_forward_encode_quantized_fn)(void*, void*, void*, void*, void
41
43
  typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
42
44
  int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
43
45
  int32_t, int32_t);
46
+ typedef bool (*mfa_backward_encode_bias_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
47
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
48
+ int32_t, int32_t);
44
49
  // Legacy sync functions (fallback)
45
50
  typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, void*,
46
51
  int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
@@ -63,11 +68,13 @@ static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
63
68
  static mfa_forward_encode_bias_fn g_mfa_forward_encode_bias = nullptr;
64
69
  static mfa_forward_encode_quantized_fn g_mfa_forward_encode_quantized = nullptr;
65
70
  static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
71
+ static mfa_backward_encode_bias_fn g_mfa_backward_encode_bias = nullptr;
66
72
  static mfa_forward_fn g_mfa_forward = nullptr;
67
73
  static mfa_backward_fn g_mfa_backward = nullptr;
68
74
  static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
69
75
  static void* g_dylib_handle = nullptr;
70
- static bool g_initialized = false;
76
+ static std::atomic<bool> g_initialized{false};
77
+ static std::mutex g_init_mutex;
71
78
 
72
79
  // ============================================================================
73
80
  // Load MFA Bridge Library
@@ -129,6 +136,7 @@ static bool load_mfa_bridge() {
129
136
  g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
130
137
  g_mfa_forward_encode_bias = (mfa_forward_encode_bias_fn)dlsym(g_dylib_handle, "mfa_forward_encode_bias");
131
138
  g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
139
+ g_mfa_backward_encode_bias = (mfa_backward_encode_bias_fn)dlsym(g_dylib_handle, "mfa_backward_encode_bias");
132
140
  g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
133
141
  g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
134
142
  g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
@@ -141,6 +149,24 @@ static bool load_mfa_bridge() {
141
149
  return true;
142
150
  }
143
151
 
152
+ // Thread-safe initialization helper
153
+ static void ensure_initialized() {
154
+ // Fast path: already initialized
155
+ if (g_initialized.load(std::memory_order_acquire)) {
156
+ return;
157
+ }
158
+ // Slow path: need to initialize with lock
159
+ std::lock_guard<std::mutex> lock(g_init_mutex);
160
+ // Double-check after acquiring lock
161
+ if (!g_initialized.load(std::memory_order_relaxed)) {
162
+ load_mfa_bridge();
163
+ if (!g_mfa_init()) {
164
+ throw std::runtime_error("Failed to initialize MFA");
165
+ }
166
+ g_initialized.store(true, std::memory_order_release);
167
+ }
168
+ }
169
+
144
170
  // ============================================================================
145
171
  // Get MTLBuffer from PyTorch MPS Tensor
146
172
  // ============================================================================
@@ -359,14 +385,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
359
385
  const c10::optional<at::Tensor>& attn_mask, // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
360
386
  int64_t window_size // 0 = full attention, >0 = sliding window
361
387
  ) {
362
- // Initialize MFA on first call
363
- if (!g_initialized) {
364
- load_mfa_bridge();
365
- if (!g_mfa_init()) {
366
- throw std::runtime_error("Failed to initialize MFA");
367
- }
368
- g_initialized = true;
369
- }
388
+ // Thread-safe initialization
389
+ ensure_initialized();
370
390
 
371
391
  // Validate inputs
372
392
  TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
@@ -562,14 +582,8 @@ at::Tensor mps_flash_attention_forward_with_bias(
562
582
  int64_t window_size,
563
583
  int64_t bias_repeat_count // >0 means bias repeats every N batches (for window attention)
564
584
  ) {
565
- // Initialize MFA on first call
566
- if (!g_initialized) {
567
- load_mfa_bridge();
568
- if (!g_mfa_init()) {
569
- throw std::runtime_error("Failed to initialize MFA");
570
- }
571
- g_initialized = true;
572
- }
585
+ // Thread-safe initialization
586
+ ensure_initialized();
573
587
 
574
588
  // Check that v6/v7 API is available
575
589
  TORCH_CHECK(g_mfa_create_kernel_v6 || g_mfa_create_kernel_v7,
@@ -630,9 +644,13 @@ at::Tensor mps_flash_attention_forward_with_bias(
630
644
  q = q.to(at::kFloat);
631
645
  k = k.to(at::kFloat);
632
646
  v = v.to(at::kFloat);
633
- bias = bias.to(at::kFloat);
634
647
  }
635
648
 
649
+ // IMPORTANT: Bias is always FP32 in the Metal kernel (registerPrecisions[.S] = .FP32
650
+ // when lowPrecisionIntermediates = false, which is the common case)
651
+ // Always convert bias to FP32 to match the kernel's expected type
652
+ bias = bias.to(at::kFloat);
653
+
636
654
  // Allocate output
637
655
  at::Tensor output;
638
656
  if (use_bf16_kernel) {
@@ -711,6 +729,156 @@ at::Tensor mps_flash_attention_forward_with_bias(
711
729
  return output;
712
730
  }
713
731
 
732
+ // Forward with bias returning both output and logsumexp (for backward pass)
733
+ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_bias_lse(
734
+ const at::Tensor& query,
735
+ const at::Tensor& key,
736
+ const at::Tensor& value,
737
+ const at::Tensor& attn_bias,
738
+ bool is_causal,
739
+ int64_t window_size,
740
+ int64_t bias_repeat_count
741
+ ) {
742
+ // Thread-safe initialization
743
+ ensure_initialized();
744
+
745
+ // Check that v6/v7 API is available
746
+ TORCH_CHECK(g_mfa_create_kernel_v6 || g_mfa_create_kernel_v7,
747
+ "Attention bias requires MFA v6+ API (update libMFABridge.dylib)");
748
+ TORCH_CHECK(g_mfa_forward_encode_bias,
749
+ "Attention bias requires mfa_forward_encode_bias");
750
+
751
+ // Validate inputs
752
+ TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
753
+ TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
754
+ TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
755
+ TORCH_CHECK(query.device().is_mps(), "Query must be on MPS device");
756
+ TORCH_CHECK(key.device().is_mps(), "Key must be on MPS device");
757
+ TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
758
+ TORCH_CHECK(attn_bias.device().is_mps(), "Attention bias must be on MPS device");
759
+
760
+ const int64_t batch_size = query.size(0);
761
+ const int64_t num_heads = query.size(1);
762
+ const int64_t seq_len_q = query.size(2);
763
+ const int64_t head_dim = query.size(3);
764
+ const int64_t seq_len_kv = key.size(2);
765
+
766
+ // Determine bias strides for broadcasting
767
+ uint32_t bias_batch_stride = 0;
768
+ uint32_t bias_head_stride = static_cast<uint32_t>(seq_len_q * seq_len_kv);
769
+
770
+ if (attn_bias.dim() == 4) {
771
+ if (attn_bias.size(0) > 1) {
772
+ bias_batch_stride = static_cast<uint32_t>(attn_bias.size(1) * seq_len_q * seq_len_kv);
773
+ }
774
+ if (attn_bias.size(1) == 1) {
775
+ bias_head_stride = 0;
776
+ }
777
+ } else if (attn_bias.dim() == 3) {
778
+ bias_batch_stride = 0;
779
+ if (attn_bias.size(0) == 1) {
780
+ bias_head_stride = 0;
781
+ }
782
+ }
783
+
784
+ // Determine precision
785
+ bool is_bfloat16 = (query.scalar_type() == at::kBFloat16);
786
+ bool is_fp16 = (query.scalar_type() == at::kHalf);
787
+ bool use_bf16_kernel = is_bfloat16 && g_mfa_create_kernel_v2;
788
+ bool low_precision = is_fp16;
789
+ bool low_precision_outputs = is_fp16 || use_bf16_kernel;
790
+
791
+ // Make inputs contiguous
792
+ auto q = query.contiguous();
793
+ auto k = key.contiguous();
794
+ auto v = value.contiguous();
795
+ auto bias = attn_bias.contiguous();
796
+
797
+ // For BF16 without native kernel, convert to FP32
798
+ if (is_bfloat16 && !use_bf16_kernel) {
799
+ q = q.to(at::kFloat);
800
+ k = k.to(at::kFloat);
801
+ v = v.to(at::kFloat);
802
+ }
803
+
804
+ // Bias is always FP32 in the Metal kernel
805
+ bias = bias.to(at::kFloat);
806
+
807
+ // Allocate output
808
+ at::Tensor output;
809
+ if (use_bf16_kernel) {
810
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
811
+ query.options().dtype(at::kBFloat16));
812
+ } else if (low_precision_outputs) {
813
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
814
+ query.options().dtype(at::kHalf));
815
+ } else {
816
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
817
+ query.options().dtype(at::kFloat));
818
+ }
819
+
820
+ // Allocate logsumexp (always fp32)
821
+ auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
822
+ query.options().dtype(at::kFloat));
823
+
824
+ // Get or create kernel with bias support
825
+ void* kernel = get_or_create_kernel(
826
+ seq_len_q, seq_len_kv, head_dim,
827
+ low_precision, low_precision_outputs, is_causal, false,
828
+ use_bf16_kernel,
829
+ static_cast<uint32_t>(window_size > 0 ? window_size : 0),
830
+ 0, false, true,
831
+ bias_batch_stride, bias_head_stride,
832
+ static_cast<uint32_t>(bias_repeat_count > 0 ? bias_repeat_count : 0)
833
+ );
834
+
835
+ // Get Metal buffers
836
+ auto q_info = getBufferInfo(q);
837
+ auto k_info = getBufferInfo(k);
838
+ auto v_info = getBufferInfo(v);
839
+ auto o_info = getBufferInfo(output);
840
+ auto l_info = getBufferInfo(logsumexp);
841
+ auto bias_info = getBufferInfo(bias);
842
+
843
+ // Execute
844
+ @autoreleasepool {
845
+ auto stream = at::mps::getCurrentMPSStream();
846
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
847
+
848
+ bool success = g_mfa_forward_encode_bias(
849
+ kernel,
850
+ (__bridge void*)encoder,
851
+ (__bridge void*)q_info.buffer,
852
+ (__bridge void*)k_info.buffer,
853
+ (__bridge void*)v_info.buffer,
854
+ (__bridge void*)o_info.buffer,
855
+ (__bridge void*)l_info.buffer,
856
+ nullptr,
857
+ (__bridge void*)bias_info.buffer,
858
+ q_info.byte_offset,
859
+ k_info.byte_offset,
860
+ v_info.byte_offset,
861
+ o_info.byte_offset,
862
+ l_info.byte_offset,
863
+ 0,
864
+ bias_info.byte_offset,
865
+ static_cast<int32_t>(batch_size),
866
+ static_cast<int32_t>(num_heads)
867
+ );
868
+
869
+ if (!success) {
870
+ throw std::runtime_error("MFA forward with bias failed");
871
+ }
872
+ }
873
+
874
+ // Convert output back to BF16 if needed
875
+ if (is_bfloat16 && !use_bf16_kernel) {
876
+ output = output.to(at::kBFloat16);
877
+ }
878
+
879
+ return std::make_tuple(output, logsumexp);
880
+ }
881
+
714
882
  // ============================================================================
715
883
  // Quantized Flash Attention Forward (FP8, INT8, NF4)
716
884
  // ============================================================================
@@ -735,14 +903,8 @@ at::Tensor mps_flash_attention_forward_quantized(
735
903
  const c10::optional<at::Tensor>& attn_mask,
736
904
  int64_t window_size
737
905
  ) {
738
- // Initialize MFA on first call
739
- if (!g_initialized) {
740
- load_mfa_bridge();
741
- if (!g_mfa_init()) {
742
- throw std::runtime_error("Failed to initialize MFA");
743
- }
744
- g_initialized = true;
745
- }
906
+ // Thread-safe initialization
907
+ ensure_initialized();
746
908
 
747
909
  // Check that v4 API is available
748
910
  TORCH_CHECK(g_mfa_create_kernel_v4, "Quantized attention requires MFA v4 API (update libMFABridge.dylib)");
@@ -992,14 +1154,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
992
1154
  int64_t window_size, // 0 = full attention, >0 = sliding window
993
1155
  bool bf16_backward // true = use BF16 intermediates for ~2x faster backward
994
1156
  ) {
995
- // Initialize MFA on first call
996
- if (!g_initialized) {
997
- load_mfa_bridge();
998
- if (!g_mfa_init()) {
999
- throw std::runtime_error("Failed to initialize MFA");
1000
- }
1001
- g_initialized = true;
1002
- }
1157
+ // Thread-safe initialization
1158
+ ensure_initialized();
1003
1159
 
1004
1160
  // Validate inputs
1005
1161
  TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
@@ -1160,6 +1316,146 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
1160
1316
  return std::make_tuple(dQ, dK, dV);
1161
1317
  }
1162
1318
 
1319
+ // Backward with bias support
1320
+ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward_with_bias(
1321
+ const at::Tensor& grad_output,
1322
+ const at::Tensor& query,
1323
+ const at::Tensor& key,
1324
+ const at::Tensor& value,
1325
+ const at::Tensor& output,
1326
+ const at::Tensor& logsumexp,
1327
+ const at::Tensor& attn_bias,
1328
+ bool is_causal,
1329
+ int64_t window_size,
1330
+ int64_t bias_repeat_count
1331
+ ) {
1332
+ ensure_initialized();
1333
+
1334
+ TORCH_CHECK(g_mfa_backward_encode_bias,
1335
+ "Backward with bias requires mfa_backward_encode_bias (update libMFABridge.dylib)");
1336
+
1337
+ TORCH_CHECK(grad_output.dim() == 4, "grad_output must be 4D (B, H, N, D)");
1338
+ TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
1339
+ TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
1340
+ TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
1341
+ TORCH_CHECK(output.dim() == 4, "Output must be 4D (B, H, N, D)");
1342
+ TORCH_CHECK(logsumexp.dim() == 3, "Logsumexp must be 3D (B, H, N)");
1343
+
1344
+ const int64_t batch_size = query.size(0);
1345
+ const int64_t num_heads = query.size(1);
1346
+ const int64_t seq_len_q = query.size(2);
1347
+ const int64_t head_dim = query.size(3);
1348
+ const int64_t seq_len_kv = key.size(2);
1349
+
1350
+ // Determine bias strides
1351
+ uint32_t bias_batch_stride = 0;
1352
+ uint32_t bias_head_stride = static_cast<uint32_t>(seq_len_q * seq_len_kv);
1353
+
1354
+ if (attn_bias.dim() == 4) {
1355
+ if (attn_bias.size(0) > 1) {
1356
+ bias_batch_stride = static_cast<uint32_t>(attn_bias.size(1) * seq_len_q * seq_len_kv);
1357
+ }
1358
+ if (attn_bias.size(1) == 1) {
1359
+ bias_head_stride = 0;
1360
+ }
1361
+ } else if (attn_bias.dim() == 3) {
1362
+ bias_batch_stride = 0;
1363
+ if (attn_bias.size(0) == 1) {
1364
+ bias_head_stride = 0;
1365
+ }
1366
+ }
1367
+
1368
+ bool low_precision = (query.scalar_type() == at::kHalf || query.scalar_type() == at::kBFloat16);
1369
+
1370
+ // Standard backward: upcast to FP32
1371
+ auto q = query.contiguous().to(at::kFloat);
1372
+ auto k = key.contiguous().to(at::kFloat);
1373
+ auto v = value.contiguous().to(at::kFloat);
1374
+ auto o = output.contiguous().to(at::kFloat);
1375
+ auto dO = grad_output.contiguous().to(at::kFloat);
1376
+ auto lse_tensor = logsumexp.contiguous();
1377
+ auto bias = attn_bias.contiguous().to(at::kFloat);
1378
+
1379
+ auto D = at::empty({batch_size, num_heads, seq_len_q}, query.options().dtype(at::kFloat));
1380
+
1381
+ // Get kernel with bias support
1382
+ void* kernel = get_or_create_kernel(
1383
+ seq_len_q, seq_len_kv, head_dim,
1384
+ false, false, is_causal, false, false,
1385
+ static_cast<uint32_t>(window_size > 0 ? window_size : 0),
1386
+ 0, false, true,
1387
+ bias_batch_stride, bias_head_stride,
1388
+ static_cast<uint32_t>(bias_repeat_count > 0 ? bias_repeat_count : 0)
1389
+ );
1390
+
1391
+ // Allocate gradients
1392
+ auto dQ = at::zeros({batch_size, num_heads, seq_len_q, head_dim}, query.options().dtype(at::kFloat));
1393
+ auto dK = at::zeros({batch_size, num_heads, seq_len_kv, head_dim}, query.options().dtype(at::kFloat));
1394
+ auto dV = at::zeros({batch_size, num_heads, seq_len_kv, head_dim}, query.options().dtype(at::kFloat));
1395
+
1396
+ // Get Metal buffers
1397
+ auto q_info = getBufferInfo(q);
1398
+ auto k_info = getBufferInfo(k);
1399
+ auto v_info = getBufferInfo(v);
1400
+ auto o_info = getBufferInfo(o);
1401
+ auto do_info = getBufferInfo(dO);
1402
+ auto l_info = getBufferInfo(lse_tensor);
1403
+ auto d_info = getBufferInfo(D);
1404
+ auto dq_info = getBufferInfo(dQ);
1405
+ auto dk_info = getBufferInfo(dK);
1406
+ auto dv_info = getBufferInfo(dV);
1407
+ auto bias_info = getBufferInfo(bias);
1408
+
1409
+ @autoreleasepool {
1410
+ auto stream = at::mps::getCurrentMPSStream();
1411
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
1412
+
1413
+ bool success = g_mfa_backward_encode_bias(
1414
+ kernel,
1415
+ (__bridge void*)encoder,
1416
+ (__bridge void*)q_info.buffer,
1417
+ (__bridge void*)k_info.buffer,
1418
+ (__bridge void*)v_info.buffer,
1419
+ (__bridge void*)o_info.buffer,
1420
+ (__bridge void*)do_info.buffer,
1421
+ (__bridge void*)l_info.buffer,
1422
+ (__bridge void*)d_info.buffer,
1423
+ (__bridge void*)dq_info.buffer,
1424
+ (__bridge void*)dk_info.buffer,
1425
+ (__bridge void*)dv_info.buffer,
1426
+ nullptr, // no mask
1427
+ (__bridge void*)bias_info.buffer,
1428
+ q_info.byte_offset,
1429
+ k_info.byte_offset,
1430
+ v_info.byte_offset,
1431
+ o_info.byte_offset,
1432
+ do_info.byte_offset,
1433
+ l_info.byte_offset,
1434
+ d_info.byte_offset,
1435
+ dq_info.byte_offset,
1436
+ dk_info.byte_offset,
1437
+ dv_info.byte_offset,
1438
+ 0, // mask_offset
1439
+ bias_info.byte_offset,
1440
+ static_cast<int32_t>(batch_size),
1441
+ static_cast<int32_t>(num_heads)
1442
+ );
1443
+
1444
+ if (!success) {
1445
+ throw std::runtime_error("MFA backward with bias failed");
1446
+ }
1447
+ }
1448
+
1449
+ // Convert gradients back to input dtype
1450
+ if (low_precision) {
1451
+ dQ = dQ.to(query.scalar_type());
1452
+ dK = dK.to(query.scalar_type());
1453
+ dV = dV.to(query.scalar_type());
1454
+ }
1455
+
1456
+ return std::make_tuple(dQ, dK, dV);
1457
+ }
1458
+
1163
1459
  // ============================================================================
1164
1460
  // Python Bindings
1165
1461
  // ============================================================================
@@ -1200,10 +1496,35 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1200
1496
 
1201
1497
  // Forward with additive attention bias (e.g., relative position encoding)
1202
1498
  m.def("forward_with_bias", &mps_flash_attention_forward_with_bias,
1203
- "Flash Attention forward with additive attention bias",
1499
+ "Flash Attention forward with additive attention bias (returns output only)",
1500
+ py::arg("query"),
1501
+ py::arg("key"),
1502
+ py::arg("value"),
1503
+ py::arg("attn_bias"),
1504
+ py::arg("is_causal") = false,
1505
+ py::arg("window_size") = 0,
1506
+ py::arg("bias_repeat_count") = 0);
1507
+
1508
+ // Forward with bias returning logsumexp (for backward)
1509
+ m.def("forward_with_bias_lse", &mps_flash_attention_forward_with_bias_lse,
1510
+ "Flash Attention forward with bias (returns output and logsumexp)",
1511
+ py::arg("query"),
1512
+ py::arg("key"),
1513
+ py::arg("value"),
1514
+ py::arg("attn_bias"),
1515
+ py::arg("is_causal") = false,
1516
+ py::arg("window_size") = 0,
1517
+ py::arg("bias_repeat_count") = 0);
1518
+
1519
+ // Backward with bias
1520
+ m.def("backward_with_bias", &mps_flash_attention_backward_with_bias,
1521
+ "Flash Attention backward with bias",
1522
+ py::arg("grad_output"),
1204
1523
  py::arg("query"),
1205
1524
  py::arg("key"),
1206
1525
  py::arg("value"),
1526
+ py::arg("output"),
1527
+ py::arg("logsumexp"),
1207
1528
  py::arg("attn_bias"),
1208
1529
  py::arg("is_causal") = false,
1209
1530
  py::arg("window_size") = 0,
Binary file
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.6
3
+ Version: 0.3.7
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -1,7 +1,7 @@
1
- mps_flash_attn/_C.cpython-314-darwin.so,sha256=-KfSzahwyV6JcqoT9TstAFmTQiFwS8Bl6NhiphCUzdA,313160
2
- mps_flash_attn/__init__.py,sha256=tNSb4nu1MhLqstuFAEA8ezWYKuhpwDya8TVxGmA9VMw,39711
1
+ mps_flash_attn/_C.cpython-314-darwin.so,sha256=MzbWTzTvd0siau5-Vk7TaovGYq3nBursRoRxnMzVpxo,335544
2
+ mps_flash_attn/__init__.py,sha256=tocFoOTiCauMObdVdvaD_f1sV2ioUW2wZr8vV5RZIaY,51807
3
3
  mps_flash_attn/benchmark.py,sha256=qHhvb8Dmh07OEa_iXuPuJSEnRJlrjVF5nKzVwbWypWE,24141
4
- mps_flash_attn/csrc/mps_flash_attn.mm,sha256=d7Bjcm2VOTNANmdqUevN-mqa5aOEVMMxAuTYINAeSr0,51215
4
+ mps_flash_attn/csrc/mps_flash_attn.mm,sha256=4RVXZZHceOXPX_XxKVlHjjedtXMOsEFLkQDIgwJBwtQ,63498
5
5
  mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib,sha256=_oig6f2I6ZxBCKWbJF3ofmZMySm8gB399_M-lD2NOfM,13747
6
6
  mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib,sha256=1fsmVvB5EubhN-y6s5CB-eVk_wuO2tfrabiQTwXvJJc,13171
7
7
  mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib,sha256=5WKo_yAU-PgmulBUQhnzvt0DZRteVmo4-nc4U-T6G2g,17507
@@ -25,9 +25,9 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
25
25
  mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin,sha256=4BjRprnMycnhZql9829R6FS3HW30jejuDJM9p9vzVPs,34112
26
26
  mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib,sha256=qyOaQtRVwL_Wc6GGdu6z-ftf0iX84XexuY09-lNLl5o,13747
27
27
  mps_flash_attn/kernels/manifest.json,sha256=d5MkE_BjqDQuMNm1jZiwWkQKfB-yfFml3lLSeR-wCLo,1867
28
- mps_flash_attn/lib/libMFABridge.dylib,sha256=iKgfYISSKMSNt_iXnljjUr_hZZHyCAg2tdS3_ZjmLkc,605696
29
- mps_flash_attn-0.2.6.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
30
- mps_flash_attn-0.2.6.dist-info/METADATA,sha256=uxBPVD-lDaQrg9cAKjtS65ze_NiyHi37x5_LsABU9Cc,5834
31
- mps_flash_attn-0.2.6.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
32
- mps_flash_attn-0.2.6.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
33
- mps_flash_attn-0.2.6.dist-info/RECORD,,
28
+ mps_flash_attn/lib/libMFABridge.dylib,sha256=k9a015iToIw2DcEaDmMAxPdb4FpW4sJ9IqBS4jJms0g,608768
29
+ mps_flash_attn-0.3.7.dist-info/licenses/LICENSE,sha256=F_XmXSab2O-hHcqLpYJWeFaqB6GA_qiTEN23p2VfZWU,1237
30
+ mps_flash_attn-0.3.7.dist-info/METADATA,sha256=chjYcevEKiWpsZJFSKkZ5xm5k5oD9_LqUAy_23JyIq4,5834
31
+ mps_flash_attn-0.3.7.dist-info/WHEEL,sha256=uAzMRtb2noxPlbYLbRgeD25pPgKOo3k59IS71Dg5Qjs,110
32
+ mps_flash_attn-0.3.7.dist-info/top_level.txt,sha256=zbArDcWhJDnJfMUKnOUhs5TjsMgSxa2GzOlscTRfobE,15
33
+ mps_flash_attn-0.3.7.dist-info/RECORD,,