ezmsg-learn 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,12 +5,11 @@ import typing
5
5
  import ezmsg.core as ez
6
6
  import numpy as np
7
7
  import pandas as pd
8
- from ezmsg.sigproc.base import (
8
+ from ezmsg.baseproc import (
9
9
  BaseAdaptiveTransformer,
10
10
  BaseAdaptiveTransformerUnit,
11
11
  processor_state,
12
12
  )
13
- from ezmsg.sigproc.sampler import SampleMessage
14
13
  from ezmsg.util.messages.axisarray import AxisArray
15
14
  from ezmsg.util.messages.util import replace
16
15
 
@@ -45,11 +44,7 @@ class SklearnModelState:
45
44
  chan_ax: AxisArray.CoordinateAxis | None = None
46
45
 
47
46
 
48
- class SklearnModelProcessor(
49
- BaseAdaptiveTransformer[
50
- SklearnModelSettings, AxisArray, AxisArray, SklearnModelState
51
- ]
52
- ):
47
+ class SklearnModelProcessor(BaseAdaptiveTransformer[SklearnModelSettings, AxisArray, AxisArray, SklearnModelState]):
53
48
  """
54
49
  Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework.
55
50
 
@@ -115,40 +110,30 @@ class SklearnModelProcessor(
115
110
  if hasattr(self._state.model, "n_features_in_"):
116
111
  expected = self._state.model.n_features_in_
117
112
  if expected != n_input:
118
- raise ValueError(
119
- f"Model expects {expected} features, but got {n_input}"
120
- )
113
+ raise ValueError(f"Model expects {expected} features, but got {n_input}")
121
114
  else:
122
115
  # No checkpoint, initialize from scratch
123
116
  self._init_model()
124
117
 
125
- def partial_fit(self, message: SampleMessage) -> None:
126
- X = message.sample.data
127
- y = message.trigger.value
118
+ def partial_fit(self, message: AxisArray) -> None:
119
+ X = message.data
120
+ y = message.attrs["trigger"].value
128
121
  if self._state.model is None:
129
- self._reset_state(message.sample)
122
+ self._reset_state(message)
130
123
  if hasattr(self._state.model, "partial_fit"):
131
124
  kwargs = {}
132
125
  if self.settings.partial_fit_classes is not None:
133
126
  kwargs["classes"] = self.settings.partial_fit_classes
134
127
  self._state.model.partial_fit(X, y, **kwargs)
135
128
  elif hasattr(self._state.model, "learn_many"):
136
- df_X = pd.DataFrame(
137
- {
138
- k: v
139
- for k, v in zip(
140
- message.sample.axes["ch"].data, message.sample.data.T
141
- )
142
- }
143
- )
129
+ df_X = pd.DataFrame({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
144
130
  name = (
145
- message.trigger.value.axes["ch"].data[0]
146
- if hasattr(message.trigger.value, "axes")
147
- and "ch" in message.trigger.value.axes
131
+ message.attrs["trigger"].value.axes["ch"].data[0]
132
+ if hasattr(message.attrs["trigger"].value, "axes") and "ch" in message.attrs["trigger"].value.axes
148
133
  else "target"
149
134
  )
150
135
  ser_y = pd.Series(
151
- data=np.asarray(message.trigger.value.data).flatten(),
136
+ data=np.asarray(message.attrs["trigger"].value.data).flatten(),
152
137
  name=name,
153
138
  )
154
139
  self._state.model.learn_many(df_X, ser_y)
@@ -158,9 +143,7 @@ class SklearnModelProcessor(
158
143
  features = {f"f{i}": xi[i] for i in range(len(xi))}
159
144
  self._state.model.learn_one(features, yi)
160
145
  else:
161
- raise NotImplementedError(
162
- "Model does not support partial_fit or learn_many"
163
- )
146
+ raise NotImplementedError("Model does not support partial_fit or learn_many")
164
147
 
165
148
  def fit(self, X: np.ndarray, y: np.ndarray) -> None:
166
149
  if self._state.model is None:
@@ -192,9 +175,7 @@ class SklearnModelProcessor(
192
175
 
193
176
  def _process(self, message: AxisArray) -> AxisArray:
194
177
  if self._state.model is None:
195
- raise RuntimeError(
196
- "Model has not been fit yet. Call `fit()` or `partial_fit()` before processing."
197
- )
178
+ raise RuntimeError("Model has not been fit yet. Call `fit()` or `partial_fit()` before processing.")
198
179
  X = message.data
199
180
  original_shape = X.shape
200
181
  n_input = X.shape[message.get_axis_idx("ch")]
@@ -204,9 +185,7 @@ class SklearnModelProcessor(
204
185
  if hasattr(self._state.model, "n_features_in_"):
205
186
  expected = self._state.model.n_features_in_
206
187
  if expected != n_input:
207
- raise ValueError(
208
- f"Model expects {expected} features, but got {n_input}"
209
- )
188
+ raise ValueError(f"Model expects {expected} features, but got {n_input}")
210
189
 
211
190
  if hasattr(self._state.model, "predict"):
212
191
  y_pred = self._state.model.predict(X)
@@ -216,14 +195,7 @@ class SklearnModelProcessor(
216
195
  y_pred = np.array(list(y_pred))
217
196
  elif hasattr(self._state.model, "predict_one"):
218
197
  # river's random forest does not support predict_many
219
- y_pred = np.array(
220
- [
221
- self._state.model.predict_one(
222
- {f"f{i}": xi[i] for i in range(len(xi))}
223
- )
224
- for xi in X
225
- ]
226
- )
198
+ y_pred = np.array([self._state.model.predict_one({f"f{i}": xi[i] for i in range(len(xi))}) for xi in X])
227
199
  else:
228
200
  raise NotImplementedError("Model does not support predict or predict_many")
229
201
 
@@ -235,9 +207,7 @@ class SklearnModelProcessor(
235
207
  y_pred = y_pred.reshape(output_shape)
236
208
 
237
209
  if self._state.chan_ax is None:
238
- self._state.chan_ax = AxisArray.CoordinateAxis(
239
- data=np.arange(output_shape[1]), dims=["ch"]
240
- )
210
+ self._state.chan_ax = AxisArray.CoordinateAxis(data=np.arange(output_shape[1]), dims=["ch"])
241
211
 
242
212
  return replace(
243
213
  message,
@@ -246,11 +216,7 @@ class SklearnModelProcessor(
246
216
  )
247
217
 
248
218
 
249
- class SklearnModelUnit(
250
- BaseAdaptiveTransformerUnit[
251
- SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor
252
- ]
253
- ):
219
+ class SklearnModelUnit(BaseAdaptiveTransformerUnit[SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor]):
254
220
  """
255
221
  Unit wrapper for the `SklearnModelProcessor`.
256
222
 
@@ -2,10 +2,10 @@ import typing
2
2
 
3
3
  import ezmsg.core as ez
4
4
  import numpy as np
5
- from ezmsg.sigproc.base import (
5
+ from ezmsg.baseproc import (
6
6
  BaseStatefulTransformer,
7
- processor_state,
8
7
  BaseTransformerUnit,
8
+ processor_state,
9
9
  )
10
10
  from ezmsg.util.messages.axisarray import AxisArray
11
11
  from ezmsg.util.messages.util import replace
@@ -25,9 +25,7 @@ class SLDAState:
25
25
  out_template: typing.Optional[ClassifierMessage] = None
26
26
 
27
27
 
28
- class SLDATransformer(
29
- BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]
30
- ):
28
+ class SLDATransformer(BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]):
31
29
  def _reset_state(self, message: AxisArray) -> None:
