hiera-optim 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. hiera_optim-0.1.0/LICENSE +21 -0
  2. hiera_optim-0.1.0/PKG-INFO +135 -0
  3. hiera_optim-0.1.0/README.md +101 -0
  4. hiera_optim-0.1.0/hiera_optim/__init__.py +44 -0
  5. hiera_optim-0.1.0/hiera_optim/adapters/__init__.py +28 -0
  6. hiera_optim-0.1.0/hiera_optim/adapters/hiera.py +142 -0
  7. hiera_optim-0.1.0/hiera_optim/attention/__init__.py +6 -0
  8. hiera_optim-0.1.0/hiera_optim/attention/mask_unit.py +116 -0
  9. hiera_optim-0.1.0/hiera_optim/checkpoint.py +113 -0
  10. hiera_optim-0.1.0/hiera_optim/kernels/__init__.py +22 -0
  11. hiera_optim-0.1.0/hiera_optim/kernels/flash_qpool.py +220 -0
  12. hiera_optim-0.1.0/hiera_optim/kernels/mask_gather.py +148 -0
  13. hiera_optim-0.1.0/hiera_optim/ops/__init__.py +18 -0
  14. hiera_optim-0.1.0/hiera_optim/ops/mask_gather.py +112 -0
  15. hiera_optim-0.1.0/hiera_optim/patch.py +321 -0
  16. hiera_optim-0.1.0/hiera_optim.egg-info/PKG-INFO +135 -0
  17. hiera_optim-0.1.0/hiera_optim.egg-info/SOURCES.txt +27 -0
  18. hiera_optim-0.1.0/hiera_optim.egg-info/dependency_links.txt +1 -0
  19. hiera_optim-0.1.0/hiera_optim.egg-info/requires.txt +14 -0
  20. hiera_optim-0.1.0/hiera_optim.egg-info/top_level.txt +1 -0
  21. hiera_optim-0.1.0/pyproject.toml +53 -0
  22. hiera_optim-0.1.0/setup.cfg +4 -0
  23. hiera_optim-0.1.0/tests/test_e2e_equivalence.py +111 -0
  24. hiera_optim-0.1.0/tests/test_flash_qpool.py +122 -0
  25. hiera_optim-0.1.0/tests/test_mask_gather.py +159 -0
  26. hiera_optim-0.1.0/tests/test_mask_unit_attention.py +97 -0
  27. hiera_optim-0.1.0/tests/test_matrix.py +235 -0
  28. hiera_optim-0.1.0/tests/test_sdpa_backend.py +90 -0
  29. hiera_optim-0.1.0/tests/test_selective_checkpoint.py +153 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Maxi Kalcher
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,135 @@
1
+ Metadata-Version: 2.4
2
+ Name: hiera-optim
3
+ Version: 0.1.0
4
+ Summary: Drop-in throughput and memory optimisations for FAIR Hiera (4D-SDPA, gather/scatter, Triton kernels).
5
+ Author: Maxi Kalcher
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/avocardio/hiera-optim
8
+ Project-URL: Repository, https://github.com/avocardio/hiera-optim
9
+ Project-URL: Issues, https://github.com/avocardio/hiera-optim/issues
10
+ Keywords: pytorch,transformer,vision,hiera,mae,flash-attention,triton,hopper,h100,gh200
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Classifier: Operating System :: POSIX :: Linux
19
+ Requires-Python: >=3.10
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: torch>=2.5.0
23
+ Requires-Dist: triton>=2.3.0
24
+ Provides-Extra: hiera
25
+ Requires-Dist: hiera-transformer>=0.1.4; extra == "hiera"
26
+ Provides-Extra: test
27
+ Requires-Dist: pytest>=7.0; extra == "test"
28
+ Provides-Extra: dev
29
+ Requires-Dist: pytest>=7.0; extra == "dev"
30
+ Requires-Dist: ruff; extra == "dev"
31
+ Requires-Dist: build; extra == "dev"
32
+ Requires-Dist: twine; extra == "dev"
33
+ Dynamic: license-file
34
+
35
+ # hiera-optim
36
+
37
+ Drop-in throughput and memory optimisations for [FAIR's Hiera](https://github.com/facebookresearch/hiera) and its MAE variant. Two lines:
38
+
39
+ ```python
40
+ from hiera_optim import optimize
41
+ optimize(model)
42
+ ```
43
+
44
+ restore the model's silent math-fallback attention to FlashAttention / cuDNN-attn, replace boolean mask indexing with `torch.gather` / `scatter_`, and unblock `torch.compile`. Numerically equivalent within bf16 noise.
45
+
46
+ ## Results
47
+
48
+ H100 (GH200), bf16, full forward + backward.
49
+
50
+ ### Production config: Hiera-Base, 224x224, 8 in-chans, B=128
51
+
52
+ | | ms / step | samples / s | peak mem |
53
+ |---|---|---|---|
54
+ | FAIR baseline + `torch.compile` | 131.7 | 972 | 14.0 GB |
55
+ | **hiera-optim + `torch.compile`** | **70.3** | **1820** | **9.4 GB** |
56
+ | speedup / saving | 1.88x | 1.87x | 33% |
57
+
58
+ ### Across the variant matrix (444 GH200 cells)
59
+
60
+ | | median | mean | best | worst |
61
+ |---|---|---|---|---|
62
+ | speedup | 1.35x | 1.42x | 2.10x | 1.10x |
63
+ | memory ratio | 74% | 73% | 29% | 99% |
64
+
65
+ RTX 4090, Hiera-Base, 8 in-chans, B=32: 1.81x eager, **2.86x with `torch.compile`**.
66
+
67
+ Full matrix and per-cell numbers: [`MATRIX_RESULTS.md`](MATRIX_RESULTS.md).
68
+
69
+ ## Install
70
+
71
+ ```bash
72
+ pip install hiera-optim
73
+ ```
74
+
75
+ From source:
76
+
77
+ ```bash
78
+ git clone https://github.com/avocardio/hiera-optim.git
79
+ cd hiera-optim
80
+ pip install -e .
81
+ ```
82
+
83
+ Requires PyTorch >= 2.5 and Triton >= 2.3. Recognises FAIR Hiera in-tree (`models.hiera`) or via PyPI (`hiera-transformer`).
84
+
85
+ ## Usage
86
+
87
+ ```python
88
+ import torch
89
+ from hiera_optim import optimize
90
+ from hiera import mae_hiera_base_224
91
+
92
+ model = mae_hiera_base_224(pretrained=False, in_chans=3, input_size=(224, 224))
93
+ optimize(model)
94
+ model = torch.compile(model, mode="default", dynamic=False)
95
+
96
+ x = torch.randn(128, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
97
+ loss, *_ = model(x, mask_ratio=0.6)
98
+ loss.backward()
99
+ ```
100
+
101
+ `optimize(model)` does two things, in place, weights preserved:
102
+
103
+ 1. Swap every `MaskUnitAttention` for a 4D-reshape variant so PyTorch SDPA dispatches to FlashAttention / cuDNN-attn / mem-efficient instead of math. FAIR's original feeds SDPA a 5-D tensor which the fused kernels reject, costing ~13x per call on Ada, ~6x on Hopper.
104
+ 2. Swap `x[mask.tile(...)]` and `x_dec[mask] = ...` for explicit `torch.gather` / `scatter_`. Removes a slow `indexing_backward_kernel` and the `aten::nonzero` graph break that stops `torch.compile`.
105
+
106
+ ## Optional
107
+
108
+ ```python
109
+ from hiera_optim import optimize, enable_stage_checkpointing
110
+
111
+ optimize(model, sdpa_backend="auto") # per-block SDPA hint
112
+ enable_stage_checkpointing(model, stages=(2,)) # OOM lever
113
+ ```
114
+
115
+ ## GPU support
116
+
117
+ | Architecture | SM | Status |
118
+ |---|---|---|
119
+ | Ada (RTX 4090, L40) | SM89 | Tested |
120
+ | Hopper (H100, GH200) | SM90 | Tested |
121
+ | Ampere (A100) | SM80 | Should work |
122
+ | Blackwell (B200) | SM100 | Should work |
123
+
124
+ ## Tests
125
+
126
+ ```bash
127
+ pip install -e .[test]
128
+ pytest
129
+ ```
130
+
131
+ 112 tests cover all 5 Hiera variants x q_pool {1, 2, 3} x mask ratios x bf16/fp16/fp32 x 1D/2D/3D inputs x classification + MAE.
132
+
133
+ ## License
134
+
135
+ MIT.
@@ -0,0 +1,101 @@
1
+ # hiera-optim
2
+
3
+ Drop-in throughput and memory optimisations for [FAIR's Hiera](https://github.com/facebookresearch/hiera) and its MAE variant. Two lines:
4
+
5
+ ```python
6
+ from hiera_optim import optimize
7
+ optimize(model)
8
+ ```
9
+
10
+ restore the model's silent math-fallback attention to FlashAttention / cuDNN-attn, replace boolean mask indexing with `torch.gather` / `scatter_`, and unblock `torch.compile`. Numerically equivalent within bf16 noise.
11
+
12
+ ## Results
13
+
14
+ H100 (GH200), bf16, full forward + backward.
15
+
16
+ ### Production config: Hiera-Base, 224x224, 8 in-chans, B=128
17
+
18
+ | | ms / step | samples / s | peak mem |
19
+ |---|---|---|---|
20
+ | FAIR baseline + `torch.compile` | 131.7 | 972 | 14.0 GB |
21
+ | **hiera-optim + `torch.compile`** | **70.3** | **1820** | **9.4 GB** |
22
+ | speedup / saving | 1.88x | 1.87x | 33% |
23
+
24
+ ### Across the variant matrix (444 GH200 cells)
25
+
26
+ | | median | mean | best | worst |
27
+ |---|---|---|---|---|
28
+ | speedup | 1.35x | 1.42x | 2.10x | 1.10x |
29
+ | memory ratio | 74% | 73% | 29% | 99% |
30
+
31
+ RTX 4090, Hiera-Base, 8 in-chans, B=32: 1.81x eager, **2.86x with `torch.compile`**.
32
+
33
+ Full matrix and per-cell numbers: [`MATRIX_RESULTS.md`](MATRIX_RESULTS.md).
34
+
35
+ ## Install
36
+
37
+ ```bash
38
+ pip install hiera-optim
39
+ ```
40
+
41
+ From source:
42
+
43
+ ```bash
44
+ git clone https://github.com/avocardio/hiera-optim.git
45
+ cd hiera-optim
46
+ pip install -e .
47
+ ```
48
+
49
+ Requires PyTorch >= 2.5 and Triton >= 2.3. Recognises FAIR Hiera in-tree (`models.hiera`) or via PyPI (`hiera-transformer`).
50
+
51
+ ## Usage
52
+
53
+ ```python
54
+ import torch
55
+ from hiera_optim import optimize
56
+ from hiera import mae_hiera_base_224
57
+
58
+ model = mae_hiera_base_224(pretrained=False, in_chans=3, input_size=(224, 224))
59
+ optimize(model)
60
+ model = torch.compile(model, mode="default", dynamic=False)
61
+
62
+ x = torch.randn(128, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
63
+ loss, *_ = model(x, mask_ratio=0.6)
64
+ loss.backward()
65
+ ```
66
+
67
+ `optimize(model)` does two things, in place, weights preserved:
68
+
69
+ 1. Swap every `MaskUnitAttention` for a 4D-reshape variant so PyTorch SDPA dispatches to FlashAttention / cuDNN-attn / mem-efficient instead of math. FAIR's original feeds SDPA a 5-D tensor which the fused kernels reject, costing ~13x per call on Ada, ~6x on Hopper.
70
+ 2. Swap `x[mask.tile(...)]` and `x_dec[mask] = ...` for explicit `torch.gather` / `scatter_`. Removes a slow `indexing_backward_kernel` and the `aten::nonzero` graph break that stops `torch.compile`.
71
+
72
+ ## Optional
73
+
74
+ ```python
75
+ from hiera_optim import optimize, enable_stage_checkpointing
76
+
77
+ optimize(model, sdpa_backend="auto") # per-block SDPA hint
78
+ enable_stage_checkpointing(model, stages=(2,)) # OOM lever
79
+ ```
80
+
81
+ ## GPU support
82
+
83
+ | Architecture | SM | Status |
84
+ |---|---|---|
85
+ | Ada (RTX 4090, L40) | SM89 | Tested |
86
+ | Hopper (H100, GH200) | SM90 | Tested |
87
+ | Ampere (A100) | SM80 | Should work |
88
+ | Blackwell (B200) | SM100 | Should work |
89
+
90
+ ## Tests
91
+
92
+ ```bash
93
+ pip install -e .[test]
94
+ pytest
95
+ ```
96
+
97
+ 112 tests cover all 5 Hiera variants x q_pool {1, 2, 3} x mask ratios x bf16/fp16/fp32 x 1D/2D/3D inputs x classification + MAE.
98
+
99
+ ## License
100
+
101
+ MIT.
@@ -0,0 +1,44 @@
1
+ """hiera_optim — training-throughput optimisations for FAIR's Hiera.
2
+
3
+ Quick start::
4
+
5
+ from models.hiera import mae_hiera_base_224 # FAIR Hiera
6
+ from hiera_optim import optimize
7
+
8
+ model = mae_hiera_base_224(pretrained=False, in_chans=8)
9
+ optimize(model) # in-place
10
+ # optionally: model = torch.compile(model, mode="default", dynamic=False)
11
+
12
+ What `optimize()` does:
13
+ 1. Swaps every `MaskUnitAttention` for a FlashAttention/cuDNN-friendly
14
+ 4-D variant (`MaskUnitAttentionFast`). Restores math-fallback SDPA to
15
+ fused kernel paths — 5-12× per-call attention speedup.
16
+ 2. Replaces the boolean `x[mask.tile(...)]` and `x_dec[mask] = ...`
17
+ indexing patterns with explicit `torch.gather` / `scatter_`. Removes
18
+ the `aten::nonzero` graph break (compile-friendly).
19
+
20
+ Optional add-ons (opt-in, not invoked by default):
21
+ - `optimize(model, sdpa_backend="auto" | "cudnn" | ...)`: pin the SDPA
22
+ backend per-block. Sometimes helps, sometimes hurts — see RESULTS.md.
23
+ - `enable_stage_checkpointing(model, stages=(2,))`: trade compute for
24
+ activation memory at chosen stages. OOM lever, not a throughput tool.
25
+ """
26
+ from .patch import optimize, swap_mask_unit_attention, recommended_backend
27
+ from .attention import MaskUnitAttentionFast, BACKEND_NAMES
28
+ from .checkpoint import enable_stage_checkpointing, disable_stage_checkpointing
29
+ from .adapters import HieraAdapter, get_hiera_adapter, auto_detect
30
+
31
+ __version__ = "0.1.0"
32
+
33
+ __all__ = [
34
+ "optimize",
35
+ "swap_mask_unit_attention",
36
+ "recommended_backend",
37
+ "MaskUnitAttentionFast",
38
+ "BACKEND_NAMES",
39
+ "enable_stage_checkpointing",
40
+ "disable_stage_checkpointing",
41
+ "HieraAdapter",
42
+ "get_hiera_adapter",
43
+ "auto_detect",
44
+ ]
@@ -0,0 +1,28 @@
1
+ """Model adapters.
2
+
3
+ The rest of the package is intentionally model-name-free: it operates on
4
+ PyTorch `nn.Module` graphs and looks up attention/block classes through
5
+ adapters. Each adapter teaches `optimize()` how to find the right submodules
6
+ on a specific model family.
7
+
8
+ Currently bundled:
9
+ - hiera: FAIR's Hiera / MaskedAutoencoderHiera (https://github.com/facebookresearch/hiera)
10
+
11
+ Other adapters (Swin, MViTv2, JEPA-Hiera, custom architectures) can plug in by
12
+ implementing the same `ModelAdapter` protocol.
13
+ """
14
+ from __future__ import annotations
15
+ from typing import Optional
16
+ from .hiera import HieraAdapter, get_hiera_adapter
17
+
18
+ __all__ = ["HieraAdapter", "get_hiera_adapter", "auto_detect"]
19
+
20
+
21
+ def auto_detect(model) -> Optional["ModelAdapter"]: # type: ignore[name-defined]
22
+ """Best-effort: return the right adapter for a given model. Returns None
23
+ if no bundled adapter matches.
24
+ """
25
+ a = get_hiera_adapter(model)
26
+ if a is not None:
27
+ return a
28
+ return None
@@ -0,0 +1,142 @@
1
+ """Adapter for FAIR Hiera (https://github.com/facebookresearch/hiera).
2
+
3
+ This is the ONE module in `hiera_optim` allowed to import FAIR classes by name.
4
+ Everything else works through this adapter's protocol so the package is
5
+ trivially portable to derivative architectures.
6
+
7
+ The adapter resolves at import time and degrades gracefully if FAIR Hiera
8
+ isn't installed — `get_hiera_adapter(model)` returns None instead of crashing.
9
+ """
10
+ from __future__ import annotations
11
+ import importlib
12
+ from dataclasses import dataclass
13
+ from typing import Any, Optional, Type
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class _HieraSymbols:
21
+ """Resolved references to the FAIR Hiera classes we touch."""
22
+ Hiera: Type[nn.Module]
23
+ MaskedAutoencoderHiera: Type[nn.Module]
24
+ HieraBlock: Type[nn.Module]
25
+ MaskUnitAttention: Type[nn.Module]
26
+ apply_fusion_head: callable
27
+ undo_windowing: callable
28
+
29
+
30
+ def _resolve() -> Optional[_HieraSymbols]:
31
+ """Try a few import paths. Returns None if Hiera isn't available."""
32
+ candidates = [
33
+ # In-tree (brain_atlas, the development environment)
34
+ ("models.hiera", "utils.hiera_utils"),
35
+ # PyPI package (`pip install hiera-transformer`)
36
+ ("hiera.hiera", "hiera.hiera_utils"),
37
+ ("hiera", "hiera.hiera_utils"),
38
+ ]
39
+ for hmod, hutil in candidates:
40
+ try:
41
+ hm = importlib.import_module(hmod)
42
+ hu = importlib.import_module(hutil)
43
+ return _HieraSymbols(
44
+ Hiera=hm.Hiera,
45
+ MaskedAutoencoderHiera=hm.MaskedAutoencoderHiera,
46
+ HieraBlock=hm.HieraBlock,
47
+ MaskUnitAttention=hm.MaskUnitAttention,
48
+ apply_fusion_head=hm.apply_fusion_head,
49
+ undo_windowing=hu.undo_windowing,
50
+ )
51
+ except (ImportError, AttributeError):
52
+ continue
53
+ return None
54
+
55
+
56
+ _SYMS = _resolve()
57
+
58
+
59
+ def is_available() -> bool:
60
+ return _SYMS is not None
61
+
62
+
63
+ def symbols() -> _HieraSymbols:
64
+ """Return resolved FAIR symbols. Raises if Hiera isn't installed."""
65
+ if _SYMS is None:
66
+ raise ImportError(
67
+ "FAIR Hiera not installed. Add `pip install hiera-transformer`, "
68
+ "or ensure `models.hiera` is importable in this project."
69
+ )
70
+ return _SYMS
71
+
72
+
73
+ class HieraAdapter:
74
+ """Describes how to introspect / patch a FAIR Hiera model.
75
+
76
+ The adapter is the bridge between FAIR's class hierarchy and our
77
+ framework-agnostic patching code. Methods are pure introspection — they
78
+ do not import FAIR classes unless explicitly needed.
79
+ """
80
+
81
+ def __init__(self):
82
+ if not is_available():
83
+ raise ImportError(
84
+ "HieraAdapter requires FAIR Hiera to be importable."
85
+ )
86
+ self._syms = symbols()
87
+
88
+ # ---- Introspection -----------------------------------------------------
89
+
90
+ def matches(self, model: nn.Module) -> bool:
91
+ """True if `model` is a Hiera-family model."""
92
+ return isinstance(model, (self._syms.Hiera, self._syms.MaskedAutoencoderHiera))
93
+
94
+ def is_mae(self, model: nn.Module) -> bool:
95
+ return isinstance(model, self._syms.MaskedAutoencoderHiera)
96
+
97
+ def block_class(self) -> Type[nn.Module]:
98
+ return self._syms.HieraBlock
99
+
100
+ def attention_class(self) -> Type[nn.Module]:
101
+ return self._syms.MaskUnitAttention
102
+
103
+ def encoder_blocks(self, model: nn.Module) -> nn.ModuleList:
104
+ """Returns the encoder block list (`model.blocks`)."""
105
+ return model.blocks
106
+
107
+ def decoder_blocks(self, model: nn.Module) -> Optional[nn.ModuleList]:
108
+ """Returns the MAE decoder block list, or None for non-MAE models."""
109
+ return getattr(model, "decoder_blocks", None)
110
+
111
+ def stage_ends(self, model: nn.Module) -> list[int]:
112
+ return list(model.stage_ends)
113
+
114
+ # ---- Layout convention ------------------------------------------------
115
+ #
116
+ # FAIR's `Unroll` produces a (T outer, nw inner) token layout: token n in
117
+ # the N axis maps to (t = n // nw, w = n % nw). This is what FAIR's
118
+ # `do_pool(x, stride)` (which expects stride as the OUTER dim) and the
119
+ # MaskUnitAttention reshape pattern both rely on.
120
+ layout = "T_outer_nw_inner"
121
+
122
+ # ---- Helpers used by the patched forwards -----------------------------
123
+
124
+ def apply_fusion_head(self, head: nn.Module, x: torch.Tensor) -> torch.Tensor:
125
+ return self._syms.apply_fusion_head(head, x)
126
+
127
+ def undo_windowing(self, x, shape, mu_shape):
128
+ return self._syms.undo_windowing(x, shape, mu_shape)
129
+
130
+
131
+ # Module-level singleton resolution (cheap; just an isinstance check)
132
+ _DEFAULT_ADAPTER: Optional[HieraAdapter] = None
133
+
134
+
135
+ def get_hiera_adapter(model: nn.Module) -> Optional[HieraAdapter]:
136
+ """Return a HieraAdapter if the model is a Hiera-family model, else None."""
137
+ global _DEFAULT_ADAPTER
138
+ if not is_available():
139
+ return None
140
+ if _DEFAULT_ADAPTER is None:
141
+ _DEFAULT_ADAPTER = HieraAdapter()
142
+ return _DEFAULT_ADAPTER if _DEFAULT_ADAPTER.matches(model) else None
@@ -0,0 +1,6 @@
1
+ """Attention modules. Currently exports MaskUnitAttentionFast — the
2
+ FlashAttention/cuDNN-friendly 4-D variant of FAIR's MaskUnitAttention.
3
+ """
4
+ from .mask_unit import MaskUnitAttentionFast, copy_weights_from_orig, BACKEND_NAMES
5
+
6
+ __all__ = ["MaskUnitAttentionFast", "copy_weights_from_orig", "BACKEND_NAMES"]
@@ -0,0 +1,116 @@
1
+ """Optimized MaskUnitAttention — drop-in replacement.
2
+
3
+ Key fixes vs FAIR original (models/hiera.py):
4
+ - Reshape Q/K/V to 4D (B*num_windows, heads, T, D) so SDPA dispatches to
5
+ FlashAttention (cuDNN/Flash). Original feeds 5D tensors which fall back
6
+ to the math backend (12-13x slower on stage-0 shapes per microbench).
7
+ - Match FAIR's N-axis layout exactly: token n in input has n = t*nw + w,
8
+ so within-window positions are SLOW-varying and num_windows is FAST.
9
+ That's because the upstream Unroll module produces this interleaved
10
+ layout (and do_pool relies on the stride axis being outer).
11
+ - Skip per-tensor .contiguous() — the permute lands flash-friendly.
12
+ - Optional per-stage SDPA backend hint via `sdpa_backend` attribute. Set to
13
+ one of {"cudnn", "flash", "mem_efficient", "math", None}; None lets the
14
+ PyTorch dispatcher pick. Per-stage tuning is useful because Hiera's small
15
+ Tq stages (stage-1 post q_pool: T=16) often favor mem-efficient while the
16
+ long-seq global-attention stages favor cuDNN-attn or flash on Hopper.
17
+
18
+ Numerically identical to FAIR baseline up to bf16 noise (~1e-2 max abs diff,
19
+ ~1e-3 rel RMS) — the noise is the difference between SDPA math vs flash.
20
+ """
21
+ from __future__ import annotations
22
+
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch.nn.attention import SDPBackend, sdpa_kernel
29
+
30
+
31
+ _BACKEND_MAP = {
32
+ "cudnn": SDPBackend.CUDNN_ATTENTION,
33
+ "flash": SDPBackend.FLASH_ATTENTION,
34
+ "mem_efficient": SDPBackend.EFFICIENT_ATTENTION,
35
+ "math": SDPBackend.MATH,
36
+ }
37
+
38
+ BACKEND_NAMES = tuple(_BACKEND_MAP.keys())
39
+
40
+
41
+ class MaskUnitAttentionFast(nn.Module):
42
+ """Drop-in replacement for models.hiera.MaskUnitAttention with 4D SDPA.
43
+
44
+ Args:
45
+ sdpa_backend: optional SDPA backend hint. One of {"cudnn", "flash",
46
+ "mem_efficient", "math", None}. None = default dispatcher.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ dim: int,
52
+ dim_out: int,
53
+ heads: int,
54
+ q_stride: int = 1,
55
+ window_size: int = 0,
56
+ use_mask_unit_attn: bool = False,
57
+ sdpa_backend: Optional[str] = None,
58
+ ):
59
+ super().__init__()
60
+ self.dim = dim
61
+ self.dim_out = dim_out
62
+ self.heads = heads
63
+ self.q_stride = q_stride
64
+ self.head_dim = dim_out // heads
65
+ self.scale = self.head_dim ** -0.5
66
+ self.qkv = nn.Linear(dim, 3 * dim_out)
67
+ self.proj = nn.Linear(dim_out, dim_out)
68
+ self.window_size = window_size
69
+ self.use_mask_unit_attn = use_mask_unit_attn
70
+ if sdpa_backend is not None and sdpa_backend not in _BACKEND_MAP:
71
+ raise ValueError(f"sdpa_backend must be one of {list(_BACKEND_MAP)} or None; got {sdpa_backend!r}")
72
+ self.sdpa_backend = sdpa_backend
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ B, N, _ = x.shape
76
+ H, D = self.heads, self.head_dim
77
+ nw = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
78
+ T = N // nw # tokens per window before q-pool
79
+
80
+ # qkv: (B, N, 3*dim_out)
81
+ # FAIR layout: N is interpreted as (T, nw) with nw fast-varying.
82
+ # We want (3, B*nw, H, T, D) so SDPA gets a 4D Q/K/V.
83
+ qkv = self.qkv(x).view(B, T, nw, 3, H, D)
84
+ # permute -> (3, B, nw, H, T, D), then flatten (B, nw) into batch
85
+ qkv = qkv.permute(3, 0, 2, 4, 1, 5).reshape(3, B * nw, H, T, D)
86
+ q, k, v = qkv[0], qkv[1], qkv[2]
87
+
88
+ if self.q_stride > 1:
89
+ # Max-pool Q over the q_stride flat axis. T = q_stride * Tq.
90
+ # FAIR's do_pool: view(B, stride, -1, C).max(dim=1) — stride is OUTER.
91
+ # So inside T, q_stride is slow-varying. Mirror: view (.., q_stride, Tq, ..)
92
+ Tq = T // self.q_stride
93
+ q = q.view(B * nw, H, self.q_stride, Tq, D).amax(dim=2)
94
+
95
+ # 4D SDPA → FlashAttention/cuDNN. Optionally pin a backend.
96
+ if self.sdpa_backend is None:
97
+ out = F.scaled_dot_product_attention(q, k, v)
98
+ else:
99
+ with sdpa_kernel([_BACKEND_MAP[self.sdpa_backend]]):
100
+ out = F.scaled_dot_product_attention(q, k, v)
101
+ # out: (B*nw, H, Tq, D)
102
+ Tq_out = out.shape[2]
103
+
104
+ # Back to (B, N_out, dim_out) with N_out indexed as (tq, w) tq slow / w fast
105
+ # out: (B*nw, H, Tq, D) -> (B, nw, H, Tq, D) -> (B, Tq, nw, H, D) -> reshape
106
+ out = out.view(B, nw, H, Tq_out, D).permute(0, 3, 1, 2, 4).reshape(B, Tq_out * nw, self.dim_out)
107
+ return self.proj(out)
108
+
109
+
110
+ @torch.no_grad()
111
+ def copy_weights_from_orig(fast: MaskUnitAttentionFast, orig) -> None:
112
+ """Copy parameters from FAIR's MaskUnitAttention into a fast one."""
113
+ fast.qkv.weight.copy_(orig.qkv.weight)
114
+ fast.qkv.bias.copy_(orig.qkv.bias)
115
+ fast.proj.weight.copy_(orig.proj.weight)
116
+ fast.proj.bias.copy_(orig.proj.bias)