nervecode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,116 @@
1
+ """Core type definitions.
2
+
3
+ This module defines lightweight dataclasses used across the core coding
4
+ pipeline. Shapes follow the convention that probability-like tensors carry
5
+ arbitrary leading dimensions with a final code dimension of size ``K``.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Any
12
+
13
+ __all__ = ["SoftCode"]
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class SoftCode:
18
+ """Soft assignment over a codebook.
19
+
20
+ The ``probs`` tensor encodes probabilities over ``K`` codes for each
21
+ position in the (possibly batched) input. It supports arbitrary leading
22
+ dimensions followed by a final code dimension of shape ``(..., K)``.
23
+
24
+ Fields
25
+ - `probs` (Tensor): probabilities over codes with shape ``(..., K)``.
26
+ - `best_length` (Tensor, optional): best-code codelength per position
27
+ with shape matching the leading dimensions ``...``.
28
+ - `entropy` (Tensor, optional): assignment entropy per position with
29
+ shape matching the leading dimensions ``...``.
30
+ - `best_indices` (Tensor, optional): argmax code indices per position with
31
+ shape matching the leading dimensions ``...``.
32
+ - `combined_surprise` (Tensor, optional): a scalar-like per-position
33
+ surprise that may combine multiple terms; shape matches ``...``.
34
+
35
+ All optional scalar-like fields must match the leading dimensions of
36
+ ``probs`` exactly (i.e., ``tensor.shape == probs.shape[:-1]``). When the
37
+ leading shape is empty, scalar-like fields must be 0-dim tensors.
38
+ """
39
+
40
+ probs: Any
41
+ best_length: Any | None = None
42
+ entropy: Any | None = None
43
+ best_indices: Any | None = None
44
+ combined_surprise: Any | None = None
45
+
46
+ def __post_init__(self) -> None:
47
+ from typing import Any
48
+ from typing import cast as _cast
49
+
50
+ try:
51
+ import torch # runtime import; typing deferred
52
+ except Exception: # pragma: no cover - torch absent in some test envs
53
+ torch = _cast(Any, None)
54
+
55
+ # Validate probs tensor
56
+ if torch is not None and not isinstance(
57
+ self.probs, torch.Tensor
58
+ ): # pragma: no cover - defensive
59
+ raise TypeError("probs must be a torch.Tensor")
60
+ if getattr(self.probs, "ndim", None) is None or self.probs.ndim < 1:
61
+ raise ValueError("probs must have at least one dimension (..., K)")
62
+ k = int(self.probs.shape[-1])
63
+ if k <= 0:
64
+ raise ValueError("final code dimension K must be >= 1")
65
+
66
+ # Optional fields shape validation (only when torch is available)
67
+ if torch is None: # pragma: no cover - validation relies on torch
68
+ return
69
+
70
+ lead_shape = tuple(int(s) for s in self.probs.shape[:-1])
71
+
72
+ def _check_shape(name: str, value: Any) -> None:
73
+ if not isinstance(value, torch.Tensor):
74
+ raise TypeError(f"{name} must be a torch.Tensor if provided")
75
+ if tuple(int(s) for s in value.shape) != lead_shape:
76
+ raise ValueError(
77
+ f"{name} must have shape {lead_shape} to match probs.leading_shape"
78
+ )
79
+
80
+ if self.best_length is not None:
81
+ _check_shape("best_length", self.best_length)
82
+ # float-like recommended; do not hard-enforce dtype beyond non-bool
83
+ if self.best_length.dtype == torch.bool: # pragma: no cover - defensive
84
+ raise TypeError("best_length must be numeric, not bool")
85
+
86
+ if self.entropy is not None:
87
+ _check_shape("entropy", self.entropy)
88
+ if self.entropy.dtype == torch.bool: # pragma: no cover - defensive
89
+ raise TypeError("entropy must be numeric, not bool")
90
+
91
+ if self.combined_surprise is not None:
92
+ _check_shape("combined_surprise", self.combined_surprise)
93
+ if self.combined_surprise.dtype == torch.bool: # pragma: no cover - defensive
94
+ raise TypeError("combined_surprise must be numeric, not bool")
95
+
96
+ if self.best_indices is not None:
97
+ _check_shape("best_indices", self.best_indices)
98
+ # Require integer dtype for indices
99
+ if torch.is_floating_point(self.best_indices) or self.best_indices.dtype == torch.bool:
100
+ raise TypeError("best_indices must use an integer dtype")
101
+
102
+ @property
103
+ def leading_shape(self) -> tuple[int, ...]:
104
+ """Return leading shape before the final code dimension.
105
+
106
+ For ``probs`` with shape ``(..., K)``, this returns ``...`` as a tuple.
107
+ """
108
+
109
+ # mypy struggles with torch.Size as a tuple of ints; cast explicitly.
110
+ return tuple(int(s) for s in self.probs.shape[:-1])
111
+
112
+ @property
113
+ def code_dim(self) -> int:
114
+ """Size of the final code dimension ``K``."""
115
+
116
+ return int(self.probs.shape[-1])
@@ -0,0 +1,9 @@
1
+ """Integration points and adapters.
2
+
3
+ This subpackage will provide adapters and convenience utilities for integrating
4
+ Nervecode into applications, scripts, and logging pipelines.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ __all__: list[str] = []
@@ -0,0 +1,15 @@
1
+ """Layer wrappers and instrumentation shims.
2
+
3
+ This subpackage will contain observe-only wrappers around selected PyTorch
4
+ layers and utilities to apply them to user models.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from .base import BaseCodingWrapper, ReducerFn, identity_reduction
10
+
11
+ __all__: list[str] = [
12
+ "BaseCodingWrapper",
13
+ "ReducerFn",
14
+ "identity_reduction",
15
+ ]
@@ -0,0 +1,333 @@
1
+ """Common wrapper base: bypass, trace cache, reduction, diagnostics hooks.
2
+
3
+ This module defines a small, torch-agnostic base class and companion typing to
4
+ standardize behavior shared by coding wrappers:
5
+
6
+ - Bypass and fail-open controls (enable_coding(), disable_coding(), and the
7
+ bypass() context manager) with correct nesting semantics.
8
+ - Latest per-layer trace caching (last_trace) with explicit update/clear.
9
+ - A light reduction callable type compatible with identity or learned reducers.
10
+ - Diagnostics hook registration that runs safely on trace updates.
11
+
12
+ Wrappers built on top of this class remain observe-only: the base makes no
13
+ assumptions about forward semantics beyond providing helpers that wrappers can
14
+ use after computing the original layer output.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from collections.abc import Callable, Iterator
20
+ from contextlib import contextmanager, suppress
21
+ from typing import Any
22
+
23
+ try: # Optional torch import for typing only; keep runtime torch-agnostic.
24
+ from nervecode.core import CodingTrace # re-exported type
25
+ except Exception: # pragma: no cover - available in normal package use
26
+ CodingTrace = Any # type: ignore[misc,assignment]
27
+
28
+ __all__ = ["BaseCodingWrapper", "ReducerFn", "identity_reduction"]
29
+
30
+
31
+ ReducerFn = Callable[[Any], tuple[Any, dict[str, Any]]]
32
+ """Callable that reduces an activation and returns metadata.
33
+
34
+ Accepts a single activation-like object (typically a tensor with shape
35
+ ``(..., D_out)``) and returns ``(reduced, reduction_meta)`` where ``reduced``
36
+ has shape ``(..., D)`` and ``reduction_meta`` is a JSON-serializable dict that
37
+ identifies the applied reduction (e.g., {"method": "identity"}).
38
+ """
39
+
40
+
41
+ def identity_reduction(x: Any) -> tuple[Any, dict[str, Any]]:
42
+ """Return input unchanged with minimal metadata."""
43
+
44
+ return x, {"method": "identity"}
45
+
46
+
47
+ class BaseCodingWrapper:
48
+ """Torch-agnostic base for coding wrappers.
49
+
50
+ Responsibilities
51
+ - Maintain an enable/disable switch and a nestable bypass context to
52
+ implement fail-open behavior consistently across wrappers.
53
+ - Cache the most recent per-layer trace and run registered diagnostics
54
+ hooks safely when it updates.
55
+ - Provide a reducer slot with an identity default so wrappers can apply
56
+ a dimension reduction before coding without duplicating boilerplate.
57
+ """
58
+
59
+ def __init__(self, *, reducer: ReducerFn | None = None) -> None:
60
+ self._coding_enabled: bool = True
61
+ self._bypass_depth: int = 0
62
+ self._last_trace: CodingTrace | None = None
63
+ # Default to identity reduction when none is provided
64
+ self._reducer: ReducerFn = reducer or identity_reduction
65
+ self._diagnostic_hooks: list[Callable[[CodingTrace], None]] = []
66
+
67
+ # --- Bypass and enable/disable -------------------------------------------------
68
+ @property
69
+ def coding_enabled(self) -> bool:
70
+ """Return whether coding is globally enabled for this wrapper."""
71
+
72
+ return self._coding_enabled
73
+
74
+ @property
75
+ def is_coding_active(self) -> bool:
76
+ """True when coding is enabled and not currently bypassed."""
77
+
78
+ return self._coding_enabled and self._bypass_depth == 0
79
+
80
+ def enable_coding(self) -> None:
81
+ """Enable coding instrumentation for subsequent forwards."""
82
+
83
+ self._coding_enabled = True
84
+
85
+ def disable_coding(self) -> None:
86
+ """Disable coding instrumentation (fail-open) and clear caches.
87
+
88
+ Disabling coding should restore base-model behavior immediately. To
89
+ avoid stale reads from convenience accessors, drop any cached latest
90
+ trace at the time of disabling rather than waiting for the next
91
+ forward pass.
92
+ """
93
+
94
+ self._coding_enabled = False
95
+ # Clear any previously cached trace so that model-level aggregations
96
+ # observe the disabled state immediately.
97
+ self.clear_last_trace()
98
+
99
+ @contextmanager
100
+ def bypass(self) -> Iterator[None]:
101
+ """Temporarily bypass coding instrumentation.
102
+
103
+ The context is nestable; coding is considered bypassed when the bypass
104
+ depth is greater than zero.
105
+ """
106
+
107
+ self._bypass_depth += 1
108
+ try:
109
+ yield
110
+ finally:
111
+ self._bypass_depth -= 1
112
+ if self._bypass_depth < 0: # pragma: no cover - defensive
113
+ self._bypass_depth = 0
114
+
115
+ # --- Reduction ----------------------------------------------------------------
116
+ def set_reducer(self, reducer: ReducerFn) -> None:
117
+ """Set the reducer callable used by ``maybe_reduce``."""
118
+
119
+ self._reducer = reducer
120
+
121
+ def get_reducer(self) -> ReducerFn:
122
+ """Return the current reducer callable."""
123
+
124
+ return self._reducer
125
+
126
+ def maybe_reduce(self, activation: Any) -> tuple[Any, dict[str, Any]]:
127
+ """Apply the configured reducer when coding is active.
128
+
129
+ When coding is not active (disabled or bypassed), return the input
130
+ unchanged with a metadata dictionary indicating bypass. This function
131
+ never raises; wrappers should keep observe-only semantics.
132
+ """
133
+
134
+ if not self.is_coding_active:
135
+ return activation, {"method": "bypass"}
136
+ try:
137
+ return self._reducer(activation)
138
+ except Exception: # Fail-open on reducer errors
139
+ # Leave a recognizable marker to help downstream logging while
140
+ # guaranteeing observe-only behavior.
141
+ return activation, {"method": "bypass", "reason": "reduction_error"}
142
+
143
+ # --- Trace cache and diagnostics ----------------------------------------------
144
+ @property
145
+ def last_trace(self) -> CodingTrace | None:
146
+ """Return the most recently cached trace, if any."""
147
+
148
+ return self._last_trace
149
+
150
+ def clear_last_trace(self) -> None:
151
+ """Drop the cached trace."""
152
+
153
+ self._last_trace = None
154
+
155
+ def update_last_trace(self, trace: CodingTrace) -> None:
156
+ """Cache the latest trace and run diagnostics hooks safely."""
157
+
158
+ self._last_trace = trace
159
+ self._run_diagnostic_hooks(trace)
160
+
161
+ def add_diagnostics_hook(self, hook: Callable[[CodingTrace], None]) -> None:
162
+ """Register a hook invoked on every trace update.
163
+
164
+ Hooks must be side-effect safe; exceptions are suppressed to preserve
165
+ fail-open behavior.
166
+ """
167
+
168
+ self._diagnostic_hooks.append(hook)
169
+
170
+ def remove_diagnostics_hook(self, hook: Callable[[CodingTrace], None]) -> None:
171
+ """Remove a previously registered diagnostics hook, if present."""
172
+
173
+ with suppress(ValueError):
174
+ self._diagnostic_hooks.remove(hook)
175
+
176
+ def _run_diagnostic_hooks(self, trace: CodingTrace) -> None:
177
+ for hook in tuple(self._diagnostic_hooks): # work on a snapshot
178
+ try:
179
+ hook(trace)
180
+ except Exception:
181
+ # Suppress hook exceptions to preserve user forwards; wrappers
182
+ # are observe-only. Hooks are best-effort diagnostics.
183
+ continue
184
+
185
+ # --- Layer-level diagnostics (latest-trace) ---------------------------------
186
+ def utilization(self) -> float | None:
187
+ """Fraction of codes used in the latest trace.
188
+
189
+ Returns the number of unique chosen indices divided by ``K`` from the
190
+ most recent trace, or ``None`` when unavailable. The result is a
191
+ Python ``float`` in ``[0, 1]``.
192
+ """
193
+
194
+ trace = self.last_trace
195
+ if trace is None: # no trace cached
196
+ return None
197
+ try:
198
+ K = int(trace.soft_code.code_dim)
199
+ if K <= 0:
200
+ return None
201
+ # Prefer SoftCode.best_indices when present; otherwise fall back to
202
+ # the hard choice stored on the trace.
203
+ best = trace.soft_code.best_indices
204
+ if best is None:
205
+ best = trace.chosen_center_indices
206
+ values = self._flatten_to_list(best)
207
+ if values is None or K == 0:
208
+ return None
209
+ used = len({int(v) for v in values})
210
+ return float(used) / float(K)
211
+ except Exception:
212
+ return None
213
+
214
+ def mean_entropy(self) -> float | None:
215
+ """Mean assignment entropy from the latest trace, or ``None``.
216
+
217
+ Returns ``soft_code.entropy.mean()`` as a Python ``float`` when a
218
+ cached trace is available and contains entropy; otherwise ``None``.
219
+ """
220
+
221
+ trace = self.last_trace
222
+ if trace is None:
223
+ return None
224
+ entropy = trace.soft_code.entropy
225
+ if entropy is None:
226
+ return None
227
+ return self._mean_as_float(entropy)
228
+
229
+ def mean_code_length(self) -> float | None:
230
+ """Mean best-code length from the latest trace, or ``None``.
231
+
232
+ Returns ``soft_code.best_length.mean()`` as a Python ``float`` when a
233
+ cached trace is available and contains best-code lengths; otherwise
234
+ ``None``.
235
+ """
236
+
237
+ trace = self.last_trace
238
+ if trace is None:
239
+ return None
240
+ best_len = trace.soft_code.best_length
241
+ if best_len is None:
242
+ return None
243
+ return self._mean_as_float(best_len)
244
+
245
+ def mean_commitment_distance(self) -> float | None:
246
+ """Mean commitment distance from the latest trace, or ``None``.
247
+
248
+ Uses the per-position ``commitment_distances`` stored on the trace.
249
+ Returns a Python ``float`` when available, else ``None``.
250
+ """
251
+
252
+ trace = self.last_trace
253
+ if trace is None:
254
+ return None
255
+ d = trace.commitment_distances
256
+ if d is None:
257
+ return None
258
+ return self._mean_as_float(d)
259
+
260
+ @staticmethod
261
+ def _flatten_to_list(x: Any) -> list[float] | None:
262
+ """Best-effort flatten to a Python list of numbers.
263
+
264
+ Supports torch-like objects exposing ``reshape``/``view`` and ``tolist``,
265
+ as well as nested Python sequences.
266
+ """
267
+
268
+ # Torch-like path
269
+ try:
270
+ if hasattr(x, "reshape") and hasattr(x, "tolist"):
271
+ return list(x.reshape(-1).tolist())
272
+ if hasattr(x, "view") and hasattr(x, "tolist"):
273
+ return list(x.view(-1).tolist())
274
+ if hasattr(x, "tolist"):
275
+ out = x.tolist()
276
+ if isinstance(out, list):
277
+ # Shallow list; attempt to flatten one level if needed
278
+ def _flatten(v: Any) -> list[float]:
279
+ if isinstance(v, list):
280
+ res: list[float] = []
281
+ for item in v:
282
+ res.extend(_flatten(item))
283
+ return res
284
+ try:
285
+ return [float(v)]
286
+ except Exception:
287
+ return []
288
+
289
+ return _flatten(out)
290
+ except Exception:
291
+ pass
292
+
293
+ # Python sequence fallback
294
+ if isinstance(x, (list, tuple)):
295
+ res2: list[float] = []
296
+ for item in x:
297
+ sub = BaseCodingWrapper._flatten_to_list(item)
298
+ if sub is not None:
299
+ res2.extend(sub)
300
+ return res2
301
+ # Scalar fallback
302
+ try:
303
+ return [float(x)]
304
+ except Exception:
305
+ return None
306
+
307
+ @staticmethod
308
+ def _mean_as_float(x: Any) -> float | None:
309
+ """Return mean(x) as a float for torch-like or Python sequences.
310
+
311
+ - If ``x`` has a ``mean()`` method, call it and extract ``item()`` if
312
+ available.
313
+ - Otherwise, attempt to flatten to a list and compute an arithmetic
314
+ mean.
315
+ """
316
+
317
+ # Torch-like: x.mean().item()
318
+ try:
319
+ if hasattr(x, "mean"):
320
+ m = x.mean()
321
+ if hasattr(m, "item"):
322
+ return float(m.item())
323
+ return float(m)
324
+ except Exception:
325
+ pass
326
+
327
+ values = BaseCodingWrapper._flatten_to_list(x)
328
+ if not values:
329
+ return None
330
+ try:
331
+ return float(sum(values) / float(len(values)))
332
+ except Exception:
333
+ return None
@@ -0,0 +1,174 @@
1
+ """Coding wrapper for ``torch.nn.Conv2d`` with pooled coding.
2
+
3
+ This module provides ``CodingConv2d`` — an observe-only wrapper around
4
+ ``nn.Conv2d`` that computes coding statistics over a pooled representation of
5
+ the layer output. The visible forward value remains identical to the underlying
6
+ convolution; coding runs side-by-side and updates an internal trace cache.
7
+
8
+ Design constraints for the first convolutional wrapper:
9
+ - Observe-only contract for forwards (fail-open behavior).
10
+ - Coding over a pooled representation (global average pooling over H and W).
11
+ - Default sample-level reducer corresponds to global average pooling semantics.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any
17
+
18
+ from .base import BaseCodingWrapper, ReducerFn
19
+
20
+ # Import torch lazily but tolerate environments without it at import time.
21
+ try: # pragma: no cover - exercised in torch-enabled environments
22
+ import torch
23
+ from torch import nn
24
+ except Exception: # pragma: no cover - allow import without torch
25
+ from typing import Any
26
+ from typing import cast as _cast
27
+
28
+ torch = _cast(Any, None)
29
+ nn = _cast(Any, None)
30
+
31
+ from nervecode.core import CodingTrace
32
+ from nervecode.core.assignment import SoftAssignment
33
+ from nervecode.core.codebook import Codebook
34
+
35
+ __all__ = ["CodingConv2d"]
36
+
37
+
38
+ class CodingConv2d(nn.Module, BaseCodingWrapper):
39
+ """Observe-only wrapper around ``nn.Conv2d`` with pooled coding.
40
+
41
+ The wrapper delegates computation to the underlying ``nn.Conv2d`` and,
42
+ when coding is active, computes a ``CodingTrace`` over a globally pooled
43
+ representation of the visible activation. The visible forward output is
44
+ never altered.
45
+
46
+ Parameters
47
+ - layer: The ``nn.Conv2d`` module to wrap.
48
+ - K: Number of codebook centers. Defaults to 16.
49
+ - coding_dim: Target coding dimension ``D``. When ``None`` (default), uses
50
+ ``layer.out_channels``. If provided, must satisfy ``1 <= coding_dim <=
51
+ layer.out_channels``; when smaller than ``out_channels``, a learned linear
52
+ projection is applied after global average pooling, affecting only the
53
+ coding path (observe-only contract preserved).
54
+ - reducer: Optional reducer callable. Overrides the default global-average
55
+ pooling behavior when provided.
56
+ - temperature: Soft-assignment temperature. Higher is softer.
57
+ - eps: Numerical epsilon to clamp divisions and logs.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ layer: Any, # use Any to avoid a hard torch dependency in type hints
63
+ *,
64
+ K: int = 16,
65
+ coding_dim: int | None = None,
66
+ reducer: ReducerFn | None = None,
67
+ temperature: float = 1.0,
68
+ eps: float = 1e-12,
69
+ ) -> None:
70
+ if torch is None: # pragma: no cover - defensive
71
+ raise RuntimeError("CodingConv2d requires PyTorch to be installed")
72
+
73
+ nn.Module.__init__(self)
74
+ BaseCodingWrapper.__init__(self, reducer=reducer)
75
+
76
+ if not isinstance(layer, nn.Conv2d):
77
+ raise TypeError("layer must be an instance of torch.nn.Conv2d")
78
+
79
+ self.layer: Any = layer
80
+
81
+ out_channels = int(self.layer.out_channels)
82
+ if coding_dim is None:
83
+ code_dim = out_channels
84
+ else:
85
+ code_dim = int(coding_dim)
86
+ if code_dim <= 0:
87
+ raise ValueError("coding_dim must be >= 1")
88
+ if code_dim > out_channels:
89
+ raise ValueError("coding_dim must be <= layer.out_channels for reduction")
90
+
91
+ # Optional learned projection after pooling when reducing channels.
92
+ self._projection: Any | None = None
93
+ if reducer is None and code_dim < out_channels:
94
+ # Use a linear projection over the pooled channel vector (B, C) -> (B, D).
95
+ self._projection = nn.Linear(out_channels, code_dim, bias=False)
96
+
97
+ def _pool_then_project(x: Any) -> tuple[Any, dict[str, Any]]:
98
+ # Global average pooling over spatial dims (H, W)
99
+ pooled = x.mean(dim=(-1, -2)) # (B, C)
100
+ y_proj = self._projection(pooled) # type: ignore[misc]
101
+ meta = {
102
+ "method": "global_avg_pool2d+linear_projection",
103
+ "in": out_channels,
104
+ "out": code_dim,
105
+ # Record that coding view comes from spatial reduction
106
+ "spatial_reduction": True,
107
+ "reduction_axes": [-2, -1], # H, W in NCHW
108
+ }
109
+ return y_proj, meta
110
+
111
+ self.set_reducer(_pool_then_project)
112
+ elif reducer is None:
113
+ # Default: global average pooling over spatial dims to (B, C)
114
+ def _global_avg_pool2d(x: Any) -> tuple[Any, dict[str, Any]]:
115
+ reduced = x.mean(dim=(-1, -2)) # (B, C)
116
+ return reduced, {
117
+ "method": "global_avg_pool2d",
118
+ "spatial_reduction": True,
119
+ "reduction_axes": [-2, -1], # H, W in NCHW
120
+ }
121
+
122
+ self.set_reducer(_global_avg_pool2d)
123
+ # else: keep user-provided reducer
124
+
125
+ # Codebook over final feature dimension D=code_dim
126
+ self.codebook = Codebook(K=int(K), code_dim=code_dim)
127
+ self.assignment = SoftAssignment(temperature=float(temperature), eps=float(eps))
128
+
129
+ def forward(self, x: Any) -> Any: # torch.Tensor in torch-enabled envs
130
+ """Return the wrapped layer's output and update the cached trace.
131
+
132
+ Preserves the observe-only contract: the visible value is the
133
+ underlying layer output. For programmatic use, prefer
134
+ ``forward_with_trace`` to retrieve the explicit ``CodingTrace``
135
+ alongside the output.
136
+ """
137
+
138
+ y, _ = self.forward_with_trace(x)
139
+ return y
140
+
141
+ def forward_with_trace(self, x: Any) -> tuple[Any, CodingTrace | None]:
142
+ """Return the wrapped output and, when active, the explicit trace."""
143
+
144
+ y = self.layer(x)
145
+
146
+ if not self.is_coding_active:
147
+ self.clear_last_trace()
148
+ return y, None
149
+
150
+ # Apply configured reducer (default: global average pooling over H, W).
151
+ reduced, reduction_meta = self.maybe_reduce(y)
152
+
153
+ # Compute assignments over the pooled representation.
154
+ soft, inter = self.assignment(reduced, self.codebook)
155
+
156
+ trace = CodingTrace(
157
+ reduced=reduced,
158
+ reduction_meta=reduction_meta,
159
+ nearest_center_distances=inter["nearest_center_distances"],
160
+ chosen_center_indices=inter["chosen_center_indices"],
161
+ commitment_distances=inter["commitment_distances"],
162
+ soft_code=soft,
163
+ )
164
+ self.update_last_trace(trace)
165
+ return y, trace
166
+
167
+ def extra_repr(self) -> str: # pragma: no cover - cosmetic
168
+ status = "on" if self.coding_enabled else "off"
169
+ return (
170
+ f"layer=Conv2d(out_channels={int(self.layer.out_channels)}), "
171
+ f"code_dim={int(self.codebook.code_dim)}, "
172
+ f"K={int(self.codebook.K)}, reducer={self.get_reducer().__name__}, "
173
+ f"coding={status}"
174
+ )