careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
"""N2V Algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Literal, Self
|
|
4
|
+
|
|
5
|
+
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
6
|
+
from pydantic import AfterValidator, ConfigDict, model_validator
|
|
7
|
+
|
|
8
|
+
from careamics.config.architectures import UNetConfig
|
|
9
|
+
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
10
|
+
from careamics.config.transformations import N2VManipulateConfig
|
|
11
|
+
from careamics.config.validators import (
|
|
12
|
+
model_matching_in_out_channels,
|
|
13
|
+
model_without_final_activation,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from .unet_algorithm_config import UNetBasedAlgorithm
|
|
17
|
+
|
|
18
|
+
N2V = "Noise2Void"
|
|
19
|
+
N2V2 = "N2V2"
|
|
20
|
+
STRUCT_N2V = "StructN2V"
|
|
21
|
+
STRUCT_N2V2 = "StructN2V2"
|
|
22
|
+
|
|
23
|
+
N2V_REF = CiteEntry(
|
|
24
|
+
text='Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning '
|
|
25
|
+
'denoising from single noisy images". In Proceedings of the IEEE/CVF '
|
|
26
|
+
"conference on computer vision and pattern recognition (pp. 2129-2137).",
|
|
27
|
+
doi="10.1109/cvpr.2019.00223",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
N2V2_REF = CiteEntry(
|
|
31
|
+
text="Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., "
|
|
32
|
+
'2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified '
|
|
33
|
+
'sampling strategies and a tweaked network architecture". In European '
|
|
34
|
+
"Conference on Computer Vision (pp. 503-518).",
|
|
35
|
+
doi="10.1007/978-3-031-25069-9_33",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
STRUCTN2V_REF = CiteEntry(
|
|
39
|
+
text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020."
|
|
40
|
+
'"Removing structured noise with self-supervised blind-spot '
|
|
41
|
+
'networks". In 2020 IEEE 17th International Symposium on Biomedical '
|
|
42
|
+
"Imaging (ISBI) (pp. 159-163).",
|
|
43
|
+
doi="10.1109/isbi45749.2020.9098336",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
N2V_DESCRIPTION = (
|
|
47
|
+
"Noise2Void is a UNet-based self-supervised algorithm that "
|
|
48
|
+
"uses blind-spot training to denoise images. In short, in every "
|
|
49
|
+
"patches during training, random pixels are selected and their "
|
|
50
|
+
"value replaced by a neighboring pixel value. The network is then "
|
|
51
|
+
"trained to predict the original pixel value. The algorithm "
|
|
52
|
+
"relies on the continuity of the signal (neighboring pixels have "
|
|
53
|
+
"similar values) and the pixel-wise independence of the noise "
|
|
54
|
+
"(the noise in a pixel is not correlated with the noise in "
|
|
55
|
+
"neighboring pixels)."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
N2V2_DESCRIPTION = (
|
|
59
|
+
"N2V2 is a variant of Noise2Void. "
|
|
60
|
+
+ N2V_DESCRIPTION
|
|
61
|
+
+ "\nN2V2 introduces blur-pool layers and removed skip "
|
|
62
|
+
"connections in the UNet architecture to remove checkboard "
|
|
63
|
+
"artefacts, a common artefacts ocurring in Noise2Void."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
STR_N2V_DESCRIPTION = (
|
|
67
|
+
"StructN2V is a variant of Noise2Void. "
|
|
68
|
+
+ N2V_DESCRIPTION
|
|
69
|
+
+ "\nStructN2V uses a linear mask (horizontal or vertical) to replace "
|
|
70
|
+
"the pixel values of neighbors of the masked pixels by a random "
|
|
71
|
+
"value. Such masking allows removing 1D structured noise from the "
|
|
72
|
+
"the images, the main failure case of the original N2V."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
STR_N2V2_DESCRIPTION = (
|
|
76
|
+
"StructN2V2 is a a variant of Noise2Void that uses both "
|
|
77
|
+
"structN2V and N2V2. "
|
|
78
|
+
+ N2V_DESCRIPTION
|
|
79
|
+
+ "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace "
|
|
80
|
+
"the pixel values of neighbors of the masked pixels by a random "
|
|
81
|
+
"value. Such masking allows removing 1D structured noise from the "
|
|
82
|
+
"the images, the main failure case of the original N2V."
|
|
83
|
+
"\nN2V2 introduces blur-pool layers and removed skip connections in "
|
|
84
|
+
"the UNet architecture to remove checkboard artefacts, a common "
|
|
85
|
+
"artefacts ocurring in Noise2Void."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class N2VAlgorithm(UNetBasedAlgorithm):
|
|
90
|
+
"""N2V Algorithm configuration."""
|
|
91
|
+
|
|
92
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
93
|
+
|
|
94
|
+
algorithm: Literal["n2v"] = "n2v"
|
|
95
|
+
"""N2V Algorithm name."""
|
|
96
|
+
|
|
97
|
+
loss: Literal["n2v"] = "n2v"
|
|
98
|
+
"""N2V loss function."""
|
|
99
|
+
|
|
100
|
+
n2v_config: N2VManipulateConfig = N2VManipulateConfig()
|
|
101
|
+
|
|
102
|
+
model: Annotated[
|
|
103
|
+
UNetConfig,
|
|
104
|
+
AfterValidator(model_matching_in_out_channels),
|
|
105
|
+
AfterValidator(model_without_final_activation),
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
@model_validator(mode="after")
|
|
109
|
+
def validate_n2v2(self) -> Self:
|
|
110
|
+
"""Validate that the N2V2 strategy and models are set correctly.
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
Self
|
|
115
|
+
The validateed configuration.
|
|
116
|
+
|
|
117
|
+
Raises
|
|
118
|
+
------
|
|
119
|
+
ValueError
|
|
120
|
+
If N2V2 is used with the wrong pixel manipulation strategy.
|
|
121
|
+
"""
|
|
122
|
+
if self.model.n2v2:
|
|
123
|
+
if self.n2v_config.strategy != SupportedPixelManipulation.MEDIAN.value:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"N2V2 can only be used with the "
|
|
126
|
+
f"{SupportedPixelManipulation.MEDIAN} pixel manipulation strategy. "
|
|
127
|
+
f"Change the `strategy` parameters in `n2v_config` to "
|
|
128
|
+
f"{SupportedPixelManipulation.MEDIAN}."
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
if self.n2v_config.strategy != SupportedPixelManipulation.UNIFORM.value:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"N2V can only be used with the "
|
|
134
|
+
f"{SupportedPixelManipulation.UNIFORM} pixel manipulation strategy."
|
|
135
|
+
f" Change the `strategy` parameters in `n2v_config` to "
|
|
136
|
+
f"{SupportedPixelManipulation.UNIFORM}."
|
|
137
|
+
)
|
|
138
|
+
return self
|
|
139
|
+
|
|
140
|
+
def set_n2v2(self, use_n2v2: bool) -> None:
|
|
141
|
+
"""
|
|
142
|
+
Set the configuration to use N2V2 or the vanilla Noise2Void.
|
|
143
|
+
|
|
144
|
+
This method ensures that N2V2 is set correctly and remain coherent, as opposed
|
|
145
|
+
to setting the different parameters individually.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
use_n2v2 : bool
|
|
150
|
+
Whether to use N2V2.
|
|
151
|
+
"""
|
|
152
|
+
if use_n2v2:
|
|
153
|
+
self.n2v_config.strategy = SupportedPixelManipulation.MEDIAN.value
|
|
154
|
+
self.model.n2v2 = True
|
|
155
|
+
else:
|
|
156
|
+
self.n2v_config.strategy = SupportedPixelManipulation.UNIFORM.value
|
|
157
|
+
self.model.n2v2 = False
|
|
158
|
+
|
|
159
|
+
def is_struct_n2v(self) -> bool:
|
|
160
|
+
"""Check if the configuration is using structN2V.
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
bool
|
|
165
|
+
Whether the configuration is using structN2V.
|
|
166
|
+
"""
|
|
167
|
+
return self.n2v_config.struct_mask_axis != SupportedStructAxis.NONE.value
|
|
168
|
+
|
|
169
|
+
def get_algorithm_friendly_name(self) -> str:
|
|
170
|
+
"""
|
|
171
|
+
Get the friendly name of the algorithm.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
str
|
|
176
|
+
Friendly name.
|
|
177
|
+
"""
|
|
178
|
+
use_n2v2 = self.model.n2v2
|
|
179
|
+
use_structN2V = self.is_struct_n2v()
|
|
180
|
+
|
|
181
|
+
if use_n2v2 and use_structN2V:
|
|
182
|
+
return STRUCT_N2V2
|
|
183
|
+
elif use_n2v2:
|
|
184
|
+
return N2V2
|
|
185
|
+
elif use_structN2V:
|
|
186
|
+
return STRUCT_N2V
|
|
187
|
+
else:
|
|
188
|
+
return N2V
|
|
189
|
+
|
|
190
|
+
def get_algorithm_keywords(self) -> list[str]:
|
|
191
|
+
"""
|
|
192
|
+
Get algorithm keywords.
|
|
193
|
+
|
|
194
|
+
Returns
|
|
195
|
+
-------
|
|
196
|
+
list[str]
|
|
197
|
+
List of keywords.
|
|
198
|
+
"""
|
|
199
|
+
use_n2v2 = self.model.n2v2
|
|
200
|
+
use_structN2V = self.is_struct_n2v()
|
|
201
|
+
|
|
202
|
+
keywords = [
|
|
203
|
+
"denoising",
|
|
204
|
+
"restoration",
|
|
205
|
+
"UNet",
|
|
206
|
+
"3D" if self.model.is_3D() else "2D",
|
|
207
|
+
"CAREamics",
|
|
208
|
+
"pytorch",
|
|
209
|
+
N2V,
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
if use_n2v2:
|
|
213
|
+
keywords.append(N2V2)
|
|
214
|
+
if use_structN2V:
|
|
215
|
+
keywords.append(STRUCT_N2V)
|
|
216
|
+
|
|
217
|
+
return keywords
|
|
218
|
+
|
|
219
|
+
def get_algorithm_references(self) -> str:
|
|
220
|
+
"""
|
|
221
|
+
Get the algorithm references.
|
|
222
|
+
|
|
223
|
+
This is used to generate the README of the BioImage Model Zoo export.
|
|
224
|
+
|
|
225
|
+
Returns
|
|
226
|
+
-------
|
|
227
|
+
str
|
|
228
|
+
Algorithm references.
|
|
229
|
+
"""
|
|
230
|
+
use_n2v2 = self.model.n2v2
|
|
231
|
+
use_structN2V = self.is_struct_n2v()
|
|
232
|
+
|
|
233
|
+
references = [
|
|
234
|
+
N2V_REF.text + " doi: " + N2V_REF.doi,
|
|
235
|
+
N2V2_REF.text + " doi: " + N2V2_REF.doi,
|
|
236
|
+
STRUCTN2V_REF.text + " doi: " + STRUCTN2V_REF.doi,
|
|
237
|
+
]
|
|
238
|
+
|
|
239
|
+
# return the (struct)N2V(2) references
|
|
240
|
+
if use_n2v2 and use_structN2V:
|
|
241
|
+
return "\n".join(references)
|
|
242
|
+
elif use_n2v2:
|
|
243
|
+
references.pop(-1)
|
|
244
|
+
return "\n".join(references)
|
|
245
|
+
elif use_structN2V:
|
|
246
|
+
references.pop(-2)
|
|
247
|
+
return "\n".join(references)
|
|
248
|
+
else:
|
|
249
|
+
return references[0]
|
|
250
|
+
|
|
251
|
+
def get_algorithm_citations(self) -> list[CiteEntry]:
|
|
252
|
+
"""
|
|
253
|
+
Return a list of citation entries of the current algorithm.
|
|
254
|
+
|
|
255
|
+
This is used to generate the model description for the BioImage Model Zoo.
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
List[CiteEntry]
|
|
260
|
+
List of citation entries.
|
|
261
|
+
"""
|
|
262
|
+
use_n2v2 = self.model.n2v2
|
|
263
|
+
use_structN2V = self.is_struct_n2v()
|
|
264
|
+
|
|
265
|
+
references = [N2V_REF]
|
|
266
|
+
|
|
267
|
+
if use_n2v2:
|
|
268
|
+
references.append(N2V2_REF)
|
|
269
|
+
|
|
270
|
+
if use_structN2V:
|
|
271
|
+
references.append(STRUCTN2V_REF)
|
|
272
|
+
|
|
273
|
+
return references
|
|
274
|
+
|
|
275
|
+
def get_algorithm_description(self) -> str:
|
|
276
|
+
"""
|
|
277
|
+
Return a description of the algorithm.
|
|
278
|
+
|
|
279
|
+
This method is used to generate the README of the BioImage Model Zoo export.
|
|
280
|
+
|
|
281
|
+
Returns
|
|
282
|
+
-------
|
|
283
|
+
str
|
|
284
|
+
Description of the algorithm.
|
|
285
|
+
"""
|
|
286
|
+
use_n2v2 = self.model.n2v2
|
|
287
|
+
use_structN2V = self.is_struct_n2v()
|
|
288
|
+
|
|
289
|
+
if use_n2v2 and use_structN2V:
|
|
290
|
+
return STR_N2V2_DESCRIPTION
|
|
291
|
+
elif use_n2v2:
|
|
292
|
+
return N2V2_DESCRIPTION
|
|
293
|
+
elif use_structN2V:
|
|
294
|
+
return STR_N2V_DESCRIPTION
|
|
295
|
+
else:
|
|
296
|
+
return N2V_DESCRIPTION
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
"""PN2V Algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Literal, Self
|
|
4
|
+
|
|
5
|
+
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
6
|
+
from pydantic import AfterValidator, ConfigDict, model_validator
|
|
7
|
+
|
|
8
|
+
from careamics.config.architectures import UNetConfig
|
|
9
|
+
from careamics.config.noise_model import GaussianMixtureNMConfig
|
|
10
|
+
from careamics.config.support import SupportedPixelManipulation, SupportedStructAxis
|
|
11
|
+
from careamics.config.transformations import N2VManipulateConfig
|
|
12
|
+
from careamics.config.validators import (
|
|
13
|
+
model_without_final_activation,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from .unet_algorithm_config import UNetBasedAlgorithm
|
|
17
|
+
|
|
18
|
+
PN2V = "Probabilistic Noise2Void"
|
|
19
|
+
PN2V2 = "PN2V2"
|
|
20
|
+
STRUCT_PN2V = "StructPN2V"
|
|
21
|
+
STRUCT_PN2V2 = "StructPN2V2"
|
|
22
|
+
|
|
23
|
+
PN2V_REF = CiteEntry(
|
|
24
|
+
text="Krull, A., Vicar, T., Prakash, M., Lalit, M. and Jug, F., 2020. "
|
|
25
|
+
'"Probabilistic noise2void: Unsupervised content-aware denoising". '
|
|
26
|
+
"Frontiers in Computer Science, 2, p.5.",
|
|
27
|
+
doi="10.3389/fcomp.2020.00005",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
N2V2_REF = CiteEntry(
|
|
31
|
+
text="Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., "
|
|
32
|
+
'2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified '
|
|
33
|
+
'sampling strategies and a tweaked network architecture". In European '
|
|
34
|
+
"Conference on Computer Vision (pp. 503-518).",
|
|
35
|
+
doi="10.1007/978-3-031-25069-9_33",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
STRUCTN2V_REF = CiteEntry(
|
|
39
|
+
text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020."
|
|
40
|
+
'"Removing structured noise with self-supervised blind-spot '
|
|
41
|
+
'networks". In 2020 IEEE 17th International Symposium on Biomedical '
|
|
42
|
+
"Imaging (ISBI) (pp. 159-163).",
|
|
43
|
+
doi="10.1109/isbi45749.2020.9098336",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
PN2V_DESCRIPTION = (
|
|
47
|
+
"Probabilistic Noise2Void (PN2V) is a UNet-based self-supervised algorithm that "
|
|
48
|
+
"extends Noise2Void by incorporating a probabilistic noise model to estimate the "
|
|
49
|
+
"posterior distribution of each pixel more precisely. Like N2V, it uses blind-spot "
|
|
50
|
+
"training where random pixels are selected in patches during training and their "
|
|
51
|
+
"value replaced by a neighboring pixel value. The model is then trained to predict "
|
|
52
|
+
"the original pixel value. PN2V additionally uses a noise model to better capture "
|
|
53
|
+
"the noise characteristics and provide more accurate denoising by modeling the "
|
|
54
|
+
"likelihood of pixel intensities."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
PN2V2_DESCRIPTION = (
|
|
58
|
+
"PN2V2 is a variant of Probabilistic Noise2Void. "
|
|
59
|
+
+ PN2V_DESCRIPTION
|
|
60
|
+
+ "\nPN2V2 introduces blur-pool layers and removed skip "
|
|
61
|
+
"connections in the UNet architecture to remove checkboard "
|
|
62
|
+
"artefacts, a common artefacts occurring in Noise2Void variants."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
STRUCT_PN2V_DESCRIPTION = (
|
|
66
|
+
"StructPN2V is a variant of Probabilistic Noise2Void. "
|
|
67
|
+
+ PN2V_DESCRIPTION
|
|
68
|
+
+ "\nStructPN2V uses a linear mask (horizontal or vertical) to replace "
|
|
69
|
+
"the pixel values of neighbors of the masked pixels by a random "
|
|
70
|
+
"value. Such masking allows removing 1D structured noise from the "
|
|
71
|
+
"the images, the main failure case of the original N2V variants."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
STRUCT_PN2V2_DESCRIPTION = (
|
|
75
|
+
"StructPN2V2 is a variant of Probabilistic Noise2Void that uses both "
|
|
76
|
+
"structN2V and N2V2. "
|
|
77
|
+
+ PN2V_DESCRIPTION
|
|
78
|
+
+ "\nStructPN2V2 uses a linear mask (horizontal or vertical) to replace "
|
|
79
|
+
"the pixel values of neighbors of the masked pixels by a random "
|
|
80
|
+
"value. Such masking allows removing 1D structured noise from the "
|
|
81
|
+
"the images, the main failure case of the original N2V variants."
|
|
82
|
+
"\nPN2V2 introduces blur-pool layers and removed skip connections in "
|
|
83
|
+
"the UNet architecture to remove checkboard artefacts, a common "
|
|
84
|
+
"artefacts occurring in Noise2Void variants."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class PN2VAlgorithm(UNetBasedAlgorithm):
|
|
89
|
+
"""PN2V Algorithm configuration."""
|
|
90
|
+
|
|
91
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
92
|
+
|
|
93
|
+
algorithm: Literal["pn2v"] = "pn2v"
|
|
94
|
+
"""PN2V Algorithm name."""
|
|
95
|
+
|
|
96
|
+
loss: Literal["pn2v"] = "pn2v"
|
|
97
|
+
"""PN2V loss function (uses N2V loss with noise model)."""
|
|
98
|
+
|
|
99
|
+
n2v_config: N2VManipulateConfig = N2VManipulateConfig()
|
|
100
|
+
|
|
101
|
+
noise_model: GaussianMixtureNMConfig
|
|
102
|
+
"""Noise model configuration for probabilistic denoising."""
|
|
103
|
+
|
|
104
|
+
model: Annotated[
|
|
105
|
+
UNetConfig,
|
|
106
|
+
# AfterValidator(model_matching_in_out_channels),
|
|
107
|
+
# TODO for pn2v channel handling needs to be changed
|
|
108
|
+
AfterValidator(model_without_final_activation),
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
@model_validator(mode="after")
|
|
112
|
+
def validate_n2v2(self) -> Self:
|
|
113
|
+
"""Validate that the N2V2 strategy and models are set correctly.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
Self
|
|
118
|
+
The validated configuration.
|
|
119
|
+
|
|
120
|
+
Raises
|
|
121
|
+
------
|
|
122
|
+
ValueError
|
|
123
|
+
If N2V2 is used with the wrong pixel manipulation strategy.
|
|
124
|
+
"""
|
|
125
|
+
if self.model.n2v2:
|
|
126
|
+
if self.n2v_config.strategy != SupportedPixelManipulation.MEDIAN.value:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"PN2V2 can only be used with the "
|
|
129
|
+
f"{SupportedPixelManipulation.MEDIAN} pixel manipulation strategy. "
|
|
130
|
+
f"Change the `strategy` parameters in `n2v_config` to "
|
|
131
|
+
f"{SupportedPixelManipulation.MEDIAN}."
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
if self.n2v_config.strategy != SupportedPixelManipulation.UNIFORM.value:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"PN2V can only be used with the "
|
|
137
|
+
f"{SupportedPixelManipulation.UNIFORM} pixel manipulation strategy."
|
|
138
|
+
f" Change the `strategy` parameters in `n2v_config` to "
|
|
139
|
+
f"{SupportedPixelManipulation.UNIFORM}."
|
|
140
|
+
)
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
def set_n2v2(self, use_n2v2: bool) -> None:
|
|
144
|
+
"""
|
|
145
|
+
Set the configuration to use PN2V2 or the vanilla Probabilistic Noise2Void.
|
|
146
|
+
|
|
147
|
+
This method ensures that PN2V2 is set correctly and remain coherent, as opposed
|
|
148
|
+
to setting the different parameters individually.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
use_n2v2 : bool
|
|
153
|
+
Whether to use PN2V2.
|
|
154
|
+
"""
|
|
155
|
+
if use_n2v2:
|
|
156
|
+
self.n2v_config.strategy = SupportedPixelManipulation.MEDIAN.value
|
|
157
|
+
self.model.n2v2 = True
|
|
158
|
+
else:
|
|
159
|
+
self.n2v_config.strategy = SupportedPixelManipulation.UNIFORM.value
|
|
160
|
+
self.model.n2v2 = False
|
|
161
|
+
|
|
162
|
+
def is_struct_n2v(self) -> bool:
|
|
163
|
+
"""Check if the configuration is using structPN2V.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
bool
|
|
168
|
+
Whether the configuration is using structPN2V.
|
|
169
|
+
"""
|
|
170
|
+
return self.n2v_config.struct_mask_axis != SupportedStructAxis.NONE.value
|
|
171
|
+
|
|
172
|
+
def get_algorithm_friendly_name(self) -> str:
|
|
173
|
+
"""
|
|
174
|
+
Get the friendly name of the algorithm.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
str
|
|
179
|
+
Friendly name.
|
|
180
|
+
"""
|
|
181
|
+
use_n2v2 = self.model.n2v2
|
|
182
|
+
use_structN2V = self.is_struct_n2v()
|
|
183
|
+
|
|
184
|
+
if use_n2v2 and use_structN2V:
|
|
185
|
+
return STRUCT_PN2V2
|
|
186
|
+
elif use_n2v2:
|
|
187
|
+
return PN2V2
|
|
188
|
+
elif use_structN2V:
|
|
189
|
+
return STRUCT_PN2V
|
|
190
|
+
else:
|
|
191
|
+
return PN2V
|
|
192
|
+
|
|
193
|
+
def get_algorithm_keywords(self) -> list[str]:
|
|
194
|
+
"""
|
|
195
|
+
Get algorithm keywords.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
list[str]
|
|
200
|
+
List of keywords.
|
|
201
|
+
"""
|
|
202
|
+
use_n2v2 = self.model.n2v2
|
|
203
|
+
use_structN2V = self.is_struct_n2v()
|
|
204
|
+
|
|
205
|
+
keywords = [
|
|
206
|
+
"denoising",
|
|
207
|
+
"restoration",
|
|
208
|
+
"UNet",
|
|
209
|
+
"3D" if self.model.is_3D() else "2D",
|
|
210
|
+
"CAREamics",
|
|
211
|
+
"pytorch",
|
|
212
|
+
PN2V,
|
|
213
|
+
"probabilistic",
|
|
214
|
+
"noise model",
|
|
215
|
+
]
|
|
216
|
+
|
|
217
|
+
if use_n2v2:
|
|
218
|
+
keywords.append(PN2V2)
|
|
219
|
+
if use_structN2V:
|
|
220
|
+
keywords.append(STRUCT_PN2V)
|
|
221
|
+
|
|
222
|
+
return keywords
|
|
223
|
+
|
|
224
|
+
def get_algorithm_references(self) -> str:
|
|
225
|
+
"""
|
|
226
|
+
Get the algorithm references.
|
|
227
|
+
|
|
228
|
+
This is used to generate the README of the BioImage Model Zoo export.
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
str
|
|
233
|
+
Algorithm references.
|
|
234
|
+
"""
|
|
235
|
+
use_n2v2 = self.model.n2v2
|
|
236
|
+
use_structN2V = self.is_struct_n2v()
|
|
237
|
+
|
|
238
|
+
references = [
|
|
239
|
+
PN2V_REF.text + " doi: " + PN2V_REF.doi,
|
|
240
|
+
N2V2_REF.text + " doi: " + N2V2_REF.doi,
|
|
241
|
+
STRUCTN2V_REF.text + " doi: " + STRUCTN2V_REF.doi,
|
|
242
|
+
]
|
|
243
|
+
|
|
244
|
+
# return the (struct)PN2V(2) references
|
|
245
|
+
if use_n2v2 and use_structN2V:
|
|
246
|
+
return "\n".join(references)
|
|
247
|
+
elif use_n2v2:
|
|
248
|
+
references.pop(-1)
|
|
249
|
+
return "\n".join(references)
|
|
250
|
+
elif use_structN2V:
|
|
251
|
+
references.pop(-2)
|
|
252
|
+
return "\n".join(references)
|
|
253
|
+
else:
|
|
254
|
+
return references[0]
|
|
255
|
+
|
|
256
|
+
def get_algorithm_citations(self) -> list[CiteEntry]:
|
|
257
|
+
"""
|
|
258
|
+
Return a list of citation entries of the current algorithm.
|
|
259
|
+
|
|
260
|
+
This is used to generate the model description for the BioImage Model Zoo.
|
|
261
|
+
|
|
262
|
+
Returns
|
|
263
|
+
-------
|
|
264
|
+
List[CiteEntry]
|
|
265
|
+
List of citation entries.
|
|
266
|
+
"""
|
|
267
|
+
use_n2v2 = self.model.n2v2
|
|
268
|
+
use_structN2V = self.is_struct_n2v()
|
|
269
|
+
|
|
270
|
+
references = [PN2V_REF]
|
|
271
|
+
|
|
272
|
+
if use_n2v2:
|
|
273
|
+
references.append(N2V2_REF)
|
|
274
|
+
|
|
275
|
+
if use_structN2V:
|
|
276
|
+
references.append(STRUCTN2V_REF)
|
|
277
|
+
|
|
278
|
+
return references
|
|
279
|
+
|
|
280
|
+
def get_algorithm_description(self) -> str:
|
|
281
|
+
"""
|
|
282
|
+
Return a description of the algorithm.
|
|
283
|
+
|
|
284
|
+
This method is used to generate the README of the BioImage Model Zoo export.
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
str
|
|
289
|
+
Description of the algorithm.
|
|
290
|
+
"""
|
|
291
|
+
use_n2v2 = self.model.n2v2
|
|
292
|
+
use_structN2V = self.is_struct_n2v()
|
|
293
|
+
|
|
294
|
+
if use_n2v2 and use_structN2V:
|
|
295
|
+
return STRUCT_PN2V2_DESCRIPTION
|
|
296
|
+
elif use_n2v2:
|
|
297
|
+
return PN2V2_DESCRIPTION
|
|
298
|
+
elif use_structN2V:
|
|
299
|
+
return STRUCT_PN2V_DESCRIPTION
|
|
300
|
+
else:
|
|
301
|
+
return PN2V_DESCRIPTION
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""UNet-based algorithm Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from pprint import pformat
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
|
|
8
|
+
from careamics.config.architectures import UNetConfig
|
|
9
|
+
from careamics.config.lightning.optimizer_configs import (
|
|
10
|
+
LrSchedulerConfig,
|
|
11
|
+
OptimizerConfig,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UNetBasedAlgorithm(BaseModel):
|
|
16
|
+
"""General UNet-based algorithm configuration.
|
|
17
|
+
|
|
18
|
+
This Pydantic model validates the parameters governing the components of the
|
|
19
|
+
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
20
|
+
and learning rate scheduler to use.
|
|
21
|
+
|
|
22
|
+
Currently, we only support N2V, CARE, N2N, and PN2V algorithms. In order to train
|
|
23
|
+
these algorithms, use the corresponding configuration child classes (e.g.
|
|
24
|
+
`N2VAlgorithm`) to ensure coherent parameters (e.g. specific losses).
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
Attributes
|
|
28
|
+
----------
|
|
29
|
+
algorithm : {"n2v", "care", "n2n", "pn2v"}
|
|
30
|
+
Algorithm to use.
|
|
31
|
+
loss : {"n2v", "mae", "mse"}
|
|
32
|
+
Loss function to use.
|
|
33
|
+
model : UNetConfig
|
|
34
|
+
Model architecture to use.
|
|
35
|
+
optimizer : OptimizerConfig, optional
|
|
36
|
+
Optimizer to use.
|
|
37
|
+
lr_scheduler : LrSchedulerConfig, optional
|
|
38
|
+
Learning rate scheduler to use.
|
|
39
|
+
|
|
40
|
+
Raises
|
|
41
|
+
------
|
|
42
|
+
ValueError
|
|
43
|
+
Algorithm parameter type validation errors.
|
|
44
|
+
ValueError
|
|
45
|
+
If the algorithm, loss and model are not compatible.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
# Pydantic class configuration
|
|
49
|
+
model_config = ConfigDict(
|
|
50
|
+
protected_namespaces=(), # allows to use model_* as a field name
|
|
51
|
+
validate_assignment=True,
|
|
52
|
+
extra="allow",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Mandatory fields
|
|
56
|
+
algorithm: Literal["n2v", "care", "n2n", "pn2v"]
|
|
57
|
+
"""Algorithm name, as defined in SupportedAlgorithm."""
|
|
58
|
+
|
|
59
|
+
loss: Literal["n2v", "mae", "mse", "pn2v"]
|
|
60
|
+
"""Loss function to use, as defined in SupportedLoss."""
|
|
61
|
+
|
|
62
|
+
model: UNetConfig
|
|
63
|
+
"""UNet model configuration."""
|
|
64
|
+
|
|
65
|
+
# Optional fields
|
|
66
|
+
optimizer: OptimizerConfig = OptimizerConfig()
|
|
67
|
+
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
68
|
+
|
|
69
|
+
lr_scheduler: LrSchedulerConfig = LrSchedulerConfig()
|
|
70
|
+
"""Learning rate scheduler to use, defined in SupportedLrScheduler."""
|
|
71
|
+
|
|
72
|
+
def __str__(self) -> str:
|
|
73
|
+
"""Pretty string representing the configuration.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
str
|
|
78
|
+
Pretty string.
|
|
79
|
+
"""
|
|
80
|
+
return pformat(self.model_dump())
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def get_compatible_algorithms(cls) -> list[str]:
|
|
84
|
+
"""Get the list of compatible algorithms.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
list of str
|
|
89
|
+
List of compatible algorithms.
|
|
90
|
+
"""
|
|
91
|
+
return ["n2v", "care", "n2n", "pn2v"]
|