careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Data Pydantic configuration models."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"DataConfig",
|
|
5
|
+
"MaskFilterConfig",
|
|
6
|
+
"MaxFilterConfig",
|
|
7
|
+
"MeanSTDFilterConfig",
|
|
8
|
+
"NGDataConfig",
|
|
9
|
+
"RandomPatchingConfig",
|
|
10
|
+
"ShannonFilterConfig",
|
|
11
|
+
"TiledPatchingConfig",
|
|
12
|
+
"WholePatchingConfig",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
from .data_config import DataConfig
|
|
16
|
+
from .ng_data_config import NGDataConfig
|
|
17
|
+
from .patch_filter import (
|
|
18
|
+
MaskFilterConfig,
|
|
19
|
+
MaxFilterConfig,
|
|
20
|
+
MeanSTDFilterConfig,
|
|
21
|
+
ShannonFilterConfig,
|
|
22
|
+
)
|
|
23
|
+
from .patching_strategies import (
|
|
24
|
+
RandomPatchingConfig,
|
|
25
|
+
TiledPatchingConfig,
|
|
26
|
+
WholePatchingConfig,
|
|
27
|
+
)
|
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
"""Data configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from pprint import pformat
|
|
9
|
+
from typing import Annotated, Any, Literal, Self, Union
|
|
10
|
+
from warnings import warn
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from numpy.typing import NDArray
|
|
14
|
+
from pydantic import (
|
|
15
|
+
BaseModel,
|
|
16
|
+
ConfigDict,
|
|
17
|
+
Field,
|
|
18
|
+
PlainSerializer,
|
|
19
|
+
field_validator,
|
|
20
|
+
model_validator,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from ..transformations import XYFlipConfig, XYRandomRotate90Config
|
|
24
|
+
from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def np_float_to_scientific_str(x: float) -> str:
|
|
28
|
+
"""Return a string scientific representation of a float.
|
|
29
|
+
|
|
30
|
+
In particular, this method is used to serialize floats to strings, allowing
|
|
31
|
+
numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
x : float
|
|
36
|
+
Input value.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
str
|
|
41
|
+
Scientific string representation of the input value.
|
|
42
|
+
"""
|
|
43
|
+
return np.format_float_scientific(x, precision=7)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
|
|
47
|
+
"""Annotated float type, used to serialize floats to strings."""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class DataConfig(BaseModel):
|
|
51
|
+
"""Data configuration.
|
|
52
|
+
|
|
53
|
+
If std is specified, mean must be specified as well. Note that setting the std first
|
|
54
|
+
and then the mean (if they were both `None` before) will raise a validation error.
|
|
55
|
+
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
|
|
56
|
+
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
57
|
+
of the target could be different from the input data.
|
|
58
|
+
|
|
59
|
+
All supported transforms are defined in the SupportedTransform enum.
|
|
60
|
+
|
|
61
|
+
Examples
|
|
62
|
+
--------
|
|
63
|
+
Minimum example:
|
|
64
|
+
|
|
65
|
+
>>> data = DataConfig(
|
|
66
|
+
... data_type="array", # defined in SupportedData
|
|
67
|
+
... patch_size=[128, 128],
|
|
68
|
+
... batch_size=4,
|
|
69
|
+
... axes="YX"
|
|
70
|
+
... )
|
|
71
|
+
|
|
72
|
+
To change the image_means and image_stds of the data:
|
|
73
|
+
>>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
|
|
74
|
+
|
|
75
|
+
One can pass also a list of transformations, by keyword, using the
|
|
76
|
+
SupportedTransform value:
|
|
77
|
+
>>> from careamics.config.support import SupportedTransform
|
|
78
|
+
>>> data = DataConfig(
|
|
79
|
+
... data_type="tiff",
|
|
80
|
+
... patch_size=[128, 128],
|
|
81
|
+
... batch_size=4,
|
|
82
|
+
... axes="YX",
|
|
83
|
+
... transforms=[
|
|
84
|
+
... {
|
|
85
|
+
... "name": "XYFlip",
|
|
86
|
+
... }
|
|
87
|
+
... ]
|
|
88
|
+
... )
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
# Pydantic class configuration
|
|
92
|
+
model_config = ConfigDict(
|
|
93
|
+
validate_assignment=True,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Dataset configuration
|
|
97
|
+
data_type: Literal["array", "tiff", "czi", "custom"]
|
|
98
|
+
"""Type of input data, numpy.ndarray (array) or paths (tiff, czi, and custom), as
|
|
99
|
+
defined in SupportedData."""
|
|
100
|
+
|
|
101
|
+
axes: str
|
|
102
|
+
"""Axes of the data, as defined in SupportedAxes."""
|
|
103
|
+
|
|
104
|
+
patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
|
|
105
|
+
"""Patch size, as used during training."""
|
|
106
|
+
|
|
107
|
+
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
108
|
+
"""Batch size for training."""
|
|
109
|
+
|
|
110
|
+
# Optional fields
|
|
111
|
+
image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
112
|
+
"""Means of the data across channels, used for normalization."""
|
|
113
|
+
|
|
114
|
+
image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
115
|
+
"""Standard deviations of the data across channels, used for normalization."""
|
|
116
|
+
|
|
117
|
+
target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
118
|
+
"""Means of the target data across channels, used for normalization."""
|
|
119
|
+
|
|
120
|
+
target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
|
|
121
|
+
"""Standard deviations of the target data across channels, used for
|
|
122
|
+
normalization."""
|
|
123
|
+
|
|
124
|
+
transforms: Sequence[Union[XYFlipConfig, XYRandomRotate90Config]] = Field(
|
|
125
|
+
default=[
|
|
126
|
+
XYFlipConfig(),
|
|
127
|
+
XYRandomRotate90Config(),
|
|
128
|
+
],
|
|
129
|
+
validate_default=True,
|
|
130
|
+
)
|
|
131
|
+
"""List of transformations to apply to the data, available transforms are defined
|
|
132
|
+
in SupportedTransform."""
|
|
133
|
+
|
|
134
|
+
train_dataloader_params: dict[str, Any] = Field(
|
|
135
|
+
default={"shuffle": True}, validate_default=True
|
|
136
|
+
)
|
|
137
|
+
"""Dictionary of PyTorch training dataloader parameters. The dataloader parameters,
|
|
138
|
+
should include the `shuffle` key, which is set to `True` by default. We strongly
|
|
139
|
+
recommend to keep it as `True` to ensure the best training results."""
|
|
140
|
+
|
|
141
|
+
val_dataloader_params: dict[str, Any] = Field(default={}, validate_default=True)
|
|
142
|
+
"""Dictionary of PyTorch validation dataloader parameters."""
|
|
143
|
+
|
|
144
|
+
@field_validator("patch_size")
|
|
145
|
+
@classmethod
|
|
146
|
+
def all_elements_power_of_2_minimum_8(
|
|
147
|
+
cls, patch_list: Union[list[int]]
|
|
148
|
+
) -> Union[list[int]]:
|
|
149
|
+
"""
|
|
150
|
+
Validate patch size.
|
|
151
|
+
|
|
152
|
+
Patch size must be powers of 2 and minimum 8.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
patch_list : list of int
|
|
157
|
+
Patch size.
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
list of int
|
|
162
|
+
Validated patch size.
|
|
163
|
+
|
|
164
|
+
Raises
|
|
165
|
+
------
|
|
166
|
+
ValueError
|
|
167
|
+
If the patch size is smaller than 8.
|
|
168
|
+
ValueError
|
|
169
|
+
If the patch size is not a power of 2.
|
|
170
|
+
"""
|
|
171
|
+
patch_size_ge_than_8_power_of_2(patch_list)
|
|
172
|
+
|
|
173
|
+
return patch_list
|
|
174
|
+
|
|
175
|
+
@field_validator("axes")
|
|
176
|
+
@classmethod
|
|
177
|
+
def axes_valid(cls, axes: str) -> str:
|
|
178
|
+
"""
|
|
179
|
+
Validate axes.
|
|
180
|
+
|
|
181
|
+
Axes must:
|
|
182
|
+
- be a combination of 'STCZYX'
|
|
183
|
+
- not contain duplicates
|
|
184
|
+
- contain at least 2 contiguous axes: X and Y
|
|
185
|
+
- contain at most 4 axes
|
|
186
|
+
- not contain both S and T axes
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
axes : str
|
|
191
|
+
Axes to validate.
|
|
192
|
+
|
|
193
|
+
Returns
|
|
194
|
+
-------
|
|
195
|
+
str
|
|
196
|
+
Validated axes.
|
|
197
|
+
|
|
198
|
+
Raises
|
|
199
|
+
------
|
|
200
|
+
ValueError
|
|
201
|
+
If axes are not valid.
|
|
202
|
+
"""
|
|
203
|
+
# Validate axes
|
|
204
|
+
check_axes_validity(axes)
|
|
205
|
+
|
|
206
|
+
return axes
|
|
207
|
+
|
|
208
|
+
@field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
|
|
209
|
+
@classmethod
|
|
210
|
+
def set_default_pin_memory(
|
|
211
|
+
cls, dataloader_params: dict[str, Any]
|
|
212
|
+
) -> dict[str, Any]:
|
|
213
|
+
"""
|
|
214
|
+
Set default pin_memory for dataloader parameters if not provided.
|
|
215
|
+
|
|
216
|
+
- If 'pin_memory' is not set, it defaults to True if CUDA is available.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
dataloader_params : dict of {str: Any}
|
|
221
|
+
The dataloader parameters.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
dict of {str: Any}
|
|
226
|
+
The dataloader parameters with pin_memory default applied.
|
|
227
|
+
"""
|
|
228
|
+
if "pin_memory" not in dataloader_params:
|
|
229
|
+
import torch
|
|
230
|
+
|
|
231
|
+
dataloader_params["pin_memory"] = torch.cuda.is_available()
|
|
232
|
+
|
|
233
|
+
return dataloader_params
|
|
234
|
+
|
|
235
|
+
@field_validator("train_dataloader_params", mode="before")
|
|
236
|
+
@classmethod
|
|
237
|
+
def set_default_train_workers(
|
|
238
|
+
cls, dataloader_params: dict[str, Any]
|
|
239
|
+
) -> dict[str, Any]:
|
|
240
|
+
"""
|
|
241
|
+
Set default num_workers for training dataloader if not provided.
|
|
242
|
+
|
|
243
|
+
- If 'num_workers' is not set, it defaults to the number of available CPU cores.
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
dataloader_params : dict of {str: Any}
|
|
248
|
+
The training dataloader parameters.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
dict of {str: Any}
|
|
253
|
+
The dataloader parameters with num_workers default applied.
|
|
254
|
+
"""
|
|
255
|
+
if "num_workers" not in dataloader_params:
|
|
256
|
+
# Use 0 workers during tests, otherwise use all available CPU cores
|
|
257
|
+
if "pytest" in sys.modules:
|
|
258
|
+
dataloader_params["num_workers"] = 0
|
|
259
|
+
else:
|
|
260
|
+
dataloader_params["num_workers"] = os.cpu_count()
|
|
261
|
+
|
|
262
|
+
return dataloader_params
|
|
263
|
+
|
|
264
|
+
@model_validator(mode="after")
|
|
265
|
+
def set_val_workers_to_match_train(self: Self) -> Self:
|
|
266
|
+
"""
|
|
267
|
+
Set validation dataloader num_workers to match training dataloader.
|
|
268
|
+
|
|
269
|
+
If num_workers is not specified in val_dataloader_params, it will be set to the
|
|
270
|
+
same value as train_dataloader_params["num_workers"].
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
Self
|
|
275
|
+
Validated data model with synchronized num_workers.
|
|
276
|
+
"""
|
|
277
|
+
if "num_workers" not in self.val_dataloader_params:
|
|
278
|
+
self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
|
|
279
|
+
"num_workers"
|
|
280
|
+
]
|
|
281
|
+
return self
|
|
282
|
+
|
|
283
|
+
@field_validator("train_dataloader_params")
|
|
284
|
+
@classmethod
|
|
285
|
+
def shuffle_train_dataloader(
|
|
286
|
+
cls, train_dataloader_params: dict[str, Any]
|
|
287
|
+
) -> dict[str, Any]:
|
|
288
|
+
"""
|
|
289
|
+
Validate that "shuffle" is included in the training dataloader params.
|
|
290
|
+
|
|
291
|
+
A warning will be raised if `shuffle=False`.
|
|
292
|
+
|
|
293
|
+
Parameters
|
|
294
|
+
----------
|
|
295
|
+
train_dataloader_params : dict of {str: Any}
|
|
296
|
+
The training dataloader parameters.
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
-------
|
|
300
|
+
dict of {str: Any}
|
|
301
|
+
The validated training dataloader parameters.
|
|
302
|
+
|
|
303
|
+
Raises
|
|
304
|
+
------
|
|
305
|
+
ValueError
|
|
306
|
+
If "shuffle" is not included in the training dataloader params.
|
|
307
|
+
"""
|
|
308
|
+
if "shuffle" not in train_dataloader_params:
|
|
309
|
+
raise ValueError(
|
|
310
|
+
"Value for 'shuffle' was not included in the `train_dataloader_params`."
|
|
311
|
+
)
|
|
312
|
+
elif ("shuffle" in train_dataloader_params) and (
|
|
313
|
+
not train_dataloader_params["shuffle"]
|
|
314
|
+
):
|
|
315
|
+
warn(
|
|
316
|
+
"Dataloader parameters include `shuffle=False`, this will be passed to "
|
|
317
|
+
"the training dataloader and may lead to lower quality results.",
|
|
318
|
+
stacklevel=1,
|
|
319
|
+
)
|
|
320
|
+
return train_dataloader_params
|
|
321
|
+
|
|
322
|
+
@model_validator(mode="after")
|
|
323
|
+
def std_only_with_mean(self: Self) -> Self:
|
|
324
|
+
"""
|
|
325
|
+
Check that mean and std are either both None, or both specified.
|
|
326
|
+
|
|
327
|
+
Returns
|
|
328
|
+
-------
|
|
329
|
+
Self
|
|
330
|
+
Validated data model.
|
|
331
|
+
|
|
332
|
+
Raises
|
|
333
|
+
------
|
|
334
|
+
ValueError
|
|
335
|
+
If std is not None and mean is None.
|
|
336
|
+
"""
|
|
337
|
+
# check that mean and std are either both None, or both specified
|
|
338
|
+
if (self.image_means and not self.image_stds) or (
|
|
339
|
+
self.image_stds and not self.image_means
|
|
340
|
+
):
|
|
341
|
+
raise ValueError(
|
|
342
|
+
"Mean and std must be either both None, or both specified."
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
346
|
+
len(self.image_means) != len(self.image_stds)
|
|
347
|
+
):
|
|
348
|
+
raise ValueError("Mean and std must be specified for each input channel.")
|
|
349
|
+
|
|
350
|
+
if (self.target_means and not self.target_stds) or (
|
|
351
|
+
self.target_stds and not self.target_means
|
|
352
|
+
):
|
|
353
|
+
raise ValueError(
|
|
354
|
+
"Mean and std must be either both None, or both specified "
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
elif self.target_means is not None and self.target_stds is not None:
|
|
358
|
+
if len(self.target_means) != len(self.target_stds):
|
|
359
|
+
raise ValueError(
|
|
360
|
+
"Mean and std must be either both None, or both specified for each "
|
|
361
|
+
"target channel."
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
return self
|
|
365
|
+
|
|
366
|
+
@model_validator(mode="after")
|
|
367
|
+
def validate_dimensions(self: Self) -> Self:
|
|
368
|
+
"""
|
|
369
|
+
Validate 2D/3D dimensions between axes, patch size and transforms.
|
|
370
|
+
|
|
371
|
+
Returns
|
|
372
|
+
-------
|
|
373
|
+
Self
|
|
374
|
+
Validated data model.
|
|
375
|
+
|
|
376
|
+
Raises
|
|
377
|
+
------
|
|
378
|
+
ValueError
|
|
379
|
+
If the transforms are not valid.
|
|
380
|
+
"""
|
|
381
|
+
if "Z" in self.axes:
|
|
382
|
+
if len(self.patch_size) != 3:
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"Patch size must have 3 dimensions if the data is 3D "
|
|
385
|
+
f"({self.axes})."
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
else:
|
|
389
|
+
if len(self.patch_size) != 2:
|
|
390
|
+
raise ValueError(
|
|
391
|
+
f"Patch size must have 3 dimensions if the data is 3D "
|
|
392
|
+
f"({self.axes})."
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
return self
|
|
396
|
+
|
|
397
|
+
def __str__(self) -> str:
|
|
398
|
+
"""
|
|
399
|
+
Pretty string reprensenting the configuration.
|
|
400
|
+
|
|
401
|
+
Returns
|
|
402
|
+
-------
|
|
403
|
+
str
|
|
404
|
+
Pretty string.
|
|
405
|
+
"""
|
|
406
|
+
return pformat(self.model_dump())
|
|
407
|
+
|
|
408
|
+
def _update(self, **kwargs: Any) -> None:
|
|
409
|
+
"""
|
|
410
|
+
Update multiple arguments at once.
|
|
411
|
+
|
|
412
|
+
Parameters
|
|
413
|
+
----------
|
|
414
|
+
**kwargs : Any
|
|
415
|
+
Keyword arguments to update.
|
|
416
|
+
"""
|
|
417
|
+
self.__dict__.update(kwargs)
|
|
418
|
+
self.__class__.model_validate(self.__dict__)
|
|
419
|
+
|
|
420
|
+
def set_means_and_stds(
|
|
421
|
+
self,
|
|
422
|
+
image_means: Union[NDArray, tuple, list, None],
|
|
423
|
+
image_stds: Union[NDArray, tuple, list, None],
|
|
424
|
+
target_means: Union[NDArray, tuple, list, None] | None = None,
|
|
425
|
+
target_stds: Union[NDArray, tuple, list, None] | None = None,
|
|
426
|
+
) -> None:
|
|
427
|
+
"""
|
|
428
|
+
Set mean and standard deviation of the data across channels.
|
|
429
|
+
|
|
430
|
+
This method should be used instead setting the fields directly, as it would
|
|
431
|
+
otherwise trigger a validation error.
|
|
432
|
+
|
|
433
|
+
Parameters
|
|
434
|
+
----------
|
|
435
|
+
image_means : numpy.ndarray, tuple or list
|
|
436
|
+
Mean values for normalization.
|
|
437
|
+
image_stds : numpy.ndarray, tuple or list
|
|
438
|
+
Standard deviation values for normalization.
|
|
439
|
+
target_means : numpy.ndarray, tuple or list, optional
|
|
440
|
+
Target mean values for normalization, by default ().
|
|
441
|
+
target_stds : numpy.ndarray, tuple or list, optional
|
|
442
|
+
Target standard deviation values for normalization, by default ().
|
|
443
|
+
"""
|
|
444
|
+
# make sure we pass a list
|
|
445
|
+
if image_means is not None:
|
|
446
|
+
image_means = list(image_means)
|
|
447
|
+
if image_stds is not None:
|
|
448
|
+
image_stds = list(image_stds)
|
|
449
|
+
if target_means is not None:
|
|
450
|
+
target_means = list(target_means)
|
|
451
|
+
if target_stds is not None:
|
|
452
|
+
target_stds = list(target_stds)
|
|
453
|
+
|
|
454
|
+
self._update(
|
|
455
|
+
image_means=image_means,
|
|
456
|
+
image_stds=image_stds,
|
|
457
|
+
target_means=target_means,
|
|
458
|
+
target_stds=target_stds,
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
def set_3D(self, axes: str, patch_size: list[int]) -> None:
|
|
462
|
+
"""
|
|
463
|
+
Set 3D parameters.
|
|
464
|
+
|
|
465
|
+
Parameters
|
|
466
|
+
----------
|
|
467
|
+
axes : str
|
|
468
|
+
Axes.
|
|
469
|
+
patch_size : list of int
|
|
470
|
+
Patch size.
|
|
471
|
+
"""
|
|
472
|
+
self._update(axes=axes, patch_size=patch_size)
|