ezmsg-learn 1.1.0__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/PKG-INFO +3 -3
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/pyproject.toml +2 -2
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/__version__.py +2 -2
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/adaptive_linear_regressor.py +12 -13
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/linear_regressor.py +6 -7
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/mlp_old.py +3 -4
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/refit_kalman.py +5 -6
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/rnn.py +4 -5
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/sgd.py +6 -7
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/sklearn.py +8 -9
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/torch.py +3 -4
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/transformer.py +3 -4
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_adaptive_linear_regressor.py +2 -2
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_linear_regressor.py +2 -2
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_mlp.py +9 -9
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_mlp_old.py +10 -5
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_refit_kalman.py +7 -8
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_rnn.py +19 -25
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_sgd.py +5 -4
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_sklearn.py +12 -13
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_torch.py +9 -14
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_transformer.py +17 -19
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.github/workflows/docs.yml +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.github/workflows/python-publish.yml +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.github/workflows/python-tests.yml +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.gitignore +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/.pre-commit-config.yaml +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/LICENSE +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/README.md +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/Makefile +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/make.bat +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/_templates/autosummary/module.rst +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/api/index.rst +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/conf.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/guides/classification.rst +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/docs/source/index.rst +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/__init__.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/dim_reduce/__init__.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/dim_reduce/adaptive_decomp.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/dim_reduce/incremental_decomp.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/__init__.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/adaptive_linear_regressor.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/cca.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/linear_regressor.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/sgd.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/slda.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/__init__.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/cca.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/mlp.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/mlp_old.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/refit_kalman.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/rnn.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/model/transformer.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/nlin_model/__init__.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/nlin_model/mlp.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/__init__.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/base.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/slda.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/ssr.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/util.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/benchmark/bench_lrr.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/dim_reduce/test_adaptive_decomp.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/dim_reduce/test_incremental_decomp.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/conftest.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_mlp_system.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_refit_kalman_system.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_rnn_system.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_sklearn_system.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_torch_system.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/integration/test_transformer_system.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_slda.py +0 -0
- {ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/tests/unit/test_ssr.py +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-learn
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: ezmsg namespace package for machine learning
|
|
5
5
|
Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
7
7
|
License-File: LICENSE
|
|
8
8
|
Requires-Python: >=3.10.15
|
|
9
|
-
Requires-Dist: ezmsg-baseproc>=1.0
|
|
10
|
-
Requires-Dist: ezmsg-sigproc>=2.
|
|
9
|
+
Requires-Dist: ezmsg-baseproc>=1.3.0
|
|
10
|
+
Requires-Dist: ezmsg-sigproc>=2.15.0
|
|
11
11
|
Requires-Dist: river>=0.22.0
|
|
12
12
|
Requires-Dist: scikit-learn>=1.6.0
|
|
13
13
|
Requires-Dist: torch>=2.6.0
|
|
@@ -9,8 +9,8 @@ license = "MIT"
|
|
|
9
9
|
requires-python = ">=3.10.15"
|
|
10
10
|
dynamic = ["version"]
|
|
11
11
|
dependencies = [
|
|
12
|
-
"ezmsg-baseproc>=1.0
|
|
13
|
-
"ezmsg-sigproc>=2.
|
|
12
|
+
"ezmsg-baseproc>=1.3.0",
|
|
13
|
+
"ezmsg-sigproc>=2.15.0",
|
|
14
14
|
"river>=0.22.0",
|
|
15
15
|
"scikit-learn>=1.6.0",
|
|
16
16
|
"torch>=2.6.0",
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.
|
|
32
|
-
__version_tuple__ = version_tuple = (1,
|
|
31
|
+
__version__ = version = '1.2.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 2, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
{ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/process/adaptive_linear_regressor.py
RENAMED
|
@@ -11,7 +11,6 @@ from ezmsg.baseproc import (
|
|
|
11
11
|
BaseAdaptiveTransformerUnit,
|
|
12
12
|
processor_state,
|
|
13
13
|
)
|
|
14
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
15
14
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
16
15
|
|
|
17
16
|
from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor
|
|
@@ -78,30 +77,30 @@ class AdaptiveLinearRegressorTransformer(
|
|
|
78
77
|
# .template is updated in partial_fit
|
|
79
78
|
pass
|
|
80
79
|
|
|
81
|
-
def partial_fit(self, message:
|
|
82
|
-
if np.any(np.isnan(message.
|
|
80
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
81
|
+
if np.any(np.isnan(message.data)):
|
|
83
82
|
return
|
|
84
83
|
|
|
85
84
|
if self.settings.model_type in [
|
|
86
85
|
AdaptiveLinearRegressor.LINEAR,
|
|
87
86
|
AdaptiveLinearRegressor.LOGISTIC,
|
|
88
87
|
]:
|
|
89
|
-
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.
|
|
88
|
+
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
|
|
90
89
|
y = pd.Series(
|
|
91
|
-
data=message.trigger.value.data[:, 0],
|
|
92
|
-
name=message.trigger.value.axes["ch"].data[0],
|
|
90
|
+
data=message.attrs["trigger"].value.data[:, 0],
|
|
91
|
+
name=message.attrs["trigger"].value.axes["ch"].data[0],
|
|
93
92
|
)
|
|
94
93
|
self.state.model.learn_many(x, y)
|
|
95
94
|
else:
|
|
96
|
-
X = message.
|
|
97
|
-
if message.
|
|
98
|
-
X = np.moveaxis(X, message.
|
|
99
|
-
self.state.model.partial_fit(X, message.trigger.value.data)
|
|
95
|
+
X = message.data
|
|
96
|
+
if message.get_axis_idx("time") != 0:
|
|
97
|
+
X = np.moveaxis(X, message.get_axis_idx("time"), 0)
|
|
98
|
+
self.state.model.partial_fit(X, message.attrs["trigger"].value.data)
|
|
100
99
|
|
|
101
100
|
self.state.template = replace(
|
|
102
|
-
message.trigger.value,
|
|
103
|
-
data=np.empty_like(message.trigger.value.data),
|
|
104
|
-
key=message.trigger.value.key + "_pred",
|
|
101
|
+
message.attrs["trigger"].value,
|
|
102
|
+
data=np.empty_like(message.attrs["trigger"].value.data),
|
|
103
|
+
key=message.attrs["trigger"].value.key + "_pred",
|
|
105
104
|
)
|
|
106
105
|
|
|
107
106
|
def _process(self, message: AxisArray) -> AxisArray | None:
|
|
@@ -7,7 +7,6 @@ from ezmsg.baseproc import (
|
|
|
7
7
|
BaseAdaptiveTransformerUnit,
|
|
8
8
|
processor_state,
|
|
9
9
|
)
|
|
10
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
11
10
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
12
11
|
from sklearn.linear_model._base import LinearModel
|
|
13
12
|
|
|
@@ -53,18 +52,18 @@ class LinearRegressorTransformer(
|
|
|
53
52
|
# .model and .template are initialized in __init__
|
|
54
53
|
pass
|
|
55
54
|
|
|
56
|
-
def partial_fit(self, message:
|
|
57
|
-
if np.any(np.isnan(message.
|
|
55
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
56
|
+
if np.any(np.isnan(message.data)):
|
|
58
57
|
return
|
|
59
58
|
|
|
60
|
-
X = message.
|
|
61
|
-
y = message.trigger.value.data
|
|
59
|
+
X = message.data
|
|
60
|
+
y = message.attrs["trigger"].value.data
|
|
62
61
|
# TODO: Resample should provide identical durations.
|
|
63
62
|
self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]])
|
|
64
63
|
self.state.template = replace(
|
|
65
|
-
message.trigger.value,
|
|
64
|
+
message.attrs["trigger"].value,
|
|
66
65
|
data=np.array([[]]),
|
|
67
|
-
key=message.trigger.value.key + "_pred",
|
|
66
|
+
key=message.attrs["trigger"].value.key + "_pred",
|
|
68
67
|
)
|
|
69
68
|
|
|
70
69
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
@@ -9,7 +9,6 @@ from ezmsg.baseproc import (
|
|
|
9
9
|
BaseAdaptiveTransformerUnit,
|
|
10
10
|
processor_state,
|
|
11
11
|
)
|
|
12
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
13
12
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
14
13
|
from ezmsg.util.messages.util import replace
|
|
15
14
|
|
|
@@ -134,14 +133,14 @@ class MLPProcessor(BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, ML
|
|
|
134
133
|
dtype = torch.float32 if self.settings.single_precision else torch.float64
|
|
135
134
|
return torch.tensor(data, dtype=dtype, device=self._state.device)
|
|
136
135
|
|
|
137
|
-
def partial_fit(self, message:
|
|
136
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
138
137
|
self._state.model.train()
|
|
139
138
|
|
|
140
139
|
# TODO: loss_fn should be determined by setting
|
|
141
140
|
loss_fn = torch.nn.functional.mse_loss
|
|
142
141
|
|
|
143
|
-
X = self._to_tensor(message.
|
|
144
|
-
y_targ = self._to_tensor(message.trigger.value)
|
|
142
|
+
X = self._to_tensor(message.data)
|
|
143
|
+
y_targ = self._to_tensor(message.attrs["trigger"].value)
|
|
145
144
|
|
|
146
145
|
with torch.set_grad_enabled(True):
|
|
147
146
|
self._state.model.train()
|
|
@@ -8,7 +8,6 @@ from ezmsg.baseproc import (
|
|
|
8
8
|
BaseAdaptiveTransformerUnit,
|
|
9
9
|
processor_state,
|
|
10
10
|
)
|
|
11
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
12
11
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
12
|
from ezmsg.util.messages.util import replace
|
|
14
13
|
|
|
@@ -284,22 +283,22 @@ class RefitKalmanFilterProcessor(
|
|
|
284
283
|
key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered",
|
|
285
284
|
)
|
|
286
285
|
|
|
287
|
-
def partial_fit(self, message:
|
|
286
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
288
287
|
"""
|
|
289
288
|
Perform refitting using externally provided data.
|
|
290
289
|
|
|
291
|
-
Expects message.
|
|
290
|
+
Expects message.data (neural input) and message.attrs["trigger"].value as a dict with:
|
|
292
291
|
- Y_state: (n_samples, n_states) array
|
|
293
292
|
- intention_velocity_indices: Optional[int]
|
|
294
293
|
- target_positions: Optional[np.ndarray]
|
|
295
294
|
- cursor_positions: Optional[np.ndarray]
|
|
296
295
|
- hold_flags: Optional[list[bool]]
|
|
297
296
|
"""
|
|
298
|
-
if
|
|
297
|
+
if "trigger" not in message.attrs:
|
|
299
298
|
raise ValueError("Invalid message format for partial_fit.")
|
|
300
299
|
|
|
301
|
-
X = np.array(message.
|
|
302
|
-
values = message.trigger.value
|
|
300
|
+
X = np.array(message.data)
|
|
301
|
+
values = message.attrs["trigger"].value
|
|
303
302
|
|
|
304
303
|
if not isinstance(values, dict) or "Y_state" not in values:
|
|
305
304
|
raise ValueError("partial_fit expects trigger.value to include at least 'Y_state'.")
|
|
@@ -5,7 +5,6 @@ import numpy as np
|
|
|
5
5
|
import torch
|
|
6
6
|
from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
|
|
7
7
|
from ezmsg.baseproc.util.profile import profile_subpub
|
|
8
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
9
8
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
9
|
from ezmsg.util.messages.util import replace
|
|
11
10
|
|
|
@@ -184,18 +183,18 @@ class RNNProcessor(
|
|
|
184
183
|
if self._state.scheduler is not None:
|
|
185
184
|
self._state.scheduler.step()
|
|
186
185
|
|
|
187
|
-
def partial_fit(self, message:
|
|
186
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
188
187
|
self._state.model.train()
|
|
189
188
|
|
|
190
|
-
X = self._to_tensor(message.
|
|
189
|
+
X = self._to_tensor(message.data)
|
|
191
190
|
|
|
192
191
|
# Add batch dimension if missing
|
|
193
192
|
X, batched = self._ensure_batched(X)
|
|
194
193
|
|
|
195
194
|
batch_size = X.shape[0]
|
|
196
|
-
preserve_state = self._maybe_reset_state(message
|
|
195
|
+
preserve_state = self._maybe_reset_state(message, batch_size)
|
|
197
196
|
|
|
198
|
-
y_targ = message.trigger.value
|
|
197
|
+
y_targ = message.attrs["trigger"].value
|
|
199
198
|
if not isinstance(y_targ, dict):
|
|
200
199
|
y_targ = {"output": y_targ}
|
|
201
200
|
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
|
|
@@ -5,7 +5,6 @@ import numpy as np
|
|
|
5
5
|
from ezmsg.baseproc import (
|
|
6
6
|
BaseAdaptiveTransformer,
|
|
7
7
|
BaseAdaptiveTransformerUnit,
|
|
8
|
-
SampleMessage,
|
|
9
8
|
processor_state,
|
|
10
9
|
)
|
|
11
10
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
@@ -87,23 +86,23 @@ class SGDDecoderTransformer(BaseAdaptiveTransformer[SGDDecoderSettings, AxisArra
|
|
|
87
86
|
key=message.key,
|
|
88
87
|
)
|
|
89
88
|
|
|
90
|
-
def partial_fit(self, message:
|
|
89
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
91
90
|
if self._hash != 0:
|
|
92
|
-
self._reset_state(message
|
|
91
|
+
self._reset_state(message)
|
|
93
92
|
self._hash = 0
|
|
94
93
|
|
|
95
|
-
if np.any(np.isnan(message.
|
|
94
|
+
if np.any(np.isnan(message.data)):
|
|
96
95
|
return
|
|
97
|
-
train_sample = message.
|
|
96
|
+
train_sample = message.data.reshape(1, -1)
|
|
98
97
|
if self._state.b_first_train:
|
|
99
98
|
self._state.model.partial_fit(
|
|
100
99
|
train_sample,
|
|
101
|
-
[message.trigger.value],
|
|
100
|
+
[message.attrs["trigger"].value],
|
|
102
101
|
classes=list(self.settings.label_weights.keys()),
|
|
103
102
|
)
|
|
104
103
|
self._state.b_first_train = False
|
|
105
104
|
else:
|
|
106
|
-
self._state.model.partial_fit(train_sample, [message.trigger.value])
|
|
105
|
+
self._state.model.partial_fit(train_sample, [message.attrs["trigger"].value])
|
|
107
106
|
|
|
108
107
|
|
|
109
108
|
class SGDDecoder(
|
|
@@ -10,7 +10,6 @@ from ezmsg.baseproc import (
|
|
|
10
10
|
BaseAdaptiveTransformerUnit,
|
|
11
11
|
processor_state,
|
|
12
12
|
)
|
|
13
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
14
13
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
15
14
|
from ezmsg.util.messages.util import replace
|
|
16
15
|
|
|
@@ -116,25 +115,25 @@ class SklearnModelProcessor(BaseAdaptiveTransformer[SklearnModelSettings, AxisAr
|
|
|
116
115
|
# No checkpoint, initialize from scratch
|
|
117
116
|
self._init_model()
|
|
118
117
|
|
|
119
|
-
def partial_fit(self, message:
|
|
120
|
-
X = message.
|
|
121
|
-
y = message.trigger.value
|
|
118
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
119
|
+
X = message.data
|
|
120
|
+
y = message.attrs["trigger"].value
|
|
122
121
|
if self._state.model is None:
|
|
123
|
-
self._reset_state(message
|
|
122
|
+
self._reset_state(message)
|
|
124
123
|
if hasattr(self._state.model, "partial_fit"):
|
|
125
124
|
kwargs = {}
|
|
126
125
|
if self.settings.partial_fit_classes is not None:
|
|
127
126
|
kwargs["classes"] = self.settings.partial_fit_classes
|
|
128
127
|
self._state.model.partial_fit(X, y, **kwargs)
|
|
129
128
|
elif hasattr(self._state.model, "learn_many"):
|
|
130
|
-
df_X = pd.DataFrame({k: v for k, v in zip(message.
|
|
129
|
+
df_X = pd.DataFrame({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
|
|
131
130
|
name = (
|
|
132
|
-
message.trigger.value.axes["ch"].data[0]
|
|
133
|
-
if hasattr(message.trigger.value, "axes") and "ch" in message.trigger.value.axes
|
|
131
|
+
message.attrs["trigger"].value.axes["ch"].data[0]
|
|
132
|
+
if hasattr(message.attrs["trigger"].value, "axes") and "ch" in message.attrs["trigger"].value.axes
|
|
134
133
|
else "target"
|
|
135
134
|
)
|
|
136
135
|
ser_y = pd.Series(
|
|
137
|
-
data=np.asarray(message.trigger.value.data).flatten(),
|
|
136
|
+
data=np.asarray(message.attrs["trigger"].value.data).flatten(),
|
|
138
137
|
name=name,
|
|
139
138
|
)
|
|
140
139
|
self._state.model.learn_many(df_X, ser_y)
|
|
@@ -12,7 +12,6 @@ from ezmsg.baseproc import (
|
|
|
12
12
|
processor_state,
|
|
13
13
|
)
|
|
14
14
|
from ezmsg.baseproc.util.profile import profile_subpub
|
|
15
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
16
15
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
17
16
|
from ezmsg.util.messages.util import replace
|
|
18
17
|
|
|
@@ -294,13 +293,13 @@ class TorchModelProcessor(
|
|
|
294
293
|
def _process(self, message: AxisArray) -> list[AxisArray]:
|
|
295
294
|
return self._common_process(message)
|
|
296
295
|
|
|
297
|
-
def partial_fit(self, message:
|
|
296
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
298
297
|
self._state.model.train()
|
|
299
298
|
|
|
300
|
-
X = self._to_tensor(message.
|
|
299
|
+
X = self._to_tensor(message.data)
|
|
301
300
|
X, batched = self._ensure_batched(X)
|
|
302
301
|
|
|
303
|
-
y_targ = message.trigger.value
|
|
302
|
+
y_targ = message.attrs["trigger"].value
|
|
304
303
|
if not isinstance(y_targ, dict):
|
|
305
304
|
y_targ = {"output": y_targ}
|
|
306
305
|
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
|
|
@@ -4,7 +4,6 @@ import ezmsg.core as ez
|
|
|
4
4
|
import torch
|
|
5
5
|
from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
|
|
6
6
|
from ezmsg.baseproc.util.profile import profile_subpub
|
|
7
|
-
from ezmsg.sigproc.sampler import SampleMessage
|
|
8
7
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
8
|
from ezmsg.util.messages.util import replace
|
|
10
9
|
|
|
@@ -125,13 +124,13 @@ class TransformerProcessor(
|
|
|
125
124
|
)
|
|
126
125
|
]
|
|
127
126
|
|
|
128
|
-
def partial_fit(self, message:
|
|
127
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
129
128
|
self._state.model.train()
|
|
130
129
|
|
|
131
|
-
X = self._to_tensor(message.
|
|
130
|
+
X = self._to_tensor(message.data)
|
|
132
131
|
X, batched = self._ensure_batched(X)
|
|
133
132
|
|
|
134
|
-
y_targ = message.trigger.value
|
|
133
|
+
y_targ = message.attrs["trigger"].value
|
|
135
134
|
if not isinstance(y_targ, dict):
|
|
136
135
|
y_targ = {"output": y_targ}
|
|
137
136
|
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import pytest
|
|
3
|
-
from ezmsg.
|
|
3
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
4
4
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
5
5
|
|
|
6
6
|
from ezmsg.learn.process.adaptive_linear_regressor import (
|
|
@@ -42,7 +42,7 @@ def test_adaptive_linear_regressor(model_type: str):
|
|
|
42
42
|
period=(0.0, dur),
|
|
43
43
|
value=value_axarr,
|
|
44
44
|
)
|
|
45
|
-
samp =
|
|
45
|
+
samp = replace(sig_axarr, attrs={"trigger": samp_trig})
|
|
46
46
|
|
|
47
47
|
proc = AdaptiveLinearRegressorTransformer(model_type=model_type)
|
|
48
48
|
_ = proc.send(samp)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import pytest
|
|
3
|
-
from ezmsg.
|
|
3
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
4
4
|
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
5
5
|
|
|
6
6
|
from ezmsg.learn.process.linear_regressor import LinearRegressorTransformer
|
|
@@ -40,7 +40,7 @@ def test_linear_regressor(model_type: str):
|
|
|
40
40
|
period=(0.0, dur),
|
|
41
41
|
value=value_axarr,
|
|
42
42
|
)
|
|
43
|
-
samp =
|
|
43
|
+
samp = replace(sig_axarr, attrs={"trigger": samp_trig})
|
|
44
44
|
|
|
45
45
|
gen = LinearRegressorTransformer(model_type=model_type)
|
|
46
46
|
_ = gen.send(samp)
|
|
@@ -77,7 +77,8 @@ def test_mlp_checkpoint_io(tmp_path, sample_input, mlp_settings):
|
|
|
77
77
|
|
|
78
78
|
|
|
79
79
|
def test_mlp_partial_fit_learns(sample_input, mlp_settings):
|
|
80
|
-
from ezmsg.
|
|
80
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
81
|
+
from ezmsg.util.messages.util import replace
|
|
81
82
|
|
|
82
83
|
proc = TorchModelProcessor(
|
|
83
84
|
model_class="ezmsg.learn.model.mlp.MLP",
|
|
@@ -88,13 +89,12 @@ def test_mlp_partial_fit_learns(sample_input, mlp_settings):
|
|
|
88
89
|
)
|
|
89
90
|
proc(sample_input)
|
|
90
91
|
|
|
91
|
-
sample = AxisArray(
|
|
92
|
-
data=sample_input.data[:1], dims=["time", "ch"], axes=sample_input.axes
|
|
93
|
-
)
|
|
92
|
+
sample = AxisArray(data=sample_input.data[:1], dims=["time", "ch"], axes=sample_input.axes)
|
|
94
93
|
target = np.random.randn(1, 5)
|
|
95
94
|
|
|
96
|
-
msg =
|
|
97
|
-
sample
|
|
95
|
+
msg = replace(
|
|
96
|
+
sample,
|
|
97
|
+
attrs={**sample.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target)},
|
|
98
98
|
)
|
|
99
99
|
|
|
100
100
|
before = [p.detach().clone() for p in proc.state.model.parameters()]
|
|
@@ -135,9 +135,9 @@ def test_mlp_hidden_size_integer(sample_input):
|
|
|
135
135
|
device="cpu",
|
|
136
136
|
)
|
|
137
137
|
proc(sample_input)
|
|
138
|
-
hidden_layers = [
|
|
139
|
-
|
|
140
|
-
]
|
|
138
|
+
hidden_layers = [m for m in proc._state.model.modules() if isinstance(m, torch.nn.Linear)][
|
|
139
|
+
:-1
|
|
140
|
+
] # Exclude the output head
|
|
141
141
|
assert len(hidden_layers) == 3 # num_layers = 3
|
|
142
142
|
assert hidden_layers[0].in_features == 8
|
|
143
143
|
assert all(layer.out_features == 32 for layer in hidden_layers[:-1])
|
|
@@ -4,8 +4,9 @@ import numpy as np
|
|
|
4
4
|
import pytest
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn
|
|
7
|
-
from ezmsg.
|
|
7
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
8
8
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
from ezmsg.util.messages.util import replace
|
|
9
10
|
from sklearn.model_selection import train_test_split
|
|
10
11
|
|
|
11
12
|
from ezmsg.learn.process.mlp_old import MLPProcessor
|
|
@@ -146,7 +147,10 @@ def test_mlp_process():
|
|
|
146
147
|
template.data[:] = X # This would fail if n_samps / batch_size had a remainder.
|
|
147
148
|
template.axes["time"].offset = ts
|
|
148
149
|
if set == 0:
|
|
149
|
-
yield
|
|
150
|
+
yield replace(
|
|
151
|
+
template,
|
|
152
|
+
attrs={**template.attrs, "trigger": SampleTriggerMessage(timestamp=ts, value=y)},
|
|
153
|
+
)
|
|
150
154
|
else:
|
|
151
155
|
yield template, y
|
|
152
156
|
|
|
@@ -167,14 +171,15 @@ def test_mlp_process():
|
|
|
167
171
|
result = []
|
|
168
172
|
train_loss = []
|
|
169
173
|
for sample_msg in xy_gen(set=0):
|
|
170
|
-
# Naive closed-loop inference
|
|
171
|
-
|
|
174
|
+
# Naive closed-loop inference — strip trigger attrs before inference
|
|
175
|
+
plain_msg = replace(sample_msg, attrs={})
|
|
176
|
+
result.append(proc(plain_msg))
|
|
172
177
|
|
|
173
178
|
# Collect the loss to see if it decreases with training.
|
|
174
179
|
train_loss.append(
|
|
175
180
|
torch.nn.MSELoss()(
|
|
176
181
|
torch.tensor(result[-1].data),
|
|
177
|
-
torch.tensor(sample_msg.trigger.value.reshape(-1, 1), dtype=torch.float32),
|
|
182
|
+
torch.tensor(sample_msg.attrs["trigger"].value.reshape(-1, 1), dtype=torch.float32),
|
|
178
183
|
).item()
|
|
179
184
|
)
|
|
180
185
|
|
|
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pytest
|
|
7
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
7
8
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
9
|
|
|
9
10
|
from ezmsg.learn.process.refit_kalman import (
|
|
@@ -299,12 +300,6 @@ def test_partial_fit_functionality(create_test_message, checkpoint_file):
|
|
|
299
300
|
H_initial = checkpoint_data["H_observation_matrix"]
|
|
300
301
|
Q_initial = checkpoint_data["Q_measurement_noise_covariance"]
|
|
301
302
|
|
|
302
|
-
# Create a mock SampleMessage with the expected structure
|
|
303
|
-
class MockSampleMessage:
|
|
304
|
-
def __init__(self, neural_data, trigger_value):
|
|
305
|
-
self.sample = type("obj", (object,), {"data": neural_data})()
|
|
306
|
-
self.trigger = type("obj", (object,), {"value": trigger_value})()
|
|
307
|
-
|
|
308
303
|
# Create test data
|
|
309
304
|
neural_data = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) # 3 samples, 2 channels
|
|
310
305
|
trigger_value = {
|
|
@@ -315,8 +310,12 @@ def test_partial_fit_functionality(create_test_message, checkpoint_file):
|
|
|
315
310
|
"hold_flags": [False, False, False],
|
|
316
311
|
}
|
|
317
312
|
|
|
318
|
-
|
|
319
|
-
|
|
313
|
+
sample_msg = AxisArray(
|
|
314
|
+
data=neural_data,
|
|
315
|
+
dims=["time", "ch"],
|
|
316
|
+
attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=trigger_value)},
|
|
317
|
+
)
|
|
318
|
+
processor.partial_fit(sample_msg)
|
|
320
319
|
|
|
321
320
|
assert not np.allclose(H_initial, processor._state.model.H_observation_matrix)
|
|
322
321
|
assert not np.allclose(Q_initial, processor._state.model.Q_measurement_noise_covariance)
|
|
@@ -5,8 +5,9 @@ import numpy as np
|
|
|
5
5
|
import pytest
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn
|
|
8
|
-
from ezmsg.
|
|
8
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
9
9
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
10
11
|
|
|
11
12
|
from ezmsg.learn.process.rnn import RNNProcessor
|
|
12
13
|
|
|
@@ -107,9 +108,7 @@ def test_rnn_process(rnn_type, simple_message):
|
|
|
107
108
|
# We don't pass in the hx state so it should be initialized to zeros, same as in the first call to proc.
|
|
108
109
|
in_tensor = torch.tensor(simple_message.data[None, ...], dtype=torch.float32)
|
|
109
110
|
with torch.no_grad():
|
|
110
|
-
expected_result = (
|
|
111
|
-
proc.state.model(in_tensor)[0]["output"].cpu().numpy().squeeze(0)
|
|
112
|
-
)
|
|
111
|
+
expected_result = proc.state.model(in_tensor)[0]["output"].cpu().numpy().squeeze(0)
|
|
113
112
|
assert np.allclose(output.data, expected_result)
|
|
114
113
|
|
|
115
114
|
|
|
@@ -139,9 +138,9 @@ def test_rnn_partial_fit(simple_message):
|
|
|
139
138
|
|
|
140
139
|
target_shape = (simple_message.data.shape[0], output_size)
|
|
141
140
|
target_value = np.ones(target_shape, dtype=np.float32)
|
|
142
|
-
sample_message =
|
|
143
|
-
|
|
144
|
-
|
|
141
|
+
sample_message = replace(
|
|
142
|
+
simple_message,
|
|
143
|
+
attrs={**simple_message.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target_value)},
|
|
145
144
|
)
|
|
146
145
|
|
|
147
146
|
proc(sample_message)
|
|
@@ -149,9 +148,7 @@ def test_rnn_partial_fit(simple_message):
|
|
|
149
148
|
assert not proc.state.model.training
|
|
150
149
|
updated_weights = [p.detach() for p in proc.state.model.parameters()]
|
|
151
150
|
|
|
152
|
-
assert any(
|
|
153
|
-
not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)
|
|
154
|
-
)
|
|
151
|
+
assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights))
|
|
155
152
|
|
|
156
153
|
|
|
157
154
|
def test_rnn_checkpoint_save_load(simple_message):
|
|
@@ -201,9 +198,7 @@ def test_rnn_checkpoint_save_load(simple_message):
|
|
|
201
198
|
|
|
202
199
|
for key in state_dict1:
|
|
203
200
|
assert key in state_dict2, f"Missing key {key} in loaded state_dict"
|
|
204
|
-
assert torch.equal(state_dict1[key], state_dict2[key]),
|
|
205
|
-
f"Mismatch in parameter {key}"
|
|
206
|
-
)
|
|
201
|
+
assert torch.equal(state_dict1[key], state_dict2[key]), f"Mismatch in parameter {key}"
|
|
207
202
|
|
|
208
203
|
finally:
|
|
209
204
|
# Ensure the temporary file is deleted
|
|
@@ -244,20 +239,21 @@ def test_rnn_partial_fit_multiloss(simple_message):
|
|
|
244
239
|
dtype=torch.long,
|
|
245
240
|
)
|
|
246
241
|
|
|
247
|
-
sample_message =
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
242
|
+
sample_message = replace(
|
|
243
|
+
simple_message,
|
|
244
|
+
attrs={
|
|
245
|
+
**simple_message.attrs,
|
|
246
|
+
"trigger": SampleTriggerMessage(
|
|
247
|
+
timestamp=0.0,
|
|
248
|
+
value={"traj": traj_target, "state": state_target},
|
|
249
|
+
),
|
|
250
|
+
},
|
|
253
251
|
)
|
|
254
252
|
|
|
255
253
|
proc.partial_fit(sample_message)
|
|
256
254
|
|
|
257
255
|
updated_weights = [p.detach() for p in proc.state.model.parameters()]
|
|
258
|
-
assert any(
|
|
259
|
-
not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)
|
|
260
|
-
)
|
|
256
|
+
assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights))
|
|
261
257
|
|
|
262
258
|
|
|
263
259
|
@pytest.mark.parametrize(
|
|
@@ -269,9 +265,7 @@ def test_rnn_partial_fit_multiloss(simple_message):
|
|
|
269
265
|
("auto", 0.05, 0.1, False), # overlapping → reset
|
|
270
266
|
],
|
|
271
267
|
)
|
|
272
|
-
def test_rnn_preserve_state(
|
|
273
|
-
preserve_state_across_windows, win_stride, win_len, should_preserve
|
|
274
|
-
):
|
|
268
|
+
def test_rnn_preserve_state(preserve_state_across_windows, win_stride, win_len, should_preserve):
|
|
275
269
|
hidden_size = 16
|
|
276
270
|
num_layers = 1
|
|
277
271
|
output_size = 2
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
from ezmsg.
|
|
2
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
3
3
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
4
4
|
|
|
5
5
|
from ezmsg.learn.process.sgd import SGDDecoderSettings, SGDDecoderTransformer
|
|
@@ -13,9 +13,10 @@ def test_sgd():
|
|
|
13
13
|
data = np.random.normal(scale=0.05, size=(3, 2, 1))
|
|
14
14
|
data[time_idx[label] : time_idx[label] + 1, 0, 0] += 1.0
|
|
15
15
|
samples.append(
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
AxisArray(
|
|
17
|
+
data=data,
|
|
18
|
+
dims=["time", "ch", "freq"],
|
|
19
|
+
attrs={"trigger": SampleTriggerMessage(timestamp=len(samples), period=None, value=label)},
|
|
19
20
|
)
|
|
20
21
|
)
|
|
21
22
|
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import pytest
|
|
3
|
-
from ezmsg.
|
|
3
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
4
4
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
5
|
+
from ezmsg.util.messages.util import replace
|
|
5
6
|
|
|
6
7
|
from ezmsg.learn.process.sklearn import SklearnModelProcessor
|
|
7
8
|
|
|
@@ -83,9 +84,9 @@ def test_partial_fit_supported_models(
|
|
|
83
84
|
proc = SklearnModelProcessor(**settings_kwargs)
|
|
84
85
|
proc._reset_state(input_axisarray)
|
|
85
86
|
|
|
86
|
-
sample_msg =
|
|
87
|
-
|
|
88
|
-
trigger
|
|
87
|
+
sample_msg = replace(
|
|
88
|
+
input_axisarray,
|
|
89
|
+
attrs={**input_axisarray.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=labels)},
|
|
89
90
|
)
|
|
90
91
|
|
|
91
92
|
proc.partial_fit(sample_msg)
|
|
@@ -96,9 +97,9 @@ def test_partial_fit_supported_models(
|
|
|
96
97
|
def test_partial_fit_unsupported_model(input_axisarray, labels_regression):
|
|
97
98
|
proc = SklearnModelProcessor(model_class="sklearn.linear_model.Ridge")
|
|
98
99
|
proc._reset_state(input_axisarray)
|
|
99
|
-
sample_msg =
|
|
100
|
-
|
|
101
|
-
trigger
|
|
100
|
+
sample_msg = replace(
|
|
101
|
+
input_axisarray,
|
|
102
|
+
attrs={**input_axisarray.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=labels_regression)},
|
|
102
103
|
)
|
|
103
104
|
with pytest.raises(NotImplementedError, match="partial_fit"):
|
|
104
105
|
proc.partial_fit(sample_msg)
|
|
@@ -108,9 +109,9 @@ def test_partial_fit_changes_model(input_axisarray, labels_regression):
|
|
|
108
109
|
proc = SklearnModelProcessor(model_class="sklearn.linear_model.SGDRegressor")
|
|
109
110
|
proc._reset_state(input_axisarray)
|
|
110
111
|
|
|
111
|
-
sample_msg =
|
|
112
|
-
|
|
113
|
-
trigger
|
|
112
|
+
sample_msg = replace(
|
|
113
|
+
input_axisarray,
|
|
114
|
+
attrs={**input_axisarray.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=labels_regression)},
|
|
114
115
|
)
|
|
115
116
|
|
|
116
117
|
proc.partial_fit(sample_msg)
|
|
@@ -127,9 +128,7 @@ def test_model_save_and_load(tmp_path, input_axisarray):
|
|
|
127
128
|
checkpoint_path = tmp_path / "model_checkpoint.pkl"
|
|
128
129
|
proc.save_checkpoint(str(checkpoint_path))
|
|
129
130
|
|
|
130
|
-
new_proc = SklearnModelProcessor(
|
|
131
|
-
model_class="sklearn.linear_model.Ridge", checkpoint_path=str(checkpoint_path)
|
|
132
|
-
)
|
|
131
|
+
new_proc = SklearnModelProcessor(model_class="sklearn.linear_model.Ridge", checkpoint_path=str(checkpoint_path))
|
|
133
132
|
new_proc._reset_state(input_axisarray)
|
|
134
133
|
assert new_proc._state.model is not None
|
|
135
134
|
|
|
@@ -5,8 +5,9 @@ from pathlib import Path
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pytest
|
|
7
7
|
import torch
|
|
8
|
-
from ezmsg.
|
|
8
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
9
9
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
10
11
|
|
|
11
12
|
from ezmsg.learn.process.torch import TorchModelProcessor
|
|
12
13
|
|
|
@@ -185,9 +186,9 @@ def test_partial_fit_changes_weights(batch_message, device):
|
|
|
185
186
|
},
|
|
186
187
|
)
|
|
187
188
|
|
|
188
|
-
msg =
|
|
189
|
-
sample
|
|
190
|
-
trigger
|
|
189
|
+
msg = replace(
|
|
190
|
+
sample,
|
|
191
|
+
attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=y)},
|
|
191
192
|
)
|
|
192
193
|
|
|
193
194
|
proc(sample) # run forward pass once to init model
|
|
@@ -318,14 +319,11 @@ def test_multihead_partial_fit_with_loss_dict(batch_message, device):
|
|
|
318
319
|
"head_a": np.random.randn(1, 2),
|
|
319
320
|
"head_b": np.random.randn(1, 3),
|
|
320
321
|
}
|
|
321
|
-
|
|
322
|
+
msg = AxisArray(
|
|
322
323
|
data=batch_message.data[:1],
|
|
323
324
|
dims=["time", "ch"],
|
|
324
325
|
axes=batch_message.axes,
|
|
325
|
-
|
|
326
|
-
msg = SampleMessage(
|
|
327
|
-
sample=sample,
|
|
328
|
-
trigger=SampleTriggerMessage(timestamp=0.0, value=y_targ),
|
|
326
|
+
attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=y_targ)},
|
|
329
327
|
)
|
|
330
328
|
|
|
331
329
|
before_a = proc._state.model.head_a.weight.clone()
|
|
@@ -360,14 +358,11 @@ def test_partial_fit_with_loss_weights(batch_message, device):
|
|
|
360
358
|
"head_a": np.random.randn(1, 2),
|
|
361
359
|
"head_b": np.random.randn(1, 3),
|
|
362
360
|
}
|
|
363
|
-
|
|
361
|
+
msg = AxisArray(
|
|
364
362
|
data=batch_message.data[:1],
|
|
365
363
|
dims=["time", "ch"],
|
|
366
364
|
axes=batch_message.axes,
|
|
367
|
-
|
|
368
|
-
msg = SampleMessage(
|
|
369
|
-
sample=sample,
|
|
370
|
-
trigger=SampleTriggerMessage(timestamp=0.0, value=y_targ),
|
|
365
|
+
attrs={"trigger": SampleTriggerMessage(timestamp=0.0, value=y_targ)},
|
|
371
366
|
)
|
|
372
367
|
|
|
373
368
|
# Expect no error, and just run once
|
|
@@ -5,8 +5,9 @@ import numpy as np
|
|
|
5
5
|
import pytest
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn
|
|
8
|
-
from ezmsg.
|
|
8
|
+
from ezmsg.baseproc import SampleTriggerMessage
|
|
9
9
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
|
+
from ezmsg.util.messages.util import replace
|
|
10
11
|
|
|
11
12
|
from ezmsg.learn.process.transformer import TransformerProcessor
|
|
12
13
|
|
|
@@ -138,9 +139,9 @@ def test_transformer_partial_fit(simple_message, decoder_layers):
|
|
|
138
139
|
|
|
139
140
|
target_shape = (simple_message.data.shape[0], output_size)
|
|
140
141
|
target_value = np.ones(target_shape, dtype=np.float32)
|
|
141
|
-
sample_message =
|
|
142
|
-
|
|
143
|
-
|
|
142
|
+
sample_message = replace(
|
|
143
|
+
simple_message,
|
|
144
|
+
attrs={**simple_message.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target_value)},
|
|
144
145
|
)
|
|
145
146
|
|
|
146
147
|
proc.partial_fit(sample_message)
|
|
@@ -149,9 +150,7 @@ def test_transformer_partial_fit(simple_message, decoder_layers):
|
|
|
149
150
|
assert proc.state.tgt_cache is None
|
|
150
151
|
updated_weights = [p.detach() for p in proc.state.model.parameters()]
|
|
151
152
|
|
|
152
|
-
assert any(
|
|
153
|
-
not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)
|
|
154
|
-
)
|
|
153
|
+
assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights))
|
|
155
154
|
|
|
156
155
|
|
|
157
156
|
def test_transformer_checkpoint_save_load(simple_message):
|
|
@@ -201,9 +200,7 @@ def test_transformer_checkpoint_save_load(simple_message):
|
|
|
201
200
|
|
|
202
201
|
for key in state_dict1:
|
|
203
202
|
assert key in state_dict2, f"Missing key {key} in loaded state_dict"
|
|
204
|
-
assert torch.equal(state_dict1[key], state_dict2[key]),
|
|
205
|
-
f"Mismatch in parameter {key}"
|
|
206
|
-
)
|
|
203
|
+
assert torch.equal(state_dict1[key], state_dict2[key]), f"Mismatch in parameter {key}"
|
|
207
204
|
|
|
208
205
|
finally:
|
|
209
206
|
# Ensure the temporary file is deleted
|
|
@@ -244,20 +241,21 @@ def test_transformer_partial_fit_multiloss(simple_message):
|
|
|
244
241
|
dtype=torch.long,
|
|
245
242
|
)
|
|
246
243
|
|
|
247
|
-
sample_message =
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
244
|
+
sample_message = replace(
|
|
245
|
+
simple_message,
|
|
246
|
+
attrs={
|
|
247
|
+
**simple_message.attrs,
|
|
248
|
+
"trigger": SampleTriggerMessage(
|
|
249
|
+
timestamp=0.0,
|
|
250
|
+
value={"traj": traj_target, "state": state_target},
|
|
251
|
+
),
|
|
252
|
+
},
|
|
253
253
|
)
|
|
254
254
|
|
|
255
255
|
proc.partial_fit(sample_message)
|
|
256
256
|
|
|
257
257
|
updated_weights = [p.detach() for p in proc.state.model.parameters()]
|
|
258
|
-
assert any(
|
|
259
|
-
not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights)
|
|
260
|
-
)
|
|
258
|
+
assert any(not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights))
|
|
261
259
|
|
|
262
260
|
|
|
263
261
|
def test_autoregressive_cache_behavior(simple_message):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ezmsg_learn-1.1.0 → ezmsg_learn-1.2.0}/src/ezmsg/learn/linear_model/adaptive_linear_regressor.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|