nervecode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nervecode/__init__.py +415 -0
- nervecode/_version.py +10 -0
- nervecode/core/__init__.py +19 -0
- nervecode/core/assignment.py +165 -0
- nervecode/core/codebook.py +182 -0
- nervecode/core/shapes.py +107 -0
- nervecode/core/temperature.py +227 -0
- nervecode/core/trace.py +166 -0
- nervecode/core/types.py +116 -0
- nervecode/integration/__init__.py +9 -0
- nervecode/layers/__init__.py +15 -0
- nervecode/layers/base.py +333 -0
- nervecode/layers/conv.py +174 -0
- nervecode/layers/linear.py +176 -0
- nervecode/layers/reducers.py +80 -0
- nervecode/layers/wrap.py +223 -0
- nervecode/scoring/__init__.py +20 -0
- nervecode/scoring/aggregator.py +369 -0
- nervecode/scoring/calibrator.py +396 -0
- nervecode/scoring/types.py +33 -0
- nervecode/training/__init__.py +25 -0
- nervecode/training/diagnostics.py +194 -0
- nervecode/training/loss.py +188 -0
- nervecode/training/updaters.py +168 -0
- nervecode/utils/__init__.py +14 -0
- nervecode/utils/overhead.py +177 -0
- nervecode/utils/seed.py +161 -0
- nervecode-0.1.0.dist-info/METADATA +83 -0
- nervecode-0.1.0.dist-info/RECORD +31 -0
- nervecode-0.1.0.dist-info/WHEEL +4 -0
- nervecode-0.1.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
"""Surprise aggregation utilities.
|
|
2
|
+
|
|
3
|
+
MVP provides simple, explicit aggregation strategies across wrapped layers:
|
|
4
|
+
- mean, max, and fixed weighted combinations of per-layer surprise signals.
|
|
5
|
+
|
|
6
|
+
The primary entry point, ``mean_surprise(...)``, accepts a collection of
|
|
7
|
+
per-layer traces or soft-code objects and returns a per-sample aggregated
|
|
8
|
+
surprise tensor. Inputs are interpreted best-effort to preserve fail-open
|
|
9
|
+
behavior:
|
|
10
|
+
|
|
11
|
+
- ``CodingTrace``: use ``soft_code.combined_surprise`` if available, otherwise
|
|
12
|
+
fall back to ``soft_code.best_length``.
|
|
13
|
+
- ``SoftCode``: use ``combined_surprise`` if available, otherwise ``best_length``.
|
|
14
|
+
- ``torch.Tensor``: treated as an already-computed per-sample surprise signal.
|
|
15
|
+
|
|
16
|
+
Entries without an interpretable per-sample signal are skipped. When no valid
|
|
17
|
+
signals remain, ``None`` is returned. When multiple layers are provided,
|
|
18
|
+
signals are first reduced to a sample-level view so that wrappers with
|
|
19
|
+
different leading dimensions (e.g., ``(B, T)`` vs. ``(B,)``) can be combined.
|
|
20
|
+
Layers that cannot provide a sample-level view are skipped.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
from collections.abc import Iterable, Mapping
|
|
26
|
+
from typing import Any, cast
|
|
27
|
+
|
|
28
|
+
try: # Keep import-time behavior tolerant in environments without torch
|
|
29
|
+
import torch
|
|
30
|
+
except Exception: # pragma: no cover - torch is a project dependency in tests
|
|
31
|
+
torch = cast(Any, None)
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from nervecode.core import CodingTrace, SoftCode # re-exported types
|
|
35
|
+
except Exception: # pragma: no cover - available during normal package use
|
|
36
|
+
CodingTrace = object # type: ignore[misc,assignment]
|
|
37
|
+
SoftCode = object # type: ignore[misc,assignment]
|
|
38
|
+
|
|
39
|
+
from .types import AggregatedSurprise
|
|
40
|
+
|
|
41
|
+
__all__ = ["AggregatedSurprise", "max_surprise", "mean_surprise", "weighted_surprise"]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def mean_surprise(
|
|
45
|
+
traces: Mapping[str, Any] | Iterable[Any],
|
|
46
|
+
) -> AggregatedSurprise | None:
|
|
47
|
+
"""Return the mean-aggregated surprise across layers.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
- traces: Mapping from layer names to trace-like objects or an iterable of
|
|
51
|
+
such entries. Each entry may be a ``CodingTrace``, ``SoftCode``, or a
|
|
52
|
+
per-sample ``torch.Tensor``.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
- ``AggregatedSurprise`` with the aggregated per-sample signal and basic
|
|
56
|
+
metadata, or ``None`` when no valid per-sample signals are available.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if torch is None: # pragma: no cover - defensive
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
# Normalize input to an iterable of values
|
|
63
|
+
values: Iterable[Any] = traces.values() if isinstance(traces, Mapping) else traces
|
|
64
|
+
|
|
65
|
+
signals: list[torch.Tensor] = []
|
|
66
|
+
ref_shape: tuple[int, ...] | None = None
|
|
67
|
+
|
|
68
|
+
for obj in values:
|
|
69
|
+
t = _as_sample_surprise_tensor(obj)
|
|
70
|
+
if t is None:
|
|
71
|
+
continue
|
|
72
|
+
# Use the shape of the first valid tensor as a reference
|
|
73
|
+
shape = tuple(int(s) for s in t.shape)
|
|
74
|
+
if ref_shape is None:
|
|
75
|
+
ref_shape = shape
|
|
76
|
+
signals.append(t)
|
|
77
|
+
else:
|
|
78
|
+
# Only aggregate signals with the same per-sample shape.
|
|
79
|
+
if shape != ref_shape:
|
|
80
|
+
continue
|
|
81
|
+
signals.append(t)
|
|
82
|
+
|
|
83
|
+
if not signals:
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
# Bring all tensors to the device/dtype of the first entry for stacking
|
|
87
|
+
first = signals[0]
|
|
88
|
+
device = getattr(first, "device", None)
|
|
89
|
+
dtype = first.dtype if hasattr(first, "dtype") else None
|
|
90
|
+
aligned: list[torch.Tensor] = []
|
|
91
|
+
for s in signals:
|
|
92
|
+
s2 = s
|
|
93
|
+
if dtype is not None and getattr(s2, "dtype", None) is not dtype:
|
|
94
|
+
s2 = s2.to(dtype=dtype)
|
|
95
|
+
if device is not None and getattr(s2, "device", None) != device:
|
|
96
|
+
s2 = s2.to(device=device)
|
|
97
|
+
aligned.append(s2)
|
|
98
|
+
|
|
99
|
+
stacked = torch.stack(aligned, dim=0)
|
|
100
|
+
agg = stacked.mean(dim=0)
|
|
101
|
+
return AggregatedSurprise(
|
|
102
|
+
surprise=agg,
|
|
103
|
+
method="mean",
|
|
104
|
+
num_layers=len(aligned),
|
|
105
|
+
details=None,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def max_surprise(
|
|
110
|
+
traces: Mapping[str, Any] | Iterable[Any],
|
|
111
|
+
) -> AggregatedSurprise | None:
|
|
112
|
+
"""Return the max-aggregated surprise across layers.
|
|
113
|
+
|
|
114
|
+
Mirrors ``mean_surprise`` for input handling and sample-level reduction but
|
|
115
|
+
combines participating layer signals using a max across layers instead of a
|
|
116
|
+
mean. Returns ``None`` when no valid per-sample signals are available.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
if torch is None: # pragma: no cover - defensive
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
values: Iterable[Any] = traces.values() if isinstance(traces, Mapping) else traces
|
|
123
|
+
|
|
124
|
+
signals: list[torch.Tensor] = []
|
|
125
|
+
ref_shape: tuple[int, ...] | None = None
|
|
126
|
+
|
|
127
|
+
for obj in values:
|
|
128
|
+
t = _as_sample_surprise_tensor(obj)
|
|
129
|
+
if t is None:
|
|
130
|
+
continue
|
|
131
|
+
shape = tuple(int(s) for s in t.shape)
|
|
132
|
+
if ref_shape is None:
|
|
133
|
+
ref_shape = shape
|
|
134
|
+
signals.append(t)
|
|
135
|
+
else:
|
|
136
|
+
if shape != ref_shape:
|
|
137
|
+
continue
|
|
138
|
+
signals.append(t)
|
|
139
|
+
|
|
140
|
+
if not signals:
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
first = signals[0]
|
|
144
|
+
device = getattr(first, "device", None)
|
|
145
|
+
dtype = first.dtype if hasattr(first, "dtype") else None
|
|
146
|
+
aligned: list[torch.Tensor] = []
|
|
147
|
+
for s in signals:
|
|
148
|
+
s2 = s
|
|
149
|
+
if dtype is not None and getattr(s2, "dtype", None) is not dtype:
|
|
150
|
+
s2 = s2.to(dtype=dtype)
|
|
151
|
+
if device is not None and getattr(s2, "device", None) != device:
|
|
152
|
+
s2 = s2.to(device=device)
|
|
153
|
+
aligned.append(s2)
|
|
154
|
+
|
|
155
|
+
stacked = torch.stack(aligned, dim=0)
|
|
156
|
+
agg = torch.max(stacked, dim=0).values
|
|
157
|
+
return AggregatedSurprise(
|
|
158
|
+
surprise=agg,
|
|
159
|
+
method="max",
|
|
160
|
+
num_layers=len(aligned),
|
|
161
|
+
details=None,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def weighted_surprise(
|
|
166
|
+
traces: Mapping[str, Any] | Iterable[Any],
|
|
167
|
+
*,
|
|
168
|
+
weights: Mapping[str, float] | Iterable[float] | None = None,
|
|
169
|
+
normalize: bool = True,
|
|
170
|
+
) -> AggregatedSurprise | None:
|
|
171
|
+
"""Return a fixed weighted aggregation across layers.
|
|
172
|
+
|
|
173
|
+
The function mirrors ``mean_surprise`` for input handling and sample-level
|
|
174
|
+
reduction, but combines participating layer signals using explicit fixed
|
|
175
|
+
weights instead of a uniform mean. This function does not learn weights.
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
- traces: Mapping from layer names to trace-like objects or an iterable of
|
|
179
|
+
such entries. Each entry may be a ``CodingTrace``, ``SoftCode``, or a
|
|
180
|
+
per-sample ``torch.Tensor``.
|
|
181
|
+
- weights: Either a mapping from layer name to weight (preferred when
|
|
182
|
+
``traces`` is a mapping) or an iterable of weights aligned with the order
|
|
183
|
+
of ``traces``. When omitted, all included layers use weight ``1.0``.
|
|
184
|
+
- normalize: When ``True`` (default), divides by the sum of the included
|
|
185
|
+
weights so that outputs are comparable to a mean when weights are equal.
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
- ``AggregatedSurprise`` with aggregation metadata including the effective
|
|
189
|
+
weights used, or ``None`` when no valid per-sample signals are available
|
|
190
|
+
or the total weight after filtering is zero.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
if torch is None: # pragma: no cover - defensive
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
# Normalize to lists of (name, value) to preserve a stable order for
|
|
197
|
+
# iterable inputs and enable name-based weight lookup when available.
|
|
198
|
+
items: list[tuple[str | None, Any]]
|
|
199
|
+
if isinstance(traces, Mapping):
|
|
200
|
+
items = [(str(k), v) for k, v in traces.items()]
|
|
201
|
+
else:
|
|
202
|
+
items = [(None, v) for v in traces]
|
|
203
|
+
|
|
204
|
+
# Prepare a parallel list of sample-level tensors alongside resolved weights.
|
|
205
|
+
signals: list[torch.Tensor] = []
|
|
206
|
+
resolved_weights: list[float] = []
|
|
207
|
+
names: list[str | None] = []
|
|
208
|
+
ref_shape: tuple[int, ...] | None = None
|
|
209
|
+
|
|
210
|
+
# Helper to get a weight for an item given its index and (optional) name.
|
|
211
|
+
weights_map: Mapping[str, float] | None
|
|
212
|
+
weights_seq: list[float] | None
|
|
213
|
+
if weights is None:
|
|
214
|
+
weights_map = None
|
|
215
|
+
weights_seq = None
|
|
216
|
+
elif isinstance(weights, Mapping):
|
|
217
|
+
# Use a plain mapping lookup for name-keyed inputs
|
|
218
|
+
weights_map = {str(k): float(v) for k, v in weights.items()}
|
|
219
|
+
weights_seq = None
|
|
220
|
+
else:
|
|
221
|
+
# Materialize iterable weights to support index-aligned access
|
|
222
|
+
try:
|
|
223
|
+
weights_seq = [float(v) for v in list(weights)]
|
|
224
|
+
except Exception:
|
|
225
|
+
weights_seq = None
|
|
226
|
+
weights_map = None
|
|
227
|
+
|
|
228
|
+
def _weight_for(idx: int, name: str | None) -> float:
|
|
229
|
+
if weights_map is not None and name is not None:
|
|
230
|
+
return float(weights_map.get(name, 1.0))
|
|
231
|
+
if weights_seq is not None and 0 <= idx < len(weights_seq):
|
|
232
|
+
return float(weights_seq[idx])
|
|
233
|
+
return 1.0
|
|
234
|
+
|
|
235
|
+
for idx, (name, obj) in enumerate(items):
|
|
236
|
+
t = _as_sample_surprise_tensor(obj)
|
|
237
|
+
if t is None:
|
|
238
|
+
continue
|
|
239
|
+
shape = tuple(int(s) for s in t.shape)
|
|
240
|
+
if ref_shape is None:
|
|
241
|
+
ref_shape = shape
|
|
242
|
+
if shape != ref_shape:
|
|
243
|
+
# Skip mismatched sample shapes to preserve fail-open behavior.
|
|
244
|
+
continue
|
|
245
|
+
w_i = _weight_for(idx, name)
|
|
246
|
+
# Skip strictly non-positive weights to avoid divide-by-zero surprises
|
|
247
|
+
# and preserve user intent (e.g., selectively disabling layers).
|
|
248
|
+
if not (w_i > 0.0):
|
|
249
|
+
continue
|
|
250
|
+
names.append(name)
|
|
251
|
+
signals.append(t)
|
|
252
|
+
resolved_weights.append(float(w_i))
|
|
253
|
+
|
|
254
|
+
if not signals:
|
|
255
|
+
return None
|
|
256
|
+
|
|
257
|
+
# Align device/dtype across participating tensors.
|
|
258
|
+
first = signals[0]
|
|
259
|
+
device = getattr(first, "device", None)
|
|
260
|
+
dtype = first.dtype if hasattr(first, "dtype") else None
|
|
261
|
+
aligned: list[torch.Tensor] = []
|
|
262
|
+
for s in signals:
|
|
263
|
+
s2 = s
|
|
264
|
+
if dtype is not None and getattr(s2, "dtype", None) is not dtype:
|
|
265
|
+
s2 = s2.to(dtype=dtype)
|
|
266
|
+
if device is not None and getattr(s2, "device", None) != device:
|
|
267
|
+
s2 = s2.to(device=device)
|
|
268
|
+
aligned.append(s2)
|
|
269
|
+
|
|
270
|
+
# Convert weights to a tensor on the target device/dtype for safe math.
|
|
271
|
+
w_vec = torch.tensor(resolved_weights, dtype=dtype or torch.float32, device=device)
|
|
272
|
+
total = float(w_vec.sum().item())
|
|
273
|
+
if total <= 0.0:
|
|
274
|
+
return None
|
|
275
|
+
|
|
276
|
+
stacked = torch.stack(aligned, dim=0)
|
|
277
|
+
# Weighted sum along layer dimension
|
|
278
|
+
# Shape: (layers, ...) -> (...) via tensordot/broadcasted mul then sum.
|
|
279
|
+
weighted = stacked * w_vec.view(-1, *([1] * (stacked.ndim - 1)))
|
|
280
|
+
agg = weighted.sum(dim=0)
|
|
281
|
+
if normalize:
|
|
282
|
+
agg = agg / total
|
|
283
|
+
|
|
284
|
+
# Provide simple details: per-layer names and weights actually used.
|
|
285
|
+
used: list[tuple[str | None, float]] = list(zip(names, resolved_weights))
|
|
286
|
+
details: dict[str, Any] = {
|
|
287
|
+
"weights": used,
|
|
288
|
+
"normalized": bool(normalize),
|
|
289
|
+
"sum_weights": total,
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
return AggregatedSurprise(
|
|
293
|
+
surprise=agg,
|
|
294
|
+
method="weighted",
|
|
295
|
+
num_layers=len(aligned),
|
|
296
|
+
details=details,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _as_sample_surprise_tensor(obj: Any) -> torch.Tensor | None:
|
|
301
|
+
"""Return a per-sample surprise tensor for a supported input object.
|
|
302
|
+
|
|
303
|
+
All supported inputs are reduced to a sample-level vector by collapsing
|
|
304
|
+
leading dimensions after the first (using mean) when necessary. This allows
|
|
305
|
+
combining wrappers that produce position- or token-level signals with
|
|
306
|
+
batch-only signals.
|
|
307
|
+
|
|
308
|
+
Supported inputs:
|
|
309
|
+
- CodingTrace: use ``sample_reduced_surprise()`` when available; otherwise
|
|
310
|
+
prefer ``soft_code.combined_surprise``, else ``soft_code.best_length``.
|
|
311
|
+
- SoftCode: prefer combined_surprise, else best_length; reduce to ``(B,)``
|
|
312
|
+
when rank > 1.
|
|
313
|
+
- torch.Tensor: reduce to ``(B,)`` when rank > 1.
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
if torch is None: # pragma: no cover - defensive
|
|
317
|
+
return None
|
|
318
|
+
|
|
319
|
+
# Torch tensor path
|
|
320
|
+
if hasattr(torch, "Tensor") and isinstance(obj, torch.Tensor):
|
|
321
|
+
return _reduce_to_sample(obj)
|
|
322
|
+
|
|
323
|
+
# SoftCode path
|
|
324
|
+
if isinstance(obj, SoftCode):
|
|
325
|
+
cand = getattr(obj, "combined_surprise", None)
|
|
326
|
+
if cand is None:
|
|
327
|
+
cand = getattr(obj, "best_length", None)
|
|
328
|
+
if isinstance(cand, torch.Tensor):
|
|
329
|
+
return _reduce_to_sample(cand)
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
# CodingTrace path
|
|
333
|
+
if isinstance(obj, CodingTrace):
|
|
334
|
+
# Prefer a sample-level view provided by the trace itself.
|
|
335
|
+
view = getattr(obj, "sample_reduced_surprise", None)
|
|
336
|
+
if callable(view):
|
|
337
|
+
try:
|
|
338
|
+
out = view()
|
|
339
|
+
if isinstance(out, torch.Tensor):
|
|
340
|
+
return out
|
|
341
|
+
except Exception:
|
|
342
|
+
pass
|
|
343
|
+
sc = getattr(obj, "soft_code", None)
|
|
344
|
+
if isinstance(sc, SoftCode):
|
|
345
|
+
return _as_sample_surprise_tensor(sc)
|
|
346
|
+
return None
|
|
347
|
+
|
|
348
|
+
# Mapping path: allow passing a dict-like with a direct 'surprise' tensor
|
|
349
|
+
if isinstance(obj, Mapping):
|
|
350
|
+
# Common keys we might accept in future; keep minimal for MVP.
|
|
351
|
+
cand = obj.get("surprise") if hasattr(obj, "get") else None
|
|
352
|
+
if isinstance(cand, torch.Tensor):
|
|
353
|
+
return _reduce_to_sample(cand)
|
|
354
|
+
|
|
355
|
+
return None
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _reduce_to_sample(t: torch.Tensor) -> torch.Tensor:
|
|
359
|
+
"""Reduce a tensor to a per-sample vector by averaging extra leading dims.
|
|
360
|
+
|
|
361
|
+
- ``(B,)`` -> returned unchanged
|
|
362
|
+
- ``(B, T, ...)`` -> mean over dims 1..N-1
|
|
363
|
+
- ``()`` (scalar) -> returned as-is (caller may choose to skip)
|
|
364
|
+
"""
|
|
365
|
+
|
|
366
|
+
if getattr(t, "ndim", 0) <= 1:
|
|
367
|
+
return t
|
|
368
|
+
dims = tuple(range(1, int(t.ndim)))
|
|
369
|
+
return t.mean(dim=dims)
|