mps-flash-attn 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (38) hide show
  1. mps_flash_attn-0.1.2/LICENSE +27 -0
  2. mps_flash_attn-0.1.2/PKG-INFO +261 -0
  3. mps_flash_attn-0.1.2/README.md +235 -0
  4. mps_flash_attn-0.1.2/mps_flash_attn/__init__.py +246 -0
  5. mps_flash_attn-0.1.2/mps_flash_attn/csrc/mps_flash_attn.mm +441 -0
  6. mps_flash_attn-0.1.2/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  7. mps_flash_attn-0.1.2/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  8. mps_flash_attn-0.1.2/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  9. mps_flash_attn-0.1.2/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  10. mps_flash_attn-0.1.2/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  11. mps_flash_attn-0.1.2/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  12. mps_flash_attn-0.1.2/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  13. mps_flash_attn-0.1.2/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  14. mps_flash_attn-0.1.2/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  15. mps_flash_attn-0.1.2/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  16. mps_flash_attn-0.1.2/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  17. mps_flash_attn-0.1.2/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  18. mps_flash_attn-0.1.2/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  19. mps_flash_attn-0.1.2/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  20. mps_flash_attn-0.1.2/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  21. mps_flash_attn-0.1.2/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  22. mps_flash_attn-0.1.2/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  23. mps_flash_attn-0.1.2/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  24. mps_flash_attn-0.1.2/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  25. mps_flash_attn-0.1.2/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  26. mps_flash_attn-0.1.2/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  27. mps_flash_attn-0.1.2/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  28. mps_flash_attn-0.1.2/mps_flash_attn/kernels/manifest.json +27 -0
  29. mps_flash_attn-0.1.2/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  30. mps_flash_attn-0.1.2/mps_flash_attn.egg-info/PKG-INFO +261 -0
  31. mps_flash_attn-0.1.2/mps_flash_attn.egg-info/SOURCES.txt +36 -0
  32. mps_flash_attn-0.1.2/mps_flash_attn.egg-info/dependency_links.txt +1 -0
  33. mps_flash_attn-0.1.2/mps_flash_attn.egg-info/requires.txt +1 -0
  34. mps_flash_attn-0.1.2/mps_flash_attn.egg-info/top_level.txt +1 -0
  35. mps_flash_attn-0.1.2/pyproject.toml +47 -0
  36. mps_flash_attn-0.1.2/setup.cfg +4 -0
  37. mps_flash_attn-0.1.2/setup.py +68 -0
  38. mps_flash_attn-0.1.2/tests/test_attention.py +145 -0
