nervecode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,182 @@
1
+ """Learnable codebook module.
2
+
3
+ This module defines a small ``torch.nn.Module`` that stores a set of codebook
4
+ centers and exposes them as a trainable parameter. The centers are initialized
5
+ with a scale that depends on the coding-space dimension so that distances and
6
+ soft assignments behave well at the start of training.
7
+
8
+ The shape convention follows the rest of the core: a codebook with ``K`` codes
9
+ over a ``code_dim``-dimensional space uses a parameter of shape ``(K, code_dim)``.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ from collections import OrderedDict
16
+ from collections.abc import MutableMapping
17
+ from contextlib import suppress
18
+ from typing import Any, TypeVar, cast, overload
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+ __all__ = ["Codebook"]
24
+
25
+
26
+ class Codebook(nn.Module):
27
+ """Gradient-updated codebook with centers of shape ``(K, code_dim)``.
28
+
29
+ Parameters
30
+ - K: Number of codes (centers) in the codebook. Must be ``>= 1``.
31
+ - code_dim: Dimensionality of the coding space. Must be ``>= 1``.
32
+ - init: Initialization strategy for centers. Supported values are
33
+ ``"uniform"`` (default) and ``"normal"``. Both scale with
34
+ ``1 / sqrt(code_dim)``.
35
+ - device, dtype: Optional factory kwargs for creating the parameter.
36
+
37
+ Notes
38
+ - The centers are stored in ``self.centers`` and require gradients by
39
+ default, making the codebook trainable by standard optimizers.
40
+ - The default uniform initialization bounds each component within
41
+ ``[-1/sqrt(code_dim), 1/sqrt(code_dim)]``.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ K: int,
47
+ code_dim: int,
48
+ *,
49
+ init: str = "uniform",
50
+ device: torch.device | None = None,
51
+ dtype: torch.dtype | None = None,
52
+ ) -> None:
53
+ super().__init__()
54
+ if int(K) <= 0:
55
+ raise ValueError("K must be >= 1")
56
+ if int(code_dim) <= 0:
57
+ raise ValueError("code_dim must be >= 1")
58
+
59
+ self.K: int = int(K)
60
+ self.code_dim: int = int(code_dim)
61
+ # Store init strategy so that reset_parameters matches PyTorch's
62
+ # signature (no args) and can be invoked by utilities.
63
+ if init not in {"uniform", "normal"}:
64
+ raise ValueError("init must be one of {'uniform', 'normal'}")
65
+ self.init: str = init
66
+
67
+ factory: dict[str, Any] = {}
68
+ if device is not None:
69
+ factory["device"] = device
70
+ if dtype is not None:
71
+ factory["dtype"] = dtype
72
+
73
+ self.centers: nn.Parameter
74
+ self.centers = nn.Parameter(torch.empty(self.K, self.code_dim, **factory))
75
+ self.reset_parameters()
76
+
77
+ def reset_parameters(self) -> None:
78
+ """Initialize centers with scale tied to the coding dimension.
79
+
80
+ - ``uniform``: components sampled from ``[-s, s]`` where
81
+ ``s = 1 / sqrt(code_dim)``.
82
+ - ``normal``: components sampled from ``N(0, s)`` with
83
+ ``s = 1 / sqrt(code_dim)``.
84
+ """
85
+
86
+ scale = 1.0 / math.sqrt(float(self.code_dim))
87
+ if self.init == "uniform":
88
+ nn.init.uniform_(self.centers, -scale, scale)
89
+ elif self.init == "normal":
90
+ nn.init.normal_(self.centers, mean=0.0, std=scale)
91
+
92
+ def forward(self) -> torch.Tensor:
93
+ """Return the current centers tensor.
94
+
95
+ This convenience makes it easy to pass a codebook into distance/assignment
96
+ utilities while keeping it a standard ``nn.Module`` with a parameter.
97
+ """
98
+
99
+ return self.centers
100
+
101
+ def extra_repr(self) -> str: # pragma: no cover - cosmetic
102
+ return f"K={self.K}, code_dim={self.code_dim}, init='{self.init}'"
103
+
104
+ # --- Serialization contract -------------------------------------------------
105
+ # We persist lightweight metadata (K, code_dim, init) as module extra state
106
+ # so that checkpoints carry sufficient information for inspection and
107
+ # validation when reloading.
108
+ def get_extra_state(self) -> dict[str, Any]:
109
+ """Return metadata saved alongside parameters in ``state_dict``."""
110
+
111
+ return {
112
+ "version": 1,
113
+ "K": self.K,
114
+ "code_dim": self.code_dim,
115
+ "init": self.init,
116
+ }
117
+
118
+ def set_extra_state(self, state: Any) -> None:
119
+ """Load metadata saved via :meth:`get_extra_state`.
120
+
121
+ This validates that the incoming metadata matches the current instance
122
+ shape, which helps catch accidental loads into mismatched modules.
123
+ """
124
+
125
+ if not isinstance(state, dict): # pragma: no cover - defensive
126
+ return
127
+ # Optional version check for forward-compat.
128
+ _ = int(state.get("version", 1))
129
+ K = int(state.get("K", self.K))
130
+ D = int(state.get("code_dim", self.code_dim))
131
+ init = str(state.get("init", self.init))
132
+ # Validate but do not mutate core attributes.
133
+ if (self.K, self.code_dim) != (K, D): # pragma: no cover - rare
134
+ raise RuntimeError(
135
+ "Loaded Codebook metadata does not match current shape: "
136
+ f"got (K={K}, code_dim={D}) vs current (K={self.K}, code_dim={self.code_dim})"
137
+ )
138
+ if init not in {"uniform", "normal"}: # pragma: no cover - rare
139
+ # Ignore unknown init; keep current.
140
+ return
141
+ # Keep self.init as-is to avoid surprising behavior during load; users
142
+ # can call reset_parameters() explicitly if they wish to reinitialize.
143
+
144
+ # PyTorch automatically includes the return of get_extra_state() under the
145
+ # special key "_extra_state" inside state_dict(). Our tests expect the
146
+ # public state_dict to expose only the trainable parameter 'centers'.
147
+ # Override state_dict to drop the special entry while keeping load_state_dict
148
+ # behavior unchanged (it tolerates missing extra state).
149
+ T_destination = TypeVar("T_destination", bound=MutableMapping[str, Any])
150
+
151
+ @overload
152
+ def state_dict(
153
+ self, *, destination: T_destination, prefix: str = "", keep_vars: bool = False
154
+ ) -> T_destination: ...
155
+
156
+ @overload
157
+ def state_dict(self, *, prefix: str = "", keep_vars: bool = False) -> dict[str, Any]: ...
158
+
159
+ def state_dict(self, *args: Any, **kwargs: Any) -> MutableMapping[str, Any]:
160
+ sd_any = super().state_dict(*args, **kwargs)
161
+ sd = cast(MutableMapping[str, Any], sd_any)
162
+ with suppress(Exception):
163
+ # Drop extra metadata key while preserving parameter entries
164
+ sd.pop("_extra_state", None)
165
+ return sd
166
+
167
+ def load_state_dict(self, state_dict: Any, strict: bool = True, assign: bool = False) -> Any:
168
+ """Load centers while tolerating missing _extra_state in strict mode.
169
+
170
+ Tests expect public state_dicts to contain only 'centers'. Since
171
+ PyTorch includes get_extra_state() under '_extra_state' by default,
172
+ we synthesize a minimal extra state if it's missing so strict=True
173
+ loads remain compatible.
174
+ """
175
+ try:
176
+ if isinstance(state_dict, dict) and "_extra_state" not in state_dict:
177
+ sd = OrderedDict(state_dict)
178
+ sd["_extra_state"] = self.get_extra_state()
179
+ return super().load_state_dict(sd, strict=strict, assign=assign)
180
+ except Exception:
181
+ pass
182
+ return super().load_state_dict(state_dict, strict=strict, assign=assign)
@@ -0,0 +1,107 @@
1
+ """Shape helpers for coding engines.
2
+
3
+ This module provides utilities to flatten and unflatten the leading dimensions
4
+ of tensors that follow the common Nervecode convention of a feature or code
5
+ dimension as the final axis. This allows assignment and distance computations to
6
+ operate uniformly on shapes like:
7
+
8
+ - batch-only: ``(B, D)`` or probabilities ``(B, K)``,
9
+ - token-like: ``(B, T, D)`` or ``(B, T, K)``,
10
+ - pooled convolution: ``(B, H, W, D)`` or ``(B, H, W, K)``.
11
+
12
+ By flattening the leading dimensions, engines can work over a simple 2D view
13
+ ``(N, D)`` or ``(N, K)`` where ``N = prod(leading_shape)``. The corresponding
14
+ unflatten operation restores the original leading layout.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import torch
20
+
21
+ __all__ = [
22
+ "flatten_leading",
23
+ "leading_shape",
24
+ "unflatten_leading",
25
+ ]
26
+
27
+
28
+ def leading_shape(x: torch.Tensor) -> tuple[int, ...]:
29
+ """Return the tuple of leading dimensions before the final axis.
30
+
31
+ Given a tensor ``x`` with shape ``(..., D)``, this returns ``...``.
32
+
33
+ Parameters
34
+ - x: Tensor with at least one dimension.
35
+
36
+ Returns
37
+ - Tuple of ints representing the leading shape.
38
+ """
39
+
40
+ if x.ndim < 1:
41
+ raise ValueError("x must have at least one dimension (..., D)")
42
+ # Explicit int cast to avoid mypy complaints about torch.Size contents.
43
+ return tuple(int(s) for s in x.shape[:-1])
44
+
45
+
46
+ def flatten_leading(x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, ...]]:
47
+ """Flatten all leading dims into a single dimension.
48
+
49
+ For an input with shape ``(..., D)``, returns ``x_flat`` with shape
50
+ ``(N, D)`` where ``N = prod(...)`` along with the original leading shape.
51
+
52
+ This is a light wrapper around ``reshape`` that preserves non-contiguity
53
+ safety.
54
+
55
+ Parameters
56
+ - x: Tensor of shape ``(..., D)``.
57
+
58
+ Returns
59
+ - x_flat: Tensor of shape ``(N, D)`` where ``N = prod(leading_shape)``.
60
+ - lead: The original leading shape as a tuple of ints.
61
+ """
62
+
63
+ if x.ndim < 1:
64
+ raise ValueError("x must have at least one dimension (..., D)")
65
+ lead = leading_shape(x)
66
+ d = int(x.shape[-1])
67
+ # Compute N = prod(lead) safely even when lead is empty.
68
+ n = 1
69
+ for s in lead:
70
+ n *= int(s)
71
+ # Use reshape to handle non-contiguous inputs robustly.
72
+ x_flat = x.reshape(n, d)
73
+ return x_flat, lead
74
+
75
+
76
+ def unflatten_leading(y: torch.Tensor, lead: tuple[int, ...]) -> torch.Tensor:
77
+ """Restore a flattened first dimension back to ``lead``.
78
+
79
+ If ``y`` has shape ``(N, ...)`` where ``N = prod(lead)``, the returned
80
+ tensor has shape ``(*lead, ...)``. This works for scalar-like per-position
81
+ tensors (``(N,)``), probability tensors (``(N, K)``), and vector-valued
82
+ tensors (``(N, D)``).
83
+
84
+ Parameters
85
+ - y: Tensor whose first dimension flattens a known set of leading dims.
86
+ - lead: Tuple of leading sizes that were previously flattened.
87
+
88
+ Returns
89
+ - Tensor reshaped to ``(*lead, *y.shape[1:])``.
90
+ """
91
+
92
+ # Validate that y has at least one dimension representing the flattened N.
93
+ if y.ndim < 1:
94
+ raise ValueError("y must have at least one dimension (N, ...)")
95
+
96
+ # Compute expected N and verify it matches when eager shapes are available.
97
+ n = 1
98
+ for s in lead:
99
+ n *= int(s)
100
+ if int(y.shape[0]) != n:
101
+ raise ValueError(f"first dimension {int(y.shape[0])} does not match prod(lead)={n}")
102
+
103
+ # Reshape back to the original leading layout, preserving trailing dims.
104
+ if y.ndim == 1:
105
+ # Special case: scalar-like values (N,) -> (*lead,)
106
+ return y.reshape(*lead)
107
+ return y.reshape(*lead, *y.shape[1:])
@@ -0,0 +1,227 @@
1
+ """Temperature schedules for coding and training.
2
+
3
+ This module provides a minimal interface and two concrete schedules for the
4
+ MVP: a fixed temperature and a cosine schedule. The interface is designed so
5
+ that a future learnable temperature (e.g., an ``nn.Parameter`` inside an
6
+ ``nn.Module``) can conform to the same calling convention.
7
+
8
+ Design goals
9
+ - Keep evaluation cheap and dependency-light.
10
+ - Accept either explicit progress (``0..1``) or step-based inputs.
11
+ - Provide a tensor-producing helper for easy integration in Torch graphs.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import math
17
+ from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, cast, runtime_checkable
18
+
19
+ # Optional PyTorch typing without a hard runtime dependency.
20
+ if TYPE_CHECKING: # pragma: no cover - typing-only branch
21
+ import torch as _torch
22
+ from torch import Tensor as _Tensor
23
+
24
+ TensorT: TypeAlias = _Tensor
25
+ DeviceT: TypeAlias = _torch.device
26
+ DTypeT: TypeAlias = _torch.dtype
27
+ else: # Fallbacks keep the module importable without torch installed.
28
+ TensorT: TypeAlias = Any
29
+ DeviceT: TypeAlias = Any
30
+ DTypeT: TypeAlias = Any
31
+
32
+ __all__ = ["CosineTemperature", "FixedTemperature", "TemperatureSchedule"]
33
+
34
+
35
+ def _require_torch() -> Any:
36
+ """Import torch on-demand with a helpful error if missing.
37
+
38
+ This keeps the module importable when PyTorch is not installed while still
39
+ supporting ``as_tensor`` helpers for users who have it.
40
+ """
41
+ try:
42
+ import torch
43
+ except Exception as exc: # pragma: no cover - exercised in envs without torch
44
+ raise RuntimeError(
45
+ "as_tensor requires PyTorch; install 'torch' to use this method."
46
+ ) from exc
47
+ return torch
48
+
49
+
50
+ @runtime_checkable
51
+ class TemperatureSchedule(Protocol):
52
+ """Callable interface for temperature schedules.
53
+
54
+ Implementations should return a positive scalar temperature when invoked.
55
+ Schedules may accept either normalized progress in ``[0, 1]`` via the
56
+ ``progress`` keyword, or a ``step`` index with an associated total number
57
+ of steps ``total``. When both are supplied, ``progress`` takes precedence.
58
+ """
59
+
60
+ def __call__(
61
+ self,
62
+ step: int | None = None,
63
+ *,
64
+ total: int | None = None,
65
+ progress: float | None = None,
66
+ ) -> float: # pragma: no cover - interface only
67
+ ...
68
+
69
+ def as_tensor(
70
+ self,
71
+ step: int | None = None,
72
+ *,
73
+ total: int | None = None,
74
+ progress: float | None = None,
75
+ device: DeviceT | None = None,
76
+ dtype: DTypeT | None = None,
77
+ ) -> TensorT: # pragma: no cover - interface only
78
+ ...
79
+
80
+
81
+ class FixedTemperature:
82
+ """Constant temperature independent of step or progress."""
83
+
84
+ def __init__(self, temperature: float) -> None:
85
+ t = float(temperature)
86
+ if not math.isfinite(t) or t <= 0.0:
87
+ raise ValueError("temperature must be a positive finite float")
88
+ self._t: float = t
89
+
90
+ def __repr__(self) -> str: # pragma: no cover - cosmetic
91
+ return f"FixedTemperature(temperature={self._t})"
92
+
93
+ def __call__(
94
+ self,
95
+ step: int | None = None,
96
+ *,
97
+ total: int | None = None,
98
+ progress: float | None = None,
99
+ ) -> float:
100
+ # Parameters are accepted for interface uniformity; they do not affect
101
+ # a fixed schedule.
102
+ _ = step, total, progress
103
+ return self._t
104
+
105
+ def as_tensor(
106
+ self,
107
+ step: int | None = None,
108
+ *,
109
+ total: int | None = None,
110
+ progress: float | None = None,
111
+ device: DeviceT | None = None,
112
+ dtype: DTypeT | None = None,
113
+ ) -> TensorT:
114
+ _ = step, total, progress
115
+ torch = _require_torch()
116
+ return cast("TensorT", torch.tensor(self._t, device=device, dtype=dtype))
117
+
118
+
119
+ class CosineTemperature:
120
+ """Cosine schedule between ``start`` and ``end`` with optional warmup.
121
+
122
+ The temperature follows the closed-form cosine schedule
123
+
124
+ ``T(p) = end + (start - end) * 0.5 * (1 + cos(pi * p))``
125
+
126
+ where ``p`` is normalized progress in ``[0, 1]``. Progress can be supplied
127
+ directly, or derived from ``step``/``total`` while honoring
128
+ ``warmup_steps``. During warmup, the temperature stays at ``start``.
129
+
130
+ Parameters
131
+ - start: Temperature at the beginning (``p = 0``). Must be ``> 0``.
132
+ - end: Temperature at the end (``p = 1``). Must be ``> 0``.
133
+ - total_steps: Default total steps for the schedule. Can be overridden per
134
+ call using the ``total=`` kwarg.
135
+ - warmup_steps: Number of initial steps to hold at ``start`` before
136
+ annealing begins. Must be ``>= 0``.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ start: float,
142
+ end: float,
143
+ *,
144
+ total_steps: int | None = None,
145
+ warmup_steps: int = 0,
146
+ ) -> None:
147
+ s = float(start)
148
+ e = float(end)
149
+ if not math.isfinite(s) or s <= 0.0:
150
+ raise ValueError("start must be a positive finite float")
151
+ if not math.isfinite(e) or e <= 0.0:
152
+ raise ValueError("end must be a positive finite float")
153
+ if total_steps is not None and int(total_steps) <= 0:
154
+ raise ValueError("total_steps must be >= 1 when provided")
155
+ if int(warmup_steps) < 0:
156
+ raise ValueError("warmup_steps must be >= 0")
157
+
158
+ self.start: float = s
159
+ self.end: float = e
160
+ self.total_steps: int | None = int(total_steps) if total_steps is not None else None
161
+ self.warmup_steps: int = int(warmup_steps)
162
+
163
+ def __repr__(self) -> str: # pragma: no cover - cosmetic
164
+ return (
165
+ "CosineTemperature("
166
+ f"start={self.start}, end={self.end}, "
167
+ f"total_steps={self.total_steps}, warmup_steps={self.warmup_steps})"
168
+ )
169
+
170
+ def _progress_from_step(self, step: int, total: int) -> float:
171
+ # Compute normalized progress in [0, 1] after warmup.
172
+ s = max(0, int(step))
173
+ t = max(1, int(total))
174
+ w = max(0, self.warmup_steps)
175
+ if s <= w:
176
+ return 0.0
177
+ # Remaining steps after warmup; guard against division by zero.
178
+ remain = max(1, t - w)
179
+ p = (s - w) / float(remain)
180
+ # Clamp to [0, 1] to be robust against overrun.
181
+ if p < 0.0:
182
+ return 0.0
183
+ if p > 1.0:
184
+ return 1.0
185
+ return p
186
+
187
+ @staticmethod
188
+ def _cosine_mix(start: float, end: float, p01: float) -> float:
189
+ # Clamp progress defensively.
190
+ p = 0.0 if p01 < 0.0 else (1.0 if p01 > 1.0 else p01)
191
+ return end + (start - end) * 0.5 * (1.0 + math.cos(math.pi * p))
192
+
193
+ def __call__(
194
+ self,
195
+ step: int | None = None,
196
+ *,
197
+ total: int | None = None,
198
+ progress: float | None = None,
199
+ ) -> float:
200
+ if progress is not None:
201
+ return self._cosine_mix(self.start, self.end, float(progress))
202
+
203
+ # Derive progress from steps; prefer call-time total over default.
204
+ resolved_total = int(total) if total is not None else self.total_steps
205
+ if resolved_total is None or step is None:
206
+ raise ValueError(
207
+ "CosineTemperature requires either progress in [0,1] or both step and total"
208
+ )
209
+
210
+ p = self._progress_from_step(int(step), int(resolved_total))
211
+ # During warmup (_progress_from_step returns 0), hold at start when step <= warmup
212
+ if int(step) <= self.warmup_steps:
213
+ return float(self.start)
214
+ return self._cosine_mix(self.start, self.end, p)
215
+
216
+ def as_tensor(
217
+ self,
218
+ step: int | None = None,
219
+ *,
220
+ total: int | None = None,
221
+ progress: float | None = None,
222
+ device: DeviceT | None = None,
223
+ dtype: DTypeT | None = None,
224
+ ) -> TensorT:
225
+ val = self(step, total=total, progress=progress)
226
+ torch = _require_torch()
227
+ return cast("TensorT", torch.tensor(val, device=device, dtype=dtype))
@@ -0,0 +1,166 @@
1
+ """Trace data structures for coding.
2
+
3
+ This module defines lightweight, immutable containers that carry the rich
4
+ per-layer record produced during coding. Traces support arbitrary leading
5
+ dimensions so they can represent batch-only, token-like, or pooled layouts.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Any
12
+
13
+ from .types import SoftCode
14
+
15
+ __all__ = ["CodingTrace"]
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class CodingTrace:
20
+ """Per-layer coding trace.
21
+
22
+ Carries the reduced activation, reduction metadata, nearest-center and
23
+ commitment distances, chosen indices, and the corresponding ``SoftCode``.
24
+
25
+ Fields
26
+ - ``reduced`` (Tensor): reduced activation with shape ``(..., D)``.
27
+ - ``reduction_meta`` (dict): metadata describing the applied reduction
28
+ (e.g., method, parameters). May be empty when no reduction is applied.
29
+ - ``nearest_center_distances`` (Tensor): distance to the selected center
30
+ per position with shape matching the leading dimensions ``...``.
31
+ - ``chosen_center_indices`` (Tensor): index of the selected/nearest center
32
+ per position with shape matching the leading dimensions ``...``.
33
+ - ``commitment_distances`` (Tensor): commitment distance per position with
34
+ shape matching the leading dimensions ``...``.
35
+ - ``soft_code`` (SoftCode): soft assignment statistics over ``K`` codes
36
+ with probabilities of shape ``(..., K)``.
37
+
38
+ All scalar-like fields must match the leading shape of ``soft_code.probs``
39
+ exactly (i.e., ``tensor.shape == soft_code.probs.shape[:-1]``).
40
+ """
41
+
42
+ reduced: Any
43
+ reduction_meta: dict[str, Any]
44
+ nearest_center_distances: Any
45
+ chosen_center_indices: Any
46
+ commitment_distances: Any
47
+ soft_code: SoftCode
48
+
49
+ def __post_init__(self) -> None:
50
+ # Import torch lazily to keep import-time dependencies lightweight.
51
+ from typing import Any
52
+ from typing import cast as _cast
53
+
54
+ try:
55
+ import torch
56
+ except Exception: # pragma: no cover - torch may be absent in some envs
57
+ torch = _cast(Any, None)
58
+
59
+ # Basic validation of required components
60
+ if not isinstance(self.soft_code, SoftCode): # pragma: no cover - defensive
61
+ raise TypeError("soft_code must be a SoftCode instance")
62
+
63
+ # Validate reduced tensor shape
64
+ if torch is not None and not isinstance(
65
+ self.reduced, torch.Tensor
66
+ ): # pragma: no cover - defensive
67
+ raise TypeError("reduced must be a torch.Tensor")
68
+ if getattr(self.reduced, "ndim", None) is None or self.reduced.ndim < 1:
69
+ raise ValueError("reduced must have at least one dimension (..., D)")
70
+ if int(self.reduced.shape[-1]) <= 0:
71
+ raise ValueError("final reduced dimension D must be >= 1")
72
+
73
+ # If torch isn't available, skip dtype/shape validations that require it.
74
+ if torch is None: # pragma: no cover
75
+ return
76
+
77
+ lead_shape = tuple(int(s) for s in self.soft_code.leading_shape)
78
+ reduced_lead = tuple(int(s) for s in self.reduced.shape[:-1])
79
+ if reduced_lead != lead_shape:
80
+ raise ValueError(
81
+ f"reduced.leading_shape {reduced_lead} must match "
82
+ f"soft_code.leading_shape {lead_shape}"
83
+ )
84
+
85
+ def _check_scalar_tensor(name: str, value: Any, *, require_integer: bool = False) -> None:
86
+ if not isinstance(value, torch.Tensor):
87
+ raise TypeError(f"{name} must be a torch.Tensor")
88
+ if tuple(int(s) for s in value.shape) != lead_shape:
89
+ raise ValueError(
90
+ f"{name} must have shape {lead_shape} to match soft_code.leading_shape"
91
+ )
92
+ if require_integer:
93
+ if torch.is_floating_point(value) or value.dtype == torch.bool:
94
+ raise TypeError(f"{name} must use an integer dtype")
95
+ else:
96
+ if value.dtype == torch.bool:
97
+ raise TypeError(f"{name} must be numeric, not bool")
98
+
99
+ _check_scalar_tensor("nearest_center_distances", self.nearest_center_distances)
100
+ _check_scalar_tensor(
101
+ "chosen_center_indices", self.chosen_center_indices, require_integer=True
102
+ )
103
+ _check_scalar_tensor("commitment_distances", self.commitment_distances)
104
+
105
+ @property
106
+ def leading_shape(self) -> tuple[int, ...]:
107
+ """Leading shape shared by all per-position fields."""
108
+
109
+ return tuple(int(s) for s in self.soft_code.leading_shape)
110
+
111
+ @property
112
+ def reduced_dim(self) -> int:
113
+ """Size of the final reduced dimension ``D``."""
114
+
115
+ return int(self.reduced.shape[-1])
116
+
117
+ # --- Sample-level reduction views -------------------------------------------
118
+ def sample_reduced_surprise(self, *, reduction: str = "mean") -> Any:
119
+ """Return a per-sample surprise tensor by reducing extra leading dims.
120
+
121
+ Many wrappers operate over activations with multiple leading dimensions
122
+ (e.g., batch and sequence). Aggregation across layers requires a common
123
+ per-sample view. This helper reduces the per-position surprise stored on
124
+ the associated ``SoftCode`` to a vector over samples by collapsing all
125
+ leading dimensions after the first (typically sequence/time) using the
126
+ specified reduction.
127
+
128
+ Parameters
129
+ - reduction: Current supported value is ``"mean"``. Additional
130
+ reductions may be added later.
131
+
132
+ Returns
133
+ - A tensor with shape ``(B,)`` where ``B`` is the size of the first
134
+ leading dimension. In environments without torch or when an
135
+ interpretable surprise is unavailable, this may return ``None`` at
136
+ runtime; the return type is ``Any`` to remain flexible during import
137
+ without a hard torch dependency.
138
+ """
139
+
140
+ # Import torch lazily to avoid hard dependency at import time.
141
+ try:
142
+ import torch
143
+ except Exception: # pragma: no cover - environments without torch
144
+ return None
145
+
146
+ # Prefer a combined surprise when present; otherwise use best_length.
147
+ cand = self.soft_code.combined_surprise
148
+ if cand is None:
149
+ cand = self.soft_code.best_length
150
+ if not isinstance(cand, torch.Tensor):
151
+ return None
152
+
153
+ # Scalar-like can't be mapped to samples reliably; skip.
154
+ if cand.ndim == 0: # pragma: no cover - uncommon layout
155
+ return None
156
+
157
+ if cand.ndim == 1:
158
+ return cand
159
+
160
+ # Reduce all leading dims after the first to obtain a per-sample vector.
161
+ if reduction == "mean":
162
+ dims = tuple(range(1, int(cand.ndim)))
163
+ return cand.mean(dim=dims)
164
+ # Fallback: default to mean for unknown identifiers to remain fail-open.
165
+ dims = tuple(range(1, int(cand.ndim)))
166
+ return cand.mean(dim=dims)