mps-flash-attn 0.2.1__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (40) hide show
  1. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/PKG-INFO +8 -1
  2. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/README.md +7 -0
  3. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/__init__.py +241 -1
  4. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/csrc/mps_flash_attn.mm +246 -7
  5. mps_flash_attn-0.2.5/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  6. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/PKG-INFO +8 -1
  7. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/pyproject.toml +1 -1
  8. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/tests/test_mfa_v2.py +91 -0
  9. mps_flash_attn-0.2.1/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  10. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/LICENSE +0 -0
  11. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/benchmark.py +0 -0
  12. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  13. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  14. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  15. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  16. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  17. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  18. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  19. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  20. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  21. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  22. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  23. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  24. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  25. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  26. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  27. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  28. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  29. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  30. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  31. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  32. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  33. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  34. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn/kernels/manifest.json +0 -0
  35. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
  36. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  37. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/requires.txt +0 -0
  38. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/mps_flash_attn.egg-info/top_level.txt +0 -0
  39. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/setup.cfg +0 -0
  40. {mps_flash_attn-0.2.1 → mps_flash_attn-0.2.5}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.1
3
+ Version: 0.2.5
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -201,6 +201,13 @@ Python API (mps_flash_attn)
201
201
  - Python 3.10+
202
202
  - PyTorch 2.0+
203
203
 
204
+ ## TODO / Future Optimizations
205
+
206
+ - [ ] **Batched kernel dispatch** - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
207
+ - [ ] **Fused QKV projection + attention** - Single kernel from input to output, avoid intermediate buffers
208
+ - [ ] **Pre-scaled bias option** - Allow passing pre-scaled bias to avoid per-call scaling overhead
209
+ - [ ] **LoRA fusion** - Fuse adapter weights into attention computation
210
+
204
211
  ## Credits
205
212
 
206
213
  - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
@@ -176,6 +176,13 @@ Python API (mps_flash_attn)
176
176
  - Python 3.10+
177
177
  - PyTorch 2.0+
178
178
 
179
+ ## TODO / Future Optimizations
180
+
181
+ - [ ] **Batched kernel dispatch** - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
182
+ - [ ] **Fused QKV projection + attention** - Single kernel from input to output, avoid intermediate buffers
183
+ - [ ] **Pre-scaled bias option** - Allow passing pre-scaled bias to avoid per-call scaling overhead
184
+ - [ ] **LoRA fusion** - Fuse adapter weights into attention computation
185
+
179
186
  ## Credits
180
187
 
181
188
  - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.2.1"
7
+ __version__ = "0.2.5"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -296,6 +296,86 @@ def precompile():
296
296
  print("\nPre-compilation complete! Kernels cached to disk.")
297
297
 
298
298
 
