mps-flash-attn 0.1.5__tar.gz → 0.1.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (39) hide show
  1. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/PKG-INFO +13 -3
  2. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/README.md +12 -2
  3. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/__init__.py +87 -46
  4. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/csrc/mps_flash_attn.mm +191 -61
  5. mps_flash_attn-0.1.13/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  6. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/PKG-INFO +13 -3
  7. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/pyproject.toml +1 -1
  8. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/setup.py +21 -0
  9. mps_flash_attn-0.1.5/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  10. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/LICENSE +0 -0
  11. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  12. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  13. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  14. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  15. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  16. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  17. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  18. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  19. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  20. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  21. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  22. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  23. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  24. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  25. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  26. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  27. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  28. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  29. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  30. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  31. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  32. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  33. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/manifest.json +0 -0
  34. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
  35. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  36. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/requires.txt +0 -0
  37. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/top_level.txt +0 -0
  38. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/setup.cfg +0 -0
  39. {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/tests/test_attention.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.5
3
+ Version: 0.1.13
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
32
32
  ## Features
33
33
 
34
34
  - **Forward pass**: 2-5x faster than PyTorch SDPA
35
- - **Backward pass**: Full gradient support for training
35
+ - **Backward pass**: Full gradient support for training (fp32 precision)
36
36
  - **Causal masking**: Native kernel support (only 5% overhead)
37
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
37
38
  - **FP16/FP32**: Native fp16 output (no conversion overhead)
38
39
  - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
39
40
 
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
98
99
  out = flash_attention(q, k, v, is_causal=True)
99
100
  ```
100
101
 
102
+ ### Attention masks (for custom masking patterns)
103
+
104
+ ```python
105
+ # Boolean mask: True = masked (don't attend), False = attend
106
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
107
+ mask[:, :, :, 512:] = True # Mask out positions after 512
108
+
109
+ out = flash_attention(q, k, v, attn_mask=mask)
110
+ ```
111
+
101
112
  ### Training with gradients
102
113
 
103
114
  ```python
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
247
258
  **Known limitations:**
248
259
  - Sequence length must be divisible by block size (typically 64)
249
260
  - Head dimension: Best with 32, 64, 96, 128
250
- - No arbitrary attention masks (only causal or none)
251
261
  - No dropout
252
262
 
253
263
  ## Credits
@@ -7,8 +7,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
7
7
  ## Features
8
8
 
9
9
  - **Forward pass**: 2-5x faster than PyTorch SDPA
10
- - **Backward pass**: Full gradient support for training
10
+ - **Backward pass**: Full gradient support for training (fp32 precision)
11
11
  - **Causal masking**: Native kernel support (only 5% overhead)
12
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
12
13
  - **FP16/FP32**: Native fp16 output (no conversion overhead)
13
14
  - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
14
15
 
@@ -73,6 +74,16 @@ out = flash_attention(q, k, v)
73
74
  out = flash_attention(q, k, v, is_causal=True)
74
75
  ```
75
76
 
77
+ ### Attention masks (for custom masking patterns)
78
+
79
+ ```python
80
+ # Boolean mask: True = masked (don't attend), False = attend
81
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
82
+ mask[:, :, :, 512:] = True # Mask out positions after 512
83
+
84
+ out = flash_attention(q, k, v, attn_mask=mask)
85
+ ```
86
+
76
87
  ### Training with gradients
77
88
 
78
89
  ```python
@@ -222,7 +233,6 @@ python scripts/build_metallibs.py
222
233
  **Known limitations:**
223
234
  - Sequence length must be divisible by block size (typically 64)
224
235
  - Head dimension: Best with 32, 64, 96, 128
225
- - No arbitrary attention masks (only causal or none)
226
236
  - No dropout
227
237
 
228
238
  ## Credits
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.1.5"
7
+ __version__ = "0.1.13"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -20,39 +20,9 @@ except ImportError as e:
20
20
  _HAS_MFA = False
21
21
  _IMPORT_ERROR = str(e)
22
22
 
23
- # Set up shipped kernels directory for zero-compilation loading
24
- def _init_shipped_kernels():
25
- """Point the Swift bridge to pre-shipped kernel binaries."""
26
- try:
27
- import ctypes
28
- bridge_path = os.environ.get("MFA_BRIDGE_PATH")
29
- if not bridge_path:
30
- module_dir = os.path.dirname(__file__)
31
- candidates = [
32
- os.path.join(module_dir, "lib", "libMFABridge.dylib"), # Bundled in wheel
33
- os.path.join(module_dir, "..", "swift-bridge", ".build", "release", "libMFABridge.dylib"),
34
- os.path.join(module_dir, "libMFABridge.dylib"),
35
- ]
36
- for path in candidates:
37
- if os.path.exists(path):
38
- bridge_path = path
39
- break
40
-
41
- if bridge_path and os.path.exists(bridge_path):
42
- lib = ctypes.CDLL(bridge_path)
43
-
44
- # Set shipped kernels directory (pre-compiled metallibs + pipeline binaries)
45
- kernels_dir = os.path.join(os.path.dirname(__file__), "kernels")
46
- if os.path.exists(kernels_dir):
47
- lib.mfa_set_kernels_dir(kernels_dir.encode('utf-8'))
48
-
49
- lib.mfa_init()
50
- except Exception:
51
- pass # Init is optional, will fall back to runtime compilation
52
-
53
- # Initialize shipped kernels on import
54
- if _HAS_MFA:
55
- _init_shipped_kernels()
23
+ # Note: The C++ extension handles loading libMFABridge.dylib via dlopen.
24
+ # Set MFA_BRIDGE_PATH environment variable to specify the library location.
25
+ # Do NOT load the library here via ctypes - that causes duplicate class warnings.
56
26
 
57
27
 
58
28
  def is_available() -> bool:
@@ -60,11 +30,35 @@ def is_available() -> bool:
60
30
  return _HAS_MFA and torch.backends.mps.is_available()
61
31
 
62
32
 
33
+ def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
34
+ """
35
+ Convert attention mask to MFA's boolean format.
36
+
37
+ MFA uses boolean masks where True = masked (don't attend).
38
+ PyTorch SDPA uses additive float masks where -inf/large negative = masked.
39
+
40
+ Args:
41
+ attn_mask: Optional mask, either:
42
+ - None: no mask
43
+ - bool tensor: already in MFA format (True = masked)
44
+ - float tensor: additive mask (large negative = masked)
45
+
46
+ Returns:
47
+ Boolean mask suitable for flash_attention(), or None
48
+ """
49
+ if attn_mask is None:
50
+ return None
51
+ if attn_mask.dtype == torch.bool:
52
+ return attn_mask
53
+ # Float mask: large negative values indicate masked positions
54
+ return attn_mask <= -1e3
55
+
56
+
63
57
  class FlashAttentionFunction(torch.autograd.Function):
64
58
  """Autograd function for Flash Attention with backward pass support."""
65
59
 
66
60
  @staticmethod
67
- def forward(ctx, query, key, value, is_causal, scale):
61
+ def forward(ctx, query, key, value, is_causal, scale, attn_mask):
68
62
  # Apply scale if provided (MFA uses 1/sqrt(D) internally)
69
63
  scale_factor = 1.0
70
64
  if scale is not None:
@@ -74,10 +68,15 @@ class FlashAttentionFunction(torch.autograd.Function):
74
68
  query = query * scale_factor
75
69
 
76
70
  # Forward with logsumexp for backward
77
- output, logsumexp = _C.forward_with_lse(query, key, value, is_causal)
71
+ output, logsumexp = _C.forward_with_lse(query, key, value, is_causal, attn_mask)
78
72
 
79
73
  # Save for backward
80
- ctx.save_for_backward(query, key, value, output, logsumexp)
74
+ if attn_mask is not None:
75
+ ctx.save_for_backward(query, key, value, output, logsumexp, attn_mask)
76
+ ctx.has_mask = True
77
+ else:
78
+ ctx.save_for_backward(query, key, value, output, logsumexp)
79
+ ctx.has_mask = False
81
80
  ctx.is_causal = is_causal
82
81
  ctx.scale_factor = scale_factor
83
82
 
@@ -85,19 +84,23 @@ class FlashAttentionFunction(torch.autograd.Function):
85
84
 
86
85
  @staticmethod
87
86
  def backward(ctx, grad_output):
88
- query, key, value, output, logsumexp = ctx.saved_tensors
87
+ if ctx.has_mask:
88
+ query, key, value, output, logsumexp, attn_mask = ctx.saved_tensors
89
+ else:
90
+ query, key, value, output, logsumexp = ctx.saved_tensors
91
+ attn_mask = None
89
92
 
90
93
  # Compute gradients
91
94
  dQ, dK, dV = _C.backward(
92
- grad_output, query, key, value, output, logsumexp, ctx.is_causal
95
+ grad_output, query, key, value, output, logsumexp, ctx.is_causal, attn_mask
93
96
  )
94
97
 
95
98
  # If we scaled the query in forward, scale the gradient back
96
99
  if ctx.scale_factor != 1.0:
97
100
  dQ = dQ * ctx.scale_factor
98
101
 
99
- # Return gradients (None for is_causal and scale since they're not tensors)
100
- return dQ, dK, dV, None, None
102
+ # Return gradients (None for is_causal, scale, and attn_mask since they're not tensors or don't need grad)
103
+ return dQ, dK, dV, None, None, None
101
104
 
102
105
 
103
106
  def flash_attention(
@@ -106,6 +109,7 @@ def flash_attention(
106
109
  value: torch.Tensor,
107
110
  is_causal: bool = False,
108
111
  scale: Optional[float] = None,
112
+ attn_mask: Optional[torch.Tensor] = None,
109
113
  ) -> torch.Tensor:
110
114
  """
111
115
  Compute scaled dot-product attention using Flash Attention on MPS.
@@ -121,6 +125,9 @@ def flash_attention(
121
125
  value: Value tensor of shape (B, num_heads, seq_len, head_dim)
122
126
  is_causal: If True, applies causal masking (for autoregressive models)
123
127
  scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
128
+ attn_mask: Optional boolean attention mask of shape (B, 1, seq_len_q, seq_len_kv)
129
+ or (B, num_heads, seq_len_q, seq_len_kv). True values indicate
130
+ positions to be masked (not attended to).
124
131
 
125
132
  Returns:
126
133
  Output tensor of shape (B, num_heads, seq_len, head_dim)
@@ -137,6 +144,11 @@ def flash_attention(
137
144
  >>> q.requires_grad = True
138
145
  >>> out = flash_attention(q, k, v)
139
146
  >>> out.sum().backward() # Computes dQ
147
+
148
+ # With attention mask:
149
+ >>> mask = torch.zeros(2, 1, 4096, 4096, dtype=torch.bool, device='mps')
150
+ >>> mask[:, :, :, 2048:] = True # mask out second half of keys
151
+ >>> out = flash_attention(q, k, v, attn_mask=mask)
140
152
  """
141
153
  if not _HAS_MFA:
142
154
  raise RuntimeError(
@@ -154,9 +166,23 @@ def flash_attention(
154
166
  raise ValueError("key must be on MPS device")
155
167
  if value.device.type != 'mps':
156
168
  raise ValueError("value must be on MPS device")
169
+ if attn_mask is not None and attn_mask.device.type != 'mps':
170
+ raise ValueError("attn_mask must be on MPS device")
171
+
172
+ # Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
173
+ if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
174
+ # Apply scale if provided
175
+ if scale is not None:
176
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
177
+ if abs(scale - default_scale) > 1e-6:
178
+ scale_factor = scale / default_scale
179
+ query = query * scale_factor
180
+
181
+ # Forward only - no logsumexp needed, no tensors saved
182
+ return _C.forward(query, key, value, is_causal, attn_mask)
157
183
 
158
184
  # Use autograd function for gradient support
159
- return FlashAttentionFunction.apply(query, key, value, is_causal, scale)
185
+ return FlashAttentionFunction.apply(query, key, value, is_causal, scale, attn_mask)
160
186
 
161
187
 
162
188
  def replace_sdpa():
@@ -173,13 +199,28 @@ def replace_sdpa():
173
199
 
174
200
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
175
201
  is_causal=False, scale=None):
176
- # Use MFA for MPS tensors without mask/dropout
202
+ # Use MFA for MPS tensors without dropout
203
+ # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
204
+ # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
177
205
  if (query.device.type == 'mps' and
178
- attn_mask is None and
179
206
  dropout_p == 0.0 and
180
- _HAS_MFA):
207
+ _HAS_MFA and
208
+ query.shape[2] >= 1024):
181
209
  try:
182
- return flash_attention(query, key, value, is_causal=is_causal, scale=scale)
210
+ # Convert float mask to bool mask if needed
211
+ # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
212
+ # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
213
+ mfa_mask = None
214
+ if attn_mask is not None:
215
+ if attn_mask.dtype == torch.bool:
216
+ # Boolean mask: True means masked (don't attend)
217
+ mfa_mask = attn_mask
218
+ else:
219
+ # Float mask: typically -inf for masked positions, 0 for unmasked
220
+ # Convert: positions with large negative values -> True (masked)
221
+ # Use -1e3 threshold to catch -1000, -10000, -inf, etc.
222
+ mfa_mask = attn_mask <= -1e3
223
+ return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
183
224
  except Exception:
184
225
  # Fall back to original on any error
185
226
  pass
@@ -12,28 +12,37 @@
12
12
  #import <Foundation/Foundation.h>
13
13
 
14
14
  #include <dlfcn.h>
15
+ #include <string>
16
+ #include <vector>
15
17
 
16
18
  // ============================================================================
17
19
  // MFA Bridge Function Types
18
20
  // ============================================================================
19
21
 
20
22
  typedef bool (*mfa_init_fn)();
21
- typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool);
23
+ typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool);
24
+ typedef void* (*mfa_create_kernel_v2_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool);
22
25
  // New zero-sync encode functions that take PyTorch's command encoder
