mps-flash-attn 0.1.4__tar.gz → 0.1.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (39) hide show
  1. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/PKG-INFO +13 -3
  2. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/README.md +12 -2
  3. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/__init__.py +87 -46
  4. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/csrc/mps_flash_attn.mm +244 -72
  5. mps_flash_attn-0.1.13/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  6. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/PKG-INFO +13 -3
  7. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/pyproject.toml +1 -1
  8. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/setup.py +23 -2
  9. mps_flash_attn-0.1.4/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  10. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/LICENSE +0 -0
  11. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  12. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  13. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  14. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  15. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  16. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  17. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  18. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  19. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  20. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  21. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  22. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  23. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  24. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  25. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  26. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  27. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  28. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  29. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  30. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  31. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  32. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  33. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/manifest.json +0 -0
  34. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
  35. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  36. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/requires.txt +0 -0
  37. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/top_level.txt +0 -0
  38. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/setup.cfg +0 -0
  39. {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/tests/test_attention.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.4
3
+ Version: 0.1.13
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
32
32
  ## Features
33
33
 
34
34
  - **Forward pass**: 2-5x faster than PyTorch SDPA
35
- - **Backward pass**: Full gradient support for training
35
+ - **Backward pass**: Full gradient support for training (fp32 precision)
36
36
  - **Causal masking**: Native kernel support (only 5% overhead)
37
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
37
38
  - **FP16/FP32**: Native fp16 output (no conversion overhead)
38
39
  - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
39
40
 
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
98
99
  out = flash_attention(q, k, v, is_causal=True)
99
100
  ```
100
101
 
102
+ ### Attention masks (for custom masking patterns)
103
+
104
+ ```python
105
+ # Boolean mask: True = masked (don't attend), False = attend
106
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
107
+ mask[:, :, :, 512:] = True # Mask out positions after 512
108
+
109
+ out = flash_attention(q, k, v, attn_mask=mask)
110
+ ```
111
+
101
112
  ### Training with gradients
102
113
 
103
114
  ```python
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
247
258
  **Known limitations:**
248
259
  - Sequence length must be divisible by block size (typically 64)
249
260
  - Head dimension: Best with 32, 64, 96, 128
250
- - No arbitrary attention masks (only causal or none)
251
261
  - No dropout
252
262
 
253
263
  ## Credits
@@ -7,8 +7,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
7
7
  ## Features
8
8
 
9
9
  - **Forward pass**: 2-5x faster than PyTorch SDPA
10
- - **Backward pass**: Full gradient support for training
10
+ - **Backward pass**: Full gradient support for training (fp32 precision)
11
11
  - **Causal masking**: Native kernel support (only 5% overhead)
12
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
12
13
  - **FP16/FP32**: Native fp16 output (no conversion overhead)
13
14
  - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
14
15
 
@@ -73,6 +74,16 @@ out = flash_attention(q, k, v)
73
74
  out = flash_attention(q, k, v, is_causal=True)
74
75
  ```
75
76
 
77
+ ### Attention masks (for custom masking patterns)
78
+
79
+ ```python
80
+ # Boolean mask: True = masked (don't attend), False = attend
81
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
82
+ mask[:, :, :, 512:] = True # Mask out positions after 512
83
+
84
+ out = flash_attention(q, k, v, attn_mask=mask)
85
+ ```
86
+
76
87
  ### Training with gradients
77
88
 
78
89
  ```python
@@ -222,7 +233,6 @@ python scripts/build_metallibs.py
222
233
  **Known limitations:**
223
234
  - Sequence length must be divisible by block size (typically 64)
224
235
  - Head dimension: Best with 32, 64, 96, 128
225
- - No arbitrary attention masks (only causal or none)
226
236
  - No dropout
227
237
 
228
238
  ## Credits
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.1.4"
7
+ __version__ = "0.1.13"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -20,39 +20,9 @@ except ImportError as e:
20
20
  _HAS_MFA = False
21
21
  _IMPORT_ERROR = str(e)
22
22
 
23
- # Set up shipped kernels directory for zero-compilation loading
24
- def _init_shipped_kernels():
25
- """Point the Swift bridge to pre-shipped kernel binaries."""
26
- try:
27
- import ctypes
28
- bridge_path = os.environ.get("MFA_BRIDGE_PATH")
29
- if not bridge_path:
30
- module_dir = os.path.dirname(__file__)
31
- candidates = [
32
- os.path.join(module_dir, "lib", "libMFABridge.dylib"), # Bundled in wheel
33
- os.path.join(module_dir, "..", "swift-bridge", ".build", "release", "libMFABridge.dylib"),
34
- os.path.join(module_dir, "libMFABridge.dylib"),
35
- ]
36
- for path in candidates:
37
- if os.path.exists(path):
38
- bridge_path = path
39
- break
40
-
41
- if bridge_path and os.path.exists(bridge_path):
42
- lib = ctypes.CDLL(bridge_path)
43
-
44
- # Set shipped kernels directory (pre-compiled metallibs + pipeline binaries)
45
- kernels_dir = os.path.join(os.path.dirname(__file__), "kernels")
46
- if os.path.exists(kernels_dir):
47
- lib.mfa_set_kernels_dir(kernels_dir.encode('utf-8'))
48
-
49
- lib.mfa_init()
50
- except Exception:
51
- pass # Init is optional, will fall back to runtime compilation
52
-
53
- # Initialize shipped kernels on import
54
- if _HAS_MFA:
55
- _init_shipped_kernels()
23
+ # Note: The C++ extension handles loading libMFABridge.dylib via dlopen.
24
+ # Set MFA_BRIDGE_PATH environment variable to specify the library location.
25
+ # Do NOT load the library here via ctypes - that causes duplicate class warnings.
56
26
 
57
27
 
58
28
  def is_available() -> bool:
@@ -60,11 +30,35 @@ def is_available() -> bool:
60
30
  return _HAS_MFA and torch.backends.mps.is_available()
61
31
 
62
32
 
33
+ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
34
+ """
35
+ Convert attention mask to MFA's boolean format.
36
+
37
+ MFA uses boolean masks where True = masked (don't attend).
38
+ PyTorch SDPA uses additive float masks where -inf/large negative = masked.
39
+
40
+ Args:
41
+ attn_mask: Optional mask, either:
42
+ - None: no mask
43
+ - bool tensor: already in MFA format (True = masked)
44
+ - float tensor: additive mask (large negative = masked)
45
+
46
+ Returns:
47
+ Boolean mask suitable for flash_attention(), or None
48
+ """
49
+ if attn_mask is None:
50
+ return None
51
+ if attn_mask.dtype == torch.bool:
52
+ return attn_mask
53
+ # Float mask: large negative values indicate masked positions
54
+ return attn_mask <= -1e3
55
+
56
+
63
57
  class FlashAttentionFunction(torch.autograd.Function):
64
58
  """Autograd function for Flash Attention with backward pass support."""
65
59
 
66
60
  @staticmethod
67
- def forward(ctx, query, key, value, is_causal, scale):
61
+ def forward(ctx, query, key, value, is_causal, scale, attn_mask):
68
62
  # Apply scale if provided (MFA uses 1/sqrt(D) internally)
69
63
  scale_factor = 1.0
70
64
  if scale is not None:
@@ -74,10 +68,15 @@ class FlashAttentionFunction(torch.autograd.Function):
74
68
  query = query * scale_factor
75
69
 
76
70
  # Forward with logsumexp for backward
77
- output, logsumexp = _C.forward_with_lse(query, key, value, is_causal)
71
+ output, logsumexp = _C.forward_with_lse(query, key, value, is_causal, attn_mask)
78
72
 
79
73
  # Save for backward
80
- ctx.save_for_backward(query, key, value, output, logsumexp)
74
+ if attn_mask is not None:
75
+ ctx.save_for_backward(query, key, value, output, logsumexp, attn_mask)
76
+ ctx.has_mask = True
77
+ else:
78
+ ctx.save_for_backward(query, key, value, output, logsumexp)
79
+ ctx.has_mask = False
81
80
  ctx.is_causal = is_causal
82
81
  ctx.scale_factor = scale_factor
83
82
 
@@ -85,19 +84,23 @@ class FlashAttentionFunction(torch.autograd.Function):
85
84
 
86
85
  @staticmethod
87
86
  def backward(ctx, grad_output):
88
- query, key, value, output, logsumexp = ctx.saved_tensors
87
+ if ctx.has_mask:
88
+ query, key, value, output, logsumexp, attn_mask = ctx.saved_tensors
89
+ else:
90
+ query, key, value, output, logsumexp = ctx.saved_tensors
91
+ attn_mask = None
89
92
 
90
93
  # Compute gradients
91
94
  dQ, dK, dV = _C.backward(
92
- grad_output, query, key, value, output, logsumexp, ctx.is_causal
95
+ grad_output, query, key, value, output, logsumexp, ctx.is_causal, attn_mask
93
96
  )
94
97
 
95
98
  # If we scaled the query in forward, scale the gradient back
96
99
  if ctx.scale_factor != 1.0:
97
100
  dQ = dQ * ctx.scale_factor
98
101
 
99
- # Return gradients (None for is_causal and scale since they're not tensors)
100
- return dQ, dK, dV, None, None
102
+ # Return gradients (None for is_causal, scale, and attn_mask since they're not tensors or don't need grad)
103
+ return dQ, dK, dV, None, None, None
101
104
 
102
105
 
103
106
  def flash_attention(
@@ -106,6 +109,7 @@ def flash_attention(
106
109
  value: torch.Tensor,
107
110
  is_causal: bool = False,
108
111
  scale: Optional[float] = None,
112
+ attn_mask: Optional[torch.Tensor] = None,
109
113
  ) -> torch.Tensor:
110
114
  """
111
115
  Compute scaled dot-product attention using Flash Attention on MPS.
@@ -121,6 +125,9 @@ def flash_attention(
121
125
  value: Value tensor of shape (B, num_heads, seq_len, head_dim)
122
126
  is_causal: If True, applies causal masking (for autoregressive models)
123
127
  scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
128
+ attn_mask: Optional boolean attention mask of shape (B, 1, seq_len_q, seq_len_kv)
129
+ or (B, num_heads, seq_len_q, seq_len_kv). True values indicate
130
+ positions to be masked (not attended to).
124
131
 
125
132
  Returns:
126
133
  Output tensor of shape (B, num_heads, seq_len, head_dim)
@@ -137,6 +144,11 @@ def flash_attention(
137
144
  >>> q.requires_grad = True
138
145
  >>> out = flash_attention(q, k, v)
139
146
  >>> out.sum().backward() # Computes dQ
147
+
148
+ # With attention mask:
149
+ >>> mask = torch.zeros(2, 1, 4096, 4096, dtype=torch.bool, device='mps')
150
+ >>> mask[:, :, :, 2048:] = True # mask out second half of keys
151
+ >>> out = flash_attention(q, k, v, attn_mask=mask)
140
152
  """
141
153
  if not _HAS_MFA:
142
154
  raise RuntimeError(
@@ -154,9 +166,23 @@ def flash_attention(
154
166
  raise ValueError("key must be on MPS device")
155
167
  if value.device.type != 'mps':
156
168
  raise ValueError("value must be on MPS device")
169
+ if attn_mask is not None and attn_mask.device.type != 'mps':
170
+ raise ValueError("attn_mask must be on MPS device")
171
+
172
+ # Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
173
+ if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
174
+ # Apply scale if provided
175
+ if scale is not None:
176
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
177
+ if abs(scale - default_scale) > 1e-6:
178
+ scale_factor = scale / default_scale
179
+ query = query * scale_factor
180
+
181
+ # Forward only - no logsumexp needed, no tensors saved
182
+ return _C.forward(query, key, value, is_causal, attn_mask)
157
183
 
158
184
  # Use autograd function for gradient support
159
- return FlashAttentionFunction.apply(query, key, value, is_causal, scale)
185
+ return FlashAttentionFunction.apply(query, key, value, is_causal, scale, attn_mask)
160
186
 
161
187
 
162
188
  def replace_sdpa():
@@ -173,13 +199,28 @@ def replace_sdpa():
173
199
 
174
200
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
175
201
  is_causal=False, scale=None):
176
- # Use MFA for MPS tensors without mask/dropout
202
+ # Use MFA for MPS tensors without dropout
203
+ # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
204
+ # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
177
205
  if (query.device.type == 'mps' and
178
- attn_mask is None and
179
206
  dropout_p == 0.0 and
180
- _HAS_MFA):
207
+ _HAS_MFA and
208
+ query.shape[2] >= 1024):
181
209
  try:
182
- return flash_attention(query, key, value, is_causal=is_causal, scale=scale)
210
+ # Convert float mask to bool mask if needed
211
+ # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
212
+ # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
213
+ mfa_mask = None
214
+ if attn_mask is not None:
215
+ if attn_mask.dtype == torch.bool:
216
+ # Boolean mask: True means masked (don't attend)
217
+ mfa_mask = attn_mask
218
+ else:
219
+ # Float mask: typically -inf for masked positions, 0 for unmasked
220
+ # Convert: positions with large negative values -> True (masked)
221
+ # Use -1e3 threshold to catch -1000, -10000, -inf, etc.
222
+ mfa_mask = attn_mask <= -1e3
223
+ return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
183
224
  except Exception:
184
225
  # Fall back to original on any error
185
226
  pass
@@ -12,22 +12,39 @@
12
12
  #import <Foundation/Foundation.h>
13
13
 
14
14
  #include <dlfcn.h>
15
+ #include <string>
16
+ #include <vector>
15
17
 
16
18
  // ============================================================================
17
19
  // MFA Bridge Function Types
18
20
  // ============================================================================
19
21
 
20
22
  typedef bool (*mfa_init_fn)();
21
- typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool); // Added causal param
22
- typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
23
- typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
24
- int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
23
+ typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool);
24
+ typedef void* (*mfa_create_kernel_v2_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool);
25
+ // New zero-sync encode functions that take PyTorch's command encoder
26
+ // Added mask_ptr and mask_offset parameters
27
+ typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
28
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
29
+ int32_t, int32_t);
30
+ typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
31
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
32
+ int32_t, int32_t);
33
+ // Legacy sync functions (fallback)
34
+ typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, void*,
35
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
36
+ int32_t, int32_t);
37
+ typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
38
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
25
39
  int32_t, int32_t);
