nervecode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,188 @@
1
+ """Auxiliary coding loss that consumes CodingTrace objects.
2
+
3
+ This module implements a lightweight training-time loss that reads existing
4
+ ``CodingTrace`` objects produced by wrapped layers. It avoids recomputing
5
+ distances or other intermediates from raw activations by relying on the
6
+ per-position surprise already stored on the associated ``SoftCode`` inside each
7
+ trace. When multiple layers are provided, their per-sample surprise signals are
8
+ aggregated using the default mean strategy from ``nervecode.scoring`` and then
9
+ reduced across the batch to yield a scalar loss.
10
+
11
+ The loss intentionally remains minimal in the MVP: it performs a mean reduction
12
+ over samples. Follow-up tasks extend it with configurable term weights and
13
+ clear per-layer breakdowns.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from collections.abc import Iterable, Mapping
19
+ from typing import Any, cast
20
+
21
+ try: # Keep import-time behavior tolerant in environments without torch
22
+ import torch
23
+ from torch import nn
24
+ except Exception: # pragma: no cover - torch is a project dependency in tests
25
+ torch = cast(Any, None)
26
+ nn = cast(Any, None)
27
+
28
+ try:
29
+ from nervecode.core import CodingTrace, SoftCode # re-exported types
30
+ except Exception: # pragma: no cover - available during normal package use
31
+ CodingTrace = object # type: ignore[misc,assignment]
32
+ SoftCode = object # type: ignore[misc,assignment]
33
+
34
+ from nervecode.scoring import mean_surprise
35
+
36
+ __all__ = ["CodingLoss"]
37
+
38
+
39
+ class _StubModule: # pragma: no cover - used only when torch is unavailable
40
+ pass
41
+
42
+
43
+ _Base = nn.Module if nn is not None else _StubModule
44
+
45
+
46
+ class CodingLoss(_Base): # type: ignore[misc,valid-type]
47
+ """Compute a scalar coding loss from traces without recomputation.
48
+
49
+ Inputs may be a single ``CodingTrace`` (or ``SoftCode``/tensor), an
50
+ iterable of such objects, or a mapping from layer names to objects.
51
+
52
+ The loss has two conceptual parts:
53
+ - an assignment/sharpness term derived from the per-position surprise
54
+ on ``SoftCode`` (aggregated via ``mean_surprise``), and
55
+ - an optional commitment term that uses the nearest-center distance
56
+ already stored on each ``CodingTrace``.
57
+
58
+ Both terms are reduced to a per-sample view and averaged across layers;
59
+ the final scalar is a weighted sum with ``assign_weight`` and
60
+ ``commit_weight``. By default the commitment term weight is zero to keep
61
+ backward compatibility with earlier behavior.
62
+ """
63
+
64
+ def __init__(self, *, assign_weight: float = 1.0, commit_weight: float = 0.0) -> None:
65
+ if torch is None: # pragma: no cover - defensive
66
+ raise RuntimeError("CodingLoss requires PyTorch to be installed")
67
+ super().__init__()
68
+ self.assign_weight: float = float(assign_weight)
69
+ self.commit_weight: float = float(commit_weight)
70
+
71
+ def forward(self, traces: Any) -> torch.Tensor:
72
+ """Return weighted mean-aggregated coding loss across samples.
73
+
74
+ Parameters
75
+ - traces: A ``CodingTrace`` object (or ``SoftCode``/tensor), an
76
+ iterable of such objects, or a mapping from layer names to objects.
77
+
78
+ Returns
79
+ - A scalar ``torch.Tensor`` equal to the weighted sum of the per-sample
80
+ aggregated assignment term and, when enabled, the aggregated
81
+ commitment term.
82
+ """
83
+
84
+ if torch is None: # pragma: no cover - defensive
85
+ raise RuntimeError("CodingLoss requires PyTorch to be installed")
86
+
87
+ # Normalize input to what mean_surprise and our commitment aggregator expect
88
+ payload: Iterable[Any] | Mapping[str, Any]
89
+ if isinstance(traces, Mapping) or (
90
+ isinstance(traces, Iterable)
91
+ and not isinstance(traces, (torch.Tensor, SoftCode, CodingTrace))
92
+ ):
93
+ payload = traces
94
+ else:
95
+ # Single object: wrap into a list for aggregation
96
+ payload = [traces]
97
+
98
+ # Assignment/sharpness component via mean_surprise (SoftCode-based)
99
+ agg = mean_surprise(payload)
100
+ if agg is None or not hasattr(agg, "surprise"):
101
+ raise ValueError("No valid surprise signals found in provided traces")
102
+ s = cast(torch.Tensor, agg.surprise)
103
+ assign_term = s.mean()
104
+
105
+ # Commitment component via nearest-center distances stored on traces
106
+ commit_vec: torch.Tensor | None = _mean_commitment_across_layers(payload)
107
+
108
+ if self.commit_weight != 0.0 and commit_vec is not None:
109
+ commit_mean = commit_vec.mean()
110
+ return (
111
+ assign_term.to(dtype=commit_mean.dtype, device=commit_mean.device)
112
+ * self.assign_weight
113
+ + commit_mean * self.commit_weight
114
+ )
115
+ # Default: only assignment term
116
+ return assign_term * self.assign_weight
117
+
118
+
119
+ def _mean_commitment_across_layers(
120
+ traces: Iterable[Any] | Mapping[str, Any],
121
+ ) -> torch.Tensor | None:
122
+ """Return per-sample mean-aggregated commitment distance across layers.
123
+
124
+ This helper scans the provided collection for ``CodingTrace`` objects,
125
+ extracts their per-position ``commitment_distances``, reduces them to a
126
+ per-sample vector by averaging extra leading dimensions, and averages
127
+ across layers. When no layer provides commitment distances, returns
128
+ ``None``.
129
+ """
130
+
131
+ if torch is None: # pragma: no cover - defensive
132
+ return None
133
+
134
+ values: Iterable[Any] = traces.values() if isinstance(traces, Mapping) else traces
135
+
136
+ signals: list[torch.Tensor] = []
137
+ ref_shape: tuple[int, ...] | None = None
138
+
139
+ for obj in values:
140
+ # Only CodingTrace carries commitment distances in the MVP
141
+ if not isinstance(obj, CodingTrace):
142
+ continue
143
+ d = getattr(obj, "commitment_distances", None)
144
+ if not isinstance(d, torch.Tensor):
145
+ continue
146
+ t = _reduce_to_sample(d)
147
+ shape = tuple(int(s) for s in t.shape)
148
+ if ref_shape is None:
149
+ ref_shape = shape
150
+ signals.append(t)
151
+ else:
152
+ if shape != ref_shape:
153
+ # Skip mismatched sample shapes to preserve fail-open behavior
154
+ continue
155
+ signals.append(t)
156
+
157
+ if not signals:
158
+ return None
159
+
160
+ # Align dtype/device to the first signal for stable aggregation
161
+ first = signals[0]
162
+ device = getattr(first, "device", None)
163
+ dtype = first.dtype if hasattr(first, "dtype") else None
164
+ aligned: list[torch.Tensor] = []
165
+ for s in signals:
166
+ s2 = s
167
+ if dtype is not None and getattr(s2, "dtype", None) is not dtype:
168
+ s2 = s2.to(dtype=dtype)
169
+ if device is not None and getattr(s2, "device", None) != device:
170
+ s2 = s2.to(device=device)
171
+ aligned.append(s2)
172
+
173
+ stacked = torch.stack(aligned, dim=0)
174
+ # Mean over layers yields a per-sample vector
175
+ return stacked.mean(dim=0)
176
+
177
+
178
+ def _reduce_to_sample(t: torch.Tensor) -> torch.Tensor:
179
+ """Reduce a tensor to a per-sample vector by averaging extra leading dims.
180
+
181
+ Mirrors the reduction used by the scoring aggregator so that both
182
+ assignment and commitment terms are combined consistently across layers.
183
+ """
184
+
185
+ if getattr(t, "ndim", 0) <= 1:
186
+ return t
187
+ dims = tuple(range(1, int(t.ndim)))
188
+ return t.mean(dim=dims)
@@ -0,0 +1,168 @@
1
+ """Experimental codebook update helpers (EMA / hybrid).
2
+
3
+ These helpers are optional, off-by-default training utilities that update
4
+ codebook centers using exponential moving averages driven by the latest
5
+ coding trace. They are provided for exploration only — the default and
6
+ recommended baseline remains standard gradient-based optimization via
7
+ ``CodingLoss`` and your optimizer.
8
+
9
+ Design goals
10
+ - Observe-only wrappers and existing training code remain unchanged unless an
11
+ updater is explicitly instantiated and called by the user.
12
+ - Safe numerics: clamp divisors, support empty usages gracefully.
13
+ - Torch-first: implemented as small ``nn.Module``s with buffers so they move
14
+ with devices and can be checkpointed when needed.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any, Literal, cast
20
+
21
+ try: # Keep import-time failure tolerant when torch isn't installed
22
+ import torch
23
+ from torch import nn
24
+ except Exception: # pragma: no cover - torch is a project dependency in tests
25
+ torch = cast(Any, None)
26
+ nn = cast(Any, None)
27
+
28
+ try:
29
+ from nervecode.core import CodingTrace
30
+ from nervecode.core.codebook import Codebook
31
+ except Exception: # pragma: no cover - available during normal package use
32
+ CodingTrace = object # type: ignore[misc,assignment]
33
+ Codebook = object # type: ignore[misc,assignment]
34
+
35
+ __all__ = ["EmaCodebookUpdater"]
36
+
37
+
38
+ class _StubModule: # pragma: no cover - used only when torch is unavailable
39
+ pass
40
+
41
+
42
+ _Base = nn.Module if nn is not None else _StubModule
43
+
44
+
45
+ class EmaCodebookUpdater(_Base): # type: ignore[misc,valid-type]
46
+ """Update codebook centers using an EMA of batch-weighted means.
47
+
48
+ This updater maintains per-code EMA counts and summed activations and sets
49
+ centers to ``ema_sum / (ema_count + eps)`` after each update call. The batch
50
+ statistics are computed from a ``CodingTrace`` using either soft
51
+ probabilities or hard assignments.
52
+
53
+ Parameters
54
+ - codebook: The ``Codebook`` whose centers are updated in-place.
55
+ - decay: Exponential decay factor in ``[0, 1)``. Higher means slower update
56
+ toward the batch statistics. ``decay=0`` uses the current batch only.
57
+ - assignment: ``"soft"`` (default) to use soft probabilities, or ``"hard"``
58
+ to use one-hot assignments based on argmax indices.
59
+ - eps: Small positive constant to avoid division by zero.
60
+
61
+ Notes
62
+ - Updates run under ``torch.no_grad()`` and preserve ``requires_grad`` on
63
+ ``codebook.centers`` so gradient-based training can continue to operate
64
+ (hybrid usage).
65
+ - The updater holds its own buffers (ema_count, ema_sum) sized from the
66
+ attached codebook at construction time. If the codebook is reinitialized
67
+ with a different shape, construct a new updater.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ codebook: Codebook,
73
+ *,
74
+ decay: float = 0.99,
75
+ assignment: Literal["soft", "hard"] = "soft",
76
+ eps: float = 1e-12,
77
+ ) -> None:
78
+ if torch is None: # pragma: no cover - defensive
79
+ raise RuntimeError("EmaCodebookUpdater requires PyTorch to be installed")
80
+ super().__init__()
81
+
82
+ if decay < 0.0 or decay >= 1.0 or not float(decay) == decay:
83
+ raise ValueError("decay must be a float in [0, 1)")
84
+ if assignment not in ("soft", "hard"):
85
+ raise ValueError("assignment must be 'soft' or 'hard'")
86
+ if eps <= 0.0:
87
+ raise ValueError("eps must be positive")
88
+
89
+ self.codebook = codebook
90
+ self.decay: float = float(decay)
91
+ self.assignment: Literal["soft", "hard"] = assignment
92
+ self.eps: float = float(eps)
93
+
94
+ K = int(codebook.K)
95
+ D = int(codebook.code_dim)
96
+ # EMA state mirrors codebook shape: counts per code and sums per code x dim
97
+ self._ema_count: torch.Tensor
98
+ self._ema_sum: torch.Tensor
99
+ self.register_buffer("_ema_count", torch.zeros(K), persistent=True)
100
+ self.register_buffer("_ema_sum", torch.zeros(K, D), persistent=True)
101
+
102
+ def reset_state(self) -> None:
103
+ """Reset EMA state to zeros (no prior influence)."""
104
+ if torch is None: # pragma: no cover - defensive
105
+ return
106
+ with torch.no_grad():
107
+ self._ema_count.zero_()
108
+ self._ema_sum.zero_()
109
+
110
+ def update_from_trace(self, trace: CodingTrace) -> None:
111
+ """Update EMA statistics and set codebook centers for a batch.
112
+
113
+ Expects a ``CodingTrace`` produced by the same codebook. The reduced
114
+ activations and the ``SoftCode`` probabilities/indices are used to
115
+ compute per-code batch counts and sums.
116
+ """
117
+
118
+ # Validate shape alignment in a torch-first way
119
+ probs = trace.soft_code.probs
120
+ x = trace.reduced
121
+
122
+ if not isinstance(probs, torch.Tensor) or not isinstance(x, torch.Tensor):
123
+ raise TypeError("trace must carry torch.Tensor probabilities and reduced activations")
124
+ if probs.ndim < 1 or x.ndim < 1:
125
+ raise ValueError("trace tensors must have at least one dimension")
126
+ if int(probs.shape[-1]) != int(self.codebook.K):
127
+ raise ValueError("trace code dimension K must match attached codebook")
128
+ if int(x.shape[-1]) != int(self.codebook.code_dim):
129
+ raise ValueError("trace reduced dim D must match attached codebook")
130
+
131
+ with torch.no_grad():
132
+ # Flatten leading dims to (N, ...)
133
+ N = int(x.numel() // x.shape[-1])
134
+ x_flat = x.reshape(N, -1) # (N, D)
135
+ p_flat = probs.reshape(N, -1) # (N, K)
136
+
137
+ if self.assignment == "hard":
138
+ # One-hot from argmax indices
139
+ idx = torch.argmax(p_flat, dim=-1) # (N,)
140
+ w = torch.zeros_like(p_flat)
141
+ w.scatter_(1, idx.unsqueeze(-1), 1.0)
142
+ else:
143
+ # Soft probabilities (already sum to 1 across codes)
144
+ w = p_flat
145
+
146
+ # Per-code batch counts and sums
147
+ batch_count = w.sum(dim=0) # (K,)
148
+ batch_sum = w.t() @ x_flat # (K, D)
149
+
150
+ # Align EMA state to the same device/dtype as incoming batch
151
+ device = x_flat.device
152
+ dtype = x_flat.dtype
153
+ if self._ema_count.device != device or self._ema_count.dtype != dtype:
154
+ self._ema_count = self._ema_count.to(device=device, dtype=dtype)
155
+ if self._ema_sum.device != device or self._ema_sum.dtype != dtype:
156
+ self._ema_sum = self._ema_sum.to(device=device, dtype=dtype)
157
+
158
+ d = self.decay
159
+ one_minus = 1.0 - d
160
+ self._ema_count.mul_(d).add_(batch_count.to(dtype=dtype), alpha=one_minus)
161
+ self._ema_sum.mul_(d).add_(batch_sum.to(dtype=dtype), alpha=one_minus)
162
+
163
+ # Update centers: ema_sum / (ema_count + eps)
164
+ denom = (self._ema_count + self.eps).unsqueeze(-1) # (K, 1)
165
+ new_centers = self._ema_sum / denom
166
+
167
+ # In-place update preserves requires_grad on the parameter
168
+ self.codebook.centers.detach().copy_(new_centers)
@@ -0,0 +1,14 @@
1
+ """Small utilities shared across the package.
2
+
3
+ Utility helpers should remain lightweight and free of heavy dependencies.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from .seed import seed_everything, seed_from_env, temp_seed
9
+
10
+ __all__: list[str] = [
11
+ "seed_everything",
12
+ "seed_from_env",
13
+ "temp_seed",
14
+ ]
@@ -0,0 +1,177 @@
1
+ """Overhead estimators for pooled Conv2d coding.
2
+
3
+ This module provides lightweight, deterministic estimators for the compute and
4
+ memory overhead of the pooled-convolutional coding path used by
5
+ ``nervecode.layers.conv.CodingConv2d``. The estimates are torch-agnostic and
6
+ intended for documentation, sanity checks, and unit tests — not for replacing
7
+ runtime profiling.
8
+
9
+ Conventions
10
+ - We report multiply-accumulate operations (MACs) rather than FLOPs. A single
11
+ multiply followed by an add counts as one MAC. The baseline Conv2d MACs use
12
+ the common approximation: ``B * H_out * W_out * C_out * (C_in * kH * kW)``.
13
+ - Coding overhead counts are broken down into: global-average pooling over
14
+ spatial dimensions, optional linear projection (when ``coding_dim < C_out``),
15
+ and codebook distance + soft-assignment operations of size ``(B, D, K)``.
16
+ - Memory estimates cover transient activations in the coding path (e.g., the
17
+ pooled vector ``(B, D)`` and soft-code ``(B, K)``) and parameter overhead
18
+ introduced by the wrapper (codebook centers and optional projection matrix).
19
+
20
+ These estimates aim to be simple and stable across environments. Real kernels
21
+ and library implementations may differ by constant factors, fusions, and cache
22
+ effects; treat the results as order-of-magnitude guidance.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ from dataclasses import dataclass
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class PooledConvConfig:
32
+ """Configuration for a 2D convolution with pooled coding.
33
+
34
+ All dimensions are positive integers. ``kernel_size``, ``stride``,
35
+ ``padding``, and ``dilation`` accept symmetric pairs as ``(h, w)``.
36
+
37
+ Fields
38
+ - batch: batch size ``B``
39
+ - in_channels: input channels ``C_in``
40
+ - out_channels: output channels ``C_out``
41
+ - in_size: input spatial size ``(H_in, W_in)``
42
+ - kernel_size: kernel size ``(kH, kW)``
43
+ - stride: stride ``(sH, sW)``
44
+ - padding: padding ``(pH, pW)``
45
+ - dilation: dilation ``(dH, dW)``
46
+ - coding_dim: pooled coding dimension ``D``; defaults to ``C_out``
47
+ - K: codebook size
48
+ - dtype_bytes: bytes per float element for activation estimates (default: 4)
49
+ """
50
+
51
+ batch: int
52
+ in_channels: int
53
+ out_channels: int
54
+ in_size: tuple[int, int]
55
+ kernel_size: tuple[int, int] = (3, 3)
56
+ stride: tuple[int, int] = (1, 1)
57
+ padding: tuple[int, int] = (0, 0)
58
+ dilation: tuple[int, int] = (1, 1)
59
+ coding_dim: int | None = None
60
+ K: int = 16
61
+ dtype_bytes: int = 4
62
+
63
+ def output_size(self) -> tuple[int, int]:
64
+ """Compute the output spatial size (H_out, W_out) for the conv2d.
65
+
66
+ Uses the standard convolution formula with floor division.
67
+ """
68
+
69
+ H_in, W_in = self.in_size
70
+ kH, kW = self.kernel_size
71
+ sH, sW = self.stride
72
+ pH, pW = self.padding
73
+ dH, dW = self.dilation
74
+
75
+ H_out = (H_in + 2 * pH - dH * (kH - 1) - 1) // sH + 1
76
+ W_out = (W_in + 2 * pW - dW * (kW - 1) - 1) // sW + 1
77
+ return int(H_out), int(W_out)
78
+
79
+ def resolved_coding_dim(self) -> int:
80
+ """Return effective coding dimension ``D`` (defaults to ``C_out``)."""
81
+
82
+ return int(self.out_channels if self.coding_dim is None else self.coding_dim)
83
+
84
+
85
+ @dataclass(frozen=True)
86
+ class OverheadEstimate:
87
+ """Deterministic estimate of compute and memory costs for pooled coding.
88
+
89
+ Fields
90
+ - conv_macs: baseline Conv2d MACs per forward pass.
91
+ - coding_macs_total: total MACs added by pooled coding.
92
+ - coding_macs_breakdown: dict with keys ``pool``, ``projection``,
93
+ ``codebook`` (distance + assignment), and ``softmax`` (small terms).
94
+ - param_overhead: number of parameters introduced by the wrapper
95
+ (codebook centers + optional projection).
96
+ - activation_overhead_bytes: estimated transient activation bytes in the
97
+ coding path (pooled vector and soft-code; float elements only).
98
+ - overhead_ratio: ``coding_macs_total / conv_macs`` when ``conv_macs > 0``
99
+ else ``None``.
100
+ """
101
+
102
+ conv_macs: int
103
+ coding_macs_total: int
104
+ coding_macs_breakdown: dict[str, int]
105
+ param_overhead: int
106
+ activation_overhead_bytes: int
107
+ overhead_ratio: float | None
108
+
109
+
110
+ def estimate_pooled_conv_overhead(cfg: PooledConvConfig) -> OverheadEstimate:
111
+ """Estimate compute and memory overhead for pooled Conv2d coding.
112
+
113
+ The estimator uses simple closed-form counts for MACs and activation sizes
114
+ that match the pooled-coding path in ``CodingConv2d``. See module docstring
115
+ for conventions and caveats.
116
+ """
117
+
118
+ B = int(cfg.batch)
119
+ Cin = int(cfg.in_channels)
120
+ Cout = int(cfg.out_channels)
121
+ kH, kW = (int(cfg.kernel_size[0]), int(cfg.kernel_size[1]))
122
+ H_out, W_out = cfg.output_size()
123
+ D = int(cfg.resolved_coding_dim())
124
+ K = int(cfg.K)
125
+ bytes_per = max(1, int(cfg.dtype_bytes))
126
+
127
+ # Baseline conv MACs (multiply-accumulate count)
128
+ conv_macs = B * H_out * W_out * Cout * (Cin * kH * kW)
129
+
130
+ # Coding computations
131
+ # 1) Global average pool over spatial dims: sum H*W then divide (per channel)
132
+ pool_macs = B * Cout * H_out * W_out
133
+
134
+ # 2) Optional linear projection when D < Cout
135
+ proj_macs = 0
136
+ proj_params = 0
137
+ if Cout > D:
138
+ # Matrix-vector multiply per sample: (Cout x D) @ (B x Cout)
139
+ proj_macs = B * Cout * D
140
+ proj_params = Cout * D
141
+
142
+ # 3) Codebook distances and soft assignment on (B, D) against K centers
143
+ # Approximate as ~ 2 * B * K * D MACs (diff + square + sum)
144
+ codebook_macs = 2 * B * K * D
145
+ # Small softmax/logsumexp overhead per sample; keep it simple
146
+ softmax_macs = B * K
147
+
148
+ coding_macs_total = pool_macs + proj_macs + codebook_macs + softmax_macs
149
+
150
+ # Parameter overhead: K * D codebook centers + optional projection
151
+ param_overhead = K * D + proj_params
152
+
153
+ # Transient activation bytes (float elements): pooled (B, D) + soft (B, K)
154
+ activation_overhead_bytes = (B * D + B * K) * bytes_per
155
+
156
+ overhead_ratio = float(coding_macs_total) / float(conv_macs) if conv_macs > 0 else None
157
+
158
+ return OverheadEstimate(
159
+ conv_macs=int(conv_macs),
160
+ coding_macs_total=int(coding_macs_total),
161
+ coding_macs_breakdown={
162
+ "pool": int(pool_macs),
163
+ "projection": int(proj_macs),
164
+ "codebook": int(codebook_macs),
165
+ "softmax": int(softmax_macs),
166
+ },
167
+ param_overhead=int(param_overhead),
168
+ activation_overhead_bytes=int(activation_overhead_bytes),
169
+ overhead_ratio=None if overhead_ratio is None else float(overhead_ratio),
170
+ )
171
+
172
+
173
+ __all__ = [
174
+ "OverheadEstimate",
175
+ "PooledConvConfig",
176
+ "estimate_pooled_conv_overhead",
177
+ ]