23
- typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
24
- typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
25
- int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
26
+ // Added mask_ptr and mask_offset parameters
27
+ typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
28
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
29
+ int32_t, int32_t);
30
+ typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
31
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
26
32
  int32_t, int32_t);
27
33
  // Legacy sync functions (fallback)
28
- typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
29
- typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
30
- int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
34
+ typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, void*,
35
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
36
+ int32_t, int32_t);
37
+ typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
38
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
31
39
  int32_t, int32_t);
32
40
  typedef void (*mfa_release_kernel_fn)(void*);
33
41
 
34
42
  // Global function pointers
35
43
  static mfa_init_fn g_mfa_init = nullptr;
36
44
  static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
45
+ static mfa_create_kernel_v2_fn g_mfa_create_kernel_v2 = nullptr;
37
46
  static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
38
47
  static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
39
48
  static mfa_forward_fn g_mfa_forward = nullptr;
@@ -46,31 +55,44 @@ static bool g_initialized = false;
46
55
  // Load MFA Bridge Library
47
56
  // ============================================================================
48
57
 
58
+ // Get the directory containing this shared library
59
+ static std::string get_module_dir() {
60
+ Dl_info info;
61
+ if (dladdr((void*)get_module_dir, &info) && info.dli_fname) {
62
+ std::string path(info.dli_fname);
63
+ size_t last_slash = path.rfind('/');
64
+ if (last_slash != std::string::npos) {
65
+ return path.substr(0, last_slash);
66
+ }
67
+ }
68
+ return ".";
69
+ }
70
+
49
71
  static bool load_mfa_bridge() {
50
72
  if (g_dylib_handle) return true;
51
73
 
52
- // Try to find the dylib relative to this extension
53
- // First try the standard location
54
- const char* paths[] = {
55
- "libMFABridge.dylib",
56
- "./libMFABridge.dylib",
57
- "../swift-bridge/.build/release/libMFABridge.dylib",
58
- nullptr
74
+ // First check environment variable (highest priority)
75
+ const char* mfa_path = getenv("MFA_BRIDGE_PATH");
76
+ if (mfa_path) {
77
+ g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
78
+ if (g_dylib_handle) return true;
79
+ }
80
+
81
+ // Get the directory containing this extension module
82
+ std::string module_dir = get_module_dir();
83
+
84
+ // Try paths relative to the module directory
85
+ std::vector<std::string> paths = {
86
+ module_dir + "/lib/libMFABridge.dylib", // Bundled in wheel
87
+ module_dir + "/../swift-bridge/.build/release/libMFABridge.dylib", // Dev build
88
+ "libMFABridge.dylib", // Current directory fallback
59
89
  };
60
90
 
61
- for (int i = 0; paths[i] != nullptr; i++) {
62
- g_dylib_handle = dlopen(paths[i], RTLD_NOW);
91
+ for (const auto& path : paths) {
92
+ g_dylib_handle = dlopen(path.c_str(), RTLD_NOW);
63
93
  if (g_dylib_handle) break;
64
94
  }
65
95
 
66
- if (!g_dylib_handle) {
67
- // Try with absolute path from environment
68
- const char* mfa_path = getenv("MFA_BRIDGE_PATH");
69
- if (mfa_path) {
70
- g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
71
- }
72
- }
73
-
74
96
  if (!g_dylib_handle) {
75
97
  throw std::runtime_error(
76
98
  "Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
@@ -79,6 +101,7 @@ static bool load_mfa_bridge() {
79
101
  // Load function pointers
80
102
  g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
81
103
  g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
104
+ g_mfa_create_kernel_v2 = (mfa_create_kernel_v2_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v2");
82
105
  g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
83
106
  g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
84
107
  g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
@@ -127,6 +150,8 @@ struct KernelCacheKey {
127
150
  bool low_precision;
128
151
  bool low_precision_outputs;
129
152
  bool causal;
153
+ bool has_mask;
154
+ bool use_bf16;
130
155
 
131
156
  bool operator==(const KernelCacheKey& other) const {
132
157
  return seq_len_q == other.seq_len_q &&
@@ -134,7 +159,9 @@ struct KernelCacheKey {
134
159
  head_dim == other.head_dim &&
135
160
  low_precision == other.low_precision &&
136
161
  low_precision_outputs == other.low_precision_outputs &&
137
- causal == other.causal;
162
+ causal == other.causal &&
163
+ has_mask == other.has_mask &&
164
+ use_bf16 == other.use_bf16;
138
165
  }
139
166
  };
140
167
 
@@ -145,28 +172,47 @@ struct KernelCacheKeyHash {
145
172
  (std::hash<int64_t>()(k.head_dim) << 2) ^
146
173
  (std::hash<bool>()(k.low_precision) << 3) ^
147
174
  (std::hash<bool>()(k.low_precision_outputs) << 4) ^
148
- (std::hash<bool>()(k.causal) << 5);
175
+ (std::hash<bool>()(k.causal) << 5) ^
176
+ (std::hash<bool>()(k.has_mask) << 6) ^
177
+ (std::hash<bool>()(k.use_bf16) << 7);
149
178
  }
150
179
  };
151
180
 
152
181
  static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
153
182
 
154
- static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal) {
155
- KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal};
183
+ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false) {
184
+ KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16};
156
185
 
157
186
  auto it = g_kernel_cache.find(key);
158
187
  if (it != g_kernel_cache.end()) {
159
188
  return it->second;
160
189
  }
161
190
 
162
- void* kernel = g_mfa_create_kernel(
163
- static_cast<int32_t>(seq_q),
164
- static_cast<int32_t>(seq_kv),
165
- static_cast<int32_t>(head_dim),
166
- low_prec,
167
- low_prec_outputs,
168
- causal
169
- );
191
+ void* kernel = nullptr;
192
+ if (use_bf16 && g_mfa_create_kernel_v2) {
193
+ // Use v2 API with BF16 support
194
+ kernel = g_mfa_create_kernel_v2(
195
+ static_cast<int32_t>(seq_q),
196
+ static_cast<int32_t>(seq_kv),
197
+ static_cast<int32_t>(head_dim),
198
+ low_prec,
199
+ low_prec_outputs,
200
+ causal,
201
+ has_mask,
202
+ use_bf16
203
+ );
204
+ } else {
205
+ // Legacy API
206
+ kernel = g_mfa_create_kernel(
207
+ static_cast<int32_t>(seq_q),
208
+ static_cast<int32_t>(seq_kv),
209
+ static_cast<int32_t>(head_dim),
210
+ low_prec,
211
+ low_prec_outputs,
212
+ causal,
213
+ has_mask
214
+ );
215
+ }
170
216
 
171
217
  if (!kernel) {
172
218
  throw std::runtime_error("Failed to create MFA kernel");
@@ -184,7 +230,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
184
230
  const at::Tensor& query, // (B, H, N, D)
185
231
  const at::Tensor& key, // (B, H, N, D)
186
232
  const at::Tensor& value, // (B, H, N, D)
187
- bool is_causal
233
+ bool is_causal,
234
+ const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
188
235
  ) {
189
236
  // Initialize MFA on first call
190
237
  if (!g_initialized) {
@@ -236,25 +283,63 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
236
283
  v_expanded = value;
237
284
  }
238
285
 
239
- // Determine precision
240
- bool low_precision = (query.scalar_type() == at::kHalf ||
241
- query.scalar_type() == at::kBFloat16);
286
+ // Determine precision - MFA kernel supports FP16, BF16, and FP32
287
+ bool is_bfloat16 = (query.scalar_type() == at::kBFloat16);
288
+ bool is_fp16 = (query.scalar_type() == at::kHalf);
242
289
 
243
- // For fp16 inputs, we can now output directly to fp16 (no extra conversion needed!)
244
- bool low_precision_outputs = low_precision;
290
+ // Use native BF16 kernel if available, otherwise fall back to FP32
291
+ bool use_bf16_kernel = is_bfloat16 && g_mfa_create_kernel_v2;
292
+ bool low_precision = is_fp16; // FP16 path
293
+ bool low_precision_outputs = is_fp16 || use_bf16_kernel;
245
294
 
246
295
  // Make inputs contiguous
247
296
  auto q = query.contiguous();
248
297
  auto k = k_expanded.contiguous();
249
298
  auto v = v_expanded.contiguous();
250
299
 
300
+ // For BF16 without native kernel support, convert to FP32 (avoids FP16 overflow)
301
+ if (is_bfloat16 && !use_bf16_kernel) {
302
+ q = q.to(at::kFloat);
303
+ k = k.to(at::kFloat);
304
+ v = v.to(at::kFloat);
305
+ }
306
+
307
+ // Handle attention mask
308
+ bool has_mask = attn_mask.has_value();
309
+ at::Tensor mask;
310
+ if (has_mask) {
311
+ mask = attn_mask.value();
312
+ TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
313
+ TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
314
+ // Convert to bool/uint8 if needed - kernel expects uchar (0 = attend, non-0 = mask out)
315
+ if (mask.scalar_type() == at::kBool) {
316
+ // Convert bool to uint8 for Metal compatibility
317
+ mask = mask.to(at::kByte);
318
+ }
319
+ TORCH_CHECK(mask.scalar_type() == at::kByte,
320
+ "Attention mask must be bool or uint8");
321
+ // Expand mask dimensions if needed -> (B, H, N_q, N_kv)
322
+ // Handle (B, 1, N_q, N_kv) -> expand heads
323
+ // Handle (B, H, 1, N_kv) -> expand query dim (1D key mask)
324
+ // Handle (B, 1, 1, N_kv) -> expand both
325
+ if (mask.size(1) == 1 || mask.size(2) == 1) {
326
+ mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
327
+ }
328
+ mask = mask.contiguous();
329
+ }
330
+
251
331
  // Allocate output in the appropriate precision
252
- // With lowPrecisionOutputs=true, MFA writes FP16 directly
253
332
  at::Tensor output;
254
- if (low_precision_outputs) {
333
+ if (use_bf16_kernel) {
334
+ // Native BF16 kernel outputs BF16
335
+ output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
336
+ query.options().dtype(at::kBFloat16));
337
+ } else if (low_precision_outputs) {
338
+ // FP16 kernel outputs FP16
255
339
  output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
256
340
  query.options().dtype(at::kHalf));
257
341
  } else {
342
+ // FP32 kernel outputs FP32
258
343
  output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
259
344
  query.options().dtype(at::kFloat));
260
345
  }
@@ -263,8 +348,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
263
348
  auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
264
349
  query.options().dtype(at::kFloat));
265
350
 
266
- // Get or create kernel with matching output precision and causal mode
267
- void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
351
+ // Get or create kernel with matching precision and causal mode
352
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask, use_bf16_kernel);
268
353
 
269
354
  // Get Metal buffers with byte offsets
270
355
  auto q_info = getBufferInfo(q);
@@ -273,6 +358,12 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
273
358
  auto o_info = getBufferInfo(output);
274
359
  auto l_info = getBufferInfo(logsumexp);
275
360
 
361
+ // Mask buffer info (may be nullptr if no mask)
362
+ BufferInfo mask_info = {nil, 0};
363
+ if (has_mask) {
364
+ mask_info = getBufferInfo(mask);
365
+ }
366
+
276
367
  // Use PyTorch's MPS stream command encoder for zero-sync integration
277
368
  @autoreleasepool {
278
369
  auto stream = at::mps::getCurrentMPSStream();
@@ -290,11 +381,13 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
290
381
  (__bridge void*)v_info.buffer,
291
382
  (__bridge void*)o_info.buffer,
292
383
  (__bridge void*)l_info.buffer,
384
+ has_mask ? (__bridge void*)mask_info.buffer : nullptr,
293
385
  q_info.byte_offset,
294
386
  k_info.byte_offset,
295
387
  v_info.byte_offset,
296
388
  o_info.byte_offset,
297
389
  l_info.byte_offset,
390
+ mask_info.byte_offset,
298
391
  static_cast<int32_t>(batch_size),
299
392
  static_cast<int32_t>(num_heads)
300
393
  );
@@ -307,7 +400,12 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
307
400
  // The encoder stays open for coalescing more kernels
308
401
  }
309
402
 
310
- // Output is already in the correct dtype (fp16 or fp32)
403
+ // Convert output back to BF16 if input was BF16 and we used FP32 fallback
404
+ // (native BF16 kernel already outputs BF16, no conversion needed)
405
+ if (is_bfloat16 && !use_bf16_kernel) {
406
+ output = output.to(at::kBFloat16);
407
+ }
408
+
311
409
  // Return both output and logsumexp (needed for backward pass)
312
410
  return std::make_tuple(output, logsumexp);
313
411
  }
