ezmsg-learn 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__version__.py +2 -2
- ezmsg/learn/dim_reduce/adaptive_decomp.py +9 -19
- ezmsg/learn/dim_reduce/incremental_decomp.py +8 -16
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
- ezmsg/learn/linear_model/linear_regressor.py +4 -0
- ezmsg/learn/linear_model/sgd.py +6 -2
- ezmsg/learn/linear_model/slda.py +7 -1
- ezmsg/learn/model/mlp.py +8 -14
- ezmsg/learn/model/refit_kalman.py +17 -49
- ezmsg/learn/nlin_model/mlp.py +5 -1
- ezmsg/learn/process/adaptive_linear_regressor.py +20 -36
- ezmsg/learn/process/base.py +12 -31
- ezmsg/learn/process/linear_regressor.py +13 -18
- ezmsg/learn/process/mlp_old.py +18 -31
- ezmsg/learn/process/refit_kalman.py +8 -13
- ezmsg/learn/process/rnn.py +14 -36
- ezmsg/learn/process/sgd.py +94 -109
- ezmsg/learn/process/sklearn.py +17 -51
- ezmsg/learn/process/slda.py +6 -15
- ezmsg/learn/process/ssr.py +374 -0
- ezmsg/learn/process/torch.py +12 -29
- ezmsg/learn/process/transformer.py +11 -19
- ezmsg/learn/util.py +5 -4
- {ezmsg_learn-1.0.dist-info → ezmsg_learn-1.2.0.dist-info}/METADATA +5 -9
- ezmsg_learn-1.2.0.dist-info/RECORD +38 -0
- {ezmsg_learn-1.0.dist-info → ezmsg_learn-1.2.0.dist-info}/WHEEL +1 -1
- ezmsg_learn-1.2.0.dist-info/licenses/LICENSE +21 -0
- ezmsg_learn-1.0.dist-info/RECORD +0 -36
ezmsg/learn/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.0'
|
|
32
|
-
__version_tuple__ = version_tuple = (1, 0)
|
|
31
|
+
__version__ = version = '1.2.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 2, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
|
-
from sklearn.decomposition import IncrementalPCA, MiniBatchNMF
|
|
4
|
-
import numpy as np
|
|
5
3
|
import ezmsg.core as ez
|
|
6
|
-
|
|
7
|
-
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ezmsg.baseproc import (
|
|
8
6
|
BaseAdaptiveTransformer,
|
|
9
7
|
BaseAdaptiveTransformerUnit,
|
|
8
|
+
processor_state,
|
|
10
9
|
)
|
|
11
10
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
11
|
+
from sklearn.decomposition import IncrementalPCA, MiniBatchNMF
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class AdaptiveDecompSettings(ez.Settings):
|
|
@@ -23,15 +23,11 @@ class AdaptiveDecompState:
|
|
|
23
23
|
estimator: typing.Any = None
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
EstimatorType = typing.TypeVar(
|
|
27
|
-
"EstimatorType", bound=typing.Union[IncrementalPCA, MiniBatchNMF]
|
|
28
|
-
)
|
|
26
|
+
EstimatorType = typing.TypeVar("EstimatorType", bound=typing.Union[IncrementalPCA, MiniBatchNMF])
|
|
29
27
|
|
|
30
28
|
|
|
31
29
|
class AdaptiveDecompTransformer(
|
|
32
|
-
BaseAdaptiveTransformer[
|
|
33
|
-
AdaptiveDecompSettings, AxisArray, AxisArray, AdaptiveDecompState
|
|
34
|
-
],
|
|
30
|
+
BaseAdaptiveTransformer[AdaptiveDecompSettings, AxisArray, AxisArray, AdaptiveDecompState],
|
|
35
31
|
typing.Generic[EstimatorType],
|
|
36
32
|
):
|
|
37
33
|
"""
|
|
@@ -80,9 +76,7 @@ class AdaptiveDecompTransformer(
|
|
|
80
76
|
it_ax_ix = message.get_axis_idx(iter_axis)
|
|
81
77
|
# Remaining axes are to be treated independently
|
|
82
78
|
off_targ_axes = [
|
|
83
|
-
_
|
|
84
|
-
for _ in (message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :])
|
|
85
|
-
if _ != self.settings.axis
|
|
79
|
+
_ for _ in (message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :]) if _ != self.settings.axis
|
|
86
80
|
]
|
|
87
81
|
self._state.axis_groups = iter_axis, targ_axes, off_targ_axes
|
|
88
82
|
|
|
@@ -152,9 +146,7 @@ class AdaptiveDecompTransformer(
|
|
|
152
146
|
|
|
153
147
|
# Transform data
|
|
154
148
|
if hasattr(self._state.estimator, "components_"):
|
|
155
|
-
decomp_dat = self._state.estimator.transform(in_dat).reshape(
|
|
156
|
-
(-1,) + self._state.template.data.shape[1:]
|
|
157
|
-
)
|
|
149
|
+
decomp_dat = self._state.estimator.transform(in_dat).reshape((-1,) + self._state.template.data.shape[1:])
|
|
158
150
|
replace_kwargs["data"] = decomp_dat
|
|
159
151
|
|
|
160
152
|
return replace(self._state.template, **replace_kwargs)
|
|
@@ -241,9 +233,7 @@ class MiniBatchNMFTransformer(AdaptiveDecompTransformer[MiniBatchNMF]):
|
|
|
241
233
|
pass
|
|
242
234
|
|
|
243
235
|
|
|
244
|
-
SettingsType = typing.TypeVar(
|
|
245
|
-
"SettingsType", bound=typing.Union[IncrementalPCASettings, MiniBatchNMFSettings]
|
|
246
|
-
)
|
|
236
|
+
SettingsType = typing.TypeVar("SettingsType", bound=typing.Union[IncrementalPCASettings, MiniBatchNMFSettings])
|
|
247
237
|
TransformerType = typing.TypeVar(
|
|
248
238
|
"TransformerType",
|
|
249
239
|
bound=typing.Union[IncrementalPCATransformer, MiniBatchNMFTransformer],
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
|
-
import numpy as np
|
|
4
3
|
import ezmsg.core as ez
|
|
5
|
-
|
|
6
|
-
from ezmsg.
|
|
7
|
-
CompositeProcessor,
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ezmsg.baseproc import (
|
|
8
6
|
BaseStatefulProcessor,
|
|
9
7
|
BaseTransformerUnit,
|
|
8
|
+
CompositeProcessor,
|
|
10
9
|
)
|
|
11
10
|
from ezmsg.sigproc.window import WindowTransformer
|
|
11
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
12
12
|
|
|
13
13
|
from .adaptive_decomp import (
|
|
14
14
|
IncrementalPCASettings,
|
|
@@ -36,9 +36,7 @@ class IncrementalDecompSettings(ez.Settings):
|
|
|
36
36
|
forget_factor: float = 0.7
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
class IncrementalDecompTransformer(
|
|
40
|
-
CompositeProcessor[IncrementalDecompSettings, AxisArray, AxisArray]
|
|
41
|
-
):
|
|
39
|
+
class IncrementalDecompTransformer(CompositeProcessor[IncrementalDecompSettings, AxisArray, AxisArray]):
|
|
42
40
|
"""
|
|
43
41
|
Automates usage of IncrementalPCATransformer and MiniBatchNMFTransformer by using a WindowTransformer
|
|
44
42
|
to extract training samples then calls partial_fit on the decomposition transformer.
|
|
@@ -125,15 +123,11 @@ class IncrementalDecompTransformer(
|
|
|
125
123
|
# If the estimator has not been trained once, train it with the first message
|
|
126
124
|
self._procs["decomp"].partial_fit(message)
|
|
127
125
|
elif "windowing" in self._procs:
|
|
128
|
-
state["windowing"], train_msg = self._procs["windowing"].stateful_op(
|
|
129
|
-
state.get("windowing", None), message
|
|
130
|
-
)
|
|
126
|
+
state["windowing"], train_msg = self._procs["windowing"].stateful_op(state.get("windowing", None), message)
|
|
131
127
|
self._partial_fit_windowed(train_msg)
|
|
132
128
|
|
|
133
129
|
# Process the incoming message
|
|
134
|
-
state["decomp"], result = self._procs["decomp"].stateful_op(
|
|
135
|
-
state.get("decomp", None), message
|
|
136
|
-
)
|
|
130
|
+
state["decomp"], result = self._procs["decomp"].stateful_op(state.get("decomp", None), message)
|
|
137
131
|
|
|
138
132
|
return state, result
|
|
139
133
|
|
|
@@ -174,8 +168,6 @@ class IncrementalDecompTransformer(
|
|
|
174
168
|
|
|
175
169
|
|
|
176
170
|
class IncrementalDecompUnit(
|
|
177
|
-
BaseTransformerUnit[
|
|
178
|
-
IncrementalDecompSettings, AxisArray, AxisArray, IncrementalDecompTransformer
|
|
179
|
-
]
|
|
171
|
+
BaseTransformerUnit[IncrementalDecompSettings, AxisArray, AxisArray, IncrementalDecompTransformer]
|
|
180
172
|
):
|
|
181
173
|
SETTINGS = IncrementalDecompSettings
|
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
from ..process.adaptive_linear_regressor import (
|
|
2
2
|
AdaptiveLinearRegressorSettings as AdaptiveLinearRegressorSettings,
|
|
3
|
+
)
|
|
4
|
+
from ..process.adaptive_linear_regressor import (
|
|
3
5
|
AdaptiveLinearRegressorState as AdaptiveLinearRegressorState,
|
|
6
|
+
)
|
|
7
|
+
from ..process.adaptive_linear_regressor import (
|
|
4
8
|
AdaptiveLinearRegressorTransformer as AdaptiveLinearRegressorTransformer,
|
|
9
|
+
)
|
|
10
|
+
from ..process.adaptive_linear_regressor import (
|
|
5
11
|
AdaptiveLinearRegressorUnit as AdaptiveLinearRegressorUnit,
|
|
6
12
|
)
|
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
from ..process.linear_regressor import (
|
|
2
2
|
LinearRegressorSettings as LinearRegressorSettings,
|
|
3
|
+
)
|
|
4
|
+
from ..process.linear_regressor import (
|
|
3
5
|
LinearRegressorState as LinearRegressorState,
|
|
6
|
+
)
|
|
7
|
+
from ..process.linear_regressor import (
|
|
4
8
|
LinearRegressorTransformer as LinearRegressorTransformer,
|
|
5
9
|
)
|
ezmsg/learn/linear_model/sgd.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
from ..process.sgd import (
|
|
2
|
-
sgd_decoder as sgd_decoder,
|
|
3
|
-
SGDDecoderSettings as SGDDecoderSettings,
|
|
4
2
|
SGDDecoder as SGDDecoder,
|
|
5
3
|
)
|
|
4
|
+
from ..process.sgd import (
|
|
5
|
+
SGDDecoderSettings as SGDDecoderSettings,
|
|
6
|
+
)
|
|
7
|
+
from ..process.sgd import (
|
|
8
|
+
sgd_decoder as sgd_decoder,
|
|
9
|
+
)
|
ezmsg/learn/linear_model/slda.py
CHANGED
|
@@ -1,6 +1,12 @@
|
|
|
1
|
+
from ..process.slda import (
|
|
2
|
+
SLDA as SLDA,
|
|
3
|
+
)
|
|
1
4
|
from ..process.slda import (
|
|
2
5
|
SLDASettings as SLDASettings,
|
|
6
|
+
)
|
|
7
|
+
from ..process.slda import (
|
|
3
8
|
SLDAState as SLDAState,
|
|
9
|
+
)
|
|
10
|
+
from ..process.slda import (
|
|
4
11
|
SLDATransformer as SLDATransformer,
|
|
5
|
-
SLDA as SLDA,
|
|
6
12
|
)
|
ezmsg/learn/model/mlp.py
CHANGED
|
@@ -27,11 +27,12 @@ class MLP(torch.nn.Module):
|
|
|
27
27
|
Initialize the MLP model.
|
|
28
28
|
Args:
|
|
29
29
|
input_size (int): The size of the input features.
|
|
30
|
-
hidden_size (int | list[int]): The sizes of the hidden layers. If a list, num_layers must be None or the
|
|
31
|
-
of the list. If a single integer, num_layers must be specified and determines the number of
|
|
30
|
+
hidden_size (int | list[int]): The sizes of the hidden layers. If a list, num_layers must be None or the
|
|
31
|
+
length of the list. If a single integer, num_layers must be specified and determines the number of
|
|
32
|
+
hidden layers.
|
|
32
33
|
num_layers (int, optional): The number of hidden layers. Length of hidden_size if None. Default is None.
|
|
33
|
-
output_heads (int | dict[str, int], optional): Number of output features or classes if single head output
|
|
34
|
-
dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
|
|
34
|
+
output_heads (int | dict[str, int], optional): Number of output features or classes if single head output
|
|
35
|
+
or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
|
|
35
36
|
norm_layer (str, optional): A normalization layer to be applied after each linear layer. Default is None.
|
|
36
37
|
Common choices are "BatchNorm1d" or "LayerNorm".
|
|
37
38
|
activation_layer (str, optional): An activation function to be applied after each normalization
|
|
@@ -43,18 +44,14 @@ class MLP(torch.nn.Module):
|
|
|
43
44
|
super().__init__()
|
|
44
45
|
if isinstance(hidden_size, int):
|
|
45
46
|
if num_layers is None:
|
|
46
|
-
raise ValueError(
|
|
47
|
-
"If hidden_size is an integer, num_layers must be specified."
|
|
48
|
-
)
|
|
47
|
+
raise ValueError("If hidden_size is an integer, num_layers must be specified.")
|
|
49
48
|
hidden_size = [hidden_size] * num_layers
|
|
50
49
|
if len(hidden_size) == 0:
|
|
51
50
|
raise ValueError("hidden_size must have at least one element")
|
|
52
51
|
if any(not isinstance(x, int) for x in hidden_size):
|
|
53
52
|
raise ValueError("hidden_size must contain only integers")
|
|
54
53
|
if num_layers is not None and len(hidden_size) != num_layers:
|
|
55
|
-
raise ValueError(
|
|
56
|
-
"Length of hidden_size must match num_layers if num_layers is specified."
|
|
57
|
-
)
|
|
54
|
+
raise ValueError("Length of hidden_size must match num_layers if num_layers is specified.")
|
|
58
55
|
|
|
59
56
|
params = {} if inplace is None else {"inplace": inplace}
|
|
60
57
|
|
|
@@ -84,10 +81,7 @@ class MLP(torch.nn.Module):
|
|
|
84
81
|
if isinstance(output_heads, int):
|
|
85
82
|
output_heads = {"output": output_heads}
|
|
86
83
|
self.heads = torch.nn.ModuleDict(
|
|
87
|
-
{
|
|
88
|
-
name: torch.nn.Linear(hidden_size[-1], output_size)
|
|
89
|
-
for name, output_size in output_heads.items()
|
|
90
|
-
}
|
|
84
|
+
{name: torch.nn.Linear(hidden_size[-1], output_size) for name, output_size in output_heads.items()}
|
|
91
85
|
)
|
|
92
86
|
|
|
93
87
|
@classmethod
|
|
@@ -83,17 +83,12 @@ class RefitKalmanFilter:
|
|
|
83
83
|
if Y_state.ndim != 2:
|
|
84
84
|
raise ValueError(f"State vector must be 2D, got {Y_state.ndim}D")
|
|
85
85
|
|
|
86
|
-
if (
|
|
87
|
-
not hasattr(self, "H_observation_matrix")
|
|
88
|
-
or self.H_observation_matrix is None
|
|
89
|
-
):
|
|
86
|
+
if not hasattr(self, "H_observation_matrix") or self.H_observation_matrix is None:
|
|
90
87
|
raise ValueError("Model must be fitted before refitting")
|
|
91
88
|
|
|
92
89
|
expected_states = self.H_observation_matrix.shape[1]
|
|
93
90
|
if Y_state.shape[1] != expected_states:
|
|
94
|
-
raise ValueError(
|
|
95
|
-
f"State vector has {Y_state.shape[1]} dimensions, expected {expected_states}"
|
|
96
|
-
)
|
|
91
|
+
raise ValueError(f"State vector has {Y_state.shape[1]} dimensions, expected {expected_states}")
|
|
97
92
|
|
|
98
93
|
def fit(self, X_train, y_train):
|
|
99
94
|
"""
|
|
@@ -121,15 +116,11 @@ class RefitKalmanFilter:
|
|
|
121
116
|
X2 = X[1:, :] # x_{t+1}
|
|
122
117
|
X1 = X[:-1, :] # x_t
|
|
123
118
|
A = X2.T @ X1 @ np.linalg.inv(X1.T @ X1) # Transition matrix
|
|
124
|
-
W = (
|
|
125
|
-
(X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1)
|
|
126
|
-
) # Covariance of transition matrix
|
|
119
|
+
W = (X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1) # Covariance of transition matrix
|
|
127
120
|
|
|
128
121
|
# Calculate the measurement matrix (from x_t to z_t) using least-squares
|
|
129
122
|
H = Z.T @ X @ np.linalg.inv(X.T @ X) # Measurement matrix
|
|
130
|
-
Q = (
|
|
131
|
-
(Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0]
|
|
132
|
-
) # Covariance of measurement matrix
|
|
123
|
+
Q = (Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0] # Covariance of measurement matrix
|
|
133
124
|
|
|
134
125
|
self.A_state_transition_matrix = A
|
|
135
126
|
self.W_process_noise_covariance = W * self.process_noise_scale
|
|
@@ -179,15 +170,11 @@ class RefitKalmanFilter:
|
|
|
179
170
|
if intention_velocity_indices is None:
|
|
180
171
|
# Assume (x, y, vx, vy)
|
|
181
172
|
vel_idx = 2 if Y_state.shape[1] >= 4 else 0
|
|
182
|
-
print(
|
|
183
|
-
f"[RefitKalmanFilter] No velocity index provided — defaulting to {vel_idx}"
|
|
184
|
-
)
|
|
173
|
+
print(f"[RefitKalmanFilter] No velocity index provided — defaulting to {vel_idx}")
|
|
185
174
|
else:
|
|
186
175
|
if isinstance(intention_velocity_indices, (list, tuple)):
|
|
187
176
|
if len(intention_velocity_indices) != 1:
|
|
188
|
-
raise ValueError(
|
|
189
|
-
"Only one velocity start index should be provided."
|
|
190
|
-
)
|
|
177
|
+
raise ValueError("Only one velocity start index should be provided.")
|
|
191
178
|
vel_idx = intention_velocity_indices[0]
|
|
192
179
|
else:
|
|
193
180
|
vel_idx = intention_velocity_indices
|
|
@@ -198,18 +185,14 @@ class RefitKalmanFilter:
|
|
|
198
185
|
else:
|
|
199
186
|
intended_states = Y_state.copy()
|
|
200
187
|
# Calculate intended velocities for each sample
|
|
201
|
-
for i, (state, pos, target) in enumerate(
|
|
202
|
-
zip(Y_state, cursor_positions, target_positions)
|
|
203
|
-
):
|
|
188
|
+
for i, (state, pos, target) in enumerate(zip(Y_state, cursor_positions, target_positions)):
|
|
204
189
|
is_hold = hold_indices[i] if hold_indices is not None else False
|
|
205
190
|
|
|
206
191
|
if is_hold:
|
|
207
192
|
# During hold periods, intended velocity is zero
|
|
208
193
|
intended_states[i, vel_idx : vel_idx + 2] = 0.0
|
|
209
194
|
if i > 0:
|
|
210
|
-
intended_states[i, :2] = intended_states[
|
|
211
|
-
i - 1, :2
|
|
212
|
-
] # Same position as previous
|
|
195
|
+
intended_states[i, :2] = intended_states[i - 1, :2] # Same position as previous
|
|
213
196
|
else:
|
|
214
197
|
# Calculate direction to target
|
|
215
198
|
to_target = target - pos
|
|
@@ -228,9 +211,7 @@ class RefitKalmanFilter:
|
|
|
228
211
|
intended_states[i, vel_idx : vel_idx + 2] = intended_velocity
|
|
229
212
|
# If target is very close, keep original velocity
|
|
230
213
|
else:
|
|
231
|
-
intended_states[i, vel_idx : vel_idx + 2] = state[
|
|
232
|
-
vel_idx : vel_idx + 2
|
|
233
|
-
]
|
|
214
|
+
intended_states[i, vel_idx : vel_idx + 2] = state[vel_idx : vel_idx + 2]
|
|
234
215
|
|
|
235
216
|
intended_states = np.array(intended_states)
|
|
236
217
|
Z = np.array(X_neural)
|
|
@@ -258,7 +239,8 @@ class RefitKalmanFilter:
|
|
|
258
239
|
Raises:
|
|
259
240
|
LinAlgError: If the Riccati equation cannot be solved or matrix operations fail.
|
|
260
241
|
"""
|
|
261
|
-
|
|
242
|
+
# TODO: consider removing non-steady-state for compute_gain() -
|
|
243
|
+
# non_steady_state updates will occur during predict() and update()
|
|
262
244
|
# if self.steady_state:
|
|
263
245
|
try:
|
|
264
246
|
# Try with original matrices
|
|
@@ -272,9 +254,7 @@ class RefitKalmanFilter:
|
|
|
272
254
|
self.P_state_covariance
|
|
273
255
|
@ self.H_observation_matrix.T
|
|
274
256
|
@ np.linalg.inv(
|
|
275
|
-
self.H_observation_matrix
|
|
276
|
-
@ self.P_state_covariance
|
|
277
|
-
@ self.H_observation_matrix.T
|
|
257
|
+
self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T
|
|
278
258
|
+ self.Q_measurement_noise_covariance
|
|
279
259
|
)
|
|
280
260
|
)
|
|
@@ -284,9 +264,7 @@ class RefitKalmanFilter:
|
|
|
284
264
|
# W_reg = self.W_process_noise_covariance + 1e-7 * np.eye(
|
|
285
265
|
# self.W_process_noise_covariance.shape[0]
|
|
286
266
|
# )
|
|
287
|
-
Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye(
|
|
288
|
-
self.Q_measurement_noise_covariance.shape[0]
|
|
289
|
-
)
|
|
267
|
+
Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye(self.Q_measurement_noise_covariance.shape[0])
|
|
290
268
|
|
|
291
269
|
try:
|
|
292
270
|
self.P_state_covariance = solve_discrete_are(
|
|
@@ -299,19 +277,14 @@ class RefitKalmanFilter:
|
|
|
299
277
|
self.P_state_covariance
|
|
300
278
|
@ self.H_observation_matrix.T
|
|
301
279
|
@ np.linalg.inv(
|
|
302
|
-
self.H_observation_matrix
|
|
303
|
-
@ self.P_state_covariance
|
|
304
|
-
@ self.H_observation_matrix.T
|
|
305
|
-
+ Q_reg
|
|
280
|
+
self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T + Q_reg
|
|
306
281
|
)
|
|
307
282
|
)
|
|
308
283
|
print("Warning: Used regularized matrices for DARE solution")
|
|
309
284
|
except LinAlgError:
|
|
310
285
|
# Fallback to identity or manual initialization
|
|
311
286
|
print("Warning: DARE failed, using identity covariance")
|
|
312
|
-
self.P_state_covariance = np.eye(
|
|
313
|
-
self.A_state_transition_matrix.shape[0]
|
|
314
|
-
)
|
|
287
|
+
self.P_state_covariance = np.eye(self.A_state_transition_matrix.shape[0])
|
|
315
288
|
|
|
316
289
|
# else:
|
|
317
290
|
# n_states = self.A_state_transition_matrix.shape[0]
|
|
@@ -349,9 +322,7 @@ class RefitKalmanFilter:
|
|
|
349
322
|
return x_predicted, None
|
|
350
323
|
else:
|
|
351
324
|
P_predicted = self.alpha_fading_memory**2 * (
|
|
352
|
-
self.A_state_transition_matrix
|
|
353
|
-
@ self.P_state_covariance
|
|
354
|
-
@ self.A_state_transition_matrix.T
|
|
325
|
+
self.A_state_transition_matrix @ self.P_state_covariance @ self.A_state_transition_matrix.T
|
|
355
326
|
+ self.W_process_noise_covariance
|
|
356
327
|
)
|
|
357
328
|
return x_predicted, P_predicted
|
|
@@ -376,10 +347,7 @@ class RefitKalmanFilter:
|
|
|
376
347
|
|
|
377
348
|
# Non-steady-state mode
|
|
378
349
|
# System uncertainty
|
|
379
|
-
S =
|
|
380
|
-
self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T
|
|
381
|
-
+ self.Q_measurement_noise_covariance
|
|
382
|
-
)
|
|
350
|
+
S = self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T + self.Q_measurement_noise_covariance
|
|
383
351
|
|
|
384
352
|
# Kalman gain
|
|
385
353
|
K = P_predicted @ self.H_observation_matrix.T @ np.linalg.pinv(S)
|
ezmsg/learn/nlin_model/mlp.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
from ..model.mlp_old import MLP as MLP
|
|
2
|
+
from ..process.mlp_old import (
|
|
3
|
+
MLPProcessor as MLPProcessor,
|
|
4
|
+
)
|
|
2
5
|
from ..process.mlp_old import (
|
|
3
6
|
MLPSettings as MLPSettings,
|
|
7
|
+
)
|
|
8
|
+
from ..process.mlp_old import (
|
|
4
9
|
MLPState as MLPState,
|
|
5
|
-
MLPProcessor as MLPProcessor,
|
|
6
10
|
)
|
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
from dataclasses import field
|
|
2
2
|
|
|
3
|
+
import ezmsg.core as ez
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pandas as pd
|
|
5
|
-
import river.optim
|
|
6
6
|
import river.linear_model
|
|
7
|
+
import river.optim
|
|
7
8
|
import sklearn.base
|
|
8
|
-
|
|
9
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
10
|
-
from ezmsg.sigproc.base import (
|
|
11
|
-
processor_state,
|
|
9
|
+
from ezmsg.baseproc import (
|
|
12
10
|
BaseAdaptiveTransformer,
|
|
13
11
|
BaseAdaptiveTransformerUnit,
|
|
12
|
+
processor_state,
|
|
14
13
|
)
|
|
15
14
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
16
15
|
|
|
@@ -39,9 +38,7 @@ class AdaptiveLinearRegressorTransformer(
|
|
|
39
38
|
):
|
|
40
39
|
def __init__(self, *args, **kwargs):
|
|
41
40
|
super().__init__(*args, **kwargs)
|
|
42
|
-
self.settings = replace(
|
|
43
|
-
self.settings, model_type=AdaptiveLinearRegressor(self.settings.model_type)
|
|
44
|
-
)
|
|
41
|
+
self.settings = replace(self.settings, model_type=AdaptiveLinearRegressor(self.settings.model_type))
|
|
45
42
|
b_river = self.settings.model_type in [
|
|
46
43
|
AdaptiveLinearRegressor.LINEAR,
|
|
47
44
|
AdaptiveLinearRegressor.LOGISTIC,
|
|
@@ -49,9 +46,7 @@ class AdaptiveLinearRegressorTransformer(
|
|
|
49
46
|
if b_river:
|
|
50
47
|
self.settings.model_kwargs["l2"] = self.settings.model_kwargs.get("l2", 0.0)
|
|
51
48
|
if "learn_rate" in self.settings.model_kwargs:
|
|
52
|
-
self.settings.model_kwargs["optimizer"] = river.optim.SGD(
|
|
53
|
-
self.settings.model_kwargs.pop("learn_rate")
|
|
54
|
-
)
|
|
49
|
+
self.settings.model_kwargs["optimizer"] = river.optim.SGD(self.settings.model_kwargs.pop("learn_rate"))
|
|
55
50
|
|
|
56
51
|
if self.settings.settings_path is not None:
|
|
57
52
|
# Load model from file
|
|
@@ -69,9 +64,7 @@ class AdaptiveLinearRegressorTransformer(
|
|
|
69
64
|
print("TODO: Override sklearn model with kwargs")
|
|
70
65
|
else:
|
|
71
66
|
# Build model from scratch.
|
|
72
|
-
regressor_klass = get_regressor(
|
|
73
|
-
RegressorType.ADAPTIVE, self.settings.model_type
|
|
74
|
-
)
|
|
67
|
+
regressor_klass = get_regressor(RegressorType.ADAPTIVE, self.settings.model_type)
|
|
75
68
|
self.state.model = regressor_klass(**self.settings.model_kwargs)
|
|
76
69
|
|
|
77
70
|
def _hash_message(self, message: AxisArray) -> int:
|
|
@@ -84,37 +77,30 @@ class AdaptiveLinearRegressorTransformer(
|
|
|
84
77
|
# .template is updated in partial_fit
|
|
85
78
|
pass
|
|
86
79
|
|
|
87
|
-
def partial_fit(self, message:
|
|
88
|
-
if np.any(np.isnan(message.
|
|
80
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
81
|
+
if np.any(np.isnan(message.data)):
|
|
89
82
|
return
|
|
90
83
|
|
|
91
84
|
if self.settings.model_type in [
|
|
92
85
|
AdaptiveLinearRegressor.LINEAR,
|
|
93
86
|
AdaptiveLinearRegressor.LOGISTIC,
|
|
94
87
|
]:
|
|
95
|
-
x = pd.DataFrame.from_dict(
|
|
96
|
-
{
|
|
97
|
-
k: v
|
|
98
|
-
for k, v in zip(
|
|
99
|
-
message.sample.axes["ch"].data, message.sample.data.T
|
|
100
|
-
)
|
|
101
|
-
}
|
|
102
|
-
)
|
|
88
|
+
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
|
|
103
89
|
y = pd.Series(
|
|
104
|
-
data=message.trigger.value.data[:, 0],
|
|
105
|
-
name=message.trigger.value.axes["ch"].data[0],
|
|
90
|
+
data=message.attrs["trigger"].value.data[:, 0],
|
|
91
|
+
name=message.attrs["trigger"].value.axes["ch"].data[0],
|
|
106
92
|
)
|
|
107
93
|
self.state.model.learn_many(x, y)
|
|
108
94
|
else:
|
|
109
|
-
X = message.
|
|
110
|
-
if message.
|
|
111
|
-
X = np.moveaxis(X, message.
|
|
112
|
-
self.state.model.partial_fit(X, message.trigger.value.data)
|
|
95
|
+
X = message.data
|
|
96
|
+
if message.get_axis_idx("time") != 0:
|
|
97
|
+
X = np.moveaxis(X, message.get_axis_idx("time"), 0)
|
|
98
|
+
self.state.model.partial_fit(X, message.attrs["trigger"].value.data)
|
|
113
99
|
|
|
114
100
|
self.state.template = replace(
|
|
115
|
-
message.trigger.value,
|
|
116
|
-
data=np.empty_like(message.trigger.value.data),
|
|
117
|
-
key=message.trigger.value.key + "_pred",
|
|
101
|
+
message.attrs["trigger"].value,
|
|
102
|
+
data=np.empty_like(message.attrs["trigger"].value.data),
|
|
103
|
+
key=message.attrs["trigger"].value.key + "_pred",
|
|
118
104
|
)
|
|
119
105
|
|
|
120
106
|
def _process(self, message: AxisArray) -> AxisArray | None:
|
|
@@ -127,9 +113,7 @@ class AdaptiveLinearRegressorTransformer(
|
|
|
127
113
|
AdaptiveLinearRegressor.LOGISTIC,
|
|
128
114
|
]:
|
|
129
115
|
# convert msg_in.data to something appropriate for river
|
|
130
|
-
x = pd.DataFrame.from_dict(
|
|
131
|
-
{k: v for k, v in zip(message.axes["ch"].data, message.data.T)}
|
|
132
|
-
)
|
|
116
|
+
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
|
|
133
117
|
preds = self.state.model.predict_many(x).values
|
|
134
118
|
else:
|
|
135
119
|
preds = self.state.model.predict(message.data)
|
ezmsg/learn/process/base.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import json
|
|
3
|
-
from pathlib import Path
|
|
4
3
|
import typing
|
|
4
|
+
from pathlib import Path
|
|
5
5
|
|
|
6
6
|
import ezmsg.core as ez
|
|
7
7
|
import torch
|
|
@@ -32,9 +32,7 @@ class ModelInitMixin:
|
|
|
32
32
|
for key, value in config.items():
|
|
33
33
|
if key in model_kwargs:
|
|
34
34
|
if model_kwargs[key] != value:
|
|
35
|
-
ez.logger.warning(
|
|
36
|
-
f"Config parameter {key} ({value}) differs from settings ({model_kwargs[key]})."
|
|
37
|
-
)
|
|
35
|
+
ez.logger.warning(f"Config parameter {key} ({value}) differs from settings ({model_kwargs[key]}).")
|
|
38
36
|
else:
|
|
39
37
|
ez.logger.warning(f"Config parameter {key} is not in model_kwargs.")
|
|
40
38
|
model_kwargs[key] = value
|
|
@@ -44,7 +42,8 @@ class ModelInitMixin:
|
|
|
44
42
|
filtered_out = set(kwargs.keys()) - {k for k in valid_params if k != "self"}
|
|
45
43
|
if filtered_out:
|
|
46
44
|
ez.logger.warning(
|
|
47
|
-
|
|
45
|
+
"Ignoring unexpected model parameters not accepted by"
|
|
46
|
+
f"{model_class.__name__} constructor: {sorted(filtered_out)}"
|
|
48
47
|
)
|
|
49
48
|
# Keep all valid parameters, including None values, so checkpoint-inferred values can overwrite them
|
|
50
49
|
return {k: v for k, v in kwargs.items() if k in valid_params and k != "self"}
|
|
@@ -92,22 +91,16 @@ class ModelInitMixin:
|
|
|
92
91
|
config = json.load(f)
|
|
93
92
|
self._merge_config(model_kwargs, config)
|
|
94
93
|
except Exception as e:
|
|
95
|
-
raise RuntimeError(
|
|
96
|
-
f"Failed to load config from {config_path}: {str(e)}"
|
|
97
|
-
)
|
|
94
|
+
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
|
98
95
|
|
|
99
96
|
# If a checkpoint file is provided, load it.
|
|
100
97
|
if checkpoint_path:
|
|
101
98
|
checkpoint_path = Path(checkpoint_path)
|
|
102
99
|
if not checkpoint_path.exists():
|
|
103
100
|
ez.logger.error(f"Checkpoint path {checkpoint_path} does not exist.")
|
|
104
|
-
raise FileNotFoundError(
|
|
105
|
-
f"Checkpoint path {checkpoint_path} does not exist."
|
|
106
|
-
)
|
|
101
|
+
raise FileNotFoundError(f"Checkpoint path {checkpoint_path} does not exist.")
|
|
107
102
|
try:
|
|
108
|
-
checkpoint = torch.load(
|
|
109
|
-
checkpoint_path, map_location=device, weights_only=weights_only
|
|
110
|
-
)
|
|
103
|
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
|
|
111
104
|
|
|
112
105
|
if "config" in checkpoint:
|
|
113
106
|
config = checkpoint["config"]
|
|
@@ -126,20 +119,14 @@ class ModelInitMixin:
|
|
|
126
119
|
"infer_config_from_state_dict",
|
|
127
120
|
lambda _state_dict: {}, # Default to empty dict if not defined
|
|
128
121
|
)
|
|
129
|
-
infer_kwargs =
|
|
130
|
-
{"rnn_type": model_kwargs["rnn_type"]}
|
|
131
|
-
if "rnn_type" in model_kwargs
|
|
132
|
-
else {}
|
|
133
|
-
)
|
|
122
|
+
infer_kwargs = {"rnn_type": model_kwargs["rnn_type"]} if "rnn_type" in model_kwargs else {}
|
|
134
123
|
self._merge_config(
|
|
135
124
|
model_kwargs,
|
|
136
125
|
infer_config(state_dict, **infer_kwargs),
|
|
137
126
|
)
|
|
138
127
|
|
|
139
128
|
except Exception as e:
|
|
140
|
-
raise RuntimeError(
|
|
141
|
-
f"Failed to load checkpoint from {checkpoint_path}: {str(e)}"
|
|
142
|
-
)
|
|
129
|
+
raise RuntimeError(f"Failed to load checkpoint from {checkpoint_path}: {str(e)}")
|
|
143
130
|
|
|
144
131
|
# Filter model_kwargs to only include valid parameters for the model class
|
|
145
132
|
filtered_kwargs = self._filter_model_kwargs(model_class, model_kwargs)
|
|
@@ -156,18 +143,12 @@ class ModelInitMixin:
|
|
|
156
143
|
if state_dict_prefix:
|
|
157
144
|
# If a prefix is provided, filter the state_dict keys
|
|
158
145
|
state_dict = {
|
|
159
|
-
k[len(state_dict_prefix) :]: v
|
|
160
|
-
for k, v in state_dict.items()
|
|
161
|
-
if k.startswith(state_dict_prefix)
|
|
146
|
+
k[len(state_dict_prefix) :]: v for k, v in state_dict.items() if k.startswith(state_dict_prefix)
|
|
162
147
|
}
|
|
163
148
|
# Load the model weights
|
|
164
|
-
missing, unexpected = model.load_state_dict(
|
|
165
|
-
state_dict, strict=False, assign=True
|
|
166
|
-
)
|
|
149
|
+
missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
|
|
167
150
|
if missing or unexpected:
|
|
168
|
-
ez.logger.warning(
|
|
169
|
-
f"Partial load: missing keys: {missing}, unexpected keys: {unexpected}"
|
|
170
|
-
)
|
|
151
|
+
ez.logger.warning(f"Partial load: missing keys: {missing}, unexpected keys: {unexpected}")
|
|
171
152
|
|
|
172
153
|
model.to(device)
|
|
173
154
|
return model
|