hiera-optim 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hiera_optim-0.1.0/LICENSE +21 -0
- hiera_optim-0.1.0/PKG-INFO +135 -0
- hiera_optim-0.1.0/README.md +101 -0
- hiera_optim-0.1.0/hiera_optim/__init__.py +44 -0
- hiera_optim-0.1.0/hiera_optim/adapters/__init__.py +28 -0
- hiera_optim-0.1.0/hiera_optim/adapters/hiera.py +142 -0
- hiera_optim-0.1.0/hiera_optim/attention/__init__.py +6 -0
- hiera_optim-0.1.0/hiera_optim/attention/mask_unit.py +116 -0
- hiera_optim-0.1.0/hiera_optim/checkpoint.py +113 -0
- hiera_optim-0.1.0/hiera_optim/kernels/__init__.py +22 -0
- hiera_optim-0.1.0/hiera_optim/kernels/flash_qpool.py +220 -0
- hiera_optim-0.1.0/hiera_optim/kernels/mask_gather.py +148 -0
- hiera_optim-0.1.0/hiera_optim/ops/__init__.py +18 -0
- hiera_optim-0.1.0/hiera_optim/ops/mask_gather.py +112 -0
- hiera_optim-0.1.0/hiera_optim/patch.py +321 -0
- hiera_optim-0.1.0/hiera_optim.egg-info/PKG-INFO +135 -0
- hiera_optim-0.1.0/hiera_optim.egg-info/SOURCES.txt +27 -0
- hiera_optim-0.1.0/hiera_optim.egg-info/dependency_links.txt +1 -0
- hiera_optim-0.1.0/hiera_optim.egg-info/requires.txt +14 -0
- hiera_optim-0.1.0/hiera_optim.egg-info/top_level.txt +1 -0
- hiera_optim-0.1.0/pyproject.toml +53 -0
- hiera_optim-0.1.0/setup.cfg +4 -0
- hiera_optim-0.1.0/tests/test_e2e_equivalence.py +111 -0
- hiera_optim-0.1.0/tests/test_flash_qpool.py +122 -0
- hiera_optim-0.1.0/tests/test_mask_gather.py +159 -0
- hiera_optim-0.1.0/tests/test_mask_unit_attention.py +97 -0
- hiera_optim-0.1.0/tests/test_matrix.py +235 -0
- hiera_optim-0.1.0/tests/test_sdpa_backend.py +90 -0
- hiera_optim-0.1.0/tests/test_selective_checkpoint.py +153 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Maxi Kalcher
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: hiera-optim
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Drop-in throughput and memory optimisations for FAIR Hiera (4D-SDPA, gather/scatter, Triton kernels).
|
|
5
|
+
Author: Maxi Kalcher
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/avocardio/hiera-optim
|
|
8
|
+
Project-URL: Repository, https://github.com/avocardio/hiera-optim
|
|
9
|
+
Project-URL: Issues, https://github.com/avocardio/hiera-optim/issues
|
|
10
|
+
Keywords: pytorch,transformer,vision,hiera,mae,flash-attention,triton,hopper,h100,gh200
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
19
|
+
Requires-Python: >=3.10
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: torch>=2.5.0
|
|
23
|
+
Requires-Dist: triton>=2.3.0
|
|
24
|
+
Provides-Extra: hiera
|
|
25
|
+
Requires-Dist: hiera-transformer>=0.1.4; extra == "hiera"
|
|
26
|
+
Provides-Extra: test
|
|
27
|
+
Requires-Dist: pytest>=7.0; extra == "test"
|
|
28
|
+
Provides-Extra: dev
|
|
29
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
30
|
+
Requires-Dist: ruff; extra == "dev"
|
|
31
|
+
Requires-Dist: build; extra == "dev"
|
|
32
|
+
Requires-Dist: twine; extra == "dev"
|
|
33
|
+
Dynamic: license-file
|
|
34
|
+
|
|
35
|
+
# hiera-optim
|
|
36
|
+
|
|
37
|
+
Drop-in throughput and memory optimisations for [FAIR's Hiera](https://github.com/facebookresearch/hiera) and its MAE variant. Two lines:
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
from hiera_optim import optimize
|
|
41
|
+
optimize(model)
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
restore the model's silent math-fallback attention to FlashAttention / cuDNN-attn, replace boolean mask indexing with `torch.gather` / `scatter_`, and unblock `torch.compile`. Numerically equivalent within bf16 noise.
|
|
45
|
+
|
|
46
|
+
## Results
|
|
47
|
+
|
|
48
|
+
H100 (GH200), bf16, full forward + backward.
|
|
49
|
+
|
|
50
|
+
### Production config: Hiera-Base, 224x224, 8 in-chans, B=128
|
|
51
|
+
|
|
52
|
+
| | ms / step | samples / s | peak mem |
|
|
53
|
+
|---|---|---|---|
|
|
54
|
+
| FAIR baseline + `torch.compile` | 131.7 | 972 | 14.0 GB |
|
|
55
|
+
| **hiera-optim + `torch.compile`** | **70.3** | **1820** | **9.4 GB** |
|
|
56
|
+
| speedup / saving | 1.88x | 1.87x | 33% |
|
|
57
|
+
|
|
58
|
+
### Across the variant matrix (444 GH200 cells)
|
|
59
|
+
|
|
60
|
+
| | median | mean | best | worst |
|
|
61
|
+
|---|---|---|---|---|
|
|
62
|
+
| speedup | 1.35x | 1.42x | 2.10x | 1.10x |
|
|
63
|
+
| memory ratio | 74% | 73% | 29% | 99% |
|
|
64
|
+
|
|
65
|
+
RTX 4090, Hiera-Base, 8 in-chans, B=32: 1.81x eager, **2.86x with `torch.compile`**.
|
|
66
|
+
|
|
67
|
+
Full matrix and per-cell numbers: [`MATRIX_RESULTS.md`](MATRIX_RESULTS.md).
|
|
68
|
+
|
|
69
|
+
## Install
|
|
70
|
+
|
|
71
|
+
```bash
|
|
72
|
+
pip install hiera-optim
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
From source:
|
|
76
|
+
|
|
77
|
+
```bash
|
|
78
|
+
git clone https://github.com/avocardio/hiera-optim.git
|
|
79
|
+
cd hiera-optim
|
|
80
|
+
pip install -e .
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
Requires PyTorch >= 2.5 and Triton >= 2.3. Recognises FAIR Hiera in-tree (`models.hiera`) or via PyPI (`hiera-transformer`).
|
|
84
|
+
|
|
85
|
+
## Usage
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
import torch
|
|
89
|
+
from hiera_optim import optimize
|
|
90
|
+
from hiera import mae_hiera_base_224
|
|
91
|
+
|
|
92
|
+
model = mae_hiera_base_224(pretrained=False, in_chans=3, input_size=(224, 224))
|
|
93
|
+
optimize(model)
|
|
94
|
+
model = torch.compile(model, mode="default", dynamic=False)
|
|
95
|
+
|
|
96
|
+
x = torch.randn(128, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
|
|
97
|
+
loss, *_ = model(x, mask_ratio=0.6)
|
|
98
|
+
loss.backward()
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
`optimize(model)` does two things, in place, weights preserved:
|
|
102
|
+
|
|
103
|
+
1. Swap every `MaskUnitAttention` for a 4D-reshape variant so PyTorch SDPA dispatches to FlashAttention / cuDNN-attn / mem-efficient instead of math. FAIR's original feeds SDPA a 5-D tensor which the fused kernels reject, costing ~13x per call on Ada, ~6x on Hopper.
|
|
104
|
+
2. Swap `x[mask.tile(...)]` and `x_dec[mask] = ...` for explicit `torch.gather` / `scatter_`. Removes a slow `indexing_backward_kernel` and the `aten::nonzero` graph break that stops `torch.compile`.
|
|
105
|
+
|
|
106
|
+
## Optional
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
from hiera_optim import optimize, enable_stage_checkpointing
|
|
110
|
+
|
|
111
|
+
optimize(model, sdpa_backend="auto") # per-block SDPA hint
|
|
112
|
+
enable_stage_checkpointing(model, stages=(2,)) # OOM lever
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
## GPU support
|
|
116
|
+
|
|
117
|
+
| Architecture | SM | Status |
|
|
118
|
+
|---|---|---|
|
|
119
|
+
| Ada (RTX 4090, L40) | SM89 | Tested |
|
|
120
|
+
| Hopper (H100, GH200) | SM90 | Tested |
|
|
121
|
+
| Ampere (A100) | SM80 | Should work |
|
|
122
|
+
| Blackwell (B200) | SM100 | Should work |
|
|
123
|
+
|
|
124
|
+
## Tests
|
|
125
|
+
|
|
126
|
+
```bash
|
|
127
|
+
pip install -e .[test]
|
|
128
|
+
pytest
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
112 tests cover all 5 Hiera variants x q_pool {1, 2, 3} x mask ratios x bf16/fp16/fp32 x 1D/2D/3D inputs x classification + MAE.
|
|
132
|
+
|
|
133
|
+
## License
|
|
134
|
+
|
|
135
|
+
MIT.
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# hiera-optim
|
|
2
|
+
|
|
3
|
+
Drop-in throughput and memory optimisations for [FAIR's Hiera](https://github.com/facebookresearch/hiera) and its MAE variant. Two lines:
|
|
4
|
+
|
|
5
|
+
```python
|
|
6
|
+
from hiera_optim import optimize
|
|
7
|
+
optimize(model)
|
|
8
|
+
```
|
|
9
|
+
|
|
10
|
+
restore the model's silent math-fallback attention to FlashAttention / cuDNN-attn, replace boolean mask indexing with `torch.gather` / `scatter_`, and unblock `torch.compile`. Numerically equivalent within bf16 noise.
|
|
11
|
+
|
|
12
|
+
## Results
|
|
13
|
+
|
|
14
|
+
H100 (GH200), bf16, full forward + backward.
|
|
15
|
+
|
|
16
|
+
### Production config: Hiera-Base, 224x224, 8 in-chans, B=128
|
|
17
|
+
|
|
18
|
+
| | ms / step | samples / s | peak mem |
|
|
19
|
+
|---|---|---|---|
|
|
20
|
+
| FAIR baseline + `torch.compile` | 131.7 | 972 | 14.0 GB |
|
|
21
|
+
| **hiera-optim + `torch.compile`** | **70.3** | **1820** | **9.4 GB** |
|
|
22
|
+
| speedup / saving | 1.88x | 1.87x | 33% |
|
|
23
|
+
|
|
24
|
+
### Across the variant matrix (444 GH200 cells)
|
|
25
|
+
|
|
26
|
+
| | median | mean | best | worst |
|
|
27
|
+
|---|---|---|---|---|
|
|
28
|
+
| speedup | 1.35x | 1.42x | 2.10x | 1.10x |
|
|
29
|
+
| memory ratio | 74% | 73% | 29% | 99% |
|
|
30
|
+
|
|
31
|
+
RTX 4090, Hiera-Base, 8 in-chans, B=32: 1.81x eager, **2.86x with `torch.compile`**.
|
|
32
|
+
|
|
33
|
+
Full matrix and per-cell numbers: [`MATRIX_RESULTS.md`](MATRIX_RESULTS.md).
|
|
34
|
+
|
|
35
|
+
## Install
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
pip install hiera-optim
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
From source:
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
git clone https://github.com/avocardio/hiera-optim.git
|
|
45
|
+
cd hiera-optim
|
|
46
|
+
pip install -e .
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Requires PyTorch >= 2.5 and Triton >= 2.3. Recognises FAIR Hiera in-tree (`models.hiera`) or via PyPI (`hiera-transformer`).
|
|
50
|
+
|
|
51
|
+
## Usage
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
import torch
|
|
55
|
+
from hiera_optim import optimize
|
|
56
|
+
from hiera import mae_hiera_base_224
|
|
57
|
+
|
|
58
|
+
model = mae_hiera_base_224(pretrained=False, in_chans=3, input_size=(224, 224))
|
|
59
|
+
optimize(model)
|
|
60
|
+
model = torch.compile(model, mode="default", dynamic=False)
|
|
61
|
+
|
|
62
|
+
x = torch.randn(128, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
|
|
63
|
+
loss, *_ = model(x, mask_ratio=0.6)
|
|
64
|
+
loss.backward()
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
`optimize(model)` does two things, in place, weights preserved:
|
|
68
|
+
|
|
69
|
+
1. Swap every `MaskUnitAttention` for a 4D-reshape variant so PyTorch SDPA dispatches to FlashAttention / cuDNN-attn / mem-efficient instead of math. FAIR's original feeds SDPA a 5-D tensor which the fused kernels reject, costing ~13x per call on Ada, ~6x on Hopper.
|
|
70
|
+
2. Swap `x[mask.tile(...)]` and `x_dec[mask] = ...` for explicit `torch.gather` / `scatter_`. Removes a slow `indexing_backward_kernel` and the `aten::nonzero` graph break that stops `torch.compile`.
|
|
71
|
+
|
|
72
|
+
## Optional
|
|
73
|
+
|
|
74
|
+
```python
|
|
75
|
+
from hiera_optim import optimize, enable_stage_checkpointing
|
|
76
|
+
|
|
77
|
+
optimize(model, sdpa_backend="auto") # per-block SDPA hint
|
|
78
|
+
enable_stage_checkpointing(model, stages=(2,)) # OOM lever
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
## GPU support
|
|
82
|
+
|
|
83
|
+
| Architecture | SM | Status |
|
|
84
|
+
|---|---|---|
|
|
85
|
+
| Ada (RTX 4090, L40) | SM89 | Tested |
|
|
86
|
+
| Hopper (H100, GH200) | SM90 | Tested |
|
|
87
|
+
| Ampere (A100) | SM80 | Should work |
|
|
88
|
+
| Blackwell (B200) | SM100 | Should work |
|
|
89
|
+
|
|
90
|
+
## Tests
|
|
91
|
+
|
|
92
|
+
```bash
|
|
93
|
+
pip install -e .[test]
|
|
94
|
+
pytest
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
112 tests cover all 5 Hiera variants x q_pool {1, 2, 3} x mask ratios x bf16/fp16/fp32 x 1D/2D/3D inputs x classification + MAE.
|
|
98
|
+
|
|
99
|
+
## License
|
|
100
|
+
|
|
101
|
+
MIT.
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""hiera_optim — training-throughput optimisations for FAIR's Hiera.
|
|
2
|
+
|
|
3
|
+
Quick start::
|
|
4
|
+
|
|
5
|
+
from models.hiera import mae_hiera_base_224 # FAIR Hiera
|
|
6
|
+
from hiera_optim import optimize
|
|
7
|
+
|
|
8
|
+
model = mae_hiera_base_224(pretrained=False, in_chans=8)
|
|
9
|
+
optimize(model) # in-place
|
|
10
|
+
# optionally: model = torch.compile(model, mode="default", dynamic=False)
|
|
11
|
+
|
|
12
|
+
What `optimize()` does:
|
|
13
|
+
1. Swaps every `MaskUnitAttention` for a FlashAttention/cuDNN-friendly
|
|
14
|
+
4-D variant (`MaskUnitAttentionFast`). Restores math-fallback SDPA to
|
|
15
|
+
fused kernel paths — 5-12× per-call attention speedup.
|
|
16
|
+
2. Replaces the boolean `x[mask.tile(...)]` and `x_dec[mask] = ...`
|
|
17
|
+
indexing patterns with explicit `torch.gather` / `scatter_`. Removes
|
|
18
|
+
the `aten::nonzero` graph break (compile-friendly).
|
|
19
|
+
|
|
20
|
+
Optional add-ons (opt-in, not invoked by default):
|
|
21
|
+
- `optimize(model, sdpa_backend="auto" | "cudnn" | ...)`: pin the SDPA
|
|
22
|
+
backend per-block. Sometimes helps, sometimes hurts — see RESULTS.md.
|
|
23
|
+
- `enable_stage_checkpointing(model, stages=(2,))`: trade compute for
|
|
24
|
+
activation memory at chosen stages. OOM lever, not a throughput tool.
|
|
25
|
+
"""
|
|
26
|
+
from .patch import optimize, swap_mask_unit_attention, recommended_backend
|
|
27
|
+
from .attention import MaskUnitAttentionFast, BACKEND_NAMES
|
|
28
|
+
from .checkpoint import enable_stage_checkpointing, disable_stage_checkpointing
|
|
29
|
+
from .adapters import HieraAdapter, get_hiera_adapter, auto_detect
|
|
30
|
+
|
|
31
|
+
__version__ = "0.1.0"
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"optimize",
|
|
35
|
+
"swap_mask_unit_attention",
|
|
36
|
+
"recommended_backend",
|
|
37
|
+
"MaskUnitAttentionFast",
|
|
38
|
+
"BACKEND_NAMES",
|
|
39
|
+
"enable_stage_checkpointing",
|
|
40
|
+
"disable_stage_checkpointing",
|
|
41
|
+
"HieraAdapter",
|
|
42
|
+
"get_hiera_adapter",
|
|
43
|
+
"auto_detect",
|
|
44
|
+
]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Model adapters.
|
|
2
|
+
|
|
3
|
+
The rest of the package is intentionally model-name-free: it operates on
|
|
4
|
+
PyTorch `nn.Module` graphs and looks up attention/block classes through
|
|
5
|
+
adapters. Each adapter teaches `optimize()` how to find the right submodules
|
|
6
|
+
on a specific model family.
|
|
7
|
+
|
|
8
|
+
Currently bundled:
|
|
9
|
+
- hiera: FAIR's Hiera / MaskedAutoencoderHiera (https://github.com/facebookresearch/hiera)
|
|
10
|
+
|
|
11
|
+
Other adapters (Swin, MViTv2, JEPA-Hiera, custom architectures) can plug in by
|
|
12
|
+
implementing the same `ModelAdapter` protocol.
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
from typing import Optional
|
|
16
|
+
from .hiera import HieraAdapter, get_hiera_adapter
|
|
17
|
+
|
|
18
|
+
__all__ = ["HieraAdapter", "get_hiera_adapter", "auto_detect"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def auto_detect(model) -> Optional["ModelAdapter"]: # type: ignore[name-defined]
|
|
22
|
+
"""Best-effort: return the right adapter for a given model. Returns None
|
|
23
|
+
if no bundled adapter matches.
|
|
24
|
+
"""
|
|
25
|
+
a = get_hiera_adapter(model)
|
|
26
|
+
if a is not None:
|
|
27
|
+
return a
|
|
28
|
+
return None
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Adapter for FAIR Hiera (https://github.com/facebookresearch/hiera).
|
|
2
|
+
|
|
3
|
+
This is the ONE module in `hiera_optim` allowed to import FAIR classes by name.
|
|
4
|
+
Everything else works through this adapter's protocol so the package is
|
|
5
|
+
trivially portable to derivative architectures.
|
|
6
|
+
|
|
7
|
+
The adapter resolves at import time and degrades gracefully if FAIR Hiera
|
|
8
|
+
isn't installed — `get_hiera_adapter(model)` returns None instead of crashing.
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
import importlib
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import Any, Optional, Type
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class _HieraSymbols:
|
|
21
|
+
"""Resolved references to the FAIR Hiera classes we touch."""
|
|
22
|
+
Hiera: Type[nn.Module]
|
|
23
|
+
MaskedAutoencoderHiera: Type[nn.Module]
|
|
24
|
+
HieraBlock: Type[nn.Module]
|
|
25
|
+
MaskUnitAttention: Type[nn.Module]
|
|
26
|
+
apply_fusion_head: callable
|
|
27
|
+
undo_windowing: callable
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _resolve() -> Optional[_HieraSymbols]:
|
|
31
|
+
"""Try a few import paths. Returns None if Hiera isn't available."""
|
|
32
|
+
candidates = [
|
|
33
|
+
# In-tree (brain_atlas, the development environment)
|
|
34
|
+
("models.hiera", "utils.hiera_utils"),
|
|
35
|
+
# PyPI package (`pip install hiera-transformer`)
|
|
36
|
+
("hiera.hiera", "hiera.hiera_utils"),
|
|
37
|
+
("hiera", "hiera.hiera_utils"),
|
|
38
|
+
]
|
|
39
|
+
for hmod, hutil in candidates:
|
|
40
|
+
try:
|
|
41
|
+
hm = importlib.import_module(hmod)
|
|
42
|
+
hu = importlib.import_module(hutil)
|
|
43
|
+
return _HieraSymbols(
|
|
44
|
+
Hiera=hm.Hiera,
|
|
45
|
+
MaskedAutoencoderHiera=hm.MaskedAutoencoderHiera,
|
|
46
|
+
HieraBlock=hm.HieraBlock,
|
|
47
|
+
MaskUnitAttention=hm.MaskUnitAttention,
|
|
48
|
+
apply_fusion_head=hm.apply_fusion_head,
|
|
49
|
+
undo_windowing=hu.undo_windowing,
|
|
50
|
+
)
|
|
51
|
+
except (ImportError, AttributeError):
|
|
52
|
+
continue
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
_SYMS = _resolve()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def is_available() -> bool:
|
|
60
|
+
return _SYMS is not None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def symbols() -> _HieraSymbols:
|
|
64
|
+
"""Return resolved FAIR symbols. Raises if Hiera isn't installed."""
|
|
65
|
+
if _SYMS is None:
|
|
66
|
+
raise ImportError(
|
|
67
|
+
"FAIR Hiera not installed. Add `pip install hiera-transformer`, "
|
|
68
|
+
"or ensure `models.hiera` is importable in this project."
|
|
69
|
+
)
|
|
70
|
+
return _SYMS
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class HieraAdapter:
|
|
74
|
+
"""Describes how to introspect / patch a FAIR Hiera model.
|
|
75
|
+
|
|
76
|
+
The adapter is the bridge between FAIR's class hierarchy and our
|
|
77
|
+
framework-agnostic patching code. Methods are pure introspection — they
|
|
78
|
+
do not import FAIR classes unless explicitly needed.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self):
|
|
82
|
+
if not is_available():
|
|
83
|
+
raise ImportError(
|
|
84
|
+
"HieraAdapter requires FAIR Hiera to be importable."
|
|
85
|
+
)
|
|
86
|
+
self._syms = symbols()
|
|
87
|
+
|
|
88
|
+
# ---- Introspection -----------------------------------------------------
|
|
89
|
+
|
|
90
|
+
def matches(self, model: nn.Module) -> bool:
|
|
91
|
+
"""True if `model` is a Hiera-family model."""
|
|
92
|
+
return isinstance(model, (self._syms.Hiera, self._syms.MaskedAutoencoderHiera))
|
|
93
|
+
|
|
94
|
+
def is_mae(self, model: nn.Module) -> bool:
|
|
95
|
+
return isinstance(model, self._syms.MaskedAutoencoderHiera)
|
|
96
|
+
|
|
97
|
+
def block_class(self) -> Type[nn.Module]:
|
|
98
|
+
return self._syms.HieraBlock
|
|
99
|
+
|
|
100
|
+
def attention_class(self) -> Type[nn.Module]:
|
|
101
|
+
return self._syms.MaskUnitAttention
|
|
102
|
+
|
|
103
|
+
def encoder_blocks(self, model: nn.Module) -> nn.ModuleList:
|
|
104
|
+
"""Returns the encoder block list (`model.blocks`)."""
|
|
105
|
+
return model.blocks
|
|
106
|
+
|
|
107
|
+
def decoder_blocks(self, model: nn.Module) -> Optional[nn.ModuleList]:
|
|
108
|
+
"""Returns the MAE decoder block list, or None for non-MAE models."""
|
|
109
|
+
return getattr(model, "decoder_blocks", None)
|
|
110
|
+
|
|
111
|
+
def stage_ends(self, model: nn.Module) -> list[int]:
|
|
112
|
+
return list(model.stage_ends)
|
|
113
|
+
|
|
114
|
+
# ---- Layout convention ------------------------------------------------
|
|
115
|
+
#
|
|
116
|
+
# FAIR's `Unroll` produces a (T outer, nw inner) token layout: token n in
|
|
117
|
+
# the N axis maps to (t = n // nw, w = n % nw). This is what FAIR's
|
|
118
|
+
# `do_pool(x, stride)` (which expects stride as the OUTER dim) and the
|
|
119
|
+
# MaskUnitAttention reshape pattern both rely on.
|
|
120
|
+
layout = "T_outer_nw_inner"
|
|
121
|
+
|
|
122
|
+
# ---- Helpers used by the patched forwards -----------------------------
|
|
123
|
+
|
|
124
|
+
def apply_fusion_head(self, head: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
return self._syms.apply_fusion_head(head, x)
|
|
126
|
+
|
|
127
|
+
def undo_windowing(self, x, shape, mu_shape):
|
|
128
|
+
return self._syms.undo_windowing(x, shape, mu_shape)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# Module-level singleton resolution (cheap; just an isinstance check)
|
|
132
|
+
_DEFAULT_ADAPTER: Optional[HieraAdapter] = None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_hiera_adapter(model: nn.Module) -> Optional[HieraAdapter]:
|
|
136
|
+
"""Return a HieraAdapter if the model is a Hiera-family model, else None."""
|
|
137
|
+
global _DEFAULT_ADAPTER
|
|
138
|
+
if not is_available():
|
|
139
|
+
return None
|
|
140
|
+
if _DEFAULT_ADAPTER is None:
|
|
141
|
+
_DEFAULT_ADAPTER = HieraAdapter()
|
|
142
|
+
return _DEFAULT_ADAPTER if _DEFAULT_ADAPTER.matches(model) else None
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""Attention modules. Currently exports MaskUnitAttentionFast — the
|
|
2
|
+
FlashAttention/cuDNN-friendly 4-D variant of FAIR's MaskUnitAttention.
|
|
3
|
+
"""
|
|
4
|
+
from .mask_unit import MaskUnitAttentionFast, copy_weights_from_orig, BACKEND_NAMES
|
|
5
|
+
|
|
6
|
+
__all__ = ["MaskUnitAttentionFast", "copy_weights_from_orig", "BACKEND_NAMES"]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Optimized MaskUnitAttention — drop-in replacement.
|
|
2
|
+
|
|
3
|
+
Key fixes vs FAIR original (models/hiera.py):
|
|
4
|
+
- Reshape Q/K/V to 4D (B*num_windows, heads, T, D) so SDPA dispatches to
|
|
5
|
+
FlashAttention (cuDNN/Flash). Original feeds 5D tensors which fall back
|
|
6
|
+
to the math backend (12-13x slower on stage-0 shapes per microbench).
|
|
7
|
+
- Match FAIR's N-axis layout exactly: token n in input has n = t*nw + w,
|
|
8
|
+
so within-window positions are SLOW-varying and num_windows is FAST.
|
|
9
|
+
That's because the upstream Unroll module produces this interleaved
|
|
10
|
+
layout (and do_pool relies on the stride axis being outer).
|
|
11
|
+
- Skip per-tensor .contiguous() — the permute lands flash-friendly.
|
|
12
|
+
- Optional per-stage SDPA backend hint via `sdpa_backend` attribute. Set to
|
|
13
|
+
one of {"cudnn", "flash", "mem_efficient", "math", None}; None lets the
|
|
14
|
+
PyTorch dispatcher pick. Per-stage tuning is useful because Hiera's small
|
|
15
|
+
Tq stages (stage-1 post q_pool: T=16) often favor mem-efficient while the
|
|
16
|
+
long-seq global-attention stages favor cuDNN-attn or flash on Hopper.
|
|
17
|
+
|
|
18
|
+
Numerically identical to FAIR baseline up to bf16 noise (~1e-2 max abs diff,
|
|
19
|
+
~1e-3 rel RMS) — the noise is the difference between SDPA math vs flash.
|
|
20
|
+
"""
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from typing import Optional
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
_BACKEND_MAP = {
|
|
32
|
+
"cudnn": SDPBackend.CUDNN_ATTENTION,
|
|
33
|
+
"flash": SDPBackend.FLASH_ATTENTION,
|
|
34
|
+
"mem_efficient": SDPBackend.EFFICIENT_ATTENTION,
|
|
35
|
+
"math": SDPBackend.MATH,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
BACKEND_NAMES = tuple(_BACKEND_MAP.keys())
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class MaskUnitAttentionFast(nn.Module):
|
|
42
|
+
"""Drop-in replacement for models.hiera.MaskUnitAttention with 4D SDPA.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
sdpa_backend: optional SDPA backend hint. One of {"cudnn", "flash",
|
|
46
|
+
"mem_efficient", "math", None}. None = default dispatcher.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
dim: int,
|
|
52
|
+
dim_out: int,
|
|
53
|
+
heads: int,
|
|
54
|
+
q_stride: int = 1,
|
|
55
|
+
window_size: int = 0,
|
|
56
|
+
use_mask_unit_attn: bool = False,
|
|
57
|
+
sdpa_backend: Optional[str] = None,
|
|
58
|
+
):
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.dim = dim
|
|
61
|
+
self.dim_out = dim_out
|
|
62
|
+
self.heads = heads
|
|
63
|
+
self.q_stride = q_stride
|
|
64
|
+
self.head_dim = dim_out // heads
|
|
65
|
+
self.scale = self.head_dim ** -0.5
|
|
66
|
+
self.qkv = nn.Linear(dim, 3 * dim_out)
|
|
67
|
+
self.proj = nn.Linear(dim_out, dim_out)
|
|
68
|
+
self.window_size = window_size
|
|
69
|
+
self.use_mask_unit_attn = use_mask_unit_attn
|
|
70
|
+
if sdpa_backend is not None and sdpa_backend not in _BACKEND_MAP:
|
|
71
|
+
raise ValueError(f"sdpa_backend must be one of {list(_BACKEND_MAP)} or None; got {sdpa_backend!r}")
|
|
72
|
+
self.sdpa_backend = sdpa_backend
|
|
73
|
+
|
|
74
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
75
|
+
B, N, _ = x.shape
|
|
76
|
+
H, D = self.heads, self.head_dim
|
|
77
|
+
nw = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
|
|
78
|
+
T = N // nw # tokens per window before q-pool
|
|
79
|
+
|
|
80
|
+
# qkv: (B, N, 3*dim_out)
|
|
81
|
+
# FAIR layout: N is interpreted as (T, nw) with nw fast-varying.
|
|
82
|
+
# We want (3, B*nw, H, T, D) so SDPA gets a 4D Q/K/V.
|
|
83
|
+
qkv = self.qkv(x).view(B, T, nw, 3, H, D)
|
|
84
|
+
# permute -> (3, B, nw, H, T, D), then flatten (B, nw) into batch
|
|
85
|
+
qkv = qkv.permute(3, 0, 2, 4, 1, 5).reshape(3, B * nw, H, T, D)
|
|
86
|
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
87
|
+
|
|
88
|
+
if self.q_stride > 1:
|
|
89
|
+
# Max-pool Q over the q_stride flat axis. T = q_stride * Tq.
|
|
90
|
+
# FAIR's do_pool: view(B, stride, -1, C).max(dim=1) — stride is OUTER.
|
|
91
|
+
# So inside T, q_stride is slow-varying. Mirror: view (.., q_stride, Tq, ..)
|
|
92
|
+
Tq = T // self.q_stride
|
|
93
|
+
q = q.view(B * nw, H, self.q_stride, Tq, D).amax(dim=2)
|
|
94
|
+
|
|
95
|
+
# 4D SDPA → FlashAttention/cuDNN. Optionally pin a backend.
|
|
96
|
+
if self.sdpa_backend is None:
|
|
97
|
+
out = F.scaled_dot_product_attention(q, k, v)
|
|
98
|
+
else:
|
|
99
|
+
with sdpa_kernel([_BACKEND_MAP[self.sdpa_backend]]):
|
|
100
|
+
out = F.scaled_dot_product_attention(q, k, v)
|
|
101
|
+
# out: (B*nw, H, Tq, D)
|
|
102
|
+
Tq_out = out.shape[2]
|
|
103
|
+
|
|
104
|
+
# Back to (B, N_out, dim_out) with N_out indexed as (tq, w) tq slow / w fast
|
|
105
|
+
# out: (B*nw, H, Tq, D) -> (B, nw, H, Tq, D) -> (B, Tq, nw, H, D) -> reshape
|
|
106
|
+
out = out.view(B, nw, H, Tq_out, D).permute(0, 3, 1, 2, 4).reshape(B, Tq_out * nw, self.dim_out)
|
|
107
|
+
return self.proj(out)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@torch.no_grad()
|
|
111
|
+
def copy_weights_from_orig(fast: MaskUnitAttentionFast, orig) -> None:
|
|
112
|
+
"""Copy parameters from FAIR's MaskUnitAttention into a fast one."""
|
|
113
|
+
fast.qkv.weight.copy_(orig.qkv.weight)
|
|
114
|
+
fast.qkv.bias.copy_(orig.qkv.bias)
|
|
115
|
+
fast.proj.weight.copy_(orig.proj.weight)
|
|
116
|
+
fast.proj.bias.copy_(orig.proj.bias)
|