mps-flash-attn 0.1.4__tar.gz → 0.1.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/PKG-INFO +13 -3
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/README.md +12 -2
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/__init__.py +87 -46
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/csrc/mps_flash_attn.mm +244 -72
- mps_flash_attn-0.1.13/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/PKG-INFO +13 -3
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/pyproject.toml +1 -1
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/setup.py +23 -2
- mps_flash_attn-0.1.4/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/LICENSE +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/SOURCES.txt +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/setup.cfg +0 -0
- {mps_flash_attn-0.1.4 → mps_flash_attn-0.1.13}/tests/test_attention.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mps-flash-attn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.13
|
|
4
4
|
Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
|
|
5
5
|
Author: imperatormk
|
|
6
6
|
License-Expression: MIT
|
|
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
|
|
|
32
32
|
## Features
|
|
33
33
|
|
|
34
34
|
- **Forward pass**: 2-5x faster than PyTorch SDPA
|
|
35
|
-
- **Backward pass**: Full gradient support for training
|
|
35
|
+
- **Backward pass**: Full gradient support for training (fp32 precision)
|
|
36
36
|
- **Causal masking**: Native kernel support (only 5% overhead)
|
|
37
|
+
- **Attention masks**: Full boolean mask support for arbitrary masking patterns
|
|
37
38
|
- **FP16/FP32**: Native fp16 output (no conversion overhead)
|
|
38
39
|
- **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
|
|
39
40
|
|
|
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
|
|
|
98
99
|
out = flash_attention(q, k, v, is_causal=True)
|
|
99
100
|
```
|
|
100
101
|
|
|
102
|
+
### Attention masks (for custom masking patterns)
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
# Boolean mask: True = masked (don't attend), False = attend
|
|
106
|
+
mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
|
|
107
|
+
mask[:, :, :, 512:] = True # Mask out positions after 512
|
|
108
|
+
|
|
109
|
+
out = flash_attention(q, k, v, attn_mask=mask)
|
|
110
|
+
```
|
|
111
|
+
|
|
101
112
|
### Training with gradients
|
|
102
113
|
|
|
103
114
|
```python
|
|
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
|
|
|
247
258
|
**Known limitations:**
|
|
248
259
|
- Sequence length must be divisible by block size (typically 64)
|
|
249
260
|
- Head dimension: Best with 32, 64, 96, 128
|
|
250
|
-
- No arbitrary attention masks (only causal or none)
|
|
251
261
|
- No dropout
|
|
252
262
|
|
|
253
263
|
## Credits
|
|
@@ -7,8 +7,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
|
|
|
7
7
|
## Features
|
|
8
8
|
|
|
9
9
|
- **Forward pass**: 2-5x faster than PyTorch SDPA
|
|
10
|
-
- **Backward pass**: Full gradient support for training
|
|
10
|
+
- **Backward pass**: Full gradient support for training (fp32 precision)
|
|
11
11
|
- **Causal masking**: Native kernel support (only 5% overhead)
|
|
12
|
+
- **Attention masks**: Full boolean mask support for arbitrary masking patterns
|
|
12
13
|
- **FP16/FP32**: Native fp16 output (no conversion overhead)
|
|
13
14
|
- **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
|
|
14
15
|
|
|
@@ -73,6 +74,16 @@ out = flash_attention(q, k, v)
|
|
|
73
74
|
out = flash_attention(q, k, v, is_causal=True)
|
|
74
75
|
```
|
|
75
76
|
|
|
77
|
+
### Attention masks (for custom masking patterns)
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
# Boolean mask: True = masked (don't attend), False = attend
|
|
81
|
+
mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
|
|
82
|
+
mask[:, :, :, 512:] = True # Mask out positions after 512
|
|
83
|
+
|
|
84
|
+
out = flash_attention(q, k, v, attn_mask=mask)
|
|
85
|
+
```
|
|
86
|
+
|
|
76
87
|
### Training with gradients
|
|
77
88
|
|
|
78
89
|
```python
|
|
@@ -222,7 +233,6 @@ python scripts/build_metallibs.py
|
|
|
222
233
|
**Known limitations:**
|
|
223
234
|
- Sequence length must be divisible by block size (typically 64)
|
|
224
235
|
- Head dimension: Best with 32, 64, 96, 128
|
|
225
|
-
- No arbitrary attention masks (only causal or none)
|
|
226
236
|
- No dropout
|
|
227
237
|
|
|
228
238
|
## Credits
|
|
@@ -4,7 +4,7 @@ MPS Flash Attention - Flash Attention for PyTorch on Apple Silicon
|
|
|
4
4
|
This package provides memory-efficient attention using Metal Flash Attention kernels.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.13"
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
from typing import Optional
|
|
@@ -20,39 +20,9 @@ except ImportError as e:
|
|
|
20
20
|
_HAS_MFA = False
|
|
21
21
|
_IMPORT_ERROR = str(e)
|
|
22
22
|
|
|
23
|
-
#
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
try:
|
|
27
|
-
import ctypes
|
|
28
|
-
bridge_path = os.environ.get("MFA_BRIDGE_PATH")
|
|
29
|
-
if not bridge_path:
|
|
30
|
-
module_dir = os.path.dirname(__file__)
|
|
31
|
-
candidates = [
|
|
32
|
-
os.path.join(module_dir, "lib", "libMFABridge.dylib"), # Bundled in wheel
|
|
33
|
-
os.path.join(module_dir, "..", "swift-bridge", ".build", "release", "libMFABridge.dylib"),
|
|
34
|
-
os.path.join(module_dir, "libMFABridge.dylib"),
|
|
35
|
-
]
|
|
36
|
-
for path in candidates:
|
|
37
|
-
if os.path.exists(path):
|
|
38
|
-
bridge_path = path
|
|
39
|
-
break
|
|
40
|
-
|
|
41
|
-
if bridge_path and os.path.exists(bridge_path):
|
|
42
|
-
lib = ctypes.CDLL(bridge_path)
|
|
43
|
-
|
|
44
|
-
# Set shipped kernels directory (pre-compiled metallibs + pipeline binaries)
|
|
45
|
-
kernels_dir = os.path.join(os.path.dirname(__file__), "kernels")
|
|
46
|
-
if os.path.exists(kernels_dir):
|
|
47
|
-
lib.mfa_set_kernels_dir(kernels_dir.encode('utf-8'))
|
|
48
|
-
|
|
49
|
-
lib.mfa_init()
|
|
50
|
-
except Exception:
|
|
51
|
-
pass # Init is optional, will fall back to runtime compilation
|
|
52
|
-
|
|
53
|
-
# Initialize shipped kernels on import
|
|
54
|
-
if _HAS_MFA:
|
|
55
|
-
_init_shipped_kernels()
|
|
23
|
+
# Note: The C++ extension handles loading libMFABridge.dylib via dlopen.
|
|
24
|
+
# Set MFA_BRIDGE_PATH environment variable to specify the library location.
|
|
25
|
+
# Do NOT load the library here via ctypes - that causes duplicate class warnings.
|
|
56
26
|
|
|
57
27
|
|
|
58
28
|
def is_available() -> bool:
|
|
@@ -60,11 +30,35 @@ def is_available() -> bool:
|
|
|
60
30
|
return _HAS_MFA and torch.backends.mps.is_available()
|
|
61
31
|
|
|
62
32
|
|
|
33
|
+
def convert_mask(attn_mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
34
|
+
"""
|
|
35
|
+
Convert attention mask to MFA's boolean format.
|
|
36
|
+
|
|
37
|
+
MFA uses boolean masks where True = masked (don't attend).
|
|
38
|
+
PyTorch SDPA uses additive float masks where -inf/large negative = masked.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
attn_mask: Optional mask, either:
|
|
42
|
+
- None: no mask
|
|
43
|
+
- bool tensor: already in MFA format (True = masked)
|
|
44
|
+
- float tensor: additive mask (large negative = masked)
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Boolean mask suitable for flash_attention(), or None
|
|
48
|
+
"""
|
|
49
|
+
if attn_mask is None:
|
|
50
|
+
return None
|
|
51
|
+
if attn_mask.dtype == torch.bool:
|
|
52
|
+
return attn_mask
|
|
53
|
+
# Float mask: large negative values indicate masked positions
|
|
54
|
+
return attn_mask <= -1e3
|
|
55
|
+
|
|
56
|
+
|
|
63
57
|
class FlashAttentionFunction(torch.autograd.Function):
|
|
64
58
|
"""Autograd function for Flash Attention with backward pass support."""
|
|
65
59
|
|
|
66
60
|
@staticmethod
|
|
67
|
-
def forward(ctx, query, key, value, is_causal, scale):
|
|
61
|
+
def forward(ctx, query, key, value, is_causal, scale, attn_mask):
|
|
68
62
|
# Apply scale if provided (MFA uses 1/sqrt(D) internally)
|
|
69
63
|
scale_factor = 1.0
|
|
70
64
|
if scale is not None:
|
|
@@ -74,10 +68,15 @@ class FlashAttentionFunction(torch.autograd.Function):
|
|
|
74
68
|
query = query * scale_factor
|
|
75
69
|
|
|
76
70
|
# Forward with logsumexp for backward
|
|
77
|
-
output, logsumexp = _C.forward_with_lse(query, key, value, is_causal)
|
|
71
|
+
output, logsumexp = _C.forward_with_lse(query, key, value, is_causal, attn_mask)
|
|
78
72
|
|
|
79
73
|
# Save for backward
|
|
80
|
-
|
|
74
|
+
if attn_mask is not None:
|
|
75
|
+
ctx.save_for_backward(query, key, value, output, logsumexp, attn_mask)
|
|
76
|
+
ctx.has_mask = True
|
|
77
|
+
else:
|
|
78
|
+
ctx.save_for_backward(query, key, value, output, logsumexp)
|
|
79
|
+
ctx.has_mask = False
|
|
81
80
|
ctx.is_causal = is_causal
|
|
82
81
|
ctx.scale_factor = scale_factor
|
|
83
82
|
|
|
@@ -85,19 +84,23 @@ class FlashAttentionFunction(torch.autograd.Function):
|
|
|
85
84
|
|
|
86
85
|
@staticmethod
|
|
87
86
|
def backward(ctx, grad_output):
|
|
88
|
-
|
|
87
|
+
if ctx.has_mask:
|
|
88
|
+
query, key, value, output, logsumexp, attn_mask = ctx.saved_tensors
|
|
89
|
+
else:
|
|
90
|
+
query, key, value, output, logsumexp = ctx.saved_tensors
|
|
91
|
+
attn_mask = None
|
|
89
92
|
|
|
90
93
|
# Compute gradients
|
|
91
94
|
dQ, dK, dV = _C.backward(
|
|
92
|
-
grad_output, query, key, value, output, logsumexp, ctx.is_causal
|
|
95
|
+
grad_output, query, key, value, output, logsumexp, ctx.is_causal, attn_mask
|
|
93
96
|
)
|
|
94
97
|
|
|
95
98
|
# If we scaled the query in forward, scale the gradient back
|
|
96
99
|
if ctx.scale_factor != 1.0:
|
|
97
100
|
dQ = dQ * ctx.scale_factor
|
|
98
101
|
|
|
99
|
-
# Return gradients (None for is_causal and
|
|
100
|
-
return dQ, dK, dV, None, None
|
|
102
|
+
# Return gradients (None for is_causal, scale, and attn_mask since they're not tensors or don't need grad)
|
|
103
|
+
return dQ, dK, dV, None, None, None
|
|
101
104
|
|
|
102
105
|
|
|
103
106
|
def flash_attention(
|
|
@@ -106,6 +109,7 @@ def flash_attention(
|
|
|
106
109
|
value: torch.Tensor,
|
|
107
110
|
is_causal: bool = False,
|
|
108
111
|
scale: Optional[float] = None,
|
|
112
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
109
113
|
) -> torch.Tensor:
|
|
110
114
|
"""
|
|
111
115
|
Compute scaled dot-product attention using Flash Attention on MPS.
|
|
@@ -121,6 +125,9 @@ def flash_attention(
|
|
|
121
125
|
value: Value tensor of shape (B, num_heads, seq_len, head_dim)
|
|
122
126
|
is_causal: If True, applies causal masking (for autoregressive models)
|
|
123
127
|
scale: Scaling factor for attention scores. Default: 1/sqrt(head_dim)
|
|
128
|
+
attn_mask: Optional boolean attention mask of shape (B, 1, seq_len_q, seq_len_kv)
|
|
129
|
+
or (B, num_heads, seq_len_q, seq_len_kv). True values indicate
|
|
130
|
+
positions to be masked (not attended to).
|
|
124
131
|
|
|
125
132
|
Returns:
|
|
126
133
|
Output tensor of shape (B, num_heads, seq_len, head_dim)
|
|
@@ -137,6 +144,11 @@ def flash_attention(
|
|
|
137
144
|
>>> q.requires_grad = True
|
|
138
145
|
>>> out = flash_attention(q, k, v)
|
|
139
146
|
>>> out.sum().backward() # Computes dQ
|
|
147
|
+
|
|
148
|
+
# With attention mask:
|
|
149
|
+
>>> mask = torch.zeros(2, 1, 4096, 4096, dtype=torch.bool, device='mps')
|
|
150
|
+
>>> mask[:, :, :, 2048:] = True # mask out second half of keys
|
|
151
|
+
>>> out = flash_attention(q, k, v, attn_mask=mask)
|
|
140
152
|
"""
|
|
141
153
|
if not _HAS_MFA:
|
|
142
154
|
raise RuntimeError(
|
|
@@ -154,9 +166,23 @@ def flash_attention(
|
|
|
154
166
|
raise ValueError("key must be on MPS device")
|
|
155
167
|
if value.device.type != 'mps':
|
|
156
168
|
raise ValueError("value must be on MPS device")
|
|
169
|
+
if attn_mask is not None and attn_mask.device.type != 'mps':
|
|
170
|
+
raise ValueError("attn_mask must be on MPS device")
|
|
171
|
+
|
|
172
|
+
# Fast path: inference mode (no grad) - skip autograd overhead and don't save tensors
|
|
173
|
+
if not torch.is_grad_enabled() or (not query.requires_grad and not key.requires_grad and not value.requires_grad):
|
|
174
|
+
# Apply scale if provided
|
|
175
|
+
if scale is not None:
|
|
176
|
+
default_scale = 1.0 / math.sqrt(query.shape[-1])
|
|
177
|
+
if abs(scale - default_scale) > 1e-6:
|
|
178
|
+
scale_factor = scale / default_scale
|
|
179
|
+
query = query * scale_factor
|
|
180
|
+
|
|
181
|
+
# Forward only - no logsumexp needed, no tensors saved
|
|
182
|
+
return _C.forward(query, key, value, is_causal, attn_mask)
|
|
157
183
|
|
|
158
184
|
# Use autograd function for gradient support
|
|
159
|
-
return FlashAttentionFunction.apply(query, key, value, is_causal, scale)
|
|
185
|
+
return FlashAttentionFunction.apply(query, key, value, is_causal, scale, attn_mask)
|
|
160
186
|
|
|
161
187
|
|
|
162
188
|
def replace_sdpa():
|
|
@@ -173,13 +199,28 @@ def replace_sdpa():
|
|
|
173
199
|
|
|
174
200
|
def patched_sdpa(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
175
201
|
is_causal=False, scale=None):
|
|
176
|
-
# Use MFA for MPS tensors without
|
|
202
|
+
# Use MFA for MPS tensors without dropout
|
|
203
|
+
# Only use MFA for seq_len >= 1024 where it outperforms PyTorch's math backend
|
|
204
|
+
# For shorter sequences, PyTorch's simpler matmul+softmax approach is faster
|
|
177
205
|
if (query.device.type == 'mps' and
|
|
178
|
-
attn_mask is None and
|
|
179
206
|
dropout_p == 0.0 and
|
|
180
|
-
_HAS_MFA
|
|
207
|
+
_HAS_MFA and
|
|
208
|
+
query.shape[2] >= 1024):
|
|
181
209
|
try:
|
|
182
|
-
|
|
210
|
+
# Convert float mask to bool mask if needed
|
|
211
|
+
# PyTorch SDPA uses additive masks (0 = attend, -inf = mask)
|
|
212
|
+
# MFA uses boolean masks (False/0 = attend, True/non-zero = mask)
|
|
213
|
+
mfa_mask = None
|
|
214
|
+
if attn_mask is not None:
|
|
215
|
+
if attn_mask.dtype == torch.bool:
|
|
216
|
+
# Boolean mask: True means masked (don't attend)
|
|
217
|
+
mfa_mask = attn_mask
|
|
218
|
+
else:
|
|
219
|
+
# Float mask: typically -inf for masked positions, 0 for unmasked
|
|
220
|
+
# Convert: positions with large negative values -> True (masked)
|
|
221
|
+
# Use -1e3 threshold to catch -1000, -10000, -inf, etc.
|
|
222
|
+
mfa_mask = attn_mask <= -1e3
|
|
223
|
+
return flash_attention(query, key, value, is_causal=is_causal, scale=scale, attn_mask=mfa_mask)
|
|
183
224
|
except Exception:
|
|
184
225
|
# Fall back to original on any error
|
|
185
226
|
pass
|
|
@@ -12,22 +12,39 @@
|
|
|
12
12
|
#import <Foundation/Foundation.h>
|
|
13
13
|
|
|
14
14
|
#include <dlfcn.h>
|
|
15
|
+
#include <string>
|
|
16
|
+
#include <vector>
|
|
15
17
|
|
|
16
18
|
// ============================================================================
|
|
17
19
|
// MFA Bridge Function Types
|
|
18
20
|
// ============================================================================
|
|
19
21
|
|
|
20
22
|
typedef bool (*mfa_init_fn)();
|
|
21
|
-
typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool);
|
|
22
|
-
typedef
|
|
23
|
-
|
|
24
|
-
|
|
23
|
+
typedef void* (*mfa_create_kernel_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool);
|
|
24
|
+
typedef void* (*mfa_create_kernel_v2_fn)(int32_t, int32_t, int32_t, bool, bool, bool, bool, bool);
|
|
25
|
+
// New zero-sync encode functions that take PyTorch's command encoder
|
|
26
|
+
// Added mask_ptr and mask_offset parameters
|
|
27
|
+
typedef bool (*mfa_forward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*,
|
|
28
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
29
|
+
int32_t, int32_t);
|
|
30
|
+
typedef bool (*mfa_backward_encode_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
31
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
32
|
+
int32_t, int32_t);
|
|
33
|
+
// Legacy sync functions (fallback)
|
|
34
|
+
typedef bool (*mfa_forward_fn)(void*, void*, void*, void*, void*, void*, void*,
|
|
35
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
36
|
+
int32_t, int32_t);
|
|
37
|
+
typedef bool (*mfa_backward_fn)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*,
|
|
38
|
+
int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t,
|
|
25
39
|
int32_t, int32_t);
|
|
26
40
|
typedef void (*mfa_release_kernel_fn)(void*);
|
|
27
41
|
|
|
28
42
|
// Global function pointers
|
|
29
43
|
static mfa_init_fn g_mfa_init = nullptr;
|
|
30
44
|
static mfa_create_kernel_fn g_mfa_create_kernel = nullptr;
|
|
45
|
+
static mfa_create_kernel_v2_fn g_mfa_create_kernel_v2 = nullptr;
|
|
46
|
+
static mfa_forward_encode_fn g_mfa_forward_encode = nullptr;
|
|
47
|
+
static mfa_backward_encode_fn g_mfa_backward_encode = nullptr;
|
|
31
48
|
static mfa_forward_fn g_mfa_forward = nullptr;
|
|
32
49
|
static mfa_backward_fn g_mfa_backward = nullptr;
|
|
33
50
|
static mfa_release_kernel_fn g_mfa_release_kernel = nullptr;
|
|
@@ -38,31 +55,44 @@ static bool g_initialized = false;
|
|
|
38
55
|
// Load MFA Bridge Library
|
|
39
56
|
// ============================================================================
|
|
40
57
|
|
|
58
|
+
// Get the directory containing this shared library
|
|
59
|
+
static std::string get_module_dir() {
|
|
60
|
+
Dl_info info;
|
|
61
|
+
if (dladdr((void*)get_module_dir, &info) && info.dli_fname) {
|
|
62
|
+
std::string path(info.dli_fname);
|
|
63
|
+
size_t last_slash = path.rfind('/');
|
|
64
|
+
if (last_slash != std::string::npos) {
|
|
65
|
+
return path.substr(0, last_slash);
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
return ".";
|
|
69
|
+
}
|
|
70
|
+
|
|
41
71
|
static bool load_mfa_bridge() {
|
|
42
72
|
if (g_dylib_handle) return true;
|
|
43
73
|
|
|
44
|
-
//
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
74
|
+
// First check environment variable (highest priority)
|
|
75
|
+
const char* mfa_path = getenv("MFA_BRIDGE_PATH");
|
|
76
|
+
if (mfa_path) {
|
|
77
|
+
g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
|
|
78
|
+
if (g_dylib_handle) return true;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
// Get the directory containing this extension module
|
|
82
|
+
std::string module_dir = get_module_dir();
|
|
83
|
+
|
|
84
|
+
// Try paths relative to the module directory
|
|
85
|
+
std::vector<std::string> paths = {
|
|
86
|
+
module_dir + "/lib/libMFABridge.dylib", // Bundled in wheel
|
|
87
|
+
module_dir + "/../swift-bridge/.build/release/libMFABridge.dylib", // Dev build
|
|
88
|
+
"libMFABridge.dylib", // Current directory fallback
|
|
51
89
|
};
|
|
52
90
|
|
|
53
|
-
for (
|
|
54
|
-
g_dylib_handle = dlopen(
|
|
91
|
+
for (const auto& path : paths) {
|
|
92
|
+
g_dylib_handle = dlopen(path.c_str(), RTLD_NOW);
|
|
55
93
|
if (g_dylib_handle) break;
|
|
56
94
|
}
|
|
57
95
|
|
|
58
|
-
if (!g_dylib_handle) {
|
|
59
|
-
// Try with absolute path from environment
|
|
60
|
-
const char* mfa_path = getenv("MFA_BRIDGE_PATH");
|
|
61
|
-
if (mfa_path) {
|
|
62
|
-
g_dylib_handle = dlopen(mfa_path, RTLD_NOW);
|
|
63
|
-
}
|
|
64
|
-
}
|
|
65
|
-
|
|
66
96
|
if (!g_dylib_handle) {
|
|
67
97
|
throw std::runtime_error(
|
|
68
98
|
"Failed to load libMFABridge.dylib. Set MFA_BRIDGE_PATH environment variable.");
|
|
@@ -71,11 +101,15 @@ static bool load_mfa_bridge() {
|
|
|
71
101
|
// Load function pointers
|
|
72
102
|
g_mfa_init = (mfa_init_fn)dlsym(g_dylib_handle, "mfa_init");
|
|
73
103
|
g_mfa_create_kernel = (mfa_create_kernel_fn)dlsym(g_dylib_handle, "mfa_create_kernel");
|
|
104
|
+
g_mfa_create_kernel_v2 = (mfa_create_kernel_v2_fn)dlsym(g_dylib_handle, "mfa_create_kernel_v2");
|
|
105
|
+
g_mfa_forward_encode = (mfa_forward_encode_fn)dlsym(g_dylib_handle, "mfa_forward_encode");
|
|
106
|
+
g_mfa_backward_encode = (mfa_backward_encode_fn)dlsym(g_dylib_handle, "mfa_backward_encode");
|
|
74
107
|
g_mfa_forward = (mfa_forward_fn)dlsym(g_dylib_handle, "mfa_forward");
|
|
75
108
|
g_mfa_backward = (mfa_backward_fn)dlsym(g_dylib_handle, "mfa_backward");
|
|
76
109
|
g_mfa_release_kernel = (mfa_release_kernel_fn)dlsym(g_dylib_handle, "mfa_release_kernel");
|
|
77
110
|
|
|
78
|
-
|
|
111
|
+
// Require at least init, create_kernel, forward_encode (for zero-sync path)
|
|
112
|
+
if (!g_mfa_init || !g_mfa_create_kernel || !g_mfa_forward_encode) {
|
|
79
113
|
throw std::runtime_error("Failed to load MFA bridge functions");
|
|
80
114
|
}
|
|
81
115
|
|
|
@@ -116,6 +150,8 @@ struct KernelCacheKey {
|
|
|
116
150
|
bool low_precision;
|
|
117
151
|
bool low_precision_outputs;
|
|
118
152
|
bool causal;
|
|
153
|
+
bool has_mask;
|
|
154
|
+
bool use_bf16;
|
|
119
155
|
|
|
120
156
|
bool operator==(const KernelCacheKey& other) const {
|
|
121
157
|
return seq_len_q == other.seq_len_q &&
|
|
@@ -123,7 +159,9 @@ struct KernelCacheKey {
|
|
|
123
159
|
head_dim == other.head_dim &&
|
|
124
160
|
low_precision == other.low_precision &&
|
|
125
161
|
low_precision_outputs == other.low_precision_outputs &&
|
|
126
|
-
causal == other.causal
|
|
162
|
+
causal == other.causal &&
|
|
163
|
+
has_mask == other.has_mask &&
|
|
164
|
+
use_bf16 == other.use_bf16;
|
|
127
165
|
}
|
|
128
166
|
};
|
|
129
167
|
|
|
@@ -134,28 +172,47 @@ struct KernelCacheKeyHash {
|
|
|
134
172
|
(std::hash<int64_t>()(k.head_dim) << 2) ^
|
|
135
173
|
(std::hash<bool>()(k.low_precision) << 3) ^
|
|
136
174
|
(std::hash<bool>()(k.low_precision_outputs) << 4) ^
|
|
137
|
-
(std::hash<bool>()(k.causal) << 5)
|
|
175
|
+
(std::hash<bool>()(k.causal) << 5) ^
|
|
176
|
+
(std::hash<bool>()(k.has_mask) << 6) ^
|
|
177
|
+
(std::hash<bool>()(k.use_bf16) << 7);
|
|
138
178
|
}
|
|
139
179
|
};
|
|
140
180
|
|
|
141
181
|
static std::unordered_map<KernelCacheKey, void*, KernelCacheKeyHash> g_kernel_cache;
|
|
142
182
|
|
|
143
|
-
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal) {
|
|
144
|
-
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal};
|
|
183
|
+
static void* get_or_create_kernel(int64_t seq_q, int64_t seq_kv, int64_t head_dim, bool low_prec, bool low_prec_outputs, bool causal, bool has_mask, bool use_bf16 = false) {
|
|
184
|
+
KernelCacheKey key{seq_q, seq_kv, head_dim, low_prec, low_prec_outputs, causal, has_mask, use_bf16};
|
|
145
185
|
|
|
146
186
|
auto it = g_kernel_cache.find(key);
|
|
147
187
|
if (it != g_kernel_cache.end()) {
|
|
148
188
|
return it->second;
|
|
149
189
|
}
|
|
150
190
|
|
|
151
|
-
void* kernel =
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
191
|
+
void* kernel = nullptr;
|
|
192
|
+
if (use_bf16 && g_mfa_create_kernel_v2) {
|
|
193
|
+
// Use v2 API with BF16 support
|
|
194
|
+
kernel = g_mfa_create_kernel_v2(
|
|
195
|
+
static_cast<int32_t>(seq_q),
|
|
196
|
+
static_cast<int32_t>(seq_kv),
|
|
197
|
+
static_cast<int32_t>(head_dim),
|
|
198
|
+
low_prec,
|
|
199
|
+
low_prec_outputs,
|
|
200
|
+
causal,
|
|
201
|
+
has_mask,
|
|
202
|
+
use_bf16
|
|
203
|
+
);
|
|
204
|
+
} else {
|
|
205
|
+
// Legacy API
|
|
206
|
+
kernel = g_mfa_create_kernel(
|
|
207
|
+
static_cast<int32_t>(seq_q),
|
|
208
|
+
static_cast<int32_t>(seq_kv),
|
|
209
|
+
static_cast<int32_t>(head_dim),
|
|
210
|
+
low_prec,
|
|
211
|
+
low_prec_outputs,
|
|
212
|
+
causal,
|
|
213
|
+
has_mask
|
|
214
|
+
);
|
|
215
|
+
}
|
|
159
216
|
|
|
160
217
|
if (!kernel) {
|
|
161
218
|
throw std::runtime_error("Failed to create MFA kernel");
|
|
@@ -173,7 +230,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
173
230
|
const at::Tensor& query, // (B, H, N, D)
|
|
174
231
|
const at::Tensor& key, // (B, H, N, D)
|
|
175
232
|
const at::Tensor& value, // (B, H, N, D)
|
|
176
|
-
bool is_causal
|
|
233
|
+
bool is_causal,
|
|
234
|
+
const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
|
|
177
235
|
) {
|
|
178
236
|
// Initialize MFA on first call
|
|
179
237
|
if (!g_initialized) {
|
|
@@ -193,37 +251,95 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
193
251
|
TORCH_CHECK(value.device().is_mps(), "Value must be on MPS device");
|
|
194
252
|
|
|
195
253
|
const int64_t batch_size = query.size(0);
|
|
196
|
-
const int64_t
|
|
254
|
+
const int64_t num_heads_q = query.size(1);
|
|
255
|
+
const int64_t num_heads_kv = key.size(1);
|
|
197
256
|
const int64_t seq_len_q = query.size(2);
|
|
198
257
|
const int64_t head_dim = query.size(3);
|
|
199
258
|
const int64_t seq_len_kv = key.size(2);
|
|
200
259
|
|
|
201
260
|
TORCH_CHECK(key.size(0) == batch_size && value.size(0) == batch_size,
|
|
202
261
|
"Batch size mismatch");
|
|
203
|
-
TORCH_CHECK(key.size(1) == num_heads && value.size(1) == num_heads,
|
|
204
|
-
"Number of heads mismatch");
|
|
205
262
|
TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim,
|
|
206
263
|
"Head dimension mismatch");
|
|
264
|
+
TORCH_CHECK(key.size(1) == value.size(1),
|
|
265
|
+
"K and V must have same number of heads");
|
|
266
|
+
|
|
267
|
+
// Handle GQA (Grouped Query Attention): expand K/V if fewer heads than Q
|
|
268
|
+
const int64_t num_heads = num_heads_q;
|
|
269
|
+
at::Tensor k_expanded, v_expanded;
|
|
270
|
+
|
|
271
|
+
if (num_heads_kv != num_heads_q) {
|
|
272
|
+
// GQA: num_heads_q must be divisible by num_heads_kv
|
|
273
|
+
TORCH_CHECK(num_heads_q % num_heads_kv == 0,
|
|
274
|
+
"num_heads_q (", num_heads_q, ") must be divisible by num_heads_kv (", num_heads_kv, ")");
|
|
275
|
+
int64_t repeat_factor = num_heads_q / num_heads_kv;
|
|
276
|
+
|
|
277
|
+
// Expand K and V to match Q's head count: (B, H_kv, S, D) -> (B, H_q, S, D)
|
|
278
|
+
// Use repeat_interleave for proper GQA expansion
|
|
279
|
+
k_expanded = key.repeat_interleave(repeat_factor, /*dim=*/1);
|
|
280
|
+
v_expanded = value.repeat_interleave(repeat_factor, /*dim=*/1);
|
|
281
|
+
} else {
|
|
282
|
+
k_expanded = key;
|
|
283
|
+
v_expanded = value;
|
|
284
|
+
}
|
|
207
285
|
|
|
208
|
-
// Determine precision
|
|
209
|
-
bool
|
|
210
|
-
|
|
286
|
+
// Determine precision - MFA kernel supports FP16, BF16, and FP32
|
|
287
|
+
bool is_bfloat16 = (query.scalar_type() == at::kBFloat16);
|
|
288
|
+
bool is_fp16 = (query.scalar_type() == at::kHalf);
|
|
211
289
|
|
|
212
|
-
//
|
|
213
|
-
bool
|
|
290
|
+
// Use native BF16 kernel if available, otherwise fall back to FP32
|
|
291
|
+
bool use_bf16_kernel = is_bfloat16 && g_mfa_create_kernel_v2;
|
|
292
|
+
bool low_precision = is_fp16; // FP16 path
|
|
293
|
+
bool low_precision_outputs = is_fp16 || use_bf16_kernel;
|
|
214
294
|
|
|
215
295
|
// Make inputs contiguous
|
|
216
296
|
auto q = query.contiguous();
|
|
217
|
-
auto k =
|
|
218
|
-
auto v =
|
|
297
|
+
auto k = k_expanded.contiguous();
|
|
298
|
+
auto v = v_expanded.contiguous();
|
|
299
|
+
|
|
300
|
+
// For BF16 without native kernel support, convert to FP32 (avoids FP16 overflow)
|
|
301
|
+
if (is_bfloat16 && !use_bf16_kernel) {
|
|
302
|
+
q = q.to(at::kFloat);
|
|
303
|
+
k = k.to(at::kFloat);
|
|
304
|
+
v = v.to(at::kFloat);
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
// Handle attention mask
|
|
308
|
+
bool has_mask = attn_mask.has_value();
|
|
309
|
+
at::Tensor mask;
|
|
310
|
+
if (has_mask) {
|
|
311
|
+
mask = attn_mask.value();
|
|
312
|
+
TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
|
|
313
|
+
TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
|
|
314
|
+
// Convert to bool/uint8 if needed - kernel expects uchar (0 = attend, non-0 = mask out)
|
|
315
|
+
if (mask.scalar_type() == at::kBool) {
|
|
316
|
+
// Convert bool to uint8 for Metal compatibility
|
|
317
|
+
mask = mask.to(at::kByte);
|
|
318
|
+
}
|
|
319
|
+
TORCH_CHECK(mask.scalar_type() == at::kByte,
|
|
320
|
+
"Attention mask must be bool or uint8");
|
|
321
|
+
// Expand mask dimensions if needed -> (B, H, N_q, N_kv)
|
|
322
|
+
// Handle (B, 1, N_q, N_kv) -> expand heads
|
|
323
|
+
// Handle (B, H, 1, N_kv) -> expand query dim (1D key mask)
|
|
324
|
+
// Handle (B, 1, 1, N_kv) -> expand both
|
|
325
|
+
if (mask.size(1) == 1 || mask.size(2) == 1) {
|
|
326
|
+
mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
|
|
327
|
+
}
|
|
328
|
+
mask = mask.contiguous();
|
|
329
|
+
}
|
|
219
330
|
|
|
220
331
|
// Allocate output in the appropriate precision
|
|
221
|
-
// With lowPrecisionOutputs=true, MFA writes FP16 directly
|
|
222
332
|
at::Tensor output;
|
|
223
|
-
if (
|
|
333
|
+
if (use_bf16_kernel) {
|
|
334
|
+
// Native BF16 kernel outputs BF16
|
|
335
|
+
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
336
|
+
query.options().dtype(at::kBFloat16));
|
|
337
|
+
} else if (low_precision_outputs) {
|
|
338
|
+
// FP16 kernel outputs FP16
|
|
224
339
|
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
225
340
|
query.options().dtype(at::kHalf));
|
|
226
341
|
} else {
|
|
342
|
+
// FP32 kernel outputs FP32
|
|
227
343
|
output = at::empty({batch_size, num_heads, seq_len_q, head_dim},
|
|
228
344
|
query.options().dtype(at::kFloat));
|
|
229
345
|
}
|
|
@@ -232,8 +348,8 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
232
348
|
auto logsumexp = at::empty({batch_size, num_heads, seq_len_q},
|
|
233
349
|
query.options().dtype(at::kFloat));
|
|
234
350
|
|
|
235
|
-
// Get or create kernel with matching
|
|
236
|
-
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal);
|
|
351
|
+
// Get or create kernel with matching precision and causal mode
|
|
352
|
+
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, low_precision, low_precision_outputs, is_causal, has_mask, use_bf16_kernel);
|
|
237
353
|
|
|
238
354
|
// Get Metal buffers with byte offsets
|
|
239
355
|
auto q_info = getBufferInfo(q);
|
|
@@ -242,25 +358,36 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
242
358
|
auto o_info = getBufferInfo(output);
|
|
243
359
|
auto l_info = getBufferInfo(logsumexp);
|
|
244
360
|
|
|
245
|
-
//
|
|
361
|
+
// Mask buffer info (may be nullptr if no mask)
|
|
362
|
+
BufferInfo mask_info = {nil, 0};
|
|
363
|
+
if (has_mask) {
|
|
364
|
+
mask_info = getBufferInfo(mask);
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
// Use PyTorch's MPS stream command encoder for zero-sync integration
|
|
246
368
|
@autoreleasepool {
|
|
247
|
-
// Wait for PyTorch operations to complete
|
|
248
369
|
auto stream = at::mps::getCurrentMPSStream();
|
|
249
|
-
stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
|
|
250
370
|
|
|
251
|
-
//
|
|
252
|
-
|
|
371
|
+
// Get PyTorch's shared command encoder - this is the key for zero-sync!
|
|
372
|
+
// All our dispatches go onto the same encoder that PyTorch uses.
|
|
373
|
+
id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
|
|
374
|
+
|
|
375
|
+
// Execute MFA using the shared encoder (no sync needed!)
|
|
376
|
+
bool success = g_mfa_forward_encode(
|
|
253
377
|
kernel,
|
|
378
|
+
(__bridge void*)encoder, // PyTorch's shared command encoder
|
|
254
379
|
(__bridge void*)q_info.buffer,
|
|
255
380
|
(__bridge void*)k_info.buffer,
|
|
256
381
|
(__bridge void*)v_info.buffer,
|
|
257
382
|
(__bridge void*)o_info.buffer,
|
|
258
383
|
(__bridge void*)l_info.buffer,
|
|
384
|
+
has_mask ? (__bridge void*)mask_info.buffer : nullptr,
|
|
259
385
|
q_info.byte_offset,
|
|
260
386
|
k_info.byte_offset,
|
|
261
387
|
v_info.byte_offset,
|
|
262
388
|
o_info.byte_offset,
|
|
263
389
|
l_info.byte_offset,
|
|
390
|
+
mask_info.byte_offset,
|
|
264
391
|
static_cast<int32_t>(batch_size),
|
|
265
392
|
static_cast<int32_t>(num_heads)
|
|
266
393
|
);
|
|
@@ -268,9 +395,17 @@ std::tuple<at::Tensor, at::Tensor> mps_flash_attention_forward_with_lse(
|
|
|
268
395
|
if (!success) {
|
|
269
396
|
throw std::runtime_error("MFA forward pass failed");
|
|
270
397
|
}
|
|
398
|
+
|
|
399
|
+
// No commit needed - PyTorch will commit when it needs the results
|
|
400
|
+
// The encoder stays open for coalescing more kernels
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
// Convert output back to BF16 if input was BF16 and we used FP32 fallback
|
|
404
|
+
// (native BF16 kernel already outputs BF16, no conversion needed)
|
|
405
|
+
if (is_bfloat16 && !use_bf16_kernel) {
|
|
406
|
+
output = output.to(at::kBFloat16);
|
|
271
407
|
}
|
|
272
408
|
|
|
273
|
-
// Output is already in the correct dtype (fp16 or fp32)
|
|
274
409
|
// Return both output and logsumexp (needed for backward pass)
|
|
275
410
|
return std::make_tuple(output, logsumexp);
|
|
276
411
|
}
|
|
@@ -280,9 +415,10 @@ at::Tensor mps_flash_attention_forward(
|
|
|
280
415
|
const at::Tensor& query,
|
|
281
416
|
const at::Tensor& key,
|
|
282
417
|
const at::Tensor& value,
|
|
283
|
-
bool is_causal
|
|
418
|
+
bool is_causal,
|
|
419
|
+
const c10::optional<at::Tensor>& attn_mask
|
|
284
420
|
) {
|
|
285
|
-
auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal);
|
|
421
|
+
auto [output, logsumexp] = mps_flash_attention_forward_with_lse(query, key, value, is_causal, attn_mask);
|
|
286
422
|
return output;
|
|
287
423
|
}
|
|
288
424
|
|
|
@@ -297,7 +433,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
297
433
|
const at::Tensor& value, // (B, H, N, D)
|
|
298
434
|
const at::Tensor& output, // (B, H, N, D)
|
|
299
435
|
const at::Tensor& logsumexp, // (B, H, N)
|
|
300
|
-
bool is_causal
|
|
436
|
+
bool is_causal,
|
|
437
|
+
const c10::optional<at::Tensor>& attn_mask // Optional (B, 1, N_q, N_kv) or (B, H, N_q, N_kv)
|
|
301
438
|
) {
|
|
302
439
|
// Initialize MFA on first call
|
|
303
440
|
if (!g_initialized) {
|
|
@@ -327,16 +464,35 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
327
464
|
query.scalar_type() == at::kBFloat16);
|
|
328
465
|
bool low_precision_outputs = low_precision;
|
|
329
466
|
|
|
330
|
-
//
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
467
|
+
// Handle attention mask
|
|
468
|
+
bool has_mask = attn_mask.has_value();
|
|
469
|
+
at::Tensor mask;
|
|
470
|
+
if (has_mask) {
|
|
471
|
+
mask = attn_mask.value();
|
|
472
|
+
TORCH_CHECK(mask.dim() == 4, "Attention mask must be 4D (B, H or 1, N_q, N_kv)");
|
|
473
|
+
TORCH_CHECK(mask.device().is_mps(), "Attention mask must be on MPS device");
|
|
474
|
+
if (mask.scalar_type() == at::kBool) {
|
|
475
|
+
mask = mask.to(at::kByte);
|
|
476
|
+
}
|
|
477
|
+
TORCH_CHECK(mask.scalar_type() == at::kByte,
|
|
478
|
+
"Attention mask must be bool or uint8");
|
|
479
|
+
if (mask.size(1) == 1 && num_heads > 1) {
|
|
480
|
+
mask = mask.expand({batch_size, num_heads, seq_len_q, seq_len_kv});
|
|
481
|
+
}
|
|
482
|
+
mask = mask.contiguous();
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
// Make inputs contiguous and upcast to fp32 for numerical stability
|
|
486
|
+
// The backward pass accumulates many small values, so fp32 precision is critical
|
|
487
|
+
auto q = query.contiguous().to(at::kFloat);
|
|
488
|
+
auto k = key.contiguous().to(at::kFloat);
|
|
489
|
+
auto v = value.contiguous().to(at::kFloat);
|
|
490
|
+
auto o = output.contiguous().to(at::kFloat);
|
|
491
|
+
auto dO = grad_output.contiguous().to(at::kFloat);
|
|
336
492
|
auto lse = logsumexp.contiguous();
|
|
337
493
|
|
|
338
|
-
// Get or create kernel
|
|
339
|
-
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim,
|
|
494
|
+
// Get or create kernel - always use fp32 for backward pass
|
|
495
|
+
void* kernel = get_or_create_kernel(seq_len_q, seq_len_kv, head_dim, false, false, is_causal, has_mask);
|
|
340
496
|
|
|
341
497
|
// Allocate D buffer (dO * O reduction, always fp32)
|
|
342
498
|
auto D = at::empty({batch_size, num_heads, seq_len_q},
|
|
@@ -362,13 +518,22 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
362
518
|
auto dk_info = getBufferInfo(dK);
|
|
363
519
|
auto dv_info = getBufferInfo(dV);
|
|
364
520
|
|
|
365
|
-
//
|
|
521
|
+
// Mask buffer info (may be nullptr if no mask)
|
|
522
|
+
BufferInfo mask_info = {nil, 0};
|
|
523
|
+
if (has_mask) {
|
|
524
|
+
mask_info = getBufferInfo(mask);
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
// Use PyTorch's MPS stream command encoder for zero-sync integration
|
|
366
528
|
@autoreleasepool {
|
|
367
529
|
auto stream = at::mps::getCurrentMPSStream();
|
|
368
|
-
stream->synchronize(at::mps::SyncType::COMMIT_AND_WAIT);
|
|
369
530
|
|
|
370
|
-
|
|
531
|
+
// Get PyTorch's shared command encoder
|
|
532
|
+
id<MTLComputeCommandEncoder> encoder = stream->commandEncoder();
|
|
533
|
+
|
|
534
|
+
bool success = g_mfa_backward_encode(
|
|
371
535
|
kernel,
|
|
536
|
+
(__bridge void*)encoder, // PyTorch's shared command encoder
|
|
372
537
|
(__bridge void*)q_info.buffer,
|
|
373
538
|
(__bridge void*)k_info.buffer,
|
|
374
539
|
(__bridge void*)v_info.buffer,
|
|
@@ -379,6 +544,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
379
544
|
(__bridge void*)dq_info.buffer,
|
|
380
545
|
(__bridge void*)dk_info.buffer,
|
|
381
546
|
(__bridge void*)dv_info.buffer,
|
|
547
|
+
has_mask ? (__bridge void*)mask_info.buffer : nullptr,
|
|
382
548
|
q_info.byte_offset,
|
|
383
549
|
k_info.byte_offset,
|
|
384
550
|
v_info.byte_offset,
|
|
@@ -389,6 +555,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
389
555
|
dq_info.byte_offset,
|
|
390
556
|
dk_info.byte_offset,
|
|
391
557
|
dv_info.byte_offset,
|
|
558
|
+
mask_info.byte_offset,
|
|
392
559
|
static_cast<int32_t>(batch_size),
|
|
393
560
|
static_cast<int32_t>(num_heads)
|
|
394
561
|
);
|
|
@@ -396,6 +563,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_flash_attention_backward(
|
|
|
396
563
|
if (!success) {
|
|
397
564
|
throw std::runtime_error("MFA backward pass failed");
|
|
398
565
|
}
|
|
566
|
+
|
|
567
|
+
// No commit needed - PyTorch will commit when it needs the results
|
|
399
568
|
}
|
|
400
569
|
|
|
401
570
|
// Convert gradients back to input dtype if needed
|
|
@@ -420,14 +589,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
420
589
|
py::arg("query"),
|
|
421
590
|
py::arg("key"),
|
|
422
591
|
py::arg("value"),
|
|
423
|
-
py::arg("is_causal") = false
|
|
592
|
+
py::arg("is_causal") = false,
|
|
593
|
+
py::arg("attn_mask") = py::none());
|
|
424
594
|
|
|
425
595
|
m.def("forward_with_lse", &mps_flash_attention_forward_with_lse,
|
|
426
596
|
"Flash Attention forward pass (returns output and logsumexp for backward)",
|
|
427
597
|
py::arg("query"),
|
|
428
598
|
py::arg("key"),
|
|
429
599
|
py::arg("value"),
|
|
430
|
-
py::arg("is_causal") = false
|
|
600
|
+
py::arg("is_causal") = false,
|
|
601
|
+
py::arg("attn_mask") = py::none());
|
|
431
602
|
|
|
432
603
|
m.def("backward", &mps_flash_attention_backward,
|
|
433
604
|
"Flash Attention backward pass",
|
|
@@ -437,5 +608,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
|
437
608
|
py::arg("value"),
|
|
438
609
|
py::arg("output"),
|
|
439
610
|
py::arg("logsumexp"),
|
|
440
|
-
py::arg("is_causal") = false
|
|
611
|
+
py::arg("is_causal") = false,
|
|
612
|
+
py::arg("attn_mask") = py::none());
|
|
441
613
|
}
|
|
Binary file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mps-flash-attn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.13
|
|
4
4
|
Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
|
|
5
5
|
Author: imperatormk
|
|
6
6
|
License-Expression: MIT
|
|
@@ -32,8 +32,9 @@ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
|
|
|
32
32
|
## Features
|
|
33
33
|
|
|
34
34
|
- **Forward pass**: 2-5x faster than PyTorch SDPA
|
|
35
|
-
- **Backward pass**: Full gradient support for training
|
|
35
|
+
- **Backward pass**: Full gradient support for training (fp32 precision)
|
|
36
36
|
- **Causal masking**: Native kernel support (only 5% overhead)
|
|
37
|
+
- **Attention masks**: Full boolean mask support for arbitrary masking patterns
|
|
37
38
|
- **FP16/FP32**: Native fp16 output (no conversion overhead)
|
|
38
39
|
- **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
|
|
39
40
|
|
|
@@ -98,6 +99,16 @@ out = flash_attention(q, k, v)
|
|
|
98
99
|
out = flash_attention(q, k, v, is_causal=True)
|
|
99
100
|
```
|
|
100
101
|
|
|
102
|
+
### Attention masks (for custom masking patterns)
|
|
103
|
+
|
|
104
|
+
```python
|
|
105
|
+
# Boolean mask: True = masked (don't attend), False = attend
|
|
106
|
+
mask = torch.zeros(B, 1, N, N, dtype=torch.bool, device='mps')
|
|
107
|
+
mask[:, :, :, 512:] = True # Mask out positions after 512
|
|
108
|
+
|
|
109
|
+
out = flash_attention(q, k, v, attn_mask=mask)
|
|
110
|
+
```
|
|
111
|
+
|
|
101
112
|
### Training with gradients
|
|
102
113
|
|
|
103
114
|
```python
|
|
@@ -247,7 +258,6 @@ python scripts/build_metallibs.py
|
|
|
247
258
|
**Known limitations:**
|
|
248
259
|
- Sequence length must be divisible by block size (typically 64)
|
|
249
260
|
- Head dimension: Best with 32, 64, 96, 128
|
|
250
|
-
- No arbitrary attention masks (only causal or none)
|
|
251
261
|
- No dropout
|
|
252
262
|
|
|
253
263
|
## Credits
|
|
@@ -4,6 +4,7 @@ Setup script for MPS Flash Attention
|
|
|
4
4
|
|
|
5
5
|
import os
|
|
6
6
|
import sys
|
|
7
|
+
import shutil
|
|
7
8
|
from setuptools import setup, find_packages, Extension
|
|
8
9
|
from setuptools.command.build_ext import build_ext
|
|
9
10
|
|
|
@@ -36,6 +37,26 @@ class ObjCppBuildExt(build_ext):
|
|
|
36
37
|
|
|
37
38
|
super().build_extensions()
|
|
38
39
|
|
|
40
|
+
# Copy libMFABridge.dylib to lib/ after building
|
|
41
|
+
self._copy_swift_bridge()
|
|
42
|
+
|
|
43
|
+
def _copy_swift_bridge(self):
|
|
44
|
+
"""Copy Swift bridge dylib to package lib/ directory."""
|
|
45
|
+
src_path = os.path.join(
|
|
46
|
+
os.path.dirname(__file__),
|
|
47
|
+
"swift-bridge", ".build", "release", "libMFABridge.dylib"
|
|
48
|
+
)
|
|
49
|
+
dst_dir = os.path.join(os.path.dirname(__file__), "mps_flash_attn", "lib")
|
|
50
|
+
dst_path = os.path.join(dst_dir, "libMFABridge.dylib")
|
|
51
|
+
|
|
52
|
+
if os.path.exists(src_path):
|
|
53
|
+
os.makedirs(dst_dir, exist_ok=True)
|
|
54
|
+
shutil.copy2(src_path, dst_path)
|
|
55
|
+
print(f"Copied libMFABridge.dylib to {dst_path}")
|
|
56
|
+
else:
|
|
57
|
+
print(f"Warning: {src_path} not found. Build swift-bridge first with:")
|
|
58
|
+
print(" cd swift-bridge && swift build -c release")
|
|
59
|
+
|
|
39
60
|
|
|
40
61
|
def get_extensions():
|
|
41
62
|
if sys.platform != "darwin":
|
|
@@ -44,14 +65,14 @@ def get_extensions():
|
|
|
44
65
|
return [Extension(
|
|
45
66
|
name="mps_flash_attn._C",
|
|
46
67
|
sources=["mps_flash_attn/csrc/mps_flash_attn.mm"],
|
|
47
|
-
extra_compile_args=["-std=c++17", "-O3"],
|
|
68
|
+
extra_compile_args=["-std=c++17", "-O3", "-DTORCH_EXTENSION_NAME=_C"],
|
|
48
69
|
extra_link_args=["-framework", "Metal", "-framework", "Foundation"],
|
|
49
70
|
)]
|
|
50
71
|
|
|
51
72
|
|
|
52
73
|
setup(
|
|
53
74
|
name="mps-flash-attn",
|
|
54
|
-
version="0.1.
|
|
75
|
+
version="0.1.5",
|
|
55
76
|
packages=find_packages(),
|
|
56
77
|
package_data={
|
|
57
78
|
"mps_flash_attn": [
|
|
Binary file
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|