mps-flash-attn 0.1.5__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/PKG-INFO +1 -1
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/__init__.py +5 -2
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/pyproject.toml +1 -1
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/LICENSE +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/README.md +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/setup.cfg +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/setup.py +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.6}/tests/test_attention.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.6"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -174,10 +174,13 @@ def replace_sdpa():
|
|
|
174
174
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
175
175
|
is_causal=False, scale=None):
|
|
176
176
|
# Use MFA for MPS tensors without mask/dropout
|
|
177
|
+
# Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
|
|
178
|
+
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
177
179
|
if (query.device.type == 'mps' and
|
|
178
180
|
attn_mask is None and
|
|
179
181
|
dropout_p == 0.0 and
|
|
180
|
-
_HAS_MFA
|
|
182
|
+
_HAS_MFA and
|
|
183
|
+
query.shape[2] >= 1024):
|
|
181
184
|
try:
|
|
182
185
|
return flash_attention(query, key, value, is_causal=is_causal, scale=scale)
|
|
183
186
|
except Exception:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|