ezmsg-learn 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/learn/__init__.py +2 -0
- ezmsg/learn/__version__.py +34 -0
- ezmsg/learn/dim_reduce/__init__.py +0 -0
- ezmsg/learn/dim_reduce/adaptive_decomp.py +284 -0
- ezmsg/learn/dim_reduce/incremental_decomp.py +181 -0
- ezmsg/learn/linear_model/__init__.py +1 -0
- ezmsg/learn/linear_model/adaptive_linear_regressor.py +6 -0
- ezmsg/learn/linear_model/cca.py +1 -0
- ezmsg/learn/linear_model/linear_regressor.py +5 -0
- ezmsg/learn/linear_model/sgd.py +5 -0
- ezmsg/learn/linear_model/slda.py +6 -0
- ezmsg/learn/model/__init__.py +0 -0
- ezmsg/learn/model/cca.py +122 -0
- ezmsg/learn/model/mlp.py +133 -0
- ezmsg/learn/model/mlp_old.py +49 -0
- ezmsg/learn/model/refit_kalman.py +401 -0
- ezmsg/learn/model/rnn.py +160 -0
- ezmsg/learn/model/transformer.py +175 -0
- ezmsg/learn/nlin_model/__init__.py +1 -0
- ezmsg/learn/nlin_model/mlp.py +6 -0
- ezmsg/learn/process/__init__.py +0 -0
- ezmsg/learn/process/adaptive_linear_regressor.py +157 -0
- ezmsg/learn/process/base.py +173 -0
- ezmsg/learn/process/linear_regressor.py +99 -0
- ezmsg/learn/process/mlp_old.py +200 -0
- ezmsg/learn/process/refit_kalman.py +407 -0
- ezmsg/learn/process/rnn.py +266 -0
- ezmsg/learn/process/sgd.py +131 -0
- ezmsg/learn/process/sklearn.py +274 -0
- ezmsg/learn/process/slda.py +119 -0
- ezmsg/learn/process/torch.py +378 -0
- ezmsg/learn/process/transformer.py +222 -0
- ezmsg/learn/util.py +66 -0
- ezmsg_learn-1.0.dist-info/METADATA +34 -0
- ezmsg_learn-1.0.dist-info/RECORD +36 -0
- ezmsg_learn-1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ezmsg.sigproc.base import (
|
|
6
|
+
BaseStatefulTransformer,
|
|
7
|
+
processor_state,
|
|
8
|
+
BaseTransformerUnit,
|
|
9
|
+
)
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
from ezmsg.util.messages.util import replace
|
|
12
|
+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
|
|
13
|
+
|
|
14
|
+
from ..util import ClassifierMessage
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SLDASettings(ez.Settings):
|
|
18
|
+
settings_path: str
|
|
19
|
+
axis: str = "time"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@processor_state
|
|
23
|
+
class SLDAState:
|
|
24
|
+
lda: LDA
|
|
25
|
+
out_template: typing.Optional[ClassifierMessage] = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SLDATransformer(
|
|
29
|
+
BaseStatefulTransformer[SLDASettings, AxisArray, ClassifierMessage, SLDAState]
|
|
30
|
+
):
|
|
31
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
32
|
+
if self.settings.settings_path[-4:] == ".mat":
|
|
33
|
+
# Expects a very specific format from a specific project. Not for general use.
|
|
34
|
+
import scipy.io as sio
|
|
35
|
+
|
|
36
|
+
matlab_sLDA = sio.loadmat(self.settings.settings_path, squeeze_me=True)
|
|
37
|
+
temp_weights = matlab_sLDA["weights"][1, 1:]
|
|
38
|
+
temp_intercept = matlab_sLDA["weights"][1, 0]
|
|
39
|
+
|
|
40
|
+
# Create weights and use zeros for channels we do not keep.
|
|
41
|
+
channels = matlab_sLDA["channels"] - 4
|
|
42
|
+
channels -= channels[0] # Offsets are wrong somehow.
|
|
43
|
+
n_channels = message.data.shape[message.dims.index("ch")]
|
|
44
|
+
valid_indices = [ch for ch in channels if ch < n_channels]
|
|
45
|
+
full_weights = np.zeros(n_channels)
|
|
46
|
+
full_weights[valid_indices] = temp_weights[: len(valid_indices)]
|
|
47
|
+
|
|
48
|
+
lda = LDA(solver="lsqr", shrinkage="auto")
|
|
49
|
+
lda.classes_ = np.asarray([0, 1])
|
|
50
|
+
lda.coef_ = np.expand_dims(full_weights, axis=0)
|
|
51
|
+
lda.intercept_ = temp_intercept # TODO: Is this supposed to be per-channel? Why the [1, 0]?
|
|
52
|
+
self.state.lda = lda
|
|
53
|
+
# mean = matlab_sLDA['mXtrain']
|
|
54
|
+
# std = matlab_sLDA['sXtrain']
|
|
55
|
+
# lags = matlab_sLDA['lags'] + 1
|
|
56
|
+
else:
|
|
57
|
+
import pickle
|
|
58
|
+
|
|
59
|
+
with open(self.settings.settings_path, "rb") as f:
|
|
60
|
+
self.state.lda = pickle.load(f)
|
|
61
|
+
|
|
62
|
+
# Create template ClassifierMessage using lda.classes_
|
|
63
|
+
out_labels = self.state.lda.classes_.tolist()
|
|
64
|
+
zero_shape = (0, len(out_labels))
|
|
65
|
+
self.state.out_template = ClassifierMessage(
|
|
66
|
+
data=np.zeros(zero_shape, dtype=message.data.dtype),
|
|
67
|
+
dims=[self.settings.axis, "classes"],
|
|
68
|
+
axes={
|
|
69
|
+
self.settings.axis: message.axes[self.settings.axis],
|
|
70
|
+
"classes": AxisArray.CoordinateAxis(
|
|
71
|
+
data=np.array(out_labels), dims=["classes"]
|
|
72
|
+
),
|
|
73
|
+
},
|
|
74
|
+
labels=out_labels,
|
|
75
|
+
key=message.key,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _process(self, message: AxisArray) -> ClassifierMessage:
|
|
79
|
+
samp_ax_idx = message.dims.index(self.settings.axis)
|
|
80
|
+
X = np.moveaxis(message.data, samp_ax_idx, 0)
|
|
81
|
+
|
|
82
|
+
if X.shape[0]:
|
|
83
|
+
if (
|
|
84
|
+
isinstance(self.settings.settings_path, str)
|
|
85
|
+
and self.settings.settings_path[-4:] == ".mat"
|
|
86
|
+
):
|
|
87
|
+
# Assumes F-contiguous weights
|
|
88
|
+
pred_probas = []
|
|
89
|
+
for samp in X:
|
|
90
|
+
tmp = samp.flatten(order="F") * 1e-6
|
|
91
|
+
tmp = np.expand_dims(tmp, axis=0)
|
|
92
|
+
probas = self.state.lda.predict_proba(tmp)
|
|
93
|
+
pred_probas.append(probas)
|
|
94
|
+
pred_probas = np.concatenate(pred_probas, axis=0)
|
|
95
|
+
else:
|
|
96
|
+
# This creates a copy.
|
|
97
|
+
X = X.reshape(X.shape[0], -1)
|
|
98
|
+
pred_probas = self.state.lda.predict_proba(X)
|
|
99
|
+
|
|
100
|
+
update_ax = self.state.out_template.axes[self.settings.axis]
|
|
101
|
+
update_ax.offset = message.axes[self.settings.axis].offset
|
|
102
|
+
|
|
103
|
+
return replace(
|
|
104
|
+
self.state.out_template,
|
|
105
|
+
data=pred_probas,
|
|
106
|
+
axes={
|
|
107
|
+
**self.state.out_template.axes,
|
|
108
|
+
# `replace` will copy the minimal set of fields
|
|
109
|
+
self.settings.axis: replace(update_ax, offset=update_ax.offset),
|
|
110
|
+
},
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
return self.state.out_template
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class SLDA(
|
|
117
|
+
BaseTransformerUnit[SLDASettings, AxisArray, ClassifierMessage, SLDATransformer]
|
|
118
|
+
):
|
|
119
|
+
SETTINGS = SLDASettings
|
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import ezmsg.core as ez
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from ezmsg.sigproc.base import (
|
|
8
|
+
BaseAdaptiveTransformer,
|
|
9
|
+
BaseAdaptiveTransformerUnit,
|
|
10
|
+
BaseStatefulTransformer,
|
|
11
|
+
BaseTransformerUnit,
|
|
12
|
+
processor_state,
|
|
13
|
+
)
|
|
14
|
+
from ezmsg.sigproc.sampler import SampleMessage
|
|
15
|
+
from ezmsg.sigproc.util.profile import profile_subpub
|
|
16
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
17
|
+
from ezmsg.util.messages.util import replace
|
|
18
|
+
|
|
19
|
+
from .base import ModelInitMixin
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TorchSimpleSettings(ez.Settings):
|
|
23
|
+
model_class: str
|
|
24
|
+
"""
|
|
25
|
+
Fully qualified class path of the model to be used.
|
|
26
|
+
Example: "my_module.MyModelClass"
|
|
27
|
+
This class should inherit from `torch.nn.Module`.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
checkpoint_path: str | None = None
|
|
31
|
+
"""
|
|
32
|
+
Path to a checkpoint file containing model weights.
|
|
33
|
+
If None, the model will be initialized with random weights.
|
|
34
|
+
If parameters inferred from the weight sizes conflict with the settings / config,
|
|
35
|
+
then the the inferred parameters will take priority and a warning will be logged.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
config_path: str | None = None
|
|
39
|
+
"""
|
|
40
|
+
Path to a config file containing model parameters.
|
|
41
|
+
Parameters loaded from the config file will take priority over settings.
|
|
42
|
+
If settings differ from config parameters then a warning will be logged.
|
|
43
|
+
If `checkpoint_path` is provided then any parameters inferred from the weights
|
|
44
|
+
will take priority over the config parameters.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
single_precision: bool = True
|
|
48
|
+
"""Use single precision (float32) instead of double precision (float64)"""
|
|
49
|
+
|
|
50
|
+
device: str | None = None
|
|
51
|
+
"""
|
|
52
|
+
Device to use for the model. If None, the device will be determined automatically,
|
|
53
|
+
with preference for cuda > mps > cpu.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
model_kwargs: dict[str, typing.Any] | None = None
|
|
57
|
+
"""
|
|
58
|
+
Additional keyword arguments to pass to the model constructor.
|
|
59
|
+
This can include parameters like `input_size`, `output_size`, etc.
|
|
60
|
+
If a config file is provided, these parameters will be updated with the config values.
|
|
61
|
+
If a checkpoint file is provided, these parameters will be updated with the inferred values
|
|
62
|
+
from the model weights.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TorchModelSettings(TorchSimpleSettings):
|
|
67
|
+
learning_rate: float = 0.001
|
|
68
|
+
"""Learning rate for the optimizer"""
|
|
69
|
+
|
|
70
|
+
weight_decay: float = 0.0001
|
|
71
|
+
"""Weight decay for the optimizer"""
|
|
72
|
+
|
|
73
|
+
loss_fn: torch.nn.Module | dict[str, torch.nn.Module] | None = None
|
|
74
|
+
"""
|
|
75
|
+
Loss function(s) for the decoder. If using multiple heads, this should be a dictionary
|
|
76
|
+
mapping head names to loss functions. The keys must match the output head names.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
loss_weights: dict[str, float] | None = None
|
|
80
|
+
"""
|
|
81
|
+
Weights for each loss function if using multiple heads.
|
|
82
|
+
The keys must match the output head names.
|
|
83
|
+
If None or missing/mismatched keys, losses will be unweighted.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
scheduler_gamma: float = 0.999
|
|
87
|
+
"""Learning scheduler decay rate. Set to 0.0 to disable the scheduler."""
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@processor_state
|
|
91
|
+
class TorchSimpleState:
|
|
92
|
+
model: torch.nn.Module | None = None
|
|
93
|
+
device: torch.device | None = None
|
|
94
|
+
chan_ax: dict[str, AxisArray.CoordinateAxis] | None = None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class TorchModelState(TorchSimpleState):
|
|
98
|
+
optimizer: torch.optim.Optimizer | None = None
|
|
99
|
+
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
P = typing.TypeVar("P", bound=BaseStatefulTransformer)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class TorchProcessorMixin:
|
|
106
|
+
"""Mixin with shared functionality for torch processors."""
|
|
107
|
+
|
|
108
|
+
def _import_model(self, class_path: str) -> type[torch.nn.Module]:
|
|
109
|
+
"""Dynamically import model class from string."""
|
|
110
|
+
if class_path is None:
|
|
111
|
+
raise ValueError("Model class path must be provided in settings.")
|
|
112
|
+
module_path, class_name = class_path.rsplit(".", 1)
|
|
113
|
+
module = importlib.import_module(module_path)
|
|
114
|
+
return getattr(module, class_name)
|
|
115
|
+
|
|
116
|
+
def _infer_output_sizes(
|
|
117
|
+
self: P, model: torch.nn.Module, n_input: int
|
|
118
|
+
) -> dict[str, int]:
|
|
119
|
+
"""Simple inference to get output channel size. Override if needed."""
|
|
120
|
+
dummy_input = torch.zeros(1, 1, n_input, device=self._state.device)
|
|
121
|
+
with torch.no_grad():
|
|
122
|
+
output = model(dummy_input)
|
|
123
|
+
|
|
124
|
+
if isinstance(output, dict):
|
|
125
|
+
return {k: v.shape[-1] for k, v in output.items()}
|
|
126
|
+
else:
|
|
127
|
+
return {"output": output.shape[-1]}
|
|
128
|
+
|
|
129
|
+
def _init_optimizer(self) -> None:
|
|
130
|
+
self._state.optimizer = torch.optim.AdamW(
|
|
131
|
+
self._state.model.parameters(),
|
|
132
|
+
lr=self.settings.learning_rate,
|
|
133
|
+
weight_decay=self.settings.weight_decay,
|
|
134
|
+
)
|
|
135
|
+
self._state.scheduler = (
|
|
136
|
+
torch.optim.lr_scheduler.ExponentialLR(
|
|
137
|
+
self._state.optimizer, gamma=self.settings.scheduler_gamma
|
|
138
|
+
)
|
|
139
|
+
if self.settings.scheduler_gamma > 0.0
|
|
140
|
+
else None
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def _validate_loss_keys(self, output_keys: list[str]):
|
|
144
|
+
if isinstance(self.settings.loss_fn, dict):
|
|
145
|
+
missing = [k for k in output_keys if k not in self.settings.loss_fn]
|
|
146
|
+
if missing:
|
|
147
|
+
raise ValueError(f"Missing loss function(s) for output keys: {missing}")
|
|
148
|
+
|
|
149
|
+
def _to_tensor(self: P, data: np.ndarray) -> torch.Tensor:
|
|
150
|
+
dtype = torch.float32 if self.settings.single_precision else torch.float64
|
|
151
|
+
if isinstance(data, torch.Tensor):
|
|
152
|
+
return data.detach().clone().to(device=self._state.device, dtype=dtype)
|
|
153
|
+
return torch.tensor(data, dtype=dtype, device=self._state.device)
|
|
154
|
+
|
|
155
|
+
def save_checkpoint(self: P, path: str) -> None:
|
|
156
|
+
"""Save the current model state to a checkpoint file."""
|
|
157
|
+
if self._state.model is None:
|
|
158
|
+
raise RuntimeError("Model must be initialized before saving a checkpoint.")
|
|
159
|
+
|
|
160
|
+
checkpoint = {
|
|
161
|
+
"model_state_dict": self._state.model.state_dict(),
|
|
162
|
+
"config": self.settings.model_kwargs or {},
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
# Add optimizer state if available
|
|
166
|
+
if hasattr(self._state, "optimizer") and self._state.optimizer is not None:
|
|
167
|
+
checkpoint["optimizer_state_dict"] = self._state.optimizer.state_dict()
|
|
168
|
+
|
|
169
|
+
torch.save(checkpoint, path)
|
|
170
|
+
|
|
171
|
+
def _ensure_batched(self, tensor: torch.Tensor) -> tuple[torch.Tensor, bool]:
|
|
172
|
+
"""
|
|
173
|
+
Ensure tensor has a batch dimension.
|
|
174
|
+
Returns the potentially modified tensor and a flag indicating whether a dimension was added.
|
|
175
|
+
"""
|
|
176
|
+
if tensor.ndim == 2:
|
|
177
|
+
return tensor.unsqueeze(0), True
|
|
178
|
+
return tensor, False
|
|
179
|
+
|
|
180
|
+
def _common_process(self: P, message: AxisArray) -> list[AxisArray]:
|
|
181
|
+
data = message.data
|
|
182
|
+
data = self._to_tensor(data)
|
|
183
|
+
|
|
184
|
+
# Add batch dimension if missing
|
|
185
|
+
data, added_batch_dim = self._ensure_batched(data)
|
|
186
|
+
|
|
187
|
+
with torch.no_grad():
|
|
188
|
+
output = self._state.model(data)
|
|
189
|
+
|
|
190
|
+
if isinstance(output, dict):
|
|
191
|
+
output_messages = [
|
|
192
|
+
replace(
|
|
193
|
+
message,
|
|
194
|
+
data=value.cpu().numpy().squeeze(0)
|
|
195
|
+
if added_batch_dim
|
|
196
|
+
else value.cpu().numpy(),
|
|
197
|
+
axes={
|
|
198
|
+
**message.axes,
|
|
199
|
+
"ch": self._state.chan_ax[key],
|
|
200
|
+
},
|
|
201
|
+
key=key,
|
|
202
|
+
)
|
|
203
|
+
for key, value in output.items()
|
|
204
|
+
]
|
|
205
|
+
return output_messages
|
|
206
|
+
|
|
207
|
+
return [
|
|
208
|
+
replace(
|
|
209
|
+
message,
|
|
210
|
+
data=output.cpu().numpy().squeeze(0)
|
|
211
|
+
if added_batch_dim
|
|
212
|
+
else output.cpu().numpy(),
|
|
213
|
+
axes={
|
|
214
|
+
**message.axes,
|
|
215
|
+
"ch": self._state.chan_ax["output"],
|
|
216
|
+
},
|
|
217
|
+
)
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
def _common_reset_state(self: P, message: AxisArray, model_kwargs: dict) -> None:
|
|
221
|
+
n_input = message.data.shape[message.get_axis_idx("ch")]
|
|
222
|
+
|
|
223
|
+
if "input_size" in model_kwargs:
|
|
224
|
+
if model_kwargs["input_size"] != n_input:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
f"Mismatch between model_kwargs['input_size']={model_kwargs['input_size']} "
|
|
227
|
+
f"and input data channels={n_input}"
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
model_kwargs["input_size"] = n_input
|
|
231
|
+
|
|
232
|
+
device = (
|
|
233
|
+
"cuda"
|
|
234
|
+
if torch.cuda.is_available()
|
|
235
|
+
else ("mps" if torch.mps.is_available() else "cpu")
|
|
236
|
+
)
|
|
237
|
+
device = self.settings.device or device
|
|
238
|
+
self._state.device = torch.device(device)
|
|
239
|
+
|
|
240
|
+
model_class = self._import_model(self.settings.model_class)
|
|
241
|
+
|
|
242
|
+
self._state.model = self._init_model(
|
|
243
|
+
model_class=model_class,
|
|
244
|
+
params=model_kwargs,
|
|
245
|
+
config_path=self.settings.config_path,
|
|
246
|
+
checkpoint_path=self.settings.checkpoint_path,
|
|
247
|
+
device=device,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self._state.model.eval()
|
|
251
|
+
|
|
252
|
+
output_sizes = self._infer_output_sizes(self._state.model, n_input)
|
|
253
|
+
self._state.chan_ax = {
|
|
254
|
+
head: AxisArray.CoordinateAxis(
|
|
255
|
+
data=np.array([f"{head}_ch{_}" for _ in range(size)]),
|
|
256
|
+
dims=["ch"],
|
|
257
|
+
)
|
|
258
|
+
for head, size in output_sizes.items()
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class TorchSimpleProcessor(
|
|
263
|
+
BaseStatefulTransformer[
|
|
264
|
+
TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState
|
|
265
|
+
],
|
|
266
|
+
TorchProcessorMixin,
|
|
267
|
+
ModelInitMixin,
|
|
268
|
+
):
|
|
269
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
270
|
+
model_kwargs = dict(self.settings.model_kwargs or {})
|
|
271
|
+
self._common_reset_state(message, model_kwargs)
|
|
272
|
+
|
|
273
|
+
def _process(self, message: AxisArray) -> list[AxisArray]:
|
|
274
|
+
"""Process the input message and return the output messages."""
|
|
275
|
+
return self._common_process(message)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class TorchSimpleUnit(
|
|
279
|
+
BaseTransformerUnit[
|
|
280
|
+
TorchSimpleSettings,
|
|
281
|
+
AxisArray,
|
|
282
|
+
AxisArray,
|
|
283
|
+
TorchSimpleProcessor,
|
|
284
|
+
]
|
|
285
|
+
):
|
|
286
|
+
SETTINGS = TorchSimpleSettings
|
|
287
|
+
|
|
288
|
+
@ez.subscriber(BaseTransformerUnit.INPUT_SIGNAL, zero_copy=True)
|
|
289
|
+
@ez.publisher(BaseTransformerUnit.OUTPUT_SIGNAL)
|
|
290
|
+
@profile_subpub(trace_oldest=False)
|
|
291
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
292
|
+
results = await self.processor.__acall__(message)
|
|
293
|
+
for result in results:
|
|
294
|
+
yield self.OUTPUT_SIGNAL, result
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class TorchModelProcessor(
|
|
298
|
+
BaseAdaptiveTransformer[TorchModelSettings, AxisArray, AxisArray, TorchModelState],
|
|
299
|
+
TorchProcessorMixin,
|
|
300
|
+
ModelInitMixin,
|
|
301
|
+
):
|
|
302
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
303
|
+
model_kwargs = dict(self.settings.model_kwargs or {})
|
|
304
|
+
self._common_reset_state(message, model_kwargs)
|
|
305
|
+
self._init_optimizer()
|
|
306
|
+
self._validate_loss_keys(list(self._state.chan_ax.keys()))
|
|
307
|
+
|
|
308
|
+
def _process(self, message: AxisArray) -> list[AxisArray]:
|
|
309
|
+
return self._common_process(message)
|
|
310
|
+
|
|
311
|
+
def partial_fit(self, message: SampleMessage) -> None:
|
|
312
|
+
self._state.model.train()
|
|
313
|
+
|
|
314
|
+
X = self._to_tensor(message.sample.data)
|
|
315
|
+
X, batched = self._ensure_batched(X)
|
|
316
|
+
|
|
317
|
+
y_targ = message.trigger.value
|
|
318
|
+
if not isinstance(y_targ, dict):
|
|
319
|
+
y_targ = {"output": y_targ}
|
|
320
|
+
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}
|
|
321
|
+
if batched:
|
|
322
|
+
for key in y_targ:
|
|
323
|
+
y_targ[key] = y_targ[key].unsqueeze(0)
|
|
324
|
+
|
|
325
|
+
loss_fns = self.settings.loss_fn
|
|
326
|
+
if loss_fns is None:
|
|
327
|
+
raise ValueError("loss_fn must be provided in settings to use partial_fit")
|
|
328
|
+
if not isinstance(loss_fns, dict):
|
|
329
|
+
loss_fns = {k: loss_fns for k in y_targ.keys()}
|
|
330
|
+
|
|
331
|
+
weights = self.settings.loss_weights or {}
|
|
332
|
+
|
|
333
|
+
with torch.set_grad_enabled(True):
|
|
334
|
+
y_pred = self._state.model(X)
|
|
335
|
+
if not isinstance(y_pred, dict):
|
|
336
|
+
y_pred = {"output": y_pred}
|
|
337
|
+
|
|
338
|
+
losses = []
|
|
339
|
+
for key in y_targ.keys():
|
|
340
|
+
loss_fn = loss_fns.get(key)
|
|
341
|
+
if loss_fn is None:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
f"Loss function for key '{key}' is not defined in settings."
|
|
344
|
+
)
|
|
345
|
+
if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
|
|
346
|
+
loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long())
|
|
347
|
+
else:
|
|
348
|
+
loss = loss_fn(y_pred[key], y_targ[key])
|
|
349
|
+
weight = weights.get(key, 1.0)
|
|
350
|
+
losses.append(loss * weight)
|
|
351
|
+
total_loss = sum(losses)
|
|
352
|
+
|
|
353
|
+
self._state.optimizer.zero_grad()
|
|
354
|
+
total_loss.backward()
|
|
355
|
+
self._state.optimizer.step()
|
|
356
|
+
if self._state.scheduler is not None:
|
|
357
|
+
self._state.scheduler.step()
|
|
358
|
+
|
|
359
|
+
self._state.model.eval()
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class TorchModelUnit(
|
|
363
|
+
BaseAdaptiveTransformerUnit[
|
|
364
|
+
TorchModelSettings,
|
|
365
|
+
AxisArray,
|
|
366
|
+
AxisArray,
|
|
367
|
+
TorchModelProcessor,
|
|
368
|
+
]
|
|
369
|
+
):
|
|
370
|
+
SETTINGS = TorchModelSettings
|
|
371
|
+
|
|
372
|
+
@ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True)
|
|
373
|
+
@ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL)
|
|
374
|
+
@profile_subpub(trace_oldest=False)
|
|
375
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
376
|
+
results = await self.processor.__acall__(message)
|
|
377
|
+
for result in results:
|
|
378
|
+
yield self.OUTPUT_SIGNAL, result
|