26
40
  typedef void (*mfa_release_kernel_fn)(void*);
27
41
 
28
42
  // Global function pointers
29
43
  static mfa_init_fn g_mfa_init = nullptr;
30
44
  static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
45
+ static mfa_create_kernel_v2_fn g_mfa_create_kernel_v2 = nullptr;
46
+ static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
47
+ static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
31
48
  static mfa_forward_fn g_mfa_forward = nullptr;
32
49
  static mfa_backward_fn g_mfa_backward = nullptr;
33
50
  static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
@@ -38,31 +55,44 @@ static bool g_initialized = false;
38
55
  // Load MFA Bridge Library
39
56
  // ============================================================================
40
57
 
58
+ // Get the directory containing this shared library
59
+ static std::string get_module_dir() {
60
+ Dl_info info;
61
+ if (dladdr((void*)get_module_dir, &info) && info.dli_fname) {
62
+ std::string path(info.dli_fname);
63
+ size_t last_slash = path.rfind('/');
64
+ if (last_slash != std::string::npos) {
65
+ return path.substr(0, last_slash);
66
+ }
67
+ }
68
+ return ".";
69
+ }
70
+
41
71
  static bool load_mfa_bridge() {
42
72
  if (g_dylib_handle) return true;
43
73
 
44
- // Try to find the dylib relative to this extension
45
- // First try the standard location
46
- const char* paths[] = {
47
- "libMFABridge.dylib",
48
- "./libMFABridge.dylib",
49
- "../swift-bridge/.build/release/libMFABridge.dylib",
50
- nullptr
74
+ // First check environment variable (highest priority)
75
+ const char* mfa_path = getenv("MFA_BRIDGE_PATH");
76
+ if (mfa_path) {
77
+ g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
78
+ if (g_dylib_handle) return true;
79
+ }
80
+
81
+ // Get the directory containing this extension module
82
+ std::string module_dir = get_module_dir();
83
+
84
+ // Try paths relative to the module directory
85
+ std::vector<std::string> paths = {
86
+ module_dir + "/lib/libMFABridge.dylib", // Bundled in wheel
87
+ module_dir + "/../swift-bridge/.build/release/libMFABridge.dylib", // Dev build
88
+ "libMFABridge.dylib", // Current directory fallback
51
89
  };
52
90
 
53
- for (int i = 0; paths[i] != nullptr; i++) {
54
- g_dylib_handle = dlopen(paths[i], RTLD_NOW);
91
+ for (const auto& path : paths) {
92
+ g_dylib_handle = dlopen(path.c_str(), RTLD_NOW);
55
93
  if (g_dylib_handle) break;
56
94
  }
57
95
 
58
- if (!g_dylib_handle) {
59
- // Try with absolute path from environment
60
- const char* mfa_path = getenv("MFA_BRIDGE_PATH");
61
- if (mfa_path) {
62
- g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
63
- }
64
- }
65
-
66
96
  if (!g_dylib_handle) {
67
97
  throw std::runtime_error(
68
98
  "Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
@@ -71,11 +101,15 @@ static bool load_mfa_bridge() {
71
101
  // Load function pointers
72
102
  g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
73
103
  g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
104
+ g_mfa_create_kernel_v2 = (mfa_create_kernel_v2_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v2");
105
+ g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
106
+ g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
74
107
  g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
75
108
  g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
76
109
  g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
77
110
 
78
- if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward || !g_mfa_backward || !g_mfa_release_kernel) {
111
+ // Require at least init, create_kernel, forward_encode (for zero-sync path)
112
+ if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward_encode) {
79
113
  throw std::runtime_error("Failed to load MFA bridge functions");
80
114
  }
81
115
 
@@ -116,6 +150,8 @@ struct KernelCacheKey {
116
150
  bool low_precision;
117
151
  bool low_precision_outputs;
118
152
  bool causal;
153
+ bool has_mask;
154
+ bool use_bf16;
119
155
 
120
156
  bool operator==(const KernelCacheKey& other) const {
121
157
  return seq_len_q == other.seq_len_q &&
@@ -123,7 +159,9 @@ struct KernelCacheKey {
123
159
  head_dim == other.head_dim &&
124
160
  low_precision == other.low_precision &&
125
161
  low_precision_outputs == other.low_precision_outputs &&
126
- causal == other.causal;
162
+ causal == other.causal &&
163
+ has_mask == other.has_mask &&
164
+ use_bf16 == other.use_bf16;
127
165
  }
128
166
  };
129
167
 
@@ -134,28 +172,47 @@ struct KernelCacheKeyHash {
134
172
  (std::hash<int64_t>()(k.head_dim) << 2) ^
135
173
  (std::hash<bool>()(k.low_precision) << 3) ^
136
174
  (std::hash<bool>()(k.low_precision_outputs) << 4) ^
137
- (std::hash<bool>()(k.causal) << 5);
175
+ (std::hash<bool>()(k.causal) << 5) ^
176
+ (std::hash<bool>()(k.has_mask) << 6) ^
177
+ (std::hash<bool>()(k.use_bf16) << 7);
138
178
  }
139
179
  };
140
180
 
141
181
  static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
142
182
 
143
- static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal) {
144
- KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal};
183
+ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false) {
184
+ KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16};
145
185
 
146
186
  auto it = g_kernel_cache.find(key);
147
187
  if (it != g_kernel_cache.end()) {
148
188
  return it->second;
149
189
  }
150
190
 
151
- void* kernel = g_mfa_create_kernel(
152
- static_cast<int32_t>(seq_q),
153
- static_cast<int32_t>(seq_kv),
154
- static_cast<int32_t>(head_dim),
155
- low_prec,
156
- low_prec_outputs,
157
- causal
158
- );
191
+ void* kernel = nullptr;
192
+ if (use_bf16 && g_mfa_create_kernel_v2) {
193
+ // Use v2 API with BF16 support
194
+ kernel = g_mfa_create_kernel_v2(
195
+ static_cast<int32_t>(seq_q),
196
+ static_cast<int32_t>(seq_kv),
197
+ static_cast<int32_t>(head_dim),
198
+ low_prec,
199
+ low_prec_outputs,
200
+ causal,
201
+ has_mask,
202
+ use_bf16
203
+ );
204
+ } else {
205
+ // Legacy API
206
+ kernel = g_mfa_create_kernel(
207
+ static_cast<int32_t>(seq_q),
208
+ static_cast<int32_t>(seq_kv),
209
+ static_cast<int32_t>(head_dim),
210
+ low_prec,
211
+ low_prec_outputs,
212
+ causal,
213
+ has_mask
214
+ );
215
+ }
159
216
 
160
217
  if (!kernel) {
161
218
  throw std::runtime_error("Failed to create MFA kernel");
@@ -173,7 +230,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
173
230
  const at::Tensor& query, // (B, H, N, D)
174
231
  const at::Tensor& key, // (B, H, N, D)
175
232
  const at::Tensor& value, // (B, H, N, D)
176
- bool is_causal
233
+ bool is_causal,
234
+ const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
177
235
  ) {
178
236
  // Initialize MFA on first call
179
237
  if (!g_initialized) {
@@ -193,37 +251,95 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
193
251
  TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
194
252
 
195
253
  const int64_t batch_size = query.size(0);
196
- const int64_t num_heads = query.size(1);
254
+ const int64_t num_heads_q = query.size(1);
255
+ const int64_t num_heads_kv = key.size(1);
197
256
  const int64_t seq_len_q = query.size(2);
198
257
  const int64_t head_dim = query.size(3);
199
258
  const int64_t seq_len_kv = key.size(2);
200
259
 
201
260
  TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
202
261
  "Batch size mismatch");
203
- TORCH_CHECK(key.size(1) == num_heads && value.size(1) == num_heads,
204
- "Number of heads mismatch");
205
262
  TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
206
263
  "Head dimension mismatch");
264
+ TORCH_CHECK(key.size(1) == value.size(1),
265
+ "K and V must have same number of heads");
266
+
267
+ // Handle GQA (Grouped Query Attention): expand K/V if fewer heads than Q
268
+ const int64_t num_heads = num_heads_q;
269
+ at::Tensor k_expanded, v_expanded;
270
+
271
+ if (num_heads_kv != num_heads_q) {
272
+ // GQA: num_heads_q must be divisible by num_heads_kv
273
+ TORCH_CHECK(num_heads_q % num_heads_kv == 0,
274
+ "num_heads_q (", num_heads_q, ") must be divisible by num_heads_kv (", num_heads_kv, ")");
275
+ int64_t repeat_factor = num_heads_q / num_heads_kv;
276
+
277
+ // Expand K and V to match Q's head count: (B, H_kv, S, D) -> (B, H_q, S, D)
278
+ // Use repeat_interleave for proper GQA expansion
279
+ k_expanded = key.repeat_interleave(repeat_factor, /*dim=*/1);
280
+ v_expanded = value.repeat_interleave(repeat_factor, /*dim=*/1);
281
+ } else {
282
+ k_expanded = key;
283
+ v_expanded = value;
284
+ }
207
285
 
208
- // Determine precision
209
- bool low_precision = (query.scalar_type() == at::kHalf ||
210
- query.scalar_type() == at::kBFloat16);
286
+ // Determine precision - MFA kernel supports FP16, BF16, and FP32
287
+ bool is_bfloat16 = (query.scalar_type() == at::kBFloat16);
288
+ bool is_fp16 = (query.scalar_type() == at::kHalf);
211
289
 
212
- // For fp16 inputs, we can now output directly to fp16 (no extra conversion needed!)
213
- bool low_precision_outputs = low_precision;
290
+ // Use native BF16 kernel if available, otherwise fall back to FP32
291
+ bool use_bf16_kernel = is_bfloat16 && g_mfa_create_kernel_v2;
292
+ bool low_precision = is_fp16; // FP16 path
293
+ bool low_precision_outputs = is_fp16 || use_bf16_kernel;
214
294
 
215
295
  // Make inputs contiguous
216
296
  auto q = query.contiguous();
217
- auto k = key.contiguous();
218
- auto v = value.contiguous();
297
+ auto k = k_expanded.contiguous();
298
+ auto v = v_expanded.contiguous();
299
+
300
+ // For BF16 without native kernel support, convert to FP32 (avoids FP16 overflow)
301
+ if (is_bfloat16 && !use_bf16_kernel) {
302
+ q = q.to(at::kFloat);
303
+ k = k.to(at::kFloat);
304
+ v = v.to(at::kFloat);
305
+ }
306
+
307
+ // Handle attention mask
308
+ bool has_mask = attn_mask.has_value();
309
+ at::Tensor mask;
310
+ if (has_mask) {
311
+ mask = attn_mask.value();
312
+ TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
313
+ TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
314
+ // Convert to bool/uint8 if needed - kernel expects uchar (0 = attend, non-0 = mask out)
315
+ if (mask.scalar_type() == at::kBool) {
316
+ // Convert bool to uint8 for Metal compatibility
317
+ mask = mask.to(at::kByte);
318
+ }
319
+ TORCH_CHECK(mask.scalar_type() == at::kByte,
320
+ "Attention mask must be bool or uint8");
321
+ // Expand mask dimensions if needed -> (B, H, N_q, N_kv)
322
+ // Handle (B, 1, N_q, N_kv) -> expand heads
323
+ // Handle (B, H, 1, N_kv) -> expand query dim (1D key mask)
324
+ // Handle (B, 1, 1, N_kv) -> expand both
325
+ if (mask.size(1) == 1 || mask.size(2) == 1) {
326
+ mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
327
+ }
328
+ mask = mask.contiguous();
329
+ }
219
330
 
220
331
  // Allocate output in the appropriate precision
221
- // With lowPrecisionOutputs=true, MFA writes FP16 directly
222
332
  at::Tensor output;
223
- if (low_precision_outputs) {
333
+ if (use_bf16_kernel) {
334
+ // Native BF16 kernel outputs BF16
335
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
336
+ query.options().dtype(at::kBFloat16));
337
+ } else if (low_precision_outputs) {
338
+ // FP16 kernel outputs FP16
224
339
  output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
225
340
  query.options().dtype(at::kHalf));
226
341
  } else {
342
+ // FP32 kernel outputs FP32
227
343
  output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
228
344
  query.options().dtype(at::kFloat));
229
345
  }
@@ -232,8 +348,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
232
348
  auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
233
349
  query.options().dtype(at::kFloat));
234
350
 
235
- // Get or create kernel with matching output precision and causal mode
236
- void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
351
+ // Get or create kernel with matching precision and causal mode
352
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask, use_bf16_kernel);
237
353
 