@@ -317,9 +415,10 @@ at::Tensor mps_flash_attention_forward(
317
415
  const at::Tensor& query,
318
416
  const at::Tensor& key,
319
417
  const at::Tensor& value,
320
- bool is_causal
418
+ bool is_causal,
419
+ const c10::optional<at::Tensor>& attn_mask
321
420
  ) {
322
- auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal);
421
+ auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal, attn_mask);
323
422
  return output;
324
423
  }
325
424
 
@@ -334,7 +433,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
334
433
  const at::Tensor& value, // (B, H, N, D)
335
434
  const at::Tensor& output, // (B, H, N, D)
336
435
  const at::Tensor& logsumexp, // (B, H, N)
337
- bool is_causal
436
+ bool is_causal,
437
+ const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
338
438
  ) {
339
439
  // Initialize MFA on first call
340
440
  if (!g_initialized) {
@@ -364,16 +464,35 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
364
464
  query.scalar_type() == at::kBFloat16);
365
465
  bool low_precision_outputs = low_precision;
366
466
 
367
- // Make inputs contiguous
368
- auto q = query.contiguous();
369
- auto k = key.contiguous();
370
- auto v = value.contiguous();
371
- auto o = output.contiguous();
372
- auto dO = grad_output.contiguous();
467
+ // Handle attention mask
468
+ bool has_mask = attn_mask.has_value();
469
+ at::Tensor mask;
470
+ if (has_mask) {
471
+ mask = attn_mask.value();
472
+ TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
473
+ TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
474
+ if (mask.scalar_type() == at::kBool) {
475
+ mask = mask.to(at::kByte);
476
+ }
477
+ TORCH_CHECK(mask.scalar_type() == at::kByte,
478
+ "Attention mask must be bool or uint8");
479
+ if (mask.size(1) == 1 && num_heads > 1) {
480
+ mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
481
+ }
482
+ mask = mask.contiguous();
483
+ }
484
+
485
+ // Make inputs contiguous and upcast to fp32 for numerical stability
486
+ // The backward pass accumulates many small values, so fp32 precision is critical
487
+ auto q = query.contiguous().to(at::kFloat);
488
+ auto k = key.contiguous().to(at::kFloat);
489
+ auto v = value.contiguous().to(at::kFloat);
490
+ auto o = output.contiguous().to(at::kFloat);
491
+ auto dO = grad_output.contiguous().to(at::kFloat);
373
492
  auto lse = logsumexp.contiguous();
374
493
 
375
- // Get or create kernel (with causal mode)
376
- void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
494
+ // Get or create kernel - always use fp32 for backward pass
495
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, false, false, is_causal, has_mask);
377
496
 
