mps-flash-attn 0.3.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (40) hide show
  1. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/PKG-INFO +1 -1
  2. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/__init__.py +9 -3
  3. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/PKG-INFO +1 -1
  4. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/pyproject.toml +1 -1
  5. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/LICENSE +0 -0
  6. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/README.md +0 -0
  7. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/benchmark.py +0 -0
  8. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
  9. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  10. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  11. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  12. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  13. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  14. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  15. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  16. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  17. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  18. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  19. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  20. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  21. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  22. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  23. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  24. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  25. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  26. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  27. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  28. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  29. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  30. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  31. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/manifest.json +0 -0
  32. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  33. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
  34. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  35. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/requires.txt +0 -0
  36. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/top_level.txt +0 -0
  37. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/setup.cfg +0 -0
  38. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/setup.py +0 -0
  39. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/tests/test_issues.py +0 -0
  40. {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/tests/test_mfa_v2.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.3.0"
7
+ __version__ = "0.3.1"
8
8
 
9
9
  __all__ = [
10
10
  # Core functions
@@ -255,10 +255,16 @@ def flash_attention(
255
255
  if attn_mask.dim() != 4:
256
256
  raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
257
257
  mb, mh, mq, mk = attn_mask.shape
258
- if mq != N_q or mk != N_kv:
258
+ # Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
259
+ if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
259
260
  raise ValueError(
260
- f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv})"
261
+ f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
261
262
  )
263
+ # Expand broadcast mask to full shape for Metal kernel
264
+ if mq == 1 and N_q > 1:
265
+ attn_mask = attn_mask.expand(mb, mh, N_q, mk)
266
+ if mk == 1 and N_kv > 1:
267
+ attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
262
268
  if mb != 1 and mb != B:
263
269
  raise ValueError(
264
270
  f"attn_mask batch size must be 1 or {B}, got {mb}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.3.0"
7
+ version = "0.3.1"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
File without changes
File without changes
File without changes
File without changes