mps-flash-attn 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (38) hide show
  1. mps_flash_attn-0.1.4/LICENSE +27 -0
  2. mps_flash_attn-0.1.4/PKG-INFO +260 -0
  3. mps_flash_attn-0.1.4/README.md +235 -0
  4. mps_flash_attn-0.1.4/mps_flash_attn/__init__.py +248 -0
  5. mps_flash_attn-0.1.4/mps_flash_attn/csrc/mps_flash_attn.mm +441 -0
  6. mps_flash_attn-0.1.4/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  7. mps_flash_attn-0.1.4/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  8. mps_flash_attn-0.1.4/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  9. mps_flash_attn-0.1.4/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  10. mps_flash_attn-0.1.4/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  11. mps_flash_attn-0.1.4/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  12. mps_flash_attn-0.1.4/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  13. mps_flash_attn-0.1.4/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  14. mps_flash_attn-0.1.4/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  15. mps_flash_attn-0.1.4/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  16. mps_flash_attn-0.1.4/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  17. mps_flash_attn-0.1.4/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  18. mps_flash_attn-0.1.4/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  19. mps_flash_attn-0.1.4/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  20. mps_flash_attn-0.1.4/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  21. mps_flash_attn-0.1.4/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  22. mps_flash_attn-0.1.4/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  23. mps_flash_attn-0.1.4/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  24. mps_flash_attn-0.1.4/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  25. mps_flash_attn-0.1.4/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  26. mps_flash_attn-0.1.4/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  27. mps_flash_attn-0.1.4/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  28. mps_flash_attn-0.1.4/mps_flash_attn/kernels/manifest.json +27 -0
  29. mps_flash_attn-0.1.4/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  30. mps_flash_attn-0.1.4/mps_flash_attn.egg-info/PKG-INFO +260 -0
  31. mps_flash_attn-0.1.4/mps_flash_attn.egg-info/SOURCES.txt +36 -0
  32. mps_flash_attn-0.1.4/mps_flash_attn.egg-info/dependency_links.txt +1 -0
  33. mps_flash_attn-0.1.4/mps_flash_attn.egg-info/requires.txt +1 -0
  34. mps_flash_attn-0.1.4/mps_flash_attn.egg-info/top_level.txt +1 -0
  35. mps_flash_attn-0.1.4/pyproject.toml +46 -0
  36. mps_flash_attn-0.1.4/setup.cfg +4 -0
  37. mps_flash_attn-0.1.4/setup.py +68 -0
  38. mps_flash_attn-0.1.4/tests/test_attention.py +145 -0
