mps-flash-attn 0.1.8__tar.gz → 0.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/PKG-INFO +1 -1
  2. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/__init__.py +30 -4
  3. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/csrc/mps_flash_attn.mm +69 -25
  4. mps_flash_attn-0.1.14/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  5. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/PKG-INFO +1 -1
  6. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/pyproject.toml +1 -1
  7. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/setup.py +21 -0
  8. mps_flash_attn-0.1.8/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  9. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/LICENSE +0 -0
  10. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/README.md +0 -0
  11. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  12. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  13. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  14. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  15. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  16. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  17. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  18. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  19. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  20. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  21. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  22. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  23. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  24. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  25. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  26. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  27. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  28. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  29. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  30. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  31. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  32. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  33. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn/kernels/manifest.json +0 -0
  34. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
  35. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  36. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/requires.txt +0 -0
  37. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/mps_flash_attn.egg-info/top_level.txt +0 -0
  38. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/setup.cfg +0 -0
  39. {mps_flash_attn-0.1.8 → mps_flash_attn-0.1.14}/tests/test_attention.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.8
3
+ Version: 0.1.14
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.1.8"
7
+ __version__ = "0.1.14"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -30,6 +30,30 @@ def is_available() -> bool:
30
30
  return _HAS_MFA and torch.backends.mps.is_available()
31
31
 
32
32
 
33
+ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
34
+ """
35
+ Convert attention mask to MFA's boolean format.
36
+
37
+ MFA uses boolean masks where True = masked (don't attend).
38
+ PyTorch SDPA uses additive float masks where -inf/large negative = masked.
39
+
40
+ Args:
41
+ attn_mask: Optional mask, either:
42
+ - None: no mask
43
+ - bool tensor: already in MFA format (True = masked)
44
+ - float tensor: additive mask (large negative = masked)
45
+
46
+ Returns:
47
+ Boolean mask suitable for flash_attention(), or None
48
+ """
49
+ if attn_mask is None:
50
+ return None
51
+ if attn_mask.dtype == torch.bool:
52
+ return attn_mask
53
+ # Float mask: large negative values indicate masked positions
54
+ return attn_mask <= -1e3
55
+
56
+
33
57
  class FlashAttentionFunction(torch.autograd.Function):
34
58
  """Autograd function for Flash Attention with backward pass support."""
35
59
 
@@ -176,12 +200,13 @@ def replace_sdpa():
176
200
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
177
201
  is_causal=False, scale=None):
178
202
  # Use MFA for MPS tensors without dropout
179
- # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
203
+ # Only use MFA for seq_len >= 1536 where it outperforms PyTorch's math backend
180
204
  # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
205
+ # Benchmark (BF16, heads=30, head_dim=128): crossover is ~1200-1500
181
206
  if (query.device.type == 'mps' and
182
207
  dropout_p == 0.0 and
183
208
  _HAS_MFA and
184
- query.shape[2] >= 1024):
209
+ query.shape[2] >= 1536):
185
210
  try:
186
211
  # Convert float mask to bool mask if needed
187
212
  # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
@@ -194,7 +219,8 @@ def replace_sdpa():
194
219
  else:
195
220
  # Float mask: typically -inf for masked positions, 0 for unmasked
196
221
  # Convert: positions with large negative values -> True (masked)
197
- mfa_mask = attn_mask < -1e4
222
+ # Use -1e3 threshold to catch -1000, -10000, -inf, etc.
223
+ mfa_mask = attn_mask <= -1e3
198
224
  return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
199
225
  except Exception:
200
226
  # Fall back to original on any error
@@ -21,6 +21,7 @@
21
21
 
22
22
  typedef bool (*mfa_init_fn)();
23
23
  typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool);
24
+ typedef void* (*mfa_create_kernel_v2_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool);
24
25
  // New zero-sync encode functions that take PyTorch's command encoder
25
26
  // Added mask_ptr and mask_offset parameters
26
27
  typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
@@ -41,6 +42,7 @@ typedef void (*mfa_release_kernel_fn)(void*);
41
42
  // Global function pointers
42
43
  static mfa_init_fn g_mfa_init = nullptr;
43
44
  static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
45
+ static mfa_create_kernel_v2_fn g_mfa_create_kernel_v2 = nullptr;
44
46
  static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
45
47
  static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
46
48
  static mfa_forward_fn g_mfa_forward = nullptr;