299
+ def flash_attention_with_bias(
300
+ query: torch.Tensor,
301
+ key: torch.Tensor,
302
+ value: torch.Tensor,
303
+ attn_bias: torch.Tensor,
304
+ is_causal: bool = False,
305
+ window_size: int = 0,
306
+ bias_repeat_count: int = 0,
307
+ ) -> torch.Tensor:
308
+ """
309
+ Compute scaled dot-product attention with additive attention bias.
310
+
311
+ This function supports additive attention bias (like relative position encodings
312
+ or ALiBi) which is added to the attention scores before softmax:
313
+
314
+ Attention(Q, K, V) = softmax((Q @ K.T) / sqrt(d) + bias) @ V
315
+
316
+ IMPORTANT: MFA adds bias to UNSCALED scores internally and scales during softmax.
317
+ If your bias was computed for scaled scores (like PyTorch SDPA), you need to
318
+ pre-scale it by multiplying by sqrt(head_dim).
319
+
320
+ Args:
321
+ query: Query tensor of shape (B, H, N_q, D)
322
+ key: Key tensor of shape (B, H, N_kv, D)
323
+ value: Value tensor of shape (B, H, N_kv, D)
324
+ attn_bias: Additive attention bias of shape:
325
+ - (B, H, N_q, N_kv): Full bias for each batch/head
326
+ - (1, H, N_q, N_kv): Broadcast across batch
327
+ - (H, N_q, N_kv): Broadcast across batch (3D)
328
+ is_causal: If True, applies causal masking
329
+ window_size: Sliding window attention size (0 = full attention)
330
+ bias_repeat_count: If > 0, the bias tensor repeats every N batches.
331
+ Useful for window attention where multiple windows share the same
332
+ position bias pattern. E.g., for Swin Transformer with 4 windows,
333
+ set bias_repeat_count=num_windows so bias[batch_idx % num_windows]
334
+ is used.
335
+
336
+ Returns:
337
+ Output tensor of shape (B, H, N_q, D)
338
+
339
+ Example:
340
+ >>> # Relative position bias (Swin Transformer style)
341
+ >>> q = torch.randn(4, 8, 64, 64, device='mps', dtype=torch.float16)
342
+ >>> k = torch.randn(4, 8, 64, 64, device='mps', dtype=torch.float16)
343
+ >>> v = torch.randn(4, 8, 64, 64, device='mps', dtype=torch.float16)
344
+ >>> # Position bias: (1, num_heads, seq_len, seq_len)
345
+ >>> bias = torch.randn(1, 8, 64, 64, device='mps', dtype=torch.float16)
346
+ >>> # Pre-scale bias since MFA uses unscaled scores
347
+ >>> scaled_bias = bias * math.sqrt(64) # sqrt(head_dim)
348
+ >>> out = flash_attention_with_bias(q, k, v, scaled_bias)
349
+
350
+ >>> # Window attention with repeating bias pattern
351
+ >>> n_windows = 16
352
+ >>> q = torch.randn(n_windows * 4, 8, 49, 64, device='mps', dtype=torch.float16)
353
+ >>> bias = torch.randn(n_windows, 8, 49, 49, device='mps', dtype=torch.float16)
354
+ >>> scaled_bias = bias * math.sqrt(64)
355
+ >>> out = flash_attention_with_bias(q, k, v, scaled_bias, bias_repeat_count=n_windows)
356
+ """
357
+ if not _HAS_MFA:
358
+ raise RuntimeError(
359
+ f"MPS Flash Attention C++ extension not available: {_IMPORT_ERROR}\n"
360
+ "Please rebuild with: pip install -e ."
361
+ )
362
+
363
+ if not torch.backends.mps.is_available():
364
+ raise RuntimeError("MPS not available")
365
+
366
+ # Validate device
367
+ if query.device.type != 'mps':
368
+ raise ValueError("query must be on MPS device")
369
+ if key.device.type != 'mps':
370
+ raise ValueError("key must be on MPS device")
371
+ if value.device.type != 'mps':
372
+ raise ValueError("value must be on MPS device")
373
+ if attn_bias.device.type != 'mps':
374
+ raise ValueError("attn_bias must be on MPS device")
375
+
376
+ return _C.forward_with_bias(query, key, value, attn_bias, is_causal, window_size, bias_repeat_count)
377
+
378
+
299
379
  def flash_attention_chunked(
300
380
  query: torch.Tensor,
301
381
  key: torch.Tensor,
@@ -541,6 +621,80 @@ def quantize_kv_int8(
541
621
  return k_quant, v_quant, k_scale, v_scale
542
622
 
543
623
 
624
+ # NF4 codebook (must match Metal shader's NF4_CODEBOOK exactly)
625
+ _NF4_CODEBOOK = torch.tensor([
626
+ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
627
+ -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
628
+ 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
629
+ 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
630
+ ], dtype=torch.float32)
631
+
632
+
633
+ def quantize_kv_nf4(
634
+ key: torch.Tensor,
635
+ value: torch.Tensor,
636
+ ) -> tuple:
637
+ """
638
+ Quantize Key and Value tensors to NF4 (NormalFloat 4-bit) format.
639
+
640
+ NF4 quantization provides 4x memory reduction using a 16-value codebook
641
+ optimized for normally distributed weights. Two values are packed per byte.
642
+
643
+ Args:
644
+ key: Key tensor of shape (B, H, N, D) where D must be even
645
+ value: Value tensor of shape (B, H, N, D) where D must be even
646
+
647
+ Returns:
648
+ Tuple of (key_quant, value_quant, k_scale, v_scale) where:
649
+ - key_quant, value_quant: uint8 tensors of shape (B, H, N, D//2) with packed values
650
+ - k_scale, v_scale: float32 tensors with per-head scale factors (B, H)
651
+
652
+ Example:
653
+ >>> k_q, v_q, k_s, v_s = quantize_kv_nf4(key, value)
654
+ >>> out = flash_attention_nf4(query, k_q, v_q, k_s, v_s)
655
+ """
656
+ if not _HAS_MFA:
657
+ raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
658
+
659
+ def _quantize_nf4(tensor: torch.Tensor) -> tuple:
660
+ """Quantize a single tensor to NF4 format."""
661
+ B, H, N, D = tensor.shape
662
+ if D % 2 != 0:
663
+ raise ValueError(f"Head dimension D must be even for NF4 quantization, got D={D}")
664
+
665
+ # Convert to float32 for quantization
666
+ t = tensor.float()
667
+
668
+ # Compute per-head absmax for scale
669
+ abs_max = t.abs().amax(dim=(2, 3)) # (B, H)
670
+ scale = abs_max.clamp_min(1e-12)
671
+
672
+ # Normalize to [-1, 1] range
673
+ scale_expanded = scale.unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1)
674
+ normalized = t / scale_expanded
675
+
676
+ # Find nearest NF4 codebook entry for each value
677
+ codebook = _NF4_CODEBOOK.to(tensor.device) # (16,)
678
+ # Reshape for broadcasting: normalized is (B, H, N, D), codebook is (16,)
679
+ # Compute distances to all codebook entries
680
+ flat = normalized.reshape(-1, 1) # (B*H*N*D, 1)
681
+ distances = (flat - codebook.unsqueeze(0)).abs() # (B*H*N*D, 16)
682
+ indices = distances.argmin(dim=1) # (B*H*N*D,)
683
+ indices = indices.reshape(B, H, N, D) # (B, H, N, D)
684
+
685
+ # Pack two 4-bit indices per byte
686
+ # Even indices go to low nibble, odd indices go to high nibble
687
+ indices_even = indices[:, :, :, 0::2] # (B, H, N, D//2)
688
+ indices_odd = indices[:, :, :, 1::2] # (B, H, N, D//2)
689
+ packed = (indices_even | (indices_odd << 4)).to(torch.uint8) # (B, H, N, D//2)
690
+
691
+ return packed, scale
692
+
693
+ k_quant, k_scale = _quantize_nf4(key)
694
+ v_quant, v_scale = _quantize_nf4(value)
695
+ return k_quant, v_quant, k_scale, v_scale
696
+
697
+
544
698
  def flash_attention_fp8(
545
699
  query: torch.Tensor,
546
700
  key: torch.Tensor,
@@ -551,6 +705,7 @@ def flash_attention_fp8(
551
705
  attn_mask: Optional[torch.Tensor] = None,
552
706
  window_size: int = 0,
553
707
  use_e5m2: bool = False,
708
+ scale: Optional[float] = None,
554
709
  ) -> torch.Tensor:
555
710
  """
556
711
  Compute attention with FP8 quantized Key/Value tensors.
@@ -568,6 +723,7 @@ def flash_attention_fp8(
568
723
  attn_mask: Optional boolean attention mask
569
724
  window_size: Sliding window size (0 = full attention)
570
725
  use_e5m2: If True, use E5M2 format. Default: False (E4M3)
726
+ scale: Softmax scale factor. If None, uses 1/sqrt(head_dim)
571
727
 
572
728
  Returns:
573
729
  Output tensor of shape (B, H, N, D)
@@ -581,6 +737,14 @@ def flash_attention_fp8(
581
737
  if not _HAS_MFA:
582
738
  raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
583
739
 
740
+ # Apply custom scale by pre-scaling Q
741
+ if scale is not None:
742
+ head_dim = query.shape[-1]
743
+ default_scale = 1.0 / math.sqrt(head_dim)
744
+ if abs(scale - default_scale) > 1e-9:
745
+ scale_factor = scale / default_scale
746
+ query = query * scale_factor
747
+
584
748
  quant_type = QUANT_FP8_E5M2 if use_e5m2 else QUANT_FP8_E4M3
585
749
  return _C.forward_quantized(
586
750
  query, key, value, k_scale, v_scale,
@@ -597,6 +761,7 @@ def flash_attention_int8(
597
761
  is_causal: bool = False,
598
762
  attn_mask: Optional[torch.Tensor] = None,
599
763
  window_size: int = 0,
764
+ scale: Optional[float] = None,
600
765
  ) -> torch.Tensor:
601
766
  """
602
767
  Compute attention with INT8 quantized Key/Value tensors.
@@ -613,6 +778,7 @@ def flash_attention_int8(
613
778
  is_causal: If True, applies causal masking
614
779
  attn_mask: Optional boolean attention mask
615
780
  window_size: Sliding window size (0 = full attention)
781
+ scale: Softmax scale factor. If None, uses 1/sqrt(head_dim)
616
782
 
617
783
  Returns:
618
784
  Output tensor of shape (B, H, N, D)
@@ -624,12 +790,76 @@ def flash_attention_int8(
624
790
  if not _HAS_MFA:
625
791
  raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
626
792
 
793
+ # Apply custom scale by pre-scaling Q
794
+ if scale is not None:
795
+ head_dim = query.shape[-1]
796
+ default_scale = 1.0 / math.sqrt(head_dim)
797
+ if abs(scale - default_scale) > 1e-9:
798
+ scale_factor = scale / default_scale
799
+ query = query * scale_factor
800
+
627
801
  return _C.forward_quantized(
628
802
  query, key, value, k_scale, v_scale,
629
803
  QUANT_INT8, is_causal, attn_mask, window_size
630
804
  )
631
805
 
632
806
 
807
+ def flash_attention_nf4(
808
+ query: torch.Tensor,
809
+ key: torch.Tensor,
810
+ value: torch.Tensor,
811
+ k_scale: torch.Tensor,
812
+ v_scale: torch.Tensor,
813
+ is_causal: bool = False,
814
+ attn_mask: Optional[torch.Tensor] = None,
815
+ window_size: int = 0,
816
+ scale: Optional[float] = None,
817
+ ) -> torch.Tensor:
818
+ """
819
+ Compute attention with NF4 (NormalFloat 4-bit) quantized Key/Value tensors.
820
+
821
+ This function provides 4x memory reduction for K/V cache using NF4 quantization
822
+ with a 16-value codebook optimized for normally distributed weights.
823
+
824
+ NF4 packs two 4-bit values per byte, so the key/value tensors have shape
825
+ (B, H, N, D//2) where D is the original head dimension.
826
+
827
+ Args:
828
+ query: Query tensor (B, H, N, D) in FP16/BF16/FP32
829
+ key: Quantized Key tensor (B, H, N, D//2) as uint8 (packed NF4)
830
+ value: Quantized Value tensor (B, H, N, D//2) as uint8 (packed NF4)
831
+ k_scale: Per-head scale for K (B, H) or (H,)
832
+ v_scale: Per-head scale for V (B, H) or (H,)
833
+ is_causal: If True, applies causal masking
834
+ attn_mask: Optional boolean attention mask
835
+ window_size: Sliding window size (0 = full attention)
836
+ scale: Softmax scale factor. If None, uses 1/sqrt(head_dim)
837
+
838
+ Returns:
839
+ Output tensor of shape (B, H, N, D)
840
+
841
+ Example:
842
+ >>> k_q, v_q, k_s, v_s = quantize_kv_nf4(key, value)
843
+ >>> out = flash_attention_nf4(query, k_q, v_q, k_s, v_s)
844
+ """
845
+ if not _HAS_MFA:
846
+ raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
847
+
848
+ # Apply custom scale by pre-scaling Q
849
+ # Kernel uses 1/sqrt(D), so we adjust Q to achieve desired scale
850
+ if scale is not None:
851
+ head_dim = query.shape[-1]
852
+ default_scale = 1.0 / math.sqrt(head_dim)
853
+ if abs(scale - default_scale) > 1e-9:
854
+ scale_factor = scale / default_scale
855
+ query = query * scale_factor
856
+
857
+ return _C.forward_quantized(
858
+ query, key, value, k_scale, v_scale,
859
+ QUANT_NF4, is_causal, attn_mask, window_size
860
+ )
861
+
862
+
633
863
  def flash_attention_quantized(
634
864
  query: torch.Tensor,
635
865
  key: torch.Tensor,
@@ -640,6 +870,7 @@ def flash_attention_quantized(
640
870
  is_causal: bool = False,
641
871
  attn_mask: Optional[torch.Tensor] = None,
642
872
  window_size: int = 0,
873
+ scale: Optional[float] = None,
643
874
  ) -> torch.Tensor:
644
875
  """
645
876
  Generic quantized attention with configurable quantization type.
@@ -661,6 +892,7 @@ def flash_attention_quantized(
661
892
  is_causal: If True, applies causal masking
662
893
  attn_mask: Optional boolean attention mask
663
894
  window_size: Sliding window size (0 = full attention)
895
+ scale: Softmax scale factor. If None, uses 1/sqrt(head_dim)
664
896
 
665
897
  Returns:
666
898
  Output tensor of shape (B, H, N, D)
@@ -668,6 +900,14 @@ def flash_attention_quantized(
668
900
  if not _HAS_MFA:
669
901
  raise RuntimeError(f"MPS Flash Attention not available: {_IMPORT_ERROR}")
670
902
 
903
+ # Apply custom scale by pre-scaling Q
904
+ if scale is not None:
905
+ head_dim = query.shape[-1]
906
+ default_scale = 1.0 / math.sqrt(head_dim)
907
+ if abs(scale - default_scale) > 1e-9:
908
+ scale_factor = scale / default_scale
909
+ query = query * scale_factor
910
+
671
911
  return _C.forward_quantized(
672
912
  query, key, value, k_scale, v_scale,
673
913
  quant_type, is_causal, attn_mask, window_size
@@ -25,11 +25,16 @@ typedef void* (*mfa_create_kernel_v2_fn)(int32_t, int32_t, int32_t, bool, bool,
25
25
  typedef void* (*mfa_create_kernel_v3_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t);
26
26
  typedef void* (*mfa_create_kernel_v4_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t, uint16_t);
27
27
  typedef void* (*mfa_create_kernel_v5_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t, uint16_t, bool);
28
+ typedef void* (*mfa_create_kernel_v6_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t, uint16_t, bool, bool, uint32_t, uint32_t);
29
+ typedef void* (*mfa_create_kernel_v7_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool, uint32_t, uint16_t, bool, bool, uint32_t, uint32_t, uint32_t);
28
30
  // New zero-sync encode functions that take PyTorch's command encoder
29
31
  // Added mask_ptr and mask_offset parameters
30
32
  typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
31
33
  int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
32
34
  int32_t, int32_t);
35
+ typedef bool (*mfa_forward_encode_bias_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*,
36
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
37
+ int32_t, int32_t);
33
38
  typedef bool (*mfa_forward_encode_quantized_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
34
39
  int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
35
40
  int32_t, int32_t);
@@ -52,7 +57,10 @@ static mfa_create_kernel_v2_fn g_mfa_create_kernel_v2 = nullptr;
52
57
  static mfa_create_kernel_v3_fn g_mfa_create_kernel_v3 = nullptr;
53
58
  static mfa_create_kernel_v4_fn g_mfa_create_kernel_v4 = nullptr;
54
59
  static mfa_create_kernel_v5_fn g_mfa_create_kernel_v5 = nullptr;
60
+ static mfa_create_kernel_v6_fn g_mfa_create_kernel_v6 = nullptr;
61
+ static mfa_create_kernel_v7_fn g_mfa_create_kernel_v7 = nullptr;
55
62
  static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
63
+ static mfa_forward_encode_bias_fn g_mfa_forward_encode_bias = nullptr;
56
64
  static mfa_forward_encode_quantized_fn g_mfa_forward_encode_quantized = nullptr;
57
65
  static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
58
66
  static mfa_forward_fn g_mfa_forward = nullptr;
@@ -115,8 +123,11 @@ static bool load_mfa_bridge() {
115
123
  g_mfa_create_kernel_v3 = (mfa_create_kernel_v3_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v3");
116
124
  g_mfa_create_kernel_v4 = (mfa_create_kernel_v4_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v4");
117
125
  g_mfa_create_kernel_v5 = (mfa_create_kernel_v5_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v5");
126
+ g_mfa_create_kernel_v6 = (mfa_create_kernel_v6_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v6");
127
+ g_mfa_create_kernel_v7 = (mfa_create_kernel_v7_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v7");
118
128
  g_mfa_forward_encode_quantized = (mfa_forward_encode_quantized_fn)dlsym(g_dylib_handle, "mfa_forward_encode_quantized");
119
129
  g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
130
+ g_mfa_forward_encode_bias = (mfa_forward_encode_bias_fn)dlsym(g_dylib_handle, "mfa_forward_encode_bias");
120
131
  g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
121
132
  g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
122
133
  g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
@@ -169,6 +180,10 @@ struct KernelCacheKey {
169
180
  uint32_t window_size;
170
181
  uint16_t quantized_kv; // 0 = none, 3 = FP8_E4M3, 4 = FP8_E5M2, 5 = INT8, 6 = NF4
171
182
  bool bf16_backward; // true = use BF16 for backward intermediates (faster)
183
+ bool has_attn_bias; // true = additive attention bias
184
+ uint32_t bias_batch_stride;
185
+ uint32_t bias_head_stride;
186
+ uint32_t bias_repeat_count;
172
187
 
173
188
  bool operator==(const KernelCacheKey& other) const {
174
189
  return seq_len_q == other.seq_len_q &&
@@ -181,7 +196,11 @@ struct KernelCacheKey {
181
196
  use_bf16 == other.use_bf16 &&
182
197
  window_size == other.window_size &&
183
198
  quantized_kv == other.quantized_kv &&
184
- bf16_backward == other.bf16_backward;
199
+ bf16_backward == other.bf16_backward &&
200
+ has_attn_bias == other.has_attn_bias &&
201
+ bias_batch_stride == other.bias_batch_stride &&
202
+ bias_head_stride == other.bias_head_stride &&
203
+ bias_repeat_count == other.bias_repeat_count;
185
204
  }
186
205
  };
187
206
 
@@ -197,14 +216,18 @@ struct KernelCacheKeyHash {
197
216
  (std::hash<bool>()(k.use_bf16) << 7) ^
198
217
  (std::hash<uint32_t>()(k.window_size) << 8) ^
199
218
  (std::hash<uint16_t>()(k.quantized_kv) << 9) ^
200
- (std::hash<bool>()(k.bf16_backward) << 10);
219
+ (std::hash<bool>()(k.bf16_backward) << 10) ^
220
+ (std::hash<bool>()(k.has_attn_bias) << 11) ^
221
+ (std::hash<uint32_t>()(k.bias_batch_stride) << 12) ^
222
+ (std::hash<uint32_t>()(k.bias_head_stride) << 13) ^
223
+ (std::hash<uint32_t>()(k.bias_repeat_count) << 14);
201
224
  }
202
225
  };
203
226
 
204
227
  static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
205
228
 
206
- static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false, uint32_t window_size = 0, uint16_t quantized_kv = 0, bool bf16_backward = false) {
207
- KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16, window_size, quantized_kv, bf16_backward};
229
+ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false, uint32_t window_size = 0, uint16_t quantized_kv = 0, bool bf16_backward = false, bool has_attn_bias = false, uint32_t bias_batch_stride = 0, uint32_t bias_head_stride = 0, uint32_t bias_repeat_count = 0) {
230
+ KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16, window_size, quantized_kv, bf16_backward, has_attn_bias, bias_batch_stride, bias_head_stride, bias_repeat_count};
208
231
 
209
232
  auto it = g_kernel_cache.find(key);
210
233
  if (it != g_kernel_cache.end()) {
@@ -212,7 +235,44 @@ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_di
212
235
  }
213
236
 
214
237
  void* kernel = nullptr;
215
- if (g_mfa_create_kernel_v5) {
238
+ if (has_attn_bias && g_mfa_create_kernel_v7) {
239
+ // Use v7 API with additive attention bias and repeat support
240
+ kernel = g_mfa_create_kernel_v7(
241
+ static_cast<int32_t>(seq_q),
242
+ static_cast<int32_t>(seq_kv),
243
+ static_cast<int32_t>(head_dim),
244
+ low_prec,
245
+ low_prec_outputs,
246
+ causal,
247
+ has_mask,
248
+ use_bf16,
249
+ window_size,
250
+ quantized_kv,
251
+ bf16_backward,
252
+ has_attn_bias,
253
+ bias_batch_stride,
254
+ bias_head_stride,
255
+ bias_repeat_count
256
+ );
257
+ } else if (has_attn_bias && g_mfa_create_kernel_v6) {
258
+ // Use v6 API with additive attention bias (no repeat)
259
+ kernel = g_mfa_create_kernel_v6(
260
+ static_cast<int32_t>(seq_q),
261
+ static_cast<int32_t>(seq_kv),
262
+ static_cast<int32_t>(head_dim),
263
+ low_prec,
264
+ low_prec_outputs,
265
+ causal,
266
+ has_mask,
267
+ use_bf16,
268
+ window_size,
269
+ quantized_kv,
270
+ bf16_backward,
271
+ has_attn_bias,
272
+ bias_batch_stride,
273
+ bias_head_stride
274
+ );
275
+ } else if (g_mfa_create_kernel_v5) {
216
276
  // Use v5 API with all features including mixed-precision backward
217
277
  kernel = g_mfa_create_kernel_v5(
218
278
  static_cast<int32_t>(seq_q),
@@ -489,6 +549,168 @@ at::Tensor mps_flash_attention_forward(
489
549
  return output;
490
550
  }
491
551
 
552
+ // ============================================================================
553
+ // Flash Attention Forward with Additive Bias
554
+ // ============================================================================
555
+
556
+ at::Tensor mps_flash_attention_forward_with_bias(
557
+ const at::Tensor& query, // (B, H, N, D)
558
+ const at::Tensor& key, // (B, H, N, D)
559
+ const at::Tensor& value, // (B, H, N, D)
560
+ const at::Tensor& attn_bias, // (B, H, N_q, N_kv) or (1, H, N_q, N_kv) or (H, N_q, N_kv)
561
+ bool is_causal,
562
+ int64_t window_size,
563
+ int64_t bias_repeat_count // >0 means bias repeats every N batches (for window attention)
564
+ ) {
565
+ // Initialize MFA on first call
566
+ if (!g_initialized) {
567
+ load_mfa_bridge();
568
+ if (!g_mfa_init()) {
569
+ throw std::runtime_error("Failed to initialize MFA");
570
+ }
571
+ g_initialized = true;
572
+ }
573
+
574
+ // Check that v6/v7 API is available
575
+ TORCH_CHECK(g_mfa_create_kernel_v6 || g_mfa_create_kernel_v7,
576
+ "Attention bias requires MFA v6+ API (update libMFABridge.dylib)");
577
+ TORCH_CHECK(g_mfa_forward_encode_bias,
578
+ "Attention bias requires mfa_forward_encode_bias");
579
+
580
+ // Validate inputs
581
+ TORCH_CHECK(query.dim() == 4, "Query must be 4D (B, H, N, D)");
582
+ TORCH_CHECK(key.dim() == 4, "Key must be 4D (B, H, N, D)");
583
+ TORCH_CHECK(value.dim() == 4, "Value must be 4D (B, H, N, D)");
584
+ TORCH_CHECK(query.device().is_mps(), "Query must be on MPS device");
585
+ TORCH_CHECK(key.device().is_mps(), "Key must be on MPS device");
586
+ TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
587
+ TORCH_CHECK(attn_bias.device().is_mps(), "Attention bias must be on MPS device");
588
+
589
+ const int64_t batch_size = query.size(0);
590
+ const int64_t num_heads = query.size(1);
591
+ const int64_t seq_len_q = query.size(2);
592
+ const int64_t head_dim = query.size(3);
593
+ const int64_t seq_len_kv = key.size(2);
594
+
595
+ // Determine bias strides for broadcasting
596
+ // Bias can be: (B, H, N_q, N_kv), (1, H, N_q, N_kv), or (H, N_q, N_kv)
597
+ uint32_t bias_batch_stride = 0;
598
+ uint32_t bias_head_stride = static_cast<uint32_t>(seq_len_q * seq_len_kv);
599
+
600
+ if (attn_bias.dim() == 4) {
601
+ if (attn_bias.size(0) > 1) {
602
+ bias_batch_stride = static_cast<uint32_t>(attn_bias.size(1) * seq_len_q * seq_len_kv);
603
+ }
604
+ if (attn_bias.size(1) == 1) {
605
+ bias_head_stride = 0; // Broadcast across heads
606
+ }
607
+ } else if (attn_bias.dim() == 3) {
608
+ // (H, N_q, N_kv) - broadcast across batch
609
+ bias_batch_stride = 0;
610
+ if (attn_bias.size(0) == 1) {
611
+ bias_head_stride = 0;
612
+ }
613
+ }
614
+
615
+ // Determine precision
616
+ bool is_bfloat16 = (query.scalar_type() == at::kBFloat16);
617
+ bool is_fp16 = (query.scalar_type() == at::kHalf);
618
+ bool use_bf16_kernel = is_bfloat16 && g_mfa_create_kernel_v2;
619
+ bool low_precision = is_fp16;
620
+ bool low_precision_outputs = is_fp16 || use_bf16_kernel;
621
+
622
+ // Make inputs contiguous
623
+ auto q = query.contiguous();
624
+ auto k = key.contiguous();
625
+ auto v = value.contiguous();
626
+ auto bias = attn_bias.contiguous();
627
+
628
+ // For BF16 without native kernel, convert to FP32
629
+ if (is_bfloat16 && !use_bf16_kernel) {
630
+ q = q.to(at::kFloat);
631
+ k = k.to(at::kFloat);
632
+ v = v.to(at::kFloat);
633
+ bias = bias.to(at::kFloat);
634
+ }
635
+
636
+ // Allocate output
637
+ at::Tensor output;
638
+ if (use_bf16_kernel) {
639
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
640
+ query.options().dtype(at::kBFloat16));
641
+ } else if (low_precision_outputs) {
642
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
643
+ query.options().dtype(at::kHalf));
644
+ } else {
645
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
646
+ query.options().dtype(at::kFloat));
647
+ }
648
+
649
+ // Allocate logsumexp (always fp32)
650
+ auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
651
+ query.options().dtype(at::kFloat));
652
+
653
+ // Get or create kernel with bias support
654
+ void* kernel = get_or_create_kernel(
655
+ seq_len_q, seq_len_kv, head_dim,
656
+ low_precision, low_precision_outputs, is_causal, false, // has_mask = false
657
+ use_bf16_kernel,
658
+ static_cast<uint32_t>(window_size > 0 ? window_size : 0),
659
+ 0, // no quantization
660
+ false, // bf16_backward
661
+ true, // has_attn_bias
662
+ bias_batch_stride,
663
+ bias_head_stride,
664
+ static_cast<uint32_t>(bias_repeat_count > 0 ? bias_repeat_count : 0)
665
+ );
666
+
667
+ // Get Metal buffers
668
+ auto q_info = getBufferInfo(q);
669
+ auto k_info = getBufferInfo(k);
670
+ auto v_info = getBufferInfo(v);
671
+ auto o_info = getBufferInfo(output);
672
+ auto l_info = getBufferInfo(logsumexp);
673
+ auto bias_info = getBufferInfo(bias);
674
+
675
+ // Execute using bias forward encode
676
+ @autoreleasepool {
677
+ auto stream = at::mps::getCurrentMPSStream();
678
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
679
+
680
+ bool success = g_mfa_forward_encode_bias(
681
+ kernel,
682
+ (__bridge void*)encoder,
683
+ (__bridge void*)q_info.buffer,
684
+ (__bridge void*)k_info.buffer,
685
+ (__bridge void*)v_info.buffer,
686
+ (__bridge void*)o_info.buffer,
687
+ (__bridge void*)l_info.buffer,
688
+ nullptr, // no boolean mask
689
+ (__bridge void*)bias_info.buffer,
690
+ q_info.byte_offset,
691
+ k_info.byte_offset,
692
+ v_info.byte_offset,
693
+ o_info.byte_offset,
694
+ l_info.byte_offset,
695
+ 0, // mask_offset
696
+ bias_info.byte_offset,
697
+ static_cast<int32_t>(batch_size),
698
+ static_cast<int32_t>(num_heads)
699
+ );
700
+
701
+ if (!success) {
702
+ throw std::runtime_error("MFA forward with bias failed");
703
+ }
704
+ }
705
+
706
+ // Convert output back to BF16 if needed
707
+ if (is_bfloat16 && !use_bf16_kernel) {
708
+ output = output.to(at::kBFloat16);
709
+ }
710
+
711
+ return output;
712
+ }
713
+
492
714
  // ============================================================================
493
715
  // Quantized Flash Attention Forward (FP8, INT8, NF4)
494
716
  // ============================================================================
@@ -550,6 +772,9 @@ at::Tensor mps_flash_attention_forward_quantized(
550
772
  TORCH_CHECK(v_scale.scalar_type() == at::kFloat,
551
773
  "V scale must be float32");
552
774
 
775
+ // For NF4, K/V have packed head dimension (D//2) since 2 values per byte
776
+ bool is_nf4 = (quant_type == static_cast<int64_t>(QuantizationType::NF4));
777
+
553
778
  const int64_t batch_size = query.size(0);
554
779
  const int64_t num_heads_q = query.size(1);
555
780
  const int64_t num_heads_kv = key.size(1);
@@ -557,10 +782,13 @@ at::Tensor mps_flash_attention_forward_quantized(
557
782
  const int64_t head_dim = query.size(3);
558
783
  const int64_t seq_len_kv = key.size(2);
559
784
 
785
+ // For NF4, expected K/V head dim is D//2 (packed)
786
+ int64_t expected_kv_head_dim = is_nf4 ? (head_dim / 2) : head_dim;
787
+
560
788
  TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
561
789
  "Batch size mismatch");
562
- TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
563
- "Head dimension mismatch");
790
+ TORCH_CHECK(key.size(3) == expected_kv_head_dim && value.size(3) == expected_kv_head_dim,
791
+ is_nf4 ? "Head dimension mismatch for NF4 (expected D//2)" : "Head dimension mismatch");
564
792
  TORCH_CHECK(key.size(1) == value.size(1),
565
793
  "K and V must have same number of heads");
566
794
 
@@ -970,6 +1198,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
970
1198
  py::arg("window_size") = 0,
971
1199
  py::arg("bf16_backward") = false);
972
1200
 
1201
+ // Forward with additive attention bias (e.g., relative position encoding)
1202
+ m.def("forward_with_bias", &mps_flash_attention_forward_with_bias,
1203
+ "Flash Attention forward with additive attention bias",
1204
+ py::arg("query"),
1205
+ py::arg("key"),
1206
+ py::arg("value"),
1207
+ py::arg("attn_bias"),
1208
+ py::arg("is_causal") = false,
1209
+ py::arg("window_size") = 0,
1210
+ py::arg("bias_repeat_count") = 0);
1211
+
973
1212
  // Quantized attention (forward only - no gradients through quantized weights)
974
1213
  m.def("forward_quantized", &mps_flash_attention_forward_quantized,
975
1214
  "Quantized Flash Attention forward (FP8/INT8/NF4 K/V)",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.1
3
+ Version: 0.2.5
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -201,6 +201,13 @@ Python API (mps_flash_attn)
201
201
  - Python 3.10+
202
202
  - PyTorch 2.0+
203
203
 
204
+ ## TODO / Future Optimizations
205
+
206
+ - [ ] **Batched kernel dispatch** - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
207
+ - [ ] **Fused QKV projection + attention** - Single kernel from input to output, avoid intermediate buffers
208
+ - [ ] **Pre-scaled bias option** - Allow passing pre-scaled bias to avoid per-call scaling overhead
209
+ - [ ] **LoRA fusion** - Fuse adapter weights into attention computation
210
+
204
211
  ## Credits
205
212
 
206
213
  - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.2.1"
7
+ version = "0.2.5"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -342,6 +342,97 @@ class TestQuantized:
342
342
  assert output.shape == (B, H, N, D)
343
343
  assert not torch.isnan(output).any()
344
344
 
345
+ def test_quantize_nf4(self, mfa):
346
+ """Test NF4 quantization helper.
347
+
348
+ NF4 packs 2 values per byte along head dimension, so output shape is (B,H,N,D//2).
349
+ """
350
+ B, H, N, D = 1, 4, 128, 64
351
+ dtype = torch.float16
352
+
353
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
354
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
355
+
356
+ k_q, v_q, k_s, v_s = mfa.quantize_kv_nf4(k, v)
357
+
358
+ assert k_q.dtype == torch.uint8
359
+ assert v_q.dtype == torch.uint8
360
+ assert k_s.dtype == torch.float32
361
+ assert v_s.dtype == torch.float32
362
+ # NF4 packs 2 values per byte, so D dimension is halved
363
+ assert k_q.shape == (B, H, N, D // 2)
364
+ assert v_q.shape == (B, H, N, D // 2)
365
+
366
+ def test_flash_attention_nf4(self, mfa):
367
+ """Test NF4 quantized attention forward.
368
+
369
+ NF4 uses a 16-value codebook for 4-bit quantization, packing 2 values per byte.
370
+ This provides 4x memory reduction for K/V cache with acceptable accuracy loss.
371
+ """
372
+ B, H, N, D = 2, 8, 128, 64
373
+ dtype = torch.float16
374
+
375
+ torch.manual_seed(42)
376
+ q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
377
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
378
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
379
+
380
+ k_q, v_q, k_s, v_s = mfa.quantize_kv_nf4(k, v)
381
+ output = mfa.flash_attention_nf4(q, k_q, v_q, k_s, v_s)
382
+
383
+ assert output.shape == (B, H, N, D)
384
+ assert not torch.isnan(output).any()
385
+ assert not torch.isinf(output).any()
386
+
387
+ def test_flash_attention_nf4_correctness(self, mfa):
388
+ """Test NF4 attention correctness against reference.
389
+
390
+ NF4 is 4-bit so we expect larger error than FP8/INT8, but output
391
+ should still be in a reasonable range (max diff < 0.5).
392
+ """
393
+ B, H, N, D = 1, 4, 64, 64
394
+ dtype = torch.float16
395
+
396
+ torch.manual_seed(42)
397
+ q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
398
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
399
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
400
+
401
+ # Reference
402
+ ref = reference_attention(q, k, v)
403
+
404
+ # NF4
405
+ k_q, v_q, k_s, v_s = mfa.quantize_kv_nf4(k, v)
406
+ output = mfa.flash_attention_nf4(q, k_q, v_q, k_s, v_s)
407
+
408
+ max_diff = (ref - output).abs().max().item()
409
+ mean_diff = (ref - output).abs().mean().item()
410
+
411
+ # 4-bit quantization has larger error, but should be bounded
412
+ assert max_diff < 0.5, f"NF4 max diff {max_diff} exceeds threshold 0.5"
413
+ assert mean_diff < 0.1, f"NF4 mean diff {mean_diff} exceeds threshold 0.1"
414
+
415
+ @pytest.mark.parametrize("config", [
416
+ (1, 1, 32, 32), # Small
417
+ (1, 8, 128, 64), # Medium
418
+ (2, 16, 256, 64), # Large
419
+ (1, 8, 512, 128), # Large head dim
420
+ ])
421
+ def test_flash_attention_nf4_various_sizes(self, mfa, config):
422
+ """Test NF4 attention with various tensor sizes."""
423
+ B, H, N, D = config
424
+ dtype = torch.float16
425
+
426
+ q = torch.randn(B, H, N, D, device='mps', dtype=dtype)
427
+ k = torch.randn(B, H, N, D, device='mps', dtype=dtype)
428
+ v = torch.randn(B, H, N, D, device='mps', dtype=dtype)
429
+
430
+ k_q, v_q, k_s, v_s = mfa.quantize_kv_nf4(k, v)
431
+ output = mfa.flash_attention_nf4(q, k_q, v_q, k_s, v_s)
432
+
433
+ assert output.shape == (B, H, N, D)
434
+ assert not torch.isnan(output).any()
435
+
345
436
 
346
437
  # =============================================================================
347
438
  # Chunked/Streaming Attention Tests
File without changes
File without changes
File without changes