mps-flash-attn 0.1.6__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (38) hide show
  1. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/PKG-INFO +13 -3
  2. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/README.md +12 -2
  3. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/__init__.py +58 -45
  4. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/csrc/mps_flash_attn.mm +131 -45
  5. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn.egg-info/PKG-INFO +13 -3
  6. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn.egg-info/SOURCES.txt +0 -1
  7. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/pyproject.toml +1 -1
  8. mps_flash_attn-0.1.6/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  9. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/LICENSE +0 -0
  10. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  11. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  12. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  13. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  14. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  15. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  16. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  17. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  18. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  19. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  20. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  21. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  22. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  23. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  24. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  25. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  26. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  27. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  28. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  29. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  30. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  31. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  32. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn/kernels/manifest.json +0 -0
  33. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  34. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn.egg-info/requires.txt +0 -0
  35. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/mps_flash_attn.egg-info/top_level.txt +0 -0
  36. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/setup.cfg +0 -0
  37. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/setup.py +0 -0
  38. {mps_flash_attn-0.1.6 → mps_flash_attn-0.1.7}/tests/test_attention.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
32
32
  ## Features
33
33
 
34
34
  - **Forward pass**: 2-5x faster than PyTorch SDPA
35
- - **Backward pass**: Full gradient support for training
35
+ - **Backward pass**: Full gradient support for training (fp32 precision)
36
36
  - **Causal masking**: Native kernel support (only 5% overhead)
37
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
37
38
  - **FP16/FP32**: Native fp16 output (no conversion overhead)
38
39
  - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
39
40
 
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
98
99
  out = flash_attention(q, k, v, is_causal=True)
99
100
  ```
100
101
 
102
+ ### Attention masks (for custom masking patterns)
103
+
104
+ ```python
105
+ # Boolean mask: True = masked (don't attend), False = attend
106
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
107
+ mask[:, :, :, 512:] = True # Mask out positions after 512
108
+
109
+ out = flash_attention(q, k, v, attn_mask=mask)
110
+ ```
111
+
101
112
  ### Training with gradients
102
113
 
103
114
  ```python
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
247
258
  **Known limitations:**
248
259
  - Sequence length must be divisible by block size (typically 64)
249
260
  - Head dimension: Best with 32, 64, 96, 128
250
- - No arbitrary attention masks (only causal or none)
251
261
  - No dropout
252
262
 
253
263
  ## Credits
@@ -7,8 +7,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
7
7
  ## Features
8
8
 
9
9
  - **Forward pass**: 2-5x faster than PyTorch SDPA
10
- - **Backward pass**: Full gradient support for training
10
+ - **Backward pass**: Full gradient support for training (fp32 precision)
11
11
  - **Causal masking**: Native kernel support (only 5% overhead)
12
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
12
13
  - **FP16/FP32**: Native fp16 output (no conversion overhead)
13
14
  - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
14
15
 
@@ -73,6 +74,16 @@ out = flash_attention(q, k, v)
73
74
  out = flash_attention(q, k, v, is_causal=True)
74
75
  ```
75
76
 
77
+ ### Attention masks (for custom masking patterns)
78
+
79
+ ```python
80
+ # Boolean mask: True = masked (don't attend), False = attend
81
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
82
+ mask[:, :, :, 512:] = True # Mask out positions after 512
83
+
84
+ out = flash_attention(q, k, v, attn_mask=mask)
85
+ ```
86
+
76
87
  ### Training with gradients
77
88
 
78
89
  ```python
@@ -222,7 +233,6 @@ python scripts/build_metallibs.py
222
233
  **Known limitations:**
223
234
  - Sequence length must be divisible by block size (typically 64)
224
235
  - Head dimension: Best with 32, 64, 96, 128
225
- - No arbitrary attention masks (only causal or none)
226
236
  - No dropout
227
237
 
228
238
  ## Credits
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
4
4
  This package provides memory-efficient attention using Metal Flash Attention kernels.
