mps-flash-attn 0.3.0__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/PKG-INFO +1 -1
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/__init__.py +9 -3
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/pyproject.toml +1 -1
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/LICENSE +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/README.md +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/setup.cfg +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/setup.py +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/tests/test_issues.py +0 -0
- {mps_flash_attn-0.3.0 → mps_flash_attn-0.3.1}/tests/test_mfa_v2.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.3.
|
|
7
|
+
__version__ = "0.3.1"
|
|
8
8
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
# Core functions
|
|
@@ -255,10 +255,16 @@ def flash_attention(
|
|
|
255
255
|
if attn_mask.dim() != 4:
|
|
256
256
|
raise ValueError(f"attn_mask must be 4D (B, H, N_q, N_kv), got {attn_mask.dim()}D")
|
|
257
257
|
mb, mh, mq, mk = attn_mask.shape
|
|
258
|
-
|
|
258
|
+
# Allow broadcast: mq can be 1 (applies same mask to all query positions) or N_q
|
|
259
|
+
if (mq != 1 and mq != N_q) or (mk != 1 and mk != N_kv):
|
|
259
260
|
raise ValueError(
|
|
260
|
-
f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv})"
|
|
261
|
+
f"attn_mask shape mismatch: mask is ({mq}, {mk}) but expected ({N_q}, {N_kv}) or broadcastable (1, {N_kv})"
|
|
261
262
|
)
|
|
263
|
+
# Expand broadcast mask to full shape for Metal kernel
|
|
264
|
+
if mq == 1 and N_q > 1:
|
|
265
|
+
attn_mask = attn_mask.expand(mb, mh, N_q, mk)
|
|
266
|
+
if mk == 1 and N_kv > 1:
|
|
267
|
+
attn_mask = attn_mask.expand(mb, mh, mq if mq > 1 else N_q, N_kv)
|
|
262
268
|
if mb != 1 and mb != B:
|
|
263
269
|
raise ValueError(
|
|
264
270
|
f"attn_mask batch size must be 1 or {B}, got {mb}"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|