ezmsg-learn 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,119 @@
1
+ import typing
2
+
3
+ import ezmsg.core as ez
4
+ import numpy as np
5
+ from ezmsg.sigproc.base import (
6
+ BaseStatefulTransformer,
7
+ processor_state,
8
+ BaseTransformerUnit,
9
+ )
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
+ from ezmsg.util.messages.util import replace
12
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
13
+
14
+ from ..util import ClassifierMessage
15
+
16
+
17
+ class SLDASettings(ez.Settings):
18
+ settings_path: str
19
+ axis: str = "time"
20
+
21
+
22
+ @processor_state
23
+ class SLDAState:
24
+ lda: LDA
25
+ out_template: typing.Optional[ClassifierMessage] = None
26
+
27
+
28
+ class SLDATransformer(
29
+ BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]
30
+ ):
31
+ def _reset_state(self, message: AxisArray) -> None:
32
+ if self.settings.settings_path[-4:] == ".mat":
33
+ # Expects a very specific format from a specific project. Not for general use.
34
+ import scipy.io as sio
35
+
36
+ matlab_sLDA = sio.loadmat(self.settings.settings_path, squeeze_me=True)
37
+ temp_weights = matlab_sLDA["weights"][1, 1:]
38
+ temp_intercept = matlab_sLDA["weights"][1, 0]
39
+
40
+ # Create weights and use zeros for channels we do not keep.
41
+ channels = matlab_sLDA["channels"] - 4
42
+ channels -= channels[0] # Offsets are wrong somehow.
43
+ n_channels = message.data.shape[message.dims.index("ch")]
44
+ valid_indices = [ch for ch in channels if ch < n_channels]
45
+ full_weights = np.zeros(n_channels)
46
+ full_weights[valid_indices] = temp_weights[: len(valid_indices)]
47
+
48
+ lda = LDA(solver="lsqr", shrinkage="auto")
49
+ lda.classes_ = np.asarray([0, 1])
50
+ lda.coef_ = np.expand_dims(full_weights, axis=0)
51
+ lda.intercept_ = temp_intercept # TODO: Is this supposed to be per-channel? Why the [1, 0]?
52
+ self.state.lda = lda
53
+ # mean = matlab_sLDA['mXtrain']
54
+ # std = matlab_sLDA['sXtrain']
55
+ # lags = matlab_sLDA['lags'] + 1
56
+ else:
57
+ import pickle
58
+
59
+ with open(self.settings.settings_path, "rb") as f:
60
+ self.state.lda = pickle.load(f)
61
+
62
+ # Create template ClassifierMessage using lda.classes_
63
+ out_labels = self.state.lda.classes_.tolist()
64
+ zero_shape = (0, len(out_labels))
65
+ self.state.out_template = ClassifierMessage(
66
+ data=np.zeros(zero_shape, dtype=message.data.dtype),
67
+ dims=[self.settings.axis, "classes"],
68
+ axes={
69
+ self.settings.axis: message.axes[self.settings.axis],
70
+ "classes": AxisArray.CoordinateAxis(
71
+ data=np.array(out_labels), dims=["classes"]
72
+ ),
73
+ },
74
+ labels=out_labels,
75
+ key=message.key,
76
+ )
77
+
78
+ def _process(self, message: AxisArray) -> ClassifierMessage:
79
+ samp_ax_idx = message.dims.index(self.settings.axis)
80
+ X = np.moveaxis(message.data, samp_ax_idx, 0)
81
+
82
+ if X.shape[0]:
83
+ if (
84
+ isinstance(self.settings.settings_path, str)
85
+ and self.settings.settings_path[-4:] == ".mat"
86
+ ):
87
+ # Assumes F-contiguous weights
88
+ pred_probas = []
89
+ for samp in X:
90
+ tmp = samp.flatten(order="F") * 1e-6
91
+ tmp = np.expand_dims(tmp, axis=0)
92
+ probas = self.state.lda.predict_proba(tmp)
93
+ pred_probas.append(probas)
94
+ pred_probas = np.concatenate(pred_probas, axis=0)
95
+ else:
96
+ # This creates a copy.
97
+ X = X.reshape(X.shape[0], -1)
98
+ pred_probas = self.state.lda.predict_proba(X)
99
+
100
+ update_ax = self.state.out_template.axes[self.settings.axis]
101
+ update_ax.offset = message.axes[self.settings.axis].offset
102
+
103
+ return replace(
104
+ self.state.out_template,
105
+ data=pred_probas,
106
+ axes={
107
+ **self.state.out_template.axes,
108
+ # `replace` will copy the minimal set of fields
109
+ self.settings.axis: replace(update_ax, offset=update_ax.offset),
110
+ },
111
+ )
112
+ else:
113
+ return self.state.out_template
114
+
115
+
116
+ class SLDA(
117
+ BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]
118
+ ):
119
+ SETTINGS = SLDASettings
@@ -0,0 +1,378 @@
1
+ import importlib
2
+ import typing
3
+
4
+ import ezmsg.core as ez
5
+ import numpy as np
6
+ import torch
7
+ from ezmsg.sigproc.base import (
8
+ BaseAdaptiveTransformer,
9
+ BaseAdaptiveTransformerUnit,
10
+ BaseStatefulTransformer,
11
+ BaseTransformerUnit,
12
+ processor_state,
13
+ )
14
+ from ezmsg.sigproc.sampler import SampleMessage
15
+ from ezmsg.sigproc.util.profile import profile_subpub
16
+ from ezmsg.util.messages.axisarray import AxisArray
17
+ from ezmsg.util.messages.util import replace
18
+
19
+ from .base import ModelInitMixin
20
+
21
+
22
+ class TorchSimpleSettings(ez.Settings):
23
+ model_class: str
24
+ """
25
+ Fully qualified class path of the model to be used.
26
+ Example: "my_module.MyModelClass"
27
+ This class should inherit from `torch.nn.Module`.
28
+ """
29
+
30
+ checkpoint_path: str | None = None
31
+ """
32
+ Path to a checkpoint file containing model weights.
33
+ If None, the model will be initialized with random weights.
34
+ If parameters inferred from the weight sizes conflict with the settings / config,
35
+ then the the inferred parameters will take priority and a warning will be logged.
36
+ """
37
+
38
+ config_path: str | None = None
39
+ """
40
+ Path to a config file containing model parameters.
41
+ Parameters loaded from the config file will take priority over settings.
42
+ If settings differ from config parameters then a warning will be logged.
43
+ If `checkpoint_path` is provided then any parameters inferred from the weights
44
+ will take priority over the config parameters.
45
+ """
46
+
47
+ single_precision: bool = True
48
+ """Use single precision (float32) instead of double precision (float64)"""
49
+
50
+ device: str | None = None
51
+ """
52
+ Device to use for the model. If None, the device will be determined automatically,
53
+ with preference for cuda > mps > cpu.
54
+ """
55
+
56
+ model_kwargs: dict[str, typing.Any] | None = None
57
+ """
58
+ Additional keyword arguments to pass to the model constructor.
59
+ This can include parameters like `input_size`, `output_size`, etc.
60
+ If a config file is provided, these parameters will be updated with the config values.
61
+ If a checkpoint file is provided, these parameters will be updated with the inferred values
62
+ from the model weights.
63
+ """
64
+
65
+
66
+ class TorchModelSettings(TorchSimpleSettings):
67
+ learning_rate: float = 0.001
68
+ """Learning rate for the optimizer"""
69
+
70
+ weight_decay: float = 0.0001
71
+ """Weight decay for the optimizer"""
72
+
73
+ loss_fn: torch.nn.Module | dict[str, torch.nn.Module] | None = None
74
+ """
75
+ Loss function(s) for the decoder. If using multiple heads, this should be a dictionary
76
+ mapping head names to loss functions. The keys must match the output head names.
77
+ """
78
+
79
+ loss_weights: dict[str, float] | None = None
80
+ """
81
+ Weights for each loss function if using multiple heads.
82
+ The keys must match the output head names.
83
+ If None or missing/mismatched keys, losses will be unweighted.
84
+ """
85
+
86
+ scheduler_gamma: float = 0.999
87
+ """Learning scheduler decay rate. Set to 0.0 to disable the scheduler."""
88
+
89
+
90
+ @processor_state
91
+ class TorchSimpleState:
92
+ model: torch.nn.Module | None = None
93
+ device: torch.device | None = None
94
+ chan_ax: dict[str, AxisArray.CoordinateAxis] | None = None
95
+
96
+
97
+ class TorchModelState(TorchSimpleState):
98
+ optimizer: torch.optim.Optimizer | None = None
99
+ scheduler: torch.optim.lr_scheduler.LRScheduler | None = None
100
+
101
+
102
+ P = typing.TypeVar("P", bound=BaseStatefulTransformer)
103
+
104
+
105
+ class TorchProcessorMixin:
106
+ """Mixin with shared functionality for torch processors."""
107
+
108
+ def _import_model(self, class_path: str) -> type[torch.nn.Module]:
109
+ """Dynamically import model class from string."""
110
+ if class_path is None:
111
+ raise ValueError("Model class path must be provided in settings.")
112
+ module_path, class_name = class_path.rsplit(".", 1)
113
+ module = importlib.import_module(module_path)
114
+ return getattr(module, class_name)
115
+
116
+ def _infer_output_sizes(
117
+ self: P, model: torch.nn.Module, n_input: int
118
+ ) -> dict[str, int]:
119
+ """Simple inference to get output channel size. Override if needed."""
120
+ dummy_input = torch.zeros(1, 1, n_input, device=self._state.device)
121
+ with torch.no_grad():
122
+ output = model(dummy_input)
123
+
124
+ if isinstance(output, dict):
125
+ return {k: v.shape[-1] for k, v in output.items()}
126
+ else:
127
+ return {"output": output.shape[-1]}
128
+
129
+ def _init_optimizer(self) -> None:
130
+ self._state.optimizer = torch.optim.AdamW(
131
+ self._state.model.parameters(),
132
+ lr=self.settings.learning_rate,
133
+ weight_decay=self.settings.weight_decay,
134
+ )
135
+ self._state.scheduler = (
136
+ torch.optim.lr_scheduler.ExponentialLR(
137
+ self._state.optimizer, gamma=self.settings.scheduler_gamma
138
+ )
139
+ if self.settings.scheduler_gamma > 0.0
140
+ else None
141
+ )
142
+
143
+ def _validate_loss_keys(self, output_keys: list[str]):
144
+ if isinstance(self.settings.loss_fn, dict):
145
+ missing = [k for k in output_keys if k not in self.settings.loss_fn]
146
+ if missing:
147
+ raise ValueError(f"Missing loss function(s) for output keys: {missing}")
148
+
149
+ def _to_tensor(self: P, data: np.ndarray) -> torch.Tensor:
150
+ dtype = torch.float32 if self.settings.single_precision else torch.float64
151
+ if isinstance(data, torch.Tensor):
152
+ return data.detach().clone().to(device=self._state.device, dtype=dtype)
153
+ return torch.tensor(data, dtype=dtype, device=self._state.device)
154
+
155
+ def save_checkpoint(self: P, path: str) -> None:
156
+ """Save the current model state to a checkpoint file."""
157
+ if self._state.model is None:
158
+ raise RuntimeError("Model must be initialized before saving a checkpoint.")
159
+
160
+ checkpoint = {
161
+ "model_state_dict": self._state.model.state_dict(),
162
+ "config": self.settings.model_kwargs or {},
163
+ }
164
+
165
+ # Add optimizer state if available
166
+ if hasattr(self._state, "optimizer") and self._state.optimizer is not None:
167
+ checkpoint["optimizer_state_dict"] = self._state.optimizer.state_dict()
168
+
169
+ torch.save(checkpoint, path)
170
+
171
+ def _ensure_batched(self, tensor: torch.Tensor) -> tuple[torch.Tensor, bool]:
172
+ """
173
+ Ensure tensor has a batch dimension.
174
+ Returns the potentially modified tensor and a flag indicating whether a dimension was added.
175
+ """
176
+ if tensor.ndim == 2:
177
+ return tensor.unsqueeze(0), True
178
+ return tensor, False
179
+
180
+ def _common_process(self: P, message: AxisArray) -> list[AxisArray]:
181
+ data = message.data
182
+ data = self._to_tensor(data)
183
+
184
+ # Add batch dimension if missing
185
+ data, added_batch_dim = self._ensure_batched(data)
186
+
187
+ with torch.no_grad():
188
+ output = self._state.model(data)
189
+
190
+ if isinstance(output, dict):
191
+ output_messages = [
192
+ replace(
193
+ message,
194
+ data=value.cpu().numpy().squeeze(0)
195
+ if added_batch_dim
196
+ else value.cpu().numpy(),
197
+ axes={
198
+ **message.axes,
199
+ "ch": self._state.chan_ax[key],
200
+ },
201
+ key=key,
202
+ )
203
+ for key, value in output.items()
204
+ ]
205
+ return output_messages
206
+
207
+ return [
208
+ replace(
209
+ message,
210
+ data=output.cpu().numpy().squeeze(0)
211
+ if added_batch_dim
212
+ else output.cpu().numpy(),
213
+ axes={
214
+ **message.axes,
215
+ "ch": self._state.chan_ax["output"],
216
+ },
217
+ )
218
+ ]
219
+
220
+ def _common_reset_state(self: P, message: AxisArray, model_kwargs: dict) -> None:
221
+ n_input = message.data.shape[message.get_axis_idx("ch")]
222
+
223
+ if "input_size" in model_kwargs:
224
+ if model_kwargs["input_size"] != n_input:
225
+ raise ValueError(
226
+ f"Mismatch between model_kwargs['input_size']={model_kwargs['input_size']} "
227
+ f"and input data channels={n_input}"
228
+ )
229
+ else:
230
+ model_kwargs["input_size"] = n_input
231
+
232
+ device = (
233
+ "cuda"
234
+ if torch.cuda.is_available()
235
+ else ("mps" if torch.mps.is_available() else "cpu")
236
+ )
237
+ device = self.settings.device or device
238
+ self._state.device = torch.device(device)
239
+
240
+ model_class = self._import_model(self.settings.model_class)
241
+
242
+ self._state.model = self._init_model(
243
+ model_class=model_class,
244
+ params=model_kwargs,
245
+ config_path=self.settings.config_path,
246
+ checkpoint_path=self.settings.checkpoint_path,
247
+ device=device,
248
+ )
249
+
250
+ self._state.model.eval()
251
+
252
+ output_sizes = self._infer_output_sizes(self._state.model, n_input)
253
+ self._state.chan_ax = {
254
+ head: AxisArray.CoordinateAxis(
255
+ data=np.array([f"{head}_ch{_}" for _ in range(size)]),
256
+ dims=["ch"],
257
+ )
258
+ for head, size in output_sizes.items()
259
+ }
260
+
261
+
262
+ class TorchSimpleProcessor(
263
+ BaseStatefulTransformer[
264
+ TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState
265
+ ],
266
+ TorchProcessorMixin,
267
+ ModelInitMixin,
268
+ ):
269
+ def _reset_state(self, message: AxisArray) -> None:
270
+ model_kwargs = dict(self.settings.model_kwargs or {})
271
+ self._common_reset_state(message, model_kwargs)
272
+
273
+ def _process(self, message: AxisArray) -> list[AxisArray]:
274
+ """Process the input message and return the output messages."""
275
+ return self._common_process(message)
276
+
277
+
278
+ class TorchSimpleUnit(
279
+ BaseTransformerUnit[
280
+ TorchSimpleSettings,
281
+ AxisArray,
282
+ AxisArray,
283
+ TorchSimpleProcessor,
284
+ ]
285
+ ):
286
+ SETTINGS = TorchSimpleSettings
287
+
288
+ @ez.subscriber(BaseTransformerUnit.INPUT_SIGNAL, zero_copy=True)
289
+ @ez.publisher(BaseTransformerUnit.OUTPUT_SIGNAL)
290
+ @profile_subpub(trace_oldest=False)
291
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
292
+ results = await self.processor.__acall__(message)
293
+ for result in results:
294
+ yield self.OUTPUT_SIGNAL, result
295
+
296
+
297
+ class TorchModelProcessor(
298
+ BaseAdaptiveTransformer[TorchModelSettings, AxisArray, AxisArray, TorchModelState],
299
+ TorchProcessorMixin,
300
+ ModelInitMixin,
301
+ ):
302
+ def _reset_state(self, message: AxisArray) -> None:
303
+ model_kwargs = dict(self.settings.model_kwargs or {})
304
+ self._common_reset_state(message, model_kwargs)
305
+ self._init_optimizer()
306
+ self._validate_loss_keys(list(self._state.chan_ax.keys()))
307
+
308
+ def _process(self, message: AxisArray) -> list[AxisArray]:
309
+ return self._common_process(message)
310
+
311
+ def partial_fit(self, message: SampleMessage) -> None:
312
+ self._state.model.train()
313
+
314
+ X = self._to_tensor(message.sample.data)
315
+ X, batched = self._ensure_batched(X)
316
+
317
+ y_targ = message.trigger.value
318
+ if not isinstance(y_targ, dict):
319
+ y_targ = {"output": y_targ}
320
+ y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
321
+ if batched:
322
+ for key in y_targ:
323
+ y_targ[key] = y_targ[key].unsqueeze(0)
324
+
325
+ loss_fns = self.settings.loss_fn
326
+ if loss_fns is None:
327
+ raise ValueError("loss_fn must be provided in settings to use partial_fit")
328
+ if not isinstance(loss_fns, dict):
329
+ loss_fns = {k: loss_fns for k in y_targ.keys()}
330
+
331
+ weights = self.settings.loss_weights or {}
332
+
333
+ with torch.set_grad_enabled(True):
334
+ y_pred = self._state.model(X)
335
+ if not isinstance(y_pred, dict):
336
+ y_pred = {"output": y_pred}
337
+
338
+ losses = []
339
+ for key in y_targ.keys():
340
+ loss_fn = loss_fns.get(key)
341
+ if loss_fn is None:
342
+ raise ValueError(
343
+ f"Loss function for key '{key}' is not defined in settings."
344
+ )
345
+ if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
346
+ loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long())
347
+ else:
348
+ loss = loss_fn(y_pred[key], y_targ[key])
349
+ weight = weights.get(key, 1.0)
350
+ losses.append(loss * weight)
351
+ total_loss = sum(losses)
352
+
353
+ self._state.optimizer.zero_grad()
354
+ total_loss.backward()
355
+ self._state.optimizer.step()
356
+ if self._state.scheduler is not None:
357
+ self._state.scheduler.step()
358
+
359
+ self._state.model.eval()
360
+
361
+
362
+ class TorchModelUnit(
363
+ BaseAdaptiveTransformerUnit[
364
+ TorchModelSettings,
365
+ AxisArray,
366
+ AxisArray,
367
+ TorchModelProcessor,
368
+ ]
369
+ ):
370
+ SETTINGS = TorchModelSettings
371
+
372
+ @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
373
+ @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
374
+ @profile_subpub(trace_oldest=False)
375
+ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
376
+ results = await self.processor.__acall__(message)
377
+ for result in results:
378
+ yield self.OUTPUT_SIGNAL, result