mps-flash-attn 0.2.5__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/PKG-INFO +1 -1
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/__init__.py +7 -7
- mps_flash_attn-0.2.7/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/pyproject.toml +1 -1
- mps_flash_attn-0.2.5/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/LICENSE +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/README.md +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/setup.cfg +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/setup.py +0 -0
- {mps_flash_attn-0.2.5 → mps_flash_attn-0.2.7}/tests/test_mfa_v2.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.2.
|
|
7
|
+
__version__ = "0.2.7"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -217,17 +217,17 @@ def replace_sdpa():
|
|
|
217
217
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
218
218
|
is_causal=False, scale=None, enable_gqa=False, **kwargs):
|
|
219
219
|
# Use MFA for MPS tensors without dropout
|
|
220
|
-
# Only use MFA for seq_len >=
|
|
220
|
+
# Only use MFA for seq_len >= 512 where it outperforms PyTorch's math backend
|
|
221
221
|
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
222
222
|
# Benchmark results (B=1-4, H=8, D=64-128, fp16/bf16):
|
|
223
|
-
# seq=512:
|
|
224
|
-
# seq=1024:
|
|
225
|
-
# seq=2048:
|
|
226
|
-
# seq=4096: 2.
|
|
223
|
+
# seq=512: 1.2-1.6x (MFA faster)
|
|
224
|
+
# seq=1024: 2.3-3.7x (MFA much faster)
|
|
225
|
+
# seq=2048: 2.2-3.9x (MFA much faster)
|
|
226
|
+
# seq=4096: 2.1-3.7x (MFA much faster)
|
|
227
227
|
if (query.device.type == 'mps' and
|
|
228
228
|
dropout_p == 0.0 and
|
|
229
229
|
_HAS_MFA and
|
|
230
|
-
query.shape[2] >=
|
|
230
|
+
query.shape[2] >= 512):
|
|
231
231
|
try:
|
|
232
232
|
# Convert float mask to bool mask if needed
|
|
233
233
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
Binary file
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|