nervecode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nervecode/__init__.py +415 -0
- nervecode/_version.py +10 -0
- nervecode/core/__init__.py +19 -0
- nervecode/core/assignment.py +165 -0
- nervecode/core/codebook.py +182 -0
- nervecode/core/shapes.py +107 -0
- nervecode/core/temperature.py +227 -0
- nervecode/core/trace.py +166 -0
- nervecode/core/types.py +116 -0
- nervecode/integration/__init__.py +9 -0
- nervecode/layers/__init__.py +15 -0
- nervecode/layers/base.py +333 -0
- nervecode/layers/conv.py +174 -0
- nervecode/layers/linear.py +176 -0
- nervecode/layers/reducers.py +80 -0
- nervecode/layers/wrap.py +223 -0
- nervecode/scoring/__init__.py +20 -0
- nervecode/scoring/aggregator.py +369 -0
- nervecode/scoring/calibrator.py +396 -0
- nervecode/scoring/types.py +33 -0
- nervecode/training/__init__.py +25 -0
- nervecode/training/diagnostics.py +194 -0
- nervecode/training/loss.py +188 -0
- nervecode/training/updaters.py +168 -0
- nervecode/utils/__init__.py +14 -0
- nervecode/utils/overhead.py +177 -0
- nervecode/utils/seed.py +161 -0
- nervecode-0.1.0.dist-info/METADATA +83 -0
- nervecode-0.1.0.dist-info/RECORD +31 -0
- nervecode-0.1.0.dist-info/WHEEL +4 -0
- nervecode-0.1.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Auxiliary coding loss that consumes CodingTrace objects.
|
|
2
|
+
|
|
3
|
+
This module implements a lightweight training-time loss that reads existing
|
|
4
|
+
``CodingTrace`` objects produced by wrapped layers. It avoids recomputing
|
|
5
|
+
distances or other intermediates from raw activations by relying on the
|
|
6
|
+
per-position surprise already stored on the associated ``SoftCode`` inside each
|
|
7
|
+
trace. When multiple layers are provided, their per-sample surprise signals are
|
|
8
|
+
aggregated using the default mean strategy from ``nervecode.scoring`` and then
|
|
9
|
+
reduced across the batch to yield a scalar loss.
|
|
10
|
+
|
|
11
|
+
The loss intentionally remains minimal in the MVP: it performs a mean reduction
|
|
12
|
+
over samples. Follow-up tasks extend it with configurable term weights and
|
|
13
|
+
clear per-layer breakdowns.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from collections.abc import Iterable, Mapping
|
|
19
|
+
from typing import Any, cast
|
|
20
|
+
|
|
21
|
+
try: # Keep import-time behavior tolerant in environments without torch
|
|
22
|
+
import torch
|
|
23
|
+
from torch import nn
|
|
24
|
+
except Exception: # pragma: no cover - torch is a project dependency in tests
|
|
25
|
+
torch = cast(Any, None)
|
|
26
|
+
nn = cast(Any, None)
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from nervecode.core import CodingTrace, SoftCode # re-exported types
|
|
30
|
+
except Exception: # pragma: no cover - available during normal package use
|
|
31
|
+
CodingTrace = object # type: ignore[misc,assignment]
|
|
32
|
+
SoftCode = object # type: ignore[misc,assignment]
|
|
33
|
+
|
|
34
|
+
from nervecode.scoring import mean_surprise
|
|
35
|
+
|
|
36
|
+
__all__ = ["CodingLoss"]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class _StubModule: # pragma: no cover - used only when torch is unavailable
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
_Base = nn.Module if nn is not None else _StubModule
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CodingLoss(_Base): # type: ignore[misc,valid-type]
|
|
47
|
+
"""Compute a scalar coding loss from traces without recomputation.
|
|
48
|
+
|
|
49
|
+
Inputs may be a single ``CodingTrace`` (or ``SoftCode``/tensor), an
|
|
50
|
+
iterable of such objects, or a mapping from layer names to objects.
|
|
51
|
+
|
|
52
|
+
The loss has two conceptual parts:
|
|
53
|
+
- an assignment/sharpness term derived from the per-position surprise
|
|
54
|
+
on ``SoftCode`` (aggregated via ``mean_surprise``), and
|
|
55
|
+
- an optional commitment term that uses the nearest-center distance
|
|
56
|
+
already stored on each ``CodingTrace``.
|
|
57
|
+
|
|
58
|
+
Both terms are reduced to a per-sample view and averaged across layers;
|
|
59
|
+
the final scalar is a weighted sum with ``assign_weight`` and
|
|
60
|
+
``commit_weight``. By default the commitment term weight is zero to keep
|
|
61
|
+
backward compatibility with earlier behavior.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, *, assign_weight: float = 1.0, commit_weight: float = 0.0) -> None:
|
|
65
|
+
if torch is None: # pragma: no cover - defensive
|
|
66
|
+
raise RuntimeError("CodingLoss requires PyTorch to be installed")
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.assign_weight: float = float(assign_weight)
|
|
69
|
+
self.commit_weight: float = float(commit_weight)
|
|
70
|
+
|
|
71
|
+
def forward(self, traces: Any) -> torch.Tensor:
|
|
72
|
+
"""Return weighted mean-aggregated coding loss across samples.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
- traces: A ``CodingTrace`` object (or ``SoftCode``/tensor), an
|
|
76
|
+
iterable of such objects, or a mapping from layer names to objects.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
- A scalar ``torch.Tensor`` equal to the weighted sum of the per-sample
|
|
80
|
+
aggregated assignment term and, when enabled, the aggregated
|
|
81
|
+
commitment term.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
if torch is None: # pragma: no cover - defensive
|
|
85
|
+
raise RuntimeError("CodingLoss requires PyTorch to be installed")
|
|
86
|
+
|
|
87
|
+
# Normalize input to what mean_surprise and our commitment aggregator expect
|
|
88
|
+
payload: Iterable[Any] | Mapping[str, Any]
|
|
89
|
+
if isinstance(traces, Mapping) or (
|
|
90
|
+
isinstance(traces, Iterable)
|
|
91
|
+
and not isinstance(traces, (torch.Tensor, SoftCode, CodingTrace))
|
|
92
|
+
):
|
|
93
|
+
payload = traces
|
|
94
|
+
else:
|
|
95
|
+
# Single object: wrap into a list for aggregation
|
|
96
|
+
payload = [traces]
|
|
97
|
+
|
|
98
|
+
# Assignment/sharpness component via mean_surprise (SoftCode-based)
|
|
99
|
+
agg = mean_surprise(payload)
|
|
100
|
+
if agg is None or not hasattr(agg, "surprise"):
|
|
101
|
+
raise ValueError("No valid surprise signals found in provided traces")
|
|
102
|
+
s = cast(torch.Tensor, agg.surprise)
|
|
103
|
+
assign_term = s.mean()
|
|
104
|
+
|
|
105
|
+
# Commitment component via nearest-center distances stored on traces
|
|
106
|
+
commit_vec: torch.Tensor | None = _mean_commitment_across_layers(payload)
|
|
107
|
+
|
|
108
|
+
if self.commit_weight != 0.0 and commit_vec is not None:
|
|
109
|
+
commit_mean = commit_vec.mean()
|
|
110
|
+
return (
|
|
111
|
+
assign_term.to(dtype=commit_mean.dtype, device=commit_mean.device)
|
|
112
|
+
* self.assign_weight
|
|
113
|
+
+ commit_mean * self.commit_weight
|
|
114
|
+
)
|
|
115
|
+
# Default: only assignment term
|
|
116
|
+
return assign_term * self.assign_weight
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _mean_commitment_across_layers(
|
|
120
|
+
traces: Iterable[Any] | Mapping[str, Any],
|
|
121
|
+
) -> torch.Tensor | None:
|
|
122
|
+
"""Return per-sample mean-aggregated commitment distance across layers.
|
|
123
|
+
|
|
124
|
+
This helper scans the provided collection for ``CodingTrace`` objects,
|
|
125
|
+
extracts their per-position ``commitment_distances``, reduces them to a
|
|
126
|
+
per-sample vector by averaging extra leading dimensions, and averages
|
|
127
|
+
across layers. When no layer provides commitment distances, returns
|
|
128
|
+
``None``.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
if torch is None: # pragma: no cover - defensive
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
values: Iterable[Any] = traces.values() if isinstance(traces, Mapping) else traces
|
|
135
|
+
|
|
136
|
+
signals: list[torch.Tensor] = []
|
|
137
|
+
ref_shape: tuple[int, ...] | None = None
|
|
138
|
+
|
|
139
|
+
for obj in values:
|
|
140
|
+
# Only CodingTrace carries commitment distances in the MVP
|
|
141
|
+
if not isinstance(obj, CodingTrace):
|
|
142
|
+
continue
|
|
143
|
+
d = getattr(obj, "commitment_distances", None)
|
|
144
|
+
if not isinstance(d, torch.Tensor):
|
|
145
|
+
continue
|
|
146
|
+
t = _reduce_to_sample(d)
|
|
147
|
+
shape = tuple(int(s) for s in t.shape)
|
|
148
|
+
if ref_shape is None:
|
|
149
|
+
ref_shape = shape
|
|
150
|
+
signals.append(t)
|
|
151
|
+
else:
|
|
152
|
+
if shape != ref_shape:
|
|
153
|
+
# Skip mismatched sample shapes to preserve fail-open behavior
|
|
154
|
+
continue
|
|
155
|
+
signals.append(t)
|
|
156
|
+
|
|
157
|
+
if not signals:
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
# Align dtype/device to the first signal for stable aggregation
|
|
161
|
+
first = signals[0]
|
|
162
|
+
device = getattr(first, "device", None)
|
|
163
|
+
dtype = first.dtype if hasattr(first, "dtype") else None
|
|
164
|
+
aligned: list[torch.Tensor] = []
|
|
165
|
+
for s in signals:
|
|
166
|
+
s2 = s
|
|
167
|
+
if dtype is not None and getattr(s2, "dtype", None) is not dtype:
|
|
168
|
+
s2 = s2.to(dtype=dtype)
|
|
169
|
+
if device is not None and getattr(s2, "device", None) != device:
|
|
170
|
+
s2 = s2.to(device=device)
|
|
171
|
+
aligned.append(s2)
|
|
172
|
+
|
|
173
|
+
stacked = torch.stack(aligned, dim=0)
|
|
174
|
+
# Mean over layers yields a per-sample vector
|
|
175
|
+
return stacked.mean(dim=0)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _reduce_to_sample(t: torch.Tensor) -> torch.Tensor:
|
|
179
|
+
"""Reduce a tensor to a per-sample vector by averaging extra leading dims.
|
|
180
|
+
|
|
181
|
+
Mirrors the reduction used by the scoring aggregator so that both
|
|
182
|
+
assignment and commitment terms are combined consistently across layers.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
if getattr(t, "ndim", 0) <= 1:
|
|
186
|
+
return t
|
|
187
|
+
dims = tuple(range(1, int(t.ndim)))
|
|
188
|
+
return t.mean(dim=dims)
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""Experimental codebook update helpers (EMA / hybrid).
|
|
2
|
+
|
|
3
|
+
These helpers are optional, off-by-default training utilities that update
|
|
4
|
+
codebook centers using exponential moving averages driven by the latest
|
|
5
|
+
coding trace. They are provided for exploration only — the default and
|
|
6
|
+
recommended baseline remains standard gradient-based optimization via
|
|
7
|
+
``CodingLoss`` and your optimizer.
|
|
8
|
+
|
|
9
|
+
Design goals
|
|
10
|
+
- Observe-only wrappers and existing training code remain unchanged unless an
|
|
11
|
+
updater is explicitly instantiated and called by the user.
|
|
12
|
+
- Safe numerics: clamp divisors, support empty usages gracefully.
|
|
13
|
+
- Torch-first: implemented as small ``nn.Module``s with buffers so they move
|
|
14
|
+
with devices and can be checkpointed when needed.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any, Literal, cast
|
|
20
|
+
|
|
21
|
+
try: # Keep import-time failure tolerant when torch isn't installed
|
|
22
|
+
import torch
|
|
23
|
+
from torch import nn
|
|
24
|
+
except Exception: # pragma: no cover - torch is a project dependency in tests
|
|
25
|
+
torch = cast(Any, None)
|
|
26
|
+
nn = cast(Any, None)
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from nervecode.core import CodingTrace
|
|
30
|
+
from nervecode.core.codebook import Codebook
|
|
31
|
+
except Exception: # pragma: no cover - available during normal package use
|
|
32
|
+
CodingTrace = object # type: ignore[misc,assignment]
|
|
33
|
+
Codebook = object # type: ignore[misc,assignment]
|
|
34
|
+
|
|
35
|
+
__all__ = ["EmaCodebookUpdater"]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class _StubModule: # pragma: no cover - used only when torch is unavailable
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
_Base = nn.Module if nn is not None else _StubModule
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class EmaCodebookUpdater(_Base): # type: ignore[misc,valid-type]
|
|
46
|
+
"""Update codebook centers using an EMA of batch-weighted means.
|
|
47
|
+
|
|
48
|
+
This updater maintains per-code EMA counts and summed activations and sets
|
|
49
|
+
centers to ``ema_sum / (ema_count + eps)`` after each update call. The batch
|
|
50
|
+
statistics are computed from a ``CodingTrace`` using either soft
|
|
51
|
+
probabilities or hard assignments.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
- codebook: The ``Codebook`` whose centers are updated in-place.
|
|
55
|
+
- decay: Exponential decay factor in ``[0, 1)``. Higher means slower update
|
|
56
|
+
toward the batch statistics. ``decay=0`` uses the current batch only.
|
|
57
|
+
- assignment: ``"soft"`` (default) to use soft probabilities, or ``"hard"``
|
|
58
|
+
to use one-hot assignments based on argmax indices.
|
|
59
|
+
- eps: Small positive constant to avoid division by zero.
|
|
60
|
+
|
|
61
|
+
Notes
|
|
62
|
+
- Updates run under ``torch.no_grad()`` and preserve ``requires_grad`` on
|
|
63
|
+
``codebook.centers`` so gradient-based training can continue to operate
|
|
64
|
+
(hybrid usage).
|
|
65
|
+
- The updater holds its own buffers (ema_count, ema_sum) sized from the
|
|
66
|
+
attached codebook at construction time. If the codebook is reinitialized
|
|
67
|
+
with a different shape, construct a new updater.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
codebook: Codebook,
|
|
73
|
+
*,
|
|
74
|
+
decay: float = 0.99,
|
|
75
|
+
assignment: Literal["soft", "hard"] = "soft",
|
|
76
|
+
eps: float = 1e-12,
|
|
77
|
+
) -> None:
|
|
78
|
+
if torch is None: # pragma: no cover - defensive
|
|
79
|
+
raise RuntimeError("EmaCodebookUpdater requires PyTorch to be installed")
|
|
80
|
+
super().__init__()
|
|
81
|
+
|
|
82
|
+
if decay < 0.0 or decay >= 1.0 or not float(decay) == decay:
|
|
83
|
+
raise ValueError("decay must be a float in [0, 1)")
|
|
84
|
+
if assignment not in ("soft", "hard"):
|
|
85
|
+
raise ValueError("assignment must be 'soft' or 'hard'")
|
|
86
|
+
if eps <= 0.0:
|
|
87
|
+
raise ValueError("eps must be positive")
|
|
88
|
+
|
|
89
|
+
self.codebook = codebook
|
|
90
|
+
self.decay: float = float(decay)
|
|
91
|
+
self.assignment: Literal["soft", "hard"] = assignment
|
|
92
|
+
self.eps: float = float(eps)
|
|
93
|
+
|
|
94
|
+
K = int(codebook.K)
|
|
95
|
+
D = int(codebook.code_dim)
|
|
96
|
+
# EMA state mirrors codebook shape: counts per code and sums per code x dim
|
|
97
|
+
self._ema_count: torch.Tensor
|
|
98
|
+
self._ema_sum: torch.Tensor
|
|
99
|
+
self.register_buffer("_ema_count", torch.zeros(K), persistent=True)
|
|
100
|
+
self.register_buffer("_ema_sum", torch.zeros(K, D), persistent=True)
|
|
101
|
+
|
|
102
|
+
def reset_state(self) -> None:
|
|
103
|
+
"""Reset EMA state to zeros (no prior influence)."""
|
|
104
|
+
if torch is None: # pragma: no cover - defensive
|
|
105
|
+
return
|
|
106
|
+
with torch.no_grad():
|
|
107
|
+
self._ema_count.zero_()
|
|
108
|
+
self._ema_sum.zero_()
|
|
109
|
+
|
|
110
|
+
def update_from_trace(self, trace: CodingTrace) -> None:
|
|
111
|
+
"""Update EMA statistics and set codebook centers for a batch.
|
|
112
|
+
|
|
113
|
+
Expects a ``CodingTrace`` produced by the same codebook. The reduced
|
|
114
|
+
activations and the ``SoftCode`` probabilities/indices are used to
|
|
115
|
+
compute per-code batch counts and sums.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
# Validate shape alignment in a torch-first way
|
|
119
|
+
probs = trace.soft_code.probs
|
|
120
|
+
x = trace.reduced
|
|
121
|
+
|
|
122
|
+
if not isinstance(probs, torch.Tensor) or not isinstance(x, torch.Tensor):
|
|
123
|
+
raise TypeError("trace must carry torch.Tensor probabilities and reduced activations")
|
|
124
|
+
if probs.ndim < 1 or x.ndim < 1:
|
|
125
|
+
raise ValueError("trace tensors must have at least one dimension")
|
|
126
|
+
if int(probs.shape[-1]) != int(self.codebook.K):
|
|
127
|
+
raise ValueError("trace code dimension K must match attached codebook")
|
|
128
|
+
if int(x.shape[-1]) != int(self.codebook.code_dim):
|
|
129
|
+
raise ValueError("trace reduced dim D must match attached codebook")
|
|
130
|
+
|
|
131
|
+
with torch.no_grad():
|
|
132
|
+
# Flatten leading dims to (N, ...)
|
|
133
|
+
N = int(x.numel() // x.shape[-1])
|
|
134
|
+
x_flat = x.reshape(N, -1) # (N, D)
|
|
135
|
+
p_flat = probs.reshape(N, -1) # (N, K)
|
|
136
|
+
|
|
137
|
+
if self.assignment == "hard":
|
|
138
|
+
# One-hot from argmax indices
|
|
139
|
+
idx = torch.argmax(p_flat, dim=-1) # (N,)
|
|
140
|
+
w = torch.zeros_like(p_flat)
|
|
141
|
+
w.scatter_(1, idx.unsqueeze(-1), 1.0)
|
|
142
|
+
else:
|
|
143
|
+
# Soft probabilities (already sum to 1 across codes)
|
|
144
|
+
w = p_flat
|
|
145
|
+
|
|
146
|
+
# Per-code batch counts and sums
|
|
147
|
+
batch_count = w.sum(dim=0) # (K,)
|
|
148
|
+
batch_sum = w.t() @ x_flat # (K, D)
|
|
149
|
+
|
|
150
|
+
# Align EMA state to the same device/dtype as incoming batch
|
|
151
|
+
device = x_flat.device
|
|
152
|
+
dtype = x_flat.dtype
|
|
153
|
+
if self._ema_count.device != device or self._ema_count.dtype != dtype:
|
|
154
|
+
self._ema_count = self._ema_count.to(device=device, dtype=dtype)
|
|
155
|
+
if self._ema_sum.device != device or self._ema_sum.dtype != dtype:
|
|
156
|
+
self._ema_sum = self._ema_sum.to(device=device, dtype=dtype)
|
|
157
|
+
|
|
158
|
+
d = self.decay
|
|
159
|
+
one_minus = 1.0 - d
|
|
160
|
+
self._ema_count.mul_(d).add_(batch_count.to(dtype=dtype), alpha=one_minus)
|
|
161
|
+
self._ema_sum.mul_(d).add_(batch_sum.to(dtype=dtype), alpha=one_minus)
|
|
162
|
+
|
|
163
|
+
# Update centers: ema_sum / (ema_count + eps)
|
|
164
|
+
denom = (self._ema_count + self.eps).unsqueeze(-1) # (K, 1)
|
|
165
|
+
new_centers = self._ema_sum / denom
|
|
166
|
+
|
|
167
|
+
# In-place update preserves requires_grad on the parameter
|
|
168
|
+
self.codebook.centers.detach().copy_(new_centers)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Small utilities shared across the package.
|
|
2
|
+
|
|
3
|
+
Utility helpers should remain lightweight and free of heavy dependencies.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from .seed import seed_everything, seed_from_env, temp_seed
|
|
9
|
+
|
|
10
|
+
__all__: list[str] = [
|
|
11
|
+
"seed_everything",
|
|
12
|
+
"seed_from_env",
|
|
13
|
+
"temp_seed",
|
|
14
|
+
]
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Overhead estimators for pooled Conv2d coding.
|
|
2
|
+
|
|
3
|
+
This module provides lightweight, deterministic estimators for the compute and
|
|
4
|
+
memory overhead of the pooled-convolutional coding path used by
|
|
5
|
+
``nervecode.layers.conv.CodingConv2d``. The estimates are torch-agnostic and
|
|
6
|
+
intended for documentation, sanity checks, and unit tests — not for replacing
|
|
7
|
+
runtime profiling.
|
|
8
|
+
|
|
9
|
+
Conventions
|
|
10
|
+
- We report multiply-accumulate operations (MACs) rather than FLOPs. A single
|
|
11
|
+
multiply followed by an add counts as one MAC. The baseline Conv2d MACs use
|
|
12
|
+
the common approximation: ``B * H_out * W_out * C_out * (C_in * kH * kW)``.
|
|
13
|
+
- Coding overhead counts are broken down into: global-average pooling over
|
|
14
|
+
spatial dimensions, optional linear projection (when ``coding_dim < C_out``),
|
|
15
|
+
and codebook distance + soft-assignment operations of size ``(B, D, K)``.
|
|
16
|
+
- Memory estimates cover transient activations in the coding path (e.g., the
|
|
17
|
+
pooled vector ``(B, D)`` and soft-code ``(B, K)``) and parameter overhead
|
|
18
|
+
introduced by the wrapper (codebook centers and optional projection matrix).
|
|
19
|
+
|
|
20
|
+
These estimates aim to be simple and stable across environments. Real kernels
|
|
21
|
+
and library implementations may differ by constant factors, fusions, and cache
|
|
22
|
+
effects; treat the results as order-of-magnitude guidance.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from dataclasses import dataclass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class PooledConvConfig:
|
|
32
|
+
"""Configuration for a 2D convolution with pooled coding.
|
|
33
|
+
|
|
34
|
+
All dimensions are positive integers. ``kernel_size``, ``stride``,
|
|
35
|
+
``padding``, and ``dilation`` accept symmetric pairs as ``(h, w)``.
|
|
36
|
+
|
|
37
|
+
Fields
|
|
38
|
+
- batch: batch size ``B``
|
|
39
|
+
- in_channels: input channels ``C_in``
|
|
40
|
+
- out_channels: output channels ``C_out``
|
|
41
|
+
- in_size: input spatial size ``(H_in, W_in)``
|
|
42
|
+
- kernel_size: kernel size ``(kH, kW)``
|
|
43
|
+
- stride: stride ``(sH, sW)``
|
|
44
|
+
- padding: padding ``(pH, pW)``
|
|
45
|
+
- dilation: dilation ``(dH, dW)``
|
|
46
|
+
- coding_dim: pooled coding dimension ``D``; defaults to ``C_out``
|
|
47
|
+
- K: codebook size
|
|
48
|
+
- dtype_bytes: bytes per float element for activation estimates (default: 4)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
batch: int
|
|
52
|
+
in_channels: int
|
|
53
|
+
out_channels: int
|
|
54
|
+
in_size: tuple[int, int]
|
|
55
|
+
kernel_size: tuple[int, int] = (3, 3)
|
|
56
|
+
stride: tuple[int, int] = (1, 1)
|
|
57
|
+
padding: tuple[int, int] = (0, 0)
|
|
58
|
+
dilation: tuple[int, int] = (1, 1)
|
|
59
|
+
coding_dim: int | None = None
|
|
60
|
+
K: int = 16
|
|
61
|
+
dtype_bytes: int = 4
|
|
62
|
+
|
|
63
|
+
def output_size(self) -> tuple[int, int]:
|
|
64
|
+
"""Compute the output spatial size (H_out, W_out) for the conv2d.
|
|
65
|
+
|
|
66
|
+
Uses the standard convolution formula with floor division.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
H_in, W_in = self.in_size
|
|
70
|
+
kH, kW = self.kernel_size
|
|
71
|
+
sH, sW = self.stride
|
|
72
|
+
pH, pW = self.padding
|
|
73
|
+
dH, dW = self.dilation
|
|
74
|
+
|
|
75
|
+
H_out = (H_in + 2 * pH - dH * (kH - 1) - 1) // sH + 1
|
|
76
|
+
W_out = (W_in + 2 * pW - dW * (kW - 1) - 1) // sW + 1
|
|
77
|
+
return int(H_out), int(W_out)
|
|
78
|
+
|
|
79
|
+
def resolved_coding_dim(self) -> int:
|
|
80
|
+
"""Return effective coding dimension ``D`` (defaults to ``C_out``)."""
|
|
81
|
+
|
|
82
|
+
return int(self.out_channels if self.coding_dim is None else self.coding_dim)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass(frozen=True)
|
|
86
|
+
class OverheadEstimate:
|
|
87
|
+
"""Deterministic estimate of compute and memory costs for pooled coding.
|
|
88
|
+
|
|
89
|
+
Fields
|
|
90
|
+
- conv_macs: baseline Conv2d MACs per forward pass.
|
|
91
|
+
- coding_macs_total: total MACs added by pooled coding.
|
|
92
|
+
- coding_macs_breakdown: dict with keys ``pool``, ``projection``,
|
|
93
|
+
``codebook`` (distance + assignment), and ``softmax`` (small terms).
|
|
94
|
+
- param_overhead: number of parameters introduced by the wrapper
|
|
95
|
+
(codebook centers + optional projection).
|
|
96
|
+
- activation_overhead_bytes: estimated transient activation bytes in the
|
|
97
|
+
coding path (pooled vector and soft-code; float elements only).
|
|
98
|
+
- overhead_ratio: ``coding_macs_total / conv_macs`` when ``conv_macs > 0``
|
|
99
|
+
else ``None``.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
conv_macs: int
|
|
103
|
+
coding_macs_total: int
|
|
104
|
+
coding_macs_breakdown: dict[str, int]
|
|
105
|
+
param_overhead: int
|
|
106
|
+
activation_overhead_bytes: int
|
|
107
|
+
overhead_ratio: float | None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def estimate_pooled_conv_overhead(cfg: PooledConvConfig) -> OverheadEstimate:
|
|
111
|
+
"""Estimate compute and memory overhead for pooled Conv2d coding.
|
|
112
|
+
|
|
113
|
+
The estimator uses simple closed-form counts for MACs and activation sizes
|
|
114
|
+
that match the pooled-coding path in ``CodingConv2d``. See module docstring
|
|
115
|
+
for conventions and caveats.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
B = int(cfg.batch)
|
|
119
|
+
Cin = int(cfg.in_channels)
|
|
120
|
+
Cout = int(cfg.out_channels)
|
|
121
|
+
kH, kW = (int(cfg.kernel_size[0]), int(cfg.kernel_size[1]))
|
|
122
|
+
H_out, W_out = cfg.output_size()
|
|
123
|
+
D = int(cfg.resolved_coding_dim())
|
|
124
|
+
K = int(cfg.K)
|
|
125
|
+
bytes_per = max(1, int(cfg.dtype_bytes))
|
|
126
|
+
|
|
127
|
+
# Baseline conv MACs (multiply-accumulate count)
|
|
128
|
+
conv_macs = B * H_out * W_out * Cout * (Cin * kH * kW)
|
|
129
|
+
|
|
130
|
+
# Coding computations
|
|
131
|
+
# 1) Global average pool over spatial dims: sum H*W then divide (per channel)
|
|
132
|
+
pool_macs = B * Cout * H_out * W_out
|
|
133
|
+
|
|
134
|
+
# 2) Optional linear projection when D < Cout
|
|
135
|
+
proj_macs = 0
|
|
136
|
+
proj_params = 0
|
|
137
|
+
if Cout > D:
|
|
138
|
+
# Matrix-vector multiply per sample: (Cout x D) @ (B x Cout)
|
|
139
|
+
proj_macs = B * Cout * D
|
|
140
|
+
proj_params = Cout * D
|
|
141
|
+
|
|
142
|
+
# 3) Codebook distances and soft assignment on (B, D) against K centers
|
|
143
|
+
# Approximate as ~ 2 * B * K * D MACs (diff + square + sum)
|
|
144
|
+
codebook_macs = 2 * B * K * D
|
|
145
|
+
# Small softmax/logsumexp overhead per sample; keep it simple
|
|
146
|
+
softmax_macs = B * K
|
|
147
|
+
|
|
148
|
+
coding_macs_total = pool_macs + proj_macs + codebook_macs + softmax_macs
|
|
149
|
+
|
|
150
|
+
# Parameter overhead: K * D codebook centers + optional projection
|
|
151
|
+
param_overhead = K * D + proj_params
|
|
152
|
+
|
|
153
|
+
# Transient activation bytes (float elements): pooled (B, D) + soft (B, K)
|
|
154
|
+
activation_overhead_bytes = (B * D + B * K) * bytes_per
|
|
155
|
+
|
|
156
|
+
overhead_ratio = float(coding_macs_total) / float(conv_macs) if conv_macs > 0 else None
|
|
157
|
+
|
|
158
|
+
return OverheadEstimate(
|
|
159
|
+
conv_macs=int(conv_macs),
|
|
160
|
+
coding_macs_total=int(coding_macs_total),
|
|
161
|
+
coding_macs_breakdown={
|
|
162
|
+
"pool": int(pool_macs),
|
|
163
|
+
"projection": int(proj_macs),
|
|
164
|
+
"codebook": int(codebook_macs),
|
|
165
|
+
"softmax": int(softmax_macs),
|
|
166
|
+
},
|
|
167
|
+
param_overhead=int(param_overhead),
|
|
168
|
+
activation_overhead_bytes=int(activation_overhead_bytes),
|
|
169
|
+
overhead_ratio=None if overhead_ratio is None else float(overhead_ratio),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
__all__ = [
|
|
174
|
+
"OverheadEstimate",
|
|
175
|
+
"PooledConvConfig",
|
|
176
|
+
"estimate_pooled_conv_overhead",
|
|
177
|
+
]
|