mps-flash-attn 0.2.1__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/PKG-INFO +8 -1
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/README.md +7 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/__init__.py +241 -1
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/csrc/mps_flash_attn.mm +246 -7
- mps_flash_attn-0.2.5/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/PKG-INFO +8 -1
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/pyproject.toml +1 -1
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/tests/test_mfa_v2.py +91 -0
- mps_flash_attn-0.2.1/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/LICENSE +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/setup.cfg +0 -0
- {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mps-flash-attn
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
|
|
5
5
|
Author: imperatormk
|
|
6
6
|
License-Expression: MIT
|
|
@@ -201,6 +201,13 @@ Python API (mps_flash_attn)
|
|
|
201
201
|
- Python 3.10+
|
|
202
202
|
- PyTorch 2.0+
|
|
203
203
|
|
|
204
|
+
## TODO / Future Optimizations
|
|
205
|
+
|
|
206
|
+
- [ ] **Batched kernel dispatch** - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
|
|
207
|
+
- [ ] **Fused QKV projection + attention** - Single kernel from input to output, avoid intermediate buffers
|
|
208
|
+
- [ ] **Pre-scaled bias option** - Allow passing pre-scaled bias to avoid per-call scaling overhead
|
|
209
|
+
- [ ] **LoRA fusion** - Fuse adapter weights into attention computation
|
|
210
|
+
|
|
204
211
|
## Credits
|
|
205
212
|
|
|
206
213
|
- [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
|
|
@@ -176,6 +176,13 @@ Python API (mps_flash_attn)
|
|
|
176
176
|
- Python 3.10+
|
|
177
177
|
- PyTorch 2.0+
|
|
178
178
|
|
|
179
|
+
## TODO / Future Optimizations
|
|
180
|
+
|
|
181
|
+
- [ ] **Batched kernel dispatch** - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
|
|
182
|
+
- [ ] **Fused QKV projection + attention** - Single kernel from input to output, avoid intermediate buffers
|
|
183
|
+
- [ ] **Pre-scaled bias option** - Allow passing pre-scaled bias to avoid per-call scaling overhead
|
|
184
|
+
- [ ] **LoRA fusion** - Fuse adapter weights into attention computation
|
|
185
|
+
|
|
179
186
|
## Credits
|
|
180
187
|
|
|
181
188
|
- [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
|
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.2.
|
|
7
|
+
__version__ = "0.2.5"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -296,6 +296,86 @@ def precompile():
|
|
|
296
296
|
print("\nPre-compilation complete! Kernels cached to disk.")
|
|
297
297
|
|
|
298
298
|
|
|
299
|
+
def flash_attention_with_bias(
|
|
300
|
+
query: torch.Tensor,
|
|
301
|
+
key: torch.Tensor,
|
|
302
|
+
value: torch.Tensor,
|
|
303
|
+
attn_bias: torch.Tensor,
|
|
304
|
+
is_causal: bool = False,
|
|
305
|
+
window_size: int = 0,
|
|
306
|
+
bias_repeat_count: int = 0,
|
|
307
|
+
) -> torch.Tensor:
|
|
308
|
+
"""
|
|
309
|
+
Compute scaled dot-product attention with additive attention bias.
|
|
310
|
+
|
|
311
|
+
This function supports additive attention bias (like relative position encodings
|
|
312
|
+
or ALiBi) which is added to the attention scores before softmax:
|
|
313
|
+
|
|
314
|
+
Attention(Q, K, V) = softmax((Q @ K.T) / sqrt(d) + bias) @ V
|
|
315
|
+
|
|
316
|
+
IMPORTANT: MFA adds bias to UNSCALED scores internally and scales during softmax.
|
|
317
|
+
If your bias was computed for scaled scores (like PyTorch SDPA), you need to
|
|
318
|
+
pre-scale it by multiplying by sqrt(head_dim).
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
query: Query tensor of shape (B, H, N_q, D)
|
|
322
|
+
key: Key tensor of shape (B, H, N_kv, D)
|
|
323
|
+
value: Value tensor of shape (B, H, N_kv, D)
|
|
324
|
+
attn_bias: Additive attention bias of shape:
|
|
325
|
+
- (B, H, N_q, N_kv): Full bias for each batch/head
|
|
326
|
+
- (1, H, N_q, N_kv): Broadcast across batch
|
|
327
|
+
- (H, N_q, N_kv): Broadcast across batch (3D)
|
|
328
|
+
is_causal: If True, applies causal masking
|
|
329
|
+
window_size: Sliding window attention size (0 = full attention)
|
|
330
|
+
bias_repeat_count: If > 0, the bias tensor repeats every N batches.
|
|
331
|
+
Useful for window attention where multiple windows share the same
|
|
332
|
+
position bias pattern. E.g., for Swin Transformer with 4 windows,
|
|
333
|
+
set bias_repeat_count=num_windows so bias[batch_idx % num_windows]
|
|
334
|
+
is used.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
Output tensor of shape (B, H, N_q, D)
|
|
338
|
+
|
|
339
|
+
Example:
|
|
340
|
+
>>> # Relative position bias (Swin Transformer style)
|
|
341
|
+
>>> q = torch.randn(4, 8, 64, 64, device='mps', dtype=torch.float16)
|
|
342
|
+
>>> k = torch.randn(4, 8, 64, 64, device='mps', dtype=torch.float16)
|
|
343
|
+
>>> v = torch.randn(4, 8, 64, 64, device='mps', dtype=torch.float16)
|
|
344
|
+
>>> # Position bias: (1, num_heads, seq_len, seq_len)
|
|
345
|
+
>>> bias = torch.randn(1, 8, 64, 64, device='mps', dtype=torch.float16)
|
|
346
|
+
>>> # Pre-scale bias since MFA uses unscaled scores
|
|
347
|
+
>>> scaled_bias = bias * math.sqrt(64) # sqrt(head_dim)
|
|
348
|
+
>>> out = flash_attention_with_bias(q, k, v, scaled_bias)
|
|
349
|
+
|
|
350
|
+
>>> # Window attention with repeating bias pattern
|
|
351
|
+
>>> n_windows = 16
|
|
352
|
+
>>> q = torch.randn(n_windows * 4, 8, 49, 64, device='mps', dtype=torch.float16)
|
|
353
|
+
>>> bias = torch.randn(n_windows, 8, 49, 49, device='mps', dtype=torch.float16)
|
|
354
|
+
>>> scaled_bias = bias * math.sqrt(64)
|
|
355
|
+
>>> out = flash_attention_with_bias(q, k, v, scaled_bias, bias_repeat_count=n_windows)
|
|
356
|
+
"""
|
|
357
|
+
if not _HAS_MFA:
|
|
358
|
+
raise RuntimeError(
|
|
359
|
+
f"MPS Flash Attention C++ extension not available: {_IMPORT_ERROR}\n"
|
|
360
|
+
"Please rebuild with: pip install -e ."
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
if not torch.backends.mps.is_available():
|
|
364
|
+
raise RuntimeError("MPS not available")
|
|
365
|
+
|
|
366
|
+
# Validate device
|
|
367
|
+
if query.device.type != 'mps':
|
|
368
|
+
raise ValueError("query must be on MPS device")
|
|
369
|
+
if key.device.type != 'mps':
|
|
370
|
+
raise ValueError("key must be on MPS device")
|
|
371
|
+
if value.device.type != 'mps':
|
|
372
|
+
raise ValueError("value must be on MPS device")
|
|
373
|
+
if attn_bias.device.type != 'mps':
|
|
374
|
+
raise ValueError("attn_bias must be on MPS device")
|
|
375
|
+
|
|
376
|
+
return _C.forward_with_bias(query, key, value, attn_bias, is_causal, window_size, bias_repeat_count)
|
|
377
|
+
|
|
378
|
+
|
|
299
379
|
def flash_attention_chunked(
|
|
300
380
|
query: torch.Tensor,
|
|
301
381
|
key: torch.Tensor,
|
|
@@ -541,6 +621,80 @@ def quantize_kv_int8(
|
|
|
541
621
|
return k_quant, v_quant, k_scale, v_scale
|
|
542
622
|
|
|
543
623
|
|
|
624
|
+
# NF4 codebook (must match Metal shader's NF4_CODEBOOK exactly)
|
|
625
|
+
_NF4_CODEBOOK = torch.tensor([
|
|
626
|
+
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
|
|
627
|
+
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
|
|
628
|
+
0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
|
|
629
|
+
0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
|
|
630
|
+
], dtype=torch.float32)
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def quantize_kv_nf4(
|
|
634
|
+
key: torch.Tensor,
|
|
635
|
+
value: torch.Tensor,
|
|
636
|
+
) -> tuple:
|
|
637
|
+
"""
|
|
638
|
+
Quantize Key and Value tensors to NF4 (NormalFloat 4-bit) format.
|
|
639
|
+
|
|
640
|
+
NF4 quantization provides 4x memory reduction using a 16-value codebook
|
|
641
|
+
optimized for normally distributed weights. Two values are packed per byte.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
key: Key tensor of shape (B, H, N, D) where D must be even
|
|
645
|
+
value: Value tensor of shape (B, H, N, D) where D must be even
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
Tuple of (key_quant, value_quant, k_scale, v_scale) where:
|
|
649
|
+
- key_quant, value_quant: uint8 tensors of shape (B, H, N, D//2) with packed values
|
|
650
|
+
- k_scale, v_scale: float32 tensors with per-head scale factors (B, H)
|
|
651
|
+
|
|
652
|
+
Example:
|
|
653
|
+
>>> k_q, v_q, k_s, v_s = quantize_kv_nf4(key, value)
|
|
654
|
+
>>> out = flash_attention_nf4(query, k_q, v_q, k_s, v_s)
|
|
655
|
+
"""
|
|
656
|
+
if not _HAS_MFA:
|
|
657
|
+
raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
|
|
658
|
+
|
|
659
|
+
def _quantize_nf4(tensor: torch.Tensor) -> tuple:
|
|
660
|
+
"""Quantize a single tensor to NF4 format."""
|
|
661
|
+
B, H, N, D = tensor.shape
|
|
662
|
+
if D % 2 != 0:
|
|
663
|
+
raise ValueError(f"Head dimension D must be even for NF4 quantization, got D={D}")
|
|
664
|
+
|
|
665
|
+
# Convert to float32 for quantization
|
|
666
|
+
t = tensor.float()
|
|
667
|
+
|
|
668
|
+
# Compute per-head absmax for scale
|
|
669
|
+
abs_max = t.abs().amax(dim=(2, 3)) # (B, H)
|
|
670
|
+
scale = abs_max.clamp_min(1e-12)
|
|
671
|
+
|
|
672
|
+
# Normalize to [-1, 1] range
|
|
673
|
+
scale_expanded = scale.unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1)
|
|
674
|
+
normalized = t / scale_expanded
|
|
675
|
+
|
|
676
|
+
# Find nearest NF4 codebook entry for each value
|
|
677
|
+
codebook = _NF4_CODEBOOK.to(tensor.device) # (16,)
|
|
678
|
+
# Reshape for broadcasting: normalized is (B, H, N, D), codebook is (16,)
|
|
679
|
+
# Compute distances to all codebook entries
|
|
680
|
+
flat = normalized.reshape(-1, 1) # (B*H*N*D, 1)
|
|
681
|
+
distances = (flat - codebook.unsqueeze(0)).abs() # (B*H*N*D, 16)
|
|
682
|
+
indices = distances.argmin(dim=1) # (B*H*N*D,)
|
|
683
|
+
indices = indices.reshape(B, H, N, D) # (B, H, N, D)
|
|
684
|
+
|
|
685
|
+
# Pack two 4-bit indices per byte
|
|
686
|
+
# Even indices go to low nibble, odd indices go to high nibble
|
|
687
|
+
indices_even = indices[:, :, :, 0::2] # (B, H, N, D//2)
|
|
688
|
+
indices_odd = indices[:, :, :, 1::2] # (B, H, N, D//2)
|
|
689
|
+
packed = (indices_even | (indices_odd << 4)).to(torch.uint8) # (B, H, N, D//2)
|
|
690
|
+
|
|
691
|
+
return packed, scale
|
|
692
|
+
|
|
693
|
+
k_quant, k_scale = _quantize_nf4(key)
|
|
694
|
+
v_quant, v_scale = _quantize_nf4(value)
|
|
695
|
+
return k_quant, v_quant, k_scale, v_scale
|
|
696
|
+
|
|
697
|
+
|
|
544
698
|
def flash_attention_fp8(
|
|
545
699
|
query: torch.Tensor,
|
|
546
700
|
key: torch.Tensor,
|
|
@@ -551,6 +705,7 @@ def flash_attention_fp8(
|
|
|
551
705
|
attn_mask: Optional[torch.Tensor] = None,
|
|
552
706
|
window_size: int = 0,
|
|
553
707
|
use_e5m2: bool = False,
|
|
708
|
+
scale: Optional[float] = None,
|
|
554
709
|
) -> torch.Tensor:
|
|
555
710
|
"""
|
|
556
711
|
Compute attention with FP8 quantized Key/Value tensors.
|
|
@@ -568,6 +723,7 @@ def flash_attention_fp8(
|
|
|
568
723
|
attn_mask: Optional boolean attention mask
|
|
569
724
|
window_size: Sliding window size (0 = full attention)
|
|
570
725
|
use_e5m2: If True, use E5M2 format. Default: False (E4M3)
|
|
726
|
+
scale: Softmax scale factor. If None, uses 1/sqrt(head_dim)
|
|
571
727
|
|
|
572
728
|
Returns:
|
|
573
729
|
Output tensor of shape (B, H, N, D)
|
|
@@ -581,6 +737,14 @@ def flash_attention_fp8(
|
|
|
581
737
|
if not _HAS_MFA:
|
|
582
738
|
raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
|
|
583
739
|
|
|
740
|
+
# Apply custom scale by pre-scaling Q
|
|
741
|
+
if scale is not None:
|
|
742
|
+
head_dim = query.shape[-1]
|
|
743
|
+
default_scale = 1.0 / math.sqrt(head_dim)
|
|
744
|
+
if abs(scale - default_scale) > 1e-9:
|
|
745
|
+
scale_factor = scale / default_scale
|
|
746
|
+
query = query * scale_factor
|
|
747
|
+
|
|
584
748
|
quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
|
|
585
749
|
return _C.forward_quantized(
|
|
586
750
|
query, key, value, k_scale, v_scale,
|
|
@@ -597,6 +761,7 @@ def flash_attention_int8(
|
|
|
597
761
|
is_causal: bool = False,
|
|
598
762
|
attn_mask: Optional[torch.Tensor] = None,
|
|
599
763
|
window_size: int = 0,
|
|
764
|
+
scale: Optional[float] = None,
|
|
600
765
|
) -> torch.Tensor:
|
|
601
766
|
"""
|
|
602
767
|
Compute attention with INT8 quantized Key/Value tensors.
|
|
@@ -613,6 +778,7 @@ def flash_attention_int8(
|
|
|
613
778
|
is_causal: If True, applies causal masking
|
|
614
779
|
attn_mask: Optional boolean attention mask
|
|
615
780
|
window_size: Sliding window size (0 = full attention)
|
|
781
|
+
scale: Softmax scale factor. If None, uses 1/sqrt(head_dim)
|
|
616
782
|
|
|
617
783
|
Returns:
|
|
618
784
|
Output tensor of shape (B, H, N, D)
|
|
@@ -624,12 +790,76 @@ def flash_attention_int8(
|
|
|
624
790
|
if not _HAS_MFA:
|
|
625
791
|
raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
|
|
626
792
|
|
|
793
|
+
# Apply custom scale by pre-scaling Q
|
|
794
|
+
if scale is not None:
|
|
795
|
+
head_dim = query.shape[-1]
|
|
796
|
+
default_scale = 1.0 / math.sqrt(head_dim)
|
|
797
|
+
if abs(scale - default_scale) > 1e-9:
|
|
798
|
+
scale_factor = scale / default_scale
|
|
799
|
+
query = query * scale_factor
|
|
800
|
+
|
|
627
801
|
return _C.forward_quantized(
|
|
628
802
|
query, key, value, k_scale, v_scale,
|
|
629
803
|
QUANT_INT8, is_causal, attn_mask, window_size
|
|
630
804
|
)
|
|
631
805
|
|
|
632
806
|
|
|
807
|
+
def flash_attention_nf4(
|
|
808
|
+
query: torch.Tensor,
|
|
809
|
+
key: torch.Tensor,
|
|
810
|
+
value: torch.Tensor,
|
|
811
|
+
k_scale: torch.Tensor,
|
|
812
|
+
v_scale: torch.Tensor,
|
|
813
|
+
is_causal: bool = False,
|
|
814
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
815
|
+
window_size: int = 0,
|
|
816
|
+
scale: Optional[float] = None,
|
|
817
|
+
) -> torch.Tensor:
|
|
818
|
+
"""
|
|
819
|
+
Compute attention with NF4 (NormalFloat 4-bit) quantized Key/Value tensors.
|
|
820
|
+
|
|
821
|
+
This function provides 4x memory reduction for K/V cache using NF4 quantization
|
|
822
|
+
with a 16-value codebook optimized for normally distributed weights.
|
|
823
|
+
|
|
824
|
+
NF4 packs two 4-bit values per byte, so the key/value tensors have shape
|
|
825
|
+
(B, H, N, D//2) where D is the original head dimension.
|
|
826
|
+
|
|
827
|
+
Args:
|
|
828
|
+
query: Query tensor (B, H, N, D) in FP16/BF16/FP32
|
|
829
|
+
key: Quantized Key tensor (B, H, N, D//2) as uint8 (packed NF4)
|
|
830
|
+
value: Quantized Value tensor (B, H, N, D//2) as uint8 (packed NF4)
|
|
831
|
+
k_scale: Per-head scale for K (B, H) or (H,)
|
|
832
|
+
v_scale: Per-head scale for V (B, H) or (H,)
|
|
833
|
+
is_causal: If True, applies causal masking
|
|
834
|
+
attn_mask: Optional boolean attention mask
|
|
835
|
+
window_size: Sliding window size (0 = full attention)
|
|
836
|
+
scale: Softmax scale factor. If None, uses 1/sqrt(head_dim)
|
|
837
|
+
|
|
838
|
+
Returns:
|
|
839
|
+
Output tensor of shape (B, H, N, D)
|
|
840
|
+
|
|
841
|
+
Example:
|
|
842
|
+
>>> k_q, v_q, k_s, v_s = quantize_kv_nf4(key, value)
|
|
843
|
+
>>> out = flash_attention_nf4(query, k_q, v_q, k_s, v_s)
|
|
844
|
+
"""
|
|
845
|
+
if not _HAS_MFA:
|
|
846
|
+
raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
|
|
847
|
+
|
|
848
|
+
# Apply custom scale by pre-scaling Q
|
|
849
|
+
# Kernel uses 1/sqrt(D), so we adjust Q to achieve desired scale
|
|
850
|
+
if scale is not None:
|
|
851
|
+
head_dim = query.shape[-1]
|
|
852
|
+
default_scale = 1.0 / math.sqrt(head_dim)
|
|
853
|
+
if abs(scale - default_scale) > 1e-9:
|
|
854
|
+
scale_factor = scale / default_scale
|
|
855
|
+
query = query * scale_factor
|
|
856
|
+
|
|
857
|
+
return _C.forward_quantized(
|
|
858
|
+
query, key, value, k_scale, v_scale,
|
|
859
|
+
QUANT_NF4, is_causal, attn_mask, window_size
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
|
|
633
863
|
def flash_attention_quantized(
|
|
634
864
|
query: torch.Tensor,
|
|
635
865
|
key: torch.Tensor,
|
|
@@ -640,6 +870,7 @@ def flash_attention_quantized(
|
|
|
640
870
|
is_causal: bool = False,
|
|
641
871
|
attn_mask: Optional[torch.Tensor] = None,
|
|
642
872
|
window_size: int = 0,
|
|
873
|
+
scale: Optional[float] = None,
|
|
643
874
|
) -> torch.Tensor:
|
|
644
875
|
"""
|
|
645
876
|
Generic quantized attention with configurable quantization type.
|
|
@@ -661,6 +892,7 @@ def flash_attention_quantized(
|
|
|
661
892
|
is_causal: If True, applies causal masking
|
|
662
893
|
attn_mask: Optional boolean attention mask
|
|
663
894
|
window_size: Sliding window size (0 = full attention)
|
|
895
|
+
scale: Softmax scale factor. If None, uses 1/sqrt(head_dim)
|
|
664
896
|
|
|
665
897
|
Returns:
|
|
666
898
|
Output tensor of shape (B, H, N, D)
|
|
@@ -668,6 +900,14 @@ def flash_attention_quantized(
|
|
|
668
900
|
if not _HAS_MFA:
|
|
669
901
|
raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
|
|
670
902
|
|
|
903
|
+
# Apply custom scale by pre-scaling Q
|
|
904
|
+
if scale is not None:
|
|
905
|
+
head_dim = query.shape[-1]
|
|
906
|
+
default_scale = 1.0 / math.sqrt(head_dim)
|
|
907
|
+
if abs(scale - default_scale) > 1e-9:
|
|
908
|
+
scale_factor = scale / default_scale
|
|
909
|
+
query = query * scale_factor
|
|
910
|
+
|
|
671
911
|
return _C.forward_quantized(
|
|
672
912
|
query, key, value, k_scale, v_scale,
|
|
673
913
|
quant_type, is_causal, attn_mask, window_size
|
|
@@ -25,11 +25,16 @@ typedef void* (*mfa_create_kernel_v2_fn)(int32_t, int32_t, int32_t, bool, bool,
|
|
|
25
25
|
typedef void* (*mfa_create_kernel_v3_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t);
|
|
26
26
|
typedef void* (*mfa_create_kernel_v4_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t, uint16_t);
|
|
27
27
|
typedef void* (*mfa_create_kernel_v5_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t, uint16_t, bool);
|
|
28
|
+
typedef void* (*mfa_create_kernel_v6_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t, uint16_t, bool, bool, uint32_t, uint32_t);
|
|
29
|
+
typedef void* (*mfa_create_kernel_v7_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t, uint16_t, bool, bool, uint32_t, uint32_t, uint32_t);
|
|
28
30
|
// New zero-sync encode functions that take PyTorch's command encoder
|
|
29
31
|
// Added mask_ptr and mask_offset parameters
|
|
30
32
|
typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
|
|
31
33
|
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
32
34
|
int32_t, int32_t);
|
|
35
|
+
typedef bool (*mfa_forward_encode_bias_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
36
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
37
|
+
int32_t, int32_t);
|
|
33
38
|
typedef bool (*mfa_forward_encode_quantized_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
34
39
|
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
35
40
|
int32_t, int32_t);
|
|
@@ -52,7 +57,10 @@ static mfa_create_kernel_v2_fn g_mfa_create_kernel_v2 = nullptr;
|
|
|
52
57
|
static mfa_create_kernel_v3_fn g_mfa_create_kernel_v3 = nullptr;
|
|
53
58
|
static mfa_create_kernel_v4_fn g_mfa_create_kernel_v4 = nullptr;
|
|
54
59
|
static mfa_create_kernel_v5_fn g_mfa_create_kernel_v5 = nullptr;
|
|
60
|
+
static mfa_create_kernel_v6_fn g_mfa_create_kernel_v6 = nullptr;
|
|
61
|
+
static mfa_create_kernel_v7_fn g_mfa_create_kernel_v7 = nullptr;
|
|
55
62
|
static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
|
|
63
|
+
static mfa_forward_encode_bias_fn g_mfa_forward_encode_bias = nullptr;
|
|
56
64
|
static mfa_forward_encode_quantized_fn g_mfa_forward_encode_quantized = nullptr;
|
|
57
65
|
static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
|
|
58
66
|
static mfa_forward_fn g_mfa_forward = nullptr;
|
|
@@ -115,8 +123,11 @@ static bool load_mfa_bridge() {
|
|
|
115
123
|
g_mfa_create_kernel_v3 = (mfa_create_kernel_v3_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v3");
|
|
116
124
|
g_mfa_create_kernel_v4 = (mfa_create_kernel_v4_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v4");
|
|
117
125
|
g_mfa_create_kernel_v5 = (mfa_create_kernel_v5_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v5");
|
|
126
|
+
g_mfa_create_kernel_v6 = (mfa_create_kernel_v6_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v6");
|
|
127
|
+
g_mfa_create_kernel_v7 = (mfa_create_kernel_v7_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v7");
|
|
118
128
|
g_mfa_forward_encode_quantized = (mfa_forward_encode_quantized_fn)dlsym(g_dylib_handle, "mfa_forward_encode_quantized");
|
|
119
129
|
g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
|
|
130
|
+
g_mfa_forward_encode_bias = (mfa_forward_encode_bias_fn)dlsym(g_dylib_handle, "mfa_forward_encode_bias");
|
|
120
131
|
g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
|
|
121
132
|
g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
|
|
122
133
|
g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
|
|
@@ -169,6 +180,10 @@ struct KernelCacheKey {
|
|
|
169
180
|
uint32_t window_size;
|
|
170
181
|
uint16_t quantized_kv; // 0 = none, 3 = FP8_E4M3, 4 = FP8_E5M2, 5 = INT8, 6 = NF4
|
|
171
182
|
bool bf16_backward; // true = use BF16 for backward intermediates (faster)
|
|
183
|
+
bool has_attn_bias; // true = additive attention bias
|
|
184
|
+
uint32_t bias_batch_stride;
|
|
185
|
+
uint32_t bias_head_stride;
|
|
186
|
+
uint32_t bias_repeat_count;
|
|
172
187
|
|
|
173
188
|
bool operator==(const KernelCacheKey& other) const {
|
|
174
189
|
return seq_len_q == other.seq_len_q &&
|
|
@@ -181,7 +196,11 @@ struct KernelCacheKey {
|
|
|
181
196
|
use_bf16 == other.use_bf16 &&
|
|
182
197
|
window_size == other.window_size &&
|
|
183
198
|
quantized_kv == other.quantized_kv &&
|
|
184
|
-
bf16_backward == other.bf16_backward
|
|
199
|
+
bf16_backward == other.bf16_backward &&
|
|
200
|
+
has_attn_bias == other.has_attn_bias &&
|
|
201
|
+
bias_batch_stride == other.bias_batch_stride &&
|
|
202
|
+
bias_head_stride == other.bias_head_stride &&
|
|
203
|
+
bias_repeat_count == other.bias_repeat_count;
|
|
185
204
|
}
|
|
186
205
|
};
|
|
187
206
|
|
|
@@ -197,14 +216,18 @@ struct KernelCacheKeyHash {
|
|
|
197
216
|
(std::hash<bool>()(k.use_bf16) << 7) ^
|
|
198
217
|
(std::hash<uint32_t>()(k.window_size) << 8) ^
|
|
199
218
|
(std::hash<uint16_t>()(k.quantized_kv) << 9) ^
|
|
200
|
-
(std::hash<bool>()(k.bf16_backward) << 10)
|
|
219
|
+
(std::hash<bool>()(k.bf16_backward) << 10) ^
|
|
220
|
+
(std::hash<bool>()(k.has_attn_bias) << 11) ^
|
|
221
|
+
(std::hash<uint32_t>()(k.bias_batch_stride) << 12) ^
|
|
222
|
+
(std::hash<uint32_t>()(k.bias_head_stride) << 13) ^
|
|
223
|
+
(std::hash<uint32_t>()(k.bias_repeat_count) << 14);
|
|
201
224
|
}
|
|
202
225
|
};
|
|
203
226
|
|
|
204
227
|
static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
|
|
205
228
|
|
|
206
|
-
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false, uint32_t window_size = 0, uint16_t quantized_kv = 0, bool bf16_backward = false) {
|
|
207
|
-
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16, window_size, quantized_kv, bf16_backward};
|
|
229
|
+
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false, uint32_t window_size = 0, uint16_t quantized_kv = 0, bool bf16_backward = false, bool has_attn_bias = false, uint32_t bias_batch_stride = 0, uint32_t bias_head_stride = 0, uint32_t bias_repeat_count = 0) {
|
|
230
|
+
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16, window_size, quantized_kv, bf16_backward, has_attn_bias, bias_batch_stride, bias_head_stride, bias_repeat_count};
|
|
208
231
|
|
|
209
232
|
auto it = g_kernel_cache.find(key);
|
|
210
233
|
if (it != g_kernel_cache.end()) {
|
|
@@ -212,7 +235,44 @@ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_di
|
|
|
212
235
|
}
|
|
213
236
|
|
|
214
237
|
void* kernel = nullptr;
|
|
215
|
-
if (
|
|
238
|
+
if (has_attn_bias && g_mfa_create_kernel_v7) {
|
|
239
|
+
// Use v7 API with additive attention bias and repeat support
|
|
240
|
+
kernel = g_mfa_create_kernel_v7(
|
|
241
|
+
static_cast<int32_t>(seq_q),
|
|
242
|
+
static_cast<int32_t>(seq_kv),
|
|
243
|
+
static_cast<int32_t>(head_dim),
|
|
244
|
+
low_prec,
|
|
245
|
+
low_prec_outputs,
|
|
246
|
+
causal,
|
|
247
|
+
has_mask,
|
|
248
|
+
use_bf16,
|
|
249
|
+
window_size,
|
|
250
|
+
quantized_kv,
|
|
251
|
+
bf16_backward,
|
|
252
|
+
has_attn_bias,
|
|
253
|
+
bias_batch_stride,
|
|
254
|
+
bias_head_stride,
|
|
255
|
+
bias_repeat_count
|
|
256
|
+
);
|
|
257
|
+
} else if (has_attn_bias && g_mfa_create_kernel_v6) {
|
|
258
|
+
// Use v6 API with additive attention bias (no repeat)
|
|
259
|
+
kernel = g_mfa_create_kernel_v6(
|
|
260
|
+
static_cast<int32_t>(seq_q),
|
|
261
|
+
static_cast<int32_t>(seq_kv),
|
|
262
|
+
static_cast<int32_t>(head_dim),
|
|
263
|
+
low_prec,
|
|
264
|
+
low_prec_outputs,
|
|
265
|
+
causal,
|
|
266
|
+
has_mask,
|
|
267
|
+
use_bf16,
|
|
268
|
+
window_size,
|
|
269
|
+
quantized_kv,
|
|
270
|
+
bf16_backward,
|
|
271
|
+
has_attn_bias,
|
|
272
|
+
bias_batch_stride,
|
|
273
|
+
bias_head_stride
|
|
274
|
+
);
|
|
275
|
+
} else if (g_mfa_create_kernel_v5) {
|
|
216
276
|
// Use v5 API with all features including mixed-precision backward
|
|
217
277
|
kernel = g_mfa_create_kernel_v5(
|
|
218
278
|
static_cast<int32_t>(seq_q),
|
|
@@ -489,6 +549,168 @@ at::Tensor mps_flash_attention_forward(
|
|
|
489
549
|
return output;
|
|
490
550
|
}
|
|
491
551
|
|
|
552
|
+
// ============================================================================
|
|
553
|
+
// Flash Attention Forward with Additive Bias
|
|
554
|
+
// ============================================================================
|
|
555
|
+
|
|
556
|
+
at::Tensor mps_flash_attention_forward_with_bias(
|
|
557
|
+
const at::Tensor& query, // (B, H, N, D)
|
|
558
|
+
const at::Tensor& key, // (B, H, N, D)
|
|
559
|
+
const at::Tensor& value, // (B, H, N, D)
|
|
560
|
+
const at::Tensor& attn_bias, // (B, H, N_q, N_kv) or (1, H, N_q, N_kv) or (H, N_q, N_kv)
|
|
561
|
+
bool is_causal,
|
|
562
|
+
int64_t window_size,
|
|
563
|
+
int64_t bias_repeat_count // >0 means bias repeats every N batches (for window attention)
|
|
564
|
+
) {
|
|
565
|
+
// Initialize MFA on first call
|
|
566
|
+
if (!g_initialized) {
|
|
567
|
+
load_mfa_bridge();
|
|
568
|
+
if (!g_mfa_init()) {
|
|
569
|
+
throw std::runtime_error("Failed to initialize MFA");
|
|
570
|
+
}
|
|
571
|
+
g_initialized = true;
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
// Check that v6/v7 API is available
|
|
575
|
+
TORCH_CHECK(g_mfa_create_kernel_v6 || g_mfa_create_kernel_v7,
|
|
576
|
+
"Attention bias requires MFA v6+ API (update libMFABridge.dylib)");
|
|
577
|
+
TORCH_CHECK(g_mfa_forward_encode_bias,
|
|
578
|
+
"Attention bias requires mfa_forward_encode_bias");
|
|
579
|
+
|
|
580
|
+
// Validate inputs
|
|
581
|
+
TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
|
|
582
|
+
TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
|
|
583
|
+
TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
|
|
584
|
+
TORCH_CHECK(query.device().is_mps(), "Query must be on MPS device");
|
|
585
|
+
TORCH_CHECK(key.device().is_mps(), "Key must be on MPS device");
|
|
586
|
+
TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
|
|
587
|
+
TORCH_CHECK(attn_bias.device().is_mps(), "Attention bias must be on MPS device");
|
|
588
|
+
|
|
589
|
+
const int64_t batch_size = query.size(0);
|
|
590
|
+
const int64_t num_heads = query.size(1);
|
|
591
|
+
const int64_t seq_len_q = query.size(2);
|
|
592
|
+
const int64_t head_dim = query.size(3);
|
|
593
|
+
const int64_t seq_len_kv = key.size(2);
|
|
594
|
+
|
|
595
|
+
// Determine bias strides for broadcasting
|
|
596
|
+
// Bias can be: (B, H, N_q, N_kv), (1, H, N_q, N_kv), or (H, N_q, N_kv)
|
|
597
|
+
uint32_t bias_batch_stride = 0;
|
|
598
|
+
uint32_t bias_head_stride = static_cast<uint32_t>(seq_len_q * seq_len_kv);
|
|
599
|
+
|
|
600
|
+
if (attn_bias.dim() == 4) {
|
|
601
|
+
if (attn_bias.size(0) > 1) {
|
|
602
|
+
bias_batch_stride = static_cast<uint32_t>(attn_bias.size(1) * seq_len_q * seq_len_kv);
|
|
603
|
+
}
|
|
604
|
+
if (attn_bias.size(1) == 1) {
|
|
605
|
+
bias_head_stride = 0; // Broadcast across heads
|
|
606
|
+
}
|
|
607
|
+
} else if (attn_bias.dim() == 3) {
|
|
608
|
+
// (H, N_q, N_kv) - broadcast across batch
|
|
609
|
+
bias_batch_stride = 0;
|
|
610
|
+
if (attn_bias.size(0) == 1) {
|
|
611
|
+
bias_head_stride = 0;
|
|
612
|
+
}
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
// Determine precision
|
|
616
|
+
bool is_bfloat16 = (query.scalar_type() == at::kBFloat16);
|
|
617
|
+
bool is_fp16 = (query.scalar_type() == at::kHalf);
|
|
618
|
+
bool use_bf16_kernel = is_bfloat16 && g_mfa_create_kernel_v2;
|
|
619
|
+
bool low_precision = is_fp16;
|
|
620
|
+
bool low_precision_outputs = is_fp16 || use_bf16_kernel;
|
|
621
|
+
|
|
622
|
+
// Make inputs contiguous
|
|
623
|
+
auto q = query.contiguous();
|
|
624
|
+
auto k = key.contiguous();
|
|
625
|
+
auto v = value.contiguous();
|
|
626
|
+
auto bias = attn_bias.contiguous();
|
|
627
|
+
|
|
628
|
+
// For BF16 without native kernel, convert to FP32
|
|
629
|
+
if (is_bfloat16 && !use_bf16_kernel) {
|
|
630
|
+
q = q.to(at::kFloat);
|
|
631
|
+
k = k.to(at::kFloat);
|
|
632
|
+
v = v.to(at::kFloat);
|
|
633
|
+
bias = bias.to(at::kFloat);
|
|
634
|
+
}
|
|
635
|
+
|
|
636
|
+
// Allocate output
|
|
637
|
+
at::Tensor output;
|
|
638
|
+
if (use_bf16_kernel) {
|
|
639
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
640
|
+
query.options().dtype(at::kBFloat16));
|
|
641
|
+
} else if (low_precision_outputs) {
|
|
642
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
643
|
+
query.options().dtype(at::kHalf));
|
|
644
|
+
} else {
|
|
645
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
646
|
+
query.options().dtype(at::kFloat));
|
|
647
|
+
}
|
|
648
|
+
|
|
649
|
+
// Allocate logsumexp (always fp32)
|
|
650
|
+
auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
|
|
651
|
+
query.options().dtype(at::kFloat));
|
|
652
|
+
|
|
653
|
+
// Get or create kernel with bias support
|
|
654
|
+
void* kernel = get_or_create_kernel(
|
|
655
|
+
seq_len_q, seq_len_kv, head_dim,
|
|
656
|
+
low_precision, low_precision_outputs, is_causal, false, // has_mask = false
|
|
657
|
+
use_bf16_kernel,
|
|
658
|
+
static_cast<uint32_t>(window_size > 0 ? window_size : 0),
|
|
659
|
+
0, // no quantization
|
|
660
|
+
false, // bf16_backward
|
|
661
|
+
true, // has_attn_bias
|
|
662
|
+
bias_batch_stride,
|
|
663
|
+
bias_head_stride,
|
|
664
|
+
static_cast<uint32_t>(bias_repeat_count > 0 ? bias_repeat_count : 0)
|
|
665
|
+
);
|
|
666
|
+
|
|
667
|
+
// Get Metal buffers
|
|
668
|
+
auto q_info = getBufferInfo(q);
|
|
669
|
+
auto k_info = getBufferInfo(k);
|
|
670
|
+
auto v_info = getBufferInfo(v);
|
|
671
|
+
auto o_info = getBufferInfo(output);
|
|
672
|
+
auto l_info = getBufferInfo(logsumexp);
|
|
673
|
+
auto bias_info = getBufferInfo(bias);
|
|
674
|
+
|
|
675
|
+
// Execute using bias forward encode
|
|
676
|
+
@autoreleasepool {
|
|
677
|
+
auto stream = at::mps::getCurrentMPSStream();
|
|
678
|
+
id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
|
|
679
|
+
|
|
680
|
+
bool success = g_mfa_forward_encode_bias(
|
|
681
|
+
kernel,
|
|
682
|
+
(__bridge void*)encoder,
|
|
683
|
+
(__bridge void*)q_info.buffer,
|
|
684
|
+
(__bridge void*)k_info.buffer,
|
|
685
|
+
(__bridge void*)v_info.buffer,
|
|
686
|
+
(__bridge void*)o_info.buffer,
|
|
687
|
+
(__bridge void*)l_info.buffer,
|
|
688
|
+
nullptr, // no boolean mask
|
|
689
|
+
(__bridge void*)bias_info.buffer,
|
|
690
|
+
q_info.byte_offset,
|
|
691
|
+
k_info.byte_offset,
|
|
692
|
+
v_info.byte_offset,
|
|
693
|
+
o_info.byte_offset,
|
|
694
|
+
l_info.byte_offset,
|
|
695
|
+
0, // mask_offset
|
|
696
|
+
bias_info.byte_offset,
|
|
697
|
+
static_cast<int32_t>(batch_size),
|
|
698
|
+
static_cast<int32_t>(num_heads)
|
|
699
|
+
);
|
|
700
|
+
|
|
701
|
+
if (!success) {
|
|
702
|
+
throw std::runtime_error("MFA forward with bias failed");
|
|
703
|
+
}
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
// Convert output back to BF16 if needed
|
|
707
|
+
if (is_bfloat16 && !use_bf16_kernel) {
|
|
708
|
+
output = output.to(at::kBFloat16);
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
return output;
|
|
712
|
+
}
|
|
713
|
+
|
|
492
714
|
// ============================================================================
|
|
493
715
|
// Quantized Flash Attention Forward (FP8, INT8, NF4)
|
|
494
716
|
// ============================================================================
|
|
@@ -550,6 +772,9 @@ at::Tensor mps_flash_attention_forward_quantized(
|
|
|
550
772
|
TORCH_CHECK(v_scale.scalar_type() == at::kFloat,
|
|
551
773
|
"V scale must be float32");
|
|
552
774
|
|
|
775
|
+
// For NF4, K/V have packed head dimension (D//2) since 2 values per byte
|
|
776
|
+
bool is_nf4 = (quant_type == static_cast<int64_t>(QuantizationType::NF4));
|
|
777
|
+
|
|
553
778
|
const int64_t batch_size = query.size(0);
|
|
554
779
|
const int64_t num_heads_q = query.size(1);
|
|
555
780
|
const int64_t num_heads_kv = key.size(1);
|
|
@@ -557,10 +782,13 @@ at::Tensor mps_flash_attention_forward_quantized(
|
|
|
557
782
|
const int64_t head_dim = query.size(3);
|
|
558
783
|
const int64_t seq_len_kv = key.size(2);
|
|
559
784
|
|
|
785
|
+
// For NF4, expected K/V head dim is D//2 (packed)
|
|
786
|
+
int64_t expected_kv_head_dim = is_nf4 ? (head_dim / 2) : head_dim;
|
|
787
|
+
|
|
560
788
|
TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
|
|
561
789
|
"Batch size mismatch");
|
|
562
|
-
TORCH_CHECK(key.size(3) ==
|
|
563
|
-
"Head dimension mismatch");
|
|
790
|
+
TORCH_CHECK(key.size(3) == expected_kv_head_dim && value.size(3) == expected_kv_head_dim,
|
|
791
|
+
is_nf4 ? "Head dimension mismatch for NF4 (expected D//2)" : "Head dimension mismatch");
|
|
564
792
|
TORCH_CHECK(key.size(1) == value.size(1),
|
|
565
793
|
"K and V must have same number of heads");
|
|
566
794
|
|
|
@@ -970,6 +1198,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
970
1198
|
py::arg("window_size") = 0,
|
|
971
1199
|
py::arg("bf16_backward") = false);
|
|
972
1200
|
|
|
1201
|
+
// Forward with additive attention bias (e.g., relative position encoding)
|
|
1202
|
+
m.def("forward_with_bias", &mps_flash_attention_forward_with_bias,
|
|
1203
|
+
"Flash Attention forward with additive attention bias",
|
|
1204
|
+
py::arg("query"),
|
|
1205
|
+
py::arg("key"),
|
|
1206
|
+
py::arg("value"),
|
|
1207
|
+
py::arg("attn_bias"),
|
|
1208
|
+
py::arg("is_causal") = false,
|
|
1209
|
+
py::arg("window_size") = 0,
|
|
1210
|
+
py::arg("bias_repeat_count") = 0);
|
|
1211
|
+
|
|
973
1212
|
// Quantized attention (forward only - no gradients through quantized weights)
|
|
974
1213
|
m.def("forward_quantized", &mps_flash_attention_forward_quantized,
|
|
975
1214
|
"Quantized Flash Attention forward (FP8/INT8/NF4 K/V)",
|
|
Binary file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mps-flash-attn
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
|
|
5
5
|
Author: imperatormk
|
|
6
6
|
License-Expression: MIT
|
|
@@ -201,6 +201,13 @@ Python API (mps_flash_attn)
|
|
|
201
201
|
- Python 3.10+
|
|
202
202
|
- PyTorch 2.0+
|
|
203
203
|
|
|
204
|
+
## TODO / Future Optimizations
|
|
205
|
+
|
|
206
|
+
- [ ] **Batched kernel dispatch** - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
|
|
207
|
+
- [ ] **Fused QKV projection + attention** - Single kernel from input to output, avoid intermediate buffers
|
|
208
|
+
- [ ] **Pre-scaled bias option** - Allow passing pre-scaled bias to avoid per-call scaling overhead
|
|
209
|
+
- [ ] **LoRA fusion** - Fuse adapter weights into attention computation
|
|
210
|
+
|
|
204
211
|
## Credits
|
|
205
212
|
|
|
206
213
|
- [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
|
|
@@ -342,6 +342,97 @@ class TestQuantized:
|
|
|
342
342
|
assert output.shape == (B, H, N, D)
|
|
343
343
|
assert not torch.isnan(output).any()
|
|
344
344
|
|
|
345
|
+
def test_quantize_nf4(self, mfa):
|
|
346
|
+
"""Test NF4 quantization helper.
|
|
347
|
+
|
|
348
|
+
NF4 packs 2 values per byte along head dimension, so output shape is (B,H,N,D//2).
|
|
349
|
+
"""
|
|
350
|
+
B, H, N, D = 1, 4, 128, 64
|
|
351
|
+
dtype = torch.float16
|
|
352
|
+
|
|
353
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
354
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
355
|
+
|
|
356
|
+
k_q, v_q, k_s, v_s = mfa.quantize_kv_nf4(k, v)
|
|
357
|
+
|
|
358
|
+
assert k_q.dtype == torch.uint8
|
|
359
|
+
assert v_q.dtype == torch.uint8
|
|
360
|
+
assert k_s.dtype == torch.float32
|
|
361
|
+
assert v_s.dtype == torch.float32
|
|
362
|
+
# NF4 packs 2 values per byte, so D dimension is halved
|
|
363
|
+
assert k_q.shape == (B, H, N, D // 2)
|
|
364
|
+
assert v_q.shape == (B, H, N, D // 2)
|
|
365
|
+
|
|
366
|
+
def test_flash_attention_nf4(self, mfa):
|
|
367
|
+
"""Test NF4 quantized attention forward.
|
|
368
|
+
|
|
369
|
+
NF4 uses a 16-value codebook for 4-bit quantization, packing 2 values per byte.
|
|
370
|
+
This provides 4x memory reduction for K/V cache with acceptable accuracy loss.
|
|
371
|
+
"""
|
|
372
|
+
B, H, N, D = 2, 8, 128, 64
|
|
373
|
+
dtype = torch.float16
|
|
374
|
+
|
|
375
|
+
torch.manual_seed(42)
|
|
376
|
+
q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
377
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
378
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
379
|
+
|
|
380
|
+
k_q, v_q, k_s, v_s = mfa.quantize_kv_nf4(k, v)
|
|
381
|
+
output = mfa.flash_attention_nf4(q, k_q, v_q, k_s, v_s)
|
|
382
|
+
|
|
383
|
+
assert output.shape == (B, H, N, D)
|
|
384
|
+
assert not torch.isnan(output).any()
|
|
385
|
+
assert not torch.isinf(output).any()
|
|
386
|
+
|
|
387
|
+
def test_flash_attention_nf4_correctness(self, mfa):
|
|
388
|
+
"""Test NF4 attention correctness against reference.
|
|
389
|
+
|
|
390
|
+
NF4 is 4-bit so we expect larger error than FP8/INT8, but output
|
|
391
|
+
should still be in a reasonable range (max diff < 0.5).
|
|
392
|
+
"""
|
|
393
|
+
B, H, N, D = 1, 4, 64, 64
|
|
394
|
+
dtype = torch.float16
|
|
395
|
+
|
|
396
|
+
torch.manual_seed(42)
|
|
397
|
+
q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
398
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
399
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
400
|
+
|
|
401
|
+
# Reference
|
|
402
|
+
ref = reference_attention(q, k, v)
|
|
403
|
+
|
|
404
|
+
# NF4
|
|
405
|
+
k_q, v_q, k_s, v_s = mfa.quantize_kv_nf4(k, v)
|
|
406
|
+
output = mfa.flash_attention_nf4(q, k_q, v_q, k_s, v_s)
|
|
407
|
+
|
|
408
|
+
max_diff = (ref - output).abs().max().item()
|
|
409
|
+
mean_diff = (ref - output).abs().mean().item()
|
|
410
|
+
|
|
411
|
+
# 4-bit quantization has larger error, but should be bounded
|
|
412
|
+
assert max_diff < 0.5, f"NF4 max diff {max_diff} exceeds threshold 0.5"
|
|
413
|
+
assert mean_diff < 0.1, f"NF4 mean diff {mean_diff} exceeds threshold 0.1"
|
|
414
|
+
|
|
415
|
+
@pytest.mark.parametrize("config", [
|
|
416
|
+
(1, 1, 32, 32), # Small
|
|
417
|
+
(1, 8, 128, 64), # Medium
|
|
418
|
+
(2, 16, 256, 64), # Large
|
|
419
|
+
(1, 8, 512, 128), # Large head dim
|
|
420
|
+
])
|
|
421
|
+
def test_flash_attention_nf4_various_sizes(self, mfa, config):
|
|
422
|
+
"""Test NF4 attention with various tensor sizes."""
|
|
423
|
+
B, H, N, D = config
|
|
424
|
+
dtype = torch.float16
|
|
425
|
+
|
|
426
|
+
q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
427
|
+
k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
428
|
+
v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
|
|
429
|
+
|
|
430
|
+
k_q, v_q, k_s, v_s = mfa.quantize_kv_nf4(k, v)
|
|
431
|
+
output = mfa.flash_attention_nf4(q, k_q, v_q, k_s, v_s)
|
|
432
|
+
|
|
433
|
+
assert output.shape == (B, H, N, D)
|
|
434
|
+
assert not torch.isnan(output).any()
|
|
435
|
+
|
|
345
436
|
|
|
346
437
|
# =============================================================================
|
|
347
438
|
# Chunked/Streaming Attention Tests
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|