mps-flash-attn 0.2.7__tar.gz → 0.3.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/PKG-INFO +1 -1
  2. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/__init__.py +364 -54
  3. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/csrc/mps_flash_attn.mm +356 -35
  4. mps_flash_attn-0.3.7/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  5. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/PKG-INFO +1 -1
  6. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/SOURCES.txt +1 -0
  7. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/pyproject.toml +1 -1
  8. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/setup.py +1 -1
  9. mps_flash_attn-0.3.7/tests/test_issues.py +446 -0
  10. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/tests/test_mfa_v2.py +296 -0
  11. mps_flash_attn-0.2.7/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  12. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/LICENSE +0 -0
  13. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/README.md +0 -0
  14. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/benchmark.py +0 -0
  15. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  16. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  17. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  18. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  19. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  20. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  21. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  22. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  23. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  24. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  25. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  26. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  27. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  28. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  29. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  30. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  31. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  32. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  33. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  34. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  35. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  36. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  37. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/manifest.json +0 -0
  38. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  39. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/requires.txt +0 -0
  40. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/top_level.txt +0 -0
  41. {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.7
3
+ Version: 0.3.7
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,13 +4,43 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.2.7"
7
+ __version__ = "0.3.7"
8
+
9
+ __all__ = [
10
+ # Core functions
11
+ "flash_attention",
12
+ "flash_attention_with_bias",
13
+ "flash_attention_chunked",
14
+ # Quantized attention
15
+ "flash_attention_fp8",
16
+ "flash_attention_int8",
17
+ "flash_attention_nf4",
18
+ "quantize_kv_fp8",
19
+ "quantize_kv_int8",
20
+ "quantize_kv_nf4",
21
+ # Utilities
22
+ "replace_sdpa",
23
+ "precompile",
24
+ "clear_cache",
25
+ "register_custom_op",
26
+ "is_available",
27
+ "convert_mask",
28
+ # Constants
29
+ "QUANT_FP8_E4M3",
30
+ "QUANT_FP8_E5M2",
31
+ "QUANT_INT8",
32
+ "QUANT_NF4",
33
+ # Version
34
+ "__version__",
35
+ ]
8
36
 
9
37
  import torch
10
- from typing import Optional
38
+ import torch.nn.functional as F
39
+ from typing import Optional, Tuple
11
40
  import math
12
41
  import threading
13
42
  import os
43
+ import warnings
14
44
 
15
45
  # Try to import the C++ extension
16
46
  try:
@@ -30,6 +60,20 @@ def is_available() -> bool:
30
60
  return _HAS_MFA and torch.backends.mps.is_available()
31
61
 
32
62
 
63
+ def _ensure_contiguous(tensor: torch.Tensor, name: str) -> torch.Tensor:
64
+ """Ensure tensor is contiguous, with a debug warning if conversion needed."""
65
+ if tensor.is_contiguous():
66
+ return tensor
67
+ # Auto-convert with debug info
68
+ if os.environ.get("MFA_DEBUG", "0") == "1":
69
+ warnings.warn(
70
+ f"MFA: {name} tensor was not contiguous (stride={tensor.stride()}), "
71
+ f"auto-converting. For best performance, ensure inputs are contiguous.",
72
+ UserWarning
73
+ )
74
+ return tensor.contiguous()
75
+
76
+
33
77
  def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
34
78
  """
35
79
  Convert attention mask to MFA's boolean format.
@@ -54,6 +98,98 @@ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
54
98
  return attn_mask <= -1e3
55
99
 
56
100
 
101
+ def _validate_and_expand_mask(
102
+ attn_mask: Optional[torch.Tensor],
103
+ B: int,
104
+ H: int,
105
+ N_q: int,
106
+ N_kv: int,
107
+ ) -> Optional[torch.Tensor]:
108
+ """
109
+ Validate attention mask shape and expand broadcast dimensions.
110
+
111
+ Args:
112
+ attn_mask: Optional mask of shape (B, H, N_q, N_kv) or broadcastable
113
+ B: Batch size
114
+ H: Number of heads
115
+ N_q: Query sequence length
116
+ N_kv: Key/Value sequence length
117
+
118
+ Returns:
119
+ Expanded mask of shape (mb, mh, N_q, N_kv) or None
120
+ """
121
+ if attn_mask is None:
122
+ return None
123
+
124
+ attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
125
+
126
+ if attn_mask.dim() != 4:
127
+ raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
128
+
129
+ mb, mh, mq, mk = attn_mask.shape
130
+
131
+ # Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
132
+ if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
133
+ raise ValueError(
134
+ f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
135
+ )
136
+
137
+ # Expand broadcast mask to full shape for Metal kernel
138
+ if mq == 1 and N_q > 1:
139
+ attn_mask = attn_mask.expand(mb, mh, N_q, mk)
140
+ if mk == 1 and N_kv > 1:
141
+ attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
142
+
143
+ if mb != 1 and mb != B:
144
+ raise ValueError(f"attn_mask batch size must be 1 or {B}, got {mb}")
145
+ if mh != 1 and mh != H:
146
+ raise ValueError(f"attn_mask head count must be 1 or {H}, got {mh}")
147
+
148
+ return attn_mask
149
+
150
+
151
+ class FlashAttentionWithBiasFunction(torch.autograd.Function):
152
+ """Autograd function for Flash Attention with bias - native C++ backward."""
153
+
154
+ @staticmethod
155
+ def forward(ctx, query, key, value, attn_bias, is_causal, scale, window_size, bias_repeat_count):
156
+ # Apply scale if provided
157
+ scale_factor = 1.0
158
+ if scale is not None:
159
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
160
+ if abs(scale - default_scale) > 1e-6:
161
+ scale_factor = scale / default_scale
162
+ query = query * scale_factor
163
+
164
+ # Call C++ forward with bias (returns output and logsumexp)
165
+ output, logsumexp = _C.forward_with_bias_lse(query, key, value, attn_bias, is_causal, window_size, bias_repeat_count)
166
+
167
+ # Save for backward
168
+ ctx.save_for_backward(query, key, value, output, logsumexp, attn_bias)
169
+ ctx.is_causal = is_causal
170
+ ctx.scale_factor = scale_factor
171
+ ctx.window_size = window_size
172
+ ctx.bias_repeat_count = bias_repeat_count
173
+
174
+ return output
175
+
176
+ @staticmethod
177
+ def backward(ctx, grad_output):
178
+ query, key, value, output, logsumexp, attn_bias = ctx.saved_tensors
179
+
180
+ # Call native C++ backward with bias
181
+ dQ, dK, dV = _C.backward_with_bias(
182
+ grad_output, query, key, value, output, logsumexp, attn_bias,
183
+ ctx.is_causal, ctx.window_size, ctx.bias_repeat_count
184
+ )
185
+
186
+ # Scale dQ back if we scaled query
187
+ if ctx.scale_factor != 1.0:
188
+ dQ = dQ * ctx.scale_factor
189
+
190
+ return dQ, dK, dV, None, None, None, None, None
191
+
192
+
57
193
  class FlashAttentionFunction(torch.autograd.Function):
58
194
  """Autograd function for Flash Attention with backward pass support."""
59
195
 
@@ -176,6 +312,20 @@ def flash_attention(
176
312
  if not torch.backends.mps.is_available():
177
313
  raise RuntimeError("MPS not available")
178
314
 
315
+ # Validate scale parameter
316
+ if scale is not None:
317
+ if scale <= 0:
318
+ raise ValueError(f"scale must be positive, got {scale}")
319
+ # Warn about extreme scale values that could cause numerical issues
320
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
321
+ if scale < default_scale * 0.01 or scale > default_scale * 100:
322
+ warnings.warn(
323
+ f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
324
+ "this may cause numerical issues",
325
+ UserWarning,
326
+ stacklevel=2
327
+ )
328
+
179
329
  # Validate device
180
330
  if query.device.type != 'mps':
181
331
  raise ValueError("query must be on MPS device")
@@ -186,6 +336,24 @@ def flash_attention(
186
336
  if attn_mask is not None and attn_mask.device.type != 'mps':
187
337
  raise ValueError("attn_mask must be on MPS device")
188
338
 
339
+ # Ensure contiguous (auto-convert with debug warning)
340
+ query = _ensure_contiguous(query, "query")
341
+ key = _ensure_contiguous(key, "key")
342
+ value = _ensure_contiguous(value, "value")
343
+
344
+ # Validate tensor dimensions
345
+ if query.dim() != 4:
346
+ raise RuntimeError(f"query must be 4D (B, H, N, D), got {query.dim()}D")
347
+ if key.dim() != 4:
348
+ raise RuntimeError(f"key must be 4D (B, H, N, D), got {key.dim()}D")
349
+ if value.dim() != 4:
350
+ raise RuntimeError(f"value must be 4D (B, H, N, D), got {value.dim()}D")
351
+
352
+ # Validate and expand broadcast mask
353
+ B, H, N_q, D = query.shape
354
+ N_kv = key.shape[2]
355
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
356
+
189
357
  # Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
190
358
  if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
191
359
  # Apply scale if provided
@@ -213,6 +381,10 @@ def replace_sdpa():
213
381
  import torch.nn.functional as F
214
382
 
215
383
  original_sdpa = F.scaled_dot_product_attention
384
+ _debug = os.environ.get("MFA_DEBUG", "0") == "1"
385
+ _call_count = [0] # mutable for closure
386
+ _fallback_count = [0] # track fallbacks for warning
387
+ _last_fallback_error = [None]
216
388
 
217
389
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
218
390
  is_causal=False, scale=None, enable_gqa=False, **kwargs):
@@ -224,28 +396,92 @@ def replace_sdpa():
224
396
  # seq=1024: 2.3-3.7x (MFA much faster)
225
397
  # seq=2048: 2.2-3.9x (MFA much faster)
226
398
  # seq=4096: 2.1-3.7x (MFA much faster)
399
+ # Determine seq_len based on tensor dimensionality
400
+ # 4D: (B, H, S, D) -> seq_len = shape[2]
401
+ # 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
402
+ is_3d = query.ndim == 3
403
+ seq_len = query.shape[1] if is_3d else query.shape[2]
404
+
227
405
  if (query.device.type == 'mps' and
228
406
  dropout_p == 0.0 and
229
407
  _HAS_MFA and
230
- query.shape[2] >= 512):
408
+ query.ndim >= 3 and
409
+ seq_len >= 512):
231
410
  try:
411
+ q, k, v = query, key, value
412
+
413
+ # Handle 3D tensors (B, S, D) - treat as single-head attention
414
+ # Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
415
+ if is_3d:
416
+ q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
417
+ k = k.unsqueeze(1)
418
+ v = v.unsqueeze(1)
419
+
420
+ # Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
421
+ # Common in Llama 2/3, Mistral, Qwen, etc.
422
+ # NOTE: Always expand when heads mismatch, not just when enable_gqa=True
423
+ # Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
424
+ # MPS SDPA doesn't support native GQA - we handle it here
425
+ if q.shape[1] != k.shape[1]:
426
+ # Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
427
+ n_rep = q.shape[1] // k.shape[1]
428
+ k = k.repeat_interleave(n_rep, dim=1)
429
+ v = v.repeat_interleave(n_rep, dim=1)
430
+
232
431
  # Convert float mask to bool mask if needed
233
432
  # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
234
433
  # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
235
434
  mfa_mask = None
236
435
  if attn_mask is not None:
436
+ if _debug:
437
+ print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
237
438
  if attn_mask.dtype == torch.bool:
238
- # Boolean mask: True means masked (don't attend)
239
- mfa_mask = attn_mask
439
+ # PyTorch SDPA bool mask: True = ATTEND, False = MASKED
440
+ # MFA bool mask: True = MASKED, False = ATTEND
441
+ # They're opposite! Invert it.
442
+ mfa_mask = ~attn_mask
240
443
  else:
241
444
  # Float mask: typically -inf for masked positions, 0 for unmasked
242
445
  # Convert: positions with large negative values -> True (masked)
243
446
  # Use -1e3 threshold to catch -1000, -10000, -inf, etc.
244
447
  mfa_mask = attn_mask <= -1e3
245
- return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
246
- except Exception:
247
- # Fall back to original on any error
248
- pass
448
+ if _debug:
449
+ print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
450
+
451
+ out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
452
+
453
+ # Squeeze back for 3D input
454
+ if is_3d:
455
+ out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
456
+
457
+ if _debug:
458
+ _call_count[0] += 1
459
+ print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
460
+
461
+ return out
462
+ except Exception as e:
463
+ # Fall back to original on any error, but track it
464
+ _fallback_count[0] += 1
465
+ _last_fallback_error[0] = str(e)
466
+ if _debug:
467
+ import traceback
468
+ print(f"[MFA FALLBACK #{_fallback_count[0]}] shape={tuple(query.shape)}\n{traceback.format_exc()}")
469
+ # Warn user after repeated fallbacks (likely a real problem)
470
+ if _fallback_count[0] == 10:
471
+ warnings.warn(
472
+ f"MFA has fallen back to native SDPA {_fallback_count[0]} times. "
473
+ f"Last error: {_last_fallback_error[0]}. "
474
+ f"Set MFA_DEBUG=1 for details.",
475
+ UserWarning
476
+ )
477
+
478
+ if _debug and query.device.type == 'mps':
479
+ _call_count[0] += 1
480
+ reason = []
481
+ if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
482
+ if query.ndim < 3: reason.append(f"ndim={query.ndim}")
483
+ if seq_len < 512: reason.append(f"seq={seq_len}<512")
484
+ print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
249
485
 
250
486
  return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
251
487
 
@@ -302,6 +538,7 @@ def flash_attention_with_bias(
302
538
  value: torch.Tensor,
303
539
  attn_bias: torch.Tensor,
304
540
  is_causal: bool = False,
541
+ scale: Optional[float] = None,
305
542
  window_size: int = 0,
306
543
  bias_repeat_count: int = 0,
307
544
  ) -> torch.Tensor:
@@ -309,13 +546,15 @@ def flash_attention_with_bias(
309
546
  Compute scaled dot-product attention with additive attention bias.
310
547
 
311
548
  This function supports additive attention bias (like relative position encodings
312
- or ALiBi) which is added to the attention scores before softmax:
549
+ or ALiBi) which is added to the attention scores:
550
+
551
+ Attention(Q, K, V) = softmax((Q @ K.T + bias) * scale) @ V
313
552
 
314
- Attention(Q, K, V) = softmax((Q @ K.T) / sqrt(d) + bias) @ V
553
+ IMPORTANT: The bias is added to UNSCALED scores, then the sum is scaled.
554
+ This differs from PyTorch SDPA which does: softmax((Q @ K.T) * scale + bias).
315
555
 
316
- IMPORTANT: MFA adds bias to UNSCALED scores internally and scales during softmax.
317
- If your bias was computed for scaled scores (like PyTorch SDPA), you need to
318
- pre-scale it by multiplying by sqrt(head_dim).
556
+ To convert from SDPA-style bias to MFA-style:
557
+ bias_mfa = bias_sdpa * sqrt(head_dim) # when using default scale
319
558
 
320
559
  Args:
321
560
  query: Query tensor of shape (B, H, N_q, D)
@@ -326,6 +565,7 @@ def flash_attention_with_bias(
326
565
  - (1, H, N_q, N_kv): Broadcast across batch
327
566
  - (H, N_q, N_kv): Broadcast across batch (3D)
328
567
  is_causal: If True, applies causal masking
568
+ scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
329
569
  window_size: Sliding window attention size (0 = full attention)
330
570
  bias_repeat_count: If > 0, the bias tensor repeats every N batches.
331
571
  Useful for window attention where multiple windows share the same
@@ -343,10 +583,13 @@ def flash_attention_with_bias(
343
583
  >>> v = torch.randn(4, 8, 64, 64, device='mps', dtype=torch.float16)
344
584
  >>> # Position bias: (1, num_heads, seq_len, seq_len)
345
585
  >>> bias = torch.randn(1, 8, 64, 64, device='mps', dtype=torch.float16)
346
- >>> # Pre-scale bias since MFA uses unscaled scores
586
+ >>> # Pre-scale bias since default scale is 1/sqrt(head_dim)
347
587
  >>> scaled_bias = bias * math.sqrt(64) # sqrt(head_dim)
348
588
  >>> out = flash_attention_with_bias(q, k, v, scaled_bias)
349
589
 
590
+ >>> # With custom scale
591
+ >>> out = flash_attention_with_bias(q, k, v, bias, scale=0.1)
592
+
350
593
  >>> # Window attention with repeating bias pattern
351
594
  >>> n_windows = 16
352
595
  >>> q = torch.randn(n_windows * 4, 8, 49, 64, device='mps', dtype=torch.float16)
@@ -363,6 +606,20 @@ def flash_attention_with_bias(
363
606
  if not torch.backends.mps.is_available():
364
607
  raise RuntimeError("MPS not available")
365
608
 
609
+ # Validate scale parameter
610
+ if scale is not None:
611
+ if scale <= 0:
612
+ raise ValueError(f"scale must be positive, got {scale}")
613
+ # Warn about extreme scale values
614
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
615
+ if scale < default_scale * 0.01 or scale > default_scale * 100:
616
+ warnings.warn(
617
+ f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
618
+ "this may cause numerical issues",
619
+ UserWarning,
620
+ stacklevel=2
621
+ )
622
+
366
623
  # Validate device
367
624
  if query.device.type != 'mps':
368
625
  raise ValueError("query must be on MPS device")
@@ -373,7 +630,10 @@ def flash_attention_with_bias(
373
630
  if attn_bias.device.type != 'mps':
374
631
  raise ValueError("attn_bias must be on MPS device")
375
632
 
376
- return _C.forward_with_bias(query, key, value, attn_bias, is_causal, window_size, bias_repeat_count)
633
+ # Use autograd Function for backward support
634
+ return FlashAttentionWithBiasFunction.apply(
635
+ query, key, value, attn_bias, is_causal, scale, window_size, bias_repeat_count
636
+ )
377
637
 
378
638
 
379
639
  def flash_attention_chunked(
@@ -452,13 +712,13 @@ def flash_attention_chunked(
452
712
  return _C.forward(query, key, value, is_causal, None, 0)
453
713
 
454
714
  # Initialize running statistics for online softmax
455
- # m = running max, l = running sum of exp, acc = accumulated output
456
715
  device = query.device
457
716
  dtype = query.dtype
458
717
 
459
718
  # Use float32 for numerical stability of softmax statistics
460
- running_max = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
461
- running_sum = torch.zeros((B, H, seq_len_q, 1), device=device, dtype=torch.float32)
719
+ # running_L: base-2 logsumexp of all attention scores seen so far (-inf means no data yet)
720
+ # output_acc: weighted combination of outputs (weights sum to 1 after each update)
721
+ running_L = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
462
722
  output_acc = torch.zeros((B, H, seq_len_q, D), device=device, dtype=torch.float32)
463
723
 
464
724
  # Process K/V in chunks
@@ -479,51 +739,81 @@ def flash_attention_chunked(
479
739
  # - Partial chunk (up to q) if start_idx <= q < end_idx
480
740
  # - None of chunk if q < start_idx
481
741
 
482
- chunk_is_causal = is_causal and (end_idx <= seq_len_q)
742
+ if is_causal:
743
+ # Create explicit causal mask for this chunk
744
+ # Query positions: 0 to seq_len_q-1
745
+ # Key positions in chunk: start_idx to end_idx-1
746
+ chunk_len = end_idx - start_idx
483
747
 
484
- # Compute attention for this chunk
485
- # forward_with_lse returns (output, logsumexp) where logsumexp = m + log(l)
486
- chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, chunk_is_causal, None, 0)
748
+ # Build mask: mask[q, k_local] = True means DON'T attend
749
+ # We want to attend when global_k_pos <= q
750
+ # global_k_pos = start_idx + k_local
751
+ # So: attend when start_idx + k_local <= q
752
+ # mask = start_idx + k_local > q
487
753
 
488
- # chunk_lse shape: (B, H, seq_len_q)
489
- # We need to convert logsumexp to (max, sum) for online algorithm
490
- chunk_lse = chunk_lse.unsqueeze(-1) # (B, H, seq_len_q, 1)
754
+ q_pos = torch.arange(seq_len_q, device=device).view(1, 1, seq_len_q, 1)
755
+ k_pos = torch.arange(chunk_len, device=device).view(1, 1, 1, chunk_len) + start_idx
756
+ causal_mask = k_pos > q_pos # True = masked (don't attend)
491
757
 
492
- # Convert chunk output to float32 for accumulation
493
- chunk_out = chunk_out.float()
494
-
495
- # Online softmax update:
496
- # new_max = max(running_max, chunk_max)
497
- # For flash attention, chunk_lse ≈ chunk_max + log(chunk_sum)
498
- # We approximate chunk_max ≈ chunk_lse (valid when exp sum dominates)
758
+ # Expand to batch and heads
759
+ causal_mask = causal_mask.expand(B, H, seq_len_q, chunk_len)
499
760
 
500
- chunk_max = chunk_lse # Approximation: logsumexp max when sum is dominated by max
501
-
502
- # Compute new max
503
- new_max = torch.maximum(running_max, chunk_max)
761
+ # Call forward with explicit mask (is_causal=False since we handle it)
762
+ chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, causal_mask, 0)
763
+ else:
764
+ # Non-causal: just process the chunk directly
765
+ chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, None, 0)
504
766
 
505
- # Rescale previous accumulator
506
- # correction_old = exp(running_max - new_max)
507
- correction_old = torch.exp(running_max - new_max)
508
- # Clip to avoid inf * 0 issues when running_max was -inf
509
- correction_old = torch.where(running_max == float('-inf'), torch.zeros_like(correction_old), correction_old)
767
+ # chunk_L shape: (B, H, seq_len_q)
768
+ # The kernel returns L = m + log2(l) where:
769
+ # m = max(scores * log2(e) / sqrt(D))
770
+ # l = sum(exp2(scores * log2(e) / sqrt(D) - m))
771
+ # This is a base-2 logsumexp: L = log2(sum(exp2(scaled_scores)))
772
+ chunk_L = chunk_lse.unsqueeze(-1).float() # (B, H, seq_len_q, 1)
510
773
 
511
- # Rescale chunk output
512
- # correction_new = exp(chunk_max - new_max)
513
- correction_new = torch.exp(chunk_max - new_max)
774
+ # Convert chunk output to float32 for accumulation
775
+ chunk_out = chunk_out.float()
514
776
 
515
- # For the sum, we need exp(chunk_lse - new_max) = exp(chunk_max + log(chunk_sum) - new_max)
516
- # = exp(chunk_max - new_max) * chunk_sum
517
- # But we only have logsumexp, so: exp(chunk_lse - new_max)
518
- chunk_sum_scaled = torch.exp(chunk_lse - new_max)
777
+ # Online softmax algorithm using base-2 representation
778
+ #
779
+ # Flash attention returns: chunk_out = softmax(scores) @ V
780
+ # The output is already normalized. For online combination:
781
+ # new_L = log2(2^running_L + 2^chunk_L)
782
+ # = max(running_L, chunk_L) + log2(2^(running_L - max) + 2^(chunk_L - max))
783
+ #
784
+ # The weights for combining outputs are:
785
+ # old_weight = 2^(running_L - new_L)
786
+ # new_weight = 2^(chunk_L - new_L)
787
+ # These weights sum to 1, so: output = old_weight * old_out + new_weight * new_out
788
+
789
+ # Compute new base-2 logsumexp
790
+ max_L = torch.maximum(running_L, chunk_L)
791
+
792
+ # Handle -inf case (no previous data)
793
+ # Use exp2 for base-2 (matches kernel's internal representation)
794
+ running_exp2 = torch.where(
795
+ running_L == float('-inf'),
796
+ torch.zeros_like(running_L),
797
+ torch.exp2(running_L - max_L)
798
+ )
799
+ chunk_exp2 = torch.exp2(chunk_L - max_L)
800
+ new_L = max_L + torch.log2(running_exp2 + chunk_exp2)
801
+
802
+ # Compute correction factors using base-2 exp
803
+ old_weight = torch.where(
804
+ running_L == float('-inf'),
805
+ torch.zeros_like(running_L),
806
+ torch.exp2(running_L - new_L)
807
+ )
808
+ new_weight = torch.exp2(chunk_L - new_L)
519
809
 
520
810
  # Update accumulator
521
- output_acc = output_acc * correction_old + chunk_out * correction_new
522
- running_sum = running_sum * correction_old + chunk_sum_scaled
523
- running_max = new_max
811
+ # Update accumulator
812
+ output_acc = output_acc * old_weight + chunk_out * new_weight
813
+ running_L = new_L
524
814
 
525
- # Final normalization
526
- output = output_acc / running_sum
815
+ # No final normalization needed - weights already sum to 1
816
+ output = output_acc
527
817
 
528
818
  # Convert back to original dtype
529
819
  return output.to(dtype)
@@ -745,6 +1035,11 @@ def flash_attention_fp8(
745
1035
  scale_factor = scale / default_scale
746
1036
  query = query * scale_factor
747
1037
 
1038
+ # Validate and expand broadcast mask
1039
+ B, H, N_q, D = query.shape
1040
+ N_kv = key.shape[2]
1041
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1042
+
748
1043
  quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
749
1044
  return _C.forward_quantized(
750
1045
  query, key, value, k_scale, v_scale,
@@ -798,6 +1093,11 @@ def flash_attention_int8(
798
1093
  scale_factor = scale / default_scale
799
1094
  query = query * scale_factor
800
1095
 
1096
+ # Validate and expand broadcast mask
1097
+ B, H, N_q, D = query.shape
1098
+ N_kv = key.shape[2]
1099
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1100
+
801
1101
  return _C.forward_quantized(
802
1102
  query, key, value, k_scale, v_scale,
803
1103
  QUANT_INT8, is_causal, attn_mask, window_size
@@ -854,6 +1154,11 @@ def flash_attention_nf4(
854
1154
  scale_factor = scale / default_scale
855
1155
  query = query * scale_factor
856
1156
 
1157
+ # Validate and expand broadcast mask
1158
+ B, H, N_q, D = query.shape
1159
+ N_kv = key.shape[2]
1160
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1161
+
857
1162
  return _C.forward_quantized(
858
1163
  query, key, value, k_scale, v_scale,
859
1164
  QUANT_NF4, is_causal, attn_mask, window_size
@@ -908,6 +1213,11 @@ def flash_attention_quantized(
908
1213
  scale_factor = scale / default_scale
909
1214
  query = query * scale_factor
910
1215
 
1216
+ # Validate and expand broadcast mask
1217
+ B, H, N_q, D = query.shape
1218
+ N_kv = key.shape[2]
1219
+ attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
1220
+
911
1221
  return _C.forward_quantized(
912
1222
  query, key, value, k_scale, v_scale,
913
1223
  quant_type, is_causal, attn_mask, window_size