mps-flash-attn 0.1.13__tar.gz → 0.1.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/PKG-INFO +1 -1
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/__init__.py +4 -3
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/pyproject.toml +1 -1
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/LICENSE +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/README.md +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/setup.cfg +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/setup.py +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/tests/test_attention.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.14"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -200,12 +200,13 @@ def replace_sdpa():
|
|
|
200
200
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
201
201
|
is_causal=False, scale=None):
|
|
202
202
|
# Use MFA for MPS tensors without dropout
|
|
203
|
-
# Only use MFA for seq_len >=
|
|
203
|
+
# Only use MFA for seq_len >= 1536 where it outperforms PyTorch's math backend
|
|
204
204
|
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
205
|
+
# Benchmark (BF16, heads=30, head_dim=128): crossover is ~1200-1500
|
|
205
206
|
if (query.device.type == 'mps' and
|
|
206
207
|
dropout_p == 0.0 and
|
|
207
208
|
_HAS_MFA and
|
|
208
|
-
query.shape[2] >=
|
|
209
|
+
query.shape[2] >= 1536):
|
|
209
210
|
try:
|
|
210
211
|
# Convert float mask to bool mask if needed
|
|
211
212
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{mps_flash_attn-0.1.13 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|