nervecode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nervecode/__init__.py ADDED
@@ -0,0 +1,415 @@
1
+ """Nervecode package root and public API stubs.
2
+
3
+ The real implementation will land incrementally following the MVP plan. For now,
4
+ we expose the intended top-level API so users and tooling can rely on the package
5
+ surface existing from the start.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import Iterable, Iterator
11
+ from contextlib import ExitStack, contextmanager, suppress
12
+ from typing import Any
13
+
14
+ # Import base wrapper type for lightweight wrapper tracking without a hard
15
+ # torch dependency at import time.
16
+ try: # pragma: no cover - available in normal package use
17
+ from nervecode.layers.base import BaseCodingWrapper
18
+ except Exception: # pragma: no cover - fallback for environments without package context
19
+ BaseCodingWrapper = object # type: ignore[misc,assignment]
20
+
21
+
22
+ class WrappedModel:
23
+ """Thin container around a base model that tracks coding layers.
24
+
25
+ Goals
26
+ - Preserve the original model API: attribute access and ``__call__``
27
+ delegate to the wrapped model so existing code keeps working.
28
+ - Track inserted coding layers (wrappers) to support convenience
29
+ iterators and aggregations implemented in subsequent tasks.
30
+ """
31
+
32
+ # Keep the constructor torch-agnostic and accept any object. When the
33
+ # object is a PyTorch ``nn.Module``, wrapper discovery uses
34
+ # ``named_modules()``; otherwise, tracking stays empty.
35
+ def __init__(self, model: object) -> None:
36
+ # Store the wrapped model reference for external access and tests.
37
+ self._model = model
38
+ # Internal registry of discovered coding wrappers keyed by dotted name.
39
+ self._wrapped_layers: dict[str, BaseCodingWrapper] = {}
40
+ # Convenience caches populated on forward() when coding is active.
41
+ # Latest visible model output from the most recent forward.
42
+ self._last_output: Any | None = None
43
+ # Latest per-layer traces keyed by dotted module name.
44
+ # Values are intentionally typed as Any to keep this module torch-agnostic;
45
+ # concrete type is nervecode.core.CodingTrace in torch-enabled envs.
46
+ self._last_layer_traces: dict[str, Any] = {}
47
+ # Optional calibrator attached after a calibration run.
48
+ self._calibrator: Any | None = None
49
+ self._refresh_wrapped_layers()
50
+
51
+ # --- Delegation to preserve the base model API -----------------------------
52
+ def __getattr__(self, name: str) -> Any: # pragma: no cover - exercised via tests
53
+ # Delegate attribute access to the wrapped model when not found on self.
54
+ try:
55
+ return getattr(self._model, name)
56
+ except Exception as exc: # propagate standard AttributeError semantics
57
+ raise exc
58
+
59
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
60
+ """Delegate calls to the wrapped model's ``__call__``.
61
+
62
+ This keeps the container usable anywhere the base model is expected.
63
+ The explicit ``forward()`` behavior is implemented in a later task.
64
+ """
65
+
66
+ model = self._model
67
+ # Fast path: call the model if it is callable.
68
+ if callable(model):
69
+ return model(*args, **kwargs)
70
+ # Fallback: attempt a direct ``forward`` attribute.
71
+ fwd = getattr(model, "forward", None)
72
+ if callable(fwd):
73
+ return fwd(*args, **kwargs)
74
+ raise TypeError("Wrapped base model is not callable and has no forward()")
75
+
76
+ # --- Common nn.Module-style helpers that preserve chaining semantics -------
77
+ def train(self, mode: bool = True) -> WrappedModel: # pragma: no cover - exercised indirectly
78
+ """Set underlying model to train/eval mode and return self for chaining.
79
+
80
+ Mirrors ``torch.nn.Module.train`` semantics so that patterns like
81
+ ``WrappedModel(model).train()`` keep returning a wrapper rather than the
82
+ bare model (which would drop wrapper methods).
83
+ """
84
+
85
+ m = getattr(self._model, "train", None)
86
+ if callable(m):
87
+ with suppress(Exception):
88
+ m(bool(mode))
89
+ return self
90
+
91
+ def eval(self) -> WrappedModel: # pragma: no cover - exercised indirectly
92
+ """Switch to evaluation mode and return self for chaining."""
93
+ return self.train(False)
94
+
95
+ def forward(self, *args: Any, **kwargs: Any) -> object:
96
+ """Run a forward pass through the wrapped base model.
97
+
98
+ Returns the base model output. In the MVP, this will also collect coding
99
+ statistics when enabled.
100
+ """
101
+ # Ensure our registry reflects any instrumentation that may have been
102
+ # applied after construction.
103
+ self._refresh_wrapped_layers()
104
+
105
+ # Clear last-forward convenience caches before computing a new result.
106
+ self._last_output = None
107
+ self._last_layer_traces = {}
108
+
109
+ # Delegate to the underlying model using its normal call semantics.
110
+ model = self._model
111
+ if callable(model):
112
+ out = model(*args, **kwargs)
113
+ else:
114
+ fwd = getattr(model, "forward", None)
115
+ if callable(fwd):
116
+ out = fwd(*args, **kwargs)
117
+ else:
118
+ raise TypeError("Wrapped base model is not callable and has no forward()")
119
+
120
+ # Update convenience caches: store output and collect latest per-layer traces
121
+ # from wrapped modules when coding is enabled at the layer level.
122
+ self._last_output = out
123
+ try:
124
+ traces: dict[str, Any] = {}
125
+ for name, wrapper in self._wrapped_layers.items():
126
+ # Each wrapper manages its own fail-open behavior and cache.
127
+ trace = getattr(wrapper, "last_trace", None)
128
+ if trace is not None:
129
+ traces[name] = trace
130
+ self._last_layer_traces = traces
131
+ except Exception:
132
+ # Fail-open: keep caches best-effort without affecting user forwards.
133
+ self._last_layer_traces = {}
134
+
135
+ return out
136
+
137
+ # --- Model-level fail-open controls ----------------------------------------
138
+ def enable_coding(self) -> None:
139
+ """Enable coding on all discovered wrapped layers.
140
+
141
+ Safe to call at any time; silently does nothing when no wrappers are
142
+ present or when the wrapped object is not a PyTorch module.
143
+ """
144
+
145
+ self._refresh_wrapped_layers()
146
+ try:
147
+ for wrapper in self._wrapped_layers.values():
148
+ wrapper.enable_coding()
149
+ except Exception:
150
+ # Fail-open: best-effort delegation only
151
+ return
152
+
153
+ def disable_coding(self) -> None:
154
+ """Disable coding on all discovered wrapped layers (fail-open).
155
+
156
+ Clearing any cached per-layer traces ensures that callers observing
157
+ model-level caches see the disabled state immediately, without having
158
+ to run another forward first.
159
+ """
160
+
161
+ self._refresh_wrapped_layers()
162
+ try:
163
+ for wrapper in self._wrapped_layers.values():
164
+ wrapper.disable_coding()
165
+ # Ensure stale traces are not exposed after disabling.
166
+ if hasattr(wrapper, "clear_last_trace"):
167
+ wrapper.clear_last_trace()
168
+ except Exception:
169
+ # Fail-open: best-effort delegation only
170
+ pass
171
+ # Clear model-level convenience cache as well.
172
+ self._last_layer_traces = {}
173
+
174
+ @contextmanager
175
+ def bypass(self) -> Iterator[None]:
176
+ """Temporarily bypass coding across all wrapped layers.
177
+
178
+ The context nests correctly and delegates to every discovered wrapper's
179
+ own bypass context. Safe when no wrappers are present.
180
+ """
181
+
182
+ self._refresh_wrapped_layers()
183
+ with ExitStack() as stack:
184
+ for wrapper in self._wrapped_layers.values():
185
+ # Each wrapper manages its own nesting depth.
186
+ stack.enter_context(wrapper.bypass())
187
+ yield
188
+
189
+ # Training-time auxiliary loss used to learn codebooks.
190
+ def coding_loss(self) -> Any:
191
+ """Compute coding loss from the latest per-layer traces.
192
+
193
+ Uses the most recent per-layer traces collected during ``forward`` and
194
+ computes a scalar training-time loss via ``nervecode.training.CodingLoss``.
195
+ When no traces are available, raises a helpful error explaining how to
196
+ proceed (e.g., instrument the model, enable coding, and run a forward
197
+ pass, or use the explicit trace path and call ``CodingLoss`` directly).
198
+ """
199
+
200
+ # Lazy import to keep package importable without torch in non-training contexts
201
+ try:
202
+ from nervecode.training import CodingLoss
203
+ except Exception as exc: # pragma: no cover - defensive
204
+ raise RuntimeError(
205
+ "CodingLoss is unavailable. Ensure PyTorch and training components are installed."
206
+ ) from exc
207
+
208
+ traces = getattr(self, "_last_layer_traces", {}) or {}
209
+ if not traces:
210
+ # Diagnose common causes and provide actionable guidance.
211
+ wrappers = getattr(self, "_wrapped_layers", {}) or {}
212
+ if not wrappers:
213
+ raise RuntimeError(
214
+ "No coding layers discovered on this model. Instrument your model "
215
+ "(e.g., nervecode.layers.wrap.wrap(model, layers='all_linear')) "
216
+ "before calling WrappedModel.coding_loss()."
217
+ )
218
+ any_enabled = any(getattr(w, "coding_enabled", False) for w in wrappers.values())
219
+ if not any_enabled:
220
+ raise RuntimeError(
221
+ "Coding is disabled on all wrapped layers. Call wrapped.enable_coding() "
222
+ "and run a forward pass before calling WrappedModel.coding_loss()."
223
+ )
224
+ raise RuntimeError(
225
+ "No per-layer traces are available. Run wrapped.forward(...) first to "
226
+ "populate the latest traces, or use forward_with_trace(...) and pass explicit "
227
+ "traces to nervecode.training.CodingLoss."
228
+ )
229
+
230
+ loss_mod = CodingLoss()
231
+ return loss_mod(traces)
232
+
233
+ # Calibrate surprise thresholds/percentiles on in-distribution data.
234
+ def calibrate(self, data_loader: Iterable[Any]) -> None:
235
+ """Calibrate empirical percentiles on aggregated surprise.
236
+
237
+ Runs the wrapped model over a data loader (held-out in-distribution),
238
+ collects per-sample aggregated surprise values, and fits an
239
+ EmpiricalPercentileCalibrator. The fitted calibrator is attached to the
240
+ instance for later use (e.g., percentiles, is_ood) via ``self._calibrator``.
241
+
242
+ Notes
243
+ - The data loader may yield either ``inputs`` or ``(inputs, targets)``.
244
+ - Coding must be enabled on wrapped layers for traces to be produced.
245
+ """
246
+
247
+ try:
248
+ from nervecode.scoring import EmpiricalPercentileCalibrator, mean_surprise
249
+ except Exception as exc: # pragma: no cover - defensive
250
+ raise RuntimeError(
251
+ "Calibration utilities are unavailable. Ensure nervecode.scoring is importable."
252
+ ) from exc
253
+
254
+ # Collect per-sample surprise vectors across the loader
255
+ scores: list[Any] = []
256
+ # Prefer torch.no_grad if available, but avoid hard dependency at import time.
257
+ try:
258
+ import torch
259
+
260
+ no_grad_ctx: Any = torch.no_grad()
261
+ except Exception: # pragma: no cover - environments without torch
262
+
263
+ class _NullCtx:
264
+ def __enter__(self, *a: Any, **k: Any) -> None:
265
+ return None
266
+
267
+ def __exit__(self, *a: Any, **k: Any) -> None:
268
+ return None
269
+
270
+ no_grad_ctx = _NullCtx()
271
+
272
+ with no_grad_ctx:
273
+ for batch in data_loader:
274
+ # Common patterns: (inputs, targets) or inputs directly
275
+ if isinstance(batch, (tuple, list)) and len(batch) >= 1:
276
+ inputs = batch[0]
277
+ elif isinstance(batch, dict) and "inputs" in batch:
278
+ inputs = batch["inputs"]
279
+ else:
280
+ inputs = batch
281
+
282
+ _ = self.forward(inputs)
283
+ agg = self.surprise()
284
+ if agg is None:
285
+ # Fall back to explicit aggregation over the latest traces
286
+ traces = getattr(self, "_last_layer_traces", {}) or {}
287
+ agg = mean_surprise(traces)
288
+ if agg is not None and hasattr(agg, "surprise"):
289
+ scores.append(agg.surprise)
290
+
291
+ if not scores:
292
+ raise RuntimeError(
293
+ "Calibration collected no surprise scores; ensure coding is enabled."
294
+ )
295
+
296
+ # Fit calibrator on concatenated scores when possible (torch),
297
+ # otherwise fit per-chunk (rare path when tensor ops are unavailable).
298
+ try:
299
+ import torch
300
+
301
+ vec = torch.cat([s.detach().to(device="cpu").reshape(-1) for s in scores], dim=0)
302
+ calib = EmpiricalPercentileCalibrator(threshold_quantiles=(0.95,))
303
+ _ = calib.fit(vec, aggregation="mean")
304
+ except Exception: # pragma: no cover - defensive minimal fallback
305
+ calib = EmpiricalPercentileCalibrator(threshold_quantiles=(0.95,))
306
+ # Let calibrator normalize iterable inputs internally
307
+ flat: list[float] = []
308
+ for s in scores:
309
+ with suppress(Exception):
310
+ resh = getattr(s, "reshape", None)
311
+ if callable(resh):
312
+ arr = resh(-1).tolist()
313
+ else:
314
+ arr = s.tolist() if hasattr(s, "tolist") else [s]
315
+ flat.extend([float(v) for v in list(arr)])
316
+ _ = calib.fit(flat, aggregation="mean")
317
+
318
+ self._calibrator = calib
319
+ return None
320
+
321
+ # Inference-time surprise signal (aggregated over wrapped layers).
322
+ def surprise(self) -> object:
323
+ """Return the latest aggregated surprise across wrapped layers.
324
+
325
+ This convenience method aggregates the most recent per-layer traces
326
+ collected during a standard ``forward`` into a per-sample surprise
327
+ signal using the default mean strategy. When no valid per-layer
328
+ signals are available, returns ``None``.
329
+ """
330
+
331
+ try:
332
+ from nervecode.scoring import mean_surprise # lazy import
333
+ except Exception:
334
+ return None
335
+
336
+ traces = getattr(self, "_last_layer_traces", {}) or {}
337
+ return mean_surprise(traces)
338
+
339
+ # Explicit trace-returning path for robust integrations.
340
+ def forward_with_trace(self, *args: Any, **kwargs: Any) -> tuple[object, dict[str, Any]]:
341
+ """Return (model_output, per-layer_traces) for a single forward.
342
+
343
+ The first element mirrors the wrapped model's normal return value. The
344
+ second is a mapping from dotted layer names to ``CodingTrace`` objects
345
+ produced during this forward by wrapped layers (those with coding
346
+ active). When coding is disabled, the mapping is empty.
347
+ """
348
+
349
+ out = self.forward(*args, **kwargs)
350
+ traces = dict(getattr(self, "_last_layer_traces", {}) or {})
351
+ return out, traces
352
+
353
+ # --- Internal: discover coding wrappers on the model -----------------------
354
+ def _refresh_wrapped_layers(self) -> None:
355
+ """Discover and cache coding wrappers present on the wrapped model.
356
+
357
+ Populates ``_wrapped_layers`` with a mapping from dotted module names to
358
+ wrapper instances. Non-PyTorch objects result in an empty registry.
359
+ The method never raises; it is safe to call after in-place
360
+ instrumentation.
361
+ """
362
+
363
+ registry: dict[str, BaseCodingWrapper] = {}
364
+ model = self._model
365
+ try:
366
+ named_modules = getattr(model, "named_modules", None)
367
+ if callable(named_modules):
368
+ for name, module in named_modules():
369
+ # isinstance check remains torch-agnostic because
370
+ # BaseCodingWrapper is torch-free.
371
+ if isinstance(module, BaseCodingWrapper):
372
+ registry[str(name)] = module
373
+ except Exception:
374
+ # Fail-open: keep tracking best-effort only.
375
+ registry = {}
376
+ self._wrapped_layers = registry
377
+
378
+
379
+ def wrap(model: object, *, layers: str | Iterable[str] = "all_linear") -> WrappedModel:
380
+ """Wrap a PyTorch module to produce coding-based surprise signals.
381
+
382
+ Parameters
383
+ - model: Base `torch.nn.Module` to wrap.
384
+ - layers: Strategy or explicit layer names to wrap. The MVP will start with
385
+ the convenience value "all_linear".
386
+ """
387
+ # Import the in-place instrumentation helper lazily to keep import-time
388
+ # dependencies minimal in environments without torch.
389
+ try:
390
+ from nervecode.layers.wrap import wrap as _wrap_layers
391
+ except Exception as exc: # pragma: no cover - defensive
392
+ raise RuntimeError(
393
+ "Layer instrumentation is unavailable. Import nervecode.layers.wrap directly."
394
+ ) from exc
395
+
396
+ try:
397
+ _wrap_layers(model, layers=layers)
398
+ except Exception:
399
+ # Fail-open: return a container even if instrumentation could not be applied.
400
+ # This preserves the attribute and call delegation behavior for non-torch objects.
401
+ return WrappedModel(model)
402
+
403
+ return WrappedModel(model)
404
+
405
+
406
+ __all__ = [
407
+ "WrappedModel",
408
+ "wrap",
409
+ ]
410
+
411
+ # Runtime-accessible package version
412
+ try: # pragma: no cover - trivial import
413
+ from ._version import __version__ as __version__
414
+ except Exception: # pragma: no cover - fallback if version file unavailable
415
+ __version__ = "0.0.0"
nervecode/_version.py ADDED
@@ -0,0 +1,10 @@
1
+ """Package version for Nervecode.
2
+
3
+ The version is the single source of truth for both runtime access via
4
+ ``nervecode.__version__`` and build-time metadata via Hatch's version plugin.
5
+ """
6
+
7
+ __all__ = ["__version__"]
8
+
9
+ # Follow PEP 440; bump as releases evolve.
10
+ __version__ = "0.1.0"
@@ -0,0 +1,19 @@
1
+ """Core data model, primitives, and engines.
2
+
3
+ This subpackage will host the core types, codebook modules, assignment
4
+ algorithms, and trace representations used across Nervecode.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from .temperature import CosineTemperature, FixedTemperature, TemperatureSchedule
10
+ from .trace import CodingTrace
11
+ from .types import SoftCode
12
+
13
+ __all__: list[str] = [
14
+ "CodingTrace",
15
+ "CosineTemperature",
16
+ "FixedTemperature",
17
+ "SoftCode",
18
+ "TemperatureSchedule",
19
+ ]
@@ -0,0 +1,165 @@
1
+ """Soft assignment engine.
2
+
3
+ This module computes squared Euclidean distances from reduced activations to a
4
+ codebook, converts them into soft assignments over codes, and returns both a
5
+ ``SoftCode`` object and the intermediate tensors needed to construct a
6
+ ``CodingTrace`` without recomputing distances.
7
+
8
+ Shapes follow the core convention:
9
+ - Inputs have shape ``(..., D)`` where ``...`` are arbitrary leading dims.
10
+ - Probabilities have shape ``(..., K)`` where ``K`` is the number of codes.
11
+ - Per-position scalar values have shape matching the leading dims ``...``.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from .codebook import Codebook
20
+ from .shapes import flatten_leading, unflatten_leading
21
+ from .types import SoftCode
22
+
23
+ __all__ = ["SoftAssignment"]
24
+
25
+
26
+ class SoftAssignment(nn.Module):
27
+ """Compute soft assignments and return trace-ready intermediates.
28
+
29
+ Parameters
30
+ - temperature: Positive scalar controlling assignment sharpness. Larger
31
+ values produce softer assignments. This parameter is not learnable; use a
32
+ schedule or an external tensor if you need dynamic temperature.
33
+ - eps: Small positive constant used to clamp divisions and logarithms to
34
+ keep values finite.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ *,
40
+ temperature: float = 1.0,
41
+ eps: float = 1e-12,
42
+ beta_length: float = 1.0,
43
+ beta_entropy: float = 1.0,
44
+ beta_distance: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ t = float(temperature)
48
+ if not torch.isfinite(torch.tensor(t)) or t <= 0.0:
49
+ raise ValueError("temperature must be a positive finite float")
50
+ e = float(eps)
51
+ if not torch.isfinite(torch.tensor(e)) or e <= 0.0:
52
+ raise ValueError("eps must be a positive finite float")
53
+ # Explicit attribute annotation helps static type checkers understand
54
+ # the buffer's type (avoids Tensor|Module unions on .to/.clamp paths).
55
+ self._temperature: torch.Tensor
56
+ self.register_buffer("_temperature", torch.tensor(t), persistent=False)
57
+ self.eps: float = e
58
+ # Weights for the combined surprise signal S = bL * L + bH * H + bD * D_norm
59
+ self.beta_length: float = float(beta_length)
60
+ self.beta_entropy: float = float(beta_entropy)
61
+ self.beta_distance: float = float(beta_distance)
62
+
63
+ @property
64
+ def temperature(self) -> float:
65
+ return float(self._temperature.item())
66
+
67
+ def forward(
68
+ self,
69
+ reduced: torch.Tensor,
70
+ codebook: Codebook | torch.Tensor,
71
+ ) -> tuple[SoftCode, dict[str, torch.Tensor]]:
72
+ """Return soft code and intermediates for ``CodingTrace``.
73
+
74
+ Inputs
75
+ - reduced: Tensor of shape ``(..., D)`` representing reduced activations.
76
+ - codebook: ``Codebook`` module or a centers tensor of shape
77
+ ``(K, D)``.
78
+
79
+ Returns
80
+ - soft_code: ``SoftCode`` with probabilities ``(..., K)`` and basic
81
+ scalar statistics (best length, entropy, best indices, combined surprise).
82
+ - intermediates: dict with keys
83
+ - ``nearest_center_distances``: ``(...,)``
84
+ - ``chosen_center_indices``: integer ``(...,)``
85
+ - ``commitment_distances``: ``(...,)``
86
+ which are sufficient to construct a ``CodingTrace`` without
87
+ recomputation.
88
+ """
89
+
90
+ centers = codebook() if isinstance(codebook, Codebook) else codebook
91
+
92
+ if centers.ndim != 2:
93
+ raise ValueError("centers must have shape (K, D)")
94
+
95
+ x = reduced
96
+ if x.ndim < 1:
97
+ raise ValueError("reduced must have at least one dimension (..., D)")
98
+
99
+ x_flat, lead = flatten_leading(x) # (N, D)
100
+ d = int(centers.shape[1])
101
+ if int(x_flat.shape[1]) != d:
102
+ raise ValueError(f"reduced last dim {int(x_flat.shape[1])} must match centers D={d}")
103
+
104
+ # Compute squared Euclidean distances using the expansion
105
+ # ||x - c||^2 = ||x||^2 + ||c||^2 - 2 x·c
106
+ x_norm = (x_flat * x_flat).sum(dim=1, keepdim=True) # (N, 1)
107
+ c_norm = (centers * centers).sum(dim=1).unsqueeze(0) # (1, K)
108
+ # x @ centers^T yields (N, K)
109
+ dot = x_flat @ centers.t() # (N, K)
110
+ distances = x_norm + c_norm - 2.0 * dot # (N, K)
111
+ # For numerical safety, clamp at zero from below (squared distances).
112
+ distances = distances.clamp_min(0.0)
113
+
114
+ # Convert distances into softmax probabilities over codes using
115
+ # negative distances as logits, scaled by temperature.
116
+ t = self._temperature.to(dtype=distances.dtype, device=distances.device)
117
+ logits = -distances / torch.clamp(t, min=self.eps)
118
+ probs_flat = torch.softmax(logits, dim=-1)
119
+
120
+ # Best indices from min distance (equivalently max probability).
121
+ best_idx_flat = torch.argmin(distances, dim=-1) # (N,)
122
+ # Nearest-center distances and commitment distances per position.
123
+ nearest_flat = distances.gather(-1, best_idx_flat.unsqueeze(-1)).squeeze(-1) # (N,)
124
+ commitment_flat = nearest_flat # identical for squared Euclidean
125
+
126
+ # Information-theoretic scalars
127
+ eps_t = torch.tensor(self.eps, dtype=probs_flat.dtype, device=probs_flat.device)
128
+ best_prob_flat = probs_flat.gather(-1, best_idx_flat.unsqueeze(-1)).squeeze(-1)
129
+ best_length_flat = -(best_prob_flat + eps_t).log()
130
+ entropy_flat = -(probs_flat * (probs_flat + eps_t).log()).sum(dim=-1)
131
+ # Distance-based component: use a stable, monotonically increasing
132
+ # transform of the nearest-center squared Euclidean distance. ``log1p``
133
+ # reduces dynamic range without erasing separation. Avoid per-forward
134
+ # normalization to preserve absolute contrast between ID and OOD.
135
+ d_comp_flat = torch.log1p(nearest_flat)
136
+ # Weighted combination including optional distance contribution
137
+ combined_surprise_flat = (
138
+ self.beta_length * best_length_flat
139
+ + self.beta_entropy * entropy_flat
140
+ + self.beta_distance * d_comp_flat
141
+ )
142
+
143
+ # Restore original leading layout
144
+ probs = unflatten_leading(probs_flat, lead) # (..., K)
145
+ best_idx = unflatten_leading(best_idx_flat, lead) # (...,)
146
+ nearest = unflatten_leading(nearest_flat, lead) # (...,)
147
+ commitment = unflatten_leading(commitment_flat, lead) # (...,)
148
+ best_length = unflatten_leading(best_length_flat, lead) # (...,)
149
+ entropy = unflatten_leading(entropy_flat, lead) # (...,)
150
+ combined_surprise = unflatten_leading(combined_surprise_flat, lead) # (...,)
151
+
152
+ soft = SoftCode(
153
+ probs=probs,
154
+ best_length=best_length,
155
+ entropy=entropy,
156
+ best_indices=best_idx,
157
+ combined_surprise=combined_surprise,
158
+ )
159
+
160
+ intermediates: dict[str, torch.Tensor] = {
161
+ "nearest_center_distances": nearest,
162
+ "chosen_center_indices": best_idx,
163
+ "commitment_distances": commitment,
164
+ }
165
+ return soft, intermediates