nervecode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,369 @@
1
+ """Surprise aggregation utilities.
2
+
3
+ MVP provides simple, explicit aggregation strategies across wrapped layers:
4
+ - mean, max, and fixed weighted combinations of per-layer surprise signals.
5
+
6
+ The primary entry point, ``mean_surprise(...)``, accepts a collection of
7
+ per-layer traces or soft-code objects and returns a per-sample aggregated
8
+ surprise tensor. Inputs are interpreted best-effort to preserve fail-open
9
+ behavior:
10
+
11
+ - ``CodingTrace``: use ``soft_code.combined_surprise`` if available, otherwise
12
+ fall back to ``soft_code.best_length``.
13
+ - ``SoftCode``: use ``combined_surprise`` if available, otherwise ``best_length``.
14
+ - ``torch.Tensor``: treated as an already-computed per-sample surprise signal.
15
+
16
+ Entries without an interpretable per-sample signal are skipped. When no valid
17
+ signals remain, ``None`` is returned. When multiple layers are provided,
18
+ signals are first reduced to a sample-level view so that wrappers with
19
+ different leading dimensions (e.g., ``(B, T)`` vs. ``(B,)``) can be combined.
20
+ Layers that cannot provide a sample-level view are skipped.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from collections.abc import Iterable, Mapping
26
+ from typing import Any, cast
27
+
28
+ try: # Keep import-time behavior tolerant in environments without torch
29
+ import torch
30
+ except Exception: # pragma: no cover - torch is a project dependency in tests
31
+ torch = cast(Any, None)
32
+
33
+ try:
34
+ from nervecode.core import CodingTrace, SoftCode # re-exported types
35
+ except Exception: # pragma: no cover - available during normal package use
36
+ CodingTrace = object # type: ignore[misc,assignment]
37
+ SoftCode = object # type: ignore[misc,assignment]
38
+
39
+ from .types import AggregatedSurprise
40
+
41
+ __all__ = ["AggregatedSurprise", "max_surprise", "mean_surprise", "weighted_surprise"]
42
+
43
+
44
+ def mean_surprise(
45
+ traces: Mapping[str, Any] | Iterable[Any],
46
+ ) -> AggregatedSurprise | None:
47
+ """Return the mean-aggregated surprise across layers.
48
+
49
+ Parameters
50
+ - traces: Mapping from layer names to trace-like objects or an iterable of
51
+ such entries. Each entry may be a ``CodingTrace``, ``SoftCode``, or a
52
+ per-sample ``torch.Tensor``.
53
+
54
+ Returns
55
+ - ``AggregatedSurprise`` with the aggregated per-sample signal and basic
56
+ metadata, or ``None`` when no valid per-sample signals are available.
57
+ """
58
+
59
+ if torch is None: # pragma: no cover - defensive
60
+ return None
61
+
62
+ # Normalize input to an iterable of values
63
+ values: Iterable[Any] = traces.values() if isinstance(traces, Mapping) else traces
64
+
65
+ signals: list[torch.Tensor] = []
66
+ ref_shape: tuple[int, ...] | None = None
67
+
68
+ for obj in values:
69
+ t = _as_sample_surprise_tensor(obj)
70
+ if t is None:
71
+ continue
72
+ # Use the shape of the first valid tensor as a reference
73
+ shape = tuple(int(s) for s in t.shape)
74
+ if ref_shape is None:
75
+ ref_shape = shape
76
+ signals.append(t)
77
+ else:
78
+ # Only aggregate signals with the same per-sample shape.
79
+ if shape != ref_shape:
80
+ continue
81
+ signals.append(t)
82
+
83
+ if not signals:
84
+ return None
85
+
86
+ # Bring all tensors to the device/dtype of the first entry for stacking
87
+ first = signals[0]
88
+ device = getattr(first, "device", None)
89
+ dtype = first.dtype if hasattr(first, "dtype") else None
90
+ aligned: list[torch.Tensor] = []
91
+ for s in signals:
92
+ s2 = s
93
+ if dtype is not None and getattr(s2, "dtype", None) is not dtype:
94
+ s2 = s2.to(dtype=dtype)
95
+ if device is not None and getattr(s2, "device", None) != device:
96
+ s2 = s2.to(device=device)
97
+ aligned.append(s2)
98
+
99
+ stacked = torch.stack(aligned, dim=0)
100
+ agg = stacked.mean(dim=0)
101
+ return AggregatedSurprise(
102
+ surprise=agg,
103
+ method="mean",
104
+ num_layers=len(aligned),
105
+ details=None,
106
+ )
107
+
108
+
109
+ def max_surprise(
110
+ traces: Mapping[str, Any] | Iterable[Any],
111
+ ) -> AggregatedSurprise | None:
112
+ """Return the max-aggregated surprise across layers.
113
+
114
+ Mirrors ``mean_surprise`` for input handling and sample-level reduction but
115
+ combines participating layer signals using a max across layers instead of a
116
+ mean. Returns ``None`` when no valid per-sample signals are available.
117
+ """
118
+
119
+ if torch is None: # pragma: no cover - defensive
120
+ return None
121
+
122
+ values: Iterable[Any] = traces.values() if isinstance(traces, Mapping) else traces
123
+
124
+ signals: list[torch.Tensor] = []
125
+ ref_shape: tuple[int, ...] | None = None
126
+
127
+ for obj in values:
128
+ t = _as_sample_surprise_tensor(obj)
129
+ if t is None:
130
+ continue
131
+ shape = tuple(int(s) for s in t.shape)
132
+ if ref_shape is None:
133
+ ref_shape = shape
134
+ signals.append(t)
135
+ else:
136
+ if shape != ref_shape:
137
+ continue
138
+ signals.append(t)
139
+
140
+ if not signals:
141
+ return None
142
+
143
+ first = signals[0]
144
+ device = getattr(first, "device", None)
145
+ dtype = first.dtype if hasattr(first, "dtype") else None
146
+ aligned: list[torch.Tensor] = []
147
+ for s in signals:
148
+ s2 = s
149
+ if dtype is not None and getattr(s2, "dtype", None) is not dtype:
150
+ s2 = s2.to(dtype=dtype)
151
+ if device is not None and getattr(s2, "device", None) != device:
152
+ s2 = s2.to(device=device)
153
+ aligned.append(s2)
154
+
155
+ stacked = torch.stack(aligned, dim=0)
156
+ agg = torch.max(stacked, dim=0).values
157
+ return AggregatedSurprise(
158
+ surprise=agg,
159
+ method="max",
160
+ num_layers=len(aligned),
161
+ details=None,
162
+ )
163
+
164
+
165
+ def weighted_surprise(
166
+ traces: Mapping[str, Any] | Iterable[Any],
167
+ *,
168
+ weights: Mapping[str, float] | Iterable[float] | None = None,
169
+ normalize: bool = True,
170
+ ) -> AggregatedSurprise | None:
171
+ """Return a fixed weighted aggregation across layers.
172
+
173
+ The function mirrors ``mean_surprise`` for input handling and sample-level
174
+ reduction, but combines participating layer signals using explicit fixed
175
+ weights instead of a uniform mean. This function does not learn weights.
176
+
177
+ Parameters
178
+ - traces: Mapping from layer names to trace-like objects or an iterable of
179
+ such entries. Each entry may be a ``CodingTrace``, ``SoftCode``, or a
180
+ per-sample ``torch.Tensor``.
181
+ - weights: Either a mapping from layer name to weight (preferred when
182
+ ``traces`` is a mapping) or an iterable of weights aligned with the order
183
+ of ``traces``. When omitted, all included layers use weight ``1.0``.
184
+ - normalize: When ``True`` (default), divides by the sum of the included
185
+ weights so that outputs are comparable to a mean when weights are equal.
186
+
187
+ Returns
188
+ - ``AggregatedSurprise`` with aggregation metadata including the effective
189
+ weights used, or ``None`` when no valid per-sample signals are available
190
+ or the total weight after filtering is zero.
191
+ """
192
+
193
+ if torch is None: # pragma: no cover - defensive
194
+ return None
195
+
196
+ # Normalize to lists of (name, value) to preserve a stable order for
197
+ # iterable inputs and enable name-based weight lookup when available.
198
+ items: list[tuple[str | None, Any]]
199
+ if isinstance(traces, Mapping):
200
+ items = [(str(k), v) for k, v in traces.items()]
201
+ else:
202
+ items = [(None, v) for v in traces]
203
+
204
+ # Prepare a parallel list of sample-level tensors alongside resolved weights.
205
+ signals: list[torch.Tensor] = []
206
+ resolved_weights: list[float] = []
207
+ names: list[str | None] = []
208
+ ref_shape: tuple[int, ...] | None = None
209
+
210
+ # Helper to get a weight for an item given its index and (optional) name.
211
+ weights_map: Mapping[str, float] | None
212
+ weights_seq: list[float] | None
213
+ if weights is None:
214
+ weights_map = None
215
+ weights_seq = None
216
+ elif isinstance(weights, Mapping):
217
+ # Use a plain mapping lookup for name-keyed inputs
218
+ weights_map = {str(k): float(v) for k, v in weights.items()}
219
+ weights_seq = None
220
+ else:
221
+ # Materialize iterable weights to support index-aligned access
222
+ try:
223
+ weights_seq = [float(v) for v in list(weights)]
224
+ except Exception:
225
+ weights_seq = None
226
+ weights_map = None
227
+
228
+ def _weight_for(idx: int, name: str | None) -> float:
229
+ if weights_map is not None and name is not None:
230
+ return float(weights_map.get(name, 1.0))
231
+ if weights_seq is not None and 0 <= idx < len(weights_seq):
232
+ return float(weights_seq[idx])
233
+ return 1.0
234
+
235
+ for idx, (name, obj) in enumerate(items):
236
+ t = _as_sample_surprise_tensor(obj)
237
+ if t is None:
238
+ continue
239
+ shape = tuple(int(s) for s in t.shape)
240
+ if ref_shape is None:
241
+ ref_shape = shape
242
+ if shape != ref_shape:
243
+ # Skip mismatched sample shapes to preserve fail-open behavior.
244
+ continue
245
+ w_i = _weight_for(idx, name)
246
+ # Skip strictly non-positive weights to avoid divide-by-zero surprises
247
+ # and preserve user intent (e.g., selectively disabling layers).
248
+ if not (w_i > 0.0):
249
+ continue
250
+ names.append(name)
251
+ signals.append(t)
252
+ resolved_weights.append(float(w_i))
253
+
254
+ if not signals:
255
+ return None
256
+
257
+ # Align device/dtype across participating tensors.
258
+ first = signals[0]
259
+ device = getattr(first, "device", None)
260
+ dtype = first.dtype if hasattr(first, "dtype") else None
261
+ aligned: list[torch.Tensor] = []
262
+ for s in signals:
263
+ s2 = s
264
+ if dtype is not None and getattr(s2, "dtype", None) is not dtype:
265
+ s2 = s2.to(dtype=dtype)
266
+ if device is not None and getattr(s2, "device", None) != device:
267
+ s2 = s2.to(device=device)
268
+ aligned.append(s2)
269
+
270
+ # Convert weights to a tensor on the target device/dtype for safe math.
271
+ w_vec = torch.tensor(resolved_weights, dtype=dtype or torch.float32, device=device)
272
+ total = float(w_vec.sum().item())
273
+ if total <= 0.0:
274
+ return None
275
+
276
+ stacked = torch.stack(aligned, dim=0)
277
+ # Weighted sum along layer dimension
278
+ # Shape: (layers, ...) -> (...) via tensordot/broadcasted mul then sum.
279
+ weighted = stacked * w_vec.view(-1, *([1] * (stacked.ndim - 1)))
280
+ agg = weighted.sum(dim=0)
281
+ if normalize:
282
+ agg = agg / total
283
+
284
+ # Provide simple details: per-layer names and weights actually used.
285
+ used: list[tuple[str | None, float]] = list(zip(names, resolved_weights))
286
+ details: dict[str, Any] = {
287
+ "weights": used,
288
+ "normalized": bool(normalize),
289
+ "sum_weights": total,
290
+ }
291
+
292
+ return AggregatedSurprise(
293
+ surprise=agg,
294
+ method="weighted",
295
+ num_layers=len(aligned),
296
+ details=details,
297
+ )
298
+
299
+
300
+ def _as_sample_surprise_tensor(obj: Any) -> torch.Tensor | None:
301
+ """Return a per-sample surprise tensor for a supported input object.
302
+
303
+ All supported inputs are reduced to a sample-level vector by collapsing
304
+ leading dimensions after the first (using mean) when necessary. This allows
305
+ combining wrappers that produce position- or token-level signals with
306
+ batch-only signals.
307
+
308
+ Supported inputs:
309
+ - CodingTrace: use ``sample_reduced_surprise()`` when available; otherwise
310
+ prefer ``soft_code.combined_surprise``, else ``soft_code.best_length``.
311
+ - SoftCode: prefer combined_surprise, else best_length; reduce to ``(B,)``
312
+ when rank > 1.
313
+ - torch.Tensor: reduce to ``(B,)`` when rank > 1.
314
+ """
315
+
316
+ if torch is None: # pragma: no cover - defensive
317
+ return None
318
+
319
+ # Torch tensor path
320
+ if hasattr(torch, "Tensor") and isinstance(obj, torch.Tensor):
321
+ return _reduce_to_sample(obj)
322
+
323
+ # SoftCode path
324
+ if isinstance(obj, SoftCode):
325
+ cand = getattr(obj, "combined_surprise", None)
326
+ if cand is None:
327
+ cand = getattr(obj, "best_length", None)
328
+ if isinstance(cand, torch.Tensor):
329
+ return _reduce_to_sample(cand)
330
+ return None
331
+
332
+ # CodingTrace path
333
+ if isinstance(obj, CodingTrace):
334
+ # Prefer a sample-level view provided by the trace itself.
335
+ view = getattr(obj, "sample_reduced_surprise", None)
336
+ if callable(view):
337
+ try:
338
+ out = view()
339
+ if isinstance(out, torch.Tensor):
340
+ return out
341
+ except Exception:
342
+ pass
343
+ sc = getattr(obj, "soft_code", None)
344
+ if isinstance(sc, SoftCode):
345
+ return _as_sample_surprise_tensor(sc)
346
+ return None
347
+
348
+ # Mapping path: allow passing a dict-like with a direct 'surprise' tensor
349
+ if isinstance(obj, Mapping):
350
+ # Common keys we might accept in future; keep minimal for MVP.
351
+ cand = obj.get("surprise") if hasattr(obj, "get") else None
352
+ if isinstance(cand, torch.Tensor):
353
+ return _reduce_to_sample(cand)
354
+
355
+ return None
356
+
357
+
358
+ def _reduce_to_sample(t: torch.Tensor) -> torch.Tensor:
359
+ """Reduce a tensor to a per-sample vector by averaging extra leading dims.
360
+
361
+ - ``(B,)`` -> returned unchanged
362
+ - ``(B, T, ...)`` -> mean over dims 1..N-1
363
+ - ``()`` (scalar) -> returned as-is (caller may choose to skip)
364
+ """
365
+
366
+ if getattr(t, "ndim", 0) <= 1:
367
+ return t
368
+ dims = tuple(range(1, int(t.ndim)))
369
+ return t.mean(dim=dims)