5
5
  """
6
6
 
7
- __version__ = "0.1.6"
7
+ __version__ = "0.1.7"
8
8
 
9
9
  import torch
10
10
  from typing import Optional
@@ -20,39 +20,9 @@ except ImportError as e:
20
20
  _HAS_MFA = False
21
21
  _IMPORT_ERROR = str(e)
22
22
 
23
- # Set up shipped kernels directory for zero-compilation loading
24
- def _init_shipped_kernels():
25
- """Point the Swift bridge to pre-shipped kernel binaries."""
26
- try:
27
- import ctypes
28
- bridge_path = os.environ.get("MFA_BRIDGE_PATH")
29
- if not bridge_path:
30
- module_dir = os.path.dirname(__file__)
31
- candidates = [
32
- os.path.join(module_dir, "lib", "libMFABridge.dylib"), # Bundled in wheel
33
- os.path.join(module_dir, "..", "swift-bridge", ".build", "release", "libMFABridge.dylib"),
34
- os.path.join(module_dir, "libMFABridge.dylib"),
35
- ]
36
- for path in candidates:
37
- if os.path.exists(path):
38
- bridge_path = path
39
- break
40
-
41
- if bridge_path and os.path.exists(bridge_path):
42
- lib = ctypes.CDLL(bridge_path)
43
-
44
- # Set shipped kernels directory (pre-compiled metallibs + pipeline binaries)
45
- kernels_dir = os.path.join(os.path.dirname(__file__), "kernels")
46
- if os.path.exists(kernels_dir):
47
- lib.mfa_set_kernels_dir(kernels_dir.encode('utf-8'))
48
-
49
- lib.mfa_init()
50
- except Exception:
51
- pass # Init is optional, will fall back to runtime compilation
52
-
53
- # Initialize shipped kernels on import
54
- if _HAS_MFA:
55
- _init_shipped_kernels()
23
+ # Note: The C++ extension handles loading libMFABridge.dylib via dlopen.
24
+ # Set MFA_BRIDGE_PATH environment variable to specify the library location.
25
+ # Do NOT load the library here via ctypes - that causes duplicate class warnings.
56
26
 
57
27
 
58
28
  def is_available() -> bool:
@@ -64,7 +34,7 @@ class FlashAttentionFunction(torch.autograd.Function):
64
34
  """Autograd function for Flash Attention with backward pass support."""
65
35
 
66
36
  @staticmethod
67
- def forward(ctx, query, key, value, is_causal, scale):
37
+ def forward(ctx, query, key, value, is_causal, scale, attn_mask):
68
38
  # Apply scale if provided (MFA uses 1/sqrt(D) internally)
69
39
  scale_factor = 1.0
70
40
  if scale is not None:
@@ -74,10 +44,15 @@ class FlashAttentionFunction(torch.autograd.Function):
74
44
  query = query * scale_factor
75
45
 
76
46
  # Forward with logsumexp for backward
77
- output, logsumexp = _C.forward_with_lse(query, key, value, is_causal)
47
+ output, logsumexp = _C.forward_with_lse(query, key, value, is_causal, attn_mask)
78
48
 
79
49
  # Save for backward
80
- ctx.save_for_backward(query, key, value, output, logsumexp)
50
+ if attn_mask is not None:
51
+ ctx.save_for_backward(query, key, value, output, logsumexp, attn_mask)
52
+ ctx.has_mask = True
53
+ else:
54
+ ctx.save_for_backward(query, key, value, output, logsumexp)
55
+ ctx.has_mask = False
81
56
  ctx.is_causal = is_causal
82
57
  ctx.scale_factor = scale_factor
83
58
 
@@ -85,19 +60,23 @@ class FlashAttentionFunction(torch.autograd.Function):
85
60
 
86
61
  @staticmethod
87
62
  def backward(ctx, grad_output):
88
- query, key, value, output, logsumexp = ctx.saved_tensors
63
+ if ctx.has_mask:
64
+ query, key, value, output, logsumexp, attn_mask = ctx.saved_tensors
65
+ else:
66
+ query, key, value, output, logsumexp = ctx.saved_tensors
67
+ attn_mask = None
89
68
 
90
69
  # Compute gradients
91
70
  dQ, dK, dV = _C.backward(
92
- grad_output, query, key, value, output, logsumexp, ctx.is_causal
71
+ grad_output, query, key, value, output, logsumexp, ctx.is_causal, attn_mask
93
72
  )
94
73
 
95
74
  # If we scaled the query in forward, scale the gradient back
96
75
  if ctx.scale_factor != 1.0:
97
76
  dQ = dQ * ctx.scale_factor
98
77
 
99
- # Return gradients (None for is_causal and scale since they're not tensors)
100
- return dQ, dK, dV, None, None
78
+ # Return gradients (None for is_causal, scale, and attn_mask since they're not tensors or don't need grad)
79
+ return dQ, dK, dV, None, None, None
101
80
 
102
81
 
103
82
  def flash_attention(
@@ -106,6 +85,7 @@ def flash_attention(
106
85
  value: torch.Tensor,
107
86
  is_causal: bool = False,
108
87
  scale: Optional[float] = None,
88
+ attn_mask: Optional[torch.Tensor] = None,
109
89
  ) -> torch.Tensor:
110
90
  """
