ezmsg-learn 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ezmsg/learn/__init__.py +2 -0
  2. ezmsg/learn/__version__.py +34 -0
  3. ezmsg/learn/dim_reduce/__init__.py +0 -0
  4. ezmsg/learn/dim_reduce/adaptive_decomp.py +274 -0
  5. ezmsg/learn/dim_reduce/incremental_decomp.py +173 -0
  6. ezmsg/learn/linear_model/__init__.py +1 -0
  7. ezmsg/learn/linear_model/adaptive_linear_regressor.py +12 -0
  8. ezmsg/learn/linear_model/cca.py +1 -0
  9. ezmsg/learn/linear_model/linear_regressor.py +9 -0
  10. ezmsg/learn/linear_model/sgd.py +9 -0
  11. ezmsg/learn/linear_model/slda.py +12 -0
  12. ezmsg/learn/model/__init__.py +0 -0
  13. ezmsg/learn/model/cca.py +122 -0
  14. ezmsg/learn/model/mlp.py +127 -0
  15. ezmsg/learn/model/mlp_old.py +49 -0
  16. ezmsg/learn/model/refit_kalman.py +369 -0
  17. ezmsg/learn/model/rnn.py +160 -0
  18. ezmsg/learn/model/transformer.py +175 -0
  19. ezmsg/learn/nlin_model/__init__.py +1 -0
  20. ezmsg/learn/nlin_model/mlp.py +10 -0
  21. ezmsg/learn/process/__init__.py +0 -0
  22. ezmsg/learn/process/adaptive_linear_regressor.py +142 -0
  23. ezmsg/learn/process/base.py +154 -0
  24. ezmsg/learn/process/linear_regressor.py +95 -0
  25. ezmsg/learn/process/mlp_old.py +188 -0
  26. ezmsg/learn/process/refit_kalman.py +403 -0
  27. ezmsg/learn/process/rnn.py +245 -0
  28. ezmsg/learn/process/sgd.py +117 -0
  29. ezmsg/learn/process/sklearn.py +241 -0
  30. ezmsg/learn/process/slda.py +110 -0
  31. ezmsg/learn/process/ssr.py +374 -0
  32. ezmsg/learn/process/torch.py +362 -0
  33. ezmsg/learn/process/transformer.py +215 -0
  34. ezmsg/learn/util.py +67 -0
  35. ezmsg_learn-1.1.0.dist-info/METADATA +30 -0
  36. ezmsg_learn-1.1.0.dist-info/RECORD +38 -0
  37. ezmsg_learn-1.1.0.dist-info/WHEEL +4 -0
  38. ezmsg_learn-1.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,95 @@
