mps-flash-attn 0.2.8__tar.gz → 0.2.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/PKG-INFO +1 -1
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/__init__.py +58 -11
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/PKG-INFO +1 -1
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/pyproject.toml +1 -1
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/LICENSE +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/README.md +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/benchmark.py +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/setup.cfg +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/setup.py +0 -0
- {mps_flash_attn-0.2.8 → mps_flash_attn-0.2.9}/tests/test_mfa_v2.py +0 -0
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.2.
|
|
7
|
+
__version__ = "0.2.9"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -213,6 +213,8 @@ def replace_sdpa():
|
|
|
213
213
|
import torch.nn.functional as F
|
|
214
214
|
|
|
215
215
|
original_sdpa = F.scaled_dot_product_attention
|
|
216
|
+
_debug = os.environ.get("MFA_DEBUG", "0") == "1"
|
|
217
|
+
_call_count = [0] # mutable for closure
|
|
216
218
|
|
|
217
219
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
218
220
|
is_causal=False, scale=None, enable_gqa=False, **kwargs):
|
|
@@ -224,38 +226,83 @@ def replace_sdpa():
|
|
|
224
226
|
# seq=1024: 2.3-3.7x (MFA much faster)
|
|
225
227
|
# seq=2048: 2.2-3.9x (MFA much faster)
|
|
226
228
|
# seq=4096: 2.1-3.7x (MFA much faster)
|
|
229
|
+
# Determine seq_len based on tensor dimensionality
|
|
230
|
+
# 4D: (B, H, S, D) -> seq_len = shape[2]
|
|
231
|
+
# 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
|
|
232
|
+
is_3d = query.ndim == 3
|
|
233
|
+
seq_len = query.shape[1] if is_3d else query.shape[2]
|
|
234
|
+
|
|
227
235
|
if (query.device.type == 'mps' and
|
|
228
236
|
dropout_p == 0.0 and
|
|
229
237
|
_HAS_MFA and
|
|
230
|
-
query.
|
|
238
|
+
query.ndim >= 3 and
|
|
239
|
+
seq_len >= 512):
|
|
231
240
|
try:
|
|
241
|
+
q, k, v = query, key, value
|
|
242
|
+
|
|
243
|
+
# Handle 3D tensors (B, S, D) - treat as single-head attention
|
|
244
|
+
# Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
|
|
245
|
+
if is_3d:
|
|
246
|
+
q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
|
|
247
|
+
k = k.unsqueeze(1)
|
|
248
|
+
v = v.unsqueeze(1)
|
|
249
|
+
|
|
232
250
|
# Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
|
|
233
251
|
# Common in Llama 2/3, Mistral, Qwen, etc.
|
|
234
|
-
|
|
235
|
-
|
|
252
|
+
# NOTE: Always expand when heads mismatch, not just when enable_gqa=True
|
|
253
|
+
# Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
|
|
254
|
+
# MPS SDPA doesn't support native GQA - we handle it here
|
|
255
|
+
if q.shape[1] != k.shape[1]:
|
|
236
256
|
# Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
|
|
237
|
-
n_rep =
|
|
238
|
-
k =
|
|
239
|
-
v =
|
|
257
|
+
n_rep = q.shape[1] // k.shape[1]
|
|
258
|
+
k = k.repeat_interleave(n_rep, dim=1)
|
|
259
|
+
v = v.repeat_interleave(n_rep, dim=1)
|
|
240
260
|
|
|
241
261
|
# Convert float mask to bool mask if needed
|
|
242
262
|
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
243
263
|
# MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
|
|
244
264
|
mfa_mask = None
|
|
245
265
|
if attn_mask is not None:
|
|
266
|
+
if _debug:
|
|
267
|
+
print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
|
|
246
268
|
if attn_mask.dtype == torch.bool:
|
|
247
|
-
#
|
|
248
|
-
|
|
269
|
+
# PyTorch SDPA bool mask: True = ATTEND, False = MASKED
|
|
270
|
+
# MFA bool mask: True = MASKED, False = ATTEND
|
|
271
|
+
# They're opposite! Invert it.
|
|
272
|
+
mfa_mask = ~attn_mask
|
|
249
273
|
else:
|
|
250
274
|
# Float mask: typically -inf for masked positions, 0 for unmasked
|
|
251
275
|
# Convert: positions with large negative values -> True (masked)
|
|
252
276
|
# Use -1e3 threshold to catch -1000, -10000, -inf, etc.
|
|
253
277
|
mfa_mask = attn_mask <= -1e3
|
|
254
|
-
|
|
255
|
-
|
|
278
|
+
if _debug:
|
|
279
|
+
print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
|
|
280
|
+
|
|
281
|
+
out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
282
|
+
|
|
283
|
+
# Squeeze back for 3D input
|
|
284
|
+
if is_3d:
|
|
285
|
+
out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
|
|
286
|
+
|
|
287
|
+
if _debug:
|
|
288
|
+
_call_count[0] += 1
|
|
289
|
+
print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
|
|
290
|
+
|
|
291
|
+
return out
|
|
292
|
+
except Exception as e:
|
|
256
293
|
# Fall back to original on any error
|
|
294
|
+
if _debug:
|
|
295
|
+
print(f"[MFA FALLBACK] shape={tuple(query.shape)} error={e}")
|
|
257
296
|
pass
|
|
258
297
|
|
|
298
|
+
if _debug and query.device.type == 'mps':
|
|
299
|
+
_call_count[0] += 1
|
|
300
|
+
reason = []
|
|
301
|
+
if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
|
|
302
|
+
if query.ndim < 3: reason.append(f"ndim={query.ndim}")
|
|
303
|
+
if seq_len < 512: reason.append(f"seq={seq_len}<512")
|
|
304
|
+
print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
|
|
305
|
+
|
|
259
306
|
return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
|
|
260
307
|
|
|
261
308
|
F.scaled_dot_product_attention = patched_sdpa
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|