hiera-optim 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hiera_optim/__init__.py +44 -0
- hiera_optim/adapters/__init__.py +28 -0
- hiera_optim/adapters/hiera.py +142 -0
- hiera_optim/attention/__init__.py +6 -0
- hiera_optim/attention/mask_unit.py +116 -0
- hiera_optim/checkpoint.py +113 -0
- hiera_optim/kernels/__init__.py +22 -0
- hiera_optim/kernels/flash_qpool.py +220 -0
- hiera_optim/kernels/mask_gather.py +148 -0
- hiera_optim/ops/__init__.py +18 -0
- hiera_optim/ops/mask_gather.py +112 -0
- hiera_optim/patch.py +321 -0
- hiera_optim-0.1.0.dist-info/METADATA +135 -0
- hiera_optim-0.1.0.dist-info/RECORD +17 -0
- hiera_optim-0.1.0.dist-info/WHEEL +5 -0
- hiera_optim-0.1.0.dist-info/licenses/LICENSE +21 -0
- hiera_optim-0.1.0.dist-info/top_level.txt +1 -0
hiera_optim/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""hiera_optim — training-throughput optimisations for FAIR's Hiera.
|
|
2
|
+
|
|
3
|
+
Quick start::
|
|
4
|
+
|
|
5
|
+
from models.hiera import mae_hiera_base_224 # FAIR Hiera
|
|
6
|
+
from hiera_optim import optimize
|
|
7
|
+
|
|
8
|
+
model = mae_hiera_base_224(pretrained=False, in_chans=8)
|
|
9
|
+
optimize(model) # in-place
|
|
10
|
+
# optionally: model = torch.compile(model, mode="default", dynamic=False)
|
|
11
|
+
|
|
12
|
+
What `optimize()` does:
|
|
13
|
+
1. Swaps every `MaskUnitAttention` for a FlashAttention/cuDNN-friendly
|
|
14
|
+
4-D variant (`MaskUnitAttentionFast`). Restores math-fallback SDPA to
|
|
15
|
+
fused kernel paths — 5-12× per-call attention speedup.
|
|
16
|
+
2. Replaces the boolean `x[mask.tile(...)]` and `x_dec[mask] = ...`
|
|
17
|
+
indexing patterns with explicit `torch.gather` / `scatter_`. Removes
|
|
18
|
+
the `aten::nonzero` graph break (compile-friendly).
|
|
19
|
+
|
|
20
|
+
Optional add-ons (opt-in, not invoked by default):
|
|
21
|
+
- `optimize(model, sdpa_backend="auto" | "cudnn" | ...)`: pin the SDPA
|
|
22
|
+
backend per-block. Sometimes helps, sometimes hurts — see RESULTS.md.
|
|
23
|
+
- `enable_stage_checkpointing(model, stages=(2,))`: trade compute for
|
|
24
|
+
activation memory at chosen stages. OOM lever, not a throughput tool.
|
|
25
|
+
"""
|
|
26
|
+
from .patch import optimize, swap_mask_unit_attention, recommended_backend
|
|
27
|
+
from .attention import MaskUnitAttentionFast, BACKEND_NAMES
|
|
28
|
+
from .checkpoint import enable_stage_checkpointing, disable_stage_checkpointing
|
|
29
|
+
from .adapters import HieraAdapter, get_hiera_adapter, auto_detect
|
|
30
|
+
|
|
31
|
+
__version__ = "0.1.0"
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"optimize",
|
|
35
|
+
"swap_mask_unit_attention",
|
|
36
|
+
"recommended_backend",
|
|
37
|
+
"MaskUnitAttentionFast",
|
|
38
|
+
"BACKEND_NAMES",
|
|
39
|
+
"enable_stage_checkpointing",
|
|
40
|
+
"disable_stage_checkpointing",
|
|
41
|
+
"HieraAdapter",
|
|
42
|
+
"get_hiera_adapter",
|
|
43
|
+
"auto_detect",
|
|
44
|
+
]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Model adapters.
|
|
2
|
+
|
|
3
|
+
The rest of the package is intentionally model-name-free: it operates on
|
|
4
|
+
PyTorch `nn.Module` graphs and looks up attention/block classes through
|
|
5
|
+
adapters. Each adapter teaches `optimize()` how to find the right submodules
|
|
6
|
+
on a specific model family.
|
|
7
|
+
|
|
8
|
+
Currently bundled:
|
|
9
|
+
- hiera: FAIR's Hiera / MaskedAutoencoderHiera (https://github.com/facebookresearch/hiera)
|
|
10
|
+
|
|
11
|
+
Other adapters (Swin, MViTv2, JEPA-Hiera, custom architectures) can plug in by
|
|
12
|
+
implementing the same `ModelAdapter` protocol.
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
from typing import Optional
|
|
16
|
+
from .hiera import HieraAdapter, get_hiera_adapter
|
|
17
|
+
|
|
18
|
+
__all__ = ["HieraAdapter", "get_hiera_adapter", "auto_detect"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def auto_detect(model) -> Optional["ModelAdapter"]: # type: ignore[name-defined]
|
|
22
|
+
"""Best-effort: return the right adapter for a given model. Returns None
|
|
23
|
+
if no bundled adapter matches.
|
|
24
|
+
"""
|
|
25
|
+
a = get_hiera_adapter(model)
|
|
26
|
+
if a is not None:
|
|
27
|
+
return a
|
|
28
|
+
return None
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Adapter for FAIR Hiera (https://github.com/facebookresearch/hiera).
|
|
2
|
+
|
|
3
|
+
This is the ONE module in `hiera_optim` allowed to import FAIR classes by name.
|
|
4
|
+
Everything else works through this adapter's protocol so the package is
|
|
5
|
+
trivially portable to derivative architectures.
|
|
6
|
+
|
|
7
|
+
The adapter resolves at import time and degrades gracefully if FAIR Hiera
|
|
8
|
+
isn't installed — `get_hiera_adapter(model)` returns None instead of crashing.
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
import importlib
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import Any, Optional, Type
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class _HieraSymbols:
|
|
21
|
+
"""Resolved references to the FAIR Hiera classes we touch."""
|
|
22
|
+
Hiera: Type[nn.Module]
|
|
23
|
+
MaskedAutoencoderHiera: Type[nn.Module]
|
|
24
|
+
HieraBlock: Type[nn.Module]
|
|
25
|
+
MaskUnitAttention: Type[nn.Module]
|
|
26
|
+
apply_fusion_head: callable
|
|
27
|
+
undo_windowing: callable
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _resolve() -> Optional[_HieraSymbols]:
|
|
31
|
+
"""Try a few import paths. Returns None if Hiera isn't available."""
|
|
32
|
+
candidates = [
|
|
33
|
+
# In-tree (brain_atlas, the development environment)
|
|
34
|
+
("models.hiera", "utils.hiera_utils"),
|
|
35
|
+
# PyPI package (`pip install hiera-transformer`)
|
|
36
|
+
("hiera.hiera", "hiera.hiera_utils"),
|
|
37
|
+
("hiera", "hiera.hiera_utils"),
|
|
38
|
+
]
|
|
39
|
+
for hmod, hutil in candidates:
|
|
40
|
+
try:
|
|
41
|
+
hm = importlib.import_module(hmod)
|
|
42
|
+
hu = importlib.import_module(hutil)
|
|
43
|
+
return _HieraSymbols(
|
|
44
|
+
Hiera=hm.Hiera,
|
|
45
|
+
MaskedAutoencoderHiera=hm.MaskedAutoencoderHiera,
|
|
46
|
+
HieraBlock=hm.HieraBlock,
|
|
47
|
+
MaskUnitAttention=hm.MaskUnitAttention,
|
|
48
|
+
apply_fusion_head=hm.apply_fusion_head,
|
|
49
|
+
undo_windowing=hu.undo_windowing,
|
|
50
|
+
)
|
|
51
|
+
except (ImportError, AttributeError):
|
|
52
|
+
continue
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
_SYMS = _resolve()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def is_available() -> bool:
|
|
60
|
+
return _SYMS is not None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def symbols() -> _HieraSymbols:
|
|
64
|
+
"""Return resolved FAIR symbols. Raises if Hiera isn't installed."""
|
|
65
|
+
if _SYMS is None:
|
|
66
|
+
raise ImportError(
|
|
67
|
+
"FAIR Hiera not installed. Add `pip install hiera-transformer`, "
|
|
68
|
+
"or ensure `models.hiera` is importable in this project."
|
|
69
|
+
)
|
|
70
|
+
return _SYMS
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class HieraAdapter:
|
|
74
|
+
"""Describes how to introspect / patch a FAIR Hiera model.
|
|
75
|
+
|
|
76
|
+
The adapter is the bridge between FAIR's class hierarchy and our
|
|
77
|
+
framework-agnostic patching code. Methods are pure introspection — they
|
|
78
|
+
do not import FAIR classes unless explicitly needed.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self):
|
|
82
|
+
if not is_available():
|
|
83
|
+
raise ImportError(
|
|
84
|
+
"HieraAdapter requires FAIR Hiera to be importable."
|
|
85
|
+
)
|
|
86
|
+
self._syms = symbols()
|
|
87
|
+
|
|
88
|
+
# ---- Introspection -----------------------------------------------------
|
|
89
|
+
|
|
90
|
+
def matches(self, model: nn.Module) -> bool:
|
|
91
|
+
"""True if `model` is a Hiera-family model."""
|
|
92
|
+
return isinstance(model, (self._syms.Hiera, self._syms.MaskedAutoencoderHiera))
|
|
93
|
+
|
|
94
|
+
def is_mae(self, model: nn.Module) -> bool:
|
|
95
|
+
return isinstance(model, self._syms.MaskedAutoencoderHiera)
|
|
96
|
+
|
|
97
|
+
def block_class(self) -> Type[nn.Module]:
|
|
98
|
+
return self._syms.HieraBlock
|
|
99
|
+
|
|
100
|
+
def attention_class(self) -> Type[nn.Module]:
|
|
101
|
+
return self._syms.MaskUnitAttention
|
|
102
|
+
|
|
103
|
+
def encoder_blocks(self, model: nn.Module) -> nn.ModuleList:
|
|
104
|
+
"""Returns the encoder block list (`model.blocks`)."""
|
|
105
|
+
return model.blocks
|
|
106
|
+
|
|
107
|
+
def decoder_blocks(self, model: nn.Module) -> Optional[nn.ModuleList]:
|
|
108
|
+
"""Returns the MAE decoder block list, or None for non-MAE models."""
|
|
109
|
+
return getattr(model, "decoder_blocks", None)
|
|
110
|
+
|
|
111
|
+
def stage_ends(self, model: nn.Module) -> list[int]:
|
|
112
|
+
return list(model.stage_ends)
|
|
113
|
+
|
|
114
|
+
# ---- Layout convention ------------------------------------------------
|
|
115
|
+
#
|
|
116
|
+
# FAIR's `Unroll` produces a (T outer, nw inner) token layout: token n in
|
|
117
|
+
# the N axis maps to (t = n // nw, w = n % nw). This is what FAIR's
|
|
118
|
+
# `do_pool(x, stride)` (which expects stride as the OUTER dim) and the
|
|
119
|
+
# MaskUnitAttention reshape pattern both rely on.
|
|
120
|
+
layout = "T_outer_nw_inner"
|
|
121
|
+
|
|
122
|
+
# ---- Helpers used by the patched forwards -----------------------------
|
|
123
|
+
|
|
124
|
+
def apply_fusion_head(self, head: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
return self._syms.apply_fusion_head(head, x)
|
|
126
|
+
|
|
127
|
+
def undo_windowing(self, x, shape, mu_shape):
|
|
128
|
+
return self._syms.undo_windowing(x, shape, mu_shape)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# Module-level singleton resolution (cheap; just an isinstance check)
|
|
132
|
+
_DEFAULT_ADAPTER: Optional[HieraAdapter] = None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_hiera_adapter(model: nn.Module) -> Optional[HieraAdapter]:
|
|
136
|
+
"""Return a HieraAdapter if the model is a Hiera-family model, else None."""
|
|
137
|
+
global _DEFAULT_ADAPTER
|
|
138
|
+
if not is_available():
|
|
139
|
+
return None
|
|
140
|
+
if _DEFAULT_ADAPTER is None:
|
|
141
|
+
_DEFAULT_ADAPTER = HieraAdapter()
|
|
142
|
+
return _DEFAULT_ADAPTER if _DEFAULT_ADAPTER.matches(model) else None
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"""Attention modules. Currently exports MaskUnitAttentionFast — the
|
|
2
|
+
FlashAttention/cuDNN-friendly 4-D variant of FAIR's MaskUnitAttention.
|
|
3
|
+
"""
|
|
4
|
+
from .mask_unit import MaskUnitAttentionFast, copy_weights_from_orig, BACKEND_NAMES
|
|
5
|
+
|
|
6
|
+
__all__ = ["MaskUnitAttentionFast", "copy_weights_from_orig", "BACKEND_NAMES"]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Optimized MaskUnitAttention — drop-in replacement.
|
|
2
|
+
|
|
3
|
+
Key fixes vs FAIR original (models/hiera.py):
|
|
4
|
+
- Reshape Q/K/V to 4D (B*num_windows, heads, T, D) so SDPA dispatches to
|
|
5
|
+
FlashAttention (cuDNN/Flash). Original feeds 5D tensors which fall back
|
|
6
|
+
to the math backend (12-13x slower on stage-0 shapes per microbench).
|
|
7
|
+
- Match FAIR's N-axis layout exactly: token n in input has n = t*nw + w,
|
|
8
|
+
so within-window positions are SLOW-varying and num_windows is FAST.
|
|
9
|
+
That's because the upstream Unroll module produces this interleaved
|
|
10
|
+
layout (and do_pool relies on the stride axis being outer).
|
|
11
|
+
- Skip per-tensor .contiguous() — the permute lands flash-friendly.
|
|
12
|
+
- Optional per-stage SDPA backend hint via `sdpa_backend` attribute. Set to
|
|
13
|
+
one of {"cudnn", "flash", "mem_efficient", "math", None}; None lets the
|
|
14
|
+
PyTorch dispatcher pick. Per-stage tuning is useful because Hiera's small
|
|
15
|
+
Tq stages (stage-1 post q_pool: T=16) often favor mem-efficient while the
|
|
16
|
+
long-seq global-attention stages favor cuDNN-attn or flash on Hopper.
|
|
17
|
+
|
|
18
|
+
Numerically identical to FAIR baseline up to bf16 noise (~1e-2 max abs diff,
|
|
19
|
+
~1e-3 rel RMS) — the noise is the difference between SDPA math vs flash.
|
|
20
|
+
"""
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from typing import Optional
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
_BACKEND_MAP = {
|
|
32
|
+
"cudnn": SDPBackend.CUDNN_ATTENTION,
|
|
33
|
+
"flash": SDPBackend.FLASH_ATTENTION,
|
|
34
|
+
"mem_efficient": SDPBackend.EFFICIENT_ATTENTION,
|
|
35
|
+
"math": SDPBackend.MATH,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
BACKEND_NAMES = tuple(_BACKEND_MAP.keys())
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class MaskUnitAttentionFast(nn.Module):
|
|
42
|
+
"""Drop-in replacement for models.hiera.MaskUnitAttention with 4D SDPA.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
sdpa_backend: optional SDPA backend hint. One of {"cudnn", "flash",
|
|
46
|
+
"mem_efficient", "math", None}. None = default dispatcher.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
dim: int,
|
|
52
|
+
dim_out: int,
|
|
53
|
+
heads: int,
|
|
54
|
+
q_stride: int = 1,
|
|
55
|
+
window_size: int = 0,
|
|
56
|
+
use_mask_unit_attn: bool = False,
|
|
57
|
+
sdpa_backend: Optional[str] = None,
|
|
58
|
+
):
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.dim = dim
|
|
61
|
+
self.dim_out = dim_out
|
|
62
|
+
self.heads = heads
|
|
63
|
+
self.q_stride = q_stride
|
|
64
|
+
self.head_dim = dim_out // heads
|
|
65
|
+
self.scale = self.head_dim ** -0.5
|
|
66
|
+
self.qkv = nn.Linear(dim, 3 * dim_out)
|
|
67
|
+
self.proj = nn.Linear(dim_out, dim_out)
|
|
68
|
+
self.window_size = window_size
|
|
69
|
+
self.use_mask_unit_attn = use_mask_unit_attn
|
|
70
|
+
if sdpa_backend is not None and sdpa_backend not in _BACKEND_MAP:
|
|
71
|
+
raise ValueError(f"sdpa_backend must be one of {list(_BACKEND_MAP)} or None; got {sdpa_backend!r}")
|
|
72
|
+
self.sdpa_backend = sdpa_backend
|
|
73
|
+
|
|
74
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
75
|
+
B, N, _ = x.shape
|
|
76
|
+
H, D = self.heads, self.head_dim
|
|
77
|
+
nw = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
|
|
78
|
+
T = N // nw # tokens per window before q-pool
|
|
79
|
+
|
|
80
|
+
# qkv: (B, N, 3*dim_out)
|
|
81
|
+
# FAIR layout: N is interpreted as (T, nw) with nw fast-varying.
|
|
82
|
+
# We want (3, B*nw, H, T, D) so SDPA gets a 4D Q/K/V.
|
|
83
|
+
qkv = self.qkv(x).view(B, T, nw, 3, H, D)
|
|
84
|
+
# permute -> (3, B, nw, H, T, D), then flatten (B, nw) into batch
|
|
85
|
+
qkv = qkv.permute(3, 0, 2, 4, 1, 5).reshape(3, B * nw, H, T, D)
|
|
86
|
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
87
|
+
|
|
88
|
+
if self.q_stride > 1:
|
|
89
|
+
# Max-pool Q over the q_stride flat axis. T = q_stride * Tq.
|
|
90
|
+
# FAIR's do_pool: view(B, stride, -1, C).max(dim=1) — stride is OUTER.
|
|
91
|
+
# So inside T, q_stride is slow-varying. Mirror: view (.., q_stride, Tq, ..)
|
|
92
|
+
Tq = T // self.q_stride
|
|
93
|
+
q = q.view(B * nw, H, self.q_stride, Tq, D).amax(dim=2)
|
|
94
|
+
|
|
95
|
+
# 4D SDPA → FlashAttention/cuDNN. Optionally pin a backend.
|
|
96
|
+
if self.sdpa_backend is None:
|
|
97
|
+
out = F.scaled_dot_product_attention(q, k, v)
|
|
98
|
+
else:
|
|
99
|
+
with sdpa_kernel([_BACKEND_MAP[self.sdpa_backend]]):
|
|
100
|
+
out = F.scaled_dot_product_attention(q, k, v)
|
|
101
|
+
# out: (B*nw, H, Tq, D)
|
|
102
|
+
Tq_out = out.shape[2]
|
|
103
|
+
|
|
104
|
+
# Back to (B, N_out, dim_out) with N_out indexed as (tq, w) tq slow / w fast
|
|
105
|
+
# out: (B*nw, H, Tq, D) -> (B, nw, H, Tq, D) -> (B, Tq, nw, H, D) -> reshape
|
|
106
|
+
out = out.view(B, nw, H, Tq_out, D).permute(0, 3, 1, 2, 4).reshape(B, Tq_out * nw, self.dim_out)
|
|
107
|
+
return self.proj(out)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@torch.no_grad()
|
|
111
|
+
def copy_weights_from_orig(fast: MaskUnitAttentionFast, orig) -> None:
|
|
112
|
+
"""Copy parameters from FAIR's MaskUnitAttention into a fast one."""
|
|
113
|
+
fast.qkv.weight.copy_(orig.qkv.weight)
|
|
114
|
+
fast.qkv.bias.copy_(orig.qkv.bias)
|
|
115
|
+
fast.proj.weight.copy_(orig.proj.weight)
|
|
116
|
+
fast.proj.bias.copy_(orig.proj.bias)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Selective gradient checkpointing for Hiera.
|
|
2
|
+
|
|
3
|
+
Wraps specific blocks (default: stage-2) with `torch.utils.checkpoint` so that
|
|
4
|
+
their activations are *not* stored during forward but recomputed on backward.
|
|
5
|
+
Costs ~33% extra compute on the checkpointed blocks; saves ~50% of activation
|
|
6
|
+
memory total in Hiera-Base (stage-2 holds 16 of 24 enc blocks).
|
|
7
|
+
|
|
8
|
+
Use case: fit larger per-GPU batch → climb out of launch-overhead regime →
|
|
9
|
+
higher MFU. Especially useful when going from B=128 → B=256 per GPU on H100.
|
|
10
|
+
|
|
11
|
+
Compatible with `torch.compile(mode="default")`. Use `use_reentrant=False`
|
|
12
|
+
(the modern checkpoint variant) to avoid known autograd issues with reentrant
|
|
13
|
+
checkpointing + compile.
|
|
14
|
+
"""
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
from typing import Iterable, Optional, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
from torch.utils.checkpoint import checkpoint
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class _CheckpointedBlock(nn.Module):
|
|
24
|
+
"""Wraps a block so its forward goes through `torch.utils.checkpoint`.
|
|
25
|
+
|
|
26
|
+
Keeps a reference to the original block (for weight loading / state_dict
|
|
27
|
+
compatibility) and forwards all attribute access to it. Works for any
|
|
28
|
+
single-input single-output `nn.Module` whose forward is a function of
|
|
29
|
+
its parameters and the input — typical for transformer blocks.
|
|
30
|
+
"""
|
|
31
|
+
def __init__(self, block: HieraBlock):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.block = block
|
|
34
|
+
|
|
35
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
36
|
+
return checkpoint(self.block, x, use_reentrant=False)
|
|
37
|
+
|
|
38
|
+
# Pass-through for `.attn`, `.dim`, etc. that downstream code may inspect
|
|
39
|
+
def __getattr__(self, name):
|
|
40
|
+
try:
|
|
41
|
+
return super().__getattr__(name)
|
|
42
|
+
except AttributeError:
|
|
43
|
+
return getattr(self.block, name)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _block_to_stage(model: nn.Module) -> list[int]:
|
|
47
|
+
"""Returns a list mapping block_idx -> stage_idx based on model.stage_ends."""
|
|
48
|
+
stage_ends = model.stage_ends
|
|
49
|
+
mapping = []
|
|
50
|
+
for i in range(stage_ends[-1] + 1):
|
|
51
|
+
for s, end in enumerate(stage_ends):
|
|
52
|
+
if i <= end:
|
|
53
|
+
mapping.append(s)
|
|
54
|
+
break
|
|
55
|
+
return mapping
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def enable_stage_checkpointing(
|
|
59
|
+
model: nn.Module,
|
|
60
|
+
stages: Union[Iterable[int], str] = (2,),
|
|
61
|
+
decoder: bool = False,
|
|
62
|
+
) -> int:
|
|
63
|
+
"""In-place: wrap blocks in `stages` with checkpoint().
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
model: a built Hiera / MaskedAutoencoderHiera.
|
|
67
|
+
stages: iterable of stage indices to checkpoint (default: just stage 2),
|
|
68
|
+
or the string "all" to checkpoint every stage.
|
|
69
|
+
decoder: if True (and model is MAE) also checkpoint decoder blocks.
|
|
70
|
+
Decoder blocks have a much smaller token count so this is rarely
|
|
71
|
+
worth it; default is False.
|
|
72
|
+
|
|
73
|
+
Returns the count of wrapped blocks.
|
|
74
|
+
"""
|
|
75
|
+
if isinstance(stages, str):
|
|
76
|
+
if stages != "all":
|
|
77
|
+
raise ValueError(f"stages must be an iterable or 'all'; got {stages!r}")
|
|
78
|
+
stages = set(range(len(model.stage_ends)))
|
|
79
|
+
else:
|
|
80
|
+
stages = set(stages)
|
|
81
|
+
|
|
82
|
+
mapping = _block_to_stage(model)
|
|
83
|
+
n_wrapped = 0
|
|
84
|
+
for i, b in enumerate(model.blocks):
|
|
85
|
+
if isinstance(b, _CheckpointedBlock):
|
|
86
|
+
continue # idempotent
|
|
87
|
+
if mapping[i] in stages:
|
|
88
|
+
model.blocks[i] = _CheckpointedBlock(b)
|
|
89
|
+
n_wrapped += 1
|
|
90
|
+
|
|
91
|
+
if decoder and hasattr(model, "decoder_blocks"):
|
|
92
|
+
for i, b in enumerate(model.decoder_blocks):
|
|
93
|
+
if isinstance(b, _CheckpointedBlock):
|
|
94
|
+
continue
|
|
95
|
+
model.decoder_blocks[i] = _CheckpointedBlock(b)
|
|
96
|
+
n_wrapped += 1
|
|
97
|
+
|
|
98
|
+
return n_wrapped
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def disable_stage_checkpointing(model: nn.Module) -> int:
|
|
102
|
+
"""Unwrap any _CheckpointedBlock wrappers. Returns count unwrapped."""
|
|
103
|
+
n = 0
|
|
104
|
+
for i, b in enumerate(model.blocks):
|
|
105
|
+
if isinstance(b, _CheckpointedBlock):
|
|
106
|
+
model.blocks[i] = b.block
|
|
107
|
+
n += 1
|
|
108
|
+
if hasattr(model, "decoder_blocks"):
|
|
109
|
+
for i, b in enumerate(model.decoder_blocks):
|
|
110
|
+
if isinstance(b, _CheckpointedBlock):
|
|
111
|
+
model.decoder_blocks[i] = b.block
|
|
112
|
+
n += 1
|
|
113
|
+
return n
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Triton kernels.
|
|
2
|
+
|
|
3
|
+
These are pure shape-parametric kernels with no model dependencies. Each
|
|
4
|
+
provides a PyTorch reference implementation alongside the Triton kernel for
|
|
5
|
+
correctness testing.
|
|
6
|
+
|
|
7
|
+
Bundled:
|
|
8
|
+
- mask_gather: Triton mask-unit gather. Correct, but `torch.gather` is
|
|
9
|
+
faster at the shapes we tested — kept for fusion experiments.
|
|
10
|
+
- flash_qpool: Triton fused Q-pool + FlashAttention. Correct, but loses
|
|
11
|
+
to PyTorch's flash on RTX 4090 short-seq workloads. May be useful on
|
|
12
|
+
Hopper with hand-tuned TMA paths.
|
|
13
|
+
"""
|
|
14
|
+
from .mask_gather import gather_mu_groups_triton, keep_idx_from_bool_mask
|
|
15
|
+
from .flash_qpool import flash_qpool_attention, qpool_attention_ref
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"gather_mu_groups_triton",
|
|
19
|
+
"keep_idx_from_bool_mask",
|
|
20
|
+
"flash_qpool_attention",
|
|
21
|
+
"qpool_attention_ref",
|
|
22
|
+
]
|