378
497
  // Allocate D buffer (dO * O reduction, always fp32)
379
498
  auto D = at::empty({batch_size, num_heads, seq_len_q},
@@ -399,6 +518,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
399
518
  auto dk_info = getBufferInfo(dK);
400
519
  auto dv_info = getBufferInfo(dV);
401
520
 
521
+ // Mask buffer info (may be nullptr if no mask)
522
+ BufferInfo mask_info = {nil, 0};
523
+ if (has_mask) {
524
+ mask_info = getBufferInfo(mask);
525
+ }
526
+
402
527
  // Use PyTorch's MPS stream command encoder for zero-sync integration
403
528
  @autoreleasepool {
404
529
  auto stream = at::mps::getCurrentMPSStream();
@@ -419,6 +544,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
419
544
  (__bridge void*)dq_info.buffer,
420
545
  (__bridge void*)dk_info.buffer,
421
546
  (__bridge void*)dv_info.buffer,
547
+ has_mask ? (__bridge void*)mask_info.buffer : nullptr,
422
548
  q_info.byte_offset,
423
549
  k_info.byte_offset,
424
550
  v_info.byte_offset,
@@ -429,6 +555,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
429
555
  dq_info.byte_offset,
430
556
  dk_info.byte_offset,
431
557
  dv_info.byte_offset,
558
+ mask_info.byte_offset,
432
559
  static_cast<int32_t>(batch_size),
433
560
  static_cast<int32_t>(num_heads)
434
561
  );
@@ -462,14 +589,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
462
589
  py::arg("query"),
463
590
  py::arg("key"),
464
591
  py::arg("value"),
465
- py::arg("is_causal") = false);
592
+ py::arg("is_causal") = false,
593
+ py::arg("attn_mask") = py::none());
466
594
 