238
354
  // Get Metal buffers with byte offsets
239
355
  auto q_info = getBufferInfo(q);
@@ -242,25 +358,36 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
242
358
  auto o_info = getBufferInfo(output);
243
359
  auto l_info = getBufferInfo(logsumexp);
244
360
 
245
- // Synchronize with PyTorch's MPS stream
361
+ // Mask buffer info (may be nullptr if no mask)
362
+ BufferInfo mask_info = {nil, 0};
363
+ if (has_mask) {
364
+ mask_info = getBufferInfo(mask);
365
+ }
366
+
367
+ // Use PyTorch's MPS stream command encoder for zero-sync integration
246
368
  @autoreleasepool {
247
- // Wait for PyTorch operations to complete
248
369
  auto stream = at::mps::getCurrentMPSStream();
249
- stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
250
370
 
251
- // Execute MFA with storage byte offsets
252
- bool success = g_mfa_forward(
371
+ // Get PyTorch's shared command encoder - this is the key for zero-sync!
372
+ // All our dispatches go onto the same encoder that PyTorch uses.
373
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
374
+
375
+ // Execute MFA using the shared encoder (no sync needed!)
376
+ bool success = g_mfa_forward_encode(
253
377
  kernel,
378
+ (__bridge void*)encoder, // PyTorch's shared command encoder
254
379
  (__bridge void*)q_info.buffer,
255
380
  (__bridge void*)k_info.buffer,
256
381
  (__bridge void*)v_info.buffer,
257
382
  (__bridge void*)o_info.buffer,
258
383
  (__bridge void*)l_info.buffer,
384
+ has_mask ? (__bridge void*)mask_info.buffer : nullptr,
259
385
  q_info.byte_offset,
260
386
  k_info.byte_offset,
261
387
  v_info.byte_offset,
262
388
  o_info.byte_offset,
263
389
  l_info.byte_offset,
390
+ mask_info.byte_offset,
264
391
  static_cast<int32_t>(batch_size),
265
392
  static_cast<int32_t>(num_heads)
266
393
  );
@@ -268,9 +395,17 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
268
395
  if (!success) {
269
396
  throw std::runtime_error("MFA forward pass failed");
270
397
  }
398
+
399
+ // No commit needed - PyTorch will commit when it needs the results
400
+ // The encoder stays open for coalescing more kernels
401
+ }
402
+
403
+ // Convert output back to BF16 if input was BF16 and we used FP32 fallback
404
+ // (native BF16 kernel already outputs BF16, no conversion needed)
405
+ if (is_bfloat16 && !use_bf16_kernel) {
406
+ output = output.to(at::kBFloat16);
271
407
  }
272
408
 
273
- // Output is already in the correct dtype (fp16 or fp32)
274
409
  // Return both output and logsumexp (needed for backward pass)
275
410
  return std::make_tuple(output, logsumexp);
276
411
  }
