mps-flash-attn 0.2.7__tar.gz → 0.2.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/PKG-INFO +1 -1
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/__init__.py +11 -2
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/pyproject.toml +1 -1
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/LICENSE +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/README.md +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/setup.cfg +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/setup.py +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.8}/tests/test_mfa_v2.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.2.
|
|
7
|
+
__version__ = "0.2.8"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -229,6 +229,15 @@ def replace_sdpa():
|
|
|
229
229
|
_HAS_MFA and
|
|
230
230
|
query.shape[2] >= 512):
|
|
231
231
|
try:
|
|
232
|
+
# Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
|
|
233
|
+
# Common in Llama 2/3, Mistral, Qwen, etc.
|
|
234
|
+
k, v = key, value
|
|
235
|
+
if enable_gqa and query.shape[1] != key.shape[1]:
|
|
236
|
+
# Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
|
|
237
|
+
n_rep = query.shape[1] // key.shape[1]
|
|
238
|
+
k = key.repeat_interleave(n_rep, dim=1)
|
|
239
|
+
v = value.repeat_interleave(n_rep, dim=1)
|
|
240
|
+
|
|
232
241
|
# Convert float mask to bool mask if needed
|
|
233
242
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
234
243
|
# MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
|
|
@@ -242,7 +251,7 @@ def replace_sdpa():
|
|
|
242
251
|
# Convert: positions with large negative values -> True (masked)
|
|
243
252
|
# Use -1e3 threshold to catch -1000, -10000, -inf, etc.
|
|
244
253
|
mfa_mask = attn_mask <= -1e3
|
|
245
|
-
return flash_attention(query,
|
|
254
|
+
return flash_attention(query, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
246
255
|
except Exception:
|
|
247
256
|
# Fall back to original on any error
|
|
248
257
|
pass
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|