ezmsg-learn 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ezmsg/learn/__init__.py +2 -0
  2. ezmsg/learn/__version__.py +34 -0
  3. ezmsg/learn/dim_reduce/__init__.py +0 -0
  4. ezmsg/learn/dim_reduce/adaptive_decomp.py +274 -0
  5. ezmsg/learn/dim_reduce/incremental_decomp.py +173 -0
  6. ezmsg/learn/linear_model/__init__.py +1 -0
  7. ezmsg/learn/linear_model/adaptive_linear_regressor.py +12 -0
  8. ezmsg/learn/linear_model/cca.py +1 -0
  9. ezmsg/learn/linear_model/linear_regressor.py +9 -0
  10. ezmsg/learn/linear_model/sgd.py +9 -0
  11. ezmsg/learn/linear_model/slda.py +12 -0
  12. ezmsg/learn/model/__init__.py +0 -0
  13. ezmsg/learn/model/cca.py +122 -0
  14. ezmsg/learn/model/mlp.py +127 -0
  15. ezmsg/learn/model/mlp_old.py +49 -0
  16. ezmsg/learn/model/refit_kalman.py +369 -0
  17. ezmsg/learn/model/rnn.py +160 -0
  18. ezmsg/learn/model/transformer.py +175 -0
  19. ezmsg/learn/nlin_model/__init__.py +1 -0
  20. ezmsg/learn/nlin_model/mlp.py +10 -0
  21. ezmsg/learn/process/__init__.py +0 -0
  22. ezmsg/learn/process/adaptive_linear_regressor.py +142 -0
  23. ezmsg/learn/process/base.py +154 -0
  24. ezmsg/learn/process/linear_regressor.py +95 -0
  25. ezmsg/learn/process/mlp_old.py +188 -0
  26. ezmsg/learn/process/refit_kalman.py +403 -0
  27. ezmsg/learn/process/rnn.py +245 -0
  28. ezmsg/learn/process/sgd.py +117 -0
  29. ezmsg/learn/process/sklearn.py +241 -0
  30. ezmsg/learn/process/slda.py +110 -0
  31. ezmsg/learn/process/ssr.py +374 -0
  32. ezmsg/learn/process/torch.py +362 -0
  33. ezmsg/learn/process/transformer.py +215 -0
  34. ezmsg/learn/util.py +67 -0
  35. ezmsg_learn-1.1.0.dist-info/METADATA +30 -0
  36. ezmsg_learn-1.1.0.dist-info/RECORD +38 -0
  37. ezmsg_learn-1.1.0.dist-info/WHEEL +4 -0
  38. ezmsg_learn-1.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,110 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
