mps-flash-attn 0.1.15__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mps-flash-attn might be problematic. Click here for more details.

Files changed (46) hide show
  1. mps_flash_attn-0.2.0/PKG-INFO +211 -0
  2. mps_flash_attn-0.2.0/README.md +186 -0
  3. mps_flash_attn-0.2.0/mps_flash_attn/__init__.py +788 -0
  4. mps_flash_attn-0.2.0/mps_flash_attn/benchmark.py +666 -0
  5. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/csrc/mps_flash_attn.mm +417 -29
  6. mps_flash_attn-0.2.0/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  7. mps_flash_attn-0.2.0/mps_flash_attn.egg-info/PKG-INFO +211 -0
  8. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn.egg-info/SOURCES.txt +2 -2
  9. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/pyproject.toml +1 -1
  10. mps_flash_attn-0.2.0/tests/test_mfa_v2.py +610 -0
  11. mps_flash_attn-0.1.15/PKG-INFO +0 -270
  12. mps_flash_attn-0.1.15/README.md +0 -245
  13. mps_flash_attn-0.1.15/mps_flash_attn/__init__.py +0 -290
  14. mps_flash_attn-0.1.15/mps_flash_attn/lib/libMFABridge.dylib +0 -0
  15. mps_flash_attn-0.1.15/mps_flash_attn.egg-info/PKG-INFO +0 -270
  16. mps_flash_attn-0.1.15/tests/test_attention.py +0 -145
  17. mps_flash_attn-0.1.15/tests/test_flash_attn.py +0 -255
  18. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/LICENSE +0 -0
  19. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/06c421e7a01418cf64aafa07f6b1df0558148583959c596d9a7ce260987f89f0.metallib +0 -0
  20. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/09b9615289be632fdf05444004a0b3b67fb1b70b05a7e0fce8e0ba3a95e3921c.metallib +0 -0
  21. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/0c36461301fb52cbad786d0642b020ad2bfc7229b487ccb5dff44d198423b347.metallib +0 -0
  22. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/2ca9312d1151f792e1a95617db9186928300e3d0ffbe016f0ad53b62ab840bac.metallib +0 -0
  23. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2.metallib +0 -0
  24. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_1024_1024.bin +0 -0
  25. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_2048_2048.bin +0 -0
  26. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_4096_4096.bin +0 -0
  27. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_512_512.bin +0 -0
  28. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/73254c55475c6b7f7b009f095398994b1f9ae8215beafcf810f100357ccc99b2_8192_8192.bin +0 -0
  29. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/771935bf47d248650e287da82bc82e04bff7c4c52964823e7a12462ccd23408e.metallib +0 -0
  30. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/975aece2b4d3d78035be08a0735a7deacf2e544adee5af81c9c0a3a42a926129.metallib +0 -0
  31. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/a5e2c5c401e3872af0899c1fb3e30b5f52a6070fc49c9dac02982cc1c2f25849.metallib +0 -0
  32. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/ac4573fb201e92310867c59bd569a8ae68f859d60a9352d9d4d5d41c1547c83c.metallib +0 -0
  33. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/adc0f77ff05156bbda8fe78afd9ba8a8d3c890ba8fea0902ae79a6ae8c4f04c3.metallib +0 -0
  34. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f.metallib +0 -0
  35. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_1024_1024.bin +0 -0
  36. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_2048_2048.bin +0 -0
  37. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_4096_4096.bin +0 -0
  38. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_512_512.bin +0 -0
  39. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/eab4f40de4b0ebd2765b41c25dba7ccab5db4abf6a6eb87d76fff7b5e0829b2f_8192_8192.bin +0 -0
  40. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/f08fe0efd72e055177e068154dae01e08c4d52d3cb883330a04f1431d274aece.metallib +0 -0
  41. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn/kernels/manifest.json +0 -0
  42. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn.egg-info/dependency_links.txt +0 -0
  43. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn.egg-info/requires.txt +0 -0
  44. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/mps_flash_attn.egg-info/top_level.txt +0 -0
  45. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/setup.cfg +0 -0
  46. {mps_flash_attn-0.1.15 → mps_flash_attn-0.2.0}/setup.py +0 -0
