mps-flash-attn 0.1.5__tar.gz → 0.1.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/PKG-INFO +13 -3
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/README.md +12 -2
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/__init__.py +87 -46
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/csrc/mps_flash_attn.mm +191 -61
- mps_flash_attn-0.1.13/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/PKG-INFO +13 -3
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/pyproject.toml +1 -1
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/setup.py +21 -0
- mps_flash_attn-0.1.5/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/LICENSE +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/setup.cfg +0 -0
- {mps_flash_attn-0.1.5 → mps_flash_attn-0.1.13}/tests/test_attention.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mps-flash-attn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.13
|
|
4
4
|
Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
|
|
5
5
|
Author: imperatormk
|
|
6
6
|
License-Expression: MIT
|
|
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
|
|
|
32
32
|
## Features
|
|
33
33
|
|
|
34
34
|
- **Forward pass**: 2-5x faster than PyTorch SDPA
|
|
35
|
-
- **Backward pass**: Full gradient support for training
|
|
35
|
+
- **Backward pass**: Full gradient support for training (fp32 precision)
|
|
36
36
|
- **Causal masking**: Native kernel support (only 5% overhead)
|
|
37
|
+
- **Attention masks**: Full boolean mask support for arbitrary masking patterns
|
|
37
38
|
- **FP16/FP32**: Native fp16 output (no conversion overhead)
|
|
38
39
|
- **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
|
|
39
40
|
|
|
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
|
|
|
98
99
|
out = flash_attention(q, k, v, is_causal=True)
|
|
99
100
|
```
|
|
100
101
|
|
|
102
|
+
### Attention masks (for custom masking patterns)
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
# Boolean mask: True = masked (don't attend), False = attend
|
|
106
|
+
mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
|
|
107
|
+
mask[:, :, :, 512:] = True # Mask out positions after 512
|
|
108
|
+
|
|
109
|
+
out = flash_attention(q, k, v, attn_mask=mask)
|
|
110
|
+
```
|
|
111
|
+
|
|
101
112
|
### Training with gradients
|
|
102
113
|
|
|
103
114
|
```python
|
|
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
|
|
|
247
258
|
**Known limitations:**
|
|
248
259
|
- Sequence length must be divisible by block size (typically 64)
|
|
249
260
|
- Head dimension: Best with 32, 64, 96, 128
|
|
250
|
-
- No arbitrary attention masks (only causal or none)
|
|
251
261
|
- No dropout
|
|
252
262
|
|
|
253
263
|
## Credits
|
|
@@ -7,8 +7,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
|
|
|
7
7
|
## Features
|
|
8
8
|
|
|
9
9
|
- **Forward pass**: 2-5x faster than PyTorch SDPA
|
|
10
|
-
- **Backward pass**: Full gradient support for training
|
|
10
|
+
- **Backward pass**: Full gradient support for training (fp32 precision)
|
|
11
11
|
- **Causal masking**: Native kernel support (only 5% overhead)
|
|
12
|
+
- **Attention masks**: Full boolean mask support for arbitrary masking patterns
|
|
12
13
|
- **FP16/FP32**: Native fp16 output (no conversion overhead)
|
|
13
14
|
- **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
|
|
14
15
|
|
|
@@ -73,6 +74,16 @@ out = flash_attention(q, k, v)
|
|
|
73
74
|
out = flash_attention(q, k, v, is_causal=True)
|
|
74
75
|
```
|
|
75
76
|
|
|
77
|
+
### Attention masks (for custom masking patterns)
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
# Boolean mask: True = masked (don't attend), False = attend
|
|
81
|
+
mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
|
|
82
|
+
mask[:, :, :, 512:] = True # Mask out positions after 512
|
|
83
|
+
|
|
84
|
+
out = flash_attention(q, k, v, attn_mask=mask)
|
|
85
|
+
```
|
|
86
|
+
|
|
76
87
|
### Training with gradients
|
|
77
88
|
|
|
78
89
|
```python
|
|
@@ -222,7 +233,6 @@ python scripts/build_metallibs.py
|
|
|
222
233
|
**Known limitations:**
|
|
223
234
|
- Sequence length must be divisible by block size (typically 64)
|
|
224
235
|
- Head dimension: Best with 32, 64, 96, 128
|
|
225
|
-
- No arbitrary attention masks (only causal or none)
|
|
226
236
|
- No dropout
|
|
227
237
|
|
|
228
238
|
## Credits
|
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.13"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -20,39 +20,9 @@ except ImportError as e:
|
|
|
20
20
|
_HAS_MFA = False
|
|
21
21
|
_IMPORT_ERROR = str(e)
|
|
22
22
|
|
|
23
|
-
#
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
try:
|
|
27
|
-
import ctypes
|
|
28
|
-
bridge_path = os.environ.get("MFA_BRIDGE_PATH")
|
|
29
|
-
if not bridge_path:
|
|
30
|
-
module_dir = os.path.dirname(__file__)
|
|
31
|
-
candidates = [
|
|
32
|
-
os.path.join(module_dir, "lib", "libMFABridge.dylib"), # Bundled in wheel
|
|
33
|
-
os.path.join(module_dir, "..", "swift-bridge", ".build", "release", "libMFABridge.dylib"),
|
|
34
|
-
os.path.join(module_dir, "libMFABridge.dylib"),
|
|
35
|
-
]
|
|
36
|
-
for path in candidates:
|
|
37
|
-
if os.path.exists(path):
|
|
38
|
-
bridge_path = path
|
|
39
|
-
break
|
|
40
|
-
|
|
41
|
-
if bridge_path and os.path.exists(bridge_path):
|
|
42
|
-
lib = ctypes.CDLL(bridge_path)
|
|
43
|
-
|
|
44
|
-
# Set shipped kernels directory (pre-compiled metallibs + pipeline binaries)
|
|
45
|
-
kernels_dir = os.path.join(os.path.dirname(__file__), "kernels")
|
|
46
|
-
if os.path.exists(kernels_dir):
|
|
47
|
-
lib.mfa_set_kernels_dir(kernels_dir.encode('utf-8'))
|
|
48
|
-
|
|
49
|
-
lib.mfa_init()
|
|
50
|
-
except Exception:
|
|
51
|
-
pass # Init is optional, will fall back to runtime compilation
|
|
52
|
-
|
|
53
|
-
# Initialize shipped kernels on import
|
|
54
|
-
if _HAS_MFA:
|
|
55
|
-
_init_shipped_kernels()
|
|
23
|
+
# Note: The C++ extension handles loading libMFABridge.dylib via dlopen.
|
|
24
|
+
# Set MFA_BRIDGE_PATH environment variable to specify the library location.
|
|
25
|
+
# Do NOT load the library here via ctypes - that causes duplicate class warnings.
|
|
56
26
|
|
|
57
27
|
|
|
58
28
|
def is_available() -> bool:
|
|
@@ -60,11 +30,35 @@ def is_available() -> bool:
|
|
|
60
30
|
return _HAS_MFA and torch.backends.mps.is_available()
|
|
61
31
|
|
|
62
32
|
|
|
33
|
+
def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
34
|
+
"""
|
|
35
|
+
Convert attention mask to MFA's boolean format.
|
|
36
|
+
|
|
37
|
+
MFA uses boolean masks where True = masked (don't attend).
|
|
38
|
+
PyTorch SDPA uses additive float masks where -inf/large negative = masked.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
attn_mask: Optional mask, either:
|
|
42
|
+
- None: no mask
|
|
43
|
+
- bool tensor: already in MFA format (True = masked)
|
|
44
|
+
- float tensor: additive mask (large negative = masked)
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Boolean mask suitable for flash_attention(), or None
|
|
48
|
+
"""
|
|
49
|
+
if attn_mask is None:
|
|
50
|
+
return None
|
|
51
|
+
if attn_mask.dtype == torch.bool:
|
|
52
|
+
return attn_mask
|
|
53
|
+
# Float mask: large negative values indicate masked positions
|
|
54
|
+
return attn_mask <= -1e3
|
|
55
|
+
|
|
56
|
+
|
|
63
57
|
class FlashAttentionFunction(torch.autograd.Function):
|
|
64
58
|
"""Autograd function for Flash Attention with backward pass support."""
|
|
65
59
|
|
|
66
60
|
@staticmethod
|
|
67
|
-
def forward(ctx, query, key, value, is_causal, scale):
|
|
61
|
+
def forward(ctx, query, key, value, is_causal, scale, attn_mask):
|
|
68
62
|
# Apply scale if provided (MFA uses 1/sqrt(D) internally)
|
|
69
63
|
scale_factor = 1.0
|
|
70
64
|
if scale is not None:
|
|
@@ -74,10 +68,15 @@ class FlashAttentionFunction(torch.autograd.Function):
|
|
|
74
68
|
query = query * scale_factor
|
|
75
69
|
|
|
76
70
|
# Forward with logsumexp for backward
|
|
77
|
-
output, logsumexp = _C.forward_with_lse(query, key, value, is_causal)
|
|
71
|
+
output, logsumexp = _C.forward_with_lse(query, key, value, is_causal, attn_mask)
|
|
78
72
|
|
|
79
73
|
# Save for backward
|
|
80
|
-
|
|
74
|
+
if attn_mask is not None:
|
|
75
|
+
ctx.save_for_backward(query, key, value, output, logsumexp, attn_mask)
|
|
76
|
+
ctx.has_mask = True
|
|
77
|
+
else:
|
|
78
|
+
ctx.save_for_backward(query, key, value, output, logsumexp)
|
|
79
|
+
ctx.has_mask = False
|
|
81
80
|
ctx.is_causal = is_causal
|
|
82
81
|
ctx.scale_factor = scale_factor
|
|
83
82
|
|
|
@@ -85,19 +84,23 @@ class FlashAttentionFunction(torch.autograd.Function):
|
|
|
85
84
|
|
|
86
85
|
@staticmethod
|
|
87
86
|
def backward(ctx, grad_output):
|
|
88
|
-
|
|
87
|
+
if ctx.has_mask:
|
|
88
|
+
query, key, value, output, logsumexp, attn_mask = ctx.saved_tensors
|
|
89
|
+
else:
|
|
90
|
+
query, key, value, output, logsumexp = ctx.saved_tensors
|
|
91
|
+
attn_mask = None
|
|
89
92
|
|
|
90
93
|
# Compute gradients
|
|
91
94
|
dQ, dK, dV = _C.backward(
|
|
92
|
-
grad_output, query, key, value, output, logsumexp, ctx.is_causal
|
|
95
|
+
grad_output, query, key, value, output, logsumexp, ctx.is_causal, attn_mask
|
|
93
96
|
)
|
|
94
97
|
|
|
95
98
|
# If we scaled the query in forward, scale the gradient back
|
|
96
99
|
if ctx.scale_factor != 1.0:
|
|
97
100
|
dQ = dQ * ctx.scale_factor
|
|
98
101
|
|
|
99
|
-
# Return gradients (None for is_causal and
|
|
100
|
-
return dQ, dK, dV, None, None
|
|
102
|
+
# Return gradients (None for is_causal, scale, and attn_mask since they're not tensors or don't need grad)
|
|
103
|
+
return dQ, dK, dV, None, None, None
|
|
101
104
|
|
|
102
105
|
|
|
103
106
|
def flash_attention(
|
|
@@ -106,6 +109,7 @@ def flash_attention(
|
|
|
106
109
|
value: torch.Tensor,
|
|
107
110
|
is_causal: bool = False,
|
|
108
111
|
scale: Optional[float] = None,
|
|
112
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
109
113
|
) -> torch.Tensor:
|
|
110
114
|
"""
|
|
111
115
|
Compute scaled dot-product attention using Flash Attention on MPS.
|
|
@@ -121,6 +125,9 @@ def flash_attention(
|
|
|
121
125
|
value: Value tensor of shape (B, num_heads, seq_len, head_dim)
|
|
122
126
|
is_causal: If True, applies causal masking (for autoregressive models)
|
|
123
127
|
scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
|
|
128
|
+
attn_mask: Optional boolean attention mask of shape (B, 1, seq_len_q, seq_len_kv)
|
|
129
|
+
or (B, num_heads, seq_len_q, seq_len_kv). True values indicate
|
|
130
|
+
positions to be masked (not attended to).
|
|
124
131
|
|
|
125
132
|
Returns:
|
|
126
133
|
Output tensor of shape (B, num_heads, seq_len, head_dim)
|
|
@@ -137,6 +144,11 @@ def flash_attention(
|
|
|
137
144
|
>>> q.requires_grad = True
|
|
138
145
|
>>> out = flash_attention(q, k, v)
|
|
139
146
|
>>> out.sum().backward() # Computes dQ
|
|
147
|
+
|
|
148
|
+
# With attention mask:
|
|
149
|
+
>>> mask = torch.zeros(2, 1, 4096, 4096, dtype=torch.bool, device='mps')
|
|
150
|
+
>>> mask[:, :, :, 2048:] = True # mask out second half of keys
|
|
151
|
+
>>> out = flash_attention(q, k, v, attn_mask=mask)
|
|
140
152
|
"""
|
|
141
153
|
if not _HAS_MFA:
|
|
142
154
|
raise RuntimeError(
|
|
@@ -154,9 +166,23 @@ def flash_attention(
|
|
|
154
166
|
raise ValueError("key must be on MPS device")
|
|
155
167
|
if value.device.type != 'mps':
|
|
156
168
|
raise ValueError("value must be on MPS device")
|
|
169
|
+
if attn_mask is not None and attn_mask.device.type != 'mps':
|
|
170
|
+
raise ValueError("attn_mask must be on MPS device")
|
|
171
|
+
|
|
172
|
+
# Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
|
|
173
|
+
if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
|
|
174
|
+
# Apply scale if provided
|
|
175
|
+
if scale is not None:
|
|
176
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
177
|
+
if abs(scale - default_scale) > 1e-6:
|
|
178
|
+
scale_factor = scale / default_scale
|
|
179
|
+
query = query * scale_factor
|
|
180
|
+
|
|
181
|
+
# Forward only - no logsumexp needed, no tensors saved
|
|
182
|
+
return _C.forward(query, key, value, is_causal, attn_mask)
|
|
157
183
|
|
|
158
184
|
# Use autograd function for gradient support
|
|
159
|
-
return FlashAttentionFunction.apply(query, key, value, is_causal, scale)
|
|
185
|
+
return FlashAttentionFunction.apply(query, key, value, is_causal, scale, attn_mask)
|
|
160
186
|
|
|
161
187
|
|
|
162
188
|
def replace_sdpa():
|
|
@@ -173,13 +199,28 @@ def replace_sdpa():
|
|
|
173
199
|
|
|
174
200
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
175
201
|
is_causal=False, scale=None):
|
|
176
|
-
# Use MFA for MPS tensors without
|
|
202
|
+
# Use MFA for MPS tensors without dropout
|
|
203
|
+
# Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
|
|
204
|
+
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
177
205
|
if (query.device.type == 'mps' and
|
|
178
|
-
attn_mask is None and
|
|
179
206
|
dropout_p == 0.0 and
|
|
180
|
-
_HAS_MFA
|
|
207
|
+
_HAS_MFA and
|
|
208
|
+
query.shape[2] >= 1024):
|
|
181
209
|
try:
|
|
182
|
-
|
|
210
|
+
# Convert float mask to bool mask if needed
|
|
211
|
+
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
212
|
+
# MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
|
|
213
|
+
mfa_mask = None
|
|
214
|
+
if attn_mask is not None:
|
|
215
|
+
if attn_mask.dtype == torch.bool:
|
|
216
|
+
# Boolean mask: True means masked (don't attend)
|
|
217
|
+
mfa_mask = attn_mask
|
|
218
|
+
else:
|
|
219
|
+
# Float mask: typically -inf for masked positions, 0 for unmasked
|
|
220
|
+
# Convert: positions with large negative values -> True (masked)
|
|
221
|
+
# Use -1e3 threshold to catch -1000, -10000, -inf, etc.
|
|
222
|
+
mfa_mask = attn_mask <= -1e3
|
|
223
|
+
return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
183
224
|
except Exception:
|
|
184
225
|
# Fall back to original on any error
|
|
185
226
|
pass
|
|
@@ -12,28 +12,37 @@
|
|
|
12
12
|
#import <Foundation/Foundation.h>
|
|
13
13
|
|
|
14
14
|
#include <dlfcn.h>
|
|
15
|
+
#include <string>
|
|
16
|
+
#include <vector>
|
|
15
17
|
|
|
16
18
|
// ============================================================================
|
|
17
19
|
// MFA Bridge Function Types
|
|
18
20
|
// ============================================================================
|
|
19
21
|
|
|
20
22
|
typedef bool (*mfa_init_fn)();
|
|
21
|
-
typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool);
|
|
23
|
+
typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool);
|
|
24
|
+
typedef void* (*mfa_create_kernel_v2_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool);
|
|
22
25
|
// New zero-sync encode functions that take PyTorch's command encoder
|
|
23
|
-
|
|
24
|
-
typedef bool (*
|
|
25
|
-
|
|
26
|
+
// Added mask_ptr and mask_offset parameters
|
|
27
|
+
typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
|
|
28
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
29
|
+
int32_t, int32_t);
|
|
30
|
+
typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
31
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
26
32
|
int32_t, int32_t);
|
|
27
33
|
// Legacy sync functions (fallback)
|
|
28
|
-
typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*,
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, void*,
|
|
35
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
36
|
+
int32_t, int32_t);
|
|
37
|
+
typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
38
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
31
39
|
int32_t, int32_t);
|
|
32
40
|
typedef void (*mfa_release_kernel_fn)(void*);
|
|
33
41
|
|
|
34
42
|
// Global function pointers
|
|
35
43
|
static mfa_init_fn g_mfa_init = nullptr;
|
|
36
44
|
static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
|
|
45
|
+
static mfa_create_kernel_v2_fn g_mfa_create_kernel_v2 = nullptr;
|
|
37
46
|
static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
|
|
38
47
|
static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
|
|
39
48
|
static mfa_forward_fn g_mfa_forward = nullptr;
|
|
@@ -46,31 +55,44 @@ static bool g_initialized = false;
|
|
|
46
55
|
// Load MFA Bridge Library
|
|
47
56
|
// ============================================================================
|
|
48
57
|
|
|
58
|
+
// Get the directory containing this shared library
|
|
59
|
+
static std::string get_module_dir() {
|
|
60
|
+
Dl_info info;
|
|
61
|
+
if (dladdr((void*)get_module_dir, &info) && info.dli_fname) {
|
|
62
|
+
std::string path(info.dli_fname);
|
|
63
|
+
size_t last_slash = path.rfind('/');
|
|
64
|
+
if (last_slash != std::string::npos) {
|
|
65
|
+
return path.substr(0, last_slash);
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
return ".";
|
|
69
|
+
}
|
|
70
|
+
|
|
49
71
|
static bool load_mfa_bridge() {
|
|
50
72
|
if (g_dylib_handle) return true;
|
|
51
73
|
|
|
52
|
-
//
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
74
|
+
// First check environment variable (highest priority)
|
|
75
|
+
const char* mfa_path = getenv("MFA_BRIDGE_PATH");
|
|
76
|
+
if (mfa_path) {
|
|
77
|
+
g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
|
|
78
|
+
if (g_dylib_handle) return true;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
// Get the directory containing this extension module
|
|
82
|
+
std::string module_dir = get_module_dir();
|
|
83
|
+
|
|
84
|
+
// Try paths relative to the module directory
|
|
85
|
+
std::vector<std::string> paths = {
|
|
86
|
+
module_dir + "/lib/libMFABridge.dylib", // Bundled in wheel
|
|
87
|
+
module_dir + "/../swift-bridge/.build/release/libMFABridge.dylib", // Dev build
|
|
88
|
+
"libMFABridge.dylib", // Current directory fallback
|
|
59
89
|
};
|
|
60
90
|
|
|
61
|
-
for (
|
|
62
|
-
g_dylib_handle = dlopen(
|
|
91
|
+
for (const auto& path : paths) {
|
|
92
|
+
g_dylib_handle = dlopen(path.c_str(), RTLD_NOW);
|
|
63
93
|
if (g_dylib_handle) break;
|
|
64
94
|
}
|
|
65
95
|
|
|
66
|
-
if (!g_dylib_handle) {
|
|
67
|
-
// Try with absolute path from environment
|
|
68
|
-
const char* mfa_path = getenv("MFA_BRIDGE_PATH");
|
|
69
|
-
if (mfa_path) {
|
|
70
|
-
g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
|
|
71
|
-
}
|
|
72
|
-
}
|
|
73
|
-
|
|
74
96
|
if (!g_dylib_handle) {
|
|
75
97
|
throw std::runtime_error(
|
|
76
98
|
"Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
|
|
@@ -79,6 +101,7 @@ static bool load_mfa_bridge() {
|
|
|
79
101
|
// Load function pointers
|
|
80
102
|
g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
|
|
81
103
|
g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
|
|
104
|
+
g_mfa_create_kernel_v2 = (mfa_create_kernel_v2_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v2");
|
|
82
105
|
g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
|
|
83
106
|
g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
|
|
84
107
|
g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
|
|
@@ -127,6 +150,8 @@ struct KernelCacheKey {
|
|
|
127
150
|
bool low_precision;
|
|
128
151
|
bool low_precision_outputs;
|
|
129
152
|
bool causal;
|
|
153
|
+
bool has_mask;
|
|
154
|
+
bool use_bf16;
|
|
130
155
|
|
|
131
156
|
bool operator==(const KernelCacheKey& other) const {
|
|
132
157
|
return seq_len_q == other.seq_len_q &&
|
|
@@ -134,7 +159,9 @@ struct KernelCacheKey {
|
|
|
134
159
|
head_dim == other.head_dim &&
|
|
135
160
|
low_precision == other.low_precision &&
|
|
136
161
|
low_precision_outputs == other.low_precision_outputs &&
|
|
137
|
-
causal == other.causal
|
|
162
|
+
causal == other.causal &&
|
|
163
|
+
has_mask == other.has_mask &&
|
|
164
|
+
use_bf16 == other.use_bf16;
|
|
138
165
|
}
|
|
139
166
|
};
|
|
140
167
|
|
|
@@ -145,28 +172,47 @@ struct KernelCacheKeyHash {
|
|
|
145
172
|
(std::hash<int64_t>()(k.head_dim) << 2) ^
|
|
146
173
|
(std::hash<bool>()(k.low_precision) << 3) ^
|
|
147
174
|
(std::hash<bool>()(k.low_precision_outputs) << 4) ^
|
|
148
|
-
(std::hash<bool>()(k.causal) << 5)
|
|
175
|
+
(std::hash<bool>()(k.causal) << 5) ^
|
|
176
|
+
(std::hash<bool>()(k.has_mask) << 6) ^
|
|
177
|
+
(std::hash<bool>()(k.use_bf16) << 7);
|
|
149
178
|
}
|
|
150
179
|
};
|
|
151
180
|
|
|
152
181
|
static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
|
|
153
182
|
|
|
154
|
-
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal) {
|
|
155
|
-
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal};
|
|
183
|
+
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false) {
|
|
184
|
+
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16};
|
|
156
185
|
|
|
157
186
|
auto it = g_kernel_cache.find(key);
|
|
158
187
|
if (it != g_kernel_cache.end()) {
|
|
159
188
|
return it->second;
|
|
160
189
|
}
|
|
161
190
|
|
|
162
|
-
void* kernel =
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
191
|
+
void* kernel = nullptr;
|
|
192
|
+
if (use_bf16 && g_mfa_create_kernel_v2) {
|
|
193
|
+
// Use v2 API with BF16 support
|
|
194
|
+
kernel = g_mfa_create_kernel_v2(
|
|
195
|
+
static_cast<int32_t>(seq_q),
|
|
196
|
+
static_cast<int32_t>(seq_kv),
|
|
197
|
+
static_cast<int32_t>(head_dim),
|
|
198
|
+
low_prec,
|
|
199
|
+
low_prec_outputs,
|
|
200
|
+
causal,
|
|
201
|
+
has_mask,
|
|
202
|
+
use_bf16
|
|
203
|
+
);
|
|
204
|
+
} else {
|
|
205
|
+
// Legacy API
|
|
206
|
+
kernel = g_mfa_create_kernel(
|
|
207
|
+
static_cast<int32_t>(seq_q),
|
|
208
|
+
static_cast<int32_t>(seq_kv),
|
|
209
|
+
static_cast<int32_t>(head_dim),
|
|
210
|
+
low_prec,
|
|
211
|
+
low_prec_outputs,
|
|
212
|
+
causal,
|
|
213
|
+
has_mask
|
|
214
|
+
);
|
|
215
|
+
}
|
|
170
216
|
|
|
171
217
|
if (!kernel) {
|
|
172
218
|
throw std::runtime_error("Failed to create MFA kernel");
|
|
@@ -184,7 +230,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
184
230
|
const at::Tensor& query, // (B, H, N, D)
|
|
185
231
|
const at::Tensor& key, // (B, H, N, D)
|
|
186
232
|
const at::Tensor& value, // (B, H, N, D)
|
|
187
|
-
bool is_causal
|
|
233
|
+
bool is_causal,
|
|
234
|
+
const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
|
|
188
235
|
) {
|
|
189
236
|
// Initialize MFA on first call
|
|
190
237
|
if (!g_initialized) {
|
|
@@ -236,25 +283,63 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
236
283
|
v_expanded = value;
|
|
237
284
|
}
|
|
238
285
|
|
|
239
|
-
// Determine precision
|
|
240
|
-
bool
|
|
241
|
-
|
|
286
|
+
// Determine precision - MFA kernel supports FP16, BF16, and FP32
|
|
287
|
+
bool is_bfloat16 = (query.scalar_type() == at::kBFloat16);
|
|
288
|
+
bool is_fp16 = (query.scalar_type() == at::kHalf);
|
|
242
289
|
|
|
243
|
-
//
|
|
244
|
-
bool
|
|
290
|
+
// Use native BF16 kernel if available, otherwise fall back to FP32
|
|
291
|
+
bool use_bf16_kernel = is_bfloat16 && g_mfa_create_kernel_v2;
|
|
292
|
+
bool low_precision = is_fp16; // FP16 path
|
|
293
|
+
bool low_precision_outputs = is_fp16 || use_bf16_kernel;
|
|
245
294
|
|
|
246
295
|
// Make inputs contiguous
|
|
247
296
|
auto q = query.contiguous();
|
|
248
297
|
auto k = k_expanded.contiguous();
|
|
249
298
|
auto v = v_expanded.contiguous();
|
|
250
299
|
|
|
300
|
+
// For BF16 without native kernel support, convert to FP32 (avoids FP16 overflow)
|
|
301
|
+
if (is_bfloat16 && !use_bf16_kernel) {
|
|
302
|
+
q = q.to(at::kFloat);
|
|
303
|
+
k = k.to(at::kFloat);
|
|
304
|
+
v = v.to(at::kFloat);
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
// Handle attention mask
|
|
308
|
+
bool has_mask = attn_mask.has_value();
|
|
309
|
+
at::Tensor mask;
|
|
310
|
+
if (has_mask) {
|
|
311
|
+
mask = attn_mask.value();
|
|
312
|
+
TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
|
|
313
|
+
TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
|
|
314
|
+
// Convert to bool/uint8 if needed - kernel expects uchar (0 = attend, non-0 = mask out)
|
|
315
|
+
if (mask.scalar_type() == at::kBool) {
|
|
316
|
+
// Convert bool to uint8 for Metal compatibility
|
|
317
|
+
mask = mask.to(at::kByte);
|
|
318
|
+
}
|
|
319
|
+
TORCH_CHECK(mask.scalar_type() == at::kByte,
|
|
320
|
+
"Attention mask must be bool or uint8");
|
|
321
|
+
// Expand mask dimensions if needed -> (B, H, N_q, N_kv)
|
|
322
|
+
// Handle (B, 1, N_q, N_kv) -> expand heads
|
|
323
|
+
// Handle (B, H, 1, N_kv) -> expand query dim (1D key mask)
|
|
324
|
+
// Handle (B, 1, 1, N_kv) -> expand both
|
|
325
|
+
if (mask.size(1) == 1 || mask.size(2) == 1) {
|
|
326
|
+
mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
|
|
327
|
+
}
|
|
328
|
+
mask = mask.contiguous();
|
|
329
|
+
}
|
|
330
|
+
|
|
251
331
|
// Allocate output in the appropriate precision
|
|
252
|
-
// With lowPrecisionOutputs=true, MFA writes FP16 directly
|
|
253
332
|
at::Tensor output;
|
|
254
|
-
if (
|
|
333
|
+
if (use_bf16_kernel) {
|
|
334
|
+
// Native BF16 kernel outputs BF16
|
|
335
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
336
|
+
query.options().dtype(at::kBFloat16));
|
|
337
|
+
} else if (low_precision_outputs) {
|
|
338
|
+
// FP16 kernel outputs FP16
|
|
255
339
|
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
256
340
|
query.options().dtype(at::kHalf));
|
|
257
341
|
} else {
|
|
342
|
+
// FP32 kernel outputs FP32
|
|
258
343
|
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
259
344
|
query.options().dtype(at::kFloat));
|
|
260
345
|
}
|
|
@@ -263,8 +348,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
263
348
|
auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
|
|
264
349
|
query.options().dtype(at::kFloat));
|
|
265
350
|
|
|
266
|
-
// Get or create kernel with matching
|
|
267
|
-
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
|
|
351
|
+
// Get or create kernel with matching precision and causal mode
|
|
352
|
+
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask, use_bf16_kernel);
|
|
268
353
|
|
|
269
354
|
// Get Metal buffers with byte offsets
|
|
270
355
|
auto q_info = getBufferInfo(q);
|
|
@@ -273,6 +358,12 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
273
358
|
auto o_info = getBufferInfo(output);
|
|
274
359
|
auto l_info = getBufferInfo(logsumexp);
|
|
275
360
|
|
|
361
|
+
// Mask buffer info (may be nullptr if no mask)
|
|
362
|
+
BufferInfo mask_info = {nil, 0};
|
|
363
|
+
if (has_mask) {
|
|
364
|
+
mask_info = getBufferInfo(mask);
|
|
365
|
+
}
|
|
366
|
+
|
|
276
367
|
// Use PyTorch's MPS stream command encoder for zero-sync integration
|
|
277
368
|
@autoreleasepool {
|
|
278
369
|
auto stream = at::mps::getCurrentMPSStream();
|
|
@@ -290,11 +381,13 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
290
381
|
(__bridge void*)v_info.buffer,
|
|
291
382
|
(__bridge void*)o_info.buffer,
|
|
292
383
|
(__bridge void*)l_info.buffer,
|
|
384
|
+
has_mask ? (__bridge void*)mask_info.buffer : nullptr,
|
|
293
385
|
q_info.byte_offset,
|
|
294
386
|
k_info.byte_offset,
|
|
295
387
|
v_info.byte_offset,
|
|
296
388
|
o_info.byte_offset,
|
|
297
389
|
l_info.byte_offset,
|
|
390
|
+
mask_info.byte_offset,
|
|
298
391
|
static_cast<int32_t>(batch_size),
|
|
299
392
|
static_cast<int32_t>(num_heads)
|
|
300
393
|
);
|
|
@@ -307,7 +400,12 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
307
400
|
// The encoder stays open for coalescing more kernels
|
|
308
401
|
}
|
|
309
402
|
|
|
310
|
-
//
|
|
403
|
+
// Convert output back to BF16 if input was BF16 and we used FP32 fallback
|
|
404
|
+
// (native BF16 kernel already outputs BF16, no conversion needed)
|
|
405
|
+
if (is_bfloat16 && !use_bf16_kernel) {
|
|
406
|
+
output = output.to(at::kBFloat16);
|
|
407
|
+
}
|
|
408
|
+
|
|
311
409
|
// Return both output and logsumexp (needed for backward pass)
|
|
312
410
|
return std::make_tuple(output, logsumexp);
|
|
313
411
|
}
|
|
@@ -317,9 +415,10 @@ at::Tensor mps_flash_attention_forward(
|
|
|
317
415
|
const at::Tensor& query,
|
|
318
416
|
const at::Tensor& key,
|
|
319
417
|
const at::Tensor& value,
|
|
320
|
-
bool is_causal
|
|
418
|
+
bool is_causal,
|
|
419
|
+
const c10::optional<at::Tensor>& attn_mask
|
|
321
420
|
) {
|
|
322
|
-
auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal);
|
|
421
|
+
auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal, attn_mask);
|
|
323
422
|
return output;
|
|
324
423
|
}
|
|
325
424
|
|
|
@@ -334,7 +433,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
334
433
|
const at::Tensor& value, // (B, H, N, D)
|
|
335
434
|
const at::Tensor& output, // (B, H, N, D)
|
|
336
435
|
const at::Tensor& logsumexp, // (B, H, N)
|
|
337
|
-
bool is_causal
|
|
436
|
+
bool is_causal,
|
|
437
|
+
const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
|
|
338
438
|
) {
|
|
339
439
|
// Initialize MFA on first call
|
|
340
440
|
if (!g_initialized) {
|
|
@@ -364,16 +464,35 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
364
464
|
query.scalar_type() == at::kBFloat16);
|
|
365
465
|
bool low_precision_outputs = low_precision;
|
|
366
466
|
|
|
367
|
-
//
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
467
|
+
// Handle attention mask
|
|
468
|
+
bool has_mask = attn_mask.has_value();
|
|
469
|
+
at::Tensor mask;
|
|
470
|
+
if (has_mask) {
|
|
471
|
+
mask = attn_mask.value();
|
|
472
|
+
TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
|
|
473
|
+
TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
|
|
474
|
+
if (mask.scalar_type() == at::kBool) {
|
|
475
|
+
mask = mask.to(at::kByte);
|
|
476
|
+
}
|
|
477
|
+
TORCH_CHECK(mask.scalar_type() == at::kByte,
|
|
478
|
+
"Attention mask must be bool or uint8");
|
|
479
|
+
if (mask.size(1) == 1 && num_heads > 1) {
|
|
480
|
+
mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
|
|
481
|
+
}
|
|
482
|
+
mask = mask.contiguous();
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
// Make inputs contiguous and upcast to fp32 for numerical stability
|
|
486
|
+
// The backward pass accumulates many small values, so fp32 precision is critical
|
|
487
|
+
auto q = query.contiguous().to(at::kFloat);
|
|
488
|
+
auto k = key.contiguous().to(at::kFloat);
|
|
489
|
+
auto v = value.contiguous().to(at::kFloat);
|
|
490
|
+
auto o = output.contiguous().to(at::kFloat);
|
|
491
|
+
auto dO = grad_output.contiguous().to(at::kFloat);
|
|
373
492
|
auto lse = logsumexp.contiguous();
|
|
374
493
|
|
|
375
|
-
// Get or create kernel
|
|
376
|
-
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim,
|
|
494
|
+
// Get or create kernel - always use fp32 for backward pass
|
|
495
|
+
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, false, false, is_causal, has_mask);
|
|
377
496
|
|
|
378
497
|
// Allocate D buffer (dO * O reduction, always fp32)
|
|
379
498
|
auto D = at::empty({batch_size, num_heads, seq_len_q},
|
|
@@ -399,6 +518,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
399
518
|
auto dk_info = getBufferInfo(dK);
|
|
400
519
|
auto dv_info = getBufferInfo(dV);
|
|
401
520
|
|
|
521
|
+
// Mask buffer info (may be nullptr if no mask)
|
|
522
|
+
BufferInfo mask_info = {nil, 0};
|
|
523
|
+
if (has_mask) {
|
|
524
|
+
mask_info = getBufferInfo(mask);
|
|
525
|
+
}
|
|
526
|
+
|
|
402
527
|
// Use PyTorch's MPS stream command encoder for zero-sync integration
|
|
403
528
|
@autoreleasepool {
|
|
404
529
|
auto stream = at::mps::getCurrentMPSStream();
|
|
@@ -419,6 +544,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
419
544
|
(__bridge void*)dq_info.buffer,
|
|
420
545
|
(__bridge void*)dk_info.buffer,
|
|
421
546
|
(__bridge void*)dv_info.buffer,
|
|
547
|
+
has_mask ? (__bridge void*)mask_info.buffer : nullptr,
|
|
422
548
|
q_info.byte_offset,
|
|
423
549
|
k_info.byte_offset,
|
|
424
550
|
v_info.byte_offset,
|
|
@@ -429,6 +555,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
429
555
|
dq_info.byte_offset,
|
|
430
556
|
dk_info.byte_offset,
|
|
431
557
|
dv_info.byte_offset,
|
|
558
|
+
mask_info.byte_offset,
|
|
432
559
|
static_cast<int32_t>(batch_size),
|
|
433
560
|
static_cast<int32_t>(num_heads)
|
|
434
561
|
);
|
|
@@ -462,14 +589,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
462
589
|
py::arg("query"),
|
|
463
590
|
py::arg("key"),
|
|
464
591
|
py::arg("value"),
|
|
465
|
-
py::arg("is_causal") = false
|
|
592
|
+
py::arg("is_causal") = false,
|
|
593
|
+
py::arg("attn_mask") = py::none());
|
|
466
594
|
|
|
467
595
|
m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
|
|
468
596
|
"Flash Attention forward pass (returns output and logsumexp for backward)",
|
|
469
597
|
py::arg("query"),
|
|
470
598
|
py::arg("key"),
|
|
471
599
|
py::arg("value"),
|
|
472
|
-
py::arg("is_causal") = false
|
|
600
|
+
py::arg("is_causal") = false,
|
|
601
|
+
py::arg("attn_mask") = py::none());
|
|
473
602
|
|
|
474
603
|
m.def("backward", &mps_flash_attention_backward,
|
|
475
604
|
"Flash Attention backward pass",
|
|
@@ -479,5 +608,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
479
608
|
py::arg("value"),
|
|
480
609
|
py::arg("output"),
|
|
481
610
|
py::arg("logsumexp"),
|
|
482
|
-
py::arg("is_causal") = false
|
|
611
|
+
py::arg("is_causal") = false,
|
|
612
|
+
py::arg("attn_mask") = py::none());
|
|
483
613
|
}
|
|
Binary file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mps-flash-attn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.13
|
|
4
4
|
Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
|
|
5
5
|
Author: imperatormk
|
|
6
6
|
License-Expression: MIT
|
|
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
|
|
|
32
32
|
## Features
|
|
33
33
|
|
|
34
34
|
- **Forward pass**: 2-5x faster than PyTorch SDPA
|
|
35
|
-
- **Backward pass**: Full gradient support for training
|
|
35
|
+
- **Backward pass**: Full gradient support for training (fp32 precision)
|
|
36
36
|
- **Causal masking**: Native kernel support (only 5% overhead)
|
|
37
|
+
- **Attention masks**: Full boolean mask support for arbitrary masking patterns
|
|
37
38
|
- **FP16/FP32**: Native fp16 output (no conversion overhead)
|
|
38
39
|
- **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
|
|
39
40
|
|
|
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
|
|
|
98
99
|
out = flash_attention(q, k, v, is_causal=True)
|
|
99
100
|
```
|
|
100
101
|
|
|
102
|
+
### Attention masks (for custom masking patterns)
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
# Boolean mask: True = masked (don't attend), False = attend
|
|
106
|
+
mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
|
|
107
|
+
mask[:, :, :, 512:] = True # Mask out positions after 512
|
|
108
|
+
|
|
109
|
+
out = flash_attention(q, k, v, attn_mask=mask)
|
|
110
|
+
```
|
|
111
|
+
|
|
101
112
|
### Training with gradients
|
|
102
113
|
|
|
103
114
|
```python
|
|
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
|
|
|
247
258
|
**Known limitations:**
|
|
248
259
|
- Sequence length must be divisible by block size (typically 64)
|
|
249
260
|
- Head dimension: Best with 32, 64, 96, 128
|
|
250
|
-
- No arbitrary attention masks (only causal or none)
|
|
251
261
|
- No dropout
|
|
252
262
|
|
|
253
263
|
## Credits
|
|
@@ -4,6 +4,7 @@ Setup script for MPS Flash Attention
|
|
|
4
4
|
|
|
5
5
|
import os
|
|
6
6
|
import sys
|
|
7
|
+
import shutil
|
|
7
8
|
from setuptools import setup, find_packages, Extension
|
|
8
9
|
from setuptools.command.build_ext import build_ext
|
|
9
10
|
|
|
@@ -36,6 +37,26 @@ class ObjCppBuildExt(build_ext):
|
|
|
36
37
|
|
|
37
38
|
super().build_extensions()
|
|
38
39
|
|
|
40
|
+
# Copy libMFABridge.dylib to lib/ after building
|
|
41
|
+
self._copy_swift_bridge()
|
|
42
|
+
|
|
43
|
+
def _copy_swift_bridge(self):
|
|
44
|
+
"""Copy Swift bridge dylib to package lib/ directory."""
|
|
45
|
+
src_path = os.path.join(
|
|
46
|
+
os.path.dirname(__file__),
|
|
47
|
+
"swift-bridge", ".build", "release", "libMFABridge.dylib"
|
|
48
|
+
)
|
|
49
|
+
dst_dir = os.path.join(os.path.dirname(__file__), "mps_flash_attn", "lib")
|
|
50
|
+
dst_path = os.path.join(dst_dir, "libMFABridge.dylib")
|
|
51
|
+
|
|
52
|
+
if os.path.exists(src_path):
|
|
53
|
+
os.makedirs(dst_dir, exist_ok=True)
|
|
54
|
+
shutil.copy2(src_path, dst_path)
|
|
55
|
+
print(f"Copied libMFABridge.dylib to {dst_path}")
|
|
56
|
+
else:
|
|
57
|
+
print(f"Warning: {src_path} not found. Build swift-bridge first with:")
|
|
58
|
+
print(" cd swift-bridge && swift build -c release")
|
|
59
|
+
|
|
39
60
|
|
|
40
61
|
def get_extensions():
|
|
41
62
|
if sys.platform != "darwin":
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|