careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,2400 @@
|
|
|
1
|
+
"""Convenience functions to create configurations for training and inference."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Annotated, Any, Literal, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import Field, TypeAdapter
|
|
7
|
+
|
|
8
|
+
from careamics.config.algorithms import (
|
|
9
|
+
CAREAlgorithm,
|
|
10
|
+
MicroSplitAlgorithm,
|
|
11
|
+
N2NAlgorithm,
|
|
12
|
+
N2VAlgorithm,
|
|
13
|
+
PN2VAlgorithm,
|
|
14
|
+
)
|
|
15
|
+
from careamics.config.architectures import LVAEConfig, UNetConfig
|
|
16
|
+
from careamics.config.data import DataConfig
|
|
17
|
+
from careamics.config.lightning.training_config import TrainingConfig
|
|
18
|
+
from careamics.config.losses.loss_config import LVAELossConfig
|
|
19
|
+
from careamics.config.noise_model.likelihood_config import (
|
|
20
|
+
GaussianLikelihoodConfig,
|
|
21
|
+
NMLikelihoodConfig,
|
|
22
|
+
)
|
|
23
|
+
from careamics.config.noise_model.noise_model_config import (
|
|
24
|
+
GaussianMixtureNMConfig,
|
|
25
|
+
MultiChannelNMConfig,
|
|
26
|
+
)
|
|
27
|
+
from careamics.config.support import (
|
|
28
|
+
SupportedArchitecture,
|
|
29
|
+
SupportedPixelManipulation,
|
|
30
|
+
SupportedTransform,
|
|
31
|
+
)
|
|
32
|
+
from careamics.config.transformations import (
|
|
33
|
+
SPATIAL_TRANSFORMS_UNION,
|
|
34
|
+
N2VManipulateConfig,
|
|
35
|
+
XYFlipConfig,
|
|
36
|
+
XYRandomRotate90Config,
|
|
37
|
+
)
|
|
38
|
+
from careamics.lvae_training.dataset.config import MicroSplitDataConfig
|
|
39
|
+
|
|
40
|
+
from .configuration import Configuration
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def algorithm_factory(
|
|
44
|
+
algorithm: dict[str, Any],
|
|
45
|
+
) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm, PN2VAlgorithm]:
|
|
46
|
+
"""
|
|
47
|
+
Create an algorithm model for training CAREamics.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
algorithm : dict
|
|
52
|
+
Algorithm dictionary.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
N2VAlgorithm or N2NAlgorithm or CAREAlgorithm or PN2VAlgorithm
|
|
57
|
+
Algorithm model for training CAREamics.
|
|
58
|
+
"""
|
|
59
|
+
adapter: TypeAdapter = TypeAdapter(
|
|
60
|
+
Annotated[
|
|
61
|
+
Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm, PN2VAlgorithm],
|
|
62
|
+
Field(discriminator="algorithm"),
|
|
63
|
+
]
|
|
64
|
+
)
|
|
65
|
+
return adapter.validate_python(algorithm)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _list_spatial_augmentations(
|
|
69
|
+
augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
|
|
70
|
+
) -> list[SPATIAL_TRANSFORMS_UNION]:
|
|
71
|
+
"""
|
|
72
|
+
List the augmentations to apply.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
augmentations : list of transforms, optional
|
|
77
|
+
List of transforms to apply, either both or one of XYFlipConfig and
|
|
78
|
+
XYRandomRotate90Config.
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
list of transforms
|
|
83
|
+
List of transforms to apply.
|
|
84
|
+
|
|
85
|
+
Raises
|
|
86
|
+
------
|
|
87
|
+
ValueError
|
|
88
|
+
If the transforms are not XYFlipConfig or XYRandomRotate90Config.
|
|
89
|
+
ValueError
|
|
90
|
+
If there are duplicate transforms.
|
|
91
|
+
"""
|
|
92
|
+
if augmentations is None:
|
|
93
|
+
transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
|
|
94
|
+
XYFlipConfig(),
|
|
95
|
+
XYRandomRotate90Config(),
|
|
96
|
+
]
|
|
97
|
+
else:
|
|
98
|
+
# throw error if not all transforms are pydantic models
|
|
99
|
+
if not all(
|
|
100
|
+
isinstance(t, XYFlipConfig) or isinstance(t, XYRandomRotate90Config)
|
|
101
|
+
for t in augmentations
|
|
102
|
+
):
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"Accepted transforms are either XYFlipConfig or "
|
|
105
|
+
"XYRandomRotate90Config."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# check that there is no duplication
|
|
109
|
+
aug_types = [t.__class__ for t in augmentations]
|
|
110
|
+
if len(set(aug_types)) != len(aug_types):
|
|
111
|
+
raise ValueError("Duplicate transforms are not allowed.")
|
|
112
|
+
|
|
113
|
+
transform_list = augmentations
|
|
114
|
+
|
|
115
|
+
return transform_list
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _create_unet_configuration(
|
|
119
|
+
axes: str,
|
|
120
|
+
n_channels_in: int,
|
|
121
|
+
n_channels_out: int,
|
|
122
|
+
independent_channels: bool,
|
|
123
|
+
use_n2v2: bool,
|
|
124
|
+
model_params: dict[str, Any] | None = None,
|
|
125
|
+
) -> UNetConfig:
|
|
126
|
+
"""
|
|
127
|
+
Create a dictionary with the parameters of the UNet model.
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
axes : str
|
|
132
|
+
Axes of the data.
|
|
133
|
+
n_channels_in : int
|
|
134
|
+
Number of input channels.
|
|
135
|
+
n_channels_out : int
|
|
136
|
+
Number of output channels.
|
|
137
|
+
independent_channels : bool
|
|
138
|
+
Whether to train all channels independently.
|
|
139
|
+
use_n2v2 : bool
|
|
140
|
+
Whether to use N2V2.
|
|
141
|
+
model_params : dict
|
|
142
|
+
UNetModel parameters.
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
UNetModel
|
|
147
|
+
UNet model with the specified parameters.
|
|
148
|
+
"""
|
|
149
|
+
if model_params is None:
|
|
150
|
+
model_params = {}
|
|
151
|
+
|
|
152
|
+
model_params["n2v2"] = use_n2v2
|
|
153
|
+
model_params["conv_dims"] = 3 if "Z" in axes else 2
|
|
154
|
+
model_params["in_channels"] = n_channels_in
|
|
155
|
+
model_params["num_classes"] = n_channels_out
|
|
156
|
+
model_params["independent_channels"] = independent_channels
|
|
157
|
+
|
|
158
|
+
return UNetConfig(
|
|
159
|
+
architecture=SupportedArchitecture.UNET.value,
|
|
160
|
+
**model_params,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _create_algorithm_configuration(
|
|
165
|
+
axes: str,
|
|
166
|
+
algorithm: Literal["n2v", "care", "n2n", "pn2v"],
|
|
167
|
+
loss: Literal["n2v", "mae", "mse", "pn2v"],
|
|
168
|
+
independent_channels: bool,
|
|
169
|
+
n_channels_in: int,
|
|
170
|
+
n_channels_out: int,
|
|
171
|
+
use_n2v2: bool = False,
|
|
172
|
+
model_params: dict | None = None,
|
|
173
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
174
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
175
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
176
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
177
|
+
) -> dict:
|
|
178
|
+
"""
|
|
179
|
+
Create a dictionary with the parameters of the algorithm model.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
axes : str
|
|
184
|
+
Axes of the data.
|
|
185
|
+
algorithm : {"n2v", "care", "n2n", "pn2v"}
|
|
186
|
+
Algorithm to use.
|
|
187
|
+
loss : {"n2v", "mae", "mse", "pn2v"}
|
|
188
|
+
Loss function to use.
|
|
189
|
+
independent_channels : bool
|
|
190
|
+
Whether to train all channels independently.
|
|
191
|
+
n_channels_in : int
|
|
192
|
+
Number of input channels.
|
|
193
|
+
n_channels_out : int
|
|
194
|
+
Number of output channels.
|
|
195
|
+
use_n2v2 : bool, default=false
|
|
196
|
+
Whether to use N2V2.
|
|
197
|
+
model_params : dict, default=None
|
|
198
|
+
UNetModel parameters.
|
|
199
|
+
optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
|
|
200
|
+
Optimizer to use.
|
|
201
|
+
optimizer_params : dict, default=None
|
|
202
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
203
|
+
lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
|
|
204
|
+
Learning rate scheduler to use.
|
|
205
|
+
lr_scheduler_params : dict, default=None
|
|
206
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
207
|
+
details.
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
dict
|
|
213
|
+
Algorithm model as dictionnary with the specified parameters.
|
|
214
|
+
"""
|
|
215
|
+
# model
|
|
216
|
+
unet_model = _create_unet_configuration(
|
|
217
|
+
axes=axes,
|
|
218
|
+
n_channels_in=n_channels_in,
|
|
219
|
+
n_channels_out=n_channels_out,
|
|
220
|
+
independent_channels=independent_channels,
|
|
221
|
+
use_n2v2=use_n2v2,
|
|
222
|
+
model_params=model_params,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return {
|
|
226
|
+
"algorithm": algorithm,
|
|
227
|
+
"loss": loss,
|
|
228
|
+
"model": unet_model,
|
|
229
|
+
"optimizer": {
|
|
230
|
+
"name": optimizer,
|
|
231
|
+
"parameters": {} if optimizer_params is None else optimizer_params,
|
|
232
|
+
},
|
|
233
|
+
"lr_scheduler": {
|
|
234
|
+
"name": lr_scheduler,
|
|
235
|
+
"parameters": {} if lr_scheduler_params is None else lr_scheduler_params,
|
|
236
|
+
},
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _create_data_configuration(
|
|
241
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
242
|
+
axes: str,
|
|
243
|
+
patch_size: Sequence[int],
|
|
244
|
+
batch_size: int,
|
|
245
|
+
augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
|
|
246
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
247
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
248
|
+
) -> DataConfig:
|
|
249
|
+
"""
|
|
250
|
+
Create a dictionary with the parameters of the data model.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
data_type : {"array", "tiff", "czi", "custom"}
|
|
255
|
+
Type of the data.
|
|
256
|
+
axes : str
|
|
257
|
+
Axes of the data.
|
|
258
|
+
patch_size : list of int
|
|
259
|
+
Size of the patches along the spatial dimensions.
|
|
260
|
+
batch_size : int
|
|
261
|
+
Batch size.
|
|
262
|
+
augmentations : list of transforms
|
|
263
|
+
List of transforms to apply.
|
|
264
|
+
train_dataloader_params : dict
|
|
265
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
266
|
+
val_dataloader_params : dict
|
|
267
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
DataConfig
|
|
272
|
+
Data model with the specified parameters.
|
|
273
|
+
"""
|
|
274
|
+
# data model
|
|
275
|
+
data = {
|
|
276
|
+
"data_type": data_type,
|
|
277
|
+
"axes": axes,
|
|
278
|
+
"patch_size": patch_size,
|
|
279
|
+
"batch_size": batch_size,
|
|
280
|
+
"transforms": augmentations,
|
|
281
|
+
}
|
|
282
|
+
# Don't override defaults set in DataConfig class
|
|
283
|
+
if train_dataloader_params is not None:
|
|
284
|
+
# DataConfig enforces the presence of `shuffle` key in the dataloader parameters
|
|
285
|
+
if "shuffle" not in train_dataloader_params:
|
|
286
|
+
train_dataloader_params["shuffle"] = True
|
|
287
|
+
|
|
288
|
+
data["train_dataloader_params"] = train_dataloader_params
|
|
289
|
+
|
|
290
|
+
if val_dataloader_params is not None:
|
|
291
|
+
data["val_dataloader_params"] = val_dataloader_params
|
|
292
|
+
|
|
293
|
+
return DataConfig(**data)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _create_microsplit_data_configuration(
|
|
297
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
298
|
+
axes: str,
|
|
299
|
+
patch_size: Sequence[int],
|
|
300
|
+
grid_size: int,
|
|
301
|
+
multiscale_count: int,
|
|
302
|
+
batch_size: int,
|
|
303
|
+
augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
|
|
304
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
305
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
306
|
+
) -> DataConfig:
|
|
307
|
+
"""
|
|
308
|
+
Create a dictionary with the parameters of the data model.
|
|
309
|
+
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
data_type : {"array", "tiff", "czi", "custom"}
|
|
313
|
+
Type of the data.
|
|
314
|
+
axes : str
|
|
315
|
+
Axes of the data.
|
|
316
|
+
patch_size : list of int
|
|
317
|
+
Size of the patches along the spatial dimensions.
|
|
318
|
+
grid_size : int
|
|
319
|
+
Size of the grid for multiscale data configuration.
|
|
320
|
+
multiscale_count : int
|
|
321
|
+
Number of multiscale levels.
|
|
322
|
+
batch_size : int
|
|
323
|
+
Batch size.
|
|
324
|
+
augmentations : list of transforms
|
|
325
|
+
List of transforms to apply.
|
|
326
|
+
train_dataloader_params : dict
|
|
327
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
328
|
+
val_dataloader_params : dict
|
|
329
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
330
|
+
|
|
331
|
+
Returns
|
|
332
|
+
-------
|
|
333
|
+
DataConfig
|
|
334
|
+
Data model with the specified parameters.
|
|
335
|
+
"""
|
|
336
|
+
# data model
|
|
337
|
+
data = {
|
|
338
|
+
"data_type": data_type,
|
|
339
|
+
"axes": axes,
|
|
340
|
+
"image_size": patch_size,
|
|
341
|
+
"grid_size": grid_size,
|
|
342
|
+
"multiscale_lowres_count": multiscale_count,
|
|
343
|
+
"batch_size": batch_size,
|
|
344
|
+
"transforms": augmentations,
|
|
345
|
+
}
|
|
346
|
+
# Don't override defaults set in DataConfig class
|
|
347
|
+
if train_dataloader_params is not None:
|
|
348
|
+
# DataConfig enforces the presence of `shuffle` key in the dataloader parameters
|
|
349
|
+
if "shuffle" not in train_dataloader_params:
|
|
350
|
+
train_dataloader_params["shuffle"] = True
|
|
351
|
+
|
|
352
|
+
data["train_dataloader_params"] = train_dataloader_params
|
|
353
|
+
|
|
354
|
+
if val_dataloader_params is not None:
|
|
355
|
+
data["val_dataloader_params"] = val_dataloader_params
|
|
356
|
+
|
|
357
|
+
return MicroSplitDataConfig(**data)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _create_training_configuration(
|
|
361
|
+
trainer_params: dict,
|
|
362
|
+
logger: Literal["wandb", "tensorboard", "none"],
|
|
363
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
364
|
+
) -> TrainingConfig:
|
|
365
|
+
"""
|
|
366
|
+
Create a dictionary with the parameters of the training model.
|
|
367
|
+
|
|
368
|
+
Parameters
|
|
369
|
+
----------
|
|
370
|
+
trainer_params : dict
|
|
371
|
+
Parameters for Lightning Trainer class, see PyTorch Lightning documentation.
|
|
372
|
+
logger : {"wandb", "tensorboard", "none"}
|
|
373
|
+
Logger to use.
|
|
374
|
+
checkpoint_params : dict, default=None
|
|
375
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
376
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
377
|
+
|
|
378
|
+
Returns
|
|
379
|
+
-------
|
|
380
|
+
TrainingConfig
|
|
381
|
+
Training model with the specified parameters.
|
|
382
|
+
"""
|
|
383
|
+
return TrainingConfig(
|
|
384
|
+
lightning_trainer_config=trainer_params,
|
|
385
|
+
logger=None if logger == "none" else logger,
|
|
386
|
+
checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def update_trainer_params(
|
|
391
|
+
trainer_params: dict[str, Any] | None = None,
|
|
392
|
+
num_epochs: int | None = None,
|
|
393
|
+
num_steps: int | None = None,
|
|
394
|
+
) -> dict[str, Any]:
|
|
395
|
+
"""
|
|
396
|
+
Update trainer parameters with num_epochs and num_steps.
|
|
397
|
+
|
|
398
|
+
Parameters
|
|
399
|
+
----------
|
|
400
|
+
trainer_params : dict, optional
|
|
401
|
+
Parameters for Lightning Trainer class, by default None.
|
|
402
|
+
num_epochs : int, optional
|
|
403
|
+
Number of epochs to train for. If provided, this will be added as max_epochs
|
|
404
|
+
to trainer_params, by default None.
|
|
405
|
+
num_steps : int, optional
|
|
406
|
+
Number of batches in 1 epoch. If provided, this will be added as
|
|
407
|
+
limit_train_batches to trainer_params, by default None.
|
|
408
|
+
|
|
409
|
+
Returns
|
|
410
|
+
-------
|
|
411
|
+
dict
|
|
412
|
+
Updated trainer parameters dictionary.
|
|
413
|
+
"""
|
|
414
|
+
final_trainer_params = {} if trainer_params is None else trainer_params.copy()
|
|
415
|
+
|
|
416
|
+
if num_epochs is not None:
|
|
417
|
+
final_trainer_params["max_epochs"] = num_epochs
|
|
418
|
+
if num_steps is not None:
|
|
419
|
+
final_trainer_params["limit_train_batches"] = num_steps
|
|
420
|
+
|
|
421
|
+
return final_trainer_params
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
# TODO reconsider naming once we officially support LVAE approaches
|
|
425
|
+
def _create_supervised_config_dict(
|
|
426
|
+
algorithm: Literal["care", "n2n"],
|
|
427
|
+
experiment_name: str,
|
|
428
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
429
|
+
axes: str,
|
|
430
|
+
patch_size: Sequence[int],
|
|
431
|
+
batch_size: int,
|
|
432
|
+
trainer_params: dict | None = None,
|
|
433
|
+
augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
|
|
434
|
+
independent_channels: bool = True,
|
|
435
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
436
|
+
n_channels_in: int | None = None,
|
|
437
|
+
n_channels_out: int | None = None,
|
|
438
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
439
|
+
model_params: dict | None = None,
|
|
440
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
441
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
442
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
443
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
444
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
445
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
446
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
447
|
+
num_epochs: int | None = None,
|
|
448
|
+
num_steps: int | None = None,
|
|
449
|
+
) -> dict:
|
|
450
|
+
"""
|
|
451
|
+
Create a configuration for training CARE or Noise2Noise.
|
|
452
|
+
|
|
453
|
+
Parameters
|
|
454
|
+
----------
|
|
455
|
+
algorithm : Literal["care", "n2n"]
|
|
456
|
+
Algorithm to use.
|
|
457
|
+
experiment_name : str
|
|
458
|
+
Name of the experiment.
|
|
459
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
460
|
+
Type of the data.
|
|
461
|
+
axes : str
|
|
462
|
+
Axes of the data (e.g. SYX).
|
|
463
|
+
patch_size : List[int]
|
|
464
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
465
|
+
batch_size : int
|
|
466
|
+
Batch size.
|
|
467
|
+
trainer_params : dict
|
|
468
|
+
Parameters for the training configuration.
|
|
469
|
+
augmentations : list of transforms, default=None
|
|
470
|
+
List of transforms to apply, either both or one of XYFlipConfig and
|
|
471
|
+
XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
|
|
472
|
+
and XYRandomRotate90 (in XY) to the images.
|
|
473
|
+
independent_channels : bool, optional
|
|
474
|
+
Whether to train all channels independently, by default False.
|
|
475
|
+
loss : Literal["mae", "mse"], optional
|
|
476
|
+
Loss function to use, by default "mae".
|
|
477
|
+
n_channels_in : int or None, default=None
|
|
478
|
+
Number of channels in.
|
|
479
|
+
n_channels_out : int or None, default=None
|
|
480
|
+
Number of channels out.
|
|
481
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
482
|
+
Logger to use, by default "none".
|
|
483
|
+
model_params : dict, default=None
|
|
484
|
+
UNetModel parameters.
|
|
485
|
+
optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
|
|
486
|
+
Optimizer to use.
|
|
487
|
+
optimizer_params : dict, default=None
|
|
488
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
489
|
+
lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
|
|
490
|
+
Learning rate scheduler to use.
|
|
491
|
+
lr_scheduler_params : dict, default=None
|
|
492
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
493
|
+
details.
|
|
494
|
+
train_dataloader_params : dict
|
|
495
|
+
Parameters for the training dataloader, see PyTorch notes, by default None.
|
|
496
|
+
val_dataloader_params : dict
|
|
497
|
+
Parameters for the validation dataloader, see PyTorch notes, by default None.
|
|
498
|
+
checkpoint_params : dict, default=None
|
|
499
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
500
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
501
|
+
num_epochs : int or None, default=None
|
|
502
|
+
Number of epochs to train for. If provided, this will be added to
|
|
503
|
+
trainer_params.
|
|
504
|
+
num_steps : int or None, default=None
|
|
505
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
506
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
507
|
+
documentation for more details.
|
|
508
|
+
|
|
509
|
+
Returns
|
|
510
|
+
-------
|
|
511
|
+
Configuration
|
|
512
|
+
Configuration for training CARE or Noise2Noise.
|
|
513
|
+
|
|
514
|
+
Raises
|
|
515
|
+
------
|
|
516
|
+
ValueError
|
|
517
|
+
If the number of channels is not specified when using channels.
|
|
518
|
+
ValueError
|
|
519
|
+
If the number of channels is specified but "C" is not in the axes.
|
|
520
|
+
"""
|
|
521
|
+
# if there are channels, we need to specify their number
|
|
522
|
+
if "C" in axes and n_channels_in is None:
|
|
523
|
+
raise ValueError("Number of channels in must be specified when using channels ")
|
|
524
|
+
elif "C" not in axes and (n_channels_in is not None and n_channels_in > 1):
|
|
525
|
+
raise ValueError(
|
|
526
|
+
f"C is not present in the axes, but number of channels is specified "
|
|
527
|
+
f"(got {n_channels_in} channels)."
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if n_channels_in is None:
|
|
531
|
+
n_channels_in = 1
|
|
532
|
+
|
|
533
|
+
if n_channels_out is None:
|
|
534
|
+
n_channels_out = n_channels_in
|
|
535
|
+
|
|
536
|
+
# augmentations
|
|
537
|
+
spatial_transform_list = _list_spatial_augmentations(augmentations)
|
|
538
|
+
|
|
539
|
+
# algorithm
|
|
540
|
+
algorithm_params = _create_algorithm_configuration(
|
|
541
|
+
axes=axes,
|
|
542
|
+
algorithm=algorithm,
|
|
543
|
+
loss=loss,
|
|
544
|
+
independent_channels=independent_channels,
|
|
545
|
+
n_channels_in=n_channels_in,
|
|
546
|
+
n_channels_out=n_channels_out,
|
|
547
|
+
model_params=model_params,
|
|
548
|
+
optimizer=optimizer,
|
|
549
|
+
optimizer_params=optimizer_params,
|
|
550
|
+
lr_scheduler=lr_scheduler,
|
|
551
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# data
|
|
555
|
+
data_params = _create_data_configuration(
|
|
556
|
+
data_type=data_type,
|
|
557
|
+
axes=axes,
|
|
558
|
+
patch_size=patch_size,
|
|
559
|
+
batch_size=batch_size,
|
|
560
|
+
augmentations=spatial_transform_list,
|
|
561
|
+
train_dataloader_params=train_dataloader_params,
|
|
562
|
+
val_dataloader_params=val_dataloader_params,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# training
|
|
566
|
+
final_trainer_params = update_trainer_params(
|
|
567
|
+
trainer_params=trainer_params,
|
|
568
|
+
num_epochs=num_epochs,
|
|
569
|
+
num_steps=num_steps,
|
|
570
|
+
)
|
|
571
|
+
training_params = _create_training_configuration(
|
|
572
|
+
trainer_params=final_trainer_params,
|
|
573
|
+
logger=logger,
|
|
574
|
+
checkpoint_params=checkpoint_params,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
return {
|
|
578
|
+
"experiment_name": experiment_name,
|
|
579
|
+
"algorithm_config": algorithm_params,
|
|
580
|
+
"data_config": data_params,
|
|
581
|
+
"training_config": training_params,
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def create_care_configuration(
|
|
586
|
+
experiment_name: str,
|
|
587
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
588
|
+
axes: str,
|
|
589
|
+
patch_size: Sequence[int],
|
|
590
|
+
batch_size: int,
|
|
591
|
+
num_epochs: int = 100,
|
|
592
|
+
num_steps: int | None = None,
|
|
593
|
+
augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
|
|
594
|
+
independent_channels: bool = True,
|
|
595
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
596
|
+
n_channels_in: int | None = None,
|
|
597
|
+
n_channels_out: int | None = None,
|
|
598
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
599
|
+
trainer_params: dict | None = None,
|
|
600
|
+
model_params: dict | None = None,
|
|
601
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
602
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
603
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
604
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
605
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
606
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
607
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
608
|
+
) -> Configuration:
|
|
609
|
+
"""
|
|
610
|
+
Create a configuration for training CARE.
|
|
611
|
+
|
|
612
|
+
If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
|
|
613
|
+
2.
|
|
614
|
+
|
|
615
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
616
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
617
|
+
`axes`.
|
|
618
|
+
|
|
619
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
620
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
621
|
+
|
|
622
|
+
By default, all channels are trained together. To train all channels independently,
|
|
623
|
+
set `independent_channels` to True.
|
|
624
|
+
|
|
625
|
+
By setting `augmentations` to `None`, the default transformations (flip in X and Y,
|
|
626
|
+
rotations by 90 degrees in the XY plane) are applied. Rather than the default
|
|
627
|
+
transforms, a list of transforms can be passed to the `augmentations` parameter. To
|
|
628
|
+
disable the transforms, simply pass an empty list.
|
|
629
|
+
|
|
630
|
+
Parameters
|
|
631
|
+
----------
|
|
632
|
+
experiment_name : str
|
|
633
|
+
Name of the experiment.
|
|
634
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
635
|
+
Type of the data.
|
|
636
|
+
axes : str
|
|
637
|
+
Axes of the data (e.g. SYX).
|
|
638
|
+
patch_size : List[int]
|
|
639
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
640
|
+
batch_size : int
|
|
641
|
+
Batch size.
|
|
642
|
+
num_epochs : int, default=100
|
|
643
|
+
Number of epochs to train for. If provided, this will be added to
|
|
644
|
+
trainer_params.
|
|
645
|
+
num_steps : int, optional
|
|
646
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
647
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
648
|
+
documentation for more details.
|
|
649
|
+
augmentations : list of transforms, default=None
|
|
650
|
+
List of transforms to apply, either both or one of XYFlipConfig and
|
|
651
|
+
XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
|
|
652
|
+
and XYRandomRotate90 (in XY) to the images.
|
|
653
|
+
independent_channels : bool, optional
|
|
654
|
+
Whether to train all channels independently, by default False.
|
|
655
|
+
loss : Literal["mae", "mse"], default="mae"
|
|
656
|
+
Loss function to use.
|
|
657
|
+
n_channels_in : int or None, default=None
|
|
658
|
+
Number of channels in.
|
|
659
|
+
n_channels_out : int or None, default=None
|
|
660
|
+
Number of channels out.
|
|
661
|
+
logger : Literal["wandb", "tensorboard", "none"], default="none"
|
|
662
|
+
Logger to use.
|
|
663
|
+
trainer_params : dict, optional
|
|
664
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
665
|
+
model_params : dict, default=None
|
|
666
|
+
UNetModel parameters.
|
|
667
|
+
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
668
|
+
Optimizer to use.
|
|
669
|
+
optimizer_params : dict, default=None
|
|
670
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
671
|
+
lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
|
|
672
|
+
Learning rate scheduler to use.
|
|
673
|
+
lr_scheduler_params : dict, default=None
|
|
674
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
675
|
+
details.
|
|
676
|
+
train_dataloader_params : dict, optional
|
|
677
|
+
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
678
|
+
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
679
|
+
the `GeneralDataConfig`.
|
|
680
|
+
val_dataloader_params : dict, optional
|
|
681
|
+
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
682
|
+
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
683
|
+
`GeneralDataConfig`.
|
|
684
|
+
checkpoint_params : dict, default=None
|
|
685
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
686
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
687
|
+
|
|
688
|
+
Returns
|
|
689
|
+
-------
|
|
690
|
+
Configuration
|
|
691
|
+
Configuration for training CARE.
|
|
692
|
+
|
|
693
|
+
Examples
|
|
694
|
+
--------
|
|
695
|
+
Minimum example:
|
|
696
|
+
>>> config = create_care_configuration(
|
|
697
|
+
... experiment_name="care_experiment",
|
|
698
|
+
... data_type="array",
|
|
699
|
+
... axes="YX",
|
|
700
|
+
... patch_size=[64, 64],
|
|
701
|
+
... batch_size=32,
|
|
702
|
+
... num_epochs=100
|
|
703
|
+
... )
|
|
704
|
+
|
|
705
|
+
You can also limit the number of batches per epoch:
|
|
706
|
+
>>> config = create_care_configuration(
|
|
707
|
+
... experiment_name="care_experiment",
|
|
708
|
+
... data_type="array",
|
|
709
|
+
... axes="YX",
|
|
710
|
+
... patch_size=[64, 64],
|
|
711
|
+
... batch_size=32,
|
|
712
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
713
|
+
... )
|
|
714
|
+
|
|
715
|
+
To disable transforms, simply set `augmentations` to an empty list:
|
|
716
|
+
>>> config = create_care_configuration(
|
|
717
|
+
... experiment_name="care_experiment",
|
|
718
|
+
... data_type="array",
|
|
719
|
+
... axes="YX",
|
|
720
|
+
... patch_size=[64, 64],
|
|
721
|
+
... batch_size=32,
|
|
722
|
+
... num_epochs=100,
|
|
723
|
+
... augmentations=[]
|
|
724
|
+
... )
|
|
725
|
+
|
|
726
|
+
A list of transforms can be passed to the `augmentations` parameter to replace the
|
|
727
|
+
default augmentations:
|
|
728
|
+
>>> from careamics.config.transformations import XYFlipConfig
|
|
729
|
+
>>> config = create_care_configuration(
|
|
730
|
+
... experiment_name="care_experiment",
|
|
731
|
+
... data_type="array",
|
|
732
|
+
... axes="YX",
|
|
733
|
+
... patch_size=[64, 64],
|
|
734
|
+
... batch_size=32,
|
|
735
|
+
... num_epochs=100,
|
|
736
|
+
... augmentations=[
|
|
737
|
+
... # No rotation and only Y flipping
|
|
738
|
+
... XYFlipConfig(flip_x = False, flip_y = True)
|
|
739
|
+
... ]
|
|
740
|
+
... )
|
|
741
|
+
|
|
742
|
+
If you are training multiple channels they will be trained independently by default,
|
|
743
|
+
you simply need to specify the number of channels input (and optionally, the number
|
|
744
|
+
of channels output):
|
|
745
|
+
>>> config = create_care_configuration(
|
|
746
|
+
... experiment_name="care_experiment",
|
|
747
|
+
... data_type="array",
|
|
748
|
+
... axes="YXC", # channels must be in the axes
|
|
749
|
+
... patch_size=[64, 64],
|
|
750
|
+
... batch_size=32,
|
|
751
|
+
... num_epochs=100,
|
|
752
|
+
... n_channels_in=3, # number of input channels
|
|
753
|
+
... n_channels_out=1 # if applicable
|
|
754
|
+
... )
|
|
755
|
+
|
|
756
|
+
If instead you want to train multiple channels together, you need to turn off the
|
|
757
|
+
`independent_channels` parameter:
|
|
758
|
+
>>> config = create_care_configuration(
|
|
759
|
+
... experiment_name="care_experiment",
|
|
760
|
+
... data_type="array",
|
|
761
|
+
... axes="YXC", # channels must be in the axes
|
|
762
|
+
... patch_size=[64, 64],
|
|
763
|
+
... batch_size=32,
|
|
764
|
+
... num_epochs=100,
|
|
765
|
+
... independent_channels=False,
|
|
766
|
+
... n_channels_in=3,
|
|
767
|
+
... n_channels_out=1 # if applicable
|
|
768
|
+
... )
|
|
769
|
+
|
|
770
|
+
If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
|
|
771
|
+
`axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
|
|
772
|
+
for 3-D data but spatial context along the Z dimension will then not be taken into
|
|
773
|
+
account.
|
|
774
|
+
>>> config_2d = create_care_configuration(
|
|
775
|
+
... experiment_name="care_experiment",
|
|
776
|
+
... data_type="czi",
|
|
777
|
+
... axes="SCYX",
|
|
778
|
+
... patch_size=[64, 64],
|
|
779
|
+
... batch_size=32,
|
|
780
|
+
... num_epochs=100,
|
|
781
|
+
... n_channels_in=1,
|
|
782
|
+
... )
|
|
783
|
+
>>> config_3d = create_care_configuration(
|
|
784
|
+
... experiment_name="care_experiment",
|
|
785
|
+
... data_type="czi",
|
|
786
|
+
... axes="SCZYX",
|
|
787
|
+
... patch_size=[16, 64, 64],
|
|
788
|
+
... batch_size=16,
|
|
789
|
+
... num_epochs=100,
|
|
790
|
+
... n_channels_in=1,
|
|
791
|
+
... )
|
|
792
|
+
"""
|
|
793
|
+
return Configuration(
|
|
794
|
+
**_create_supervised_config_dict(
|
|
795
|
+
algorithm="care",
|
|
796
|
+
experiment_name=experiment_name,
|
|
797
|
+
data_type=data_type,
|
|
798
|
+
axes=axes,
|
|
799
|
+
patch_size=patch_size,
|
|
800
|
+
batch_size=batch_size,
|
|
801
|
+
augmentations=augmentations,
|
|
802
|
+
independent_channels=independent_channels,
|
|
803
|
+
loss=loss,
|
|
804
|
+
n_channels_in=n_channels_in,
|
|
805
|
+
n_channels_out=n_channels_out,
|
|
806
|
+
logger=logger,
|
|
807
|
+
trainer_params=trainer_params,
|
|
808
|
+
model_params=model_params,
|
|
809
|
+
optimizer=optimizer,
|
|
810
|
+
optimizer_params=optimizer_params,
|
|
811
|
+
lr_scheduler=lr_scheduler,
|
|
812
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
813
|
+
train_dataloader_params=train_dataloader_params,
|
|
814
|
+
val_dataloader_params=val_dataloader_params,
|
|
815
|
+
checkpoint_params=checkpoint_params,
|
|
816
|
+
num_epochs=num_epochs,
|
|
817
|
+
num_steps=num_steps,
|
|
818
|
+
)
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
def create_n2n_configuration(
|
|
823
|
+
experiment_name: str,
|
|
824
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
825
|
+
axes: str,
|
|
826
|
+
patch_size: Sequence[int],
|
|
827
|
+
batch_size: int,
|
|
828
|
+
num_epochs: int = 100,
|
|
829
|
+
num_steps: int | None = None,
|
|
830
|
+
augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
|
|
831
|
+
independent_channels: bool = True,
|
|
832
|
+
loss: Literal["mae", "mse"] = "mae",
|
|
833
|
+
n_channels_in: int | None = None,
|
|
834
|
+
n_channels_out: int | None = None,
|
|
835
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
836
|
+
trainer_params: dict | None = None,
|
|
837
|
+
model_params: dict | None = None,
|
|
838
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
839
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
840
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
841
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
842
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
843
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
844
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
845
|
+
) -> Configuration:
|
|
846
|
+
"""
|
|
847
|
+
Create a configuration for training Noise2Noise.
|
|
848
|
+
|
|
849
|
+
If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
|
|
850
|
+
2.
|
|
851
|
+
|
|
852
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
853
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
854
|
+
`axes`.
|
|
855
|
+
|
|
856
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
857
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
858
|
+
|
|
859
|
+
By default, all channels are trained together. To train all channels independently,
|
|
860
|
+
set `independent_channels` to True.
|
|
861
|
+
|
|
862
|
+
By setting `augmentations` to `None`, the default transformations (flip in X and Y,
|
|
863
|
+
rotations by 90 degrees in the XY plane) are applied. Rather than the default
|
|
864
|
+
transforms, a list of transforms can be passed to the `augmentations` parameter. To
|
|
865
|
+
disable the transforms, simply pass an empty list.
|
|
866
|
+
|
|
867
|
+
Parameters
|
|
868
|
+
----------
|
|
869
|
+
experiment_name : str
|
|
870
|
+
Name of the experiment.
|
|
871
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
872
|
+
Type of the data.
|
|
873
|
+
axes : str
|
|
874
|
+
Axes of the data (e.g. SYX).
|
|
875
|
+
patch_size : List[int]
|
|
876
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
877
|
+
batch_size : int
|
|
878
|
+
Batch size.
|
|
879
|
+
num_epochs : int, default=100
|
|
880
|
+
Number of epochs to train for. If provided, this will be added to
|
|
881
|
+
trainer_params.
|
|
882
|
+
num_steps : int, optional
|
|
883
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
884
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
885
|
+
documentation for more details.
|
|
886
|
+
augmentations : list of transforms, default=None
|
|
887
|
+
List of transforms to apply, either both or one of XYFlipConfig and
|
|
888
|
+
XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
|
|
889
|
+
and XYRandomRotate90 (in XY) to the images.
|
|
890
|
+
independent_channels : bool, optional
|
|
891
|
+
Whether to train all channels independently, by default False.
|
|
892
|
+
loss : Literal["mae", "mse"], optional
|
|
893
|
+
Loss function to use, by default "mae".
|
|
894
|
+
n_channels_in : int or None, default=None
|
|
895
|
+
Number of channels in.
|
|
896
|
+
n_channels_out : int or None, default=None
|
|
897
|
+
Number of channels out.
|
|
898
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
899
|
+
Logger to use, by default "none".
|
|
900
|
+
trainer_params : dict, optional
|
|
901
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
902
|
+
model_params : dict, default=None
|
|
903
|
+
UNetModel parameters.
|
|
904
|
+
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
905
|
+
Optimizer to use.
|
|
906
|
+
optimizer_params : dict, default=None
|
|
907
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
908
|
+
lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
|
|
909
|
+
Learning rate scheduler to use.
|
|
910
|
+
lr_scheduler_params : dict, default=None
|
|
911
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
912
|
+
details.
|
|
913
|
+
train_dataloader_params : dict, optional
|
|
914
|
+
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
915
|
+
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
916
|
+
the `GeneralDataConfig`.
|
|
917
|
+
val_dataloader_params : dict, optional
|
|
918
|
+
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
919
|
+
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
920
|
+
`GeneralDataConfig`.
|
|
921
|
+
checkpoint_params : dict, default=None
|
|
922
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
923
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
924
|
+
|
|
925
|
+
Returns
|
|
926
|
+
-------
|
|
927
|
+
Configuration
|
|
928
|
+
Configuration for training Noise2Noise.
|
|
929
|
+
|
|
930
|
+
Examples
|
|
931
|
+
--------
|
|
932
|
+
Minimum example:
|
|
933
|
+
>>> config = create_n2n_configuration(
|
|
934
|
+
... experiment_name="n2n_experiment",
|
|
935
|
+
... data_type="array",
|
|
936
|
+
... axes="YX",
|
|
937
|
+
... patch_size=[64, 64],
|
|
938
|
+
... batch_size=32,
|
|
939
|
+
... num_epochs=100
|
|
940
|
+
... )
|
|
941
|
+
|
|
942
|
+
You can also limit the number of batches per epoch:
|
|
943
|
+
>>> config = create_n2n_configuration(
|
|
944
|
+
... experiment_name="n2n_experiment",
|
|
945
|
+
... data_type="array",
|
|
946
|
+
... axes="YX",
|
|
947
|
+
... patch_size=[64, 64],
|
|
948
|
+
... batch_size=32,
|
|
949
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
950
|
+
... )
|
|
951
|
+
|
|
952
|
+
To disable transforms, simply set `augmentations` to an empty list:
|
|
953
|
+
>>> config = create_n2n_configuration(
|
|
954
|
+
... experiment_name="n2n_experiment",
|
|
955
|
+
... data_type="array",
|
|
956
|
+
... axes="YX",
|
|
957
|
+
... patch_size=[64, 64],
|
|
958
|
+
... batch_size=32,
|
|
959
|
+
... num_epochs=100,
|
|
960
|
+
... augmentations=[]
|
|
961
|
+
... )
|
|
962
|
+
|
|
963
|
+
A list of transforms can be passed to the `augmentations` parameter:
|
|
964
|
+
>>> from careamics.config.transformations import XYFlipConfig
|
|
965
|
+
>>> config = create_n2n_configuration(
|
|
966
|
+
... experiment_name="n2n_experiment",
|
|
967
|
+
... data_type="array",
|
|
968
|
+
... axes="YX",
|
|
969
|
+
... patch_size=[64, 64],
|
|
970
|
+
... batch_size=32,
|
|
971
|
+
... num_epochs=100,
|
|
972
|
+
... augmentations=[
|
|
973
|
+
... # No rotation and only Y flipping
|
|
974
|
+
... XYFlipConfig(flip_x = False, flip_y = True)
|
|
975
|
+
... ]
|
|
976
|
+
... )
|
|
977
|
+
|
|
978
|
+
If you are training multiple channels they will be trained independently by default,
|
|
979
|
+
you simply need to specify the number of channels input (and optionally, the number
|
|
980
|
+
of channels output):
|
|
981
|
+
>>> config = create_n2n_configuration(
|
|
982
|
+
... experiment_name="n2n_experiment",
|
|
983
|
+
... data_type="array",
|
|
984
|
+
... axes="YXC", # channels must be in the axes
|
|
985
|
+
... patch_size=[64, 64],
|
|
986
|
+
... batch_size=32,
|
|
987
|
+
... num_epochs=100,
|
|
988
|
+
... n_channels_in=3, # number of input channels
|
|
989
|
+
... n_channels_out=1 # if applicable
|
|
990
|
+
... )
|
|
991
|
+
|
|
992
|
+
If instead you want to train multiple channels together, you need to turn off the
|
|
993
|
+
`independent_channels` parameter:
|
|
994
|
+
>>> config = create_n2n_configuration(
|
|
995
|
+
... experiment_name="n2n_experiment",
|
|
996
|
+
... data_type="array",
|
|
997
|
+
... axes="YXC", # channels must be in the axes
|
|
998
|
+
... patch_size=[64, 64],
|
|
999
|
+
... batch_size=32,
|
|
1000
|
+
... num_epochs=100,
|
|
1001
|
+
... independent_channels=False,
|
|
1002
|
+
... n_channels_in=3,
|
|
1003
|
+
... n_channels_out=1 # if applicable
|
|
1004
|
+
... )
|
|
1005
|
+
|
|
1006
|
+
If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
|
|
1007
|
+
`axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
|
|
1008
|
+
for 3-D data but spatial context along the Z dimension will then not be taken into
|
|
1009
|
+
account.
|
|
1010
|
+
>>> config_2d = create_n2n_configuration(
|
|
1011
|
+
... experiment_name="n2n_experiment",
|
|
1012
|
+
... data_type="czi",
|
|
1013
|
+
... axes="SCYX",
|
|
1014
|
+
... patch_size=[64, 64],
|
|
1015
|
+
... batch_size=32,
|
|
1016
|
+
... num_epochs=100,
|
|
1017
|
+
... n_channels_in=1,
|
|
1018
|
+
... )
|
|
1019
|
+
>>> config_3d = create_n2n_configuration(
|
|
1020
|
+
... experiment_name="n2n_experiment",
|
|
1021
|
+
... data_type="czi",
|
|
1022
|
+
... axes="SCZYX",
|
|
1023
|
+
... patch_size=[16, 64, 64],
|
|
1024
|
+
... batch_size=16,
|
|
1025
|
+
... num_epochs=100,
|
|
1026
|
+
... n_channels_in=1,
|
|
1027
|
+
... )
|
|
1028
|
+
"""
|
|
1029
|
+
return Configuration(
|
|
1030
|
+
**_create_supervised_config_dict(
|
|
1031
|
+
algorithm="n2n",
|
|
1032
|
+
experiment_name=experiment_name,
|
|
1033
|
+
data_type=data_type,
|
|
1034
|
+
axes=axes,
|
|
1035
|
+
patch_size=patch_size,
|
|
1036
|
+
batch_size=batch_size,
|
|
1037
|
+
trainer_params=trainer_params,
|
|
1038
|
+
augmentations=augmentations,
|
|
1039
|
+
independent_channels=independent_channels,
|
|
1040
|
+
loss=loss,
|
|
1041
|
+
n_channels_in=n_channels_in,
|
|
1042
|
+
n_channels_out=n_channels_out,
|
|
1043
|
+
logger=logger,
|
|
1044
|
+
model_params=model_params,
|
|
1045
|
+
optimizer=optimizer,
|
|
1046
|
+
optimizer_params=optimizer_params,
|
|
1047
|
+
lr_scheduler=lr_scheduler,
|
|
1048
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
1049
|
+
train_dataloader_params=train_dataloader_params,
|
|
1050
|
+
val_dataloader_params=val_dataloader_params,
|
|
1051
|
+
checkpoint_params=checkpoint_params,
|
|
1052
|
+
num_epochs=num_epochs,
|
|
1053
|
+
num_steps=num_steps,
|
|
1054
|
+
)
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
def create_n2v_configuration(
|
|
1059
|
+
experiment_name: str,
|
|
1060
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
1061
|
+
axes: str,
|
|
1062
|
+
patch_size: Sequence[int],
|
|
1063
|
+
batch_size: int,
|
|
1064
|
+
num_epochs: int = 100,
|
|
1065
|
+
num_steps: int | None = None,
|
|
1066
|
+
augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
|
|
1067
|
+
independent_channels: bool = True,
|
|
1068
|
+
use_n2v2: bool = False,
|
|
1069
|
+
n_channels: int | None = None,
|
|
1070
|
+
roi_size: int = 11,
|
|
1071
|
+
masked_pixel_percentage: float = 0.2,
|
|
1072
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
1073
|
+
struct_n2v_span: int = 5,
|
|
1074
|
+
trainer_params: dict | None = None,
|
|
1075
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
1076
|
+
model_params: dict | None = None,
|
|
1077
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
1078
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
1079
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
1080
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
1081
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
1082
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
1083
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
1084
|
+
) -> Configuration:
|
|
1085
|
+
"""
|
|
1086
|
+
Create a configuration for training Noise2Void.
|
|
1087
|
+
|
|
1088
|
+
N2V uses a UNet model to denoise images in a self-supervised manner. To use its
|
|
1089
|
+
variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
|
|
1090
|
+
(structN2V) parameters, or set `use_n2v2` to True (N2V2).
|
|
1091
|
+
|
|
1092
|
+
N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
|
|
1093
|
+
connections, thus removing checkboard artefacts. StructN2V is used when vertical
|
|
1094
|
+
or horizontal correlations are present in the noise; it applies an additional mask
|
|
1095
|
+
to the manipulated pixel neighbors.
|
|
1096
|
+
|
|
1097
|
+
If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
|
|
1098
|
+
2.
|
|
1099
|
+
|
|
1100
|
+
If "C" is present in `axes`, then you need to set `n_channels` to the number of
|
|
1101
|
+
channels.
|
|
1102
|
+
|
|
1103
|
+
By default, all channels are trained independently. To train all channels together,
|
|
1104
|
+
set `independent_channels` to False.
|
|
1105
|
+
|
|
1106
|
+
By default, the transformations applied are a random flip along X or Y, and a random
|
|
1107
|
+
90 degrees rotation in the XY plane. Normalization is always applied, as well as the
|
|
1108
|
+
N2V manipulation.
|
|
1109
|
+
|
|
1110
|
+
By setting `augmentations` to `None`, the default transformations (flip in X and Y,
|
|
1111
|
+
rotations by 90 degrees in the XY plane) are applied. Rather than the default
|
|
1112
|
+
transforms, a list of transforms can be passed to the `augmentations` parameter. To
|
|
1113
|
+
disable the transforms, simply pass an empty list.
|
|
1114
|
+
|
|
1115
|
+
The `roi_size` parameter specifies the size of the area around each pixel that will
|
|
1116
|
+
be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
|
|
1117
|
+
pixels per patch will be manipulated.
|
|
1118
|
+
|
|
1119
|
+
The parameters of the UNet can be specified in the `model_params` (passed as a
|
|
1120
|
+
parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
|
|
1121
|
+
corresponding parameters passed in `model_params`.
|
|
1122
|
+
|
|
1123
|
+
If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
|
|
1124
|
+
will be applied to each manipulated pixel.
|
|
1125
|
+
|
|
1126
|
+
Parameters
|
|
1127
|
+
----------
|
|
1128
|
+
experiment_name : str
|
|
1129
|
+
Name of the experiment.
|
|
1130
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
1131
|
+
Type of the data.
|
|
1132
|
+
axes : str
|
|
1133
|
+
Axes of the data (e.g. SYX).
|
|
1134
|
+
patch_size : List[int]
|
|
1135
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
1136
|
+
batch_size : int
|
|
1137
|
+
Batch size.
|
|
1138
|
+
num_epochs : int, default=100
|
|
1139
|
+
Number of epochs to train for. If provided, this will be added to
|
|
1140
|
+
trainer_params.
|
|
1141
|
+
num_steps : int, optional
|
|
1142
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
1143
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
1144
|
+
documentation for more details.
|
|
1145
|
+
augmentations : list of transforms, default=None
|
|
1146
|
+
List of transforms to apply, either both or one of XYFlipConfig and
|
|
1147
|
+
XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
|
|
1148
|
+
and XYRandomRotate90 (in XY) to the images.
|
|
1149
|
+
independent_channels : bool, optional
|
|
1150
|
+
Whether to train all channels together, by default True.
|
|
1151
|
+
use_n2v2 : bool, optional
|
|
1152
|
+
Whether to use N2V2, by default False.
|
|
1153
|
+
n_channels : int or None, default=None
|
|
1154
|
+
Number of channels (in and out).
|
|
1155
|
+
roi_size : int, optional
|
|
1156
|
+
N2V pixel manipulation area, by default 11.
|
|
1157
|
+
masked_pixel_percentage : float, optional
|
|
1158
|
+
Percentage of pixels masked in each patch, by default 0.2.
|
|
1159
|
+
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
1160
|
+
Axis along which to apply structN2V mask, by default "none".
|
|
1161
|
+
struct_n2v_span : int, optional
|
|
1162
|
+
Span of the structN2V mask, by default 5.
|
|
1163
|
+
trainer_params : dict, optional
|
|
1164
|
+
Parameters for the trainer, see the relevant documentation.
|
|
1165
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
1166
|
+
Logger to use, by default "none".
|
|
1167
|
+
model_params : dict, default=None
|
|
1168
|
+
UNetModel parameters.
|
|
1169
|
+
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
1170
|
+
Optimizer to use.
|
|
1171
|
+
optimizer_params : dict, default=None
|
|
1172
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
1173
|
+
lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
|
|
1174
|
+
Learning rate scheduler to use.
|
|
1175
|
+
lr_scheduler_params : dict, default=None
|
|
1176
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
1177
|
+
details.
|
|
1178
|
+
train_dataloader_params : dict, optional
|
|
1179
|
+
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
1180
|
+
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
1181
|
+
the `GeneralDataConfig`.
|
|
1182
|
+
val_dataloader_params : dict, optional
|
|
1183
|
+
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
1184
|
+
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
1185
|
+
`GeneralDataConfig`.
|
|
1186
|
+
checkpoint_params : dict, default=None
|
|
1187
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
1188
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
1189
|
+
|
|
1190
|
+
Returns
|
|
1191
|
+
-------
|
|
1192
|
+
Configuration
|
|
1193
|
+
Configuration for training N2V.
|
|
1194
|
+
|
|
1195
|
+
Examples
|
|
1196
|
+
--------
|
|
1197
|
+
Minimum example:
|
|
1198
|
+
>>> config = create_n2v_configuration(
|
|
1199
|
+
... experiment_name="n2v_experiment",
|
|
1200
|
+
... data_type="array",
|
|
1201
|
+
... axes="YX",
|
|
1202
|
+
... patch_size=[64, 64],
|
|
1203
|
+
... batch_size=32,
|
|
1204
|
+
... num_epochs=100
|
|
1205
|
+
... )
|
|
1206
|
+
|
|
1207
|
+
You can also limit the number of batches per epoch:
|
|
1208
|
+
>>> config = create_n2v_configuration(
|
|
1209
|
+
... experiment_name="n2v_experiment",
|
|
1210
|
+
... data_type="array",
|
|
1211
|
+
... axes="YX",
|
|
1212
|
+
... patch_size=[64, 64],
|
|
1213
|
+
... batch_size=32,
|
|
1214
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
1215
|
+
... )
|
|
1216
|
+
|
|
1217
|
+
To disable transforms, simply set `augmentations` to an empty list:
|
|
1218
|
+
>>> config = create_n2v_configuration(
|
|
1219
|
+
... experiment_name="n2v_experiment",
|
|
1220
|
+
... data_type="array",
|
|
1221
|
+
... axes="YX",
|
|
1222
|
+
... patch_size=[64, 64],
|
|
1223
|
+
... batch_size=32,
|
|
1224
|
+
... num_epochs=100,
|
|
1225
|
+
... augmentations=[]
|
|
1226
|
+
... )
|
|
1227
|
+
|
|
1228
|
+
A list of transforms can be passed to the `augmentations` parameter:
|
|
1229
|
+
>>> from careamics.config.transformations import XYFlipConfig
|
|
1230
|
+
>>> config = create_n2v_configuration(
|
|
1231
|
+
... experiment_name="n2v_experiment",
|
|
1232
|
+
... data_type="array",
|
|
1233
|
+
... axes="YX",
|
|
1234
|
+
... patch_size=[64, 64],
|
|
1235
|
+
... batch_size=32,
|
|
1236
|
+
... num_epochs=100,
|
|
1237
|
+
... augmentations=[
|
|
1238
|
+
... # No rotation and only Y flipping
|
|
1239
|
+
... XYFlipConfig(flip_x = False, flip_y = True)
|
|
1240
|
+
... ]
|
|
1241
|
+
... )
|
|
1242
|
+
|
|
1243
|
+
To use N2V2, simply pass the `use_n2v2` parameter:
|
|
1244
|
+
>>> config = create_n2v_configuration(
|
|
1245
|
+
... experiment_name="n2v2_experiment",
|
|
1246
|
+
... data_type="tiff",
|
|
1247
|
+
... axes="YX",
|
|
1248
|
+
... patch_size=[64, 64],
|
|
1249
|
+
... batch_size=32,
|
|
1250
|
+
... num_epochs=100,
|
|
1251
|
+
... use_n2v2=True
|
|
1252
|
+
... )
|
|
1253
|
+
|
|
1254
|
+
For structN2V, there are two parameters to set, `struct_n2v_axis` and
|
|
1255
|
+
`struct_n2v_span`:
|
|
1256
|
+
>>> config = create_n2v_configuration(
|
|
1257
|
+
... experiment_name="structn2v_experiment",
|
|
1258
|
+
... data_type="tiff",
|
|
1259
|
+
... axes="YX",
|
|
1260
|
+
... patch_size=[64, 64],
|
|
1261
|
+
... batch_size=32,
|
|
1262
|
+
... num_epochs=100,
|
|
1263
|
+
... struct_n2v_axis="horizontal",
|
|
1264
|
+
... struct_n2v_span=7
|
|
1265
|
+
... )
|
|
1266
|
+
|
|
1267
|
+
If you are training multiple channels they will be trained independently by default,
|
|
1268
|
+
you simply need to specify the number of channels:
|
|
1269
|
+
>>> config = create_n2v_configuration(
|
|
1270
|
+
... experiment_name="n2v_experiment",
|
|
1271
|
+
... data_type="array",
|
|
1272
|
+
... axes="YXC",
|
|
1273
|
+
... patch_size=[64, 64],
|
|
1274
|
+
... batch_size=32,
|
|
1275
|
+
... num_epochs=100,
|
|
1276
|
+
... n_channels=3
|
|
1277
|
+
... )
|
|
1278
|
+
|
|
1279
|
+
If instead you want to train multiple channels together, you need to turn off the
|
|
1280
|
+
`independent_channels` parameter:
|
|
1281
|
+
>>> config = create_n2v_configuration(
|
|
1282
|
+
... experiment_name="n2v_experiment",
|
|
1283
|
+
... data_type="array",
|
|
1284
|
+
... axes="YXC",
|
|
1285
|
+
... patch_size=[64, 64],
|
|
1286
|
+
... batch_size=32,
|
|
1287
|
+
... num_epochs=100,
|
|
1288
|
+
... independent_channels=False,
|
|
1289
|
+
... n_channels=3
|
|
1290
|
+
... )
|
|
1291
|
+
|
|
1292
|
+
If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
|
|
1293
|
+
`axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
|
|
1294
|
+
for 3-D data but spatial context along the Z dimension will then not be taken into
|
|
1295
|
+
account.
|
|
1296
|
+
>>> config_2d = create_n2v_configuration(
|
|
1297
|
+
... experiment_name="n2v_experiment",
|
|
1298
|
+
... data_type="czi",
|
|
1299
|
+
... axes="SCYX",
|
|
1300
|
+
... patch_size=[64, 64],
|
|
1301
|
+
... batch_size=32,
|
|
1302
|
+
... num_epochs=100,
|
|
1303
|
+
... n_channels=1,
|
|
1304
|
+
... )
|
|
1305
|
+
>>> config_3d = create_n2v_configuration(
|
|
1306
|
+
... experiment_name="n2v_experiment",
|
|
1307
|
+
... data_type="czi",
|
|
1308
|
+
... axes="SCZYX",
|
|
1309
|
+
... patch_size=[16, 64, 64],
|
|
1310
|
+
... batch_size=16,
|
|
1311
|
+
... num_epochs=100,
|
|
1312
|
+
... n_channels=1,
|
|
1313
|
+
... )
|
|
1314
|
+
"""
|
|
1315
|
+
# if there are channels, we need to specify their number
|
|
1316
|
+
if "C" in axes and n_channels is None:
|
|
1317
|
+
raise ValueError("Number of channels must be specified when using channels.")
|
|
1318
|
+
elif "C" not in axes and (n_channels is not None and n_channels > 1):
|
|
1319
|
+
raise ValueError(
|
|
1320
|
+
f"C is not present in the axes, but number of channels is specified "
|
|
1321
|
+
f"(got {n_channels} channel)."
|
|
1322
|
+
)
|
|
1323
|
+
|
|
1324
|
+
if n_channels is None:
|
|
1325
|
+
n_channels = 1
|
|
1326
|
+
|
|
1327
|
+
# augmentations
|
|
1328
|
+
spatial_transforms = _list_spatial_augmentations(augmentations)
|
|
1329
|
+
|
|
1330
|
+
# create the N2VManipulate transform using the supplied parameters
|
|
1331
|
+
n2v_transform = N2VManipulateConfig(
|
|
1332
|
+
name=SupportedTransform.N2V_MANIPULATE.value,
|
|
1333
|
+
strategy=(
|
|
1334
|
+
SupportedPixelManipulation.MEDIAN.value
|
|
1335
|
+
if use_n2v2
|
|
1336
|
+
else SupportedPixelManipulation.UNIFORM.value
|
|
1337
|
+
),
|
|
1338
|
+
roi_size=roi_size,
|
|
1339
|
+
masked_pixel_percentage=masked_pixel_percentage,
|
|
1340
|
+
struct_mask_axis=struct_n2v_axis,
|
|
1341
|
+
struct_mask_span=struct_n2v_span,
|
|
1342
|
+
)
|
|
1343
|
+
|
|
1344
|
+
# algorithm
|
|
1345
|
+
algorithm_params = _create_algorithm_configuration(
|
|
1346
|
+
axes=axes,
|
|
1347
|
+
algorithm="n2v",
|
|
1348
|
+
loss="n2v",
|
|
1349
|
+
independent_channels=independent_channels,
|
|
1350
|
+
n_channels_in=n_channels,
|
|
1351
|
+
n_channels_out=n_channels,
|
|
1352
|
+
use_n2v2=use_n2v2,
|
|
1353
|
+
model_params=model_params,
|
|
1354
|
+
optimizer=optimizer,
|
|
1355
|
+
optimizer_params=optimizer_params,
|
|
1356
|
+
lr_scheduler=lr_scheduler,
|
|
1357
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
1358
|
+
)
|
|
1359
|
+
algorithm_params["n2v_config"] = n2v_transform
|
|
1360
|
+
|
|
1361
|
+
# data
|
|
1362
|
+
data_params = _create_data_configuration(
|
|
1363
|
+
data_type=data_type,
|
|
1364
|
+
axes=axes,
|
|
1365
|
+
patch_size=patch_size,
|
|
1366
|
+
batch_size=batch_size,
|
|
1367
|
+
augmentations=spatial_transforms,
|
|
1368
|
+
train_dataloader_params=train_dataloader_params,
|
|
1369
|
+
val_dataloader_params=val_dataloader_params,
|
|
1370
|
+
)
|
|
1371
|
+
|
|
1372
|
+
# training
|
|
1373
|
+
final_trainer_params = update_trainer_params(
|
|
1374
|
+
trainer_params=trainer_params,
|
|
1375
|
+
num_epochs=num_epochs,
|
|
1376
|
+
num_steps=num_steps,
|
|
1377
|
+
)
|
|
1378
|
+
training_params = _create_training_configuration(
|
|
1379
|
+
trainer_params=final_trainer_params,
|
|
1380
|
+
logger=logger,
|
|
1381
|
+
checkpoint_params=checkpoint_params,
|
|
1382
|
+
)
|
|
1383
|
+
|
|
1384
|
+
return Configuration(
|
|
1385
|
+
experiment_name=experiment_name,
|
|
1386
|
+
algorithm_config=algorithm_params,
|
|
1387
|
+
data_config=data_params,
|
|
1388
|
+
training_config=training_params,
|
|
1389
|
+
)
|
|
1390
|
+
|
|
1391
|
+
|
|
1392
|
+
def _create_vae_configuration(
|
|
1393
|
+
input_shape: Sequence[int],
|
|
1394
|
+
encoder_conv_strides: tuple[int, ...],
|
|
1395
|
+
decoder_conv_strides: tuple[int, ...],
|
|
1396
|
+
multiscale_count: int,
|
|
1397
|
+
z_dims: tuple[int, ...],
|
|
1398
|
+
output_channels: int,
|
|
1399
|
+
encoder_n_filters: int,
|
|
1400
|
+
decoder_n_filters: int,
|
|
1401
|
+
encoder_dropout: float,
|
|
1402
|
+
decoder_dropout: float,
|
|
1403
|
+
nonlinearity: Literal[
|
|
1404
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1405
|
+
],
|
|
1406
|
+
predict_logvar: Literal[None, "pixelwise"],
|
|
1407
|
+
analytical_kl: bool,
|
|
1408
|
+
) -> LVAEConfig:
|
|
1409
|
+
"""Create a dictionary with the parameters of the vae based algorithm model.
|
|
1410
|
+
|
|
1411
|
+
Parameters
|
|
1412
|
+
----------
|
|
1413
|
+
input_shape : tuple[int, ...]
|
|
1414
|
+
Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D.
|
|
1415
|
+
encoder_conv_strides : tuple[int, ...]
|
|
1416
|
+
Strides of the encoder convolutional layers, length also defines 2D or 3D.
|
|
1417
|
+
decoder_conv_strides : tuple[int, ...]
|
|
1418
|
+
Strides of the decoder convolutional layers, length also defines 2D or 3D.
|
|
1419
|
+
multiscale_count : int
|
|
1420
|
+
Number of lateral context layers, specific to MicroSplit.
|
|
1421
|
+
z_dims : tuple[int, ...]
|
|
1422
|
+
Number of hierarchies in the LVAE model.
|
|
1423
|
+
output_channels : int
|
|
1424
|
+
Number of output channels.
|
|
1425
|
+
encoder_n_filters : int
|
|
1426
|
+
Number of filters in the convolutional layers of the encoder.
|
|
1427
|
+
decoder_n_filters : int
|
|
1428
|
+
Number of filters in the convolutional layers of the decoder.
|
|
1429
|
+
encoder_dropout : float
|
|
1430
|
+
Dropout rate for the encoder.
|
|
1431
|
+
decoder_dropout : float
|
|
1432
|
+
Dropout rate for the decoder.
|
|
1433
|
+
nonlinearity : Literal
|
|
1434
|
+
Type of nonlinearity function to use.
|
|
1435
|
+
predict_logvar : Literal # TODO needs review
|
|
1436
|
+
_description_.
|
|
1437
|
+
analytical_kl : bool # TODO needs clarification
|
|
1438
|
+
_description_.
|
|
1439
|
+
|
|
1440
|
+
Returns
|
|
1441
|
+
-------
|
|
1442
|
+
LVAEModel
|
|
1443
|
+
LVAE model with the specified parameters.
|
|
1444
|
+
"""
|
|
1445
|
+
return LVAEConfig(
|
|
1446
|
+
architecture=SupportedArchitecture.LVAE.value,
|
|
1447
|
+
input_shape=input_shape,
|
|
1448
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1449
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1450
|
+
multiscale_count=multiscale_count,
|
|
1451
|
+
z_dims=z_dims,
|
|
1452
|
+
output_channels=output_channels,
|
|
1453
|
+
encoder_n_filters=encoder_n_filters,
|
|
1454
|
+
decoder_n_filters=decoder_n_filters,
|
|
1455
|
+
encoder_dropout=encoder_dropout,
|
|
1456
|
+
decoder_dropout=decoder_dropout,
|
|
1457
|
+
nonlinearity=nonlinearity,
|
|
1458
|
+
predict_logvar=predict_logvar,
|
|
1459
|
+
analytical_kl=analytical_kl,
|
|
1460
|
+
)
|
|
1461
|
+
|
|
1462
|
+
|
|
1463
|
+
def _create_vae_based_algorithm(
|
|
1464
|
+
algorithm: Literal["hdn", "microsplit"],
|
|
1465
|
+
loss: LVAELossConfig,
|
|
1466
|
+
input_shape: Sequence[int],
|
|
1467
|
+
encoder_conv_strides: tuple[int, ...],
|
|
1468
|
+
decoder_conv_strides: tuple[int, ...],
|
|
1469
|
+
multiscale_count: int,
|
|
1470
|
+
z_dims: tuple[int, ...],
|
|
1471
|
+
output_channels: int,
|
|
1472
|
+
encoder_n_filters: int,
|
|
1473
|
+
decoder_n_filters: int,
|
|
1474
|
+
encoder_dropout: float,
|
|
1475
|
+
decoder_dropout: float,
|
|
1476
|
+
nonlinearity: Literal[
|
|
1477
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1478
|
+
],
|
|
1479
|
+
predict_logvar: Literal[None, "pixelwise"],
|
|
1480
|
+
analytical_kl: bool,
|
|
1481
|
+
gaussian_likelihood: GaussianLikelihoodConfig | None = None,
|
|
1482
|
+
nm_likelihood: NMLikelihoodConfig | None = None,
|
|
1483
|
+
) -> dict:
|
|
1484
|
+
"""
|
|
1485
|
+
Create a dictionary with the parameters of the VAE-based algorithm model.
|
|
1486
|
+
|
|
1487
|
+
Parameters
|
|
1488
|
+
----------
|
|
1489
|
+
algorithm : Literal["hdn"]
|
|
1490
|
+
The algorithm type.
|
|
1491
|
+
loss : Literal["hdn"]
|
|
1492
|
+
The loss function type.
|
|
1493
|
+
input_shape : tuple[int, ...]
|
|
1494
|
+
The shape of the input data.
|
|
1495
|
+
encoder_conv_strides : list[int]
|
|
1496
|
+
The strides of the encoder convolutional layers.
|
|
1497
|
+
decoder_conv_strides : list[int]
|
|
1498
|
+
The strides of the decoder convolutional layers.
|
|
1499
|
+
multiscale_count : int
|
|
1500
|
+
The number of multiscale layers.
|
|
1501
|
+
z_dims : list[int]
|
|
1502
|
+
The dimensions of the latent space.
|
|
1503
|
+
output_channels : int
|
|
1504
|
+
The number of output channels.
|
|
1505
|
+
encoder_n_filters : int
|
|
1506
|
+
The number of filters in the encoder.
|
|
1507
|
+
decoder_n_filters : int
|
|
1508
|
+
The number of filters in the decoder.
|
|
1509
|
+
encoder_dropout : float
|
|
1510
|
+
The dropout rate for the encoder.
|
|
1511
|
+
decoder_dropout : float
|
|
1512
|
+
The dropout rate for the decoder.
|
|
1513
|
+
nonlinearity : Literal
|
|
1514
|
+
The nonlinearity function to use.
|
|
1515
|
+
predict_logvar : Literal[None, "pixelwise"]
|
|
1516
|
+
The type of log variance prediction.
|
|
1517
|
+
analytical_kl : bool
|
|
1518
|
+
Whether to use analytical KL divergence.
|
|
1519
|
+
gaussian_likelihood : Optional[GaussianLikelihoodConfig], optional
|
|
1520
|
+
The Gaussian likelihood model, by default None.
|
|
1521
|
+
nm_likelihood : Optional[NMLikelihoodConfig], optional
|
|
1522
|
+
The noise model likelihood model, by default None.
|
|
1523
|
+
|
|
1524
|
+
Returns
|
|
1525
|
+
-------
|
|
1526
|
+
dict
|
|
1527
|
+
A dictionary with the parameters of the VAE-based algorithm model.
|
|
1528
|
+
"""
|
|
1529
|
+
network_model = _create_vae_configuration(
|
|
1530
|
+
input_shape=input_shape,
|
|
1531
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1532
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1533
|
+
multiscale_count=multiscale_count,
|
|
1534
|
+
z_dims=z_dims,
|
|
1535
|
+
output_channels=output_channels,
|
|
1536
|
+
encoder_n_filters=encoder_n_filters,
|
|
1537
|
+
decoder_n_filters=decoder_n_filters,
|
|
1538
|
+
encoder_dropout=encoder_dropout,
|
|
1539
|
+
decoder_dropout=decoder_dropout,
|
|
1540
|
+
nonlinearity=nonlinearity,
|
|
1541
|
+
predict_logvar=predict_logvar,
|
|
1542
|
+
analytical_kl=analytical_kl,
|
|
1543
|
+
)
|
|
1544
|
+
assert gaussian_likelihood or nm_likelihood, "Likelihood model must be specified"
|
|
1545
|
+
return {
|
|
1546
|
+
"algorithm": algorithm,
|
|
1547
|
+
"loss": loss,
|
|
1548
|
+
"model": network_model,
|
|
1549
|
+
"gaussian_likelihood": gaussian_likelihood,
|
|
1550
|
+
"noise_model_likelihood": nm_likelihood,
|
|
1551
|
+
}
|
|
1552
|
+
|
|
1553
|
+
|
|
1554
|
+
def get_likelihood_config(
|
|
1555
|
+
loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
|
|
1556
|
+
# TODO remove different microsplit loss types, refac
|
|
1557
|
+
predict_logvar: Literal["pixelwise"] | None = None,
|
|
1558
|
+
logvar_lowerbound: float | None = -5.0,
|
|
1559
|
+
nm_paths: list[str] | None = None,
|
|
1560
|
+
data_stats: tuple[float, float] | None = None,
|
|
1561
|
+
) -> tuple[
|
|
1562
|
+
GaussianLikelihoodConfig | None,
|
|
1563
|
+
MultiChannelNMConfig | None,
|
|
1564
|
+
NMLikelihoodConfig | None,
|
|
1565
|
+
]:
|
|
1566
|
+
"""Get the likelihood configuration for split models.
|
|
1567
|
+
|
|
1568
|
+
Returns a tuple containing the following optional entries:
|
|
1569
|
+
- GaussianLikelihoodConfig: Gaussian likelihood configuration for musplit losses
|
|
1570
|
+
- MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
|
|
1571
|
+
losses
|
|
1572
|
+
- NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
|
|
1573
|
+
|
|
1574
|
+
Parameters
|
|
1575
|
+
----------
|
|
1576
|
+
loss_type : Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
1577
|
+
The type of loss function to use.
|
|
1578
|
+
predict_logvar : Literal["pixelwise"] | None, optional
|
|
1579
|
+
Type of log variance prediction, by default None.
|
|
1580
|
+
Required when loss_type is "musplit" or "denoisplit_musplit".
|
|
1581
|
+
logvar_lowerbound : float | None, optional
|
|
1582
|
+
Lower bound for the log variance, by default -5.0.
|
|
1583
|
+
Used when loss_type is "musplit" or "denoisplit_musplit".
|
|
1584
|
+
nm_paths : list[str] | None, optional
|
|
1585
|
+
Paths to the noise model files, by default None.
|
|
1586
|
+
Required when loss_type is "denoisplit" or "denoisplit_musplit".
|
|
1587
|
+
data_stats : tuple[float, float] | None, optional
|
|
1588
|
+
Data statistics (mean, std), by default None.
|
|
1589
|
+
Required when loss_type is "denoisplit" or "denoisplit_musplit".
|
|
1590
|
+
|
|
1591
|
+
Returns
|
|
1592
|
+
-------
|
|
1593
|
+
gaussian_lik_config : GaussianLikelihoodConfig | None
|
|
1594
|
+
Gaussian likelihood configuration for musplit losses, or None.
|
|
1595
|
+
nm_config : MultiChannelNMConfig | None
|
|
1596
|
+
Multi-channel noise model configuration for denoisplit losses, or None.
|
|
1597
|
+
nm_lik_config : NMLikelihoodConfig | None
|
|
1598
|
+
Noise model likelihood configuration for denoisplit losses, or None.
|
|
1599
|
+
|
|
1600
|
+
Raises
|
|
1601
|
+
------
|
|
1602
|
+
ValueError
|
|
1603
|
+
If required parameters are missing for the specified loss_type.
|
|
1604
|
+
"""
|
|
1605
|
+
# gaussian likelihood
|
|
1606
|
+
if loss_type in ["musplit", "denoisplit_musplit"]:
|
|
1607
|
+
# if predict_logvar is None:
|
|
1608
|
+
# raise ValueError(f"predict_logvar is required for '{loss_type}'")
|
|
1609
|
+
# TODO validators should be in pydantic models
|
|
1610
|
+
gaussian_lik_config = GaussianLikelihoodConfig(
|
|
1611
|
+
predict_logvar=predict_logvar,
|
|
1612
|
+
logvar_lowerbound=logvar_lowerbound,
|
|
1613
|
+
)
|
|
1614
|
+
else:
|
|
1615
|
+
gaussian_lik_config = None
|
|
1616
|
+
|
|
1617
|
+
# noise model likelihood
|
|
1618
|
+
if loss_type in ["denoisplit", "denoisplit_musplit"]:
|
|
1619
|
+
# if nm_paths is None:
|
|
1620
|
+
# raise ValueError(f"nm_paths is required for loss_type '{loss_type}'")
|
|
1621
|
+
# if data_stats is None:
|
|
1622
|
+
# raise ValueError(f"data_stats is required for loss_type '{loss_type}'")
|
|
1623
|
+
# TODO validators should be in pydantic models
|
|
1624
|
+
gmm_list = []
|
|
1625
|
+
if nm_paths is not None:
|
|
1626
|
+
for NM_path in nm_paths:
|
|
1627
|
+
gmm_list.append(
|
|
1628
|
+
GaussianMixtureNMConfig(
|
|
1629
|
+
model_type="GaussianMixtureNoiseModel",
|
|
1630
|
+
path=NM_path,
|
|
1631
|
+
)
|
|
1632
|
+
)
|
|
1633
|
+
noise_model_config = MultiChannelNMConfig(noise_models=gmm_list)
|
|
1634
|
+
nm_lik_config = NMLikelihoodConfig() # TODO this config isn't needed probably
|
|
1635
|
+
else:
|
|
1636
|
+
noise_model_config = None
|
|
1637
|
+
nm_lik_config = None
|
|
1638
|
+
|
|
1639
|
+
return gaussian_lik_config, noise_model_config, nm_lik_config
|
|
1640
|
+
|
|
1641
|
+
|
|
1642
|
+
# TODO wrap parameters into model, loss etc
|
|
1643
|
+
# TODO refac likelihood configs to make it 1. Can it be done ?
|
|
1644
|
+
def create_hdn_configuration(
|
|
1645
|
+
experiment_name: str,
|
|
1646
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
1647
|
+
axes: str,
|
|
1648
|
+
patch_size: Sequence[int],
|
|
1649
|
+
batch_size: int,
|
|
1650
|
+
num_epochs: int = 100,
|
|
1651
|
+
num_steps: int | None = None,
|
|
1652
|
+
encoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1653
|
+
decoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1654
|
+
multiscale_count: int = 1,
|
|
1655
|
+
z_dims: tuple[int, ...] = (128, 128),
|
|
1656
|
+
output_channels: int = 1,
|
|
1657
|
+
encoder_n_filters: int = 32,
|
|
1658
|
+
decoder_n_filters: int = 32,
|
|
1659
|
+
encoder_dropout: float = 0.0,
|
|
1660
|
+
decoder_dropout: float = 0.0,
|
|
1661
|
+
nonlinearity: Literal[
|
|
1662
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1663
|
+
] = "ReLU",
|
|
1664
|
+
analytical_kl: bool = False,
|
|
1665
|
+
predict_logvar: Literal["pixelwise"] | None = None,
|
|
1666
|
+
logvar_lowerbound: Union[float, None] = None,
|
|
1667
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
1668
|
+
trainer_params: dict | None = None,
|
|
1669
|
+
augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
|
|
1670
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
1671
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
1672
|
+
) -> Configuration:
|
|
1673
|
+
"""
|
|
1674
|
+
Create a configuration for training HDN.
|
|
1675
|
+
|
|
1676
|
+
If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
|
|
1677
|
+
2.
|
|
1678
|
+
|
|
1679
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
1680
|
+
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
1681
|
+
`axes`.
|
|
1682
|
+
|
|
1683
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
1684
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
1685
|
+
|
|
1686
|
+
By default, all channels are trained independently. To train all channels together,
|
|
1687
|
+
set `independent_channels` to False.
|
|
1688
|
+
|
|
1689
|
+
By setting `augmentations` to `None`, the default transformations (flip in X and Y,
|
|
1690
|
+
rotations by 90 degrees in the XY plane) are applied. Rather than the default
|
|
1691
|
+
transforms, a list of transforms can be passed to the `augmentations` parameter. To
|
|
1692
|
+
disable the transforms, simply pass an empty list.
|
|
1693
|
+
|
|
1694
|
+
# TODO revisit the necessity of model_params
|
|
1695
|
+
|
|
1696
|
+
Parameters
|
|
1697
|
+
----------
|
|
1698
|
+
experiment_name : str
|
|
1699
|
+
Name of the experiment.
|
|
1700
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
1701
|
+
Type of the data.
|
|
1702
|
+
axes : str
|
|
1703
|
+
Axes of the data (e.g. SYX).
|
|
1704
|
+
patch_size : List[int]
|
|
1705
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
1706
|
+
batch_size : int
|
|
1707
|
+
Batch size.
|
|
1708
|
+
num_epochs : int, default=100
|
|
1709
|
+
Number of epochs to train for. If provided, this will be added to
|
|
1710
|
+
trainer_params.
|
|
1711
|
+
num_steps : int, optional
|
|
1712
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
1713
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
1714
|
+
documentation for more details.
|
|
1715
|
+
encoder_conv_strides : tuple[int, ...], optional
|
|
1716
|
+
Strides for the encoder convolutional layers, by default (2, 2).
|
|
1717
|
+
decoder_conv_strides : tuple[int, ...], optional
|
|
1718
|
+
Strides for the decoder convolutional layers, by default (2, 2).
|
|
1719
|
+
multiscale_count : int, optional
|
|
1720
|
+
Number of scales in the multiscale architecture, by default 1.
|
|
1721
|
+
z_dims : tuple[int, ...], optional
|
|
1722
|
+
Dimensions of the latent space, by default (128, 128).
|
|
1723
|
+
output_channels : int, optional
|
|
1724
|
+
Number of output channels, by default 1.
|
|
1725
|
+
encoder_n_filters : int, optional
|
|
1726
|
+
Number of filters in the encoder, by default 32.
|
|
1727
|
+
decoder_n_filters : int, optional
|
|
1728
|
+
Number of filters in the decoder, by default 32.
|
|
1729
|
+
encoder_dropout : float, optional
|
|
1730
|
+
Dropout rate for the encoder, by default 0.0.
|
|
1731
|
+
decoder_dropout : float, optional
|
|
1732
|
+
Dropout rate for the decoder, by default 0.0.
|
|
1733
|
+
nonlinearity : Literal, optional
|
|
1734
|
+
Nonlinearity function to use, by default "ReLU".
|
|
1735
|
+
analytical_kl : bool, optional
|
|
1736
|
+
Whether to use analytical KL divergence, by default False.
|
|
1737
|
+
predict_logvar : Literal[None, "pixelwise"], optional
|
|
1738
|
+
Type of log variance prediction, by default None.
|
|
1739
|
+
logvar_lowerbound : Union[float, None], optional
|
|
1740
|
+
Lower bound for the log variance, by default None.
|
|
1741
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
1742
|
+
Logger to use for training, by default "none".
|
|
1743
|
+
trainer_params : dict, optional
|
|
1744
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
1745
|
+
augmentations : list[XYFlipConfig | XYRandomRotate90Config] | None, optional
|
|
1746
|
+
List of augmentations to apply, by default None.
|
|
1747
|
+
train_dataloader_params : Optional[dict[str, Any]], optional
|
|
1748
|
+
Parameters for the training dataloader, by default None.
|
|
1749
|
+
val_dataloader_params : Optional[dict[str, Any]], optional
|
|
1750
|
+
Parameters for the validation dataloader, by default None.
|
|
1751
|
+
|
|
1752
|
+
Returns
|
|
1753
|
+
-------
|
|
1754
|
+
Configuration
|
|
1755
|
+
The configuration object for training HDN.
|
|
1756
|
+
|
|
1757
|
+
Examples
|
|
1758
|
+
--------
|
|
1759
|
+
Minimum example:
|
|
1760
|
+
>>> config = create_hdn_configuration(
|
|
1761
|
+
... experiment_name="hdn_experiment",
|
|
1762
|
+
... data_type="array",
|
|
1763
|
+
... axes="YX",
|
|
1764
|
+
... patch_size=[64, 64],
|
|
1765
|
+
... batch_size=32,
|
|
1766
|
+
... num_epochs=100
|
|
1767
|
+
... )
|
|
1768
|
+
|
|
1769
|
+
You can also limit the number of batches per epoch:
|
|
1770
|
+
>>> config = create_hdn_configuration(
|
|
1771
|
+
... experiment_name="hdn_experiment",
|
|
1772
|
+
... data_type="array",
|
|
1773
|
+
... axes="YX",
|
|
1774
|
+
... patch_size=[64, 64],
|
|
1775
|
+
... batch_size=32,
|
|
1776
|
+
... num_steps=100 # limit to 100 batches per epoch
|
|
1777
|
+
... )
|
|
1778
|
+
"""
|
|
1779
|
+
transform_list = _list_spatial_augmentations(augmentations)
|
|
1780
|
+
|
|
1781
|
+
loss_config = LVAELossConfig(
|
|
1782
|
+
loss_type="hdn", denoisplit_weight=1, musplit_weight=0
|
|
1783
|
+
) # TODO what are the correct defaults for HDN?
|
|
1784
|
+
|
|
1785
|
+
gaussian_likelihood = GaussianLikelihoodConfig(
|
|
1786
|
+
predict_logvar=predict_logvar, logvar_lowerbound=logvar_lowerbound
|
|
1787
|
+
)
|
|
1788
|
+
|
|
1789
|
+
# algorithm & model
|
|
1790
|
+
algorithm_params = _create_vae_based_algorithm(
|
|
1791
|
+
algorithm="hdn",
|
|
1792
|
+
loss=loss_config,
|
|
1793
|
+
input_shape=patch_size,
|
|
1794
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1795
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1796
|
+
multiscale_count=multiscale_count,
|
|
1797
|
+
z_dims=z_dims,
|
|
1798
|
+
output_channels=output_channels,
|
|
1799
|
+
encoder_n_filters=encoder_n_filters,
|
|
1800
|
+
decoder_n_filters=decoder_n_filters,
|
|
1801
|
+
encoder_dropout=encoder_dropout,
|
|
1802
|
+
decoder_dropout=decoder_dropout,
|
|
1803
|
+
nonlinearity=nonlinearity,
|
|
1804
|
+
predict_logvar=predict_logvar,
|
|
1805
|
+
analytical_kl=analytical_kl,
|
|
1806
|
+
gaussian_likelihood=gaussian_likelihood,
|
|
1807
|
+
nm_likelihood=None,
|
|
1808
|
+
)
|
|
1809
|
+
|
|
1810
|
+
# data
|
|
1811
|
+
data_params = _create_data_configuration(
|
|
1812
|
+
data_type=data_type,
|
|
1813
|
+
axes=axes,
|
|
1814
|
+
patch_size=patch_size,
|
|
1815
|
+
batch_size=batch_size,
|
|
1816
|
+
augmentations=transform_list,
|
|
1817
|
+
train_dataloader_params=train_dataloader_params,
|
|
1818
|
+
val_dataloader_params=val_dataloader_params,
|
|
1819
|
+
)
|
|
1820
|
+
|
|
1821
|
+
# training
|
|
1822
|
+
final_trainer_params = update_trainer_params(
|
|
1823
|
+
trainer_params=trainer_params,
|
|
1824
|
+
num_epochs=num_epochs,
|
|
1825
|
+
num_steps=num_steps,
|
|
1826
|
+
)
|
|
1827
|
+
training_params = _create_training_configuration(
|
|
1828
|
+
trainer_params=final_trainer_params,
|
|
1829
|
+
logger=logger,
|
|
1830
|
+
)
|
|
1831
|
+
|
|
1832
|
+
return Configuration(
|
|
1833
|
+
experiment_name=experiment_name,
|
|
1834
|
+
algorithm_config=algorithm_params,
|
|
1835
|
+
data_config=data_params,
|
|
1836
|
+
training_config=training_params,
|
|
1837
|
+
)
|
|
1838
|
+
|
|
1839
|
+
|
|
1840
|
+
def create_microsplit_configuration(
|
|
1841
|
+
experiment_name: str,
|
|
1842
|
+
data_type: Literal["array", "tiff", "custom"],
|
|
1843
|
+
axes: str,
|
|
1844
|
+
patch_size: Sequence[int],
|
|
1845
|
+
batch_size: int,
|
|
1846
|
+
num_epochs: int = 100,
|
|
1847
|
+
num_steps: int | None = None,
|
|
1848
|
+
encoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1849
|
+
decoder_conv_strides: tuple[int, ...] = (2, 2),
|
|
1850
|
+
multiscale_count: int = 3,
|
|
1851
|
+
grid_size: int = 32, # TODO most likely can be derived from patch size
|
|
1852
|
+
z_dims: tuple[int, ...] = (128, 128),
|
|
1853
|
+
output_channels: int = 1,
|
|
1854
|
+
encoder_n_filters: int = 32,
|
|
1855
|
+
decoder_n_filters: int = 32,
|
|
1856
|
+
encoder_dropout: float = 0.0,
|
|
1857
|
+
decoder_dropout: float = 0.0,
|
|
1858
|
+
nonlinearity: Literal[
|
|
1859
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
1860
|
+
] = "ReLU", # TODO do we need all these?
|
|
1861
|
+
analytical_kl: bool = False,
|
|
1862
|
+
predict_logvar: Literal["pixelwise"] = "pixelwise",
|
|
1863
|
+
logvar_lowerbound: Union[float, None] = None,
|
|
1864
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
1865
|
+
trainer_params: dict | None = None,
|
|
1866
|
+
augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
|
|
1867
|
+
nm_paths: list[str] | None = None,
|
|
1868
|
+
data_stats: tuple[float, float] | None = None,
|
|
1869
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
1870
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
1871
|
+
) -> Configuration:
|
|
1872
|
+
"""
|
|
1873
|
+
Create a configuration for training MicroSplit.
|
|
1874
|
+
|
|
1875
|
+
Parameters
|
|
1876
|
+
----------
|
|
1877
|
+
experiment_name : str
|
|
1878
|
+
Name of the experiment.
|
|
1879
|
+
data_type : Literal["array", "tiff", "custom"]
|
|
1880
|
+
Type of the data.
|
|
1881
|
+
axes : str
|
|
1882
|
+
Axes of the data (e.g. SYX).
|
|
1883
|
+
patch_size : Sequence[int]
|
|
1884
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
1885
|
+
batch_size : int
|
|
1886
|
+
Batch size.
|
|
1887
|
+
num_epochs : int, default=100
|
|
1888
|
+
Number of epochs to train for. If provided, this will be added to
|
|
1889
|
+
trainer_params.
|
|
1890
|
+
num_steps : int, optional
|
|
1891
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
1892
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
1893
|
+
documentation for more details.
|
|
1894
|
+
encoder_conv_strides : tuple[int, ...], optional
|
|
1895
|
+
Strides for the encoder convolutional layers, by default (2, 2).
|
|
1896
|
+
decoder_conv_strides : tuple[int, ...], optional
|
|
1897
|
+
Strides for the decoder convolutional layers, by default (2, 2).
|
|
1898
|
+
multiscale_count : int, optional
|
|
1899
|
+
Number of multiscale levels, by default 3.
|
|
1900
|
+
grid_size : int, optional
|
|
1901
|
+
Size of the grid for multiscale training, by default 32.
|
|
1902
|
+
z_dims : tuple[int, ...], optional
|
|
1903
|
+
List of latent dims for each hierarchy level in the LVAE, default (128, 128).
|
|
1904
|
+
output_channels : int, optional
|
|
1905
|
+
Number of output channels for the model, by default 1.
|
|
1906
|
+
encoder_n_filters : int, optional
|
|
1907
|
+
Number of filters in the encoder, by default 32.
|
|
1908
|
+
decoder_n_filters : int, optional
|
|
1909
|
+
Number of filters in the decoder, by default 32.
|
|
1910
|
+
encoder_dropout : float, optional
|
|
1911
|
+
Dropout rate for the encoder, by default 0.0.
|
|
1912
|
+
decoder_dropout : float, optional
|
|
1913
|
+
Dropout rate for the decoder, by default 0.0.
|
|
1914
|
+
nonlinearity : Literal, optional
|
|
1915
|
+
Nonlinearity to use in the model, by default "ReLU".
|
|
1916
|
+
analytical_kl : bool, optional
|
|
1917
|
+
Whether to use analytical KL divergence, by default False.
|
|
1918
|
+
predict_logvar : Literal["pixelwise"] | None, optional
|
|
1919
|
+
Type of log-variance prediction, by default None.
|
|
1920
|
+
logvar_lowerbound : Union[float, None], optional
|
|
1921
|
+
Lower bound for the log variance, by default None.
|
|
1922
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
1923
|
+
Logger to use for training, by default "none".
|
|
1924
|
+
trainer_params : dict, optional
|
|
1925
|
+
Parameters for the trainer class, see PyTorch Lightning documentation.
|
|
1926
|
+
augmentations : list[Union[XYFlipConfig, XYRandomRotate90Config]] | None, optional
|
|
1927
|
+
List of augmentations to apply, by default None.
|
|
1928
|
+
nm_paths : list[str] | None, optional
|
|
1929
|
+
Paths to the noise model files, by default None.
|
|
1930
|
+
data_stats : tuple[float, float] | None, optional
|
|
1931
|
+
Data statistics (mean, std), by default None.
|
|
1932
|
+
train_dataloader_params : dict[str, Any] | None, optional
|
|
1933
|
+
Parameters for the training dataloader, by default None.
|
|
1934
|
+
val_dataloader_params : dict[str, Any] | None, optional
|
|
1935
|
+
Parameters for the validation dataloader, by default None.
|
|
1936
|
+
|
|
1937
|
+
Returns
|
|
1938
|
+
-------
|
|
1939
|
+
Configuration
|
|
1940
|
+
A configuration object for the microsplit algorithm.
|
|
1941
|
+
|
|
1942
|
+
Examples
|
|
1943
|
+
--------
|
|
1944
|
+
Minimum example:
|
|
1945
|
+
# >>> config = create_microsplit_configuration(
|
|
1946
|
+
# ... experiment_name="microsplit_experiment",
|
|
1947
|
+
# ... data_type="array",
|
|
1948
|
+
# ... axes="YX",
|
|
1949
|
+
# ... patch_size=[64, 64],
|
|
1950
|
+
# ... batch_size=32,
|
|
1951
|
+
# ... num_epochs=100
|
|
1952
|
+
|
|
1953
|
+
# ... )
|
|
1954
|
+
|
|
1955
|
+
# You can also limit the number of batches per epoch:
|
|
1956
|
+
# >>> config = create_microsplit_configuration(
|
|
1957
|
+
# ... experiment_name="microsplit_experiment",
|
|
1958
|
+
# ... data_type="array",
|
|
1959
|
+
# ... axes="YX",
|
|
1960
|
+
# ... patch_size=[64, 64],
|
|
1961
|
+
# ... batch_size=32,
|
|
1962
|
+
# ... num_steps=100 # limit to 100 batches per epoch
|
|
1963
|
+
# ... )
|
|
1964
|
+
"""
|
|
1965
|
+
transform_list = _list_spatial_augmentations(augmentations)
|
|
1966
|
+
|
|
1967
|
+
loss_config = LVAELossConfig(
|
|
1968
|
+
loss_type="denoisplit_musplit", denoisplit_weight=0.9, musplit_weight=0.1
|
|
1969
|
+
) # TODO losses need to be refactored! just for example. Add validator if sum to 1
|
|
1970
|
+
|
|
1971
|
+
# Create likelihood configurations
|
|
1972
|
+
gaussian_likelihood_config, noise_model_config, nm_likelihood_config = (
|
|
1973
|
+
get_likelihood_config(
|
|
1974
|
+
loss_type="denoisplit_musplit",
|
|
1975
|
+
predict_logvar=predict_logvar,
|
|
1976
|
+
logvar_lowerbound=logvar_lowerbound,
|
|
1977
|
+
nm_paths=nm_paths,
|
|
1978
|
+
data_stats=data_stats,
|
|
1979
|
+
)
|
|
1980
|
+
)
|
|
1981
|
+
|
|
1982
|
+
# Create the LVAE model
|
|
1983
|
+
network_model = _create_vae_configuration(
|
|
1984
|
+
input_shape=patch_size,
|
|
1985
|
+
encoder_conv_strides=encoder_conv_strides,
|
|
1986
|
+
decoder_conv_strides=decoder_conv_strides,
|
|
1987
|
+
multiscale_count=multiscale_count,
|
|
1988
|
+
z_dims=z_dims,
|
|
1989
|
+
output_channels=output_channels,
|
|
1990
|
+
encoder_n_filters=encoder_n_filters,
|
|
1991
|
+
decoder_n_filters=decoder_n_filters,
|
|
1992
|
+
encoder_dropout=encoder_dropout,
|
|
1993
|
+
decoder_dropout=decoder_dropout,
|
|
1994
|
+
nonlinearity=nonlinearity,
|
|
1995
|
+
predict_logvar=predict_logvar,
|
|
1996
|
+
analytical_kl=analytical_kl,
|
|
1997
|
+
)
|
|
1998
|
+
|
|
1999
|
+
# Create the MicroSplit algorithm configuration
|
|
2000
|
+
algorithm_params = {
|
|
2001
|
+
"algorithm": "microsplit",
|
|
2002
|
+
"loss": loss_config,
|
|
2003
|
+
"model": network_model,
|
|
2004
|
+
"gaussian_likelihood": gaussian_likelihood_config,
|
|
2005
|
+
"noise_model": noise_model_config,
|
|
2006
|
+
"noise_model_likelihood": nm_likelihood_config,
|
|
2007
|
+
}
|
|
2008
|
+
|
|
2009
|
+
# Convert to MicroSplitAlgorithm instance
|
|
2010
|
+
algorithm_config = MicroSplitAlgorithm(**algorithm_params)
|
|
2011
|
+
|
|
2012
|
+
# data
|
|
2013
|
+
data_params = _create_microsplit_data_configuration(
|
|
2014
|
+
data_type=data_type,
|
|
2015
|
+
axes=axes,
|
|
2016
|
+
patch_size=patch_size,
|
|
2017
|
+
grid_size=grid_size,
|
|
2018
|
+
multiscale_count=multiscale_count,
|
|
2019
|
+
batch_size=batch_size,
|
|
2020
|
+
augmentations=transform_list,
|
|
2021
|
+
train_dataloader_params=train_dataloader_params,
|
|
2022
|
+
val_dataloader_params=val_dataloader_params,
|
|
2023
|
+
)
|
|
2024
|
+
|
|
2025
|
+
# training
|
|
2026
|
+
final_trainer_params = update_trainer_params(
|
|
2027
|
+
trainer_params=trainer_params,
|
|
2028
|
+
num_epochs=num_epochs,
|
|
2029
|
+
num_steps=num_steps,
|
|
2030
|
+
)
|
|
2031
|
+
training_params = _create_training_configuration(
|
|
2032
|
+
trainer_params=final_trainer_params,
|
|
2033
|
+
logger=logger,
|
|
2034
|
+
)
|
|
2035
|
+
|
|
2036
|
+
return Configuration(
|
|
2037
|
+
experiment_name=experiment_name,
|
|
2038
|
+
algorithm_config=algorithm_config,
|
|
2039
|
+
data_config=data_params,
|
|
2040
|
+
training_config=training_params,
|
|
2041
|
+
)
|
|
2042
|
+
|
|
2043
|
+
|
|
2044
|
+
def create_pn2v_configuration(
|
|
2045
|
+
experiment_name: str,
|
|
2046
|
+
data_type: Literal["array", "tiff", "czi", "custom"],
|
|
2047
|
+
axes: str,
|
|
2048
|
+
patch_size: Sequence[int],
|
|
2049
|
+
batch_size: int,
|
|
2050
|
+
nm_path: str,
|
|
2051
|
+
num_epochs: int = 100,
|
|
2052
|
+
num_steps: int | None = None,
|
|
2053
|
+
augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
|
|
2054
|
+
independent_channels: bool = True,
|
|
2055
|
+
use_n2v2: bool = False,
|
|
2056
|
+
num_in_channels: int = 1,
|
|
2057
|
+
num_out_channels: int = 100,
|
|
2058
|
+
roi_size: int = 11,
|
|
2059
|
+
masked_pixel_percentage: float = 0.2,
|
|
2060
|
+
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
2061
|
+
struct_n2v_span: int = 5,
|
|
2062
|
+
trainer_params: dict | None = None,
|
|
2063
|
+
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
2064
|
+
model_params: dict | None = None,
|
|
2065
|
+
optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
|
|
2066
|
+
optimizer_params: dict[str, Any] | None = None,
|
|
2067
|
+
lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
|
|
2068
|
+
lr_scheduler_params: dict[str, Any] | None = None,
|
|
2069
|
+
train_dataloader_params: dict[str, Any] | None = None,
|
|
2070
|
+
val_dataloader_params: dict[str, Any] | None = None,
|
|
2071
|
+
checkpoint_params: dict[str, Any] | None = None,
|
|
2072
|
+
) -> Configuration:
|
|
2073
|
+
"""
|
|
2074
|
+
Create a configuration for training Probabilistic Noise2Void (PN2V).
|
|
2075
|
+
|
|
2076
|
+
PN2V extends N2V by incorporating a probabilistic noise model to estimate the
|
|
2077
|
+
posterior distibution of each pixel more precisely.
|
|
2078
|
+
|
|
2079
|
+
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
2080
|
+
2.
|
|
2081
|
+
|
|
2082
|
+
If "C" is present in `axes`, then you need to set `num_in_channels` to the number of
|
|
2083
|
+
channels.
|
|
2084
|
+
|
|
2085
|
+
By default, all channels are trained independently. To train all channels together,
|
|
2086
|
+
set `independent_channels` to False. When training independently, each input channel
|
|
2087
|
+
will have `num_out_channels` outputs (default 400). When training together, all
|
|
2088
|
+
input channels will share `num_out_channels` outputs.
|
|
2089
|
+
|
|
2090
|
+
By default, the transformations applied are a random flip along X or Y, and a random
|
|
2091
|
+
90 degrees rotation in the XY plane. Normalization is always applied, as well as the
|
|
2092
|
+
N2V manipulation.
|
|
2093
|
+
|
|
2094
|
+
By setting `augmentations` to `None`, the default transformations (flip in X and Y,
|
|
2095
|
+
rotations by 90 degrees in the XY plane) are applied. Rather than the default
|
|
2096
|
+
transforms, a list of transforms can be passed to the `augmentations` parameter. To
|
|
2097
|
+
disable the transforms, simply pass an empty list.
|
|
2098
|
+
|
|
2099
|
+
The `roi_size` parameter specifies the size of the area around each pixel that will
|
|
2100
|
+
be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
|
|
2101
|
+
pixels per patch will be manipulated.
|
|
2102
|
+
|
|
2103
|
+
The parameters of the UNet can be specified in the `model_params` (passed as a
|
|
2104
|
+
parameter-value dictionary). Note that `use_n2v2`, `num_in_channels`, and
|
|
2105
|
+
`num_out_channels` override the corresponding parameters passed in `model_params`.
|
|
2106
|
+
|
|
2107
|
+
If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
|
|
2108
|
+
will be applied to each manipulated pixel.
|
|
2109
|
+
|
|
2110
|
+
Parameters
|
|
2111
|
+
----------
|
|
2112
|
+
experiment_name : str
|
|
2113
|
+
Name of the experiment.
|
|
2114
|
+
data_type : Literal["array", "tiff", "czi", "custom"]
|
|
2115
|
+
Type of the data.
|
|
2116
|
+
axes : str
|
|
2117
|
+
Axes of the data (e.g. SYX).
|
|
2118
|
+
patch_size : List[int]
|
|
2119
|
+
Size of the patches along the spatial dimensions (e.g. [64, 64]).
|
|
2120
|
+
batch_size : int
|
|
2121
|
+
Batch size.
|
|
2122
|
+
nm_path : str
|
|
2123
|
+
Path to the noise model file.
|
|
2124
|
+
num_epochs : int, default=100
|
|
2125
|
+
Number of epochs to train for. If provided, this will be added to
|
|
2126
|
+
trainer_params.
|
|
2127
|
+
num_steps : int, optional
|
|
2128
|
+
Number of batches in 1 epoch. If provided, this will be added to trainer_params.
|
|
2129
|
+
Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
|
|
2130
|
+
documentation for more details.
|
|
2131
|
+
augmentations : list of transforms, default=None
|
|
2132
|
+
List of transforms to apply, either both or one of XYFlipModel and
|
|
2133
|
+
XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
|
|
2134
|
+
and XYRandomRotate90 (in XY) to the images.
|
|
2135
|
+
independent_channels : bool, optional
|
|
2136
|
+
Whether to train all channels independently, by default True. If True, each
|
|
2137
|
+
input channel will correspond to num_out_channels output channels (e.g., 3
|
|
2138
|
+
input channels with num_out_channels=400 results in 1200 total output
|
|
2139
|
+
channels).
|
|
2140
|
+
use_n2v2 : bool, optional
|
|
2141
|
+
Whether to use N2V2, by default False.
|
|
2142
|
+
num_in_channels : int, default=1
|
|
2143
|
+
Number of input channels.
|
|
2144
|
+
num_out_channels : int, default=400
|
|
2145
|
+
Number of output channels per input channel when independent_channels is True,
|
|
2146
|
+
or total number of output channels when independent_channels is False.
|
|
2147
|
+
roi_size : int, optional
|
|
2148
|
+
N2V pixel manipulation area, by default 11.
|
|
2149
|
+
masked_pixel_percentage : float, optional
|
|
2150
|
+
Percentage of pixels masked in each patch, by default 0.2.
|
|
2151
|
+
struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
|
|
2152
|
+
Axis along which to apply structN2V mask, by default "none".
|
|
2153
|
+
struct_n2v_span : int, optional
|
|
2154
|
+
Span of the structN2V mask, by default 5.
|
|
2155
|
+
trainer_params : dict, optional
|
|
2156
|
+
Parameters for the trainer, see the relevant documentation.
|
|
2157
|
+
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
2158
|
+
Logger to use, by default "none".
|
|
2159
|
+
model_params : dict, default=None
|
|
2160
|
+
UNetModel parameters.
|
|
2161
|
+
optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
|
|
2162
|
+
Optimizer to use.
|
|
2163
|
+
optimizer_params : dict, default=None
|
|
2164
|
+
Parameters for the optimizer, see PyTorch documentation for more details.
|
|
2165
|
+
lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
|
|
2166
|
+
Learning rate scheduler to use.
|
|
2167
|
+
lr_scheduler_params : dict, default=None
|
|
2168
|
+
Parameters for the learning rate scheduler, see PyTorch documentation for more
|
|
2169
|
+
details.
|
|
2170
|
+
train_dataloader_params : dict, optional
|
|
2171
|
+
Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
|
|
2172
|
+
If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
|
|
2173
|
+
the `GeneralDataConfig`.
|
|
2174
|
+
val_dataloader_params : dict, optional
|
|
2175
|
+
Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
|
|
2176
|
+
If left as `None`, the empty dict `{}` will be used, this is set in the
|
|
2177
|
+
`GeneralDataConfig`.
|
|
2178
|
+
checkpoint_params : dict, default=None
|
|
2179
|
+
Parameters for the checkpoint callback, see PyTorch Lightning documentation
|
|
2180
|
+
(`ModelCheckpoint`) for the list of available parameters.
|
|
2181
|
+
|
|
2182
|
+
Returns
|
|
2183
|
+
-------
|
|
2184
|
+
Configuration
|
|
2185
|
+
Configuration for training PN2V.
|
|
2186
|
+
|
|
2187
|
+
Examples
|
|
2188
|
+
--------
|
|
2189
|
+
Minimum example:
|
|
2190
|
+
# >>> config = create_pn2v_configuration(
|
|
2191
|
+
# ... experiment_name="pn2v_experiment",
|
|
2192
|
+
# ... data_type="array",
|
|
2193
|
+
# ... axes="YX",
|
|
2194
|
+
# ... patch_size=[64, 64],
|
|
2195
|
+
# ... batch_size=32,
|
|
2196
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2197
|
+
# ... num_epochs=100
|
|
2198
|
+
# ... )
|
|
2199
|
+
|
|
2200
|
+
# You can also limit the number of batches per epoch:
|
|
2201
|
+
# >>> config = create_pn2v_configuration(
|
|
2202
|
+
# ... experiment_name="pn2v_experiment",
|
|
2203
|
+
# ... data_type="array",
|
|
2204
|
+
# ... axes="YX",
|
|
2205
|
+
# ... patch_size=[64, 64],
|
|
2206
|
+
# ... batch_size=32,
|
|
2207
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2208
|
+
# ... num_steps=100 # limit to 100 batches per epoch
|
|
2209
|
+
# ... )
|
|
2210
|
+
|
|
2211
|
+
# To disable transforms, simply set `augmentations` to an empty list:
|
|
2212
|
+
# >>> config = create_pn2v_configuration(
|
|
2213
|
+
# ... experiment_name="pn2v_experiment",
|
|
2214
|
+
# ... data_type="array",
|
|
2215
|
+
# ... axes="YX",
|
|
2216
|
+
# ... patch_size=[64, 64],
|
|
2217
|
+
# ... batch_size=32,
|
|
2218
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2219
|
+
# ... num_epochs=100,
|
|
2220
|
+
# ... augmentations=[]
|
|
2221
|
+
# ... )
|
|
2222
|
+
|
|
2223
|
+
# A list of transforms can be passed to the `augmentations` parameter:
|
|
2224
|
+
# >>> from careamics.config.transformations import XYFlipModel
|
|
2225
|
+
# >>> config = create_pn2v_configuration(
|
|
2226
|
+
# ... experiment_name="pn2v_experiment",
|
|
2227
|
+
# ... data_type="array",
|
|
2228
|
+
# ... axes="YX",
|
|
2229
|
+
# ... patch_size=[64, 64],
|
|
2230
|
+
# ... batch_size=32,
|
|
2231
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2232
|
+
# ... num_epochs=100,
|
|
2233
|
+
# ... augmentations=[
|
|
2234
|
+
# ... # No rotation and only Y flipping
|
|
2235
|
+
# ... XYFlipModel(flip_x = False, flip_y = True)
|
|
2236
|
+
# ... ]
|
|
2237
|
+
# ... )
|
|
2238
|
+
|
|
2239
|
+
# To use N2V2, simply pass the `use_n2v2` parameter:
|
|
2240
|
+
# >>> config = create_pn2v_configuration(
|
|
2241
|
+
# ... experiment_name="pn2v2_experiment",
|
|
2242
|
+
# ... data_type="tiff",
|
|
2243
|
+
# ... axes="YX",
|
|
2244
|
+
# ... patch_size=[64, 64],
|
|
2245
|
+
# ... batch_size=32,
|
|
2246
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2247
|
+
# ... num_epochs=100,
|
|
2248
|
+
# ... use_n2v2=True
|
|
2249
|
+
# ... )
|
|
2250
|
+
|
|
2251
|
+
# For structN2V, there are two parameters to set, `struct_n2v_axis` and
|
|
2252
|
+
# `struct_n2v_span`:
|
|
2253
|
+
# >>> config = create_pn2v_configuration(
|
|
2254
|
+
# ... experiment_name="structpn2v_experiment",
|
|
2255
|
+
# ... data_type="tiff",
|
|
2256
|
+
# ... axes="YX",
|
|
2257
|
+
# ... patch_size=[64, 64],
|
|
2258
|
+
# ... batch_size=32,
|
|
2259
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2260
|
+
# ... num_epochs=100,
|
|
2261
|
+
# ... struct_n2v_axis="horizontal",
|
|
2262
|
+
# ... struct_n2v_span=7
|
|
2263
|
+
# ... )
|
|
2264
|
+
|
|
2265
|
+
# If you are training multiple channels they will be trained independently by
|
|
2266
|
+
# default, you simply need to specify the number of input channels. Each input
|
|
2267
|
+
# channel will correspond to num_out_channels outputs (1200 total for 3
|
|
2268
|
+
# channels with default num_out_channels=400):
|
|
2269
|
+
# >>> config = create_pn2v_configuration(
|
|
2270
|
+
# ... experiment_name="pn2v_experiment",
|
|
2271
|
+
# ... data_type="array",
|
|
2272
|
+
# ... axes="YXC",
|
|
2273
|
+
# ... patch_size=[64, 64],
|
|
2274
|
+
# ... batch_size=32,
|
|
2275
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2276
|
+
# ... num_epochs=100,
|
|
2277
|
+
# ... num_in_channels=3
|
|
2278
|
+
# ... )
|
|
2279
|
+
|
|
2280
|
+
# If instead you want to train multiple channels together, you need to turn
|
|
2281
|
+
# off the `independent_channels` parameter (resulting in 400 total output
|
|
2282
|
+
# channels regardless of the number of input channels):
|
|
2283
|
+
# >>> config = create_pn2v_configuration(
|
|
2284
|
+
# ... experiment_name="pn2v_experiment",
|
|
2285
|
+
# ... data_type="array",
|
|
2286
|
+
# ... axes="YXC",
|
|
2287
|
+
# ... patch_size=[64, 64],
|
|
2288
|
+
# ... batch_size=32,
|
|
2289
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2290
|
+
# ... num_epochs=100,
|
|
2291
|
+
# ... independent_channels=False,
|
|
2292
|
+
# ... num_in_channels=3
|
|
2293
|
+
# ... )
|
|
2294
|
+
|
|
2295
|
+
# >>> config_2d = create_pn2v_configuration(
|
|
2296
|
+
# ... experiment_name="pn2v_experiment",
|
|
2297
|
+
# ... data_type="czi",
|
|
2298
|
+
# ... axes="SCYX",
|
|
2299
|
+
# ... patch_size=[64, 64],
|
|
2300
|
+
# ... batch_size=32,
|
|
2301
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2302
|
+
# ... num_epochs=100,
|
|
2303
|
+
# ... num_in_channels=1,
|
|
2304
|
+
# ... )
|
|
2305
|
+
# >>> config_3d = create_pn2v_configuration(
|
|
2306
|
+
# ... experiment_name="pn2v_experiment",
|
|
2307
|
+
# ... data_type="czi",
|
|
2308
|
+
# ... axes="SCZYX",
|
|
2309
|
+
# ... patch_size=[16, 64, 64],
|
|
2310
|
+
# ... batch_size=16,
|
|
2311
|
+
# ... nm_path="path/to/noise_model.npz",
|
|
2312
|
+
# ... num_epochs=100,
|
|
2313
|
+
# ... num_in_channels=1,
|
|
2314
|
+
# ... )
|
|
2315
|
+
"""
|
|
2316
|
+
# Validate channel configuration
|
|
2317
|
+
if "C" in axes and num_in_channels < 1:
|
|
2318
|
+
raise ValueError("num_in_channels must be at least 1 when using channels.")
|
|
2319
|
+
elif "C" not in axes and num_in_channels > 1:
|
|
2320
|
+
raise ValueError(
|
|
2321
|
+
f"C is not present in the axes, but num_in_channels is specified "
|
|
2322
|
+
f"(got {num_in_channels} channels)."
|
|
2323
|
+
)
|
|
2324
|
+
|
|
2325
|
+
# Calculate total output channels based on independent_channels setting
|
|
2326
|
+
if independent_channels:
|
|
2327
|
+
total_out_channels = num_in_channels * num_out_channels
|
|
2328
|
+
else:
|
|
2329
|
+
total_out_channels = num_out_channels
|
|
2330
|
+
|
|
2331
|
+
# augmentations
|
|
2332
|
+
spatial_transforms = _list_spatial_augmentations(augmentations)
|
|
2333
|
+
|
|
2334
|
+
# create the N2VManipulate transform using the supplied parameters
|
|
2335
|
+
n2v_transform = N2VManipulateConfig(
|
|
2336
|
+
name=SupportedTransform.N2V_MANIPULATE.value,
|
|
2337
|
+
strategy=(
|
|
2338
|
+
SupportedPixelManipulation.MEDIAN.value
|
|
2339
|
+
if use_n2v2
|
|
2340
|
+
else SupportedPixelManipulation.UNIFORM.value
|
|
2341
|
+
),
|
|
2342
|
+
roi_size=roi_size,
|
|
2343
|
+
masked_pixel_percentage=masked_pixel_percentage,
|
|
2344
|
+
struct_mask_axis=struct_n2v_axis,
|
|
2345
|
+
struct_mask_span=struct_n2v_span,
|
|
2346
|
+
)
|
|
2347
|
+
|
|
2348
|
+
# Create noise model configuration
|
|
2349
|
+
noise_model_config = GaussianMixtureNMConfig(path=nm_path)
|
|
2350
|
+
|
|
2351
|
+
# algorithm
|
|
2352
|
+
algorithm_params = _create_algorithm_configuration(
|
|
2353
|
+
axes=axes,
|
|
2354
|
+
algorithm="pn2v",
|
|
2355
|
+
loss="pn2v",
|
|
2356
|
+
independent_channels=independent_channels,
|
|
2357
|
+
n_channels_in=num_in_channels,
|
|
2358
|
+
n_channels_out=total_out_channels,
|
|
2359
|
+
use_n2v2=use_n2v2,
|
|
2360
|
+
model_params=model_params,
|
|
2361
|
+
optimizer=optimizer,
|
|
2362
|
+
optimizer_params=optimizer_params,
|
|
2363
|
+
lr_scheduler=lr_scheduler,
|
|
2364
|
+
lr_scheduler_params=lr_scheduler_params,
|
|
2365
|
+
)
|
|
2366
|
+
algorithm_params["n2v_config"] = n2v_transform
|
|
2367
|
+
algorithm_params["noise_model"] = noise_model_config
|
|
2368
|
+
|
|
2369
|
+
# Convert to PN2VAlgorithm instance
|
|
2370
|
+
algorithm_config = PN2VAlgorithm(**algorithm_params)
|
|
2371
|
+
|
|
2372
|
+
# data
|
|
2373
|
+
data_params = _create_data_configuration(
|
|
2374
|
+
data_type=data_type,
|
|
2375
|
+
axes=axes,
|
|
2376
|
+
patch_size=patch_size,
|
|
2377
|
+
batch_size=batch_size,
|
|
2378
|
+
augmentations=spatial_transforms,
|
|
2379
|
+
train_dataloader_params=train_dataloader_params,
|
|
2380
|
+
val_dataloader_params=val_dataloader_params,
|
|
2381
|
+
)
|
|
2382
|
+
|
|
2383
|
+
# training
|
|
2384
|
+
final_trainer_params = update_trainer_params(
|
|
2385
|
+
trainer_params=trainer_params,
|
|
2386
|
+
num_epochs=num_epochs,
|
|
2387
|
+
num_steps=num_steps,
|
|
2388
|
+
)
|
|
2389
|
+
training_params = _create_training_configuration(
|
|
2390
|
+
trainer_params=final_trainer_params,
|
|
2391
|
+
logger=logger,
|
|
2392
|
+
checkpoint_params=checkpoint_params,
|
|
2393
|
+
)
|
|
2394
|
+
|
|
2395
|
+
return Configuration(
|
|
2396
|
+
experiment_name=experiment_name,
|
|
2397
|
+
algorithm_config=algorithm_config,
|
|
2398
|
+
data_config=data_params,
|
|
2399
|
+
training_config=training_params,
|
|
2400
|
+
)
|