nervecode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nervecode/__init__.py +415 -0
- nervecode/_version.py +10 -0
- nervecode/core/__init__.py +19 -0
- nervecode/core/assignment.py +165 -0
- nervecode/core/codebook.py +182 -0
- nervecode/core/shapes.py +107 -0
- nervecode/core/temperature.py +227 -0
- nervecode/core/trace.py +166 -0
- nervecode/core/types.py +116 -0
- nervecode/integration/__init__.py +9 -0
- nervecode/layers/__init__.py +15 -0
- nervecode/layers/base.py +333 -0
- nervecode/layers/conv.py +174 -0
- nervecode/layers/linear.py +176 -0
- nervecode/layers/reducers.py +80 -0
- nervecode/layers/wrap.py +223 -0
- nervecode/scoring/__init__.py +20 -0
- nervecode/scoring/aggregator.py +369 -0
- nervecode/scoring/calibrator.py +396 -0
- nervecode/scoring/types.py +33 -0
- nervecode/training/__init__.py +25 -0
- nervecode/training/diagnostics.py +194 -0
- nervecode/training/loss.py +188 -0
- nervecode/training/updaters.py +168 -0
- nervecode/utils/__init__.py +14 -0
- nervecode/utils/overhead.py +177 -0
- nervecode/utils/seed.py +161 -0
- nervecode-0.1.0.dist-info/METADATA +83 -0
- nervecode-0.1.0.dist-info/RECORD +31 -0
- nervecode-0.1.0.dist-info/WHEEL +4 -0
- nervecode-0.1.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
"""Empirical percentile calibrator and state.
|
|
2
|
+
|
|
3
|
+
This module provides a minimal, torch-agnostic-at-import calibrator that
|
|
4
|
+
collects an empirical surprise score distribution on held-out
|
|
5
|
+
in-distribution data. The calibrator stores the sorted distribution, the
|
|
6
|
+
chosen threshold quantiles, and lightweight metadata sufficient to reproduce
|
|
7
|
+
calibration later (e.g., index rule, sample count, dtype).
|
|
8
|
+
|
|
9
|
+
Threshold comparison, percentile lookup, and persistence are implemented in
|
|
10
|
+
subsequent tasks; the goal here is to make calibration runs produce a durable
|
|
11
|
+
state snapshot that downstream code can consume.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from typing import Any, cast
|
|
19
|
+
|
|
20
|
+
try: # Keep import-time behavior tolerant in environments without torch
|
|
21
|
+
import torch
|
|
22
|
+
except Exception: # pragma: no cover - torch is a project dependency in tests
|
|
23
|
+
torch = cast(Any, None)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"CalibrationState",
|
|
27
|
+
"EmpiricalPercentileCalibrator",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class CalibrationState:
|
|
33
|
+
"""Stored result of a calibration run.
|
|
34
|
+
|
|
35
|
+
Fields
|
|
36
|
+
- ``sorted_surprise``: 1D tensor of per-sample surprise scores sorted in
|
|
37
|
+
ascending order. Typically a ``torch.Tensor`` on CPU.
|
|
38
|
+
- ``threshold_quantiles``: tuple of quantiles in ``[0, 1]`` that were
|
|
39
|
+
used to choose thresholds.
|
|
40
|
+
- ``threshold_values``: 1D tensor of threshold values aligned with
|
|
41
|
+
``threshold_quantiles`` using a deterministic index rule (see metadata).
|
|
42
|
+
- ``metadata``: free-form details needed to reproduce calibration later
|
|
43
|
+
(e.g., index rule, number of samples, dtype, aggregation hint).
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
sorted_surprise: Any
|
|
47
|
+
threshold_quantiles: tuple[float, ...]
|
|
48
|
+
threshold_values: Any
|
|
49
|
+
metadata: dict[str, Any]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Attempt to allowlist CalibrationState for PyTorch 2.6+ safe loading
|
|
53
|
+
try: # pragma: no cover - environment-dependent
|
|
54
|
+
if torch is not None:
|
|
55
|
+
from torch import serialization as _serialization
|
|
56
|
+
|
|
57
|
+
_add = getattr(_serialization, "add_safe_globals", None)
|
|
58
|
+
if callable(_add):
|
|
59
|
+
_add([CalibrationState])
|
|
60
|
+
except Exception:
|
|
61
|
+
# Older torch versions or environments without the API: skip registration.
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class EmpiricalPercentileCalibrator:
|
|
66
|
+
"""Collect an empirical surprise distribution and store threshold info.
|
|
67
|
+
|
|
68
|
+
The calibrator accepts a per-sample surprise vector, sorts it, and computes
|
|
69
|
+
threshold values at configured quantiles using a deterministic index rule
|
|
70
|
+
based on ``floor(q * (n - 1))``. The full sorted distribution, the chosen
|
|
71
|
+
quantiles, and metadata sufficient for reproduction are stored in a
|
|
72
|
+
``CalibrationState``.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, *, threshold_quantiles: Sequence[float] = (0.95,)) -> None:
|
|
76
|
+
self._threshold_quantiles: tuple[float, ...] = tuple(float(q) for q in threshold_quantiles)
|
|
77
|
+
self._state: CalibrationState | None = None
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def state(self) -> CalibrationState | None:
|
|
81
|
+
"""Return the last computed calibration state, if available."""
|
|
82
|
+
|
|
83
|
+
return self._state
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def threshold_quantiles(self) -> tuple[float, ...]:
|
|
87
|
+
return self._threshold_quantiles
|
|
88
|
+
|
|
89
|
+
def fit(self, scores: Any, *, aggregation: str | None = None) -> CalibrationState:
|
|
90
|
+
"""Calibrate from a per-sample surprise vector and store state.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
- scores: Per-sample surprise signal. Accepts a ``torch.Tensor`` of any
|
|
94
|
+
shape (flattened to 1D), an iterable of numbers, or an object with a
|
|
95
|
+
``surprise`` tensor attribute.
|
|
96
|
+
- aggregation: Optional hint describing how the surprise vector was
|
|
97
|
+
produced (e.g., ``"mean"``). Stored in metadata for later reference.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
- ``CalibrationState`` with sorted distribution, chosen quantiles,
|
|
101
|
+
threshold values, and metadata.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
if torch is None: # pragma: no cover - defensive
|
|
105
|
+
raise RuntimeError("Calibration requires PyTorch to be installed")
|
|
106
|
+
|
|
107
|
+
vec = _as_1d_tensor(scores)
|
|
108
|
+
if vec.numel() == 0:
|
|
109
|
+
raise ValueError("Calibration requires at least one surprise score")
|
|
110
|
+
|
|
111
|
+
# Defensive copy to CPU float for stable sorting and storage.
|
|
112
|
+
vec = vec.detach().to(device="cpu", dtype=torch.float32).contiguous().view(-1)
|
|
113
|
+
sorted_vec, _ = torch.sort(vec, stable=True)
|
|
114
|
+
|
|
115
|
+
# Validate and normalize quantiles
|
|
116
|
+
qs = list(self._threshold_quantiles)
|
|
117
|
+
if not qs:
|
|
118
|
+
qs = [0.95]
|
|
119
|
+
for q in qs:
|
|
120
|
+
if not (0.0 <= q <= 1.0): # strict validation to avoid silent misuse
|
|
121
|
+
raise ValueError(f"Quantile {q} is outside [0, 1]")
|
|
122
|
+
|
|
123
|
+
# Deterministic index rule: floor(q * (n-1)).
|
|
124
|
+
n = int(sorted_vec.numel())
|
|
125
|
+
idxs = [min(n - 1, max(0, int((n - 1) * q))) for q in qs]
|
|
126
|
+
thr = sorted_vec[torch.tensor(idxs, dtype=torch.long)]
|
|
127
|
+
|
|
128
|
+
meta: dict[str, Any] = {
|
|
129
|
+
"method": "empirical_percentile",
|
|
130
|
+
"index_rule": "floor_q_times_n_minus_1",
|
|
131
|
+
"num_scores": n,
|
|
132
|
+
"aggregation": aggregation or "unknown",
|
|
133
|
+
"dtype": str(sorted_vec.dtype).replace("torch.", ""),
|
|
134
|
+
"device": str(sorted_vec.device),
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
state = CalibrationState(
|
|
138
|
+
sorted_surprise=sorted_vec,
|
|
139
|
+
threshold_quantiles=tuple(qs),
|
|
140
|
+
threshold_values=thr,
|
|
141
|
+
metadata=meta,
|
|
142
|
+
)
|
|
143
|
+
self._state = state
|
|
144
|
+
return state
|
|
145
|
+
|
|
146
|
+
# --- Inference-time utilities -------------------------------------------------
|
|
147
|
+
def threshold_for(self, quantile: float | None = None) -> Any:
|
|
148
|
+
"""Return the calibrated threshold value for a given quantile.
|
|
149
|
+
|
|
150
|
+
If ``quantile`` is omitted and exactly one quantile was configured at
|
|
151
|
+
construction time, that single threshold is returned. When a quantile is
|
|
152
|
+
provided that was not part of the original configuration, the value is
|
|
153
|
+
computed on the fly using the stored sorted distribution and the same
|
|
154
|
+
deterministic index rule used during fitting (``floor(q * (n-1))``).
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
if torch is None: # pragma: no cover - defensive
|
|
158
|
+
raise RuntimeError("Calibration requires PyTorch to be installed")
|
|
159
|
+
if self._state is None:
|
|
160
|
+
raise RuntimeError("Calibrator has no state; call fit(...) first")
|
|
161
|
+
|
|
162
|
+
state = self._state
|
|
163
|
+
if quantile is None:
|
|
164
|
+
if len(self._threshold_quantiles) == 1:
|
|
165
|
+
return state.threshold_values[0]
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"Multiple quantiles configured; specify 'quantile' explicitly",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
q = float(quantile)
|
|
171
|
+
if not (0.0 <= q <= 1.0):
|
|
172
|
+
raise ValueError(f"Quantile {q} is outside [0, 1]")
|
|
173
|
+
|
|
174
|
+
# If the quantile matches one from fit(), return the stored value.
|
|
175
|
+
try:
|
|
176
|
+
idx = state.threshold_quantiles.index(q)
|
|
177
|
+
return state.threshold_values[idx]
|
|
178
|
+
except ValueError:
|
|
179
|
+
pass # compute ad-hoc below
|
|
180
|
+
|
|
181
|
+
# Compute ad-hoc using the stored sorted distribution and index rule.
|
|
182
|
+
sorted_vec = cast("torch.Tensor", state.sorted_surprise)
|
|
183
|
+
n = int(sorted_vec.numel())
|
|
184
|
+
idx = min(n - 1, max(0, int((n - 1) * q)))
|
|
185
|
+
return sorted_vec[idx]
|
|
186
|
+
|
|
187
|
+
def percentiles(self, scores: Any) -> Any:
|
|
188
|
+
"""Map surprise values to empirical percentiles in [0, 1].
|
|
189
|
+
|
|
190
|
+
Computes an empirical CDF against the stored ID distribution and
|
|
191
|
+
returns a percentile per input. To be robust across scoring variants
|
|
192
|
+
where either tail may correspond to increased OOD-ness, this method
|
|
193
|
+
automatically selects the upper or lower tail based on the median of
|
|
194
|
+
the provided ``scores`` relative to the ID median:
|
|
195
|
+
- If ``median(scores) >= median(ID)``: use upper-tail CDF
|
|
196
|
+
``p(x) = # {v <= x} / n`` (higher scores → higher percentiles).
|
|
197
|
+
- Else: use lower-tail complement ``p(x) = 1 - # {v <= x} / n``
|
|
198
|
+
(smaller scores → higher percentiles).
|
|
199
|
+
|
|
200
|
+
Supports scalar and batched inputs; tensor inputs preserve their shape
|
|
201
|
+
in the result.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
if torch is None: # pragma: no cover - defensive
|
|
205
|
+
raise RuntimeError("Calibration requires PyTorch to be installed")
|
|
206
|
+
if self._state is None:
|
|
207
|
+
raise RuntimeError("Calibrator has no state; call fit(...) first")
|
|
208
|
+
|
|
209
|
+
x = _as_tensor(scores)
|
|
210
|
+
sorted_vec = cast("torch.Tensor", self._state.sorted_surprise)
|
|
211
|
+
|
|
212
|
+
# Compute counts of values <= x via searchsorted(..., right=True)
|
|
213
|
+
flat = x.reshape(-1)
|
|
214
|
+
counts = torch.searchsorted(sorted_vec, flat, right=True)
|
|
215
|
+
n = float(sorted_vec.numel())
|
|
216
|
+
cdf = counts.to(dtype=torch.float32) / n
|
|
217
|
+
# Tail selection: compare input median vs ID median
|
|
218
|
+
try:
|
|
219
|
+
id_median = sorted_vec[sorted_vec.numel() // 2]
|
|
220
|
+
x_median = torch.median(flat).to(dtype=id_median.dtype, device=id_median.device)
|
|
221
|
+
use_upper = bool(x_median >= id_median)
|
|
222
|
+
except Exception:
|
|
223
|
+
use_upper = True
|
|
224
|
+
p = cdf if use_upper else (1.0 - cdf)
|
|
225
|
+
p = torch.clamp(p, 0.0, 1.0)
|
|
226
|
+
return p.reshape(x.shape)
|
|
227
|
+
|
|
228
|
+
def is_ood(self, scores: Any, *, quantile: float | None = None) -> Any:
|
|
229
|
+
"""Return a boolean mask where inputs are marked OOD.
|
|
230
|
+
|
|
231
|
+
By default, compares the empirical percentile of each input against a
|
|
232
|
+
configured threshold quantile from calibration, marking OOD when
|
|
233
|
+
``percentile >= quantile``. If ``quantile`` is omitted:
|
|
234
|
+
- with a single configured threshold, that quantile is used;
|
|
235
|
+
- with multiple thresholds, the most conservative choice (max quantile)
|
|
236
|
+
is used to avoid ambiguity and keep a sensible default.
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
if torch is None: # pragma: no cover - defensive
|
|
240
|
+
raise RuntimeError("Calibration requires PyTorch to be installed")
|
|
241
|
+
if self._state is None:
|
|
242
|
+
raise RuntimeError("Calibrator has no state; call fit(...) first")
|
|
243
|
+
|
|
244
|
+
q: float
|
|
245
|
+
if quantile is None:
|
|
246
|
+
if len(self._threshold_quantiles) == 0:
|
|
247
|
+
q = 0.95
|
|
248
|
+
elif len(self._threshold_quantiles) == 1:
|
|
249
|
+
q = float(self._threshold_quantiles[0])
|
|
250
|
+
else:
|
|
251
|
+
# Default to the highest configured quantile for a conservative threshold
|
|
252
|
+
q = float(max(self._threshold_quantiles))
|
|
253
|
+
else:
|
|
254
|
+
q = float(quantile)
|
|
255
|
+
if not (0.0 <= q <= 1.0):
|
|
256
|
+
raise ValueError(f"Quantile {q} is outside [0, 1]")
|
|
257
|
+
|
|
258
|
+
p = self.percentiles(scores)
|
|
259
|
+
# 'p' is a tensor shaped like the input; compare against q
|
|
260
|
+
return cast("torch.Tensor", p) >= q
|
|
261
|
+
|
|
262
|
+
# --- Persistence helpers -----------------------------------------------------
|
|
263
|
+
def get_state(self) -> CalibrationState | None:
|
|
264
|
+
"""Return the current CalibrationState (or None if not fitted)."""
|
|
265
|
+
|
|
266
|
+
return self._state
|
|
267
|
+
|
|
268
|
+
def set_state(self, state: CalibrationState) -> None:
|
|
269
|
+
"""Assign an externally provided CalibrationState."""
|
|
270
|
+
|
|
271
|
+
if not isinstance(state, CalibrationState): # pragma: no cover - defensive
|
|
272
|
+
raise TypeError("state must be a CalibrationState")
|
|
273
|
+
self._state = state
|
|
274
|
+
|
|
275
|
+
def state_dict(self) -> dict[str, Any]:
|
|
276
|
+
"""Return a versioned dictionary suitable for checkpointing.
|
|
277
|
+
|
|
278
|
+
Contains the configured threshold quantiles and, when available, the
|
|
279
|
+
fitted CalibrationState with its tensors and metadata.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
out: dict[str, Any] = {
|
|
283
|
+
"__nervecode__": "EmpiricalPercentileCalibrator",
|
|
284
|
+
"version": 1,
|
|
285
|
+
"threshold_quantiles": self._threshold_quantiles,
|
|
286
|
+
}
|
|
287
|
+
if self._state is not None:
|
|
288
|
+
out["state"] = {
|
|
289
|
+
"sorted_surprise": self._state.sorted_surprise,
|
|
290
|
+
"threshold_quantiles": self._state.threshold_quantiles,
|
|
291
|
+
"threshold_values": self._state.threshold_values,
|
|
292
|
+
"metadata": self._state.metadata,
|
|
293
|
+
}
|
|
294
|
+
return out
|
|
295
|
+
|
|
296
|
+
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
|
|
297
|
+
"""Load from a dictionary produced by state_dict()."""
|
|
298
|
+
|
|
299
|
+
if not isinstance(state_dict, Mapping): # pragma: no cover - defensive
|
|
300
|
+
raise TypeError("state_dict must be a mapping")
|
|
301
|
+
# Threshold quantiles configuration
|
|
302
|
+
tq = state_dict.get("threshold_quantiles")
|
|
303
|
+
if tq is not None:
|
|
304
|
+
self._threshold_quantiles = tuple(float(q) for q in list(tq))
|
|
305
|
+
# Optional embedded calibration state
|
|
306
|
+
st = state_dict.get("state")
|
|
307
|
+
if isinstance(st, Mapping) and "sorted_surprise" in st and "threshold_values" in st:
|
|
308
|
+
self._state = CalibrationState(
|
|
309
|
+
sorted_surprise=st["sorted_surprise"],
|
|
310
|
+
threshold_quantiles=tuple(
|
|
311
|
+
float(q) for q in list(st.get("threshold_quantiles", self._threshold_quantiles))
|
|
312
|
+
),
|
|
313
|
+
threshold_values=st["threshold_values"],
|
|
314
|
+
metadata=dict(st.get("metadata", {})),
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def save(self, path: str) -> None:
|
|
318
|
+
"""Save state_dict() to a file using torch.save if available."""
|
|
319
|
+
|
|
320
|
+
if torch is None: # pragma: no cover - defensive
|
|
321
|
+
raise RuntimeError("Saving calibrator state requires PyTorch to be installed")
|
|
322
|
+
torch.save(self.state_dict(), path)
|
|
323
|
+
|
|
324
|
+
@classmethod
|
|
325
|
+
def load(cls, path: str) -> EmpiricalPercentileCalibrator:
|
|
326
|
+
"""Load a calibrator from a file produced by save()."""
|
|
327
|
+
|
|
328
|
+
if torch is None: # pragma: no cover - defensive
|
|
329
|
+
raise RuntimeError("Loading calibrator state requires PyTorch to be installed")
|
|
330
|
+
sd = torch.load(path, map_location="cpu")
|
|
331
|
+
# Attempt to recover configured quantiles; default to embedded state
|
|
332
|
+
tq = sd.get("threshold_quantiles", (0.95,)) if isinstance(sd, Mapping) else (0.95,)
|
|
333
|
+
calib = cls(threshold_quantiles=tq)
|
|
334
|
+
if isinstance(sd, Mapping):
|
|
335
|
+
calib.load_state_dict(sd)
|
|
336
|
+
return calib
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def _as_1d_tensor(scores: Any) -> Any:
|
|
340
|
+
"""Return a 1D tensor view of per-sample surprise values.
|
|
341
|
+
|
|
342
|
+
Supports common inputs without importing project-internal types here.
|
|
343
|
+
- ``torch.Tensor``: flattened to 1D (keeps dtype/device for intermediate ops)
|
|
344
|
+
- object with ``surprise`` attribute: treat it as a tensor and flatten
|
|
345
|
+
- sequence/iterable of numbers: converted to a CPU float tensor
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
if torch is None: # pragma: no cover - defensive
|
|
349
|
+
raise RuntimeError("Calibration requires PyTorch to be installed")
|
|
350
|
+
|
|
351
|
+
if isinstance(scores, torch.Tensor):
|
|
352
|
+
return scores.reshape(-1)
|
|
353
|
+
|
|
354
|
+
cand = getattr(scores, "surprise", None)
|
|
355
|
+
if isinstance(cand, torch.Tensor):
|
|
356
|
+
return cand.reshape(-1)
|
|
357
|
+
|
|
358
|
+
if isinstance(scores, Iterable):
|
|
359
|
+
try:
|
|
360
|
+
return torch.tensor(list(scores), dtype=torch.float32)
|
|
361
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
362
|
+
raise TypeError("Unsupported iterable for calibration scores") from exc
|
|
363
|
+
|
|
364
|
+
raise TypeError("Unsupported input type for calibration scores")
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _as_tensor(scores: Any) -> Any:
|
|
368
|
+
"""Return a tensor view of input scores preserving shape when possible.
|
|
369
|
+
|
|
370
|
+
Accepted inputs mirror ``_as_1d_tensor`` but keep the original shape for
|
|
371
|
+
tensor-like objects and scalars. Iterables are materialized into a 1D
|
|
372
|
+
tensor. The returned tensor uses ``float32`` when constructed from Python
|
|
373
|
+
values; otherwise the input dtype/device is preserved.
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
if torch is None: # pragma: no cover - defensive
|
|
377
|
+
raise RuntimeError("Calibration requires PyTorch to be installed")
|
|
378
|
+
|
|
379
|
+
if isinstance(scores, torch.Tensor):
|
|
380
|
+
return scores
|
|
381
|
+
|
|
382
|
+
cand = getattr(scores, "surprise", None)
|
|
383
|
+
if isinstance(cand, torch.Tensor):
|
|
384
|
+
return cand
|
|
385
|
+
|
|
386
|
+
# Common scalar path
|
|
387
|
+
if isinstance(scores, (int, float)):
|
|
388
|
+
return torch.tensor(scores, dtype=torch.float32)
|
|
389
|
+
|
|
390
|
+
if isinstance(scores, Iterable):
|
|
391
|
+
try:
|
|
392
|
+
return torch.tensor(list(scores), dtype=torch.float32)
|
|
393
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
394
|
+
raise TypeError("Unsupported iterable for calibration scores") from exc
|
|
395
|
+
|
|
396
|
+
raise TypeError("Unsupported input type for calibration scores")
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Scoring/aggregation result types.
|
|
2
|
+
|
|
3
|
+
This module defines lightweight dataclasses used by the scoring/aggregation
|
|
4
|
+
APIs. Keep imports torch-agnostic so the package remains importable in minimal
|
|
5
|
+
environments; fields that are typically ``torch.Tensor`` are typed as ``Any``.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
__all__ = ["AggregatedSurprise"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class AggregatedSurprise:
|
|
18
|
+
"""Result of aggregating per-layer surprise signals.
|
|
19
|
+
|
|
20
|
+
Fields
|
|
21
|
+
- ``surprise``: per-sample aggregated surprise signal (e.g., shape ``(B,)``
|
|
22
|
+
or ``(...,)``). Typically a ``torch.Tensor``.
|
|
23
|
+
- ``method``: aggregation method identifier such as ``"mean"``. Future
|
|
24
|
+
values may include ``"max"`` and ``"weighted"`` without changing the
|
|
25
|
+
public result type.
|
|
26
|
+
- ``num_layers``: number of layer signals included in the aggregation.
|
|
27
|
+
- ``details``: optional aggregation-specific details (e.g., weights).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
surprise: Any
|
|
31
|
+
method: str
|
|
32
|
+
num_layers: int
|
|
33
|
+
details: dict[str, Any] | None = None
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Training-time helpers and losses.
|
|
2
|
+
|
|
3
|
+
This subpackage will include auxiliary losses, schedulers, and small training
|
|
4
|
+
helpers used to optimize codebooks and calibrate scores.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from .diagnostics import (
|
|
10
|
+
CsvDiagnosticsLogger,
|
|
11
|
+
DiagnosticsCollector,
|
|
12
|
+
DiagnosticsRow,
|
|
13
|
+
JsonlDiagnosticsLogger,
|
|
14
|
+
)
|
|
15
|
+
from .loss import CodingLoss
|
|
16
|
+
from .updaters import EmaCodebookUpdater
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"CodingLoss",
|
|
20
|
+
"CsvDiagnosticsLogger",
|
|
21
|
+
"DiagnosticsCollector",
|
|
22
|
+
"DiagnosticsRow",
|
|
23
|
+
"EmaCodebookUpdater",
|
|
24
|
+
"JsonlDiagnosticsLogger",
|
|
25
|
+
]
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Lightweight diagnostics collection and logging backends.
|
|
2
|
+
|
|
3
|
+
This module provides utilities to summarize per-layer CodingTrace objects into
|
|
4
|
+
simple, machine-readable rows and to write those rows to CSV or JSONL without
|
|
5
|
+
bringing extra dependencies. It complements wrapper-level convenience metrics
|
|
6
|
+
by producing explicit records during training, calibration, or evaluation.
|
|
7
|
+
|
|
8
|
+
Design goals
|
|
9
|
+
- Minimal surface; importable without heavy dependencies.
|
|
10
|
+
- Fail-open behavior: best-effort extraction; skip rows on missing tensors.
|
|
11
|
+
- CSV and JSONL outputs for easy downstream processing.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import csv
|
|
17
|
+
import json
|
|
18
|
+
from collections.abc import Iterable, Mapping
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
try: # optional torch for numeric reductions
|
|
23
|
+
import torch
|
|
24
|
+
except Exception: # pragma: no cover - environments without torch
|
|
25
|
+
from typing import Any
|
|
26
|
+
from typing import cast as _cast
|
|
27
|
+
|
|
28
|
+
torch = _cast(Any, None)
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from nervecode.core import CodingTrace
|
|
32
|
+
except Exception: # pragma: no cover - available during normal package use
|
|
33
|
+
CodingTrace = object # type: ignore[misc,assignment]
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"CsvDiagnosticsLogger",
|
|
37
|
+
"DiagnosticsCollector",
|
|
38
|
+
"DiagnosticsRow",
|
|
39
|
+
"JsonlDiagnosticsLogger",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True)
|
|
44
|
+
class DiagnosticsRow:
|
|
45
|
+
"""Single row of layer-level diagnostics derived from a CodingTrace."""
|
|
46
|
+
|
|
47
|
+
step: int | None
|
|
48
|
+
layer: str
|
|
49
|
+
K: int | None
|
|
50
|
+
code_dim: int | None
|
|
51
|
+
utilization: float | None
|
|
52
|
+
dead_code_frac: float | None
|
|
53
|
+
mean_entropy: float | None
|
|
54
|
+
mean_best_length: float | None
|
|
55
|
+
mean_commitment: float | None
|
|
56
|
+
reduction: str | None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class DiagnosticsCollector:
|
|
60
|
+
"""Summarize per-layer traces into DiagnosticsRow entries."""
|
|
61
|
+
|
|
62
|
+
def summarize_trace(
|
|
63
|
+
self, layer_name: str, trace: CodingTrace, *, step: int | None = None
|
|
64
|
+
) -> DiagnosticsRow:
|
|
65
|
+
# Default values
|
|
66
|
+
K: int | None = None
|
|
67
|
+
code_dim: int | None = None
|
|
68
|
+
utilization: float | None = None
|
|
69
|
+
dead_frac: float | None = None
|
|
70
|
+
mean_entropy: float | None = None
|
|
71
|
+
mean_best_len: float | None = None
|
|
72
|
+
mean_commit: float | None = None
|
|
73
|
+
reduction: str | None = None
|
|
74
|
+
|
|
75
|
+
# Resolve K and D from the trace's SoftCode and reduced vectors
|
|
76
|
+
try:
|
|
77
|
+
_k = getattr(trace.soft_code, "code_dim", 0)
|
|
78
|
+
K = int(_k) if isinstance(_k, (int,)) else None
|
|
79
|
+
except Exception:
|
|
80
|
+
K = None
|
|
81
|
+
try:
|
|
82
|
+
_d = getattr(trace, "reduced_dim", 0)
|
|
83
|
+
code_dim = int(_d) if isinstance(_d, (int,)) else None
|
|
84
|
+
except Exception:
|
|
85
|
+
code_dim = None
|
|
86
|
+
|
|
87
|
+
# Utilization and dead-code fraction
|
|
88
|
+
if torch is not None and hasattr(trace.soft_code, "best_indices"):
|
|
89
|
+
try:
|
|
90
|
+
best = trace.soft_code.best_indices
|
|
91
|
+
if not isinstance(best, torch.Tensor):
|
|
92
|
+
best = getattr(trace, "chosen_center_indices", None)
|
|
93
|
+
if isinstance(best, torch.Tensor) and K and K > 0:
|
|
94
|
+
used = len(torch.unique(best.reshape(-1)).tolist())
|
|
95
|
+
utilization = float(used) / float(K)
|
|
96
|
+
dead_frac = max(0.0, 1.0 - utilization)
|
|
97
|
+
except Exception:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
# Entropy, best-length, and commitment distance means
|
|
101
|
+
if torch is not None:
|
|
102
|
+
try:
|
|
103
|
+
ent = getattr(trace.soft_code, "entropy", None)
|
|
104
|
+
if isinstance(ent, torch.Tensor):
|
|
105
|
+
mean_entropy = float(ent.mean().item())
|
|
106
|
+
except Exception:
|
|
107
|
+
pass
|
|
108
|
+
try:
|
|
109
|
+
bl = getattr(trace.soft_code, "best_length", None)
|
|
110
|
+
if isinstance(bl, torch.Tensor):
|
|
111
|
+
mean_best_len = float(bl.mean().item())
|
|
112
|
+
except Exception:
|
|
113
|
+
pass
|
|
114
|
+
try:
|
|
115
|
+
cm = getattr(trace, "commitment_distances", None)
|
|
116
|
+
if isinstance(cm, torch.Tensor):
|
|
117
|
+
mean_commit = float(cm.mean().item())
|
|
118
|
+
except Exception:
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
# Reduction metadata
|
|
122
|
+
try:
|
|
123
|
+
meta = getattr(trace, "reduction_meta", {}) or {}
|
|
124
|
+
reduction = str(meta.get("method")) if isinstance(meta, Mapping) else None
|
|
125
|
+
except Exception:
|
|
126
|
+
reduction = None
|
|
127
|
+
|
|
128
|
+
return DiagnosticsRow(
|
|
129
|
+
step=step,
|
|
130
|
+
layer=layer_name,
|
|
131
|
+
K=K,
|
|
132
|
+
code_dim=code_dim,
|
|
133
|
+
utilization=utilization,
|
|
134
|
+
dead_code_frac=dead_frac,
|
|
135
|
+
mean_entropy=mean_entropy,
|
|
136
|
+
mean_best_length=mean_best_len,
|
|
137
|
+
mean_commitment=mean_commit,
|
|
138
|
+
reduction=reduction,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def collect(
|
|
142
|
+
self, traces: Mapping[str, CodingTrace], *, step: int | None = None
|
|
143
|
+
) -> list[DiagnosticsRow]:
|
|
144
|
+
rows: list[DiagnosticsRow] = []
|
|
145
|
+
for name, trace in traces.items():
|
|
146
|
+
try:
|
|
147
|
+
rows.append(self.summarize_trace(str(name), trace, step=step))
|
|
148
|
+
except Exception:
|
|
149
|
+
continue
|
|
150
|
+
return rows
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class CsvDiagnosticsLogger:
|
|
154
|
+
"""Append diagnostics rows to a CSV file with a fixed header."""
|
|
155
|
+
|
|
156
|
+
def __init__(self, path: str | Path) -> None:
|
|
157
|
+
self._path = Path(path)
|
|
158
|
+
self._header_written = False
|
|
159
|
+
self._fieldnames = [
|
|
160
|
+
"step",
|
|
161
|
+
"layer",
|
|
162
|
+
"K",
|
|
163
|
+
"code_dim",
|
|
164
|
+
"utilization",
|
|
165
|
+
"dead_code_frac",
|
|
166
|
+
"mean_entropy",
|
|
167
|
+
"mean_best_length",
|
|
168
|
+
"mean_commitment",
|
|
169
|
+
"reduction",
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
def log(self, rows: Iterable[DiagnosticsRow]) -> None:
|
|
173
|
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
174
|
+
with self._path.open("a", encoding="utf-8", newline="") as f:
|
|
175
|
+
writer = csv.DictWriter(f, fieldnames=self._fieldnames)
|
|
176
|
+
if not self._header_written and self._path.stat().st_size == 0:
|
|
177
|
+
writer.writeheader()
|
|
178
|
+
self._header_written = True
|
|
179
|
+
for r in rows:
|
|
180
|
+
writer.writerow({k: getattr(r, k) for k in self._fieldnames})
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class JsonlDiagnosticsLogger:
|
|
184
|
+
"""Append diagnostics rows as JSONL objects (one per line)."""
|
|
185
|
+
|
|
186
|
+
def __init__(self, path: str | Path) -> None:
|
|
187
|
+
self._path = Path(path)
|
|
188
|
+
|
|
189
|
+
def log(self, rows: Iterable[DiagnosticsRow]) -> None:
|
|
190
|
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
191
|
+
with self._path.open("a", encoding="utf-8") as f:
|
|
192
|
+
for r in rows:
|
|
193
|
+
f.write(json.dumps(r.__dict__))
|
|
194
|
+
f.write("\n")
|