careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Convenience function to create algorithm configurations."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Any, Literal, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, TypeAdapter
|
|
6
|
+
|
|
7
|
+
from careamics.config.algorithms import (
|
|
8
|
+
CAREAlgorithm,
|
|
9
|
+
N2NAlgorithm,
|
|
10
|
+
N2VAlgorithm,
|
|
11
|
+
# PN2VAlgorithm, # TODO not yet compatible with NG Dataset
|
|
12
|
+
)
|
|
13
|
+
from careamics.config.architectures import UNetConfig
|
|
14
|
+
from careamics.config.support.supported_architectures import SupportedArchitecture
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# TODO rename so that it does not bear the same name as the module?
|
|
18
|
+
def algorithm_factory(
|
|
19
|
+
algorithm: dict[str, Any],
|
|
20
|
+
) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
|
|
21
|
+
"""
|
|
22
|
+
Create an algorithm model for training CAREamics.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
algorithm : dict
|
|
27
|
+
Algorithm dictionary.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
|
|
32
|
+
Algorithm model for training CAREamics.
|
|
33
|
+
"""
|
|
34
|
+
adapter: TypeAdapter = TypeAdapter(
|
|
35
|
+
Annotated[
|
|
36
|
+
Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm],
|
|
37
|
+
Field(discriminator="algorithm"),
|
|
38
|
+
]
|
|
39
|
+
)
|
|
40
|
+
return adapter.validate_python(algorithm)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def create_algorithm_configuration(
|
|
44
|
+
dimensions: Literal[2, 3],
|
|
45
|
+
algorithm: Literal["n2v", "care", "n2n"],
|
|
46
|
+
loss: Literal["n2v", "mae", "mse"],
|
|
47
|
+
independent_channels: bool,
|
|
48
|
+
n_channels_in: int,
|
|
49
|
+
n_channels_out: int,
|
|
50
|
+
use_n2v2: bool = False,
|
|
51
|
+
model_params: dict | None = None,
|
|
52
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
53
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
54
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
55
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
56
|
+
) -> dict:
|
|
57
|
+
"""
|
|
58
|
+
Create a dictionary with the parameters of the algorithm model.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
dimensions : {2, 3}
|
|
63
|
+
Dimension of the model, either 2D or 3D.
|
|
64
|
+
algorithm : {"n2v", "care", "n2n"}
|
|
65
|
+
Algorithm to use.
|
|
66
|
+
loss : {"n2v", "mae", "mse"}
|
|
67
|
+
Loss function to use.
|
|
68
|
+
independent_channels : bool
|
|
69
|
+
Whether to train all channels independently.
|
|
70
|
+
n_channels_in : int
|
|
71
|
+
Number of input channels.
|
|
72
|
+
n_channels_out : int
|
|
73
|
+
Number of output channels.
|
|
74
|
+
use_n2v2 : bool, default=false
|
|
75
|
+
Whether to use N2V2.
|
|
76
|
+
model_params : dict, default=None
|
|
77
|
+
UNetModel parameters.
|
|
78
|
+
optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
|
|
79
|
+
Optimizer to use.
|
|
80
|
+
optimizer_params : dict, default=None
|
|
81
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
82
|
+
lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
|
|
83
|
+
Learning rate scheduler to use.
|
|
84
|
+
lr_scheduler_params : dict, default=None
|
|
85
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
86
|
+
details.
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
dict
|
|
92
|
+
Algorithm model as dictionnary with the specified parameters.
|
|
93
|
+
"""
|
|
94
|
+
# create dictionary to ensure priority of explicit parameters over model_params
|
|
95
|
+
# and prevent multiple same parameters being passed to UNetConfig
|
|
96
|
+
model_params = {} if model_params is None else model_params
|
|
97
|
+
model_params["n2v2"] = use_n2v2
|
|
98
|
+
model_params["conv_dims"] = dimensions
|
|
99
|
+
model_params["in_channels"] = n_channels_in
|
|
100
|
+
model_params["num_classes"] = n_channels_out
|
|
101
|
+
model_params["independent_channels"] = independent_channels
|
|
102
|
+
|
|
103
|
+
unet_model = UNetConfig(
|
|
104
|
+
architecture=SupportedArchitecture.UNET.value,
|
|
105
|
+
**model_params,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return {
|
|
109
|
+
"algorithm": algorithm,
|
|
110
|
+
"loss": loss,
|
|
111
|
+
"model": unet_model,
|
|
112
|
+
"optimizer": {
|
|
113
|
+
"name": optimizer,
|
|
114
|
+
"parameters": {} if optimizer_params is None else optimizer_params,
|
|
115
|
+
},
|
|
116
|
+
"lr_scheduler": {
|
|
117
|
+
"name": lr_scheduler,
|
|
118
|
+
"parameters": {} if lr_scheduler_params is None else lr_scheduler_params,
|
|
119
|
+
},
|
|
120
|
+
}
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Convenience functions to create NG data configurations."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
from careamics.config.data import NGDataConfig
|
|
7
|
+
from careamics.config.transformations import (
|
|
8
|
+
SPATIAL_TRANSFORMS_UNION,
|
|
9
|
+
XYFlipConfig,
|
|
10
|
+
XYRandomRotate90Config,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def list_spatial_augmentations(
|
|
15
|
+
augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
|
|
16
|
+
) -> list[SPATIAL_TRANSFORMS_UNION]:
|
|
17
|
+
"""
|
|
18
|
+
List the augmentations to apply.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
augmentations : list of transforms, optional
|
|
23
|
+
List of transforms to apply, either both or one of XYFlipConfig and
|
|
24
|
+
XYRandomRotate90Config.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
list of transforms
|
|
29
|
+
List of transforms to apply.
|
|
30
|
+
|
|
31
|
+
Raises
|
|
32
|
+
------
|
|
33
|
+
ValueError
|
|
34
|
+
If the transforms are not XYFlipConfig or XYRandomRotate90Config.
|
|
35
|
+
ValueError
|
|
36
|
+
If there are duplicate transforms.
|
|
37
|
+
"""
|
|
38
|
+
if augmentations is None:
|
|
39
|
+
transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
|
|
40
|
+
XYFlipConfig(),
|
|
41
|
+
XYRandomRotate90Config(),
|
|
42
|
+
]
|
|
43
|
+
else:
|
|
44
|
+
# throw error if not all transforms are pydantic models
|
|
45
|
+
if not all(
|
|
46
|
+
isinstance(t, XYFlipConfig) or isinstance(t, XYRandomRotate90Config)
|
|
47
|
+
for t in augmentations
|
|
48
|
+
):
|
|
49
|
+
raise ValueError(
|
|
50
|
+
"Accepted transforms are either XYFlipConfig or "
|
|
51
|
+
"XYRandomRotate90Config."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# check that there is no duplication
|
|
55
|
+
aug_types = [t.__class__ for t in augmentations]
|
|
56
|
+
if len(set(aug_types)) != len(aug_types):
|
|
57
|
+
raise ValueError("Duplicate transforms are not allowed.")
|
|
58
|
+
|
|
59
|
+
transform_list = augmentations
|
|
60
|
+
|
|
61
|
+
return transform_list
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def create_ng_data_configuration(
|
|
65
|
+
data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
|
|
66
|
+
axes: str,
|
|
67
|
+
patch_size: Sequence[int],
|
|
68
|
+
batch_size: int,
|
|
69
|
+
augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
|
|
70
|
+
channels: Sequence[int] | None = None,
|
|
71
|
+
in_memory: bool | None = None,
|
|
72
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
73
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
74
|
+
pred_dataloader_params: dict[str, Any] | None = None,
|
|
75
|
+
seed: int | None = None,
|
|
76
|
+
) -> NGDataConfig:
|
|
77
|
+
"""
|
|
78
|
+
Create a training NGDatasetConfig.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
data_type : {"array", "tiff", "zarr", "czi", "custom"}
|
|
83
|
+
Type of the data.
|
|
84
|
+
axes : str
|
|
85
|
+
Axes of the data.
|
|
86
|
+
patch_size : list of int
|
|
87
|
+
Size of the patches along the spatial dimensions.
|
|
88
|
+
batch_size : int
|
|
89
|
+
Batch size.
|
|
90
|
+
augmentations : list of transforms
|
|
91
|
+
List of transforms to apply.
|
|
92
|
+
channels : Sequence of int, default=None
|
|
93
|
+
List of channels to use. If `None`, all channels are used.
|
|
94
|
+
in_memory : bool, default=None
|
|
95
|
+
Whether to load all data into memory. This is only supported for 'array',
|
|
96
|
+
'tiff' and 'custom' data types. If `None`, defaults to `True` for 'array',
|
|
97
|
+
'tiff' and `custom`, and `False` for 'zarr' and 'czi' data types. Must be `True`
|
|
98
|
+
for `array`.
|
|
99
|
+
augmentations : list of transforms or None, default=None
|
|
100
|
+
List of transforms to apply. If `None`, default augmentations are applied
|
|
101
|
+
(flip in X and Y, rotations by 90 degrees in the XY plane).
|
|
102
|
+
train_dataloader_params : dict
|
|
103
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
104
|
+
val_dataloader_params : dict
|
|
105
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
106
|
+
pred_dataloader_params : dict
|
|
107
|
+
Parameters for the test dataloader, see PyTorch notes, by default None.
|
|
108
|
+
seed : int, default=None
|
|
109
|
+
Random seed for reproducibility. If `None`, no seed is set.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
NGDataConfig
|
|
114
|
+
Next-Generation Data model with the specified parameters.
|
|
115
|
+
"""
|
|
116
|
+
if augmentations is None:
|
|
117
|
+
augmentations = list_spatial_augmentations()
|
|
118
|
+
|
|
119
|
+
# data model
|
|
120
|
+
data: dict[str, Any] = {
|
|
121
|
+
"mode": "training",
|
|
122
|
+
"data_type": data_type,
|
|
123
|
+
"axes": axes,
|
|
124
|
+
"batch_size": batch_size,
|
|
125
|
+
"channels": channels,
|
|
126
|
+
"transforms": augmentations,
|
|
127
|
+
"seed": seed,
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
if in_memory is not None:
|
|
131
|
+
data["in_memory"] = in_memory
|
|
132
|
+
|
|
133
|
+
# don't override defaults set in DataConfig class
|
|
134
|
+
if train_dataloader_params is not None:
|
|
135
|
+
# the presence of `shuffle` key in the dataloader parameters is enforced
|
|
136
|
+
# by the NGDataConfig class
|
|
137
|
+
if "shuffle" not in train_dataloader_params:
|
|
138
|
+
train_dataloader_params["shuffle"] = True
|
|
139
|
+
|
|
140
|
+
data["train_dataloader_params"] = train_dataloader_params
|
|
141
|
+
|
|
142
|
+
if val_dataloader_params is not None:
|
|
143
|
+
data["val_dataloader_params"] = val_dataloader_params
|
|
144
|
+
|
|
145
|
+
if pred_dataloader_params is not None:
|
|
146
|
+
data["pred_dataloader_params"] = pred_dataloader_params
|
|
147
|
+
|
|
148
|
+
# add training patching
|
|
149
|
+
data["patching"] = {
|
|
150
|
+
"name": "random",
|
|
151
|
+
"patch_size": patch_size,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
return NGDataConfig(**data)
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""Convenience function to create N2V configurations."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
from careamics.config.ng_configs import N2VConfiguration
|
|
7
|
+
from careamics.config.support import (
|
|
8
|
+
SupportedPixelManipulation,
|
|
9
|
+
SupportedTransform,
|
|
10
|
+
)
|
|
11
|
+
from careamics.config.transformations import (
|
|
12
|
+
N2VManipulateConfig,
|
|
13
|
+
XYFlipConfig,
|
|
14
|
+
XYRandomRotate90Config,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from .algorithm_factory import create_algorithm_configuration
|
|
18
|
+
from .data_factory import create_ng_data_configuration, list_spatial_augmentations
|
|
19
|
+
from .training_factory import create_training_configuration, update_trainer_params
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def create_n2v_configuration(
|
|
23
|
+
experiment_name: str,
|
|
24
|
+
data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
|
|
25
|
+
axes: str,
|
|
26
|
+
patch_size: Sequence[int],
|
|
27
|
+
batch_size: int,
|
|
28
|
+
num_epochs: int = 100,
|
|
29
|
+
num_steps: int | None = None,
|
|
30
|
+
augmentations: list[XYFlipConfig | XYRandomRotate90Config] | None = None,
|
|
31
|
+
channels: Sequence[int] | None = None,
|
|
32
|
+
in_memory: bool | None = None,
|
|
33
|
+
independent_channels: bool = True,
|
|
34
|
+
use_n2v2: bool = False,
|
|
35
|
+
n_channels: int | None = None,
|
|
36
|
+
roi_size: int = 11,
|
|
37
|
+
masked_pixel_percentage: float = 0.2,
|
|
38
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
39
|
+
struct_n2v_span: int = 5,
|
|
40
|
+
trainer_params: dict | None = None,
|
|
41
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
42
|
+
model_params: dict | None = None,
|
|
43
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
44
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
45
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
46
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
47
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
48
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
49
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
50
|
+
) -> N2VConfiguration:
|
|
51
|
+
"""
|
|
52
|
+
Create a configuration for training Noise2Void.
|
|
53
|
+
|
|
54
|
+
N2V uses a UNet model to denoise images in a self-supervised manner. To use its
|
|
55
|
+
variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
|
|
56
|
+
(structN2V) parameters, or set `use_n2v2` to True (N2V2).
|
|
57
|
+
|
|
58
|
+
N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
|
|
59
|
+
connections, thus removing checkboard artefacts. StructN2V is used when vertical
|
|
60
|
+
or horizontal correlations are present in the noise; it applies an additional mask
|
|
61
|
+
to the manipulated pixel neighbors.
|
|
62
|
+
|
|
63
|
+
If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
|
|
64
|
+
2.
|
|
65
|
+
|
|
66
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
67
|
+
channels.
|
|
68
|
+
|
|
69
|
+
By default, all channels are trained independently. To train all channels together,
|
|
70
|
+
set `independent_channels` to False.
|
|
71
|
+
|
|
72
|
+
By default, the transformations applied are a random flip along X or Y, and a random
|
|
73
|
+
90 degrees rotation in the XY plane. Normalization is always applied, as well as the
|
|
74
|
+
N2V manipulation.
|
|
75
|
+
|
|
76
|
+
By setting `augmentations` to `None`, the default transformations (flip in X and Y,
|
|
77
|
+
rotations by 90 degrees in the XY plane) are applied. Rather than the default
|
|
78
|
+
transforms, a list of transforms can be passed to the `augmentations` parameter. To
|
|
79
|
+
disable the transforms, simply pass an empty list.
|
|
80
|
+
|
|
81
|
+
The `roi_size` parameter specifies the size of the area around each pixel that will
|
|
82
|
+
be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
|
|
83
|
+
pixels per patch will be manipulated.
|
|
84
|
+
|
|
85
|
+
The parameters of the UNet can be specified in the `model_params` (passed as a
|
|
86
|
+
parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
|
|
87
|
+
corresponding parameters passed in `model_params`.
|
|
88
|
+
|
|
89
|
+
If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
|
|
90
|
+
will be applied to each manipulated pixel.
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
experiment_name : str
|
|
95
|
+
Name of the experiment.
|
|
96
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
97
|
+
Type of the data.
|
|
98
|
+
axes : str
|
|
99
|
+
Axes of the data (e.g. SYX).
|
|
100
|
+
patch_size : List[int]
|
|
101
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
102
|
+
batch_size : int
|
|
103
|
+
Batch size.
|
|
104
|
+
num_epochs : int, default=100
|
|
105
|
+
Number of epochs to train for. If provided, this will be added to
|
|
106
|
+
trainer_params.
|
|
107
|
+
num_steps : int, optional
|
|
108
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
109
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
110
|
+
documentation for more details.
|
|
111
|
+
augmentations : list of transforms, default=None
|
|
112
|
+
List of transforms to apply, either both or one of XYFlipConfig and
|
|
113
|
+
XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
|
|
114
|
+
and XYRandomRotate90 (in XY) to the images.
|
|
115
|
+
channels : Sequence of int, optional
|
|
116
|
+
List of channels to use. If `None`, all channels are used.
|
|
117
|
+
in_memory : bool, optional
|
|
118
|
+
Whether to load all data into memory. This is only supported for 'array',
|
|
119
|
+
'tiff' and 'custom' data types. If `None`, defaults to `True` for 'array',
|
|
120
|
+
'tiff' and `custom`, and `False` for 'zarr' and 'czi' data types. Must be `True`
|
|
121
|
+
for `array`.
|
|
122
|
+
independent_channels : bool, optional
|
|
123
|
+
Whether to train all channels together, by default True.
|
|
124
|
+
use_n2v2 : bool, optional
|
|
125
|
+
Whether to use N2V2, by default False.
|
|
126
|
+
n_channels : int or None, default=None
|
|
127
|
+
Number of channels (in and out). If `channels` is specified, then the number of
|
|
128
|
+
channels is inferred from its length.
|
|
129
|
+
roi_size : int, optional
|
|
130
|
+
N2V pixel manipulation area, by default 11.
|
|
131
|
+
masked_pixel_percentage : float, optional
|
|
132
|
+
Percentage of pixels masked in each patch, by default 0.2.
|
|
133
|
+
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
134
|
+
Axis along which to apply structN2V mask, by default "none".
|
|
135
|
+
struct_n2v_span : int, optional
|
|
136
|
+
Span of the structN2V mask, by default 5.
|
|
137
|
+
trainer_params : dict, optional
|
|
138
|
+
Parameters for the trainer, see the relevant documentation.
|
|
139
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
140
|
+
Logger to use, by default "none".
|
|
141
|
+
model_params : dict, default=None
|
|
142
|
+
UNetModel parameters.
|
|
143
|
+
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
144
|
+
Optimizer to use.
|
|
145
|
+
optimizer_params : dict, default=None
|
|
146
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
147
|
+
lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
|
|
148
|
+
Learning rate scheduler to use.
|
|
149
|
+
lr_scheduler_params : dict, default=None
|
|
150
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
151
|
+
details.
|
|
152
|
+
train_dataloader_params : dict, optional
|
|
153
|
+
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
154
|
+
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
155
|
+
the `GeneralDataConfig`.
|
|
156
|
+
val_dataloader_params : dict, optional
|
|
157
|
+
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
158
|
+
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
159
|
+
`GeneralDataConfig`.
|
|
160
|
+
checkpoint_params : dict, default=None
|
|
161
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
162
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
N2VConfiguration
|
|
167
|
+
Configuration for training N2V.
|
|
168
|
+
"""
|
|
169
|
+
# if there are channels, we need to specify their number
|
|
170
|
+
channels_present = "C" in axes
|
|
171
|
+
|
|
172
|
+
if channels_present and (n_channels is None and channels is None):
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"`n_channels` or `channels` must be specified when using channels."
|
|
175
|
+
)
|
|
176
|
+
elif not channels_present and (n_channels is not None and n_channels > 1):
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"C is not present in the axes, but number of channels is specified "
|
|
179
|
+
f"(got {n_channels} channel)."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if n_channels is not None and channels is not None:
|
|
183
|
+
if n_channels != len(channels):
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"Number of channels ({n_channels}) does not match length of "
|
|
186
|
+
f"`channels` ({len(channels)}). Only specify `channels`."
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if n_channels is None:
|
|
190
|
+
n_channels = 1 if channels is None else len(channels)
|
|
191
|
+
|
|
192
|
+
# augmentations
|
|
193
|
+
spatial_transforms = list_spatial_augmentations(augmentations)
|
|
194
|
+
|
|
195
|
+
# data
|
|
196
|
+
data_config = create_ng_data_configuration(
|
|
197
|
+
data_type=data_type,
|
|
198
|
+
axes=axes,
|
|
199
|
+
patch_size=patch_size,
|
|
200
|
+
batch_size=batch_size,
|
|
201
|
+
augmentations=spatial_transforms,
|
|
202
|
+
channels=channels,
|
|
203
|
+
in_memory=in_memory,
|
|
204
|
+
train_dataloader_params=train_dataloader_params,
|
|
205
|
+
val_dataloader_params=val_dataloader_params,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# algorithm
|
|
209
|
+
algorithm_params = create_algorithm_configuration(
|
|
210
|
+
dimensions=3 if data_config.is_3D() else 2,
|
|
211
|
+
algorithm="n2v",
|
|
212
|
+
loss="n2v",
|
|
213
|
+
independent_channels=independent_channels,
|
|
214
|
+
n_channels_in=n_channels,
|
|
215
|
+
n_channels_out=n_channels,
|
|
216
|
+
use_n2v2=use_n2v2,
|
|
217
|
+
model_params=model_params,
|
|
218
|
+
optimizer=optimizer,
|
|
219
|
+
optimizer_params=optimizer_params,
|
|
220
|
+
lr_scheduler=lr_scheduler,
|
|
221
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# create the N2VManipulate transform using the supplied parameters
|
|
225
|
+
n2v_transform = N2VManipulateConfig(
|
|
226
|
+
name=SupportedTransform.N2V_MANIPULATE.value,
|
|
227
|
+
strategy=(
|
|
228
|
+
SupportedPixelManipulation.MEDIAN.value
|
|
229
|
+
if use_n2v2
|
|
230
|
+
else SupportedPixelManipulation.UNIFORM.value
|
|
231
|
+
),
|
|
232
|
+
roi_size=roi_size,
|
|
233
|
+
masked_pixel_percentage=masked_pixel_percentage,
|
|
234
|
+
struct_mask_axis=struct_n2v_axis,
|
|
235
|
+
struct_mask_span=struct_n2v_span,
|
|
236
|
+
)
|
|
237
|
+
algorithm_params["n2v_config"] = n2v_transform
|
|
238
|
+
|
|
239
|
+
# training
|
|
240
|
+
final_trainer_params = update_trainer_params(
|
|
241
|
+
trainer_params=trainer_params,
|
|
242
|
+
num_epochs=num_epochs,
|
|
243
|
+
num_steps=num_steps,
|
|
244
|
+
)
|
|
245
|
+
training_params = create_training_configuration(
|
|
246
|
+
trainer_params=final_trainer_params,
|
|
247
|
+
logger=logger,
|
|
248
|
+
checkpoint_params=checkpoint_params,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return N2VConfiguration(
|
|
252
|
+
experiment_name=experiment_name,
|
|
253
|
+
algorithm_config=algorithm_params,
|
|
254
|
+
data_config=data_config,
|
|
255
|
+
training_config=training_params,
|
|
256
|
+
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Convenience functions to create training configurations."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from careamics.config.lightning.training_config import TrainingConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_training_configuration(
|
|
9
|
+
trainer_params: dict,
|
|
10
|
+
logger: Literal["wandb", "tensorboard", "none"],
|
|
11
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
12
|
+
) -> TrainingConfig:
|
|
13
|
+
"""
|
|
14
|
+
Create a dictionary with the parameters of the training model.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
trainer_params : dict
|
|
19
|
+
Parameters for Lightning Trainer class, see PyTorch Lightning documentation.
|
|
20
|
+
logger : {"wandb", "tensorboard", "none"}
|
|
21
|
+
Logger to use.
|
|
22
|
+
checkpoint_params : dict, default=None
|
|
23
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
24
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
TrainingConfig
|
|
29
|
+
Training model with the specified parameters.
|
|
30
|
+
"""
|
|
31
|
+
return TrainingConfig(
|
|
32
|
+
lightning_trainer_config=trainer_params,
|
|
33
|
+
logger=None if logger == "none" else logger,
|
|
34
|
+
checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def update_trainer_params(
|
|
39
|
+
trainer_params: dict[str, Any] | None = None,
|
|
40
|
+
num_epochs: int | None = None,
|
|
41
|
+
num_steps: int | None = None,
|
|
42
|
+
) -> dict[str, Any]:
|
|
43
|
+
"""
|
|
44
|
+
Update trainer parameters with num_epochs and num_steps.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
trainer_params : dict, optional
|
|
49
|
+
Parameters for Lightning Trainer class, by default None.
|
|
50
|
+
num_epochs : int, optional
|
|
51
|
+
Number of epochs to train for. If provided, this will be added as max_epochs
|
|
52
|
+
to trainer_params, by default None.
|
|
53
|
+
num_steps : int, optional
|
|
54
|
+
Number of batches in 1 epoch. If provided, this will be added as
|
|
55
|
+
limit_train_batches to trainer_params, by default None.
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
dict
|
|
60
|
+
Updated trainer parameters dictionary.
|
|
61
|
+
"""
|
|
62
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
63
|
+
|
|
64
|
+
if num_epochs is not None:
|
|
65
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
66
|
+
if num_steps is not None:
|
|
67
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
68
|
+
|
|
69
|
+
return final_trainer_params
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Noise models Pydantic configurations."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"GaussianLikelihoodConfig",
|
|
5
|
+
"GaussianMixtureNMConfig",
|
|
6
|
+
"MultiChannelNMConfig",
|
|
7
|
+
"NMLikelihoodConfig",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from .likelihood_config import GaussianLikelihoodConfig, NMLikelihoodConfig
|
|
12
|
+
from .noise_model_config import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Likelihood model."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Literal, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator
|
|
8
|
+
|
|
9
|
+
from careamics.models.lvae.noise_models import (
|
|
10
|
+
GaussianMixtureNoiseModel,
|
|
11
|
+
MultiChannelNoiseModel,
|
|
12
|
+
)
|
|
13
|
+
from careamics.utils.serializers import _array_to_json, _to_torch
|
|
14
|
+
|
|
15
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
16
|
+
|
|
17
|
+
# TODO: this is a temporary solution to serialize and deserialize tensor fields
|
|
18
|
+
# in pydantic models. Specifically, the aim is to enable saving and loading configs
|
|
19
|
+
# with such tensors to/from JSON files during, resp., training and evaluation.
|
|
20
|
+
Tensor = Annotated[
|
|
21
|
+
Union[np.ndarray, torch.Tensor],
|
|
22
|
+
PlainSerializer(_array_to_json, return_type=str),
|
|
23
|
+
PlainValidator(_to_torch),
|
|
24
|
+
]
|
|
25
|
+
"""Annotated tensor type, used to serialize arrays or tensors to JSON strings
|
|
26
|
+
and deserialize them back to tensors."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class GaussianLikelihoodConfig(BaseModel):
|
|
30
|
+
"""Gaussian likelihood configuration."""
|
|
31
|
+
|
|
32
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
33
|
+
|
|
34
|
+
predict_logvar: Literal["pixelwise"] | None = None
|
|
35
|
+
"""If `pixelwise`, log-variance is computed for each pixel, else log-variance
|
|
36
|
+
is not computed."""
|
|
37
|
+
|
|
38
|
+
logvar_lowerbound: Union[float, None] = None
|
|
39
|
+
"""The lowerbound value for log-variance."""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class NMLikelihoodConfig(BaseModel):
|
|
43
|
+
"""Noise model likelihood configuration.
|
|
44
|
+
|
|
45
|
+
NOTE: we need to define the data mean and std here because the noise model
|
|
46
|
+
is trained on not-normalized data. Hence, we need to unnormalize the model
|
|
47
|
+
output to compute the noise model likelihood.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
51
|
+
|
|
52
|
+
# TODO remove and use as parameters to the likelihood functions?
|
|
53
|
+
data_mean: Tensor | None = None
|
|
54
|
+
"""The mean of the data, used to unnormalize data for noise model evaluation.
|
|
55
|
+
Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
56
|
+
|
|
57
|
+
# TODO remove and use as parameters to the likelihood functions?
|
|
58
|
+
data_std: Tensor | None = None
|
|
59
|
+
"""The standard deviation of the data, used to unnormalize data for noise
|
|
60
|
+
model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|