ezmsg-learn 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ezmsg/learn/__init__.py +2 -0
  2. ezmsg/learn/__version__.py +34 -0
  3. ezmsg/learn/dim_reduce/__init__.py +0 -0
  4. ezmsg/learn/dim_reduce/adaptive_decomp.py +274 -0
  5. ezmsg/learn/dim_reduce/incremental_decomp.py +173 -0
  6. ezmsg/learn/linear_model/__init__.py +1 -0
  7. ezmsg/learn/linear_model/adaptive_linear_regressor.py +12 -0
  8. ezmsg/learn/linear_model/cca.py +1 -0
  9. ezmsg/learn/linear_model/linear_regressor.py +9 -0
  10. ezmsg/learn/linear_model/sgd.py +9 -0
  11. ezmsg/learn/linear_model/slda.py +12 -0
  12. ezmsg/learn/model/__init__.py +0 -0
  13. ezmsg/learn/model/cca.py +122 -0
  14. ezmsg/learn/model/mlp.py +127 -0
  15. ezmsg/learn/model/mlp_old.py +49 -0
  16. ezmsg/learn/model/refit_kalman.py +369 -0
  17. ezmsg/learn/model/rnn.py +160 -0
  18. ezmsg/learn/model/transformer.py +175 -0
  19. ezmsg/learn/nlin_model/__init__.py +1 -0
  20. ezmsg/learn/nlin_model/mlp.py +10 -0
  21. ezmsg/learn/process/__init__.py +0 -0
  22. ezmsg/learn/process/adaptive_linear_regressor.py +142 -0
  23. ezmsg/learn/process/base.py +154 -0
  24. ezmsg/learn/process/linear_regressor.py +95 -0
  25. ezmsg/learn/process/mlp_old.py +188 -0
  26. ezmsg/learn/process/refit_kalman.py +403 -0
  27. ezmsg/learn/process/rnn.py +245 -0
  28. ezmsg/learn/process/sgd.py +117 -0
  29. ezmsg/learn/process/sklearn.py +241 -0
  30. ezmsg/learn/process/slda.py +110 -0
  31. ezmsg/learn/process/ssr.py +374 -0
  32. ezmsg/learn/process/torch.py +362 -0
  33. ezmsg/learn/process/transformer.py +215 -0
  34. ezmsg/learn/util.py +67 -0
  35. ezmsg_learn-1.1.0.dist-info/METADATA +30 -0
  36. ezmsg_learn-1.1.0.dist-info/RECORD +38 -0
  37. ezmsg_learn-1.1.0.dist-info/WHEEL +4 -0
  38. ezmsg_learn-1.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,245 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ import torch