1
+ from dataclasses import field
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
6
+ BaseAdaptiveTransformer,
7
+ BaseAdaptiveTransformerUnit,
8
+ processor_state,
9
+ )
10
+ from ezmsg.sigproc.sampler import SampleMessage
11
+ from ezmsg.util.messages.axisarray import AxisArray, replace
12
+ from sklearn.linear_model._base import LinearModel
13
+
14
+ from ..util import RegressorType, StaticLinearRegressor, get_regressor
15
+
16
+
17
+ class LinearRegressorSettings(ez.Settings):
18
+ model_type: StaticLinearRegressor = StaticLinearRegressor.LINEAR
19
+ settings_path: str | None = None
20
+ model_kwargs: dict = field(default_factory=dict)
21
+
22
+
23
+ @processor_state
24
+ class LinearRegressorState:
25
+ template: AxisArray | None = None
26
+ model: LinearModel | None = None
27
+
28
+
29
+ class LinearRegressorTransformer(
30
+ BaseAdaptiveTransformer[LinearRegressorSettings, AxisArray, AxisArray, LinearRegressorState]
31
+ ):
32
+ """
33
+ Linear regressor.
34
+
35
+ Note: `partial_fit` is not 'partial'. It fully resets the model using the entirety of the SampleMessage provided.
36
+ If you require adaptive fitting, try using the adaptive_linear_regressor module.
37
+ """
38
+
39
+ def __init__(self, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+ if self.settings.settings_path is not None:
42
+ # Load model from file
43
+ import pickle
44
+
45
+ with open(self.settings.settings_path, "rb") as f:
46
+ self.state.model = pickle.load(f)
47
+ else:
48
+ regressor_klass = get_regressor(RegressorType.STATIC, self.settings.model_type)
49
+ self.state.model = regressor_klass(**self.settings.model_kwargs)
50
+
51
+ def _reset_state(self, message: AxisArray) -> None:
52
+ # So far, there is nothing to reset.
53
+ # .model and .template are initialized in __init__
54
+ pass
55
+
56
+ def partial_fit(self, message: SampleMessage) -> None:
57
+ if np.any(np.isnan(message.sample.data)):
58
+ return
59
+
60
+ X = message.sample.data
61
+ y = message.trigger.value.data
62
+ # TODO: Resample should provide identical durations.
63
+ self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]])
64
+ self.state.template = replace(
65
+ message.trigger.value,
66
+ data=np.array([[]]),
67
+ key=message.trigger.value.key + "_pred",
68
+ )
69
+
70
+ def _process(self, message: AxisArray) -> AxisArray:
71
+ if self.state.template is None:
72
+ return AxisArray(np.array([[]]), dims=["time", "ch"])
73
+ preds = self.state.model.predict(message.data)
74
+ return replace(
75
+ self.state.template,
76
+ data=preds,
77
+ axes={
78
+ **self.state.template.axes,
79
+ "time": replace(
80
+ message.axes["time"],
81
+ offset=message.axes["time"].offset,
82
+ ),
83
+ },
84
+ )
85
+
86
+
87
+ class AdaptiveLinearRegressorUnit(
88
+ BaseAdaptiveTransformerUnit[
89
+ LinearRegressorSettings,
90
+ AxisArray,
91
+ AxisArray,
92
+ LinearRegressorTransformer,
93
+ ]
94
+ ):
95
+ SETTINGS = LinearRegressorSettings
@@ -0,0 +1,188 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn
7
+ from ezmsg.baseproc import (
8
+ BaseAdaptiveTransformer,
9
+ BaseAdaptiveTransformerUnit,
10
+ processor_state,
11
+ )
12
+ from ezmsg.sigproc.sampler import SampleMessage
13
+ from ezmsg.util.messages.axisarray import AxisArray
14
+ from ezmsg.util.messages.util import replace
15
+
16
+ from ..model.mlp_old import MLP
17
+
18
+
19
+ class MLPSettings(ez.Settings):
20
+ hidden_channels: list[int]
21
+ """List of the hidden channel dimensions"""
22
+
23
+ norm_layer: typing.Callable[..., torch.nn.Module] | None = None
24
+ """Norm layer that will be stacked on top of the linear layer. If None this layer won’t be used."""
25
+
26
+ activation_layer: typing.Callable[..., torch.nn.Module] | None = torch.nn.ReLU
27
+ """Activation function which will be stacked on top of the normalization layer (if not None),
28
+ otherwise on top of the linear layer. If None this layer won’t be used."""
29
+
30
+ inplace: bool | None = None
31
+ """Parameter for the activation layer, which can optionally do the operation in-place.
32
+ Default is None, which uses the respective default values of the activation_layer and Dropout layer."""
33
+
34
+ bias: bool = True
35
+ """Whether to use bias in the linear layer."""
36
+
37
+ dropout: float = 0.0
38
+ """The probability for the dropout layer."""
39
+
40
+ single_precision: bool = True
41
+
42
+ learning_rate: float = 0.001
43
+
44
+ scheduler_gamma: float = 0.999
45
+ """Learning scheduler decay rate. Set to 0.0 to disable the scheduler."""
46
+
47
+ checkpoint_path: str | None = None
48
+ """
49
+ Path to a checkpoint file containing model weights.
50
+ If None, the model will be initialized with random weights.
51
+ """
52
+
53
+
54
+ @processor_state
55
+ class MLPState:
56
+ model: MLP | None = None
57
+ optimizer: torch.optim.Optimizer | None = None
58
+ scheduler: torch.optim.lr_scheduler.LRScheduler | None = None
59
+ template: AxisArray | None = None
60
+ device: object | None = None
61
+
62
+
63
+ class MLPProcessor(BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, MLPState]):
64
+ def _hash_message(self, message: AxisArray) -> int:
65
+ hash_items = (message.key,)
66
+ if "ch" in message.dims:
67
+ hash_items += (message.data.shape[message.get_axis_idx("ch")],)
68
+ return hash(hash_items)
69
+
70
+ def _reset_state(self, message: AxisArray) -> None:
71
+ # Create the model
72
+ self._state.model = MLP(
73
+ in_channels=message.data.shape[message.get_axis_idx("ch")],
74
+ hidden_channels=self.settings.hidden_channels,
75
+ norm_layer=self.settings.norm_layer,
76
+ activation_layer=self.settings.activation_layer,
77
+ inplace=self.settings.inplace,
78
+ bias=self.settings.bias,
79
+ dropout=self.settings.dropout,
80
+ )
81
+
82
+ # Load model weights from checkpoint if specified
83
+ if self.settings.checkpoint_path is not None:
84
+ try:
85
+ checkpoint = torch.load(self.settings.checkpoint_path)
86
+ self._state.model.load_state_dict(checkpoint["model_state_dict"])
87
+ except Exception as e:
88
+ raise RuntimeError(f"Failed to load checkpoint from {self.settings.checkpoint_path}: {str(e)}")
89
+
90
+ # Set the model to evaluation mode by default
91
+ self._state.model.eval()
92
+
93
+ # Create the optimizer
94
+ self._state.optimizer = torch.optim.Adam(self._state.model.parameters(), lr=self.settings.learning_rate)
95
+
96
+ # Update the optimizer from checkpoint if it exists
97
+ if self.settings.checkpoint_path is not None:
98
+ try:
99
+ checkpoint = torch.load(self.settings.checkpoint_path)
100
+ if "optimizer_state_dict" in checkpoint:
101
+ self._state.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
102
+ except Exception as e:
103
+ raise RuntimeError(f"Failed to load optimizer from {self.settings.checkpoint_path}: {str(e)}")
104
+
105
+ # TODO: Should the model be moved to a device before the next line?
106
+ self._state.device = next(self.state.model.parameters()).device
107
+
108
+ # Optionally create the learning rate scheduler
109
+ self._state.scheduler = (
110
+ torch.optim.lr_scheduler.ExponentialLR(self._state.optimizer, gamma=self.settings.scheduler_gamma)
111
+ if self.settings.scheduler_gamma > 0.0
112
+ else None
113
+ )
114
+
115
+ # Create the output channel axis for reuse in each output.
116
+ n_output_channels = self.settings.hidden_channels[-1]
117
+ self._state.chan_ax = AxisArray.CoordinateAxis(
118
+ data=np.array([f"ch{_}" for _ in range(n_output_channels)]), dims=["ch"]
119
+ )
120
+
121
+ def save_checkpoint(self, path: str) -> None:
122
+ """Save the current model state to a checkpoint file.
123
+
124
+ Args:
125
+ path: Path where the checkpoint will be saved
126
+ """
127
+ checkpoint = {
128
+ "model_state_dict": self._state.model.state_dict(),
129
+ "optimizer_state_dict": self._state.optimizer.state_dict(),
130
+ }
131
+ torch.save(checkpoint, path)
132
+
133
+ def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
134
+ dtype = torch.float32 if self.settings.single_precision else torch.float64
135
+ return torch.tensor(data, dtype=dtype, device=self._state.device)
136
+
137
+ def partial_fit(self, message: SampleMessage) -> None:
138
+ self._state.model.train()
139
+
140
+ # TODO: loss_fn should be determined by setting
141
+ loss_fn = torch.nn.functional.mse_loss
142
+
143
+ X = self._to_tensor(message.sample.data)
144
+ y_targ = self._to_tensor(message.trigger.value)
145
+
146
+ with torch.set_grad_enabled(True):
147
+ self._state.model.train()
148
+ y_pred = self.state.model(X)
149
+ loss = loss_fn(y_pred, y_targ)
150
+
151
+ self.state.optimizer.zero_grad()
152
+ loss.backward()
153
+ self.state.optimizer.step() # Update weights
154
+ if self.state.scheduler is not None:
155
+ self.state.scheduler.step() # Update learning rate
156
+
157
+ self._state.model.eval()
158
+
159
+ def _process(self, message: AxisArray) -> AxisArray:
160
+ data = message.data
161
+ if not isinstance(data, torch.Tensor):
162
+ data = torch.tensor(
163
+ data,
164
+ dtype=torch.float32 if self.settings.single_precision else torch.float64,
165
+ )
166
+
167
+ with torch.no_grad():
168
+ output = self.state.model(data.to(self.state.device))
169
+
170
+ return replace(
171
+ message,
172
+ data=output.cpu().numpy(),
173
+ axes={
174
+ **message.axes,
175
+ "ch": self.state.chan_ax,
176
+ },
177
+ )
178
+
179
+
180
+ class MLPUnit(
181
+ BaseAdaptiveTransformerUnit[
182
+ MLPSettings,
183
+ AxisArray,
184
+ AxisArray,
185
+ MLPProcessor,
186
+ ]
187
+ ):
188
+ SETTINGS = MLPSettings
@@ -0,0 +1,403 @@
1
+ import pickle
2
+ from pathlib import Path
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ from ezmsg.baseproc import (
7
+ BaseAdaptiveTransformer,
8
+ BaseAdaptiveTransformerUnit,
9
+ processor_state,
10
+ )
11
+ from ezmsg.sigproc.sampler import SampleMessage
12
+ from ezmsg.util.messages.axisarray import AxisArray
13
+ from ezmsg.util.messages.util import replace
14
+
15
+ from ..model.refit_kalman import RefitKalmanFilter
16
+
17
+
18
+ class RefitKalmanFilterSettings(ez.Settings):
19
+ """
20
+ Settings for the Refit Kalman filter processor.
21
+
22
+ This class defines the configuration parameters for the Refit Kalman filter processor.
23
+ The RefitKalmanFilter is designed for online processing and playback.
24
+
25
+ Attributes:
26
+ checkpoint_path: Path to saved model parameters (optional).
27
+ If provided, loads pre-trained parameters instead of learning from data.
28
+ steady_state: Whether to use steady-state Kalman filter.
29
+ If True, uses pre-computed Kalman gain; if False, updates dynamically.
30
+ """
31
+
32
+ checkpoint_path: str | None = None
33
+ steady_state: bool = False
34
+ velocity_indices: tuple[int, int] = (2, 3)
35
+
36
+
37
+ @processor_state
38
+ class RefitKalmanFilterState:
39
+ """
40
+ State management for the Refit Kalman filter processor.
41
+
42
+ This class manages the persistent state of the Refit Kalman filter processor,
43
+ including the model instance, current state estimates, and data buffers for refitting.
44
+
45
+ Attributes:
46
+ model: The RefitKalmanFilter model instance.
47
+ x: Current state estimate (n_states,).
48
+ P: Current state covariance matrix (n_states x n_states).
49
+ buffer_neural: Buffer for storing neural activity data for refitting.
50
+ buffer_state: Buffer for storing state estimates for refitting.
51
+ buffer_cursor_positions: Buffer for storing cursor positions for refitting.
52
+ buffer_target_positions: Buffer for storing target positions for refitting.
53
+ buffer_hold_flags: Buffer for storing hold flags for refitting.
54
+ current_position: Current cursor position estimate (2,).
55
+ """
56
+
57
+ model: RefitKalmanFilter | None = None
58
+ x: np.ndarray | None = None
59
+ P: np.ndarray | None = None
60
+
61
+ buffer_neural: list | None = None
62
+ buffer_state: list | None = None
63
+ buffer_cursor_positions: list | None = None
64
+ buffer_target_positions: list | None = None
65
+ buffer_hold_flags: list | None = None
66
+
67
+
68
+ class RefitKalmanFilterProcessor(
69
+ BaseAdaptiveTransformer[
70
+ RefitKalmanFilterSettings,
71
+ AxisArray,
72
+ AxisArray,
73
+ RefitKalmanFilterState,
74
+ ]
75
+ ):
76
+ """
77
+ Processor for implementing a Refit Kalman filter in the ezmsg framework.
78
+
79
+ This processor integrates the RefitKalmanFilter model into the ezmsg
80
+ message passing system. It handles the conversion between AxisArray messages
81
+ and the internal Refit Kalman filter operations.
82
+
83
+ The processor performs the following operations:
84
+ 1. Configures the Refit Kalman filter model with provided settings
85
+ 2. Processes incoming measurement messages
86
+ 3. Performs prediction and update steps
87
+ 4. Logs data for potential refitting
88
+ 5. Supports online refitting of the observation model
89
+ 6. Returns filtered state estimates as AxisArray messages
90
+ 7. Maintains state between message processing calls
91
+
92
+ The processor can operate in two modes:
93
+ 1. Pre-trained mode: Loads parameters from checkpoint_path
94
+ 2. Learning mode: Collects data and fits the model when buffer is full
95
+
96
+ Key features:
97
+ - Online refitting capability for adaptive neural decoding
98
+ - Data logging for retrospective analysis
99
+ - Position tracking for cursor control applications
100
+ - Hold period detection and handling
101
+
102
+ Attributes:
103
+ settings: Configuration settings for the Refit Kalman filter.
104
+ _state: Internal state management object.
105
+
106
+ Example:
107
+ >>> # Create settings with checkpoint path
108
+ >>> settings = RefitKalmanFilterSettings(
109
+ ... checkpoint_path="path/to/checkpoint.pkl",
110
+ ... steady_state=True
111
+ ... )
112
+ >>>
113
+ >>> # Create processor
114
+ >>> processor = RefitKalmanFilterProcessor(settings)
115
+ >>>
116
+ >>> # Process measurement message
117
+ >>> result = processor(measurement_message)
118
+ >>>
119
+ >>> # Log data for refitting
120
+ >>> processor.log_for_refit(message, target_pos, hold_flag)
121
+ >>>
122
+ >>> # Refit the model
123
+ >>> processor.refit_model()
124
+ """
125
+
126
+ def _config_from_settings(self) -> dict:
127
+ """
128
+ Returns:
129
+ dict: Dictionary containing configuration parameters for model initialization.
130
+ """
131
+ return {
132
+ "steady_state": self.settings.steady_state,
133
+ }
134
+
135
+ def _init_model(self, **kwargs):
136
+ """
137
+ Initialize a new RefitKalmanFilter model with current settings.
138
+
139
+ Args:
140
+ **kwargs: Keyword arguments for model initialization.
141
+ """
142
+ config = self._config_from_settings()
143
+ config.update(kwargs)
144
+ self._state.model = RefitKalmanFilter(**config)
145
+
146
+ def fit(self, X: np.ndarray, y: np.ndarray) -> None:
147
+ if self._state.model is None:
148
+ self._init_model()
149
+ if hasattr(self._state.model, "fit"):
150
+ self._state.model.fit(X, y)
151
+
152
+ def load_from_checkpoint(self, checkpoint_path: str) -> None:
153
+ """
154
+ Load model parameters from a serialized checkpoint file.
155
+
156
+ Args:
157
+ checkpoint_path (str): Path to the saved checkpoint file.
158
+
159
+ Side Effects:
160
+ - Initializes a new model if not already set.
161
+ - Sets model matrices A, W, H, Q from the checkpoint.
162
+ - Computes Kalman gain based on restored parameters.
163
+ """
164
+ checkpoint_file = Path(checkpoint_path)
165
+ with open(checkpoint_file, "rb") as f:
166
+ checkpoint_data = pickle.load(f)
167
+ self._init_model(**checkpoint_data)
168
+ self._state.model._compute_gain()
169
+ self._state.model.is_fitted = True
170
+
171
+ def save_checkpoint(self, checkpoint_path: str) -> None:
172
+ """
173
+ Save current model parameters to a checkpoint file.
174
+
175
+ Args:
176
+ checkpoint_path (str): Destination file path for saving model parameters.
177
+
178
+ Raises:
179
+ ValueError: If the model is not initialized or has not been fitted.
180
+ """
181
+ if not self._state.model or not self._state.model.is_fitted:
182
+ raise ValueError("Cannot save checkpoint: model not fitted")
183
+ checkpoint_data = {
184
+ "A_state_transition_matrix": self._state.model.A_state_transition_matrix,
185
+ "W_process_noise_covariance": self._state.model.W_process_noise_covariance,
186
+ "H_observation_matrix": self._state.model.H_observation_matrix,
187
+ "Q_measurement_noise_covariance": self._state.model.Q_measurement_noise_covariance,
188
+ }
189
+ checkpoint_file = Path(checkpoint_path)
190
+ checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
191
+ with open(checkpoint_file, "wb") as f:
192
+ pickle.dump(checkpoint_data, f)
193
+
194
+ def _reset_state(
195
+ self,
196
+ message: AxisArray = None,
197
+ ):
198
+ """
199
+ This method initializes or reinitializes the state vector (x), state covariance (P),
200
+ and cursor position. If a checkpoint path is specified in the settings, the model
201
+ is loaded from the checkpoint.
202
+
203
+ Args:
204
+ message (AxisArray): Time-series message containing neural measurements.
205
+ x_init (np.ndarray): Initial state vector.
206
+ P_init (np.ndarray): Initial state covariance matrix.
207
+ """
208
+ if not self._state.model:
209
+ if self.settings.checkpoint_path:
210
+ self.load_from_checkpoint(self.settings.checkpoint_path)
211
+ else:
212
+ self._init_model()
213
+ ## TODO: fit the model - how to do this given expected inputs X and y?
214
+ # for unit test purposes only, given a known kinematic state size
215
+ state_dim = 2
216
+
217
+ # # If A is None, the model has not been fitted or loaded from checkpoint
218
+ # if self._state.model.A_state_transition_matrix is None:
219
+ # raise RuntimeError(
220
+ # "Cannot reset state — model has not been fitted or loaded from checkpoint."
221
+ # )
222
+
223
+ if self._state.model.A_state_transition_matrix is not None:
224
+ state_dim = self._state.model.A_state_transition_matrix.shape[0]
225
+
226
+ self._state.x = np.zeros(state_dim)
227
+ self._state.P = np.eye(state_dim)
228
+
229
+ self._state.buffer_neural = []
230
+ self._state.buffer_state = []
231
+ self._state.buffer_cursor_positions = []
232
+ self._state.buffer_target_positions = []
233
+ self._state.buffer_hold_flags = []
234
+
235
+ def _process(self, message: AxisArray) -> AxisArray:
236
+ """
237
+ Process an incoming message using the Kalman filter.
238
+
239
+ For each time point in the message:
240
+ - Predict the next state
241
+ - Update the estimate using the current measurement
242
+ - Track the velocity and estimate position
243
+
244
+ Args:
245
+ message (AxisArray): Time-series message containing neural measurements.
246
+
247
+ Returns:
248
+ AxisArray: Filtered message containing updated state estimates.
249
+ """
250
+ # If checkpoint, load the model from the checkpoint
251
+ if not self._state.model and self.settings.checkpoint_path:
252
+ self.load_from_checkpoint(self.settings.checkpoint_path)
253
+ # No checkpoint means you need to initialize and fit the model
254
+ elif not self._state.model:
255
+ self._init_model()
256
+ state_dim = self._state.model.A_state_transition_matrix.shape[0]
257
+ if self._state.x is None:
258
+ self._state.x = np.zeros(state_dim)
259
+
260
+ filtered_data = np.zeros(
261
+ (
262
+ message.data.shape[0],
263
+ self._state.model.A_state_transition_matrix.shape[0],
264
+ )
265
+ )
266
+
267
+ for i in range(message.data.shape[0]):
268
+ measurement = message.data[i]
269
+ # Predict
270
+ x_pred, P_pred = self._state.model.predict(self._state.x)
271
+
272
+ # Update
273
+ x_updated = self._state.model.update(measurement, x_pred, P_pred)
274
+
275
+ # Store
276
+ self._state.x = x_updated.copy()
277
+ self._state.P = self._state.model.P_state_covariance.copy()
278
+ filtered_data[i] = self._state.x
279
+
280
+ return replace(
281
+ message,
282
+ data=filtered_data,
283
+ dims=["time", "state"],
284
+ key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered",
285
+ )
286
+
287
+ def partial_fit(self, message: SampleMessage) -> None:
288
+ """
289
+ Perform refitting using externally provided data.
290
+
291
+ Expects message.sample.data (neural input) and message.trigger.value as a dict with:
292
+ - Y_state: (n_samples, n_states) array
293
+ - intention_velocity_indices: Optional[int]
294
+ - target_positions: Optional[np.ndarray]
295
+ - cursor_positions: Optional[np.ndarray]
296
+ - hold_flags: Optional[list[bool]]
297
+ """
298
+ if not hasattr(message, "sample") or not hasattr(message, "trigger"):
299
+ raise ValueError("Invalid message format for partial_fit.")
300
+
301
+ X = np.array(message.sample.data)
302
+ values = message.trigger.value
303
+
304
+ if not isinstance(values, dict) or "Y_state" not in values:
305
+ raise ValueError("partial_fit expects trigger.value to include at least 'Y_state'.")
306
+
307
+ kwargs = {
308
+ "X_neural": X,
309
+ "Y_state": np.array(values["Y_state"]),
310
+ }
311
+
312
+ # Optional fields
313
+ for key in [
314
+ "intention_velocity_indices",
315
+ "target_positions",
316
+ "cursor_positions",
317
+ "hold_flags",
318
+ ]:
319
+ if key in values and values[key] is not None:
320
+ kwargs[key if key != "hold_flags" else "hold_indices"] = np.array(values[key])
321
+
322
+ # Call model refit
323
+ self._state.model.refit(**kwargs)
324
+
325
+ def log_for_refit(
326
+ self,
327
+ message: AxisArray,
328
+ target_position: np.ndarray | None = None,
329
+ hold_flag: bool | None = None,
330
+ ):
331
+ """
332
+ Log data for potential refitting of the model.
333
+
334
+ This method stores measurement data, state estimates, and contextual
335
+ information (target positions, cursor positions, hold flags) in buffers
336
+ for later use in refitting the observation model. This data is used
337
+ to adapt the model to changing neural-to-behavioral relationships.
338
+
339
+ Args:
340
+ message: AxisArray message containing measurement data.
341
+ target_position: Target position for the current time point (2,).
342
+ hold_flag: Boolean flag indicating if this is a hold period.
343
+ """
344
+ if target_position is not None:
345
+ self._state.buffer_target_positions.append(target_position.copy())
346
+ if hold_flag is not None:
347
+ self._state.buffer_hold_flags.append(hold_flag)
348
+
349
+ measurement = message.data[-1]
350
+ self._state.buffer_neural.append(measurement.copy())
351
+ self._state.buffer_state.append(self._state.x.copy())
352
+
353
+ def refit_model(self):
354
+ """
355
+ Refit the observation model (H, Q) using buffered measurements and contextual data.
356
+
357
+ This method updates the model's understanding of the neural-to-state mapping
358
+ by calculating a new observation matrix and noise covariance, based on:
359
+ - Logged neural data
360
+ - Cursor state estimates
361
+ - Hold flags and target positions
362
+
363
+ Args:
364
+ velocity_indices (tuple): Indices in the state vector corresponding to velocity components.
365
+ Default assumes 2D velocity at indices (0, 1).
366
+
367
+ Raises:
368
+ ValueError: If no buffered data exists.
369
+ """
370
+ if not self._state.buffer_neural:
371
+ print("No buffered data to refit")
372
+ return
373
+
374
+ kwargs = {
375
+ "X_neural": np.array(self._state.buffer_neural),
376
+ "Y_state": np.array(self._state.buffer_state),
377
+ "intention_velocity_indices": self.settings.velocity_indices[0],
378
+ }
379
+
380
+ if self._state.buffer_target_positions and self._state.buffer_cursor_positions:
381
+ kwargs["target_positions"] = np.array(self._state.buffer_target_positions)
382
+ kwargs["cursor_positions"] = np.array(self._state.buffer_cursor_positions)
383
+ if self._state.buffer_hold_flags:
384
+ kwargs["hold_indices"] = np.array(self._state.buffer_hold_flags)
385
+
386
+ self._state.model.refit(**kwargs)
387
+
388
+ self._state.buffer_neural.clear()
389
+ self._state.buffer_state.clear()
390
+ self._state.buffer_cursor_positions.clear()
391
+ self._state.buffer_target_positions.clear()
392
+ self._state.buffer_hold_flags.clear()
393
+
394
+
395
+ class RefitKalmanFilterUnit(
396
+ BaseAdaptiveTransformerUnit[
397
+ RefitKalmanFilterSettings,
398
+ AxisArray,
399
+ AxisArray,
400
+ RefitKalmanFilterProcessor,
401
+ ]
402
+ ):
403
+ SETTINGS = RefitKalmanFilterSettings