ezmsg-learn 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,16 @@
1
1
  from dataclasses import field
2
2
 
3
- import numpy as np
4
- from sklearn.linear_model._base import LinearModel
5
3
  import ezmsg.core as ez
6
- from ezmsg.sigproc.base import (
7
- processor_state,
4
+ import numpy as np
5
+ from ezmsg.baseproc import (
8
6
  BaseAdaptiveTransformer,
9
7
  BaseAdaptiveTransformerUnit,
8
+ processor_state,
10
9
  )
11
10
  from ezmsg.util.messages.axisarray import AxisArray, replace
12
- from ezmsg.sigproc.sampler import SampleMessage
11
+ from sklearn.linear_model._base import LinearModel
13
12
 
14
- from ..util import get_regressor, StaticLinearRegressor, RegressorType
13
+ from ..util import RegressorType, StaticLinearRegressor, get_regressor
15
14
 
16
15
 
17
16
  class LinearRegressorSettings(ez.Settings):
@@ -27,9 +26,7 @@ class LinearRegressorState:
27
26
 
28
27
 
29
28
  class LinearRegressorTransformer(
30
- BaseAdaptiveTransformer[
31
- LinearRegressorSettings, AxisArray, AxisArray, LinearRegressorState
32
- ]
29
+ BaseAdaptiveTransformer[LinearRegressorSettings, AxisArray, AxisArray, LinearRegressorState]
33
30
  ):
34
31
  """
35
32
  Linear regressor.
@@ -47,9 +44,7 @@ class LinearRegressorTransformer(
47
44
  with open(self.settings.settings_path, "rb") as f:
48
45
  self.state.model = pickle.load(f)
49
46
  else:
50
- regressor_klass = get_regressor(
51
- RegressorType.STATIC, self.settings.model_type
52
- )
47
+ regressor_klass = get_regressor(RegressorType.STATIC, self.settings.model_type)
53
48
  self.state.model = regressor_klass(**self.settings.model_kwargs)
54
49
 
55
50
  def _reset_state(self, message: AxisArray) -> None:
@@ -57,18 +52,18 @@ class LinearRegressorTransformer(
57
52
  # .model and .template are initialized in __init__
58
53
  pass
59
54
 
60
- def partial_fit(self, message: SampleMessage) -> None:
61
- if np.any(np.isnan(message.sample.data)):
55
+ def partial_fit(self, message: AxisArray) -> None:
56
+ if np.any(np.isnan(message.data)):
62
57
  return
63
58
 
64
- X = message.sample.data
65
- y = message.trigger.value.data
59
+ X = message.data
60
+ y = message.attrs["trigger"].value.data
66
61
  # TODO: Resample should provide identical durations.
67
62
  self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]])
68
63
  self.state.template = replace(
69
- message.trigger.value,
64
+ message.attrs["trigger"].value,
70
65
  data=np.array([[]]),
71
- key=message.trigger.value.key + "_pred",
66
+ key=message.attrs["trigger"].value.key + "_pred",
72
67
  )
73
68
 
74
69
  def _process(self, message: AxisArray) -> AxisArray:
@@ -1,17 +1,16 @@
1
1
  import typing
2
2
 
3
+ import ezmsg.core as ez
3
4
  import numpy as np
4
5
  import torch
5
6
  import torch.nn
6
- import ezmsg.core as ez
7
- from ezmsg.util.messages.axisarray import AxisArray
8
- from ezmsg.sigproc.sampler import SampleMessage
9
- from ezmsg.util.messages.util import replace
10
- from ezmsg.sigproc.base import (
7
+ from ezmsg.baseproc import (
11
8
  BaseAdaptiveTransformer,
12
9
  BaseAdaptiveTransformerUnit,
13
10
  processor_state,
14
11
  )
12
+ from ezmsg.util.messages.axisarray import AxisArray
13
+ from ezmsg.util.messages.util import replace
15
14
 
16
15
  from ..model.mlp_old import MLP
17
16
 
@@ -24,10 +23,12 @@ class MLPSettings(ez.Settings):
24
23
  """Norm layer that will be stacked on top of the linear layer. If None this layer won’t be used."""
25
24
 
26
25
  activation_layer: typing.Callable[..., torch.nn.Module] | None = torch.nn.ReLU
27
- """Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If None this layer won’t be used."""
26
+ """Activation function which will be stacked on top of the normalization layer (if not None),
27
+ otherwise on top of the linear layer. If None this layer won’t be used."""
28
28
 
29
29
  inplace: bool | None = None
30
- """Parameter for the activation layer, which can optionally do the operation in-place. Default is None, which uses the respective default values of the activation_layer and Dropout layer."""
30
+ """Parameter for the activation layer, which can optionally do the operation in-place.
31
+ Default is None, which uses the respective default values of the activation_layer and Dropout layer."""
31
32
 
32
33
  bias: bool = True
33
34
  """Whether to use bias in the linear layer."""
@@ -58,9 +59,7 @@ class MLPState:
58
59
  device: object | None = None
59
60
 
60
61
 
61
- class MLPProcessor(
62
- BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, MLPState]
63
- ):
62
+ class MLPProcessor(BaseAdaptiveTransformer[MLPSettings, AxisArray, AxisArray, MLPState]):
64
63
  def _hash_message(self, message: AxisArray) -> int:
65
64
  hash_items = (message.key,)
66
65
  if "ch" in message.dims:
@@ -85,39 +84,29 @@ class MLPProcessor(
85
84
  checkpoint = torch.load(self.settings.checkpoint_path)
86
85
  self._state.model.load_state_dict(checkpoint["model_state_dict"])
87
86
  except Exception as e:
88
- raise RuntimeError(
89
- f"Failed to load checkpoint from {self.settings.checkpoint_path}: {str(e)}"
90
- )
87
+ raise RuntimeError(f"Failed to load checkpoint from {self.settings.checkpoint_path}: {str(e)}")
91
88
 
92
89
  # Set the model to evaluation mode by default
93
90
  self._state.model.eval()
94
91
 
95
92
  # Create the optimizer
96
- self._state.optimizer = torch.optim.Adam(
97
- self._state.model.parameters(), lr=self.settings.learning_rate
98
- )
93
+ self._state.optimizer = torch.optim.Adam(self._state.model.parameters(), lr=self.settings.learning_rate)
99
94
 
100
95
  # Update the optimizer from checkpoint if it exists
101
96
  if self.settings.checkpoint_path is not None:
102
97
  try:
103
98
  checkpoint = torch.load(self.settings.checkpoint_path)
104
99
  if "optimizer_state_dict" in checkpoint:
105
- self._state.optimizer.load_state_dict(
106
- checkpoint["optimizer_state_dict"]
107
- )
100
+ self._state.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
108
101
  except Exception as e:
109
- raise RuntimeError(
110
- f"Failed to load optimizer from {self.settings.checkpoint_path}: {str(e)}"
111
- )
102
+ raise RuntimeError(f"Failed to load optimizer from {self.settings.checkpoint_path}: {str(e)}")
112
103
 
113
104
  # TODO: Should the model be moved to a device before the next line?
114
105
  self._state.device = next(self.state.model.parameters()).device
115
106
 
116
107
  # Optionally create the learning rate scheduler
117
108
  self._state.scheduler = (
118
- torch.optim.lr_scheduler.ExponentialLR(
119
- self._state.optimizer, gamma=self.settings.scheduler_gamma
120
- )
109
+ torch.optim.lr_scheduler.ExponentialLR(self._state.optimizer, gamma=self.settings.scheduler_gamma)
121
110
  if self.settings.scheduler_gamma > 0.0
122
111
  else None
123
112
  )
@@ -144,14 +133,14 @@ class MLPProcessor(
144
133
  dtype = torch.float32 if self.settings.single_precision else torch.float64
145
134
  return torch.tensor(data, dtype=dtype, device=self._state.device)
146
135
 
147
- def partial_fit(self, message: SampleMessage) -> None:
136
+ def partial_fit(self, message: AxisArray) -> None:
148
137
  self._state.model.train()
149
138
 
150
139
  # TODO: loss_fn should be determined by setting
151
140
  loss_fn = torch.nn.functional.mse_loss
152
141
 
153
- X = self._to_tensor(message.sample.data)
154
- y_targ = self._to_tensor(message.trigger.value)
142
+ X = self._to_tensor(message.data)
143
+ y_targ = self._to_tensor(message.attrs["trigger"].value)
155
144
 
156
145
  with torch.set_grad_enabled(True):
157
146
  self._state.model.train()
@@ -171,9 +160,7 @@ class MLPProcessor(
171
160
  if not isinstance(data, torch.Tensor):
172
161
  data = torch.tensor(
173
162
  data,
174
- dtype=torch.float32
175
- if self.settings.single_precision
176
- else torch.float64,
163
+ dtype=torch.float32 if self.settings.single_precision else torch.float64,
177
164
  )
178
165
 
179
166
  with torch.no_grad():
@@ -3,12 +3,11 @@ from pathlib import Path
3
3
 
4
4
  import ezmsg.core as ez
5
5
  import numpy as np
6
- from ezmsg.sigproc.base import (
6
+ from ezmsg.baseproc import (
7
7
  BaseAdaptiveTransformer,
8
8
  BaseAdaptiveTransformerUnit,
9
9
  processor_state,
10
10
  )
11
- from ezmsg.sigproc.sampler import SampleMessage
12
11
  from ezmsg.util.messages.axisarray import AxisArray
13
12
  from ezmsg.util.messages.util import replace
14
13
 
@@ -284,27 +283,25 @@ class RefitKalmanFilterProcessor(
284
283
  key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered",
285
284
  )
286
285
 
287
- def partial_fit(self, message: SampleMessage) -> None:
286
+ def partial_fit(self, message: AxisArray) -> None:
288
287
  """
289
288
  Perform refitting using externally provided data.
290
289
 
291
- Expects message.sample.data (neural input) and message.trigger.value as a dict with:
290
+ Expects message.data (neural input) and message.attrs["trigger"].value as a dict with:
292
291
  - Y_state: (n_samples, n_states) array
293
292
  - intention_velocity_indices: Optional[int]
294
293
  - target_positions: Optional[np.ndarray]
295
294
  - cursor_positions: Optional[np.ndarray]
296
295
  - hold_flags: Optional[list[bool]]
297
296
  """
298
- if not hasattr(message, "sample") or not hasattr(message, "trigger"):
297
+ if "trigger" not in message.attrs:
299
298
  raise ValueError("Invalid message format for partial_fit.")
300
299
 
301
- X = np.array(message.sample.data)
302
- values = message.trigger.value
300
+ X = np.array(message.data)
301
+ values = message.attrs["trigger"].value
303
302
 
304
303
  if not isinstance(values, dict) or "Y_state" not in values:
305
- raise ValueError(
306
- "partial_fit expects trigger.value to include at least 'Y_state'."
307
- )
304
+ raise ValueError("partial_fit expects trigger.value to include at least 'Y_state'.")
308
305
 
309
306
  kwargs = {
310
307
  "X_neural": X,
@@ -319,9 +316,7 @@ class RefitKalmanFilterProcessor(
319
316
  "hold_flags",
320
317
  ]:
321
318
  if key in values and values[key] is not None:
322
- kwargs[key if key != "hold_flags" else "hold_indices"] = np.array(
323
- values[key]
324
- )
319
+ kwargs[key if key != "hold_flags" else "hold_indices"] = np.array(values[key])
325
320
 
326
321
  # Call model refit
327
322
  self._state.model.refit(**kwargs)
@@ -3,9 +3,8 @@ import typing
3
3
  import ezmsg.core as ez
4
4
  import numpy as np
5
5
  import torch
6
- from ezmsg.sigproc.base import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
7
- from ezmsg.sigproc.sampler import SampleMessage
8
- from ezmsg.sigproc.util.profile import profile_subpub
6
+ from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
7
+ from ezmsg.baseproc.util.profile import profile_subpub
9
8
  from ezmsg.util.messages.axisarray import AxisArray
10
9
  from ezmsg.util.messages.util import replace
11
10
 
@@ -47,9 +46,7 @@ class RNNProcessor(
47
46
  TorchProcessorMixin,
48
47
  ModelInitMixin,
49
48
  ):
50
- def _infer_output_sizes(
51
- self, model: torch.nn.Module, n_input: int
52
- ) -> dict[str, int]:
49
+ def _infer_output_sizes(self, model: torch.nn.Module, n_input: int) -> dict[str, int]:
53
50
  """Simple inference to get output channel size."""
54
51
  dummy_input = torch.zeros(1, 50, n_input, device=self._state.device)
55
52
  with torch.no_grad():
@@ -78,9 +75,7 @@ class RNNProcessor(
78
75
  preserve_state = True
79
76
  elif "time" not in axes or "win" not in axes:
80
77
  # Default fallback
81
- ez.logger.warning(
82
- "Missing 'time' or 'win' axis for auto preserve-state logic. Defaulting to reset."
83
- )
78
+ ez.logger.warning("Missing 'time' or 'win' axis for auto preserve-state logic. Defaulting to reset.")
84
79
  preserve_state = False
85
80
  else:
86
81
  # Calculate stride between windows (assuming uniform spacing)
@@ -89,9 +84,7 @@ class RNNProcessor(
89
84
  time_len = message.data.shape[message.get_axis_idx("time")]
90
85
  gain = getattr(axes["time"], "gain", None)
91
86
  if gain is None:
92
- ez.logger.warning(
93
- "Time axis gain not found, using default gain of 1.0."
94
- )
87
+ ez.logger.warning("Time axis gain not found, using default gain of 1.0.")
95
88
  gain = 1.0 # fallback default
96
89
  win_len = time_len * gain
97
90
  # Determine if we should preserve state
@@ -102,15 +95,9 @@ class RNNProcessor(
102
95
  self.reset_hidden(batch_size)
103
96
  else:
104
97
  # If preserving state, only reset if batch size isn't 1
105
- hx_batch_size = (
106
- self._state.hx[0].shape[1]
107
- if isinstance(self._state.hx, tuple)
108
- else self._state.hx.shape[1]
109
- )
98
+ hx_batch_size = self._state.hx[0].shape[1] if isinstance(self._state.hx, tuple) else self._state.hx.shape[1]
110
99
  if hx_batch_size != 1:
111
- ez.logger.debug(
112
- f"Resetting hidden state due to batch size mismatch (hx: {hx_batch_size}, new: 1)"
113
- )
100
+ ez.logger.debug(f"Resetting hidden state due to batch size mismatch (hx: {hx_batch_size}, new: 1)")
114
101
  self.reset_hidden(1)
115
102
  return preserve_state
116
103
 
@@ -119,9 +106,7 @@ class RNNProcessor(
119
106
  if not isinstance(x, torch.Tensor):
120
107
  x = torch.tensor(
121
108
  x,
122
- dtype=torch.float32
123
- if self.settings.single_precision
124
- else torch.float64,
109
+ dtype=torch.float32 if self.settings.single_precision else torch.float64,
125
110
  device=self._state.device,
126
111
  )
127
112
 
@@ -143,18 +128,11 @@ class RNNProcessor(
143
128
  y_data[key] = []
144
129
  y_data[key].append(out.cpu().numpy())
145
130
  # Concatenate outputs for each key
146
- y_data = {
147
- key: np.concatenate(outputs, axis=0)
148
- for key, outputs in y_data.items()
149
- }
131
+ y_data = {key: np.concatenate(outputs, axis=0) for key, outputs in y_data.items()}
150
132
  else:
151
133
  y, self._state.hx = self._state.model(x, hx=self._state.hx)
152
134
  y_data = {
153
- key: (
154
- out.cpu().numpy().squeeze(0)
155
- if added_batch_dim
156
- else out.cpu().numpy()
157
- )
135
+ key: (out.cpu().numpy().squeeze(0) if added_batch_dim else out.cpu().numpy())
158
136
  for key, out in y.items()
159
137
  }
160
138
 
@@ -205,18 +183,18 @@ class RNNProcessor(
205
183
  if self._state.scheduler is not None:
206
184
  self._state.scheduler.step()
207
185
 
208
- def partial_fit(self, message: SampleMessage) -> None:
186
+ def partial_fit(self, message: AxisArray) -> None:
209
187
  self._state.model.train()
210
188
 
211
- X = self._to_tensor(message.sample.data)
189
+ X = self._to_tensor(message.data)
212
190
 
213
191
  # Add batch dimension if missing
214
192
  X, batched = self._ensure_batched(X)
215
193
 
216
194
  batch_size = X.shape[0]
217
- preserve_state = self._maybe_reset_state(message.sample, batch_size)
195
+ preserve_state = self._maybe_reset_state(message, batch_size)
218
196
 
219
- y_targ = message.trigger.value
197
+ y_targ = message.attrs["trigger"].value
220
198
  if not isinstance(y_targ, dict):
221
199
  y_targ = {"output": y_targ}
222
200
  y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
@@ -2,9 +2,11 @@ import typing
2
2
 
3
3
  import ezmsg.core as ez
4
4
  import numpy as np
5
- from ezmsg.sigproc.sampler import SampleMessage
6
- from ezmsg.sigproc.base import GenAxisArray
7
- from ezmsg.util.generator import consumer
5
+ from ezmsg.baseproc import (
6
+ BaseAdaptiveTransformer,
7
+ BaseAdaptiveTransformerUnit,
8
+ processor_state,
9
+ )
8
10
  from ezmsg.util.messages.axisarray import AxisArray
9
11
  from ezmsg.util.messages.util import replace
10
12
  from sklearn.exceptions import NotFittedError
@@ -13,103 +15,6 @@ from sklearn.linear_model import SGDClassifier
13
15
  from ..util import ClassifierMessage
14
16
 
15
17
 
16
- @consumer
17
- def sgd_decoder(
18
- alpha: float = 1.5e-5,
19
- eta0: float = 1e-7, # Lower than what you'd use for offline training.
20
- loss: str = "squared_hinge",
21
- label_weights: dict[str, float] | None = None,
22
- settings_path: str | None = None,
23
- ) -> typing.Generator[AxisArray | SampleMessage, ClassifierMessage | None, None]:
24
- """
25
- Passive Aggressive Classifier
26
- Online Passive-Aggressive Algorithms <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
27
- K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
28
-
29
- Args:
30
- alpha: Maximum step size (regularization)
31
- eta0: The initial learning rate for the 'adaptive’ schedules.
32
- loss: The loss function to be used:
33
- hinge: equivalent to PA-I in the reference paper.
34
- squared_hinge: equivalent to PA-II in the reference paper.
35
- label_weights: An optional dictionary of label names and their relative weight.
36
- e.g., {'Go': 31.0, 'Stop': 0.5}
37
- If this is None then settings_path must be provided and the pre-trained model
38
- settings_path: Path to the stored sklearn model pkl file.
39
-
40
- Returns:
41
- Generator that accepts `SampleMessage` for incremental training (`partial_fit`) and yields None,
42
- or `AxisArray` for inference (`predict`) and yields a `ClassifierMessage`.
43
- """
44
- # pre-init inputs and outputs
45
- msg_out = ClassifierMessage(data=np.array([]), dims=[""])
46
-
47
- # State variables:
48
-
49
- if settings_path is not None:
50
- import pickle
51
-
52
- with open(settings_path, "rb") as f:
53
- model = pickle.load(f)
54
- if label_weights is not None:
55
- model.class_weight = label_weights
56
- # Overwrite eta0, probably with a value lower than what was used online.
57
- model.eta0 = eta0
58
- else:
59
- model = SGDClassifier(
60
- loss=loss,
61
- alpha=alpha,
62
- penalty="elasticnet",
63
- learning_rate="adaptive",
64
- eta0=eta0,
65
- early_stopping=False,
66
- class_weight=label_weights,
67
- )
68
-
69
- b_first_train = True
70
- # TODO: template_out
71
-
72
- while True:
73
- msg_in: AxisArray | SampleMessage = yield msg_out
74
-
75
- msg_out = None
76
- if type(msg_in) is SampleMessage:
77
- # SampleMessage used for training.
78
- if not np.any(np.isnan(msg_in.sample.data)):
79
- train_sample = msg_in.sample.data.reshape(1, -1)
80
- if b_first_train:
81
- model.partial_fit(
82
- train_sample,
83
- [msg_in.trigger.value],
84
- classes=list(label_weights.keys()),
85
- )
86
- b_first_train = False
87
- else:
88
- model.partial_fit(train_sample, [msg_in.trigger.value])
89
- elif msg_in.data.size:
90
- # AxisArray used for inference
91
- if not np.any(np.isnan(msg_in.data)):
92
- try:
93
- X = msg_in.data.reshape((msg_in.data.shape[0], -1))
94
- result = model._predict_proba_lr(X)
95
- except NotFittedError:
96
- result = None
97
- if result is not None:
98
- out_axes = {}
99
- if msg_in.dims[0] in msg_in.axes:
100
- out_axes[msg_in.dims[0]] = replace(
101
- msg_in.axes[msg_in.dims[0]],
102
- offset=msg_in.axes[msg_in.dims[0]].offset,
103
- )
104
- msg_out = ClassifierMessage(
105
- data=result,
106
- dims=msg_in.dims[:1] + ["labels"],
107
- axes=out_axes,
108
- labels=list(model.class_weight.keys()),
109
- key=msg_in.key,
110
- )
111
-
112
-
113
18
  class SGDDecoderSettings(ez.Settings):
114
19
  alpha: float = 1e-5
115
20
  eta0: float = 3e-4
@@ -118,14 +23,94 @@ class SGDDecoderSettings(ez.Settings):
118
23
  settings_path: str | None = None
119
24
 
120
25
 
121
- class SGDDecoder(GenAxisArray):
122
- SETTINGS = SGDDecoderSettings
123
- INPUT_SAMPLE = ez.InputStream(SampleMessage)
26
+ @processor_state
27
+ class SGDDecoderState:
28
+ model: typing.Any = None
29
+ b_first_train: bool = True
124
30
 
125
- # Method to be implemented by subclasses to construct the specific generator
126
- def construct_generator(self):
127
- self.STATE.gen = sgd_decoder(**self.SETTINGS.__dict__)
128
31
 
129
- @ez.subscriber(INPUT_SAMPLE)
130
- async def on_sample(self, msg: SampleMessage) -> None:
131
- _ = self.STATE.gen.send(msg)
32
+ class SGDDecoderTransformer(BaseAdaptiveTransformer[SGDDecoderSettings, AxisArray, ClassifierMessage, SGDDecoderState]):
33
+ """
34
+ SGD-based online classifier.
35
+
36
+ Online Passive-Aggressive Algorithms
37
+ <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
38
+ K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
39
+ """
40
+
41
+ def _refreshed_model(self):
42
+ if self.settings.settings_path is not None:
43
+ import pickle
44
+
45
+ with open(self.settings.settings_path, "rb") as f:
46
+ model = pickle.load(f)
47
+ if self.settings.label_weights is not None:
48
+ model.class_weight = self.settings.label_weights
49
+ model.eta0 = self.settings.eta0
50
+ else:
51
+ model = SGDClassifier(
52
+ loss=self.settings.loss,
53
+ alpha=self.settings.alpha,
54
+ penalty="elasticnet",
55
+ learning_rate="adaptive",
56
+ eta0=self.settings.eta0,
57
+ early_stopping=False,
58
+ class_weight=self.settings.label_weights,
59
+ )
60
+ return model
61
+
62
+ def _reset_state(self, message: AxisArray) -> None:
63
+ self._state.model = self._refreshed_model()
64
+
65
+ def _process(self, message: AxisArray) -> ClassifierMessage | None:
66
+ if self._state.model is None or not message.data.size:
67
+ return None
68
+ if np.any(np.isnan(message.data)):
69
+ return None
70
+ try:
71
+ X = message.data.reshape((message.data.shape[0], -1))
72
+ result = self._state.model._predict_proba_lr(X)
73
+ except NotFittedError:
74
+ return None
75
+ out_axes = {}
76
+ if message.dims[0] in message.axes:
77
+ out_axes[message.dims[0]] = replace(
78
+ message.axes[message.dims[0]],
79
+ offset=message.axes[message.dims[0]].offset,
80
+ )
81
+ return ClassifierMessage(
82
+ data=result,
83
+ dims=message.dims[:1] + ["labels"],
84
+ axes=out_axes,
85
+ labels=list(self._state.model.class_weight.keys()),
86
+ key=message.key,
87
+ )
88
+
89
+ def partial_fit(self, message: AxisArray) -> None:
90
+ if self._hash != 0:
91
+ self._reset_state(message)
92
+ self._hash = 0
93
+
94
+ if np.any(np.isnan(message.data)):
95
+ return
96
+ train_sample = message.data.reshape(1, -1)
97
+ if self._state.b_first_train:
98
+ self._state.model.partial_fit(
99
+ train_sample,
100
+ [message.attrs["trigger"].value],
101
+ classes=list(self.settings.label_weights.keys()),
102
+ )
103
+ self._state.b_first_train = False
104
+ else:
105
+ self._state.model.partial_fit(train_sample, [message.attrs["trigger"].value])
106
+
107
+
108
+ class SGDDecoder(
109
+ BaseAdaptiveTransformerUnit[
110
+ SGDDecoderSettings,
111
+ AxisArray,
112
+ ClassifierMessage,
113
+ SGDDecoderTransformer,
114
+ ]
115
+ ):
116
+ SETTINGS = SGDDecoderSettings