@@ -0,0 +1,211 @@
1
+ Metadata-Version: 2.4
2
+ Name: mps-flash-attn
3
+ Version: 0.2.0
4
+ Summary: Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4)
5
+ Author: imperatormk
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/mpsops/mps-flash-attention
8
+ Project-URL: Repository, https://github.com/mpsops/mps-flash-attention
9
+ Project-URL: Issues, https://github.com/mpsops/mps-flash-attention/issues
10
+ Keywords: flash-attention,apple-silicon,pytorch,mps,metal,transformer,attention
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Operating System :: MacOS :: MacOS X
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.10
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: torch>=2.0.0
24
+ Dynamic: license-file
25
+
26
+ # MPS Flash Attention
27
+
28
+ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
29
+
30
+ **O(N) memory** instead of O(N²), enabling 100K+ sequence lengths on unified memory.
31
+
32
+ ## Performance
33
+
34
+ Benchmarked on Apple Silicon (M1/M2/M3/M4):
35
+
36
+ | Seq Length | vs PyTorch SDPA | Notes |
37
+ |------------|-----------------|-------|
38
+ | 1024 | 1.1-2.0x faster | Crossover point |
39
+ | 2048 | 1.7-3.7x faster | Sweet spot |
40
+ | 4096 | 2.0-3.9x faster | Peak performance |
41
+ | 8192+ | 3-4x faster | SDPA often OOMs |
42
+
43
+ Average speedup: **1.8x** across all configurations.
44
+
45
+ ## Installation
46
+
47
+ ```bash
48
+ pip install mps-flash-attn
49
+ ```
50
+
51
+ ### Build from source
52
+
53
+ ```bash
54
+ git clone --recursive https://github.com/mpsops/mps-flash-attention.git
55
+ cd mps-flash-attention
56
+
57
+ # Build Swift bridge
58
+ cd swift-bridge && swift build -c release && cd ..
59
+
60
+ # Install
61
+ pip install -e .
62
+
63
+ # Set bridge path
64
+ export MFA_BRIDGE_PATH=$PWD/swift-bridge/.build/release/libMFABridge.dylib
65
+ ```
66
+
67
+ ## Usage
68
+
69
+ ### Basic Attention
70
+
71
+ ```python
72
+ from mps_flash_attn import flash_attention
73
+
74
+ # (B, H, N, D) format
75
+ q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
76
+ k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
77
+ v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
78
+
79
+ out = flash_attention(q, k, v)
80
+ ```
81
+
82
+ ### Causal Masking
83
+
84
+ ```python
85
+ out = flash_attention(q, k, v, is_causal=True)
86
+ ```
87
+
88
+ ### Sliding Window (Mistral/Llama 3.2)
89
+
90
+ ```python
91
+ # Only attend to last 4096 tokens
92
+ out = flash_attention(q, k, v, is_causal=True, window_size=4096)
93
+ ```
94
+
95
+ ### Quantized KV Cache (2-4x memory savings)
96
+
97
+ ```python
98
+ from mps_flash_attn import flash_attention_fp8, quantize_kv_fp8
99
+
100
+ # Quantize K/V to FP8
101
+ k_quant, k_scale = quantize_kv_fp8(k)
102
+ v_quant, v_scale = quantize_kv_fp8(v)
103
+
104
+ # Run attention with quantized KV
105
+ out = flash_attention_fp8(q, k_quant, v_quant, k_scale, v_scale)
106
+ ```
107
+
108
+ ### 100K+ Long Sequences
109
+
110
+ ```python
111
+ from mps_flash_attn import flash_attention_chunked
112
+
113
+ # Process 100K tokens without OOM
114
+ q = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
115
+ k = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
116
+ v = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
117
+
118
+ out = flash_attention_chunked(q, k, v, chunk_size=8192)
119
+ ```
120
+
121
+ ### Drop-in SDPA Replacement
122
+
123
+ ```python
124
+ from mps_flash_attn import replace_sdpa
125
+
126
+ replace_sdpa() # Patches F.scaled_dot_product_attention
127
+
128
+ # Now all PyTorch attention uses Flash Attention on MPS
129
+ ```
130
+
131
+ ### torch.compile() Support
132
+
133
+ ```python
134
+ from mps_flash_attn import register_custom_op
135
+
136
+ register_custom_op()
137
+
138
+ @torch.compile
139
+ def my_attention(q, k, v):
140
+ return torch.ops.mfa.flash_attention(q, k, v, False, None, None)
141
+ ```
142
+
143
+ ### Training with BF16 Backward
144
+
145
+ ```python
146
+ out = flash_attention(q, k, v, bf16_backward=True) # 2x faster backward
147
+ loss = out.sum()
148
+ loss.backward()
149
+ ```
150
+
151
+ ### Benchmarking
152
+
153
+ ```bash
154
+ # Quick benchmark
155
+ python -m mps_flash_attn.benchmark --suite quick
156
+
157
+ # Full suite with report
158
+ python -m mps_flash_attn.benchmark --suite full --output report.html
159
+ ```
160
+
161
+ ```python
162
+ from mps_flash_attn.benchmark import run_suite, compare_vs_sdpa
163
+
164
+ results = run_suite(seq_lengths=[1024, 2048, 4096])
165
+ compare_vs_sdpa()
166
+ ```
167
+
168
+ ## Features
169
+
170
+ | Feature | Status | Notes |
171
+ |---------|--------|-------|
172
+ | Forward pass | ✅ | FP16/BF16/FP32 |
173
+ | Backward pass | ✅ | Full gradient support |
174
+ | Causal masking | ✅ | Native kernel support |
175
+ | Attention masks | ✅ | Boolean masks |
176
+ | Sliding window | ✅ | For local attention models |
177
+ | GQA/MQA | ✅ | Grouped-query attention |
178
+ | Quantized KV | ✅ | FP8, INT8, NF4 |
179
+ | Chunked attention | ✅ | 100K+ tokens |
180
+ | torch.compile() | ✅ | Custom op backend |
181
+ | Dropout | ❌ | Not supported |
182
+
183
+ ## Architecture
184
+
185
+ ```
186
+ Python API (mps_flash_attn)
187
+
188
+ C++ Extension (mps_flash_attn.mm)
189
+ │ dlopen
190
+ Swift Bridge (MFABridge.swift)
191
+
192
+ Metal Flash Attention (kernel generation)
193
+
194
+ Metal GPU Shaders
195
+ ```
196
+
197
+ ## Requirements
198
+
199
+ - macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
200
+ - Apple Silicon (M1/M2/M3/M4)
201
+ - Python 3.10+
202
+ - PyTorch 2.0+
203
+
204
+ ## Credits
205
+
206
+ - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
207
+ - [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
208
+
209
+ ## License
210
+
211
+ MIT
@@ -0,0 +1,186 @@
1
+ # MPS Flash Attention
2
+
3
+ Flash Attention for PyTorch on Apple Silicon (M1/M2/M3/M4).
4
+
5
+ **O(N) memory** instead of O(N²), enabling 100K+ sequence lengths on unified memory.
6
+
7
+ ## Performance
8
+
9
+ Benchmarked on Apple Silicon (M1/M2/M3/M4):
10
+
11
+ | Seq Length | vs PyTorch SDPA | Notes |
12
+ |------------|-----------------|-------|
13
+ | 1024 | 1.1-2.0x faster | Crossover point |
14
+ | 2048 | 1.7-3.7x faster | Sweet spot |
15
+ | 4096 | 2.0-3.9x faster | Peak performance |
16
+ | 8192+ | 3-4x faster | SDPA often OOMs |
17
+
18
+ Average speedup: **1.8x** across all configurations.
19
+
20
+ ## Installation
21
+
22
+ ```bash
23
+ pip install mps-flash-attn
24
+ ```
25
+
26
+ ### Build from source
27
+
28
+ ```bash
29
+ git clone --recursive https://github.com/mpsops/mps-flash-attention.git
30
+ cd mps-flash-attention
31
+
32
+ # Build Swift bridge
33
+ cd swift-bridge && swift build -c release && cd ..
34
+
35
+ # Install
36
+ pip install -e .
37
+
38
+ # Set bridge path
39
+ export MFA_BRIDGE_PATH=$PWD/swift-bridge/.build/release/libMFABridge.dylib
40
+ ```
41
+
42
+ ## Usage
43
+
44
+ ### Basic Attention
45
+
46
+ ```python
47
+ from mps_flash_attn import flash_attention
48
+
49
+ # (B, H, N, D) format
50
+ q = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
51
+ k = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
52
+ v = torch.randn(2, 8, 4096, 64, device='mps', dtype=torch.float16)
53
+
54
+ out = flash_attention(q, k, v)
55
+ ```
56
+
57
+ ### Causal Masking
58
+
59
+ ```python
60
+ out = flash_attention(q, k, v, is_causal=True)
61
+ ```
62
+
63
+ ### Sliding Window (Mistral/Llama 3.2)
64
+
65
+ ```python
66
+ # Only attend to last 4096 tokens
67
+ out = flash_attention(q, k, v, is_causal=True, window_size=4096)
68
+ ```
69
+
70
+ ### Quantized KV Cache (2-4x memory savings)
71
+
72
+ ```python
73
+ from mps_flash_attn import flash_attention_fp8, quantize_kv_fp8
74
+
75
+ # Quantize K/V to FP8
76
+ k_quant, k_scale = quantize_kv_fp8(k)
77
+ v_quant, v_scale = quantize_kv_fp8(v)
78
+
79
+ # Run attention with quantized KV
80
+ out = flash_attention_fp8(q, k_quant, v_quant, k_scale, v_scale)
81
+ ```
82
+
83
+ ### 100K+ Long Sequences
84
+
85
+ ```python
86
+ from mps_flash_attn import flash_attention_chunked
87
+
88
+ # Process 100K tokens without OOM
89
+ q = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
90
+ k = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
91
+ v = torch.randn(1, 8, 100000, 64, device='mps', dtype=torch.float16)
92
+
93
+ out = flash_attention_chunked(q, k, v, chunk_size=8192)
94
+ ```
95
+
96
+ ### Drop-in SDPA Replacement
97
+
98
+ ```python
99
+ from mps_flash_attn import replace_sdpa
100
+
101
+ replace_sdpa() # Patches F.scaled_dot_product_attention
102
+
103
+ # Now all PyTorch attention uses Flash Attention on MPS
104
+ ```
105
+
106
+ ### torch.compile() Support
107
+
108
+ ```python
109
+ from mps_flash_attn import register_custom_op
110
+
111
+ register_custom_op()
112
+
113
+ @torch.compile
114
+ def my_attention(q, k, v):
115
+ return torch.ops.mfa.flash_attention(q, k, v, False, None, None)
116
+ ```
117
+
118
+ ### Training with BF16 Backward
119
+
120
+ ```python
121
+ out = flash_attention(q, k, v, bf16_backward=True) # 2x faster backward
122
+ loss = out.sum()
123
+ loss.backward()
124
+ ```
125
+
126
+ ### Benchmarking
127
+
128
+ ```bash
129
+ # Quick benchmark
130
+ python -m mps_flash_attn.benchmark --suite quick
131
+
132
+ # Full suite with report
133
+ python -m mps_flash_attn.benchmark --suite full --output report.html
134
+ ```
135
+
136
+ ```python
137
+ from mps_flash_attn.benchmark import run_suite, compare_vs_sdpa
138
+
139
+ results = run_suite(seq_lengths=[1024, 2048, 4096])
140
+ compare_vs_sdpa()
141
+ ```
142
+
143
+ ## Features
144
+
145
+ | Feature | Status | Notes |
146
+ |---------|--------|-------|
147
+ | Forward pass | ✅ | FP16/BF16/FP32 |
148
+ | Backward pass | ✅ | Full gradient support |
149
+ | Causal masking | ✅ | Native kernel support |
150
+ | Attention masks | ✅ | Boolean masks |
151
+ | Sliding window | ✅ | For local attention models |
152
+ | GQA/MQA | ✅ | Grouped-query attention |
153
+ | Quantized KV | ✅ | FP8, INT8, NF4 |
154
+ | Chunked attention | ✅ | 100K+ tokens |
155
+ | torch.compile() | ✅ | Custom op backend |
156
+ | Dropout | ❌ | Not supported |
157
+
158
+ ## Architecture
159
+
160
+ ```
161
+ Python API (mps_flash_attn)
162
+
163
+ C++ Extension (mps_flash_attn.mm)
164
+ │ dlopen
165
+ Swift Bridge (MFABridge.swift)
166
+
167
+ Metal Flash Attention (kernel generation)
168
+
169
+ Metal GPU Shaders
170
+ ```
171
+
172
+ ## Requirements
173
+
174
+ - macOS 14+ (Sonoma) or macOS 15+ (Sequoia)
175
+ - Apple Silicon (M1/M2/M3/M4)
176
+ - Python 3.10+
177
+ - PyTorch 2.0+
178
+
179
+ ## Credits
180
+
181
+ - [metal-flash-attention](https://github.com/philipturner/metal-flash-attention) by Philip Turner
182
+ - [Flash Attention](https://arxiv.org/abs/2205.14135) paper by Tri Dao et al.
183
+
184
+ ## License
185
+
186
+ MIT