natten-mps 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. natten_mps-0.3.0/LICENSE +21 -0
  2. natten_mps-0.3.0/PKG-INFO +331 -0
  3. natten_mps-0.3.0/README.md +301 -0
  4. natten_mps-0.3.0/pyproject.toml +49 -0
  5. natten_mps-0.3.0/setup.cfg +4 -0
  6. natten_mps-0.3.0/src/natten_mps/__init__.py +53 -0
  7. natten_mps-0.3.0/src/natten_mps/_core/__init__.py +3 -0
  8. natten_mps-0.3.0/src/natten_mps/_core/_metal_shaders.py +7286 -0
  9. natten_mps-0.3.0/src/natten_mps/_core/inverse_maps.py +428 -0
  10. natten_mps-0.3.0/src/natten_mps/_core/metal.py +1605 -0
  11. natten_mps-0.3.0/src/natten_mps/_core/ops.py +159 -0
  12. natten_mps-0.3.0/src/natten_mps/_core/pure.py +696 -0
  13. natten_mps-0.3.0/src/natten_mps/_torch_ops.py +763 -0
  14. natten_mps-0.3.0/src/natten_mps/autograd/__init__.py +15 -0
  15. natten_mps-0.3.0/src/natten_mps/autograd/_factory.py +186 -0
  16. natten_mps-0.3.0/src/natten_mps/autograd/na1d.py +9 -0
  17. natten_mps-0.3.0/src/natten_mps/autograd/na2d.py +9 -0
  18. natten_mps-0.3.0/src/natten_mps/autograd/na3d.py +9 -0
  19. natten_mps-0.3.0/src/natten_mps/compat/__init__.py +26 -0
  20. natten_mps-0.3.0/src/natten_mps/compat/v014.py +205 -0
  21. natten_mps-0.3.0/src/natten_mps/compat/v015.py +97 -0
  22. natten_mps-0.3.0/src/natten_mps/compat/v017.py +1 -0
  23. natten_mps-0.3.0/src/natten_mps/compat/v020.py +37 -0
  24. natten_mps-0.3.0/src/natten_mps/extras/__init__.py +0 -0
  25. natten_mps-0.3.0/src/natten_mps/extras/allin1/__init__.py +127 -0
  26. natten_mps-0.3.0/src/natten_mps/extras/allin1/_metal_shaders.py +803 -0
  27. natten_mps-0.3.0/src/natten_mps/extras/allin1/functional.py +570 -0
  28. natten_mps-0.3.0/src/natten_mps/extras/allin1/metal.py +331 -0
  29. natten_mps-0.3.0/src/natten_mps/extras/allin1/reference_impl.py +41 -0
  30. natten_mps-0.3.0/src/natten_mps/functional.py +1033 -0
  31. natten_mps-0.3.0/src/natten_mps/merge.py +159 -0
  32. natten_mps-0.3.0/src/natten_mps/nn/__init__.py +5 -0
  33. natten_mps-0.3.0/src/natten_mps/nn/na1d.py +130 -0
  34. natten_mps-0.3.0/src/natten_mps/nn/na2d.py +127 -0
  35. natten_mps-0.3.0/src/natten_mps/nn/na3d.py +127 -0
  36. natten_mps-0.3.0/src/natten_mps/support_matrix.py +37 -0
  37. natten_mps-0.3.0/src/natten_mps/utils/__init__.py +27 -0
  38. natten_mps-0.3.0/src/natten_mps/utils/params.py +82 -0
  39. natten_mps-0.3.0/src/natten_mps/utils/window.py +135 -0
  40. natten_mps-0.3.0/src/natten_mps/version.py +1 -0
  41. natten_mps-0.3.0/src/natten_mps.egg-info/PKG-INFO +331 -0
  42. natten_mps-0.3.0/src/natten_mps.egg-info/SOURCES.txt +74 -0
  43. natten_mps-0.3.0/src/natten_mps.egg-info/dependency_links.txt +1 -0
  44. natten_mps-0.3.0/src/natten_mps.egg-info/requires.txt +7 -0
  45. natten_mps-0.3.0/src/natten_mps.egg-info/top_level.txt +1 -0
  46. natten_mps-0.3.0/tests/test_autograd.py +83 -0
  47. natten_mps-0.3.0/tests/test_autograd_3d.py +37 -0
  48. natten_mps-0.3.0/tests/test_backends.py +31 -0
  49. natten_mps-0.3.0/tests/test_backward_parity.py +89 -0
  50. natten_mps-0.3.0/tests/test_compat_for_version.py +42 -0
  51. natten_mps-0.3.0/tests/test_compat_gradients.py +63 -0
  52. natten_mps-0.3.0/tests/test_compat_v014.py +103 -0
  53. natten_mps-0.3.0/tests/test_compat_v015.py +27 -0
  54. natten_mps-0.3.0/tests/test_compat_v017.py +12 -0
  55. natten_mps-0.3.0/tests/test_compat_v020.py +25 -0
  56. natten_mps-0.3.0/tests/test_extras_allin1.py +85 -0
  57. natten_mps-0.3.0/tests/test_extras_backward.py +296 -0
  58. natten_mps-0.3.0/tests/test_functional.py +304 -0
  59. natten_mps-0.3.0/tests/test_functional_3d.py +120 -0
  60. natten_mps-0.3.0/tests/test_functional_3d_expanded.py +86 -0
  61. natten_mps-0.3.0/tests/test_fused_simd.py +384 -0
  62. natten_mps-0.3.0/tests/test_gradcheck_functional.py +121 -0
  63. natten_mps-0.3.0/tests/test_metal_backend.py +1058 -0
  64. natten_mps-0.3.0/tests/test_na1d.py +24 -0
  65. natten_mps-0.3.0/tests/test_na2d.py +24 -0
  66. natten_mps-0.3.0/tests/test_na3d.py +45 -0
  67. natten_mps-0.3.0/tests/test_new_features.py +1181 -0
  68. natten_mps-0.3.0/tests/test_nn_module_grad.py +178 -0
  69. natten_mps-0.3.0/tests/test_shift_semantics.py +99 -0
  70. natten_mps-0.3.0/tests/test_split_ops_grad.py +70 -0
  71. natten_mps-0.3.0/tests/test_support_matrix.py +36 -0
  72. natten_mps-0.3.0/tests/test_upstream_parity.py +66 -0
  73. natten_mps-0.3.0/tests/test_validation_errors.py +83 -0
  74. natten_mps-0.3.0/tests/test_varlen_na1d.py +363 -0
  75. natten_mps-0.3.0/tests/test_varlen_na2d.py +255 -0
  76. natten_mps-0.3.0/tests/test_varlen_na3d.py +203 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,331 @@
