ezmsg-learn 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__init__.py +2 -0
- ezmsg/learn/__version__.py +34 -0
- ezmsg/learn/dim_reduce/__init__.py +0 -0
- ezmsg/learn/dim_reduce/adaptive_decomp.py +274 -0
- ezmsg/learn/dim_reduce/incremental_decomp.py +173 -0
- ezmsg/learn/linear_model/__init__.py +1 -0
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +12 -0
- ezmsg/learn/linear_model/cca.py +1 -0
- ezmsg/learn/linear_model/linear_regressor.py +9 -0
- ezmsg/learn/linear_model/sgd.py +9 -0
- ezmsg/learn/linear_model/slda.py +12 -0
- ezmsg/learn/model/__init__.py +0 -0
- ezmsg/learn/model/cca.py +122 -0
- ezmsg/learn/model/mlp.py +127 -0
- ezmsg/learn/model/mlp_old.py +49 -0
- ezmsg/learn/model/refit_kalman.py +369 -0
- ezmsg/learn/model/rnn.py +160 -0
- ezmsg/learn/model/transformer.py +175 -0
- ezmsg/learn/nlin_model/__init__.py +1 -0
- ezmsg/learn/nlin_model/mlp.py +10 -0
- ezmsg/learn/process/__init__.py +0 -0
- ezmsg/learn/process/adaptive_linear_regressor.py +142 -0
- ezmsg/learn/process/base.py +154 -0
- ezmsg/learn/process/linear_regressor.py +95 -0
- ezmsg/learn/process/mlp_old.py +188 -0
- ezmsg/learn/process/refit_kalman.py +403 -0
- ezmsg/learn/process/rnn.py +245 -0
- ezmsg/learn/process/sgd.py +117 -0
- ezmsg/learn/process/sklearn.py +241 -0
- ezmsg/learn/process/slda.py +110 -0
- ezmsg/learn/process/ssr.py +374 -0
- ezmsg/learn/process/torch.py +362 -0
- ezmsg/learn/process/transformer.py +215 -0
- ezmsg/learn/util.py +67 -0
- ezmsg_learn-1.1.0.dist-info/METADATA +30 -0
- ezmsg_learn-1.1.0.dist-info/RECORD +38 -0
- ezmsg_learn-1.1.0.dist-info/WHEEL +4 -0
- ezmsg_learn-1.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
|
|
7
|
+
from ezmsg.baseproc.util.profile import profile_subpub
|
|
8
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
9
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
11
|
+
|
|
12
|
+
from .base import ModelInitMixin
|
|
13
|
+
from .torch import (
|
|
14
|
+
TorchModelSettings,
|
|
15
|
+
TorchModelState,
|
|
16
|
+
TorchProcessorMixin,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RNNSettings(TorchModelSettings):
|
|
21
|
+
model_class: str = "ezmsg.learn.model.rnn.RNNModel"
|
|
22
|
+
"""
|
|
23
|
+
Fully qualified class path of the model to be used.
|
|
24
|
+
This should be "ezmsg.learn.model.rnn.RNNModel" for this.
|
|
25
|
+
"""
|
|
26
|
+
reset_hidden_on_fit: bool = True
|
|
27
|
+
"""
|
|
28
|
+
Whether to reset the hidden state on each fit call.
|
|
29
|
+
If True, the hidden state will be reset to zero after each fit.
|
|
30
|
+
If False, the hidden state will be maintained across fit calls.
|
|
31
|
+
"""
|
|
32
|
+
preserve_state_across_windows: bool | typing.Literal["auto"] = "auto"
|
|
33
|
+
"""
|
|
34
|
+
Whether to preserve the hidden state across windows.
|
|
35
|
+
If True, the hidden state will be preserved across windows.
|
|
36
|
+
If False, the hidden state will be reset at the start of each window.
|
|
37
|
+
If "auto", preserve if there is no overlap in time windows, otherwise reset.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RNNState(TorchModelState):
|
|
42
|
+
hx: typing.Optional[torch.Tensor] = None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RNNProcessor(
|
|
46
|
+
BaseAdaptiveTransformer[RNNSettings, AxisArray, AxisArray, RNNState],
|
|
47
|
+
TorchProcessorMixin,
|
|
48
|
+
ModelInitMixin,
|
|
49
|
+
):
|
|
50
|
+
def _infer_output_sizes(self, model: torch.nn.Module, n_input: int) -> dict[str, int]:
|
|
51
|
+
"""Simple inference to get output channel size."""
|
|
52
|
+
dummy_input = torch.zeros(1, 50, n_input, device=self._state.device)
|
|
53
|
+
with torch.no_grad():
|
|
54
|
+
output, _ = model(dummy_input)
|
|
55
|
+
|
|
56
|
+
if isinstance(output, dict):
|
|
57
|
+
return {k: v.shape[-1] for k, v in output.items()}
|
|
58
|
+
else:
|
|
59
|
+
return {"output": output.shape[-1]}
|
|
60
|
+
|
|
61
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
62
|
+
model_kwargs = dict(self.settings.model_kwargs or {})
|
|
63
|
+
self._common_reset_state(message, model_kwargs)
|
|
64
|
+
self._init_optimizer()
|
|
65
|
+
self._validate_loss_keys(list(self._state.chan_ax.keys()))
|
|
66
|
+
|
|
67
|
+
batch_size = 1 if message.data.ndim == 2 else message.data.shape[0]
|
|
68
|
+
self.reset_hidden(batch_size)
|
|
69
|
+
|
|
70
|
+
def _maybe_reset_state(self, message: AxisArray, batch_size: int) -> bool:
|
|
71
|
+
preserve_state = self.settings.preserve_state_across_windows
|
|
72
|
+
if preserve_state == "auto":
|
|
73
|
+
axes = message.axes
|
|
74
|
+
if batch_size < 2:
|
|
75
|
+
# Single window, so preserve
|
|
76
|
+
preserve_state = True
|
|
77
|
+
elif "time" not in axes or "win" not in axes:
|
|
78
|
+
# Default fallback
|
|
79
|
+
ez.logger.warning("Missing 'time' or 'win' axis for auto preserve-state logic. Defaulting to reset.")
|
|
80
|
+
preserve_state = False
|
|
81
|
+
else:
|
|
82
|
+
# Calculate stride between windows (assuming uniform spacing)
|
|
83
|
+
win_stride = axes["win"].gain
|
|
84
|
+
# Calculate window length from time axis length and gain
|
|
85
|
+
time_len = message.data.shape[message.get_axis_idx("time")]
|
|
86
|
+
gain = getattr(axes["time"], "gain", None)
|
|
87
|
+
if gain is None:
|
|
88
|
+
ez.logger.warning("Time axis gain not found, using default gain of 1.0.")
|
|
89
|
+
gain = 1.0 # fallback default
|
|
90
|
+
win_len = time_len * gain
|
|
91
|
+
# Determine if we should preserve state
|
|
92
|
+
preserve_state = win_stride >= win_len
|
|
93
|
+
|
|
94
|
+
# Preserve if windows do NOT overlap: stride >= window length
|
|
95
|
+
if not preserve_state:
|
|
96
|
+
self.reset_hidden(batch_size)
|
|
97
|
+
else:
|
|
98
|
+
# If preserving state, only reset if batch size isn't 1
|
|
99
|
+
hx_batch_size = self._state.hx[0].shape[1] if isinstance(self._state.hx, tuple) else self._state.hx.shape[1]
|
|
100
|
+
if hx_batch_size != 1:
|
|
101
|
+
ez.logger.debug(f"Resetting hidden state due to batch size mismatch (hx: {hx_batch_size}, new: 1)")
|
|
102
|
+
self.reset_hidden(1)
|
|
103
|
+
return preserve_state
|
|
104
|
+
|
|
105
|
+
def _process(self, message: AxisArray) -> list[AxisArray]:
|
|
106
|
+
x = message.data
|
|
107
|
+
if not isinstance(x, torch.Tensor):
|
|
108
|
+
x = torch.tensor(
|
|
109
|
+
x,
|
|
110
|
+
dtype=torch.float32 if self.settings.single_precision else torch.float64,
|
|
111
|
+
device=self._state.device,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Add batch dimension if missing
|
|
115
|
+
x, added_batch_dim = self._ensure_batched(x)
|
|
116
|
+
|
|
117
|
+
batch_size = x.shape[0]
|
|
118
|
+
preserve_state = self._maybe_reset_state(message, batch_size)
|
|
119
|
+
|
|
120
|
+
with torch.no_grad():
|
|
121
|
+
# If we are preserving state and have multiple batches, process sequentially
|
|
122
|
+
if preserve_state and batch_size > 1:
|
|
123
|
+
y_data = {}
|
|
124
|
+
for x_batch in x:
|
|
125
|
+
x_batch = x_batch.unsqueeze(0)
|
|
126
|
+
y, self._state.hx = self._state.model(x_batch, hx=self._state.hx)
|
|
127
|
+
for key, out in y.items():
|
|
128
|
+
if key not in y_data:
|
|
129
|
+
y_data[key] = []
|
|
130
|
+
y_data[key].append(out.cpu().numpy())
|
|
131
|
+
# Concatenate outputs for each key
|
|
132
|
+
y_data = {key: np.concatenate(outputs, axis=0) for key, outputs in y_data.items()}
|
|
133
|
+
else:
|
|
134
|
+
y, self._state.hx = self._state.model(x, hx=self._state.hx)
|
|
135
|
+
y_data = {
|
|
136
|
+
key: (out.cpu().numpy().squeeze(0) if added_batch_dim else out.cpu().numpy())
|
|
137
|
+
for key, out in y.items()
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
return [
|
|
141
|
+
replace(
|
|
142
|
+
message,
|
|
143
|
+
data=out,
|
|
144
|
+
axes={**message.axes, "ch": self._state.chan_ax[key]},
|
|
145
|
+
key=key,
|
|
146
|
+
)
|
|
147
|
+
for key, out in y_data.items()
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
def reset_hidden(self, batch_size: int) -> None:
|
|
151
|
+
self._state.hx = self._state.model.init_hidden(batch_size, self._state.device)
|
|
152
|
+
|
|
153
|
+
def _train_step(
|
|
154
|
+
self,
|
|
155
|
+
X: torch.Tensor,
|
|
156
|
+
y_targ: dict[str, torch.Tensor],
|
|
157
|
+
loss_fns: dict[str, torch.nn.Module],
|
|
158
|
+
) -> None:
|
|
159
|
+
y_pred, self._state.hx = self._state.model(X, hx=self._state.hx)
|
|
160
|
+
if not isinstance(y_pred, dict):
|
|
161
|
+
y_pred = {"output": y_pred}
|
|
162
|
+
|
|
163
|
+
loss_weights = self.settings.loss_weights or {}
|
|
164
|
+
losses = []
|
|
165
|
+
for key in y_targ.keys():
|
|
166
|
+
loss_fn = loss_fns.get(key)
|
|
167
|
+
if loss_fn is None:
|
|
168
|
+
raise ValueError(f"Loss function for key '{key}' is not defined.")
|
|
169
|
+
if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
|
|
170
|
+
loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long())
|
|
171
|
+
else:
|
|
172
|
+
loss = loss_fn(y_pred[key], y_targ[key])
|
|
173
|
+
weight = loss_weights.get(key, 1.0)
|
|
174
|
+
losses.append(loss * weight)
|
|
175
|
+
|
|
176
|
+
total_loss = sum(losses)
|
|
177
|
+
ez.logger.debug(
|
|
178
|
+
f"Training step loss: {total_loss.item()} (individual losses: {[loss.item() for loss in losses]})"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
self._state.optimizer.zero_grad()
|
|
182
|
+
total_loss.backward()
|
|
183
|
+
self._state.optimizer.step()
|
|
184
|
+
if self._state.scheduler is not None:
|
|
185
|
+
self._state.scheduler.step()
|
|
186
|
+
|
|
187
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
188
|
+
self._state.model.train()
|
|
189
|
+
|
|
190
|
+
X = self._to_tensor(message.sample.data)
|
|
191
|
+
|
|
192
|
+
# Add batch dimension if missing
|
|
193
|
+
X, batched = self._ensure_batched(X)
|
|
194
|
+
|
|
195
|
+
batch_size = X.shape[0]
|
|
196
|
+
preserve_state = self._maybe_reset_state(message.sample, batch_size)
|
|
197
|
+
|
|
198
|
+
y_targ = message.trigger.value
|
|
199
|
+
if not isinstance(y_targ, dict):
|
|
200
|
+
y_targ = {"output": y_targ}
|
|
201
|
+
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
|
|
202
|
+
# Add batch dimension to y_targ values if missing
|
|
203
|
+
if batched:
|
|
204
|
+
for key in y_targ:
|
|
205
|
+
y_targ[key] = y_targ[key].unsqueeze(0)
|
|
206
|
+
|
|
207
|
+
loss_fns = self.settings.loss_fn
|
|
208
|
+
if loss_fns is None:
|
|
209
|
+
raise ValueError("loss_fn must be provided in settings to use partial_fit")
|
|
210
|
+
if not isinstance(loss_fns, dict):
|
|
211
|
+
loss_fns = {k: loss_fns for k in y_targ.keys()}
|
|
212
|
+
|
|
213
|
+
with torch.set_grad_enabled(True):
|
|
214
|
+
if preserve_state:
|
|
215
|
+
self._train_step(X, y_targ, loss_fns)
|
|
216
|
+
else:
|
|
217
|
+
for i in range(batch_size):
|
|
218
|
+
self._train_step(
|
|
219
|
+
X[i].unsqueeze(0),
|
|
220
|
+
{key: value[i].unsqueeze(0) for key, value in y_targ.items()},
|
|
221
|
+
loss_fns,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
self._state.model.eval()
|
|
225
|
+
if self.settings.reset_hidden_on_fit:
|
|
226
|
+
self.reset_hidden(X.shape[0])
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class RNNUnit(
|
|
230
|
+
BaseAdaptiveTransformerUnit[
|
|
231
|
+
RNNSettings,
|
|
232
|
+
AxisArray,
|
|
233
|
+
AxisArray,
|
|
234
|
+
RNNProcessor,
|
|
235
|
+
]
|
|
236
|
+
):
|
|
237
|
+
SETTINGS = RNNSettings
|
|
238
|
+
|
|
239
|
+
@ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
|
|
240
|
+
@ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
|
|
241
|
+
@profile_subpub(trace_oldest=False)
|
|
242
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
243
|
+
results = await self.processor.__acall__(message)
|
|
244
|
+
for result in results:
|
|
245
|
+
yield self.OUTPUT_SIGNAL, result
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ezmsg.baseproc import (
|
|
6
|
+
BaseAdaptiveTransformer,
|
|
7
|
+
BaseAdaptiveTransformerUnit,
|
|
8
|
+
SampleMessage,
|
|
9
|
+
processor_state,
|
|
10
|
+
)
|
|
11
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
12
|
+
from ezmsg.util.messages.util import replace
|
|
13
|
+
from sklearn.exceptions import NotFittedError
|
|
14
|
+
from sklearn.linear_model import SGDClassifier
|
|
15
|
+
|
|
16
|
+
from ..util import ClassifierMessage
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SGDDecoderSettings(ez.Settings):
|
|
20
|
+
alpha: float = 1e-5
|
|
21
|
+
eta0: float = 3e-4
|
|
22
|
+
loss: str = "hinge"
|
|
23
|
+
label_weights: dict[str, float] | None = None
|
|
24
|
+
settings_path: str | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@processor_state
|
|
28
|
+
class SGDDecoderState:
|
|
29
|
+
model: typing.Any = None
|
|
30
|
+
b_first_train: bool = True
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class SGDDecoderTransformer(BaseAdaptiveTransformer[SGDDecoderSettings, AxisArray, ClassifierMessage, SGDDecoderState]):
|
|
34
|
+
"""
|
|
35
|
+
SGD-based online classifier.
|
|
36
|
+
|
|
37
|
+
Online Passive-Aggressive Algorithms
|
|
38
|
+
<http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
|
|
39
|
+
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def _refreshed_model(self):
|
|
43
|
+
if self.settings.settings_path is not None:
|
|
44
|
+
import pickle
|
|
45
|
+
|
|
46
|
+
with open(self.settings.settings_path, "rb") as f:
|
|
47
|
+
model = pickle.load(f)
|
|
48
|
+
if self.settings.label_weights is not None:
|
|
49
|
+
model.class_weight = self.settings.label_weights
|
|
50
|
+
model.eta0 = self.settings.eta0
|
|
51
|
+
else:
|
|
52
|
+
model = SGDClassifier(
|
|
53
|
+
loss=self.settings.loss,
|
|
54
|
+
alpha=self.settings.alpha,
|
|
55
|
+
penalty="elasticnet",
|
|
56
|
+
learning_rate="adaptive",
|
|
57
|
+
eta0=self.settings.eta0,
|
|
58
|
+
early_stopping=False,
|
|
59
|
+
class_weight=self.settings.label_weights,
|
|
60
|
+
)
|
|
61
|
+
return model
|
|
62
|
+
|
|
63
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
64
|
+
self._state.model = self._refreshed_model()
|
|
65
|
+
|
|
66
|
+
def _process(self, message: AxisArray) -> ClassifierMessage | None:
|
|
67
|
+
if self._state.model is None or not message.data.size:
|
|
68
|
+
return None
|
|
69
|
+
if np.any(np.isnan(message.data)):
|
|
70
|
+
return None
|
|
71
|
+
try:
|
|
72
|
+
X = message.data.reshape((message.data.shape[0], -1))
|
|
73
|
+
result = self._state.model._predict_proba_lr(X)
|
|
74
|
+
except NotFittedError:
|
|
75
|
+
return None
|
|
76
|
+
out_axes = {}
|
|
77
|
+
if message.dims[0] in message.axes:
|
|
78
|
+
out_axes[message.dims[0]] = replace(
|
|
79
|
+
message.axes[message.dims[0]],
|
|
80
|
+
offset=message.axes[message.dims[0]].offset,
|
|
81
|
+
)
|
|
82
|
+
return ClassifierMessage(
|
|
83
|
+
data=result,
|
|
84
|
+
dims=message.dims[:1] + ["labels"],
|
|
85
|
+
axes=out_axes,
|
|
86
|
+
labels=list(self._state.model.class_weight.keys()),
|
|
87
|
+
key=message.key,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
91
|
+
if self._hash != 0:
|
|
92
|
+
self._reset_state(message.sample)
|
|
93
|
+
self._hash = 0
|
|
94
|
+
|
|
95
|
+
if np.any(np.isnan(message.sample.data)):
|
|
96
|
+
return
|
|
97
|
+
train_sample = message.sample.data.reshape(1, -1)
|
|
98
|
+
if self._state.b_first_train:
|
|
99
|
+
self._state.model.partial_fit(
|
|
100
|
+
train_sample,
|
|
101
|
+
[message.trigger.value],
|
|
102
|
+
classes=list(self.settings.label_weights.keys()),
|
|
103
|
+
)
|
|
104
|
+
self._state.b_first_train = False
|
|
105
|
+
else:
|
|
106
|
+
self._state.model.partial_fit(train_sample, [message.trigger.value])
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class SGDDecoder(
|
|
110
|
+
BaseAdaptiveTransformerUnit[
|
|
111
|
+
SGDDecoderSettings,
|
|
112
|
+
AxisArray,
|
|
113
|
+
ClassifierMessage,
|
|
114
|
+
SGDDecoderTransformer,
|
|
115
|
+
]
|
|
116
|
+
):
|
|
117
|
+
SETTINGS = SGDDecoderSettings
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import pickle
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from ezmsg.baseproc import (
|
|
9
|
+
BaseAdaptiveTransformer,
|
|
10
|
+
BaseAdaptiveTransformerUnit,
|
|
11
|
+
processor_state,
|
|
12
|
+
)
|
|
13
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
14
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
15
|
+
from ezmsg.util.messages.util import replace
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SklearnModelSettings(ez.Settings):
|
|
19
|
+
model_class: str
|
|
20
|
+
"""
|
|
21
|
+
Full path to the sklearn model class
|
|
22
|
+
Example: 'sklearn.linear_model.LinearRegression'
|
|
23
|
+
"""
|
|
24
|
+
model_kwargs: dict[str, typing.Any] = None
|
|
25
|
+
"""
|
|
26
|
+
Additional keyword arguments to pass to the model constructor.
|
|
27
|
+
Example: {'fit_intercept': True, 'normalize': False}
|
|
28
|
+
"""
|
|
29
|
+
checkpoint_path: str | None = None
|
|
30
|
+
"""
|
|
31
|
+
Path to a checkpoint file to load the model from.
|
|
32
|
+
If provided, the model will be initialized from this checkpoint.
|
|
33
|
+
Example: 'path/to/checkpoint.pkl'
|
|
34
|
+
"""
|
|
35
|
+
partial_fit_classes: np.ndarray | None = None
|
|
36
|
+
"""
|
|
37
|
+
The full list of classes to use for partial_fit calls.
|
|
38
|
+
This must be provided to use `partial_fit` with classifiers.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@processor_state
|
|
43
|
+
class SklearnModelState:
|
|
44
|
+
model: typing.Any = None
|
|
45
|
+
chan_ax: AxisArray.CoordinateAxis | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SklearnModelProcessor(BaseAdaptiveTransformer[SklearnModelSettings, AxisArray, AxisArray, SklearnModelState]):
|
|
49
|
+
"""
|
|
50
|
+
Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework.
|
|
51
|
+
|
|
52
|
+
This processor supports:
|
|
53
|
+
- `fit`, `partial_fit`, or River's `learn_many`/`learn_one` for training.
|
|
54
|
+
- `predict`, River's `predict_many`, or `predict_one` for inference.
|
|
55
|
+
- Optional model checkpoint loading and saving.
|
|
56
|
+
|
|
57
|
+
The processor expects and outputs `AxisArray` messages with a `"ch"` (channel) axis.
|
|
58
|
+
|
|
59
|
+
Settings:
|
|
60
|
+
---------
|
|
61
|
+
model_class : str
|
|
62
|
+
Full path to the sklearn or River model class to use.
|
|
63
|
+
Example: "sklearn.linear_model.SGDClassifier" or "river.linear_model.LogisticRegression"
|
|
64
|
+
|
|
65
|
+
model_kwargs : dict[str, typing.Any], optional
|
|
66
|
+
Additional keyword arguments passed to the model constructor.
|
|
67
|
+
|
|
68
|
+
checkpoint_path : str, optional
|
|
69
|
+
Path to a pickle file to load a previously saved model. If provided, the model will
|
|
70
|
+
be restored from this path at startup.
|
|
71
|
+
|
|
72
|
+
partial_fit_classes : np.ndarray, optional
|
|
73
|
+
For classifiers that require all class labels to be specified during `partial_fit`.
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
-----------------------------
|
|
77
|
+
```python
|
|
78
|
+
processor = SklearnModelProcessor(
|
|
79
|
+
settings=SklearnModelSettings(
|
|
80
|
+
model_class='sklearn.linear_model.SGDClassifier',
|
|
81
|
+
model_kwargs={'loss': 'log_loss'},
|
|
82
|
+
partial_fit_classes=np.array([0, 1]),
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
```
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def _init_model(self) -> None:
|
|
89
|
+
module_path, class_name = self.settings.model_class.rsplit(".", 1)
|
|
90
|
+
model_cls = getattr(importlib.import_module(module_path), class_name)
|
|
91
|
+
kwargs = self.settings.model_kwargs or {}
|
|
92
|
+
self._state.model = model_cls(**kwargs)
|
|
93
|
+
|
|
94
|
+
def save_checkpoint(self, path: str) -> None:
|
|
95
|
+
with open(path, "wb") as f:
|
|
96
|
+
pickle.dump(self._state.model, f)
|
|
97
|
+
|
|
98
|
+
def load_checkpoint(self, path: str) -> None:
|
|
99
|
+
try:
|
|
100
|
+
with open(path, "rb") as f:
|
|
101
|
+
self._state.model = pickle.load(f)
|
|
102
|
+
except Exception as e:
|
|
103
|
+
ez.logger.error(f"Failed to load model from {path}: {str(e)}")
|
|
104
|
+
raise RuntimeError(f"Failed to load model from {path}: {str(e)}") from e
|
|
105
|
+
|
|
106
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
107
|
+
# Try loading from checkpoint first
|
|
108
|
+
if self.settings.checkpoint_path:
|
|
109
|
+
self.load_checkpoint(self.settings.checkpoint_path)
|
|
110
|
+
n_input = message.data.shape[message.get_axis_idx("ch")]
|
|
111
|
+
if hasattr(self._state.model, "n_features_in_"):
|
|
112
|
+
expected = self._state.model.n_features_in_
|
|
113
|
+
if expected != n_input:
|
|
114
|
+
raise ValueError(f"Model expects {expected} features, but got {n_input}")
|
|
115
|
+
else:
|
|
116
|
+
# No checkpoint, initialize from scratch
|
|
117
|
+
self._init_model()
|
|
118
|
+
|
|
119
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
120
|
+
X = message.sample.data
|
|
121
|
+
y = message.trigger.value
|
|
122
|
+
if self._state.model is None:
|
|
123
|
+
self._reset_state(message.sample)
|
|
124
|
+
if hasattr(self._state.model, "partial_fit"):
|
|
125
|
+
kwargs = {}
|
|
126
|
+
if self.settings.partial_fit_classes is not None:
|
|
127
|
+
kwargs["classes"] = self.settings.partial_fit_classes
|
|
128
|
+
self._state.model.partial_fit(X, y, **kwargs)
|
|
129
|
+
elif hasattr(self._state.model, "learn_many"):
|
|
130
|
+
df_X = pd.DataFrame({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
|
|
131
|
+
name = (
|
|
132
|
+
message.trigger.value.axes["ch"].data[0]
|
|
133
|
+
if hasattr(message.trigger.value, "axes") and "ch" in message.trigger.value.axes
|
|
134
|
+
else "target"
|
|
135
|
+
)
|
|
136
|
+
ser_y = pd.Series(
|
|
137
|
+
data=np.asarray(message.trigger.value.data).flatten(),
|
|
138
|
+
name=name,
|
|
139
|
+
)
|
|
140
|
+
self._state.model.learn_many(df_X, ser_y)
|
|
141
|
+
elif hasattr(self._state.model, "learn_one"):
|
|
142
|
+
# river's random forest does not support learn_many
|
|
143
|
+
for xi, yi in zip(X, y):
|
|
144
|
+
features = {f"f{i}": xi[i] for i in range(len(xi))}
|
|
145
|
+
self._state.model.learn_one(features, yi)
|
|
146
|
+
else:
|
|
147
|
+
raise NotImplementedError("Model does not support partial_fit or learn_many")
|
|
148
|
+
|
|
149
|
+
def fit(self, X: np.ndarray, y: np.ndarray) -> None:
|
|
150
|
+
if self._state.model is None:
|
|
151
|
+
dummy_msg = AxisArray(
|
|
152
|
+
data=X,
|
|
153
|
+
dims=["time", "ch"],
|
|
154
|
+
axes={
|
|
155
|
+
"time": AxisArray.TimeAxis(fs=1.0),
|
|
156
|
+
"ch": AxisArray.CoordinateAxis(
|
|
157
|
+
data=np.array([f"ch_{i}" for i in range(X.shape[1])]),
|
|
158
|
+
dims=["ch"],
|
|
159
|
+
),
|
|
160
|
+
},
|
|
161
|
+
)
|
|
162
|
+
self._reset_state(dummy_msg)
|
|
163
|
+
if hasattr(self._state.model, "fit"):
|
|
164
|
+
self._state.model.fit(X, y)
|
|
165
|
+
elif hasattr(self._state.model, "learn_many"):
|
|
166
|
+
df_X = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])])
|
|
167
|
+
ser_y = pd.Series(y.flatten(), name="target")
|
|
168
|
+
self._state.model.learn_many(df_X, ser_y)
|
|
169
|
+
elif hasattr(self._state.model, "learn_one"):
|
|
170
|
+
# river's random forest does not support learn_many
|
|
171
|
+
for xi, yi in zip(X, y):
|
|
172
|
+
features = {f"f{i}": xi[i] for i in range(len(xi))}
|
|
173
|
+
self._state.model.learn_one(features, yi)
|
|
174
|
+
else:
|
|
175
|
+
raise NotImplementedError("Model does not support fit or learn_many")
|
|
176
|
+
|
|
177
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
178
|
+
if self._state.model is None:
|
|
179
|
+
raise RuntimeError("Model has not been fit yet. Call `fit()` or `partial_fit()` before processing.")
|
|
180
|
+
X = message.data
|
|
181
|
+
original_shape = X.shape
|
|
182
|
+
n_input = X.shape[message.get_axis_idx("ch")]
|
|
183
|
+
|
|
184
|
+
# Ensure X is 2D
|
|
185
|
+
X = X.reshape(-1, n_input)
|
|
186
|
+
if hasattr(self._state.model, "n_features_in_"):
|
|
187
|
+
expected = self._state.model.n_features_in_
|
|
188
|
+
if expected != n_input:
|
|
189
|
+
raise ValueError(f"Model expects {expected} features, but got {n_input}")
|
|
190
|
+
|
|
191
|
+
if hasattr(self._state.model, "predict"):
|
|
192
|
+
y_pred = self._state.model.predict(X)
|
|
193
|
+
elif hasattr(self._state.model, "predict_many"):
|
|
194
|
+
df_X = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])])
|
|
195
|
+
y_pred = self._state.model.predict_many(df_X)
|
|
196
|
+
y_pred = np.array(list(y_pred))
|
|
197
|
+
elif hasattr(self._state.model, "predict_one"):
|
|
198
|
+
# river's random forest does not support predict_many
|
|
199
|
+
y_pred = np.array([self._state.model.predict_one({f"f{i}": xi[i] for i in range(len(xi))}) for xi in X])
|
|
200
|
+
else:
|
|
201
|
+
raise NotImplementedError("Model does not support predict or predict_many")
|
|
202
|
+
|
|
203
|
+
# For scalar outputs, ensure the output is 2D
|
|
204
|
+
if y_pred.ndim == 1:
|
|
205
|
+
y_pred = y_pred[:, np.newaxis]
|
|
206
|
+
|
|
207
|
+
output_shape = original_shape[:-1] + (y_pred.shape[-1],)
|
|
208
|
+
y_pred = y_pred.reshape(output_shape)
|
|
209
|
+
|
|
210
|
+
if self._state.chan_ax is None:
|
|
211
|
+
self._state.chan_ax = AxisArray.CoordinateAxis(data=np.arange(output_shape[1]), dims=["ch"])
|
|
212
|
+
|
|
213
|
+
return replace(
|
|
214
|
+
message,
|
|
215
|
+
data=y_pred,
|
|
216
|
+
axes={**message.axes, "ch": self._state.chan_ax},
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class SklearnModelUnit(BaseAdaptiveTransformerUnit[SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor]):
|
|
221
|
+
"""
|
|
222
|
+
Unit wrapper for the `SklearnModelProcessor`.
|
|
223
|
+
|
|
224
|
+
This unit provides a plug-and-play interface for using a scikit-learn or River model
|
|
225
|
+
in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs predictions
|
|
226
|
+
in the same format, optionally performing training via `partial_fit` or `fit`.
|
|
227
|
+
|
|
228
|
+
Example:
|
|
229
|
+
--------
|
|
230
|
+
```python
|
|
231
|
+
unit = SklearnModelUnit(
|
|
232
|
+
settings=SklearnModelSettings(
|
|
233
|
+
model_class='sklearn.linear_model.SGDClassifier',
|
|
234
|
+
model_kwargs={'loss': 'log_loss'},
|
|
235
|
+
partial_fit_classes=np.array([0, 1]),
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
```
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
SETTINGS = SklearnModelSettings
|