ezmsg-learn 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ezmsg/learn/__init__.py +2 -0
  2. ezmsg/learn/__version__.py +34 -0
  3. ezmsg/learn/dim_reduce/__init__.py +0 -0
  4. ezmsg/learn/dim_reduce/adaptive_decomp.py +274 -0
  5. ezmsg/learn/dim_reduce/incremental_decomp.py +173 -0
  6. ezmsg/learn/linear_model/__init__.py +1 -0
  7. ezmsg/learn/linear_model/adaptive_linear_regressor.py +12 -0
  8. ezmsg/learn/linear_model/cca.py +1 -0
  9. ezmsg/learn/linear_model/linear_regressor.py +9 -0
  10. ezmsg/learn/linear_model/sgd.py +9 -0
  11. ezmsg/learn/linear_model/slda.py +12 -0
  12. ezmsg/learn/model/__init__.py +0 -0
  13. ezmsg/learn/model/cca.py +122 -0
  14. ezmsg/learn/model/mlp.py +127 -0
  15. ezmsg/learn/model/mlp_old.py +49 -0
  16. ezmsg/learn/model/refit_kalman.py +369 -0
  17. ezmsg/learn/model/rnn.py +160 -0
  18. ezmsg/learn/model/transformer.py +175 -0
  19. ezmsg/learn/nlin_model/__init__.py +1 -0
  20. ezmsg/learn/nlin_model/mlp.py +10 -0
  21. ezmsg/learn/process/__init__.py +0 -0
  22. ezmsg/learn/process/adaptive_linear_regressor.py +142 -0
  23. ezmsg/learn/process/base.py +154 -0
  24. ezmsg/learn/process/linear_regressor.py +95 -0
  25. ezmsg/learn/process/mlp_old.py +188 -0
  26. ezmsg/learn/process/refit_kalman.py +403 -0
  27. ezmsg/learn/process/rnn.py +245 -0
  28. ezmsg/learn/process/sgd.py +117 -0
  29. ezmsg/learn/process/sklearn.py +241 -0
  30. ezmsg/learn/process/slda.py +110 -0
  31. ezmsg/learn/process/ssr.py +374 -0
  32. ezmsg/learn/process/torch.py +362 -0
  33. ezmsg/learn/process/transformer.py +215 -0
  34. ezmsg/learn/util.py +67 -0
  35. ezmsg_learn-1.1.0.dist-info/METADATA +30 -0
  36. ezmsg_learn-1.1.0.dist-info/RECORD +38 -0
  37. ezmsg_learn-1.1.0.dist-info/WHEEL +4 -0
  38. ezmsg_learn-1.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,362 @@