@@ -99,6 +101,7 @@ static bool load_mfa_bridge() {
99
101
  // Load function pointers
100
102
  g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
101
103
  g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
104
+ g_mfa_create_kernel_v2 = (mfa_create_kernel_v2_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v2");
102
105
  g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
103
106
  g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
104
107
  g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
@@ -148,6 +151,7 @@ struct KernelCacheKey {
148
151
  bool low_precision_outputs;
149
152
  bool causal;
150
153
  bool has_mask;
154
+ bool use_bf16;
151
155
 
152
156
  bool operator==(const KernelCacheKey& other) const {
153
157
  return seq_len_q == other.seq_len_q &&
@@ -156,7 +160,8 @@ struct KernelCacheKey {
156
160
  low_precision == other.low_precision &&
157
161
  low_precision_outputs == other.low_precision_outputs &&
158
162
  causal == other.causal &&
159
- has_mask == other.has_mask;
163
+ has_mask == other.has_mask &&
164
+ use_bf16 == other.use_bf16;
160
165
  }
161
166
  };
162
167
 
@@ -168,29 +173,46 @@ struct KernelCacheKeyHash {
168
173
  (std::hash<bool>()(k.low_precision) << 3) ^
169
174
  (std::hash<bool>()(k.low_precision_outputs) << 4) ^
170
175
  (std::hash<bool>()(k.causal) << 5) ^
171
- (std::hash<bool>()(k.has_mask) << 6);
176
+ (std::hash<bool>()(k.has_mask) << 6) ^
177
+ (std::hash<bool>()(k.use_bf16) << 7);
172
178
  }
173
179
  };
174
180
 
175
181
  static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
176
182
 
177
- static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask) {
178
- KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask};
183
+ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false) {
184
+ KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16};
179
185
 
180
186
  auto it = g_kernel_cache.find(key);
181
187
  if (it != g_kernel_cache.end()) {
182
188
  return it->second;
183
189
  }
184
190
 
185
- void* kernel = g_mfa_create_kernel(
186
- static_cast<int32_t>(seq_q),
187
- static_cast<int32_t>(seq_kv),
188
- static_cast<int32_t>(head_dim),
189
- low_prec,
190
- low_prec_outputs,
191
- causal,
192
- has_mask
193
- );
191
+ void* kernel = nullptr;
192
+ if (use_bf16 && g_mfa_create_kernel_v2) {
193
+ // Use v2 API with BF16 support
194
+ kernel = g_mfa_create_kernel_v2(
195
+ static_cast<int32_t>(seq_q),
196
+ static_cast<int32_t>(seq_kv),
197
+ static_cast<int32_t>(head_dim),
198
+ low_prec,
199
+ low_prec_outputs,
200
+ causal,
201
+ has_mask,
202
+ use_bf16
203
+ );
204
+ } else {
205
+ // Legacy API
206
+ kernel = g_mfa_create_kernel(
207
+ static_cast<int32_t>(seq_q),
208
+ static_cast<int32_t>(seq_kv),
209
+ static_cast<int32_t>(head_dim),
210
+ low_prec,
211
+ low_prec_outputs,
212
+ causal,
213
+ has_mask
214
+ );
215
+ }
194
216
 
195
217
  if (!kernel) {
196
218
  throw std::runtime_error("Failed to create MFA kernel");
@@ -261,18 +283,27 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
261
283
  v_expanded = value;
262
284
  }
263
285
 
264
- // Determine precision
265
- bool low_precision = (query.scalar_type() == at::kHalf ||
266
- query.scalar_type() == at::kBFloat16);
286
+ // Determine precision - MFA kernel supports FP16, BF16, and FP32
287
+ bool is_bfloat16 = (query.scalar_type() == at::kBFloat16);
288
+ bool is_fp16 = (query.scalar_type() == at::kHalf);
267
289
 
268
- // For fp16 inputs, we can now output directly to fp16 (no extra conversion needed!)
269
- bool low_precision_outputs = low_precision;
290
+ // Use native BF16 kernel if available, otherwise fall back to FP32
291
+ bool use_bf16_kernel = is_bfloat16 && g_mfa_create_kernel_v2;
292
+ bool low_precision = is_fp16; // FP16 path
293
+ bool low_precision_outputs = is_fp16 || use_bf16_kernel;
270
294
 
271
295
  // Make inputs contiguous
272
296
  auto q = query.contiguous();
273
297
  auto k = k_expanded.contiguous();
274
298
  auto v = v_expanded.contiguous();
275
299
 
300
+ // For BF16 without native kernel support, convert to FP32 (avoids FP16 overflow)
301
+ if (is_bfloat16 && !use_bf16_kernel) {
302
+ q = q.to(at::kFloat);
303
+ k = k.to(at::kFloat);
304
+ v = v.to(at::kFloat);
305
+ }
306
+
276
307
  // Handle attention mask
