nervecode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,176 @@
1
+ """Coding wrapper for ``torch.nn.Linear``.
2
+
3
+ This module provides ``CodingLinear`` — the first production wrapper that
4
+ observes a linear layer's output, reduces it (identity in the MVP), and
5
+ computes a coding trace via a learnable codebook and a soft assignment engine.
6
+
7
+ Observe-only contract: the wrapper never modifies the value returned to the
8
+ caller; coding runs side-by-side and updates an internal trace cache.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any
14
+
15
+ from .base import BaseCodingWrapper, ReducerFn
16
+
17
+ # Import torch lazily but with typing support. Keep runtime import optional so the
18
+ # package remains importable in environments without torch (tests will skip).
19
+ try: # pragma: no cover - exercised in torch-enabled environments
20
+ import torch
21
+ from torch import nn
22
+ except Exception: # pragma: no cover - allow import without torch
23
+ from typing import Any
24
+ from typing import cast as _cast
25
+
26
+ torch = _cast(Any, None)
27
+ nn = _cast(Any, None)
28
+
29
+ # Local core imports guarded by torch availability at use sites.
30
+ from nervecode.core import CodingTrace
31
+ from nervecode.core.assignment import SoftAssignment
32
+ from nervecode.core.codebook import Codebook
33
+
34
+ __all__ = ["CodingLinear"]
35
+
36
+
37
+ class CodingLinear(nn.Module, BaseCodingWrapper):
38
+ """Observe-only wrapper around ``nn.Linear``.
39
+
40
+ The wrapper delegates computation to the underlying ``nn.Linear`` and, when
41
+ coding is active, computes a ``CodingTrace`` over the visible activation.
42
+ The visible forward output is never altered.
43
+
44
+ Parameters
45
+ - layer: The ``nn.Linear`` module to wrap.
46
+ - K: Number of codebook centers. Defaults to 16.
47
+ - coding_dim: Target coding dimension ``D``. When ``None`` (default),
48
+ uses ``layer.out_features``. If provided, must satisfy
49
+ ``1 <= coding_dim <= layer.out_features``. A learned linear projection
50
+ reducer is created automatically when ``layer.out_features > coding_dim``;
51
+ otherwise the identity reduction is used.
52
+ - reducer: Optional reducer callable. Overrides ``coding_dim`` behavior
53
+ when provided. Defaults to identity (no reduction).
54
+ - temperature: Soft-assignment temperature. Higher is softer.
55
+ - eps: Numerical epsilon to clamp divisions and logs.
56
+
57
+ Notes
58
+ - MVP uses identity reduction; when a lower coding dimension is desired, a
59
+ learned projection reducer will be added in a follow-up task.
60
+ - ``repr(wrapper)`` includes ``code_dim``, codebook size ``K``, reducer type,
61
+ and whether coding is enabled to make printed models self-descriptive.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ layer: Any, # use Any to avoid a hard torch dependency in type hints
67
+ *,
68
+ K: int = 16,
69
+ coding_dim: int | None = None,
70
+ reducer: ReducerFn | None = None,
71
+ temperature: float = 1.0,
72
+ eps: float = 1e-12,
73
+ ) -> None:
74
+ if torch is None: # pragma: no cover - defensive
75
+ raise RuntimeError("CodingLinear requires PyTorch to be installed")
76
+
77
+ nn.Module.__init__(self)
78
+ BaseCodingWrapper.__init__(self, reducer=reducer)
79
+
80
+ if not isinstance(layer, nn.Linear):
81
+ raise TypeError("layer must be an instance of torch.nn.Linear")
82
+
83
+ self.layer: Any = layer
84
+
85
+ # Decide coding dimension and set up reducer if not provided.
86
+ out_features = int(self.layer.out_features)
87
+ if coding_dim is None:
88
+ code_dim = out_features
89
+ else:
90
+ code_dim = int(coding_dim)
91
+ if code_dim <= 0:
92
+ raise ValueError("coding_dim must be >= 1")
93
+ if code_dim > out_features:
94
+ raise ValueError("coding_dim must be <= layer.out_features for reduction")
95
+
96
+ # Optional learned projection reducer when output width exceeds coding dim.
97
+ self._projection: Any | None = None
98
+ if reducer is None and code_dim < out_features:
99
+ # Register a submodule to learn the reduction; keep observe-only semantics
100
+ # by applying it only for coding. Bias is unnecessary for dimension change.
101
+ self._projection = nn.Linear(out_features, code_dim, bias=False)
102
+
103
+ def _project_reducer(x: Any) -> tuple[Any, dict[str, Any]]:
104
+ y_proj = self._projection(x) # type: ignore[misc]
105
+ meta = {"method": "linear_projection", "in": out_features, "out": code_dim}
106
+ return y_proj, meta
107
+
108
+ self.set_reducer(_project_reducer)
109
+ # else: keep provided reducer or identity_reduction
110
+
111
+ # Codebook operates over the last (feature) dimension using the decided code_dim.
112
+ self.codebook = Codebook(K=int(K), code_dim=code_dim)
113
+ self.assignment = SoftAssignment(temperature=float(temperature), eps=float(eps))
114
+
115
+ def forward(self, x: Any) -> Any: # torch.Tensor in torch-enabled envs
116
+ """Return the wrapped layer's output and update the cached trace.
117
+
118
+ This method preserves the observe-only contract: the visible value is
119
+ the underlying layer output. For robust programmatic use that avoids
120
+ reading from mutable state, use ``forward_with_trace`` to retrieve the
121
+ explicit ``CodingTrace`` alongside the output.
122
+ """
123
+
124
+ y, _ = self.forward_with_trace(x)
125
+ return y
126
+
127
+ def forward_with_trace(self, x: Any) -> tuple[Any, CodingTrace | None]:
128
+ """Return the wrapped output and, when active, the explicit trace.
129
+
130
+ When coding is disabled or bypassed, returns ``(y, None)`` while
131
+ clearing any previously cached trace to avoid confusion. When active,
132
+ computes a fresh ``CodingTrace``, updates the internal cache, and
133
+ returns ``(y, trace)``.
134
+ """
135
+
136
+ y = self.layer(x)
137
+
138
+ # Fast path: skip all coding work when disabled or bypassed.
139
+ if not self.is_coding_active:
140
+ # Clear stale trace to avoid confusion when toggling modes.
141
+ self.clear_last_trace()
142
+ return y, None
143
+
144
+ # Apply the configured reducer (identity by default in MVP).
145
+ reduced, reduction_meta = self.maybe_reduce(y)
146
+
147
+ # Compute soft assignments and gather intermediates for the trace.
148
+ soft, inter = self.assignment(reduced, self.codebook)
149
+
150
+ trace = CodingTrace(
151
+ reduced=reduced,
152
+ reduction_meta=reduction_meta,
153
+ nearest_center_distances=inter["nearest_center_distances"],
154
+ chosen_center_indices=inter["chosen_center_indices"],
155
+ commitment_distances=inter["commitment_distances"],
156
+ soft_code=soft,
157
+ )
158
+ self.update_last_trace(trace)
159
+ return y, trace
160
+
161
+ def extra_repr(self) -> str: # pragma: no cover - cosmetic
162
+ """Human-friendly summary shown inside ``nn.Module.__repr__``.
163
+
164
+ Includes the coding-space dimension (``code_dim``), codebook size
165
+ (``K``), the configured reducer type, and whether coding is currently
166
+ enabled. Also echoes the wrapped linear layer's ``out_features`` for
167
+ quick visual alignment.
168
+ """
169
+
170
+ status = "on" if self.coding_enabled else "off"
171
+ return (
172
+ f"layer=Linear(out_features={int(self.layer.out_features)}), "
173
+ f"code_dim={int(self.codebook.code_dim)}, "
174
+ f"K={int(self.codebook.K)}, reducer={self.get_reducer().__name__}, "
175
+ f"coding={status}"
176
+ )
@@ -0,0 +1,80 @@
1
+ """Experimental reducers for convolutional wrappers.
2
+
3
+ This module provides optional, opt-in reducers that can be passed to
4
+ ``CodingConv2d`` (or any wrapper built on ``BaseCodingWrapper``) to prototype
5
+ alternative spatial reductions beyond the default global average pooling.
6
+
7
+ The helpers return callables that match ``ReducerFn`` so they can be supplied
8
+ via the existing ``reducer=...`` parameter without changing stable defaults.
9
+
10
+ Notes
11
+ - These reducers are experimental and off by default. They preserve the
12
+ observe-only contract: reductions are applied only on the coding path.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any
18
+
19
+ from .base import ReducerFn
20
+
21
+ __all__ = [
22
+ "global_max_pool2d",
23
+ "spatial_tokens_identity_channels",
24
+ ]
25
+
26
+
27
+ def global_max_pool2d() -> ReducerFn:
28
+ """Return a reducer that applies global max pooling over H and W.
29
+
30
+ Expects a 4D tensor with NCHW layout (``(B, C, H, W)``) and returns a
31
+ pooled tensor of shape ``(B, C)``. Metadata records the reduction method
32
+ and axes.
33
+ """
34
+
35
+ def _reduce(x: Any) -> tuple[Any, dict[str, Any]]:
36
+ # Require a 4D tensor-like object with an `amax` reduction
37
+ if getattr(x, "ndim", 0) != 4 or not hasattr(x, "amax"):
38
+ raise ValueError("global_max_pool2d expects a 4D NCHW tensor from Conv2d (B, C, H, W)")
39
+ # Max over spatial dims (H, W) -> (B, C)
40
+ reduced = x.amax(dim=(-1, -2))
41
+ meta = {
42
+ "method": "global_max_pool2d",
43
+ "spatial_reduction": True,
44
+ "reduction_axes": [-2, -1], # H, W in NCHW
45
+ }
46
+ return reduced, meta
47
+
48
+ # Return the callable reducer
49
+ return _reduce
50
+
51
+
52
+ def spatial_tokens_identity_channels() -> ReducerFn:
53
+ """Return a reducer that exposes per-spatial-position tokens with C last.
54
+
55
+ Transforms a standard Conv2d output ``(B, C, H, W)`` into a token-like
56
+ layout ``(B, H, W, C)`` without reducing channels. This allows computing
57
+ assignments per spatial position while keeping observe-only semantics.
58
+
59
+ The metadata clarifies that the view is token-like (no spatial reduction).
60
+ """
61
+
62
+ def _reduce(x: Any) -> tuple[Any, dict[str, Any]]:
63
+ # Require a 4D tensor-like object with `permute` and `contiguous`
64
+ if getattr(x, "ndim", 0) != 4 or not hasattr(x, "permute"):
65
+ raise ValueError(
66
+ "spatial_tokens_identity_channels expects a 4D NCHW tensor (B, C, H, W)"
67
+ )
68
+ # Permute to BHWC so that the feature dimension is last for coding.
69
+ reduced = x.permute(0, 2, 3, 1).contiguous()
70
+ meta = {
71
+ "method": "spatial_tokens_identity",
72
+ "token_like": True,
73
+ "channels_last": True,
74
+ "spatial_reduction": False,
75
+ "source_layout": "NCHW",
76
+ "target_layout": "NHWC",
77
+ }
78
+ return reduced, meta
79
+
80
+ return _reduce
@@ -0,0 +1,223 @@
1
+ """Model instrumentation helpers for coding wrappers.
2
+
3
+ This module provides a small `wrap()` function that can instrument a PyTorch
4
+ `nn.Module` in-place based on either:
5
+
6
+ - explicit dotted module names (e.g., ["encoder.fc", "head.1"]), or
7
+ - supported selection shortcuts (MVP: "all_linear").
8
+
9
+ Only supported layer types are wrapped; others are left unchanged to preserve
10
+ fail-open behavior. The function returns the same model instance for
11
+ convenience.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from collections.abc import Iterable
17
+ from typing import Any
18
+
19
+ try: # Keep runtime import optional for environments without torch
20
+ import torch
21
+ from torch import nn
22
+ except Exception: # pragma: no cover - import-time optionality
23
+ from typing import Any
24
+ from typing import cast as _cast
25
+
26
+ torch = _cast(Any, None)
27
+ nn = _cast(Any, None)
28
+
29
+ from .conv import CodingConv2d
30
+ from .linear import CodingLinear
31
+
32
+ __all__ = ["wrap"]
33
+
34
+
35
+ def wrap(
36
+ model: Any,
37
+ *,
38
+ layers: str | Iterable[str] = "all_linear",
39
+ K: int | None = None,
40
+ coding_dim: int | None = None,
41
+ temperature: float = 1.0,
42
+ eps: float = 1e-12,
43
+ ) -> Any:
44
+ """Instrument ``model`` in-place with coding wrappers.
45
+
46
+ Parameters
47
+ - model: A PyTorch ``nn.Module`` to instrument. The function modifies the
48
+ module tree in-place and returns the same instance for convenience.
49
+ - layers: Either a selection shortcut string (MVP supports only
50
+ ``"all_linear"``) or an iterable of explicit dotted module names to
51
+ wrap. Names must match those returned by ``model.named_modules()``.
52
+
53
+ Behavior
54
+ - Only supported layer types are wrapped. Currently this includes
55
+ ``nn.Linear`` via ``CodingLinear``. Unsupported selections are ignored to
56
+ preserve fail-open behavior.
57
+ - Missing names are ignored.
58
+ - On environments without PyTorch, the function is a no-op and returns the
59
+ model unchanged.
60
+ """
61
+
62
+ if nn is None: # pragma: no cover - defensive; torch is a dependency in tests
63
+ return model
64
+
65
+ # Validate basic interface; fall back to no-op when model is not an nn.Module
66
+ if not isinstance(model, nn.Module):
67
+ return model
68
+
69
+ # Resolve target names to instrument
70
+ if isinstance(layers, str):
71
+ target_names = _select_by_shortcut(model, layers)
72
+ else:
73
+ # Deduplicate while preserving order
74
+ seen: set[str] = set()
75
+ target_names = []
76
+ for name in layers:
77
+ if not isinstance(name, str):
78
+ continue
79
+ if name not in seen:
80
+ seen.add(name)
81
+ target_names.append(name)
82
+
83
+ # Map types to wrapper factories; extend as new wrappers are added.
84
+ def _wrap_module(m: nn.Module) -> nn.Module | None:
85
+ if isinstance(m, nn.Linear):
86
+ kwargs: dict[str, Any] = {
87
+ "temperature": float(temperature),
88
+ "eps": float(eps),
89
+ }
90
+ if K is not None:
91
+ kwargs["K"] = int(K)
92
+ if coding_dim is not None:
93
+ # Clamp coding_dim per layer to avoid invalid configuration
94
+ try:
95
+ cd_req = int(coding_dim)
96
+ cd_eff = max(1, min(int(m.out_features), cd_req))
97
+ except Exception:
98
+ cd_eff = int(m.out_features)
99
+ kwargs["coding_dim"] = cd_eff
100
+ # Defensive: if downstream validation still rejects coding_dim,
101
+ # fall back to using the layer's out_features to avoid crashes.
102
+ try:
103
+ return CodingLinear(m, **kwargs)
104
+ except Exception as exc:
105
+ msg = str(exc)
106
+ if "coding_dim must be <= layer.out_features" in msg:
107
+ safe_kwargs = dict(kwargs)
108
+ safe_kwargs["coding_dim"] = int(getattr(m, "out_features", 1) or 1)
109
+ return CodingLinear(m, **safe_kwargs)
110
+ raise
111
+ if isinstance(m, nn.Conv2d):
112
+ kwargs_c: dict[str, Any] = {
113
+ "temperature": float(temperature),
114
+ "eps": float(eps),
115
+ }
116
+ if K is not None:
117
+ kwargs_c["K"] = int(K)
118
+ if coding_dim is not None:
119
+ try:
120
+ cd_req = int(coding_dim)
121
+ cd_eff = max(1, min(int(m.out_channels), cd_req))
122
+ except Exception:
123
+ cd_eff = int(m.out_channels)
124
+ kwargs_c["coding_dim"] = cd_eff
125
+ try:
126
+ return CodingConv2d(m, **kwargs_c)
127
+ except Exception as exc:
128
+ msg = str(exc)
129
+ if "coding_dim must be <=" in msg:
130
+ safe_kwargs_c = dict(kwargs_c)
131
+ safe_kwargs_c["coding_dim"] = int(getattr(m, "out_channels", 1) or 1)
132
+ return CodingConv2d(m, **safe_kwargs_c)
133
+ raise
134
+ return None
135
+
136
+ # Instrument selected modules by replacing them on their parent.
137
+ for name in target_names:
138
+ try:
139
+ # Skip the top-level module itself
140
+ if name == "":
141
+ continue
142
+ # get_submodule raises on unknown names; ignore gracefully
143
+ sub = model.get_submodule(name)
144
+ except Exception:
145
+ continue
146
+ wrapped = _wrap_module(sub)
147
+ if wrapped is None:
148
+ continue # unsupported type
149
+ parent_path, leaf = _split_parent(name)
150
+ try:
151
+ parent = model.get_submodule(parent_path) if parent_path else model
152
+ except Exception:
153
+ continue
154
+ # Replace on the parent; handle non-identifier names (e.g., Sequential indices)
155
+ if leaf.isidentifier():
156
+ setattr(parent, leaf, wrapped)
157
+ else:
158
+ # Fall back to the internal registry for names like "1" in Sequential
159
+ parent._modules[leaf] = wrapped
160
+
161
+ return model
162
+
163
+
164
+ def _split_parent(path: str) -> tuple[str, str]:
165
+ """Return (parent_path, leaf) for a dotted module path.
166
+
167
+ Examples
168
+ - "fc1" -> ("", "fc1")
169
+ - "block.1" -> ("block", "1")
170
+ - "encoder.layer.0.attn" -> ("encoder.layer.0", "attn")
171
+ """
172
+
173
+ if "." not in path:
174
+ return "", path
175
+ parent, leaf = path.rsplit(".", 1)
176
+ return parent, leaf
177
+
178
+
179
+ def _select_by_shortcut(model: nn.Module, key: str) -> list[str]:
180
+ """Return module names selected by a supported shortcut key.
181
+
182
+ Supported keys (MVP):
183
+ - "all_linear": select every ``nn.Linear`` in ``model``.
184
+ - "all_conv": select every ``nn.Conv2d`` in ``model``.
185
+ Unrecognized keys select nothing (fail-open).
186
+ """
187
+
188
+ key_norm = key.strip().lower()
189
+
190
+ # all_* shortcuts: filter across named_modules
191
+ if key_norm == "all_linear":
192
+ return [name for name, m in model.named_modules() if isinstance(m, nn.Linear)]
193
+ if key_norm == "all_conv":
194
+ return [name for name, m in model.named_modules() if isinstance(m, nn.Conv2d)]
195
+
196
+ # first/last convenience: select a single Linear
197
+ if key_norm == "first_linear":
198
+ for name, m in model.named_modules():
199
+ if isinstance(m, nn.Linear) and name:
200
+ return [name]
201
+ return []
202
+ if key_norm == "last_linear":
203
+ last: str | None = None
204
+ for name, m in model.named_modules():
205
+ if isinstance(m, nn.Linear) and name:
206
+ last = name
207
+ return [last] if last else []
208
+
209
+ # first/last convenience for Conv2d
210
+ if key_norm == "first_conv":
211
+ for name, m in model.named_modules():
212
+ if isinstance(m, nn.Conv2d) and name:
213
+ return [name]
214
+ return []
215
+ if key_norm == "last_conv":
216
+ last_c: str | None = None
217
+ for name, m in model.named_modules():
218
+ if isinstance(m, nn.Conv2d) and name:
219
+ last_c = name
220
+ return [last_c] if last_c else []
221
+
222
+ # Unknown shortcut -> empty selection (fail-open)
223
+ return []
@@ -0,0 +1,20 @@
1
+ """Scoring, aggregation, and calibration.
2
+
3
+ Scoring functions turn per-layer codes and traces into useful signals such as
4
+ surprise estimates. Calibration utilities live here as well.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from .aggregator import max_surprise, mean_surprise, weighted_surprise
10
+ from .calibrator import CalibrationState, EmpiricalPercentileCalibrator
11
+ from .types import AggregatedSurprise
12
+
13
+ __all__: list[str] = [
14
+ "AggregatedSurprise",
15
+ "CalibrationState",
16
+ "EmpiricalPercentileCalibrator",
17
+ "max_surprise",
18
+ "mean_surprise",
19
+ "weighted_surprise",
20
+ ]