nervecode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nervecode/__init__.py +415 -0
- nervecode/_version.py +10 -0
- nervecode/core/__init__.py +19 -0
- nervecode/core/assignment.py +165 -0
- nervecode/core/codebook.py +182 -0
- nervecode/core/shapes.py +107 -0
- nervecode/core/temperature.py +227 -0
- nervecode/core/trace.py +166 -0
- nervecode/core/types.py +116 -0
- nervecode/integration/__init__.py +9 -0
- nervecode/layers/__init__.py +15 -0
- nervecode/layers/base.py +333 -0
- nervecode/layers/conv.py +174 -0
- nervecode/layers/linear.py +176 -0
- nervecode/layers/reducers.py +80 -0
- nervecode/layers/wrap.py +223 -0
- nervecode/scoring/__init__.py +20 -0
- nervecode/scoring/aggregator.py +369 -0
- nervecode/scoring/calibrator.py +396 -0
- nervecode/scoring/types.py +33 -0
- nervecode/training/__init__.py +25 -0
- nervecode/training/diagnostics.py +194 -0
- nervecode/training/loss.py +188 -0
- nervecode/training/updaters.py +168 -0
- nervecode/utils/__init__.py +14 -0
- nervecode/utils/overhead.py +177 -0
- nervecode/utils/seed.py +161 -0
- nervecode-0.1.0.dist-info/METADATA +83 -0
- nervecode-0.1.0.dist-info/RECORD +31 -0
- nervecode-0.1.0.dist-info/WHEEL +4 -0
- nervecode-0.1.0.dist-info/licenses/LICENSE +22 -0
nervecode/core/types.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Core type definitions.
|
|
2
|
+
|
|
3
|
+
This module defines lightweight dataclasses used across the core coding
|
|
4
|
+
pipeline. Shapes follow the convention that probability-like tensors carry
|
|
5
|
+
arbitrary leading dimensions with a final code dimension of size ``K``.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
__all__ = ["SoftCode"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class SoftCode:
|
|
18
|
+
"""Soft assignment over a codebook.
|
|
19
|
+
|
|
20
|
+
The ``probs`` tensor encodes probabilities over ``K`` codes for each
|
|
21
|
+
position in the (possibly batched) input. It supports arbitrary leading
|
|
22
|
+
dimensions followed by a final code dimension of shape ``(..., K)``.
|
|
23
|
+
|
|
24
|
+
Fields
|
|
25
|
+
- `probs` (Tensor): probabilities over codes with shape ``(..., K)``.
|
|
26
|
+
- `best_length` (Tensor, optional): best-code codelength per position
|
|
27
|
+
with shape matching the leading dimensions ``...``.
|
|
28
|
+
- `entropy` (Tensor, optional): assignment entropy per position with
|
|
29
|
+
shape matching the leading dimensions ``...``.
|
|
30
|
+
- `best_indices` (Tensor, optional): argmax code indices per position with
|
|
31
|
+
shape matching the leading dimensions ``...``.
|
|
32
|
+
- `combined_surprise` (Tensor, optional): a scalar-like per-position
|
|
33
|
+
surprise that may combine multiple terms; shape matches ``...``.
|
|
34
|
+
|
|
35
|
+
All optional scalar-like fields must match the leading dimensions of
|
|
36
|
+
``probs`` exactly (i.e., ``tensor.shape == probs.shape[:-1]``). When the
|
|
37
|
+
leading shape is empty, scalar-like fields must be 0-dim tensors.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
probs: Any
|
|
41
|
+
best_length: Any | None = None
|
|
42
|
+
entropy: Any | None = None
|
|
43
|
+
best_indices: Any | None = None
|
|
44
|
+
combined_surprise: Any | None = None
|
|
45
|
+
|
|
46
|
+
def __post_init__(self) -> None:
|
|
47
|
+
from typing import Any
|
|
48
|
+
from typing import cast as _cast
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
import torch # runtime import; typing deferred
|
|
52
|
+
except Exception: # pragma: no cover - torch absent in some test envs
|
|
53
|
+
torch = _cast(Any, None)
|
|
54
|
+
|
|
55
|
+
# Validate probs tensor
|
|
56
|
+
if torch is not None and not isinstance(
|
|
57
|
+
self.probs, torch.Tensor
|
|
58
|
+
): # pragma: no cover - defensive
|
|
59
|
+
raise TypeError("probs must be a torch.Tensor")
|
|
60
|
+
if getattr(self.probs, "ndim", None) is None or self.probs.ndim < 1:
|
|
61
|
+
raise ValueError("probs must have at least one dimension (..., K)")
|
|
62
|
+
k = int(self.probs.shape[-1])
|
|
63
|
+
if k <= 0:
|
|
64
|
+
raise ValueError("final code dimension K must be >= 1")
|
|
65
|
+
|
|
66
|
+
# Optional fields shape validation (only when torch is available)
|
|
67
|
+
if torch is None: # pragma: no cover - validation relies on torch
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
lead_shape = tuple(int(s) for s in self.probs.shape[:-1])
|
|
71
|
+
|
|
72
|
+
def _check_shape(name: str, value: Any) -> None:
|
|
73
|
+
if not isinstance(value, torch.Tensor):
|
|
74
|
+
raise TypeError(f"{name} must be a torch.Tensor if provided")
|
|
75
|
+
if tuple(int(s) for s in value.shape) != lead_shape:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"{name} must have shape {lead_shape} to match probs.leading_shape"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if self.best_length is not None:
|
|
81
|
+
_check_shape("best_length", self.best_length)
|
|
82
|
+
# float-like recommended; do not hard-enforce dtype beyond non-bool
|
|
83
|
+
if self.best_length.dtype == torch.bool: # pragma: no cover - defensive
|
|
84
|
+
raise TypeError("best_length must be numeric, not bool")
|
|
85
|
+
|
|
86
|
+
if self.entropy is not None:
|
|
87
|
+
_check_shape("entropy", self.entropy)
|
|
88
|
+
if self.entropy.dtype == torch.bool: # pragma: no cover - defensive
|
|
89
|
+
raise TypeError("entropy must be numeric, not bool")
|
|
90
|
+
|
|
91
|
+
if self.combined_surprise is not None:
|
|
92
|
+
_check_shape("combined_surprise", self.combined_surprise)
|
|
93
|
+
if self.combined_surprise.dtype == torch.bool: # pragma: no cover - defensive
|
|
94
|
+
raise TypeError("combined_surprise must be numeric, not bool")
|
|
95
|
+
|
|
96
|
+
if self.best_indices is not None:
|
|
97
|
+
_check_shape("best_indices", self.best_indices)
|
|
98
|
+
# Require integer dtype for indices
|
|
99
|
+
if torch.is_floating_point(self.best_indices) or self.best_indices.dtype == torch.bool:
|
|
100
|
+
raise TypeError("best_indices must use an integer dtype")
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def leading_shape(self) -> tuple[int, ...]:
|
|
104
|
+
"""Return leading shape before the final code dimension.
|
|
105
|
+
|
|
106
|
+
For ``probs`` with shape ``(..., K)``, this returns ``...`` as a tuple.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
# mypy struggles with torch.Size as a tuple of ints; cast explicitly.
|
|
110
|
+
return tuple(int(s) for s in self.probs.shape[:-1])
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def code_dim(self) -> int:
|
|
114
|
+
"""Size of the final code dimension ``K``."""
|
|
115
|
+
|
|
116
|
+
return int(self.probs.shape[-1])
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Layer wrappers and instrumentation shims.
|
|
2
|
+
|
|
3
|
+
This subpackage will contain observe-only wrappers around selected PyTorch
|
|
4
|
+
layers and utilities to apply them to user models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from .base import BaseCodingWrapper, ReducerFn, identity_reduction
|
|
10
|
+
|
|
11
|
+
__all__: list[str] = [
|
|
12
|
+
"BaseCodingWrapper",
|
|
13
|
+
"ReducerFn",
|
|
14
|
+
"identity_reduction",
|
|
15
|
+
]
|
nervecode/layers/base.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
"""Common wrapper base: bypass, trace cache, reduction, diagnostics hooks.
|
|
2
|
+
|
|
3
|
+
This module defines a small, torch-agnostic base class and companion typing to
|
|
4
|
+
standardize behavior shared by coding wrappers:
|
|
5
|
+
|
|
6
|
+
- Bypass and fail-open controls (enable_coding(), disable_coding(), and the
|
|
7
|
+
bypass() context manager) with correct nesting semantics.
|
|
8
|
+
- Latest per-layer trace caching (last_trace) with explicit update/clear.
|
|
9
|
+
- A light reduction callable type compatible with identity or learned reducers.
|
|
10
|
+
- Diagnostics hook registration that runs safely on trace updates.
|
|
11
|
+
|
|
12
|
+
Wrappers built on top of this class remain observe-only: the base makes no
|
|
13
|
+
assumptions about forward semantics beyond providing helpers that wrappers can
|
|
14
|
+
use after computing the original layer output.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from collections.abc import Callable, Iterator
|
|
20
|
+
from contextlib import contextmanager, suppress
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
try: # Optional torch import for typing only; keep runtime torch-agnostic.
|
|
24
|
+
from nervecode.core import CodingTrace # re-exported type
|
|
25
|
+
except Exception: # pragma: no cover - available in normal package use
|
|
26
|
+
CodingTrace = Any # type: ignore[misc,assignment]
|
|
27
|
+
|
|
28
|
+
__all__ = ["BaseCodingWrapper", "ReducerFn", "identity_reduction"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
ReducerFn = Callable[[Any], tuple[Any, dict[str, Any]]]
|
|
32
|
+
"""Callable that reduces an activation and returns metadata.
|
|
33
|
+
|
|
34
|
+
Accepts a single activation-like object (typically a tensor with shape
|
|
35
|
+
``(..., D_out)``) and returns ``(reduced, reduction_meta)`` where ``reduced``
|
|
36
|
+
has shape ``(..., D)`` and ``reduction_meta`` is a JSON-serializable dict that
|
|
37
|
+
identifies the applied reduction (e.g., {"method": "identity"}).
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def identity_reduction(x: Any) -> tuple[Any, dict[str, Any]]:
|
|
42
|
+
"""Return input unchanged with minimal metadata."""
|
|
43
|
+
|
|
44
|
+
return x, {"method": "identity"}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BaseCodingWrapper:
|
|
48
|
+
"""Torch-agnostic base for coding wrappers.
|
|
49
|
+
|
|
50
|
+
Responsibilities
|
|
51
|
+
- Maintain an enable/disable switch and a nestable bypass context to
|
|
52
|
+
implement fail-open behavior consistently across wrappers.
|
|
53
|
+
- Cache the most recent per-layer trace and run registered diagnostics
|
|
54
|
+
hooks safely when it updates.
|
|
55
|
+
- Provide a reducer slot with an identity default so wrappers can apply
|
|
56
|
+
a dimension reduction before coding without duplicating boilerplate.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, *, reducer: ReducerFn | None = None) -> None:
|
|
60
|
+
self._coding_enabled: bool = True
|
|
61
|
+
self._bypass_depth: int = 0
|
|
62
|
+
self._last_trace: CodingTrace | None = None
|
|
63
|
+
# Default to identity reduction when none is provided
|
|
64
|
+
self._reducer: ReducerFn = reducer or identity_reduction
|
|
65
|
+
self._diagnostic_hooks: list[Callable[[CodingTrace], None]] = []
|
|
66
|
+
|
|
67
|
+
# --- Bypass and enable/disable -------------------------------------------------
|
|
68
|
+
@property
|
|
69
|
+
def coding_enabled(self) -> bool:
|
|
70
|
+
"""Return whether coding is globally enabled for this wrapper."""
|
|
71
|
+
|
|
72
|
+
return self._coding_enabled
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def is_coding_active(self) -> bool:
|
|
76
|
+
"""True when coding is enabled and not currently bypassed."""
|
|
77
|
+
|
|
78
|
+
return self._coding_enabled and self._bypass_depth == 0
|
|
79
|
+
|
|
80
|
+
def enable_coding(self) -> None:
|
|
81
|
+
"""Enable coding instrumentation for subsequent forwards."""
|
|
82
|
+
|
|
83
|
+
self._coding_enabled = True
|
|
84
|
+
|
|
85
|
+
def disable_coding(self) -> None:
|
|
86
|
+
"""Disable coding instrumentation (fail-open) and clear caches.
|
|
87
|
+
|
|
88
|
+
Disabling coding should restore base-model behavior immediately. To
|
|
89
|
+
avoid stale reads from convenience accessors, drop any cached latest
|
|
90
|
+
trace at the time of disabling rather than waiting for the next
|
|
91
|
+
forward pass.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
self._coding_enabled = False
|
|
95
|
+
# Clear any previously cached trace so that model-level aggregations
|
|
96
|
+
# observe the disabled state immediately.
|
|
97
|
+
self.clear_last_trace()
|
|
98
|
+
|
|
99
|
+
@contextmanager
|
|
100
|
+
def bypass(self) -> Iterator[None]:
|
|
101
|
+
"""Temporarily bypass coding instrumentation.
|
|
102
|
+
|
|
103
|
+
The context is nestable; coding is considered bypassed when the bypass
|
|
104
|
+
depth is greater than zero.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
self._bypass_depth += 1
|
|
108
|
+
try:
|
|
109
|
+
yield
|
|
110
|
+
finally:
|
|
111
|
+
self._bypass_depth -= 1
|
|
112
|
+
if self._bypass_depth < 0: # pragma: no cover - defensive
|
|
113
|
+
self._bypass_depth = 0
|
|
114
|
+
|
|
115
|
+
# --- Reduction ----------------------------------------------------------------
|
|
116
|
+
def set_reducer(self, reducer: ReducerFn) -> None:
|
|
117
|
+
"""Set the reducer callable used by ``maybe_reduce``."""
|
|
118
|
+
|
|
119
|
+
self._reducer = reducer
|
|
120
|
+
|
|
121
|
+
def get_reducer(self) -> ReducerFn:
|
|
122
|
+
"""Return the current reducer callable."""
|
|
123
|
+
|
|
124
|
+
return self._reducer
|
|
125
|
+
|
|
126
|
+
def maybe_reduce(self, activation: Any) -> tuple[Any, dict[str, Any]]:
|
|
127
|
+
"""Apply the configured reducer when coding is active.
|
|
128
|
+
|
|
129
|
+
When coding is not active (disabled or bypassed), return the input
|
|
130
|
+
unchanged with a metadata dictionary indicating bypass. This function
|
|
131
|
+
never raises; wrappers should keep observe-only semantics.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
if not self.is_coding_active:
|
|
135
|
+
return activation, {"method": "bypass"}
|
|
136
|
+
try:
|
|
137
|
+
return self._reducer(activation)
|
|
138
|
+
except Exception: # Fail-open on reducer errors
|
|
139
|
+
# Leave a recognizable marker to help downstream logging while
|
|
140
|
+
# guaranteeing observe-only behavior.
|
|
141
|
+
return activation, {"method": "bypass", "reason": "reduction_error"}
|
|
142
|
+
|
|
143
|
+
# --- Trace cache and diagnostics ----------------------------------------------
|
|
144
|
+
@property
|
|
145
|
+
def last_trace(self) -> CodingTrace | None:
|
|
146
|
+
"""Return the most recently cached trace, if any."""
|
|
147
|
+
|
|
148
|
+
return self._last_trace
|
|
149
|
+
|
|
150
|
+
def clear_last_trace(self) -> None:
|
|
151
|
+
"""Drop the cached trace."""
|
|
152
|
+
|
|
153
|
+
self._last_trace = None
|
|
154
|
+
|
|
155
|
+
def update_last_trace(self, trace: CodingTrace) -> None:
|
|
156
|
+
"""Cache the latest trace and run diagnostics hooks safely."""
|
|
157
|
+
|
|
158
|
+
self._last_trace = trace
|
|
159
|
+
self._run_diagnostic_hooks(trace)
|
|
160
|
+
|
|
161
|
+
def add_diagnostics_hook(self, hook: Callable[[CodingTrace], None]) -> None:
|
|
162
|
+
"""Register a hook invoked on every trace update.
|
|
163
|
+
|
|
164
|
+
Hooks must be side-effect safe; exceptions are suppressed to preserve
|
|
165
|
+
fail-open behavior.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
self._diagnostic_hooks.append(hook)
|
|
169
|
+
|
|
170
|
+
def remove_diagnostics_hook(self, hook: Callable[[CodingTrace], None]) -> None:
|
|
171
|
+
"""Remove a previously registered diagnostics hook, if present."""
|
|
172
|
+
|
|
173
|
+
with suppress(ValueError):
|
|
174
|
+
self._diagnostic_hooks.remove(hook)
|
|
175
|
+
|
|
176
|
+
def _run_diagnostic_hooks(self, trace: CodingTrace) -> None:
|
|
177
|
+
for hook in tuple(self._diagnostic_hooks): # work on a snapshot
|
|
178
|
+
try:
|
|
179
|
+
hook(trace)
|
|
180
|
+
except Exception:
|
|
181
|
+
# Suppress hook exceptions to preserve user forwards; wrappers
|
|
182
|
+
# are observe-only. Hooks are best-effort diagnostics.
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
# --- Layer-level diagnostics (latest-trace) ---------------------------------
|
|
186
|
+
def utilization(self) -> float | None:
|
|
187
|
+
"""Fraction of codes used in the latest trace.
|
|
188
|
+
|
|
189
|
+
Returns the number of unique chosen indices divided by ``K`` from the
|
|
190
|
+
most recent trace, or ``None`` when unavailable. The result is a
|
|
191
|
+
Python ``float`` in ``[0, 1]``.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
trace = self.last_trace
|
|
195
|
+
if trace is None: # no trace cached
|
|
196
|
+
return None
|
|
197
|
+
try:
|
|
198
|
+
K = int(trace.soft_code.code_dim)
|
|
199
|
+
if K <= 0:
|
|
200
|
+
return None
|
|
201
|
+
# Prefer SoftCode.best_indices when present; otherwise fall back to
|
|
202
|
+
# the hard choice stored on the trace.
|
|
203
|
+
best = trace.soft_code.best_indices
|
|
204
|
+
if best is None:
|
|
205
|
+
best = trace.chosen_center_indices
|
|
206
|
+
values = self._flatten_to_list(best)
|
|
207
|
+
if values is None or K == 0:
|
|
208
|
+
return None
|
|
209
|
+
used = len({int(v) for v in values})
|
|
210
|
+
return float(used) / float(K)
|
|
211
|
+
except Exception:
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
def mean_entropy(self) -> float | None:
|
|
215
|
+
"""Mean assignment entropy from the latest trace, or ``None``.
|
|
216
|
+
|
|
217
|
+
Returns ``soft_code.entropy.mean()`` as a Python ``float`` when a
|
|
218
|
+
cached trace is available and contains entropy; otherwise ``None``.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
trace = self.last_trace
|
|
222
|
+
if trace is None:
|
|
223
|
+
return None
|
|
224
|
+
entropy = trace.soft_code.entropy
|
|
225
|
+
if entropy is None:
|
|
226
|
+
return None
|
|
227
|
+
return self._mean_as_float(entropy)
|
|
228
|
+
|
|
229
|
+
def mean_code_length(self) -> float | None:
|
|
230
|
+
"""Mean best-code length from the latest trace, or ``None``.
|
|
231
|
+
|
|
232
|
+
Returns ``soft_code.best_length.mean()`` as a Python ``float`` when a
|
|
233
|
+
cached trace is available and contains best-code lengths; otherwise
|
|
234
|
+
``None``.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
trace = self.last_trace
|
|
238
|
+
if trace is None:
|
|
239
|
+
return None
|
|
240
|
+
best_len = trace.soft_code.best_length
|
|
241
|
+
if best_len is None:
|
|
242
|
+
return None
|
|
243
|
+
return self._mean_as_float(best_len)
|
|
244
|
+
|
|
245
|
+
def mean_commitment_distance(self) -> float | None:
|
|
246
|
+
"""Mean commitment distance from the latest trace, or ``None``.
|
|
247
|
+
|
|
248
|
+
Uses the per-position ``commitment_distances`` stored on the trace.
|
|
249
|
+
Returns a Python ``float`` when available, else ``None``.
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
trace = self.last_trace
|
|
253
|
+
if trace is None:
|
|
254
|
+
return None
|
|
255
|
+
d = trace.commitment_distances
|
|
256
|
+
if d is None:
|
|
257
|
+
return None
|
|
258
|
+
return self._mean_as_float(d)
|
|
259
|
+
|
|
260
|
+
@staticmethod
|
|
261
|
+
def _flatten_to_list(x: Any) -> list[float] | None:
|
|
262
|
+
"""Best-effort flatten to a Python list of numbers.
|
|
263
|
+
|
|
264
|
+
Supports torch-like objects exposing ``reshape``/``view`` and ``tolist``,
|
|
265
|
+
as well as nested Python sequences.
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
# Torch-like path
|
|
269
|
+
try:
|
|
270
|
+
if hasattr(x, "reshape") and hasattr(x, "tolist"):
|
|
271
|
+
return list(x.reshape(-1).tolist())
|
|
272
|
+
if hasattr(x, "view") and hasattr(x, "tolist"):
|
|
273
|
+
return list(x.view(-1).tolist())
|
|
274
|
+
if hasattr(x, "tolist"):
|
|
275
|
+
out = x.tolist()
|
|
276
|
+
if isinstance(out, list):
|
|
277
|
+
# Shallow list; attempt to flatten one level if needed
|
|
278
|
+
def _flatten(v: Any) -> list[float]:
|
|
279
|
+
if isinstance(v, list):
|
|
280
|
+
res: list[float] = []
|
|
281
|
+
for item in v:
|
|
282
|
+
res.extend(_flatten(item))
|
|
283
|
+
return res
|
|
284
|
+
try:
|
|
285
|
+
return [float(v)]
|
|
286
|
+
except Exception:
|
|
287
|
+
return []
|
|
288
|
+
|
|
289
|
+
return _flatten(out)
|
|
290
|
+
except Exception:
|
|
291
|
+
pass
|
|
292
|
+
|
|
293
|
+
# Python sequence fallback
|
|
294
|
+
if isinstance(x, (list, tuple)):
|
|
295
|
+
res2: list[float] = []
|
|
296
|
+
for item in x:
|
|
297
|
+
sub = BaseCodingWrapper._flatten_to_list(item)
|
|
298
|
+
if sub is not None:
|
|
299
|
+
res2.extend(sub)
|
|
300
|
+
return res2
|
|
301
|
+
# Scalar fallback
|
|
302
|
+
try:
|
|
303
|
+
return [float(x)]
|
|
304
|
+
except Exception:
|
|
305
|
+
return None
|
|
306
|
+
|
|
307
|
+
@staticmethod
|
|
308
|
+
def _mean_as_float(x: Any) -> float | None:
|
|
309
|
+
"""Return mean(x) as a float for torch-like or Python sequences.
|
|
310
|
+
|
|
311
|
+
- If ``x`` has a ``mean()`` method, call it and extract ``item()`` if
|
|
312
|
+
available.
|
|
313
|
+
- Otherwise, attempt to flatten to a list and compute an arithmetic
|
|
314
|
+
mean.
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
# Torch-like: x.mean().item()
|
|
318
|
+
try:
|
|
319
|
+
if hasattr(x, "mean"):
|
|
320
|
+
m = x.mean()
|
|
321
|
+
if hasattr(m, "item"):
|
|
322
|
+
return float(m.item())
|
|
323
|
+
return float(m)
|
|
324
|
+
except Exception:
|
|
325
|
+
pass
|
|
326
|
+
|
|
327
|
+
values = BaseCodingWrapper._flatten_to_list(x)
|
|
328
|
+
if not values:
|
|
329
|
+
return None
|
|
330
|
+
try:
|
|
331
|
+
return float(sum(values) / float(len(values)))
|
|
332
|
+
except Exception:
|
|
333
|
+
return None
|
nervecode/layers/conv.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""Coding wrapper for ``torch.nn.Conv2d`` with pooled coding.
|
|
2
|
+
|
|
3
|
+
This module provides ``CodingConv2d`` — an observe-only wrapper around
|
|
4
|
+
``nn.Conv2d`` that computes coding statistics over a pooled representation of
|
|
5
|
+
the layer output. The visible forward value remains identical to the underlying
|
|
6
|
+
convolution; coding runs side-by-side and updates an internal trace cache.
|
|
7
|
+
|
|
8
|
+
Design constraints for the first convolutional wrapper:
|
|
9
|
+
- Observe-only contract for forwards (fail-open behavior).
|
|
10
|
+
- Coding over a pooled representation (global average pooling over H and W).
|
|
11
|
+
- Default sample-level reducer corresponds to global average pooling semantics.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from .base import BaseCodingWrapper, ReducerFn
|
|
19
|
+
|
|
20
|
+
# Import torch lazily but tolerate environments without it at import time.
|
|
21
|
+
try: # pragma: no cover - exercised in torch-enabled environments
|
|
22
|
+
import torch
|
|
23
|
+
from torch import nn
|
|
24
|
+
except Exception: # pragma: no cover - allow import without torch
|
|
25
|
+
from typing import Any
|
|
26
|
+
from typing import cast as _cast
|
|
27
|
+
|
|
28
|
+
torch = _cast(Any, None)
|
|
29
|
+
nn = _cast(Any, None)
|
|
30
|
+
|
|
31
|
+
from nervecode.core import CodingTrace
|
|
32
|
+
from nervecode.core.assignment import SoftAssignment
|
|
33
|
+
from nervecode.core.codebook import Codebook
|
|
34
|
+
|
|
35
|
+
__all__ = ["CodingConv2d"]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class CodingConv2d(nn.Module, BaseCodingWrapper):
|
|
39
|
+
"""Observe-only wrapper around ``nn.Conv2d`` with pooled coding.
|
|
40
|
+
|
|
41
|
+
The wrapper delegates computation to the underlying ``nn.Conv2d`` and,
|
|
42
|
+
when coding is active, computes a ``CodingTrace`` over a globally pooled
|
|
43
|
+
representation of the visible activation. The visible forward output is
|
|
44
|
+
never altered.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
- layer: The ``nn.Conv2d`` module to wrap.
|
|
48
|
+
- K: Number of codebook centers. Defaults to 16.
|
|
49
|
+
- coding_dim: Target coding dimension ``D``. When ``None`` (default), uses
|
|
50
|
+
``layer.out_channels``. If provided, must satisfy ``1 <= coding_dim <=
|
|
51
|
+
layer.out_channels``; when smaller than ``out_channels``, a learned linear
|
|
52
|
+
projection is applied after global average pooling, affecting only the
|
|
53
|
+
coding path (observe-only contract preserved).
|
|
54
|
+
- reducer: Optional reducer callable. Overrides the default global-average
|
|
55
|
+
pooling behavior when provided.
|
|
56
|
+
- temperature: Soft-assignment temperature. Higher is softer.
|
|
57
|
+
- eps: Numerical epsilon to clamp divisions and logs.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
layer: Any, # use Any to avoid a hard torch dependency in type hints
|
|
63
|
+
*,
|
|
64
|
+
K: int = 16,
|
|
65
|
+
coding_dim: int | None = None,
|
|
66
|
+
reducer: ReducerFn | None = None,
|
|
67
|
+
temperature: float = 1.0,
|
|
68
|
+
eps: float = 1e-12,
|
|
69
|
+
) -> None:
|
|
70
|
+
if torch is None: # pragma: no cover - defensive
|
|
71
|
+
raise RuntimeError("CodingConv2d requires PyTorch to be installed")
|
|
72
|
+
|
|
73
|
+
nn.Module.__init__(self)
|
|
74
|
+
BaseCodingWrapper.__init__(self, reducer=reducer)
|
|
75
|
+
|
|
76
|
+
if not isinstance(layer, nn.Conv2d):
|
|
77
|
+
raise TypeError("layer must be an instance of torch.nn.Conv2d")
|
|
78
|
+
|
|
79
|
+
self.layer: Any = layer
|
|
80
|
+
|
|
81
|
+
out_channels = int(self.layer.out_channels)
|
|
82
|
+
if coding_dim is None:
|
|
83
|
+
code_dim = out_channels
|
|
84
|
+
else:
|
|
85
|
+
code_dim = int(coding_dim)
|
|
86
|
+
if code_dim <= 0:
|
|
87
|
+
raise ValueError("coding_dim must be >= 1")
|
|
88
|
+
if code_dim > out_channels:
|
|
89
|
+
raise ValueError("coding_dim must be <= layer.out_channels for reduction")
|
|
90
|
+
|
|
91
|
+
# Optional learned projection after pooling when reducing channels.
|
|
92
|
+
self._projection: Any | None = None
|
|
93
|
+
if reducer is None and code_dim < out_channels:
|
|
94
|
+
# Use a linear projection over the pooled channel vector (B, C) -> (B, D).
|
|
95
|
+
self._projection = nn.Linear(out_channels, code_dim, bias=False)
|
|
96
|
+
|
|
97
|
+
def _pool_then_project(x: Any) -> tuple[Any, dict[str, Any]]:
|
|
98
|
+
# Global average pooling over spatial dims (H, W)
|
|
99
|
+
pooled = x.mean(dim=(-1, -2)) # (B, C)
|
|
100
|
+
y_proj = self._projection(pooled) # type: ignore[misc]
|
|
101
|
+
meta = {
|
|
102
|
+
"method": "global_avg_pool2d+linear_projection",
|
|
103
|
+
"in": out_channels,
|
|
104
|
+
"out": code_dim,
|
|
105
|
+
# Record that coding view comes from spatial reduction
|
|
106
|
+
"spatial_reduction": True,
|
|
107
|
+
"reduction_axes": [-2, -1], # H, W in NCHW
|
|
108
|
+
}
|
|
109
|
+
return y_proj, meta
|
|
110
|
+
|
|
111
|
+
self.set_reducer(_pool_then_project)
|
|
112
|
+
elif reducer is None:
|
|
113
|
+
# Default: global average pooling over spatial dims to (B, C)
|
|
114
|
+
def _global_avg_pool2d(x: Any) -> tuple[Any, dict[str, Any]]:
|
|
115
|
+
reduced = x.mean(dim=(-1, -2)) # (B, C)
|
|
116
|
+
return reduced, {
|
|
117
|
+
"method": "global_avg_pool2d",
|
|
118
|
+
"spatial_reduction": True,
|
|
119
|
+
"reduction_axes": [-2, -1], # H, W in NCHW
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
self.set_reducer(_global_avg_pool2d)
|
|
123
|
+
# else: keep user-provided reducer
|
|
124
|
+
|
|
125
|
+
# Codebook over final feature dimension D=code_dim
|
|
126
|
+
self.codebook = Codebook(K=int(K), code_dim=code_dim)
|
|
127
|
+
self.assignment = SoftAssignment(temperature=float(temperature), eps=float(eps))
|
|
128
|
+
|
|
129
|
+
def forward(self, x: Any) -> Any: # torch.Tensor in torch-enabled envs
|
|
130
|
+
"""Return the wrapped layer's output and update the cached trace.
|
|
131
|
+
|
|
132
|
+
Preserves the observe-only contract: the visible value is the
|
|
133
|
+
underlying layer output. For programmatic use, prefer
|
|
134
|
+
``forward_with_trace`` to retrieve the explicit ``CodingTrace``
|
|
135
|
+
alongside the output.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
y, _ = self.forward_with_trace(x)
|
|
139
|
+
return y
|
|
140
|
+
|
|
141
|
+
def forward_with_trace(self, x: Any) -> tuple[Any, CodingTrace | None]:
|
|
142
|
+
"""Return the wrapped output and, when active, the explicit trace."""
|
|
143
|
+
|
|
144
|
+
y = self.layer(x)
|
|
145
|
+
|
|
146
|
+
if not self.is_coding_active:
|
|
147
|
+
self.clear_last_trace()
|
|
148
|
+
return y, None
|
|
149
|
+
|
|
150
|
+
# Apply configured reducer (default: global average pooling over H, W).
|
|
151
|
+
reduced, reduction_meta = self.maybe_reduce(y)
|
|
152
|
+
|
|
153
|
+
# Compute assignments over the pooled representation.
|
|
154
|
+
soft, inter = self.assignment(reduced, self.codebook)
|
|
155
|
+
|
|
156
|
+
trace = CodingTrace(
|
|
157
|
+
reduced=reduced,
|
|
158
|
+
reduction_meta=reduction_meta,
|
|
159
|
+
nearest_center_distances=inter["nearest_center_distances"],
|
|
160
|
+
chosen_center_indices=inter["chosen_center_indices"],
|
|
161
|
+
commitment_distances=inter["commitment_distances"],
|
|
162
|
+
soft_code=soft,
|
|
163
|
+
)
|
|
164
|
+
self.update_last_trace(trace)
|
|
165
|
+
return y, trace
|
|
166
|
+
|
|
167
|
+
def extra_repr(self) -> str: # pragma: no cover - cosmetic
|
|
168
|
+
status = "on" if self.coding_enabled else "off"
|
|
169
|
+
return (
|
|
170
|
+
f"layer=Conv2d(out_channels={int(self.layer.out_channels)}), "
|
|
171
|
+
f"code_dim={int(self.codebook.code_dim)}, "
|
|
172
|
+
f"K={int(self.codebook.K)}, reducer={self.get_reducer().__name__}, "
|
|
173
|
+
f"coding={status}"
|
|
174
|
+
)
|