467
595
  m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
468
596
  "Flash Attention forward pass (returns output and logsumexp for backward)",
469
597
  py::arg("query"),
470
598
  py::arg("key"),
471
599
  py::arg("value"),
472
- py::arg("is_causal") = false);
600
+ py::arg("is_causal") = false,
601
+ py::arg("attn_mask") = py::none());
473
602
 
474
603
  m.def("backward", &mps_flash_attention_backward,
475
604
  "Flash Attention backward pass",
@@ -479,5 +608,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
479
608
  py::arg("value"),
480
609
  py::arg("output"),
481
610
  py::arg("logsumexp"),
482
- py::arg("is_causal") = false);
611
+ py::arg("is_causal") = false,
612
+ py::arg("attn_mask") = py::none());
483
613
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.5
3
+ Version: 0.1.13
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
32
32
  ## Features
33
33
 
34
34
  - **Forward pass**: 2-5x faster than PyTorch SDPA
35
- - **Backward pass**: Full gradient support for training
35
+ - **Backward pass**: Full gradient support for training (fp32 precision)
36
36
  - **Causal masking**: Native kernel support (only 5% overhead)
37
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
37
38
  - **FP16/FP32**: Native fp16 output (no conversion overhead)
38
39
  - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
39
40
 
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
98
99
  out = flash_attention(q, k, v, is_causal=True)