@@ -280,9 +415,10 @@ at::Tensor mps_flash_attention_forward(
280
415
  const at::Tensor& query,
281
416
  const at::Tensor& key,
282
417
  const at::Tensor& value,
283
- bool is_causal
418
+ bool is_causal,
419
+ const c10::optional<at::Tensor>& attn_mask
284
420
  ) {
285
- auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal);
421
+ auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal, attn_mask);
286
422
  return output;
287
423
  }
288
424
 
@@ -297,7 +433,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
297
433
  const at::Tensor& value, // (B, H, N, D)
298
434
  const at::Tensor& output, // (B, H, N, D)
299
435
  const at::Tensor& logsumexp, // (B, H, N)
300
- bool is_causal
436
+ bool is_causal,
437
+ const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
301
438
  ) {
302
439
  // Initialize MFA on first call
303
440
  if (!g_initialized) {
@@ -327,16 +464,35 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
327
464
  query.scalar_type() == at::kBFloat16);
328
465
  bool low_precision_outputs = low_precision;
329
466
 
330
- // Make inputs contiguous
331
- auto q = query.contiguous();
332
- auto k = key.contiguous();
333
- auto v = value.contiguous();
334
- auto o = output.contiguous();
335
- auto dO = grad_output.contiguous();
467
+ // Handle attention mask
468
+ bool has_mask = attn_mask.has_value();
469
+ at::Tensor mask;
470
+ if (has_mask) {
471
+ mask = attn_mask.value();
472
+ TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
473
+ TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
474
+ if (mask.scalar_type() == at::kBool) {
475
+ mask = mask.to(at::kByte);
476
+ }
477
+ TORCH_CHECK(mask.scalar_type() == at::kByte,
478
+ "Attention mask must be bool or uint8");
479
+ if (mask.size(1) == 1 && num_heads > 1) {
480
+ mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
481
+ }
482
+ mask = mask.contiguous();
483
+ }
484
+
485
+ // Make inputs contiguous and upcast to fp32 for numerical stability
486
+ // The backward pass accumulates many small values, so fp32 precision is critical
487
+ auto q = query.contiguous().to(at::kFloat);
488
+ auto k = key.contiguous().to(at::kFloat);
489
+ auto v = value.contiguous().to(at::kFloat);
490
+ auto o = output.contiguous().to(at::kFloat);
491
+ auto dO = grad_output.contiguous().to(at::kFloat);
336
492
  auto lse = logsumexp.contiguous();
337
493
 
338
- // Get or create kernel (with causal mode)
339
- void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
494
+ // Get or create kernel - always use fp32 for backward pass
495
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, false, false, is_causal, has_mask);
340
496
 
