mps-flash-attn 0.1.13__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mps-flash-attn might be problematic. Click here for more details.
- mps_flash_attn-0.2.4/PKG-INFO +218 -0
- mps_flash_attn-0.2.4/README.md +193 -0
- mps_flash_attn-0.2.4/mps_flash_attn/__init__.py +1028 -0
- mps_flash_attn-0.2.4/mps_flash_attn/benchmark.py +666 -0
- mps_flash_attn-0.2.4/mps_flash_attn/csrc/mps_flash_attn.mm +1240 -0
- mps_flash_attn-0.2.4/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- mps_flash_attn-0.2.4/mps_flash_attn.egg-info/PKG-INFO +218 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn.egg-info/SOURCES.txt +2 -1
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/pyproject.toml +1 -1
- mps_flash_attn-0.2.4/tests/test_mfa_v2.py +701 -0
- mps_flash_attn-0.1.13/PKG-INFO +0 -270
- mps_flash_attn-0.1.13/README.md +0 -245
- mps_flash_attn-0.1.13/mps_flash_attn/__init__.py +0 -289
- mps_flash_attn-0.1.13/mps_flash_attn/csrc/mps_flash_attn.mm +0 -613
- mps_flash_attn-0.1.13/mps_flash_attn/lib/libMFABridge.dylib +0 -0
- mps_flash_attn-0.1.13/mps_flash_attn.egg-info/PKG-INFO +0 -270
- mps_flash_attn-0.1.13/tests/test_attention.py +0 -145
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/LICENSE +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn/kernels/manifest.json +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn.egg-info/requires.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/mps_flash_attn.egg-info/top_level.txt +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/setup.cfg +0 -0
- {mps_flash_attn-0.1.13 → mps_flash_attn-0.2.4}/setup.py +0 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mps-flash-attn
|
|
3
|
+
Version: 0.2.4
|
|
4
|
+
Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
|
|
5
|
+
Author: imperatormk
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/mpsops/mps-flash-attention
|
|
8
|
+
Project-URL: Repository, https://github.com/mpsops/mps-flash-attention
|
|
9
|
+
Project-URL: Issues, https://github.com/mpsops/mps-flash-attention/issues
|
|
10
|
+
Keywords: flash-attention,apple-silicon,pytorch,mps,metal,transformer,attention
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Operating System :: MacOS :: MacOS X
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: torch>=2.0.0
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
|
|
26
|
+
# MPS Flash Attention
|
|
27
|
+
|
|
28
|
+
Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
|
|
29
|
+
|
|
30
|
+
**O(N) memory** instead of O(N²), enabling 100K+ sequence lengths on unified memory.
|
|
31
|
+
|
|
32
|
+
## Performance
|
|
33
|
+
|
|
34
|
+
Benchmarked on Apple Silicon (M1/M2/M3/M4):
|
|
35
|
+
|
|
36
|
+
| Seq Length | vs PyTorch SDPA | Notes |
|
|
37
|
+
|------------|-----------------|-------|
|
|
38
|
+
| 1024 | 1.1-2.0x faster | Crossover point |
|
|
39
|
+
| 2048 | 1.7-3.7x faster | Sweet spot |
|
|
40
|
+
| 4096 | 2.0-3.9x faster | Peak performance |
|
|
41
|
+
| 8192+ | 3-4x faster | SDPA often OOMs |
|
|
42
|
+
|
|
43
|
+
Average speedup: **1.8x** across all configurations.
|
|
44
|
+
|
|
45
|
+
## Installation
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install mps-flash-attn
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
### Build from source
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
git clone --recursive https://github.com/mpsops/mps-flash-attention.git
|
|
55
|
+
cd mps-flash-attention
|
|
56
|
+
|
|
57
|
+
# Build Swift bridge
|
|
58
|
+
cd swift-bridge && swift build -c release && cd ..
|
|
59
|
+
|
|
60
|
+
# Install
|
|
61
|
+
pip install -e .
|
|
62
|
+
|
|
63
|
+
# Set bridge path
|
|
64
|
+
export MFA_BRIDGE_PATH=$PWD/swift-bridge/.build/release/libMFABridge.dylib
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## Usage
|
|
68
|
+
|
|
69
|
+
### Basic Attention
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
from mps_flash_attn import flash_attention
|
|
73
|
+
|
|
74
|
+
# (B, H, N, D) format
|
|
75
|
+
q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
76
|
+
k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
77
|
+
v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
78
|
+
|
|
79
|
+
out = flash_attention(q, k, v)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
### Causal Masking
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
out = flash_attention(q, k, v, is_causal=True)
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
### Sliding Window (Mistral/Llama 3.2)
|
|
89
|
+
|
|
90
|
+
```python
|
|
91
|
+
# Only attend to last 4096 tokens
|
|
92
|
+
out = flash_attention(q, k, v, is_causal=True, window_size=4096)
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
### Quantized KV Cache (2-4x memory savings)
|
|
96
|
+
|
|
97
|
+
```python
|
|
98
|
+
from mps_flash_attn import flash_attention_fp8, quantize_kv_fp8
|
|
99
|
+
|
|
100
|
+
# Quantize K/V to FP8
|
|
101
|
+
k_quant, k_scale = quantize_kv_fp8(k)
|
|
102
|
+
v_quant, v_scale = quantize_kv_fp8(v)
|
|
103
|
+
|
|
104
|
+
# Run attention with quantized KV
|
|
105
|
+
out = flash_attention_fp8(q, k_quant, v_quant, k_scale, v_scale)
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
### 100K+ Long Sequences
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
from mps_flash_attn import flash_attention_chunked
|
|
112
|
+
|
|
113
|
+
# Process 100K tokens without OOM
|
|
114
|
+
q = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
|
|
115
|
+
k = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
|
|
116
|
+
v = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
|
|
117
|
+
|
|
118
|
+
out = flash_attention_chunked(q, k, v, chunk_size=8192)
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
### Drop-in SDPA Replacement
|
|
122
|
+
|
|
123
|
+
```python
|
|
124
|
+
from mps_flash_attn import replace_sdpa
|
|
125
|
+
|
|
126
|
+
replace_sdpa() # Patches F.scaled_dot_product_attention
|
|
127
|
+
|
|
128
|
+
# Now all PyTorch attention uses Flash Attention on MPS
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
### torch.compile() Support
|
|
132
|
+
|
|
133
|
+
```python
|
|
134
|
+
from mps_flash_attn import register_custom_op
|
|
135
|
+
|
|
136
|
+
register_custom_op()
|
|
137
|
+
|
|
138
|
+
@torch.compile
|
|
139
|
+
def my_attention(q, k, v):
|
|
140
|
+
return torch.ops.mfa.flash_attention(q, k, v, False, None, None)
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
### Training with BF16 Backward
|
|
144
|
+
|
|
145
|
+
```python
|
|
146
|
+
out = flash_attention(q, k, v, bf16_backward=True) # 2x faster backward
|
|
147
|
+
loss = out.sum()
|
|
148
|
+
loss.backward()
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
### Benchmarking
|
|
152
|
+
|
|
153
|
+
```bash
|
|
154
|
+
# Quick benchmark
|
|
155
|
+
python -m mps_flash_attn.benchmark --suite quick
|
|
156
|
+
|
|
157
|
+
# Full suite with report
|
|
158
|
+
python -m mps_flash_attn.benchmark --suite full --output report.html
|
|
159
|
+
```
|
|
160
|
+
|
|
161
|
+
```python
|
|
162
|
+
from mps_flash_attn.benchmark import run_suite, compare_vs_sdpa
|
|
163
|
+
|
|
164
|
+
results = run_suite(seq_lengths=[1024, 2048, 4096])
|
|
165
|
+
compare_vs_sdpa()
|
|
166
|
+
```
|
|
167
|
+
|
|
168
|
+
## Features
|
|
169
|
+
|
|
170
|
+
| Feature | Status | Notes |
|
|
171
|
+
|---------|--------|-------|
|
|
172
|
+
| Forward pass | ✅ | FP16/BF16/FP32 |
|
|
173
|
+
| Backward pass | ✅ | Full gradient support |
|
|
174
|
+
| Causal masking | ✅ | Native kernel support |
|
|
175
|
+
| Attention masks | ✅ | Boolean masks |
|
|
176
|
+
| Sliding window | ✅ | For local attention models |
|
|
177
|
+
| GQA/MQA | ✅ | Grouped-query attention |
|
|
178
|
+
| Quantized KV | ✅ | FP8, INT8, NF4 |
|
|
179
|
+
| Chunked attention | ✅ | 100K+ tokens |
|
|
180
|
+
| torch.compile() | ✅ | Custom op backend |
|
|
181
|
+
| Dropout | ❌ | Not supported |
|
|
182
|
+
|
|
183
|
+
## Architecture
|
|
184
|
+
|
|
185
|
+
```
|
|
186
|
+
Python API (mps_flash_attn)
|
|
187
|
+
│
|
|
188
|
+
C++ Extension (mps_flash_attn.mm)
|
|
189
|
+
│ dlopen
|
|
190
|
+
Swift Bridge (MFABridge.swift)
|
|
191
|
+
│
|
|
192
|
+
Metal Flash Attention (kernel generation)
|
|
193
|
+
│
|
|
194
|
+
Metal GPU Shaders
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
## Requirements
|
|
198
|
+
|
|
199
|
+
- macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
|
|
200
|
+
- Apple Silicon (M1/M2/M3/M4)
|
|
201
|
+
- Python 3.10+
|
|
202
|
+
- PyTorch 2.0+
|
|
203
|
+
|
|
204
|
+
## TODO / Future Optimizations
|
|
205
|
+
|
|
206
|
+
- [ ] **Batched kernel dispatch** - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
|
|
207
|
+
- [ ] **Fused QKV projection + attention** - Single kernel from input to output, avoid intermediate buffers
|
|
208
|
+
- [ ] **Pre-scaled bias option** - Allow passing pre-scaled bias to avoid per-call scaling overhead
|
|
209
|
+
- [ ] **LoRA fusion** - Fuse adapter weights into attention computation
|
|
210
|
+
|
|
211
|
+
## Credits
|
|
212
|
+
|
|
213
|
+
- [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
|
|
214
|
+
- [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
|
|
215
|
+
|
|
216
|
+
## License
|
|
217
|
+
|
|
218
|
+
MIT
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# MPS Flash Attention
|
|
2
|
+
|
|
3
|
+
Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
|
|
4
|
+
|
|
5
|
+
**O(N) memory** instead of O(N²), enabling 100K+ sequence lengths on unified memory.
|
|
6
|
+
|
|
7
|
+
## Performance
|
|
8
|
+
|
|
9
|
+
Benchmarked on Apple Silicon (M1/M2/M3/M4):
|
|
10
|
+
|
|
11
|
+
| Seq Length | vs PyTorch SDPA | Notes |
|
|
12
|
+
|------------|-----------------|-------|
|
|
13
|
+
| 1024 | 1.1-2.0x faster | Crossover point |
|
|
14
|
+
| 2048 | 1.7-3.7x faster | Sweet spot |
|
|
15
|
+
| 4096 | 2.0-3.9x faster | Peak performance |
|
|
16
|
+
| 8192+ | 3-4x faster | SDPA often OOMs |
|
|
17
|
+
|
|
18
|
+
Average speedup: **1.8x** across all configurations.
|
|
19
|
+
|
|
20
|
+
## Installation
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
pip install mps-flash-attn
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
### Build from source
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
git clone --recursive https://github.com/mpsops/mps-flash-attention.git
|
|
30
|
+
cd mps-flash-attention
|
|
31
|
+
|
|
32
|
+
# Build Swift bridge
|
|
33
|
+
cd swift-bridge && swift build -c release && cd ..
|
|
34
|
+
|
|
35
|
+
# Install
|
|
36
|
+
pip install -e .
|
|
37
|
+
|
|
38
|
+
# Set bridge path
|
|
39
|
+
export MFA_BRIDGE_PATH=$PWD/swift-bridge/.build/release/libMFABridge.dylib
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
## Usage
|
|
43
|
+
|
|
44
|
+
### Basic Attention
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from mps_flash_attn import flash_attention
|
|
48
|
+
|
|
49
|
+
# (B, H, N, D) format
|
|
50
|
+
q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
51
|
+
k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
52
|
+
v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
|
|
53
|
+
|
|
54
|
+
out = flash_attention(q, k, v)
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
### Causal Masking
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
out = flash_attention(q, k, v, is_causal=True)
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
### Sliding Window (Mistral/Llama 3.2)
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
# Only attend to last 4096 tokens
|
|
67
|
+
out = flash_attention(q, k, v, is_causal=True, window_size=4096)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
### Quantized KV Cache (2-4x memory savings)
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
from mps_flash_attn import flash_attention_fp8, quantize_kv_fp8
|
|
74
|
+
|
|
75
|
+
# Quantize K/V to FP8
|
|
76
|
+
k_quant, k_scale = quantize_kv_fp8(k)
|
|
77
|
+
v_quant, v_scale = quantize_kv_fp8(v)
|
|
78
|
+
|
|
79
|
+
# Run attention with quantized KV
|
|
80
|
+
out = flash_attention_fp8(q, k_quant, v_quant, k_scale, v_scale)
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
### 100K+ Long Sequences
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
from mps_flash_attn import flash_attention_chunked
|
|
87
|
+
|
|
88
|
+
# Process 100K tokens without OOM
|
|
89
|
+
q = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
|
|
90
|
+
k = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
|
|
91
|
+
v = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
|
|
92
|
+
|
|
93
|
+
out = flash_attention_chunked(q, k, v, chunk_size=8192)
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
### Drop-in SDPA Replacement
|
|
97
|
+
|
|
98
|
+
```python
|
|
99
|
+
from mps_flash_attn import replace_sdpa
|
|
100
|
+
|
|
101
|
+
replace_sdpa() # Patches F.scaled_dot_product_attention
|
|
102
|
+
|
|
103
|
+
# Now all PyTorch attention uses Flash Attention on MPS
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
### torch.compile() Support
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
from mps_flash_attn import register_custom_op
|
|
110
|
+
|
|
111
|
+
register_custom_op()
|
|
112
|
+
|
|
113
|
+
@torch.compile
|
|
114
|
+
def my_attention(q, k, v):
|
|
115
|
+
return torch.ops.mfa.flash_attention(q, k, v, False, None, None)
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
### Training with BF16 Backward
|
|
119
|
+
|
|
120
|
+
```python
|
|
121
|
+
out = flash_attention(q, k, v, bf16_backward=True) # 2x faster backward
|
|
122
|
+
loss = out.sum()
|
|
123
|
+
loss.backward()
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
### Benchmarking
|
|
127
|
+
|
|
128
|
+
```bash
|
|
129
|
+
# Quick benchmark
|
|
130
|
+
python -m mps_flash_attn.benchmark --suite quick
|
|
131
|
+
|
|
132
|
+
# Full suite with report
|
|
133
|
+
python -m mps_flash_attn.benchmark --suite full --output report.html
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
```python
|
|
137
|
+
from mps_flash_attn.benchmark import run_suite, compare_vs_sdpa
|
|
138
|
+
|
|
139
|
+
results = run_suite(seq_lengths=[1024, 2048, 4096])
|
|
140
|
+
compare_vs_sdpa()
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
## Features
|
|
144
|
+
|
|
145
|
+
| Feature | Status | Notes |
|
|
146
|
+
|---------|--------|-------|
|
|
147
|
+
| Forward pass | ✅ | FP16/BF16/FP32 |
|
|
148
|
+
| Backward pass | ✅ | Full gradient support |
|
|
149
|
+
| Causal masking | ✅ | Native kernel support |
|
|
150
|
+
| Attention masks | ✅ | Boolean masks |
|
|
151
|
+
| Sliding window | ✅ | For local attention models |
|
|
152
|
+
| GQA/MQA | ✅ | Grouped-query attention |
|
|
153
|
+
| Quantized KV | ✅ | FP8, INT8, NF4 |
|
|
154
|
+
| Chunked attention | ✅ | 100K+ tokens |
|
|
155
|
+
| torch.compile() | ✅ | Custom op backend |
|
|
156
|
+
| Dropout | ❌ | Not supported |
|
|
157
|
+
|
|
158
|
+
## Architecture
|
|
159
|
+
|
|
160
|
+
```
|
|
161
|
+
Python API (mps_flash_attn)
|
|
162
|
+
│
|
|
163
|
+
C++ Extension (mps_flash_attn.mm)
|
|
164
|
+
│ dlopen
|
|
165
|
+
Swift Bridge (MFABridge.swift)
|
|
166
|
+
│
|
|
167
|
+
Metal Flash Attention (kernel generation)
|
|
168
|
+
│
|
|
169
|
+
Metal GPU Shaders
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
## Requirements
|
|
173
|
+
|
|
174
|
+
- macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
|
|
175
|
+
- Apple Silicon (M1/M2/M3/M4)
|
|
176
|
+
- Python 3.10+
|
|
177
|
+
- PyTorch 2.0+
|
|
178
|
+
|
|
179
|
+
## TODO / Future Optimizations
|
|
180
|
+
|
|
181
|
+
- [ ] **Batched kernel dispatch** - Currently dispatches B×H separate kernels per attention call. Should use 3D grid to handle all batch/heads in one dispatch (major perf win for small sequences like Swin Transformer windows)
|
|
182
|
+
- [ ] **Fused QKV projection + attention** - Single kernel from input to output, avoid intermediate buffers
|
|
183
|
+
- [ ] **Pre-scaled bias option** - Allow passing pre-scaled bias to avoid per-call scaling overhead
|
|
184
|
+
- [ ] **LoRA fusion** - Fuse adapter weights into attention computation
|
|
185
|
+
|
|
186
|
+
## Credits
|
|
187
|
+
|
|
188
|
+
- [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
|
|
189
|
+
- [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
|
|
190
|
+
|
|
191
|
+
## License
|
|
192
|
+
|
|
193
|
+
MIT
|