99
100
  ```
100
101
 
102
+ ### Attention masks (for custom masking patterns)
103
+
104
+ ```python
105
+ # Boolean mask: True = masked (don't attend), False = attend
106
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
107
+ mask[:, :, :, 512:] = True # Mask out positions after 512
108
+
109
+ out = flash_attention(q, k, v, attn_mask=mask)
110
+ ```
111
+
101
112
  ### Training with gradients
102
113
 
103
114
  ```python
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
247
258
  **Known limitations:**
248
259
  - Sequence length must be divisible by block size (typically 64)
249
260
  - Head dimension: Best with 32, 64, 96, 128
250
- - No arbitrary attention masks (only causal or none)
251
261
  - No dropout
252
262
 
253
263
  ## Credits
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.1.5"
7
+ version = "0.1.13"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -4,6 +4,7 @@ Setup script for MPS Flash Attention
4
4
 
5
5
  import os
6
6
  import sys
7
+ import shutil
7
8
  from setuptools import setup, find_packages, Extension
8
9
  from setuptools.command.build_ext import build_ext
9
10
 
@@ -36,6 +37,26 @@ class ObjCppBuildExt(build_ext):
36
37
 
37
38
  super().build_extensions()
38
39
 
40
+ # Copy libMFABridge.dylib to lib/ after building
41
+ self._copy_swift_bridge()
42
+
43
+ def _copy_swift_bridge(self):
44
+ """Copy Swift bridge dylib to package lib/ directory."""
45
+ src_path = os.path.join(
46
+ os.path.dirname(__file__),
47
+ "swift-bridge", ".build", "release", "libMFABridge.dylib"
48
+ )
49
+ dst_dir = os.path.join(os.path.dirname(__file__), "mps_flash_attn", "lib")
50
+ dst_path = os.path.join(dst_dir, "libMFABridge.dylib")
51
+
52
+ if os.path.exists(src_path):
53
+ os.makedirs(dst_dir, exist_ok=True)
54
+ shutil.copy2(src_path, dst_path)
55
+ print(f"Copied libMFABridge.dylib to {dst_path}")
56
+ else:
57
+ print(f"Warning: {src_path} not found. Build swift-bridge first with:")
58
+ print(" cd swift-bridge && swift build -c release")
59
+
39
60
 
40
61
  def get_extensions():
41
62
  if sys.platform != "darwin":
File without changes