mps-flash-attn 0.2.7__tar.gz → 0.2.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (39) hide show
  1. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/PKG-INFO +1 -1
  2. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/__init__.py +62 -6
  3. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/PKG-INFO +1 -1
  4. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/pyproject.toml +1 -1
  5. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/LICENSE +0 -0
  6. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/README.md +0 -0
  7. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/benchmark.py +0 -0
  8. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/csrc/mps_flash_attn.mm +0 -0
  9. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  10. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  11. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  12. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  13. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  14. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  15. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  16. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  17. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  18. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  19. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  20. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  21. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  22. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  23. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  24. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  25. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  26. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  27. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  28. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  29. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  30. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  31. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/kernels/manifest.json +0 -0
  32. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  33. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
  34. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  35. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/requires.txt +0 -0
  36. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/mps_flash_attn.egg-info/top_level.txt +0 -0
  37. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/setup.cfg +0 -0
  38. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/setup.py +0 -0
  39. {mps_flash_attn-0.2.7 → mps_flash_attn-0.2.9}/tests/test_mfa_v2.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.7
3
+ Version: 0.2.9
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.2.7"
7
+ __version__ = "0.2.9"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -213,6 +213,8 @@ def replace_sdpa():
213
213
  import torch.nn.functional as F
214
214
 
215
215
  original_sdpa = F.scaled_dot_product_attention
216
+ _debug = os.environ.get("MFA_DEBUG", "0") == "1"
217
+ _call_count = [0] # mutable for closure
216
218
 
217
219
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
218
220
  is_causal=False, scale=None, enable_gqa=False, **kwargs):
@@ -224,29 +226,83 @@ def replace_sdpa():
224
226
  # seq=1024: 2.3-3.7x (MFA much faster)
225
227
  # seq=2048: 2.2-3.9x (MFA much faster)
226
228
  # seq=4096: 2.1-3.7x (MFA much faster)
229
+ # Determine seq_len based on tensor dimensionality
230
+ # 4D: (B, H, S, D) -> seq_len = shape[2]
231
+ # 3D: (B, S, D) -> seq_len = shape[1] (single-head attention, e.g., VAE)
232
+ is_3d = query.ndim == 3
233
+ seq_len = query.shape[1] if is_3d else query.shape[2]
234
+
227
235
  if (query.device.type == 'mps' and
228
236
  dropout_p == 0.0 and
229
237
  _HAS_MFA and
230
- query.shape[2] >= 512):
238
+ query.ndim >= 3 and
239
+ seq_len >= 512):
231
240
  try:
241
+ q, k, v = query, key, value
242
+
243
+ # Handle 3D tensors (B, S, D) - treat as single-head attention
244
+ # Unsqueeze to (B, 1, S, D) for MFA, squeeze back after
245
+ if is_3d:
246
+ q = q.unsqueeze(1) # (B, S, D) -> (B, 1, S, D)
247
+ k = k.unsqueeze(1)
248
+ v = v.unsqueeze(1)
249
+
250
+ # Handle GQA (Grouped Query Attention) - expand K/V heads to match Q heads
251
+ # Common in Llama 2/3, Mistral, Qwen, etc.
252
+ # NOTE: Always expand when heads mismatch, not just when enable_gqa=True
253
+ # Transformers may pass enable_gqa=True on MPS (torch>=2.5, no mask) even though
254
+ # MPS SDPA doesn't support native GQA - we handle it here
255
+ if q.shape[1] != k.shape[1]:
256
+ # Expand KV heads: (B, kv_heads, S, D) -> (B, q_heads, S, D)
257
+ n_rep = q.shape[1] // k.shape[1]
258
+ k = k.repeat_interleave(n_rep, dim=1)
259
+ v = v.repeat_interleave(n_rep, dim=1)
260
+
232
261
  # Convert float mask to bool mask if needed
233
262
  # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
234
263
  # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
235
264
  mfa_mask = None
236
265
  if attn_mask is not None:
266
+ if _debug:
267
+ print(f"[MFA MASK] dtype={attn_mask.dtype} shape={tuple(attn_mask.shape)} min={attn_mask.min().item():.2f} max={attn_mask.max().item():.2f}")
237
268
  if attn_mask.dtype == torch.bool:
238
- # Boolean mask: True means masked (don't attend)
239
- mfa_mask = attn_mask
269
+ # PyTorch SDPA bool mask: True = ATTEND, False = MASKED
270
+ # MFA bool mask: True = MASKED, False = ATTEND
271
+ # They're opposite! Invert it.
272
+ mfa_mask = ~attn_mask
240
273
  else:
241
274
  # Float mask: typically -inf for masked positions, 0 for unmasked
242
275
  # Convert: positions with large negative values -> True (masked)
243
276
  # Use -1e3 threshold to catch -1000, -10000, -inf, etc.
244
277
  mfa_mask = attn_mask <= -1e3
245
- return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
246
- except Exception:
278
+ if _debug:
279
+ print(f"[MFA MASK] converted: True(masked)={mfa_mask.sum().item()} False(attend)={(~mfa_mask).sum().item()}")
280
+
281
+ out = flash_attention(q, k, v, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
282
+
283
+ # Squeeze back for 3D input
284
+ if is_3d:
285
+ out = out.squeeze(1) # (B, 1, S, D) -> (B, S, D)
286
+
287
+ if _debug:
288
+ _call_count[0] += 1
289
+ print(f"[MFA #{_call_count[0]}] shape={tuple(query.shape)} is_3d={is_3d} gqa={enable_gqa} mask={attn_mask is not None} causal={is_causal}")
290
+
291
+ return out
292
+ except Exception as e:
247
293
  # Fall back to original on any error
294
+ if _debug:
295
+ print(f"[MFA FALLBACK] shape={tuple(query.shape)} error={e}")
248
296
  pass
249
297
 
298
+ if _debug and query.device.type == 'mps':
299
+ _call_count[0] += 1
300
+ reason = []
301
+ if dropout_p != 0.0: reason.append(f"dropout={dropout_p}")
302
+ if query.ndim < 3: reason.append(f"ndim={query.ndim}")
303
+ if seq_len < 512: reason.append(f"seq={seq_len}<512")
304
+ print(f"[NATIVE #{_call_count[0]}] shape={tuple(query.shape)} reason={','.join(reason) or 'unknown'}")
305
+
250
306
  return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale=scale, enable_gqa=enable_gqa, **kwargs)
251
307
 
252
308
  F.scaled_dot_product_attention = patched_sdpa
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.2.7
3
+ Version: 0.2.9
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.2.7"
7
+ version = "0.2.9"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
File without changes
File without changes
File without changes
File without changes