careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,632 @@
|
|
|
1
|
+
"""CAREamics Lightning module."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pytorch_lightning as L
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
|
|
9
|
+
from careamics.config import FCNAlgorithmConfig, VAEAlgorithmConfig
|
|
10
|
+
from careamics.config.support import (
|
|
11
|
+
SupportedAlgorithm,
|
|
12
|
+
SupportedArchitecture,
|
|
13
|
+
SupportedLoss,
|
|
14
|
+
SupportedOptimizer,
|
|
15
|
+
SupportedScheduler,
|
|
16
|
+
)
|
|
17
|
+
from careamics.losses import loss_factory
|
|
18
|
+
from careamics.losses.loss_factory import LVAELossParameters
|
|
19
|
+
from careamics.models.lvae.likelihoods import (
|
|
20
|
+
GaussianLikelihood,
|
|
21
|
+
NoiseModelLikelihood,
|
|
22
|
+
likelihood_factory,
|
|
23
|
+
)
|
|
24
|
+
from careamics.models.lvae.noise_models import (
|
|
25
|
+
GaussianMixtureNoiseModel,
|
|
26
|
+
MultiChannelNoiseModel,
|
|
27
|
+
noise_model_factory,
|
|
28
|
+
)
|
|
29
|
+
from careamics.models.model_factory import model_factory
|
|
30
|
+
from careamics.transforms import Denormalize, ImageRestorationTTA
|
|
31
|
+
from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr
|
|
32
|
+
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
33
|
+
|
|
34
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class FCNModule(L.LightningModule):
|
|
38
|
+
"""
|
|
39
|
+
CAREamics Lightning module.
|
|
40
|
+
|
|
41
|
+
This class encapsulates the PyTorch model along with the training, validation,
|
|
42
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
algorithm_config : AlgorithmModel or dict
|
|
47
|
+
Algorithm configuration.
|
|
48
|
+
|
|
49
|
+
Attributes
|
|
50
|
+
----------
|
|
51
|
+
model : torch.nn.Module
|
|
52
|
+
PyTorch model.
|
|
53
|
+
loss_func : torch.nn.Module
|
|
54
|
+
Loss function.
|
|
55
|
+
optimizer_name : str
|
|
56
|
+
Optimizer name.
|
|
57
|
+
optimizer_params : dict
|
|
58
|
+
Optimizer parameters.
|
|
59
|
+
lr_scheduler_name : str
|
|
60
|
+
Learning rate scheduler name.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None:
|
|
64
|
+
"""Lightning module for CAREamics.
|
|
65
|
+
|
|
66
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
67
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
algorithm_config : AlgorithmModel or dict
|
|
72
|
+
Algorithm configuration.
|
|
73
|
+
"""
|
|
74
|
+
super().__init__()
|
|
75
|
+
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
76
|
+
if isinstance(algorithm_config, dict):
|
|
77
|
+
algorithm_config = FCNAlgorithmConfig(**algorithm_config)
|
|
78
|
+
|
|
79
|
+
# create model and loss function
|
|
80
|
+
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
81
|
+
self.loss_func = loss_factory(algorithm_config.loss)
|
|
82
|
+
|
|
83
|
+
# save optimizer and lr_scheduler names and parameters
|
|
84
|
+
self.optimizer_name = algorithm_config.optimizer.name
|
|
85
|
+
self.optimizer_params = algorithm_config.optimizer.parameters
|
|
86
|
+
self.lr_scheduler_name = algorithm_config.lr_scheduler.name
|
|
87
|
+
self.lr_scheduler_params = algorithm_config.lr_scheduler.parameters
|
|
88
|
+
|
|
89
|
+
def forward(self, x: Any) -> Any:
|
|
90
|
+
"""Forward pass.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
x : Any
|
|
95
|
+
Input tensor.
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
Any
|
|
100
|
+
Output tensor.
|
|
101
|
+
"""
|
|
102
|
+
return self.model(x)
|
|
103
|
+
|
|
104
|
+
def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
105
|
+
"""Training step.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
batch : torch.Tensor
|
|
110
|
+
Input batch.
|
|
111
|
+
batch_idx : Any
|
|
112
|
+
Batch index.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
Any
|
|
117
|
+
Loss value.
|
|
118
|
+
"""
|
|
119
|
+
# TODO can N2V be simplified by returning mask*original_patch
|
|
120
|
+
x, *aux = batch
|
|
121
|
+
out = self.model(x)
|
|
122
|
+
loss = self.loss_func(out, *aux)
|
|
123
|
+
self.log(
|
|
124
|
+
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
|
125
|
+
)
|
|
126
|
+
return loss
|
|
127
|
+
|
|
128
|
+
def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
|
|
129
|
+
"""Validation step.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
batch : torch.Tensor
|
|
134
|
+
Input batch.
|
|
135
|
+
batch_idx : Any
|
|
136
|
+
Batch index.
|
|
137
|
+
"""
|
|
138
|
+
x, *aux = batch
|
|
139
|
+
out = self.model(x)
|
|
140
|
+
val_loss = self.loss_func(out, *aux)
|
|
141
|
+
|
|
142
|
+
# log validation loss
|
|
143
|
+
self.log(
|
|
144
|
+
"val_loss",
|
|
145
|
+
val_loss,
|
|
146
|
+
on_step=False,
|
|
147
|
+
on_epoch=True,
|
|
148
|
+
prog_bar=True,
|
|
149
|
+
logger=True,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
153
|
+
"""Prediction step.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
batch : torch.Tensor
|
|
158
|
+
Input batch.
|
|
159
|
+
batch_idx : Any
|
|
160
|
+
Batch index.
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
Any
|
|
165
|
+
Model output.
|
|
166
|
+
"""
|
|
167
|
+
if self._trainer.datamodule.tiled:
|
|
168
|
+
x, *aux = batch
|
|
169
|
+
else:
|
|
170
|
+
x = batch
|
|
171
|
+
aux = []
|
|
172
|
+
|
|
173
|
+
# apply test-time augmentation if available
|
|
174
|
+
# TODO: probably wont work with batch size > 1
|
|
175
|
+
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
176
|
+
tta = ImageRestorationTTA()
|
|
177
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
178
|
+
augmented_output = []
|
|
179
|
+
for augmented in augmented_batch:
|
|
180
|
+
augmented_pred = self.model(augmented)
|
|
181
|
+
augmented_output.append(augmented_pred)
|
|
182
|
+
output = tta.backward(augmented_output)
|
|
183
|
+
else:
|
|
184
|
+
output = self.model(x)
|
|
185
|
+
|
|
186
|
+
# Denormalize the output
|
|
187
|
+
denorm = Denormalize(
|
|
188
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
189
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
190
|
+
)
|
|
191
|
+
denormalized_output = denorm(patch=output.cpu().numpy())
|
|
192
|
+
|
|
193
|
+
if len(aux) > 0: # aux can be tiling information
|
|
194
|
+
return denormalized_output, *aux
|
|
195
|
+
else:
|
|
196
|
+
return denormalized_output
|
|
197
|
+
|
|
198
|
+
def configure_optimizers(self) -> Any:
|
|
199
|
+
"""Configure optimizers and learning rate schedulers.
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
Any
|
|
204
|
+
Optimizer and learning rate scheduler.
|
|
205
|
+
"""
|
|
206
|
+
# instantiate optimizer
|
|
207
|
+
optimizer_func = get_optimizer(self.optimizer_name)
|
|
208
|
+
optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
|
|
209
|
+
|
|
210
|
+
# and scheduler
|
|
211
|
+
scheduler_func = get_scheduler(self.lr_scheduler_name)
|
|
212
|
+
scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
|
|
213
|
+
|
|
214
|
+
return {
|
|
215
|
+
"optimizer": optimizer,
|
|
216
|
+
"lr_scheduler": scheduler,
|
|
217
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class VAEModule(L.LightningModule):
|
|
222
|
+
"""
|
|
223
|
+
CAREamics Lightning module.
|
|
224
|
+
|
|
225
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
226
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
algorithm_config : Union[VAEAlgorithmConfig, dict]
|
|
231
|
+
Algorithm configuration.
|
|
232
|
+
|
|
233
|
+
Attributes
|
|
234
|
+
----------
|
|
235
|
+
model : nn.Module
|
|
236
|
+
PyTorch model.
|
|
237
|
+
loss_func : nn.Module
|
|
238
|
+
Loss function.
|
|
239
|
+
optimizer_name : str
|
|
240
|
+
Optimizer name.
|
|
241
|
+
optimizer_params : dict
|
|
242
|
+
Optimizer parameters.
|
|
243
|
+
lr_scheduler_name : str
|
|
244
|
+
Learning rate scheduler name.
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None:
|
|
248
|
+
"""Lightning module for CAREamics.
|
|
249
|
+
|
|
250
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
251
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
algorithm_config : Union[AlgorithmModel, dict]
|
|
256
|
+
Algorithm configuration.
|
|
257
|
+
"""
|
|
258
|
+
super().__init__()
|
|
259
|
+
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
260
|
+
self.algorithm_config = (
|
|
261
|
+
VAEAlgorithmConfig(**algorithm_config)
|
|
262
|
+
if isinstance(algorithm_config, dict)
|
|
263
|
+
else algorithm_config
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# TODO: log algorithm config
|
|
267
|
+
# self.save_hyperparameters(self.algorithm_config.model_dump())
|
|
268
|
+
|
|
269
|
+
# create model and loss function
|
|
270
|
+
self.model: nn.Module = model_factory(self.algorithm_config.model)
|
|
271
|
+
self.noise_model: NoiseModel = noise_model_factory(
|
|
272
|
+
self.algorithm_config.noise_model
|
|
273
|
+
)
|
|
274
|
+
self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
|
|
275
|
+
self.algorithm_config.noise_model_likelihood_model
|
|
276
|
+
)
|
|
277
|
+
self.gaussian_likelihood: GaussianLikelihood = likelihood_factory(
|
|
278
|
+
self.algorithm_config.gaussian_likelihood_model
|
|
279
|
+
)
|
|
280
|
+
self.loss_parameters = LVAELossParameters(
|
|
281
|
+
noise_model_likelihood=self.noise_model_likelihood,
|
|
282
|
+
gaussian_likelihood=self.gaussian_likelihood,
|
|
283
|
+
# TODO: musplit/denoisplit weights ?
|
|
284
|
+
) # type: ignore
|
|
285
|
+
self.loss_func = loss_factory(self.algorithm_config.loss)
|
|
286
|
+
|
|
287
|
+
# save optimizer and lr_scheduler names and parameters
|
|
288
|
+
self.optimizer_name = self.algorithm_config.optimizer.name
|
|
289
|
+
self.optimizer_params = self.algorithm_config.optimizer.parameters
|
|
290
|
+
self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name
|
|
291
|
+
self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters
|
|
292
|
+
|
|
293
|
+
# initialize running PSNR
|
|
294
|
+
self.running_psnr = [
|
|
295
|
+
RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
|
|
296
|
+
]
|
|
297
|
+
|
|
298
|
+
def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
|
|
299
|
+
"""Forward pass.
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
x : Tensor
|
|
304
|
+
Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
305
|
+
number of lateral inputs.
|
|
306
|
+
|
|
307
|
+
Returns
|
|
308
|
+
-------
|
|
309
|
+
tuple[Tensor, dict[str, Any]]
|
|
310
|
+
A tuple with the output tensor and additional data from the top-down pass.
|
|
311
|
+
"""
|
|
312
|
+
return self.model(x) # TODO Different model can have more than one output
|
|
313
|
+
|
|
314
|
+
def training_step(
|
|
315
|
+
self, batch: tuple[Tensor, Tensor], batch_idx: Any
|
|
316
|
+
) -> Optional[dict[str, Tensor]]:
|
|
317
|
+
"""Training step.
|
|
318
|
+
|
|
319
|
+
Parameters
|
|
320
|
+
----------
|
|
321
|
+
batch : tuple[Tensor, Tensor]
|
|
322
|
+
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
323
|
+
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
324
|
+
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
325
|
+
where C is the number of target channels (e.g., 1 in HDN, >1 in
|
|
326
|
+
muSplit/denoiSplit).
|
|
327
|
+
batch_idx : Any
|
|
328
|
+
Batch index.
|
|
329
|
+
|
|
330
|
+
Returns
|
|
331
|
+
-------
|
|
332
|
+
Any
|
|
333
|
+
Loss value.
|
|
334
|
+
"""
|
|
335
|
+
x, target = batch
|
|
336
|
+
|
|
337
|
+
# Forward pass
|
|
338
|
+
out = self.model(x)
|
|
339
|
+
|
|
340
|
+
# Update loss parameters
|
|
341
|
+
# TODO rethink loss parameters
|
|
342
|
+
self.loss_parameters.current_epoch = self.current_epoch
|
|
343
|
+
|
|
344
|
+
# Compute loss
|
|
345
|
+
loss = self.loss_func(out, target, self.loss_parameters) # TODO ugly ?
|
|
346
|
+
|
|
347
|
+
# Logging
|
|
348
|
+
# TODO: implement a separate logging method?
|
|
349
|
+
self.log_dict(loss, on_step=True, on_epoch=True)
|
|
350
|
+
# self.log("lr", self, on_epoch=True)
|
|
351
|
+
return loss
|
|
352
|
+
|
|
353
|
+
def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: Any) -> None:
|
|
354
|
+
"""Validation step.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
batch : tuple[Tensor, Tensor]
|
|
359
|
+
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
360
|
+
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
361
|
+
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
362
|
+
where C is the number of target channels (e.g., 1 in HDN, >1 in
|
|
363
|
+
muSplit/denoiSplit).
|
|
364
|
+
batch_idx : Any
|
|
365
|
+
Batch index.
|
|
366
|
+
"""
|
|
367
|
+
x, target = batch
|
|
368
|
+
|
|
369
|
+
# Forward pass
|
|
370
|
+
out = self.model(x)
|
|
371
|
+
|
|
372
|
+
# Compute loss
|
|
373
|
+
loss = self.loss_func(out, target, self.loss_parameters)
|
|
374
|
+
|
|
375
|
+
# Logging
|
|
376
|
+
# Rename val_loss dict
|
|
377
|
+
loss = {"_".join(["val", k]): v for k, v in loss.items()}
|
|
378
|
+
self.log_dict(loss, on_epoch=True, prog_bar=True)
|
|
379
|
+
curr_psnr = self.compute_val_psnr(out, target)
|
|
380
|
+
for i, psnr in enumerate(curr_psnr):
|
|
381
|
+
self.log(f"val_psnr_ch{i+1}_batch", psnr, on_epoch=True)
|
|
382
|
+
|
|
383
|
+
def on_validation_epoch_end(self) -> None:
|
|
384
|
+
"""Validation epoch end."""
|
|
385
|
+
psnr_ = self.reduce_running_psnr()
|
|
386
|
+
if psnr_ is not None:
|
|
387
|
+
self.log("val_psnr", psnr_, on_epoch=True, prog_bar=True)
|
|
388
|
+
else:
|
|
389
|
+
self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
|
|
390
|
+
|
|
391
|
+
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
392
|
+
"""Prediction step.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
batch : Tensor
|
|
397
|
+
Input batch.
|
|
398
|
+
batch_idx : Any
|
|
399
|
+
Batch index.
|
|
400
|
+
|
|
401
|
+
Returns
|
|
402
|
+
-------
|
|
403
|
+
Any
|
|
404
|
+
Model output.
|
|
405
|
+
"""
|
|
406
|
+
if self._trainer.datamodule.tiled:
|
|
407
|
+
x, *aux = batch
|
|
408
|
+
else:
|
|
409
|
+
x = batch
|
|
410
|
+
aux = []
|
|
411
|
+
|
|
412
|
+
# apply test-time augmentation if available
|
|
413
|
+
# TODO: probably wont work with batch size > 1
|
|
414
|
+
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
415
|
+
tta = ImageRestorationTTA()
|
|
416
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
417
|
+
augmented_output = []
|
|
418
|
+
for augmented in augmented_batch:
|
|
419
|
+
augmented_pred = self.model(augmented)
|
|
420
|
+
augmented_output.append(augmented_pred)
|
|
421
|
+
output = tta.backward(augmented_output)
|
|
422
|
+
else:
|
|
423
|
+
output = self.model(x)
|
|
424
|
+
|
|
425
|
+
# Denormalize the output
|
|
426
|
+
denorm = Denormalize(
|
|
427
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
428
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
429
|
+
)
|
|
430
|
+
denormalized_output = denorm(patch=output.cpu().numpy())
|
|
431
|
+
|
|
432
|
+
if len(aux) > 0: # aux can be tiling information
|
|
433
|
+
return denormalized_output, *aux
|
|
434
|
+
else:
|
|
435
|
+
return denormalized_output
|
|
436
|
+
|
|
437
|
+
def configure_optimizers(self) -> Any:
|
|
438
|
+
"""Configure optimizers and learning rate schedulers.
|
|
439
|
+
|
|
440
|
+
Returns
|
|
441
|
+
-------
|
|
442
|
+
Any
|
|
443
|
+
Optimizer and learning rate scheduler.
|
|
444
|
+
"""
|
|
445
|
+
# instantiate optimizer
|
|
446
|
+
optimizer_func = get_optimizer(self.optimizer_name)
|
|
447
|
+
optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
|
|
448
|
+
|
|
449
|
+
# and scheduler
|
|
450
|
+
scheduler_func = get_scheduler(self.lr_scheduler_name)
|
|
451
|
+
scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
|
|
452
|
+
|
|
453
|
+
return {
|
|
454
|
+
"optimizer": optimizer,
|
|
455
|
+
"lr_scheduler": scheduler,
|
|
456
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
# TODO: find a way to move the following methods to a separate module
|
|
460
|
+
# TODO: this same operation is done in many other places, like in loss_func
|
|
461
|
+
# should we refactor LadderVAE so that it already outputs
|
|
462
|
+
# tuple(`mean`, `logvar`, `td_data`)?
|
|
463
|
+
def get_reconstructed_tensor(
|
|
464
|
+
self, model_outputs: tuple[Tensor, dict[str, Any]]
|
|
465
|
+
) -> Tensor:
|
|
466
|
+
"""Get the reconstructed tensor from the LVAE model outputs.
|
|
467
|
+
|
|
468
|
+
Parameters
|
|
469
|
+
----------
|
|
470
|
+
model_outputs : tuple[Tensor, dict[str, Any]]
|
|
471
|
+
Model outputs. It is a tuple with a tensor representing the predicted mean
|
|
472
|
+
and (optionally) logvar, and the top-down data dictionary.
|
|
473
|
+
|
|
474
|
+
Returns
|
|
475
|
+
-------
|
|
476
|
+
Tensor
|
|
477
|
+
Reconstructed tensor, i.e., the predicted mean.
|
|
478
|
+
"""
|
|
479
|
+
predictions, _ = model_outputs
|
|
480
|
+
if self.model.predict_logvar is None:
|
|
481
|
+
return predictions
|
|
482
|
+
elif self.model.predict_logvar == "pixelwise":
|
|
483
|
+
return predictions.chunk(2, dim=1)[0]
|
|
484
|
+
|
|
485
|
+
def compute_val_psnr(
|
|
486
|
+
self,
|
|
487
|
+
model_output: tuple[Tensor, dict[str, Any]],
|
|
488
|
+
target: Tensor,
|
|
489
|
+
psnr_func: Callable = scale_invariant_psnr,
|
|
490
|
+
) -> list[float]:
|
|
491
|
+
"""Compute the PSNR for the current validation batch.
|
|
492
|
+
|
|
493
|
+
Parameters
|
|
494
|
+
----------
|
|
495
|
+
model_output : tuple[Tensor, dict[str, Any]]
|
|
496
|
+
Model output, a tuple with the predicted mean and (optionally) logvar,
|
|
497
|
+
and the top-down data dictionary.
|
|
498
|
+
target : Tensor
|
|
499
|
+
Target tensor.
|
|
500
|
+
psnr_func : Callable, optional
|
|
501
|
+
PSNR function to use, by default `scale_invariant_psnr`.
|
|
502
|
+
|
|
503
|
+
Returns
|
|
504
|
+
-------
|
|
505
|
+
list[float]
|
|
506
|
+
PSNR for each channel in the current batch.
|
|
507
|
+
"""
|
|
508
|
+
out_channels = target.shape[1]
|
|
509
|
+
|
|
510
|
+
# get the reconstructed image
|
|
511
|
+
recons_img = self.get_reconstructed_tensor(model_output)
|
|
512
|
+
|
|
513
|
+
# update running psnr
|
|
514
|
+
for i in range(out_channels):
|
|
515
|
+
self.running_psnr[i].update(rec=recons_img[:, i], tar=target[:, i])
|
|
516
|
+
|
|
517
|
+
# compute psnr for each channel in the current batch
|
|
518
|
+
# TODO: this doesn't need do be a method of this class
|
|
519
|
+
# and hence can be moved to a separate module
|
|
520
|
+
return [
|
|
521
|
+
psnr_func(
|
|
522
|
+
gt=target[:, i].clone().detach().cpu().numpy(),
|
|
523
|
+
pred=recons_img[:, i].clone().detach().cpu().numpy(),
|
|
524
|
+
)
|
|
525
|
+
for i in range(out_channels)
|
|
526
|
+
]
|
|
527
|
+
|
|
528
|
+
def reduce_running_psnr(self) -> Optional[float]:
|
|
529
|
+
"""Reduce the running PSNR statistics and reset the running PSNR.
|
|
530
|
+
|
|
531
|
+
Returns
|
|
532
|
+
-------
|
|
533
|
+
Optional[float]
|
|
534
|
+
Running PSNR averaged over the different output channels.
|
|
535
|
+
"""
|
|
536
|
+
psnr_arr = [] # type: ignore
|
|
537
|
+
for i in range(len(self.running_psnr)):
|
|
538
|
+
psnr = self.running_psnr[i].get()
|
|
539
|
+
if psnr is None:
|
|
540
|
+
psnr_arr = None # type: ignore
|
|
541
|
+
break
|
|
542
|
+
psnr_arr.append(psnr.cpu().numpy())
|
|
543
|
+
self.running_psnr[i].reset()
|
|
544
|
+
# TODO: this line forces it to be a method of this class
|
|
545
|
+
# alternative is returning also the reset `running_psnr`
|
|
546
|
+
if psnr_arr is not None:
|
|
547
|
+
psnr = np.mean(psnr_arr)
|
|
548
|
+
return psnr
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
# TODO: make this LVAE compatible (?)
|
|
552
|
+
def create_careamics_module(
|
|
553
|
+
algorithm_type: Literal["fcn"],
|
|
554
|
+
algorithm: Union[SupportedAlgorithm, str],
|
|
555
|
+
loss: Union[SupportedLoss, str],
|
|
556
|
+
architecture: Union[SupportedArchitecture, str],
|
|
557
|
+
model_parameters: Optional[dict] = None,
|
|
558
|
+
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
559
|
+
optimizer_parameters: Optional[dict] = None,
|
|
560
|
+
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
561
|
+
lr_scheduler_parameters: Optional[dict] = None,
|
|
562
|
+
) -> Union[FCNModule, VAEModule]:
|
|
563
|
+
"""Create a CAREamics Lightning module.
|
|
564
|
+
|
|
565
|
+
This function exposes parameters used to create an AlgorithmModel instance,
|
|
566
|
+
triggering parameters validation.
|
|
567
|
+
|
|
568
|
+
Parameters
|
|
569
|
+
----------
|
|
570
|
+
algorithm_type : Literal["fcn"]
|
|
571
|
+
Algorithm type to use for training.
|
|
572
|
+
algorithm : SupportedAlgorithm or str
|
|
573
|
+
Algorithm to use for training (see SupportedAlgorithm).
|
|
574
|
+
loss : SupportedLoss or str
|
|
575
|
+
Loss function to use for training (see SupportedLoss).
|
|
576
|
+
architecture : SupportedArchitecture or str
|
|
577
|
+
Model architecture to use for training (see SupportedArchitecture).
|
|
578
|
+
model_parameters : dict, optional
|
|
579
|
+
Model parameters to use for training, by default {}. Model parameters are
|
|
580
|
+
defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
|
|
581
|
+
`careamics.config.architectures`).
|
|
582
|
+
optimizer : SupportedOptimizer or str, optional
|
|
583
|
+
Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
|
|
584
|
+
optimizer_parameters : dict, optional
|
|
585
|
+
Optimizer parameters to use for training, as defined in `torch.optim`, by
|
|
586
|
+
default {}.
|
|
587
|
+
lr_scheduler : SupportedScheduler or str, optional
|
|
588
|
+
Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
|
|
589
|
+
(see SupportedScheduler).
|
|
590
|
+
lr_scheduler_parameters : dict, optional
|
|
591
|
+
Learning rate scheduler parameters to use for training, as defined in
|
|
592
|
+
`torch.optim`, by default {}.
|
|
593
|
+
|
|
594
|
+
Returns
|
|
595
|
+
-------
|
|
596
|
+
CAREamicsModule
|
|
597
|
+
CAREamics Lightning module.
|
|
598
|
+
"""
|
|
599
|
+
# create a AlgorithmModel compatible dictionary
|
|
600
|
+
if lr_scheduler_parameters is None:
|
|
601
|
+
lr_scheduler_parameters = {}
|
|
602
|
+
if optimizer_parameters is None:
|
|
603
|
+
optimizer_parameters = {}
|
|
604
|
+
if model_parameters is None:
|
|
605
|
+
model_parameters = {}
|
|
606
|
+
algorithm_configuration: dict[str, Any] = {
|
|
607
|
+
"algorithm_type": algorithm_type,
|
|
608
|
+
"algorithm": algorithm,
|
|
609
|
+
"loss": loss,
|
|
610
|
+
"optimizer": {
|
|
611
|
+
"name": optimizer,
|
|
612
|
+
"parameters": optimizer_parameters,
|
|
613
|
+
},
|
|
614
|
+
"lr_scheduler": {
|
|
615
|
+
"name": lr_scheduler,
|
|
616
|
+
"parameters": lr_scheduler_parameters,
|
|
617
|
+
},
|
|
618
|
+
}
|
|
619
|
+
model_configuration = {"architecture": architecture}
|
|
620
|
+
model_configuration.update(model_parameters)
|
|
621
|
+
|
|
622
|
+
# add model parameters to algorithm configuration
|
|
623
|
+
algorithm_configuration["model"] = model_configuration
|
|
624
|
+
|
|
625
|
+
# call the parent init using an AlgorithmModel instance
|
|
626
|
+
if algorithm_configuration["algorithm_type"] == "fcn":
|
|
627
|
+
return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
|
|
628
|
+
else:
|
|
629
|
+
raise NotImplementedError(
|
|
630
|
+
f"Model {algorithm_configuration['model']['architecture']} is not"
|
|
631
|
+
f"implemented or unknown."
|
|
632
|
+
)
|