111
91
  Compute scaled dot-product attention using Flash Attention on MPS.
@@ -121,6 +101,9 @@ def flash_attention(
121
101
  value: Value tensor of shape (B, num_heads, seq_len, head_dim)
122
102
  is_causal: If True, applies causal masking (for autoregressive models)
123
103
  scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
104
+ attn_mask: Optional boolean attention mask of shape (B, 1, seq_len_q, seq_len_kv)
105
+ or (B, num_heads, seq_len_q, seq_len_kv). True values indicate
106
+ positions to be masked (not attended to).
124
107
 
125
108
  Returns:
126
109
  Output tensor of shape (B, num_heads, seq_len, head_dim)
@@ -137,6 +120,11 @@ def flash_attention(
137
120
  >>> q.requires_grad = True
138
121
  >>> out = flash_attention(q, k, v)
139
122
  >>> out.sum().backward() # Computes dQ
123
+
124
+ # With attention mask:
125
+ >>> mask = torch.zeros(2, 1, 4096, 4096, dtype=torch.bool, device='mps')
126
+ >>> mask[:, :, :, 2048:] = True # mask out second half of keys
127
+ >>> out = flash_attention(q, k, v, attn_mask=mask)
140
128
  """
141
129
  if not _HAS_MFA:
142
130
  raise RuntimeError(
@@ -154,9 +142,23 @@ def flash_attention(
154
142
  raise ValueError("key must be on MPS device")
155
143
  if value.device.type != 'mps':
156
144
  raise ValueError("value must be on MPS device")
145
+ if attn_mask is not None and attn_mask.device.type != 'mps':
146
+ raise ValueError("attn_mask must be on MPS device")
147
+
148
+ # Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
149
+ if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
150
+ # Apply scale if provided
151
+ if scale is not None:
152
+ default_scale = 1.0 / math.sqrt(query.shape[-1])
153
+ if abs(scale - default_scale) > 1e-6:
154
+ scale_factor = scale / default_scale
155
+ query = query * scale_factor
156
+
157
+ # Forward only - no logsumexp needed, no tensors saved
158
+ return _C.forward(query, key, value, is_causal, attn_mask)
157
159
 
158
160
  # Use autograd function for gradient support
159
- return FlashAttentionFunction.apply(query, key, value, is_causal, scale)
161
+ return FlashAttentionFunction.apply(query, key, value, is_causal, scale, attn_mask)
160
162
 
161
163
 
162
164
  def replace_sdpa():
@@ -173,16 +175,27 @@ def replace_sdpa():
173
175
 
174
176
  def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
175
177
  is_causal=False, scale=None):
176
- # Use MFA for MPS tensors without mask/dropout
178
+ # Use MFA for MPS tensors without dropout
177
179
  # Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
178
180
  # For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
179
181
  if (query.device.type == 'mps' and
180
- attn_mask is None and
181
182
  dropout_p == 0.0 and
182
183
  _HAS_MFA and
183
184
  query.shape[2] >= 1024):
184
185
  try:
185
- return flash_attention(query, key, value, is_causal=is_causal, scale=scale)
186
+ # Convert float mask to bool mask if needed
187
+ # PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
188
+ # MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
189
+ mfa_mask = None
190
+ if attn_mask is not None:
191
+ if attn_mask.dtype == torch.bool:
192
+ # Boolean mask: True means masked (don't attend)
193
+ mfa_mask = attn_mask
194
+ else:
195
+ # Float mask: typically -inf for masked positions, 0 for unmasked
196
+ # Convert: positions with large negative values -> True (masked)
197
+ mfa_mask = attn_mask < -1e4
198
+ return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
186
199
  except Exception:
187
200
  # Fall back to original on any error
188
201
  pass
@@ -12,22 +12,29 @@
12
12
  #import <Foundation/Foundation.h>
13
13
 
14
14
  #include <dlfcn.h>
15
+ #include <string>
16
+ #include <vector>
15
17
 
16
18
  // ============================================================================
17
19
  // MFA Bridge Function Types
18
20
  // ============================================================================
19
21
 
20
22
  typedef bool (*mfa_init_fn)();
21
- typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool);
23
+ typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool);
22
24
  // New zero-sync encode functions that take PyTorch's command encoder
23
- typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
24
- typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
25
- int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
25
+ // Added mask_ptr and mask_offset parameters
26
+ typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
27
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
28
+ int32_t, int32_t);
29
+ typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
30
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
26
31
  int32_t, int32_t);
27
32
  // Legacy sync functions (fallback)
28
- typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, int64_t, int64_t, int64_t, int64_t, int64_t, int32_t, int32_t);
29
- typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
30
- int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
33
+ typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, void*,
34
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
35
+ int32_t, int32_t);
36
+ typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
37
+ int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
31
38
  int32_t, int32_t);
32
39
  typedef void (*mfa_release_kernel_fn)(void*);
33
40
 
@@ -46,31 +53,44 @@ static bool g_initialized = false;
46
53
  // Load MFA Bridge Library
47
54
  // ============================================================================
48
55
 
56
+ // Get the directory containing this shared library
57
+ static std::string get_module_dir() {
58
+ Dl_info info;
59
+ if (dladdr((void*)get_module_dir, &info) && info.dli_fname) {
60
+ std::string path(info.dli_fname);
61
+ size_t last_slash = path.rfind('/');
62
+ if (last_slash != std::string::npos) {
63
+ return path.substr(0, last_slash);
64
+ }
65
+ }
66
+ return ".";
67
+ }
68
+
49
69
  static bool load_mfa_bridge() {
50
70
  if (g_dylib_handle) return true;
51
71
 
52
- // Try to find the dylib relative to this extension
53
- // First try the standard location
54
- const char* paths[] = {
55
- "libMFABridge.dylib",
56
- "./libMFABridge.dylib",
57
- "../swift-bridge/.build/release/libMFABridge.dylib",
58
- nullptr
72
+ // First check environment variable (highest priority)
73
+ const char* mfa_path = getenv("MFA_BRIDGE_PATH");
74
+ if (mfa_path) {
75
+ g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
76
+ if (g_dylib_handle) return true;
77
+ }
78
+
79
+ // Get the directory containing this extension module
80
+ std::string module_dir = get_module_dir();
81
+
82
+ // Try paths relative to the module directory
83
+ std::vector<std::string> paths = {
84
+ module_dir + "/lib/libMFABridge.dylib", // Bundled in wheel
85
+ module_dir + "/../swift-bridge/.build/release/libMFABridge.dylib", // Dev build
86
+ "libMFABridge.dylib", // Current directory fallback
59
87
  };
60
88
 
61
- for (int i = 0; paths[i] != nullptr; i++) {
62
- g_dylib_handle = dlopen(paths[i], RTLD_NOW);
89
+ for (const auto& path : paths) {
90
+ g_dylib_handle = dlopen(path.c_str(), RTLD_NOW);
63
91
  if (g_dylib_handle) break;
64
92
  }
65
93
 
66
- if (!g_dylib_handle) {
67
- // Try with absolute path from environment
68
- const char* mfa_path = getenv("MFA_BRIDGE_PATH");
69
- if (mfa_path) {
70
- g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
71
- }
72
- }
73
-
74
94
  if (!g_dylib_handle) {
75
95
  throw std::runtime_error(
76
96
  "Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
@@ -127,6 +147,7 @@ struct KernelCacheKey {
127
147
  bool low_precision;
128
148
  bool low_precision_outputs;
129
149
  bool causal;
150
+ bool has_mask;
130
151
 
131
152
  bool operator==(const KernelCacheKey& other) const {
132
153
  return seq_len_q == other.seq_len_q &&
@@ -134,7 +155,8 @@ struct KernelCacheKey {
134
155
  head_dim == other.head_dim &&
135
156
  low_precision == other.low_precision &&
136
157
  low_precision_outputs == other.low_precision_outputs &&
137
- causal == other.causal;
158
+ causal == other.causal &&
159
+ has_mask == other.has_mask;
138
160
  }
139
161
  };
140
162
 
@@ -145,14 +167,15 @@ struct KernelCacheKeyHash {
145
167
  (std::hash<int64_t>()(k.head_dim) << 2) ^
146
168
  (std::hash<bool>()(k.low_precision) << 3) ^
147
169
  (std::hash<bool>()(k.low_precision_outputs) << 4) ^
148
- (std::hash<bool>()(k.causal) << 5);
170
+ (std::hash<bool>()(k.causal) << 5) ^
171
+ (std::hash<bool>()(k.has_mask) << 6);
149
172
  }
150
173
  };
151
174
 
152
175
  static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
153
176
 
154
- static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal) {
155
- KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal};
177
+ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask) {
178
+ KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask};
156
179
 
157
180
  auto it = g_kernel_cache.find(key);
158
181
  if (it != g_kernel_cache.end()) {
@@ -165,7 +188,8 @@ static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_di
165
188
  static_cast<int32_t>(head_dim),
166
189
  low_prec,
167
190
  low_prec_outputs,
168
- causal
191
+ causal,
192
+ has_mask
169
193
  );
170
194
 
171
195
  if (!kernel) {
@@ -184,7 +208,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
184
208
  const at::Tensor& query, // (B, H, N, D)
185
209
  const at::Tensor& key, // (B, H, N, D)
186
210
  const at::Tensor& value, // (B, H, N, D)
187
- bool is_causal
211
+ bool is_causal,
212
+ const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
188
213
  ) {
189
214
  // Initialize MFA on first call
190
215
  if (!g_initialized) {
@@ -248,6 +273,27 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
248
273
  auto k = k_expanded.contiguous();
249
274
  auto v = v_expanded.contiguous();
250
275
 
276
+ // Handle attention mask
277
+ bool has_mask = attn_mask.has_value();
278
+ at::Tensor mask;
279
+ if (has_mask) {
280
+ mask = attn_mask.value();
281
+ TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
282
+ TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
283
+ // Convert to bool/uint8 if needed - kernel expects uchar (0 = attend, non-0 = mask out)
284
+ if (mask.scalar_type() == at::kBool) {
285
+ // Convert bool to uint8 for Metal compatibility
286
+ mask = mask.to(at::kByte);
287
+ }
288
+ TORCH_CHECK(mask.scalar_type() == at::kByte,
289
+ "Attention mask must be bool or uint8");
290
+ // Expand mask heads if needed (B, 1, N_q, N_kv) -> (B, H, N_q, N_kv)
291
+ if (mask.size(1) == 1 && num_heads > 1) {
292
+ mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
293
+ }
294
+ mask = mask.contiguous();
295
+ }
296
+
251
297
  // Allocate output in the appropriate precision
252
298
  // With lowPrecisionOutputs=true, MFA writes FP16 directly
253
299
  at::Tensor output;
@@ -264,7 +310,7 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
264
310
  query.options().dtype(at::kFloat));
265
311
 
266
312
  // Get or create kernel with matching output precision and causal mode
267
- void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
313
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask);
268
314
 
269
315
  // Get Metal buffers with byte offsets
270
316
  auto q_info = getBufferInfo(q);
@@ -273,6 +319,12 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
273
319
  auto o_info = getBufferInfo(output);
274
320
  auto l_info = getBufferInfo(logsumexp);
275
321
 
322
+ // Mask buffer info (may be nullptr if no mask)
323
+ BufferInfo mask_info = {nil, 0};
324
+ if (has_mask) {
325
+ mask_info = getBufferInfo(mask);
326
+ }
327
+
276
328
  // Use PyTorch's MPS stream command encoder for zero-sync integration
277
329
  @autoreleasepool {
278
330
  auto stream = at::mps::getCurrentMPSStream();
@@ -290,11 +342,13 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
290
342
  (__bridge void*)v_info.buffer,
291
343
  (__bridge void*)o_info.buffer,
292
344
  (__bridge void*)l_info.buffer,
345
+ has_mask ? (__bridge void*)mask_info.buffer : nullptr,
293
346
  q_info.byte_offset,
294
347
  k_info.byte_offset,
295
348
  v_info.byte_offset,
296
349
  o_info.byte_offset,
297
350
  l_info.byte_offset,
351
+ mask_info.byte_offset,
298
352
  static_cast<int32_t>(batch_size),
299
353
  static_cast<int32_t>(num_heads)
300
354
  );
@@ -317,9 +371,10 @@ at::Tensor mps_flash_attention_forward(
317
371
  const at::Tensor& query,
318
372
  const at::Tensor& key,
319
373
  const at::Tensor& value,
320
- bool is_causal
374
+ bool is_causal,
375
+ const c10::optional<at::Tensor>& attn_mask
321
376
  ) {
322
- auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal);
377
+ auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal, attn_mask);
323
378
  return output;
324
379
  }
325
380
 
@@ -334,7 +389,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
334
389
  const at::Tensor& value, // (B, H, N, D)
335
390
  const at::Tensor& output, // (B, H, N, D)
336
391
  const at::Tensor& logsumexp, // (B, H, N)
337
- bool is_causal
392
+ bool is_causal,
393
+ const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
338
394
  ) {
339
395
  // Initialize MFA on first call
340
396
  if (!g_initialized) {
@@ -364,16 +420,35 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
364
420
  query.scalar_type() == at::kBFloat16);
365
421
  bool low_precision_outputs = low_precision;
366
422
 
367
- // Make inputs contiguous
368
- auto q = query.contiguous();
369
- auto k = key.contiguous();
370
- auto v = value.contiguous();
371
- auto o = output.contiguous();
372
- auto dO = grad_output.contiguous();
423
+ // Handle attention mask
424
+ bool has_mask = attn_mask.has_value();
425
+ at::Tensor mask;
426
+ if (has_mask) {
427
+ mask = attn_mask.value();
428
+ TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
429
+ TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
430
+ if (mask.scalar_type() == at::kBool) {
431
+ mask = mask.to(at::kByte);
432
+ }
433
+ TORCH_CHECK(mask.scalar_type() == at::kByte,
434
+ "Attention mask must be bool or uint8");
435
+ if (mask.size(1) == 1 && num_heads > 1) {
436
+ mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
437
+ }
438
+ mask = mask.contiguous();
439
+ }
440
+
441
+ // Make inputs contiguous and upcast to fp32 for numerical stability
442
+ // The backward pass accumulates many small values, so fp32 precision is critical
443
+ auto q = query.contiguous().to(at::kFloat);
444
+ auto k = key.contiguous().to(at::kFloat);
445
+ auto v = value.contiguous().to(at::kFloat);
446
+ auto o = output.contiguous().to(at::kFloat);
447
+ auto dO = grad_output.contiguous().to(at::kFloat);
373
448
  auto lse = logsumexp.contiguous();
374
449
 
375
- // Get or create kernel (with causal mode)
376
- void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
450
+ // Get or create kernel - always use fp32 for backward pass
451
+ void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, false, false, is_causal, has_mask);
377
452
 
378
453
  // Allocate D buffer (dO * O reduction, always fp32)
379
454
  auto D = at::empty({batch_size, num_heads, seq_len_q},
@@ -399,6 +474,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
399
474
  auto dk_info = getBufferInfo(dK);
400
475
  auto dv_info = getBufferInfo(dV);
401
476
 
477
+ // Mask buffer info (may be nullptr if no mask)
478
+ BufferInfo mask_info = {nil, 0};
479
+ if (has_mask) {
480
+ mask_info = getBufferInfo(mask);
481
+ }
482
+
402
483
  // Use PyTorch's MPS stream command encoder for zero-sync integration
403
484
  @autoreleasepool {
404
485
  auto stream = at::mps::getCurrentMPSStream();
@@ -419,6 +500,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
419
500
  (__bridge void*)dq_info.buffer,
420
501
  (__bridge void*)dk_info.buffer,
421
502
  (__bridge void*)dv_info.buffer,
503
+ has_mask ? (__bridge void*)mask_info.buffer : nullptr,
422
504
  q_info.byte_offset,
423
505
  k_info.byte_offset,
424
506
  v_info.byte_offset,
@@ -429,6 +511,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
429
511
  dq_info.byte_offset,
430
512
  dk_info.byte_offset,
431
513
  dv_info.byte_offset,
514
+ mask_info.byte_offset,
432
515
  static_cast<int32_t>(batch_size),
433
516
  static_cast<int32_t>(num_heads)
434
517
  );
@@ -462,14 +545,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
462
545
  py::arg("query"),
463
546
  py::arg("key"),
464
547
  py::arg("value"),
465
- py::arg("is_causal") = false);
548
+ py::arg("is_causal") = false,
549
+ py::arg("attn_mask") = py::none());
466
550
 
467
551
  m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
468
552
  "Flash Attention forward pass (returns output and logsumexp for backward)",
469
553
  py::arg("query"),
470
554
  py::arg("key"),
471
555
  py::arg("value"),
472
- py::arg("is_causal") = false);
556
+ py::arg("is_causal") = false,
557
+ py::arg("attn_mask") = py::none());
473
558
 
474
559
  m.def("backward", &mps_flash_attention_backward,
475
560
  "Flash Attention backward pass",
@@ -479,5 +564,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
479
564
  py::arg("value"),
480
565
  py::arg("output"),
481
566
  py::arg("logsumexp"),
482
- py::arg("is_causal") = false);
567
+ py::arg("is_causal") = false,
568
+ py::arg("attn_mask") = py::none());
483
569
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mps-flash-attn
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
5
  Author: imperatormk
6
6
  License-Expression: MIT
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
32
32
  ## Features
33
33
 
34
34
  - **Forward pass**: 2-5x faster than PyTorch SDPA
35
- - **Backward pass**: Full gradient support for training
35
+ - **Backward pass**: Full gradient support for training (fp32 precision)
36
36
  - **Causal masking**: Native kernel support (only 5% overhead)
37
+ - **Attention masks**: Full boolean mask support for arbitrary masking patterns
37
38
  - **FP16/FP32**: Native fp16 output (no conversion overhead)
38
39
  - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
39
40
 
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
98
99
  out = flash_attention(q, k, v, is_causal=True)
99
100
  ```
100
101
 
102
+ ### Attention masks (for custom masking patterns)
103
+
104
+ ```python
105
+ # Boolean mask: True = masked (don't attend), False = attend
106
+ mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
107
+ mask[:, :, :, 512:] = True # Mask out positions after 512
108
+
109
+ out = flash_attention(q, k, v, attn_mask=mask)
110
+ ```
111
+
101
112
  ### Training with gradients
102
113
 
103
114
  ```python
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
247
258
  **Known limitations:**
248
259
  - Sequence length must be divisible by block size (typically 64)
249
260
  - Head dimension: Best with 32, 64, 96, 128
250
- - No arbitrary attention masks (only causal or none)
251
261
  - No dropout
252
262
 
253
263
  ## Credits
@@ -32,5 +32,4 @@ mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e
32
32
  mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin
33
33
  mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib
34
34
  mps_flash_attn/kernels/manifest.json
35
- mps_flash_attn/lib/libMFABridge.dylib
36
35
  tests/test_attention.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mps-flash-attn"
7
- version = "0.1.6"
7
+ version = "0.1.7"
8
8
  description = "Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)"
9
9
  readme = "README.md"
10
10
  license = "MIT"
File without changes
File without changes
File without changes