mps-flash-attn 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (39) hide show
  1. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/PKG-INFO +1 -1
  2. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/__init__.py +5 -2
  3. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/csrc/mps_flash_attn.mm +57 -15
  4. mps_flash_attn-0.1.6/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  5. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/PKG-INFO +1 -1
  6. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/pyproject.toml +1 -1
  7. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/setup.py +2 -2
  8. mps_flash_attn-0.1.4/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  9. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/LICENSE +0 -0
  10. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/README.md +0 -0
  11. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  12. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  13. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  14. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  15. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  16. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  17. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  18. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  19. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  20. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  21. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  22. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  23. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  24. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  25. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  26. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  27. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  28. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  29. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  30. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  31. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  32. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  33. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn/kernels/manifest.json +0 -0
  34. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
  35. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  36. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/requires.txt +0 -0
  37. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/mps_flash_attn.egg-info/top_level.txt +0 -0
  38. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/setup.cfg +0 -0
  39. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.6}/tests/test_attention.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.1.4"
7
+ __version__ = "0.1.6"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -174,10 +174,13 @@ def replace_sdpa():
174
174
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
175
175
  is_causal=False, scale=None):
176
176
  # Use MFA for MPS tensors without mask/dropout
177
+ # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
178
+ # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
177
179
  if (query.device.type == 'mps' and
178
180
  attn_mask is None and
179
181
  dropout_p == 0.0 and
180
- _HAS_MFA):
182
+ _HAS_MFA and
183
+ query.shape[2] >= 1024):
181
184
  try:
182
185
  return flash_attention(query, key, value, is_causal=is_causal, scale=scale)
183
186
  except Exception:
@@ -18,7 +18,13 @@
18
18
  // ============================================================================
19
19
 
20
20
  typedef bool (*mfa_init_fn)();
21
- typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool); // Added causal param
21
+ typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool);
22
+ // New zero-sync encode functions that take PyTorch's command encoder
23
+ typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
24
+ typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
25
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
26
+ int32_t, int32_t);
27
+ // Legacy sync functions (fallback)
22
28
  typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
23
29
  typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
24
30
  int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
@@ -28,6 +34,8 @@ typedef void (*mfa_release_kernel_fn)(void*);
28
34
  // Global function pointers
29
35
  static mfa_init_fn g_mfa_init = nullptr;
30
36
  static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
37
+ static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
38
+ static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
31
39
  static mfa_forward_fn g_mfa_forward = nullptr;
32
40
  static mfa_backward_fn g_mfa_backward = nullptr;
33
41
  static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
@@ -71,11 +79,14 @@ static bool load_mfa_bridge() {
71
79
  // Load function pointers
72
80
  g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
73
81
  g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
82
+ g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
83
+ g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
74
84
  g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
75
85
  g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
76
86
  g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
77
87
 
78
- if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward || !g_mfa_backward || !g_mfa_release_kernel) {
88
+ // Require at least init, create_kernel, forward_encode (for zero-sync path)
89
+ if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward_encode) {
79
90
  throw std::runtime_error("Failed to load MFA bridge functions");
80
91
  }
81
92
 
@@ -193,17 +204,37 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
193
204
  TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
194
205
 
195
206
  const int64_t batch_size = query.size(0);
196
- const int64_t num_heads = query.size(1);
207
+ const int64_t num_heads_q = query.size(1);
208
+ const int64_t num_heads_kv = key.size(1);
197
209
  const int64_t seq_len_q = query.size(2);
198
210
  const int64_t head_dim = query.size(3);
199
211
  const int64_t seq_len_kv = key.size(2);
200
212
 
201
213
  TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
202
214
  "Batch size mismatch");
203
- TORCH_CHECK(key.size(1) == num_heads && value.size(1) == num_heads,
204
- "Number of heads mismatch");
205
215
  TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
206
216
  "Head dimension mismatch");
217
+ TORCH_CHECK(key.size(1) == value.size(1),
218
+ "K and V must have same number of heads");
219
+
220
+ // Handle GQA (Grouped Query Attention): expand K/V if fewer heads than Q
221
+ const int64_t num_heads = num_heads_q;
222
+ at::Tensor k_expanded, v_expanded;
223
+
224
+ if (num_heads_kv != num_heads_q) {
225
+ // GQA: num_heads_q must be divisible by num_heads_kv
226
+ TORCH_CHECK(num_heads_q % num_heads_kv == 0,
227
+ "num_heads_q (", num_heads_q, ") must be divisible by num_heads_kv (", num_heads_kv, ")");
228
+ int64_t repeat_factor = num_heads_q / num_heads_kv;
229
+
230
+ // Expand K and V to match Q's head count: (B, H_kv, S, D) -> (B, H_q, S, D)
231
+ // Use repeat_interleave for proper GQA expansion
232
+ k_expanded = key.repeat_interleave(repeat_factor, /*dim=*/1);
233
+ v_expanded = value.repeat_interleave(repeat_factor, /*dim=*/1);
234
+ } else {
235
+ k_expanded = key;
236
+ v_expanded = value;
237
+ }
207
238
 
