hiera-optim 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,44 @@
1
+ """hiera_optim — training-throughput optimisations for FAIR's Hiera.
2
+
3
+ Quick start::
4
+
5
+ from models.hiera import mae_hiera_base_224 # FAIR Hiera
6
+ from hiera_optim import optimize
7
+
8
+ model = mae_hiera_base_224(pretrained=False, in_chans=8)
9
+ optimize(model) # in-place
10
+ # optionally: model = torch.compile(model, mode="default", dynamic=False)
11
+
12
+ What `optimize()` does:
13
+ 1. Swaps every `MaskUnitAttention` for a FlashAttention/cuDNN-friendly
14
+ 4-D variant (`MaskUnitAttentionFast`). Restores math-fallback SDPA to
15
+ fused kernel paths — 5-12× per-call attention speedup.
16
+ 2. Replaces the boolean `x[mask.tile(...)]` and `x_dec[mask] = ...`
17
+ indexing patterns with explicit `torch.gather` / `scatter_`. Removes
18
+ the `aten::nonzero` graph break (compile-friendly).
19
+
20
+ Optional add-ons (opt-in, not invoked by default):
21
+ - `optimize(model, sdpa_backend="auto" | "cudnn" | ...)`: pin the SDPA
22
+ backend per-block. Sometimes helps, sometimes hurts — see RESULTS.md.
23
+ - `enable_stage_checkpointing(model, stages=(2,))`: trade compute for
24
+ activation memory at chosen stages. OOM lever, not a throughput tool.
25
+ """
26
+ from .patch import optimize, swap_mask_unit_attention, recommended_backend
27
+ from .attention import MaskUnitAttentionFast, BACKEND_NAMES
28
+ from .checkpoint import enable_stage_checkpointing, disable_stage_checkpointing
29
+ from .adapters import HieraAdapter, get_hiera_adapter, auto_detect
30
+
31
+ __version__ = "0.1.0"
32
+
33
+ __all__ = [
34
+ "optimize",
35
+ "swap_mask_unit_attention",
36
+ "recommended_backend",
37
+ "MaskUnitAttentionFast",
38
+ "BACKEND_NAMES",
39
+ "enable_stage_checkpointing",
40
+ "disable_stage_checkpointing",
41
+ "HieraAdapter",
42
+ "get_hiera_adapter",
43
+ "auto_detect",
44
+ ]
@@ -0,0 +1,28 @@
1
+ """Model adapters.
2
+
3
+ The rest of the package is intentionally model-name-free: it operates on
4
+ PyTorch `nn.Module` graphs and looks up attention/block classes through
5
+ adapters. Each adapter teaches `optimize()` how to find the right submodules
6
+ on a specific model family.
7
+
8
+ Currently bundled:
9
+ - hiera: FAIR's Hiera / MaskedAutoencoderHiera (https://github.com/facebookresearch/hiera)
10
+
11
+ Other adapters (Swin, MViTv2, JEPA-Hiera, custom architectures) can plug in by
12
+ implementing the same `ModelAdapter` protocol.
13
+ """
14
+ from __future__ import annotations
15
+ from typing import Optional
16
+ from .hiera import HieraAdapter, get_hiera_adapter
17
+
18
+ __all__ = ["HieraAdapter", "get_hiera_adapter", "auto_detect"]
19
+
20
+
21
+ def auto_detect(model) -> Optional["ModelAdapter"]: # type: ignore[name-defined]
22
+ """Best-effort: return the right adapter for a given model. Returns None
23
+ if no bundled adapter matches.
24
+ """
25
+ a = get_hiera_adapter(model)
26
+ if a is not None:
27
+ return a
28
+ return None
@@ -0,0 +1,142 @@
1
+ """Adapter for FAIR Hiera (https://github.com/facebookresearch/hiera).
2
+
3
+ This is the ONE module in `hiera_optim` allowed to import FAIR classes by name.
4
+ Everything else works through this adapter's protocol so the package is
5
+ trivially portable to derivative architectures.
6
+
7
+ The adapter resolves at import time and degrades gracefully if FAIR Hiera
8
+ isn't installed — `get_hiera_adapter(model)` returns None instead of crashing.
9
+ """
10
+ from __future__ import annotations
11
+ import importlib
12
+ from dataclasses import dataclass
13
+ from typing import Any, Optional, Type
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class _HieraSymbols:
21
+ """Resolved references to the FAIR Hiera classes we touch."""
22
+ Hiera: Type[nn.Module]
23
+ MaskedAutoencoderHiera: Type[nn.Module]
24
+ HieraBlock: Type[nn.Module]
25
+ MaskUnitAttention: Type[nn.Module]
26
+ apply_fusion_head: callable
27
+ undo_windowing: callable
28
+
29
+
30
+ def _resolve() -> Optional[_HieraSymbols]:
31
+ """Try a few import paths. Returns None if Hiera isn't available."""
32
+ candidates = [
33
+ # In-tree (brain_atlas, the development environment)
34
+ ("models.hiera", "utils.hiera_utils"),
35
+ # PyPI package (`pip install hiera-transformer`)
36
+ ("hiera.hiera", "hiera.hiera_utils"),
37
+ ("hiera", "hiera.hiera_utils"),
38
+ ]
39
+ for hmod, hutil in candidates:
40
+ try:
41
+ hm = importlib.import_module(hmod)
42
+ hu = importlib.import_module(hutil)
43
+ return _HieraSymbols(
44
+ Hiera=hm.Hiera,
45
+ MaskedAutoencoderHiera=hm.MaskedAutoencoderHiera,
46
+ HieraBlock=hm.HieraBlock,
47
+ MaskUnitAttention=hm.MaskUnitAttention,
48
+ apply_fusion_head=hm.apply_fusion_head,
49
+ undo_windowing=hu.undo_windowing,
50
+ )
51
+ except (ImportError, AttributeError):
52
+ continue
53
+ return None
54
+
55
+
56
+ _SYMS = _resolve()
57
+
58
+
59
+ def is_available() -> bool:
60
+ return _SYMS is not None
61
+
62
+
63
+ def symbols() -> _HieraSymbols:
64
+ """Return resolved FAIR symbols. Raises if Hiera isn't installed."""
65
+ if _SYMS is None:
66
+ raise ImportError(
67
+ "FAIR Hiera not installed. Add `pip install hiera-transformer`, "
68
+ "or ensure `models.hiera` is importable in this project."
69
+ )
70
+ return _SYMS
71
+
72
+
73
+ class HieraAdapter:
74
+ """Describes how to introspect / patch a FAIR Hiera model.
75
+
76
+ The adapter is the bridge between FAIR's class hierarchy and our
77
+ framework-agnostic patching code. Methods are pure introspection — they
78
+ do not import FAIR classes unless explicitly needed.
79
+ """
80
+
81
+ def __init__(self):
82
+ if not is_available():
83
+ raise ImportError(
84
+ "HieraAdapter requires FAIR Hiera to be importable."
85
+ )
86
+ self._syms = symbols()
87
+
88
+ # ---- Introspection -----------------------------------------------------
89
+
90
+ def matches(self, model: nn.Module) -> bool:
91
+ """True if `model` is a Hiera-family model."""
92
+ return isinstance(model, (self._syms.Hiera, self._syms.MaskedAutoencoderHiera))
93
+
94
+ def is_mae(self, model: nn.Module) -> bool:
95
+ return isinstance(model, self._syms.MaskedAutoencoderHiera)
96
+
97
+ def block_class(self) -> Type[nn.Module]:
98
+ return self._syms.HieraBlock
99
+
100
+ def attention_class(self) -> Type[nn.Module]:
101
+ return self._syms.MaskUnitAttention
102
+
103
+ def encoder_blocks(self, model: nn.Module) -> nn.ModuleList:
104
+ """Returns the encoder block list (`model.blocks`)."""
105
+ return model.blocks
106
+
107
+ def decoder_blocks(self, model: nn.Module) -> Optional[nn.ModuleList]:
108
+ """Returns the MAE decoder block list, or None for non-MAE models."""
109
+ return getattr(model, "decoder_blocks", None)
110
+
111
+ def stage_ends(self, model: nn.Module) -> list[int]:
112
+ return list(model.stage_ends)
113
+
114
+ # ---- Layout convention ------------------------------------------------
115
+ #
116
+ # FAIR's `Unroll` produces a (T outer, nw inner) token layout: token n in
117
+ # the N axis maps to (t = n // nw, w = n % nw). This is what FAIR's
118
+ # `do_pool(x, stride)` (which expects stride as the OUTER dim) and the
119
+ # MaskUnitAttention reshape pattern both rely on.
120
+ layout = "T_outer_nw_inner"
121
+
122
+ # ---- Helpers used by the patched forwards -----------------------------
123
+
124
+ def apply_fusion_head(self, head: nn.Module, x: torch.Tensor) -> torch.Tensor:
125
+ return self._syms.apply_fusion_head(head, x)
126
+
127
+ def undo_windowing(self, x, shape, mu_shape):
128
+ return self._syms.undo_windowing(x, shape, mu_shape)
129
+
130
+
131
+ # Module-level singleton resolution (cheap; just an isinstance check)
132
+ _DEFAULT_ADAPTER: Optional[HieraAdapter] = None
133
+
134
+
135
+ def get_hiera_adapter(model: nn.Module) -> Optional[HieraAdapter]:
136
+ """Return a HieraAdapter if the model is a Hiera-family model, else None."""
137
+ global _DEFAULT_ADAPTER
138
+ if not is_available():
139
+ return None
140
+ if _DEFAULT_ADAPTER is None:
141
+ _DEFAULT_ADAPTER = HieraAdapter()
142
+ return _DEFAULT_ADAPTER if _DEFAULT_ADAPTER.matches(model) else None
@@ -0,0 +1,6 @@
1
+ """Attention modules. Currently exports MaskUnitAttentionFast — the
2
+ FlashAttention/cuDNN-friendly 4-D variant of FAIR's MaskUnitAttention.
3
+ """
4
+ from .mask_unit import MaskUnitAttentionFast, copy_weights_from_orig, BACKEND_NAMES
5
+
6
+ __all__ = ["MaskUnitAttentionFast", "copy_weights_from_orig", "BACKEND_NAMES"]
@@ -0,0 +1,116 @@
1
+ """Optimized MaskUnitAttention — drop-in replacement.
2
+
3
+ Key fixes vs FAIR original (models/hiera.py):
4
+ - Reshape Q/K/V to 4D (B*num_windows, heads, T, D) so SDPA dispatches to
5
+ FlashAttention (cuDNN/Flash). Original feeds 5D tensors which fall back
6
+ to the math backend (12-13x slower on stage-0 shapes per microbench).
7
+ - Match FAIR's N-axis layout exactly: token n in input has n = t*nw + w,
8
+ so within-window positions are SLOW-varying and num_windows is FAST.
9
+ That's because the upstream Unroll module produces this interleaved
10
+ layout (and do_pool relies on the stride axis being outer).
11
+ - Skip per-tensor .contiguous() — the permute lands flash-friendly.
12
+ - Optional per-stage SDPA backend hint via `sdpa_backend` attribute. Set to
13
+ one of {"cudnn", "flash", "mem_efficient", "math", None}; None lets the
14
+ PyTorch dispatcher pick. Per-stage tuning is useful because Hiera's small
15
+ Tq stages (stage-1 post q_pool: T=16) often favor mem-efficient while the
16
+ long-seq global-attention stages favor cuDNN-attn or flash on Hopper.
17
+
18
+ Numerically identical to FAIR baseline up to bf16 noise (~1e-2 max abs diff,
19
+ ~1e-3 rel RMS) — the noise is the difference between SDPA math vs flash.
20
+ """
21
+ from __future__ import annotations
22
+
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch.nn.attention import SDPBackend, sdpa_kernel
29
+
30
+
31
+ _BACKEND_MAP = {
32
+ "cudnn": SDPBackend.CUDNN_ATTENTION,
33
+ "flash": SDPBackend.FLASH_ATTENTION,
34
+ "mem_efficient": SDPBackend.EFFICIENT_ATTENTION,
35
+ "math": SDPBackend.MATH,
36
+ }
37
+
38
+ BACKEND_NAMES = tuple(_BACKEND_MAP.keys())
39
+
40
+
41
+ class MaskUnitAttentionFast(nn.Module):
42
+ """Drop-in replacement for models.hiera.MaskUnitAttention with 4D SDPA.
43
+
44
+ Args:
45
+ sdpa_backend: optional SDPA backend hint. One of {"cudnn", "flash",
46
+ "mem_efficient", "math", None}. None = default dispatcher.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ dim: int,
52
+ dim_out: int,
53
+ heads: int,
54
+ q_stride: int = 1,
55
+ window_size: int = 0,
56
+ use_mask_unit_attn: bool = False,
57
+ sdpa_backend: Optional[str] = None,
58
+ ):
59
+ super().__init__()
60
+ self.dim = dim
61
+ self.dim_out = dim_out
62
+ self.heads = heads
63
+ self.q_stride = q_stride
64
+ self.head_dim = dim_out // heads
65
+ self.scale = self.head_dim ** -0.5
66
+ self.qkv = nn.Linear(dim, 3 * dim_out)
67
+ self.proj = nn.Linear(dim_out, dim_out)
68
+ self.window_size = window_size
69
+ self.use_mask_unit_attn = use_mask_unit_attn
70
+ if sdpa_backend is not None and sdpa_backend not in _BACKEND_MAP:
71
+ raise ValueError(f"sdpa_backend must be one of {list(_BACKEND_MAP)} or None; got {sdpa_backend!r}")
72
+ self.sdpa_backend = sdpa_backend
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ B, N, _ = x.shape
76
+ H, D = self.heads, self.head_dim
77
+ nw = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
78
+ T = N // nw # tokens per window before q-pool
79
+
80
+ # qkv: (B, N, 3*dim_out)
81
+ # FAIR layout: N is interpreted as (T, nw) with nw fast-varying.
82
+ # We want (3, B*nw, H, T, D) so SDPA gets a 4D Q/K/V.
83
+ qkv = self.qkv(x).view(B, T, nw, 3, H, D)
84
+ # permute -> (3, B, nw, H, T, D), then flatten (B, nw) into batch
85
+ qkv = qkv.permute(3, 0, 2, 4, 1, 5).reshape(3, B * nw, H, T, D)
86
+ q, k, v = qkv[0], qkv[1], qkv[2]
87
+
88
+ if self.q_stride > 1:
89
+ # Max-pool Q over the q_stride flat axis. T = q_stride * Tq.
90
+ # FAIR's do_pool: view(B, stride, -1, C).max(dim=1) — stride is OUTER.
91
+ # So inside T, q_stride is slow-varying. Mirror: view (.., q_stride, Tq, ..)
92
+ Tq = T // self.q_stride
93
+ q = q.view(B * nw, H, self.q_stride, Tq, D).amax(dim=2)
94
+
95
+ # 4D SDPA → FlashAttention/cuDNN. Optionally pin a backend.
96
+ if self.sdpa_backend is None:
97
+ out = F.scaled_dot_product_attention(q, k, v)
98
+ else:
99
+ with sdpa_kernel([_BACKEND_MAP[self.sdpa_backend]]):
100
+ out = F.scaled_dot_product_attention(q, k, v)
101
+ # out: (B*nw, H, Tq, D)
102
+ Tq_out = out.shape[2]
103
+
104
+ # Back to (B, N_out, dim_out) with N_out indexed as (tq, w) tq slow / w fast
105
+ # out: (B*nw, H, Tq, D) -> (B, nw, H, Tq, D) -> (B, Tq, nw, H, D) -> reshape
106
+ out = out.view(B, nw, H, Tq_out, D).permute(0, 3, 1, 2, 4).reshape(B, Tq_out * nw, self.dim_out)
107
+ return self.proj(out)
108
+
109
+
110
+ @torch.no_grad()
111
+ def copy_weights_from_orig(fast: MaskUnitAttentionFast, orig) -> None:
112
+ """Copy parameters from FAIR's MaskUnitAttention into a fast one."""
113
+ fast.qkv.weight.copy_(orig.qkv.weight)
114
+ fast.qkv.bias.copy_(orig.qkv.bias)
115
+ fast.proj.weight.copy_(orig.proj.weight)
116
+ fast.proj.bias.copy_(orig.proj.bias)
@@ -0,0 +1,113 @@
1
+ """Selective gradient checkpointing for Hiera.
2
+
3
+ Wraps specific blocks (default: stage-2) with `torch.utils.checkpoint` so that
4
+ their activations are *not* stored during forward but recomputed on backward.
5
+ Costs ~33% extra compute on the checkpointed blocks; saves ~50% of activation
6
+ memory total in Hiera-Base (stage-2 holds 16 of 24 enc blocks).
7
+
8
+ Use case: fit larger per-GPU batch → climb out of launch-overhead regime →
9
+ higher MFU. Especially useful when going from B=128 → B=256 per GPU on H100.
10
+
11
+ Compatible with `torch.compile(mode="default")`. Use `use_reentrant=False`
12
+ (the modern checkpoint variant) to avoid known autograd issues with reentrant
13
+ checkpointing + compile.
14
+ """
15
+ from __future__ import annotations
16
+ from typing import Iterable, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.utils.checkpoint import checkpoint
21
+
22
+
23
+ class _CheckpointedBlock(nn.Module):
24
+ """Wraps a block so its forward goes through `torch.utils.checkpoint`.
25
+
26
+ Keeps a reference to the original block (for weight loading / state_dict
27
+ compatibility) and forwards all attribute access to it. Works for any
28
+ single-input single-output `nn.Module` whose forward is a function of
29
+ its parameters and the input — typical for transformer blocks.
30
+ """
31
+ def __init__(self, block: HieraBlock):
32
+ super().__init__()
33
+ self.block = block
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ return checkpoint(self.block, x, use_reentrant=False)
37
+
38
+ # Pass-through for `.attn`, `.dim`, etc. that downstream code may inspect
39
+ def __getattr__(self, name):
40
+ try:
41
+ return super().__getattr__(name)
42
+ except AttributeError:
43
+ return getattr(self.block, name)
44
+
45
+
46
+ def _block_to_stage(model: nn.Module) -> list[int]:
47
+ """Returns a list mapping block_idx -> stage_idx based on model.stage_ends."""
48
+ stage_ends = model.stage_ends
49
+ mapping = []
50
+ for i in range(stage_ends[-1] + 1):
51
+ for s, end in enumerate(stage_ends):
52
+ if i <= end:
53
+ mapping.append(s)
54
+ break
55
+ return mapping
56
+
57
+
58
+ def enable_stage_checkpointing(
59
+ model: nn.Module,
60
+ stages: Union[Iterable[int], str] = (2,),
61
+ decoder: bool = False,
62
+ ) -> int:
63
+ """In-place: wrap blocks in `stages` with checkpoint().
64
+
65
+ Args:
66
+ model: a built Hiera / MaskedAutoencoderHiera.
67
+ stages: iterable of stage indices to checkpoint (default: just stage 2),
68
+ or the string "all" to checkpoint every stage.
69
+ decoder: if True (and model is MAE) also checkpoint decoder blocks.
70
+ Decoder blocks have a much smaller token count so this is rarely
71
+ worth it; default is False.
72
+
73
+ Returns the count of wrapped blocks.
74
+ """
75
+ if isinstance(stages, str):
76
+ if stages != "all":
77
+ raise ValueError(f"stages must be an iterable or 'all'; got {stages!r}")
78
+ stages = set(range(len(model.stage_ends)))
79
+ else:
80
+ stages = set(stages)
81
+
82
+ mapping = _block_to_stage(model)
83
+ n_wrapped = 0
84
+ for i, b in enumerate(model.blocks):
85
+ if isinstance(b, _CheckpointedBlock):
86
+ continue # idempotent
87
+ if mapping[i] in stages:
88
+ model.blocks[i] = _CheckpointedBlock(b)
89
+ n_wrapped += 1
90
+
91
+ if decoder and hasattr(model, "decoder_blocks"):
92
+ for i, b in enumerate(model.decoder_blocks):
93
+ if isinstance(b, _CheckpointedBlock):
94
+ continue
95
+ model.decoder_blocks[i] = _CheckpointedBlock(b)
96
+ n_wrapped += 1
97
+
98
+ return n_wrapped
99
+
100
+
101
+ def disable_stage_checkpointing(model: nn.Module) -> int:
102
+ """Unwrap any _CheckpointedBlock wrappers. Returns count unwrapped."""
103
+ n = 0
104
+ for i, b in enumerate(model.blocks):
105
+ if isinstance(b, _CheckpointedBlock):
106
+ model.blocks[i] = b.block
107
+ n += 1
108
+ if hasattr(model, "decoder_blocks"):
109
+ for i, b in enumerate(model.decoder_blocks):
110
+ if isinstance(b, _CheckpointedBlock):
111
+ model.decoder_blocks[i] = b.block
112
+ n += 1
113
+ return n
@@ -0,0 +1,22 @@
1
+ """Triton kernels.
2
+
3
+ These are pure shape-parametric kernels with no model dependencies. Each
4
+ provides a PyTorch reference implementation alongside the Triton kernel for
5
+ correctness testing.
6
+
7
+ Bundled:
8
+ - mask_gather: Triton mask-unit gather. Correct, but `torch.gather` is
9
+ faster at the shapes we tested — kept for fusion experiments.
10
+ - flash_qpool: Triton fused Q-pool + FlashAttention. Correct, but loses
11
+ to PyTorch's flash on RTX 4090 short-seq workloads. May be useful on
12
+ Hopper with hand-tuned TMA paths.
13
+ """
14
+ from .mask_gather import gather_mu_groups_triton, keep_idx_from_bool_mask
15
+ from .flash_qpool import flash_qpool_attention, qpool_attention_ref
16
+
17
+ __all__ = [
18
+ "gather_mu_groups_triton",
19
+ "keep_idx_from_bool_mask",
20
+ "flash_qpool_attention",
21
+ "qpool_attention_ref",
22
+ ]