@@ -0,0 +1,27 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 imperatormk
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ ---
24
+
25
+ This project includes code from metal-flash-attention by Philip Turner,
26
+ also licensed under the MIT License:
27
+ https://github.com/philipturner/metal-flash-attention
@@ -0,0 +1,261 @@
1
+ Metadata-Version: 2.4
2
+ Name: mps-flash-attn
3
+ Version: 0.1.2
4
+ Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
+ Author: imperatormk
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/mpsops/mps-flash-attention
8
+ Project-URL: Repository, https://github.com/mpsops/mps-flash-attention
9
+ Project-URL: Issues, https://github.com/mpsops/mps-flash-attention/issues
10
+ Keywords: flash-attention,apple-silicon,pytorch,mps,metal,transformer,attention
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Operating System :: MacOS :: MacOS X
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.10
18
+ Classifier: Programming Language :: Python :: 3.11
19
+ Classifier: Programming Language :: Python :: 3.12
20
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
+ Requires-Python: >=3.10
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch>=2.0.0
25
+ Dynamic: license-file
26
+
27
+ # MPS Flash Attention
28
+
29
+ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
30
+
31
+ **O(N) memory** instead of O(N²), enabling 8K+ sequence lengths on unified memory.
32
+
33
+ ## Features
34
+
35
+ - **Forward pass**: 2-5x faster than PyTorch SDPA
36
+ - **Backward pass**: Full gradient support for training
37
+ - **Causal masking**: Native kernel support (only 5% overhead)
38
+ - **FP16/FP32**: Native fp16 output (no conversion overhead)
39
+ - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
40
+
41
+ ## Performance
42
+
43
+ Tested on M1 Max, N=2048, B=4, H=8, D=64:
44
+
45
+ | Operation | MPS Flash Attn | PyTorch SDPA | Speedup |
46
+ |-----------|----------------|--------------|---------|
47
+ | Forward | 5.3ms | 15ms | 2.8x |
48
+ | Forward+Backward | 55ms | 108ms | 2.0x |
49
+ | Memory | 80MB | 592MB | 7.4x less |
50
+
51
+ ## Installation
52
+
53
+ ### Prerequisites
54
+
55
+ - macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
56
+ - Xcode Command Line Tools (`xcode-select --install`)
57
+ - Python 3.10+ with PyTorch 2.0+
58
+
59
+ ### Build from source
60
+
61
+ ```bash
62
+ # Clone with submodules
63
+ git clone --recursive https://github.com/mpsops/mps-flash-attention.git
64
+ cd mps-flash-attention
65
+
66
+ # Build Swift bridge
67
+ cd swift-bridge
68
+ swift build -c release
69
+ cd ..
70
+
71
+ # Install Python package
72
+ pip install -e .
73
+ ```
74
+
75
+ ### Set environment variable
76
+
77
+ ```bash
78
+ export MFA_BRIDGE_PATH=/path/to/mps-flash-attention/swift-bridge/.build/release/libMFABridge.dylib
79
+ ```
80
+
81
+ ## Usage
82
+
83
+ ### Basic usage
84
+
85
+ ```python
86
+ from mps_flash_attn import flash_attention
87
+
88
+ # Standard attention (B, H, N, D)
89
+ q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
90
+ k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
91
+ v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
92
+
93
+ out = flash_attention(q, k, v)
94
+ ```
95
+
96
+ ### Causal masking (for autoregressive models)
97
+
98
+ ```python
99
+ out = flash_attention(q, k, v, is_causal=True)
100
+ ```
101
+
102
+ ### Training with gradients
103
+
104
+ ```python
105
+ q.requires_grad = True
106
+ k.requires_grad = True
107
+ v.requires_grad = True
108
+
109
+ out = flash_attention(q, k, v, is_causal=True)
110
+ loss = out.sum()
111
+ loss.backward() # Computes dQ, dK, dV
112
+ ```
113
+
114
+ ### Drop-in replacement for SDPA
115
+
116
+ ```python
117
+ from mps_flash_attn import replace_sdpa
118
+
119
+ # Monkey-patch F.scaled_dot_product_attention
120
+ replace_sdpa()
121
+
122
+ # Now all attention ops use Flash Attention on MPS
123
+ ```
124
+
125
+ ## Architecture
126
+
127
+ ```
128
+ +----------------------------------------------------------+
129
+ | Python API |
130
+ | mps_flash_attn/__init__.py |
131
+ | (flash_attention, autograd Function) |
132
+ +----------------------------+-----------------------------+
133
+ |
134
+ +----------------------------v-----------------------------+
135
+ | C++ Extension |
136
+ | mps_flash_attn/csrc/mps_flash_attn.mm |
137
+ | (PyTorch bindings, MTLBuffer handling, offsets) |
138
+ +----------------------------+-----------------------------+
139
+ | dlopen + dlsym
140
+ +----------------------------v-----------------------------+
141
+ | Swift Bridge |
142
+ | swift-bridge/Sources/MFABridge/ |
143
+ | (MFABridge.swift, MetallibCache.swift) |
144
+ | @_cdecl exports: mfa_init, mfa_create_kernel, |
145
+ | mfa_forward, mfa_backward |
146
+ +----------------------------+-----------------------------+
147
+ |
148
+ +----------------------------v-----------------------------+
149
+ | Metal Flash Attention |
150
+ | metal-flash-attention/Sources/FlashAttention/ |
151
+ | (AttentionDescriptor, AttentionKernel, etc.) |
152
+ | |
153
+ | Generates Metal shader source at runtime, |
154
+ | compiles to .metallib, caches pipelines |
155
+ +----------------------------------------------------------+
156
+ ```
157
+
158
+ ## Project Structure
159
+
160
+ ```
161
+ mps-flash-attention/
162
+ ├── mps_flash_attn/ # Python package
163
+ │ ├── __init__.py # Public API (flash_attention, replace_sdpa)
164
+ │ ├── csrc/
165
+ │ │ └── mps_flash_attn.mm # PyTorch C++ extension
166
+ │ └── kernels/ # Pre-compiled metallibs (optional)
167
+
168
+ ├── swift-bridge/ # Swift -> C bridge
169
+ │ ├── Package.swift
170
+ │ └── Sources/MFABridge/
171
+ │ ├── MFABridge.swift # C-callable API (@_cdecl)
172
+ │ └── MetallibCache.swift # Disk caching for metallibs
173
+
174
+ ├── metal-flash-attention/ # Upstream (git submodule)
175
+ │ └── Sources/FlashAttention/
176
+ │ └── Attention/
177
+ │ ├── AttentionDescriptor/ # Problem configuration
178
+ │ ├── AttentionKernel/ # Metal shader generation
179
+ │ └── ...
180
+
181
+ ├── scripts/
182
+ │ └── build_metallibs.py # Pre-compile kernels for distribution
183
+
184
+ └── setup.py # Python package setup
185
+ ```
186
+
187
+ ## Changes from upstream metal-flash-attention
188
+
189
+ We made the following modifications to `metal-flash-attention`:
190
+
191
+ ### 1. macOS 15+ compatibility (MTLLibraryCompiler.swift)
192
+
193
+ Apple restricted `__asm` in runtime-compiled Metal shaders on macOS 15. We added a fallback that uses `xcrun metal` CLI compilation when runtime compilation fails.
194
+
195
+ ### 2. Causal masking support
196
+
197
+ Added `causal` flag to AttentionDescriptor and kernel generation:
198
+
199
+ - `AttentionDescriptor.swift`: Added `causal: Bool` property
200
+ - `AttentionKernelDescriptor.swift`: Added `causal: Bool` property
201
+ - `AttentionKernel.swift`: Added `causal` field
202
+ - `AttentionKernel+Softmax.swift`: Added `maskCausal()` function
203
+ - `AttentionKernel+Source.swift`: Added causal masking to forward/backward loops
204
+
205
+ ## Next Steps
206
+
207
+ ### 1. PR to upstream metal-flash-attention
208
+
209
+ The macOS 15 fix and causal masking should be contributed back:
210
+
211
+ ```bash
212
+ cd metal-flash-attention
213
+ git checkout -b macos15-causal-support
214
+ # Commit changes to:
215
+ # - Sources/FlashAttention/Utilities/MTLLibraryCompiler.swift (new file)
216
+ # - Sources/FlashAttention/Attention/AttentionDescriptor/*.swift
217
+ # - Sources/FlashAttention/Attention/AttentionKernel/*.swift
218
+ git push origin macos15-causal-support
219
+ # Open PR at https://github.com/philipturner/metal-flash-attention
220
+ ```
221
+
222
+ ### 2. Publish mps-flash-attention to PyPI
223
+
224
+ ```bash
225
+ # Add pyproject.toml with proper metadata
226
+ # Build wheel with pre-compiled Swift bridge
227
+ python -m build
228
+ twine upload dist/*
229
+ ```
230
+
231
+ ### 3. Pre-compile kernels for zero cold start
232
+
233
+ ```bash
234
+ python scripts/build_metallibs.py
235
+ # Copies metallibs to mps_flash_attn/kernels/
236
+ # These get shipped with the wheel
237
+ ```
238
+
239
+ ## Current Status (Jan 2025)
240
+
241
+ **Working:**
242
+ - Forward pass (fp16/fp32)
243
+ - Backward pass (dQ, dK, dV gradients)
244
+ - Causal masking
245
+ - Metallib disk caching
246
+ - Pipeline binary caching (MTLBinaryArchive)
247
+
248
+ **Known limitations:**
249
+ - Sequence length must be divisible by block size (typically 64)
250
+ - Head dimension: Best with 32, 64, 96, 128
251
+ - No arbitrary attention masks (only causal or none)
252
+ - No dropout
253
+
254
+ ## Credits
255
+
256
+ - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
257
+ - [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
258
+
259
+ ## License
260
+
261
+ MIT
@@ -0,0 +1,235 @@
1
+ # MPS Flash Attention
2
+
3
+ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
4
+
5
+ **O(N) memory** instead of O(N²), enabling 8K+ sequence lengths on unified memory.
6
+
7
+ ## Features
8
+
9
+ - **Forward pass**: 2-5x faster than PyTorch SDPA
10
+ - **Backward pass**: Full gradient support for training
11
+ - **Causal masking**: Native kernel support (only 5% overhead)
12
+ - **FP16/FP32**: Native fp16 output (no conversion overhead)
13
+ - **Pre-compiled kernels**: Zero-compilation cold start (~6ms)
14
+
15
+ ## Performance
16
+
17
+ Tested on M1 Max, N=2048, B=4, H=8, D=64:
18
+
19
+ | Operation | MPS Flash Attn | PyTorch SDPA | Speedup |
20
+ |-----------|----------------|--------------|---------|
21
+ | Forward | 5.3ms | 15ms | 2.8x |
22
+ | Forward+Backward | 55ms | 108ms | 2.0x |
23
+ | Memory | 80MB | 592MB | 7.4x less |
24
+
25
+ ## Installation
26
+
27
+ ### Prerequisites
28
+
29
+ - macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
30
+ - Xcode Command Line Tools (`xcode-select --install`)
31
+ - Python 3.10+ with PyTorch 2.0+
32
+
33
+ ### Build from source
34
+
35
+ ```bash
36
+ # Clone with submodules
37
+ git clone --recursive https://github.com/mpsops/mps-flash-attention.git
38
+ cd mps-flash-attention
39
+
40
+ # Build Swift bridge
41
+ cd swift-bridge
42
+ swift build -c release
43
+ cd ..
44
+
45
+ # Install Python package
46
+ pip install -e .
47
+ ```
48
+
49
+ ### Set environment variable
50
+
51
+ ```bash
52
+ export MFA_BRIDGE_PATH=/path/to/mps-flash-attention/swift-bridge/.build/release/libMFABridge.dylib
53
+ ```
54
+
55
+ ## Usage
56
+
57
+ ### Basic usage
58
+
59
+ ```python
60
+ from mps_flash_attn import flash_attention
61
+
62
+ # Standard attention (B, H, N, D)
63
+ q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
64
+ k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
65
+ v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
66
+
67
+ out = flash_attention(q, k, v)
68
+ ```
69
+
70
+ ### Causal masking (for autoregressive models)
71
+
72
+ ```python
73
+ out = flash_attention(q, k, v, is_causal=True)
74
+ ```
75
+
76
+ ### Training with gradients
77
+
78
+ ```python
79
+ q.requires_grad = True
80
+ k.requires_grad = True
81
+ v.requires_grad = True
82
+
83
+ out = flash_attention(q, k, v, is_causal=True)
84
+ loss = out.sum()
85
+ loss.backward() # Computes dQ, dK, dV
86
+ ```
87
+
88
+ ### Drop-in replacement for SDPA
89
+
90
+ ```python
91
+ from mps_flash_attn import replace_sdpa
92
+
93
+ # Monkey-patch F.scaled_dot_product_attention
94
+ replace_sdpa()
95
+
96
+ # Now all attention ops use Flash Attention on MPS
97
+ ```
98
+
99
+ ## Architecture
100
+
101
+ ```
102
+ +----------------------------------------------------------+
103
+ | Python API |
104
+ | mps_flash_attn/__init__.py |
105
+ | (flash_attention, autograd Function) |
106
+ +----------------------------+-----------------------------+
107
+ |
108
+ +----------------------------v-----------------------------+
109
+ | C++ Extension |
110
+ | mps_flash_attn/csrc/mps_flash_attn.mm |
111
+ | (PyTorch bindings, MTLBuffer handling, offsets) |
112
+ +----------------------------+-----------------------------+
113
+ | dlopen + dlsym
114
+ +----------------------------v-----------------------------+
115
+ | Swift Bridge |
116
+ | swift-bridge/Sources/MFABridge/ |
117
+ | (MFABridge.swift, MetallibCache.swift) |
118
+ | @_cdecl exports: mfa_init, mfa_create_kernel, |
119
+ | mfa_forward, mfa_backward |
120
+ +----------------------------+-----------------------------+
121
+ |
122
+ +----------------------------v-----------------------------+
123
+ | Metal Flash Attention |
124
+ | metal-flash-attention/Sources/FlashAttention/ |
125
+ | (AttentionDescriptor, AttentionKernel, etc.) |
126
+ | |
127
+ | Generates Metal shader source at runtime, |
128
+ | compiles to .metallib, caches pipelines |
129
+ +----------------------------------------------------------+
130
+ ```
131
+
132
+ ## Project Structure
133
+
134
+ ```
135
+ mps-flash-attention/
136
+ ├── mps_flash_attn/ # Python package
137
+ │ ├── __init__.py # Public API (flash_attention, replace_sdpa)
138
+ │ ├── csrc/
139
+ │ │ └── mps_flash_attn.mm # PyTorch C++ extension
140
+ │ └── kernels/ # Pre-compiled metallibs (optional)
141
+
142
+ ├── swift-bridge/ # Swift -> C bridge
143
+ │ ├── Package.swift
144
+ │ └── Sources/MFABridge/
145
+ │ ├── MFABridge.swift # C-callable API (@_cdecl)
146
+ │ └── MetallibCache.swift # Disk caching for metallibs
147
+
148
+ ├── metal-flash-attention/ # Upstream (git submodule)
149
+ │ └── Sources/FlashAttention/
150
+ │ └── Attention/
151
+ │ ├── AttentionDescriptor/ # Problem configuration
152
+ │ ├── AttentionKernel/ # Metal shader generation
153
+ │ └── ...
154
+
155
+ ├── scripts/
156
+ │ └── build_metallibs.py # Pre-compile kernels for distribution
157
+
158
+ └── setup.py # Python package setup
159
+ ```
160
+
161
+ ## Changes from upstream metal-flash-attention
162
+
163
+ We made the following modifications to `metal-flash-attention`:
164
+
165
+ ### 1. macOS 15+ compatibility (MTLLibraryCompiler.swift)
166
+
167
+ Apple restricted `__asm` in runtime-compiled Metal shaders on macOS 15. We added a fallback that uses `xcrun metal` CLI compilation when runtime compilation fails.
168
+
169
+ ### 2. Causal masking support
170
+
171
+ Added `causal` flag to AttentionDescriptor and kernel generation:
172
+
173
+ - `AttentionDescriptor.swift`: Added `causal: Bool` property
174
+ - `AttentionKernelDescriptor.swift`: Added `causal: Bool` property
175
+ - `AttentionKernel.swift`: Added `causal` field
176
+ - `AttentionKernel+Softmax.swift`: Added `maskCausal()` function
177
+ - `AttentionKernel+Source.swift`: Added causal masking to forward/backward loops
178
+
179
+ ## Next Steps
180
+
181
+ ### 1. PR to upstream metal-flash-attention
182
+
183
+ The macOS 15 fix and causal masking should be contributed back:
184
+
185
+ ```bash
186
+ cd metal-flash-attention
187
+ git checkout -b macos15-causal-support
188
+ # Commit changes to:
189
+ # - Sources/FlashAttention/Utilities/MTLLibraryCompiler.swift (new file)
190
+ # - Sources/FlashAttention/Attention/AttentionDescriptor/*.swift
191
+ # - Sources/FlashAttention/Attention/AttentionKernel/*.swift
192
+ git push origin macos15-causal-support
193
+ # Open PR at https://github.com/philipturner/metal-flash-attention
194
+ ```
195
+
196
+ ### 2. Publish mps-flash-attention to PyPI
197
+
198
+ ```bash
199
+ # Add pyproject.toml with proper metadata
200
+ # Build wheel with pre-compiled Swift bridge
201
+ python -m build
202
+ twine upload dist/*
203
+ ```
204
+
205
+ ### 3. Pre-compile kernels for zero cold start
206
+
207
+ ```bash
208
+ python scripts/build_metallibs.py
209
+ # Copies metallibs to mps_flash_attn/kernels/
210
+ # These get shipped with the wheel
211
+ ```
212
+
213
+ ## Current Status (Jan 2025)
214
+
215
+ **Working:**
216
+ - Forward pass (fp16/fp32)
217
+ - Backward pass (dQ, dK, dV gradients)
218
+ - Causal masking
219
+ - Metallib disk caching
220
+ - Pipeline binary caching (MTLBinaryArchive)
221
+
222
+ **Known limitations:**
223
+ - Sequence length must be divisible by block size (typically 64)
224
+ - Head dimension: Best with 32, 64, 96, 128
225
+ - No arbitrary attention masks (only causal or none)
226
+ - No dropout
227
+
228
+ ## Credits
229
+
230
+ - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
231
+ - [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
232
+
233
+ ## License
234
+
235
+ MIT