ezmsg-learn 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__init__.py +2 -0
- ezmsg/learn/__version__.py +34 -0
- ezmsg/learn/dim_reduce/__init__.py +0 -0
- ezmsg/learn/dim_reduce/adaptive_decomp.py +274 -0
- ezmsg/learn/dim_reduce/incremental_decomp.py +173 -0
- ezmsg/learn/linear_model/__init__.py +1 -0
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +12 -0
- ezmsg/learn/linear_model/cca.py +1 -0
- ezmsg/learn/linear_model/linear_regressor.py +9 -0
- ezmsg/learn/linear_model/sgd.py +9 -0
- ezmsg/learn/linear_model/slda.py +12 -0
- ezmsg/learn/model/__init__.py +0 -0
- ezmsg/learn/model/cca.py +122 -0
- ezmsg/learn/model/mlp.py +127 -0
- ezmsg/learn/model/mlp_old.py +49 -0
- ezmsg/learn/model/refit_kalman.py +369 -0
- ezmsg/learn/model/rnn.py +160 -0
- ezmsg/learn/model/transformer.py +175 -0
- ezmsg/learn/nlin_model/__init__.py +1 -0
- ezmsg/learn/nlin_model/mlp.py +10 -0
- ezmsg/learn/process/__init__.py +0 -0
- ezmsg/learn/process/adaptive_linear_regressor.py +142 -0
- ezmsg/learn/process/base.py +154 -0
- ezmsg/learn/process/linear_regressor.py +95 -0
- ezmsg/learn/process/mlp_old.py +188 -0
- ezmsg/learn/process/refit_kalman.py +403 -0
- ezmsg/learn/process/rnn.py +245 -0
- ezmsg/learn/process/sgd.py +117 -0
- ezmsg/learn/process/sklearn.py +241 -0
- ezmsg/learn/process/slda.py +110 -0
- ezmsg/learn/process/ssr.py +374 -0
- ezmsg/learn/process/torch.py +362 -0
- ezmsg/learn/process/transformer.py +215 -0
- ezmsg/learn/util.py +67 -0
- ezmsg_learn-1.1.0.dist-info/METADATA +30 -0
- ezmsg_learn-1.1.0.dist-info/RECORD +38 -0
- ezmsg_learn-1.1.0.dist-info/WHEEL +4 -0
- ezmsg_learn-1.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from ezmsg.baseproc import (
|
|
8
|
+
BaseAdaptiveTransformer,
|
|
9
|
+
BaseAdaptiveTransformerUnit,
|
|
10
|
+
BaseStatefulTransformer,
|
|
11
|
+
BaseTransformerUnit,
|
|
12
|
+
processor_state,
|
|
13
|
+
)
|
|
14
|
+
from ezmsg.baseproc.util.profile import profile_subpub
|
|
15
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
16
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
17
|
+
from ezmsg.util.messages.util import replace
|
|
18
|
+
|
|
19
|
+
from .base import ModelInitMixin
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TorchSimpleSettings(ez.Settings):
|
|
23
|
+
model_class: str
|
|
24
|
+
"""
|
|
25
|
+
Fully qualified class path of the model to be used.
|
|
26
|
+
Example: "my_module.MyModelClass"
|
|
27
|
+
This class should inherit from `torch.nn.Module`.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
checkpoint_path: str | None = None
|
|
31
|
+
"""
|
|
32
|
+
Path to a checkpoint file containing model weights.
|
|
33
|
+
If None, the model will be initialized with random weights.
|
|
34
|
+
If parameters inferred from the weight sizes conflict with the settings / config,
|
|
35
|
+
then the the inferred parameters will take priority and a warning will be logged.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
config_path: str | None = None
|
|
39
|
+
"""
|
|
40
|
+
Path to a config file containing model parameters.
|
|
41
|
+
Parameters loaded from the config file will take priority over settings.
|
|
42
|
+
If settings differ from config parameters then a warning will be logged.
|
|
43
|
+
If `checkpoint_path` is provided then any parameters inferred from the weights
|
|
44
|
+
will take priority over the config parameters.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
single_precision: bool = True
|
|
48
|
+
"""Use single precision (float32) instead of double precision (float64)"""
|
|
49
|
+
|
|
50
|
+
device: str | None = None
|
|
51
|
+
"""
|
|
52
|
+
Device to use for the model. If None, the device will be determined automatically,
|
|
53
|
+
with preference for cuda > mps > cpu.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
model_kwargs: dict[str, typing.Any] | None = None
|
|
57
|
+
"""
|
|
58
|
+
Additional keyword arguments to pass to the model constructor.
|
|
59
|
+
This can include parameters like `input_size`, `output_size`, etc.
|
|
60
|
+
If a config file is provided, these parameters will be updated with the config values.
|
|
61
|
+
If a checkpoint file is provided, these parameters will be updated with the inferred values
|
|
62
|
+
from the model weights.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TorchModelSettings(TorchSimpleSettings):
|
|
67
|
+
learning_rate: float = 0.001
|
|
68
|
+
"""Learning rate for the optimizer"""
|
|
69
|
+
|
|
70
|
+
weight_decay: float = 0.0001
|
|
71
|
+
"""Weight decay for the optimizer"""
|
|
72
|
+
|
|
73
|
+
loss_fn: torch.nn.Module | dict[str, torch.nn.Module] | None = None
|
|
74
|
+
"""
|
|
75
|
+
Loss function(s) for the decoder. If using multiple heads, this should be a dictionary
|
|
76
|
+
mapping head names to loss functions. The keys must match the output head names.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
loss_weights: dict[str, float] | None = None
|
|
80
|
+
"""
|
|
81
|
+
Weights for each loss function if using multiple heads.
|
|
82
|
+
The keys must match the output head names.
|
|
83
|
+
If None or missing/mismatched keys, losses will be unweighted.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
scheduler_gamma: float = 0.999
|
|
87
|
+
"""Learning scheduler decay rate. Set to 0.0 to disable the scheduler."""
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@processor_state
|
|
91
|
+
class TorchSimpleState:
|
|
92
|
+
model: torch.nn.Module | None = None
|
|
93
|
+
device: torch.device | None = None
|
|
94
|
+
chan_ax: dict[str, AxisArray.CoordinateAxis] | None = None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class TorchModelState(TorchSimpleState):
|
|
98
|
+
optimizer: torch.optim.Optimizer | None = None
|
|
99
|
+
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
P = typing.TypeVar("P", bound=BaseStatefulTransformer)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class TorchProcessorMixin:
|
|
106
|
+
"""Mixin with shared functionality for torch processors."""
|
|
107
|
+
|
|
108
|
+
def _import_model(self, class_path: str) -> type[torch.nn.Module]:
|
|
109
|
+
"""Dynamically import model class from string."""
|
|
110
|
+
if class_path is None:
|
|
111
|
+
raise ValueError("Model class path must be provided in settings.")
|
|
112
|
+
module_path, class_name = class_path.rsplit(".", 1)
|
|
113
|
+
module = importlib.import_module(module_path)
|
|
114
|
+
return getattr(module, class_name)
|
|
115
|
+
|
|
116
|
+
def _infer_output_sizes(self: P, model: torch.nn.Module, n_input: int) -> dict[str, int]:
|
|
117
|
+
"""Simple inference to get output channel size. Override if needed."""
|
|
118
|
+
dummy_input = torch.zeros(1, 1, n_input, device=self._state.device)
|
|
119
|
+
with torch.no_grad():
|
|
120
|
+
output = model(dummy_input)
|
|
121
|
+
|
|
122
|
+
if isinstance(output, dict):
|
|
123
|
+
return {k: v.shape[-1] for k, v in output.items()}
|
|
124
|
+
else:
|
|
125
|
+
return {"output": output.shape[-1]}
|
|
126
|
+
|
|
127
|
+
def _init_optimizer(self) -> None:
|
|
128
|
+
self._state.optimizer = torch.optim.AdamW(
|
|
129
|
+
self._state.model.parameters(),
|
|
130
|
+
lr=self.settings.learning_rate,
|
|
131
|
+
weight_decay=self.settings.weight_decay,
|
|
132
|
+
)
|
|
133
|
+
self._state.scheduler = (
|
|
134
|
+
torch.optim.lr_scheduler.ExponentialLR(self._state.optimizer, gamma=self.settings.scheduler_gamma)
|
|
135
|
+
if self.settings.scheduler_gamma > 0.0
|
|
136
|
+
else None
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def _validate_loss_keys(self, output_keys: list[str]):
|
|
140
|
+
if isinstance(self.settings.loss_fn, dict):
|
|
141
|
+
missing = [k for k in output_keys if k not in self.settings.loss_fn]
|
|
142
|
+
if missing:
|
|
143
|
+
raise ValueError(f"Missing loss function(s) for output keys: {missing}")
|
|
144
|
+
|
|
145
|
+
def _to_tensor(self: P, data: np.ndarray) -> torch.Tensor:
|
|
146
|
+
dtype = torch.float32 if self.settings.single_precision else torch.float64
|
|
147
|
+
if isinstance(data, torch.Tensor):
|
|
148
|
+
return data.detach().clone().to(device=self._state.device, dtype=dtype)
|
|
149
|
+
return torch.tensor(data, dtype=dtype, device=self._state.device)
|
|
150
|
+
|
|
151
|
+
def save_checkpoint(self: P, path: str) -> None:
|
|
152
|
+
"""Save the current model state to a checkpoint file."""
|
|
153
|
+
if self._state.model is None:
|
|
154
|
+
raise RuntimeError("Model must be initialized before saving a checkpoint.")
|
|
155
|
+
|
|
156
|
+
checkpoint = {
|
|
157
|
+
"model_state_dict": self._state.model.state_dict(),
|
|
158
|
+
"config": self.settings.model_kwargs or {},
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
# Add optimizer state if available
|
|
162
|
+
if hasattr(self._state, "optimizer") and self._state.optimizer is not None:
|
|
163
|
+
checkpoint["optimizer_state_dict"] = self._state.optimizer.state_dict()
|
|
164
|
+
|
|
165
|
+
torch.save(checkpoint, path)
|
|
166
|
+
|
|
167
|
+
def _ensure_batched(self, tensor: torch.Tensor) -> tuple[torch.Tensor, bool]:
|
|
168
|
+
"""
|
|
169
|
+
Ensure tensor has a batch dimension.
|
|
170
|
+
Returns the potentially modified tensor and a flag indicating whether a dimension was added.
|
|
171
|
+
"""
|
|
172
|
+
if tensor.ndim == 2:
|
|
173
|
+
return tensor.unsqueeze(0), True
|
|
174
|
+
return tensor, False
|
|
175
|
+
|
|
176
|
+
def _common_process(self: P, message: AxisArray) -> list[AxisArray]:
|
|
177
|
+
data = message.data
|
|
178
|
+
data = self._to_tensor(data)
|
|
179
|
+
|
|
180
|
+
# Add batch dimension if missing
|
|
181
|
+
data, added_batch_dim = self._ensure_batched(data)
|
|
182
|
+
|
|
183
|
+
with torch.no_grad():
|
|
184
|
+
output = self._state.model(data)
|
|
185
|
+
|
|
186
|
+
if isinstance(output, dict):
|
|
187
|
+
output_messages = [
|
|
188
|
+
replace(
|
|
189
|
+
message,
|
|
190
|
+
data=value.cpu().numpy().squeeze(0) if added_batch_dim else value.cpu().numpy(),
|
|
191
|
+
axes={
|
|
192
|
+
**message.axes,
|
|
193
|
+
"ch": self._state.chan_ax[key],
|
|
194
|
+
},
|
|
195
|
+
key=key,
|
|
196
|
+
)
|
|
197
|
+
for key, value in output.items()
|
|
198
|
+
]
|
|
199
|
+
return output_messages
|
|
200
|
+
|
|
201
|
+
return [
|
|
202
|
+
replace(
|
|
203
|
+
message,
|
|
204
|
+
data=output.cpu().numpy().squeeze(0) if added_batch_dim else output.cpu().numpy(),
|
|
205
|
+
axes={
|
|
206
|
+
**message.axes,
|
|
207
|
+
"ch": self._state.chan_ax["output"],
|
|
208
|
+
},
|
|
209
|
+
)
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
def _common_reset_state(self: P, message: AxisArray, model_kwargs: dict) -> None:
|
|
213
|
+
n_input = message.data.shape[message.get_axis_idx("ch")]
|
|
214
|
+
|
|
215
|
+
if "input_size" in model_kwargs:
|
|
216
|
+
if model_kwargs["input_size"] != n_input:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Mismatch between model_kwargs['input_size']={model_kwargs['input_size']} "
|
|
219
|
+
f"and input data channels={n_input}"
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
model_kwargs["input_size"] = n_input
|
|
223
|
+
|
|
224
|
+
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")
|
|
225
|
+
device = self.settings.device or device
|
|
226
|
+
self._state.device = torch.device(device)
|
|
227
|
+
|
|
228
|
+
model_class = self._import_model(self.settings.model_class)
|
|
229
|
+
|
|
230
|
+
self._state.model = self._init_model(
|
|
231
|
+
model_class=model_class,
|
|
232
|
+
params=model_kwargs,
|
|
233
|
+
config_path=self.settings.config_path,
|
|
234
|
+
checkpoint_path=self.settings.checkpoint_path,
|
|
235
|
+
device=device,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
self._state.model.eval()
|
|
239
|
+
|
|
240
|
+
output_sizes = self._infer_output_sizes(self._state.model, n_input)
|
|
241
|
+
self._state.chan_ax = {
|
|
242
|
+
head: AxisArray.CoordinateAxis(
|
|
243
|
+
data=np.array([f"{head}_ch{_}" for _ in range(size)]),
|
|
244
|
+
dims=["ch"],
|
|
245
|
+
)
|
|
246
|
+
for head, size in output_sizes.items()
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class TorchSimpleProcessor(
|
|
251
|
+
BaseStatefulTransformer[TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState],
|
|
252
|
+
TorchProcessorMixin,
|
|
253
|
+
ModelInitMixin,
|
|
254
|
+
):
|
|
255
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
256
|
+
model_kwargs = dict(self.settings.model_kwargs or {})
|
|
257
|
+
self._common_reset_state(message, model_kwargs)
|
|
258
|
+
|
|
259
|
+
def _process(self, message: AxisArray) -> list[AxisArray]:
|
|
260
|
+
"""Process the input message and return the output messages."""
|
|
261
|
+
return self._common_process(message)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class TorchSimpleUnit(
|
|
265
|
+
BaseTransformerUnit[
|
|
266
|
+
TorchSimpleSettings,
|
|
267
|
+
AxisArray,
|
|
268
|
+
AxisArray,
|
|
269
|
+
TorchSimpleProcessor,
|
|
270
|
+
]
|
|
271
|
+
):
|
|
272
|
+
SETTINGS = TorchSimpleSettings
|
|
273
|
+
|
|
274
|
+
@ez.subscriber(BaseTransformerUnit.INPUT_SIGNAL, zero_copy=True)
|
|
275
|
+
@ez.publisher(BaseTransformerUnit.OUTPUT_SIGNAL)
|
|
276
|
+
@profile_subpub(trace_oldest=False)
|
|
277
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
278
|
+
results = await self.processor.__acall__(message)
|
|
279
|
+
for result in results:
|
|
280
|
+
yield self.OUTPUT_SIGNAL, result
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class TorchModelProcessor(
|
|
284
|
+
BaseAdaptiveTransformer[TorchModelSettings, AxisArray, AxisArray, TorchModelState],
|
|
285
|
+
TorchProcessorMixin,
|
|
286
|
+
ModelInitMixin,
|
|
287
|
+
):
|
|
288
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
289
|
+
model_kwargs = dict(self.settings.model_kwargs or {})
|
|
290
|
+
self._common_reset_state(message, model_kwargs)
|
|
291
|
+
self._init_optimizer()
|
|
292
|
+
self._validate_loss_keys(list(self._state.chan_ax.keys()))
|
|
293
|
+
|
|
294
|
+
def _process(self, message: AxisArray) -> list[AxisArray]:
|
|
295
|
+
return self._common_process(message)
|
|
296
|
+
|
|
297
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
298
|
+
self._state.model.train()
|
|
299
|
+
|
|
300
|
+
X = self._to_tensor(message.sample.data)
|
|
301
|
+
X, batched = self._ensure_batched(X)
|
|
302
|
+
|
|
303
|
+
y_targ = message.trigger.value
|
|
304
|
+
if not isinstance(y_targ, dict):
|
|
305
|
+
y_targ = {"output": y_targ}
|
|
306
|
+
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
|
|
307
|
+
if batched:
|
|
308
|
+
for key in y_targ:
|
|
309
|
+
y_targ[key] = y_targ[key].unsqueeze(0)
|
|
310
|
+
|
|
311
|
+
loss_fns = self.settings.loss_fn
|
|
312
|
+
if loss_fns is None:
|
|
313
|
+
raise ValueError("loss_fn must be provided in settings to use partial_fit")
|
|
314
|
+
if not isinstance(loss_fns, dict):
|
|
315
|
+
loss_fns = {k: loss_fns for k in y_targ.keys()}
|
|
316
|
+
|
|
317
|
+
weights = self.settings.loss_weights or {}
|
|
318
|
+
|
|
319
|
+
with torch.set_grad_enabled(True):
|
|
320
|
+
y_pred = self._state.model(X)
|
|
321
|
+
if not isinstance(y_pred, dict):
|
|
322
|
+
y_pred = {"output": y_pred}
|
|
323
|
+
|
|
324
|
+
losses = []
|
|
325
|
+
for key in y_targ.keys():
|
|
326
|
+
loss_fn = loss_fns.get(key)
|
|
327
|
+
if loss_fn is None:
|
|
328
|
+
raise ValueError(f"Loss function for key '{key}' is not defined in settings.")
|
|
329
|
+
if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
|
|
330
|
+
loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long())
|
|
331
|
+
else:
|
|
332
|
+
loss = loss_fn(y_pred[key], y_targ[key])
|
|
333
|
+
weight = weights.get(key, 1.0)
|
|
334
|
+
losses.append(loss * weight)
|
|
335
|
+
total_loss = sum(losses)
|
|
336
|
+
|
|
337
|
+
self._state.optimizer.zero_grad()
|
|
338
|
+
total_loss.backward()
|
|
339
|
+
self._state.optimizer.step()
|
|
340
|
+
if self._state.scheduler is not None:
|
|
341
|
+
self._state.scheduler.step()
|
|
342
|
+
|
|
343
|
+
self._state.model.eval()
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class TorchModelUnit(
|
|
347
|
+
BaseAdaptiveTransformerUnit[
|
|
348
|
+
TorchModelSettings,
|
|
349
|
+
AxisArray,
|
|
350
|
+
AxisArray,
|
|
351
|
+
TorchModelProcessor,
|
|
352
|
+
]
|
|
353
|
+
):
|
|
354
|
+
SETTINGS = TorchModelSettings
|
|
355
|
+
|
|
356
|
+
@ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
|
|
357
|
+
@ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
|
|
358
|
+
@profile_subpub(trace_oldest=False)
|
|
359
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
360
|
+
results = await self.processor.__acall__(message)
|
|
361
|
+
for result in results:
|
|
362
|
+
yield self.OUTPUT_SIGNAL, result
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import torch
|
|
5
|
+
from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
|
|
6
|
+
from ezmsg.baseproc.util.profile import profile_subpub
|
|
7
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
9
|
+
from ezmsg.util.messages.util import replace
|
|
10
|
+
|
|
11
|
+
from .base import ModelInitMixin
|
|
12
|
+
from .torch import (
|
|
13
|
+
TorchModelSettings,
|
|
14
|
+
TorchModelState,
|
|
15
|
+
TorchProcessorMixin,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TransformerSettings(TorchModelSettings):
|
|
20
|
+
model_class: str = "ezmsg.learn.model.transformer.TransformerModel"
|
|
21
|
+
"""
|
|
22
|
+
Fully qualified class path of the model to be used.
|
|
23
|
+
This should be "ezmsg.learn.model.transformer.TransformerModel" for this.
|
|
24
|
+
"""
|
|
25
|
+
autoregressive_head: str | None = None
|
|
26
|
+
"""
|
|
27
|
+
The name of the output head used for autoregressive decoding.
|
|
28
|
+
This should match one of the keys in the model's output dictionary.
|
|
29
|
+
If None, the first output head will be used.
|
|
30
|
+
"""
|
|
31
|
+
max_cache_len: int | None = 128
|
|
32
|
+
"""
|
|
33
|
+
Maximum length of the target sequence cache for autoregressive decoding.
|
|
34
|
+
This limits the context length during decoding to prevent excessive memory usage.
|
|
35
|
+
If set to None, the cache will grow indefinitely.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TransformerState(TorchModelState):
|
|
40
|
+
ar_head: str | None = None
|
|
41
|
+
"""
|
|
42
|
+
The name of the autoregressive head used for decoding.
|
|
43
|
+
This is set based on the `autoregressive_head` setting.
|
|
44
|
+
If None, the first output head will be used.
|
|
45
|
+
"""
|
|
46
|
+
tgt_cache: typing.Optional[torch.Tensor] = None
|
|
47
|
+
"""
|
|
48
|
+
Cache for the target sequence used in autoregressive decoding.
|
|
49
|
+
This is updated with each processed message to maintain context.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TransformerProcessor(
|
|
54
|
+
BaseAdaptiveTransformer[TransformerSettings, AxisArray, AxisArray, TransformerState],
|
|
55
|
+
TorchProcessorMixin,
|
|
56
|
+
ModelInitMixin,
|
|
57
|
+
):
|
|
58
|
+
@property
|
|
59
|
+
def has_decoder(self) -> bool:
|
|
60
|
+
return self.settings.model_kwargs.get("decoder_layers", 0) > 0
|
|
61
|
+
|
|
62
|
+
def reset_cache(self) -> None:
|
|
63
|
+
self._state.tgt_cache = None
|
|
64
|
+
|
|
65
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
66
|
+
model_kwargs = dict(self.settings.model_kwargs or {})
|
|
67
|
+
self._common_reset_state(message, model_kwargs)
|
|
68
|
+
self._init_optimizer()
|
|
69
|
+
self._validate_loss_keys(list(self._state.chan_ax.keys()))
|
|
70
|
+
|
|
71
|
+
self._state.tgt_cache = None
|
|
72
|
+
if (
|
|
73
|
+
self.settings.autoregressive_head is not None
|
|
74
|
+
and self.settings.autoregressive_head not in self._state.chan_ax
|
|
75
|
+
):
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Autoregressive head '{self.settings.autoregressive_head}' not found in target"
|
|
78
|
+
f"dictionary keys: {list(self._state.chan_ax.keys())}"
|
|
79
|
+
)
|
|
80
|
+
self._state.ar_head = (
|
|
81
|
+
self.settings.autoregressive_head
|
|
82
|
+
if self.settings.autoregressive_head is not None
|
|
83
|
+
else list(self._state.chan_ax.keys())[0]
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def _process(self, message: AxisArray) -> list[AxisArray]:
|
|
87
|
+
# If has_decoder is False, fallback to regular processing
|
|
88
|
+
if not self.has_decoder:
|
|
89
|
+
return self._common_process(message)
|
|
90
|
+
|
|
91
|
+
x = self._to_tensor(message.data)
|
|
92
|
+
x, _ = self._ensure_batched(x)
|
|
93
|
+
if x.shape[0] > 1:
|
|
94
|
+
raise ValueError("Autoregressive decoding only supports batch size 1.")
|
|
95
|
+
|
|
96
|
+
with torch.no_grad():
|
|
97
|
+
y_pred = self._state.model(x, tgt=self._state.tgt_cache)
|
|
98
|
+
|
|
99
|
+
pred = y_pred[self._state.ar_head]
|
|
100
|
+
if self._state.tgt_cache is None:
|
|
101
|
+
self._state.tgt_cache = pred[:, -1:, :]
|
|
102
|
+
else:
|
|
103
|
+
self._state.tgt_cache = torch.cat([self._state.tgt_cache, pred[:, -1:, :]], dim=1)
|
|
104
|
+
if self.settings.max_cache_len is not None:
|
|
105
|
+
if self._state.tgt_cache.shape[1] > self.settings.max_cache_len:
|
|
106
|
+
# Trim the cache to the maximum length
|
|
107
|
+
self._state.tgt_cache = self._state.tgt_cache[:, -self.settings.max_cache_len :, :]
|
|
108
|
+
|
|
109
|
+
if isinstance(y_pred, dict):
|
|
110
|
+
return [
|
|
111
|
+
replace(
|
|
112
|
+
message,
|
|
113
|
+
data=out.squeeze(0).cpu().numpy(),
|
|
114
|
+
axes={**message.axes, "ch": self._state.chan_ax[key]},
|
|
115
|
+
key=key,
|
|
116
|
+
)
|
|
117
|
+
for key, out in y_pred.items()
|
|
118
|
+
]
|
|
119
|
+
else:
|
|
120
|
+
return [
|
|
121
|
+
replace(
|
|
122
|
+
message,
|
|
123
|
+
data=y_pred.squeeze(0).cpu().numpy(),
|
|
124
|
+
axes={**message.axes, "ch": self._state.chan_ax["output"]},
|
|
125
|
+
)
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
129
|
+
self._state.model.train()
|
|
130
|
+
|
|
131
|
+
X = self._to_tensor(message.sample.data)
|
|
132
|
+
X, batched = self._ensure_batched(X)
|
|
133
|
+
|
|
134
|
+
y_targ = message.trigger.value
|
|
135
|
+
if not isinstance(y_targ, dict):
|
|
136
|
+
y_targ = {"output": y_targ}
|
|
137
|
+
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
|
|
138
|
+
# Add batch dimension to y_targ values if missing
|
|
139
|
+
if batched:
|
|
140
|
+
for key in y_targ:
|
|
141
|
+
y_targ[key] = y_targ[key].unsqueeze(0)
|
|
142
|
+
|
|
143
|
+
loss_fns = self.settings.loss_fn
|
|
144
|
+
if loss_fns is None:
|
|
145
|
+
raise ValueError("loss_fn must be provided in settings to use partial_fit")
|
|
146
|
+
if not isinstance(loss_fns, dict):
|
|
147
|
+
loss_fns = {k: loss_fns for k in y_targ.keys()}
|
|
148
|
+
|
|
149
|
+
weights = self.settings.loss_weights or {}
|
|
150
|
+
|
|
151
|
+
if self.has_decoder:
|
|
152
|
+
if X.shape[0] != 1:
|
|
153
|
+
raise ValueError("Autoregressive decoding only supports batch size 1.")
|
|
154
|
+
|
|
155
|
+
# Create shifted target for autoregressive head
|
|
156
|
+
tgt_tensor = y_targ[self._state.ar_head]
|
|
157
|
+
tgt = torch.cat(
|
|
158
|
+
[
|
|
159
|
+
torch.zeros(
|
|
160
|
+
(1, 1, tgt_tensor.shape[-1]),
|
|
161
|
+
dtype=tgt_tensor.dtype,
|
|
162
|
+
device=tgt_tensor.device,
|
|
163
|
+
),
|
|
164
|
+
tgt_tensor[:, :-1, :],
|
|
165
|
+
],
|
|
166
|
+
dim=1,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Reset tgt_cache at start of partial_fit to avoid stale context
|
|
170
|
+
self.reset_cache()
|
|
171
|
+
y_pred = self._state.model(X, tgt=tgt)
|
|
172
|
+
else:
|
|
173
|
+
# For non-autoregressive models, use the model directly
|
|
174
|
+
y_pred = self._state.model(X)
|
|
175
|
+
|
|
176
|
+
if not isinstance(y_pred, dict):
|
|
177
|
+
y_pred = {"output": y_pred}
|
|
178
|
+
|
|
179
|
+
with torch.set_grad_enabled(True):
|
|
180
|
+
losses = []
|
|
181
|
+
for key in y_targ.keys():
|
|
182
|
+
loss_fn = loss_fns.get(key)
|
|
183
|
+
if loss_fn is None:
|
|
184
|
+
raise ValueError(f"Loss function for key '{key}' is not defined in settings.")
|
|
185
|
+
loss = loss_fn(y_pred[key], y_targ[key])
|
|
186
|
+
weight = weights.get(key, 1.0)
|
|
187
|
+
losses.append(loss * weight)
|
|
188
|
+
total_loss = sum(losses)
|
|
189
|
+
|
|
190
|
+
self._state.optimizer.zero_grad()
|
|
191
|
+
total_loss.backward()
|
|
192
|
+
self._state.optimizer.step()
|
|
193
|
+
if self._state.scheduler is not None:
|
|
194
|
+
self._state.scheduler.step()
|
|
195
|
+
|
|
196
|
+
self._state.model.eval()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class TransformerUnit(
|
|
200
|
+
BaseAdaptiveTransformerUnit[
|
|
201
|
+
TransformerSettings,
|
|
202
|
+
AxisArray,
|
|
203
|
+
AxisArray,
|
|
204
|
+
TransformerProcessor,
|
|
205
|
+
]
|
|
206
|
+
):
|
|
207
|
+
SETTINGS = TransformerSettings
|
|
208
|
+
|
|
209
|
+
@ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
|
|
210
|
+
@ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
|
|
211
|
+
@profile_subpub(trace_oldest=False)
|
|
212
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
213
|
+
results = await self.processor.__acall__(message)
|
|
214
|
+
for result in results:
|
|
215
|
+
yield self.OUTPUT_SIGNAL, result
|
ezmsg/learn/util.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
import river.linear_model
|
|
6
|
+
import sklearn.linear_model
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
|
|
9
|
+
# from sklearn.neural_network import MLPClassifier
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RegressorType(str, Enum):
|
|
13
|
+
ADAPTIVE = "adaptive"
|
|
14
|
+
STATIC = "static"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AdaptiveLinearRegressor(str, Enum):
|
|
18
|
+
LINEAR = "linear"
|
|
19
|
+
LOGISTIC = "logistic"
|
|
20
|
+
SGD = "sgd"
|
|
21
|
+
PAR = "par" # passive-aggressive
|
|
22
|
+
# MLP = "mlp"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class StaticLinearRegressor(str, Enum):
|
|
26
|
+
LINEAR = "linear"
|
|
27
|
+
RIDGE = "ridge"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
ADAPTIVE_REGRESSORS = {
|
|
31
|
+
AdaptiveLinearRegressor.LINEAR: river.linear_model.LinearRegression,
|
|
32
|
+
AdaptiveLinearRegressor.LOGISTIC: river.linear_model.LogisticRegression,
|
|
33
|
+
AdaptiveLinearRegressor.SGD: sklearn.linear_model.SGDRegressor,
|
|
34
|
+
AdaptiveLinearRegressor.PAR: sklearn.linear_model.PassiveAggressiveRegressor,
|
|
35
|
+
# AdaptiveLinearRegressor.MLP: MLPClassifier,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Function to get a regressor by type and name
|
|
40
|
+
def get_regressor(
|
|
41
|
+
regressor_type: typing.Union[RegressorType, str],
|
|
42
|
+
regressor_name: typing.Union[AdaptiveLinearRegressor, StaticLinearRegressor, str],
|
|
43
|
+
):
|
|
44
|
+
if isinstance(regressor_type, str):
|
|
45
|
+
regressor_type = RegressorType(regressor_type)
|
|
46
|
+
|
|
47
|
+
if regressor_type == RegressorType.ADAPTIVE:
|
|
48
|
+
if isinstance(regressor_name, str):
|
|
49
|
+
regressor_name = AdaptiveLinearRegressor(regressor_name)
|
|
50
|
+
return ADAPTIVE_REGRESSORS[regressor_name]
|
|
51
|
+
elif regressor_type == RegressorType.STATIC:
|
|
52
|
+
if isinstance(regressor_name, str):
|
|
53
|
+
regressor_name = StaticLinearRegressor(regressor_name)
|
|
54
|
+
return STATIC_REGRESSORS[regressor_name]
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"Unknown regressor type: {regressor_type}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
STATIC_REGRESSORS = {
|
|
60
|
+
StaticLinearRegressor.LINEAR: sklearn.linear_model.LinearRegression,
|
|
61
|
+
StaticLinearRegressor.RIDGE: sklearn.linear_model.Ridge,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class ClassifierMessage(AxisArray):
|
|
67
|
+
labels: list[str] = field(default_factory=list)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: ezmsg-learn
|
|
3
|
+
Version: 1.1.0
|
|
4
|
+
Summary: ezmsg namespace package for machine learning
|
|
5
|
+
Author-email: Chadwick Boulay <chadwick.boulay@gmail.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
License-File: LICENSE
|
|
8
|
+
Requires-Python: >=3.10.15
|
|
9
|
+
Requires-Dist: ezmsg-baseproc>=1.0.2
|
|
10
|
+
Requires-Dist: ezmsg-sigproc>=2.14.0
|
|
11
|
+
Requires-Dist: river>=0.22.0
|
|
12
|
+
Requires-Dist: scikit-learn>=1.6.0
|
|
13
|
+
Requires-Dist: torch>=2.6.0
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
# ezmsg-learn
|
|
17
|
+
|
|
18
|
+
This repository contains a Python package with modules for machine learning (ML)-related processing in the [`ezmsg`](https://www.ezmsg.org) framework. As ezmsg is intended primarily for processing unbounded streaming signals, so are the modules in this repo.
|
|
19
|
+
|
|
20
|
+
> If you are only interested in offline analysis without concern for reproducibility in online applications, then you should probably look elsewhere.
|
|
21
|
+
|
|
22
|
+
Processing units include dimensionality reduction, linear regression, and classification that can be initialized with known weights, or adapted on-the-fly with incoming (labeled) data. Machine-learning code depends on `river`, `scikit-learn`, `numpy`, and `torch`.
|
|
23
|
+
|
|
24
|
+
## Getting Started
|
|
25
|
+
|
|
26
|
+
This ezmsg namespace package is still highly experimental and under active development. It is not yet available on PyPI, so you will need to install it from source. The easiest way to do this is to use the `pip` command to install the package directly from GitHub:
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
pip install git+https://github.com/ezmsg-org/ezmsg-learn
|
|
30
|
+
```
|