ezmsg-learn 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__init__.py +2 -0
- ezmsg/learn/__version__.py +34 -0
- ezmsg/learn/dim_reduce/__init__.py +0 -0
- ezmsg/learn/dim_reduce/adaptive_decomp.py +274 -0
- ezmsg/learn/dim_reduce/incremental_decomp.py +173 -0
- ezmsg/learn/linear_model/__init__.py +1 -0
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +12 -0
- ezmsg/learn/linear_model/cca.py +1 -0
- ezmsg/learn/linear_model/linear_regressor.py +9 -0
- ezmsg/learn/linear_model/sgd.py +9 -0
- ezmsg/learn/linear_model/slda.py +12 -0
- ezmsg/learn/model/__init__.py +0 -0
- ezmsg/learn/model/cca.py +122 -0
- ezmsg/learn/model/mlp.py +127 -0
- ezmsg/learn/model/mlp_old.py +49 -0
- ezmsg/learn/model/refit_kalman.py +369 -0
- ezmsg/learn/model/rnn.py +160 -0
- ezmsg/learn/model/transformer.py +175 -0
- ezmsg/learn/nlin_model/__init__.py +1 -0
- ezmsg/learn/nlin_model/mlp.py +10 -0
- ezmsg/learn/process/__init__.py +0 -0
- ezmsg/learn/process/adaptive_linear_regressor.py +142 -0
- ezmsg/learn/process/base.py +154 -0
- ezmsg/learn/process/linear_regressor.py +95 -0
- ezmsg/learn/process/mlp_old.py +188 -0
- ezmsg/learn/process/refit_kalman.py +403 -0
- ezmsg/learn/process/rnn.py +245 -0
- ezmsg/learn/process/sgd.py +117 -0
- ezmsg/learn/process/sklearn.py +241 -0
- ezmsg/learn/process/slda.py +110 -0
- ezmsg/learn/process/ssr.py +374 -0
- ezmsg/learn/process/torch.py +362 -0
- ezmsg/learn/process/transformer.py +215 -0
- ezmsg/learn/util.py +67 -0
- ezmsg_learn-1.1.0.dist-info/METADATA +30 -0
- ezmsg_learn-1.1.0.dist-info/RECORD +38 -0
- ezmsg_learn-1.1.0.dist-info/WHEEL +4 -0
- ezmsg_learn-1.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ezmsg.baseproc import (
|
|
6
|
+
BaseStatefulTransformer,
|
|
7
|
+
BaseTransformerUnit,
|
|
8
|
+
processor_state,
|
|
9
|
+
)
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
from ezmsg.util.messages.util import replace
|
|
12
|
+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
|
|
13
|
+
|
|
14
|
+
from ..util import ClassifierMessage
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SLDASettings(ez.Settings):
|
|
18
|
+
settings_path: str
|
|
19
|
+
axis: str = "time"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@processor_state
|
|
23
|
+
class SLDAState:
|
|
24
|
+
lda: LDA
|
|
25
|
+
out_template: typing.Optional[ClassifierMessage] = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SLDATransformer(BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]):
|
|
29
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
30
|
+
if self.settings.settings_path[-4:] == ".mat":
|
|
31
|
+
# Expects a very specific format from a specific project. Not for general use.
|
|
32
|
+
import scipy.io as sio
|
|
33
|
+
|
|
34
|
+
matlab_sLDA = sio.loadmat(self.settings.settings_path, squeeze_me=True)
|
|
35
|
+
temp_weights = matlab_sLDA["weights"][1, 1:]
|
|
36
|
+
temp_intercept = matlab_sLDA["weights"][1, 0]
|
|
37
|
+
|
|
38
|
+
# Create weights and use zeros for channels we do not keep.
|
|
39
|
+
channels = matlab_sLDA["channels"] - 4
|
|
40
|
+
channels -= channels[0] # Offsets are wrong somehow.
|
|
41
|
+
n_channels = message.data.shape[message.dims.index("ch")]
|
|
42
|
+
valid_indices = [ch for ch in channels if ch < n_channels]
|
|
43
|
+
full_weights = np.zeros(n_channels)
|
|
44
|
+
full_weights[valid_indices] = temp_weights[: len(valid_indices)]
|
|
45
|
+
|
|
46
|
+
lda = LDA(solver="lsqr", shrinkage="auto")
|
|
47
|
+
lda.classes_ = np.asarray([0, 1])
|
|
48
|
+
lda.coef_ = np.expand_dims(full_weights, axis=0)
|
|
49
|
+
lda.intercept_ = temp_intercept # TODO: Is this supposed to be per-channel? Why the [1, 0]?
|
|
50
|
+
self.state.lda = lda
|
|
51
|
+
# mean = matlab_sLDA['mXtrain']
|
|
52
|
+
# std = matlab_sLDA['sXtrain']
|
|
53
|
+
# lags = matlab_sLDA['lags'] + 1
|
|
54
|
+
else:
|
|
55
|
+
import pickle
|
|
56
|
+
|
|
57
|
+
with open(self.settings.settings_path, "rb") as f:
|
|
58
|
+
self.state.lda = pickle.load(f)
|
|
59
|
+
|
|
60
|
+
# Create template ClassifierMessage using lda.classes_
|
|
61
|
+
out_labels = self.state.lda.classes_.tolist()
|
|
62
|
+
zero_shape = (0, len(out_labels))
|
|
63
|
+
self.state.out_template = ClassifierMessage(
|
|
64
|
+
data=np.zeros(zero_shape, dtype=message.data.dtype),
|
|
65
|
+
dims=[self.settings.axis, "classes"],
|
|
66
|
+
axes={
|
|
67
|
+
self.settings.axis: message.axes[self.settings.axis],
|
|
68
|
+
"classes": AxisArray.CoordinateAxis(data=np.array(out_labels), dims=["classes"]),
|
|
69
|
+
},
|
|
70
|
+
labels=out_labels,
|
|
71
|
+
key=message.key,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _process(self, message: AxisArray) -> ClassifierMessage:
|
|
75
|
+
samp_ax_idx = message.dims.index(self.settings.axis)
|
|
76
|
+
X = np.moveaxis(message.data, samp_ax_idx, 0)
|
|
77
|
+
|
|
78
|
+
if X.shape[0]:
|
|
79
|
+
if isinstance(self.settings.settings_path, str) and self.settings.settings_path[-4:] == ".mat":
|
|
80
|
+
# Assumes F-contiguous weights
|
|
81
|
+
pred_probas = []
|
|
82
|
+
for samp in X:
|
|
83
|
+
tmp = samp.flatten(order="F") * 1e-6
|
|
84
|
+
tmp = np.expand_dims(tmp, axis=0)
|
|
85
|
+
probas = self.state.lda.predict_proba(tmp)
|
|
86
|
+
pred_probas.append(probas)
|
|
87
|
+
pred_probas = np.concatenate(pred_probas, axis=0)
|
|
88
|
+
else:
|
|
89
|
+
# This creates a copy.
|
|
90
|
+
X = X.reshape(X.shape[0], -1)
|
|
91
|
+
pred_probas = self.state.lda.predict_proba(X)
|
|
92
|
+
|
|
93
|
+
update_ax = self.state.out_template.axes[self.settings.axis]
|
|
94
|
+
update_ax.offset = message.axes[self.settings.axis].offset
|
|
95
|
+
|
|
96
|
+
return replace(
|
|
97
|
+
self.state.out_template,
|
|
98
|
+
data=pred_probas,
|
|
99
|
+
axes={
|
|
100
|
+
**self.state.out_template.axes,
|
|
101
|
+
# `replace` will copy the minimal set of fields
|
|
102
|
+
self.settings.axis: replace(update_ax, offset=update_ax.offset),
|
|
103
|
+
},
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
return self.state.out_template
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class SLDA(BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]):
|
|
110
|
+
SETTINGS = SLDASettings
|
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
"""Self-supervised regression framework and LRR implementation.
|
|
2
|
+
|
|
3
|
+
This module provides a general framework for self-supervised channel
|
|
4
|
+
regression via :class:`SelfSupervisedRegressionTransformer`, and a
|
|
5
|
+
concrete implementation — Linear Regression Rereferencing (LRR) — via
|
|
6
|
+
:class:`LRRTransformer`.
|
|
7
|
+
|
|
8
|
+
**Framework.** The base class accumulates the channel covariance
|
|
9
|
+
``C = X^T X`` and solves per-cluster ridge regressions to obtain a weight
|
|
10
|
+
matrix *W*. Subclasses define what to *do* with *W* by implementing
|
|
11
|
+
:meth:`~SelfSupervisedRegressionTransformer._on_weights_updated` and
|
|
12
|
+
:meth:`~SelfSupervisedRegressionTransformer._process`.
|
|
13
|
+
|
|
14
|
+
**LRR.** For each channel *c*, predict it from the other channels in its
|
|
15
|
+
cluster via ridge regression, then subtract the prediction::
|
|
16
|
+
|
|
17
|
+
y = X - X @ W = X @ (I - W)
|
|
18
|
+
|
|
19
|
+
The effective weight matrix ``I - W`` is passed to
|
|
20
|
+
:class:`~ezmsg.sigproc.affinetransform.AffineTransformTransformer`, which
|
|
21
|
+
automatically exploits block-diagonal structure when ``channel_clusters``
|
|
22
|
+
are provided.
|
|
23
|
+
|
|
24
|
+
**Fitting.** Given data matrix *X* of shape ``(samples, channels)``, the
|
|
25
|
+
sufficient statistic is the channel covariance ``C = X^T X``. When
|
|
26
|
+
``incremental=True`` (default), *C* is accumulated across
|
|
27
|
+
:meth:`~SelfSupervisedRegressionTransformer.partial_fit` calls.
|
|
28
|
+
|
|
29
|
+
**Solving.** Within each cluster the weight matrix *W* is obtained from
|
|
30
|
+
the inverse of the (ridge-regularised) cluster covariance
|
|
31
|
+
``C_inv = (C_cluster + lambda * I)^{-1}`` using the block-inverse identity::
|
|
32
|
+
|
|
33
|
+
W[:, c] = -C_inv[:, c] / C_inv[c, c], diag(W) = 0
|
|
34
|
+
|
|
35
|
+
This replaces the naive per-channel Cholesky loop with a single matrix
|
|
36
|
+
inverse per cluster, keeping the linear algebra in the source array
|
|
37
|
+
namespace so that GPU-backed arrays benefit from device-side computation.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
from __future__ import annotations
|
|
41
|
+
|
|
42
|
+
import os
|
|
43
|
+
import typing
|
|
44
|
+
from abc import abstractmethod
|
|
45
|
+
from pathlib import Path
|
|
46
|
+
|
|
47
|
+
import ezmsg.core as ez
|
|
48
|
+
import numpy as np
|
|
49
|
+
from array_api_compat import get_namespace
|
|
50
|
+
from ezmsg.baseproc import (
|
|
51
|
+
BaseAdaptiveTransformer,
|
|
52
|
+
BaseAdaptiveTransformerUnit,
|
|
53
|
+
processor_state,
|
|
54
|
+
)
|
|
55
|
+
from ezmsg.baseproc.protocols import SettingsType, StateType
|
|
56
|
+
from ezmsg.sigproc.affinetransform import (
|
|
57
|
+
AffineTransformSettings,
|
|
58
|
+
AffineTransformTransformer,
|
|
59
|
+
)
|
|
60
|
+
from ezmsg.sigproc.util.array import array_device, xp_create
|
|
61
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
62
|
+
|
|
63
|
+
# ---------------------------------------------------------------------------
|
|
64
|
+
# Base: Self-supervised regression
|
|
65
|
+
# ---------------------------------------------------------------------------
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SelfSupervisedRegressionSettings(ez.Settings):
|
|
69
|
+
"""Settings common to all self-supervised regression modes."""
|
|
70
|
+
|
|
71
|
+
weights: np.ndarray | str | Path | None = None
|
|
72
|
+
"""Pre-calculated weight matrix *W* or path to a CSV file (``np.loadtxt``
|
|
73
|
+
compatible). If provided, the transformer is ready immediately."""
|
|
74
|
+
|
|
75
|
+
axis: str | None = None
|
|
76
|
+
"""Channel axis name. ``None`` defaults to the last dimension."""
|
|
77
|
+
|
|
78
|
+
channel_clusters: list[list[int]] | None = None
|
|
79
|
+
"""Per-cluster regression. ``None`` treats all channels as one cluster."""
|
|
80
|
+
|
|
81
|
+
ridge_lambda: float = 0.0
|
|
82
|
+
"""Ridge (L2) regularisation parameter."""
|
|
83
|
+
|
|
84
|
+
incremental: bool = True
|
|
85
|
+
"""When ``True``, accumulate ``X^T X`` across :meth:`partial_fit` calls.
|
|
86
|
+
When ``False``, each call replaces the previous statistics."""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@processor_state
|
|
90
|
+
class SelfSupervisedRegressionState:
|
|
91
|
+
cxx: object | None = None # Array API; namespace matches source data.
|
|
92
|
+
n_samples: int = 0
|
|
93
|
+
weights: object | None = None # Array API; namespace matches cxx.
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class SelfSupervisedRegressionTransformer(
|
|
97
|
+
BaseAdaptiveTransformer[SettingsType, AxisArray, AxisArray, StateType],
|
|
98
|
+
typing.Generic[SettingsType, StateType],
|
|
99
|
+
):
|
|
100
|
+
"""Abstract base for self-supervised regression transformers.
|
|
101
|
+
|
|
102
|
+
Subclasses must implement:
|
|
103
|
+
|
|
104
|
+
* :meth:`_on_weights_updated` — called whenever the weight matrix *W* is
|
|
105
|
+
(re)computed, so the subclass can build whatever internal transform it
|
|
106
|
+
needs (e.g. ``I - W`` for LRR).
|
|
107
|
+
* :meth:`_process` — the per-message transform step.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
# -- message hash / state management ------------------------------------
|
|
111
|
+
|
|
112
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
113
|
+
axis = self.settings.axis or message.dims[-1]
|
|
114
|
+
axis_idx = message.get_axis_idx(axis)
|
|
115
|
+
return hash((message.key, message.data.shape[axis_idx]))
|
|
116
|
+
|
|
117
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
118
|
+
axis = self.settings.axis or message.dims[-1]
|
|
119
|
+
axis_idx = message.get_axis_idx(axis)
|
|
120
|
+
n_channels = message.data.shape[axis_idx]
|
|
121
|
+
|
|
122
|
+
self._validate_clusters(n_channels)
|
|
123
|
+
self._state.cxx = None
|
|
124
|
+
self._state.n_samples = 0
|
|
125
|
+
self._state.weights = None
|
|
126
|
+
|
|
127
|
+
# If pre-calculated weights are provided, load and go.
|
|
128
|
+
weights = self.settings.weights
|
|
129
|
+
if weights is not None:
|
|
130
|
+
if isinstance(weights, str):
|
|
131
|
+
weights = Path(os.path.abspath(os.path.expanduser(weights)))
|
|
132
|
+
if isinstance(weights, Path):
|
|
133
|
+
weights = np.loadtxt(weights, delimiter=",")
|
|
134
|
+
weights = np.asarray(weights, dtype=np.float64)
|
|
135
|
+
self._state.weights = weights
|
|
136
|
+
self._on_weights_updated()
|
|
137
|
+
|
|
138
|
+
# -- cluster validation --------------------------------------------------
|
|
139
|
+
|
|
140
|
+
def _validate_clusters(self, n_channels: int) -> None:
|
|
141
|
+
"""Raise if any cluster index is out of range."""
|
|
142
|
+
clusters = self.settings.channel_clusters
|
|
143
|
+
if clusters is None:
|
|
144
|
+
return
|
|
145
|
+
all_indices = np.concatenate([np.asarray(g) for g in clusters])
|
|
146
|
+
if np.any((all_indices < 0) | (all_indices >= n_channels)):
|
|
147
|
+
raise ValueError(f"channel_clusters contains out-of-range indices (valid range: 0..{n_channels - 1})")
|
|
148
|
+
|
|
149
|
+
# -- weight solving ------------------------------------------------------
|
|
150
|
+
|
|
151
|
+
def _solve_weights(self, cxx):
|
|
152
|
+
"""Solve all per-channel ridge regressions via matrix inverse.
|
|
153
|
+
|
|
154
|
+
Uses the block-inverse identity: for target channel *c* with
|
|
155
|
+
references *r*, ``w_c = -C_inv[r, c] / C_inv[c, c]`` where
|
|
156
|
+
``C_inv = (C_cluster + λI)⁻¹``. This replaces the per-channel
|
|
157
|
+
Cholesky loop with one matrix inverse per cluster.
|
|
158
|
+
|
|
159
|
+
All computation stays in the source array namespace so that
|
|
160
|
+
GPU-backed arrays benefit from device-side execution. Cluster
|
|
161
|
+
results are scattered into the full matrix via a selection-matrix
|
|
162
|
+
multiply (``S @ W_cluster @ S^T``) to avoid numpy fancy indexing.
|
|
163
|
+
|
|
164
|
+
Returns weight matrix *W* in the same namespace as *cxx*, with
|
|
165
|
+
``diag(W) == 0``.
|
|
166
|
+
"""
|
|
167
|
+
xp = get_namespace(cxx)
|
|
168
|
+
dev = array_device(cxx)
|
|
169
|
+
n = cxx.shape[0]
|
|
170
|
+
|
|
171
|
+
clusters = self.settings.channel_clusters
|
|
172
|
+
if clusters is None:
|
|
173
|
+
clusters = [list(range(n))]
|
|
174
|
+
|
|
175
|
+
W = xp_create(xp.zeros, (n, n), dtype=cxx.dtype, device=dev)
|
|
176
|
+
eye_n = xp_create(xp.eye, n, dtype=cxx.dtype, device=dev)
|
|
177
|
+
|
|
178
|
+
for cluster in clusters:
|
|
179
|
+
k = len(cluster)
|
|
180
|
+
if k <= 1:
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
idx_xp = xp.asarray(cluster) if dev is None else xp.asarray(cluster, device=dev)
|
|
184
|
+
eye_k = xp_create(xp.eye, k, dtype=cxx.dtype, device=dev)
|
|
185
|
+
|
|
186
|
+
# Extract cluster sub-covariance (stays on device)
|
|
187
|
+
sub = xp.take(xp.take(cxx, idx_xp, axis=0), idx_xp, axis=1)
|
|
188
|
+
|
|
189
|
+
if self.settings.ridge_lambda > 0:
|
|
190
|
+
sub = sub + self.settings.ridge_lambda * eye_k
|
|
191
|
+
|
|
192
|
+
# One inverse per cluster
|
|
193
|
+
try:
|
|
194
|
+
sub_inv = xp.linalg.inv(sub)
|
|
195
|
+
except Exception:
|
|
196
|
+
sub_inv = xp.linalg.pinv(sub)
|
|
197
|
+
|
|
198
|
+
# Diagonal via element-wise product with identity
|
|
199
|
+
diag_vals = xp.sum(sub_inv * eye_k, axis=0)
|
|
200
|
+
|
|
201
|
+
# w_c = -C_inv[:, c] / C_inv[c, c], vectorised over all c
|
|
202
|
+
W_cluster = -(sub_inv / xp.reshape(diag_vals, (1, k)))
|
|
203
|
+
|
|
204
|
+
# Zero the diagonal
|
|
205
|
+
W_cluster = W_cluster * (1.0 - eye_k)
|
|
206
|
+
|
|
207
|
+
# Scatter into full W
|
|
208
|
+
if k == n:
|
|
209
|
+
W = W + W_cluster
|
|
210
|
+
else:
|
|
211
|
+
# Selection matrix: columns of eye(n) at cluster indices
|
|
212
|
+
S = xp.take(eye_n, idx_xp, axis=1) # (n, k)
|
|
213
|
+
W = W + xp.matmul(S, xp.matmul(W_cluster, xp.permute_dims(S, (1, 0))))
|
|
214
|
+
|
|
215
|
+
return W
|
|
216
|
+
|
|
217
|
+
# -- partial_fit (self-supervised, accepts AxisArray) --------------------
|
|
218
|
+
|
|
219
|
+
def partial_fit(self, message: AxisArray) -> None: # type: ignore[override]
|
|
220
|
+
xp = get_namespace(message.data)
|
|
221
|
+
|
|
222
|
+
if xp.any(xp.isnan(message.data)):
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
# Hash check / state reset
|
|
226
|
+
msg_hash = self._hash_message(message)
|
|
227
|
+
if self._hash != msg_hash:
|
|
228
|
+
self._reset_state(message)
|
|
229
|
+
self._hash = msg_hash
|
|
230
|
+
|
|
231
|
+
axis = self.settings.axis or message.dims[-1]
|
|
232
|
+
axis_idx = message.get_axis_idx(axis)
|
|
233
|
+
data = message.data
|
|
234
|
+
|
|
235
|
+
# Move channel axis to last, flatten to 2-D
|
|
236
|
+
if axis_idx != data.ndim - 1:
|
|
237
|
+
perm = list(range(data.ndim))
|
|
238
|
+
perm.append(perm.pop(axis_idx))
|
|
239
|
+
data = xp.permute_dims(data, perm)
|
|
240
|
+
|
|
241
|
+
n_channels = data.shape[-1]
|
|
242
|
+
X = xp.reshape(data, (-1, n_channels))
|
|
243
|
+
|
|
244
|
+
# Covariance stays in the source namespace for accumulation.
|
|
245
|
+
cxx_new = xp.matmul(xp.permute_dims(X, (1, 0)), X)
|
|
246
|
+
|
|
247
|
+
if self.settings.incremental and self._state.cxx is not None:
|
|
248
|
+
self._state.cxx = self._state.cxx + cxx_new
|
|
249
|
+
else:
|
|
250
|
+
self._state.cxx = cxx_new
|
|
251
|
+
self._state.n_samples += int(X.shape[0])
|
|
252
|
+
|
|
253
|
+
self._state.weights = self._solve_weights(self._state.cxx)
|
|
254
|
+
self._on_weights_updated()
|
|
255
|
+
|
|
256
|
+
# -- convenience APIs ----------------------------------------------------
|
|
257
|
+
|
|
258
|
+
def fit(self, X: np.ndarray) -> None:
|
|
259
|
+
"""Batch fit from a raw numpy array (samples x channels)."""
|
|
260
|
+
n_channels = X.shape[-1]
|
|
261
|
+
self._validate_clusters(n_channels)
|
|
262
|
+
X = np.asarray(X, dtype=np.float64).reshape(-1, n_channels)
|
|
263
|
+
self._state.cxx = X.T @ X
|
|
264
|
+
self._state.n_samples = X.shape[0]
|
|
265
|
+
self._state.weights = self._solve_weights(self._state.cxx)
|
|
266
|
+
self._on_weights_updated()
|
|
267
|
+
|
|
268
|
+
def fit_transform(self, message: AxisArray) -> AxisArray:
|
|
269
|
+
"""Convenience: ``partial_fit`` then ``_process``."""
|
|
270
|
+
self.partial_fit(message)
|
|
271
|
+
return self._process(message)
|
|
272
|
+
|
|
273
|
+
# -- abstract hooks for subclasses ---------------------------------------
|
|
274
|
+
|
|
275
|
+
@abstractmethod
|
|
276
|
+
def _on_weights_updated(self) -> None:
|
|
277
|
+
"""Called after ``self._state.weights`` has been set/updated.
|
|
278
|
+
|
|
279
|
+
Subclasses should build or refresh whatever internal transform
|
|
280
|
+
object they need for :meth:`_process`.
|
|
281
|
+
"""
|
|
282
|
+
...
|
|
283
|
+
|
|
284
|
+
@abstractmethod
|
|
285
|
+
def _process(self, message: AxisArray) -> AxisArray: ...
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# ---------------------------------------------------------------------------
|
|
289
|
+
# Concrete: Linear Regression Rereferencing (LRR)
|
|
290
|
+
# ---------------------------------------------------------------------------
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class LRRSettings(SelfSupervisedRegressionSettings):
|
|
294
|
+
"""Settings for :class:`LRRTransformer`."""
|
|
295
|
+
|
|
296
|
+
min_cluster_size: int = 32
|
|
297
|
+
"""Passed to :class:`AffineTransformTransformer` for the block-diagonal
|
|
298
|
+
merge threshold."""
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@processor_state
|
|
302
|
+
class LRRState(SelfSupervisedRegressionState):
|
|
303
|
+
affine: AffineTransformTransformer | None = None
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class LRRTransformer(
|
|
307
|
+
SelfSupervisedRegressionTransformer[LRRSettings, LRRState],
|
|
308
|
+
):
|
|
309
|
+
"""Adaptive LRR transformer.
|
|
310
|
+
|
|
311
|
+
``partial_fit`` accepts a plain :class:`AxisArray` (self-supervised),
|
|
312
|
+
and the transform step is delegated to an internal :class:`AffineTransformTransformer`.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
# -- state management (clear own state, then delegate to base) ----------
|
|
316
|
+
|
|
317
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
318
|
+
self._state.affine = None
|
|
319
|
+
super()._reset_state(message)
|
|
320
|
+
|
|
321
|
+
# -- weights → affine transform -----------------------------------------
|
|
322
|
+
|
|
323
|
+
def _on_weights_updated(self) -> None:
|
|
324
|
+
xp = get_namespace(self._state.weights)
|
|
325
|
+
dev = array_device(self._state.weights)
|
|
326
|
+
n = self._state.weights.shape[0]
|
|
327
|
+
effective = xp_create(xp.eye, n, dtype=self._state.weights.dtype, device=dev) - self._state.weights
|
|
328
|
+
|
|
329
|
+
# Prefer in-place weight update when the affine transformer supports
|
|
330
|
+
# it (avoids a full _reset_state round-trip on every partial_fit).
|
|
331
|
+
if self._state.affine is not None:
|
|
332
|
+
self._state.affine.set_weights(effective)
|
|
333
|
+
else:
|
|
334
|
+
self._state.affine = AffineTransformTransformer(
|
|
335
|
+
AffineTransformSettings(
|
|
336
|
+
weights=effective,
|
|
337
|
+
axis=self.settings.axis,
|
|
338
|
+
channel_clusters=self.settings.channel_clusters,
|
|
339
|
+
min_cluster_size=self.settings.min_cluster_size,
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# -- transform -----------------------------------------------------------
|
|
344
|
+
|
|
345
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
346
|
+
if self._state.affine is None:
|
|
347
|
+
raise RuntimeError(
|
|
348
|
+
"LRRTransformer has not been fitted. Call partial_fit() or provide pre-calculated weights."
|
|
349
|
+
)
|
|
350
|
+
return self._state.affine(message)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class LRRUnit(
|
|
354
|
+
BaseAdaptiveTransformerUnit[
|
|
355
|
+
LRRSettings,
|
|
356
|
+
AxisArray,
|
|
357
|
+
AxisArray,
|
|
358
|
+
LRRTransformer,
|
|
359
|
+
],
|
|
360
|
+
):
|
|
361
|
+
"""ezmsg Unit wrapping :class:`LRRTransformer`.
|
|
362
|
+
|
|
363
|
+
Follows the :class:`BaseAdaptiveDecompUnit` pattern — accepts
|
|
364
|
+
:class:`AxisArray` (not :class:`SampleMessage`) for self-supervised
|
|
365
|
+
training via ``INPUT_SAMPLE``.
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
SETTINGS = LRRSettings
|
|
369
|
+
|
|
370
|
+
INPUT_SAMPLE = ez.InputStream(AxisArray)
|
|
371
|
+
|
|
372
|
+
@ez.subscriber(INPUT_SAMPLE)
|
|
373
|
+
async def on_sample(self, msg: AxisArray) -> None:
|
|
374
|
+
await self.processor.apartial_fit(msg)
|