341
497
  // Allocate D buffer (dO * O reduction, always fp32)
342
498
  auto D = at::empty({batch_size, num_heads, seq_len_q},
@@ -362,13 +518,22 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
362
518
  auto dk_info = getBufferInfo(dK);
363
519
  auto dv_info = getBufferInfo(dV);
364
520
 
365
- // Execute backward pass
521
+ // Mask buffer info (may be nullptr if no mask)
522
+ BufferInfo mask_info = {nil, 0};
523
+ if (has_mask) {
524
+ mask_info = getBufferInfo(mask);
525
+ }
526
+
527
+ // Use PyTorch's MPS stream command encoder for zero-sync integration
366
528
  @autoreleasepool {
367
529
  auto stream = at::mps::getCurrentMPSStream();
368
- stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
369
530
 
370
- bool success = g_mfa_backward(
531
+ // Get PyTorch's shared command encoder
532
+ id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
533
+
534
+ bool success = g_mfa_backward_encode(
371
535
  kernel,
536
+ (__bridge void*)encoder, // PyTorch's shared command encoder
372
537
  (__bridge void*)q_info.buffer,
373
538
  (__bridge void*)k_info.buffer,
374
539
  (__bridge void*)v_info.buffer,
@@ -379,6 +544,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
379
544
  (__bridge void*)dq_info.buffer,
380
545
  (__bridge void*)dk_info.buffer,
381
546
  (__bridge void*)dv_info.buffer,
547
+ has_mask ? (__bridge void*)mask_info.buffer : nullptr,
382
548
  q_info.byte_offset,
383
549
  k_info.byte_offset,
384
550
  v_info.byte_offset,
@@ -389,6 +555,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
389
555
  dq_info.byte_offset,
390
556
  dk_info.byte_offset,
391
557
  dv_info.byte_offset,
558
+ mask_info.byte_offset,
392
559
  static_cast<int32_t>(batch_size),
393
560
  static_cast<int32_t>(num_heads)
394
561
  );
@@ -396,6 +563,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
396
563
  if (!success) {
397
564
  throw std::runtime_error("MFA backward pass failed");
398
565
  }
566
+
567
+ // No commit needed - PyTorch will commit when it needs the results
399
568
  }
400
569
 
401
570
  // Convert gradients back to input dtype if needed
@@ -420,14 +589,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
420
589
  py::arg("query"),
421
590
  py::arg("key"),
422
591
  py::arg("value"),
423
- py::arg("is_causal") = false);
592
+ py::arg("is_causal") = false,
593
+ py::arg("attn_mask") = py::none());
424
594
 
