nervecode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nervecode/__init__.py +415 -0
- nervecode/_version.py +10 -0
- nervecode/core/__init__.py +19 -0
- nervecode/core/assignment.py +165 -0
- nervecode/core/codebook.py +182 -0
- nervecode/core/shapes.py +107 -0
- nervecode/core/temperature.py +227 -0
- nervecode/core/trace.py +166 -0
- nervecode/core/types.py +116 -0
- nervecode/integration/__init__.py +9 -0
- nervecode/layers/__init__.py +15 -0
- nervecode/layers/base.py +333 -0
- nervecode/layers/conv.py +174 -0
- nervecode/layers/linear.py +176 -0
- nervecode/layers/reducers.py +80 -0
- nervecode/layers/wrap.py +223 -0
- nervecode/scoring/__init__.py +20 -0
- nervecode/scoring/aggregator.py +369 -0
- nervecode/scoring/calibrator.py +396 -0
- nervecode/scoring/types.py +33 -0
- nervecode/training/__init__.py +25 -0
- nervecode/training/diagnostics.py +194 -0
- nervecode/training/loss.py +188 -0
- nervecode/training/updaters.py +168 -0
- nervecode/utils/__init__.py +14 -0
- nervecode/utils/overhead.py +177 -0
- nervecode/utils/seed.py +161 -0
- nervecode-0.1.0.dist-info/METADATA +83 -0
- nervecode-0.1.0.dist-info/RECORD +31 -0
- nervecode-0.1.0.dist-info/WHEEL +4 -0
- nervecode-0.1.0.dist-info/licenses/LICENSE +22 -0
nervecode/__init__.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
"""Nervecode package root and public API stubs.
|
|
2
|
+
|
|
3
|
+
The real implementation will land incrementally following the MVP plan. For now,
|
|
4
|
+
we expose the intended top-level API so users and tooling can rely on the package
|
|
5
|
+
surface existing from the start.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import Iterable, Iterator
|
|
11
|
+
from contextlib import ExitStack, contextmanager, suppress
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
# Import base wrapper type for lightweight wrapper tracking without a hard
|
|
15
|
+
# torch dependency at import time.
|
|
16
|
+
try: # pragma: no cover - available in normal package use
|
|
17
|
+
from nervecode.layers.base import BaseCodingWrapper
|
|
18
|
+
except Exception: # pragma: no cover - fallback for environments without package context
|
|
19
|
+
BaseCodingWrapper = object # type: ignore[misc,assignment]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class WrappedModel:
|
|
23
|
+
"""Thin container around a base model that tracks coding layers.
|
|
24
|
+
|
|
25
|
+
Goals
|
|
26
|
+
- Preserve the original model API: attribute access and ``__call__``
|
|
27
|
+
delegate to the wrapped model so existing code keeps working.
|
|
28
|
+
- Track inserted coding layers (wrappers) to support convenience
|
|
29
|
+
iterators and aggregations implemented in subsequent tasks.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
# Keep the constructor torch-agnostic and accept any object. When the
|
|
33
|
+
# object is a PyTorch ``nn.Module``, wrapper discovery uses
|
|
34
|
+
# ``named_modules()``; otherwise, tracking stays empty.
|
|
35
|
+
def __init__(self, model: object) -> None:
|
|
36
|
+
# Store the wrapped model reference for external access and tests.
|
|
37
|
+
self._model = model
|
|
38
|
+
# Internal registry of discovered coding wrappers keyed by dotted name.
|
|
39
|
+
self._wrapped_layers: dict[str, BaseCodingWrapper] = {}
|
|
40
|
+
# Convenience caches populated on forward() when coding is active.
|
|
41
|
+
# Latest visible model output from the most recent forward.
|
|
42
|
+
self._last_output: Any | None = None
|
|
43
|
+
# Latest per-layer traces keyed by dotted module name.
|
|
44
|
+
# Values are intentionally typed as Any to keep this module torch-agnostic;
|
|
45
|
+
# concrete type is nervecode.core.CodingTrace in torch-enabled envs.
|
|
46
|
+
self._last_layer_traces: dict[str, Any] = {}
|
|
47
|
+
# Optional calibrator attached after a calibration run.
|
|
48
|
+
self._calibrator: Any | None = None
|
|
49
|
+
self._refresh_wrapped_layers()
|
|
50
|
+
|
|
51
|
+
# --- Delegation to preserve the base model API -----------------------------
|
|
52
|
+
def __getattr__(self, name: str) -> Any: # pragma: no cover - exercised via tests
|
|
53
|
+
# Delegate attribute access to the wrapped model when not found on self.
|
|
54
|
+
try:
|
|
55
|
+
return getattr(self._model, name)
|
|
56
|
+
except Exception as exc: # propagate standard AttributeError semantics
|
|
57
|
+
raise exc
|
|
58
|
+
|
|
59
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
60
|
+
"""Delegate calls to the wrapped model's ``__call__``.
|
|
61
|
+
|
|
62
|
+
This keeps the container usable anywhere the base model is expected.
|
|
63
|
+
The explicit ``forward()`` behavior is implemented in a later task.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
model = self._model
|
|
67
|
+
# Fast path: call the model if it is callable.
|
|
68
|
+
if callable(model):
|
|
69
|
+
return model(*args, **kwargs)
|
|
70
|
+
# Fallback: attempt a direct ``forward`` attribute.
|
|
71
|
+
fwd = getattr(model, "forward", None)
|
|
72
|
+
if callable(fwd):
|
|
73
|
+
return fwd(*args, **kwargs)
|
|
74
|
+
raise TypeError("Wrapped base model is not callable and has no forward()")
|
|
75
|
+
|
|
76
|
+
# --- Common nn.Module-style helpers that preserve chaining semantics -------
|
|
77
|
+
def train(self, mode: bool = True) -> WrappedModel: # pragma: no cover - exercised indirectly
|
|
78
|
+
"""Set underlying model to train/eval mode and return self for chaining.
|
|
79
|
+
|
|
80
|
+
Mirrors ``torch.nn.Module.train`` semantics so that patterns like
|
|
81
|
+
``WrappedModel(model).train()`` keep returning a wrapper rather than the
|
|
82
|
+
bare model (which would drop wrapper methods).
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
m = getattr(self._model, "train", None)
|
|
86
|
+
if callable(m):
|
|
87
|
+
with suppress(Exception):
|
|
88
|
+
m(bool(mode))
|
|
89
|
+
return self
|
|
90
|
+
|
|
91
|
+
def eval(self) -> WrappedModel: # pragma: no cover - exercised indirectly
|
|
92
|
+
"""Switch to evaluation mode and return self for chaining."""
|
|
93
|
+
return self.train(False)
|
|
94
|
+
|
|
95
|
+
def forward(self, *args: Any, **kwargs: Any) -> object:
|
|
96
|
+
"""Run a forward pass through the wrapped base model.
|
|
97
|
+
|
|
98
|
+
Returns the base model output. In the MVP, this will also collect coding
|
|
99
|
+
statistics when enabled.
|
|
100
|
+
"""
|
|
101
|
+
# Ensure our registry reflects any instrumentation that may have been
|
|
102
|
+
# applied after construction.
|
|
103
|
+
self._refresh_wrapped_layers()
|
|
104
|
+
|
|
105
|
+
# Clear last-forward convenience caches before computing a new result.
|
|
106
|
+
self._last_output = None
|
|
107
|
+
self._last_layer_traces = {}
|
|
108
|
+
|
|
109
|
+
# Delegate to the underlying model using its normal call semantics.
|
|
110
|
+
model = self._model
|
|
111
|
+
if callable(model):
|
|
112
|
+
out = model(*args, **kwargs)
|
|
113
|
+
else:
|
|
114
|
+
fwd = getattr(model, "forward", None)
|
|
115
|
+
if callable(fwd):
|
|
116
|
+
out = fwd(*args, **kwargs)
|
|
117
|
+
else:
|
|
118
|
+
raise TypeError("Wrapped base model is not callable and has no forward()")
|
|
119
|
+
|
|
120
|
+
# Update convenience caches: store output and collect latest per-layer traces
|
|
121
|
+
# from wrapped modules when coding is enabled at the layer level.
|
|
122
|
+
self._last_output = out
|
|
123
|
+
try:
|
|
124
|
+
traces: dict[str, Any] = {}
|
|
125
|
+
for name, wrapper in self._wrapped_layers.items():
|
|
126
|
+
# Each wrapper manages its own fail-open behavior and cache.
|
|
127
|
+
trace = getattr(wrapper, "last_trace", None)
|
|
128
|
+
if trace is not None:
|
|
129
|
+
traces[name] = trace
|
|
130
|
+
self._last_layer_traces = traces
|
|
131
|
+
except Exception:
|
|
132
|
+
# Fail-open: keep caches best-effort without affecting user forwards.
|
|
133
|
+
self._last_layer_traces = {}
|
|
134
|
+
|
|
135
|
+
return out
|
|
136
|
+
|
|
137
|
+
# --- Model-level fail-open controls ----------------------------------------
|
|
138
|
+
def enable_coding(self) -> None:
|
|
139
|
+
"""Enable coding on all discovered wrapped layers.
|
|
140
|
+
|
|
141
|
+
Safe to call at any time; silently does nothing when no wrappers are
|
|
142
|
+
present or when the wrapped object is not a PyTorch module.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
self._refresh_wrapped_layers()
|
|
146
|
+
try:
|
|
147
|
+
for wrapper in self._wrapped_layers.values():
|
|
148
|
+
wrapper.enable_coding()
|
|
149
|
+
except Exception:
|
|
150
|
+
# Fail-open: best-effort delegation only
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
def disable_coding(self) -> None:
|
|
154
|
+
"""Disable coding on all discovered wrapped layers (fail-open).
|
|
155
|
+
|
|
156
|
+
Clearing any cached per-layer traces ensures that callers observing
|
|
157
|
+
model-level caches see the disabled state immediately, without having
|
|
158
|
+
to run another forward first.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
self._refresh_wrapped_layers()
|
|
162
|
+
try:
|
|
163
|
+
for wrapper in self._wrapped_layers.values():
|
|
164
|
+
wrapper.disable_coding()
|
|
165
|
+
# Ensure stale traces are not exposed after disabling.
|
|
166
|
+
if hasattr(wrapper, "clear_last_trace"):
|
|
167
|
+
wrapper.clear_last_trace()
|
|
168
|
+
except Exception:
|
|
169
|
+
# Fail-open: best-effort delegation only
|
|
170
|
+
pass
|
|
171
|
+
# Clear model-level convenience cache as well.
|
|
172
|
+
self._last_layer_traces = {}
|
|
173
|
+
|
|
174
|
+
@contextmanager
|
|
175
|
+
def bypass(self) -> Iterator[None]:
|
|
176
|
+
"""Temporarily bypass coding across all wrapped layers.
|
|
177
|
+
|
|
178
|
+
The context nests correctly and delegates to every discovered wrapper's
|
|
179
|
+
own bypass context. Safe when no wrappers are present.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
self._refresh_wrapped_layers()
|
|
183
|
+
with ExitStack() as stack:
|
|
184
|
+
for wrapper in self._wrapped_layers.values():
|
|
185
|
+
# Each wrapper manages its own nesting depth.
|
|
186
|
+
stack.enter_context(wrapper.bypass())
|
|
187
|
+
yield
|
|
188
|
+
|
|
189
|
+
# Training-time auxiliary loss used to learn codebooks.
|
|
190
|
+
def coding_loss(self) -> Any:
|
|
191
|
+
"""Compute coding loss from the latest per-layer traces.
|
|
192
|
+
|
|
193
|
+
Uses the most recent per-layer traces collected during ``forward`` and
|
|
194
|
+
computes a scalar training-time loss via ``nervecode.training.CodingLoss``.
|
|
195
|
+
When no traces are available, raises a helpful error explaining how to
|
|
196
|
+
proceed (e.g., instrument the model, enable coding, and run a forward
|
|
197
|
+
pass, or use the explicit trace path and call ``CodingLoss`` directly).
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
# Lazy import to keep package importable without torch in non-training contexts
|
|
201
|
+
try:
|
|
202
|
+
from nervecode.training import CodingLoss
|
|
203
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
204
|
+
raise RuntimeError(
|
|
205
|
+
"CodingLoss is unavailable. Ensure PyTorch and training components are installed."
|
|
206
|
+
) from exc
|
|
207
|
+
|
|
208
|
+
traces = getattr(self, "_last_layer_traces", {}) or {}
|
|
209
|
+
if not traces:
|
|
210
|
+
# Diagnose common causes and provide actionable guidance.
|
|
211
|
+
wrappers = getattr(self, "_wrapped_layers", {}) or {}
|
|
212
|
+
if not wrappers:
|
|
213
|
+
raise RuntimeError(
|
|
214
|
+
"No coding layers discovered on this model. Instrument your model "
|
|
215
|
+
"(e.g., nervecode.layers.wrap.wrap(model, layers='all_linear')) "
|
|
216
|
+
"before calling WrappedModel.coding_loss()."
|
|
217
|
+
)
|
|
218
|
+
any_enabled = any(getattr(w, "coding_enabled", False) for w in wrappers.values())
|
|
219
|
+
if not any_enabled:
|
|
220
|
+
raise RuntimeError(
|
|
221
|
+
"Coding is disabled on all wrapped layers. Call wrapped.enable_coding() "
|
|
222
|
+
"and run a forward pass before calling WrappedModel.coding_loss()."
|
|
223
|
+
)
|
|
224
|
+
raise RuntimeError(
|
|
225
|
+
"No per-layer traces are available. Run wrapped.forward(...) first to "
|
|
226
|
+
"populate the latest traces, or use forward_with_trace(...) and pass explicit "
|
|
227
|
+
"traces to nervecode.training.CodingLoss."
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
loss_mod = CodingLoss()
|
|
231
|
+
return loss_mod(traces)
|
|
232
|
+
|
|
233
|
+
# Calibrate surprise thresholds/percentiles on in-distribution data.
|
|
234
|
+
def calibrate(self, data_loader: Iterable[Any]) -> None:
|
|
235
|
+
"""Calibrate empirical percentiles on aggregated surprise.
|
|
236
|
+
|
|
237
|
+
Runs the wrapped model over a data loader (held-out in-distribution),
|
|
238
|
+
collects per-sample aggregated surprise values, and fits an
|
|
239
|
+
EmpiricalPercentileCalibrator. The fitted calibrator is attached to the
|
|
240
|
+
instance for later use (e.g., percentiles, is_ood) via ``self._calibrator``.
|
|
241
|
+
|
|
242
|
+
Notes
|
|
243
|
+
- The data loader may yield either ``inputs`` or ``(inputs, targets)``.
|
|
244
|
+
- Coding must be enabled on wrapped layers for traces to be produced.
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
from nervecode.scoring import EmpiricalPercentileCalibrator, mean_surprise
|
|
249
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
250
|
+
raise RuntimeError(
|
|
251
|
+
"Calibration utilities are unavailable. Ensure nervecode.scoring is importable."
|
|
252
|
+
) from exc
|
|
253
|
+
|
|
254
|
+
# Collect per-sample surprise vectors across the loader
|
|
255
|
+
scores: list[Any] = []
|
|
256
|
+
# Prefer torch.no_grad if available, but avoid hard dependency at import time.
|
|
257
|
+
try:
|
|
258
|
+
import torch
|
|
259
|
+
|
|
260
|
+
no_grad_ctx: Any = torch.no_grad()
|
|
261
|
+
except Exception: # pragma: no cover - environments without torch
|
|
262
|
+
|
|
263
|
+
class _NullCtx:
|
|
264
|
+
def __enter__(self, *a: Any, **k: Any) -> None:
|
|
265
|
+
return None
|
|
266
|
+
|
|
267
|
+
def __exit__(self, *a: Any, **k: Any) -> None:
|
|
268
|
+
return None
|
|
269
|
+
|
|
270
|
+
no_grad_ctx = _NullCtx()
|
|
271
|
+
|
|
272
|
+
with no_grad_ctx:
|
|
273
|
+
for batch in data_loader:
|
|
274
|
+
# Common patterns: (inputs, targets) or inputs directly
|
|
275
|
+
if isinstance(batch, (tuple, list)) and len(batch) >= 1:
|
|
276
|
+
inputs = batch[0]
|
|
277
|
+
elif isinstance(batch, dict) and "inputs" in batch:
|
|
278
|
+
inputs = batch["inputs"]
|
|
279
|
+
else:
|
|
280
|
+
inputs = batch
|
|
281
|
+
|
|
282
|
+
_ = self.forward(inputs)
|
|
283
|
+
agg = self.surprise()
|
|
284
|
+
if agg is None:
|
|
285
|
+
# Fall back to explicit aggregation over the latest traces
|
|
286
|
+
traces = getattr(self, "_last_layer_traces", {}) or {}
|
|
287
|
+
agg = mean_surprise(traces)
|
|
288
|
+
if agg is not None and hasattr(agg, "surprise"):
|
|
289
|
+
scores.append(agg.surprise)
|
|
290
|
+
|
|
291
|
+
if not scores:
|
|
292
|
+
raise RuntimeError(
|
|
293
|
+
"Calibration collected no surprise scores; ensure coding is enabled."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Fit calibrator on concatenated scores when possible (torch),
|
|
297
|
+
# otherwise fit per-chunk (rare path when tensor ops are unavailable).
|
|
298
|
+
try:
|
|
299
|
+
import torch
|
|
300
|
+
|
|
301
|
+
vec = torch.cat([s.detach().to(device="cpu").reshape(-1) for s in scores], dim=0)
|
|
302
|
+
calib = EmpiricalPercentileCalibrator(threshold_quantiles=(0.95,))
|
|
303
|
+
_ = calib.fit(vec, aggregation="mean")
|
|
304
|
+
except Exception: # pragma: no cover - defensive minimal fallback
|
|
305
|
+
calib = EmpiricalPercentileCalibrator(threshold_quantiles=(0.95,))
|
|
306
|
+
# Let calibrator normalize iterable inputs internally
|
|
307
|
+
flat: list[float] = []
|
|
308
|
+
for s in scores:
|
|
309
|
+
with suppress(Exception):
|
|
310
|
+
resh = getattr(s, "reshape", None)
|
|
311
|
+
if callable(resh):
|
|
312
|
+
arr = resh(-1).tolist()
|
|
313
|
+
else:
|
|
314
|
+
arr = s.tolist() if hasattr(s, "tolist") else [s]
|
|
315
|
+
flat.extend([float(v) for v in list(arr)])
|
|
316
|
+
_ = calib.fit(flat, aggregation="mean")
|
|
317
|
+
|
|
318
|
+
self._calibrator = calib
|
|
319
|
+
return None
|
|
320
|
+
|
|
321
|
+
# Inference-time surprise signal (aggregated over wrapped layers).
|
|
322
|
+
def surprise(self) -> object:
|
|
323
|
+
"""Return the latest aggregated surprise across wrapped layers.
|
|
324
|
+
|
|
325
|
+
This convenience method aggregates the most recent per-layer traces
|
|
326
|
+
collected during a standard ``forward`` into a per-sample surprise
|
|
327
|
+
signal using the default mean strategy. When no valid per-layer
|
|
328
|
+
signals are available, returns ``None``.
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
try:
|
|
332
|
+
from nervecode.scoring import mean_surprise # lazy import
|
|
333
|
+
except Exception:
|
|
334
|
+
return None
|
|
335
|
+
|
|
336
|
+
traces = getattr(self, "_last_layer_traces", {}) or {}
|
|
337
|
+
return mean_surprise(traces)
|
|
338
|
+
|
|
339
|
+
# Explicit trace-returning path for robust integrations.
|
|
340
|
+
def forward_with_trace(self, *args: Any, **kwargs: Any) -> tuple[object, dict[str, Any]]:
|
|
341
|
+
"""Return (model_output, per-layer_traces) for a single forward.
|
|
342
|
+
|
|
343
|
+
The first element mirrors the wrapped model's normal return value. The
|
|
344
|
+
second is a mapping from dotted layer names to ``CodingTrace`` objects
|
|
345
|
+
produced during this forward by wrapped layers (those with coding
|
|
346
|
+
active). When coding is disabled, the mapping is empty.
|
|
347
|
+
"""
|
|
348
|
+
|
|
349
|
+
out = self.forward(*args, **kwargs)
|
|
350
|
+
traces = dict(getattr(self, "_last_layer_traces", {}) or {})
|
|
351
|
+
return out, traces
|
|
352
|
+
|
|
353
|
+
# --- Internal: discover coding wrappers on the model -----------------------
|
|
354
|
+
def _refresh_wrapped_layers(self) -> None:
|
|
355
|
+
"""Discover and cache coding wrappers present on the wrapped model.
|
|
356
|
+
|
|
357
|
+
Populates ``_wrapped_layers`` with a mapping from dotted module names to
|
|
358
|
+
wrapper instances. Non-PyTorch objects result in an empty registry.
|
|
359
|
+
The method never raises; it is safe to call after in-place
|
|
360
|
+
instrumentation.
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
registry: dict[str, BaseCodingWrapper] = {}
|
|
364
|
+
model = self._model
|
|
365
|
+
try:
|
|
366
|
+
named_modules = getattr(model, "named_modules", None)
|
|
367
|
+
if callable(named_modules):
|
|
368
|
+
for name, module in named_modules():
|
|
369
|
+
# isinstance check remains torch-agnostic because
|
|
370
|
+
# BaseCodingWrapper is torch-free.
|
|
371
|
+
if isinstance(module, BaseCodingWrapper):
|
|
372
|
+
registry[str(name)] = module
|
|
373
|
+
except Exception:
|
|
374
|
+
# Fail-open: keep tracking best-effort only.
|
|
375
|
+
registry = {}
|
|
376
|
+
self._wrapped_layers = registry
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def wrap(model: object, *, layers: str | Iterable[str] = "all_linear") -> WrappedModel:
|
|
380
|
+
"""Wrap a PyTorch module to produce coding-based surprise signals.
|
|
381
|
+
|
|
382
|
+
Parameters
|
|
383
|
+
- model: Base `torch.nn.Module` to wrap.
|
|
384
|
+
- layers: Strategy or explicit layer names to wrap. The MVP will start with
|
|
385
|
+
the convenience value "all_linear".
|
|
386
|
+
"""
|
|
387
|
+
# Import the in-place instrumentation helper lazily to keep import-time
|
|
388
|
+
# dependencies minimal in environments without torch.
|
|
389
|
+
try:
|
|
390
|
+
from nervecode.layers.wrap import wrap as _wrap_layers
|
|
391
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
392
|
+
raise RuntimeError(
|
|
393
|
+
"Layer instrumentation is unavailable. Import nervecode.layers.wrap directly."
|
|
394
|
+
) from exc
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
_wrap_layers(model, layers=layers)
|
|
398
|
+
except Exception:
|
|
399
|
+
# Fail-open: return a container even if instrumentation could not be applied.
|
|
400
|
+
# This preserves the attribute and call delegation behavior for non-torch objects.
|
|
401
|
+
return WrappedModel(model)
|
|
402
|
+
|
|
403
|
+
return WrappedModel(model)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
__all__ = [
|
|
407
|
+
"WrappedModel",
|
|
408
|
+
"wrap",
|
|
409
|
+
]
|
|
410
|
+
|
|
411
|
+
# Runtime-accessible package version
|
|
412
|
+
try: # pragma: no cover - trivial import
|
|
413
|
+
from ._version import __version__ as __version__
|
|
414
|
+
except Exception: # pragma: no cover - fallback if version file unavailable
|
|
415
|
+
__version__ = "0.0.0"
|
nervecode/_version.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Package version for Nervecode.
|
|
2
|
+
|
|
3
|
+
The version is the single source of truth for both runtime access via
|
|
4
|
+
``nervecode.__version__`` and build-time metadata via Hatch's version plugin.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__all__ = ["__version__"]
|
|
8
|
+
|
|
9
|
+
# Follow PEP 440; bump as releases evolve.
|
|
10
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Core data model, primitives, and engines.
|
|
2
|
+
|
|
3
|
+
This subpackage will host the core types, codebook modules, assignment
|
|
4
|
+
algorithms, and trace representations used across Nervecode.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from .temperature import CosineTemperature, FixedTemperature, TemperatureSchedule
|
|
10
|
+
from .trace import CodingTrace
|
|
11
|
+
from .types import SoftCode
|
|
12
|
+
|
|
13
|
+
__all__: list[str] = [
|
|
14
|
+
"CodingTrace",
|
|
15
|
+
"CosineTemperature",
|
|
16
|
+
"FixedTemperature",
|
|
17
|
+
"SoftCode",
|
|
18
|
+
"TemperatureSchedule",
|
|
19
|
+
]
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Soft assignment engine.
|
|
2
|
+
|
|
3
|
+
This module computes squared Euclidean distances from reduced activations to a
|
|
4
|
+
codebook, converts them into soft assignments over codes, and returns both a
|
|
5
|
+
``SoftCode`` object and the intermediate tensors needed to construct a
|
|
6
|
+
``CodingTrace`` without recomputing distances.
|
|
7
|
+
|
|
8
|
+
Shapes follow the core convention:
|
|
9
|
+
- Inputs have shape ``(..., D)`` where ``...`` are arbitrary leading dims.
|
|
10
|
+
- Probabilities have shape ``(..., K)`` where ``K`` is the number of codes.
|
|
11
|
+
- Per-position scalar values have shape matching the leading dims ``...``.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
|
|
19
|
+
from .codebook import Codebook
|
|
20
|
+
from .shapes import flatten_leading, unflatten_leading
|
|
21
|
+
from .types import SoftCode
|
|
22
|
+
|
|
23
|
+
__all__ = ["SoftAssignment"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SoftAssignment(nn.Module):
|
|
27
|
+
"""Compute soft assignments and return trace-ready intermediates.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
- temperature: Positive scalar controlling assignment sharpness. Larger
|
|
31
|
+
values produce softer assignments. This parameter is not learnable; use a
|
|
32
|
+
schedule or an external tensor if you need dynamic temperature.
|
|
33
|
+
- eps: Small positive constant used to clamp divisions and logarithms to
|
|
34
|
+
keep values finite.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
*,
|
|
40
|
+
temperature: float = 1.0,
|
|
41
|
+
eps: float = 1e-12,
|
|
42
|
+
beta_length: float = 1.0,
|
|
43
|
+
beta_entropy: float = 1.0,
|
|
44
|
+
beta_distance: float = 0.0,
|
|
45
|
+
) -> None:
|
|
46
|
+
super().__init__()
|
|
47
|
+
t = float(temperature)
|
|
48
|
+
if not torch.isfinite(torch.tensor(t)) or t <= 0.0:
|
|
49
|
+
raise ValueError("temperature must be a positive finite float")
|
|
50
|
+
e = float(eps)
|
|
51
|
+
if not torch.isfinite(torch.tensor(e)) or e <= 0.0:
|
|
52
|
+
raise ValueError("eps must be a positive finite float")
|
|
53
|
+
# Explicit attribute annotation helps static type checkers understand
|
|
54
|
+
# the buffer's type (avoids Tensor|Module unions on .to/.clamp paths).
|
|
55
|
+
self._temperature: torch.Tensor
|
|
56
|
+
self.register_buffer("_temperature", torch.tensor(t), persistent=False)
|
|
57
|
+
self.eps: float = e
|
|
58
|
+
# Weights for the combined surprise signal S = bL * L + bH * H + bD * D_norm
|
|
59
|
+
self.beta_length: float = float(beta_length)
|
|
60
|
+
self.beta_entropy: float = float(beta_entropy)
|
|
61
|
+
self.beta_distance: float = float(beta_distance)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def temperature(self) -> float:
|
|
65
|
+
return float(self._temperature.item())
|
|
66
|
+
|
|
67
|
+
def forward(
|
|
68
|
+
self,
|
|
69
|
+
reduced: torch.Tensor,
|
|
70
|
+
codebook: Codebook | torch.Tensor,
|
|
71
|
+
) -> tuple[SoftCode, dict[str, torch.Tensor]]:
|
|
72
|
+
"""Return soft code and intermediates for ``CodingTrace``.
|
|
73
|
+
|
|
74
|
+
Inputs
|
|
75
|
+
- reduced: Tensor of shape ``(..., D)`` representing reduced activations.
|
|
76
|
+
- codebook: ``Codebook`` module or a centers tensor of shape
|
|
77
|
+
``(K, D)``.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
- soft_code: ``SoftCode`` with probabilities ``(..., K)`` and basic
|
|
81
|
+
scalar statistics (best length, entropy, best indices, combined surprise).
|
|
82
|
+
- intermediates: dict with keys
|
|
83
|
+
- ``nearest_center_distances``: ``(...,)``
|
|
84
|
+
- ``chosen_center_indices``: integer ``(...,)``
|
|
85
|
+
- ``commitment_distances``: ``(...,)``
|
|
86
|
+
which are sufficient to construct a ``CodingTrace`` without
|
|
87
|
+
recomputation.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
centers = codebook() if isinstance(codebook, Codebook) else codebook
|
|
91
|
+
|
|
92
|
+
if centers.ndim != 2:
|
|
93
|
+
raise ValueError("centers must have shape (K, D)")
|
|
94
|
+
|
|
95
|
+
x = reduced
|
|
96
|
+
if x.ndim < 1:
|
|
97
|
+
raise ValueError("reduced must have at least one dimension (..., D)")
|
|
98
|
+
|
|
99
|
+
x_flat, lead = flatten_leading(x) # (N, D)
|
|
100
|
+
d = int(centers.shape[1])
|
|
101
|
+
if int(x_flat.shape[1]) != d:
|
|
102
|
+
raise ValueError(f"reduced last dim {int(x_flat.shape[1])} must match centers D={d}")
|
|
103
|
+
|
|
104
|
+
# Compute squared Euclidean distances using the expansion
|
|
105
|
+
# ||x - c||^2 = ||x||^2 + ||c||^2 - 2 x·c
|
|
106
|
+
x_norm = (x_flat * x_flat).sum(dim=1, keepdim=True) # (N, 1)
|
|
107
|
+
c_norm = (centers * centers).sum(dim=1).unsqueeze(0) # (1, K)
|
|
108
|
+
# x @ centers^T yields (N, K)
|
|
109
|
+
dot = x_flat @ centers.t() # (N, K)
|
|
110
|
+
distances = x_norm + c_norm - 2.0 * dot # (N, K)
|
|
111
|
+
# For numerical safety, clamp at zero from below (squared distances).
|
|
112
|
+
distances = distances.clamp_min(0.0)
|
|
113
|
+
|
|
114
|
+
# Convert distances into softmax probabilities over codes using
|
|
115
|
+
# negative distances as logits, scaled by temperature.
|
|
116
|
+
t = self._temperature.to(dtype=distances.dtype, device=distances.device)
|
|
117
|
+
logits = -distances / torch.clamp(t, min=self.eps)
|
|
118
|
+
probs_flat = torch.softmax(logits, dim=-1)
|
|
119
|
+
|
|
120
|
+
# Best indices from min distance (equivalently max probability).
|
|
121
|
+
best_idx_flat = torch.argmin(distances, dim=-1) # (N,)
|
|
122
|
+
# Nearest-center distances and commitment distances per position.
|
|
123
|
+
nearest_flat = distances.gather(-1, best_idx_flat.unsqueeze(-1)).squeeze(-1) # (N,)
|
|
124
|
+
commitment_flat = nearest_flat # identical for squared Euclidean
|
|
125
|
+
|
|
126
|
+
# Information-theoretic scalars
|
|
127
|
+
eps_t = torch.tensor(self.eps, dtype=probs_flat.dtype, device=probs_flat.device)
|
|
128
|
+
best_prob_flat = probs_flat.gather(-1, best_idx_flat.unsqueeze(-1)).squeeze(-1)
|
|
129
|
+
best_length_flat = -(best_prob_flat + eps_t).log()
|
|
130
|
+
entropy_flat = -(probs_flat * (probs_flat + eps_t).log()).sum(dim=-1)
|
|
131
|
+
# Distance-based component: use a stable, monotonically increasing
|
|
132
|
+
# transform of the nearest-center squared Euclidean distance. ``log1p``
|
|
133
|
+
# reduces dynamic range without erasing separation. Avoid per-forward
|
|
134
|
+
# normalization to preserve absolute contrast between ID and OOD.
|
|
135
|
+
d_comp_flat = torch.log1p(nearest_flat)
|
|
136
|
+
# Weighted combination including optional distance contribution
|
|
137
|
+
combined_surprise_flat = (
|
|
138
|
+
self.beta_length * best_length_flat
|
|
139
|
+
+ self.beta_entropy * entropy_flat
|
|
140
|
+
+ self.beta_distance * d_comp_flat
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Restore original leading layout
|
|
144
|
+
probs = unflatten_leading(probs_flat, lead) # (..., K)
|
|
145
|
+
best_idx = unflatten_leading(best_idx_flat, lead) # (...,)
|
|
146
|
+
nearest = unflatten_leading(nearest_flat, lead) # (...,)
|
|
147
|
+
commitment = unflatten_leading(commitment_flat, lead) # (...,)
|
|
148
|
+
best_length = unflatten_leading(best_length_flat, lead) # (...,)
|
|
149
|
+
entropy = unflatten_leading(entropy_flat, lead) # (...,)
|
|
150
|
+
combined_surprise = unflatten_leading(combined_surprise_flat, lead) # (...,)
|
|
151
|
+
|
|
152
|
+
soft = SoftCode(
|
|
153
|
+
probs=probs,
|
|
154
|
+
best_length=best_length,
|
|
155
|
+
entropy=entropy,
|
|
156
|
+
best_indices=best_idx,
|
|
157
|
+
combined_surprise=combined_surprise,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
intermediates: dict[str, torch.Tensor] = {
|
|
161
|
+
"nearest_center_distances": nearest,
|
|
162
|
+
"chosen_center_indices": best_idx,
|
|
163
|
+
"commitment_distances": commitment,
|
|
164
|
+
}
|
|
165
|
+
return soft, intermediates
|