ezmsg-learn 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.0'
32
- __version_tuple__ = version_tuple = (1, 0)
31
+ __version__ = version = '1.2.0'
32
+ __version_tuple__ = version_tuple = (1, 2, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,14 +1,14 @@
1
1
  import typing
2
2
 
3
- from sklearn.decomposition import IncrementalPCA, MiniBatchNMF
4
- import numpy as np
5
3
  import ezmsg.core as ez
6
- from ezmsg.sigproc.base import (
7
- processor_state,
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
8
6
  BaseAdaptiveTransformer,
9
7
  BaseAdaptiveTransformerUnit,
8
+ processor_state,
10
9
  )
11
10
  from ezmsg.util.messages.axisarray import AxisArray, replace
11
+ from sklearn.decomposition import IncrementalPCA, MiniBatchNMF
12
12
 
13
13
 
14
14
  class AdaptiveDecompSettings(ez.Settings):
@@ -23,15 +23,11 @@ class AdaptiveDecompState:
23
23
  estimator: typing.Any = None
24
24
 
25
25
 
26
- EstimatorType = typing.TypeVar(
27
- "EstimatorType", bound=typing.Union[IncrementalPCA, MiniBatchNMF]
28
- )
26
+ EstimatorType = typing.TypeVar("EstimatorType", bound=typing.Union[IncrementalPCA, MiniBatchNMF])
29
27
 
30
28
 
31
29
  class AdaptiveDecompTransformer(
32
- BaseAdaptiveTransformer[
33
- AdaptiveDecompSettings, AxisArray, AxisArray, AdaptiveDecompState
34
- ],
30
+ BaseAdaptiveTransformer[AdaptiveDecompSettings, AxisArray, AxisArray, AdaptiveDecompState],
35
31
  typing.Generic[EstimatorType],
36
32
  ):
37
33
  """
@@ -80,9 +76,7 @@ class AdaptiveDecompTransformer(
80
76
  it_ax_ix = message.get_axis_idx(iter_axis)
81
77
  # Remaining axes are to be treated independently
82
78
  off_targ_axes = [
83
- _
84
- for _ in (message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :])
85
- if _ != self.settings.axis
79
+ _ for _ in (message.dims[:it_ax_ix] + message.dims[it_ax_ix + 1 :]) if _ != self.settings.axis
86
80
  ]
87
81
  self._state.axis_groups = iter_axis, targ_axes, off_targ_axes
88
82
 
@@ -152,9 +146,7 @@ class AdaptiveDecompTransformer(
152
146
 
153
147
  # Transform data
154
148
  if hasattr(self._state.estimator, "components_"):
155
- decomp_dat = self._state.estimator.transform(in_dat).reshape(
156
- (-1,) + self._state.template.data.shape[1:]
157
- )
149
+ decomp_dat = self._state.estimator.transform(in_dat).reshape((-1,) + self._state.template.data.shape[1:])
158
150
  replace_kwargs["data"] = decomp_dat
159
151
 
160
152
  return replace(self._state.template, **replace_kwargs)
@@ -241,9 +233,7 @@ class MiniBatchNMFTransformer(AdaptiveDecompTransformer[MiniBatchNMF]):
241
233
  pass
242
234
 
243
235
 
244
- SettingsType = typing.TypeVar(
245
- "SettingsType", bound=typing.Union[IncrementalPCASettings, MiniBatchNMFSettings]
246
- )
236
+ SettingsType = typing.TypeVar("SettingsType", bound=typing.Union[IncrementalPCASettings, MiniBatchNMFSettings])
247
237
  TransformerType = typing.TypeVar(
248
238
  "TransformerType",
249
239
  bound=typing.Union[IncrementalPCATransformer, MiniBatchNMFTransformer],
@@ -1,14 +1,14 @@
1
1
  import typing
2
2
 
3
- import numpy as np
4
3
  import ezmsg.core as ez
5
- from ezmsg.util.messages.axisarray import AxisArray, replace
6
- from ezmsg.sigproc.base import (
7
- CompositeProcessor,
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
8
6
  BaseStatefulProcessor,
9
7
  BaseTransformerUnit,
8
+ CompositeProcessor,
10
9
  )
11
10
  from ezmsg.sigproc.window import WindowTransformer
11
+ from ezmsg.util.messages.axisarray import AxisArray, replace
12
12
 
13
13
  from .adaptive_decomp import (
14
14
  IncrementalPCASettings,
@@ -36,9 +36,7 @@ class IncrementalDecompSettings(ez.Settings):
36
36
  forget_factor: float = 0.7
37
37
 
38
38
 
39
- class IncrementalDecompTransformer(
40
- CompositeProcessor[IncrementalDecompSettings, AxisArray, AxisArray]
41
- ):
39
+ class IncrementalDecompTransformer(CompositeProcessor[IncrementalDecompSettings, AxisArray, AxisArray]):
42
40
  """
43
41
  Automates usage of IncrementalPCATransformer and MiniBatchNMFTransformer by using a WindowTransformer
44
42
  to extract training samples then calls partial_fit on the decomposition transformer.
@@ -125,15 +123,11 @@ class IncrementalDecompTransformer(
125
123
  # If the estimator has not been trained once, train it with the first message
126
124
  self._procs["decomp"].partial_fit(message)
127
125
  elif "windowing" in self._procs:
128
- state["windowing"], train_msg = self._procs["windowing"].stateful_op(
129
- state.get("windowing", None), message
130
- )
126
+ state["windowing"], train_msg = self._procs["windowing"].stateful_op(state.get("windowing", None), message)
131
127
  self._partial_fit_windowed(train_msg)
132
128
 
133
129
  # Process the incoming message
134
- state["decomp"], result = self._procs["decomp"].stateful_op(
135
- state.get("decomp", None), message
136
- )
130
+ state["decomp"], result = self._procs["decomp"].stateful_op(state.get("decomp", None), message)
137
131
 
138
132
  return state, result
139
133
 
@@ -174,8 +168,6 @@ class IncrementalDecompTransformer(
174
168
 
175
169
 
176
170
  class IncrementalDecompUnit(
177
- BaseTransformerUnit[
178
- IncrementalDecompSettings, AxisArray, AxisArray, IncrementalDecompTransformer
179
- ]
171
+ BaseTransformerUnit[IncrementalDecompSettings, AxisArray, AxisArray, IncrementalDecompTransformer]
180
172
  ):
181
173
  SETTINGS = IncrementalDecompSettings
@@ -1,6 +1,12 @@
1
1
  from ..process.adaptive_linear_regressor import (
2
2
  AdaptiveLinearRegressorSettings as AdaptiveLinearRegressorSettings,
3
+ )
4
+ from ..process.adaptive_linear_regressor import (
3
5
  AdaptiveLinearRegressorState as AdaptiveLinearRegressorState,
6
+ )
7
+ from ..process.adaptive_linear_regressor import (
4
8
  AdaptiveLinearRegressorTransformer as AdaptiveLinearRegressorTransformer,
9
+ )
10
+ from ..process.adaptive_linear_regressor import (
5
11
  AdaptiveLinearRegressorUnit as AdaptiveLinearRegressorUnit,
6
12
  )
@@ -1,5 +1,9 @@
1
1
  from ..process.linear_regressor import (
2
2
  LinearRegressorSettings as LinearRegressorSettings,
3
+ )
4
+ from ..process.linear_regressor import (
3
5
  LinearRegressorState as LinearRegressorState,
6
+ )
7
+ from ..process.linear_regressor import (
4
8
  LinearRegressorTransformer as LinearRegressorTransformer,
5
9
  )
@@ -1,5 +1,9 @@
1
1
  from ..process.sgd import (
2
- sgd_decoder as sgd_decoder,
3
- SGDDecoderSettings as SGDDecoderSettings,
4
2
  SGDDecoder as SGDDecoder,
5
3
  )
4
+ from ..process.sgd import (
5
+ SGDDecoderSettings as SGDDecoderSettings,
6
+ )
7
+ from ..process.sgd import (
8
+ sgd_decoder as sgd_decoder,
9
+ )
@@ -1,6 +1,12 @@
1
+ from ..process.slda import (
2
+ SLDA as SLDA,
3
+ )
1
4
  from ..process.slda import (
2
5
  SLDASettings as SLDASettings,
6
+ )
7
+ from ..process.slda import (
3
8
  SLDAState as SLDAState,
9
+ )
10
+ from ..process.slda import (
4
11
  SLDATransformer as SLDATransformer,
5
- SLDA as SLDA,
6
12
  )
ezmsg/learn/model/mlp.py CHANGED
@@ -27,11 +27,12 @@ class MLP(torch.nn.Module):
27
27
  Initialize the MLP model.
28
28
  Args:
29
29
  input_size (int): The size of the input features.
30
- hidden_size (int | list[int]): The sizes of the hidden layers. If a list, num_layers must be None or the length
31
- of the list. If a single integer, num_layers must be specified and determines the number of hidden layers.
30
+ hidden_size (int | list[int]): The sizes of the hidden layers. If a list, num_layers must be None or the
31
+ length of the list. If a single integer, num_layers must be specified and determines the number of
32
+ hidden layers.
32
33
  num_layers (int, optional): The number of hidden layers. Length of hidden_size if None. Default is None.
33
- output_heads (int | dict[str, int], optional): Number of output features or classes if single head output or a
34
- dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
34
+ output_heads (int | dict[str, int], optional): Number of output features or classes if single head output
35
+ or a dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head).
35
36
  norm_layer (str, optional): A normalization layer to be applied after each linear layer. Default is None.
36
37
  Common choices are "BatchNorm1d" or "LayerNorm".
37
38
  activation_layer (str, optional): An activation function to be applied after each normalization
@@ -43,18 +44,14 @@ class MLP(torch.nn.Module):
43
44
  super().__init__()
44
45
  if isinstance(hidden_size, int):
45
46
  if num_layers is None:
46
- raise ValueError(
47
- "If hidden_size is an integer, num_layers must be specified."
48
- )
47
+ raise ValueError("If hidden_size is an integer, num_layers must be specified.")
49
48
  hidden_size = [hidden_size] * num_layers
50
49
  if len(hidden_size) == 0:
51
50
  raise ValueError("hidden_size must have at least one element")
52
51
  if any(not isinstance(x, int) for x in hidden_size):
53
52
  raise ValueError("hidden_size must contain only integers")
54
53
  if num_layers is not None and len(hidden_size) != num_layers:
55
- raise ValueError(
56
- "Length of hidden_size must match num_layers if num_layers is specified."
57
- )
54
+ raise ValueError("Length of hidden_size must match num_layers if num_layers is specified.")
58
55
 
59
56
  params = {} if inplace is None else {"inplace": inplace}
60
57
 
@@ -84,10 +81,7 @@ class MLP(torch.nn.Module):
84
81
  if isinstance(output_heads, int):
85
82
  output_heads = {"output": output_heads}
86
83
  self.heads = torch.nn.ModuleDict(
87
- {
88
- name: torch.nn.Linear(hidden_size[-1], output_size)
89
- for name, output_size in output_heads.items()
90
- }
84
+ {name: torch.nn.Linear(hidden_size[-1], output_size) for name, output_size in output_heads.items()}
91
85
  )
92
86
 
93
87
  @classmethod
@@ -83,17 +83,12 @@ class RefitKalmanFilter:
83
83
  if Y_state.ndim != 2:
84
84
  raise ValueError(f"State vector must be 2D, got {Y_state.ndim}D")
85
85
 
86
- if (
87
- not hasattr(self, "H_observation_matrix")
88
- or self.H_observation_matrix is None
89
- ):
86
+ if not hasattr(self, "H_observation_matrix") or self.H_observation_matrix is None:
90
87
  raise ValueError("Model must be fitted before refitting")
91
88
 
92
89
  expected_states = self.H_observation_matrix.shape[1]
93
90
  if Y_state.shape[1] != expected_states:
94
- raise ValueError(
95
- f"State vector has {Y_state.shape[1]} dimensions, expected {expected_states}"
96
- )
91
+ raise ValueError(f"State vector has {Y_state.shape[1]} dimensions, expected {expected_states}")
97
92
 
98
93
  def fit(self, X_train, y_train):
99
94
  """
@@ -121,15 +116,11 @@ class RefitKalmanFilter:
121
116
  X2 = X[1:, :] # x_{t+1}
122
117
  X1 = X[:-1, :] # x_t
123
118
  A = X2.T @ X1 @ np.linalg.inv(X1.T @ X1) # Transition matrix
124
- W = (
125
- (X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1)
126
- ) # Covariance of transition matrix
119
+ W = (X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1) # Covariance of transition matrix
127
120
 
128
121
  # Calculate the measurement matrix (from x_t to z_t) using least-squares
129
122
  H = Z.T @ X @ np.linalg.inv(X.T @ X) # Measurement matrix
130
- Q = (
131
- (Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0]
132
- ) # Covariance of measurement matrix
123
+ Q = (Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0] # Covariance of measurement matrix
133
124
 
134
125
  self.A_state_transition_matrix = A
135
126
  self.W_process_noise_covariance = W * self.process_noise_scale
@@ -179,15 +170,11 @@ class RefitKalmanFilter:
179
170
  if intention_velocity_indices is None:
180
171
  # Assume (x, y, vx, vy)
181
172
  vel_idx = 2 if Y_state.shape[1] >= 4 else 0
182
- print(
183
- f"[RefitKalmanFilter] No velocity index provided — defaulting to {vel_idx}"
184
- )
173
+ print(f"[RefitKalmanFilter] No velocity index provided — defaulting to {vel_idx}")
185
174
  else:
186
175
  if isinstance(intention_velocity_indices, (list, tuple)):
187
176
  if len(intention_velocity_indices) != 1:
188
- raise ValueError(
189
- "Only one velocity start index should be provided."
190
- )
177
+ raise ValueError("Only one velocity start index should be provided.")
191
178
  vel_idx = intention_velocity_indices[0]
192
179
  else:
193
180
  vel_idx = intention_velocity_indices
@@ -198,18 +185,14 @@ class RefitKalmanFilter:
198
185
  else:
199
186
  intended_states = Y_state.copy()
200
187
  # Calculate intended velocities for each sample
201
- for i, (state, pos, target) in enumerate(
202
- zip(Y_state, cursor_positions, target_positions)
203
- ):
188
+ for i, (state, pos, target) in enumerate(zip(Y_state, cursor_positions, target_positions)):
204
189
  is_hold = hold_indices[i] if hold_indices is not None else False
205
190
 
206
191
  if is_hold:
207
192
  # During hold periods, intended velocity is zero
208
193
  intended_states[i, vel_idx : vel_idx + 2] = 0.0
209
194
  if i > 0:
210
- intended_states[i, :2] = intended_states[
211
- i - 1, :2
212
- ] # Same position as previous
195
+ intended_states[i, :2] = intended_states[i - 1, :2] # Same position as previous
213
196
  else:
214
197
  # Calculate direction to target
215
198
  to_target = target - pos
@@ -228,9 +211,7 @@ class RefitKalmanFilter:
228
211
  intended_states[i, vel_idx : vel_idx + 2] = intended_velocity
229
212
  # If target is very close, keep original velocity
230
213
  else:
231
- intended_states[i, vel_idx : vel_idx + 2] = state[
232
- vel_idx : vel_idx + 2
233
- ]
214
+ intended_states[i, vel_idx : vel_idx + 2] = state[vel_idx : vel_idx + 2]
234
215
 
235
216
  intended_states = np.array(intended_states)
236
217
  Z = np.array(X_neural)
@@ -258,7 +239,8 @@ class RefitKalmanFilter:
258
239
  Raises:
259
240
  LinAlgError: If the Riccati equation cannot be solved or matrix operations fail.
260
241
  """
261
- ## TODO: consider removing non-steady-state for compute_gain() - non_steady_state updates will occur during predict() and update()
242
+ # TODO: consider removing non-steady-state for compute_gain() -
243
+ # non_steady_state updates will occur during predict() and update()
262
244
  # if self.steady_state:
263
245
  try:
264
246
  # Try with original matrices
@@ -272,9 +254,7 @@ class RefitKalmanFilter:
272
254
  self.P_state_covariance
273
255
  @ self.H_observation_matrix.T
274
256
  @ np.linalg.inv(
275
- self.H_observation_matrix
276
- @ self.P_state_covariance
277
- @ self.H_observation_matrix.T
257
+ self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T
278
258
  + self.Q_measurement_noise_covariance
279
259
  )
280
260
  )
@@ -284,9 +264,7 @@ class RefitKalmanFilter:
284
264
  # W_reg = self.W_process_noise_covariance + 1e-7 * np.eye(
285
265
  # self.W_process_noise_covariance.shape[0]
286
266
  # )
287
- Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye(
288
- self.Q_measurement_noise_covariance.shape[0]
289
- )
267
+ Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye(self.Q_measurement_noise_covariance.shape[0])
290
268
 
291
269
  try:
292
270
  self.P_state_covariance = solve_discrete_are(
@@ -299,19 +277,14 @@ class RefitKalmanFilter:
299
277
  self.P_state_covariance
300
278
  @ self.H_observation_matrix.T
301
279
  @ np.linalg.inv(
302
- self.H_observation_matrix
303
- @ self.P_state_covariance
304
- @ self.H_observation_matrix.T
305
- + Q_reg
280
+ self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T + Q_reg
306
281
  )
307
282
  )
308
283
  print("Warning: Used regularized matrices for DARE solution")
309
284
  except LinAlgError:
310
285
  # Fallback to identity or manual initialization
311
286
  print("Warning: DARE failed, using identity covariance")
312
- self.P_state_covariance = np.eye(
313
- self.A_state_transition_matrix.shape[0]
314
- )
287
+ self.P_state_covariance = np.eye(self.A_state_transition_matrix.shape[0])
315
288
 
316
289
  # else:
317
290
  # n_states = self.A_state_transition_matrix.shape[0]
@@ -349,9 +322,7 @@ class RefitKalmanFilter:
349
322
  return x_predicted, None
350
323
  else:
351
324
  P_predicted = self.alpha_fading_memory**2 * (
352
- self.A_state_transition_matrix
353
- @ self.P_state_covariance
354
- @ self.A_state_transition_matrix.T
325
+ self.A_state_transition_matrix @ self.P_state_covariance @ self.A_state_transition_matrix.T
355
326
  + self.W_process_noise_covariance
356
327
  )
357
328
  return x_predicted, P_predicted
@@ -376,10 +347,7 @@ class RefitKalmanFilter:
376
347
 
377
348
  # Non-steady-state mode
378
349
  # System uncertainty
379
- S = (
380
- self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T
381
- + self.Q_measurement_noise_covariance
382
- )
350
+ S = self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T + self.Q_measurement_noise_covariance
383
351
 
384
352
  # Kalman gain
385
353
  K = P_predicted @ self.H_observation_matrix.T @ np.linalg.pinv(S)
@@ -1,6 +1,10 @@
1
1
  from ..model.mlp_old import MLP as MLP
2
+ from ..process.mlp_old import (
3
+ MLPProcessor as MLPProcessor,
4
+ )
2
5
  from ..process.mlp_old import (
3
6
  MLPSettings as MLPSettings,
7
+ )
8
+ from ..process.mlp_old import (
4
9
  MLPState as MLPState,
5
- MLPProcessor as MLPProcessor,
6
10
  )
@@ -1,16 +1,15 @@
1
1
  from dataclasses import field
2
2
 
3
+ import ezmsg.core as ez
3
4
  import numpy as np
4
5
  import pandas as pd
5
- import river.optim
6
6
  import river.linear_model
7
+ import river.optim
7
8
  import sklearn.base
8
- import ezmsg.core as ez
9
- from ezmsg.sigproc.sampler import SampleMessage
10
- from ezmsg.sigproc.base import (
11
- processor_state,
9
+ from ezmsg.baseproc import (
12
10
  BaseAdaptiveTransformer,
13
11
  BaseAdaptiveTransformerUnit,
12
+ processor_state,
14
13
  )
15
14
  from ezmsg.util.messages.axisarray import AxisArray, replace
16
15
 
@@ -39,9 +38,7 @@ class AdaptiveLinearRegressorTransformer(
39
38
  ):
40
39
  def __init__(self, *args, **kwargs):
41
40
  super().__init__(*args, **kwargs)
42
- self.settings = replace(
43
- self.settings, model_type=AdaptiveLinearRegressor(self.settings.model_type)
44
- )
41
+ self.settings = replace(self.settings, model_type=AdaptiveLinearRegressor(self.settings.model_type))
45
42
  b_river = self.settings.model_type in [
46
43
  AdaptiveLinearRegressor.LINEAR,
47
44
  AdaptiveLinearRegressor.LOGISTIC,
@@ -49,9 +46,7 @@ class AdaptiveLinearRegressorTransformer(
49
46
  if b_river:
50
47
  self.settings.model_kwargs["l2"] = self.settings.model_kwargs.get("l2", 0.0)
51
48
  if "learn_rate" in self.settings.model_kwargs:
52
- self.settings.model_kwargs["optimizer"] = river.optim.SGD(
53
- self.settings.model_kwargs.pop("learn_rate")
54
- )
49
+ self.settings.model_kwargs["optimizer"] = river.optim.SGD(self.settings.model_kwargs.pop("learn_rate"))
55
50
 
56
51
  if self.settings.settings_path is not None:
57
52
  # Load model from file
@@ -69,9 +64,7 @@ class AdaptiveLinearRegressorTransformer(
69
64
  print("TODO: Override sklearn model with kwargs")
70
65
  else:
71
66
  # Build model from scratch.
72
- regressor_klass = get_regressor(
73
- RegressorType.ADAPTIVE, self.settings.model_type
74
- )
67
+ regressor_klass = get_regressor(RegressorType.ADAPTIVE, self.settings.model_type)
75
68
  self.state.model = regressor_klass(**self.settings.model_kwargs)
76
69
 
77
70
  def _hash_message(self, message: AxisArray) -> int:
@@ -84,37 +77,30 @@ class AdaptiveLinearRegressorTransformer(
84
77
  # .template is updated in partial_fit
85
78
  pass
86
79
 
87
- def partial_fit(self, message: SampleMessage) -> None:
88
- if np.any(np.isnan(message.sample.data)):
80
+ def partial_fit(self, message: AxisArray) -> None:
81
+ if np.any(np.isnan(message.data)):
89
82
  return
90
83
 
91
84
  if self.settings.model_type in [
92
85
  AdaptiveLinearRegressor.LINEAR,
93
86
  AdaptiveLinearRegressor.LOGISTIC,
94
87
  ]:
95
- x = pd.DataFrame.from_dict(
96
- {
97
- k: v
98
- for k, v in zip(
99
- message.sample.axes["ch"].data, message.sample.data.T
100
- )
101
- }
102
- )
88
+ x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
103
89
  y = pd.Series(
104
- data=message.trigger.value.data[:, 0],
105
- name=message.trigger.value.axes["ch"].data[0],
90
+ data=message.attrs["trigger"].value.data[:, 0],
91
+ name=message.attrs["trigger"].value.axes["ch"].data[0],
106
92
  )
107
93
  self.state.model.learn_many(x, y)
108
94
  else:
109
- X = message.sample.data
110
- if message.sample.get_axis_idx("time") != 0:
111
- X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0)
112
- self.state.model.partial_fit(X, message.trigger.value.data)
95
+ X = message.data
96
+ if message.get_axis_idx("time") != 0:
97
+ X = np.moveaxis(X, message.get_axis_idx("time"), 0)
98
+ self.state.model.partial_fit(X, message.attrs["trigger"].value.data)
113
99
 
114
100
  self.state.template = replace(
115
- message.trigger.value,
116
- data=np.empty_like(message.trigger.value.data),
117
- key=message.trigger.value.key + "_pred",
101
+ message.attrs["trigger"].value,
102
+ data=np.empty_like(message.attrs["trigger"].value.data),
103
+ key=message.attrs["trigger"].value.key + "_pred",
118
104
  )
119
105
 
120
106
  def _process(self, message: AxisArray) -> AxisArray | None:
@@ -127,9 +113,7 @@ class AdaptiveLinearRegressorTransformer(
127
113
  AdaptiveLinearRegressor.LOGISTIC,
128
114
  ]:
129
115
  # convert msg_in.data to something appropriate for river
130
- x = pd.DataFrame.from_dict(
131
- {k: v for k, v in zip(message.axes["ch"].data, message.data.T)}
132
- )
116
+ x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
133
117
  preds = self.state.model.predict_many(x).values
134
118
  else:
135
119
  preds = self.state.model.predict(message.data)
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import json
3
- from pathlib import Path
4
3
  import typing
4
+ from pathlib import Path
5
5
 
6
6
  import ezmsg.core as ez
7
7
  import torch
@@ -32,9 +32,7 @@ class ModelInitMixin:
32
32
  for key, value in config.items():
33
33
  if key in model_kwargs:
34
34
  if model_kwargs[key] != value:
35
- ez.logger.warning(
36
- f"Config parameter {key} ({value}) differs from settings ({model_kwargs[key]})."
37
- )
35
+ ez.logger.warning(f"Config parameter {key} ({value}) differs from settings ({model_kwargs[key]}).")
38
36
  else:
39
37
  ez.logger.warning(f"Config parameter {key} is not in model_kwargs.")
40
38
  model_kwargs[key] = value
@@ -44,7 +42,8 @@ class ModelInitMixin:
44
42
  filtered_out = set(kwargs.keys()) - {k for k in valid_params if k != "self"}
45
43
  if filtered_out:
46
44
  ez.logger.warning(
47
- f"Ignoring unexpected model parameters not accepted by {model_class.__name__} constructor: {sorted(filtered_out)}"
45
+ "Ignoring unexpected model parameters not accepted by"
46
+ f"{model_class.__name__} constructor: {sorted(filtered_out)}"
48
47
  )
49
48
  # Keep all valid parameters, including None values, so checkpoint-inferred values can overwrite them
50
49
  return {k: v for k, v in kwargs.items() if k in valid_params and k != "self"}
@@ -92,22 +91,16 @@ class ModelInitMixin:
92
91
  config = json.load(f)
93
92
  self._merge_config(model_kwargs, config)
94
93
  except Exception as e:
95
- raise RuntimeError(
96
- f"Failed to load config from {config_path}: {str(e)}"
97
- )
94
+ raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
98
95
 
99
96
  # If a checkpoint file is provided, load it.
100
97
  if checkpoint_path:
101
98
  checkpoint_path = Path(checkpoint_path)
102
99
  if not checkpoint_path.exists():
103
100
  ez.logger.error(f"Checkpoint path {checkpoint_path} does not exist.")
104
- raise FileNotFoundError(
105
- f"Checkpoint path {checkpoint_path} does not exist."
106
- )
101
+ raise FileNotFoundError(f"Checkpoint path {checkpoint_path} does not exist.")
107
102
  try:
108
- checkpoint = torch.load(
109
- checkpoint_path, map_location=device, weights_only=weights_only
110
- )
103
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
111
104
 
112
105
  if "config" in checkpoint:
113
106
  config = checkpoint["config"]
@@ -126,20 +119,14 @@ class ModelInitMixin:
126
119
  "infer_config_from_state_dict",
127
120
  lambda _state_dict: {}, # Default to empty dict if not defined
128
121
  )
129
- infer_kwargs = (
130
- {"rnn_type": model_kwargs["rnn_type"]}
131
- if "rnn_type" in model_kwargs
132
- else {}
133
- )
122
+ infer_kwargs = {"rnn_type": model_kwargs["rnn_type"]} if "rnn_type" in model_kwargs else {}
134
123
  self._merge_config(
135
124
  model_kwargs,
136
125
  infer_config(state_dict, **infer_kwargs),
137
126
  )
138
127
 
139
128
  except Exception as e:
140
- raise RuntimeError(
141
- f"Failed to load checkpoint from {checkpoint_path}: {str(e)}"
142
- )
129
+ raise RuntimeError(f"Failed to load checkpoint from {checkpoint_path}: {str(e)}")
143
130
 
144
131
  # Filter model_kwargs to only include valid parameters for the model class
145
132
  filtered_kwargs = self._filter_model_kwargs(model_class, model_kwargs)
@@ -156,18 +143,12 @@ class ModelInitMixin:
156
143
  if state_dict_prefix:
157
144
  # If a prefix is provided, filter the state_dict keys
158
145
  state_dict = {
159
- k[len(state_dict_prefix) :]: v
160
- for k, v in state_dict.items()
161
- if k.startswith(state_dict_prefix)
146
+ k[len(state_dict_prefix) :]: v for k, v in state_dict.items() if k.startswith(state_dict_prefix)
162
147
  }
163
148
  # Load the model weights
164
- missing, unexpected = model.load_state_dict(
165
- state_dict, strict=False, assign=True
166
- )
149
+ missing, unexpected = model.load_state_dict(state_dict, strict=False, assign=True)
167
150
  if missing or unexpected:
168
- ez.logger.warning(
169
- f"Partial load: missing keys: {missing}, unexpected keys: {unexpected}"
170
- )
151
+ ez.logger.warning(f"Partial load: missing keys: {missing}, unexpected keys: {unexpected}")
171
152
 
172
153
  model.to(device)
173
154
  return model