208
239
  // Determine precision
209
240
  bool low_precision = (query.scalar_type() == at::kHalf ||
@@ -214,8 +245,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
214
245
 
215
246
  // Make inputs contiguous
216
247
  auto q = query.contiguous();
217
- auto k = key.contiguous();
218
- auto v = value.contiguous();
248
+ auto k = k_expanded.contiguous();
249
+ auto v = v_expanded.contiguous();
219
250
 
220
251
  // Allocate output in the appropriate precision
221
252
  // With lowPrecisionOutputs=true, MFA writes FP16 directly
@@ -242,15 +273,18 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
242
273
  auto o_info = getBufferInfo(output);
243
274
  auto l_info = getBufferInfo(logsumexp);
244
275
 
245
- // Synchronize with PyTorch's MPS stream
276
+ // Use PyTorch's MPS stream command encoder for zero-sync integration
246
277
  @autoreleasepool {
247
- // Wait for PyTorch operations to complete
248
278
  auto stream = at::mps::getCurrentMPSStream();
249
- stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
250
279
 
251
- // Execute MFA with storage byte offsets
252
- bool success = g_mfa_forward(
280
+ // Get PyTorch's shared command encoder - this is the key for zero-sync!
281
+ // All our dispatches go onto the same encoder that PyTorch uses.
282
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
283
+
284
+ // Execute MFA using the shared encoder (no sync needed!)
285
+ bool success = g_mfa_forward_encode(
253
286
  kernel,
287
+ (__bridge void*)encoder, // PyTorch's shared command encoder
254
288
  (__bridge void*)q_info.buffer,
255
289
  (__bridge void*)k_info.buffer,
256
290
  (__bridge void*)v_info.buffer,
@@ -268,6 +302,9 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
268
302
  if (!success) {
269
303
  throw std::runtime_error("MFA forward pass failed");
270
304
  }
305
+
306
+ // No commit needed - PyTorch will commit when it needs the results
307
+ // The encoder stays open for coalescing more kernels
271
308
  }
272
309
 
273
310
  // Output is already in the correct dtype (fp16 or fp32)
@@ -362,13 +399,16 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
362
399
  auto dk_info = getBufferInfo(dK);
363
400
  auto dv_info = getBufferInfo(dV);
364
401
 
365
- // Execute backward pass
402
+ // Use PyTorch's MPS stream command encoder for zero-sync integration
366
403
  @autoreleasepool {
367
404
  auto stream = at::mps::getCurrentMPSStream();
368
- stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
369
405
 
370
- bool success = g_mfa_backward(
406
+ // Get PyTorch's shared command encoder
407
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
408
+
409
+ bool success = g_mfa_backward_encode(
371
410
  kernel,
411
+ (__bridge void*)encoder, // PyTorch's shared command encoder
372
412
  (__bridge void*)q_info.buffer,
373
413
  (__bridge void*)k_info.buffer,
374
414
  (__bridge void*)v_info.buffer,
@@ -396,6 +436,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
396
436
  if (!success) {
397
437
  throw std::runtime_error("MFA backward pass failed");
398
438
  }
439
+
440
+ // No commit needed - PyTorch will commit when it needs the results
399
441
  }
400
442
 
401
443
  // Convert gradients back to input dtype if needed
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.1.4"
7
+ version = "0.1.6"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -44,14 +44,14 @@ def get_extensions():
44
44
  return [Extension(
45
45
  name="mps_flash_attn._C",
46
46
  sources=["mps_flash_attn/csrc/mps_flash_attn.mm"],
47
- extra_compile_args=["-std=c++17", "-O3"],
47
+ extra_compile_args=["-std=c++17", "-O3", "-DTORCH_EXTENSION_NAME=_C"],
48
48
  extra_link_args=["-framework", "Metal", "-framework", "Foundation"],
49
49
  )]
50
50
 
51
51
 
52
52
  setup(
53
53
  name="mps-flash-attn",
54
- version="0.1.1",
54
+ version="0.1.5",
55
55
  packages=find_packages(),
56
56
  package_data={
57
57
  "mps_flash_attn": [
File without changes
File without changes
File without changes