mps-flash-attn 0.2.6__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (39) hide show
  1. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/PKG-INFO +1 -1
  2. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/__init__.py +17 -8
  3. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/PKG-INFO +1 -1
  4. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/pyproject.toml +1 -1
  5. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/LICENSE +0 -0
  6. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/README.md +0 -0
  7. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/benchmark.py +0 -0
  8. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
  9. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  10. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  11. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  12. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  13. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  14. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  15. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  16. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  17. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  18. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  19. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  20. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  21. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  22. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  23. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  24. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  25. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  26. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  27. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  28. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  29. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  30. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  31. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/manifest.json +0 -0
  32. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  33. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
  34. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  35. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/requires.txt +0 -0
  36. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/top_level.txt +0 -0
  37. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/setup.cfg +0 -0
  38. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/setup.py +0 -0
  39. {mps_flash_attn-0.2.6 → mps_flash_attn-0.2.8}/tests/test_mfa_v2.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.2.5"
7
+ __version__ = "0.2.8"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -217,18 +217,27 @@ def replace_sdpa():
217
217
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
218
218
  is_causal=False, scale=None, enable_gqa=False, **kwargs):
219
219
  # Use MFA for MPS tensors without dropout
220
- # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
220
+ # Only use MFA for seq_len >= 512 where it outperforms PyTorch's math backend
221
221
  # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
222
222
  # Benchmark results (B=1-4, H=8, D=64-128, fp16/bf16):
223
- # seq=512: 0.3-0.5x (MFA slower)
224
- # seq=1024: 1.1-2.0x (MFA faster)
225
- # seq=2048: 1.7-3.7x (MFA much faster)
226
- # seq=4096: 2.0-3.9x (MFA much faster)
223
+ # seq=512: 1.2-1.6x (MFA faster)
224
+ # seq=1024: 2.3-3.7x (MFA much faster)
225
+ # seq=2048: 2.2-3.9x (MFA much faster)
226
+ # seq=4096: 2.1-3.7x (MFA much faster)
227
227
  if (query.device.type == 'mps' and
228
228
  dropout_p == 0.0 and
229
229
  _HAS_MFA and
230
- query.shape[2] >= 1024):
230
+ query.shape[2] >= 512):
231
231
  try:
232
+ # Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
233
+ # Common in Llama 2/3, Mistral, Qwen, etc.
234
+ k, v = key, value
235
+ if enable_gqa and query.shape[1] != key.shape[1]:
236
+ # Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
237
+ n_rep = query.shape[1] // key.shape[1]
238
+ k = key.repeat_interleave(n_rep, dim=1)
239
+ v = value.repeat_interleave(n_rep, dim=1)
240
+
232
241
  # Convert float mask to bool mask if needed
233
242
  # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
234
243
  # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
@@ -242,7 +251,7 @@ def replace_sdpa():
242
251
  # Convert: positions with large negative values -> True (masked)
243
252
  # Use -1e3 threshold to catch -1000, -10000, -inf, etc.
244
253
  mfa_mask = attn_mask <= -1e3
245
- return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
254
+ return flash_attention(query, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
246
255
  except Exception:
247
256
  # Fall back to original on any error
248
257
  pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.2.6"
7
+ version = "0.2.8"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
File without changes
File without changes
File without changes
File without changes