ezmsg-learn 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,266 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ import torch
6
+ from ezmsg.sigproc.base import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
7
+ from ezmsg.sigproc.sampler import SampleMessage
8
+ from ezmsg.sigproc.util.profile import profile_subpub
9
+ from ezmsg.util.messages.axisarray import AxisArray
10
+ from ezmsg.util.messages.util import replace
11
+
12
+ from .base import ModelInitMixin
13
+ from .torch import (
14
+ TorchModelSettings,
15
+ TorchModelState,
16
+ TorchProcessorMixin,
17
+ )
18
+
19
+
20
+ class RNNSettings(TorchModelSettings):
21
+ model_class: str = "ezmsg.learn.model.rnn.RNNModel"
22
+ """
23
+ Fully qualified class path of the model to be used.
24
+ This should be "ezmsg.learn.model.rnn.RNNModel" for this.
25
+ """
26
+ reset_hidden_on_fit: bool = True
27
+ """
28
+ Whether to reset the hidden state on each fit call.
29
+ If True, the hidden state will be reset to zero after each fit.
30
+ If False, the hidden state will be maintained across fit calls.
31
+ """
32
+ preserve_state_across_windows: bool | typing.Literal["auto"] = "auto"
33
+ """
34
+ Whether to preserve the hidden state across windows.
35
+ If True, the hidden state will be preserved across windows.
36
+ If False, the hidden state will be reset at the start of each window.
37
+ If "auto", preserve if there is no overlap in time windows, otherwise reset.
38
+ """
39
+
40
+
41
+ class RNNState(TorchModelState):
42
+ hx: typing.Optional[torch.Tensor] = None
43
+
44
+
45
+ class RNNProcessor(
46
+ BaseAdaptiveTransformer[RNNSettings, AxisArray, AxisArray, RNNState],
47
+ TorchProcessorMixin,
48
+ ModelInitMixin,
49
+ ):
50
+ def _infer_output_sizes(
51
+ self, model: torch.nn.Module, n_input: int
52
+ ) -> dict[str, int]:
53
+ """Simple inference to get output channel size."""
54
+ dummy_input = torch.zeros(1, 50, n_input, device=self._state.device)
55
+ with torch.no_grad():
56
+ output, _ = model(dummy_input)
57
+
58
+ if isinstance(output, dict):
59
+ return {k: v.shape[-1] for k, v in output.items()}
60
+ else:
61
+ return {"output": output.shape[-1]}
62
+
63
+ def _reset_state(self, message: AxisArray) -> None:
64
+ model_kwargs = dict(self.settings.model_kwargs or {})
65
+ self._common_reset_state(message, model_kwargs)
66
+ self._init_optimizer()
67
+ self._validate_loss_keys(list(self._state.chan_ax.keys()))
68
+
69
+ batch_size = 1 if message.data.ndim == 2 else message.data.shape[0]
70
+ self.reset_hidden(batch_size)
71
+
72
+ def _maybe_reset_state(self, message: AxisArray, batch_size: int) -> bool:
73
+ preserve_state = self.settings.preserve_state_across_windows
74
+ if preserve_state == "auto":
75
+ axes = message.axes
76
+ if batch_size < 2:
77
+ # Single window, so preserve
78
+ preserve_state = True
79
+ elif "time" not in axes or "win" not in axes:
80
+ # Default fallback
81
+ ez.logger.warning(
82
+ "Missing 'time' or 'win' axis for auto preserve-state logic. Defaulting to reset."
83
+ )
84
+ preserve_state = False
85
+ else:
86
+ # Calculate stride between windows (assuming uniform spacing)
87
+ win_stride = axes["win"].gain
88
+ # Calculate window length from time axis length and gain
89
+ time_len = message.data.shape[message.get_axis_idx("time")]
90
+ gain = getattr(axes["time"], "gain", None)
91
+ if gain is None:
92
+ ez.logger.warning(
93
+ "Time axis gain not found, using default gain of 1.0."
94
+ )
95
+ gain = 1.0 # fallback default
96
+ win_len = time_len * gain
97
+ # Determine if we should preserve state
98
+ preserve_state = win_stride >= win_len
99
+
100
+ # Preserve if windows do NOT overlap: stride >= window length
101
+ if not preserve_state:
102
+ self.reset_hidden(batch_size)
103
+ else:
104
+ # If preserving state, only reset if batch size isn't 1
105
+ hx_batch_size = (
106
+ self._state.hx[0].shape[1]
107
+ if isinstance(self._state.hx, tuple)
108
+ else self._state.hx.shape[1]
109
+ )
110
+ if hx_batch_size != 1:
111
+ ez.logger.debug(
112
+ f"Resetting hidden state due to batch size mismatch (hx: {hx_batch_size}, new: 1)"
113
+ )
114
+ self.reset_hidden(1)
115
+ return preserve_state
116
+
117
+ def _process(self, message: AxisArray) -> list[AxisArray]:
118
+ x = message.data
119
+ if not isinstance(x, torch.Tensor):
120
+ x = torch.tensor(
121
+ x,
122
+ dtype=torch.float32
123
+ if self.settings.single_precision
124
+ else torch.float64,
125
+ device=self._state.device,
126
+ )
127
+
128
+ # Add batch dimension if missing
129
+ x, added_batch_dim = self._ensure_batched(x)
130
+
131
+ batch_size = x.shape[0]
132
+ preserve_state = self._maybe_reset_state(message, batch_size)
133
+
134
+ with torch.no_grad():
135
+ # If we are preserving state and have multiple batches, process sequentially
136
+ if preserve_state and batch_size > 1:
137
+ y_data = {}
138
+ for x_batch in x:
139
+ x_batch = x_batch.unsqueeze(0)
140
+ y, self._state.hx = self._state.model(x_batch, hx=self._state.hx)
141
+ for key, out in y.items():
142
+ if key not in y_data:
143
+ y_data[key] = []
144
+ y_data[key].append(out.cpu().numpy())
145
+ # Concatenate outputs for each key
146
+ y_data = {
147
+ key: np.concatenate(outputs, axis=0)
148
+ for key, outputs in y_data.items()
149
+ }
150
+ else:
151
+ y, self._state.hx = self._state.model(x, hx=self._state.hx)
152
+ y_data = {
153
+ key: (
154
+ out.cpu().numpy().squeeze(0)
155
+ if added_batch_dim
156
+ else out.cpu().numpy()
157
+ )
158
+ for key, out in y.items()
159
+ }
160
+
161
+ return [
162
+ replace(
163
+ message,
164
+ data=out,
165
+ axes={**message.axes, "ch": self._state.chan_ax[key]},
166
+ key=key,
167
+ )
168
+ for key, out in y_data.items()
169
+ ]
170
+
171
+ def reset_hidden(self, batch_size: int) -> None:
172
+ self._state.hx = self._state.model.init_hidden(batch_size, self._state.device)
173
+
174
+ def _train_step(
175
+ self,
176
+ X: torch.Tensor,
177
+ y_targ: dict[str, torch.Tensor],
178
+ loss_fns: dict[str, torch.nn.Module],
179
+ ) -> None:
180
+ y_pred, self._state.hx = self._state.model(X, hx=self._state.hx)
181
+ if not isinstance(y_pred, dict):
182
+ y_pred = {"output": y_pred}
183
+
184
+ loss_weights = self.settings.loss_weights or {}
185
+ losses = []
186
+ for key in y_targ.keys():
187
+ loss_fn = loss_fns.get(key)
188
+ if loss_fn is None:
189
+ raise ValueError(f"Loss function for key '{key}' is not defined.")
190
+ if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
191
+ loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long())
192
+ else:
193
+ loss = loss_fn(y_pred[key], y_targ[key])
194
+ weight = loss_weights.get(key, 1.0)
195
+ losses.append(loss * weight)
196
+
197
+ total_loss = sum(losses)
198
+ ez.logger.debug(
199
+ f"Training step loss: {total_loss.item()} (individual losses: {[loss.item() for loss in losses]})"
200
+ )
201
+
202
+ self._state.optimizer.zero_grad()
203
+ total_loss.backward()
204
+ self._state.optimizer.step()
205
+ if self._state.scheduler is not None:
206
+ self._state.scheduler.step()
207
+
208
+ def partial_fit(self, message: SampleMessage) -> None:
209
+ self._state.model.train()
210
+
211
+ X = self._to_tensor(message.sample.data)
212
+
213
+ # Add batch dimension if missing
214
+ X, batched = self._ensure_batched(X)
215
+
216
+ batch_size = X.shape[0]
217
+ preserve_state = self._maybe_reset_state(message.sample, batch_size)
218
+
219
+ y_targ = message.trigger.value
220
+ if not isinstance(y_targ, dict):
221
+ y_targ = {"output": y_targ}
222
+ y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
223
+ # Add batch dimension to y_targ values if missing
224
+ if batched:
225
+ for key in y_targ:
226
+ y_targ[key] = y_targ[key].unsqueeze(0)
227
+
228
+ loss_fns = self.settings.loss_fn
229
+ if loss_fns is None:
230
+ raise ValueError("loss_fn must be provided in settings to use partial_fit")
231
+ if not isinstance(loss_fns, dict):
232
+ loss_fns = {k: loss_fns for k in y_targ.keys()}
233
+
234
+ with torch.set_grad_enabled(True):
235
+ if preserve_state:
236
+ self._train_step(X, y_targ, loss_fns)
237
+ else:
238
+ for i in range(batch_size):
239
+ self._train_step(
240
+ X[i].unsqueeze(0),
241
+ {key: value[i].unsqueeze(0) for key, value in y_targ.items()},
242
+ loss_fns,
243
+ )
244
+
245
+ self._state.model.eval()
246
+ if self.settings.reset_hidden_on_fit:
247
+ self.reset_hidden(X.shape[0])
248
+
249
+
250
+ class RNNUnit(
251
+ BaseAdaptiveTransformerUnit[
252
+ RNNSettings,
253
+ AxisArray,
254
+ AxisArray,
255
+ RNNProcessor,
256
+ ]
257
+ ):
258
+ SETTINGS = RNNSettings
259
+
260
+ @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
261
+ @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
262
+ @profile_subpub(trace_oldest=False)
263
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
264
+ results = await self.processor.__acall__(message)
265
+ for result in results:
266
+ yield self.OUTPUT_SIGNAL, result
@@ -0,0 +1,131 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ from ezmsg.sigproc.sampler import SampleMessage
6
+ from ezmsg.sigproc.base import GenAxisArray
7
+ from ezmsg.util.generator import consumer
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.messages.util import replace
10
+ from sklearn.exceptions import NotFittedError
11
+ from sklearn.linear_model import SGDClassifier
12
+
13
+ from ..util import ClassifierMessage
14
+
15
+
16
+ @consumer
17
+ def sgd_decoder(
18
+ alpha: float = 1.5e-5,
19
+ eta0: float = 1e-7, # Lower than what you'd use for offline training.
20
+ loss: str = "squared_hinge",
21
+ label_weights: dict[str, float] | None = None,
22
+ settings_path: str | None = None,
23
+ ) -> typing.Generator[AxisArray | SampleMessage, ClassifierMessage | None, None]:
24
+ """
25
+ Passive Aggressive Classifier
26
+ Online Passive-Aggressive Algorithms <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
27
+ K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
28
+
29
+ Args:
30
+ alpha: Maximum step size (regularization)
31
+ eta0: The initial learning rate for the 'adaptive’ schedules.
32
+ loss: The loss function to be used:
33
+ hinge: equivalent to PA-I in the reference paper.
34
+ squared_hinge: equivalent to PA-II in the reference paper.
35
+ label_weights: An optional dictionary of label names and their relative weight.
36
+ e.g., {'Go': 31.0, 'Stop': 0.5}
37
+ If this is None then settings_path must be provided and the pre-trained model
38
+ settings_path: Path to the stored sklearn model pkl file.
39
+
40
+ Returns:
41
+ Generator that accepts `SampleMessage` for incremental training (`partial_fit`) and yields None,
42
+ or `AxisArray` for inference (`predict`) and yields a `ClassifierMessage`.
43
+ """
44
+ # pre-init inputs and outputs
45
+ msg_out = ClassifierMessage(data=np.array([]), dims=[""])
46
+
47
+ # State variables:
48
+
49
+ if settings_path is not None:
50
+ import pickle
51
+
52
+ with open(settings_path, "rb") as f:
53
+ model = pickle.load(f)
54
+ if label_weights is not None:
55
+ model.class_weight = label_weights
56
+ # Overwrite eta0, probably with a value lower than what was used online.
57
+ model.eta0 = eta0
58
+ else:
59
+ model = SGDClassifier(
60
+ loss=loss,
61
+ alpha=alpha,
62
+ penalty="elasticnet",
63
+ learning_rate="adaptive",
64
+ eta0=eta0,
65
+ early_stopping=False,
66
+ class_weight=label_weights,
67
+ )
68
+
69
+ b_first_train = True
70
+ # TODO: template_out
71
+
72
+ while True:
73
+ msg_in: AxisArray | SampleMessage = yield msg_out
74
+
75
+ msg_out = None
76
+ if type(msg_in) is SampleMessage:
77
+ # SampleMessage used for training.
78
+ if not np.any(np.isnan(msg_in.sample.data)):
79
+ train_sample = msg_in.sample.data.reshape(1, -1)
80
+ if b_first_train:
81
+ model.partial_fit(
82
+ train_sample,
83
+ [msg_in.trigger.value],
84
+ classes=list(label_weights.keys()),
85
+ )
86
+ b_first_train = False
87
+ else:
88
+ model.partial_fit(train_sample, [msg_in.trigger.value])
89
+ elif msg_in.data.size:
90
+ # AxisArray used for inference
91
+ if not np.any(np.isnan(msg_in.data)):
92
+ try:
93
+ X = msg_in.data.reshape((msg_in.data.shape[0], -1))
94
+ result = model._predict_proba_lr(X)
95
+ except NotFittedError:
96
+ result = None
97
+ if result is not None:
98
+ out_axes = {}
99
+ if msg_in.dims[0] in msg_in.axes:
100
+ out_axes[msg_in.dims[0]] = replace(
101
+ msg_in.axes[msg_in.dims[0]],
102
+ offset=msg_in.axes[msg_in.dims[0]].offset,
103
+ )
104
+ msg_out = ClassifierMessage(
105
+ data=result,
106
+ dims=msg_in.dims[:1] + ["labels"],
107
+ axes=out_axes,
108
+ labels=list(model.class_weight.keys()),
109
+ key=msg_in.key,
110
+ )
111
+
112
+
113
+ class SGDDecoderSettings(ez.Settings):
114
+ alpha: float = 1e-5
115
+ eta0: float = 3e-4
116
+ loss: str = "hinge"
117
+ label_weights: dict[str, float] | None = None
118
+ settings_path: str | None = None
119
+
120
+
121
+ class SGDDecoder(GenAxisArray):
122
+ SETTINGS = SGDDecoderSettings
123
+ INPUT_SAMPLE = ez.InputStream(SampleMessage)
124
+
125
+ # Method to be implemented by subclasses to construct the specific generator
126
+ def construct_generator(self):
127
+ self.STATE.gen = sgd_decoder(**self.SETTINGS.__dict__)
128
+
129
+ @ez.subscriber(INPUT_SAMPLE)
130
+ async def on_sample(self, msg: SampleMessage) -> None:
131
+ _ = self.STATE.gen.send(msg)
@@ -0,0 +1,274 @@
1
+ import importlib
2
+ import pickle
3
+ import typing
4
+
5
+ import ezmsg.core as ez
6
+ import numpy as np
7
+ import pandas as pd
8
+ from ezmsg.sigproc.base import (
9
+ BaseAdaptiveTransformer,
10
+ BaseAdaptiveTransformerUnit,
11
+ processor_state,
12
+ )
13
+ from ezmsg.sigproc.sampler import SampleMessage
14
+ from ezmsg.util.messages.axisarray import AxisArray
15
+ from ezmsg.util.messages.util import replace
16
+
17
+
18
+ class SklearnModelSettings(ez.Settings):
19
+ model_class: str
20
+ """
21
+ Full path to the sklearn model class
22
+ Example: 'sklearn.linear_model.LinearRegression'
23
+ """
24
+ model_kwargs: dict[str, typing.Any] = None
25
+ """
26
+ Additional keyword arguments to pass to the model constructor.
27
+ Example: {'fit_intercept': True, 'normalize': False}
28
+ """
29
+ checkpoint_path: str | None = None
30
+ """
31
+ Path to a checkpoint file to load the model from.
32
+ If provided, the model will be initialized from this checkpoint.
33
+ Example: 'path/to/checkpoint.pkl'
34
+ """
35
+ partial_fit_classes: np.ndarray | None = None
36
+ """
37
+ The full list of classes to use for partial_fit calls.
38
+ This must be provided to use `partial_fit` with classifiers.
39
+ """
40
+
41
+
42
+ @processor_state
43
+ class SklearnModelState:
44
+ model: typing.Any = None
45
+ chan_ax: AxisArray.CoordinateAxis | None = None
46
+
47
+
48
+ class SklearnModelProcessor(
49
+ BaseAdaptiveTransformer[
50
+ SklearnModelSettings, AxisArray, AxisArray, SklearnModelState
51
+ ]
52
+ ):
53
+ """
54
+ Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework.
55
+
56
+ This processor supports:
57
+ - `fit`, `partial_fit`, or River's `learn_many`/`learn_one` for training.
58
+ - `predict`, River's `predict_many`, or `predict_one` for inference.
59
+ - Optional model checkpoint loading and saving.
60
+
61
+ The processor expects and outputs `AxisArray` messages with a `"ch"` (channel) axis.
62
+
63
+ Settings:
64
+ ---------
65
+ model_class : str
66
+ Full path to the sklearn or River model class to use.
67
+ Example: "sklearn.linear_model.SGDClassifier" or "river.linear_model.LogisticRegression"
68
+
69
+ model_kwargs : dict[str, typing.Any], optional
70
+ Additional keyword arguments passed to the model constructor.
71
+
72
+ checkpoint_path : str, optional
73
+ Path to a pickle file to load a previously saved model. If provided, the model will
74
+ be restored from this path at startup.
75
+
76
+ partial_fit_classes : np.ndarray, optional
77
+ For classifiers that require all class labels to be specified during `partial_fit`.
78
+
79
+ Example:
80
+ -----------------------------
81
+ ```python
82
+ processor = SklearnModelProcessor(
83
+ settings=SklearnModelSettings(
84
+ model_class='sklearn.linear_model.SGDClassifier',
85
+ model_kwargs={'loss': 'log_loss'},
86
+ partial_fit_classes=np.array([0, 1]),
87
+ )
88
+ )
89
+ ```
90
+ """
91
+
92
+ def _init_model(self) -> None:
93
+ module_path, class_name = self.settings.model_class.rsplit(".", 1)
94
+ model_cls = getattr(importlib.import_module(module_path), class_name)
95
+ kwargs = self.settings.model_kwargs or {}
96
+ self._state.model = model_cls(**kwargs)
97
+
98
+ def save_checkpoint(self, path: str) -> None:
99
+ with open(path, "wb") as f:
100
+ pickle.dump(self._state.model, f)
101
+
102
+ def load_checkpoint(self, path: str) -> None:
103
+ try:
104
+ with open(path, "rb") as f:
105
+ self._state.model = pickle.load(f)
106
+ except Exception as e:
107
+ ez.logger.error(f"Failed to load model from {path}: {str(e)}")
108
+ raise RuntimeError(f"Failed to load model from {path}: {str(e)}") from e
109
+
110
+ def _reset_state(self, message: AxisArray) -> None:
111
+ # Try loading from checkpoint first
112
+ if self.settings.checkpoint_path:
113
+ self.load_checkpoint(self.settings.checkpoint_path)
114
+ n_input = message.data.shape[message.get_axis_idx("ch")]
115
+ if hasattr(self._state.model, "n_features_in_"):
116
+ expected = self._state.model.n_features_in_
117
+ if expected != n_input:
118
+ raise ValueError(
119
+ f"Model expects {expected} features, but got {n_input}"
120
+ )
121
+ else:
122
+ # No checkpoint, initialize from scratch
123
+ self._init_model()
124
+
125
+ def partial_fit(self, message: SampleMessage) -> None:
126
+ X = message.sample.data
127
+ y = message.trigger.value
128
+ if self._state.model is None:
129
+ self._reset_state(message.sample)
130
+ if hasattr(self._state.model, "partial_fit"):
131
+ kwargs = {}
132
+ if self.settings.partial_fit_classes is not None:
133
+ kwargs["classes"] = self.settings.partial_fit_classes
134
+ self._state.model.partial_fit(X, y, **kwargs)
135
+ elif hasattr(self._state.model, "learn_many"):
136
+ df_X = pd.DataFrame(
137
+ {
138
+ k: v
139
+ for k, v in zip(
140
+ message.sample.axes["ch"].data, message.sample.data.T
141
+ )
142
+ }
143
+ )
144
+ name = (
145
+ message.trigger.value.axes["ch"].data[0]
146
+ if hasattr(message.trigger.value, "axes")
147
+ and "ch" in message.trigger.value.axes
148
+ else "target"
149
+ )
150
+ ser_y = pd.Series(
151
+ data=np.asarray(message.trigger.value.data).flatten(),
152
+ name=name,
153
+ )
154
+ self._state.model.learn_many(df_X, ser_y)
155
+ elif hasattr(self._state.model, "learn_one"):
156
+ # river's random forest does not support learn_many
157
+ for xi, yi in zip(X, y):
158
+ features = {f"f{i}": xi[i] for i in range(len(xi))}
159
+ self._state.model.learn_one(features, yi)
160
+ else:
161
+ raise NotImplementedError(
162
+ "Model does not support partial_fit or learn_many"
163
+ )
164
+
165
+ def fit(self, X: np.ndarray, y: np.ndarray) -> None:
166
+ if self._state.model is None:
167
+ dummy_msg = AxisArray(
168
+ data=X,
169
+ dims=["time", "ch"],
170
+ axes={
171
+ "time": AxisArray.TimeAxis(fs=1.0),
172
+ "ch": AxisArray.CoordinateAxis(
173
+ data=np.array([f"ch_{i}" for i in range(X.shape[1])]),
174
+ dims=["ch"],
175
+ ),
176
+ },
177
+ )
178
+ self._reset_state(dummy_msg)
179
+ if hasattr(self._state.model, "fit"):
180
+ self._state.model.fit(X, y)
181
+ elif hasattr(self._state.model, "learn_many"):
182
+ df_X = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])])
183
+ ser_y = pd.Series(y.flatten(), name="target")
184
+ self._state.model.learn_many(df_X, ser_y)
185
+ elif hasattr(self._state.model, "learn_one"):
186
+ # river's random forest does not support learn_many
187
+ for xi, yi in zip(X, y):
188
+ features = {f"f{i}": xi[i] for i in range(len(xi))}
189
+ self._state.model.learn_one(features, yi)
190
+ else:
191
+ raise NotImplementedError("Model does not support fit or learn_many")
192
+
193
+ def _process(self, message: AxisArray) -> AxisArray:
194
+ if self._state.model is None:
195
+ raise RuntimeError(
196
+ "Model has not been fit yet. Call `fit()` or `partial_fit()` before processing."
197
+ )
198
+ X = message.data
199
+ original_shape = X.shape
200
+ n_input = X.shape[message.get_axis_idx("ch")]
201
+
202
+ # Ensure X is 2D
203
+ X = X.reshape(-1, n_input)
204
+ if hasattr(self._state.model, "n_features_in_"):
205
+ expected = self._state.model.n_features_in_
206
+ if expected != n_input:
207
+ raise ValueError(
208
+ f"Model expects {expected} features, but got {n_input}"
209
+ )
210
+
211
+ if hasattr(self._state.model, "predict"):
212
+ y_pred = self._state.model.predict(X)
213
+ elif hasattr(self._state.model, "predict_many"):
214
+ df_X = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])])
215
+ y_pred = self._state.model.predict_many(df_X)
216
+ y_pred = np.array(list(y_pred))
217
+ elif hasattr(self._state.model, "predict_one"):
218
+ # river's random forest does not support predict_many
219
+ y_pred = np.array(
220
+ [
221
+ self._state.model.predict_one(
222
+ {f"f{i}": xi[i] for i in range(len(xi))}
223
+ )
224
+ for xi in X
225
+ ]
226
+ )
227
+ else:
228
+ raise NotImplementedError("Model does not support predict or predict_many")
229
+
230
+ # For scalar outputs, ensure the output is 2D
231
+ if y_pred.ndim == 1:
232
+ y_pred = y_pred[:, np.newaxis]
233
+
234
+ output_shape = original_shape[:-1] + (y_pred.shape[-1],)
235
+ y_pred = y_pred.reshape(output_shape)
236
+
237
+ if self._state.chan_ax is None:
238
+ self._state.chan_ax = AxisArray.CoordinateAxis(
239
+ data=np.arange(output_shape[1]), dims=["ch"]
240
+ )
241
+
242
+ return replace(
243
+ message,
244
+ data=y_pred,
245
+ axes={**message.axes, "ch": self._state.chan_ax},
246
+ )
247
+
248
+
249
+ class SklearnModelUnit(
250
+ BaseAdaptiveTransformerUnit[
251
+ SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor
252
+ ]
253
+ ):
254
+ """
255
+ Unit wrapper for the `SklearnModelProcessor`.
256
+
257
+ This unit provides a plug-and-play interface for using a scikit-learn or River model
258
+ in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs predictions
259
+ in the same format, optionally performing training via `partial_fit` or `fit`.
260
+
261
+ Example:
262
+ --------
263
+ ```python
264
+ unit = SklearnModelUnit(
265
+ settings=SklearnModelSettings(
266
+ model_class='sklearn.linear_model.SGDClassifier',
267
+ model_kwargs={'loss': 'log_loss'},
268
+ partial_fit_classes=np.array([0, 1]),
269
+ )
270
+ )
271
+ ```
272
+ """
273
+
274
+ SETTINGS = SklearnModelSettings