@@ -0,0 +1,27 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 imperatormk
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ ---
24
+
25
+ This project includes code from metal-flash-attention by Philip Turner,
26
+ also licensed under the MIT License:
27
+ https://github.com/philipturner/metal-flash-attention
@@ -0,0 +1,260 @@
1
+ Metadata-Version: 2.4
2
+ Name: mps-flash-attn
3
+ Version: 0.1.4
4
+ Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
+ Author: imperatormk
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/mpsops/mps-flash-attention
8
+ Project-URL: Repository, https://github.com/mpsops/mps-flash-attention
9
+ Project-URL: Issues, https://github.com/mpsops/mps-flash-attention/issues
10
+ Keywords: flash-attention,apple-silicon,pytorch,mps,metal,transformer,attention
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Operating System :: MacOS :: MacOS X
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.10
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: torch>=2.0.0
24
+ Dynamic: license-file
25
+
26
+ # MPS Flash Attention
27
+
28
+ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
29
+
30
+ **O(N) memory** instead of O(N²), enabling 8K+ sequence lengths on unified memory.
31
+
32
+ ## Features
33
+
34
+ - **Forward pass**: 2-5x faster than PyTorch SDPA
35
+ - **Backward pass**: Full gradient support for training
36
+ - **Causal masking**: Native kernel support (only 5% overhead)
37
+ - **FP16/FP32**: Native fp16 output (no conversion overhead)
38
+ - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
39
+
40
+ ## Performance
41
+
42
+ Tested on M1 Max, N=2048, B=4, H=8, D=64:
43
+
44
+ | Operation | MPS Flash Attn | PyTorch SDPA | Speedup |
45
+ |-----------|----------------|--------------|---------|
46
+ | Forward | 5.3ms | 15ms | 2.8x |
47
+ | Forward+Backward | 55ms | 108ms | 2.0x |
48
+ | Memory | 80MB | 592MB | 7.4x less |
49
+
50
+ ## Installation
51
+
52
+ ### Prerequisites
53
+
54
+ - macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
55
+ - Xcode Command Line Tools (`xcode-select --install`)
56
+ - Python 3.10+ with PyTorch 2.0+
57
+
58
+ ### Build from source
59
+
60
+ ```bash
61
+ # Clone with submodules
62
+ git clone --recursive https://github.com/mpsops/mps-flash-attention.git
63
+ cd mps-flash-attention
64
+
65
+ # Build Swift bridge
66
+ cd swift-bridge
67
+ swift build -c release
68
+ cd ..
69
+
70
+ # Install Python package
71
+ pip install -e .
72
+ ```
73
+
74
+ ### Set environment variable
75
+
76
+ ```bash
77
+ export MFA_BRIDGE_PATH=/path/to/mps-flash-attention/swift-bridge/.build/release/libMFABridge.dylib
78
+ ```
79
+
80
+ ## Usage
81
+
82
+ ### Basic usage
83
+
84
+ ```python
85
+ from mps_flash_attn import flash_attention
86
+
87
+ # Standard attention (B, H, N, D)
88
+ q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
89
+ k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
90
+ v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
91
+
92
+ out = flash_attention(q, k, v)
93
+ ```
94
+
95
+ ### Causal masking (for autoregressive models)
96
+
97
+ ```python
98
+ out = flash_attention(q, k, v, is_causal=True)
99
+ ```
100
+
101
+ ### Training with gradients
102
+
103
+ ```python
104
+ q.requires_grad = True
105
+ k.requires_grad = True
106
+ v.requires_grad = True
107
+
108
+ out = flash_attention(q, k, v, is_causal=True)
109
+ loss = out.sum()
110
+ loss.backward() # Computes dQ, dK, dV
111
+ ```
112
+
113
+ ### Drop-in replacement for SDPA
114
+
115
+ ```python
116
+ from mps_flash_attn import replace_sdpa
117
+
118
+ # Monkey-patch F.scaled_dot_product_attention
119
+ replace_sdpa()
120
+
121
+ # Now all attention ops use Flash Attention on MPS
122
+ ```
123
+
124
+ ## Architecture
125
+
126
+ ```
127
+ +----------------------------------------------------------+
128
+ | Python API |
129
+ | mps_flash_attn/__init__.py |
130
+ | (flash_attention, autograd Function) |
131
+ +----------------------------+-----------------------------+
132
+ |
133
+ +----------------------------v-----------------------------+
134
+ | C++ Extension |
135
+ | mps_flash_attn/csrc/mps_flash_attn.mm |
136
+ | (PyTorch bindings, MTLBuffer handling, offsets) |
137
+ +----------------------------+-----------------------------+
138
+ | dlopen + dlsym
139
+ +----------------------------v-----------------------------+
140
+ | Swift Bridge |
141
+ | swift-bridge/Sources/MFABridge/ |
142
+ | (MFABridge.swift, MetallibCache.swift) |
143
+ | @_cdecl exports: mfa_init, mfa_create_kernel, |
144
+ | mfa_forward, mfa_backward |
145
+ +----------------------------+-----------------------------+
146
+ |
147
+ +----------------------------v-----------------------------+
148
+ | Metal Flash Attention |
149
+ | metal-flash-attention/Sources/FlashAttention/ |
150
+ | (AttentionDescriptor, AttentionKernel, etc.) |
151
+ | |
152
+ | Generates Metal shader source at runtime, |
153
+ | compiles to .metallib, caches pipelines |
154
+ +----------------------------------------------------------+
155
+ ```
156
+
157
+ ## Project Structure
158
+
159
+ ```
160
+ mps-flash-attention/
161
+ ├── mps_flash_attn/ # Python package
162
+ │ ├── __init__.py # Public API (flash_attention, replace_sdpa)
163
+ │ ├── csrc/
164
+ │ │ └── mps_flash_attn.mm # PyTorch C++ extension
165
+ │ └── kernels/ # Pre-compiled metallibs (optional)
166
+
167
+ ├── swift-bridge/ # Swift -> C bridge
168
+ │ ├── Package.swift
169
+ │ └── Sources/MFABridge/
170
+ │ ├── MFABridge.swift # C-callable API (@_cdecl)
171
+ │ └── MetallibCache.swift # Disk caching for metallibs
172
+
173
+ ├── metal-flash-attention/ # Upstream (git submodule)
174
+ │ └── Sources/FlashAttention/
175
+ │ └── Attention/
176
+ │ ├── AttentionDescriptor/ # Problem configuration
177
+ │ ├── AttentionKernel/ # Metal shader generation
178
+ │ └── ...
179
+
180
+ ├── scripts/
181
+ │ └── build_metallibs.py # Pre-compile kernels for distribution
182
+
183
+ └── setup.py # Python package setup
184
+ ```
185
+
186
+ ## Changes from upstream metal-flash-attention
187
+
188
+ We made the following modifications to `metal-flash-attention`:
189
+
190
+ ### 1. macOS 15+ compatibility (MTLLibraryCompiler.swift)
191
+
192
+ Apple restricted `__asm` in runtime-compiled Metal shaders on macOS 15. We added a fallback that uses `xcrun metal` CLI compilation when runtime compilation fails.
193
+
194
+ ### 2. Causal masking support
195
+
196
+ Added `causal` flag to AttentionDescriptor and kernel generation:
197
+
198
+ - `AttentionDescriptor.swift`: Added `causal: Bool` property
199
+ - `AttentionKernelDescriptor.swift`: Added `causal: Bool` property
200
+ - `AttentionKernel.swift`: Added `causal` field
201
+ - `AttentionKernel+Softmax.swift`: Added `maskCausal()` function
202
+ - `AttentionKernel+Source.swift`: Added causal masking to forward/backward loops
203
+
204
+ ## Next Steps
205
+
206
+ ### 1. PR to upstream metal-flash-attention
207
+
208
+ The macOS 15 fix and causal masking should be contributed back:
209
+
210
+ ```bash
211
+ cd metal-flash-attention
212
+ git checkout -b macos15-causal-support
213
+ # Commit changes to:
214
+ # - Sources/FlashAttention/Utilities/MTLLibraryCompiler.swift (new file)
215
+ # - Sources/FlashAttention/Attention/AttentionDescriptor/*.swift
216
+ # - Sources/FlashAttention/Attention/AttentionKernel/*.swift
217
+ git push origin macos15-causal-support
218
+ # Open PR at https://github.com/philipturner/metal-flash-attention
219
+ ```
220
+
221
+ ### 2. Publish mps-flash-attention to PyPI
222
+
223
+ ```bash
224
+ # Add pyproject.toml with proper metadata
225
+ # Build wheel with pre-compiled Swift bridge
226
+ python -m build
227
+ twine upload dist/*
228
+ ```
229
+
230
+ ### 3. Pre-compile kernels for zero cold start
231
+
232
+ ```bash
233
+ python scripts/build_metallibs.py
234
+ # Copies metallibs to mps_flash_attn/kernels/
235
+ # These get shipped with the wheel
236
+ ```
237
+
238
+ ## Current Status (Jan 2025)
239
+
240
+ **Working:**
241
+ - Forward pass (fp16/fp32)
242
+ - Backward pass (dQ, dK, dV gradients)
243
+ - Causal masking
244
+ - Metallib disk caching
245
+ - Pipeline binary caching (MTLBinaryArchive)
246
+
247
+ **Known limitations:**
248
+ - Sequence length must be divisible by block size (typically 64)
249
+ - Head dimension: Best with 32, 64, 96, 128
250
+ - No arbitrary attention masks (only causal or none)
251
+ - No dropout
252
+
253
+ ## Credits
254
+
255
+ - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
256
+ - [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
257
+
258
+ ## License
259
+
260
+ MIT
@@ -0,0 +1,235 @@
1
+ # MPS Flash Attention
2
+
3
+ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
4
+
5
+ **O(N) memory** instead of O(N²), enabling 8K+ sequence lengths on unified memory.
6
+
7
+ ## Features
8
+
9
+ - **Forward pass**: 2-5x faster than PyTorch SDPA
10
+ - **Backward pass**: Full gradient support for training
11
+ - **Causal masking**: Native kernel support (only 5% overhead)
12
+ - **FP16/FP32**: Native fp16 output (no conversion overhead)
13
+ - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
14
+
15
+ ## Performance
16
+
17
+ Tested on M1 Max, N=2048, B=4, H=8, D=64:
18
+
19
+ | Operation | MPS Flash Attn | PyTorch SDPA | Speedup |
20
+ |-----------|----------------|--------------|---------|
21
+ | Forward | 5.3ms | 15ms | 2.8x |
22
+ | Forward+Backward | 55ms | 108ms | 2.0x |
23
+ | Memory | 80MB | 592MB | 7.4x less |
24
+
25
+ ## Installation
26
+
27
+ ### Prerequisites
28
+
29
+ - macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
30
+ - Xcode Command Line Tools (`xcode-select --install`)
31
+ - Python 3.10+ with PyTorch 2.0+
32
+
33
+ ### Build from source
34
+
35
+ ```bash
36
+ # Clone with submodules
37
+ git clone --recursive https://github.com/mpsops/mps-flash-attention.git
38
+ cd mps-flash-attention
39
+
40
+ # Build Swift bridge
41
+ cd swift-bridge
42
+ swift build -c release
43
+ cd ..
44
+
45
+ # Install Python package
46
+ pip install -e .
47
+ ```
48
+
49
+ ### Set environment variable
50
+
51
+ ```bash
52
+ export MFA_BRIDGE_PATH=/path/to/mps-flash-attention/swift-bridge/.build/release/libMFABridge.dylib
53
+ ```
54
+
55
+ ## Usage
56
+
57
+ ### Basic usage
58
+
59
+ ```python
60
+ from mps_flash_attn import flash_attention
61
+
62
+ # Standard attention (B, H, N, D)
63
+ q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
64
+ k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
65
+ v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
66
+
67
+ out = flash_attention(q, k, v)
68
+ ```
69
+
70
+ ### Causal masking (for autoregressive models)
71
+
72
+ ```python
73
+ out = flash_attention(q, k, v, is_causal=True)
74
+ ```
75
+
76
+ ### Training with gradients
77
+
78
+ ```python
79
+ q.requires_grad = True
80
+ k.requires_grad = True
81
+ v.requires_grad = True
82
+
83
+ out = flash_attention(q, k, v, is_causal=True)
84
+ loss = out.sum()
85
+ loss.backward() # Computes dQ, dK, dV
86
+ ```
87
+
88
+ ### Drop-in replacement for SDPA
89
+
90
+ ```python
91
+ from mps_flash_attn import replace_sdpa
92
+
93
+ # Monkey-patch F.scaled_dot_product_attention
94
+ replace_sdpa()
95
+
96
+ # Now all attention ops use Flash Attention on MPS
97
+ ```
98
+
99
+ ## Architecture
100
+
101
+ ```
102
+ +----------------------------------------------------------+
103
+ | Python API |
104
+ | mps_flash_attn/__init__.py |
105
+ | (flash_attention, autograd Function) |
106
+ +----------------------------+-----------------------------+
107
+ |
108
+ +----------------------------v-----------------------------+
109
+ | C++ Extension |
110
+ | mps_flash_attn/csrc/mps_flash_attn.mm |
111
+ | (PyTorch bindings, MTLBuffer handling, offsets) |
112
+ +----------------------------+-----------------------------+
113
+ | dlopen + dlsym
114
+ +----------------------------v-----------------------------+
115
+ | Swift Bridge |
116
+ | swift-bridge/Sources/MFABridge/ |
117
+ | (MFABridge.swift, MetallibCache.swift) |
118
+ | @_cdecl exports: mfa_init, mfa_create_kernel, |
119
+ | mfa_forward, mfa_backward |
120
+ +----------------------------+-----------------------------+
121
+ |
122
+ +----------------------------v-----------------------------+
123
+ | Metal Flash Attention |
124
+ | metal-flash-attention/Sources/FlashAttention/ |
125
+ | (AttentionDescriptor, AttentionKernel, etc.) |
126
+ | |
127
+ | Generates Metal shader source at runtime, |
128
+ | compiles to .metallib, caches pipelines |
129
+ +----------------------------------------------------------+
130
+ ```
131
+
132
+ ## Project Structure
133
+
134
+ ```
135
+ mps-flash-attention/
136
+ ├── mps_flash_attn/ # Python package
137
+ │ ├── __init__.py # Public API (flash_attention, replace_sdpa)
138
+ │ ├── csrc/
139
+ │ │ └── mps_flash_attn.mm # PyTorch C++ extension
140
+ │ └── kernels/ # Pre-compiled metallibs (optional)
141
+
142
+ ├── swift-bridge/ # Swift -> C bridge
143
+ │ ├── Package.swift
144
+ │ └── Sources/MFABridge/
145
+ │ ├── MFABridge.swift # C-callable API (@_cdecl)
146
+ │ └── MetallibCache.swift # Disk caching for metallibs
147
+
148
+ ├── metal-flash-attention/ # Upstream (git submodule)
149
+ │ └── Sources/FlashAttention/
150
+ │ └── Attention/
151
+ │ ├── AttentionDescriptor/ # Problem configuration
152
+ │ ├── AttentionKernel/ # Metal shader generation
153
+ │ └── ...
154
+
155
+ ├── scripts/
156
+ │ └── build_metallibs.py # Pre-compile kernels for distribution
157
+
158
+ └── setup.py # Python package setup
159
+ ```
160
+
161
+ ## Changes from upstream metal-flash-attention
162
+
163
+ We made the following modifications to `metal-flash-attention`:
164
+
165
+ ### 1. macOS 15+ compatibility (MTLLibraryCompiler.swift)
166
+
167
+ Apple restricted `__asm` in runtime-compiled Metal shaders on macOS 15. We added a fallback that uses `xcrun metal` CLI compilation when runtime compilation fails.
168
+
169
+ ### 2. Causal masking support
170
+
171
+ Added `causal` flag to AttentionDescriptor and kernel generation:
172
+
173
+ - `AttentionDescriptor.swift`: Added `causal: Bool` property
174
+ - `AttentionKernelDescriptor.swift`: Added `causal: Bool` property
175
+ - `AttentionKernel.swift`: Added `causal` field
176
+ - `AttentionKernel+Softmax.swift`: Added `maskCausal()` function
177
+ - `AttentionKernel+Source.swift`: Added causal masking to forward/backward loops
178
+
179
+ ## Next Steps
180
+
181
+ ### 1. PR to upstream metal-flash-attention
182
+
183
+ The macOS 15 fix and causal masking should be contributed back:
184
+
185
+ ```bash
186
+ cd metal-flash-attention
187
+ git checkout -b macos15-causal-support
188
+ # Commit changes to:
189
+ # - Sources/FlashAttention/Utilities/MTLLibraryCompiler.swift (new file)
190
+ # - Sources/FlashAttention/Attention/AttentionDescriptor/*.swift
191
+ # - Sources/FlashAttention/Attention/AttentionKernel/*.swift
192
+ git push origin macos15-causal-support
193
+ # Open PR at https://github.com/philipturner/metal-flash-attention
194
+ ```
195
+
196
+ ### 2. Publish mps-flash-attention to PyPI
197
+
198
+ ```bash
199
+ # Add pyproject.toml with proper metadata
200
+ # Build wheel with pre-compiled Swift bridge
201
+ python -m build
202
+ twine upload dist/*
203
+ ```
204
+
205
+ ### 3. Pre-compile kernels for zero cold start
206
+
207
+ ```bash
208
+ python scripts/build_metallibs.py
209
+ # Copies metallibs to mps_flash_attn/kernels/
210
+ # These get shipped with the wheel
211
+ ```
212
+
213
+ ## Current Status (Jan 2025)
214
+
215
+ **Working:**
216
+ - Forward pass (fp16/fp32)
217
+ - Backward pass (dQ, dK, dV gradients)
218
+ - Causal masking
219
+ - Metallib disk caching
220
+ - Pipeline binary caching (MTLBinaryArchive)
221
+
222
+ **Known limitations:**
223
+ - Sequence length must be divisible by block size (typically 64)
224
+ - Head dimension: Best with 32, 64, 96, 128
225
+ - No arbitrary attention masks (only causal or none)
226
+ - No dropout
227
+
228
+ ## Credits
229
+
230
+ - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
231
+ - [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
232
+
233
+ ## License
234
+
235
+ MIT