425
595
  m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
426
596
  "Flash Attention forward pass (returns output and logsumexp for backward)",
427
597
  py::arg("query"),
428
598
  py::arg("key"),
429
599
  py::arg("value"),
430
- py::arg("is_causal") = false);
600
+ py::arg("is_causal") = false,
601
+ py::arg("attn_mask") = py::none());
431
602
 
432
603
  m.def("backward", &mps_flash_attention_backward,
433
604
  "Flash Attention backward pass",
@@ -437,5 +608,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
437
608
  py::arg("value"),
438
609
  py::arg("output"),
439
610
  py::arg("logsumexp"),
440
- py::arg("is_causal") = false);
611
+ py::arg("is_causal") = false,
612
+ py::arg("attn_mask") = py::none());
441
613
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.4
3
+ Version: 0.1.13
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
32
32
  ## Features
33
33
 
34
34
  - **Forward pass**: 2-5x faster than PyTorch SDPA
35
- - **Backward pass**: Full gradient support for training
35
+ - **Backward pass**: Full gradient support for training (fp32 precision)
36
36
  - **Causal masking**: Native kernel support (only 5% overhead)
37
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
37
38
  - **FP16/FP32**: Native fp16 output (no conversion overhead)
38
39
  - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
39
40
 
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
98
99
  out = flash_attention(q, k, v, is_causal=True)
