mps-flash-attn 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (39) hide show
  1. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/PKG-INFO +1 -1
  2. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/__init__.py +3 -3
  3. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/PKG-INFO +1 -1
  4. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/pyproject.toml +1 -1
  5. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/LICENSE +0 -0
  6. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/README.md +0 -0
  7. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/benchmark.py +0 -0
  8. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
  9. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  10. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  11. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  12. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  13. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  14. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  15. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  16. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  17. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  18. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  19. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  20. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  21. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  22. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  23. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  24. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  25. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  26. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  27. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  28. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  29. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  30. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  31. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/manifest.json +0 -0
  32. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  33. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
  34. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  35. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/requires.txt +0 -0
  36. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/top_level.txt +0 -0
  37. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/setup.cfg +0 -0
  38. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/setup.py +0 -0
  39. {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/tests/test_mfa_v2.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.2.0"
7
+ __version__ = "0.2.1"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -215,7 +215,7 @@ def replace_sdpa():
215
215
  original_sdpa = F.scaled_dot_product_attention
216
216
 
217
217
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
218
- is_causal=False, scale=None):
218
+ is_causal=False, scale=None, enable_gqa=False, **kwargs):
219
219
  # Use MFA for MPS tensors without dropout
220
220
  # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
221
221
  # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
@@ -247,7 +247,7 @@ def replace_sdpa():
247
247
  # Fall back to original on any error
248
248
  pass
249
249
 
250
- return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale)
250
+ return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
251
251
 
252
252
  F.scaled_dot_product_attention = patched_sdpa
253
253
  print("MPS Flash Attention: Patched F.scaled_dot_product_attention")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.2.0"
7
+ version = "0.2.1"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
File without changes
File without changes
File without changes
File without changes