ezmsg-learn 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__version__.py +2 -2
- ezmsg/learn/dim_reduce/adaptive_decomp.py +9 -19
- ezmsg/learn/dim_reduce/incremental_decomp.py +8 -16
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
- ezmsg/learn/linear_model/linear_regressor.py +4 -0
- ezmsg/learn/linear_model/sgd.py +6 -2
- ezmsg/learn/linear_model/slda.py +7 -1
- ezmsg/learn/model/mlp.py +8 -14
- ezmsg/learn/model/refit_kalman.py +17 -49
- ezmsg/learn/nlin_model/mlp.py +5 -1
- ezmsg/learn/process/adaptive_linear_regressor.py +20 -36
- ezmsg/learn/process/base.py +12 -31
- ezmsg/learn/process/linear_regressor.py +13 -18
- ezmsg/learn/process/mlp_old.py +18 -31
- ezmsg/learn/process/refit_kalman.py +8 -13
- ezmsg/learn/process/rnn.py +14 -36
- ezmsg/learn/process/sgd.py +94 -109
- ezmsg/learn/process/sklearn.py +17 -51
- ezmsg/learn/process/slda.py +6 -15
- ezmsg/learn/process/ssr.py +374 -0
- ezmsg/learn/process/torch.py +12 -29
- ezmsg/learn/process/transformer.py +11 -19
- ezmsg/learn/util.py +5 -4
- {ezmsg_learn-1.0.dist-info → ezmsg_learn-1.2.0.dist-info}/METADATA +5 -9
- ezmsg_learn-1.2.0.dist-info/RECORD +38 -0
- {ezmsg_learn-1.0.dist-info → ezmsg_learn-1.2.0.dist-info}/WHEEL +1 -1
- ezmsg_learn-1.2.0.dist-info/licenses/LICENSE +21 -0
- ezmsg_learn-1.0.dist-info/RECORD +0 -36
ezmsg/learn/process/sklearn.py
CHANGED
|
@@ -5,12 +5,11 @@ import typing
|
|
|
5
5
|
import ezmsg.core as ez
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pandas as pd
|
|
8
|
-
from ezmsg.
|
|
8
|
+
from ezmsg.baseproc import (
|
|
9
9
|
BaseAdaptiveTransformer,
|
|
10
10
|
BaseAdaptiveTransformerUnit,
|
|
11
11
|
processor_state,
|
|
12
12
|
)
|
|
13
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
14
13
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
15
14
|
from ezmsg.util.messages.util import replace
|
|
16
15
|
|
|
@@ -45,11 +44,7 @@ class SklearnModelState:
|
|
|
45
44
|
chan_ax: AxisArray.CoordinateAxis | None = None
|
|
46
45
|
|
|
47
46
|
|
|
48
|
-
class SklearnModelProcessor(
|
|
49
|
-
BaseAdaptiveTransformer[
|
|
50
|
-
SklearnModelSettings, AxisArray, AxisArray, SklearnModelState
|
|
51
|
-
]
|
|
52
|
-
):
|
|
47
|
+
class SklearnModelProcessor(BaseAdaptiveTransformer[SklearnModelSettings, AxisArray, AxisArray, SklearnModelState]):
|
|
53
48
|
"""
|
|
54
49
|
Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework.
|
|
55
50
|
|
|
@@ -115,40 +110,30 @@ class SklearnModelProcessor(
|
|
|
115
110
|
if hasattr(self._state.model, "n_features_in_"):
|
|
116
111
|
expected = self._state.model.n_features_in_
|
|
117
112
|
if expected != n_input:
|
|
118
|
-
raise ValueError(
|
|
119
|
-
f"Model expects {expected} features, but got {n_input}"
|
|
120
|
-
)
|
|
113
|
+
raise ValueError(f"Model expects {expected} features, but got {n_input}")
|
|
121
114
|
else:
|
|
122
115
|
# No checkpoint, initialize from scratch
|
|
123
116
|
self._init_model()
|
|
124
117
|
|
|
125
|
-
def partial_fit(self, message:
|
|
126
|
-
X = message.
|
|
127
|
-
y = message.trigger.value
|
|
118
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
119
|
+
X = message.data
|
|
120
|
+
y = message.attrs["trigger"].value
|
|
128
121
|
if self._state.model is None:
|
|
129
|
-
self._reset_state(message
|
|
122
|
+
self._reset_state(message)
|
|
130
123
|
if hasattr(self._state.model, "partial_fit"):
|
|
131
124
|
kwargs = {}
|
|
132
125
|
if self.settings.partial_fit_classes is not None:
|
|
133
126
|
kwargs["classes"] = self.settings.partial_fit_classes
|
|
134
127
|
self._state.model.partial_fit(X, y, **kwargs)
|
|
135
128
|
elif hasattr(self._state.model, "learn_many"):
|
|
136
|
-
df_X = pd.DataFrame(
|
|
137
|
-
{
|
|
138
|
-
k: v
|
|
139
|
-
for k, v in zip(
|
|
140
|
-
message.sample.axes["ch"].data, message.sample.data.T
|
|
141
|
-
)
|
|
142
|
-
}
|
|
143
|
-
)
|
|
129
|
+
df_X = pd.DataFrame({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
|
|
144
130
|
name = (
|
|
145
|
-
message.trigger.value.axes["ch"].data[0]
|
|
146
|
-
if hasattr(message.trigger.value, "axes")
|
|
147
|
-
and "ch" in message.trigger.value.axes
|
|
131
|
+
message.attrs["trigger"].value.axes["ch"].data[0]
|
|
132
|
+
if hasattr(message.attrs["trigger"].value, "axes") and "ch" in message.attrs["trigger"].value.axes
|
|
148
133
|
else "target"
|
|
149
134
|
)
|
|
150
135
|
ser_y = pd.Series(
|
|
151
|
-
data=np.asarray(message.trigger.value.data).flatten(),
|
|
136
|
+
data=np.asarray(message.attrs["trigger"].value.data).flatten(),
|
|
152
137
|
name=name,
|
|
153
138
|
)
|
|
154
139
|
self._state.model.learn_many(df_X, ser_y)
|
|
@@ -158,9 +143,7 @@ class SklearnModelProcessor(
|
|
|
158
143
|
features = {f"f{i}": xi[i] for i in range(len(xi))}
|
|
159
144
|
self._state.model.learn_one(features, yi)
|
|
160
145
|
else:
|
|
161
|
-
raise NotImplementedError(
|
|
162
|
-
"Model does not support partial_fit or learn_many"
|
|
163
|
-
)
|
|
146
|
+
raise NotImplementedError("Model does not support partial_fit or learn_many")
|
|
164
147
|
|
|
165
148
|
def fit(self, X: np.ndarray, y: np.ndarray) -> None:
|
|
166
149
|
if self._state.model is None:
|
|
@@ -192,9 +175,7 @@ class SklearnModelProcessor(
|
|
|
192
175
|
|
|
193
176
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
194
177
|
if self._state.model is None:
|
|
195
|
-
raise RuntimeError(
|
|
196
|
-
"Model has not been fit yet. Call `fit()` or `partial_fit()` before processing."
|
|
197
|
-
)
|
|
178
|
+
raise RuntimeError("Model has not been fit yet. Call `fit()` or `partial_fit()` before processing.")
|
|
198
179
|
X = message.data
|
|
199
180
|
original_shape = X.shape
|
|
200
181
|
n_input = X.shape[message.get_axis_idx("ch")]
|
|
@@ -204,9 +185,7 @@ class SklearnModelProcessor(
|
|
|
204
185
|
if hasattr(self._state.model, "n_features_in_"):
|
|
205
186
|
expected = self._state.model.n_features_in_
|
|
206
187
|
if expected != n_input:
|
|
207
|
-
raise ValueError(
|
|
208
|
-
f"Model expects {expected} features, but got {n_input}"
|
|
209
|
-
)
|
|
188
|
+
raise ValueError(f"Model expects {expected} features, but got {n_input}")
|
|
210
189
|
|
|
211
190
|
if hasattr(self._state.model, "predict"):
|
|
212
191
|
y_pred = self._state.model.predict(X)
|
|
@@ -216,14 +195,7 @@ class SklearnModelProcessor(
|
|
|
216
195
|
y_pred = np.array(list(y_pred))
|
|
217
196
|
elif hasattr(self._state.model, "predict_one"):
|
|
218
197
|
# river's random forest does not support predict_many
|
|
219
|
-
y_pred = np.array(
|
|
220
|
-
[
|
|
221
|
-
self._state.model.predict_one(
|
|
222
|
-
{f"f{i}": xi[i] for i in range(len(xi))}
|
|
223
|
-
)
|
|
224
|
-
for xi in X
|
|
225
|
-
]
|
|
226
|
-
)
|
|
198
|
+
y_pred = np.array([self._state.model.predict_one({f"f{i}": xi[i] for i in range(len(xi))}) for xi in X])
|
|
227
199
|
else:
|
|
228
200
|
raise NotImplementedError("Model does not support predict or predict_many")
|
|
229
201
|
|
|
@@ -235,9 +207,7 @@ class SklearnModelProcessor(
|
|
|
235
207
|
y_pred = y_pred.reshape(output_shape)
|
|
236
208
|
|
|
237
209
|
if self._state.chan_ax is None:
|
|
238
|
-
self._state.chan_ax = AxisArray.CoordinateAxis(
|
|
239
|
-
data=np.arange(output_shape[1]), dims=["ch"]
|
|
240
|
-
)
|
|
210
|
+
self._state.chan_ax = AxisArray.CoordinateAxis(data=np.arange(output_shape[1]), dims=["ch"])
|
|
241
211
|
|
|
242
212
|
return replace(
|
|
243
213
|
message,
|
|
@@ -246,11 +216,7 @@ class SklearnModelProcessor(
|
|
|
246
216
|
)
|
|
247
217
|
|
|
248
218
|
|
|
249
|
-
class SklearnModelUnit(
|
|
250
|
-
BaseAdaptiveTransformerUnit[
|
|
251
|
-
SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor
|
|
252
|
-
]
|
|
253
|
-
):
|
|
219
|
+
class SklearnModelUnit(BaseAdaptiveTransformerUnit[SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor]):
|
|
254
220
|
"""
|
|
255
221
|
Unit wrapper for the `SklearnModelProcessor`.
|
|
256
222
|
|
ezmsg/learn/process/slda.py
CHANGED
|
@@ -2,10 +2,10 @@ import typing
|
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
import numpy as np
|
|
5
|
-
from ezmsg.
|
|
5
|
+
from ezmsg.baseproc import (
|
|
6
6
|
BaseStatefulTransformer,
|
|
7
|
-
processor_state,
|
|
8
7
|
BaseTransformerUnit,
|
|
8
|
+
processor_state,
|
|
9
9
|
)
|
|
10
10
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
11
|
from ezmsg.util.messages.util import replace
|
|
@@ -25,9 +25,7 @@ class SLDAState:
|
|
|
25
25
|
out_template: typing.Optional[ClassifierMessage] = None
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
class SLDATransformer(
|
|
29
|
-
BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]
|
|
30
|
-
):
|
|
28
|
+
class SLDATransformer(BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]):
|
|
31
29
|
def _reset_state(self, message: AxisArray) -> None:
|
|
32
30
|
if self.settings.settings_path[-4:] == ".mat":
|
|
33
31
|
# Expects a very specific format from a specific project. Not for general use.
|
|
@@ -67,9 +65,7 @@ class SLDATransformer(
|
|
|
67
65
|
dims=[self.settings.axis, "classes"],
|
|
68
66
|
axes={
|
|
69
67
|
self.settings.axis: message.axes[self.settings.axis],
|
|
70
|
-
"classes": AxisArray.CoordinateAxis(
|
|
71
|
-
data=np.array(out_labels), dims=["classes"]
|
|
72
|
-
),
|
|
68
|
+
"classes": AxisArray.CoordinateAxis(data=np.array(out_labels), dims=["classes"]),
|
|
73
69
|
},
|
|
74
70
|
labels=out_labels,
|
|
75
71
|
key=message.key,
|
|
@@ -80,10 +76,7 @@ class SLDATransformer(
|
|
|
80
76
|
X = np.moveaxis(message.data, samp_ax_idx, 0)
|
|
81
77
|
|
|
82
78
|
if X.shape[0]:
|
|
83
|
-
if (
|
|
84
|
-
isinstance(self.settings.settings_path, str)
|
|
85
|
-
and self.settings.settings_path[-4:] == ".mat"
|
|
86
|
-
):
|
|
79
|
+
if isinstance(self.settings.settings_path, str) and self.settings.settings_path[-4:] == ".mat":
|
|
87
80
|
# Assumes F-contiguous weights
|
|
88
81
|
pred_probas = []
|
|
89
82
|
for samp in X:
|
|
@@ -113,7 +106,5 @@ class SLDATransformer(
|
|
|
113
106
|
return self.state.out_template
|
|
114
107
|
|
|
115
108
|
|
|
116
|
-
class SLDA(
|
|
117
|
-
BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]
|
|
118
|
-
):
|
|
109
|
+
class SLDA(BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]):
|
|
119
110
|
SETTINGS = SLDASettings
|
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
"""Self-supervised regression framework and LRR implementation.
|
|
2
|
+
|
|
3
|
+
This module provides a general framework for self-supervised channel
|
|
4
|
+
regression via :class:`SelfSupervisedRegressionTransformer`, and a
|
|
5
|
+
concrete implementation — Linear Regression Rereferencing (LRR) — via
|
|
6
|
+
:class:`LRRTransformer`.
|
|
7
|
+
|
|
8
|
+
**Framework.** The base class accumulates the channel covariance
|
|
9
|
+
``C = X^T X`` and solves per-cluster ridge regressions to obtain a weight
|
|
10
|
+
matrix *W*. Subclasses define what to *do* with *W* by implementing
|
|
11
|
+
:meth:`~SelfSupervisedRegressionTransformer._on_weights_updated` and
|
|
12
|
+
:meth:`~SelfSupervisedRegressionTransformer._process`.
|
|
13
|
+
|
|
14
|
+
**LRR.** For each channel *c*, predict it from the other channels in its
|
|
15
|
+
cluster via ridge regression, then subtract the prediction::
|
|
16
|
+
|
|
17
|
+
y = X - X @ W = X @ (I - W)
|
|
18
|
+
|
|
19
|
+
The effective weight matrix ``I - W`` is passed to
|
|
20
|
+
:class:`~ezmsg.sigproc.affinetransform.AffineTransformTransformer`, which
|
|
21
|
+
automatically exploits block-diagonal structure when ``channel_clusters``
|
|
22
|
+
are provided.
|
|
23
|
+
|
|
24
|
+
**Fitting.** Given data matrix *X* of shape ``(samples, channels)``, the
|
|
25
|
+
sufficient statistic is the channel covariance ``C = X^T X``. When
|
|
26
|
+
``incremental=True`` (default), *C* is accumulated across
|
|
27
|
+
:meth:`~SelfSupervisedRegressionTransformer.partial_fit` calls.
|
|
28
|
+
|
|
29
|
+
**Solving.** Within each cluster the weight matrix *W* is obtained from
|
|
30
|
+
the inverse of the (ridge-regularised) cluster covariance
|
|
31
|
+
``C_inv = (C_cluster + lambda * I)^{-1}`` using the block-inverse identity::
|
|
32
|
+
|
|
33
|
+
W[:, c] = -C_inv[:, c] / C_inv[c, c], diag(W) = 0
|
|
34
|
+
|
|
35
|
+
This replaces the naive per-channel Cholesky loop with a single matrix
|
|
36
|
+
inverse per cluster, keeping the linear algebra in the source array
|
|
37
|
+
namespace so that GPU-backed arrays benefit from device-side computation.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
from __future__ import annotations
|
|
41
|
+
|
|
42
|
+
import os
|
|
43
|
+
import typing
|
|
44
|
+
from abc import abstractmethod
|
|
45
|
+
from pathlib import Path
|
|
46
|
+
|
|
47
|
+
import ezmsg.core as ez
|
|
48
|
+
import numpy as np
|
|
49
|
+
from array_api_compat import get_namespace
|
|
50
|
+
from ezmsg.baseproc import (
|
|
51
|
+
BaseAdaptiveTransformer,
|
|
52
|
+
BaseAdaptiveTransformerUnit,
|
|
53
|
+
processor_state,
|
|
54
|
+
)
|
|
55
|
+
from ezmsg.baseproc.protocols import SettingsType, StateType
|
|
56
|
+
from ezmsg.sigproc.affinetransform import (
|
|
57
|
+
AffineTransformSettings,
|
|
58
|
+
AffineTransformTransformer,
|
|
59
|
+
)
|
|
60
|
+
from ezmsg.sigproc.util.array import array_device, xp_create
|
|
61
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
62
|
+
|
|
63
|
+
# ---------------------------------------------------------------------------
|
|
64
|
+
# Base: Self-supervised regression
|
|
65
|
+
# ---------------------------------------------------------------------------
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SelfSupervisedRegressionSettings(ez.Settings):
|
|
69
|
+
"""Settings common to all self-supervised regression modes."""
|
|
70
|
+
|
|
71
|
+
weights: np.ndarray | str | Path | None = None
|
|
72
|
+
"""Pre-calculated weight matrix *W* or path to a CSV file (``np.loadtxt``
|
|
73
|
+
compatible). If provided, the transformer is ready immediately."""
|
|
74
|
+
|
|
75
|
+
axis: str | None = None
|
|
76
|
+
"""Channel axis name. ``None`` defaults to the last dimension."""
|
|
77
|
+
|
|
78
|
+
channel_clusters: list[list[int]] | None = None
|
|
79
|
+
"""Per-cluster regression. ``None`` treats all channels as one cluster."""
|
|
80
|
+
|
|
81
|
+
ridge_lambda: float = 0.0
|
|
82
|
+
"""Ridge (L2) regularisation parameter."""
|
|
83
|
+
|
|
84
|
+
incremental: bool = True
|
|
85
|
+
"""When ``True``, accumulate ``X^T X`` across :meth:`partial_fit` calls.
|
|
86
|
+
When ``False``, each call replaces the previous statistics."""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@processor_state
|
|
90
|
+
class SelfSupervisedRegressionState:
|
|
91
|
+
cxx: object | None = None # Array API; namespace matches source data.
|
|
92
|
+
n_samples: int = 0
|
|
93
|
+
weights: object | None = None # Array API; namespace matches cxx.
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class SelfSupervisedRegressionTransformer(
|
|
97
|
+
BaseAdaptiveTransformer[SettingsType, AxisArray, AxisArray, StateType],
|
|
98
|
+
typing.Generic[SettingsType, StateType],
|
|
99
|
+
):
|
|
100
|
+
"""Abstract base for self-supervised regression transformers.
|
|
101
|
+
|
|
102
|
+
Subclasses must implement:
|
|
103
|
+
|
|
104
|
+
* :meth:`_on_weights_updated` — called whenever the weight matrix *W* is
|
|
105
|
+
(re)computed, so the subclass can build whatever internal transform it
|
|
106
|
+
needs (e.g. ``I - W`` for LRR).
|
|
107
|
+
* :meth:`_process` — the per-message transform step.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
# -- message hash / state management ------------------------------------
|
|
111
|
+
|
|
112
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
113
|
+
axis = self.settings.axis or message.dims[-1]
|
|
114
|
+
axis_idx = message.get_axis_idx(axis)
|
|
115
|
+
return hash((message.key, message.data.shape[axis_idx]))
|
|
116
|
+
|
|
117
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
118
|
+
axis = self.settings.axis or message.dims[-1]
|
|
119
|
+
axis_idx = message.get_axis_idx(axis)
|
|
120
|
+
n_channels = message.data.shape[axis_idx]
|
|
121
|
+
|
|
122
|
+
self._validate_clusters(n_channels)
|
|
123
|
+
self._state.cxx = None
|
|
124
|
+
self._state.n_samples = 0
|
|
125
|
+
self._state.weights = None
|
|
126
|
+
|
|
127
|
+
# If pre-calculated weights are provided, load and go.
|
|
128
|
+
weights = self.settings.weights
|
|
129
|
+
if weights is not None:
|
|
130
|
+
if isinstance(weights, str):
|
|
131
|
+
weights = Path(os.path.abspath(os.path.expanduser(weights)))
|
|
132
|
+
if isinstance(weights, Path):
|
|
133
|
+
weights = np.loadtxt(weights, delimiter=",")
|
|
134
|
+
weights = np.asarray(weights, dtype=np.float64)
|
|
135
|
+
self._state.weights = weights
|
|
136
|
+
self._on_weights_updated()
|
|
137
|
+
|
|
138
|
+
# -- cluster validation --------------------------------------------------
|
|
139
|
+
|
|
140
|
+
def _validate_clusters(self, n_channels: int) -> None:
|
|
141
|
+
"""Raise if any cluster index is out of range."""
|
|
142
|
+
clusters = self.settings.channel_clusters
|
|
143
|
+
if clusters is None:
|
|
144
|
+
return
|
|
145
|
+
all_indices = np.concatenate([np.asarray(g) for g in clusters])
|
|
146
|
+
if np.any((all_indices < 0) | (all_indices >= n_channels)):
|
|
147
|
+
raise ValueError(f"channel_clusters contains out-of-range indices (valid range: 0..{n_channels - 1})")
|
|
148
|
+
|
|
149
|
+
# -- weight solving ------------------------------------------------------
|
|
150
|
+
|
|
151
|
+
def _solve_weights(self, cxx):
|
|
152
|
+
"""Solve all per-channel ridge regressions via matrix inverse.
|
|
153
|
+
|
|
154
|
+
Uses the block-inverse identity: for target channel *c* with
|
|
155
|
+
references *r*, ``w_c = -C_inv[r, c] / C_inv[c, c]`` where
|
|
156
|
+
``C_inv = (C_cluster + λI)⁻¹``. This replaces the per-channel
|
|
157
|
+
Cholesky loop with one matrix inverse per cluster.
|
|
158
|
+
|
|
159
|
+
All computation stays in the source array namespace so that
|
|
160
|
+
GPU-backed arrays benefit from device-side execution. Cluster
|
|
161
|
+
results are scattered into the full matrix via a selection-matrix
|
|
162
|
+
multiply (``S @ W_cluster @ S^T``) to avoid numpy fancy indexing.
|
|
163
|
+
|
|
164
|
+
Returns weight matrix *W* in the same namespace as *cxx*, with
|
|
165
|
+
``diag(W) == 0``.
|
|
166
|
+
"""
|
|
167
|
+
xp = get_namespace(cxx)
|
|
168
|
+
dev = array_device(cxx)
|
|
169
|
+
n = cxx.shape[0]
|
|
170
|
+
|
|
171
|
+
clusters = self.settings.channel_clusters
|
|
172
|
+
if clusters is None:
|
|
173
|
+
clusters = [list(range(n))]
|
|
174
|
+
|
|
175
|
+
W = xp_create(xp.zeros, (n, n), dtype=cxx.dtype, device=dev)
|
|
176
|
+
eye_n = xp_create(xp.eye, n, dtype=cxx.dtype, device=dev)
|
|
177
|
+
|
|
178
|
+
for cluster in clusters:
|
|
179
|
+
k = len(cluster)
|
|
180
|
+
if k <= 1:
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
idx_xp = xp.asarray(cluster) if dev is None else xp.asarray(cluster, device=dev)
|
|
184
|
+
eye_k = xp_create(xp.eye, k, dtype=cxx.dtype, device=dev)
|
|
185
|
+
|
|
186
|
+
# Extract cluster sub-covariance (stays on device)
|
|
187
|
+
sub = xp.take(xp.take(cxx, idx_xp, axis=0), idx_xp, axis=1)
|
|
188
|
+
|
|
189
|
+
if self.settings.ridge_lambda > 0:
|
|
190
|
+
sub = sub + self.settings.ridge_lambda * eye_k
|
|
191
|
+
|
|
192
|
+
# One inverse per cluster
|
|
193
|
+
try:
|
|
194
|
+
sub_inv = xp.linalg.inv(sub)
|
|
195
|
+
except Exception:
|
|
196
|
+
sub_inv = xp.linalg.pinv(sub)
|
|
197
|
+
|
|
198
|
+
# Diagonal via element-wise product with identity
|
|
199
|
+
diag_vals = xp.sum(sub_inv * eye_k, axis=0)
|
|
200
|
+
|
|
201
|
+
# w_c = -C_inv[:, c] / C_inv[c, c], vectorised over all c
|
|
202
|
+
W_cluster = -(sub_inv / xp.reshape(diag_vals, (1, k)))
|
|
203
|
+
|
|
204
|
+
# Zero the diagonal
|
|
205
|
+
W_cluster = W_cluster * (1.0 - eye_k)
|
|
206
|
+
|
|
207
|
+
# Scatter into full W
|
|
208
|
+
if k == n:
|
|
209
|
+
W = W + W_cluster
|
|
210
|
+
else:
|
|
211
|
+
# Selection matrix: columns of eye(n) at cluster indices
|
|
212
|
+
S = xp.take(eye_n, idx_xp, axis=1) # (n, k)
|
|
213
|
+
W = W + xp.matmul(S, xp.matmul(W_cluster, xp.permute_dims(S, (1, 0))))
|
|
214
|
+
|
|
215
|
+
return W
|
|
216
|
+
|
|
217
|
+
# -- partial_fit (self-supervised, accepts AxisArray) --------------------
|
|
218
|
+
|
|
219
|
+
def partial_fit(self, message: AxisArray) -> None: # type: ignore[override]
|
|
220
|
+
xp = get_namespace(message.data)
|
|
221
|
+
|
|
222
|
+
if xp.any(xp.isnan(message.data)):
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
# Hash check / state reset
|
|
226
|
+
msg_hash = self._hash_message(message)
|
|
227
|
+
if self._hash != msg_hash:
|
|
228
|
+
self._reset_state(message)
|
|
229
|
+
self._hash = msg_hash
|
|
230
|
+
|
|
231
|
+
axis = self.settings.axis or message.dims[-1]
|
|
232
|
+
axis_idx = message.get_axis_idx(axis)
|
|
233
|
+
data = message.data
|
|
234
|
+
|
|
235
|
+
# Move channel axis to last, flatten to 2-D
|
|
236
|
+
if axis_idx != data.ndim - 1:
|
|
237
|
+
perm = list(range(data.ndim))
|
|
238
|
+
perm.append(perm.pop(axis_idx))
|
|
239
|
+
data = xp.permute_dims(data, perm)
|
|
240
|
+
|
|
241
|
+
n_channels = data.shape[-1]
|
|
242
|
+
X = xp.reshape(data, (-1, n_channels))
|
|
243
|
+
|
|
244
|
+
# Covariance stays in the source namespace for accumulation.
|
|
245
|
+
cxx_new = xp.matmul(xp.permute_dims(X, (1, 0)), X)
|
|
246
|
+
|
|
247
|
+
if self.settings.incremental and self._state.cxx is not None:
|
|
248
|
+
self._state.cxx = self._state.cxx + cxx_new
|
|
249
|
+
else:
|
|
250
|
+
self._state.cxx = cxx_new
|
|
251
|
+
self._state.n_samples += int(X.shape[0])
|
|
252
|
+
|
|
253
|
+
self._state.weights = self._solve_weights(self._state.cxx)
|
|
254
|
+
self._on_weights_updated()
|
|
255
|
+
|
|
256
|
+
# -- convenience APIs ----------------------------------------------------
|
|
257
|
+
|
|
258
|
+
def fit(self, X: np.ndarray) -> None:
|
|
259
|
+
"""Batch fit from a raw numpy array (samples x channels)."""
|
|
260
|
+
n_channels = X.shape[-1]
|
|
261
|
+
self._validate_clusters(n_channels)
|
|
262
|
+
X = np.asarray(X, dtype=np.float64).reshape(-1, n_channels)
|
|
263
|
+
self._state.cxx = X.T @ X
|
|
264
|
+
self._state.n_samples = X.shape[0]
|
|
265
|
+
self._state.weights = self._solve_weights(self._state.cxx)
|
|
266
|
+
self._on_weights_updated()
|
|
267
|
+
|
|
268
|
+
def fit_transform(self, message: AxisArray) -> AxisArray:
|
|
269
|
+
"""Convenience: ``partial_fit`` then ``_process``."""
|
|
270
|
+
self.partial_fit(message)
|
|
271
|
+
return self._process(message)
|
|
272
|
+
|
|
273
|
+
# -- abstract hooks for subclasses ---------------------------------------
|
|
274
|
+
|
|
275
|
+
@abstractmethod
|
|
276
|
+
def _on_weights_updated(self) -> None:
|
|
277
|
+
"""Called after ``self._state.weights`` has been set/updated.
|
|
278
|
+
|
|
279
|
+
Subclasses should build or refresh whatever internal transform
|
|
280
|
+
object they need for :meth:`_process`.
|
|
281
|
+
"""
|
|
282
|
+
...
|
|
283
|
+
|
|
284
|
+
@abstractmethod
|
|
285
|
+
def _process(self, message: AxisArray) -> AxisArray: ...
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# ---------------------------------------------------------------------------
|
|
289
|
+
# Concrete: Linear Regression Rereferencing (LRR)
|
|
290
|
+
# ---------------------------------------------------------------------------
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class LRRSettings(SelfSupervisedRegressionSettings):
|
|
294
|
+
"""Settings for :class:`LRRTransformer`."""
|
|
295
|
+
|
|
296
|
+
min_cluster_size: int = 32
|
|
297
|
+
"""Passed to :class:`AffineTransformTransformer` for the block-diagonal
|
|
298
|
+
merge threshold."""
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@processor_state
|
|
302
|
+
class LRRState(SelfSupervisedRegressionState):
|
|
303
|
+
affine: AffineTransformTransformer | None = None
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class LRRTransformer(
|
|
307
|
+
SelfSupervisedRegressionTransformer[LRRSettings, LRRState],
|
|
308
|
+
):
|
|
309
|
+
"""Adaptive LRR transformer.
|
|
310
|
+
|
|
311
|
+
``partial_fit`` accepts a plain :class:`AxisArray` (self-supervised),
|
|
312
|
+
and the transform step is delegated to an internal :class:`AffineTransformTransformer`.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
# -- state management (clear own state, then delegate to base) ----------
|
|
316
|
+
|
|
317
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
318
|
+
self._state.affine = None
|
|
319
|
+
super()._reset_state(message)
|
|
320
|
+
|
|
321
|
+
# -- weights → affine transform -----------------------------------------
|
|
322
|
+
|
|
323
|
+
def _on_weights_updated(self) -> None:
|
|
324
|
+
xp = get_namespace(self._state.weights)
|
|
325
|
+
dev = array_device(self._state.weights)
|
|
326
|
+
n = self._state.weights.shape[0]
|
|
327
|
+
effective = xp_create(xp.eye, n, dtype=self._state.weights.dtype, device=dev) - self._state.weights
|
|
328
|
+
|
|
329
|
+
# Prefer in-place weight update when the affine transformer supports
|
|
330
|
+
# it (avoids a full _reset_state round-trip on every partial_fit).
|
|
331
|
+
if self._state.affine is not None:
|
|
332
|
+
self._state.affine.set_weights(effective)
|
|
333
|
+
else:
|
|
334
|
+
self._state.affine = AffineTransformTransformer(
|
|
335
|
+
AffineTransformSettings(
|
|
336
|
+
weights=effective,
|
|
337
|
+
axis=self.settings.axis,
|
|
338
|
+
channel_clusters=self.settings.channel_clusters,
|
|
339
|
+
min_cluster_size=self.settings.min_cluster_size,
|
|
340
|
+
)
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# -- transform -----------------------------------------------------------
|
|
344
|
+
|
|
345
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
346
|
+
if self._state.affine is None:
|
|
347
|
+
raise RuntimeError(
|
|
348
|
+
"LRRTransformer has not been fitted. Call partial_fit() or provide pre-calculated weights."
|
|
349
|
+
)
|
|
350
|
+
return self._state.affine(message)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class LRRUnit(
|
|
354
|
+
BaseAdaptiveTransformerUnit[
|
|
355
|
+
LRRSettings,
|
|
356
|
+
AxisArray,
|
|
357
|
+
AxisArray,
|
|
358
|
+
LRRTransformer,
|
|
359
|
+
],
|
|
360
|
+
):
|
|
361
|
+
"""ezmsg Unit wrapping :class:`LRRTransformer`.
|
|
362
|
+
|
|
363
|
+
Follows the :class:`BaseAdaptiveDecompUnit` pattern — accepts
|
|
364
|
+
:class:`AxisArray` (not :class:`SampleMessage`) for self-supervised
|
|
365
|
+
training via ``INPUT_SAMPLE``.
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
SETTINGS = LRRSettings
|
|
369
|
+
|
|
370
|
+
INPUT_SAMPLE = ez.InputStream(AxisArray)
|
|
371
|
+
|
|
372
|
+
@ez.subscriber(INPUT_SAMPLE)
|
|
373
|
+
async def on_sample(self, msg: AxisArray) -> None:
|
|
374
|
+
await self.processor.apartial_fit(msg)
|
ezmsg/learn/process/torch.py
CHANGED
|
@@ -4,15 +4,14 @@ import typing
|
|
|
4
4
|
import ezmsg.core as ez
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
7
|
-
from ezmsg.
|
|
7
|
+
from ezmsg.baseproc import (
|
|
8
8
|
BaseAdaptiveTransformer,
|
|
9
9
|
BaseAdaptiveTransformerUnit,
|
|
10
10
|
BaseStatefulTransformer,
|
|
11
11
|
BaseTransformerUnit,
|
|
12
12
|
processor_state,
|
|
13
13
|
)
|
|
14
|
-
from ezmsg.
|
|
15
|
-
from ezmsg.sigproc.util.profile import profile_subpub
|
|
14
|
+
from ezmsg.baseproc.util.profile import profile_subpub
|
|
16
15
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
17
16
|
from ezmsg.util.messages.util import replace
|
|
18
17
|
|
|
@@ -113,9 +112,7 @@ class TorchProcessorMixin:
|
|
|
113
112
|
module = importlib.import_module(module_path)
|
|
114
113
|
return getattr(module, class_name)
|
|
115
114
|
|
|
116
|
-
def _infer_output_sizes(
|
|
117
|
-
self: P, model: torch.nn.Module, n_input: int
|
|
118
|
-
) -> dict[str, int]:
|
|
115
|
+
def _infer_output_sizes(self: P, model: torch.nn.Module, n_input: int) -> dict[str, int]:
|
|
119
116
|
"""Simple inference to get output channel size. Override if needed."""
|
|
120
117
|
dummy_input = torch.zeros(1, 1, n_input, device=self._state.device)
|
|
121
118
|
with torch.no_grad():
|
|
@@ -133,9 +130,7 @@ class TorchProcessorMixin:
|
|
|
133
130
|
weight_decay=self.settings.weight_decay,
|
|
134
131
|
)
|
|
135
132
|
self._state.scheduler = (
|
|
136
|
-
torch.optim.lr_scheduler.ExponentialLR(
|
|
137
|
-
self._state.optimizer, gamma=self.settings.scheduler_gamma
|
|
138
|
-
)
|
|
133
|
+
torch.optim.lr_scheduler.ExponentialLR(self._state.optimizer, gamma=self.settings.scheduler_gamma)
|
|
139
134
|
if self.settings.scheduler_gamma > 0.0
|
|
140
135
|
else None
|
|
141
136
|
)
|
|
@@ -191,9 +186,7 @@ class TorchProcessorMixin:
|
|
|
191
186
|
output_messages = [
|
|
192
187
|
replace(
|
|
193
188
|
message,
|
|
194
|
-
data=value.cpu().numpy().squeeze(0)
|
|
195
|
-
if added_batch_dim
|
|
196
|
-
else value.cpu().numpy(),
|
|
189
|
+
data=value.cpu().numpy().squeeze(0) if added_batch_dim else value.cpu().numpy(),
|
|
197
190
|
axes={
|
|
198
191
|
**message.axes,
|
|
199
192
|
"ch": self._state.chan_ax[key],
|
|
@@ -207,9 +200,7 @@ class TorchProcessorMixin:
|
|
|
207
200
|
return [
|
|
208
201
|
replace(
|
|
209
202
|
message,
|
|
210
|
-
data=output.cpu().numpy().squeeze(0)
|
|
211
|
-
if added_batch_dim
|
|
212
|
-
else output.cpu().numpy(),
|
|
203
|
+
data=output.cpu().numpy().squeeze(0) if added_batch_dim else output.cpu().numpy(),
|
|
213
204
|
axes={
|
|
214
205
|
**message.axes,
|
|
215
206
|
"ch": self._state.chan_ax["output"],
|
|
@@ -229,11 +220,7 @@ class TorchProcessorMixin:
|
|
|
229
220
|
else:
|
|
230
221
|
model_kwargs["input_size"] = n_input
|
|
231
222
|
|
|
232
|
-
device = (
|
|
233
|
-
"cuda"
|
|
234
|
-
if torch.cuda.is_available()
|
|
235
|
-
else ("mps" if torch.mps.is_available() else "cpu")
|
|
236
|
-
)
|
|
223
|
+
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
|
|
237
224
|
device = self.settings.device or device
|
|
238
225
|
self._state.device = torch.device(device)
|
|
239
226
|
|
|
@@ -260,9 +247,7 @@ class TorchProcessorMixin:
|
|
|
260
247
|
|
|
261
248
|
|
|
262
249
|
class TorchSimpleProcessor(
|
|
263
|
-
BaseStatefulTransformer[
|
|
264
|
-
TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState
|
|
265
|
-
],
|
|
250
|
+
BaseStatefulTransformer[TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState],
|
|
266
251
|
TorchProcessorMixin,
|
|
267
252
|
ModelInitMixin,
|
|
268
253
|
):
|
|
@@ -308,13 +293,13 @@ class TorchModelProcessor(
|
|
|
308
293
|
def _process(self, message: AxisArray) -> list[AxisArray]:
|
|
309
294
|
return self._common_process(message)
|
|
310
295
|
|
|
311
|
-
def partial_fit(self, message:
|
|
296
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
312
297
|
self._state.model.train()
|
|
313
298
|
|
|
314
|
-
X = self._to_tensor(message.
|
|
299
|
+
X = self._to_tensor(message.data)
|
|
315
300
|
X, batched = self._ensure_batched(X)
|
|
316
301
|
|
|
317
|
-
y_targ = message.trigger.value
|
|
302
|
+
y_targ = message.attrs["trigger"].value
|
|
318
303
|
if not isinstance(y_targ, dict):
|
|
319
304
|
y_targ = {"output": y_targ}
|
|
320
305
|
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
|
|
@@ -339,9 +324,7 @@ class TorchModelProcessor(
|
|
|
339
324
|
for key in y_targ.keys():
|
|
340
325
|
loss_fn = loss_fns.get(key)
|
|
341
326
|
if loss_fn is None:
|
|
342
|
-
raise ValueError(
|
|
343
|
-
f"Loss function for key '{key}' is not defined in settings."
|
|
344
|
-
)
|
|
327
|
+
raise ValueError(f"Loss function for key '{key}' is not defined in settings.")
|
|
345
328
|
if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
|
|
346
329
|
loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long())
|
|
347
330
|
else:
|