99
100
  ```
100
101
 
102
+ ### Attention masks (for custom masking patterns)
103
+
104
+ ```python
105
+ # Boolean mask: True = masked (don't attend), False = attend
106
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
107
+ mask[:, :, :, 512:] = True # Mask out positions after 512
108
+
109
+ out = flash_attention(q, k, v, attn_mask=mask)
110
+ ```
111
+
101
112
  ### Training with gradients
102
113
 
103
114
  ```python
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
247
258
  **Known limitations:**
248
259
  - Sequence length must be divisible by block size (typically 64)
249
260
  - Head dimension: Best with 32, 64, 96, 128
250
- - No arbitrary attention masks (only causal or none)
251
261
  - No dropout
252
262
 
253
263
  ## Credits
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.1.4"
7
+ version = "0.1.13"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -4,6 +4,7 @@ Setup script for MPS Flash Attention
4
4
 
5
5
  import os
6
6
  import sys
7
+ import shutil
7
8
  from setuptools import setup, find_packages, Extension
8
9
  from setuptools.command.build_ext import build_ext
9
10
 
@@ -36,6 +37,26 @@ class ObjCppBuildExt(build_ext):
36
37
 
37
38
  super().build_extensions()
38
39
 
40
+ # Copy libMFABridge.dylib to lib/ after building
41
+ self._copy_swift_bridge()
42
+
43
+ def _copy_swift_bridge(self):
44
+ """Copy Swift bridge dylib to package lib/ directory."""
45
+ src_path = os.path.join(
46
+ os.path.dirname(__file__),
47
+ "swift-bridge", ".build", "release", "libMFABridge.dylib"
48
+ )
49
+ dst_dir = os.path.join(os.path.dirname(__file__), "mps_flash_attn", "lib")
50
+ dst_path = os.path.join(dst_dir, "libMFABridge.dylib")
51
+
52
+ if os.path.exists(src_path):
53
+ os.makedirs(dst_dir, exist_ok=True)
54
+ shutil.copy2(src_path, dst_path)
55
+ print(f"Copied libMFABridge.dylib to {dst_path}")
56
+ else:
57
+ print(f"Warning: {src_path} not found. Build swift-bridge first with:")
58
+ print(" cd swift-bridge && swift build -c release")
59
+
39
60
 
40
61
  def get_extensions():
41
62
  if sys.platform != "darwin":
@@ -44,14 +65,14 @@ def get_extensions():
44
65
  return [Extension(
45
66
  name="mps_flash_attn._C",
46
67
  sources=["mps_flash_attn/csrc/mps_flash_attn.mm"],
47
- extra_compile_args=["-std=c++17", "-O3"],
68
+ extra_compile_args=["-std=c++17", "-O3", "-DTORCH_EXTENSION_NAME=_C"],
48
69
  extra_link_args=["-framework", "Metal", "-framework", "Foundation"],
49
70
  )]
50
71
 
51
72
 
52
73
  setup(
53
74
  name="mps-flash-attn",
54
- version="0.1.1",
75
+ version="0.1.5",
55
76
  packages=find_packages(),
56
77
  package_data={
57
78
  "mps_flash_attn": [
File without changes