1
+ Metadata-Version: 2.4
2
+ Name: natten-mps
3
+ Version: 0.3.0
4
+ Summary: Neighborhood Attention for Apple Silicon — PyTorch MPS backend
5
+ Author: ssmall256
6
+ License: MIT
7
+ Project-URL: Repository, https://github.com/ssmall256/natten-mps
8
+ Project-URL: Issues, https://github.com/ssmall256/natten-mps/issues
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Operating System :: MacOS
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.10
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: torch>=2.8.0
24
+ Requires-Dist: numpy
25
+ Requires-Dist: packaging
26
+ Provides-Extra: dev
27
+ Requires-Dist: pytest; extra == "dev"
28
+ Requires-Dist: pytest-benchmark; extra == "dev"
29
+ Dynamic: license-file
30
+
31
+ # natten-mps
32
+
33
+ GPU-accelerated Neighborhood Attention for Apple Silicon — built on **PyTorch MPS**.
34
+
35
+ > **Disclaimer (unofficial):** This is an independent, unofficial implementation/port for Apple Silicon.
36
+ > **Not affiliated with** SHI-Labs or the upstream [NATTEN](https://github.com/SHI-Labs/NATTEN) project.
37
+
38
+ This is a focused, Apple-Silicon-first implementation intended to be useful, correct, and easy to install — not a replacement for upstream NATTEN on CUDA.
39
+
40
+ Neighborhood Attention was introduced by the NATTEN authors. If you use Neighborhood Attention in research, please cite the original papers (see [Acknowledgments](#acknowledgments)).
41
+
42
+ > **v0.x** — API may change between minor versions. Pin your dependency for production use.
43
+
44
+ ---
45
+
46
+ ## Why this exists
47
+
48
+ Upstream NATTEN is CUDA-focused and targets NVIDIA GPUs. On Apple Silicon, PyTorch users often want a **GPU-accelerated** neighborhood attention option without requiring CUDA.
49
+
50
+ **natten-mps** provides:
51
+ - **Metal-backed kernels** for PyTorch MPS using `torch.mps.compile_shader`
52
+ - **1D / 2D / 3D** neighborhood attention with **full autograd support**
53
+ - A deployment story that is intentionally simple: **no native extension build step** — install from PyPI and go. Metal shaders are compiled at runtime via `torch.mps.compile_shader` and cached by PyTorch for the process (best effort).
54
+
55
+ For MLX-based workflows, see the sibling project: **[natten-mlx](https://github.com/ssmall256/natten-mlx)**.
56
+
57
+ **Jump to:** [Installation](#installation) | [Quick start](#quick-start) | [Features](#features) | [Backends](#backends) | [Performance](#performance) | [Limitations](#limitations) | [Acknowledgments](#acknowledgments)
58
+
59
+ ---
60
+
61
+ ## Use natten-mps if…
62
+
63
+ - You’re using **PyTorch**
64
+ - You run on **Apple Silicon** and want **MPS (Metal) acceleration**
65
+ - You want a drop-in-ish API (plus optional compatibility shims for historical NATTEN versions)
66
+
67
+ ---
68
+
69
+ ## Installation
70
+
71
+ ```bash
72
+ pip install natten-mps
73
+ ```
74
+
75
+ Requirements:
76
+ - Python 3.10+
77
+ - PyTorch 2.8+ with MPS support
78
+ - macOS 12.3+ for MPS (CPU fallback works anywhere PyTorch runs)
79
+
80
+ ---
81
+
82
+ ## Quick start
83
+
84
+ ### Functional API
85
+
86
+ ```python
87
+ import torch
88
+ from natten_mps import na1d, na2d, na3d
89
+
90
+ # 1D: [B, L, heads, head_dim]
91
+ q = torch.randn(2, 128, 4, 32, device="mps")
92
+ k = torch.randn(2, 128, 4, 32, device="mps")
93
+ v = torch.randn(2, 128, 4, 32, device="mps")
94
+ out = na1d(q, k, v, kernel_size=7)
95
+
96
+ # 2D: [B, H, W, heads, head_dim]
97
+ q2d = torch.randn(2, 32, 32, 4, 32, device="mps")
98
+ k2d = torch.randn(2, 32, 32, 4, 32, device="mps")
99
+ v2d = torch.randn(2, 32, 32, 4, 32, device="mps")
100
+ out2d = na2d(q2d, k2d, v2d, kernel_size=7)
101
+
102
+ # 3D: [B, D, H, W, heads, head_dim]
103
+ q3d = torch.randn(1, 8, 8, 8, 4, 32, device="mps")
104
+ k3d = torch.randn(1, 8, 8, 8, 4, 32, device="mps")
105
+ v3d = torch.randn(1, 8, 8, 8, 4, 32, device="mps")
106
+ out3d = na3d(q3d, k3d, v3d, kernel_size=3)
107
+ ```
108
+
109
+ ### Module API
110
+
111
+ ```python
112
+ import torch
113
+ from natten_mps import NeighborhoodAttention2D
114
+
115
+ layer = NeighborhoodAttention2D(embed_dim=128, num_heads=4, kernel_size=(7, 7)).to("mps")
116
+ x = torch.randn(2, 32, 32, 128, device="mps") # [B, H, W, C]
117
+ y = layer(x)
118
+ ```
119
+
120
+ ### Split QK / AV (access attention weights)
121
+
122
+ ```python
123
+ import torch
124
+ from natten_mps import na1d_qk, na1d_av
125
+
126
+ B, L, H, D = 2, 128, 4, 32
127
+ q = torch.randn(B, L, H, D, device="mps")
128
+ k = torch.randn(B, L, H, D, device="mps")
129
+ v = torch.randn(B, L, H, D, device="mps")
130
+
131
+ logits = na1d_qk(q, k, kernel_size=7, scale=D ** -0.5) # [B, L, H, K]
132
+ attn = torch.softmax(logits, dim=-1)
133
+ out = na1d_av(attn, v, kernel_size=7) # [B, L, H, D]
134
+ ```
135
+
136
+ ---
137
+
138
+ ## Features
139
+
140
+ Core:
141
+ - **1D / 2D / 3D** neighborhood attention (fused and split QK/AV ops)
142
+ - **Causal masking**, including per-axis control (e.g. `is_causal=(True, False)` for 2D)
143
+ - **Strided output** for downsampling (e.g. `stride=2`)
144
+ - **Combined causal + stride** in one kernel
145
+ - **Non-uniform kernels** for 2D/3D (per-axis kernel sizes and dilations)
146
+
147
+ Batching / advanced:
148
+ - **Variable-length (varlen) attention** — padded batches with per-sample spatial sizes, Metal-accelerated for all ranks
149
+ - **GQA / MQA** (`num_kv_heads`) for grouped-query attention patterns
150
+ - **additional_keys / additional_values** — prepend extra global tokens that every query attends to
151
+ - **merge_attentions** — numerically stable sigmoid-based merge of multiple attention outputs
152
+ - **FMHA fast path** — when the kernel covers the full spatial extent, can dispatch to efficient full attention
153
+
154
+ Extras:
155
+ - **`extras/`** namespace for model-specific fused kernels (e.g., DiNAT-style fused QK+RPB paths)
156
+
157
+ Compatibility:
158
+ - Optional **compat shims** for historical upstream API versions (see [Compatibility shims](#compatibility-shims))
159
+
160
+ ---
161
+
162
+ ## Backends
163
+
164
+ Backend dispatch is controlled at runtime and does not require a native extension.
165
+
166
+ | Backend | Status | Description |
167
+ |---|---|---|
168
+ | `pure` | Complete | Pure PyTorch fallback (CPU/MPS) |
169
+ | `metal` | Complete | Metal compute shaders via `torch.mps.compile_shader` |
170
+ | `auto` | Default | Select best available backend for the configuration |
171
+
172
+ ```python
173
+ import natten_mps
174
+
175
+ natten_mps.set_backend("metal") # "auto" (default), "metal", or "pure"
176
+ print(natten_mps.get_backend())
177
+ ```
178
+
179
+ Or via environment variable:
180
+
181
+ ```bash
182
+ NATTEN_BACKEND=metal python my_script.py # "auto" (default), "metal", or "pure"
183
+ ```
184
+
185
+ ---
186
+
187
+ ## Performance
188
+
189
+ Metal kernels vs pure-PyTorch backend on Apple Silicon (M-series), forward pass:
190
+
191
+ | Benchmark | Metal | Pure | Speedup |
192
+ |---|---:|---:|---:|
193
+ | 1D, L=256, K=7 | 0.9 ms | 9.8 ms | **11×** |
194
+ | 1D, L=1024, K=7 | 1.1 ms | 37 ms | **34×** |
195
+ | 2D, 32×32, K=7 | 1.3 ms | 20 ms | **15–17×** |
196
+ | 2D, 64×64, K=7 | 2.9 ms | 84 ms | **29×** |
197
+ | 2D, 32×32, K=7, causal | 1.1 ms | 21 ms | **19×** |
198
+ | 3D, 16³, K=3 | 1.7 ms | 12 ms | **7×** |
199
+
200
+ Run the full suite:
201
+ ```bash
202
+ python benchmarks/bench.py
203
+ # add --backward to time backward pass
204
+ ```
205
+
206
+ ### Cross-framework: natten-mps vs natten-mlx
207
+
208
+ Apple Silicon (M-series), fp32, B=1 H=4 D=32, Metal-accelerated:
209
+
210
+ | Config | natten-mps fwd | natten-mlx fwd | natten-mps bwd | natten-mlx bwd |
211
+ |---|---:|---:|---:|---:|
212
+ | 1D L=256 K=7 | 0.25 ms | 0.21 ms | 0.39 ms | 0.14 ms |
213
+ | 1D L=1024 K=7 | 0.40 ms | 0.27 ms | 0.63 ms | 0.26 ms |
214
+ | 2D 32×32 K=7 | 0.88 ms | 0.65 ms | 1.62 ms | 1.02 ms |
215
+ | 2D 64×64 K=7 | 1.32 ms | 1.13 ms | 1.55 ms | 0.97 ms |
216
+ | 2D 32×32 K=7 causal | 0.37 ms | 0.29 ms | 0.49 ms | 0.31 ms |
217
+ | 3D 16³ K=3 | 0.55 ms | 0.43 ms | 0.89 ms | 0.50 ms |
218
+
219
+ MLX’s compiled primitives tend to have lower dispatch overhead than PyTorch MPS, so natten-mlx is often faster for the same shapes. Both are dramatically faster than pure-framework baselines.
220
+
221
+ ### Variable-length (varlen) attention
222
+
223
+ Metal-accelerated varlen forward, fp32:
224
+
225
+ | Config | natten-mps | natten-mlx | MLX speedup |
226
+ |---|---:|---:|---:|
227
+ | varlen 1D B=4 L=128 K=7 | 1.74 ms | 0.53 ms | 3.3× |
228
+ | varlen 1D B=4 L=256 K=7 | 1.74 ms | 0.51 ms | 3.4× |
229
+ | varlen 2D B=2 16×16 K=3 | 2.39 ms | 0.82 ms | 2.9× |
230
+ | varlen 2D B=2 32×32 K=7 | 3.79 ms | 1.23 ms | 3.1× |
231
+ | varlen 3D B=2 8³ K=3 | 3.82 ms | 1.55 ms | 2.5× |
232
+
233
+ Backward pass uses per-sample autograd re-differentiation through the standard Metal-accelerated `na*d` kernels.
234
+
235
+ ### Methodology
236
+
237
+ All timings on **Apple M4 Max**, macOS 26.3, Python 3.11, PyTorch 2.10, float32. Each kernel is warmed up for 5 iterations, then timed for 20 repetitions with `torch.mps.synchronize()` gating; the reported value is the **median**. Reproduce with `python benchmarks/bench.py`.
238
+
239
+ ---
240
+
241
+ ## Compatibility shims
242
+
243
+ If you have downstream code written against historical upstream APIs, natten-mps includes optional shims:
244
+
245
+ ```python
246
+ import natten_mps.compat.v014 as natten014
247
+ import natten_mps.compat.v017 as natten017
248
+ import natten_mps.compat.v020 as natten020
249
+ ```
250
+
251
+ These are best-effort drop-in replacements for common upstream `natten` entry points.
252
+
253
+ ---
254
+
255
+ ## Extras: model-specific fused kernels
256
+
257
+ Example: fused DiNAT-style ops with relative position bias:
258
+
259
+ ```python
260
+ from natten_mps.extras.allin1 import (
261
+ na1d_qk_rpb, na1d_av_fused,
262
+ na2d_qk_rpb, na2d_av_fused,
263
+ )
264
+ ```
265
+
266
+ ---
267
+
268
+ ## Limitations
269
+
270
+ - **Odd kernel sizes only** for accelerated Neighborhood Attention (this matches upstream NATTEN’s neighborhood half-width formulation).
271
+ - Metal kernel acceleration has size caps tuned for performance:
272
+ - 1D: K ≤ 63
273
+ - 2D: K ≤ 13
274
+ - 3D: K ≤ 7
275
+ - Unsupported kernel sizes or configurations automatically fall back to `pure`.
276
+ - **Supported dtypes:** Metal kernels run in float32 and float16. Bfloat16 inputs are accepted but upcast to float32 internally. Other dtypes fall back to `pure`.
277
+ - MPS acceleration is **macOS-only** (CPU fallback works anywhere PyTorch runs).
278
+
279
+ ---
280
+
281
+ ## Differences from upstream NATTEN (high level)
282
+
283
+ - Targets **Apple Silicon** (PyTorch **MPS** + CPU fallback); no CUDA backend
284
+ - Uses **Metal compute shaders** instead of CUDA kernels
285
+ - Includes Apple-Silicon-focused extras (and optional compatibility shims)
286
+
287
+ ---
288
+
289
+ ## Acknowledgments
290
+
291
+ This project implements Neighborhood Attention as introduced by the upstream [NATTEN](https://github.com/SHI-Labs/NATTEN) project (SHI-Labs). The original NATTEN library and research are by Ali Hassani, Steven Walton, Humphrey Shi, and collaborators.
292
+
293
+ If you use Neighborhood Attention in research, please cite the original papers:
294
+
295
+ - Hassani et al., **Neighborhood Attention Transformer** (CVPR 2023)
296
+ - Hassani & Shi, **Dilated Neighborhood Attention Transformer** (2022)
297
+ - Hassani et al., **Faster Neighborhood Attention** (NeurIPS 2024)
298
+
299
+ <details>
300
+ <summary>BibTeX</summary>
301
+
302
+ ```bibtex
303
+ @inproceedings{hassani2023neighborhood,
304
+ title = {Neighborhood Attention Transformer},
305
+ author = {Hassani, Ali and Walton, Steven and Li, Jiachen and Li, Shen and Shi, Humphrey},
306
+ booktitle = {CVPR},
307
+ year = {2023}
308
+ }
309
+
310
+ @article{hassani2022dilated,
311
+ title = {Dilated Neighborhood Attention Transformer},
312
+ author = {Hassani, Ali and Shi, Humphrey},
313
+ journal = {arXiv preprint arXiv:2209.15001},
314
+ year = {2022}
315
+ }
316
+
317
+ @inproceedings{hassani2024faster,
318
+ title = {Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level},
319
+ author = {Hassani, Ali and Ke, Wen-Mei and Gong, Jiaming and Walton, Steven and Shi, Humphrey},
320
+ booktitle = {NeurIPS},
321
+ year = {2024}
322
+ }
323
+ ```
324
+ </details>
325
+
326
+ ---
327
+
328
+ ## License
329
+
330
+ MIT — see [LICENSE](LICENSE) for details.
331
+ Upstream NATTEN is also MIT-licensed.
@@ -0,0 +1,301 @@
1
+ # natten-mps
2
+
3
+ GPU-accelerated Neighborhood Attention for Apple Silicon — built on **PyTorch MPS**.
4
+
5
+ > **Disclaimer (unofficial):** This is an independent, unofficial implementation/port for Apple Silicon.
6
+ > **Not affiliated with** SHI-Labs or the upstream [NATTEN](https://github.com/SHI-Labs/NATTEN) project.
7
+
8
+ This is a focused, Apple-Silicon-first implementation intended to be useful, correct, and easy to install — not a replacement for upstream NATTEN on CUDA.
9
+
10
+ Neighborhood Attention was introduced by the NATTEN authors. If you use Neighborhood Attention in research, please cite the original papers (see [Acknowledgments](#acknowledgments)).
11
+
12
+ > **v0.x** — API may change between minor versions. Pin your dependency for production use.
13
+
14
+ ---
15
+
16
+ ## Why this exists
17
+
18
+ Upstream NATTEN is CUDA-focused and targets NVIDIA GPUs. On Apple Silicon, PyTorch users often want a **GPU-accelerated** neighborhood attention option without requiring CUDA.
19
+
20
+ **natten-mps** provides:
21
+ - **Metal-backed kernels** for PyTorch MPS using `torch.mps.compile_shader`
22
+ - **1D / 2D / 3D** neighborhood attention with **full autograd support**
23
+ - A deployment story that is intentionally simple: **no native extension build step** — install from PyPI and go. Metal shaders are compiled at runtime via `torch.mps.compile_shader` and cached by PyTorch for the process (best effort).
24
+
25
+ For MLX-based workflows, see the sibling project: **[natten-mlx](https://github.com/ssmall256/natten-mlx)**.
26
+
27
+ **Jump to:** [Installation](#installation) | [Quick start](#quick-start) | [Features](#features) | [Backends](#backends) | [Performance](#performance) | [Limitations](#limitations) | [Acknowledgments](#acknowledgments)
28
+
29
+ ---
30
+
31
+ ## Use natten-mps if…
32
+
33
+ - You’re using **PyTorch**
34
+ - You run on **Apple Silicon** and want **MPS (Metal) acceleration**
35
+ - You want a drop-in-ish API (plus optional compatibility shims for historical NATTEN versions)
36
+
37
+ ---
38
+
39
+ ## Installation
40
+
41
+ ```bash
42
+ pip install natten-mps
43
+ ```
44
+
45
+ Requirements:
46
+ - Python 3.10+
47
+ - PyTorch 2.8+ with MPS support
48
+ - macOS 12.3+ for MPS (CPU fallback works anywhere PyTorch runs)
49
+
50
+ ---
51
+
52
+ ## Quick start
53
+
54
+ ### Functional API
55
+
56
+ ```python
57
+ import torch
58
+ from natten_mps import na1d, na2d, na3d
59
+
60
+ # 1D: [B, L, heads, head_dim]
61
+ q = torch.randn(2, 128, 4, 32, device="mps")
62
+ k = torch.randn(2, 128, 4, 32, device="mps")
63
+ v = torch.randn(2, 128, 4, 32, device="mps")
64
+ out = na1d(q, k, v, kernel_size=7)
65
+
66
+ # 2D: [B, H, W, heads, head_dim]
67
+ q2d = torch.randn(2, 32, 32, 4, 32, device="mps")
68
+ k2d = torch.randn(2, 32, 32, 4, 32, device="mps")
69
+ v2d = torch.randn(2, 32, 32, 4, 32, device="mps")
70
+ out2d = na2d(q2d, k2d, v2d, kernel_size=7)
71
+
72
+ # 3D: [B, D, H, W, heads, head_dim]
73
+ q3d = torch.randn(1, 8, 8, 8, 4, 32, device="mps")
74
+ k3d = torch.randn(1, 8, 8, 8, 4, 32, device="mps")
75
+ v3d = torch.randn(1, 8, 8, 8, 4, 32, device="mps")
76
+ out3d = na3d(q3d, k3d, v3d, kernel_size=3)
77
+ ```
78
+
79
+ ### Module API
80
+
81
+ ```python
82
+ import torch
83
+ from natten_mps import NeighborhoodAttention2D
84
+
85
+ layer = NeighborhoodAttention2D(embed_dim=128, num_heads=4, kernel_size=(7, 7)).to("mps")
86
+ x = torch.randn(2, 32, 32, 128, device="mps") # [B, H, W, C]
87
+ y = layer(x)
88
+ ```
89
+
90
+ ### Split QK / AV (access attention weights)
91
+
92
+ ```python
93
+ import torch
94
+ from natten_mps import na1d_qk, na1d_av
95
+
96
+ B, L, H, D = 2, 128, 4, 32
97
+ q = torch.randn(B, L, H, D, device="mps")
98
+ k = torch.randn(B, L, H, D, device="mps")
99
+ v = torch.randn(B, L, H, D, device="mps")
100
+
101
+ logits = na1d_qk(q, k, kernel_size=7, scale=D ** -0.5) # [B, L, H, K]
102
+ attn = torch.softmax(logits, dim=-1)
103
+ out = na1d_av(attn, v, kernel_size=7) # [B, L, H, D]
104
+ ```
105
+
106
+ ---
107
+
108
+ ## Features
109
+
110
+ Core:
111
+ - **1D / 2D / 3D** neighborhood attention (fused and split QK/AV ops)
112
+ - **Causal masking**, including per-axis control (e.g. `is_causal=(True, False)` for 2D)
113
+ - **Strided output** for downsampling (e.g. `stride=2`)
114
+ - **Combined causal + stride** in one kernel
115
+ - **Non-uniform kernels** for 2D/3D (per-axis kernel sizes and dilations)
116
+
117
+ Batching / advanced:
118
+ - **Variable-length (varlen) attention** — padded batches with per-sample spatial sizes, Metal-accelerated for all ranks
119
+ - **GQA / MQA** (`num_kv_heads`) for grouped-query attention patterns
120
+ - **additional_keys / additional_values** — prepend extra global tokens that every query attends to
121
+ - **merge_attentions** — numerically stable sigmoid-based merge of multiple attention outputs
122
+ - **FMHA fast path** — when the kernel covers the full spatial extent, can dispatch to efficient full attention
123
+
124
+ Extras:
125
+ - **`extras/`** namespace for model-specific fused kernels (e.g., DiNAT-style fused QK+RPB paths)
126
+
127
+ Compatibility:
128
+ - Optional **compat shims** for historical upstream API versions (see [Compatibility shims](#compatibility-shims))
129
+
130
+ ---
131
+
132
+ ## Backends
133
+
134
+ Backend dispatch is controlled at runtime and does not require a native extension.
135
+
136
+ | Backend | Status | Description |
137
+ |---|---|---|
138
+ | `pure` | Complete | Pure PyTorch fallback (CPU/MPS) |
139
+ | `metal` | Complete | Metal compute shaders via `torch.mps.compile_shader` |
140
+ | `auto` | Default | Select best available backend for the configuration |
141
+
142
+ ```python
143
+ import natten_mps
144
+
145
+ natten_mps.set_backend("metal") # "auto" (default), "metal", or "pure"
146
+ print(natten_mps.get_backend())
147
+ ```
148
+
149
+ Or via environment variable:
150
+
151
+ ```bash
152
+ NATTEN_BACKEND=metal python my_script.py # "auto" (default), "metal", or "pure"
153
+ ```
154
+
155
+ ---
156
+
157
+ ## Performance
158
+
159
+ Metal kernels vs pure-PyTorch backend on Apple Silicon (M-series), forward pass:
160
+
161
+ | Benchmark | Metal | Pure | Speedup |
162
+ |---|---:|---:|---:|
163
+ | 1D, L=256, K=7 | 0.9 ms | 9.8 ms | **11×** |
164
+ | 1D, L=1024, K=7 | 1.1 ms | 37 ms | **34×** |
165
+ | 2D, 32×32, K=7 | 1.3 ms | 20 ms | **15–17×** |
166
+ | 2D, 64×64, K=7 | 2.9 ms | 84 ms | **29×** |
167
+ | 2D, 32×32, K=7, causal | 1.1 ms | 21 ms | **19×** |
168
+ | 3D, 16³, K=3 | 1.7 ms | 12 ms | **7×** |
169
+
170
+ Run the full suite:
171
+ ```bash
172
+ python benchmarks/bench.py
173
+ # add --backward to time backward pass
174
+ ```
175
+
176
+ ### Cross-framework: natten-mps vs natten-mlx
177
+
178
+ Apple Silicon (M-series), fp32, B=1 H=4 D=32, Metal-accelerated:
179
+
180
+ | Config | natten-mps fwd | natten-mlx fwd | natten-mps bwd | natten-mlx bwd |
181
+ |---|---:|---:|---:|---:|
182
+ | 1D L=256 K=7 | 0.25 ms | 0.21 ms | 0.39 ms | 0.14 ms |
183
+ | 1D L=1024 K=7 | 0.40 ms | 0.27 ms | 0.63 ms | 0.26 ms |
184
+ | 2D 32×32 K=7 | 0.88 ms | 0.65 ms | 1.62 ms | 1.02 ms |
185
+ | 2D 64×64 K=7 | 1.32 ms | 1.13 ms | 1.55 ms | 0.97 ms |
186
+ | 2D 32×32 K=7 causal | 0.37 ms | 0.29 ms | 0.49 ms | 0.31 ms |
187
+ | 3D 16³ K=3 | 0.55 ms | 0.43 ms | 0.89 ms | 0.50 ms |
188
+
189
+ MLX’s compiled primitives tend to have lower dispatch overhead than PyTorch MPS, so natten-mlx is often faster for the same shapes. Both are dramatically faster than pure-framework baselines.
190
+
191
+ ### Variable-length (varlen) attention
192
+
193
+ Metal-accelerated varlen forward, fp32:
194
+
195
+ | Config | natten-mps | natten-mlx | MLX speedup |
196
+ |---|---:|---:|---:|
197
+ | varlen 1D B=4 L=128 K=7 | 1.74 ms | 0.53 ms | 3.3× |
198
+ | varlen 1D B=4 L=256 K=7 | 1.74 ms | 0.51 ms | 3.4× |
199
+ | varlen 2D B=2 16×16 K=3 | 2.39 ms | 0.82 ms | 2.9× |
200
+ | varlen 2D B=2 32×32 K=7 | 3.79 ms | 1.23 ms | 3.1× |
201
+ | varlen 3D B=2 8³ K=3 | 3.82 ms | 1.55 ms | 2.5× |
202
+
203
+ Backward pass uses per-sample autograd re-differentiation through the standard Metal-accelerated `na*d` kernels.
204
+
205
+ ### Methodology
206
+
207
+ All timings on **Apple M4 Max**, macOS 26.3, Python 3.11, PyTorch 2.10, float32. Each kernel is warmed up for 5 iterations, then timed for 20 repetitions with `torch.mps.synchronize()` gating; the reported value is the **median**. Reproduce with `python benchmarks/bench.py`.
208
+
209
+ ---
210
+
211
+ ## Compatibility shims
212
+
213
+ If you have downstream code written against historical upstream APIs, natten-mps includes optional shims:
214
+
215
+ ```python
216
+ import natten_mps.compat.v014 as natten014
217
+ import natten_mps.compat.v017 as natten017
218
+ import natten_mps.compat.v020 as natten020
219
+ ```
220
+
221
+ These are best-effort drop-in replacements for common upstream `natten` entry points.
222
+
223
+ ---
224
+
225
+ ## Extras: model-specific fused kernels
226
+
227
+ Example: fused DiNAT-style ops with relative position bias:
228
+
229
+ ```python
230
+ from natten_mps.extras.allin1 import (
231
+ na1d_qk_rpb, na1d_av_fused,
232
+ na2d_qk_rpb, na2d_av_fused,
233
+ )
234
+ ```
235
+
236
+ ---
237
+
238
+ ## Limitations
239
+
240
+ - **Odd kernel sizes only** for accelerated Neighborhood Attention (this matches upstream NATTEN’s neighborhood half-width formulation).
241
+ - Metal kernel acceleration has size caps tuned for performance:
242
+ - 1D: K ≤ 63
243
+ - 2D: K ≤ 13
244
+ - 3D: K ≤ 7
245
+ - Unsupported kernel sizes or configurations automatically fall back to `pure`.
246
+ - **Supported dtypes:** Metal kernels run in float32 and float16. Bfloat16 inputs are accepted but upcast to float32 internally. Other dtypes fall back to `pure`.
247
+ - MPS acceleration is **macOS-only** (CPU fallback works anywhere PyTorch runs).
248
+
249
+ ---
250
+
251
+ ## Differences from upstream NATTEN (high level)
252
+
253
+ - Targets **Apple Silicon** (PyTorch **MPS** + CPU fallback); no CUDA backend
254
+ - Uses **Metal compute shaders** instead of CUDA kernels
255
+ - Includes Apple-Silicon-focused extras (and optional compatibility shims)
256
+
257
+ ---
258
+
259
+ ## Acknowledgments
260
+
261
+ This project implements Neighborhood Attention as introduced by the upstream [NATTEN](https://github.com/SHI-Labs/NATTEN) project (SHI-Labs). The original NATTEN library and research are by Ali Hassani, Steven Walton, Humphrey Shi, and collaborators.
262
+
263
+ If you use Neighborhood Attention in research, please cite the original papers:
264
+
265
+ - Hassani et al., **Neighborhood Attention Transformer** (CVPR 2023)
266
+ - Hassani & Shi, **Dilated Neighborhood Attention Transformer** (2022)
267
+ - Hassani et al., **Faster Neighborhood Attention** (NeurIPS 2024)
268
+
269
+ <details>
270
+ <summary>BibTeX</summary>
271
+
272
+ ```bibtex
273
+ @inproceedings{hassani2023neighborhood,
274
+ title = {Neighborhood Attention Transformer},
275
+ author = {Hassani, Ali and Walton, Steven and Li, Jiachen and Li, Shen and Shi, Humphrey},
276
+ booktitle = {CVPR},
277
+ year = {2023}
278
+ }
279
+
280
+ @article{hassani2022dilated,
281
+ title = {Dilated Neighborhood Attention Transformer},
282
+ author = {Hassani, Ali and Shi, Humphrey},
283
+ journal = {arXiv preprint arXiv:2209.15001},
284
+ year = {2022}
285
+ }
286
+
287
+ @inproceedings{hassani2024faster,
288
+ title = {Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level},
289
+ author = {Hassani, Ali and Ke, Wen-Mei and Gong, Jiaming and Walton, Steven and Shi, Humphrey},
290
+ booktitle = {NeurIPS},
291
+ year = {2024}
292
+ }
293
+ ```
294
+ </details>
295
+
296
+ ---
297
+
298
+ ## License
299
+
300
+ MIT — see [LICENSE](LICENSE) for details.
301
+ Upstream NATTEN is also MIT-licensed.
@@ -0,0 +1,49 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "natten-mps"
7
+ dynamic = ["version"]
8
+ description = "Neighborhood Attention for Apple Silicon — PyTorch MPS backend"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ {name = "ssmall256"},
14
+ ]
15
+ dependencies = [
16
+ "torch>=2.8.0",
17
+ "numpy",
18
+ "packaging",
19
+ ]
20
+ classifiers = [
21
+ "Development Status :: 4 - Beta",
22
+ "Intended Audience :: Developers",
23
+ "Intended Audience :: Science/Research",
24
+ "License :: OSI Approved :: MIT License",
25
+ "Operating System :: MacOS",
26
+ "Programming Language :: Python :: 3",
27
+ "Programming Language :: Python :: 3.10",
28
+ "Programming Language :: Python :: 3.11",
29
+ "Programming Language :: Python :: 3.12",
30
+ "Programming Language :: Python :: 3.13",
31
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
32
+ ]
33
+
34
+ [project.urls]
35
+ Repository = "https://github.com/ssmall256/natten-mps"
36
+ Issues = "https://github.com/ssmall256/natten-mps/issues"
37
+
38
+ [project.optional-dependencies]
39
+ dev = ["pytest", "pytest-benchmark"]
40
+
41
+ [tool.setuptools.dynamic]
42
+ version = {attr = "natten_mps.version.__version__"}
43
+
44
+ [tool.setuptools.packages.find]
45
+ where = ["src"]
46
+
47
+ [tool.pytest.ini_options]
48
+ pythonpath = ["src"]
49
+ testpaths = ["tests"]