careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,914 @@
|
|
|
1
|
+
"""CAREamics Lightning module."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, Literal, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytorch_lightning as L
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from careamics.config import (
|
|
11
|
+
N2VAlgorithm,
|
|
12
|
+
PN2VAlgorithm,
|
|
13
|
+
UNetBasedAlgorithm,
|
|
14
|
+
VAEBasedAlgorithm,
|
|
15
|
+
algorithm_factory,
|
|
16
|
+
)
|
|
17
|
+
from careamics.config.data.tile_information import TileInformation
|
|
18
|
+
from careamics.config.support import (
|
|
19
|
+
SupportedAlgorithm,
|
|
20
|
+
SupportedArchitecture,
|
|
21
|
+
SupportedLoss,
|
|
22
|
+
SupportedOptimizer,
|
|
23
|
+
SupportedScheduler,
|
|
24
|
+
)
|
|
25
|
+
from careamics.losses import loss_factory
|
|
26
|
+
from careamics.models.lvae.likelihoods import (
|
|
27
|
+
GaussianLikelihood,
|
|
28
|
+
NoiseModelLikelihood,
|
|
29
|
+
likelihood_factory,
|
|
30
|
+
)
|
|
31
|
+
from careamics.models.lvae.noise_models import (
|
|
32
|
+
GaussianMixtureNoiseModel,
|
|
33
|
+
MultiChannelNoiseModel,
|
|
34
|
+
multichannel_noise_model_factory,
|
|
35
|
+
noise_model_factory,
|
|
36
|
+
)
|
|
37
|
+
from careamics.models.model_factory import model_factory
|
|
38
|
+
from careamics.transforms import (
|
|
39
|
+
Denormalize,
|
|
40
|
+
ImageRestorationTTA,
|
|
41
|
+
N2VManipulateTorch,
|
|
42
|
+
TrainDenormalize,
|
|
43
|
+
)
|
|
44
|
+
from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr
|
|
45
|
+
from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
46
|
+
|
|
47
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# TODO rename to UNetModule
|
|
51
|
+
class FCNModule(L.LightningModule):
|
|
52
|
+
"""
|
|
53
|
+
CAREamics Lightning module.
|
|
54
|
+
|
|
55
|
+
This class encapsulates the PyTorch model along with the training, validation,
|
|
56
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
algorithm_config : AlgorithmModel or dict
|
|
61
|
+
Algorithm configuration.
|
|
62
|
+
|
|
63
|
+
Attributes
|
|
64
|
+
----------
|
|
65
|
+
model : torch.nn.Module
|
|
66
|
+
PyTorch model.
|
|
67
|
+
loss_func : torch.nn.Module
|
|
68
|
+
Loss function.
|
|
69
|
+
optimizer_name : str
|
|
70
|
+
Optimizer name.
|
|
71
|
+
optimizer_params : dict
|
|
72
|
+
Optimizer parameters.
|
|
73
|
+
lr_scheduler_name : str
|
|
74
|
+
Learning rate scheduler name.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self, algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm, dict]
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Lightning module for CAREamics.
|
|
81
|
+
|
|
82
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
83
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
algorithm_config : AlgorithmModel or dict
|
|
88
|
+
Algorithm configuration.
|
|
89
|
+
"""
|
|
90
|
+
super().__init__()
|
|
91
|
+
|
|
92
|
+
if isinstance(algorithm_config, dict):
|
|
93
|
+
algorithm_config = algorithm_factory(algorithm_config)
|
|
94
|
+
|
|
95
|
+
self.algorithm_config = algorithm_config
|
|
96
|
+
# create preprocessing, model and loss function
|
|
97
|
+
if isinstance(self.algorithm_config, N2VAlgorithm | PN2VAlgorithm):
|
|
98
|
+
self.use_n2v = True
|
|
99
|
+
self.n2v_preprocess: N2VManipulateTorch | None = N2VManipulateTorch(
|
|
100
|
+
n2v_manipulate_config=self.algorithm_config.n2v_config
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
self.use_n2v = False
|
|
104
|
+
self.n2v_preprocess = None
|
|
105
|
+
|
|
106
|
+
self.algorithm = self.algorithm_config.algorithm
|
|
107
|
+
self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
|
|
108
|
+
self.noise_model: NoiseModel | None = noise_model_factory(
|
|
109
|
+
self.algorithm_config.noise_model
|
|
110
|
+
if isinstance(self.algorithm_config, PN2VAlgorithm)
|
|
111
|
+
else None
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Create loss function, pre-configure with noise model for PN2V
|
|
115
|
+
loss_func = loss_factory(self.algorithm_config.loss)
|
|
116
|
+
if (
|
|
117
|
+
isinstance(self.algorithm_config, PN2VAlgorithm)
|
|
118
|
+
and self.noise_model is not None
|
|
119
|
+
):
|
|
120
|
+
# For PN2V, reorder arguments and pass noise model
|
|
121
|
+
self.loss_func = lambda *args: loss_func(
|
|
122
|
+
args[0], args[1], args[2], self.noise_model
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
self.loss_func = loss_func
|
|
126
|
+
|
|
127
|
+
# save optimizer and lr_scheduler names and parameters
|
|
128
|
+
self.optimizer_name = self.algorithm_config.optimizer.name
|
|
129
|
+
self.optimizer_params = self.algorithm_config.optimizer.parameters
|
|
130
|
+
self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name
|
|
131
|
+
self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters
|
|
132
|
+
|
|
133
|
+
def forward(self, x: Any) -> Any:
|
|
134
|
+
"""Forward pass.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
x : Any
|
|
139
|
+
Input tensor.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
Any
|
|
144
|
+
Output tensor.
|
|
145
|
+
"""
|
|
146
|
+
return self.model(x)
|
|
147
|
+
|
|
148
|
+
def _train_denormalize(self, out: torch.Tensor) -> torch.Tensor:
|
|
149
|
+
"""Denormalize output using training dataset statistics.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
out : torch.Tensor
|
|
154
|
+
Output tensor to denormalize.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
torch.Tensor
|
|
159
|
+
Denormalized tensor.
|
|
160
|
+
"""
|
|
161
|
+
denorm = TrainDenormalize(
|
|
162
|
+
image_means=(self._trainer.datamodule.train_dataset.image_stats.means),
|
|
163
|
+
image_stds=(self._trainer.datamodule.train_dataset.image_stats.stds),
|
|
164
|
+
)
|
|
165
|
+
return denorm(patch=out)
|
|
166
|
+
|
|
167
|
+
def _predict_denormalize(
|
|
168
|
+
self, out: torch.Tensor, from_prediction: bool
|
|
169
|
+
) -> torch.Tensor:
|
|
170
|
+
"""Denormalize output for prediction.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
out : torch.Tensor
|
|
175
|
+
Output tensor to denormalize.
|
|
176
|
+
from_prediction : bool
|
|
177
|
+
Whether using prediction or training dataset stats.
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
torch.Tensor
|
|
182
|
+
Denormalized tensor.
|
|
183
|
+
"""
|
|
184
|
+
denorm = Denormalize(
|
|
185
|
+
image_means=(
|
|
186
|
+
self._trainer.datamodule.predict_dataset.image_means
|
|
187
|
+
if from_prediction
|
|
188
|
+
else self._trainer.datamodule.train_dataset.image_stats.means
|
|
189
|
+
),
|
|
190
|
+
image_stds=(
|
|
191
|
+
self._trainer.datamodule.predict_dataset.image_stds
|
|
192
|
+
if from_prediction
|
|
193
|
+
else self._trainer.datamodule.train_dataset.image_stats.stds
|
|
194
|
+
),
|
|
195
|
+
)
|
|
196
|
+
return denorm(patch=out.cpu().numpy())
|
|
197
|
+
|
|
198
|
+
def training_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
199
|
+
"""Training step.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
batch : torch.torch.Tensor
|
|
204
|
+
Input batch.
|
|
205
|
+
batch_idx : Any
|
|
206
|
+
Batch index.
|
|
207
|
+
|
|
208
|
+
Returns
|
|
209
|
+
-------
|
|
210
|
+
Any
|
|
211
|
+
Loss value.
|
|
212
|
+
"""
|
|
213
|
+
x, *targets = batch
|
|
214
|
+
if self.use_n2v and self.n2v_preprocess is not None:
|
|
215
|
+
x_preprocessed, *aux = self.n2v_preprocess(x)
|
|
216
|
+
else:
|
|
217
|
+
x_preprocessed = x
|
|
218
|
+
aux = []
|
|
219
|
+
|
|
220
|
+
out = self.model(x_preprocessed)
|
|
221
|
+
|
|
222
|
+
# PN2V needs denormalized output and targets for loss computation
|
|
223
|
+
if isinstance(self.algorithm_config, PN2VAlgorithm):
|
|
224
|
+
out = self._train_denormalize(out)
|
|
225
|
+
aux = [self._train_denormalize(aux[0]), aux[1]]
|
|
226
|
+
# TODO hacky and ugly
|
|
227
|
+
loss = self.loss_func(out, *aux, *targets)
|
|
228
|
+
self.log(
|
|
229
|
+
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
|
230
|
+
)
|
|
231
|
+
optimizer = self.optimizers()
|
|
232
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
233
|
+
self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
|
|
234
|
+
return loss
|
|
235
|
+
|
|
236
|
+
def validation_step(self, batch: torch.Tensor, batch_idx: Any) -> None:
|
|
237
|
+
"""Validation step.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
batch : torch.torch.Tensor
|
|
242
|
+
Input batch.
|
|
243
|
+
batch_idx : Any
|
|
244
|
+
Batch index.
|
|
245
|
+
"""
|
|
246
|
+
x, *targets = batch
|
|
247
|
+
if self.use_n2v and self.n2v_preprocess is not None:
|
|
248
|
+
x_preprocessed, *aux = self.n2v_preprocess(x)
|
|
249
|
+
else:
|
|
250
|
+
x_preprocessed = x
|
|
251
|
+
aux = []
|
|
252
|
+
|
|
253
|
+
out = self.model(x_preprocessed)
|
|
254
|
+
|
|
255
|
+
# PN2V needs denormalized output and targets for loss computation
|
|
256
|
+
if isinstance(self.algorithm_config, PN2VAlgorithm):
|
|
257
|
+
out = torch.tensor(self._train_denormalize(out))
|
|
258
|
+
aux = [self._train_denormalize(aux[0]), aux[1]]
|
|
259
|
+
# TODO hacky and ugly
|
|
260
|
+
val_loss = self.loss_func(out, *aux, *targets)
|
|
261
|
+
|
|
262
|
+
# log validation loss
|
|
263
|
+
self.log(
|
|
264
|
+
"val_loss",
|
|
265
|
+
val_loss,
|
|
266
|
+
on_step=False,
|
|
267
|
+
on_epoch=True,
|
|
268
|
+
prog_bar=True,
|
|
269
|
+
logger=True,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
273
|
+
"""Prediction step.
|
|
274
|
+
|
|
275
|
+
Parameters
|
|
276
|
+
----------
|
|
277
|
+
batch : torch.torch.torch.Tensor
|
|
278
|
+
Input batch.
|
|
279
|
+
batch_idx : Any
|
|
280
|
+
Batch index.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
Any
|
|
285
|
+
Model output.
|
|
286
|
+
"""
|
|
287
|
+
# TODO refactor when redoing datasets
|
|
288
|
+
# hacky way to determine if it is PredictDataModule, otherwise there is a
|
|
289
|
+
# circular import to solve with isinstance
|
|
290
|
+
from_prediction = hasattr(self._trainer.datamodule, "tiled")
|
|
291
|
+
is_tiled = (
|
|
292
|
+
len(batch) > 1
|
|
293
|
+
and isinstance(batch[1], list)
|
|
294
|
+
and isinstance(batch[1][0], TileInformation)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# TODO add explanations for what is happening here
|
|
298
|
+
if is_tiled:
|
|
299
|
+
x, *aux = batch
|
|
300
|
+
if type(x) in [list, tuple]:
|
|
301
|
+
x = x[0]
|
|
302
|
+
else:
|
|
303
|
+
if type(batch) in [list, tuple]:
|
|
304
|
+
x = batch[0] # TODO change, ugly way to deal with n2v refac
|
|
305
|
+
else:
|
|
306
|
+
x = batch
|
|
307
|
+
aux = []
|
|
308
|
+
|
|
309
|
+
# apply test-time augmentation if available
|
|
310
|
+
# TODO: probably wont work with batch size > 1
|
|
311
|
+
if (
|
|
312
|
+
from_prediction
|
|
313
|
+
and self._trainer.datamodule.prediction_config.tta_transforms
|
|
314
|
+
):
|
|
315
|
+
tta = ImageRestorationTTA()
|
|
316
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
317
|
+
augmented_output = []
|
|
318
|
+
for augmented in augmented_batch:
|
|
319
|
+
augmented_pred = self.model(augmented)
|
|
320
|
+
augmented_output.append(augmented_pred)
|
|
321
|
+
output = tta.backward(augmented_output)
|
|
322
|
+
else:
|
|
323
|
+
output = self.model(x)
|
|
324
|
+
|
|
325
|
+
# Denormalize the output
|
|
326
|
+
# TODO incompatible API between predict and train datasets
|
|
327
|
+
|
|
328
|
+
denormalized_input = self._predict_denormalize(
|
|
329
|
+
x, from_prediction=from_prediction
|
|
330
|
+
)
|
|
331
|
+
denormalized_output = self._predict_denormalize(
|
|
332
|
+
output, from_prediction=from_prediction
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Calculate MSE estimate
|
|
336
|
+
if isinstance(self.algorithm_config, PN2VAlgorithm):
|
|
337
|
+
assert self.noise_model is not None, "Noise model required for PN2V"
|
|
338
|
+
likelihoods = self.noise_model.likelihood(
|
|
339
|
+
torch.tensor(denormalized_input), torch.tensor(denormalized_output)
|
|
340
|
+
)
|
|
341
|
+
mse_estimate = torch.sum(
|
|
342
|
+
likelihoods * denormalized_output, dim=1, keepdim=True
|
|
343
|
+
)
|
|
344
|
+
mse_estimate /= torch.sum(likelihoods, dim=1, keepdim=True)
|
|
345
|
+
|
|
346
|
+
if isinstance(self.algorithm_config, PN2VAlgorithm):
|
|
347
|
+
denormalized_output = np.mean(denormalized_output, axis=1, keepdims=True)
|
|
348
|
+
denormalized_output = (denormalized_output, mse_estimate)
|
|
349
|
+
# TODO: might be ugly but otherwise we need to change the output signature
|
|
350
|
+
if len(aux) > 0: # aux can be tiling information
|
|
351
|
+
return denormalized_output, *aux
|
|
352
|
+
else:
|
|
353
|
+
return denormalized_output
|
|
354
|
+
|
|
355
|
+
def configure_optimizers(self) -> Any:
|
|
356
|
+
"""Configure optimizers and learning rate schedulers.
|
|
357
|
+
|
|
358
|
+
Returns
|
|
359
|
+
-------
|
|
360
|
+
Any
|
|
361
|
+
Optimizer and learning rate scheduler.
|
|
362
|
+
"""
|
|
363
|
+
# instantiate optimizer
|
|
364
|
+
optimizer_func = get_optimizer(self.optimizer_name)
|
|
365
|
+
optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
|
|
366
|
+
|
|
367
|
+
# and scheduler
|
|
368
|
+
scheduler_func = get_scheduler(self.lr_scheduler_name)
|
|
369
|
+
scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
|
|
370
|
+
|
|
371
|
+
return {
|
|
372
|
+
"optimizer": optimizer,
|
|
373
|
+
"lr_scheduler": scheduler,
|
|
374
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class VAEModule(L.LightningModule):
|
|
379
|
+
"""
|
|
380
|
+
CAREamics Lightning module.
|
|
381
|
+
|
|
382
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
383
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
384
|
+
|
|
385
|
+
Parameters
|
|
386
|
+
----------
|
|
387
|
+
algorithm_config : Union[VAEAlgorithmConfig, dict]
|
|
388
|
+
Algorithm configuration.
|
|
389
|
+
|
|
390
|
+
Attributes
|
|
391
|
+
----------
|
|
392
|
+
model : nn.Module
|
|
393
|
+
PyTorch model.
|
|
394
|
+
loss_func : nn.Module
|
|
395
|
+
Loss function.
|
|
396
|
+
optimizer_name : str
|
|
397
|
+
Optimizer name.
|
|
398
|
+
optimizer_params : dict
|
|
399
|
+
Optimizer parameters.
|
|
400
|
+
lr_scheduler_name : str
|
|
401
|
+
Learning rate scheduler name.
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
def __init__(self, algorithm_config: Union[VAEBasedAlgorithm, dict]) -> None:
|
|
405
|
+
"""Lightning module for CAREamics.
|
|
406
|
+
|
|
407
|
+
This class encapsulates the a PyTorch model along with the training, validation,
|
|
408
|
+
and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
|
|
409
|
+
|
|
410
|
+
Parameters
|
|
411
|
+
----------
|
|
412
|
+
algorithm_config : Union[AlgorithmModel, dict]
|
|
413
|
+
Algorithm configuration.
|
|
414
|
+
"""
|
|
415
|
+
super().__init__()
|
|
416
|
+
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
417
|
+
self.algorithm_config = (
|
|
418
|
+
VAEBasedAlgorithm(**algorithm_config)
|
|
419
|
+
if isinstance(algorithm_config, dict)
|
|
420
|
+
else algorithm_config
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# TODO: log algorithm config
|
|
424
|
+
# self.save_hyperparameters(self.algorithm_config.model_dump())
|
|
425
|
+
|
|
426
|
+
# create model
|
|
427
|
+
self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
|
|
428
|
+
|
|
429
|
+
# supervised_mode
|
|
430
|
+
self.supervised_mode = self.algorithm_config.is_supervised
|
|
431
|
+
# create noise model (VAE algorithms always use multichannel nm factory)
|
|
432
|
+
self.noise_model: NoiseModel | None = multichannel_noise_model_factory(
|
|
433
|
+
self.algorithm_config.noise_model
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
self.noise_model_likelihood: NoiseModelLikelihood | None = None
|
|
437
|
+
if self.algorithm_config.noise_model_likelihood is not None:
|
|
438
|
+
self.noise_model_likelihood = likelihood_factory(
|
|
439
|
+
config=self.algorithm_config.noise_model_likelihood,
|
|
440
|
+
noise_model=self.noise_model,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
|
|
444
|
+
self.algorithm_config.gaussian_likelihood
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
self.loss_parameters = self.algorithm_config.loss
|
|
448
|
+
self.loss_func = loss_factory(self.algorithm_config.loss.loss_type)
|
|
449
|
+
|
|
450
|
+
# save optimizer and lr_scheduler names and parameters
|
|
451
|
+
self.optimizer_name = self.algorithm_config.optimizer.name
|
|
452
|
+
self.optimizer_params = self.algorithm_config.optimizer.parameters
|
|
453
|
+
self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name
|
|
454
|
+
self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters
|
|
455
|
+
|
|
456
|
+
# initialize running PSNR
|
|
457
|
+
self.running_psnr = [
|
|
458
|
+
RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
|
|
459
|
+
]
|
|
460
|
+
|
|
461
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
|
|
462
|
+
"""Forward pass.
|
|
463
|
+
|
|
464
|
+
Parameters
|
|
465
|
+
----------
|
|
466
|
+
x : torch.Tensor
|
|
467
|
+
Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
468
|
+
number of lateral inputs.
|
|
469
|
+
|
|
470
|
+
Returns
|
|
471
|
+
-------
|
|
472
|
+
tuple[torch.Tensor, dict[str, Any]]
|
|
473
|
+
A tuple with the output tensor and additional data from the top-down pass.
|
|
474
|
+
"""
|
|
475
|
+
return self.model(x) # TODO Different model can have more than one output
|
|
476
|
+
|
|
477
|
+
def set_data_stats(self, data_mean, data_std):
|
|
478
|
+
"""Set data mean and std for the noise model likelihood.
|
|
479
|
+
|
|
480
|
+
Parameters
|
|
481
|
+
----------
|
|
482
|
+
data_mean : float
|
|
483
|
+
Mean of the data.
|
|
484
|
+
data_std : float
|
|
485
|
+
Standard deviation of the data.
|
|
486
|
+
"""
|
|
487
|
+
if self.noise_model_likelihood is not None:
|
|
488
|
+
self.noise_model_likelihood.set_data_stats(data_mean, data_std)
|
|
489
|
+
|
|
490
|
+
def training_step(
|
|
491
|
+
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
|
|
492
|
+
) -> dict[str, torch.Tensor] | None:
|
|
493
|
+
"""Training step.
|
|
494
|
+
|
|
495
|
+
Parameters
|
|
496
|
+
----------
|
|
497
|
+
batch : tuple[torch.Tensor, torch.Tensor]
|
|
498
|
+
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
499
|
+
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
500
|
+
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
501
|
+
where C is the number of target channels (e.g., 1 in HDN, >1 in
|
|
502
|
+
muSplit/denoiSplit).
|
|
503
|
+
batch_idx : Any
|
|
504
|
+
Batch index.
|
|
505
|
+
|
|
506
|
+
Returns
|
|
507
|
+
-------
|
|
508
|
+
Any
|
|
509
|
+
Loss value.
|
|
510
|
+
"""
|
|
511
|
+
x, *target = batch
|
|
512
|
+
|
|
513
|
+
# Forward pass
|
|
514
|
+
out = self.model(x)
|
|
515
|
+
if not self.supervised_mode:
|
|
516
|
+
target = x
|
|
517
|
+
else:
|
|
518
|
+
target = target[
|
|
519
|
+
0
|
|
520
|
+
] # hacky way to unpack. #TODO maybe should be fixed on the dataset level
|
|
521
|
+
|
|
522
|
+
# Update loss parameters
|
|
523
|
+
self.loss_parameters.kl_params.current_epoch = self.current_epoch
|
|
524
|
+
|
|
525
|
+
# Compute loss
|
|
526
|
+
if self.noise_model_likelihood is not None:
|
|
527
|
+
if (
|
|
528
|
+
self.noise_model_likelihood.data_mean is None
|
|
529
|
+
or self.noise_model_likelihood.data_std is None
|
|
530
|
+
):
|
|
531
|
+
raise RuntimeError(
|
|
532
|
+
"NoiseModelLikelihood: mean and std must be set before training."
|
|
533
|
+
)
|
|
534
|
+
loss = self.loss_func(
|
|
535
|
+
model_outputs=out,
|
|
536
|
+
targets=target,
|
|
537
|
+
config=self.loss_parameters,
|
|
538
|
+
gaussian_likelihood=self.gaussian_likelihood,
|
|
539
|
+
noise_model_likelihood=self.noise_model_likelihood,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
# Logging
|
|
543
|
+
# TODO: implement a separate logging method?
|
|
544
|
+
self.log_dict(loss, on_step=True, on_epoch=True)
|
|
545
|
+
|
|
546
|
+
try:
|
|
547
|
+
optimizer = self.optimizers()
|
|
548
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
549
|
+
self.log(
|
|
550
|
+
"learning_rate", current_lr, on_step=False, on_epoch=True, logger=True
|
|
551
|
+
)
|
|
552
|
+
except RuntimeError:
|
|
553
|
+
# This happens when the module is not attached to a trainer, e.g., in tests
|
|
554
|
+
pass
|
|
555
|
+
return loss
|
|
556
|
+
|
|
557
|
+
def validation_step(
|
|
558
|
+
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
|
|
559
|
+
) -> None:
|
|
560
|
+
"""Validation step.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
batch : tuple[torch.Tensor, torch.Tensor]
|
|
565
|
+
Input batch. It is a tuple with the input tensor and the target tensor.
|
|
566
|
+
The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
|
|
567
|
+
number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
|
|
568
|
+
where C is the number of target channels (e.g., 1 in HDN, >1 in
|
|
569
|
+
muSplit/denoiSplit).
|
|
570
|
+
batch_idx : Any
|
|
571
|
+
Batch index.
|
|
572
|
+
"""
|
|
573
|
+
x, *target = batch
|
|
574
|
+
|
|
575
|
+
# Forward pass
|
|
576
|
+
out = self.model(x)
|
|
577
|
+
if not self.supervised_mode:
|
|
578
|
+
target = x
|
|
579
|
+
else:
|
|
580
|
+
target = target[
|
|
581
|
+
0
|
|
582
|
+
] # hacky way to unpack. #TODO maybe should be fixed on the datasel level
|
|
583
|
+
# Compute loss
|
|
584
|
+
loss = self.loss_func(
|
|
585
|
+
model_outputs=out,
|
|
586
|
+
targets=target,
|
|
587
|
+
config=self.loss_parameters,
|
|
588
|
+
gaussian_likelihood=self.gaussian_likelihood,
|
|
589
|
+
noise_model_likelihood=self.noise_model_likelihood,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
# Logging
|
|
593
|
+
# Rename val_loss dict
|
|
594
|
+
loss = {"_".join(["val", k]): v for k, v in loss.items()}
|
|
595
|
+
self.log_dict(loss, on_epoch=True, prog_bar=True)
|
|
596
|
+
curr_psnr = self.compute_val_psnr(out, target)
|
|
597
|
+
for i, psnr in enumerate(curr_psnr):
|
|
598
|
+
self.log(f"val_psnr_ch{i+1}_batch", psnr, on_epoch=True)
|
|
599
|
+
|
|
600
|
+
def on_validation_epoch_end(self) -> None:
|
|
601
|
+
"""Validation epoch end."""
|
|
602
|
+
psnr_ = self.reduce_running_psnr()
|
|
603
|
+
if psnr_ is not None:
|
|
604
|
+
self.log("val_psnr", psnr_, on_epoch=True, prog_bar=True)
|
|
605
|
+
else:
|
|
606
|
+
self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
|
|
607
|
+
|
|
608
|
+
def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
|
|
609
|
+
"""Prediction step.
|
|
610
|
+
|
|
611
|
+
Parameters
|
|
612
|
+
----------
|
|
613
|
+
batch : torch.Tensor
|
|
614
|
+
Input batch.
|
|
615
|
+
batch_idx : Any
|
|
616
|
+
Batch index.
|
|
617
|
+
|
|
618
|
+
Returns
|
|
619
|
+
-------
|
|
620
|
+
Any
|
|
621
|
+
Model output.
|
|
622
|
+
"""
|
|
623
|
+
if self.algorithm_config.algorithm == "microsplit":
|
|
624
|
+
x, *aux = batch
|
|
625
|
+
# Reset model for inference with spatial dimensions only (H, W)
|
|
626
|
+
self.model.reset_for_inference(x.shape[-2:])
|
|
627
|
+
|
|
628
|
+
rec_img_list = []
|
|
629
|
+
for _ in range(self.algorithm_config.mmse_count):
|
|
630
|
+
# get model output
|
|
631
|
+
rec, _ = self.model(x)
|
|
632
|
+
|
|
633
|
+
# get reconstructed img
|
|
634
|
+
if self.model.predict_logvar is None:
|
|
635
|
+
rec_img = rec
|
|
636
|
+
_logvar = torch.tensor([-1])
|
|
637
|
+
else:
|
|
638
|
+
rec_img, _logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
639
|
+
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
640
|
+
|
|
641
|
+
# aggregate results
|
|
642
|
+
samples = torch.cat(rec_img_list, dim=0)
|
|
643
|
+
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
644
|
+
std_imgs = torch.std(samples, dim=0) # std over MMSE dim
|
|
645
|
+
|
|
646
|
+
tile_prediction = mmse_imgs.cpu().numpy()
|
|
647
|
+
tile_std = std_imgs.cpu().numpy()
|
|
648
|
+
|
|
649
|
+
return tile_prediction, tile_std
|
|
650
|
+
|
|
651
|
+
else:
|
|
652
|
+
# Regular prediction logic
|
|
653
|
+
if self._trainer.datamodule.tiled:
|
|
654
|
+
# TODO tile_size should match model input size
|
|
655
|
+
x, *aux = batch
|
|
656
|
+
x = (
|
|
657
|
+
x[0] if isinstance(x, list | tuple) else x
|
|
658
|
+
) # TODO ugly, so far i don't know why x might be a list
|
|
659
|
+
self.model.reset_for_inference(x.shape) # TODO should it be here ?
|
|
660
|
+
else:
|
|
661
|
+
x = batch[0] if isinstance(batch, list | tuple) else batch
|
|
662
|
+
aux = []
|
|
663
|
+
self.model.reset_for_inference(x.shape)
|
|
664
|
+
|
|
665
|
+
mmse_list = []
|
|
666
|
+
for _ in range(self.algorithm_config.mmse_count):
|
|
667
|
+
# apply test-time augmentation if available
|
|
668
|
+
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
669
|
+
tta = ImageRestorationTTA()
|
|
670
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
671
|
+
augmented_output = []
|
|
672
|
+
for augmented in augmented_batch:
|
|
673
|
+
augmented_pred = self.model(augmented)
|
|
674
|
+
augmented_output.append(augmented_pred)
|
|
675
|
+
output = tta.backward(augmented_output)
|
|
676
|
+
else:
|
|
677
|
+
output = self.model(x)
|
|
678
|
+
|
|
679
|
+
# taking the 1st element of the output, 2nd is std if
|
|
680
|
+
# predict_logvar=="pixelwise"
|
|
681
|
+
output = (
|
|
682
|
+
output[0]
|
|
683
|
+
if self.model.predict_logvar is None
|
|
684
|
+
else output[0][:, 0:1, ...]
|
|
685
|
+
)
|
|
686
|
+
mmse_list.append(output)
|
|
687
|
+
|
|
688
|
+
mmse = torch.stack(mmse_list).mean(0)
|
|
689
|
+
std = torch.stack(mmse_list).std(0) # TODO why?
|
|
690
|
+
# TODO better way to unpack if pred logvar
|
|
691
|
+
# Denormalize the output
|
|
692
|
+
denorm = Denormalize(
|
|
693
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
694
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
denormalized_output = denorm(patch=mmse.cpu().numpy())
|
|
698
|
+
|
|
699
|
+
if len(aux) > 0: # aux can be tiling information
|
|
700
|
+
return denormalized_output, std, *aux
|
|
701
|
+
else:
|
|
702
|
+
return denormalized_output, std
|
|
703
|
+
|
|
704
|
+
def configure_optimizers(self) -> Any:
|
|
705
|
+
"""Configure optimizers and learning rate schedulers.
|
|
706
|
+
|
|
707
|
+
Returns
|
|
708
|
+
-------
|
|
709
|
+
Any
|
|
710
|
+
Optimizer and learning rate scheduler.
|
|
711
|
+
"""
|
|
712
|
+
# instantiate optimizer
|
|
713
|
+
optimizer_func = get_optimizer(self.optimizer_name)
|
|
714
|
+
optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
|
|
715
|
+
|
|
716
|
+
# and scheduler
|
|
717
|
+
scheduler_func = get_scheduler(self.lr_scheduler_name)
|
|
718
|
+
scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
|
|
719
|
+
|
|
720
|
+
return {
|
|
721
|
+
"optimizer": optimizer,
|
|
722
|
+
"lr_scheduler": scheduler,
|
|
723
|
+
"monitor": "val_loss", # otherwise triggers MisconfigurationException
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
# TODO: find a way to move the following methods to a separate module
|
|
727
|
+
# TODO: this same operation is done in many other places, like in loss_func
|
|
728
|
+
# should we refactor LadderVAE so that it already outputs
|
|
729
|
+
# tuple(`mean`, `logvar`, `td_data`)?
|
|
730
|
+
def get_reconstructed_tensor(
|
|
731
|
+
self, model_outputs: tuple[torch.Tensor, dict[str, Any]]
|
|
732
|
+
) -> torch.Tensor:
|
|
733
|
+
"""Get the reconstructed tensor from the LVAE model outputs.
|
|
734
|
+
|
|
735
|
+
Parameters
|
|
736
|
+
----------
|
|
737
|
+
model_outputs : tuple[torch.Tensor, dict[str, Any]]
|
|
738
|
+
Model outputs. It is a tuple with a tensor representing the predicted mean
|
|
739
|
+
and (optionally) logvar, and the top-down data dictionary.
|
|
740
|
+
|
|
741
|
+
Returns
|
|
742
|
+
-------
|
|
743
|
+
torch.Tensor
|
|
744
|
+
Reconstructed tensor, i.e., the predicted mean.
|
|
745
|
+
"""
|
|
746
|
+
predictions, _ = model_outputs
|
|
747
|
+
if self.model.predict_logvar is None:
|
|
748
|
+
return predictions
|
|
749
|
+
elif self.model.predict_logvar == "pixelwise":
|
|
750
|
+
return predictions.chunk(2, dim=1)[0]
|
|
751
|
+
|
|
752
|
+
def compute_val_psnr(
|
|
753
|
+
self,
|
|
754
|
+
model_output: tuple[torch.Tensor, dict[str, Any]],
|
|
755
|
+
target: torch.Tensor,
|
|
756
|
+
psnr_func: Callable = scale_invariant_psnr,
|
|
757
|
+
) -> list[float]:
|
|
758
|
+
"""Compute the PSNR for the current validation batch.
|
|
759
|
+
|
|
760
|
+
Parameters
|
|
761
|
+
----------
|
|
762
|
+
model_output : tuple[torch.Tensor, dict[str, Any]]
|
|
763
|
+
Model output, a tuple with the predicted mean and (optionally) logvar,
|
|
764
|
+
and the top-down data dictionary.
|
|
765
|
+
target : torch.Tensor
|
|
766
|
+
Target tensor.
|
|
767
|
+
psnr_func : Callable, optional
|
|
768
|
+
PSNR function to use, by default `scale_invariant_psnr`.
|
|
769
|
+
|
|
770
|
+
Returns
|
|
771
|
+
-------
|
|
772
|
+
list[float]
|
|
773
|
+
PSNR for each channel in the current batch.
|
|
774
|
+
"""
|
|
775
|
+
# TODO check this! Related to is_supervised which is also wacky
|
|
776
|
+
out_channels = target.shape[1]
|
|
777
|
+
|
|
778
|
+
# get the reconstructed image
|
|
779
|
+
recons_img = self.get_reconstructed_tensor(model_output)
|
|
780
|
+
|
|
781
|
+
# update running psnr
|
|
782
|
+
for i in range(out_channels):
|
|
783
|
+
self.running_psnr[i].update(rec=recons_img[:, i], tar=target[:, i])
|
|
784
|
+
|
|
785
|
+
# compute psnr for each channel in the current batch
|
|
786
|
+
# TODO: this doesn't need do be a method of this class
|
|
787
|
+
# and hence can be moved to a separate module
|
|
788
|
+
return [
|
|
789
|
+
psnr_func(
|
|
790
|
+
gt=target[:, i].clone().detach().cpu().numpy(),
|
|
791
|
+
pred=recons_img[:, i].clone().detach().cpu().numpy(),
|
|
792
|
+
)
|
|
793
|
+
for i in range(out_channels)
|
|
794
|
+
]
|
|
795
|
+
|
|
796
|
+
def reduce_running_psnr(self) -> float | None:
|
|
797
|
+
"""Reduce the running PSNR statistics and reset the running PSNR.
|
|
798
|
+
|
|
799
|
+
Returns
|
|
800
|
+
-------
|
|
801
|
+
Optional[float]
|
|
802
|
+
Running PSNR averaged over the different output channels.
|
|
803
|
+
"""
|
|
804
|
+
psnr_arr = [] # type: ignore
|
|
805
|
+
for i in range(len(self.running_psnr)):
|
|
806
|
+
psnr = self.running_psnr[i].get()
|
|
807
|
+
if psnr is None:
|
|
808
|
+
psnr_arr = None # type: ignore
|
|
809
|
+
break
|
|
810
|
+
psnr_arr.append(psnr.cpu().numpy())
|
|
811
|
+
self.running_psnr[i].reset()
|
|
812
|
+
# TODO: this line forces it to be a method of this class
|
|
813
|
+
# alternative is returning also the reset `running_psnr`
|
|
814
|
+
if psnr_arr is not None:
|
|
815
|
+
psnr = np.mean(psnr_arr)
|
|
816
|
+
return psnr
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
# TODO: make this LVAE compatible (?)
|
|
820
|
+
def create_careamics_module(
|
|
821
|
+
algorithm: Union[SupportedAlgorithm, str],
|
|
822
|
+
loss: Union[SupportedLoss, str],
|
|
823
|
+
architecture: Union[SupportedArchitecture, str],
|
|
824
|
+
use_n2v2: bool = False,
|
|
825
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
826
|
+
struct_n2v_span: int = 5,
|
|
827
|
+
model_parameters: dict | None = None,
|
|
828
|
+
optimizer: Union[SupportedOptimizer, str] = "Adam",
|
|
829
|
+
optimizer_parameters: dict | None = None,
|
|
830
|
+
lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
|
|
831
|
+
lr_scheduler_parameters: dict | None = None,
|
|
832
|
+
) -> Union[FCNModule, VAEModule]:
|
|
833
|
+
"""Create a CAREamics Lightning module.
|
|
834
|
+
|
|
835
|
+
This function exposes parameters used to create an AlgorithmModel instance,
|
|
836
|
+
triggering parameters validation.
|
|
837
|
+
|
|
838
|
+
Parameters
|
|
839
|
+
----------
|
|
840
|
+
algorithm : SupportedAlgorithm or str
|
|
841
|
+
Algorithm to use for training (see SupportedAlgorithm).
|
|
842
|
+
loss : SupportedLoss or str
|
|
843
|
+
Loss function to use for training (see SupportedLoss).
|
|
844
|
+
architecture : SupportedArchitecture or str
|
|
845
|
+
Model architecture to use for training (see SupportedArchitecture).
|
|
846
|
+
use_n2v2 : bool, default=False
|
|
847
|
+
Whether to use N2V2 or Noise2Void.
|
|
848
|
+
struct_n2v_axis : "horizontal", "vertical", or "none", default="none"
|
|
849
|
+
Axis of the StructN2V mask.
|
|
850
|
+
struct_n2v_span : int, default=5
|
|
851
|
+
Span of the StructN2V mask.
|
|
852
|
+
model_parameters : dict, optional
|
|
853
|
+
Model parameters to use for training, by default {}. Model parameters are
|
|
854
|
+
defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
|
|
855
|
+
`careamics.config.architectures`).
|
|
856
|
+
optimizer : SupportedOptimizer or str, optional
|
|
857
|
+
Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
|
|
858
|
+
optimizer_parameters : dict, optional
|
|
859
|
+
Optimizer parameters to use for training, as defined in `torch.optim`, by
|
|
860
|
+
default {}.
|
|
861
|
+
lr_scheduler : SupportedScheduler or str, optional
|
|
862
|
+
Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
|
|
863
|
+
(see SupportedScheduler).
|
|
864
|
+
lr_scheduler_parameters : dict, optional
|
|
865
|
+
Learning rate scheduler parameters to use for training, as defined in
|
|
866
|
+
`torch.optim`, by default {}.
|
|
867
|
+
|
|
868
|
+
Returns
|
|
869
|
+
-------
|
|
870
|
+
CAREamicsModule
|
|
871
|
+
CAREamics Lightning module.
|
|
872
|
+
"""
|
|
873
|
+
# TODO should use the same functions are in configuration_factory.py
|
|
874
|
+
# create an AlgorithmModel compatible dictionary
|
|
875
|
+
if lr_scheduler_parameters is None:
|
|
876
|
+
lr_scheduler_parameters = {}
|
|
877
|
+
if optimizer_parameters is None:
|
|
878
|
+
optimizer_parameters = {}
|
|
879
|
+
if model_parameters is None:
|
|
880
|
+
model_parameters = {}
|
|
881
|
+
algorithm_dict: dict[str, Any] = {
|
|
882
|
+
"algorithm": algorithm,
|
|
883
|
+
"loss": loss,
|
|
884
|
+
"optimizer": {
|
|
885
|
+
"name": optimizer,
|
|
886
|
+
"parameters": optimizer_parameters,
|
|
887
|
+
},
|
|
888
|
+
"lr_scheduler": {
|
|
889
|
+
"name": lr_scheduler,
|
|
890
|
+
"parameters": lr_scheduler_parameters,
|
|
891
|
+
},
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
model_dict = {"architecture": architecture}
|
|
895
|
+
model_dict.update(model_parameters)
|
|
896
|
+
|
|
897
|
+
# add model parameters to algorithm configuration
|
|
898
|
+
algorithm_dict["model"] = model_dict
|
|
899
|
+
|
|
900
|
+
which_algo = algorithm_dict["algorithm"]
|
|
901
|
+
if which_algo in UNetBasedAlgorithm.get_compatible_algorithms():
|
|
902
|
+
algorithm_cfg = algorithm_factory(algorithm_dict)
|
|
903
|
+
|
|
904
|
+
# if use N2V
|
|
905
|
+
if isinstance(algorithm_cfg, N2VAlgorithm | PN2VAlgorithm):
|
|
906
|
+
algorithm_cfg.n2v_config.struct_mask_axis = struct_n2v_axis
|
|
907
|
+
algorithm_cfg.n2v_config.struct_mask_span = struct_n2v_span
|
|
908
|
+
algorithm_cfg.set_n2v2(use_n2v2)
|
|
909
|
+
|
|
910
|
+
return FCNModule(algorithm_cfg)
|
|
911
|
+
else:
|
|
912
|
+
raise NotImplementedError(
|
|
913
|
+
f"Algorithm {which_algo} is not implemented or unknown."
|
|
914
|
+
)
|