6
+ from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
7
+ from ezmsg.baseproc.util.profile import profile_subpub
8
+ from ezmsg.sigproc.sampler import SampleMessage
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
11
+
12
+ from .base import ModelInitMixin
13
+ from .torch import (
14
+ TorchModelSettings,
15
+ TorchModelState,
16
+ TorchProcessorMixin,
17
+ )
18
+
19
+
20
+ class RNNSettings(TorchModelSettings):
21
+ model_class: str = "ezmsg.learn.model.rnn.RNNModel"
22
+ """
23
+ Fully qualified class path of the model to be used.
24
+ This should be "ezmsg.learn.model.rnn.RNNModel" for this.
25
+ """
26
+ reset_hidden_on_fit: bool = True
27
+ """
28
+ Whether to reset the hidden state on each fit call.
29
+ If True, the hidden state will be reset to zero after each fit.
30
+ If False, the hidden state will be maintained across fit calls.
31
+ """
32
+ preserve_state_across_windows: bool | typing.Literal["auto"] = "auto"
33
+ """
34
+ Whether to preserve the hidden state across windows.
35
+ If True, the hidden state will be preserved across windows.
36
+ If False, the hidden state will be reset at the start of each window.
37
+ If "auto", preserve if there is no overlap in time windows, otherwise reset.
38
+ """
39
+
40
+
41
+ class RNNState(TorchModelState):
42
+ hx: typing.Optional[torch.Tensor] = None
43
+
44
+
45
+ class RNNProcessor(
46
+ BaseAdaptiveTransformer[RNNSettings, AxisArray, AxisArray, RNNState],
47
+ TorchProcessorMixin,
48
+ ModelInitMixin,
49
+ ):
50
+ def _infer_output_sizes(self, model: torch.nn.Module, n_input: int) -> dict[str, int]:
51
+ """Simple inference to get output channel size."""
52
+ dummy_input = torch.zeros(1, 50, n_input, device=self._state.device)
53
+ with torch.no_grad():
54
+ output, _ = model(dummy_input)
55
+
56
+ if isinstance(output, dict):
57
+ return {k: v.shape[-1] for k, v in output.items()}
58
+ else:
59
+ return {"output": output.shape[-1]}
60
+
61
+ def _reset_state(self, message: AxisArray) -> None:
62
+ model_kwargs = dict(self.settings.model_kwargs or {})
63
+ self._common_reset_state(message, model_kwargs)
64
+ self._init_optimizer()
65
+ self._validate_loss_keys(list(self._state.chan_ax.keys()))
66
+
67
+ batch_size = 1 if message.data.ndim == 2 else message.data.shape[0]
68
+ self.reset_hidden(batch_size)
69
+
70
+ def _maybe_reset_state(self, message: AxisArray, batch_size: int) -> bool:
71
+ preserve_state = self.settings.preserve_state_across_windows
72
+ if preserve_state == "auto":
73
+ axes = message.axes
74
+ if batch_size < 2:
75
+ # Single window, so preserve
76
+ preserve_state = True
77
+ elif "time" not in axes or "win" not in axes:
78
+ # Default fallback
79
+ ez.logger.warning("Missing 'time' or 'win' axis for auto preserve-state logic. Defaulting to reset.")
80
+ preserve_state = False
81
+ else:
82
+ # Calculate stride between windows (assuming uniform spacing)
83
+ win_stride = axes["win"].gain
84
+ # Calculate window length from time axis length and gain
85
+ time_len = message.data.shape[message.get_axis_idx("time")]
86
+ gain = getattr(axes["time"], "gain", None)
87
+ if gain is None:
88
+ ez.logger.warning("Time axis gain not found, using default gain of 1.0.")
89
+ gain = 1.0 # fallback default
90
+ win_len = time_len * gain
91
+ # Determine if we should preserve state
92
+ preserve_state = win_stride >= win_len
93
+
94
+ # Preserve if windows do NOT overlap: stride >= window length
95
+ if not preserve_state:
96
+ self.reset_hidden(batch_size)
97
+ else:
98
+ # If preserving state, only reset if batch size isn't 1
99
+ hx_batch_size = self._state.hx[0].shape[1] if isinstance(self._state.hx, tuple) else self._state.hx.shape[1]
100
+ if hx_batch_size != 1:
101
+ ez.logger.debug(f"Resetting hidden state due to batch size mismatch (hx: {hx_batch_size}, new: 1)")
102
+ self.reset_hidden(1)
103
+ return preserve_state
104
+
105
+ def _process(self, message: AxisArray) -> list[AxisArray]:
106
+ x = message.data
107
+ if not isinstance(x, torch.Tensor):
108
+ x = torch.tensor(
109
+ x,
110
+ dtype=torch.float32 if self.settings.single_precision else torch.float64,
111
+ device=self._state.device,
112
+ )
113
+
114
+ # Add batch dimension if missing
115
+ x, added_batch_dim = self._ensure_batched(x)
116
+
117
+ batch_size = x.shape[0]
118
+ preserve_state = self._maybe_reset_state(message, batch_size)
119
+
120
+ with torch.no_grad():
121
+ # If we are preserving state and have multiple batches, process sequentially
122
+ if preserve_state and batch_size > 1:
123
+ y_data = {}
124
+ for x_batch in x:
125
+ x_batch = x_batch.unsqueeze(0)
126
+ y, self._state.hx = self._state.model(x_batch, hx=self._state.hx)
127
+ for key, out in y.items():
128
+ if key not in y_data:
129
+ y_data[key] = []
130
+ y_data[key].append(out.cpu().numpy())
131
+ # Concatenate outputs for each key
132
+ y_data = {key: np.concatenate(outputs, axis=0) for key, outputs in y_data.items()}
133
+ else:
134
+ y, self._state.hx = self._state.model(x, hx=self._state.hx)
135
+ y_data = {
136
+ key: (out.cpu().numpy().squeeze(0) if added_batch_dim else out.cpu().numpy())
137
+ for key, out in y.items()
138
+ }
139
+
140
+ return [
141
+ replace(
142
+ message,
143
+ data=out,
144
+ axes={**message.axes, "ch": self._state.chan_ax[key]},
145
+ key=key,
146
+ )
147
+ for key, out in y_data.items()
148
+ ]
149
+
150
+ def reset_hidden(self, batch_size: int) -> None:
151
+ self._state.hx = self._state.model.init_hidden(batch_size, self._state.device)
152
+
153
+ def _train_step(
154
+ self,
155
+ X: torch.Tensor,
156
+ y_targ: dict[str, torch.Tensor],
157
+ loss_fns: dict[str, torch.nn.Module],
158
+ ) -> None:
159
+ y_pred, self._state.hx = self._state.model(X, hx=self._state.hx)
160
+ if not isinstance(y_pred, dict):
161
+ y_pred = {"output": y_pred}
162
+
163
+ loss_weights = self.settings.loss_weights or {}
164
+ losses = []
165
+ for key in y_targ.keys():
166
+ loss_fn = loss_fns.get(key)
167
+ if loss_fn is None:
168
+ raise ValueError(f"Loss function for key '{key}' is not defined.")
169
+ if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
170
+ loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long())
171
+ else:
172
+ loss = loss_fn(y_pred[key], y_targ[key])
173
+ weight = loss_weights.get(key, 1.0)
174
+ losses.append(loss * weight)
175
+
176
+ total_loss = sum(losses)
177
+ ez.logger.debug(
178
+ f"Training step loss: {total_loss.item()} (individual losses: {[loss.item() for loss in losses]})"
179
+ )
180
+
181
+ self._state.optimizer.zero_grad()
182
+ total_loss.backward()
183
+ self._state.optimizer.step()
184
+ if self._state.scheduler is not None:
185
+ self._state.scheduler.step()
186
+
187
+ def partial_fit(self, message: SampleMessage) -> None:
188
+ self._state.model.train()
189
+
190
+ X = self._to_tensor(message.sample.data)
191
+
192
+ # Add batch dimension if missing
193
+ X, batched = self._ensure_batched(X)
194
+
195
+ batch_size = X.shape[0]
196
+ preserve_state = self._maybe_reset_state(message.sample, batch_size)
197
+
198
+ y_targ = message.trigger.value
199
+ if not isinstance(y_targ, dict):
200
+ y_targ = {"output": y_targ}
201
+ y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
202
+ # Add batch dimension to y_targ values if missing
203
+ if batched:
204
+ for key in y_targ:
205
+ y_targ[key] = y_targ[key].unsqueeze(0)
206
+
207
+ loss_fns = self.settings.loss_fn
208
+ if loss_fns is None:
209
+ raise ValueError("loss_fn must be provided in settings to use partial_fit")
210
+ if not isinstance(loss_fns, dict):
211
+ loss_fns = {k: loss_fns for k in y_targ.keys()}
212
+
213
+ with torch.set_grad_enabled(True):
214
+ if preserve_state:
215
+ self._train_step(X, y_targ, loss_fns)
216
+ else:
217
+ for i in range(batch_size):
218
+ self._train_step(
219
+ X[i].unsqueeze(0),
220
+ {key: value[i].unsqueeze(0) for key, value in y_targ.items()},
221
+ loss_fns,
222
+ )
223
+
224
+ self._state.model.eval()
225
+ if self.settings.reset_hidden_on_fit:
226
+ self.reset_hidden(X.shape[0])
227
+
228
+
229
+ class RNNUnit(
230
+ BaseAdaptiveTransformerUnit[
231
+ RNNSettings,
232
+ AxisArray,
233
+ AxisArray,
234
+ RNNProcessor,
235
+ ]
236
+ ):
237
+ SETTINGS = RNNSettings
238
+
239
+ @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
240
+ @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
241
+ @profile_subpub(trace_oldest=False)
242
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
243
+ results = await self.processor.__acall__(message)
244
+ for result in results:
245
+ yield self.OUTPUT_SIGNAL, result
@@ -0,0 +1,117 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
6
+ BaseAdaptiveTransformer,
7
+ BaseAdaptiveTransformerUnit,
8
+ SampleMessage,
9
+ processor_state,
10
+ )
11
+ from ezmsg.util.messages.axisarray import AxisArray
12
+ from ezmsg.util.messages.util import replace
13
+ from sklearn.exceptions import NotFittedError
14
+ from sklearn.linear_model import SGDClassifier
15
+
16
+ from ..util import ClassifierMessage
17
+
18
+
19
+ class SGDDecoderSettings(ez.Settings):
20
+ alpha: float = 1e-5
21
+ eta0: float = 3e-4
22
+ loss: str = "hinge"
23
+ label_weights: dict[str, float] | None = None
24
+ settings_path: str | None = None
25
+
26
+
27
+ @processor_state
28
+ class SGDDecoderState:
29
+ model: typing.Any = None
30
+ b_first_train: bool = True
31
+
32
+
33
+ class SGDDecoderTransformer(BaseAdaptiveTransformer[SGDDecoderSettings, AxisArray, ClassifierMessage, SGDDecoderState]):
34
+ """
35
+ SGD-based online classifier.
36
+
37
+ Online Passive-Aggressive Algorithms
38
+ <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
39
+ K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
40
+ """
41
+
42
+ def _refreshed_model(self):
43
+ if self.settings.settings_path is not None:
44
+ import pickle
45
+
46
+ with open(self.settings.settings_path, "rb") as f:
47
+ model = pickle.load(f)
48
+ if self.settings.label_weights is not None:
49
+ model.class_weight = self.settings.label_weights
50
+ model.eta0 = self.settings.eta0
51
+ else:
52
+ model = SGDClassifier(
53
+ loss=self.settings.loss,
54
+ alpha=self.settings.alpha,
55
+ penalty="elasticnet",
56
+ learning_rate="adaptive",
57
+ eta0=self.settings.eta0,
58
+ early_stopping=False,
59
+ class_weight=self.settings.label_weights,
60
+ )
61
+ return model
62
+
63
+ def _reset_state(self, message: AxisArray) -> None:
64
+ self._state.model = self._refreshed_model()
65
+
66
+ def _process(self, message: AxisArray) -> ClassifierMessage | None:
67
+ if self._state.model is None or not message.data.size:
68
+ return None
69
+ if np.any(np.isnan(message.data)):
70
+ return None
71
+ try:
72
+ X = message.data.reshape((message.data.shape[0], -1))
73
+ result = self._state.model._predict_proba_lr(X)
74
+ except NotFittedError:
75
+ return None
76
+ out_axes = {}
77
+ if message.dims[0] in message.axes:
78
+ out_axes[message.dims[0]] = replace(
79
+ message.axes[message.dims[0]],
80
+ offset=message.axes[message.dims[0]].offset,
81
+ )
82
+ return ClassifierMessage(
83
+ data=result,
84
+ dims=message.dims[:1] + ["labels"],
85
+ axes=out_axes,
86
+ labels=list(self._state.model.class_weight.keys()),
87
+ key=message.key,
88
+ )
89
+
90
+ def partial_fit(self, message: SampleMessage) -> None:
91
+ if self._hash != 0:
92
+ self._reset_state(message.sample)
93
+ self._hash = 0
94
+
95
+ if np.any(np.isnan(message.sample.data)):
96
+ return
97
+ train_sample = message.sample.data.reshape(1, -1)
98
+ if self._state.b_first_train:
99
+ self._state.model.partial_fit(
100
+ train_sample,
101
+ [message.trigger.value],
102
+ classes=list(self.settings.label_weights.keys()),
103
+ )
104
+ self._state.b_first_train = False
105
+ else:
106
+ self._state.model.partial_fit(train_sample, [message.trigger.value])
107
+
108
+
109
+ class SGDDecoder(
110
+ BaseAdaptiveTransformerUnit[
111
+ SGDDecoderSettings,
112
+ AxisArray,
113
+ ClassifierMessage,
114
+ SGDDecoderTransformer,
115
+ ]
116
+ ):
117
+ SETTINGS = SGDDecoderSettings
@@ -0,0 +1,241 @@
1
+ import importlib
2
+ import pickle
3
+ import typing
4
+
5
+ import ezmsg.core as ez
6
+ import numpy as np
7
+ import pandas as pd
8
+ from ezmsg.baseproc import (
9
+ BaseAdaptiveTransformer,
10
+ BaseAdaptiveTransformerUnit,
11
+ processor_state,
12
+ )
13
+ from ezmsg.sigproc.sampler import SampleMessage
14
+ from ezmsg.util.messages.axisarray import AxisArray
15
+ from ezmsg.util.messages.util import replace
16
+
17
+
18
+ class SklearnModelSettings(ez.Settings):
19
+ model_class: str
20
+ """
21
+ Full path to the sklearn model class
22
+ Example: 'sklearn.linear_model.LinearRegression'
23
+ """
24
+ model_kwargs: dict[str, typing.Any] = None
25
+ """
26
+ Additional keyword arguments to pass to the model constructor.
27
+ Example: {'fit_intercept': True, 'normalize': False}
28
+ """
29
+ checkpoint_path: str | None = None
30
+ """
31
+ Path to a checkpoint file to load the model from.
32
+ If provided, the model will be initialized from this checkpoint.
33
+ Example: 'path/to/checkpoint.pkl'
34
+ """
35
+ partial_fit_classes: np.ndarray | None = None
36
+ """
37
+ The full list of classes to use for partial_fit calls.
38
+ This must be provided to use `partial_fit` with classifiers.
39
+ """
40
+
41
+
42
+ @processor_state
43
+ class SklearnModelState:
44
+ model: typing.Any = None
45
+ chan_ax: AxisArray.CoordinateAxis | None = None
46
+
47
+
48
+ class SklearnModelProcessor(BaseAdaptiveTransformer[SklearnModelSettings, AxisArray, AxisArray, SklearnModelState]):
49
+ """
50
+ Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework.
51
+
52
+ This processor supports:
53
+ - `fit`, `partial_fit`, or River's `learn_many`/`learn_one` for training.
54
+ - `predict`, River's `predict_many`, or `predict_one` for inference.
55
+ - Optional model checkpoint loading and saving.
56
+
57
+ The processor expects and outputs `AxisArray` messages with a `"ch"` (channel) axis.
58
+
59
+ Settings:
60
+ ---------
61
+ model_class : str
62
+ Full path to the sklearn or River model class to use.
63
+ Example: "sklearn.linear_model.SGDClassifier" or "river.linear_model.LogisticRegression"
64
+
65
+ model_kwargs : dict[str, typing.Any], optional
66
+ Additional keyword arguments passed to the model constructor.
67
+
68
+ checkpoint_path : str, optional
69
+ Path to a pickle file to load a previously saved model. If provided, the model will
70
+ be restored from this path at startup.
71
+
72
+ partial_fit_classes : np.ndarray, optional
73
+ For classifiers that require all class labels to be specified during `partial_fit`.
74
+
75
+ Example:
76
+ -----------------------------
77
+ ```python
78
+ processor = SklearnModelProcessor(
79
+ settings=SklearnModelSettings(
80
+ model_class='sklearn.linear_model.SGDClassifier',
81
+ model_kwargs={'loss': 'log_loss'},
82
+ partial_fit_classes=np.array([0, 1]),
83
+ )
84
+ )
85
+ ```
86
+ """
87
+
88
+ def _init_model(self) -> None:
89
+ module_path, class_name = self.settings.model_class.rsplit(".", 1)
90
+ model_cls = getattr(importlib.import_module(module_path), class_name)
91
+ kwargs = self.settings.model_kwargs or {}
92
+ self._state.model = model_cls(**kwargs)
93
+
94
+ def save_checkpoint(self, path: str) -> None:
95
+ with open(path, "wb") as f:
96
+ pickle.dump(self._state.model, f)
97
+
98
+ def load_checkpoint(self, path: str) -> None:
99
+ try:
100
+ with open(path, "rb") as f:
101
+ self._state.model = pickle.load(f)
102
+ except Exception as e:
103
+ ez.logger.error(f"Failed to load model from {path}: {str(e)}")
104
+ raise RuntimeError(f"Failed to load model from {path}: {str(e)}") from e
105
+
106
+ def _reset_state(self, message: AxisArray) -> None:
107
+ # Try loading from checkpoint first
108
+ if self.settings.checkpoint_path:
109
+ self.load_checkpoint(self.settings.checkpoint_path)
110
+ n_input = message.data.shape[message.get_axis_idx("ch")]
111
+ if hasattr(self._state.model, "n_features_in_"):
112
+ expected = self._state.model.n_features_in_
113
+ if expected != n_input:
114
+ raise ValueError(f"Model expects {expected} features, but got {n_input}")
115
+ else:
116
+ # No checkpoint, initialize from scratch
117
+ self._init_model()
118
+
119
+ def partial_fit(self, message: SampleMessage) -> None:
120
+ X = message.sample.data
121
+ y = message.trigger.value
122
+ if self._state.model is None:
123
+ self._reset_state(message.sample)
124
+ if hasattr(self._state.model, "partial_fit"):
125
+ kwargs = {}
126
+ if self.settings.partial_fit_classes is not None:
127
+ kwargs["classes"] = self.settings.partial_fit_classes
128
+ self._state.model.partial_fit(X, y, **kwargs)
129
+ elif hasattr(self._state.model, "learn_many"):
130
+ df_X = pd.DataFrame({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
131
+ name = (
132
+ message.trigger.value.axes["ch"].data[0]
133
+ if hasattr(message.trigger.value, "axes") and "ch" in message.trigger.value.axes
134
+ else "target"
135
+ )
136
+ ser_y = pd.Series(
137
+ data=np.asarray(message.trigger.value.data).flatten(),
138
+ name=name,
139
+ )
140
+ self._state.model.learn_many(df_X, ser_y)
141
+ elif hasattr(self._state.model, "learn_one"):
142
+ # river's random forest does not support learn_many
143
+ for xi, yi in zip(X, y):
144
+ features = {f"f{i}": xi[i] for i in range(len(xi))}
145
+ self._state.model.learn_one(features, yi)
146
+ else:
147
+ raise NotImplementedError("Model does not support partial_fit or learn_many")
148
+
149
+ def fit(self, X: np.ndarray, y: np.ndarray) -> None:
150
+ if self._state.model is None:
151
+ dummy_msg = AxisArray(
152
+ data=X,
153
+ dims=["time", "ch"],
154
+ axes={
155
+ "time": AxisArray.TimeAxis(fs=1.0),
156
+ "ch": AxisArray.CoordinateAxis(
157
+ data=np.array([f"ch_{i}" for i in range(X.shape[1])]),
158
+ dims=["ch"],
159
+ ),
160
+ },
161
+ )
162
+ self._reset_state(dummy_msg)
163
+ if hasattr(self._state.model, "fit"):
164
+ self._state.model.fit(X, y)
165
+ elif hasattr(self._state.model, "learn_many"):
166
+ df_X = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])])
167
+ ser_y = pd.Series(y.flatten(), name="target")
168
+ self._state.model.learn_many(df_X, ser_y)
169
+ elif hasattr(self._state.model, "learn_one"):
170
+ # river's random forest does not support learn_many
171
+ for xi, yi in zip(X, y):
172
+ features = {f"f{i}": xi[i] for i in range(len(xi))}
173
+ self._state.model.learn_one(features, yi)
174
+ else:
175
+ raise NotImplementedError("Model does not support fit or learn_many")
176
+
177
+ def _process(self, message: AxisArray) -> AxisArray:
178
+ if self._state.model is None:
179
+ raise RuntimeError("Model has not been fit yet. Call `fit()` or `partial_fit()` before processing.")
180
+ X = message.data
181
+ original_shape = X.shape
182
+ n_input = X.shape[message.get_axis_idx("ch")]
183
+
184
+ # Ensure X is 2D
185
+ X = X.reshape(-1, n_input)
186
+ if hasattr(self._state.model, "n_features_in_"):
187
+ expected = self._state.model.n_features_in_
188
+ if expected != n_input:
189
+ raise ValueError(f"Model expects {expected} features, but got {n_input}")
190
+
191
+ if hasattr(self._state.model, "predict"):
192
+ y_pred = self._state.model.predict(X)
193
+ elif hasattr(self._state.model, "predict_many"):
194
+ df_X = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])])
195
+ y_pred = self._state.model.predict_many(df_X)
196
+ y_pred = np.array(list(y_pred))
197
+ elif hasattr(self._state.model, "predict_one"):
198
+ # river's random forest does not support predict_many
199
+ y_pred = np.array([self._state.model.predict_one({f"f{i}": xi[i] for i in range(len(xi))}) for xi in X])
200
+ else:
201
+ raise NotImplementedError("Model does not support predict or predict_many")
202
+
203
+ # For scalar outputs, ensure the output is 2D
204
+ if y_pred.ndim == 1:
205
+ y_pred = y_pred[:, np.newaxis]
206
+
207
+ output_shape = original_shape[:-1] + (y_pred.shape[-1],)
208
+ y_pred = y_pred.reshape(output_shape)
209
+
210
+ if self._state.chan_ax is None:
211
+ self._state.chan_ax = AxisArray.CoordinateAxis(data=np.arange(output_shape[1]), dims=["ch"])
212
+
213
+ return replace(
214
+ message,
215
+ data=y_pred,
216
+ axes={**message.axes, "ch": self._state.chan_ax},
217
+ )
218
+
219
+
220
+ class SklearnModelUnit(BaseAdaptiveTransformerUnit[SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor]):
221
+ """
222
+ Unit wrapper for the `SklearnModelProcessor`.
223
+
224
+ This unit provides a plug-and-play interface for using a scikit-learn or River model
225
+ in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs predictions
226
+ in the same format, optionally performing training via `partial_fit` or `fit`.
227
+
228
+ Example:
229
+ --------
230
+ ```python
231
+ unit = SklearnModelUnit(
232
+ settings=SklearnModelSettings(
233
+ model_class='sklearn.linear_model.SGDClassifier',
234
+ model_kwargs={'loss': 'log_loss'},
235
+ partial_fit_classes=np.array([0, 1]),
236
+ )
237
+ )
238
+ ```
239
+ """
240
+
241
+ SETTINGS = SklearnModelSettings