mps-flash-attn 0.2.7__tar.gz → 0.3.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/PKG-INFO +1 -1
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/__init__.py +364 -54
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/csrc/mps_flash_attn.mm +356 -35
- mps_flash_attn-0.3.7/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/SOURCES.txt +1 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/pyproject.toml +1 -1
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/setup.py +1 -1
- mps_flash_attn-0.3.7/tests/test_issues.py +446 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/tests/test_mfa_v2.py +296 -0
- mps_flash_attn-0.2.7/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/LICENSE +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/README.md +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.3.7}/setup.cfg +0 -0
|
@@ -4,13 +4,43 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.
|
|
7
|
+
__version__ = "0.3.7"
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
# Core functions
|
|
11
|
+
"flash_attention",
|
|
12
|
+
"flash_attention_with_bias",
|
|
13
|
+
"flash_attention_chunked",
|
|
14
|
+
# Quantized attention
|
|
15
|
+
"flash_attention_fp8",
|
|
16
|
+
"flash_attention_int8",
|
|
17
|
+
"flash_attention_nf4",
|
|
18
|
+
"quantize_kv_fp8",
|
|
19
|
+
"quantize_kv_int8",
|
|
20
|
+
"quantize_kv_nf4",
|
|
21
|
+
# Utilities
|
|
22
|
+
"replace_sdpa",
|
|
23
|
+
"precompile",
|
|
24
|
+
"clear_cache",
|
|
25
|
+
"register_custom_op",
|
|
26
|
+
"is_available",
|
|
27
|
+
"convert_mask",
|
|
28
|
+
# Constants
|
|
29
|
+
"QUANT_FP8_E4M3",
|
|
30
|
+
"QUANT_FP8_E5M2",
|
|
31
|
+
"QUANT_INT8",
|
|
32
|
+
"QUANT_NF4",
|
|
33
|
+
# Version
|
|
34
|
+
"__version__",
|
|
35
|
+
]
|
|
8
36
|
|
|
9
37
|
import torch
|
|
10
|
-
|
|
38
|
+
import torch.nn.functional as F
|
|
39
|
+
from typing import Optional, Tuple
|
|
11
40
|
import math
|
|
12
41
|
import threading
|
|
13
42
|
import os
|
|
43
|
+
import warnings
|
|
14
44
|
|
|
15
45
|
# Try to import the C++ extension
|
|
16
46
|
try:
|
|
@@ -30,6 +60,20 @@ def is_available() -> bool:
|
|
|
30
60
|
return _HAS_MFA and torch.backends.mps.is_available()
|
|
31
61
|
|
|
32
62
|
|
|
63
|
+
def _ensure_contiguous(tensor: torch.Tensor, name: str) -> torch.Tensor:
|
|
64
|
+
"""Ensure tensor is contiguous, with a debug warning if conversion needed."""
|
|
65
|
+
if tensor.is_contiguous():
|
|
66
|
+
return tensor
|
|
67
|
+
# Auto-convert with debug info
|
|
68
|
+
if os.environ.get("MFA_DEBUG", "0") == "1":
|
|
69
|
+
warnings.warn(
|
|
70
|
+
f"MFA: {name} tensor was not contiguous (stride={tensor.stride()}), "
|
|
71
|
+
f"auto-converting. For best performance, ensure inputs are contiguous.",
|
|
72
|
+
UserWarning
|
|
73
|
+
)
|
|
74
|
+
return tensor.contiguous()
|
|
75
|
+
|
|
76
|
+
|
|
33
77
|
def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
34
78
|
"""
|
|
35
79
|
Convert attention mask to MFA's boolean format.
|
|
@@ -54,6 +98,98 @@ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
|
54
98
|
return attn_mask <= -1e3
|
|
55
99
|
|
|
56
100
|
|
|
101
|
+
def _validate_and_expand_mask(
|
|
102
|
+
attn_mask: Optional[torch.Tensor],
|
|
103
|
+
B: int,
|
|
104
|
+
H: int,
|
|
105
|
+
N_q: int,
|
|
106
|
+
N_kv: int,
|
|
107
|
+
) -> Optional[torch.Tensor]:
|
|
108
|
+
"""
|
|
109
|
+
Validate attention mask shape and expand broadcast dimensions.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
attn_mask: Optional mask of shape (B, H, N_q, N_kv) or broadcastable
|
|
113
|
+
B: Batch size
|
|
114
|
+
H: Number of heads
|
|
115
|
+
N_q: Query sequence length
|
|
116
|
+
N_kv: Key/Value sequence length
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Expanded mask of shape (mb, mh, N_q, N_kv) or None
|
|
120
|
+
"""
|
|
121
|
+
if attn_mask is None:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
attn_mask = _ensure_contiguous(attn_mask, "attn_mask")
|
|
125
|
+
|
|
126
|
+
if attn_mask.dim() != 4:
|
|
127
|
+
raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
|
|
128
|
+
|
|
129
|
+
mb, mh, mq, mk = attn_mask.shape
|
|
130
|
+
|
|
131
|
+
# Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
|
|
132
|
+
if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Expand broadcast mask to full shape for Metal kernel
|
|
138
|
+
if mq == 1 and N_q > 1:
|
|
139
|
+
attn_mask = attn_mask.expand(mb, mh, N_q, mk)
|
|
140
|
+
if mk == 1 and N_kv > 1:
|
|
141
|
+
attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
|
|
142
|
+
|
|
143
|
+
if mb != 1 and mb != B:
|
|
144
|
+
raise ValueError(f"attn_mask batch size must be 1 or {B}, got {mb}")
|
|
145
|
+
if mh != 1 and mh != H:
|
|
146
|
+
raise ValueError(f"attn_mask head count must be 1 or {H}, got {mh}")
|
|
147
|
+
|
|
148
|
+
return attn_mask
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class FlashAttentionWithBiasFunction(torch.autograd.Function):
|
|
152
|
+
"""Autograd function for Flash Attention with bias - native C++ backward."""
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def forward(ctx, query, key, value, attn_bias, is_causal, scale, window_size, bias_repeat_count):
|
|
156
|
+
# Apply scale if provided
|
|
157
|
+
scale_factor = 1.0
|
|
158
|
+
if scale is not None:
|
|
159
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
160
|
+
if abs(scale - default_scale) > 1e-6:
|
|
161
|
+
scale_factor = scale / default_scale
|
|
162
|
+
query = query * scale_factor
|
|
163
|
+
|
|
164
|
+
# Call C++ forward with bias (returns output and logsumexp)
|
|
165
|
+
output, logsumexp = _C.forward_with_bias_lse(query, key, value, attn_bias, is_causal, window_size, bias_repeat_count)
|
|
166
|
+
|
|
167
|
+
# Save for backward
|
|
168
|
+
ctx.save_for_backward(query, key, value, output, logsumexp, attn_bias)
|
|
169
|
+
ctx.is_causal = is_causal
|
|
170
|
+
ctx.scale_factor = scale_factor
|
|
171
|
+
ctx.window_size = window_size
|
|
172
|
+
ctx.bias_repeat_count = bias_repeat_count
|
|
173
|
+
|
|
174
|
+
return output
|
|
175
|
+
|
|
176
|
+
@staticmethod
|
|
177
|
+
def backward(ctx, grad_output):
|
|
178
|
+
query, key, value, output, logsumexp, attn_bias = ctx.saved_tensors
|
|
179
|
+
|
|
180
|
+
# Call native C++ backward with bias
|
|
181
|
+
dQ, dK, dV = _C.backward_with_bias(
|
|
182
|
+
grad_output, query, key, value, output, logsumexp, attn_bias,
|
|
183
|
+
ctx.is_causal, ctx.window_size, ctx.bias_repeat_count
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Scale dQ back if we scaled query
|
|
187
|
+
if ctx.scale_factor != 1.0:
|
|
188
|
+
dQ = dQ * ctx.scale_factor
|
|
189
|
+
|
|
190
|
+
return dQ, dK, dV, None, None, None, None, None
|
|
191
|
+
|
|
192
|
+
|
|
57
193
|
class FlashAttentionFunction(torch.autograd.Function):
|
|
58
194
|
"""Autograd function for Flash Attention with backward pass support."""
|
|
59
195
|
|
|
@@ -176,6 +312,20 @@ def flash_attention(
|
|
|
176
312
|
if not torch.backends.mps.is_available():
|
|
177
313
|
raise RuntimeError("MPS not available")
|
|
178
314
|
|
|
315
|
+
# Validate scale parameter
|
|
316
|
+
if scale is not None:
|
|
317
|
+
if scale <= 0:
|
|
318
|
+
raise ValueError(f"scale must be positive, got {scale}")
|
|
319
|
+
# Warn about extreme scale values that could cause numerical issues
|
|
320
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
321
|
+
if scale < default_scale * 0.01 or scale > default_scale * 100:
|
|
322
|
+
warnings.warn(
|
|
323
|
+
f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
|
|
324
|
+
"this may cause numerical issues",
|
|
325
|
+
UserWarning,
|
|
326
|
+
stacklevel=2
|
|
327
|
+
)
|
|
328
|
+
|
|
179
329
|
# Validate device
|
|
180
330
|
if query.device.type != 'mps':
|
|
181
331
|
raise ValueError("query must be on MPS device")
|
|
@@ -186,6 +336,24 @@ def flash_attention(
|
|
|
186
336
|
if attn_mask is not None and attn_mask.device.type != 'mps':
|
|
187
337
|
raise ValueError("attn_mask must be on MPS device")
|
|
188
338
|
|
|
339
|
+
# Ensure contiguous (auto-convert with debug warning)
|
|
340
|
+
query = _ensure_contiguous(query, "query")
|
|
341
|
+
key = _ensure_contiguous(key, "key")
|
|
342
|
+
value = _ensure_contiguous(value, "value")
|
|
343
|
+
|
|
344
|
+
# Validate tensor dimensions
|
|
345
|
+
if query.dim() != 4:
|
|
346
|
+
raise RuntimeError(f"query must be 4D (B, H, N, D), got {query.dim()}D")
|
|
347
|
+
if key.dim() != 4:
|
|
348
|
+
raise RuntimeError(f"key must be 4D (B, H, N, D), got {key.dim()}D")
|
|
349
|
+
if value.dim() != 4:
|
|
350
|
+
raise RuntimeError(f"value must be 4D (B, H, N, D), got {value.dim()}D")
|
|
351
|
+
|
|
352
|
+
# Validate and expand broadcast mask
|
|
353
|
+
B, H, N_q, D = query.shape
|
|
354
|
+
N_kv = key.shape[2]
|
|
355
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
356
|
+
|
|
189
357
|
# Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
|
|
190
358
|
if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
|
|
191
359
|
# Apply scale if provided
|
|
@@ -213,6 +381,10 @@ def replace_sdpa():
|
|
|
213
381
|
import torch.nn.functional as F
|
|
214
382
|
|
|
215
383
|
original_sdpa = F.scaled_dot_product_attention
|
|
384
|
+
_debug = os.environ.get("MFA_DEBUG", "0") == "1"
|
|
385
|
+
_call_count = [0] # mutable for closure
|
|
386
|
+
_fallback_count = [0] # track fallbacks for warning
|
|
387
|
+
_last_fallback_error = [None]
|
|
216
388
|
|
|
217
389
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
218
390
|
is_causal=False, scale=None, enable_gqa=False, **kwargs):
|
|
@@ -224,28 +396,92 @@ def replace_sdpa():
|
|
|
224
396
|
# seq=1024: 2.3-3.7x (MFA much faster)
|
|
225
397
|
# seq=2048: 2.2-3.9x (MFA much faster)
|
|
226
398
|
# seq=4096: 2.1-3.7x (MFA much faster)
|
|
399
|
+
# Determine seq_len based on tensor dimensionality
|
|
400
|
+
# 4D: (B, H, S, D) -> seq_len = shape[2]
|
|
401
|
+
# 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
|
|
402
|
+
is_3d = query.ndim == 3
|
|
403
|
+
seq_len = query.shape[1] if is_3d else query.shape[2]
|
|
404
|
+
|
|
227
405
|
if (query.device.type == 'mps' and
|
|
228
406
|
dropout_p == 0.0 and
|
|
229
407
|
_HAS_MFA and
|
|
230
|
-
query.
|
|
408
|
+
query.ndim >= 3 and
|
|
409
|
+
seq_len >= 512):
|
|
231
410
|
try:
|
|
411
|
+
q, k, v = query, key, value
|
|
412
|
+
|
|
413
|
+
# Handle 3D tensors (B, S, D) - treat as single-head attention
|
|
414
|
+
# Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
|
|
415
|
+
if is_3d:
|
|
416
|
+
q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
|
|
417
|
+
k = k.unsqueeze(1)
|
|
418
|
+
v = v.unsqueeze(1)
|
|
419
|
+
|
|
420
|
+
# Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
|
|
421
|
+
# Common in Llama 2/3, Mistral, Qwen, etc.
|
|
422
|
+
# NOTE: Always expand when heads mismatch, not just when enable_gqa=True
|
|
423
|
+
# Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
|
|
424
|
+
# MPS SDPA doesn't support native GQA - we handle it here
|
|
425
|
+
if q.shape[1] != k.shape[1]:
|
|
426
|
+
# Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
|
|
427
|
+
n_rep = q.shape[1] // k.shape[1]
|
|
428
|
+
k = k.repeat_interleave(n_rep, dim=1)
|
|
429
|
+
v = v.repeat_interleave(n_rep, dim=1)
|
|
430
|
+
|
|
232
431
|
# Convert float mask to bool mask if needed
|
|
233
432
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
234
433
|
# MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
|
|
235
434
|
mfa_mask = None
|
|
236
435
|
if attn_mask is not None:
|
|
436
|
+
if _debug:
|
|
437
|
+
print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
|
|
237
438
|
if attn_mask.dtype == torch.bool:
|
|
238
|
-
#
|
|
239
|
-
|
|
439
|
+
# PyTorch SDPA bool mask: True = ATTEND, False = MASKED
|
|
440
|
+
# MFA bool mask: True = MASKED, False = ATTEND
|
|
441
|
+
# They're opposite! Invert it.
|
|
442
|
+
mfa_mask = ~attn_mask
|
|
240
443
|
else:
|
|
241
444
|
# Float mask: typically -inf for masked positions, 0 for unmasked
|
|
242
445
|
# Convert: positions with large negative values -> True (masked)
|
|
243
446
|
# Use -1e3 threshold to catch -1000, -10000, -inf, etc.
|
|
244
447
|
mfa_mask = attn_mask <= -1e3
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
448
|
+
if _debug:
|
|
449
|
+
print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
|
|
450
|
+
|
|
451
|
+
out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
452
|
+
|
|
453
|
+
# Squeeze back for 3D input
|
|
454
|
+
if is_3d:
|
|
455
|
+
out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
|
|
456
|
+
|
|
457
|
+
if _debug:
|
|
458
|
+
_call_count[0] += 1
|
|
459
|
+
print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
|
|
460
|
+
|
|
461
|
+
return out
|
|
462
|
+
except Exception as e:
|
|
463
|
+
# Fall back to original on any error, but track it
|
|
464
|
+
_fallback_count[0] += 1
|
|
465
|
+
_last_fallback_error[0] = str(e)
|
|
466
|
+
if _debug:
|
|
467
|
+
import traceback
|
|
468
|
+
print(f"[MFA FALLBACK #{_fallback_count[0]}] shape={tuple(query.shape)}\n{traceback.format_exc()}")
|
|
469
|
+
# Warn user after repeated fallbacks (likely a real problem)
|
|
470
|
+
if _fallback_count[0] == 10:
|
|
471
|
+
warnings.warn(
|
|
472
|
+
f"MFA has fallen back to native SDPA {_fallback_count[0]} times. "
|
|
473
|
+
f"Last error: {_last_fallback_error[0]}. "
|
|
474
|
+
f"Set MFA_DEBUG=1 for details.",
|
|
475
|
+
UserWarning
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
if _debug and query.device.type == 'mps':
|
|
479
|
+
_call_count[0] += 1
|
|
480
|
+
reason = []
|
|
481
|
+
if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
|
|
482
|
+
if query.ndim < 3: reason.append(f"ndim={query.ndim}")
|
|
483
|
+
if seq_len < 512: reason.append(f"seq={seq_len}<512")
|
|
484
|
+
print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
|
|
249
485
|
|
|
250
486
|
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
|
|
251
487
|
|
|
@@ -302,6 +538,7 @@ def flash_attention_with_bias(
|
|
|
302
538
|
value: torch.Tensor,
|
|
303
539
|
attn_bias: torch.Tensor,
|
|
304
540
|
is_causal: bool = False,
|
|
541
|
+
scale: Optional[float] = None,
|
|
305
542
|
window_size: int = 0,
|
|
306
543
|
bias_repeat_count: int = 0,
|
|
307
544
|
) -> torch.Tensor:
|
|
@@ -309,13 +546,15 @@ def flash_attention_with_bias(
|
|
|
309
546
|
Compute scaled dot-product attention with additive attention bias.
|
|
310
547
|
|
|
311
548
|
This function supports additive attention bias (like relative position encodings
|
|
312
|
-
or ALiBi) which is added to the attention scores
|
|
549
|
+
or ALiBi) which is added to the attention scores:
|
|
550
|
+
|
|
551
|
+
Attention(Q, K, V) = softmax((Q @ K.T + bias) * scale) @ V
|
|
313
552
|
|
|
314
|
-
|
|
553
|
+
IMPORTANT: The bias is added to UNSCALED scores, then the sum is scaled.
|
|
554
|
+
This differs from PyTorch SDPA which does: softmax((Q @ K.T) * scale + bias).
|
|
315
555
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
pre-scale it by multiplying by sqrt(head_dim).
|
|
556
|
+
To convert from SDPA-style bias to MFA-style:
|
|
557
|
+
bias_mfa = bias_sdpa * sqrt(head_dim) # when using default scale
|
|
319
558
|
|
|
320
559
|
Args:
|
|
321
560
|
query: Query tensor of shape (B, H, N_q, D)
|
|
@@ -326,6 +565,7 @@ def flash_attention_with_bias(
|
|
|
326
565
|
- (1, H, N_q, N_kv): Broadcast across batch
|
|
327
566
|
- (H, N_q, N_kv): Broadcast across batch (3D)
|
|
328
567
|
is_causal: If True, applies causal masking
|
|
568
|
+
scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
|
|
329
569
|
window_size: Sliding window attention size (0 = full attention)
|
|
330
570
|
bias_repeat_count: If > 0, the bias tensor repeats every N batches.
|
|
331
571
|
Useful for window attention where multiple windows share the same
|
|
@@ -343,10 +583,13 @@ def flash_attention_with_bias(
|
|
|
343
583
|
>>> v = torch.randn(4, 8, 64, 64, device='mps', dtype=torch.float16)
|
|
344
584
|
>>> # Position bias: (1, num_heads, seq_len, seq_len)
|
|
345
585
|
>>> bias = torch.randn(1, 8, 64, 64, device='mps', dtype=torch.float16)
|
|
346
|
-
>>> # Pre-scale bias since
|
|
586
|
+
>>> # Pre-scale bias since default scale is 1/sqrt(head_dim)
|
|
347
587
|
>>> scaled_bias = bias * math.sqrt(64) # sqrt(head_dim)
|
|
348
588
|
>>> out = flash_attention_with_bias(q, k, v, scaled_bias)
|
|
349
589
|
|
|
590
|
+
>>> # With custom scale
|
|
591
|
+
>>> out = flash_attention_with_bias(q, k, v, bias, scale=0.1)
|
|
592
|
+
|
|
350
593
|
>>> # Window attention with repeating bias pattern
|
|
351
594
|
>>> n_windows = 16
|
|
352
595
|
>>> q = torch.randn(n_windows * 4, 8, 49, 64, device='mps', dtype=torch.float16)
|
|
@@ -363,6 +606,20 @@ def flash_attention_with_bias(
|
|
|
363
606
|
if not torch.backends.mps.is_available():
|
|
364
607
|
raise RuntimeError("MPS not available")
|
|
365
608
|
|
|
609
|
+
# Validate scale parameter
|
|
610
|
+
if scale is not None:
|
|
611
|
+
if scale <= 0:
|
|
612
|
+
raise ValueError(f"scale must be positive, got {scale}")
|
|
613
|
+
# Warn about extreme scale values
|
|
614
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
615
|
+
if scale < default_scale * 0.01 or scale > default_scale * 100:
|
|
616
|
+
warnings.warn(
|
|
617
|
+
f"scale={scale:.6g} is very different from default {default_scale:.6g}, "
|
|
618
|
+
"this may cause numerical issues",
|
|
619
|
+
UserWarning,
|
|
620
|
+
stacklevel=2
|
|
621
|
+
)
|
|
622
|
+
|
|
366
623
|
# Validate device
|
|
367
624
|
if query.device.type != 'mps':
|
|
368
625
|
raise ValueError("query must be on MPS device")
|
|
@@ -373,7 +630,10 @@ def flash_attention_with_bias(
|
|
|
373
630
|
if attn_bias.device.type != 'mps':
|
|
374
631
|
raise ValueError("attn_bias must be on MPS device")
|
|
375
632
|
|
|
376
|
-
|
|
633
|
+
# Use autograd Function for backward support
|
|
634
|
+
return FlashAttentionWithBiasFunction.apply(
|
|
635
|
+
query, key, value, attn_bias, is_causal, scale, window_size, bias_repeat_count
|
|
636
|
+
)
|
|
377
637
|
|
|
378
638
|
|
|
379
639
|
def flash_attention_chunked(
|
|
@@ -452,13 +712,13 @@ def flash_attention_chunked(
|
|
|
452
712
|
return _C.forward(query, key, value, is_causal, None, 0)
|
|
453
713
|
|
|
454
714
|
# Initialize running statistics for online softmax
|
|
455
|
-
# m = running max, l = running sum of exp, acc = accumulated output
|
|
456
715
|
device = query.device
|
|
457
716
|
dtype = query.dtype
|
|
458
717
|
|
|
459
718
|
# Use float32 for numerical stability of softmax statistics
|
|
460
|
-
|
|
461
|
-
|
|
719
|
+
# running_L: base-2 logsumexp of all attention scores seen so far (-inf means no data yet)
|
|
720
|
+
# output_acc: weighted combination of outputs (weights sum to 1 after each update)
|
|
721
|
+
running_L = torch.full((B, H, seq_len_q, 1), float('-inf'), device=device, dtype=torch.float32)
|
|
462
722
|
output_acc = torch.zeros((B, H, seq_len_q, D), device=device, dtype=torch.float32)
|
|
463
723
|
|
|
464
724
|
# Process K/V in chunks
|
|
@@ -479,51 +739,81 @@ def flash_attention_chunked(
|
|
|
479
739
|
# - Partial chunk (up to q) if start_idx <= q < end_idx
|
|
480
740
|
# - None of chunk if q < start_idx
|
|
481
741
|
|
|
482
|
-
|
|
742
|
+
if is_causal:
|
|
743
|
+
# Create explicit causal mask for this chunk
|
|
744
|
+
# Query positions: 0 to seq_len_q-1
|
|
745
|
+
# Key positions in chunk: start_idx to end_idx-1
|
|
746
|
+
chunk_len = end_idx - start_idx
|
|
483
747
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
748
|
+
# Build mask: mask[q, k_local] = True means DON'T attend
|
|
749
|
+
# We want to attend when global_k_pos <= q
|
|
750
|
+
# global_k_pos = start_idx + k_local
|
|
751
|
+
# So: attend when start_idx + k_local <= q
|
|
752
|
+
# mask = start_idx + k_local > q
|
|
487
753
|
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
754
|
+
q_pos = torch.arange(seq_len_q, device=device).view(1, 1, seq_len_q, 1)
|
|
755
|
+
k_pos = torch.arange(chunk_len, device=device).view(1, 1, 1, chunk_len) + start_idx
|
|
756
|
+
causal_mask = k_pos > q_pos # True = masked (don't attend)
|
|
491
757
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
# Online softmax update:
|
|
496
|
-
# new_max = max(running_max, chunk_max)
|
|
497
|
-
# For flash attention, chunk_lse ≈ chunk_max + log(chunk_sum)
|
|
498
|
-
# We approximate chunk_max ≈ chunk_lse (valid when exp sum dominates)
|
|
758
|
+
# Expand to batch and heads
|
|
759
|
+
causal_mask = causal_mask.expand(B, H, seq_len_q, chunk_len)
|
|
499
760
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
761
|
+
# Call forward with explicit mask (is_causal=False since we handle it)
|
|
762
|
+
chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, causal_mask, 0)
|
|
763
|
+
else:
|
|
764
|
+
# Non-causal: just process the chunk directly
|
|
765
|
+
chunk_out, chunk_lse = _C.forward_with_lse(query, k_chunk, v_chunk, False, None, 0)
|
|
504
766
|
|
|
505
|
-
#
|
|
506
|
-
#
|
|
507
|
-
|
|
508
|
-
#
|
|
509
|
-
|
|
767
|
+
# chunk_L shape: (B, H, seq_len_q)
|
|
768
|
+
# The kernel returns L = m + log2(l) where:
|
|
769
|
+
# m = max(scores * log2(e) / sqrt(D))
|
|
770
|
+
# l = sum(exp2(scores * log2(e) / sqrt(D) - m))
|
|
771
|
+
# This is a base-2 logsumexp: L = log2(sum(exp2(scaled_scores)))
|
|
772
|
+
chunk_L = chunk_lse.unsqueeze(-1).float() # (B, H, seq_len_q, 1)
|
|
510
773
|
|
|
511
|
-
#
|
|
512
|
-
|
|
513
|
-
correction_new = torch.exp(chunk_max - new_max)
|
|
774
|
+
# Convert chunk output to float32 for accumulation
|
|
775
|
+
chunk_out = chunk_out.float()
|
|
514
776
|
|
|
515
|
-
#
|
|
516
|
-
#
|
|
517
|
-
#
|
|
518
|
-
|
|
777
|
+
# Online softmax algorithm using base-2 representation
|
|
778
|
+
#
|
|
779
|
+
# Flash attention returns: chunk_out = softmax(scores) @ V
|
|
780
|
+
# The output is already normalized. For online combination:
|
|
781
|
+
# new_L = log2(2^running_L + 2^chunk_L)
|
|
782
|
+
# = max(running_L, chunk_L) + log2(2^(running_L - max) + 2^(chunk_L - max))
|
|
783
|
+
#
|
|
784
|
+
# The weights for combining outputs are:
|
|
785
|
+
# old_weight = 2^(running_L - new_L)
|
|
786
|
+
# new_weight = 2^(chunk_L - new_L)
|
|
787
|
+
# These weights sum to 1, so: output = old_weight * old_out + new_weight * new_out
|
|
788
|
+
|
|
789
|
+
# Compute new base-2 logsumexp
|
|
790
|
+
max_L = torch.maximum(running_L, chunk_L)
|
|
791
|
+
|
|
792
|
+
# Handle -inf case (no previous data)
|
|
793
|
+
# Use exp2 for base-2 (matches kernel's internal representation)
|
|
794
|
+
running_exp2 = torch.where(
|
|
795
|
+
running_L == float('-inf'),
|
|
796
|
+
torch.zeros_like(running_L),
|
|
797
|
+
torch.exp2(running_L - max_L)
|
|
798
|
+
)
|
|
799
|
+
chunk_exp2 = torch.exp2(chunk_L - max_L)
|
|
800
|
+
new_L = max_L + torch.log2(running_exp2 + chunk_exp2)
|
|
801
|
+
|
|
802
|
+
# Compute correction factors using base-2 exp
|
|
803
|
+
old_weight = torch.where(
|
|
804
|
+
running_L == float('-inf'),
|
|
805
|
+
torch.zeros_like(running_L),
|
|
806
|
+
torch.exp2(running_L - new_L)
|
|
807
|
+
)
|
|
808
|
+
new_weight = torch.exp2(chunk_L - new_L)
|
|
519
809
|
|
|
520
810
|
# Update accumulator
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
811
|
+
# Update accumulator
|
|
812
|
+
output_acc = output_acc * old_weight + chunk_out * new_weight
|
|
813
|
+
running_L = new_L
|
|
524
814
|
|
|
525
|
-
#
|
|
526
|
-
output = output_acc
|
|
815
|
+
# No final normalization needed - weights already sum to 1
|
|
816
|
+
output = output_acc
|
|
527
817
|
|
|
528
818
|
# Convert back to original dtype
|
|
529
819
|
return output.to(dtype)
|
|
@@ -745,6 +1035,11 @@ def flash_attention_fp8(
|
|
|
745
1035
|
scale_factor = scale / default_scale
|
|
746
1036
|
query = query * scale_factor
|
|
747
1037
|
|
|
1038
|
+
# Validate and expand broadcast mask
|
|
1039
|
+
B, H, N_q, D = query.shape
|
|
1040
|
+
N_kv = key.shape[2]
|
|
1041
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1042
|
+
|
|
748
1043
|
quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
|
|
749
1044
|
return _C.forward_quantized(
|
|
750
1045
|
query, key, value, k_scale, v_scale,
|
|
@@ -798,6 +1093,11 @@ def flash_attention_int8(
|
|
|
798
1093
|
scale_factor = scale / default_scale
|
|
799
1094
|
query = query * scale_factor
|
|
800
1095
|
|
|
1096
|
+
# Validate and expand broadcast mask
|
|
1097
|
+
B, H, N_q, D = query.shape
|
|
1098
|
+
N_kv = key.shape[2]
|
|
1099
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1100
|
+
|
|
801
1101
|
return _C.forward_quantized(
|
|
802
1102
|
query, key, value, k_scale, v_scale,
|
|
803
1103
|
QUANT_INT8, is_causal, attn_mask, window_size
|
|
@@ -854,6 +1154,11 @@ def flash_attention_nf4(
|
|
|
854
1154
|
scale_factor = scale / default_scale
|
|
855
1155
|
query = query * scale_factor
|
|
856
1156
|
|
|
1157
|
+
# Validate and expand broadcast mask
|
|
1158
|
+
B, H, N_q, D = query.shape
|
|
1159
|
+
N_kv = key.shape[2]
|
|
1160
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1161
|
+
|
|
857
1162
|
return _C.forward_quantized(
|
|
858
1163
|
query, key, value, k_scale, v_scale,
|
|
859
1164
|
QUANT_NF4, is_causal, attn_mask, window_size
|
|
@@ -908,6 +1213,11 @@ def flash_attention_quantized(
|
|
|
908
1213
|
scale_factor = scale / default_scale
|
|
909
1214
|
query = query * scale_factor
|
|
910
1215
|
|
|
1216
|
+
# Validate and expand broadcast mask
|
|
1217
|
+
B, H, N_q, D = query.shape
|
|
1218
|
+
N_kv = key.shape[2]
|
|
1219
|
+
attn_mask = _validate_and_expand_mask(attn_mask, B, H, N_q, N_kv)
|
|
1220
|
+
|
|
911
1221
|
return _C.forward_quantized(
|
|
912
1222
|
query, key, value, k_scale, v_scale,
|
|
913
1223
|
quant_type, is_causal, attn_mask, window_size
|