ezmsg-learn 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__version__.py +2 -2
- ezmsg/learn/dim_reduce/adaptive_decomp.py +9 -19
- ezmsg/learn/dim_reduce/incremental_decomp.py +8 -16
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
- ezmsg/learn/linear_model/linear_regressor.py +4 -0
- ezmsg/learn/linear_model/sgd.py +6 -2
- ezmsg/learn/linear_model/slda.py +7 -1
- ezmsg/learn/model/mlp.py +8 -14
- ezmsg/learn/model/refit_kalman.py +17 -49
- ezmsg/learn/nlin_model/mlp.py +5 -1
- ezmsg/learn/process/adaptive_linear_regressor.py +20 -36
- ezmsg/learn/process/base.py +12 -31
- ezmsg/learn/process/linear_regressor.py +13 -18
- ezmsg/learn/process/mlp_old.py +18 -31
- ezmsg/learn/process/refit_kalman.py +8 -13
- ezmsg/learn/process/rnn.py +14 -36
- ezmsg/learn/process/sgd.py +94 -109
- ezmsg/learn/process/sklearn.py +17 -51
- ezmsg/learn/process/slda.py +6 -15
- ezmsg/learn/process/ssr.py +374 -0
- ezmsg/learn/process/torch.py +12 -29
- ezmsg/learn/process/transformer.py +11 -19
- ezmsg/learn/util.py +5 -4
- {ezmsg_learn-1.0.dist-info → ezmsg_learn-1.2.0.dist-info}/METADATA +5 -9
- ezmsg_learn-1.2.0.dist-info/RECORD +38 -0
- {ezmsg_learn-1.0.dist-info → ezmsg_learn-1.2.0.dist-info}/WHEEL +1 -1
- ezmsg_learn-1.2.0.dist-info/licenses/LICENSE +21 -0
- ezmsg_learn-1.0.dist-info/RECORD +0 -36
|
@@ -1,17 +1,16 @@
|
|
|
1
1
|
from dataclasses import field
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
|
-
from sklearn.linear_model._base import LinearModel
|
|
5
3
|
import ezmsg.core as ez
|
|
6
|
-
|
|
7
|
-
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ezmsg.baseproc import (
|
|
8
6
|
BaseAdaptiveTransformer,
|
|
9
7
|
BaseAdaptiveTransformerUnit,
|
|
8
|
+
processor_state,
|
|
10
9
|
)
|
|
11
10
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
12
|
-
from
|
|
11
|
+
from sklearn.linear_model._base import LinearModel
|
|
13
12
|
|
|
14
|
-
from ..util import
|
|
13
|
+
from ..util import RegressorType, StaticLinearRegressor, get_regressor
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class LinearRegressorSettings(ez.Settings):
|
|
@@ -27,9 +26,7 @@ class LinearRegressorState:
|
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
class LinearRegressorTransformer(
|
|
30
|
-
BaseAdaptiveTransformer[
|
|
31
|
-
LinearRegressorSettings, AxisArray, AxisArray, LinearRegressorState
|
|
32
|
-
]
|
|
29
|
+
BaseAdaptiveTransformer[LinearRegressorSettings, AxisArray, AxisArray, LinearRegressorState]
|
|
33
30
|
):
|
|
34
31
|
"""
|
|
35
32
|
Linear regressor.
|
|
@@ -47,9 +44,7 @@ class LinearRegressorTransformer(
|
|
|
47
44
|
with open(self.settings.settings_path, "rb") as f:
|
|
48
45
|
self.state.model = pickle.load(f)
|
|
49
46
|
else:
|
|
50
|
-
regressor_klass = get_regressor(
|
|
51
|
-
RegressorType.STATIC, self.settings.model_type
|
|
52
|
-
)
|
|
47
|
+
regressor_klass = get_regressor(RegressorType.STATIC, self.settings.model_type)
|
|
53
48
|
self.state.model = regressor_klass(**self.settings.model_kwargs)
|
|
54
49
|
|
|
55
50
|
def _reset_state(self, message: AxisArray) -> None:
|
|
@@ -57,18 +52,18 @@ class LinearRegressorTransformer(
|
|
|
57
52
|
# .model and .template are initialized in __init__
|
|
58
53
|
pass
|
|
59
54
|
|
|
60
|
-
def partial_fit(self, message:
|
|
61
|
-
if np.any(np.isnan(message.
|
|
55
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
56
|
+
if np.any(np.isnan(message.data)):
|
|
62
57
|
return
|
|
63
58
|
|
|
64
|
-
X = message.
|
|
65
|
-
y = message.trigger.value.data
|
|
59
|
+
X = message.data
|
|
60
|
+
y = message.attrs["trigger"].value.data
|
|
66
61
|
# TODO: Resample should provide identical durations.
|
|
67
62
|
self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]])
|
|
68
63
|
self.state.template = replace(
|
|
69
|
-
message.trigger.value,
|
|
64
|
+
message.attrs["trigger"].value,
|
|
70
65
|
data=np.array([[]]),
|
|
71
|
-
key=message.trigger.value.key + "_pred",
|
|
66
|
+
key=message.attrs["trigger"].value.key + "_pred",
|
|
72
67
|
)
|
|
73
68
|
|
|
74
69
|
def _process(self, message: AxisArray) -> AxisArray:
|
ezmsg/learn/process/mlp_old.py
CHANGED
|
@@ -1,17 +1,16 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
|
+
import ezmsg.core as ez
|
|
3
4
|
import numpy as np
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn
|
|
6
|
-
|
|
7
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
9
|
-
from ezmsg.util.messages.util import replace
|
|
10
|
-
from ezmsg.sigproc.base import (
|
|
7
|
+
from ezmsg.baseproc import (
|
|
11
8
|
BaseAdaptiveTransformer,
|
|
12
9
|
BaseAdaptiveTransformerUnit,
|
|
13
10
|
processor_state,
|
|
14
11
|
)
|
|
12
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
|
+
from ezmsg.util.messages.util import replace
|
|
15
14
|
|
|
16
15
|
from ..model.mlp_old import MLP
|
|
17
16
|
|
|
@@ -24,10 +23,12 @@ class MLPSettings(ez.Settings):
|
|
|
24
23
|
"""Norm layer that will be stacked on top of the linear layer. If None this layer won’t be used."""
|
|
25
24
|
|
|
26
25
|
activation_layer: typing.Callable[..., torch.nn.Module] | None = torch.nn.ReLU
|
|
27
|
-
"""Activation function which will be stacked on top of the normalization layer (if not None),
|
|
26
|
+
"""Activation function which will be stacked on top of the normalization layer (if not None),
|
|
27
|
+
otherwise on top of the linear layer. If None this layer won’t be used."""
|
|
28
28
|
|
|
29
29
|
inplace: bool | None = None
|
|
30
|
-
"""Parameter for the activation layer, which can optionally do the operation in-place.
|
|
30
|
+
"""Parameter for the activation layer, which can optionally do the operation in-place.
|
|
31
|
+
Default is None, which uses the respective default values of the activation_layer and Dropout layer."""
|
|
31
32
|
|
|
32
33
|
bias: bool = True
|
|
33
34
|
"""Whether to use bias in the linear layer."""
|
|
@@ -58,9 +59,7 @@ class MLPState:
|
|
|
58
59
|
device: object | None = None
|
|
59
60
|
|
|
60
61
|
|
|
61
|
-
class MLPProcessor(
|
|
62
|
-
BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, MLPState]
|
|
63
|
-
):
|
|
62
|
+
class MLPProcessor(BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, MLPState]):
|
|
64
63
|
def _hash_message(self, message: AxisArray) -> int:
|
|
65
64
|
hash_items = (message.key,)
|
|
66
65
|
if "ch" in message.dims:
|
|
@@ -85,39 +84,29 @@ class MLPProcessor(
|
|
|
85
84
|
checkpoint = torch.load(self.settings.checkpoint_path)
|
|
86
85
|
self._state.model.load_state_dict(checkpoint["model_state_dict"])
|
|
87
86
|
except Exception as e:
|
|
88
|
-
raise RuntimeError(
|
|
89
|
-
f"Failed to load checkpoint from {self.settings.checkpoint_path}: {str(e)}"
|
|
90
|
-
)
|
|
87
|
+
raise RuntimeError(f"Failed to load checkpoint from {self.settings.checkpoint_path}: {str(e)}")
|
|
91
88
|
|
|
92
89
|
# Set the model to evaluation mode by default
|
|
93
90
|
self._state.model.eval()
|
|
94
91
|
|
|
95
92
|
# Create the optimizer
|
|
96
|
-
self._state.optimizer = torch.optim.Adam(
|
|
97
|
-
self._state.model.parameters(), lr=self.settings.learning_rate
|
|
98
|
-
)
|
|
93
|
+
self._state.optimizer = torch.optim.Adam(self._state.model.parameters(), lr=self.settings.learning_rate)
|
|
99
94
|
|
|
100
95
|
# Update the optimizer from checkpoint if it exists
|
|
101
96
|
if self.settings.checkpoint_path is not None:
|
|
102
97
|
try:
|
|
103
98
|
checkpoint = torch.load(self.settings.checkpoint_path)
|
|
104
99
|
if "optimizer_state_dict" in checkpoint:
|
|
105
|
-
self._state.optimizer.load_state_dict(
|
|
106
|
-
checkpoint["optimizer_state_dict"]
|
|
107
|
-
)
|
|
100
|
+
self._state.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
108
101
|
except Exception as e:
|
|
109
|
-
raise RuntimeError(
|
|
110
|
-
f"Failed to load optimizer from {self.settings.checkpoint_path}: {str(e)}"
|
|
111
|
-
)
|
|
102
|
+
raise RuntimeError(f"Failed to load optimizer from {self.settings.checkpoint_path}: {str(e)}")
|
|
112
103
|
|
|
113
104
|
# TODO: Should the model be moved to a device before the next line?
|
|
114
105
|
self._state.device = next(self.state.model.parameters()).device
|
|
115
106
|
|
|
116
107
|
# Optionally create the learning rate scheduler
|
|
117
108
|
self._state.scheduler = (
|
|
118
|
-
torch.optim.lr_scheduler.ExponentialLR(
|
|
119
|
-
self._state.optimizer, gamma=self.settings.scheduler_gamma
|
|
120
|
-
)
|
|
109
|
+
torch.optim.lr_scheduler.ExponentialLR(self._state.optimizer, gamma=self.settings.scheduler_gamma)
|
|
121
110
|
if self.settings.scheduler_gamma > 0.0
|
|
122
111
|
else None
|
|
123
112
|
)
|
|
@@ -144,14 +133,14 @@ class MLPProcessor(
|
|
|
144
133
|
dtype = torch.float32 if self.settings.single_precision else torch.float64
|
|
145
134
|
return torch.tensor(data, dtype=dtype, device=self._state.device)
|
|
146
135
|
|
|
147
|
-
def partial_fit(self, message:
|
|
136
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
148
137
|
self._state.model.train()
|
|
149
138
|
|
|
150
139
|
# TODO: loss_fn should be determined by setting
|
|
151
140
|
loss_fn = torch.nn.functional.mse_loss
|
|
152
141
|
|
|
153
|
-
X = self._to_tensor(message.
|
|
154
|
-
y_targ = self._to_tensor(message.trigger.value)
|
|
142
|
+
X = self._to_tensor(message.data)
|
|
143
|
+
y_targ = self._to_tensor(message.attrs["trigger"].value)
|
|
155
144
|
|
|
156
145
|
with torch.set_grad_enabled(True):
|
|
157
146
|
self._state.model.train()
|
|
@@ -171,9 +160,7 @@ class MLPProcessor(
|
|
|
171
160
|
if not isinstance(data, torch.Tensor):
|
|
172
161
|
data = torch.tensor(
|
|
173
162
|
data,
|
|
174
|
-
dtype=torch.float32
|
|
175
|
-
if self.settings.single_precision
|
|
176
|
-
else torch.float64,
|
|
163
|
+
dtype=torch.float32 if self.settings.single_precision else torch.float64,
|
|
177
164
|
)
|
|
178
165
|
|
|
179
166
|
with torch.no_grad():
|
|
@@ -3,12 +3,11 @@ from pathlib import Path
|
|
|
3
3
|
|
|
4
4
|
import ezmsg.core as ez
|
|
5
5
|
import numpy as np
|
|
6
|
-
from ezmsg.
|
|
6
|
+
from ezmsg.baseproc import (
|
|
7
7
|
BaseAdaptiveTransformer,
|
|
8
8
|
BaseAdaptiveTransformerUnit,
|
|
9
9
|
processor_state,
|
|
10
10
|
)
|
|
11
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
12
11
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
12
|
from ezmsg.util.messages.util import replace
|
|
14
13
|
|
|
@@ -284,27 +283,25 @@ class RefitKalmanFilterProcessor(
|
|
|
284
283
|
key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered",
|
|
285
284
|
)
|
|
286
285
|
|
|
287
|
-
def partial_fit(self, message:
|
|
286
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
288
287
|
"""
|
|
289
288
|
Perform refitting using externally provided data.
|
|
290
289
|
|
|
291
|
-
Expects message.
|
|
290
|
+
Expects message.data (neural input) and message.attrs["trigger"].value as a dict with:
|
|
292
291
|
- Y_state: (n_samples, n_states) array
|
|
293
292
|
- intention_velocity_indices: Optional[int]
|
|
294
293
|
- target_positions: Optional[np.ndarray]
|
|
295
294
|
- cursor_positions: Optional[np.ndarray]
|
|
296
295
|
- hold_flags: Optional[list[bool]]
|
|
297
296
|
"""
|
|
298
|
-
if
|
|
297
|
+
if "trigger" not in message.attrs:
|
|
299
298
|
raise ValueError("Invalid message format for partial_fit.")
|
|
300
299
|
|
|
301
|
-
X = np.array(message.
|
|
302
|
-
values = message.trigger.value
|
|
300
|
+
X = np.array(message.data)
|
|
301
|
+
values = message.attrs["trigger"].value
|
|
303
302
|
|
|
304
303
|
if not isinstance(values, dict) or "Y_state" not in values:
|
|
305
|
-
raise ValueError(
|
|
306
|
-
"partial_fit expects trigger.value to include at least 'Y_state'."
|
|
307
|
-
)
|
|
304
|
+
raise ValueError("partial_fit expects trigger.value to include at least 'Y_state'.")
|
|
308
305
|
|
|
309
306
|
kwargs = {
|
|
310
307
|
"X_neural": X,
|
|
@@ -319,9 +316,7 @@ class RefitKalmanFilterProcessor(
|
|
|
319
316
|
"hold_flags",
|
|
320
317
|
]:
|
|
321
318
|
if key in values and values[key] is not None:
|
|
322
|
-
kwargs[key if key != "hold_flags" else "hold_indices"] = np.array(
|
|
323
|
-
values[key]
|
|
324
|
-
)
|
|
319
|
+
kwargs[key if key != "hold_flags" else "hold_indices"] = np.array(values[key])
|
|
325
320
|
|
|
326
321
|
# Call model refit
|
|
327
322
|
self._state.model.refit(**kwargs)
|
ezmsg/learn/process/rnn.py
CHANGED
|
@@ -3,9 +3,8 @@ import typing
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
|
-
from ezmsg.
|
|
7
|
-
from ezmsg.
|
|
8
|
-
from ezmsg.sigproc.util.profile import profile_subpub
|
|
6
|
+
from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
|
|
7
|
+
from ezmsg.baseproc.util.profile import profile_subpub
|
|
9
8
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
9
|
from ezmsg.util.messages.util import replace
|
|
11
10
|
|
|
@@ -47,9 +46,7 @@ class RNNProcessor(
|
|
|
47
46
|
TorchProcessorMixin,
|
|
48
47
|
ModelInitMixin,
|
|
49
48
|
):
|
|
50
|
-
def _infer_output_sizes(
|
|
51
|
-
self, model: torch.nn.Module, n_input: int
|
|
52
|
-
) -> dict[str, int]:
|
|
49
|
+
def _infer_output_sizes(self, model: torch.nn.Module, n_input: int) -> dict[str, int]:
|
|
53
50
|
"""Simple inference to get output channel size."""
|
|
54
51
|
dummy_input = torch.zeros(1, 50, n_input, device=self._state.device)
|
|
55
52
|
with torch.no_grad():
|
|
@@ -78,9 +75,7 @@ class RNNProcessor(
|
|
|
78
75
|
preserve_state = True
|
|
79
76
|
elif "time" not in axes or "win" not in axes:
|
|
80
77
|
# Default fallback
|
|
81
|
-
ez.logger.warning(
|
|
82
|
-
"Missing 'time' or 'win' axis for auto preserve-state logic. Defaulting to reset."
|
|
83
|
-
)
|
|
78
|
+
ez.logger.warning("Missing 'time' or 'win' axis for auto preserve-state logic. Defaulting to reset.")
|
|
84
79
|
preserve_state = False
|
|
85
80
|
else:
|
|
86
81
|
# Calculate stride between windows (assuming uniform spacing)
|
|
@@ -89,9 +84,7 @@ class RNNProcessor(
|
|
|
89
84
|
time_len = message.data.shape[message.get_axis_idx("time")]
|
|
90
85
|
gain = getattr(axes["time"], "gain", None)
|
|
91
86
|
if gain is None:
|
|
92
|
-
ez.logger.warning(
|
|
93
|
-
"Time axis gain not found, using default gain of 1.0."
|
|
94
|
-
)
|
|
87
|
+
ez.logger.warning("Time axis gain not found, using default gain of 1.0.")
|
|
95
88
|
gain = 1.0 # fallback default
|
|
96
89
|
win_len = time_len * gain
|
|
97
90
|
# Determine if we should preserve state
|
|
@@ -102,15 +95,9 @@ class RNNProcessor(
|
|
|
102
95
|
self.reset_hidden(batch_size)
|
|
103
96
|
else:
|
|
104
97
|
# If preserving state, only reset if batch size isn't 1
|
|
105
|
-
hx_batch_size = (
|
|
106
|
-
self._state.hx[0].shape[1]
|
|
107
|
-
if isinstance(self._state.hx, tuple)
|
|
108
|
-
else self._state.hx.shape[1]
|
|
109
|
-
)
|
|
98
|
+
hx_batch_size = self._state.hx[0].shape[1] if isinstance(self._state.hx, tuple) else self._state.hx.shape[1]
|
|
110
99
|
if hx_batch_size != 1:
|
|
111
|
-
ez.logger.debug(
|
|
112
|
-
f"Resetting hidden state due to batch size mismatch (hx: {hx_batch_size}, new: 1)"
|
|
113
|
-
)
|
|
100
|
+
ez.logger.debug(f"Resetting hidden state due to batch size mismatch (hx: {hx_batch_size}, new: 1)")
|
|
114
101
|
self.reset_hidden(1)
|
|
115
102
|
return preserve_state
|
|
116
103
|
|
|
@@ -119,9 +106,7 @@ class RNNProcessor(
|
|
|
119
106
|
if not isinstance(x, torch.Tensor):
|
|
120
107
|
x = torch.tensor(
|
|
121
108
|
x,
|
|
122
|
-
dtype=torch.float32
|
|
123
|
-
if self.settings.single_precision
|
|
124
|
-
else torch.float64,
|
|
109
|
+
dtype=torch.float32 if self.settings.single_precision else torch.float64,
|
|
125
110
|
device=self._state.device,
|
|
126
111
|
)
|
|
127
112
|
|
|
@@ -143,18 +128,11 @@ class RNNProcessor(
|
|
|
143
128
|
y_data[key] = []
|
|
144
129
|
y_data[key].append(out.cpu().numpy())
|
|
145
130
|
# Concatenate outputs for each key
|
|
146
|
-
y_data = {
|
|
147
|
-
key: np.concatenate(outputs, axis=0)
|
|
148
|
-
for key, outputs in y_data.items()
|
|
149
|
-
}
|
|
131
|
+
y_data = {key: np.concatenate(outputs, axis=0) for key, outputs in y_data.items()}
|
|
150
132
|
else:
|
|
151
133
|
y, self._state.hx = self._state.model(x, hx=self._state.hx)
|
|
152
134
|
y_data = {
|
|
153
|
-
key: (
|
|
154
|
-
out.cpu().numpy().squeeze(0)
|
|
155
|
-
if added_batch_dim
|
|
156
|
-
else out.cpu().numpy()
|
|
157
|
-
)
|
|
135
|
+
key: (out.cpu().numpy().squeeze(0) if added_batch_dim else out.cpu().numpy())
|
|
158
136
|
for key, out in y.items()
|
|
159
137
|
}
|
|
160
138
|
|
|
@@ -205,18 +183,18 @@ class RNNProcessor(
|
|
|
205
183
|
if self._state.scheduler is not None:
|
|
206
184
|
self._state.scheduler.step()
|
|
207
185
|
|
|
208
|
-
def partial_fit(self, message:
|
|
186
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
209
187
|
self._state.model.train()
|
|
210
188
|
|
|
211
|
-
X = self._to_tensor(message.
|
|
189
|
+
X = self._to_tensor(message.data)
|
|
212
190
|
|
|
213
191
|
# Add batch dimension if missing
|
|
214
192
|
X, batched = self._ensure_batched(X)
|
|
215
193
|
|
|
216
194
|
batch_size = X.shape[0]
|
|
217
|
-
preserve_state = self._maybe_reset_state(message
|
|
195
|
+
preserve_state = self._maybe_reset_state(message, batch_size)
|
|
218
196
|
|
|
219
|
-
y_targ = message.trigger.value
|
|
197
|
+
y_targ = message.attrs["trigger"].value
|
|
220
198
|
if not isinstance(y_targ, dict):
|
|
221
199
|
y_targ = {"output": y_targ}
|
|
222
200
|
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
|
ezmsg/learn/process/sgd.py
CHANGED
|
@@ -2,9 +2,11 @@ import typing
|
|
|
2
2
|
|
|
3
3
|
import ezmsg.core as ez
|
|
4
4
|
import numpy as np
|
|
5
|
-
from ezmsg.
|
|
6
|
-
|
|
7
|
-
|
|
5
|
+
from ezmsg.baseproc import (
|
|
6
|
+
BaseAdaptiveTransformer,
|
|
7
|
+
BaseAdaptiveTransformerUnit,
|
|
8
|
+
processor_state,
|
|
9
|
+
)
|
|
8
10
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
11
|
from ezmsg.util.messages.util import replace
|
|
10
12
|
from sklearn.exceptions import NotFittedError
|
|
@@ -13,103 +15,6 @@ from sklearn.linear_model import SGDClassifier
|
|
|
13
15
|
from ..util import ClassifierMessage
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
@consumer
|
|
17
|
-
def sgd_decoder(
|
|
18
|
-
alpha: float = 1.5e-5,
|
|
19
|
-
eta0: float = 1e-7, # Lower than what you'd use for offline training.
|
|
20
|
-
loss: str = "squared_hinge",
|
|
21
|
-
label_weights: dict[str, float] | None = None,
|
|
22
|
-
settings_path: str | None = None,
|
|
23
|
-
) -> typing.Generator[AxisArray | SampleMessage, ClassifierMessage | None, None]:
|
|
24
|
-
"""
|
|
25
|
-
Passive Aggressive Classifier
|
|
26
|
-
Online Passive-Aggressive Algorithms <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
|
|
27
|
-
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
alpha: Maximum step size (regularization)
|
|
31
|
-
eta0: The initial learning rate for the 'adaptive’ schedules.
|
|
32
|
-
loss: The loss function to be used:
|
|
33
|
-
hinge: equivalent to PA-I in the reference paper.
|
|
34
|
-
squared_hinge: equivalent to PA-II in the reference paper.
|
|
35
|
-
label_weights: An optional dictionary of label names and their relative weight.
|
|
36
|
-
e.g., {'Go': 31.0, 'Stop': 0.5}
|
|
37
|
-
If this is None then settings_path must be provided and the pre-trained model
|
|
38
|
-
settings_path: Path to the stored sklearn model pkl file.
|
|
39
|
-
|
|
40
|
-
Returns:
|
|
41
|
-
Generator that accepts `SampleMessage` for incremental training (`partial_fit`) and yields None,
|
|
42
|
-
or `AxisArray` for inference (`predict`) and yields a `ClassifierMessage`.
|
|
43
|
-
"""
|
|
44
|
-
# pre-init inputs and outputs
|
|
45
|
-
msg_out = ClassifierMessage(data=np.array([]), dims=[""])
|
|
46
|
-
|
|
47
|
-
# State variables:
|
|
48
|
-
|
|
49
|
-
if settings_path is not None:
|
|
50
|
-
import pickle
|
|
51
|
-
|
|
52
|
-
with open(settings_path, "rb") as f:
|
|
53
|
-
model = pickle.load(f)
|
|
54
|
-
if label_weights is not None:
|
|
55
|
-
model.class_weight = label_weights
|
|
56
|
-
# Overwrite eta0, probably with a value lower than what was used online.
|
|
57
|
-
model.eta0 = eta0
|
|
58
|
-
else:
|
|
59
|
-
model = SGDClassifier(
|
|
60
|
-
loss=loss,
|
|
61
|
-
alpha=alpha,
|
|
62
|
-
penalty="elasticnet",
|
|
63
|
-
learning_rate="adaptive",
|
|
64
|
-
eta0=eta0,
|
|
65
|
-
early_stopping=False,
|
|
66
|
-
class_weight=label_weights,
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
b_first_train = True
|
|
70
|
-
# TODO: template_out
|
|
71
|
-
|
|
72
|
-
while True:
|
|
73
|
-
msg_in: AxisArray | SampleMessage = yield msg_out
|
|
74
|
-
|
|
75
|
-
msg_out = None
|
|
76
|
-
if type(msg_in) is SampleMessage:
|
|
77
|
-
# SampleMessage used for training.
|
|
78
|
-
if not np.any(np.isnan(msg_in.sample.data)):
|
|
79
|
-
train_sample = msg_in.sample.data.reshape(1, -1)
|
|
80
|
-
if b_first_train:
|
|
81
|
-
model.partial_fit(
|
|
82
|
-
train_sample,
|
|
83
|
-
[msg_in.trigger.value],
|
|
84
|
-
classes=list(label_weights.keys()),
|
|
85
|
-
)
|
|
86
|
-
b_first_train = False
|
|
87
|
-
else:
|
|
88
|
-
model.partial_fit(train_sample, [msg_in.trigger.value])
|
|
89
|
-
elif msg_in.data.size:
|
|
90
|
-
# AxisArray used for inference
|
|
91
|
-
if not np.any(np.isnan(msg_in.data)):
|
|
92
|
-
try:
|
|
93
|
-
X = msg_in.data.reshape((msg_in.data.shape[0], -1))
|
|
94
|
-
result = model._predict_proba_lr(X)
|
|
95
|
-
except NotFittedError:
|
|
96
|
-
result = None
|
|
97
|
-
if result is not None:
|
|
98
|
-
out_axes = {}
|
|
99
|
-
if msg_in.dims[0] in msg_in.axes:
|
|
100
|
-
out_axes[msg_in.dims[0]] = replace(
|
|
101
|
-
msg_in.axes[msg_in.dims[0]],
|
|
102
|
-
offset=msg_in.axes[msg_in.dims[0]].offset,
|
|
103
|
-
)
|
|
104
|
-
msg_out = ClassifierMessage(
|
|
105
|
-
data=result,
|
|
106
|
-
dims=msg_in.dims[:1] + ["labels"],
|
|
107
|
-
axes=out_axes,
|
|
108
|
-
labels=list(model.class_weight.keys()),
|
|
109
|
-
key=msg_in.key,
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
|
|
113
18
|
class SGDDecoderSettings(ez.Settings):
|
|
114
19
|
alpha: float = 1e-5
|
|
115
20
|
eta0: float = 3e-4
|
|
@@ -118,14 +23,94 @@ class SGDDecoderSettings(ez.Settings):
|
|
|
118
23
|
settings_path: str | None = None
|
|
119
24
|
|
|
120
25
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
26
|
+
@processor_state
|
|
27
|
+
class SGDDecoderState:
|
|
28
|
+
model: typing.Any = None
|
|
29
|
+
b_first_train: bool = True
|
|
124
30
|
|
|
125
|
-
# Method to be implemented by subclasses to construct the specific generator
|
|
126
|
-
def construct_generator(self):
|
|
127
|
-
self.STATE.gen = sgd_decoder(**self.SETTINGS.__dict__)
|
|
128
31
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
32
|
+
class SGDDecoderTransformer(BaseAdaptiveTransformer[SGDDecoderSettings, AxisArray, ClassifierMessage, SGDDecoderState]):
|
|
33
|
+
"""
|
|
34
|
+
SGD-based online classifier.
|
|
35
|
+
|
|
36
|
+
Online Passive-Aggressive Algorithms
|
|
37
|
+
<http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
|
|
38
|
+
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def _refreshed_model(self):
|
|
42
|
+
if self.settings.settings_path is not None:
|
|
43
|
+
import pickle
|
|
44
|
+
|
|
45
|
+
with open(self.settings.settings_path, "rb") as f:
|
|
46
|
+
model = pickle.load(f)
|
|
47
|
+
if self.settings.label_weights is not None:
|
|
48
|
+
model.class_weight = self.settings.label_weights
|
|
49
|
+
model.eta0 = self.settings.eta0
|
|
50
|
+
else:
|
|
51
|
+
model = SGDClassifier(
|
|
52
|
+
loss=self.settings.loss,
|
|
53
|
+
alpha=self.settings.alpha,
|
|
54
|
+
penalty="elasticnet",
|
|
55
|
+
learning_rate="adaptive",
|
|
56
|
+
eta0=self.settings.eta0,
|
|
57
|
+
early_stopping=False,
|
|
58
|
+
class_weight=self.settings.label_weights,
|
|
59
|
+
)
|
|
60
|
+
return model
|
|
61
|
+
|
|
62
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
63
|
+
self._state.model = self._refreshed_model()
|
|
64
|
+
|
|
65
|
+
def _process(self, message: AxisArray) -> ClassifierMessage | None:
|
|
66
|
+
if self._state.model is None or not message.data.size:
|
|
67
|
+
return None
|
|
68
|
+
if np.any(np.isnan(message.data)):
|
|
69
|
+
return None
|
|
70
|
+
try:
|
|
71
|
+
X = message.data.reshape((message.data.shape[0], -1))
|
|
72
|
+
result = self._state.model._predict_proba_lr(X)
|
|
73
|
+
except NotFittedError:
|
|
74
|
+
return None
|
|
75
|
+
out_axes = {}
|
|
76
|
+
if message.dims[0] in message.axes:
|
|
77
|
+
out_axes[message.dims[0]] = replace(
|
|
78
|
+
message.axes[message.dims[0]],
|
|
79
|
+
offset=message.axes[message.dims[0]].offset,
|
|
80
|
+
)
|
|
81
|
+
return ClassifierMessage(
|
|
82
|
+
data=result,
|
|
83
|
+
dims=message.dims[:1] + ["labels"],
|
|
84
|
+
axes=out_axes,
|
|
85
|
+
labels=list(self._state.model.class_weight.keys()),
|
|
86
|
+
key=message.key,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
90
|
+
if self._hash != 0:
|
|
91
|
+
self._reset_state(message)
|
|
92
|
+
self._hash = 0
|
|
93
|
+
|
|
94
|
+
if np.any(np.isnan(message.data)):
|
|
95
|
+
return
|
|
96
|
+
train_sample = message.data.reshape(1, -1)
|
|
97
|
+
if self._state.b_first_train:
|
|
98
|
+
self._state.model.partial_fit(
|
|
99
|
+
train_sample,
|
|
100
|
+
[message.attrs["trigger"].value],
|
|
101
|
+
classes=list(self.settings.label_weights.keys()),
|
|
102
|
+
)
|
|
103
|
+
self._state.b_first_train = False
|
|
104
|
+
else:
|
|
105
|
+
self._state.model.partial_fit(train_sample, [message.attrs["trigger"].value])
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class SGDDecoder(
|
|
109
|
+
BaseAdaptiveTransformerUnit[
|
|
110
|
+
SGDDecoderSettings,
|
|
111
|
+
AxisArray,
|
|
112
|
+
ClassifierMessage,
|
|
113
|
+
SGDDecoderTransformer,
|
|
114
|
+
]
|
|
115
|
+
):
|
|
116
|
+
SETTINGS = SGDDecoderSettings
|