mps-flash-attn 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/PKG-INFO +1 -1
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/__init__.py +3 -3
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/pyproject.toml +1 -1
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/LICENSE +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/README.md +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/setup.cfg +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/setup.py +0 -0
- {mps_flash_attn-0.2.0 → mps_flash_attn-0.2.1}/tests/test_mfa_v2.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.2.
|
|
7
|
+
__version__ = "0.2.1"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -215,7 +215,7 @@ def replace_sdpa():
|
|
|
215
215
|
original_sdpa = F.scaled_dot_product_attention
|
|
216
216
|
|
|
217
217
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
218
|
-
is_causal=False, scale=None):
|
|
218
|
+
is_causal=False, scale=None, enable_gqa=False, **kwargs):
|
|
219
219
|
# Use MFA for MPS tensors without dropout
|
|
220
220
|
# Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
|
|
221
221
|
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
@@ -247,7 +247,7 @@ def replace_sdpa():
|
|
|
247
247
|
# Fall back to original on any error
|
|
248
248
|
pass
|
|
249
249
|
|
|
250
|
-
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale)
|
|
250
|
+
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
|
|
251
251
|
|
|
252
252
|
F.scaled_dot_product_attention = patched_sdpa
|
|
253
253
|
print("MPS Flash Attention: Patched F.scaled_dot_product_attention")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|