nervecode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,396 @@
1
+ """Empirical percentile calibrator and state.
2
+
3
+ This module provides a minimal, torch-agnostic-at-import calibrator that
4
+ collects an empirical surprise score distribution on held-out
5
+ in-distribution data. The calibrator stores the sorted distribution, the
6
+ chosen threshold quantiles, and lightweight metadata sufficient to reproduce
7
+ calibration later (e.g., index rule, sample count, dtype).
8
+
9
+ Threshold comparison, percentile lookup, and persistence are implemented in
10
+ subsequent tasks; the goal here is to make calibration runs produce a durable
11
+ state snapshot that downstream code can consume.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from collections.abc import Iterable, Mapping, Sequence
17
+ from dataclasses import dataclass
18
+ from typing import Any, cast
19
+
20
+ try: # Keep import-time behavior tolerant in environments without torch
21
+ import torch
22
+ except Exception: # pragma: no cover - torch is a project dependency in tests
23
+ torch = cast(Any, None)
24
+
25
+ __all__ = [
26
+ "CalibrationState",
27
+ "EmpiricalPercentileCalibrator",
28
+ ]
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class CalibrationState:
33
+ """Stored result of a calibration run.
34
+
35
+ Fields
36
+ - ``sorted_surprise``: 1D tensor of per-sample surprise scores sorted in
37
+ ascending order. Typically a ``torch.Tensor`` on CPU.
38
+ - ``threshold_quantiles``: tuple of quantiles in ``[0, 1]`` that were
39
+ used to choose thresholds.
40
+ - ``threshold_values``: 1D tensor of threshold values aligned with
41
+ ``threshold_quantiles`` using a deterministic index rule (see metadata).
42
+ - ``metadata``: free-form details needed to reproduce calibration later
43
+ (e.g., index rule, number of samples, dtype, aggregation hint).
44
+ """
45
+
46
+ sorted_surprise: Any
47
+ threshold_quantiles: tuple[float, ...]
48
+ threshold_values: Any
49
+ metadata: dict[str, Any]
50
+
51
+
52
+ # Attempt to allowlist CalibrationState for PyTorch 2.6+ safe loading
53
+ try: # pragma: no cover - environment-dependent
54
+ if torch is not None:
55
+ from torch import serialization as _serialization
56
+
57
+ _add = getattr(_serialization, "add_safe_globals", None)
58
+ if callable(_add):
59
+ _add([CalibrationState])
60
+ except Exception:
61
+ # Older torch versions or environments without the API: skip registration.
62
+ pass
63
+
64
+
65
+ class EmpiricalPercentileCalibrator:
66
+ """Collect an empirical surprise distribution and store threshold info.
67
+
68
+ The calibrator accepts a per-sample surprise vector, sorts it, and computes
69
+ threshold values at configured quantiles using a deterministic index rule
70
+ based on ``floor(q * (n - 1))``. The full sorted distribution, the chosen
71
+ quantiles, and metadata sufficient for reproduction are stored in a
72
+ ``CalibrationState``.
73
+ """
74
+
75
+ def __init__(self, *, threshold_quantiles: Sequence[float] = (0.95,)) -> None:
76
+ self._threshold_quantiles: tuple[float, ...] = tuple(float(q) for q in threshold_quantiles)
77
+ self._state: CalibrationState | None = None
78
+
79
+ @property
80
+ def state(self) -> CalibrationState | None:
81
+ """Return the last computed calibration state, if available."""
82
+
83
+ return self._state
84
+
85
+ @property
86
+ def threshold_quantiles(self) -> tuple[float, ...]:
87
+ return self._threshold_quantiles
88
+
89
+ def fit(self, scores: Any, *, aggregation: str | None = None) -> CalibrationState:
90
+ """Calibrate from a per-sample surprise vector and store state.
91
+
92
+ Parameters
93
+ - scores: Per-sample surprise signal. Accepts a ``torch.Tensor`` of any
94
+ shape (flattened to 1D), an iterable of numbers, or an object with a
95
+ ``surprise`` tensor attribute.
96
+ - aggregation: Optional hint describing how the surprise vector was
97
+ produced (e.g., ``"mean"``). Stored in metadata for later reference.
98
+
99
+ Returns
100
+ - ``CalibrationState`` with sorted distribution, chosen quantiles,
101
+ threshold values, and metadata.
102
+ """
103
+
104
+ if torch is None: # pragma: no cover - defensive
105
+ raise RuntimeError("Calibration requires PyTorch to be installed")
106
+
107
+ vec = _as_1d_tensor(scores)
108
+ if vec.numel() == 0:
109
+ raise ValueError("Calibration requires at least one surprise score")
110
+
111
+ # Defensive copy to CPU float for stable sorting and storage.
112
+ vec = vec.detach().to(device="cpu", dtype=torch.float32).contiguous().view(-1)
113
+ sorted_vec, _ = torch.sort(vec, stable=True)
114
+
115
+ # Validate and normalize quantiles
116
+ qs = list(self._threshold_quantiles)
117
+ if not qs:
118
+ qs = [0.95]
119
+ for q in qs:
120
+ if not (0.0 <= q <= 1.0): # strict validation to avoid silent misuse
121
+ raise ValueError(f"Quantile {q} is outside [0, 1]")
122
+
123
+ # Deterministic index rule: floor(q * (n-1)).
124
+ n = int(sorted_vec.numel())
125
+ idxs = [min(n - 1, max(0, int((n - 1) * q))) for q in qs]
126
+ thr = sorted_vec[torch.tensor(idxs, dtype=torch.long)]
127
+
128
+ meta: dict[str, Any] = {
129
+ "method": "empirical_percentile",
130
+ "index_rule": "floor_q_times_n_minus_1",
131
+ "num_scores": n,
132
+ "aggregation": aggregation or "unknown",
133
+ "dtype": str(sorted_vec.dtype).replace("torch.", ""),
134
+ "device": str(sorted_vec.device),
135
+ }
136
+
137
+ state = CalibrationState(
138
+ sorted_surprise=sorted_vec,
139
+ threshold_quantiles=tuple(qs),
140
+ threshold_values=thr,
141
+ metadata=meta,
142
+ )
143
+ self._state = state
144
+ return state
145
+
146
+ # --- Inference-time utilities -------------------------------------------------
147
+ def threshold_for(self, quantile: float | None = None) -> Any:
148
+ """Return the calibrated threshold value for a given quantile.
149
+
150
+ If ``quantile`` is omitted and exactly one quantile was configured at
151
+ construction time, that single threshold is returned. When a quantile is
152
+ provided that was not part of the original configuration, the value is
153
+ computed on the fly using the stored sorted distribution and the same
154
+ deterministic index rule used during fitting (``floor(q * (n-1))``).
155
+ """
156
+
157
+ if torch is None: # pragma: no cover - defensive
158
+ raise RuntimeError("Calibration requires PyTorch to be installed")
159
+ if self._state is None:
160
+ raise RuntimeError("Calibrator has no state; call fit(...) first")
161
+
162
+ state = self._state
163
+ if quantile is None:
164
+ if len(self._threshold_quantiles) == 1:
165
+ return state.threshold_values[0]
166
+ raise ValueError(
167
+ "Multiple quantiles configured; specify 'quantile' explicitly",
168
+ )
169
+
170
+ q = float(quantile)
171
+ if not (0.0 <= q <= 1.0):
172
+ raise ValueError(f"Quantile {q} is outside [0, 1]")
173
+
174
+ # If the quantile matches one from fit(), return the stored value.
175
+ try:
176
+ idx = state.threshold_quantiles.index(q)
177
+ return state.threshold_values[idx]
178
+ except ValueError:
179
+ pass # compute ad-hoc below
180
+
181
+ # Compute ad-hoc using the stored sorted distribution and index rule.
182
+ sorted_vec = cast("torch.Tensor", state.sorted_surprise)
183
+ n = int(sorted_vec.numel())
184
+ idx = min(n - 1, max(0, int((n - 1) * q)))
185
+ return sorted_vec[idx]
186
+
187
+ def percentiles(self, scores: Any) -> Any:
188
+ """Map surprise values to empirical percentiles in [0, 1].
189
+
190
+ Computes an empirical CDF against the stored ID distribution and
191
+ returns a percentile per input. To be robust across scoring variants
192
+ where either tail may correspond to increased OOD-ness, this method
193
+ automatically selects the upper or lower tail based on the median of
194
+ the provided ``scores`` relative to the ID median:
195
+ - If ``median(scores) >= median(ID)``: use upper-tail CDF
196
+ ``p(x) = # {v <= x} / n`` (higher scores → higher percentiles).
197
+ - Else: use lower-tail complement ``p(x) = 1 - # {v <= x} / n``
198
+ (smaller scores → higher percentiles).
199
+
200
+ Supports scalar and batched inputs; tensor inputs preserve their shape
201
+ in the result.
202
+ """
203
+
204
+ if torch is None: # pragma: no cover - defensive
205
+ raise RuntimeError("Calibration requires PyTorch to be installed")
206
+ if self._state is None:
207
+ raise RuntimeError("Calibrator has no state; call fit(...) first")
208
+
209
+ x = _as_tensor(scores)
210
+ sorted_vec = cast("torch.Tensor", self._state.sorted_surprise)
211
+
212
+ # Compute counts of values <= x via searchsorted(..., right=True)
213
+ flat = x.reshape(-1)
214
+ counts = torch.searchsorted(sorted_vec, flat, right=True)
215
+ n = float(sorted_vec.numel())
216
+ cdf = counts.to(dtype=torch.float32) / n
217
+ # Tail selection: compare input median vs ID median
218
+ try:
219
+ id_median = sorted_vec[sorted_vec.numel() // 2]
220
+ x_median = torch.median(flat).to(dtype=id_median.dtype, device=id_median.device)
221
+ use_upper = bool(x_median >= id_median)
222
+ except Exception:
223
+ use_upper = True
224
+ p = cdf if use_upper else (1.0 - cdf)
225
+ p = torch.clamp(p, 0.0, 1.0)
226
+ return p.reshape(x.shape)
227
+
228
+ def is_ood(self, scores: Any, *, quantile: float | None = None) -> Any:
229
+ """Return a boolean mask where inputs are marked OOD.
230
+
231
+ By default, compares the empirical percentile of each input against a
232
+ configured threshold quantile from calibration, marking OOD when
233
+ ``percentile >= quantile``. If ``quantile`` is omitted:
234
+ - with a single configured threshold, that quantile is used;
235
+ - with multiple thresholds, the most conservative choice (max quantile)
236
+ is used to avoid ambiguity and keep a sensible default.
237
+ """
238
+
239
+ if torch is None: # pragma: no cover - defensive
240
+ raise RuntimeError("Calibration requires PyTorch to be installed")
241
+ if self._state is None:
242
+ raise RuntimeError("Calibrator has no state; call fit(...) first")
243
+
244
+ q: float
245
+ if quantile is None:
246
+ if len(self._threshold_quantiles) == 0:
247
+ q = 0.95
248
+ elif len(self._threshold_quantiles) == 1:
249
+ q = float(self._threshold_quantiles[0])
250
+ else:
251
+ # Default to the highest configured quantile for a conservative threshold
252
+ q = float(max(self._threshold_quantiles))
253
+ else:
254
+ q = float(quantile)
255
+ if not (0.0 <= q <= 1.0):
256
+ raise ValueError(f"Quantile {q} is outside [0, 1]")
257
+
258
+ p = self.percentiles(scores)
259
+ # 'p' is a tensor shaped like the input; compare against q
260
+ return cast("torch.Tensor", p) >= q
261
+
262
+ # --- Persistence helpers -----------------------------------------------------
263
+ def get_state(self) -> CalibrationState | None:
264
+ """Return the current CalibrationState (or None if not fitted)."""
265
+
266
+ return self._state
267
+
268
+ def set_state(self, state: CalibrationState) -> None:
269
+ """Assign an externally provided CalibrationState."""
270
+
271
+ if not isinstance(state, CalibrationState): # pragma: no cover - defensive
272
+ raise TypeError("state must be a CalibrationState")
273
+ self._state = state
274
+
275
+ def state_dict(self) -> dict[str, Any]:
276
+ """Return a versioned dictionary suitable for checkpointing.
277
+
278
+ Contains the configured threshold quantiles and, when available, the
279
+ fitted CalibrationState with its tensors and metadata.
280
+ """
281
+
282
+ out: dict[str, Any] = {
283
+ "__nervecode__": "EmpiricalPercentileCalibrator",
284
+ "version": 1,
285
+ "threshold_quantiles": self._threshold_quantiles,
286
+ }
287
+ if self._state is not None:
288
+ out["state"] = {
289
+ "sorted_surprise": self._state.sorted_surprise,
290
+ "threshold_quantiles": self._state.threshold_quantiles,
291
+ "threshold_values": self._state.threshold_values,
292
+ "metadata": self._state.metadata,
293
+ }
294
+ return out
295
+
296
+ def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
297
+ """Load from a dictionary produced by state_dict()."""
298
+
299
+ if not isinstance(state_dict, Mapping): # pragma: no cover - defensive
300
+ raise TypeError("state_dict must be a mapping")
301
+ # Threshold quantiles configuration
302
+ tq = state_dict.get("threshold_quantiles")
303
+ if tq is not None:
304
+ self._threshold_quantiles = tuple(float(q) for q in list(tq))
305
+ # Optional embedded calibration state
306
+ st = state_dict.get("state")
307
+ if isinstance(st, Mapping) and "sorted_surprise" in st and "threshold_values" in st:
308
+ self._state = CalibrationState(
309
+ sorted_surprise=st["sorted_surprise"],
310
+ threshold_quantiles=tuple(
311
+ float(q) for q in list(st.get("threshold_quantiles", self._threshold_quantiles))
312
+ ),
313
+ threshold_values=st["threshold_values"],
314
+ metadata=dict(st.get("metadata", {})),
315
+ )
316
+
317
+ def save(self, path: str) -> None:
318
+ """Save state_dict() to a file using torch.save if available."""
319
+
320
+ if torch is None: # pragma: no cover - defensive
321
+ raise RuntimeError("Saving calibrator state requires PyTorch to be installed")
322
+ torch.save(self.state_dict(), path)
323
+
324
+ @classmethod
325
+ def load(cls, path: str) -> EmpiricalPercentileCalibrator:
326
+ """Load a calibrator from a file produced by save()."""
327
+
328
+ if torch is None: # pragma: no cover - defensive
329
+ raise RuntimeError("Loading calibrator state requires PyTorch to be installed")
330
+ sd = torch.load(path, map_location="cpu")
331
+ # Attempt to recover configured quantiles; default to embedded state
332
+ tq = sd.get("threshold_quantiles", (0.95,)) if isinstance(sd, Mapping) else (0.95,)
333
+ calib = cls(threshold_quantiles=tq)
334
+ if isinstance(sd, Mapping):
335
+ calib.load_state_dict(sd)
336
+ return calib
337
+
338
+
339
+ def _as_1d_tensor(scores: Any) -> Any:
340
+ """Return a 1D tensor view of per-sample surprise values.
341
+
342
+ Supports common inputs without importing project-internal types here.
343
+ - ``torch.Tensor``: flattened to 1D (keeps dtype/device for intermediate ops)
344
+ - object with ``surprise`` attribute: treat it as a tensor and flatten
345
+ - sequence/iterable of numbers: converted to a CPU float tensor
346
+ """
347
+
348
+ if torch is None: # pragma: no cover - defensive
349
+ raise RuntimeError("Calibration requires PyTorch to be installed")
350
+
351
+ if isinstance(scores, torch.Tensor):
352
+ return scores.reshape(-1)
353
+
354
+ cand = getattr(scores, "surprise", None)
355
+ if isinstance(cand, torch.Tensor):
356
+ return cand.reshape(-1)
357
+
358
+ if isinstance(scores, Iterable):
359
+ try:
360
+ return torch.tensor(list(scores), dtype=torch.float32)
361
+ except Exception as exc: # pragma: no cover - defensive
362
+ raise TypeError("Unsupported iterable for calibration scores") from exc
363
+
364
+ raise TypeError("Unsupported input type for calibration scores")
365
+
366
+
367
+ def _as_tensor(scores: Any) -> Any:
368
+ """Return a tensor view of input scores preserving shape when possible.
369
+
370
+ Accepted inputs mirror ``_as_1d_tensor`` but keep the original shape for
371
+ tensor-like objects and scalars. Iterables are materialized into a 1D
372
+ tensor. The returned tensor uses ``float32`` when constructed from Python
373
+ values; otherwise the input dtype/device is preserved.
374
+ """
375
+
376
+ if torch is None: # pragma: no cover - defensive
377
+ raise RuntimeError("Calibration requires PyTorch to be installed")
378
+
379
+ if isinstance(scores, torch.Tensor):
380
+ return scores
381
+
382
+ cand = getattr(scores, "surprise", None)
383
+ if isinstance(cand, torch.Tensor):
384
+ return cand
385
+
386
+ # Common scalar path
387
+ if isinstance(scores, (int, float)):
388
+ return torch.tensor(scores, dtype=torch.float32)
389
+
390
+ if isinstance(scores, Iterable):
391
+ try:
392
+ return torch.tensor(list(scores), dtype=torch.float32)
393
+ except Exception as exc: # pragma: no cover - defensive
394
+ raise TypeError("Unsupported iterable for calibration scores") from exc
395
+
396
+ raise TypeError("Unsupported input type for calibration scores")
@@ -0,0 +1,33 @@
1
+ """Scoring/aggregation result types.
2
+
3
+ This module defines lightweight dataclasses used by the scoring/aggregation
4
+ APIs. Keep imports torch-agnostic so the package remains importable in minimal
5
+ environments; fields that are typically ``torch.Tensor`` are typed as ``Any``.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Any
12
+
13
+ __all__ = ["AggregatedSurprise"]
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class AggregatedSurprise:
18
+ """Result of aggregating per-layer surprise signals.
19
+
20
+ Fields
21
+ - ``surprise``: per-sample aggregated surprise signal (e.g., shape ``(B,)``
22
+ or ``(...,)``). Typically a ``torch.Tensor``.
23
+ - ``method``: aggregation method identifier such as ``"mean"``. Future
24
+ values may include ``"max"`` and ``"weighted"`` without changing the
25
+ public result type.
26
+ - ``num_layers``: number of layer signals included in the aggregation.
27
+ - ``details``: optional aggregation-specific details (e.g., weights).
28
+ """
29
+
30
+ surprise: Any
31
+ method: str
32
+ num_layers: int
33
+ details: dict[str, Any] | None = None
@@ -0,0 +1,25 @@
1
+ """Training-time helpers and losses.
2
+
3
+ This subpackage will include auxiliary losses, schedulers, and small training
4
+ helpers used to optimize codebooks and calibrate scores.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from .diagnostics import (
10
+ CsvDiagnosticsLogger,
11
+ DiagnosticsCollector,
12
+ DiagnosticsRow,
13
+ JsonlDiagnosticsLogger,
14
+ )
15
+ from .loss import CodingLoss
16
+ from .updaters import EmaCodebookUpdater
17
+
18
+ __all__ = [
19
+ "CodingLoss",
20
+ "CsvDiagnosticsLogger",
21
+ "DiagnosticsCollector",
22
+ "DiagnosticsRow",
23
+ "EmaCodebookUpdater",
24
+ "JsonlDiagnosticsLogger",
25
+ ]
@@ -0,0 +1,194 @@
1
+ """Lightweight diagnostics collection and logging backends.
2
+
3
+ This module provides utilities to summarize per-layer CodingTrace objects into
4
+ simple, machine-readable rows and to write those rows to CSV or JSONL without
5
+ bringing extra dependencies. It complements wrapper-level convenience metrics
6
+ by producing explicit records during training, calibration, or evaluation.
7
+
8
+ Design goals
9
+ - Minimal surface; importable without heavy dependencies.
10
+ - Fail-open behavior: best-effort extraction; skip rows on missing tensors.
11
+ - CSV and JSONL outputs for easy downstream processing.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import csv
17
+ import json
18
+ from collections.abc import Iterable, Mapping
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+
22
+ try: # optional torch for numeric reductions
23
+ import torch
24
+ except Exception: # pragma: no cover - environments without torch
25
+ from typing import Any
26
+ from typing import cast as _cast
27
+
28
+ torch = _cast(Any, None)
29
+
30
+ try:
31
+ from nervecode.core import CodingTrace
32
+ except Exception: # pragma: no cover - available during normal package use
33
+ CodingTrace = object # type: ignore[misc,assignment]
34
+
35
+ __all__ = [
36
+ "CsvDiagnosticsLogger",
37
+ "DiagnosticsCollector",
38
+ "DiagnosticsRow",
39
+ "JsonlDiagnosticsLogger",
40
+ ]
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class DiagnosticsRow:
45
+ """Single row of layer-level diagnostics derived from a CodingTrace."""
46
+
47
+ step: int | None
48
+ layer: str
49
+ K: int | None
50
+ code_dim: int | None
51
+ utilization: float | None
52
+ dead_code_frac: float | None
53
+ mean_entropy: float | None
54
+ mean_best_length: float | None
55
+ mean_commitment: float | None
56
+ reduction: str | None
57
+
58
+
59
+ class DiagnosticsCollector:
60
+ """Summarize per-layer traces into DiagnosticsRow entries."""
61
+
62
+ def summarize_trace(
63
+ self, layer_name: str, trace: CodingTrace, *, step: int | None = None
64
+ ) -> DiagnosticsRow:
65
+ # Default values
66
+ K: int | None = None
67
+ code_dim: int | None = None
68
+ utilization: float | None = None
69
+ dead_frac: float | None = None
70
+ mean_entropy: float | None = None
71
+ mean_best_len: float | None = None
72
+ mean_commit: float | None = None
73
+ reduction: str | None = None
74
+
75
+ # Resolve K and D from the trace's SoftCode and reduced vectors
76
+ try:
77
+ _k = getattr(trace.soft_code, "code_dim", 0)
78
+ K = int(_k) if isinstance(_k, (int,)) else None
79
+ except Exception:
80
+ K = None
81
+ try:
82
+ _d = getattr(trace, "reduced_dim", 0)
83
+ code_dim = int(_d) if isinstance(_d, (int,)) else None
84
+ except Exception:
85
+ code_dim = None
86
+
87
+ # Utilization and dead-code fraction
88
+ if torch is not None and hasattr(trace.soft_code, "best_indices"):
89
+ try:
90
+ best = trace.soft_code.best_indices
91
+ if not isinstance(best, torch.Tensor):
92
+ best = getattr(trace, "chosen_center_indices", None)
93
+ if isinstance(best, torch.Tensor) and K and K > 0:
94
+ used = len(torch.unique(best.reshape(-1)).tolist())
95
+ utilization = float(used) / float(K)
96
+ dead_frac = max(0.0, 1.0 - utilization)
97
+ except Exception:
98
+ pass
99
+
100
+ # Entropy, best-length, and commitment distance means
101
+ if torch is not None:
102
+ try:
103
+ ent = getattr(trace.soft_code, "entropy", None)
104
+ if isinstance(ent, torch.Tensor):
105
+ mean_entropy = float(ent.mean().item())
106
+ except Exception:
107
+ pass
108
+ try:
109
+ bl = getattr(trace.soft_code, "best_length", None)
110
+ if isinstance(bl, torch.Tensor):
111
+ mean_best_len = float(bl.mean().item())
112
+ except Exception:
113
+ pass
114
+ try:
115
+ cm = getattr(trace, "commitment_distances", None)
116
+ if isinstance(cm, torch.Tensor):
117
+ mean_commit = float(cm.mean().item())
118
+ except Exception:
119
+ pass
120
+
121
+ # Reduction metadata
122
+ try:
123
+ meta = getattr(trace, "reduction_meta", {}) or {}
124
+ reduction = str(meta.get("method")) if isinstance(meta, Mapping) else None
125
+ except Exception:
126
+ reduction = None
127
+
128
+ return DiagnosticsRow(
129
+ step=step,
130
+ layer=layer_name,
131
+ K=K,
132
+ code_dim=code_dim,
133
+ utilization=utilization,
134
+ dead_code_frac=dead_frac,
135
+ mean_entropy=mean_entropy,
136
+ mean_best_length=mean_best_len,
137
+ mean_commitment=mean_commit,
138
+ reduction=reduction,
139
+ )
140
+
141
+ def collect(
142
+ self, traces: Mapping[str, CodingTrace], *, step: int | None = None
143
+ ) -> list[DiagnosticsRow]:
144
+ rows: list[DiagnosticsRow] = []
145
+ for name, trace in traces.items():
146
+ try:
147
+ rows.append(self.summarize_trace(str(name), trace, step=step))
148
+ except Exception:
149
+ continue
150
+ return rows
151
+
152
+
153
+ class CsvDiagnosticsLogger:
154
+ """Append diagnostics rows to a CSV file with a fixed header."""
155
+
156
+ def __init__(self, path: str | Path) -> None:
157
+ self._path = Path(path)
158
+ self._header_written = False
159
+ self._fieldnames = [
160
+ "step",
161
+ "layer",
162
+ "K",
163
+ "code_dim",
164
+ "utilization",
165
+ "dead_code_frac",
166
+ "mean_entropy",
167
+ "mean_best_length",
168
+ "mean_commitment",
169
+ "reduction",
170
+ ]
171
+
172
+ def log(self, rows: Iterable[DiagnosticsRow]) -> None:
173
+ self._path.parent.mkdir(parents=True, exist_ok=True)
174
+ with self._path.open("a", encoding="utf-8", newline="") as f:
175
+ writer = csv.DictWriter(f, fieldnames=self._fieldnames)
176
+ if not self._header_written and self._path.stat().st_size == 0:
177
+ writer.writeheader()
178
+ self._header_written = True
179
+ for r in rows:
180
+ writer.writerow({k: getattr(r, k) for k in self._fieldnames})
181
+
182
+
183
+ class JsonlDiagnosticsLogger:
184
+ """Append diagnostics rows as JSONL objects (one per line)."""
185
+
186
+ def __init__(self, path: str | Path) -> None:
187
+ self._path = Path(path)
188
+
189
+ def log(self, rows: Iterable[DiagnosticsRow]) -> None:
190
+ self._path.parent.mkdir(parents=True, exist_ok=True)
191
+ with self._path.open("a", encoding="utf-8") as f:
192
+ for r in rows:
193
+ f.write(json.dumps(r.__dict__))
194
+ f.write("\n")