1
+ import importlib
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ import torch
7
+ from ezmsg.baseproc import (
8
+ BaseAdaptiveTransformer,
9
+ BaseAdaptiveTransformerUnit,
10
+ BaseStatefulTransformer,
11
+ BaseTransformerUnit,
12
+ processor_state,
13
+ )
14
+ from ezmsg.baseproc.util.profile import profile_subpub
15
+ from ezmsg.sigproc.sampler import SampleMessage
16
+ from ezmsg.util.messages.axisarray import AxisArray
17
+ from ezmsg.util.messages.util import replace
18
+
19
+ from .base import ModelInitMixin
20
+
21
+
22
+ class TorchSimpleSettings(ez.Settings):
23
+ model_class: str
24
+ """
25
+ Fully qualified class path of the model to be used.
26
+ Example: "my_module.MyModelClass"
27
+ This class should inherit from `torch.nn.Module`.
28
+ """
29
+
30
+ checkpoint_path: str | None = None
31
+ """
32
+ Path to a checkpoint file containing model weights.
33
+ If None, the model will be initialized with random weights.
34
+ If parameters inferred from the weight sizes conflict with the settings / config,
35
+ then the the inferred parameters will take priority and a warning will be logged.
36
+ """
37
+
38
+ config_path: str | None = None
39
+ """
40
+ Path to a config file containing model parameters.
41
+ Parameters loaded from the config file will take priority over settings.
42
+ If settings differ from config parameters then a warning will be logged.
43
+ If `checkpoint_path` is provided then any parameters inferred from the weights
44
+ will take priority over the config parameters.
45
+ """
46
+
47
+ single_precision: bool = True
48
+ """Use single precision (float32) instead of double precision (float64)"""
49
+
50
+ device: str | None = None
51
+ """
52
+ Device to use for the model. If None, the device will be determined automatically,
53
+ with preference for cuda > mps > cpu.
54
+ """
55
+
56
+ model_kwargs: dict[str, typing.Any] | None = None
57
+ """
58
+ Additional keyword arguments to pass to the model constructor.
59
+ This can include parameters like `input_size`, `output_size`, etc.
60
+ If a config file is provided, these parameters will be updated with the config values.
61
+ If a checkpoint file is provided, these parameters will be updated with the inferred values
62
+ from the model weights.
63
+ """
64
+
65
+
66
+ class TorchModelSettings(TorchSimpleSettings):
67
+ learning_rate: float = 0.001
68
+ """Learning rate for the optimizer"""
69
+
70
+ weight_decay: float = 0.0001
71
+ """Weight decay for the optimizer"""
72
+
73
+ loss_fn: torch.nn.Module | dict[str, torch.nn.Module] | None = None
74
+ """
75
+ Loss function(s) for the decoder. If using multiple heads, this should be a dictionary
76
+ mapping head names to loss functions. The keys must match the output head names.
77
+ """
78
+
79
+ loss_weights: dict[str, float] | None = None
80
+ """
81
+ Weights for each loss function if using multiple heads.
82
+ The keys must match the output head names.
83
+ If None or missing/mismatched keys, losses will be unweighted.
84
+ """
85
+
86
+ scheduler_gamma: float = 0.999
87
+ """Learning scheduler decay rate. Set to 0.0 to disable the scheduler."""
88
+
89
+
90
+ @processor_state
91
+ class TorchSimpleState:
92
+ model: torch.nn.Module | None = None
93
+ device: torch.device | None = None
94
+ chan_ax: dict[str, AxisArray.CoordinateAxis] | None = None
95
+
96
+
97
+ class TorchModelState(TorchSimpleState):
98
+ optimizer: torch.optim.Optimizer | None = None
99
+ scheduler: torch.optim.lr_scheduler.LRScheduler | None = None
100
+
101
+
102
+ P = typing.TypeVar("P", bound=BaseStatefulTransformer)
103
+
104
+
105
+ class TorchProcessorMixin:
106
+ """Mixin with shared functionality for torch processors."""
107
+
108
+ def _import_model(self, class_path: str) -> type[torch.nn.Module]:
109
+ """Dynamically import model class from string."""
110
+ if class_path is None:
111
+ raise ValueError("Model class path must be provided in settings.")
112
+ module_path, class_name = class_path.rsplit(".", 1)
113
+ module = importlib.import_module(module_path)
114
+ return getattr(module, class_name)
115
+
116
+ def _infer_output_sizes(self: P, model: torch.nn.Module, n_input: int) -> dict[str, int]:
117
+ """Simple inference to get output channel size. Override if needed."""
118
+ dummy_input = torch.zeros(1, 1, n_input, device=self._state.device)
119
+ with torch.no_grad():
120
+ output = model(dummy_input)
121
+
122
+ if isinstance(output, dict):
123
+ return {k: v.shape[-1] for k, v in output.items()}
124
+ else:
125
+ return {"output": output.shape[-1]}
126
+
127
+ def _init_optimizer(self) -> None:
128
+ self._state.optimizer = torch.optim.AdamW(
129
+ self._state.model.parameters(),
130
+ lr=self.settings.learning_rate,
131
+ weight_decay=self.settings.weight_decay,
132
+ )
133
+ self._state.scheduler = (
134
+ torch.optim.lr_scheduler.ExponentialLR(self._state.optimizer, gamma=self.settings.scheduler_gamma)
135
+ if self.settings.scheduler_gamma > 0.0
136
+ else None
137
+ )
138
+
139
+ def _validate_loss_keys(self, output_keys: list[str]):
140
+ if isinstance(self.settings.loss_fn, dict):
141
+ missing = [k for k in output_keys if k not in self.settings.loss_fn]
142
+ if missing:
143
+ raise ValueError(f"Missing loss function(s) for output keys: {missing}")
144
+
145
+ def _to_tensor(self: P, data: np.ndarray) -> torch.Tensor:
146
+ dtype = torch.float32 if self.settings.single_precision else torch.float64
147
+ if isinstance(data, torch.Tensor):
148
+ return data.detach().clone().to(device=self._state.device, dtype=dtype)
149
+ return torch.tensor(data, dtype=dtype, device=self._state.device)
150
+
151
+ def save_checkpoint(self: P, path: str) -> None:
152
+ """Save the current model state to a checkpoint file."""
153
+ if self._state.model is None:
154
+ raise RuntimeError("Model must be initialized before saving a checkpoint.")
155
+
156
+ checkpoint = {
157
+ "model_state_dict": self._state.model.state_dict(),
158
+ "config": self.settings.model_kwargs or {},
159
+ }
160
+
161
+ # Add optimizer state if available
162
+ if hasattr(self._state, "optimizer") and self._state.optimizer is not None:
163
+ checkpoint["optimizer_state_dict"] = self._state.optimizer.state_dict()
164
+
165
+ torch.save(checkpoint, path)
166
+
167
+ def _ensure_batched(self, tensor: torch.Tensor) -> tuple[torch.Tensor, bool]:
168
+ """
169
+ Ensure tensor has a batch dimension.
170
+ Returns the potentially modified tensor and a flag indicating whether a dimension was added.
171
+ """
172
+ if tensor.ndim == 2:
173
+ return tensor.unsqueeze(0), True
174
+ return tensor, False
175
+
176
+ def _common_process(self: P, message: AxisArray) -> list[AxisArray]:
177
+ data = message.data
178
+ data = self._to_tensor(data)
179
+
180
+ # Add batch dimension if missing
181
+ data, added_batch_dim = self._ensure_batched(data)
182
+
183
+ with torch.no_grad():
184
+ output = self._state.model(data)
185
+
186
+ if isinstance(output, dict):
187
+ output_messages = [
188
+ replace(
189
+ message,
190
+ data=value.cpu().numpy().squeeze(0) if added_batch_dim else value.cpu().numpy(),
191
+ axes={
192
+ **message.axes,
193
+ "ch": self._state.chan_ax[key],
194
+ },
195
+ key=key,
196
+ )
197
+ for key, value in output.items()
198
+ ]
199
+ return output_messages
200
+
201
+ return [
202
+ replace(
203
+ message,
204
+ data=output.cpu().numpy().squeeze(0) if added_batch_dim else output.cpu().numpy(),
205
+ axes={
206
+ **message.axes,
207
+ "ch": self._state.chan_ax["output"],
208
+ },
209
+ )
210
+ ]
211
+
212
+ def _common_reset_state(self: P, message: AxisArray, model_kwargs: dict) -> None:
213
+ n_input = message.data.shape[message.get_axis_idx("ch")]
214
+
215
+ if "input_size" in model_kwargs:
216
+ if model_kwargs["input_size"] != n_input:
217
+ raise ValueError(
218
+ f"Mismatch between model_kwargs['input_size']={model_kwargs['input_size']} "
219
+ f"and input data channels={n_input}"
220
+ )
221
+ else:
222
+ model_kwargs["input_size"] = n_input
223
+
224
+ device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
225
+ device = self.settings.device or device
226
+ self._state.device = torch.device(device)
227
+
228
+ model_class = self._import_model(self.settings.model_class)
229
+
230
+ self._state.model = self._init_model(
231
+ model_class=model_class,
232
+ params=model_kwargs,
233
+ config_path=self.settings.config_path,
234
+ checkpoint_path=self.settings.checkpoint_path,
235
+ device=device,
236
+ )
237
+
238
+ self._state.model.eval()
239
+
240
+ output_sizes = self._infer_output_sizes(self._state.model, n_input)
241
+ self._state.chan_ax = {
242
+ head: AxisArray.CoordinateAxis(
243
+ data=np.array([f"{head}_ch{_}" for _ in range(size)]),
244
+ dims=["ch"],
245
+ )
246
+ for head, size in output_sizes.items()
247
+ }
248
+
249
+
250
+ class TorchSimpleProcessor(
251
+ BaseStatefulTransformer[TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState],
252
+ TorchProcessorMixin,
253
+ ModelInitMixin,
254
+ ):
255
+ def _reset_state(self, message: AxisArray) -> None:
256
+ model_kwargs = dict(self.settings.model_kwargs or {})
257
+ self._common_reset_state(message, model_kwargs)
258
+
259
+ def _process(self, message: AxisArray) -> list[AxisArray]:
260
+ """Process the input message and return the output messages."""
261
+ return self._common_process(message)
262
+
263
+
264
+ class TorchSimpleUnit(
265
+ BaseTransformerUnit[
266
+ TorchSimpleSettings,
267
+ AxisArray,
268
+ AxisArray,
269
+ TorchSimpleProcessor,
270
+ ]
271
+ ):
272
+ SETTINGS = TorchSimpleSettings
273
+
274
+ @ez.subscriber(BaseTransformerUnit.INPUT_SIGNAL, zero_copy=True)
275
+ @ez.publisher(BaseTransformerUnit.OUTPUT_SIGNAL)
276
+ @profile_subpub(trace_oldest=False)
277
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
278
+ results = await self.processor.__acall__(message)
279
+ for result in results:
280
+ yield self.OUTPUT_SIGNAL, result
281
+
282
+
283
+ class TorchModelProcessor(
284
+ BaseAdaptiveTransformer[TorchModelSettings, AxisArray, AxisArray, TorchModelState],
285
+ TorchProcessorMixin,
286
+ ModelInitMixin,
287
+ ):
288
+ def _reset_state(self, message: AxisArray) -> None:
289
+ model_kwargs = dict(self.settings.model_kwargs or {})
290
+ self._common_reset_state(message, model_kwargs)
291
+ self._init_optimizer()
292
+ self._validate_loss_keys(list(self._state.chan_ax.keys()))
293
+
294
+ def _process(self, message: AxisArray) -> list[AxisArray]:
295
+ return self._common_process(message)
296
+
297
+ def partial_fit(self, message: SampleMessage) -> None:
298
+ self._state.model.train()
299
+
300
+ X = self._to_tensor(message.sample.data)
301
+ X, batched = self._ensure_batched(X)
302
+
303
+ y_targ = message.trigger.value
304
+ if not isinstance(y_targ, dict):
305
+ y_targ = {"output": y_targ}
306
+ y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
307
+ if batched:
308
+ for key in y_targ:
309
+ y_targ[key] = y_targ[key].unsqueeze(0)
310
+
311
+ loss_fns = self.settings.loss_fn
312
+ if loss_fns is None:
313
+ raise ValueError("loss_fn must be provided in settings to use partial_fit")
314
+ if not isinstance(loss_fns, dict):
315
+ loss_fns = {k: loss_fns for k in y_targ.keys()}
316
+
317
+ weights = self.settings.loss_weights or {}
318
+
319
+ with torch.set_grad_enabled(True):
320
+ y_pred = self._state.model(X)
321
+ if not isinstance(y_pred, dict):
322
+ y_pred = {"output": y_pred}
323
+
324
+ losses = []
325
+ for key in y_targ.keys():
326
+ loss_fn = loss_fns.get(key)
327
+ if loss_fn is None:
328
+ raise ValueError(f"Loss function for key '{key}' is not defined in settings.")
329
+ if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
330
+ loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long())
331
+ else:
332
+ loss = loss_fn(y_pred[key], y_targ[key])
333
+ weight = weights.get(key, 1.0)
334
+ losses.append(loss * weight)
335
+ total_loss = sum(losses)
336
+
337
+ self._state.optimizer.zero_grad()
338
+ total_loss.backward()
339
+ self._state.optimizer.step()
340
+ if self._state.scheduler is not None:
341
+ self._state.scheduler.step()
342
+
343
+ self._state.model.eval()
344
+
345
+
346
+ class TorchModelUnit(
347
+ BaseAdaptiveTransformerUnit[
348
+ TorchModelSettings,
349
+ AxisArray,
350
+ AxisArray,
351
+ TorchModelProcessor,
352
+ ]
353
+ ):
354
+ SETTINGS = TorchModelSettings
355
+
356
+ @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
357
+ @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
358
+ @profile_subpub(trace_oldest=False)
359
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
360
+ results = await self.processor.__acall__(message)
361
+ for result in results:
362
+ yield self.OUTPUT_SIGNAL, result
@@ -0,0 +1,215 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import torch
5
+ from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
6
+ from ezmsg.baseproc.util.profile import profile_subpub
7
+ from ezmsg.sigproc.sampler import SampleMessage
8
+ from ezmsg.util.messages.axisarray import AxisArray
9
+ from ezmsg.util.messages.util import replace
10
+
11
+ from .base import ModelInitMixin
12
+ from .torch import (
13
+ TorchModelSettings,
14
+ TorchModelState,
15
+ TorchProcessorMixin,
16
+ )
17
+
18
+
19
+ class TransformerSettings(TorchModelSettings):
20
+ model_class: str = "ezmsg.learn.model.transformer.TransformerModel"
21
+ """
22
+ Fully qualified class path of the model to be used.
23
+ This should be "ezmsg.learn.model.transformer.TransformerModel" for this.
24
+ """
25
+ autoregressive_head: str | None = None
26
+ """
27
+ The name of the output head used for autoregressive decoding.
28
+ This should match one of the keys in the model's output dictionary.
29
+ If None, the first output head will be used.
30
+ """
31
+ max_cache_len: int | None = 128
32
+ """
33
+ Maximum length of the target sequence cache for autoregressive decoding.
34
+ This limits the context length during decoding to prevent excessive memory usage.
35
+ If set to None, the cache will grow indefinitely.
36
+ """
37
+
38
+
39
+ class TransformerState(TorchModelState):
40
+ ar_head: str | None = None
41
+ """
42
+ The name of the autoregressive head used for decoding.
43
+ This is set based on the `autoregressive_head` setting.
44
+ If None, the first output head will be used.
45
+ """
46
+ tgt_cache: typing.Optional[torch.Tensor] = None
47
+ """
48
+ Cache for the target sequence used in autoregressive decoding.
49
+ This is updated with each processed message to maintain context.
50
+ """
51
+
52
+
53
+ class TransformerProcessor(
54
+ BaseAdaptiveTransformer[TransformerSettings, AxisArray, AxisArray, TransformerState],
55
+ TorchProcessorMixin,
56
+ ModelInitMixin,
57
+ ):
58
+ @property
59
+ def has_decoder(self) -> bool:
60
+ return self.settings.model_kwargs.get("decoder_layers", 0) > 0
61
+
62
+ def reset_cache(self) -> None:
63
+ self._state.tgt_cache = None
64
+
65
+ def _reset_state(self, message: AxisArray) -> None:
66
+ model_kwargs = dict(self.settings.model_kwargs or {})
67
+ self._common_reset_state(message, model_kwargs)
68
+ self._init_optimizer()
69
+ self._validate_loss_keys(list(self._state.chan_ax.keys()))
70
+
71
+ self._state.tgt_cache = None
72
+ if (
73
+ self.settings.autoregressive_head is not None
74
+ and self.settings.autoregressive_head not in self._state.chan_ax
75
+ ):
76
+ raise ValueError(
77
+ f"Autoregressive head '{self.settings.autoregressive_head}' not found in target"
78
+ f"dictionary keys: {list(self._state.chan_ax.keys())}"
79
+ )
80
+ self._state.ar_head = (
81
+ self.settings.autoregressive_head
82
+ if self.settings.autoregressive_head is not None
83
+ else list(self._state.chan_ax.keys())[0]
84
+ )
85
+
86
+ def _process(self, message: AxisArray) -> list[AxisArray]:
87
+ # If has_decoder is False, fallback to regular processing
88
+ if not self.has_decoder:
89
+ return self._common_process(message)
90
+
91
+ x = self._to_tensor(message.data)
92
+ x, _ = self._ensure_batched(x)
93
+ if x.shape[0] > 1:
94
+ raise ValueError("Autoregressive decoding only supports batch size 1.")
95
+
96
+ with torch.no_grad():
97
+ y_pred = self._state.model(x, tgt=self._state.tgt_cache)
98
+
99
+ pred = y_pred[self._state.ar_head]
100
+ if self._state.tgt_cache is None:
101
+ self._state.tgt_cache = pred[:, -1:, :]
102
+ else:
103
+ self._state.tgt_cache = torch.cat([self._state.tgt_cache, pred[:, -1:, :]], dim=1)
104
+ if self.settings.max_cache_len is not None:
105
+ if self._state.tgt_cache.shape[1] > self.settings.max_cache_len:
106
+ # Trim the cache to the maximum length
107
+ self._state.tgt_cache = self._state.tgt_cache[:, -self.settings.max_cache_len :, :]
108
+
109
+ if isinstance(y_pred, dict):
110
+ return [
111
+ replace(
112
+ message,
113
+ data=out.squeeze(0).cpu().numpy(),
114
+ axes={**message.axes, "ch": self._state.chan_ax[key]},
115
+ key=key,
116
+ )
117
+ for key, out in y_pred.items()
118
+ ]
119
+ else:
120
+ return [
121
+ replace(
122
+ message,
123
+ data=y_pred.squeeze(0).cpu().numpy(),
124
+ axes={**message.axes, "ch": self._state.chan_ax["output"]},
125
+ )
126
+ ]
127
+
128
+ def partial_fit(self, message: SampleMessage) -> None:
129
+ self._state.model.train()
130
+
131
+ X = self._to_tensor(message.sample.data)
132
+ X, batched = self._ensure_batched(X)
133
+
134
+ y_targ = message.trigger.value
135
+ if not isinstance(y_targ, dict):
136
+ y_targ = {"output": y_targ}
137
+ y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
138
+ # Add batch dimension to y_targ values if missing
139
+ if batched:
140
+ for key in y_targ:
141
+ y_targ[key] = y_targ[key].unsqueeze(0)
142
+
143
+ loss_fns = self.settings.loss_fn
144
+ if loss_fns is None:
145
+ raise ValueError("loss_fn must be provided in settings to use partial_fit")
146
+ if not isinstance(loss_fns, dict):
147
+ loss_fns = {k: loss_fns for k in y_targ.keys()}
148
+
149
+ weights = self.settings.loss_weights or {}
150
+
151
+ if self.has_decoder:
152
+ if X.shape[0] != 1:
153
+ raise ValueError("Autoregressive decoding only supports batch size 1.")
154
+
155
+ # Create shifted target for autoregressive head
156
+ tgt_tensor = y_targ[self._state.ar_head]
157
+ tgt = torch.cat(
158
+ [
159
+ torch.zeros(
160
+ (1, 1, tgt_tensor.shape[-1]),
161
+ dtype=tgt_tensor.dtype,
162
+ device=tgt_tensor.device,
163
+ ),
164
+ tgt_tensor[:, :-1, :],
165
+ ],
166
+ dim=1,
167
+ )
168
+
169
+ # Reset tgt_cache at start of partial_fit to avoid stale context
170
+ self.reset_cache()
171
+ y_pred = self._state.model(X, tgt=tgt)
172
+ else:
173
+ # For non-autoregressive models, use the model directly
174
+ y_pred = self._state.model(X)
175
+
176
+ if not isinstance(y_pred, dict):
177
+ y_pred = {"output": y_pred}
178
+
179
+ with torch.set_grad_enabled(True):
180
+ losses = []
181
+ for key in y_targ.keys():
182
+ loss_fn = loss_fns.get(key)
183
+ if loss_fn is None:
184
+ raise ValueError(f"Loss function for key '{key}' is not defined in settings.")
185
+ loss = loss_fn(y_pred[key], y_targ[key])
186
+ weight = weights.get(key, 1.0)
187
+ losses.append(loss * weight)
188
+ total_loss = sum(losses)
189
+
190
+ self._state.optimizer.zero_grad()
191
+ total_loss.backward()
192
+ self._state.optimizer.step()
193
+ if self._state.scheduler is not None:
194
+ self._state.scheduler.step()
195
+
196
+ self._state.model.eval()
197
+
198
+
199
+ class TransformerUnit(
200
+ BaseAdaptiveTransformerUnit[
201
+ TransformerSettings,
202
+ AxisArray,
203
+ AxisArray,
204
+ TransformerProcessor,
205
+ ]
206
+ ):
207
+ SETTINGS = TransformerSettings
208
+
209
+ @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
210
+ @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
211
+ @profile_subpub(trace_oldest=False)
212
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
213
+ results = await self.processor.__acall__(message)
214
+ for result in results:
215
+ yield self.OUTPUT_SIGNAL, result
ezmsg/learn/util.py ADDED
@@ -0,0 +1,67 @@
1
+ import typing
2
+ from dataclasses import dataclass, field
3
+ from enum import Enum
4
+
5
+ import river.linear_model
6
+ import sklearn.linear_model
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
9
+ # from sklearn.neural_network import MLPClassifier
10
+
11
+
12
+ class RegressorType(str, Enum):
13
+ ADAPTIVE = "adaptive"
14
+ STATIC = "static"
15
+
16
+
17
+ class AdaptiveLinearRegressor(str, Enum):
18
+ LINEAR = "linear"
19
+ LOGISTIC = "logistic"
20
+ SGD = "sgd"
21
+ PAR = "par" # passive-aggressive
22
+ # MLP = "mlp"
23
+
24
+
25
+ class StaticLinearRegressor(str, Enum):
26
+ LINEAR = "linear"
27
+ RIDGE = "ridge"
28
+
29
+
30
+ ADAPTIVE_REGRESSORS = {
31
+ AdaptiveLinearRegressor.LINEAR: river.linear_model.LinearRegression,
32
+ AdaptiveLinearRegressor.LOGISTIC: river.linear_model.LogisticRegression,
33
+ AdaptiveLinearRegressor.SGD: sklearn.linear_model.SGDRegressor,
34
+ AdaptiveLinearRegressor.PAR: sklearn.linear_model.PassiveAggressiveRegressor,
35
+ # AdaptiveLinearRegressor.MLP: MLPClassifier,
36
+ }
37
+
38
+
39
+ # Function to get a regressor by type and name
40
+ def get_regressor(
41
+ regressor_type: typing.Union[RegressorType, str],
42
+ regressor_name: typing.Union[AdaptiveLinearRegressor, StaticLinearRegressor, str],
43
+ ):
44
+ if isinstance(regressor_type, str):
45
+ regressor_type = RegressorType(regressor_type)
46
+
47
+ if regressor_type == RegressorType.ADAPTIVE:
48
+ if isinstance(regressor_name, str):
49
+ regressor_name = AdaptiveLinearRegressor(regressor_name)
50
+ return ADAPTIVE_REGRESSORS[regressor_name]
51
+ elif regressor_type == RegressorType.STATIC:
52
+ if isinstance(regressor_name, str):
53
+ regressor_name = StaticLinearRegressor(regressor_name)
54
+ return STATIC_REGRESSORS[regressor_name]
55
+ else:
56
+ raise ValueError(f"Unknown regressor type: {regressor_type}")
57
+
58
+
59
+ STATIC_REGRESSORS = {
60
+ StaticLinearRegressor.LINEAR: sklearn.linear_model.LinearRegression,
61
+ StaticLinearRegressor.RIDGE: sklearn.linear_model.Ridge,
62
+ }
63
+
64
+
65
+ @dataclass
66
+ class ClassifierMessage(AxisArray):
67
+ labels: list[str] = field(default_factory=list)
@@ -0,0 +1,30 @@
1
+ Metadata-Version: 2.4
2
+ Name: ezmsg-learn
3
+ Version: 1.1.0
4
+ Summary: ezmsg namespace package for machine learning
5
+ Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
6
+ License-Expression: MIT
7
+ License-File: LICENSE
8
+ Requires-Python: >=3.10.15
9
+ Requires-Dist: ezmsg-baseproc>=1.0.2
10
+ Requires-Dist: ezmsg-sigproc>=2.14.0
11
+ Requires-Dist: river>=0.22.0
12
+ Requires-Dist: scikit-learn>=1.6.0
13
+ Requires-Dist: torch>=2.6.0
14
+ Description-Content-Type: text/markdown
15
+
16
+ # ezmsg-learn
17
+
18
+ This repository contains a Python package with modules for machine learning (ML)-related processing in the [`ezmsg`](https://www.ezmsg.org) framework. As ezmsg is intended primarily for processing unbounded streaming signals, so are the modules in this repo.
19
+
20
+ > If you are only interested in offline analysis without concern for reproducibility in online applications, then you should probably look elsewhere.
21
+
22
+ Processing units include dimensionality reduction, linear regression, and classification that can be initialized with known weights, or adapted on-the-fly with incoming (labeled) data. Machine-learning code depends on `river`, `scikit-learn`, `numpy`, and `torch`.
23
+
24
+ ## Getting Started
25
+
26
+ This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
27
+
28
+ ```bash
29
+ pip install git+https://github.com/ezmsg-org/ezmsg-learn
30
+ ```