6
+ BaseStatefulTransformer,
7
+ BaseTransformerUnit,
8
+ processor_state,
9
+ )
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
+ from ezmsg.util.messages.util import replace
12
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
13
+
14
+ from ..util import ClassifierMessage
15
+
16
+
17
+ class SLDASettings(ez.Settings):
18
+ settings_path: str
19
+ axis: str = "time"
20
+
21
+
22
+ @processor_state
23
+ class SLDAState:
24
+ lda: LDA
25
+ out_template: typing.Optional[ClassifierMessage] = None
26
+
27
+
28
+ class SLDATransformer(BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]):
29
+ def _reset_state(self, message: AxisArray) -> None:
30
+ if self.settings.settings_path[-4:] == ".mat":
31
+ # Expects a very specific format from a specific project. Not for general use.
32
+ import scipy.io as sio
33
+
34
+ matlab_sLDA = sio.loadmat(self.settings.settings_path, squeeze_me=True)
35
+ temp_weights = matlab_sLDA["weights"][1, 1:]
36
+ temp_intercept = matlab_sLDA["weights"][1, 0]
37
+
38
+ # Create weights and use zeros for channels we do not keep.
39
+ channels = matlab_sLDA["channels"] - 4
40
+ channels -= channels[0] # Offsets are wrong somehow.
41
+ n_channels = message.data.shape[message.dims.index("ch")]
42
+ valid_indices = [ch for ch in channels if ch < n_channels]
43
+ full_weights = np.zeros(n_channels)
44
+ full_weights[valid_indices] = temp_weights[: len(valid_indices)]
45
+
46
+ lda = LDA(solver="lsqr", shrinkage="auto")
47
+ lda.classes_ = np.asarray([0, 1])
48
+ lda.coef_ = np.expand_dims(full_weights, axis=0)
49
+ lda.intercept_ = temp_intercept # TODO: Is this supposed to be per-channel? Why the [1, 0]?
50
+ self.state.lda = lda
51
+ # mean = matlab_sLDA['mXtrain']
52
+ # std = matlab_sLDA['sXtrain']
53
+ # lags = matlab_sLDA['lags'] + 1
54
+ else:
55
+ import pickle
56
+
57
+ with open(self.settings.settings_path, "rb") as f:
58
+ self.state.lda = pickle.load(f)
59
+
60
+ # Create template ClassifierMessage using lda.classes_
61
+ out_labels = self.state.lda.classes_.tolist()
62
+ zero_shape = (0, len(out_labels))
63
+ self.state.out_template = ClassifierMessage(
64
+ data=np.zeros(zero_shape, dtype=message.data.dtype),
65
+ dims=[self.settings.axis, "classes"],
66
+ axes={
67
+ self.settings.axis: message.axes[self.settings.axis],
68
+ "classes": AxisArray.CoordinateAxis(data=np.array(out_labels), dims=["classes"]),
69
+ },
70
+ labels=out_labels,
71
+ key=message.key,
72
+ )
73
+
74
+ def _process(self, message: AxisArray) -> ClassifierMessage:
75
+ samp_ax_idx = message.dims.index(self.settings.axis)
76
+ X = np.moveaxis(message.data, samp_ax_idx, 0)
77
+
78
+ if X.shape[0]:
79
+ if isinstance(self.settings.settings_path, str) and self.settings.settings_path[-4:] == ".mat":
80
+ # Assumes F-contiguous weights
81
+ pred_probas = []
82
+ for samp in X:
83
+ tmp = samp.flatten(order="F") * 1e-6
84
+ tmp = np.expand_dims(tmp, axis=0)
85
+ probas = self.state.lda.predict_proba(tmp)
86
+ pred_probas.append(probas)
87
+ pred_probas = np.concatenate(pred_probas, axis=0)
88
+ else:
89
+ # This creates a copy.
90
+ X = X.reshape(X.shape[0], -1)
91
+ pred_probas = self.state.lda.predict_proba(X)
92
+
93
+ update_ax = self.state.out_template.axes[self.settings.axis]
94
+ update_ax.offset = message.axes[self.settings.axis].offset
95
+
96
+ return replace(
97
+ self.state.out_template,
98
+ data=pred_probas,
99
+ axes={
100
+ **self.state.out_template.axes,
101
+ # `replace` will copy the minimal set of fields
102
+ self.settings.axis: replace(update_ax, offset=update_ax.offset),
103
+ },
104
+ )
105
+ else:
106
+ return self.state.out_template
107
+
108
+
109
+ class SLDA(BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]):
110
+ SETTINGS = SLDASettings
@@ -0,0 +1,374 @@
1
+ """Self-supervised regression framework and LRR implementation.
2
+
3
+ This module provides a general framework for self-supervised channel
4
+ regression via :class:`SelfSupervisedRegressionTransformer`, and a
5
+ concrete implementation — Linear Regression Rereferencing (LRR) — via
6
+ :class:`LRRTransformer`.
7
+
8
+ **Framework.** The base class accumulates the channel covariance
9
+ ``C = X^T X`` and solves per-cluster ridge regressions to obtain a weight
10
+ matrix *W*. Subclasses define what to *do* with *W* by implementing
11
+ :meth:`~SelfSupervisedRegressionTransformer._on_weights_updated` and
12
+ :meth:`~SelfSupervisedRegressionTransformer._process`.
13
+
14
+ **LRR.** For each channel *c*, predict it from the other channels in its
15
+ cluster via ridge regression, then subtract the prediction::
16
+
17
+ y = X - X @ W = X @ (I - W)
18
+
19
+ The effective weight matrix ``I - W`` is passed to
20
+ :class:`~ezmsg.sigproc.affinetransform.AffineTransformTransformer`, which
21
+ automatically exploits block-diagonal structure when ``channel_clusters``
22
+ are provided.
23
+
24
+ **Fitting.** Given data matrix *X* of shape ``(samples, channels)``, the
25
+ sufficient statistic is the channel covariance ``C = X^T X``. When
26
+ ``incremental=True`` (default), *C* is accumulated across
27
+ :meth:`~SelfSupervisedRegressionTransformer.partial_fit` calls.
28
+
29
+ **Solving.** Within each cluster the weight matrix *W* is obtained from
30
+ the inverse of the (ridge-regularised) cluster covariance
31
+ ``C_inv = (C_cluster + lambda * I)^{-1}`` using the block-inverse identity::
32
+
33
+ W[:, c] = -C_inv[:, c] / C_inv[c, c], diag(W) = 0
34
+
35
+ This replaces the naive per-channel Cholesky loop with a single matrix
36
+ inverse per cluster, keeping the linear algebra in the source array
37
+ namespace so that GPU-backed arrays benefit from device-side computation.
38
+ """
39
+
40
+ from __future__ import annotations
41
+
42
+ import os
43
+ import typing
44
+ from abc import abstractmethod
45
+ from pathlib import Path
46
+
47
+ import ezmsg.core as ez
48
+ import numpy as np
49
+ from array_api_compat import get_namespace
50
+ from ezmsg.baseproc import (
51
+ BaseAdaptiveTransformer,
52
+ BaseAdaptiveTransformerUnit,
53
+ processor_state,
54
+ )
55
+ from ezmsg.baseproc.protocols import SettingsType, StateType
56
+ from ezmsg.sigproc.affinetransform import (
57
+ AffineTransformSettings,
58
+ AffineTransformTransformer,
59
+ )
60
+ from ezmsg.sigproc.util.array import array_device, xp_create
61
+ from ezmsg.util.messages.axisarray import AxisArray
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Base: Self-supervised regression
65
+ # ---------------------------------------------------------------------------
66
+
67
+
68
+ class SelfSupervisedRegressionSettings(ez.Settings):
69
+ """Settings common to all self-supervised regression modes."""
70
+
71
+ weights: np.ndarray | str | Path | None = None
72
+ """Pre-calculated weight matrix *W* or path to a CSV file (``np.loadtxt``
73
+ compatible). If provided, the transformer is ready immediately."""
74
+
75
+ axis: str | None = None
76
+ """Channel axis name. ``None`` defaults to the last dimension."""
77
+
78
+ channel_clusters: list[list[int]] | None = None
79
+ """Per-cluster regression. ``None`` treats all channels as one cluster."""
80
+
81
+ ridge_lambda: float = 0.0
82
+ """Ridge (L2) regularisation parameter."""
83
+
84
+ incremental: bool = True
85
+ """When ``True``, accumulate ``X^T X`` across :meth:`partial_fit` calls.
86
+ When ``False``, each call replaces the previous statistics."""
87
+
88
+
89
+ @processor_state
90
+ class SelfSupervisedRegressionState:
91
+ cxx: object | None = None # Array API; namespace matches source data.
92
+ n_samples: int = 0
93
+ weights: object | None = None # Array API; namespace matches cxx.
94
+
95
+
96
+ class SelfSupervisedRegressionTransformer(
97
+ BaseAdaptiveTransformer[SettingsType, AxisArray, AxisArray, StateType],
98
+ typing.Generic[SettingsType, StateType],
99
+ ):
100
+ """Abstract base for self-supervised regression transformers.
101
+
102
+ Subclasses must implement:
103
+
104
+ * :meth:`_on_weights_updated` — called whenever the weight matrix *W* is
105
+ (re)computed, so the subclass can build whatever internal transform it
106
+ needs (e.g. ``I - W`` for LRR).
107
+ * :meth:`_process` — the per-message transform step.
108
+ """
109
+
110
+ # -- message hash / state management ------------------------------------
111
+
112
+ def _hash_message(self, message: AxisArray) -> int:
113
+ axis = self.settings.axis or message.dims[-1]
114
+ axis_idx = message.get_axis_idx(axis)
115
+ return hash((message.key, message.data.shape[axis_idx]))
116
+
117
+ def _reset_state(self, message: AxisArray) -> None:
118
+ axis = self.settings.axis or message.dims[-1]
119
+ axis_idx = message.get_axis_idx(axis)
120
+ n_channels = message.data.shape[axis_idx]
121
+
122
+ self._validate_clusters(n_channels)
123
+ self._state.cxx = None
124
+ self._state.n_samples = 0
125
+ self._state.weights = None
126
+
127
+ # If pre-calculated weights are provided, load and go.
128
+ weights = self.settings.weights
129
+ if weights is not None:
130
+ if isinstance(weights, str):
131
+ weights = Path(os.path.abspath(os.path.expanduser(weights)))
132
+ if isinstance(weights, Path):
133
+ weights = np.loadtxt(weights, delimiter=",")
134
+ weights = np.asarray(weights, dtype=np.float64)
135
+ self._state.weights = weights
136
+ self._on_weights_updated()
137
+
138
+ # -- cluster validation --------------------------------------------------
139
+
140
+ def _validate_clusters(self, n_channels: int) -> None:
141
+ """Raise if any cluster index is out of range."""
142
+ clusters = self.settings.channel_clusters
143
+ if clusters is None:
144
+ return
145
+ all_indices = np.concatenate([np.asarray(g) for g in clusters])
146
+ if np.any((all_indices < 0) | (all_indices >= n_channels)):
147
+ raise ValueError(f"channel_clusters contains out-of-range indices (valid range: 0..{n_channels - 1})")
148
+
149
+ # -- weight solving ------------------------------------------------------
150
+
151
+ def _solve_weights(self, cxx):
152
+ """Solve all per-channel ridge regressions via matrix inverse.
153
+
154
+ Uses the block-inverse identity: for target channel *c* with
155
+ references *r*, ``w_c = -C_inv[r, c] / C_inv[c, c]`` where
156
+ ``C_inv = (C_cluster + λI)⁻¹``. This replaces the per-channel
157
+ Cholesky loop with one matrix inverse per cluster.
158
+
159
+ All computation stays in the source array namespace so that
160
+ GPU-backed arrays benefit from device-side execution. Cluster
161
+ results are scattered into the full matrix via a selection-matrix
162
+ multiply (``S @ W_cluster @ S^T``) to avoid numpy fancy indexing.
163
+
164
+ Returns weight matrix *W* in the same namespace as *cxx*, with
165
+ ``diag(W) == 0``.
166
+ """
167
+ xp = get_namespace(cxx)
168
+ dev = array_device(cxx)
169
+ n = cxx.shape[0]
170
+
171
+ clusters = self.settings.channel_clusters
172
+ if clusters is None:
173
+ clusters = [list(range(n))]
174
+
175
+ W = xp_create(xp.zeros, (n, n), dtype=cxx.dtype, device=dev)
176
+ eye_n = xp_create(xp.eye, n, dtype=cxx.dtype, device=dev)
177
+
178
+ for cluster in clusters:
179
+ k = len(cluster)
180
+ if k <= 1:
181
+ continue
182
+
183
+ idx_xp = xp.asarray(cluster) if dev is None else xp.asarray(cluster, device=dev)
184
+ eye_k = xp_create(xp.eye, k, dtype=cxx.dtype, device=dev)
185
+
186
+ # Extract cluster sub-covariance (stays on device)
187
+ sub = xp.take(xp.take(cxx, idx_xp, axis=0), idx_xp, axis=1)
188
+
189
+ if self.settings.ridge_lambda > 0:
190
+ sub = sub + self.settings.ridge_lambda * eye_k
191
+
192
+ # One inverse per cluster
193
+ try:
194
+ sub_inv = xp.linalg.inv(sub)
195
+ except Exception:
196
+ sub_inv = xp.linalg.pinv(sub)
197
+
198
+ # Diagonal via element-wise product with identity
199
+ diag_vals = xp.sum(sub_inv * eye_k, axis=0)
200
+
201
+ # w_c = -C_inv[:, c] / C_inv[c, c], vectorised over all c
202
+ W_cluster = -(sub_inv / xp.reshape(diag_vals, (1, k)))
203
+
204
+ # Zero the diagonal
205
+ W_cluster = W_cluster * (1.0 - eye_k)
206
+
207
+ # Scatter into full W
208
+ if k == n:
209
+ W = W + W_cluster
210
+ else:
211
+ # Selection matrix: columns of eye(n) at cluster indices
212
+ S = xp.take(eye_n, idx_xp, axis=1) # (n, k)
213
+ W = W + xp.matmul(S, xp.matmul(W_cluster, xp.permute_dims(S, (1, 0))))
214
+
215
+ return W
216
+
217
+ # -- partial_fit (self-supervised, accepts AxisArray) --------------------
218
+
219
+ def partial_fit(self, message: AxisArray) -> None: # type: ignore[override]
220
+ xp = get_namespace(message.data)
221
+
222
+ if xp.any(xp.isnan(message.data)):
223
+ return
224
+
225
+ # Hash check / state reset
226
+ msg_hash = self._hash_message(message)
227
+ if self._hash != msg_hash:
228
+ self._reset_state(message)
229
+ self._hash = msg_hash
230
+
231
+ axis = self.settings.axis or message.dims[-1]
232
+ axis_idx = message.get_axis_idx(axis)
233
+ data = message.data
234
+
235
+ # Move channel axis to last, flatten to 2-D
236
+ if axis_idx != data.ndim - 1:
237
+ perm = list(range(data.ndim))
238
+ perm.append(perm.pop(axis_idx))
239
+ data = xp.permute_dims(data, perm)
240
+
241
+ n_channels = data.shape[-1]
242
+ X = xp.reshape(data, (-1, n_channels))
243
+
244
+ # Covariance stays in the source namespace for accumulation.
245
+ cxx_new = xp.matmul(xp.permute_dims(X, (1, 0)), X)
246
+
247
+ if self.settings.incremental and self._state.cxx is not None:
248
+ self._state.cxx = self._state.cxx + cxx_new
249
+ else:
250
+ self._state.cxx = cxx_new
251
+ self._state.n_samples += int(X.shape[0])
252
+
253
+ self._state.weights = self._solve_weights(self._state.cxx)
254
+ self._on_weights_updated()
255
+
256
+ # -- convenience APIs ----------------------------------------------------
257
+
258
+ def fit(self, X: np.ndarray) -> None:
259
+ """Batch fit from a raw numpy array (samples x channels)."""
260
+ n_channels = X.shape[-1]
261
+ self._validate_clusters(n_channels)
262
+ X = np.asarray(X, dtype=np.float64).reshape(-1, n_channels)
263
+ self._state.cxx = X.T @ X
264
+ self._state.n_samples = X.shape[0]
265
+ self._state.weights = self._solve_weights(self._state.cxx)
266
+ self._on_weights_updated()
267
+
268
+ def fit_transform(self, message: AxisArray) -> AxisArray:
269
+ """Convenience: ``partial_fit`` then ``_process``."""
270
+ self.partial_fit(message)
271
+ return self._process(message)
272
+
273
+ # -- abstract hooks for subclasses ---------------------------------------
274
+
275
+ @abstractmethod
276
+ def _on_weights_updated(self) -> None:
277
+ """Called after ``self._state.weights`` has been set/updated.
278
+
279
+ Subclasses should build or refresh whatever internal transform
280
+ object they need for :meth:`_process`.
281
+ """
282
+ ...
283
+
284
+ @abstractmethod
285
+ def _process(self, message: AxisArray) -> AxisArray: ...
286
+
287
+
288
+ # ---------------------------------------------------------------------------
289
+ # Concrete: Linear Regression Rereferencing (LRR)
290
+ # ---------------------------------------------------------------------------
291
+
292
+
293
+ class LRRSettings(SelfSupervisedRegressionSettings):
294
+ """Settings for :class:`LRRTransformer`."""
295
+
296
+ min_cluster_size: int = 32
297
+ """Passed to :class:`AffineTransformTransformer` for the block-diagonal
298
+ merge threshold."""
299
+
300
+
301
+ @processor_state
302
+ class LRRState(SelfSupervisedRegressionState):
303
+ affine: AffineTransformTransformer | None = None
304
+
305
+
306
+ class LRRTransformer(
307
+ SelfSupervisedRegressionTransformer[LRRSettings, LRRState],
308
+ ):
309
+ """Adaptive LRR transformer.
310
+
311
+ ``partial_fit`` accepts a plain :class:`AxisArray` (self-supervised),
312
+ and the transform step is delegated to an internal :class:`AffineTransformTransformer`.
313
+ """
314
+
315
+ # -- state management (clear own state, then delegate to base) ----------
316
+
317
+ def _reset_state(self, message: AxisArray) -> None:
318
+ self._state.affine = None
319
+ super()._reset_state(message)
320
+
321
+ # -- weights → affine transform -----------------------------------------
322
+
323
+ def _on_weights_updated(self) -> None:
324
+ xp = get_namespace(self._state.weights)
325
+ dev = array_device(self._state.weights)
326
+ n = self._state.weights.shape[0]
327
+ effective = xp_create(xp.eye, n, dtype=self._state.weights.dtype, device=dev) - self._state.weights
328
+
329
+ # Prefer in-place weight update when the affine transformer supports
330
+ # it (avoids a full _reset_state round-trip on every partial_fit).
331
+ if self._state.affine is not None:
332
+ self._state.affine.set_weights(effective)
333
+ else:
334
+ self._state.affine = AffineTransformTransformer(
335
+ AffineTransformSettings(
336
+ weights=effective,
337
+ axis=self.settings.axis,
338
+ channel_clusters=self.settings.channel_clusters,
339
+ min_cluster_size=self.settings.min_cluster_size,
340
+ )
341
+ )
342
+
343
+ # -- transform -----------------------------------------------------------
344
+
345
+ def _process(self, message: AxisArray) -> AxisArray:
346
+ if self._state.affine is None:
347
+ raise RuntimeError(
348
+ "LRRTransformer has not been fitted. Call partial_fit() or provide pre-calculated weights."
349
+ )
350
+ return self._state.affine(message)
351
+
352
+
353
+ class LRRUnit(
354
+ BaseAdaptiveTransformerUnit[
355
+ LRRSettings,
356
+ AxisArray,
357
+ AxisArray,
358
+ LRRTransformer,
359
+ ],
360
+ ):
361
+ """ezmsg Unit wrapping :class:`LRRTransformer`.
362
+
363
+ Follows the :class:`BaseAdaptiveDecompUnit` pattern — accepts
364
+ :class:`AxisArray` (not :class:`SampleMessage`) for self-supervised
365
+ training via ``INPUT_SAMPLE``.
366
+ """
367
+
368
+ SETTINGS = LRRSettings
369
+
370
+ INPUT_SAMPLE = ez.InputStream(AxisArray)
371
+
372
+ @ez.subscriber(INPUT_SAMPLE)
373
+ async def on_sample(self, msg: AxisArray) -> None:
374
+ await self.processor.apartial_fit(msg)