nervecode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nervecode/__init__.py +415 -0
- nervecode/_version.py +10 -0
- nervecode/core/__init__.py +19 -0
- nervecode/core/assignment.py +165 -0
- nervecode/core/codebook.py +182 -0
- nervecode/core/shapes.py +107 -0
- nervecode/core/temperature.py +227 -0
- nervecode/core/trace.py +166 -0
- nervecode/core/types.py +116 -0
- nervecode/integration/__init__.py +9 -0
- nervecode/layers/__init__.py +15 -0
- nervecode/layers/base.py +333 -0
- nervecode/layers/conv.py +174 -0
- nervecode/layers/linear.py +176 -0
- nervecode/layers/reducers.py +80 -0
- nervecode/layers/wrap.py +223 -0
- nervecode/scoring/__init__.py +20 -0
- nervecode/scoring/aggregator.py +369 -0
- nervecode/scoring/calibrator.py +396 -0
- nervecode/scoring/types.py +33 -0
- nervecode/training/__init__.py +25 -0
- nervecode/training/diagnostics.py +194 -0
- nervecode/training/loss.py +188 -0
- nervecode/training/updaters.py +168 -0
- nervecode/utils/__init__.py +14 -0
- nervecode/utils/overhead.py +177 -0
- nervecode/utils/seed.py +161 -0
- nervecode-0.1.0.dist-info/METADATA +83 -0
- nervecode-0.1.0.dist-info/RECORD +31 -0
- nervecode-0.1.0.dist-info/WHEEL +4 -0
- nervecode-0.1.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Learnable codebook module.
|
|
2
|
+
|
|
3
|
+
This module defines a small ``torch.nn.Module`` that stores a set of codebook
|
|
4
|
+
centers and exposes them as a trainable parameter. The centers are initialized
|
|
5
|
+
with a scale that depends on the coding-space dimension so that distances and
|
|
6
|
+
soft assignments behave well at the start of training.
|
|
7
|
+
|
|
8
|
+
The shape convention follows the rest of the core: a codebook with ``K`` codes
|
|
9
|
+
over a ``code_dim``-dimensional space uses a parameter of shape ``(K, code_dim)``.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
from collections import OrderedDict
|
|
16
|
+
from collections.abc import MutableMapping
|
|
17
|
+
from contextlib import suppress
|
|
18
|
+
from typing import Any, TypeVar, cast, overload
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn
|
|
22
|
+
|
|
23
|
+
__all__ = ["Codebook"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Codebook(nn.Module):
|
|
27
|
+
"""Gradient-updated codebook with centers of shape ``(K, code_dim)``.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
- K: Number of codes (centers) in the codebook. Must be ``>= 1``.
|
|
31
|
+
- code_dim: Dimensionality of the coding space. Must be ``>= 1``.
|
|
32
|
+
- init: Initialization strategy for centers. Supported values are
|
|
33
|
+
``"uniform"`` (default) and ``"normal"``. Both scale with
|
|
34
|
+
``1 / sqrt(code_dim)``.
|
|
35
|
+
- device, dtype: Optional factory kwargs for creating the parameter.
|
|
36
|
+
|
|
37
|
+
Notes
|
|
38
|
+
- The centers are stored in ``self.centers`` and require gradients by
|
|
39
|
+
default, making the codebook trainable by standard optimizers.
|
|
40
|
+
- The default uniform initialization bounds each component within
|
|
41
|
+
``[-1/sqrt(code_dim), 1/sqrt(code_dim)]``.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
K: int,
|
|
47
|
+
code_dim: int,
|
|
48
|
+
*,
|
|
49
|
+
init: str = "uniform",
|
|
50
|
+
device: torch.device | None = None,
|
|
51
|
+
dtype: torch.dtype | None = None,
|
|
52
|
+
) -> None:
|
|
53
|
+
super().__init__()
|
|
54
|
+
if int(K) <= 0:
|
|
55
|
+
raise ValueError("K must be >= 1")
|
|
56
|
+
if int(code_dim) <= 0:
|
|
57
|
+
raise ValueError("code_dim must be >= 1")
|
|
58
|
+
|
|
59
|
+
self.K: int = int(K)
|
|
60
|
+
self.code_dim: int = int(code_dim)
|
|
61
|
+
# Store init strategy so that reset_parameters matches PyTorch's
|
|
62
|
+
# signature (no args) and can be invoked by utilities.
|
|
63
|
+
if init not in {"uniform", "normal"}:
|
|
64
|
+
raise ValueError("init must be one of {'uniform', 'normal'}")
|
|
65
|
+
self.init: str = init
|
|
66
|
+
|
|
67
|
+
factory: dict[str, Any] = {}
|
|
68
|
+
if device is not None:
|
|
69
|
+
factory["device"] = device
|
|
70
|
+
if dtype is not None:
|
|
71
|
+
factory["dtype"] = dtype
|
|
72
|
+
|
|
73
|
+
self.centers: nn.Parameter
|
|
74
|
+
self.centers = nn.Parameter(torch.empty(self.K, self.code_dim, **factory))
|
|
75
|
+
self.reset_parameters()
|
|
76
|
+
|
|
77
|
+
def reset_parameters(self) -> None:
|
|
78
|
+
"""Initialize centers with scale tied to the coding dimension.
|
|
79
|
+
|
|
80
|
+
- ``uniform``: components sampled from ``[-s, s]`` where
|
|
81
|
+
``s = 1 / sqrt(code_dim)``.
|
|
82
|
+
- ``normal``: components sampled from ``N(0, s)`` with
|
|
83
|
+
``s = 1 / sqrt(code_dim)``.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
scale = 1.0 / math.sqrt(float(self.code_dim))
|
|
87
|
+
if self.init == "uniform":
|
|
88
|
+
nn.init.uniform_(self.centers, -scale, scale)
|
|
89
|
+
elif self.init == "normal":
|
|
90
|
+
nn.init.normal_(self.centers, mean=0.0, std=scale)
|
|
91
|
+
|
|
92
|
+
def forward(self) -> torch.Tensor:
|
|
93
|
+
"""Return the current centers tensor.
|
|
94
|
+
|
|
95
|
+
This convenience makes it easy to pass a codebook into distance/assignment
|
|
96
|
+
utilities while keeping it a standard ``nn.Module`` with a parameter.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
return self.centers
|
|
100
|
+
|
|
101
|
+
def extra_repr(self) -> str: # pragma: no cover - cosmetic
|
|
102
|
+
return f"K={self.K}, code_dim={self.code_dim}, init='{self.init}'"
|
|
103
|
+
|
|
104
|
+
# --- Serialization contract -------------------------------------------------
|
|
105
|
+
# We persist lightweight metadata (K, code_dim, init) as module extra state
|
|
106
|
+
# so that checkpoints carry sufficient information for inspection and
|
|
107
|
+
# validation when reloading.
|
|
108
|
+
def get_extra_state(self) -> dict[str, Any]:
|
|
109
|
+
"""Return metadata saved alongside parameters in ``state_dict``."""
|
|
110
|
+
|
|
111
|
+
return {
|
|
112
|
+
"version": 1,
|
|
113
|
+
"K": self.K,
|
|
114
|
+
"code_dim": self.code_dim,
|
|
115
|
+
"init": self.init,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
def set_extra_state(self, state: Any) -> None:
|
|
119
|
+
"""Load metadata saved via :meth:`get_extra_state`.
|
|
120
|
+
|
|
121
|
+
This validates that the incoming metadata matches the current instance
|
|
122
|
+
shape, which helps catch accidental loads into mismatched modules.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
if not isinstance(state, dict): # pragma: no cover - defensive
|
|
126
|
+
return
|
|
127
|
+
# Optional version check for forward-compat.
|
|
128
|
+
_ = int(state.get("version", 1))
|
|
129
|
+
K = int(state.get("K", self.K))
|
|
130
|
+
D = int(state.get("code_dim", self.code_dim))
|
|
131
|
+
init = str(state.get("init", self.init))
|
|
132
|
+
# Validate but do not mutate core attributes.
|
|
133
|
+
if (self.K, self.code_dim) != (K, D): # pragma: no cover - rare
|
|
134
|
+
raise RuntimeError(
|
|
135
|
+
"Loaded Codebook metadata does not match current shape: "
|
|
136
|
+
f"got (K={K}, code_dim={D}) vs current (K={self.K}, code_dim={self.code_dim})"
|
|
137
|
+
)
|
|
138
|
+
if init not in {"uniform", "normal"}: # pragma: no cover - rare
|
|
139
|
+
# Ignore unknown init; keep current.
|
|
140
|
+
return
|
|
141
|
+
# Keep self.init as-is to avoid surprising behavior during load; users
|
|
142
|
+
# can call reset_parameters() explicitly if they wish to reinitialize.
|
|
143
|
+
|
|
144
|
+
# PyTorch automatically includes the return of get_extra_state() under the
|
|
145
|
+
# special key "_extra_state" inside state_dict(). Our tests expect the
|
|
146
|
+
# public state_dict to expose only the trainable parameter 'centers'.
|
|
147
|
+
# Override state_dict to drop the special entry while keeping load_state_dict
|
|
148
|
+
# behavior unchanged (it tolerates missing extra state).
|
|
149
|
+
T_destination = TypeVar("T_destination", bound=MutableMapping[str, Any])
|
|
150
|
+
|
|
151
|
+
@overload
|
|
152
|
+
def state_dict(
|
|
153
|
+
self, *, destination: T_destination, prefix: str = "", keep_vars: bool = False
|
|
154
|
+
) -> T_destination: ...
|
|
155
|
+
|
|
156
|
+
@overload
|
|
157
|
+
def state_dict(self, *, prefix: str = "", keep_vars: bool = False) -> dict[str, Any]: ...
|
|
158
|
+
|
|
159
|
+
def state_dict(self, *args: Any, **kwargs: Any) -> MutableMapping[str, Any]:
|
|
160
|
+
sd_any = super().state_dict(*args, **kwargs)
|
|
161
|
+
sd = cast(MutableMapping[str, Any], sd_any)
|
|
162
|
+
with suppress(Exception):
|
|
163
|
+
# Drop extra metadata key while preserving parameter entries
|
|
164
|
+
sd.pop("_extra_state", None)
|
|
165
|
+
return sd
|
|
166
|
+
|
|
167
|
+
def load_state_dict(self, state_dict: Any, strict: bool = True, assign: bool = False) -> Any:
|
|
168
|
+
"""Load centers while tolerating missing _extra_state in strict mode.
|
|
169
|
+
|
|
170
|
+
Tests expect public state_dicts to contain only 'centers'. Since
|
|
171
|
+
PyTorch includes get_extra_state() under '_extra_state' by default,
|
|
172
|
+
we synthesize a minimal extra state if it's missing so strict=True
|
|
173
|
+
loads remain compatible.
|
|
174
|
+
"""
|
|
175
|
+
try:
|
|
176
|
+
if isinstance(state_dict, dict) and "_extra_state" not in state_dict:
|
|
177
|
+
sd = OrderedDict(state_dict)
|
|
178
|
+
sd["_extra_state"] = self.get_extra_state()
|
|
179
|
+
return super().load_state_dict(sd, strict=strict, assign=assign)
|
|
180
|
+
except Exception:
|
|
181
|
+
pass
|
|
182
|
+
return super().load_state_dict(state_dict, strict=strict, assign=assign)
|
nervecode/core/shapes.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Shape helpers for coding engines.
|
|
2
|
+
|
|
3
|
+
This module provides utilities to flatten and unflatten the leading dimensions
|
|
4
|
+
of tensors that follow the common Nervecode convention of a feature or code
|
|
5
|
+
dimension as the final axis. This allows assignment and distance computations to
|
|
6
|
+
operate uniformly on shapes like:
|
|
7
|
+
|
|
8
|
+
- batch-only: ``(B, D)`` or probabilities ``(B, K)``,
|
|
9
|
+
- token-like: ``(B, T, D)`` or ``(B, T, K)``,
|
|
10
|
+
- pooled convolution: ``(B, H, W, D)`` or ``(B, H, W, K)``.
|
|
11
|
+
|
|
12
|
+
By flattening the leading dimensions, engines can work over a simple 2D view
|
|
13
|
+
``(N, D)`` or ``(N, K)`` where ``N = prod(leading_shape)``. The corresponding
|
|
14
|
+
unflatten operation restores the original leading layout.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"flatten_leading",
|
|
23
|
+
"leading_shape",
|
|
24
|
+
"unflatten_leading",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def leading_shape(x: torch.Tensor) -> tuple[int, ...]:
|
|
29
|
+
"""Return the tuple of leading dimensions before the final axis.
|
|
30
|
+
|
|
31
|
+
Given a tensor ``x`` with shape ``(..., D)``, this returns ``...``.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
- x: Tensor with at least one dimension.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
- Tuple of ints representing the leading shape.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
if x.ndim < 1:
|
|
41
|
+
raise ValueError("x must have at least one dimension (..., D)")
|
|
42
|
+
# Explicit int cast to avoid mypy complaints about torch.Size contents.
|
|
43
|
+
return tuple(int(s) for s in x.shape[:-1])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def flatten_leading(x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, ...]]:
|
|
47
|
+
"""Flatten all leading dims into a single dimension.
|
|
48
|
+
|
|
49
|
+
For an input with shape ``(..., D)``, returns ``x_flat`` with shape
|
|
50
|
+
``(N, D)`` where ``N = prod(...)`` along with the original leading shape.
|
|
51
|
+
|
|
52
|
+
This is a light wrapper around ``reshape`` that preserves non-contiguity
|
|
53
|
+
safety.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
- x: Tensor of shape ``(..., D)``.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
- x_flat: Tensor of shape ``(N, D)`` where ``N = prod(leading_shape)``.
|
|
60
|
+
- lead: The original leading shape as a tuple of ints.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
if x.ndim < 1:
|
|
64
|
+
raise ValueError("x must have at least one dimension (..., D)")
|
|
65
|
+
lead = leading_shape(x)
|
|
66
|
+
d = int(x.shape[-1])
|
|
67
|
+
# Compute N = prod(lead) safely even when lead is empty.
|
|
68
|
+
n = 1
|
|
69
|
+
for s in lead:
|
|
70
|
+
n *= int(s)
|
|
71
|
+
# Use reshape to handle non-contiguous inputs robustly.
|
|
72
|
+
x_flat = x.reshape(n, d)
|
|
73
|
+
return x_flat, lead
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def unflatten_leading(y: torch.Tensor, lead: tuple[int, ...]) -> torch.Tensor:
|
|
77
|
+
"""Restore a flattened first dimension back to ``lead``.
|
|
78
|
+
|
|
79
|
+
If ``y`` has shape ``(N, ...)`` where ``N = prod(lead)``, the returned
|
|
80
|
+
tensor has shape ``(*lead, ...)``. This works for scalar-like per-position
|
|
81
|
+
tensors (``(N,)``), probability tensors (``(N, K)``), and vector-valued
|
|
82
|
+
tensors (``(N, D)``).
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
- y: Tensor whose first dimension flattens a known set of leading dims.
|
|
86
|
+
- lead: Tuple of leading sizes that were previously flattened.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
- Tensor reshaped to ``(*lead, *y.shape[1:])``.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
# Validate that y has at least one dimension representing the flattened N.
|
|
93
|
+
if y.ndim < 1:
|
|
94
|
+
raise ValueError("y must have at least one dimension (N, ...)")
|
|
95
|
+
|
|
96
|
+
# Compute expected N and verify it matches when eager shapes are available.
|
|
97
|
+
n = 1
|
|
98
|
+
for s in lead:
|
|
99
|
+
n *= int(s)
|
|
100
|
+
if int(y.shape[0]) != n:
|
|
101
|
+
raise ValueError(f"first dimension {int(y.shape[0])} does not match prod(lead)={n}")
|
|
102
|
+
|
|
103
|
+
# Reshape back to the original leading layout, preserving trailing dims.
|
|
104
|
+
if y.ndim == 1:
|
|
105
|
+
# Special case: scalar-like values (N,) -> (*lead,)
|
|
106
|
+
return y.reshape(*lead)
|
|
107
|
+
return y.reshape(*lead, *y.shape[1:])
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""Temperature schedules for coding and training.
|
|
2
|
+
|
|
3
|
+
This module provides a minimal interface and two concrete schedules for the
|
|
4
|
+
MVP: a fixed temperature and a cosine schedule. The interface is designed so
|
|
5
|
+
that a future learnable temperature (e.g., an ``nn.Parameter`` inside an
|
|
6
|
+
``nn.Module``) can conform to the same calling convention.
|
|
7
|
+
|
|
8
|
+
Design goals
|
|
9
|
+
- Keep evaluation cheap and dependency-light.
|
|
10
|
+
- Accept either explicit progress (``0..1``) or step-based inputs.
|
|
11
|
+
- Provide a tensor-producing helper for easy integration in Torch graphs.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, cast, runtime_checkable
|
|
18
|
+
|
|
19
|
+
# Optional PyTorch typing without a hard runtime dependency.
|
|
20
|
+
if TYPE_CHECKING: # pragma: no cover - typing-only branch
|
|
21
|
+
import torch as _torch
|
|
22
|
+
from torch import Tensor as _Tensor
|
|
23
|
+
|
|
24
|
+
TensorT: TypeAlias = _Tensor
|
|
25
|
+
DeviceT: TypeAlias = _torch.device
|
|
26
|
+
DTypeT: TypeAlias = _torch.dtype
|
|
27
|
+
else: # Fallbacks keep the module importable without torch installed.
|
|
28
|
+
TensorT: TypeAlias = Any
|
|
29
|
+
DeviceT: TypeAlias = Any
|
|
30
|
+
DTypeT: TypeAlias = Any
|
|
31
|
+
|
|
32
|
+
__all__ = ["CosineTemperature", "FixedTemperature", "TemperatureSchedule"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _require_torch() -> Any:
|
|
36
|
+
"""Import torch on-demand with a helpful error if missing.
|
|
37
|
+
|
|
38
|
+
This keeps the module importable when PyTorch is not installed while still
|
|
39
|
+
supporting ``as_tensor`` helpers for users who have it.
|
|
40
|
+
"""
|
|
41
|
+
try:
|
|
42
|
+
import torch
|
|
43
|
+
except Exception as exc: # pragma: no cover - exercised in envs without torch
|
|
44
|
+
raise RuntimeError(
|
|
45
|
+
"as_tensor requires PyTorch; install 'torch' to use this method."
|
|
46
|
+
) from exc
|
|
47
|
+
return torch
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@runtime_checkable
|
|
51
|
+
class TemperatureSchedule(Protocol):
|
|
52
|
+
"""Callable interface for temperature schedules.
|
|
53
|
+
|
|
54
|
+
Implementations should return a positive scalar temperature when invoked.
|
|
55
|
+
Schedules may accept either normalized progress in ``[0, 1]`` via the
|
|
56
|
+
``progress`` keyword, or a ``step`` index with an associated total number
|
|
57
|
+
of steps ``total``. When both are supplied, ``progress`` takes precedence.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __call__(
|
|
61
|
+
self,
|
|
62
|
+
step: int | None = None,
|
|
63
|
+
*,
|
|
64
|
+
total: int | None = None,
|
|
65
|
+
progress: float | None = None,
|
|
66
|
+
) -> float: # pragma: no cover - interface only
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
def as_tensor(
|
|
70
|
+
self,
|
|
71
|
+
step: int | None = None,
|
|
72
|
+
*,
|
|
73
|
+
total: int | None = None,
|
|
74
|
+
progress: float | None = None,
|
|
75
|
+
device: DeviceT | None = None,
|
|
76
|
+
dtype: DTypeT | None = None,
|
|
77
|
+
) -> TensorT: # pragma: no cover - interface only
|
|
78
|
+
...
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class FixedTemperature:
|
|
82
|
+
"""Constant temperature independent of step or progress."""
|
|
83
|
+
|
|
84
|
+
def __init__(self, temperature: float) -> None:
|
|
85
|
+
t = float(temperature)
|
|
86
|
+
if not math.isfinite(t) or t <= 0.0:
|
|
87
|
+
raise ValueError("temperature must be a positive finite float")
|
|
88
|
+
self._t: float = t
|
|
89
|
+
|
|
90
|
+
def __repr__(self) -> str: # pragma: no cover - cosmetic
|
|
91
|
+
return f"FixedTemperature(temperature={self._t})"
|
|
92
|
+
|
|
93
|
+
def __call__(
|
|
94
|
+
self,
|
|
95
|
+
step: int | None = None,
|
|
96
|
+
*,
|
|
97
|
+
total: int | None = None,
|
|
98
|
+
progress: float | None = None,
|
|
99
|
+
) -> float:
|
|
100
|
+
# Parameters are accepted for interface uniformity; they do not affect
|
|
101
|
+
# a fixed schedule.
|
|
102
|
+
_ = step, total, progress
|
|
103
|
+
return self._t
|
|
104
|
+
|
|
105
|
+
def as_tensor(
|
|
106
|
+
self,
|
|
107
|
+
step: int | None = None,
|
|
108
|
+
*,
|
|
109
|
+
total: int | None = None,
|
|
110
|
+
progress: float | None = None,
|
|
111
|
+
device: DeviceT | None = None,
|
|
112
|
+
dtype: DTypeT | None = None,
|
|
113
|
+
) -> TensorT:
|
|
114
|
+
_ = step, total, progress
|
|
115
|
+
torch = _require_torch()
|
|
116
|
+
return cast("TensorT", torch.tensor(self._t, device=device, dtype=dtype))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class CosineTemperature:
|
|
120
|
+
"""Cosine schedule between ``start`` and ``end`` with optional warmup.
|
|
121
|
+
|
|
122
|
+
The temperature follows the closed-form cosine schedule
|
|
123
|
+
|
|
124
|
+
``T(p) = end + (start - end) * 0.5 * (1 + cos(pi * p))``
|
|
125
|
+
|
|
126
|
+
where ``p`` is normalized progress in ``[0, 1]``. Progress can be supplied
|
|
127
|
+
directly, or derived from ``step``/``total`` while honoring
|
|
128
|
+
``warmup_steps``. During warmup, the temperature stays at ``start``.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
- start: Temperature at the beginning (``p = 0``). Must be ``> 0``.
|
|
132
|
+
- end: Temperature at the end (``p = 1``). Must be ``> 0``.
|
|
133
|
+
- total_steps: Default total steps for the schedule. Can be overridden per
|
|
134
|
+
call using the ``total=`` kwarg.
|
|
135
|
+
- warmup_steps: Number of initial steps to hold at ``start`` before
|
|
136
|
+
annealing begins. Must be ``>= 0``.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
start: float,
|
|
142
|
+
end: float,
|
|
143
|
+
*,
|
|
144
|
+
total_steps: int | None = None,
|
|
145
|
+
warmup_steps: int = 0,
|
|
146
|
+
) -> None:
|
|
147
|
+
s = float(start)
|
|
148
|
+
e = float(end)
|
|
149
|
+
if not math.isfinite(s) or s <= 0.0:
|
|
150
|
+
raise ValueError("start must be a positive finite float")
|
|
151
|
+
if not math.isfinite(e) or e <= 0.0:
|
|
152
|
+
raise ValueError("end must be a positive finite float")
|
|
153
|
+
if total_steps is not None and int(total_steps) <= 0:
|
|
154
|
+
raise ValueError("total_steps must be >= 1 when provided")
|
|
155
|
+
if int(warmup_steps) < 0:
|
|
156
|
+
raise ValueError("warmup_steps must be >= 0")
|
|
157
|
+
|
|
158
|
+
self.start: float = s
|
|
159
|
+
self.end: float = e
|
|
160
|
+
self.total_steps: int | None = int(total_steps) if total_steps is not None else None
|
|
161
|
+
self.warmup_steps: int = int(warmup_steps)
|
|
162
|
+
|
|
163
|
+
def __repr__(self) -> str: # pragma: no cover - cosmetic
|
|
164
|
+
return (
|
|
165
|
+
"CosineTemperature("
|
|
166
|
+
f"start={self.start}, end={self.end}, "
|
|
167
|
+
f"total_steps={self.total_steps}, warmup_steps={self.warmup_steps})"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def _progress_from_step(self, step: int, total: int) -> float:
|
|
171
|
+
# Compute normalized progress in [0, 1] after warmup.
|
|
172
|
+
s = max(0, int(step))
|
|
173
|
+
t = max(1, int(total))
|
|
174
|
+
w = max(0, self.warmup_steps)
|
|
175
|
+
if s <= w:
|
|
176
|
+
return 0.0
|
|
177
|
+
# Remaining steps after warmup; guard against division by zero.
|
|
178
|
+
remain = max(1, t - w)
|
|
179
|
+
p = (s - w) / float(remain)
|
|
180
|
+
# Clamp to [0, 1] to be robust against overrun.
|
|
181
|
+
if p < 0.0:
|
|
182
|
+
return 0.0
|
|
183
|
+
if p > 1.0:
|
|
184
|
+
return 1.0
|
|
185
|
+
return p
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def _cosine_mix(start: float, end: float, p01: float) -> float:
|
|
189
|
+
# Clamp progress defensively.
|
|
190
|
+
p = 0.0 if p01 < 0.0 else (1.0 if p01 > 1.0 else p01)
|
|
191
|
+
return end + (start - end) * 0.5 * (1.0 + math.cos(math.pi * p))
|
|
192
|
+
|
|
193
|
+
def __call__(
|
|
194
|
+
self,
|
|
195
|
+
step: int | None = None,
|
|
196
|
+
*,
|
|
197
|
+
total: int | None = None,
|
|
198
|
+
progress: float | None = None,
|
|
199
|
+
) -> float:
|
|
200
|
+
if progress is not None:
|
|
201
|
+
return self._cosine_mix(self.start, self.end, float(progress))
|
|
202
|
+
|
|
203
|
+
# Derive progress from steps; prefer call-time total over default.
|
|
204
|
+
resolved_total = int(total) if total is not None else self.total_steps
|
|
205
|
+
if resolved_total is None or step is None:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
"CosineTemperature requires either progress in [0,1] or both step and total"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
p = self._progress_from_step(int(step), int(resolved_total))
|
|
211
|
+
# During warmup (_progress_from_step returns 0), hold at start when step <= warmup
|
|
212
|
+
if int(step) <= self.warmup_steps:
|
|
213
|
+
return float(self.start)
|
|
214
|
+
return self._cosine_mix(self.start, self.end, p)
|
|
215
|
+
|
|
216
|
+
def as_tensor(
|
|
217
|
+
self,
|
|
218
|
+
step: int | None = None,
|
|
219
|
+
*,
|
|
220
|
+
total: int | None = None,
|
|
221
|
+
progress: float | None = None,
|
|
222
|
+
device: DeviceT | None = None,
|
|
223
|
+
dtype: DTypeT | None = None,
|
|
224
|
+
) -> TensorT:
|
|
225
|
+
val = self(step, total=total, progress=progress)
|
|
226
|
+
torch = _require_torch()
|
|
227
|
+
return cast("TensorT", torch.tensor(val, device=device, dtype=dtype))
|
nervecode/core/trace.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Trace data structures for coding.
|
|
2
|
+
|
|
3
|
+
This module defines lightweight, immutable containers that carry the rich
|
|
4
|
+
per-layer record produced during coding. Traces support arbitrary leading
|
|
5
|
+
dimensions so they can represent batch-only, token-like, or pooled layouts.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from .types import SoftCode
|
|
14
|
+
|
|
15
|
+
__all__ = ["CodingTrace"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class CodingTrace:
|
|
20
|
+
"""Per-layer coding trace.
|
|
21
|
+
|
|
22
|
+
Carries the reduced activation, reduction metadata, nearest-center and
|
|
23
|
+
commitment distances, chosen indices, and the corresponding ``SoftCode``.
|
|
24
|
+
|
|
25
|
+
Fields
|
|
26
|
+
- ``reduced`` (Tensor): reduced activation with shape ``(..., D)``.
|
|
27
|
+
- ``reduction_meta`` (dict): metadata describing the applied reduction
|
|
28
|
+
(e.g., method, parameters). May be empty when no reduction is applied.
|
|
29
|
+
- ``nearest_center_distances`` (Tensor): distance to the selected center
|
|
30
|
+
per position with shape matching the leading dimensions ``...``.
|
|
31
|
+
- ``chosen_center_indices`` (Tensor): index of the selected/nearest center
|
|
32
|
+
per position with shape matching the leading dimensions ``...``.
|
|
33
|
+
- ``commitment_distances`` (Tensor): commitment distance per position with
|
|
34
|
+
shape matching the leading dimensions ``...``.
|
|
35
|
+
- ``soft_code`` (SoftCode): soft assignment statistics over ``K`` codes
|
|
36
|
+
with probabilities of shape ``(..., K)``.
|
|
37
|
+
|
|
38
|
+
All scalar-like fields must match the leading shape of ``soft_code.probs``
|
|
39
|
+
exactly (i.e., ``tensor.shape == soft_code.probs.shape[:-1]``).
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
reduced: Any
|
|
43
|
+
reduction_meta: dict[str, Any]
|
|
44
|
+
nearest_center_distances: Any
|
|
45
|
+
chosen_center_indices: Any
|
|
46
|
+
commitment_distances: Any
|
|
47
|
+
soft_code: SoftCode
|
|
48
|
+
|
|
49
|
+
def __post_init__(self) -> None:
|
|
50
|
+
# Import torch lazily to keep import-time dependencies lightweight.
|
|
51
|
+
from typing import Any
|
|
52
|
+
from typing import cast as _cast
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
import torch
|
|
56
|
+
except Exception: # pragma: no cover - torch may be absent in some envs
|
|
57
|
+
torch = _cast(Any, None)
|
|
58
|
+
|
|
59
|
+
# Basic validation of required components
|
|
60
|
+
if not isinstance(self.soft_code, SoftCode): # pragma: no cover - defensive
|
|
61
|
+
raise TypeError("soft_code must be a SoftCode instance")
|
|
62
|
+
|
|
63
|
+
# Validate reduced tensor shape
|
|
64
|
+
if torch is not None and not isinstance(
|
|
65
|
+
self.reduced, torch.Tensor
|
|
66
|
+
): # pragma: no cover - defensive
|
|
67
|
+
raise TypeError("reduced must be a torch.Tensor")
|
|
68
|
+
if getattr(self.reduced, "ndim", None) is None or self.reduced.ndim < 1:
|
|
69
|
+
raise ValueError("reduced must have at least one dimension (..., D)")
|
|
70
|
+
if int(self.reduced.shape[-1]) <= 0:
|
|
71
|
+
raise ValueError("final reduced dimension D must be >= 1")
|
|
72
|
+
|
|
73
|
+
# If torch isn't available, skip dtype/shape validations that require it.
|
|
74
|
+
if torch is None: # pragma: no cover
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
lead_shape = tuple(int(s) for s in self.soft_code.leading_shape)
|
|
78
|
+
reduced_lead = tuple(int(s) for s in self.reduced.shape[:-1])
|
|
79
|
+
if reduced_lead != lead_shape:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"reduced.leading_shape {reduced_lead} must match "
|
|
82
|
+
f"soft_code.leading_shape {lead_shape}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def _check_scalar_tensor(name: str, value: Any, *, require_integer: bool = False) -> None:
|
|
86
|
+
if not isinstance(value, torch.Tensor):
|
|
87
|
+
raise TypeError(f"{name} must be a torch.Tensor")
|
|
88
|
+
if tuple(int(s) for s in value.shape) != lead_shape:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"{name} must have shape {lead_shape} to match soft_code.leading_shape"
|
|
91
|
+
)
|
|
92
|
+
if require_integer:
|
|
93
|
+
if torch.is_floating_point(value) or value.dtype == torch.bool:
|
|
94
|
+
raise TypeError(f"{name} must use an integer dtype")
|
|
95
|
+
else:
|
|
96
|
+
if value.dtype == torch.bool:
|
|
97
|
+
raise TypeError(f"{name} must be numeric, not bool")
|
|
98
|
+
|
|
99
|
+
_check_scalar_tensor("nearest_center_distances", self.nearest_center_distances)
|
|
100
|
+
_check_scalar_tensor(
|
|
101
|
+
"chosen_center_indices", self.chosen_center_indices, require_integer=True
|
|
102
|
+
)
|
|
103
|
+
_check_scalar_tensor("commitment_distances", self.commitment_distances)
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def leading_shape(self) -> tuple[int, ...]:
|
|
107
|
+
"""Leading shape shared by all per-position fields."""
|
|
108
|
+
|
|
109
|
+
return tuple(int(s) for s in self.soft_code.leading_shape)
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def reduced_dim(self) -> int:
|
|
113
|
+
"""Size of the final reduced dimension ``D``."""
|
|
114
|
+
|
|
115
|
+
return int(self.reduced.shape[-1])
|
|
116
|
+
|
|
117
|
+
# --- Sample-level reduction views -------------------------------------------
|
|
118
|
+
def sample_reduced_surprise(self, *, reduction: str = "mean") -> Any:
|
|
119
|
+
"""Return a per-sample surprise tensor by reducing extra leading dims.
|
|
120
|
+
|
|
121
|
+
Many wrappers operate over activations with multiple leading dimensions
|
|
122
|
+
(e.g., batch and sequence). Aggregation across layers requires a common
|
|
123
|
+
per-sample view. This helper reduces the per-position surprise stored on
|
|
124
|
+
the associated ``SoftCode`` to a vector over samples by collapsing all
|
|
125
|
+
leading dimensions after the first (typically sequence/time) using the
|
|
126
|
+
specified reduction.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
- reduction: Current supported value is ``"mean"``. Additional
|
|
130
|
+
reductions may be added later.
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
- A tensor with shape ``(B,)`` where ``B`` is the size of the first
|
|
134
|
+
leading dimension. In environments without torch or when an
|
|
135
|
+
interpretable surprise is unavailable, this may return ``None`` at
|
|
136
|
+
runtime; the return type is ``Any`` to remain flexible during import
|
|
137
|
+
without a hard torch dependency.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
# Import torch lazily to avoid hard dependency at import time.
|
|
141
|
+
try:
|
|
142
|
+
import torch
|
|
143
|
+
except Exception: # pragma: no cover - environments without torch
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
# Prefer a combined surprise when present; otherwise use best_length.
|
|
147
|
+
cand = self.soft_code.combined_surprise
|
|
148
|
+
if cand is None:
|
|
149
|
+
cand = self.soft_code.best_length
|
|
150
|
+
if not isinstance(cand, torch.Tensor):
|
|
151
|
+
return None
|
|
152
|
+
|
|
153
|
+
# Scalar-like can't be mapped to samples reliably; skip.
|
|
154
|
+
if cand.ndim == 0: # pragma: no cover - uncommon layout
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
if cand.ndim == 1:
|
|
158
|
+
return cand
|
|
159
|
+
|
|
160
|
+
# Reduce all leading dims after the first to obtain a per-sample vector.
|
|
161
|
+
if reduction == "mean":
|
|
162
|
+
dims = tuple(range(1, int(cand.ndim)))
|
|
163
|
+
return cand.mean(dim=dims)
|
|
164
|
+
# Fallback: default to mean for unknown identifiers to remain fail-open.
|
|
165
|
+
dims = tuple(range(1, int(cand.ndim)))
|
|
166
|
+
return cand.mean(dim=dims)
|