277
308
  bool has_mask = attn_mask.has_value();
278
309
  at::Tensor mask;
@@ -287,20 +318,28 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
287
318
  }
288
319
  TORCH_CHECK(mask.scalar_type() == at::kByte,
289
320
  "Attention mask must be bool or uint8");
290
- // Expand mask heads if needed (B, 1, N_q, N_kv) -> (B, H, N_q, N_kv)
291
- if (mask.size(1) == 1 && num_heads > 1) {
321
+ // Expand mask dimensions if needed -> (B, H, N_q, N_kv)
322
+ // Handle (B, 1, N_q, N_kv) -> expand heads
323
+ // Handle (B, H, 1, N_kv) -> expand query dim (1D key mask)
324
+ // Handle (B, 1, 1, N_kv) -> expand both
325
+ if (mask.size(1) == 1 || mask.size(2) == 1) {
292
326
  mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
293
327
  }
294
328
  mask = mask.contiguous();
295
329
  }
296
330
 
297
331
  // Allocate output in the appropriate precision
298
- // With lowPrecisionOutputs=true, MFA writes FP16 directly
299
332
  at::Tensor output;
300
- if (low_precision_outputs) {
333
+ if (use_bf16_kernel) {
334
+ // Native BF16 kernel outputs BF16
335
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
336
+ query.options().dtype(at::kBFloat16));
337
+ } else if (low_precision_outputs) {
338
+ // FP16 kernel outputs FP16
301
339
  output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
302
340
  query.options().dtype(at::kHalf));
303
341
  } else {
342
+ // FP32 kernel outputs FP32
304
343
  output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
305
344
  query.options().dtype(at::kFloat));
306
345
  }
@@ -309,8 +348,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
309
348
  auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
310
349
  query.options().dtype(at::kFloat));
311
350
 
312
- // Get or create kernel with matching output precision and causal mode
313
- void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask);
351
+ // Get or create kernel with matching precision and causal mode
352
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask, use_bf16_kernel);
314
353
 
315
354
  // Get Metal buffers with byte offsets
316
355
  auto q_info = getBufferInfo(q);
@@ -361,7 +400,12 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
361
400
  // The encoder stays open for coalescing more kernels
362
401
  }
363
402
 
364
- // Output is already in the correct dtype (fp16 or fp32)
403
+ // Convert output back to BF16 if input was BF16 and we used FP32 fallback
404
+ // (native BF16 kernel already outputs BF16, no conversion needed)
405
+ if (is_bfloat16 && !use_bf16_kernel) {
406
+ output = output.to(at::kBFloat16);
407
+ }
408
+
365
409
  // Return both output and logsumexp (needed for backward pass)
366
410
  return std::make_tuple(output, logsumexp);
367
411
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.8
3
+ Version: 0.1.14
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.1.8"
7
+ version = "0.1.14"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -4,6 +4,7 @@ Setup script for MPS Flash Attention
4
4
 
5
5
  import os
6
6
  import sys
7
+ import shutil
7
8
  from setuptools import setup, find_packages, Extension
8
9
  from setuptools.command.build_ext import build_ext
9
10
 
@@ -36,6 +37,26 @@ class ObjCppBuildExt(build_ext):
36
37
 
37
38
  super().build_extensions()
38
39
 
40
+ # Copy libMFABridge.dylib to lib/ after building
41
+ self._copy_swift_bridge()
42
+
43
+ def _copy_swift_bridge(self):
44
+ """Copy Swift bridge dylib to package lib/ directory."""
45
+ src_path = os.path.join(
46
+ os.path.dirname(__file__),
47
+ "swift-bridge", ".build", "release", "libMFABridge.dylib"
48
+ )
49
+ dst_dir = os.path.join(os.path.dirname(__file__), "mps_flash_attn", "lib")
50
+ dst_path = os.path.join(dst_dir, "libMFABridge.dylib")
51
+
52
+ if os.path.exists(src_path):
53
+ os.makedirs(dst_dir, exist_ok=True)
54
+ shutil.copy2(src_path, dst_path)
55
+ print(f"Copied libMFABridge.dylib to {dst_path}")
56
+ else:
57
+ print(f"Warning: {src_path} not found. Build swift-bridge first with:")
58
+ print(" cd swift-bridge && swift build -c release")
59
+
39
60
 
40
61
  def get_extensions():
41
62
  if sys.platform != "darwin":
File without changes