nervecode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nervecode/__init__.py +415 -0
- nervecode/_version.py +10 -0
- nervecode/core/__init__.py +19 -0
- nervecode/core/assignment.py +165 -0
- nervecode/core/codebook.py +182 -0
- nervecode/core/shapes.py +107 -0
- nervecode/core/temperature.py +227 -0
- nervecode/core/trace.py +166 -0
- nervecode/core/types.py +116 -0
- nervecode/integration/__init__.py +9 -0
- nervecode/layers/__init__.py +15 -0
- nervecode/layers/base.py +333 -0
- nervecode/layers/conv.py +174 -0
- nervecode/layers/linear.py +176 -0
- nervecode/layers/reducers.py +80 -0
- nervecode/layers/wrap.py +223 -0
- nervecode/scoring/__init__.py +20 -0
- nervecode/scoring/aggregator.py +369 -0
- nervecode/scoring/calibrator.py +396 -0
- nervecode/scoring/types.py +33 -0
- nervecode/training/__init__.py +25 -0
- nervecode/training/diagnostics.py +194 -0
- nervecode/training/loss.py +188 -0
- nervecode/training/updaters.py +168 -0
- nervecode/utils/__init__.py +14 -0
- nervecode/utils/overhead.py +177 -0
- nervecode/utils/seed.py +161 -0
- nervecode-0.1.0.dist-info/METADATA +83 -0
- nervecode-0.1.0.dist-info/RECORD +31 -0
- nervecode-0.1.0.dist-info/WHEEL +4 -0
- nervecode-0.1.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""Coding wrapper for ``torch.nn.Linear``.
|
|
2
|
+
|
|
3
|
+
This module provides ``CodingLinear`` — the first production wrapper that
|
|
4
|
+
observes a linear layer's output, reduces it (identity in the MVP), and
|
|
5
|
+
computes a coding trace via a learnable codebook and a soft assignment engine.
|
|
6
|
+
|
|
7
|
+
Observe-only contract: the wrapper never modifies the value returned to the
|
|
8
|
+
caller; coding runs side-by-side and updates an internal trace cache.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from .base import BaseCodingWrapper, ReducerFn
|
|
16
|
+
|
|
17
|
+
# Import torch lazily but with typing support. Keep runtime import optional so the
|
|
18
|
+
# package remains importable in environments without torch (tests will skip).
|
|
19
|
+
try: # pragma: no cover - exercised in torch-enabled environments
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn
|
|
22
|
+
except Exception: # pragma: no cover - allow import without torch
|
|
23
|
+
from typing import Any
|
|
24
|
+
from typing import cast as _cast
|
|
25
|
+
|
|
26
|
+
torch = _cast(Any, None)
|
|
27
|
+
nn = _cast(Any, None)
|
|
28
|
+
|
|
29
|
+
# Local core imports guarded by torch availability at use sites.
|
|
30
|
+
from nervecode.core import CodingTrace
|
|
31
|
+
from nervecode.core.assignment import SoftAssignment
|
|
32
|
+
from nervecode.core.codebook import Codebook
|
|
33
|
+
|
|
34
|
+
__all__ = ["CodingLinear"]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CodingLinear(nn.Module, BaseCodingWrapper):
|
|
38
|
+
"""Observe-only wrapper around ``nn.Linear``.
|
|
39
|
+
|
|
40
|
+
The wrapper delegates computation to the underlying ``nn.Linear`` and, when
|
|
41
|
+
coding is active, computes a ``CodingTrace`` over the visible activation.
|
|
42
|
+
The visible forward output is never altered.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
- layer: The ``nn.Linear`` module to wrap.
|
|
46
|
+
- K: Number of codebook centers. Defaults to 16.
|
|
47
|
+
- coding_dim: Target coding dimension ``D``. When ``None`` (default),
|
|
48
|
+
uses ``layer.out_features``. If provided, must satisfy
|
|
49
|
+
``1 <= coding_dim <= layer.out_features``. A learned linear projection
|
|
50
|
+
reducer is created automatically when ``layer.out_features > coding_dim``;
|
|
51
|
+
otherwise the identity reduction is used.
|
|
52
|
+
- reducer: Optional reducer callable. Overrides ``coding_dim`` behavior
|
|
53
|
+
when provided. Defaults to identity (no reduction).
|
|
54
|
+
- temperature: Soft-assignment temperature. Higher is softer.
|
|
55
|
+
- eps: Numerical epsilon to clamp divisions and logs.
|
|
56
|
+
|
|
57
|
+
Notes
|
|
58
|
+
- MVP uses identity reduction; when a lower coding dimension is desired, a
|
|
59
|
+
learned projection reducer will be added in a follow-up task.
|
|
60
|
+
- ``repr(wrapper)`` includes ``code_dim``, codebook size ``K``, reducer type,
|
|
61
|
+
and whether coding is enabled to make printed models self-descriptive.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
layer: Any, # use Any to avoid a hard torch dependency in type hints
|
|
67
|
+
*,
|
|
68
|
+
K: int = 16,
|
|
69
|
+
coding_dim: int | None = None,
|
|
70
|
+
reducer: ReducerFn | None = None,
|
|
71
|
+
temperature: float = 1.0,
|
|
72
|
+
eps: float = 1e-12,
|
|
73
|
+
) -> None:
|
|
74
|
+
if torch is None: # pragma: no cover - defensive
|
|
75
|
+
raise RuntimeError("CodingLinear requires PyTorch to be installed")
|
|
76
|
+
|
|
77
|
+
nn.Module.__init__(self)
|
|
78
|
+
BaseCodingWrapper.__init__(self, reducer=reducer)
|
|
79
|
+
|
|
80
|
+
if not isinstance(layer, nn.Linear):
|
|
81
|
+
raise TypeError("layer must be an instance of torch.nn.Linear")
|
|
82
|
+
|
|
83
|
+
self.layer: Any = layer
|
|
84
|
+
|
|
85
|
+
# Decide coding dimension and set up reducer if not provided.
|
|
86
|
+
out_features = int(self.layer.out_features)
|
|
87
|
+
if coding_dim is None:
|
|
88
|
+
code_dim = out_features
|
|
89
|
+
else:
|
|
90
|
+
code_dim = int(coding_dim)
|
|
91
|
+
if code_dim <= 0:
|
|
92
|
+
raise ValueError("coding_dim must be >= 1")
|
|
93
|
+
if code_dim > out_features:
|
|
94
|
+
raise ValueError("coding_dim must be <= layer.out_features for reduction")
|
|
95
|
+
|
|
96
|
+
# Optional learned projection reducer when output width exceeds coding dim.
|
|
97
|
+
self._projection: Any | None = None
|
|
98
|
+
if reducer is None and code_dim < out_features:
|
|
99
|
+
# Register a submodule to learn the reduction; keep observe-only semantics
|
|
100
|
+
# by applying it only for coding. Bias is unnecessary for dimension change.
|
|
101
|
+
self._projection = nn.Linear(out_features, code_dim, bias=False)
|
|
102
|
+
|
|
103
|
+
def _project_reducer(x: Any) -> tuple[Any, dict[str, Any]]:
|
|
104
|
+
y_proj = self._projection(x) # type: ignore[misc]
|
|
105
|
+
meta = {"method": "linear_projection", "in": out_features, "out": code_dim}
|
|
106
|
+
return y_proj, meta
|
|
107
|
+
|
|
108
|
+
self.set_reducer(_project_reducer)
|
|
109
|
+
# else: keep provided reducer or identity_reduction
|
|
110
|
+
|
|
111
|
+
# Codebook operates over the last (feature) dimension using the decided code_dim.
|
|
112
|
+
self.codebook = Codebook(K=int(K), code_dim=code_dim)
|
|
113
|
+
self.assignment = SoftAssignment(temperature=float(temperature), eps=float(eps))
|
|
114
|
+
|
|
115
|
+
def forward(self, x: Any) -> Any: # torch.Tensor in torch-enabled envs
|
|
116
|
+
"""Return the wrapped layer's output and update the cached trace.
|
|
117
|
+
|
|
118
|
+
This method preserves the observe-only contract: the visible value is
|
|
119
|
+
the underlying layer output. For robust programmatic use that avoids
|
|
120
|
+
reading from mutable state, use ``forward_with_trace`` to retrieve the
|
|
121
|
+
explicit ``CodingTrace`` alongside the output.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
y, _ = self.forward_with_trace(x)
|
|
125
|
+
return y
|
|
126
|
+
|
|
127
|
+
def forward_with_trace(self, x: Any) -> tuple[Any, CodingTrace | None]:
|
|
128
|
+
"""Return the wrapped output and, when active, the explicit trace.
|
|
129
|
+
|
|
130
|
+
When coding is disabled or bypassed, returns ``(y, None)`` while
|
|
131
|
+
clearing any previously cached trace to avoid confusion. When active,
|
|
132
|
+
computes a fresh ``CodingTrace``, updates the internal cache, and
|
|
133
|
+
returns ``(y, trace)``.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
y = self.layer(x)
|
|
137
|
+
|
|
138
|
+
# Fast path: skip all coding work when disabled or bypassed.
|
|
139
|
+
if not self.is_coding_active:
|
|
140
|
+
# Clear stale trace to avoid confusion when toggling modes.
|
|
141
|
+
self.clear_last_trace()
|
|
142
|
+
return y, None
|
|
143
|
+
|
|
144
|
+
# Apply the configured reducer (identity by default in MVP).
|
|
145
|
+
reduced, reduction_meta = self.maybe_reduce(y)
|
|
146
|
+
|
|
147
|
+
# Compute soft assignments and gather intermediates for the trace.
|
|
148
|
+
soft, inter = self.assignment(reduced, self.codebook)
|
|
149
|
+
|
|
150
|
+
trace = CodingTrace(
|
|
151
|
+
reduced=reduced,
|
|
152
|
+
reduction_meta=reduction_meta,
|
|
153
|
+
nearest_center_distances=inter["nearest_center_distances"],
|
|
154
|
+
chosen_center_indices=inter["chosen_center_indices"],
|
|
155
|
+
commitment_distances=inter["commitment_distances"],
|
|
156
|
+
soft_code=soft,
|
|
157
|
+
)
|
|
158
|
+
self.update_last_trace(trace)
|
|
159
|
+
return y, trace
|
|
160
|
+
|
|
161
|
+
def extra_repr(self) -> str: # pragma: no cover - cosmetic
|
|
162
|
+
"""Human-friendly summary shown inside ``nn.Module.__repr__``.
|
|
163
|
+
|
|
164
|
+
Includes the coding-space dimension (``code_dim``), codebook size
|
|
165
|
+
(``K``), the configured reducer type, and whether coding is currently
|
|
166
|
+
enabled. Also echoes the wrapped linear layer's ``out_features`` for
|
|
167
|
+
quick visual alignment.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
status = "on" if self.coding_enabled else "off"
|
|
171
|
+
return (
|
|
172
|
+
f"layer=Linear(out_features={int(self.layer.out_features)}), "
|
|
173
|
+
f"code_dim={int(self.codebook.code_dim)}, "
|
|
174
|
+
f"K={int(self.codebook.K)}, reducer={self.get_reducer().__name__}, "
|
|
175
|
+
f"coding={status}"
|
|
176
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Experimental reducers for convolutional wrappers.
|
|
2
|
+
|
|
3
|
+
This module provides optional, opt-in reducers that can be passed to
|
|
4
|
+
``CodingConv2d`` (or any wrapper built on ``BaseCodingWrapper``) to prototype
|
|
5
|
+
alternative spatial reductions beyond the default global average pooling.
|
|
6
|
+
|
|
7
|
+
The helpers return callables that match ``ReducerFn`` so they can be supplied
|
|
8
|
+
via the existing ``reducer=...`` parameter without changing stable defaults.
|
|
9
|
+
|
|
10
|
+
Notes
|
|
11
|
+
- These reducers are experimental and off by default. They preserve the
|
|
12
|
+
observe-only contract: reductions are applied only on the coding path.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from .base import ReducerFn
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"global_max_pool2d",
|
|
23
|
+
"spatial_tokens_identity_channels",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def global_max_pool2d() -> ReducerFn:
|
|
28
|
+
"""Return a reducer that applies global max pooling over H and W.
|
|
29
|
+
|
|
30
|
+
Expects a 4D tensor with NCHW layout (``(B, C, H, W)``) and returns a
|
|
31
|
+
pooled tensor of shape ``(B, C)``. Metadata records the reduction method
|
|
32
|
+
and axes.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def _reduce(x: Any) -> tuple[Any, dict[str, Any]]:
|
|
36
|
+
# Require a 4D tensor-like object with an `amax` reduction
|
|
37
|
+
if getattr(x, "ndim", 0) != 4 or not hasattr(x, "amax"):
|
|
38
|
+
raise ValueError("global_max_pool2d expects a 4D NCHW tensor from Conv2d (B, C, H, W)")
|
|
39
|
+
# Max over spatial dims (H, W) -> (B, C)
|
|
40
|
+
reduced = x.amax(dim=(-1, -2))
|
|
41
|
+
meta = {
|
|
42
|
+
"method": "global_max_pool2d",
|
|
43
|
+
"spatial_reduction": True,
|
|
44
|
+
"reduction_axes": [-2, -1], # H, W in NCHW
|
|
45
|
+
}
|
|
46
|
+
return reduced, meta
|
|
47
|
+
|
|
48
|
+
# Return the callable reducer
|
|
49
|
+
return _reduce
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def spatial_tokens_identity_channels() -> ReducerFn:
|
|
53
|
+
"""Return a reducer that exposes per-spatial-position tokens with C last.
|
|
54
|
+
|
|
55
|
+
Transforms a standard Conv2d output ``(B, C, H, W)`` into a token-like
|
|
56
|
+
layout ``(B, H, W, C)`` without reducing channels. This allows computing
|
|
57
|
+
assignments per spatial position while keeping observe-only semantics.
|
|
58
|
+
|
|
59
|
+
The metadata clarifies that the view is token-like (no spatial reduction).
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def _reduce(x: Any) -> tuple[Any, dict[str, Any]]:
|
|
63
|
+
# Require a 4D tensor-like object with `permute` and `contiguous`
|
|
64
|
+
if getattr(x, "ndim", 0) != 4 or not hasattr(x, "permute"):
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"spatial_tokens_identity_channels expects a 4D NCHW tensor (B, C, H, W)"
|
|
67
|
+
)
|
|
68
|
+
# Permute to BHWC so that the feature dimension is last for coding.
|
|
69
|
+
reduced = x.permute(0, 2, 3, 1).contiguous()
|
|
70
|
+
meta = {
|
|
71
|
+
"method": "spatial_tokens_identity",
|
|
72
|
+
"token_like": True,
|
|
73
|
+
"channels_last": True,
|
|
74
|
+
"spatial_reduction": False,
|
|
75
|
+
"source_layout": "NCHW",
|
|
76
|
+
"target_layout": "NHWC",
|
|
77
|
+
}
|
|
78
|
+
return reduced, meta
|
|
79
|
+
|
|
80
|
+
return _reduce
|
nervecode/layers/wrap.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
"""Model instrumentation helpers for coding wrappers.
|
|
2
|
+
|
|
3
|
+
This module provides a small `wrap()` function that can instrument a PyTorch
|
|
4
|
+
`nn.Module` in-place based on either:
|
|
5
|
+
|
|
6
|
+
- explicit dotted module names (e.g., ["encoder.fc", "head.1"]), or
|
|
7
|
+
- supported selection shortcuts (MVP: "all_linear").
|
|
8
|
+
|
|
9
|
+
Only supported layer types are wrapped; others are left unchanged to preserve
|
|
10
|
+
fail-open behavior. The function returns the same model instance for
|
|
11
|
+
convenience.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from collections.abc import Iterable
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
try: # Keep runtime import optional for environments without torch
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn
|
|
22
|
+
except Exception: # pragma: no cover - import-time optionality
|
|
23
|
+
from typing import Any
|
|
24
|
+
from typing import cast as _cast
|
|
25
|
+
|
|
26
|
+
torch = _cast(Any, None)
|
|
27
|
+
nn = _cast(Any, None)
|
|
28
|
+
|
|
29
|
+
from .conv import CodingConv2d
|
|
30
|
+
from .linear import CodingLinear
|
|
31
|
+
|
|
32
|
+
__all__ = ["wrap"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def wrap(
|
|
36
|
+
model: Any,
|
|
37
|
+
*,
|
|
38
|
+
layers: str | Iterable[str] = "all_linear",
|
|
39
|
+
K: int | None = None,
|
|
40
|
+
coding_dim: int | None = None,
|
|
41
|
+
temperature: float = 1.0,
|
|
42
|
+
eps: float = 1e-12,
|
|
43
|
+
) -> Any:
|
|
44
|
+
"""Instrument ``model`` in-place with coding wrappers.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
- model: A PyTorch ``nn.Module`` to instrument. The function modifies the
|
|
48
|
+
module tree in-place and returns the same instance for convenience.
|
|
49
|
+
- layers: Either a selection shortcut string (MVP supports only
|
|
50
|
+
``"all_linear"``) or an iterable of explicit dotted module names to
|
|
51
|
+
wrap. Names must match those returned by ``model.named_modules()``.
|
|
52
|
+
|
|
53
|
+
Behavior
|
|
54
|
+
- Only supported layer types are wrapped. Currently this includes
|
|
55
|
+
``nn.Linear`` via ``CodingLinear``. Unsupported selections are ignored to
|
|
56
|
+
preserve fail-open behavior.
|
|
57
|
+
- Missing names are ignored.
|
|
58
|
+
- On environments without PyTorch, the function is a no-op and returns the
|
|
59
|
+
model unchanged.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
if nn is None: # pragma: no cover - defensive; torch is a dependency in tests
|
|
63
|
+
return model
|
|
64
|
+
|
|
65
|
+
# Validate basic interface; fall back to no-op when model is not an nn.Module
|
|
66
|
+
if not isinstance(model, nn.Module):
|
|
67
|
+
return model
|
|
68
|
+
|
|
69
|
+
# Resolve target names to instrument
|
|
70
|
+
if isinstance(layers, str):
|
|
71
|
+
target_names = _select_by_shortcut(model, layers)
|
|
72
|
+
else:
|
|
73
|
+
# Deduplicate while preserving order
|
|
74
|
+
seen: set[str] = set()
|
|
75
|
+
target_names = []
|
|
76
|
+
for name in layers:
|
|
77
|
+
if not isinstance(name, str):
|
|
78
|
+
continue
|
|
79
|
+
if name not in seen:
|
|
80
|
+
seen.add(name)
|
|
81
|
+
target_names.append(name)
|
|
82
|
+
|
|
83
|
+
# Map types to wrapper factories; extend as new wrappers are added.
|
|
84
|
+
def _wrap_module(m: nn.Module) -> nn.Module | None:
|
|
85
|
+
if isinstance(m, nn.Linear):
|
|
86
|
+
kwargs: dict[str, Any] = {
|
|
87
|
+
"temperature": float(temperature),
|
|
88
|
+
"eps": float(eps),
|
|
89
|
+
}
|
|
90
|
+
if K is not None:
|
|
91
|
+
kwargs["K"] = int(K)
|
|
92
|
+
if coding_dim is not None:
|
|
93
|
+
# Clamp coding_dim per layer to avoid invalid configuration
|
|
94
|
+
try:
|
|
95
|
+
cd_req = int(coding_dim)
|
|
96
|
+
cd_eff = max(1, min(int(m.out_features), cd_req))
|
|
97
|
+
except Exception:
|
|
98
|
+
cd_eff = int(m.out_features)
|
|
99
|
+
kwargs["coding_dim"] = cd_eff
|
|
100
|
+
# Defensive: if downstream validation still rejects coding_dim,
|
|
101
|
+
# fall back to using the layer's out_features to avoid crashes.
|
|
102
|
+
try:
|
|
103
|
+
return CodingLinear(m, **kwargs)
|
|
104
|
+
except Exception as exc:
|
|
105
|
+
msg = str(exc)
|
|
106
|
+
if "coding_dim must be <= layer.out_features" in msg:
|
|
107
|
+
safe_kwargs = dict(kwargs)
|
|
108
|
+
safe_kwargs["coding_dim"] = int(getattr(m, "out_features", 1) or 1)
|
|
109
|
+
return CodingLinear(m, **safe_kwargs)
|
|
110
|
+
raise
|
|
111
|
+
if isinstance(m, nn.Conv2d):
|
|
112
|
+
kwargs_c: dict[str, Any] = {
|
|
113
|
+
"temperature": float(temperature),
|
|
114
|
+
"eps": float(eps),
|
|
115
|
+
}
|
|
116
|
+
if K is not None:
|
|
117
|
+
kwargs_c["K"] = int(K)
|
|
118
|
+
if coding_dim is not None:
|
|
119
|
+
try:
|
|
120
|
+
cd_req = int(coding_dim)
|
|
121
|
+
cd_eff = max(1, min(int(m.out_channels), cd_req))
|
|
122
|
+
except Exception:
|
|
123
|
+
cd_eff = int(m.out_channels)
|
|
124
|
+
kwargs_c["coding_dim"] = cd_eff
|
|
125
|
+
try:
|
|
126
|
+
return CodingConv2d(m, **kwargs_c)
|
|
127
|
+
except Exception as exc:
|
|
128
|
+
msg = str(exc)
|
|
129
|
+
if "coding_dim must be <=" in msg:
|
|
130
|
+
safe_kwargs_c = dict(kwargs_c)
|
|
131
|
+
safe_kwargs_c["coding_dim"] = int(getattr(m, "out_channels", 1) or 1)
|
|
132
|
+
return CodingConv2d(m, **safe_kwargs_c)
|
|
133
|
+
raise
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
# Instrument selected modules by replacing them on their parent.
|
|
137
|
+
for name in target_names:
|
|
138
|
+
try:
|
|
139
|
+
# Skip the top-level module itself
|
|
140
|
+
if name == "":
|
|
141
|
+
continue
|
|
142
|
+
# get_submodule raises on unknown names; ignore gracefully
|
|
143
|
+
sub = model.get_submodule(name)
|
|
144
|
+
except Exception:
|
|
145
|
+
continue
|
|
146
|
+
wrapped = _wrap_module(sub)
|
|
147
|
+
if wrapped is None:
|
|
148
|
+
continue # unsupported type
|
|
149
|
+
parent_path, leaf = _split_parent(name)
|
|
150
|
+
try:
|
|
151
|
+
parent = model.get_submodule(parent_path) if parent_path else model
|
|
152
|
+
except Exception:
|
|
153
|
+
continue
|
|
154
|
+
# Replace on the parent; handle non-identifier names (e.g., Sequential indices)
|
|
155
|
+
if leaf.isidentifier():
|
|
156
|
+
setattr(parent, leaf, wrapped)
|
|
157
|
+
else:
|
|
158
|
+
# Fall back to the internal registry for names like "1" in Sequential
|
|
159
|
+
parent._modules[leaf] = wrapped
|
|
160
|
+
|
|
161
|
+
return model
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _split_parent(path: str) -> tuple[str, str]:
|
|
165
|
+
"""Return (parent_path, leaf) for a dotted module path.
|
|
166
|
+
|
|
167
|
+
Examples
|
|
168
|
+
- "fc1" -> ("", "fc1")
|
|
169
|
+
- "block.1" -> ("block", "1")
|
|
170
|
+
- "encoder.layer.0.attn" -> ("encoder.layer.0", "attn")
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
if "." not in path:
|
|
174
|
+
return "", path
|
|
175
|
+
parent, leaf = path.rsplit(".", 1)
|
|
176
|
+
return parent, leaf
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _select_by_shortcut(model: nn.Module, key: str) -> list[str]:
|
|
180
|
+
"""Return module names selected by a supported shortcut key.
|
|
181
|
+
|
|
182
|
+
Supported keys (MVP):
|
|
183
|
+
- "all_linear": select every ``nn.Linear`` in ``model``.
|
|
184
|
+
- "all_conv": select every ``nn.Conv2d`` in ``model``.
|
|
185
|
+
Unrecognized keys select nothing (fail-open).
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
key_norm = key.strip().lower()
|
|
189
|
+
|
|
190
|
+
# all_* shortcuts: filter across named_modules
|
|
191
|
+
if key_norm == "all_linear":
|
|
192
|
+
return [name for name, m in model.named_modules() if isinstance(m, nn.Linear)]
|
|
193
|
+
if key_norm == "all_conv":
|
|
194
|
+
return [name for name, m in model.named_modules() if isinstance(m, nn.Conv2d)]
|
|
195
|
+
|
|
196
|
+
# first/last convenience: select a single Linear
|
|
197
|
+
if key_norm == "first_linear":
|
|
198
|
+
for name, m in model.named_modules():
|
|
199
|
+
if isinstance(m, nn.Linear) and name:
|
|
200
|
+
return [name]
|
|
201
|
+
return []
|
|
202
|
+
if key_norm == "last_linear":
|
|
203
|
+
last: str | None = None
|
|
204
|
+
for name, m in model.named_modules():
|
|
205
|
+
if isinstance(m, nn.Linear) and name:
|
|
206
|
+
last = name
|
|
207
|
+
return [last] if last else []
|
|
208
|
+
|
|
209
|
+
# first/last convenience for Conv2d
|
|
210
|
+
if key_norm == "first_conv":
|
|
211
|
+
for name, m in model.named_modules():
|
|
212
|
+
if isinstance(m, nn.Conv2d) and name:
|
|
213
|
+
return [name]
|
|
214
|
+
return []
|
|
215
|
+
if key_norm == "last_conv":
|
|
216
|
+
last_c: str | None = None
|
|
217
|
+
for name, m in model.named_modules():
|
|
218
|
+
if isinstance(m, nn.Conv2d) and name:
|
|
219
|
+
last_c = name
|
|
220
|
+
return [last_c] if last_c else []
|
|
221
|
+
|
|
222
|
+
# Unknown shortcut -> empty selection (fail-open)
|
|
223
|
+
return []
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Scoring, aggregation, and calibration.
|
|
2
|
+
|
|
3
|
+
Scoring functions turn per-layer codes and traces into useful signals such as
|
|
4
|
+
surprise estimates. Calibration utilities live here as well.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from .aggregator import max_surprise, mean_surprise, weighted_surprise
|
|
10
|
+
from .calibrator import CalibrationState, EmpiricalPercentileCalibrator
|
|
11
|
+
from .types import AggregatedSurprise
|
|
12
|
+
|
|
13
|
+
__all__: list[str] = [
|
|
14
|
+
"AggregatedSurprise",
|
|
15
|
+
"CalibrationState",
|
|
16
|
+
"EmpiricalPercentileCalibrator",
|
|
17
|
+
"max_surprise",
|
|
18
|
+
"mean_surprise",
|
|
19
|
+
"weighted_surprise",
|
|
20
|
+
]
|