ezmsg-learn 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__init__.py +2 -0
- ezmsg/learn/__version__.py +34 -0
- ezmsg/learn/dim_reduce/__init__.py +0 -0
- ezmsg/learn/dim_reduce/adaptive_decomp.py +284 -0
- ezmsg/learn/dim_reduce/incremental_decomp.py +181 -0
- ezmsg/learn/linear_model/__init__.py +1 -0
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
- ezmsg/learn/linear_model/cca.py +1 -0
- ezmsg/learn/linear_model/linear_regressor.py +5 -0
- ezmsg/learn/linear_model/sgd.py +5 -0
- ezmsg/learn/linear_model/slda.py +6 -0
- ezmsg/learn/model/__init__.py +0 -0
- ezmsg/learn/model/cca.py +122 -0
- ezmsg/learn/model/mlp.py +133 -0
- ezmsg/learn/model/mlp_old.py +49 -0
- ezmsg/learn/model/refit_kalman.py +401 -0
- ezmsg/learn/model/rnn.py +160 -0
- ezmsg/learn/model/transformer.py +175 -0
- ezmsg/learn/nlin_model/__init__.py +1 -0
- ezmsg/learn/nlin_model/mlp.py +6 -0
- ezmsg/learn/process/__init__.py +0 -0
- ezmsg/learn/process/adaptive_linear_regressor.py +157 -0
- ezmsg/learn/process/base.py +173 -0
- ezmsg/learn/process/linear_regressor.py +99 -0
- ezmsg/learn/process/mlp_old.py +200 -0
- ezmsg/learn/process/refit_kalman.py +407 -0
- ezmsg/learn/process/rnn.py +266 -0
- ezmsg/learn/process/sgd.py +131 -0
- ezmsg/learn/process/sklearn.py +274 -0
- ezmsg/learn/process/slda.py +119 -0
- ezmsg/learn/process/torch.py +378 -0
- ezmsg/learn/process/transformer.py +222 -0
- ezmsg/learn/util.py +66 -0
- ezmsg_learn-1.0.dist-info/METADATA +34 -0
- ezmsg_learn-1.0.dist-info/RECORD +36 -0
- ezmsg_learn-1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from dataclasses import field
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from sklearn.linear_model._base import LinearModel
|
|
5
|
+
import ezmsg.core as ez
|
|
6
|
+
from ezmsg.sigproc.base import (
|
|
7
|
+
processor_state,
|
|
8
|
+
BaseAdaptiveTransformer,
|
|
9
|
+
BaseAdaptiveTransformerUnit,
|
|
10
|
+
)
|
|
11
|
+
from ezmsg.util.messages.axisarray import AxisArray, replace
|
|
12
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
13
|
+
|
|
14
|
+
from ..util import get_regressor, StaticLinearRegressor, RegressorType
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LinearRegressorSettings(ez.Settings):
|
|
18
|
+
model_type: StaticLinearRegressor = StaticLinearRegressor.LINEAR
|
|
19
|
+
settings_path: str | None = None
|
|
20
|
+
model_kwargs: dict = field(default_factory=dict)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@processor_state
|
|
24
|
+
class LinearRegressorState:
|
|
25
|
+
template: AxisArray | None = None
|
|
26
|
+
model: LinearModel | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LinearRegressorTransformer(
|
|
30
|
+
BaseAdaptiveTransformer[
|
|
31
|
+
LinearRegressorSettings, AxisArray, AxisArray, LinearRegressorState
|
|
32
|
+
]
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Linear regressor.
|
|
36
|
+
|
|
37
|
+
Note: `partial_fit` is not 'partial'. It fully resets the model using the entirety of the SampleMessage provided.
|
|
38
|
+
If you require adaptive fitting, try using the adaptive_linear_regressor module.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, *args, **kwargs):
|
|
42
|
+
super().__init__(*args, **kwargs)
|
|
43
|
+
if self.settings.settings_path is not None:
|
|
44
|
+
# Load model from file
|
|
45
|
+
import pickle
|
|
46
|
+
|
|
47
|
+
with open(self.settings.settings_path, "rb") as f:
|
|
48
|
+
self.state.model = pickle.load(f)
|
|
49
|
+
else:
|
|
50
|
+
regressor_klass = get_regressor(
|
|
51
|
+
RegressorType.STATIC, self.settings.model_type
|
|
52
|
+
)
|
|
53
|
+
self.state.model = regressor_klass(**self.settings.model_kwargs)
|
|
54
|
+
|
|
55
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
56
|
+
# So far, there is nothing to reset.
|
|
57
|
+
# .model and .template are initialized in __init__
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
61
|
+
if np.any(np.isnan(message.sample.data)):
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
X = message.sample.data
|
|
65
|
+
y = message.trigger.value.data
|
|
66
|
+
# TODO: Resample should provide identical durations.
|
|
67
|
+
self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]])
|
|
68
|
+
self.state.template = replace(
|
|
69
|
+
message.trigger.value,
|
|
70
|
+
data=np.array([[]]),
|
|
71
|
+
key=message.trigger.value.key + "_pred",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
75
|
+
if self.state.template is None:
|
|
76
|
+
return AxisArray(np.array([[]]), dims=["time", "ch"])
|
|
77
|
+
preds = self.state.model.predict(message.data)
|
|
78
|
+
return replace(
|
|
79
|
+
self.state.template,
|
|
80
|
+
data=preds,
|
|
81
|
+
axes={
|
|
82
|
+
**self.state.template.axes,
|
|
83
|
+
"time": replace(
|
|
84
|
+
message.axes["time"],
|
|
85
|
+
offset=message.axes["time"].offset,
|
|
86
|
+
),
|
|
87
|
+
},
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class AdaptiveLinearRegressorUnit(
|
|
92
|
+
BaseAdaptiveTransformerUnit[
|
|
93
|
+
LinearRegressorSettings,
|
|
94
|
+
AxisArray,
|
|
95
|
+
AxisArray,
|
|
96
|
+
LinearRegressorTransformer,
|
|
97
|
+
]
|
|
98
|
+
):
|
|
99
|
+
SETTINGS = LinearRegressorSettings
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn
|
|
6
|
+
import ezmsg.core as ez
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
9
|
+
from ezmsg.util.messages.util import replace
|
|
10
|
+
from ezmsg.sigproc.base import (
|
|
11
|
+
BaseAdaptiveTransformer,
|
|
12
|
+
BaseAdaptiveTransformerUnit,
|
|
13
|
+
processor_state,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from ..model.mlp_old import MLP
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MLPSettings(ez.Settings):
|
|
20
|
+
hidden_channels: list[int]
|
|
21
|
+
"""List of the hidden channel dimensions"""
|
|
22
|
+
|
|
23
|
+
norm_layer: typing.Callable[..., torch.nn.Module] | None = None
|
|
24
|
+
"""Norm layer that will be stacked on top of the linear layer. If None this layer won’t be used."""
|
|
25
|
+
|
|
26
|
+
activation_layer: typing.Callable[..., torch.nn.Module] | None = torch.nn.ReLU
|
|
27
|
+
"""Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If None this layer won’t be used."""
|
|
28
|
+
|
|
29
|
+
inplace: bool | None = None
|
|
30
|
+
"""Parameter for the activation layer, which can optionally do the operation in-place. Default is None, which uses the respective default values of the activation_layer and Dropout layer."""
|
|
31
|
+
|
|
32
|
+
bias: bool = True
|
|
33
|
+
"""Whether to use bias in the linear layer."""
|
|
34
|
+
|
|
35
|
+
dropout: float = 0.0
|
|
36
|
+
"""The probability for the dropout layer."""
|
|
37
|
+
|
|
38
|
+
single_precision: bool = True
|
|
39
|
+
|
|
40
|
+
learning_rate: float = 0.001
|
|
41
|
+
|
|
42
|
+
scheduler_gamma: float = 0.999
|
|
43
|
+
"""Learning scheduler decay rate. Set to 0.0 to disable the scheduler."""
|
|
44
|
+
|
|
45
|
+
checkpoint_path: str | None = None
|
|
46
|
+
"""
|
|
47
|
+
Path to a checkpoint file containing model weights.
|
|
48
|
+
If None, the model will be initialized with random weights.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@processor_state
|
|
53
|
+
class MLPState:
|
|
54
|
+
model: MLP | None = None
|
|
55
|
+
optimizer: torch.optim.Optimizer | None = None
|
|
56
|
+
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None
|
|
57
|
+
template: AxisArray | None = None
|
|
58
|
+
device: object | None = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MLPProcessor(
|
|
62
|
+
BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, MLPState]
|
|
63
|
+
):
|
|
64
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
65
|
+
hash_items = (message.key,)
|
|
66
|
+
if "ch" in message.dims:
|
|
67
|
+
hash_items += (message.data.shape[message.get_axis_idx("ch")],)
|
|
68
|
+
return hash(hash_items)
|
|
69
|
+
|
|
70
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
71
|
+
# Create the model
|
|
72
|
+
self._state.model = MLP(
|
|
73
|
+
in_channels=message.data.shape[message.get_axis_idx("ch")],
|
|
74
|
+
hidden_channels=self.settings.hidden_channels,
|
|
75
|
+
norm_layer=self.settings.norm_layer,
|
|
76
|
+
activation_layer=self.settings.activation_layer,
|
|
77
|
+
inplace=self.settings.inplace,
|
|
78
|
+
bias=self.settings.bias,
|
|
79
|
+
dropout=self.settings.dropout,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Load model weights from checkpoint if specified
|
|
83
|
+
if self.settings.checkpoint_path is not None:
|
|
84
|
+
try:
|
|
85
|
+
checkpoint = torch.load(self.settings.checkpoint_path)
|
|
86
|
+
self._state.model.load_state_dict(checkpoint["model_state_dict"])
|
|
87
|
+
except Exception as e:
|
|
88
|
+
raise RuntimeError(
|
|
89
|
+
f"Failed to load checkpoint from {self.settings.checkpoint_path}: {str(e)}"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Set the model to evaluation mode by default
|
|
93
|
+
self._state.model.eval()
|
|
94
|
+
|
|
95
|
+
# Create the optimizer
|
|
96
|
+
self._state.optimizer = torch.optim.Adam(
|
|
97
|
+
self._state.model.parameters(), lr=self.settings.learning_rate
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Update the optimizer from checkpoint if it exists
|
|
101
|
+
if self.settings.checkpoint_path is not None:
|
|
102
|
+
try:
|
|
103
|
+
checkpoint = torch.load(self.settings.checkpoint_path)
|
|
104
|
+
if "optimizer_state_dict" in checkpoint:
|
|
105
|
+
self._state.optimizer.load_state_dict(
|
|
106
|
+
checkpoint["optimizer_state_dict"]
|
|
107
|
+
)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
f"Failed to load optimizer from {self.settings.checkpoint_path}: {str(e)}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# TODO: Should the model be moved to a device before the next line?
|
|
114
|
+
self._state.device = next(self.state.model.parameters()).device
|
|
115
|
+
|
|
116
|
+
# Optionally create the learning rate scheduler
|
|
117
|
+
self._state.scheduler = (
|
|
118
|
+
torch.optim.lr_scheduler.ExponentialLR(
|
|
119
|
+
self._state.optimizer, gamma=self.settings.scheduler_gamma
|
|
120
|
+
)
|
|
121
|
+
if self.settings.scheduler_gamma > 0.0
|
|
122
|
+
else None
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Create the output channel axis for reuse in each output.
|
|
126
|
+
n_output_channels = self.settings.hidden_channels[-1]
|
|
127
|
+
self._state.chan_ax = AxisArray.CoordinateAxis(
|
|
128
|
+
data=np.array([f"ch{_}" for _ in range(n_output_channels)]), dims=["ch"]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def save_checkpoint(self, path: str) -> None:
|
|
132
|
+
"""Save the current model state to a checkpoint file.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
path: Path where the checkpoint will be saved
|
|
136
|
+
"""
|
|
137
|
+
checkpoint = {
|
|
138
|
+
"model_state_dict": self._state.model.state_dict(),
|
|
139
|
+
"optimizer_state_dict": self._state.optimizer.state_dict(),
|
|
140
|
+
}
|
|
141
|
+
torch.save(checkpoint, path)
|
|
142
|
+
|
|
143
|
+
def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
|
|
144
|
+
dtype = torch.float32 if self.settings.single_precision else torch.float64
|
|
145
|
+
return torch.tensor(data, dtype=dtype, device=self._state.device)
|
|
146
|
+
|
|
147
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
148
|
+
self._state.model.train()
|
|
149
|
+
|
|
150
|
+
# TODO: loss_fn should be determined by setting
|
|
151
|
+
loss_fn = torch.nn.functional.mse_loss
|
|
152
|
+
|
|
153
|
+
X = self._to_tensor(message.sample.data)
|
|
154
|
+
y_targ = self._to_tensor(message.trigger.value)
|
|
155
|
+
|
|
156
|
+
with torch.set_grad_enabled(True):
|
|
157
|
+
self._state.model.train()
|
|
158
|
+
y_pred = self.state.model(X)
|
|
159
|
+
loss = loss_fn(y_pred, y_targ)
|
|
160
|
+
|
|
161
|
+
self.state.optimizer.zero_grad()
|
|
162
|
+
loss.backward()
|
|
163
|
+
self.state.optimizer.step() # Update weights
|
|
164
|
+
if self.state.scheduler is not None:
|
|
165
|
+
self.state.scheduler.step() # Update learning rate
|
|
166
|
+
|
|
167
|
+
self._state.model.eval()
|
|
168
|
+
|
|
169
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
170
|
+
data = message.data
|
|
171
|
+
if not isinstance(data, torch.Tensor):
|
|
172
|
+
data = torch.tensor(
|
|
173
|
+
data,
|
|
174
|
+
dtype=torch.float32
|
|
175
|
+
if self.settings.single_precision
|
|
176
|
+
else torch.float64,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
with torch.no_grad():
|
|
180
|
+
output = self.state.model(data.to(self.state.device))
|
|
181
|
+
|
|
182
|
+
return replace(
|
|
183
|
+
message,
|
|
184
|
+
data=output.cpu().numpy(),
|
|
185
|
+
axes={
|
|
186
|
+
**message.axes,
|
|
187
|
+
"ch": self.state.chan_ax,
|
|
188
|
+
},
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class MLPUnit(
|
|
193
|
+
BaseAdaptiveTransformerUnit[
|
|
194
|
+
MLPSettings,
|
|
195
|
+
AxisArray,
|
|
196
|
+
AxisArray,
|
|
197
|
+
MLPProcessor,
|
|
198
|
+
]
|
|
199
|
+
):
|
|
200
|
+
SETTINGS = MLPSettings
|
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
import pickle
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
import numpy as np
|
|
6
|
+
from ezmsg.sigproc.base import (
|
|
7
|
+
BaseAdaptiveTransformer,
|
|
8
|
+
BaseAdaptiveTransformerUnit,
|
|
9
|
+
processor_state,
|
|
10
|
+
)
|
|
11
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
12
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
|
+
from ezmsg.util.messages.util import replace
|
|
14
|
+
|
|
15
|
+
from ..model.refit_kalman import RefitKalmanFilter
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RefitKalmanFilterSettings(ez.Settings):
|
|
19
|
+
"""
|
|
20
|
+
Settings for the Refit Kalman filter processor.
|
|
21
|
+
|
|
22
|
+
This class defines the configuration parameters for the Refit Kalman filter processor.
|
|
23
|
+
The RefitKalmanFilter is designed for online processing and playback.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
checkpoint_path: Path to saved model parameters (optional).
|
|
27
|
+
If provided, loads pre-trained parameters instead of learning from data.
|
|
28
|
+
steady_state: Whether to use steady-state Kalman filter.
|
|
29
|
+
If True, uses pre-computed Kalman gain; if False, updates dynamically.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
checkpoint_path: str | None = None
|
|
33
|
+
steady_state: bool = False
|
|
34
|
+
velocity_indices: tuple[int, int] = (2, 3)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@processor_state
|
|
38
|
+
class RefitKalmanFilterState:
|
|
39
|
+
"""
|
|
40
|
+
State management for the Refit Kalman filter processor.
|
|
41
|
+
|
|
42
|
+
This class manages the persistent state of the Refit Kalman filter processor,
|
|
43
|
+
including the model instance, current state estimates, and data buffers for refitting.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
model: The RefitKalmanFilter model instance.
|
|
47
|
+
x: Current state estimate (n_states,).
|
|
48
|
+
P: Current state covariance matrix (n_states x n_states).
|
|
49
|
+
buffer_neural: Buffer for storing neural activity data for refitting.
|
|
50
|
+
buffer_state: Buffer for storing state estimates for refitting.
|
|
51
|
+
buffer_cursor_positions: Buffer for storing cursor positions for refitting.
|
|
52
|
+
buffer_target_positions: Buffer for storing target positions for refitting.
|
|
53
|
+
buffer_hold_flags: Buffer for storing hold flags for refitting.
|
|
54
|
+
current_position: Current cursor position estimate (2,).
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
model: RefitKalmanFilter | None = None
|
|
58
|
+
x: np.ndarray | None = None
|
|
59
|
+
P: np.ndarray | None = None
|
|
60
|
+
|
|
61
|
+
buffer_neural: list | None = None
|
|
62
|
+
buffer_state: list | None = None
|
|
63
|
+
buffer_cursor_positions: list | None = None
|
|
64
|
+
buffer_target_positions: list | None = None
|
|
65
|
+
buffer_hold_flags: list | None = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class RefitKalmanFilterProcessor(
|
|
69
|
+
BaseAdaptiveTransformer[
|
|
70
|
+
RefitKalmanFilterSettings,
|
|
71
|
+
AxisArray,
|
|
72
|
+
AxisArray,
|
|
73
|
+
RefitKalmanFilterState,
|
|
74
|
+
]
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
Processor for implementing a Refit Kalman filter in the ezmsg framework.
|
|
78
|
+
|
|
79
|
+
This processor integrates the RefitKalmanFilter model into the ezmsg
|
|
80
|
+
message passing system. It handles the conversion between AxisArray messages
|
|
81
|
+
and the internal Refit Kalman filter operations.
|
|
82
|
+
|
|
83
|
+
The processor performs the following operations:
|
|
84
|
+
1. Configures the Refit Kalman filter model with provided settings
|
|
85
|
+
2. Processes incoming measurement messages
|
|
86
|
+
3. Performs prediction and update steps
|
|
87
|
+
4. Logs data for potential refitting
|
|
88
|
+
5. Supports online refitting of the observation model
|
|
89
|
+
6. Returns filtered state estimates as AxisArray messages
|
|
90
|
+
7. Maintains state between message processing calls
|
|
91
|
+
|
|
92
|
+
The processor can operate in two modes:
|
|
93
|
+
1. Pre-trained mode: Loads parameters from checkpoint_path
|
|
94
|
+
2. Learning mode: Collects data and fits the model when buffer is full
|
|
95
|
+
|
|
96
|
+
Key features:
|
|
97
|
+
- Online refitting capability for adaptive neural decoding
|
|
98
|
+
- Data logging for retrospective analysis
|
|
99
|
+
- Position tracking for cursor control applications
|
|
100
|
+
- Hold period detection and handling
|
|
101
|
+
|
|
102
|
+
Attributes:
|
|
103
|
+
settings: Configuration settings for the Refit Kalman filter.
|
|
104
|
+
_state: Internal state management object.
|
|
105
|
+
|
|
106
|
+
Example:
|
|
107
|
+
>>> # Create settings with checkpoint path
|
|
108
|
+
>>> settings = RefitKalmanFilterSettings(
|
|
109
|
+
... checkpoint_path="path/to/checkpoint.pkl",
|
|
110
|
+
... steady_state=True
|
|
111
|
+
... )
|
|
112
|
+
>>>
|
|
113
|
+
>>> # Create processor
|
|
114
|
+
>>> processor = RefitKalmanFilterProcessor(settings)
|
|
115
|
+
>>>
|
|
116
|
+
>>> # Process measurement message
|
|
117
|
+
>>> result = processor(measurement_message)
|
|
118
|
+
>>>
|
|
119
|
+
>>> # Log data for refitting
|
|
120
|
+
>>> processor.log_for_refit(message, target_pos, hold_flag)
|
|
121
|
+
>>>
|
|
122
|
+
>>> # Refit the model
|
|
123
|
+
>>> processor.refit_model()
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def _config_from_settings(self) -> dict:
|
|
127
|
+
"""
|
|
128
|
+
Returns:
|
|
129
|
+
dict: Dictionary containing configuration parameters for model initialization.
|
|
130
|
+
"""
|
|
131
|
+
return {
|
|
132
|
+
"steady_state": self.settings.steady_state,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
def _init_model(self, **kwargs):
|
|
136
|
+
"""
|
|
137
|
+
Initialize a new RefitKalmanFilter model with current settings.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
**kwargs: Keyword arguments for model initialization.
|
|
141
|
+
"""
|
|
142
|
+
config = self._config_from_settings()
|
|
143
|
+
config.update(kwargs)
|
|
144
|
+
self._state.model = RefitKalmanFilter(**config)
|
|
145
|
+
|
|
146
|
+
def fit(self, X: np.ndarray, y: np.ndarray) -> None:
|
|
147
|
+
if self._state.model is None:
|
|
148
|
+
self._init_model()
|
|
149
|
+
if hasattr(self._state.model, "fit"):
|
|
150
|
+
self._state.model.fit(X, y)
|
|
151
|
+
|
|
152
|
+
def load_from_checkpoint(self, checkpoint_path: str) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Load model parameters from a serialized checkpoint file.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
checkpoint_path (str): Path to the saved checkpoint file.
|
|
158
|
+
|
|
159
|
+
Side Effects:
|
|
160
|
+
- Initializes a new model if not already set.
|
|
161
|
+
- Sets model matrices A, W, H, Q from the checkpoint.
|
|
162
|
+
- Computes Kalman gain based on restored parameters.
|
|
163
|
+
"""
|
|
164
|
+
checkpoint_file = Path(checkpoint_path)
|
|
165
|
+
with open(checkpoint_file, "rb") as f:
|
|
166
|
+
checkpoint_data = pickle.load(f)
|
|
167
|
+
self._init_model(**checkpoint_data)
|
|
168
|
+
self._state.model._compute_gain()
|
|
169
|
+
self._state.model.is_fitted = True
|
|
170
|
+
|
|
171
|
+
def save_checkpoint(self, checkpoint_path: str) -> None:
|
|
172
|
+
"""
|
|
173
|
+
Save current model parameters to a checkpoint file.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
checkpoint_path (str): Destination file path for saving model parameters.
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
ValueError: If the model is not initialized or has not been fitted.
|
|
180
|
+
"""
|
|
181
|
+
if not self._state.model or not self._state.model.is_fitted:
|
|
182
|
+
raise ValueError("Cannot save checkpoint: model not fitted")
|
|
183
|
+
checkpoint_data = {
|
|
184
|
+
"A_state_transition_matrix": self._state.model.A_state_transition_matrix,
|
|
185
|
+
"W_process_noise_covariance": self._state.model.W_process_noise_covariance,
|
|
186
|
+
"H_observation_matrix": self._state.model.H_observation_matrix,
|
|
187
|
+
"Q_measurement_noise_covariance": self._state.model.Q_measurement_noise_covariance,
|
|
188
|
+
}
|
|
189
|
+
checkpoint_file = Path(checkpoint_path)
|
|
190
|
+
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
|
|
191
|
+
with open(checkpoint_file, "wb") as f:
|
|
192
|
+
pickle.dump(checkpoint_data, f)
|
|
193
|
+
|
|
194
|
+
def _reset_state(
|
|
195
|
+
self,
|
|
196
|
+
message: AxisArray = None,
|
|
197
|
+
):
|
|
198
|
+
"""
|
|
199
|
+
This method initializes or reinitializes the state vector (x), state covariance (P),
|
|
200
|
+
and cursor position. If a checkpoint path is specified in the settings, the model
|
|
201
|
+
is loaded from the checkpoint.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
message (AxisArray): Time-series message containing neural measurements.
|
|
205
|
+
x_init (np.ndarray): Initial state vector.
|
|
206
|
+
P_init (np.ndarray): Initial state covariance matrix.
|
|
207
|
+
"""
|
|
208
|
+
if not self._state.model:
|
|
209
|
+
if self.settings.checkpoint_path:
|
|
210
|
+
self.load_from_checkpoint(self.settings.checkpoint_path)
|
|
211
|
+
else:
|
|
212
|
+
self._init_model()
|
|
213
|
+
## TODO: fit the model - how to do this given expected inputs X and y?
|
|
214
|
+
# for unit test purposes only, given a known kinematic state size
|
|
215
|
+
state_dim = 2
|
|
216
|
+
|
|
217
|
+
# # If A is None, the model has not been fitted or loaded from checkpoint
|
|
218
|
+
# if self._state.model.A_state_transition_matrix is None:
|
|
219
|
+
# raise RuntimeError(
|
|
220
|
+
# "Cannot reset state — model has not been fitted or loaded from checkpoint."
|
|
221
|
+
# )
|
|
222
|
+
|
|
223
|
+
if self._state.model.A_state_transition_matrix is not None:
|
|
224
|
+
state_dim = self._state.model.A_state_transition_matrix.shape[0]
|
|
225
|
+
|
|
226
|
+
self._state.x = np.zeros(state_dim)
|
|
227
|
+
self._state.P = np.eye(state_dim)
|
|
228
|
+
|
|
229
|
+
self._state.buffer_neural = []
|
|
230
|
+
self._state.buffer_state = []
|
|
231
|
+
self._state.buffer_cursor_positions = []
|
|
232
|
+
self._state.buffer_target_positions = []
|
|
233
|
+
self._state.buffer_hold_flags = []
|
|
234
|
+
|
|
235
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
236
|
+
"""
|
|
237
|
+
Process an incoming message using the Kalman filter.
|
|
238
|
+
|
|
239
|
+
For each time point in the message:
|
|
240
|
+
- Predict the next state
|
|
241
|
+
- Update the estimate using the current measurement
|
|
242
|
+
- Track the velocity and estimate position
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
message (AxisArray): Time-series message containing neural measurements.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
AxisArray: Filtered message containing updated state estimates.
|
|
249
|
+
"""
|
|
250
|
+
# If checkpoint, load the model from the checkpoint
|
|
251
|
+
if not self._state.model and self.settings.checkpoint_path:
|
|
252
|
+
self.load_from_checkpoint(self.settings.checkpoint_path)
|
|
253
|
+
# No checkpoint means you need to initialize and fit the model
|
|
254
|
+
elif not self._state.model:
|
|
255
|
+
self._init_model()
|
|
256
|
+
state_dim = self._state.model.A_state_transition_matrix.shape[0]
|
|
257
|
+
if self._state.x is None:
|
|
258
|
+
self._state.x = np.zeros(state_dim)
|
|
259
|
+
|
|
260
|
+
filtered_data = np.zeros(
|
|
261
|
+
(
|
|
262
|
+
message.data.shape[0],
|
|
263
|
+
self._state.model.A_state_transition_matrix.shape[0],
|
|
264
|
+
)
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
for i in range(message.data.shape[0]):
|
|
268
|
+
measurement = message.data[i]
|
|
269
|
+
# Predict
|
|
270
|
+
x_pred, P_pred = self._state.model.predict(self._state.x)
|
|
271
|
+
|
|
272
|
+
# Update
|
|
273
|
+
x_updated = self._state.model.update(measurement, x_pred, P_pred)
|
|
274
|
+
|
|
275
|
+
# Store
|
|
276
|
+
self._state.x = x_updated.copy()
|
|
277
|
+
self._state.P = self._state.model.P_state_covariance.copy()
|
|
278
|
+
filtered_data[i] = self._state.x
|
|
279
|
+
|
|
280
|
+
return replace(
|
|
281
|
+
message,
|
|
282
|
+
data=filtered_data,
|
|
283
|
+
dims=["time", "state"],
|
|
284
|
+
key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered",
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
288
|
+
"""
|
|
289
|
+
Perform refitting using externally provided data.
|
|
290
|
+
|
|
291
|
+
Expects message.sample.data (neural input) and message.trigger.value as a dict with:
|
|
292
|
+
- Y_state: (n_samples, n_states) array
|
|
293
|
+
- intention_velocity_indices: Optional[int]
|
|
294
|
+
- target_positions: Optional[np.ndarray]
|
|
295
|
+
- cursor_positions: Optional[np.ndarray]
|
|
296
|
+
- hold_flags: Optional[list[bool]]
|
|
297
|
+
"""
|
|
298
|
+
if not hasattr(message, "sample") or not hasattr(message, "trigger"):
|
|
299
|
+
raise ValueError("Invalid message format for partial_fit.")
|
|
300
|
+
|
|
301
|
+
X = np.array(message.sample.data)
|
|
302
|
+
values = message.trigger.value
|
|
303
|
+
|
|
304
|
+
if not isinstance(values, dict) or "Y_state" not in values:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
"partial_fit expects trigger.value to include at least 'Y_state'."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
kwargs = {
|
|
310
|
+
"X_neural": X,
|
|
311
|
+
"Y_state": np.array(values["Y_state"]),
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
# Optional fields
|
|
315
|
+
for key in [
|
|
316
|
+
"intention_velocity_indices",
|
|
317
|
+
"target_positions",
|
|
318
|
+
"cursor_positions",
|
|
319
|
+
"hold_flags",
|
|
320
|
+
]:
|
|
321
|
+
if key in values and values[key] is not None:
|
|
322
|
+
kwargs[key if key != "hold_flags" else "hold_indices"] = np.array(
|
|
323
|
+
values[key]
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Call model refit
|
|
327
|
+
self._state.model.refit(**kwargs)
|
|
328
|
+
|
|
329
|
+
def log_for_refit(
|
|
330
|
+
self,
|
|
331
|
+
message: AxisArray,
|
|
332
|
+
target_position: np.ndarray | None = None,
|
|
333
|
+
hold_flag: bool | None = None,
|
|
334
|
+
):
|
|
335
|
+
"""
|
|
336
|
+
Log data for potential refitting of the model.
|
|
337
|
+
|
|
338
|
+
This method stores measurement data, state estimates, and contextual
|
|
339
|
+
information (target positions, cursor positions, hold flags) in buffers
|
|
340
|
+
for later use in refitting the observation model. This data is used
|
|
341
|
+
to adapt the model to changing neural-to-behavioral relationships.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
message: AxisArray message containing measurement data.
|
|
345
|
+
target_position: Target position for the current time point (2,).
|
|
346
|
+
hold_flag: Boolean flag indicating if this is a hold period.
|
|
347
|
+
"""
|
|
348
|
+
if target_position is not None:
|
|
349
|
+
self._state.buffer_target_positions.append(target_position.copy())
|
|
350
|
+
if hold_flag is not None:
|
|
351
|
+
self._state.buffer_hold_flags.append(hold_flag)
|
|
352
|
+
|
|
353
|
+
measurement = message.data[-1]
|
|
354
|
+
self._state.buffer_neural.append(measurement.copy())
|
|
355
|
+
self._state.buffer_state.append(self._state.x.copy())
|
|
356
|
+
|
|
357
|
+
def refit_model(self):
|
|
358
|
+
"""
|
|
359
|
+
Refit the observation model (H, Q) using buffered measurements and contextual data.
|
|
360
|
+
|
|
361
|
+
This method updates the model's understanding of the neural-to-state mapping
|
|
362
|
+
by calculating a new observation matrix and noise covariance, based on:
|
|
363
|
+
- Logged neural data
|
|
364
|
+
- Cursor state estimates
|
|
365
|
+
- Hold flags and target positions
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
velocity_indices (tuple): Indices in the state vector corresponding to velocity components.
|
|
369
|
+
Default assumes 2D velocity at indices (0, 1).
|
|
370
|
+
|
|
371
|
+
Raises:
|
|
372
|
+
ValueError: If no buffered data exists.
|
|
373
|
+
"""
|
|
374
|
+
if not self._state.buffer_neural:
|
|
375
|
+
print("No buffered data to refit")
|
|
376
|
+
return
|
|
377
|
+
|
|
378
|
+
kwargs = {
|
|
379
|
+
"X_neural": np.array(self._state.buffer_neural),
|
|
380
|
+
"Y_state": np.array(self._state.buffer_state),
|
|
381
|
+
"intention_velocity_indices": self.settings.velocity_indices[0],
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
if self._state.buffer_target_positions and self._state.buffer_cursor_positions:
|
|
385
|
+
kwargs["target_positions"] = np.array(self._state.buffer_target_positions)
|
|
386
|
+
kwargs["cursor_positions"] = np.array(self._state.buffer_cursor_positions)
|
|
387
|
+
if self._state.buffer_hold_flags:
|
|
388
|
+
kwargs["hold_indices"] = np.array(self._state.buffer_hold_flags)
|
|
389
|
+
|
|
390
|
+
self._state.model.refit(**kwargs)
|
|
391
|
+
|
|
392
|
+
self._state.buffer_neural.clear()
|
|
393
|
+
self._state.buffer_state.clear()
|
|
394
|
+
self._state.buffer_cursor_positions.clear()
|
|
395
|
+
self._state.buffer_target_positions.clear()
|
|
396
|
+
self._state.buffer_hold_flags.clear()
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class RefitKalmanFilterUnit(
|
|
400
|
+
BaseAdaptiveTransformerUnit[
|
|
401
|
+
RefitKalmanFilterSettings,
|
|
402
|
+
AxisArray,
|
|
403
|
+
AxisArray,
|
|
404
|
+
RefitKalmanFilterProcessor,
|
|
405
|
+
]
|
|
406
|
+
):
|
|
407
|
+
SETTINGS = RefitKalmanFilterSettings
|