careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""Pydantic model representing CAREamics prediction configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Literal, Self, Union
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
|
+
|
|
9
|
+
from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class InferenceConfig(BaseModel):
|
|
13
|
+
"""Configuration class for the prediction model."""
|
|
14
|
+
|
|
15
|
+
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
16
|
+
|
|
17
|
+
data_type: Literal["array", "tiff", "czi", "custom"] # As defined in SupportedData
|
|
18
|
+
"""Type of input data: numpy.ndarray (array) or path (tiff, czi, or custom)."""
|
|
19
|
+
|
|
20
|
+
tile_size: Union[list[int]] | None = Field(default=None, min_length=2, max_length=3)
|
|
21
|
+
"""Tile size of prediction, only effective if `tile_overlap` is specified."""
|
|
22
|
+
|
|
23
|
+
tile_overlap: Union[list[int]] | None = Field(
|
|
24
|
+
default=None, min_length=2, max_length=3
|
|
25
|
+
)
|
|
26
|
+
"""Overlap between tiles, only effective if `tile_size` is specified."""
|
|
27
|
+
|
|
28
|
+
axes: str
|
|
29
|
+
"""Data axes (TSCZYX) in the order of the input data."""
|
|
30
|
+
|
|
31
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
32
|
+
"""Mean values for each input channel."""
|
|
33
|
+
|
|
34
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
35
|
+
"""Standard deviation values for each input channel."""
|
|
36
|
+
|
|
37
|
+
# TODO only default TTAs are supported for now
|
|
38
|
+
tta_transforms: bool = Field(default=True)
|
|
39
|
+
"""Whether to apply test-time augmentation (all 90 degrees rotations and flips)."""
|
|
40
|
+
|
|
41
|
+
# Dataloader parameters
|
|
42
|
+
batch_size: int = Field(default=1, ge=1)
|
|
43
|
+
"""Batch size for prediction."""
|
|
44
|
+
|
|
45
|
+
@field_validator("tile_overlap")
|
|
46
|
+
@classmethod
|
|
47
|
+
def all_elements_non_zero_even(
|
|
48
|
+
cls, tile_overlap: list[int] | None
|
|
49
|
+
) -> list[int] | None:
|
|
50
|
+
"""
|
|
51
|
+
Validate tile overlap.
|
|
52
|
+
|
|
53
|
+
Overlaps must be non-zero, positive and even.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
tile_overlap : list[int] or None
|
|
58
|
+
Patch size.
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
list[int] or None
|
|
63
|
+
Validated tile overlap.
|
|
64
|
+
|
|
65
|
+
Raises
|
|
66
|
+
------
|
|
67
|
+
ValueError
|
|
68
|
+
If the patch size is 0.
|
|
69
|
+
ValueError
|
|
70
|
+
If the patch size is not even.
|
|
71
|
+
"""
|
|
72
|
+
if tile_overlap is not None:
|
|
73
|
+
for dim in tile_overlap:
|
|
74
|
+
if dim < 1:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Patch size must be non-zero positive (got {dim})."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if dim % 2 != 0:
|
|
80
|
+
raise ValueError(f"Patch size must be even (got {dim}).")
|
|
81
|
+
|
|
82
|
+
return tile_overlap
|
|
83
|
+
|
|
84
|
+
@field_validator("tile_size")
|
|
85
|
+
@classmethod
|
|
86
|
+
def tile_min_8_power_of_2(cls, tile_list: list[int] | None) -> list[int] | None:
|
|
87
|
+
"""
|
|
88
|
+
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
tile_list : list of int
|
|
93
|
+
Patch size.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
list of int
|
|
98
|
+
Validated patch size.
|
|
99
|
+
|
|
100
|
+
Raises
|
|
101
|
+
------
|
|
102
|
+
ValueError
|
|
103
|
+
If the patch size if smaller than 8.
|
|
104
|
+
ValueError
|
|
105
|
+
If the patch size is not a power of 2.
|
|
106
|
+
"""
|
|
107
|
+
patch_size_ge_than_8_power_of_2(tile_list)
|
|
108
|
+
|
|
109
|
+
return tile_list
|
|
110
|
+
|
|
111
|
+
@field_validator("axes")
|
|
112
|
+
@classmethod
|
|
113
|
+
def axes_valid(cls, axes: str) -> str:
|
|
114
|
+
"""
|
|
115
|
+
Validate axes.
|
|
116
|
+
|
|
117
|
+
Axes must:
|
|
118
|
+
- be a combination of 'STCZYX'
|
|
119
|
+
- not contain duplicates
|
|
120
|
+
- contain at least 2 contiguous axes: X and Y
|
|
121
|
+
- contain at most 4 axes
|
|
122
|
+
- not contain both S and T axes
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
axes : str
|
|
127
|
+
Axes to validate.
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
str
|
|
132
|
+
Validated axes.
|
|
133
|
+
|
|
134
|
+
Raises
|
|
135
|
+
------
|
|
136
|
+
ValueError
|
|
137
|
+
If axes are not valid.
|
|
138
|
+
"""
|
|
139
|
+
# Validate axes
|
|
140
|
+
check_axes_validity(axes)
|
|
141
|
+
|
|
142
|
+
return axes
|
|
143
|
+
|
|
144
|
+
@model_validator(mode="after")
|
|
145
|
+
def validate_dimensions(self: Self) -> Self:
|
|
146
|
+
"""
|
|
147
|
+
Validate 2D/3D dimensions between axes and tile size.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
Self
|
|
152
|
+
Validated prediction model.
|
|
153
|
+
"""
|
|
154
|
+
expected_len = 3 if "Z" in self.axes else 2
|
|
155
|
+
|
|
156
|
+
if self.tile_size is not None and self.tile_overlap is not None:
|
|
157
|
+
if len(self.tile_size) != expected_len:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Tile size must have {expected_len} dimensions given axes "
|
|
160
|
+
f"{self.axes} (got {self.tile_size})."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if len(self.tile_overlap) != expected_len:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"Tile overlap must have {expected_len} dimensions given axes "
|
|
166
|
+
f"{self.axes} (got {self.tile_overlap})."
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if any(
|
|
170
|
+
(i >= j)
|
|
171
|
+
for i, j in zip(self.tile_overlap, self.tile_size, strict=False)
|
|
172
|
+
):
|
|
173
|
+
raise ValueError("Tile overlap must be smaller than tile size.")
|
|
174
|
+
|
|
175
|
+
return self
|
|
176
|
+
|
|
177
|
+
@model_validator(mode="after")
|
|
178
|
+
def std_only_with_mean(self: Self) -> Self:
|
|
179
|
+
"""
|
|
180
|
+
Check that mean and std are either both None, or both specified.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
Self
|
|
185
|
+
Validated prediction model.
|
|
186
|
+
|
|
187
|
+
Raises
|
|
188
|
+
------
|
|
189
|
+
ValueError
|
|
190
|
+
If std is not None and mean is None.
|
|
191
|
+
"""
|
|
192
|
+
# check that mean and std are either both None, or both specified
|
|
193
|
+
if not self.image_means and not self.image_stds:
|
|
194
|
+
raise ValueError("Mean and std must be specified during inference.")
|
|
195
|
+
|
|
196
|
+
if (self.image_means and not self.image_stds) or (
|
|
197
|
+
self.image_stds and not self.image_means
|
|
198
|
+
):
|
|
199
|
+
raise ValueError(
|
|
200
|
+
"Mean and std must be either both None, or both specified."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
204
|
+
len(self.image_means) != len(self.image_stds)
|
|
205
|
+
):
|
|
206
|
+
raise ValueError(
|
|
207
|
+
"Mean and std must be specified for each " "input channel."
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return self
|
|
211
|
+
|
|
212
|
+
def _update(self, **kwargs: Any) -> None:
|
|
213
|
+
"""
|
|
214
|
+
Update multiple arguments at once.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
**kwargs : Any
|
|
219
|
+
Key-value pairs of arguments to update.
|
|
220
|
+
"""
|
|
221
|
+
self.__dict__.update(kwargs)
|
|
222
|
+
self.__class__.model_validate(self.__dict__)
|
|
223
|
+
|
|
224
|
+
def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
|
|
225
|
+
"""
|
|
226
|
+
Set 3D parameters.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
axes : str
|
|
231
|
+
Axes.
|
|
232
|
+
tile_size : list of int
|
|
233
|
+
Tile size.
|
|
234
|
+
tile_overlap : list of int
|
|
235
|
+
Tile overlap.
|
|
236
|
+
"""
|
|
237
|
+
self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
|