mps-flash-attn 0.2.7__tar.gz → 0.2.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/PKG-INFO +1 -1
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/__init__.py +62 -6
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/pyproject.toml +1 -1
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/LICENSE +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/README.md +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/setup.cfg +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/setup.py +0 -0
- {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/tests/test_mfa_v2.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.2.
|
|
7
|
+
__version__ = "0.2.9"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -213,6 +213,8 @@ def replace_sdpa():
|
|
|
213
213
|
import torch.nn.functional as F
|
|
214
214
|
|
|
215
215
|
original_sdpa = F.scaled_dot_product_attention
|
|
216
|
+
_debug = os.environ.get("MFA_DEBUG", "0") == "1"
|
|
217
|
+
_call_count = [0] # mutable for closure
|
|
216
218
|
|
|
217
219
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
218
220
|
is_causal=False, scale=None, enable_gqa=False, **kwargs):
|
|
@@ -224,29 +226,83 @@ def replace_sdpa():
|
|
|
224
226
|
# seq=1024: 2.3-3.7x (MFA much faster)
|
|
225
227
|
# seq=2048: 2.2-3.9x (MFA much faster)
|
|
226
228
|
# seq=4096: 2.1-3.7x (MFA much faster)
|
|
229
|
+
# Determine seq_len based on tensor dimensionality
|
|
230
|
+
# 4D: (B, H, S, D) -> seq_len = shape[2]
|
|
231
|
+
# 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
|
|
232
|
+
is_3d = query.ndim == 3
|
|
233
|
+
seq_len = query.shape[1] if is_3d else query.shape[2]
|
|
234
|
+
|
|
227
235
|
if (query.device.type == 'mps' and
|
|
228
236
|
dropout_p == 0.0 and
|
|
229
237
|
_HAS_MFA and
|
|
230
|
-
query.
|
|
238
|
+
query.ndim >= 3 and
|
|
239
|
+
seq_len >= 512):
|
|
231
240
|
try:
|
|
241
|
+
q, k, v = query, key, value
|
|
242
|
+
|
|
243
|
+
# Handle 3D tensors (B, S, D) - treat as single-head attention
|
|
244
|
+
# Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
|
|
245
|
+
if is_3d:
|
|
246
|
+
q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
|
|
247
|
+
k = k.unsqueeze(1)
|
|
248
|
+
v = v.unsqueeze(1)
|
|
249
|
+
|
|
250
|
+
# Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
|
|
251
|
+
# Common in Llama 2/3, Mistral, Qwen, etc.
|
|
252
|
+
# NOTE: Always expand when heads mismatch, not just when enable_gqa=True
|
|
253
|
+
# Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
|
|
254
|
+
# MPS SDPA doesn't support native GQA - we handle it here
|
|
255
|
+
if q.shape[1] != k.shape[1]:
|
|
256
|
+
# Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
|
|
257
|
+
n_rep = q.shape[1] // k.shape[1]
|
|
258
|
+
k = k.repeat_interleave(n_rep, dim=1)
|
|
259
|
+
v = v.repeat_interleave(n_rep, dim=1)
|
|
260
|
+
|
|
232
261
|
# Convert float mask to bool mask if needed
|
|
233
262
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
234
263
|
# MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
|
|
235
264
|
mfa_mask = None
|
|
236
265
|
if attn_mask is not None:
|
|
266
|
+
if _debug:
|
|
267
|
+
print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
|
|
237
268
|
if attn_mask.dtype == torch.bool:
|
|
238
|
-
#
|
|
239
|
-
|
|
269
|
+
# PyTorch SDPA bool mask: True = ATTEND, False = MASKED
|
|
270
|
+
# MFA bool mask: True = MASKED, False = ATTEND
|
|
271
|
+
# They're opposite! Invert it.
|
|
272
|
+
mfa_mask = ~attn_mask
|
|
240
273
|
else:
|
|
241
274
|
# Float mask: typically -inf for masked positions, 0 for unmasked
|
|
242
275
|
# Convert: positions with large negative values -> True (masked)
|
|
243
276
|
# Use -1e3 threshold to catch -1000, -10000, -inf, etc.
|
|
244
277
|
mfa_mask = attn_mask <= -1e3
|
|
245
|
-
|
|
246
|
-
|
|
278
|
+
if _debug:
|
|
279
|
+
print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
|
|
280
|
+
|
|
281
|
+
out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
282
|
+
|
|
283
|
+
# Squeeze back for 3D input
|
|
284
|
+
if is_3d:
|
|
285
|
+
out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
|
|
286
|
+
|
|
287
|
+
if _debug:
|
|
288
|
+
_call_count[0] += 1
|
|
289
|
+
print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
|
|
290
|
+
|
|
291
|
+
return out
|
|
292
|
+
except Exception as e:
|
|
247
293
|
# Fall back to original on any error
|
|
294
|
+
if _debug:
|
|
295
|
+
print(f"[MFA FALLBACK] shape={tuple(query.shape)} error={e}")
|
|
248
296
|
pass
|
|
249
297
|
|
|
298
|
+
if _debug and query.device.type == 'mps':
|
|
299
|
+
_call_count[0] += 1
|
|
300
|
+
reason = []
|
|
301
|
+
if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
|
|
302
|
+
if query.ndim < 3: reason.append(f"ndim={query.ndim}")
|
|
303
|
+
if seq_len < 512: reason.append(f"seq={seq_len}<512")
|
|
304
|
+
print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
|
|
305
|
+
|
|
250
306
|
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
|
|
251
307
|
|
|
252
308
|
F.scaled_dot_product_attention = patched_sdpa
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|