32
30
  if self.settings.settings_path[-4:] == ".mat":
33
31
  # Expects a very specific format from a specific project. Not for general use.
@@ -67,9 +65,7 @@ class SLDATransformer(
67
65
  dims=[self.settings.axis, "classes"],
68
66
  axes={
69
67
  self.settings.axis: message.axes[self.settings.axis],
70
- "classes": AxisArray.CoordinateAxis(
71
- data=np.array(out_labels), dims=["classes"]
72
- ),
68
+ "classes": AxisArray.CoordinateAxis(data=np.array(out_labels), dims=["classes"]),
73
69
  },
74
70
  labels=out_labels,
75
71
  key=message.key,
@@ -80,10 +76,7 @@ class SLDATransformer(
80
76
  X = np.moveaxis(message.data, samp_ax_idx, 0)
81
77
 
82
78
  if X.shape[0]:
83
- if (
84
- isinstance(self.settings.settings_path, str)
85
- and self.settings.settings_path[-4:] == ".mat"
86
- ):
79
+ if isinstance(self.settings.settings_path, str) and self.settings.settings_path[-4:] == ".mat":
87
80
  # Assumes F-contiguous weights
88
81
  pred_probas = []
89
82
  for samp in X:
@@ -113,7 +106,5 @@ class SLDATransformer(
113
106
  return self.state.out_template
114
107
 
115
108
 
116
- class SLDA(
117
- BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]
118
- ):
109
+ class SLDA(BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]):
119
110
  SETTINGS = SLDASettings
@@ -0,0 +1,374 @@
1
+ """Self-supervised regression framework and LRR implementation.
2
+
3
+ This module provides a general framework for self-supervised channel
4
+ regression via :class:`SelfSupervisedRegressionTransformer`, and a
5
+ concrete implementation — Linear Regression Rereferencing (LRR) — via
6
+ :class:`LRRTransformer`.
7
+
8
+ **Framework.** The base class accumulates the channel covariance
9
+ ``C = X^T X`` and solves per-cluster ridge regressions to obtain a weight
10
+ matrix *W*. Subclasses define what to *do* with *W* by implementing
11
+ :meth:`~SelfSupervisedRegressionTransformer._on_weights_updated` and
12
+ :meth:`~SelfSupervisedRegressionTransformer._process`.
13
+
14
+ **LRR.** For each channel *c*, predict it from the other channels in its
15
+ cluster via ridge regression, then subtract the prediction::
16
+
17
+ y = X - X @ W = X @ (I - W)
18
+
19
+ The effective weight matrix ``I - W`` is passed to
20
+ :class:`~ezmsg.sigproc.affinetransform.AffineTransformTransformer`, which
21
+ automatically exploits block-diagonal structure when ``channel_clusters``
22
+ are provided.
23
+
24
+ **Fitting.** Given data matrix *X* of shape ``(samples, channels)``, the
25
+ sufficient statistic is the channel covariance ``C = X^T X``. When
26
+ ``incremental=True`` (default), *C* is accumulated across
27
+ :meth:`~SelfSupervisedRegressionTransformer.partial_fit` calls.
28
+
29
+ **Solving.** Within each cluster the weight matrix *W* is obtained from
30
+ the inverse of the (ridge-regularised) cluster covariance
31
+ ``C_inv = (C_cluster + lambda * I)^{-1}`` using the block-inverse identity::
32
+
33
+ W[:, c] = -C_inv[:, c] / C_inv[c, c], diag(W) = 0
34
+
35
+ This replaces the naive per-channel Cholesky loop with a single matrix
36
+ inverse per cluster, keeping the linear algebra in the source array
37
+ namespace so that GPU-backed arrays benefit from device-side computation.
38
+ """
39
+
40
+ from __future__ import annotations
41
+
42
+ import os
43
+ import typing
44
+ from abc import abstractmethod
45
+ from pathlib import Path
46
+
47
+ import ezmsg.core as ez
48
+ import numpy as np
49
+ from array_api_compat import get_namespace
50
+ from ezmsg.baseproc import (
51
+ BaseAdaptiveTransformer,
52
+ BaseAdaptiveTransformerUnit,
53
+ processor_state,
54
+ )
55
+ from ezmsg.baseproc.protocols import SettingsType, StateType
56
+ from ezmsg.sigproc.affinetransform import (
57
+ AffineTransformSettings,
58
+ AffineTransformTransformer,
59
+ )
60
+ from ezmsg.sigproc.util.array import array_device, xp_create
61
+ from ezmsg.util.messages.axisarray import AxisArray
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Base: Self-supervised regression
65
+ # ---------------------------------------------------------------------------
66
+
67
+
68
+ class SelfSupervisedRegressionSettings(ez.Settings):
69
+ """Settings common to all self-supervised regression modes."""
70
+
71
+ weights: np.ndarray | str | Path | None = None
72
+ """Pre-calculated weight matrix *W* or path to a CSV file (``np.loadtxt``
73
+ compatible). If provided, the transformer is ready immediately."""
74
+
75
+ axis: str | None = None
76
+ """Channel axis name. ``None`` defaults to the last dimension."""
77
+
78
+ channel_clusters: list[list[int]] | None = None
79
+ """Per-cluster regression. ``None`` treats all channels as one cluster."""
80
+
81
+ ridge_lambda: float = 0.0
82
+ """Ridge (L2) regularisation parameter."""
83
+
84
+ incremental: bool = True
85
+ """When ``True``, accumulate ``X^T X`` across :meth:`partial_fit` calls.
86
+ When ``False``, each call replaces the previous statistics."""
87
+
88
+
89
+ @processor_state
90
+ class SelfSupervisedRegressionState:
91
+ cxx: object | None = None # Array API; namespace matches source data.
92
+ n_samples: int = 0
93
+ weights: object | None = None # Array API; namespace matches cxx.
94
+
95
+
96
+ class SelfSupervisedRegressionTransformer(
97
+ BaseAdaptiveTransformer[SettingsType, AxisArray, AxisArray, StateType],
98
+ typing.Generic[SettingsType, StateType],
99
+ ):
100
+ """Abstract base for self-supervised regression transformers.
101
+
102
+ Subclasses must implement:
103
+
104
+ * :meth:`_on_weights_updated` — called whenever the weight matrix *W* is
105
+ (re)computed, so the subclass can build whatever internal transform it
106
+ needs (e.g. ``I - W`` for LRR).
107
+ * :meth:`_process` — the per-message transform step.
108
+ """
109
+
110
+ # -- message hash / state management ------------------------------------
111
+
112
+ def _hash_message(self, message: AxisArray) -> int:
113
+ axis = self.settings.axis or message.dims[-1]
114
+ axis_idx = message.get_axis_idx(axis)
115
+ return hash((message.key, message.data.shape[axis_idx]))
116
+
117
+ def _reset_state(self, message: AxisArray) -> None:
118
+ axis = self.settings.axis or message.dims[-1]
119
+ axis_idx = message.get_axis_idx(axis)
120
+ n_channels = message.data.shape[axis_idx]
121
+
122
+ self._validate_clusters(n_channels)
123
+ self._state.cxx = None
124
+ self._state.n_samples = 0
125
+ self._state.weights = None
126
+
127
+ # If pre-calculated weights are provided, load and go.
128
+ weights = self.settings.weights
129
+ if weights is not None:
130
+ if isinstance(weights, str):
131
+ weights = Path(os.path.abspath(os.path.expanduser(weights)))
132
+ if isinstance(weights, Path):
133
+ weights = np.loadtxt(weights, delimiter=",")
134
+ weights = np.asarray(weights, dtype=np.float64)
135
+ self._state.weights = weights
136
+ self._on_weights_updated()
137
+
138
+ # -- cluster validation --------------------------------------------------
139
+
140
+ def _validate_clusters(self, n_channels: int) -> None:
141
+ """Raise if any cluster index is out of range."""
142
+ clusters = self.settings.channel_clusters
143
+ if clusters is None:
144
+ return
145
+ all_indices = np.concatenate([np.asarray(g) for g in clusters])
146
+ if np.any((all_indices < 0) | (all_indices >= n_channels)):
147
+ raise ValueError(f"channel_clusters contains out-of-range indices (valid range: 0..{n_channels - 1})")
148
+
149
+ # -- weight solving ------------------------------------------------------
150
+
151
+ def _solve_weights(self, cxx):
152
+ """Solve all per-channel ridge regressions via matrix inverse.
153
+
154
+ Uses the block-inverse identity: for target channel *c* with
155
+ references *r*, ``w_c = -C_inv[r, c] / C_inv[c, c]`` where
156
+ ``C_inv = (C_cluster + λI)⁻¹``. This replaces the per-channel
157
+ Cholesky loop with one matrix inverse per cluster.
158
+
159
+ All computation stays in the source array namespace so that
160
+ GPU-backed arrays benefit from device-side execution. Cluster
161
+ results are scattered into the full matrix via a selection-matrix
162
+ multiply (``S @ W_cluster @ S^T``) to avoid numpy fancy indexing.
163
+
164
+ Returns weight matrix *W* in the same namespace as *cxx*, with
165
+ ``diag(W) == 0``.
166
+ """
167
+ xp = get_namespace(cxx)
168
+ dev = array_device(cxx)
169
+ n = cxx.shape[0]
170
+
171
+ clusters = self.settings.channel_clusters
172
+ if clusters is None:
173
+ clusters = [list(range(n))]
174
+
175
+ W = xp_create(xp.zeros, (n, n), dtype=cxx.dtype, device=dev)
176
+ eye_n = xp_create(xp.eye, n, dtype=cxx.dtype, device=dev)
177
+
178
+ for cluster in clusters:
179
+ k = len(cluster)
180
+ if k <= 1:
181
+ continue
182
+
183
+ idx_xp = xp.asarray(cluster) if dev is None else xp.asarray(cluster, device=dev)
184
+ eye_k = xp_create(xp.eye, k, dtype=cxx.dtype, device=dev)
185
+
186
+ # Extract cluster sub-covariance (stays on device)
187
+ sub = xp.take(xp.take(cxx, idx_xp, axis=0), idx_xp, axis=1)
188
+
189
+ if self.settings.ridge_lambda > 0:
190
+ sub = sub + self.settings.ridge_lambda * eye_k
191
+
192
+ # One inverse per cluster
193
+ try:
194
+ sub_inv = xp.linalg.inv(sub)
195
+ except Exception:
196
+ sub_inv = xp.linalg.pinv(sub)
197
+
198
+ # Diagonal via element-wise product with identity
199
+ diag_vals = xp.sum(sub_inv * eye_k, axis=0)
200
+
201
+ # w_c = -C_inv[:, c] / C_inv[c, c], vectorised over all c
202
+ W_cluster = -(sub_inv / xp.reshape(diag_vals, (1, k)))
203
+
204
+ # Zero the diagonal
205
+ W_cluster = W_cluster * (1.0 - eye_k)
206
+
207
+ # Scatter into full W
208
+ if k == n:
209
+ W = W + W_cluster
210
+ else:
211
+ # Selection matrix: columns of eye(n) at cluster indices
212
+ S = xp.take(eye_n, idx_xp, axis=1) # (n, k)
213
+ W = W + xp.matmul(S, xp.matmul(W_cluster, xp.permute_dims(S, (1, 0))))
214
+
215
+ return W
216
+
217
+ # -- partial_fit (self-supervised, accepts AxisArray) --------------------
218
+
219
+ def partial_fit(self, message: AxisArray) -> None: # type: ignore[override]
220
+ xp = get_namespace(message.data)
221
+
222
+ if xp.any(xp.isnan(message.data)):
223
+ return
224
+
225
+ # Hash check / state reset
226
+ msg_hash = self._hash_message(message)
227
+ if self._hash != msg_hash:
228
+ self._reset_state(message)
229
+ self._hash = msg_hash
230
+
231
+ axis = self.settings.axis or message.dims[-1]
232
+ axis_idx = message.get_axis_idx(axis)
233
+ data = message.data
234
+
235
+ # Move channel axis to last, flatten to 2-D
236
+ if axis_idx != data.ndim - 1:
237
+ perm = list(range(data.ndim))
238
+ perm.append(perm.pop(axis_idx))
239
+ data = xp.permute_dims(data, perm)
240
+
241
+ n_channels = data.shape[-1]
242
+ X = xp.reshape(data, (-1, n_channels))
243
+
244
+ # Covariance stays in the source namespace for accumulation.
245
+ cxx_new = xp.matmul(xp.permute_dims(X, (1, 0)), X)
246
+
247
+ if self.settings.incremental and self._state.cxx is not None:
248
+ self._state.cxx = self._state.cxx + cxx_new
249
+ else:
250
+ self._state.cxx = cxx_new
251
+ self._state.n_samples += int(X.shape[0])
252
+
253
+ self._state.weights = self._solve_weights(self._state.cxx)
254
+ self._on_weights_updated()
255
+
256
+ # -- convenience APIs ----------------------------------------------------
257
+
258
+ def fit(self, X: np.ndarray) -> None:
259
+ """Batch fit from a raw numpy array (samples x channels)."""
260
+ n_channels = X.shape[-1]
261
+ self._validate_clusters(n_channels)
262
+ X = np.asarray(X, dtype=np.float64).reshape(-1, n_channels)
263
+ self._state.cxx = X.T @ X
264
+ self._state.n_samples = X.shape[0]
265
+ self._state.weights = self._solve_weights(self._state.cxx)
266
+ self._on_weights_updated()
267
+
268
+ def fit_transform(self, message: AxisArray) -> AxisArray:
269
+ """Convenience: ``partial_fit`` then ``_process``."""
270
+ self.partial_fit(message)
271
+ return self._process(message)
272
+
273
+ # -- abstract hooks for subclasses ---------------------------------------
274
+
275
+ @abstractmethod
276
+ def _on_weights_updated(self) -> None:
277
+ """Called after ``self._state.weights`` has been set/updated.
278
+
279
+ Subclasses should build or refresh whatever internal transform
280
+ object they need for :meth:`_process`.
281
+ """
282
+ ...
283
+
284
+ @abstractmethod
285
+ def _process(self, message: AxisArray) -> AxisArray: ...
286
+
287
+
288
+ # ---------------------------------------------------------------------------
289
+ # Concrete: Linear Regression Rereferencing (LRR)
290
+ # ---------------------------------------------------------------------------
291
+
292
+
293
+ class LRRSettings(SelfSupervisedRegressionSettings):
294
+ """Settings for :class:`LRRTransformer`."""
295
+
296
+ min_cluster_size: int = 32
297
+ """Passed to :class:`AffineTransformTransformer` for the block-diagonal
298
+ merge threshold."""
299
+
300
+
301
+ @processor_state
302
+ class LRRState(SelfSupervisedRegressionState):
303
+ affine: AffineTransformTransformer | None = None
304
+
305
+
306
+ class LRRTransformer(
307
+ SelfSupervisedRegressionTransformer[LRRSettings, LRRState],
308
+ ):
309
+ """Adaptive LRR transformer.
310
+
311
+ ``partial_fit`` accepts a plain :class:`AxisArray` (self-supervised),
312
+ and the transform step is delegated to an internal :class:`AffineTransformTransformer`.
313
+ """
314
+
315
+ # -- state management (clear own state, then delegate to base) ----------
316
+
317
+ def _reset_state(self, message: AxisArray) -> None:
318
+ self._state.affine = None
319
+ super()._reset_state(message)
320
+
321
+ # -- weights → affine transform -----------------------------------------
322
+
323
+ def _on_weights_updated(self) -> None:
324
+ xp = get_namespace(self._state.weights)
325
+ dev = array_device(self._state.weights)
326
+ n = self._state.weights.shape[0]
327
+ effective = xp_create(xp.eye, n, dtype=self._state.weights.dtype, device=dev) - self._state.weights
328
+
329
+ # Prefer in-place weight update when the affine transformer supports
330
+ # it (avoids a full _reset_state round-trip on every partial_fit).
331
+ if self._state.affine is not None:
332
+ self._state.affine.set_weights(effective)
333
+ else:
334
+ self._state.affine = AffineTransformTransformer(
335
+ AffineTransformSettings(
336
+ weights=effective,
337
+ axis=self.settings.axis,
338
+ channel_clusters=self.settings.channel_clusters,
339
+ min_cluster_size=self.settings.min_cluster_size,
340
+ )
341
+ )
342
+
343
+ # -- transform -----------------------------------------------------------
344
+
345
+ def _process(self, message: AxisArray) -> AxisArray:
346
+ if self._state.affine is None:
347
+ raise RuntimeError(
348
+ "LRRTransformer has not been fitted. Call partial_fit() or provide pre-calculated weights."
349
+ )
350
+ return self._state.affine(message)
351
+
352
+
353
+ class LRRUnit(
354
+ BaseAdaptiveTransformerUnit[
355
+ LRRSettings,
356
+ AxisArray,
357
+ AxisArray,
358
+ LRRTransformer,
359
+ ],
360
+ ):
361
+ """ezmsg Unit wrapping :class:`LRRTransformer`.
362
+
363
+ Follows the :class:`BaseAdaptiveDecompUnit` pattern — accepts
364
+ :class:`AxisArray` (not :class:`SampleMessage`) for self-supervised
365
+ training via ``INPUT_SAMPLE``.
366
+ """
367
+
368
+ SETTINGS = LRRSettings
369
+
370
+ INPUT_SAMPLE = ez.InputStream(AxisArray)
371
+
372
+ @ez.subscriber(INPUT_SAMPLE)
373
+ async def on_sample(self, msg: AxisArray) -> None:
374
+ await self.processor.apartial_fit(msg)
@@ -4,15 +4,14 @@ import typing
4
4
  import ezmsg.core as ez
5
5
  import numpy as np
6
6
  import torch
7
- from ezmsg.sigproc.base import (
7
+ from ezmsg.baseproc import (
8
8
  BaseAdaptiveTransformer,
9
9
  BaseAdaptiveTransformerUnit,
10
10
  BaseStatefulTransformer,
11
11
  BaseTransformerUnit,
12
12
  processor_state,
13
13
  )
14
- from ezmsg.sigproc.sampler import SampleMessage
15
- from ezmsg.sigproc.util.profile import profile_subpub
14
+ from ezmsg.baseproc.util.profile import profile_subpub
16
15
  from ezmsg.util.messages.axisarray import AxisArray
17
16
  from ezmsg.util.messages.util import replace
18
17
 
@@ -113,9 +112,7 @@ class TorchProcessorMixin:
113
112
  module = importlib.import_module(module_path)
114
113
  return getattr(module, class_name)
115
114
 
116
- def _infer_output_sizes(
117
- self: P, model: torch.nn.Module, n_input: int
118
- ) -> dict[str, int]:
115
+ def _infer_output_sizes(self: P, model: torch.nn.Module, n_input: int) -> dict[str, int]:
119
116
  """Simple inference to get output channel size. Override if needed."""
120
117
  dummy_input = torch.zeros(1, 1, n_input, device=self._state.device)
121
118
  with torch.no_grad():
@@ -133,9 +130,7 @@ class TorchProcessorMixin:
133
130
  weight_decay=self.settings.weight_decay,
134
131
  )
135
132
  self._state.scheduler = (
136
- torch.optim.lr_scheduler.ExponentialLR(
137
- self._state.optimizer, gamma=self.settings.scheduler_gamma
138
- )
133
+ torch.optim.lr_scheduler.ExponentialLR(self._state.optimizer, gamma=self.settings.scheduler_gamma)
139
134
  if self.settings.scheduler_gamma > 0.0
140
135
  else None
141
136
  )
@@ -191,9 +186,7 @@ class TorchProcessorMixin:
191
186
  output_messages = [
192
187
  replace(
193
188
  message,
194
- data=value.cpu().numpy().squeeze(0)
195
- if added_batch_dim
196
- else value.cpu().numpy(),
189
+ data=value.cpu().numpy().squeeze(0) if added_batch_dim else value.cpu().numpy(),
197
190
  axes={
198
191
  **message.axes,
199
192
  "ch": self._state.chan_ax[key],
@@ -207,9 +200,7 @@ class TorchProcessorMixin:
207
200
  return [
208
201
  replace(
209
202
  message,
210
- data=output.cpu().numpy().squeeze(0)
211
- if added_batch_dim
212
- else output.cpu().numpy(),
203
+ data=output.cpu().numpy().squeeze(0) if added_batch_dim else output.cpu().numpy(),
213
204
  axes={
214
205
  **message.axes,
215
206
  "ch": self._state.chan_ax["output"],
@@ -229,11 +220,7 @@ class TorchProcessorMixin:
229
220
  else:
230
221
  model_kwargs["input_size"] = n_input
231
222
 
232
- device = (
233
- "cuda"
234
- if torch.cuda.is_available()
235
- else ("mps" if torch.mps.is_available() else "cpu")
236
- )
223
+ device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
237
224
  device = self.settings.device or device
238
225
  self._state.device = torch.device(device)
239
226
 
@@ -260,9 +247,7 @@ class TorchProcessorMixin:
260
247
 
261
248
 
262
249
  class TorchSimpleProcessor(
263
- BaseStatefulTransformer[
264
- TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState
265
- ],
250
+ BaseStatefulTransformer[TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState],
266
251
  TorchProcessorMixin,
267
252
  ModelInitMixin,
268
253
  ):
@@ -308,13 +293,13 @@ class TorchModelProcessor(
308
293
  def _process(self, message: AxisArray) -> list[AxisArray]:
309
294
  return self._common_process(message)
310
295
 
311
- def partial_fit(self, message: SampleMessage) -> None:
296
+ def partial_fit(self, message: AxisArray) -> None:
312
297
  self._state.model.train()
313
298
 
314
- X = self._to_tensor(message.sample.data)
299
+ X = self._to_tensor(message.data)
315
300
  X, batched = self._ensure_batched(X)
316
301
 
317
- y_targ = message.trigger.value
302
+ y_targ = message.attrs["trigger"].value
318
303
  if not isinstance(y_targ, dict):
319
304
  y_targ = {"output": y_targ}
320
305
  y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
@@ -339,9 +324,7 @@ class TorchModelProcessor(
339
324
  for key in y_targ.keys():
340
325
  loss_fn = loss_fns.get(key)
341
326
  if loss_fn is None:
342
- raise ValueError(
343
- f"Loss function for key '{key}' is not defined in settings."
344
- )
327
+ raise ValueError(f"Loss function for key '{key}' is not defined in settings.")
345
328
  if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
346
329
  loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long())
347
330
  else: