careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +55 -61
- careamics/cli/conf.py +24 -9
- careamics/cli/main.py +8 -8
- careamics/cli/utils.py +2 -4
- careamics/config/__init__.py +8 -0
- careamics/config/algorithms/__init__.py +4 -0
- careamics/config/algorithms/hdn_algorithm_model.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
- careamics/config/algorithms/n2v_algorithm_model.py +1 -2
- careamics/config/algorithms/vae_algorithm_model.py +53 -18
- careamics/config/architectures/lvae_model.py +12 -8
- careamics/config/callback_model.py +15 -11
- careamics/config/configuration.py +9 -8
- careamics/config/configuration_factories.py +892 -78
- careamics/config/data/data_model.py +7 -14
- careamics/config/data/ng_data_model.py +8 -15
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
- careamics/config/inference_model.py +6 -11
- careamics/config/likelihood_model.py +4 -4
- careamics/config/loss_model.py +6 -2
- careamics/config/nm_model.py +30 -7
- careamics/config/optimizer_models.py +1 -2
- careamics/config/support/supported_algorithms.py +5 -3
- careamics/config/support/supported_losses.py +5 -2
- careamics/config/training_model.py +8 -38
- careamics/config/transformations/normalize_model.py +3 -4
- careamics/config/transformations/xy_flip_model.py +2 -2
- careamics/config/transformations/xy_random_rotate90_model.py +2 -2
- careamics/config/validators/validator_utils.py +1 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
- careamics/dataset/in_memory_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +1 -2
- careamics/dataset/patching/random_patching.py +6 -6
- careamics/dataset/patching/sequential_patching.py +4 -4
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
- careamics/dataset_ng/dataset.py +3 -3
- careamics/dataset_ng/factory.py +19 -19
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
- careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
- careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
- careamics/file_io/read/__init__.py +0 -1
- careamics/lightning/__init__.py +16 -2
- careamics/lightning/callbacks/__init__.py +2 -0
- careamics/lightning/callbacks/data_stats_callback.py +23 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
- careamics/lightning/dataset_ng/data_module.py +43 -43
- careamics/lightning/lightning_module.py +166 -68
- careamics/lightning/microsplit_data_module.py +631 -0
- careamics/lightning/predict_data_module.py +16 -9
- careamics/lightning/train_data_module.py +29 -18
- careamics/losses/__init__.py +7 -1
- careamics/losses/loss_factory.py +9 -1
- careamics/losses/lvae/losses.py +94 -9
- careamics/lvae_training/dataset/__init__.py +8 -8
- careamics/lvae_training/dataset/config.py +56 -44
- careamics/lvae_training/dataset/lc_dataset.py +18 -12
- careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
- careamics/lvae_training/dataset/multich_dataset.py +24 -18
- careamics/lvae_training/dataset/multifile_dataset.py +6 -6
- careamics/model_io/bioimage/model_description.py +12 -11
- careamics/model_io/bmz_io.py +12 -8
- careamics/models/layers.py +5 -5
- careamics/models/lvae/likelihoods.py +30 -14
- careamics/models/lvae/lvae.py +2 -2
- careamics/models/lvae/noise_models.py +20 -14
- careamics/prediction_utils/__init__.py +8 -2
- careamics/prediction_utils/lvae_prediction.py +5 -5
- careamics/prediction_utils/prediction_outputs.py +48 -3
- careamics/prediction_utils/stitch_prediction.py +71 -0
- careamics/transforms/compose.py +9 -9
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/n2v_manipulate_torch.py +4 -4
- careamics/transforms/normalize.py +4 -6
- careamics/transforms/pixel_manipulation.py +6 -8
- careamics/transforms/pixel_manipulation_torch.py +5 -7
- careamics/transforms/xy_flip.py +3 -5
- careamics/transforms/xy_random_rotate90.py +4 -6
- careamics/utils/logging.py +8 -8
- careamics/utils/metrics.py +2 -2
- careamics/utils/plotting.py +1 -3
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
- careamics/dataset/zarr_dataset.py +0 -151
- careamics/file_io/read/zarr.py +0 -60
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Union, overload
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pytorch_lightning as L
|
|
@@ -124,14 +124,14 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
124
124
|
self,
|
|
125
125
|
data_config: NGDataConfig,
|
|
126
126
|
*,
|
|
127
|
-
train_data:
|
|
128
|
-
train_data_target:
|
|
129
|
-
val_data:
|
|
130
|
-
val_data_target:
|
|
131
|
-
pred_data:
|
|
132
|
-
pred_data_target:
|
|
127
|
+
train_data: InputType | None = None,
|
|
128
|
+
train_data_target: InputType | None = None,
|
|
129
|
+
val_data: InputType | None = None,
|
|
130
|
+
val_data_target: InputType | None = None,
|
|
131
|
+
pred_data: InputType | None = None,
|
|
132
|
+
pred_data_target: InputType | None = None,
|
|
133
133
|
extension_filter: str = "",
|
|
134
|
-
val_percentage:
|
|
134
|
+
val_percentage: float | None = None,
|
|
135
135
|
val_minimum_split: int = 5,
|
|
136
136
|
use_in_memory: bool = True,
|
|
137
137
|
) -> None: ...
|
|
@@ -142,16 +142,16 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
142
142
|
self,
|
|
143
143
|
data_config: NGDataConfig,
|
|
144
144
|
*,
|
|
145
|
-
train_data:
|
|
146
|
-
train_data_target:
|
|
147
|
-
val_data:
|
|
148
|
-
val_data_target:
|
|
149
|
-
pred_data:
|
|
150
|
-
pred_data_target:
|
|
145
|
+
train_data: InputType | None = None,
|
|
146
|
+
train_data_target: InputType | None = None,
|
|
147
|
+
val_data: InputType | None = None,
|
|
148
|
+
val_data_target: InputType | None = None,
|
|
149
|
+
pred_data: InputType | None = None,
|
|
150
|
+
pred_data_target: InputType | None = None,
|
|
151
151
|
read_source_func: Callable,
|
|
152
|
-
read_kwargs:
|
|
152
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
153
153
|
extension_filter: str = "",
|
|
154
|
-
val_percentage:
|
|
154
|
+
val_percentage: float | None = None,
|
|
155
155
|
val_minimum_split: int = 5,
|
|
156
156
|
use_in_memory: bool = True,
|
|
157
157
|
) -> None: ...
|
|
@@ -161,16 +161,16 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
161
161
|
self,
|
|
162
162
|
data_config: NGDataConfig,
|
|
163
163
|
*,
|
|
164
|
-
train_data:
|
|
165
|
-
train_data_target:
|
|
166
|
-
val_data:
|
|
167
|
-
val_data_target:
|
|
168
|
-
pred_data:
|
|
169
|
-
pred_data_target:
|
|
164
|
+
train_data: Any | None = None,
|
|
165
|
+
train_data_target: Any | None = None,
|
|
166
|
+
val_data: Any | None = None,
|
|
167
|
+
val_data_target: Any | None = None,
|
|
168
|
+
pred_data: Any | None = None,
|
|
169
|
+
pred_data_target: Any | None = None,
|
|
170
170
|
image_stack_loader: ImageStackLoader,
|
|
171
|
-
image_stack_loader_kwargs:
|
|
171
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
172
172
|
extension_filter: str = "",
|
|
173
|
-
val_percentage:
|
|
173
|
+
val_percentage: float | None = None,
|
|
174
174
|
val_minimum_split: int = 5,
|
|
175
175
|
use_in_memory: bool = True,
|
|
176
176
|
) -> None: ...
|
|
@@ -179,18 +179,18 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
179
179
|
self,
|
|
180
180
|
data_config: NGDataConfig,
|
|
181
181
|
*,
|
|
182
|
-
train_data:
|
|
183
|
-
train_data_target:
|
|
184
|
-
val_data:
|
|
185
|
-
val_data_target:
|
|
186
|
-
pred_data:
|
|
187
|
-
pred_data_target:
|
|
188
|
-
read_source_func:
|
|
189
|
-
read_kwargs:
|
|
190
|
-
image_stack_loader:
|
|
191
|
-
image_stack_loader_kwargs:
|
|
182
|
+
train_data: Any | None = None,
|
|
183
|
+
train_data_target: Any | None = None,
|
|
184
|
+
val_data: Any | None = None,
|
|
185
|
+
val_data_target: Any | None = None,
|
|
186
|
+
pred_data: Any | None = None,
|
|
187
|
+
pred_data_target: Any | None = None,
|
|
188
|
+
read_source_func: Callable | None = None,
|
|
189
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
190
|
+
image_stack_loader: ImageStackLoader | None = None,
|
|
191
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
192
192
|
extension_filter: str = "",
|
|
193
|
-
val_percentage:
|
|
193
|
+
val_percentage: float | None = None,
|
|
194
194
|
val_minimum_split: int = 5,
|
|
195
195
|
use_in_memory: bool = True,
|
|
196
196
|
) -> None:
|
|
@@ -280,7 +280,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
280
280
|
def _validate_input_target_type_consistency(
|
|
281
281
|
self,
|
|
282
282
|
input_data: InputType,
|
|
283
|
-
target_data:
|
|
283
|
+
target_data: InputType | None,
|
|
284
284
|
) -> None:
|
|
285
285
|
"""Validate if the input and target data types are consistent.
|
|
286
286
|
|
|
@@ -314,7 +314,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
314
314
|
self,
|
|
315
315
|
input_data,
|
|
316
316
|
target_data=None,
|
|
317
|
-
) -> tuple[list[Path],
|
|
317
|
+
) -> tuple[list[Path], list[Path] | None]:
|
|
318
318
|
"""List files from input and target directories.
|
|
319
319
|
|
|
320
320
|
Parameters
|
|
@@ -347,7 +347,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
347
347
|
self,
|
|
348
348
|
input_data,
|
|
349
349
|
target_data=None,
|
|
350
|
-
) -> tuple[list[Path],
|
|
350
|
+
) -> tuple[list[Path], list[Path] | None]:
|
|
351
351
|
"""Create a list of file paths from the input and target data.
|
|
352
352
|
|
|
353
353
|
Parameters
|
|
@@ -379,7 +379,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
379
379
|
def _validate_array_input(
|
|
380
380
|
self,
|
|
381
381
|
input_data: InputType,
|
|
382
|
-
target_data:
|
|
382
|
+
target_data: InputType | None,
|
|
383
383
|
) -> tuple[Any, Any]:
|
|
384
384
|
"""Validate if the input data is a numpy array.
|
|
385
385
|
|
|
@@ -408,8 +408,8 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
408
408
|
)
|
|
409
409
|
|
|
410
410
|
def _validate_path_input(
|
|
411
|
-
self, input_data: InputType, target_data:
|
|
412
|
-
) -> tuple[list[Path],
|
|
411
|
+
self, input_data: InputType, target_data: InputType | None
|
|
412
|
+
) -> tuple[list[Path], list[Path] | None]:
|
|
413
413
|
"""Validate if the input data is a path or a list of paths.
|
|
414
414
|
|
|
415
415
|
Parameters
|
|
@@ -488,8 +488,8 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
488
488
|
|
|
489
489
|
def _initialize_data_pair(
|
|
490
490
|
self,
|
|
491
|
-
input_data:
|
|
492
|
-
target_data:
|
|
491
|
+
input_data: InputType | None,
|
|
492
|
+
target_data: InputType | None,
|
|
493
493
|
) -> tuple[Any, Any]:
|
|
494
494
|
"""
|
|
495
495
|
Initialize a pair of input and target data.
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
"""CAREamics Lightning module."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
-
from typing import Any, Literal,
|
|
4
|
+
from typing import Any, Literal, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pytorch_lightning as L
|
|
8
|
-
|
|
8
|
+
import torch
|
|
9
9
|
|
|
10
10
|
from careamics.config import (
|
|
11
11
|
N2VAlgorithm,
|
|
@@ -71,7 +71,9 @@ class FCNModule(L.LightningModule):
|
|
|
71
71
|
Learning rate scheduler name.
|
|
72
72
|
"""
|
|
73
73
|
|
|
74
|
-
def __init__(
|
|
74
|
+
def __init__(
|
|
75
|
+
self, algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm, dict]
|
|
76
|
+
) -> None:
|
|
75
77
|
"""Lightning module for CAREamics.
|
|
76
78
|
|
|
77
79
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
@@ -90,7 +92,7 @@ class FCNModule(L.LightningModule):
|
|
|
90
92
|
# create preprocessing, model and loss function
|
|
91
93
|
if isinstance(algorithm_config, N2VAlgorithm):
|
|
92
94
|
self.use_n2v = True
|
|
93
|
-
self.n2v_preprocess:
|
|
95
|
+
self.n2v_preprocess: N2VManipulateTorch | None = N2VManipulateTorch(
|
|
94
96
|
n2v_manipulate_config=algorithm_config.n2v_config
|
|
95
97
|
)
|
|
96
98
|
else:
|
|
@@ -98,7 +100,7 @@ class FCNModule(L.LightningModule):
|
|
|
98
100
|
self.n2v_preprocess = None
|
|
99
101
|
|
|
100
102
|
self.algorithm = algorithm_config.algorithm
|
|
101
|
-
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
103
|
+
self.model: torch.nn.Module = model_factory(algorithm_config.model)
|
|
102
104
|
self.loss_func = loss_factory(algorithm_config.loss)
|
|
103
105
|
|
|
104
106
|
# save optimizer and lr_scheduler names and parameters
|
|
@@ -122,12 +124,12 @@ class FCNModule(L.LightningModule):
|
|
|
122
124
|
"""
|
|
123
125
|
return self.model(x)
|
|
124
126
|
|
|
125
|
-
def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
127
|
+
def training_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
126
128
|
"""Training step.
|
|
127
129
|
|
|
128
130
|
Parameters
|
|
129
131
|
----------
|
|
130
|
-
batch : torch.Tensor
|
|
132
|
+
batch : torch.torch.Tensor
|
|
131
133
|
Input batch.
|
|
132
134
|
batch_idx : Any
|
|
133
135
|
Batch index.
|
|
@@ -154,12 +156,12 @@ class FCNModule(L.LightningModule):
|
|
|
154
156
|
self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
|
|
155
157
|
return loss
|
|
156
158
|
|
|
157
|
-
def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
|
|
159
|
+
def validation_step(self, batch: torch.Tensor, batch_idx: Any) -> None:
|
|
158
160
|
"""Validation step.
|
|
159
161
|
|
|
160
162
|
Parameters
|
|
161
163
|
----------
|
|
162
|
-
batch : torch.Tensor
|
|
164
|
+
batch : torch.torch.Tensor
|
|
163
165
|
Input batch.
|
|
164
166
|
batch_idx : Any
|
|
165
167
|
Batch index.
|
|
@@ -184,12 +186,12 @@ class FCNModule(L.LightningModule):
|
|
|
184
186
|
logger=True,
|
|
185
187
|
)
|
|
186
188
|
|
|
187
|
-
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
189
|
+
def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
188
190
|
"""Prediction step.
|
|
189
191
|
|
|
190
192
|
Parameters
|
|
191
193
|
----------
|
|
192
|
-
batch : torch.Tensor
|
|
194
|
+
batch : torch.torch.torch.Tensor
|
|
193
195
|
Input batch.
|
|
194
196
|
batch_idx : Any
|
|
195
197
|
Batch index.
|
|
@@ -330,21 +332,23 @@ class VAEModule(L.LightningModule):
|
|
|
330
332
|
# self.save_hyperparameters(self.algorithm_config.model_dump())
|
|
331
333
|
|
|
332
334
|
# create model
|
|
333
|
-
self.model: nn.Module = model_factory(self.algorithm_config.model)
|
|
335
|
+
self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
|
|
334
336
|
|
|
337
|
+
# supervised_mode
|
|
338
|
+
self.supervised_mode = self.algorithm_config.is_supervised
|
|
335
339
|
# create loss function
|
|
336
|
-
self.noise_model:
|
|
340
|
+
self.noise_model: NoiseModel | None = noise_model_factory(
|
|
337
341
|
self.algorithm_config.noise_model
|
|
338
342
|
)
|
|
339
343
|
|
|
340
|
-
self.noise_model_likelihood:
|
|
341
|
-
|
|
344
|
+
self.noise_model_likelihood: NoiseModelLikelihood | None = None
|
|
345
|
+
if self.algorithm_config.noise_model_likelihood is not None:
|
|
346
|
+
self.noise_model_likelihood = likelihood_factory(
|
|
342
347
|
config=self.algorithm_config.noise_model_likelihood,
|
|
343
348
|
noise_model=self.noise_model,
|
|
344
349
|
)
|
|
345
|
-
)
|
|
346
350
|
|
|
347
|
-
self.gaussian_likelihood:
|
|
351
|
+
self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
|
|
348
352
|
self.algorithm_config.gaussian_likelihood
|
|
349
353
|
)
|
|
350
354
|
|
|
@@ -362,30 +366,43 @@ class VAEModule(L.LightningModule):
|
|
|
362
366
|
RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
|
|
363
367
|
]
|
|
364
368
|
|
|
365
|
-
def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
|
|
369
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
|
|
366
370
|
"""Forward pass.
|
|
367
371
|
|
|
368
372
|
Parameters
|
|
369
373
|
----------
|
|
370
|
-
x : Tensor
|
|
374
|
+
x : torch.Tensor
|
|
371
375
|
Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
372
376
|
number of lateral inputs.
|
|
373
377
|
|
|
374
378
|
Returns
|
|
375
379
|
-------
|
|
376
|
-
tuple[Tensor, dict[str, Any]]
|
|
380
|
+
tuple[torch.Tensor, dict[str, Any]]
|
|
377
381
|
A tuple with the output tensor and additional data from the top-down pass.
|
|
378
382
|
"""
|
|
379
383
|
return self.model(x) # TODO Different model can have more than one output
|
|
380
384
|
|
|
385
|
+
def set_data_stats(self, data_mean, data_std):
|
|
386
|
+
"""Set data mean and std for the noise model likelihood.
|
|
387
|
+
|
|
388
|
+
Parameters
|
|
389
|
+
----------
|
|
390
|
+
data_mean : float
|
|
391
|
+
Mean of the data.
|
|
392
|
+
data_std : float
|
|
393
|
+
Standard deviation of the data.
|
|
394
|
+
"""
|
|
395
|
+
if self.noise_model_likelihood is not None:
|
|
396
|
+
self.noise_model_likelihood.set_data_stats(data_mean, data_std)
|
|
397
|
+
|
|
381
398
|
def training_step(
|
|
382
|
-
self, batch: tuple[Tensor, Tensor], batch_idx: Any
|
|
383
|
-
) ->
|
|
399
|
+
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
|
|
400
|
+
) -> dict[str, torch.Tensor] | None:
|
|
384
401
|
"""Training step.
|
|
385
402
|
|
|
386
403
|
Parameters
|
|
387
404
|
----------
|
|
388
|
-
batch : tuple[Tensor, Tensor]
|
|
405
|
+
batch : tuple[torch.Tensor, torch.Tensor]
|
|
389
406
|
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
390
407
|
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
391
408
|
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
@@ -399,15 +416,29 @@ class VAEModule(L.LightningModule):
|
|
|
399
416
|
Any
|
|
400
417
|
Loss value.
|
|
401
418
|
"""
|
|
402
|
-
x, target = batch
|
|
419
|
+
x, *target = batch
|
|
403
420
|
|
|
404
421
|
# Forward pass
|
|
405
422
|
out = self.model(x)
|
|
423
|
+
if not self.supervised_mode:
|
|
424
|
+
target = x
|
|
425
|
+
else:
|
|
426
|
+
target = target[
|
|
427
|
+
0
|
|
428
|
+
] # hacky way to unpack. #TODO maybe should be fixed on the dataset level
|
|
406
429
|
|
|
407
430
|
# Update loss parameters
|
|
408
431
|
self.loss_parameters.kl_params.current_epoch = self.current_epoch
|
|
409
432
|
|
|
410
433
|
# Compute loss
|
|
434
|
+
if self.noise_model_likelihood is not None:
|
|
435
|
+
if (
|
|
436
|
+
self.noise_model_likelihood.data_mean is None
|
|
437
|
+
or self.noise_model_likelihood.data_std is None
|
|
438
|
+
):
|
|
439
|
+
raise RuntimeError(
|
|
440
|
+
"NoiseModelLikelihood: data_mean and data_std must be set before training."
|
|
441
|
+
)
|
|
411
442
|
loss = self.loss_func(
|
|
412
443
|
model_outputs=out,
|
|
413
444
|
targets=target,
|
|
@@ -419,15 +450,26 @@ class VAEModule(L.LightningModule):
|
|
|
419
450
|
# Logging
|
|
420
451
|
# TODO: implement a separate logging method?
|
|
421
452
|
self.log_dict(loss, on_step=True, on_epoch=True)
|
|
422
|
-
|
|
453
|
+
|
|
454
|
+
try:
|
|
455
|
+
optimizer = self.optimizers()
|
|
456
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
457
|
+
self.log(
|
|
458
|
+
"learning_rate", current_lr, on_step=False, on_epoch=True, logger=True
|
|
459
|
+
)
|
|
460
|
+
except RuntimeError:
|
|
461
|
+
# This happens when the module is not attached to a trainer, e.g., in tests
|
|
462
|
+
pass
|
|
423
463
|
return loss
|
|
424
464
|
|
|
425
|
-
def validation_step(
|
|
465
|
+
def validation_step(
|
|
466
|
+
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
|
|
467
|
+
) -> None:
|
|
426
468
|
"""Validation step.
|
|
427
469
|
|
|
428
470
|
Parameters
|
|
429
471
|
----------
|
|
430
|
-
batch : tuple[Tensor, Tensor]
|
|
472
|
+
batch : tuple[torch.Tensor, torch.Tensor]
|
|
431
473
|
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
432
474
|
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
433
475
|
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
@@ -436,11 +478,16 @@ class VAEModule(L.LightningModule):
|
|
|
436
478
|
batch_idx : Any
|
|
437
479
|
Batch index.
|
|
438
480
|
"""
|
|
439
|
-
x, target = batch
|
|
481
|
+
x, *target = batch
|
|
440
482
|
|
|
441
483
|
# Forward pass
|
|
442
484
|
out = self.model(x)
|
|
443
|
-
|
|
485
|
+
if not self.supervised_mode:
|
|
486
|
+
target = x
|
|
487
|
+
else:
|
|
488
|
+
target = target[
|
|
489
|
+
0
|
|
490
|
+
] # hacky way to unpack. #TODO maybe should be fixed on the datasel level
|
|
444
491
|
# Compute loss
|
|
445
492
|
loss = self.loss_func(
|
|
446
493
|
model_outputs=out,
|
|
@@ -466,12 +513,12 @@ class VAEModule(L.LightningModule):
|
|
|
466
513
|
else:
|
|
467
514
|
self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
|
|
468
515
|
|
|
469
|
-
def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
|
|
516
|
+
def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
470
517
|
"""Prediction step.
|
|
471
518
|
|
|
472
519
|
Parameters
|
|
473
520
|
----------
|
|
474
|
-
batch : Tensor
|
|
521
|
+
batch : torch.Tensor
|
|
475
522
|
Input batch.
|
|
476
523
|
batch_idx : Any
|
|
477
524
|
Batch index.
|
|
@@ -481,36 +528,86 @@ class VAEModule(L.LightningModule):
|
|
|
481
528
|
Any
|
|
482
529
|
Model output.
|
|
483
530
|
"""
|
|
484
|
-
if self.
|
|
531
|
+
if self.algorithm_config.algorithm == "microsplit":
|
|
485
532
|
x, *aux = batch
|
|
486
|
-
|
|
487
|
-
x
|
|
488
|
-
aux = []
|
|
533
|
+
# Reset model for inference with spatial dimensions only (H, W)
|
|
534
|
+
self.model.reset_for_inference(x.shape[-2:])
|
|
489
535
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
495
|
-
augmented_output = []
|
|
496
|
-
for augmented in augmented_batch:
|
|
497
|
-
augmented_pred = self.model(augmented)
|
|
498
|
-
augmented_output.append(augmented_pred)
|
|
499
|
-
output = tta.backward(augmented_output)
|
|
500
|
-
else:
|
|
501
|
-
output = self.model(x)
|
|
536
|
+
rec_img_list = []
|
|
537
|
+
for _ in range(self.algorithm_config.mmse_count):
|
|
538
|
+
# get model output
|
|
539
|
+
rec, _ = self.model(x)
|
|
502
540
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
541
|
+
# get reconstructed img
|
|
542
|
+
if self.model.predict_logvar is None:
|
|
543
|
+
rec_img = rec
|
|
544
|
+
logvar = torch.tensor([-1])
|
|
545
|
+
else:
|
|
546
|
+
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
547
|
+
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
548
|
+
|
|
549
|
+
# aggregate results
|
|
550
|
+
samples = torch.cat(rec_img_list, dim=0)
|
|
551
|
+
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
552
|
+
std_imgs = torch.std(samples, dim=0) # std over MMSE dim
|
|
553
|
+
|
|
554
|
+
tile_prediction = mmse_imgs.cpu().numpy()
|
|
555
|
+
tile_std = std_imgs.cpu().numpy()
|
|
556
|
+
|
|
557
|
+
return tile_prediction, tile_std
|
|
509
558
|
|
|
510
|
-
if len(aux) > 0: # aux can be tiling information
|
|
511
|
-
return denormalized_output, *aux
|
|
512
559
|
else:
|
|
513
|
-
|
|
560
|
+
# Regular prediction logic
|
|
561
|
+
if self._trainer.datamodule.tiled:
|
|
562
|
+
# TODO tile_size should match model input size
|
|
563
|
+
x, *aux = batch
|
|
564
|
+
x = (
|
|
565
|
+
x[0] if isinstance(x, list | tuple) else x
|
|
566
|
+
) # TODO ugly, so far i don't know why x might be a list
|
|
567
|
+
self.model.reset_for_inference(x.shape) # TODO should it be here ?
|
|
568
|
+
else:
|
|
569
|
+
x = batch[0] if isinstance(batch, list | tuple) else batch
|
|
570
|
+
aux = []
|
|
571
|
+
self.model.reset_for_inference(x.shape)
|
|
572
|
+
|
|
573
|
+
mmse_list = []
|
|
574
|
+
for _ in range(self.algorithm_config.mmse_count):
|
|
575
|
+
# apply test-time augmentation if available
|
|
576
|
+
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
577
|
+
tta = ImageRestorationTTA()
|
|
578
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
579
|
+
augmented_output = []
|
|
580
|
+
for augmented in augmented_batch:
|
|
581
|
+
augmented_pred = self.model(augmented)
|
|
582
|
+
augmented_output.append(augmented_pred)
|
|
583
|
+
output = tta.backward(augmented_output)
|
|
584
|
+
else:
|
|
585
|
+
output = self.model(x)
|
|
586
|
+
|
|
587
|
+
# taking the 1st element of the output, 2nd is std if
|
|
588
|
+
# predict_logvar=="pixelwise"
|
|
589
|
+
output = (
|
|
590
|
+
output[0]
|
|
591
|
+
if self.model.predict_logvar is None
|
|
592
|
+
else output[0][:, 0:1, ...]
|
|
593
|
+
)
|
|
594
|
+
mmse_list.append(output)
|
|
595
|
+
|
|
596
|
+
mmse = torch.stack(mmse_list).mean(0)
|
|
597
|
+
std = torch.stack(mmse_list).std(0) # TODO why?
|
|
598
|
+
# TODO better way to unpack if pred logvar
|
|
599
|
+
# Denormalize the output
|
|
600
|
+
denorm = Denormalize(
|
|
601
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
602
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
denormalized_output = denorm(patch=mmse.cpu().numpy())
|
|
606
|
+
|
|
607
|
+
if len(aux) > 0: # aux can be tiling information
|
|
608
|
+
return denormalized_output, std, *aux
|
|
609
|
+
else:
|
|
610
|
+
return denormalized_output, std
|
|
514
611
|
|
|
515
612
|
def configure_optimizers(self) -> Any:
|
|
516
613
|
"""Configure optimizers and learning rate schedulers.
|
|
@@ -539,19 +636,19 @@ class VAEModule(L.LightningModule):
|
|
|
539
636
|
# should we refactor LadderVAE so that it already outputs
|
|
540
637
|
# tuple(`mean`, `logvar`, `td_data`)?
|
|
541
638
|
def get_reconstructed_tensor(
|
|
542
|
-
self, model_outputs: tuple[Tensor, dict[str, Any]]
|
|
543
|
-
) -> Tensor:
|
|
639
|
+
self, model_outputs: tuple[torch.Tensor, dict[str, Any]]
|
|
640
|
+
) -> torch.Tensor:
|
|
544
641
|
"""Get the reconstructed tensor from the LVAE model outputs.
|
|
545
642
|
|
|
546
643
|
Parameters
|
|
547
644
|
----------
|
|
548
|
-
model_outputs : tuple[Tensor, dict[str, Any]]
|
|
645
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
549
646
|
Model outputs. It is a tuple with a tensor representing the predicted mean
|
|
550
647
|
and (optionally) logvar, and the top-down data dictionary.
|
|
551
648
|
|
|
552
649
|
Returns
|
|
553
650
|
-------
|
|
554
|
-
Tensor
|
|
651
|
+
torch.Tensor
|
|
555
652
|
Reconstructed tensor, i.e., the predicted mean.
|
|
556
653
|
"""
|
|
557
654
|
predictions, _ = model_outputs
|
|
@@ -562,18 +659,18 @@ class VAEModule(L.LightningModule):
|
|
|
562
659
|
|
|
563
660
|
def compute_val_psnr(
|
|
564
661
|
self,
|
|
565
|
-
model_output: tuple[Tensor, dict[str, Any]],
|
|
566
|
-
target: Tensor,
|
|
662
|
+
model_output: tuple[torch.Tensor, dict[str, Any]],
|
|
663
|
+
target: torch.Tensor,
|
|
567
664
|
psnr_func: Callable = scale_invariant_psnr,
|
|
568
665
|
) -> list[float]:
|
|
569
666
|
"""Compute the PSNR for the current validation batch.
|
|
570
667
|
|
|
571
668
|
Parameters
|
|
572
669
|
----------
|
|
573
|
-
model_output : tuple[Tensor, dict[str, Any]]
|
|
670
|
+
model_output : tuple[torch.Tensor, dict[str, Any]]
|
|
574
671
|
Model output, a tuple with the predicted mean and (optionally) logvar,
|
|
575
672
|
and the top-down data dictionary.
|
|
576
|
-
target : Tensor
|
|
673
|
+
target : torch.Tensor
|
|
577
674
|
Target tensor.
|
|
578
675
|
psnr_func : Callable, optional
|
|
579
676
|
PSNR function to use, by default `scale_invariant_psnr`.
|
|
@@ -583,6 +680,7 @@ class VAEModule(L.LightningModule):
|
|
|
583
680
|
list[float]
|
|
584
681
|
PSNR for each channel in the current batch.
|
|
585
682
|
"""
|
|
683
|
+
# TODO check this! Related to is_supervised which is also wacky
|
|
586
684
|
out_channels = target.shape[1]
|
|
587
685
|
|
|
588
686
|
# get the reconstructed image
|
|
@@ -603,7 +701,7 @@ class VAEModule(L.LightningModule):
|
|
|
603
701
|
for i in range(out_channels)
|
|
604
702
|
]
|
|
605
703
|
|
|
606
|
-
def reduce_running_psnr(self) ->
|
|
704
|
+
def reduce_running_psnr(self) -> float | None:
|
|
607
705
|
"""Reduce the running PSNR statistics and reset the running PSNR.
|
|
608
706
|
|
|
609
707
|
Returns
|
|
@@ -634,11 +732,11 @@ def create_careamics_module(
|
|
|
634
732
|
use_n2v2: bool = False,
|
|
635
733
|
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
636
734
|
struct_n2v_span: int = 5,
|
|
637
|
-
model_parameters:
|
|
735
|
+
model_parameters: dict | None = None,
|
|
638
736
|
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
639
|
-
optimizer_parameters:
|
|
737
|
+
optimizer_parameters: dict | None = None,
|
|
640
738
|
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
641
|
-
lr_scheduler_parameters:
|
|
739
|
+
lr_scheduler_parameters: dict | None = None,
|
|
642
740
|
) -> Union[FCNModule, VAEModule]:
|
|
643
741
|
"""Create a CAREamics Lightning module.
|
|
644
742
|
|