careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Pydantic model for the Normalize transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Self
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
+
|
|
7
|
+
from .transform_config import TransformConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NormalizeConfig(TransformConfig):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent Normalize transformation.
|
|
13
|
+
|
|
14
|
+
The Normalize transform is a zero mean and unit variance transformation.
|
|
15
|
+
|
|
16
|
+
Attributes
|
|
17
|
+
----------
|
|
18
|
+
name : Literal["Normalize"]
|
|
19
|
+
Name of the transformation.
|
|
20
|
+
mean : float
|
|
21
|
+
Mean value for normalization.
|
|
22
|
+
std : float
|
|
23
|
+
Standard deviation value for normalization.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
model_config = ConfigDict(
|
|
27
|
+
validate_assignment=True,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
name: Literal["Normalize"] = "Normalize"
|
|
31
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
32
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
33
|
+
target_means: list | None = Field(default=None, min_length=0, max_length=32)
|
|
34
|
+
target_stds: list | None = Field(default=None, min_length=0, max_length=32)
|
|
35
|
+
|
|
36
|
+
@model_validator(mode="after")
|
|
37
|
+
def validate_means_stds(self: Self) -> Self:
|
|
38
|
+
"""Validate that the means and stds have the same length.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
Self
|
|
43
|
+
The instance of the model.
|
|
44
|
+
"""
|
|
45
|
+
if len(self.image_means) != len(self.image_stds):
|
|
46
|
+
raise ValueError("The number of image means and stds must be the same.")
|
|
47
|
+
|
|
48
|
+
if (self.target_means is None) != (self.target_stds is None):
|
|
49
|
+
raise ValueError(
|
|
50
|
+
"Both target means and stds must be provided together, or bot None."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if self.target_means is not None and self.target_stds is not None:
|
|
54
|
+
if len(self.target_means) != len(self.target_stds):
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"The number of target means and stds must be the same."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return self
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Parent model for the transforms."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TransformConfig(BaseModel):
|
|
9
|
+
"""
|
|
10
|
+
Pydantic model used to represent a transformation.
|
|
11
|
+
|
|
12
|
+
The `model_dump` method is overwritten to exclude the name field.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : str
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
model_config = ConfigDict(
|
|
21
|
+
extra="forbid", # throw errors if the parameters are not properly passed
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
name: str
|
|
25
|
+
|
|
26
|
+
def model_dump(self, **kwargs) -> dict[str, Any]:
|
|
27
|
+
"""
|
|
28
|
+
Return the model as a dictionary.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
**kwargs
|
|
33
|
+
Pydantic BaseMode model_dump method keyword arguments.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
{str: Any}
|
|
38
|
+
Dictionary representation of the model.
|
|
39
|
+
"""
|
|
40
|
+
model_dict = super().model_dump(**kwargs)
|
|
41
|
+
|
|
42
|
+
# remove the name field
|
|
43
|
+
model_dict.pop("name")
|
|
44
|
+
|
|
45
|
+
return model_dict
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""Type used to represent all transformations users can create."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import Discriminator
|
|
6
|
+
|
|
7
|
+
from .normalize_config import NormalizeConfig
|
|
8
|
+
from .xy_flip_config import XYFlipConfig
|
|
9
|
+
from .xy_random_rotate90_config import XYRandomRotate90Config
|
|
10
|
+
|
|
11
|
+
NORM_AND_SPATIAL_UNION = Annotated[
|
|
12
|
+
Union[
|
|
13
|
+
NormalizeConfig,
|
|
14
|
+
XYFlipConfig,
|
|
15
|
+
XYRandomRotate90Config,
|
|
16
|
+
],
|
|
17
|
+
Discriminator("name"), # used to tell the different transform models apart
|
|
18
|
+
]
|
|
19
|
+
"""All transforms including normalization."""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
SPATIAL_TRANSFORMS_UNION = Annotated[
|
|
23
|
+
Union[
|
|
24
|
+
XYFlipConfig,
|
|
25
|
+
XYRandomRotate90Config,
|
|
26
|
+
],
|
|
27
|
+
Discriminator("name"), # used to tell the different transform models apart
|
|
28
|
+
]
|
|
29
|
+
"""Available spatial transforms in CAREamics."""
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Pydantic model for the XYFlip transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from .transform_config import TransformConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYFlipConfig(TransformConfig):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent XYFlip transformation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["XYFlip"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
p : float
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
20
|
+
seed : Optional[int]
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(
|
|
25
|
+
validate_assignment=True,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
name: Literal["XYFlip"] = "XYFlip"
|
|
29
|
+
flip_x: bool = Field(
|
|
30
|
+
True,
|
|
31
|
+
description="Whether to flip along the X axis.",
|
|
32
|
+
)
|
|
33
|
+
flip_y: bool = Field(
|
|
34
|
+
True,
|
|
35
|
+
description="Whether to flip along the Y axis.",
|
|
36
|
+
)
|
|
37
|
+
p: float = Field(
|
|
38
|
+
0.5,
|
|
39
|
+
description="Probability of applying the transform.",
|
|
40
|
+
ge=0,
|
|
41
|
+
le=1,
|
|
42
|
+
)
|
|
43
|
+
seed: int | None = None
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Pydantic model for the XYRandomRotate90 transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from .transform_config import TransformConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYRandomRotate90Config(TransformConfig):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent the XY random 90 degree rotation transformation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["XYRandomRotate90"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
p : float
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
20
|
+
seed : Optional[int]
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(
|
|
25
|
+
validate_assignment=True,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
|
|
29
|
+
p: float = Field(
|
|
30
|
+
0.5,
|
|
31
|
+
description="Probability of applying the transform.",
|
|
32
|
+
ge=0,
|
|
33
|
+
le=1,
|
|
34
|
+
)
|
|
35
|
+
seed: int | None = None
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""I/O functions for Configuration objects."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from careamics.config import Configuration
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
12
|
+
"""
|
|
13
|
+
Load configuration from a yaml file.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
path : str or Path
|
|
18
|
+
Path to the configuration.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
Configuration
|
|
23
|
+
Configuration.
|
|
24
|
+
|
|
25
|
+
Raises
|
|
26
|
+
------
|
|
27
|
+
FileNotFoundError
|
|
28
|
+
If the configuration file does not exist.
|
|
29
|
+
"""
|
|
30
|
+
# load dictionary from yaml
|
|
31
|
+
if not Path(path).exists():
|
|
32
|
+
raise FileNotFoundError(
|
|
33
|
+
f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
|
|
37
|
+
|
|
38
|
+
return Configuration(**dictionary)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
42
|
+
"""
|
|
43
|
+
Save configuration to path.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
config : Configuration
|
|
48
|
+
Configuration to save.
|
|
49
|
+
path : str or Path
|
|
50
|
+
Path to a existing folder in which to save the configuration, or to a valid
|
|
51
|
+
configuration file path (uses a .yml or .yaml extension).
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
Path
|
|
56
|
+
Path object representing the configuration.
|
|
57
|
+
|
|
58
|
+
Raises
|
|
59
|
+
------
|
|
60
|
+
ValueError
|
|
61
|
+
If the path does not point to an existing directory or .yml file.
|
|
62
|
+
"""
|
|
63
|
+
# make sure path is a Path object
|
|
64
|
+
config_path = Path(path)
|
|
65
|
+
|
|
66
|
+
# check if path is pointing to an existing directory or .yml file
|
|
67
|
+
if config_path.exists():
|
|
68
|
+
if config_path.is_dir():
|
|
69
|
+
config_path = Path(config_path, "config.yml")
|
|
70
|
+
elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# save configuration as dictionary to yaml
|
|
81
|
+
with open(config_path, "w") as f:
|
|
82
|
+
# dump configuration
|
|
83
|
+
yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
|
|
84
|
+
|
|
85
|
+
return config_path
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Validator utilities."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"check_axes_validity",
|
|
5
|
+
"check_czi_axes_validity",
|
|
6
|
+
"model_matching_in_out_channels",
|
|
7
|
+
"model_without_final_activation",
|
|
8
|
+
"model_without_n2v2",
|
|
9
|
+
"patch_size_ge_than_8_power_of_2",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
from .axes_validators import check_axes_validity, check_czi_axes_validity
|
|
13
|
+
from .model_validators import (
|
|
14
|
+
model_matching_in_out_channels,
|
|
15
|
+
model_without_final_activation,
|
|
16
|
+
model_without_n2v2,
|
|
17
|
+
)
|
|
18
|
+
from .patch_validators import patch_size_ge_than_8_power_of_2
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Axes validation utilities."""
|
|
2
|
+
|
|
3
|
+
_AXES = "STCZYX"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def check_axes_validity(axes: str) -> None:
|
|
7
|
+
"""
|
|
8
|
+
Sanity check on axes.
|
|
9
|
+
|
|
10
|
+
The constraints on the axes are the following:
|
|
11
|
+
- must be a combination of 'STCZYX'
|
|
12
|
+
- must not contain duplicates
|
|
13
|
+
- must contain at least 2 contiguous axes: X and Y
|
|
14
|
+
- must contain at most 4 axes
|
|
15
|
+
|
|
16
|
+
Axes do not need to be in the order 'STCZYX', as this depends on the user data.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
axes : str
|
|
21
|
+
Axes to validate.
|
|
22
|
+
"""
|
|
23
|
+
_axes = axes.upper()
|
|
24
|
+
|
|
25
|
+
# Minimum is 2 (XY) and maximum is 4 (TZYX)
|
|
26
|
+
if len(_axes) < 2 or len(_axes) > 6:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if "YX" not in _axes and "XY" not in _axes:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"Invalid axes {axes}. Must contain at least X and Y axes consecutively."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# all characters must be in REF_AXES = 'STCZYX'
|
|
37
|
+
if not all(s in _AXES for s in _axes):
|
|
38
|
+
raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.")
|
|
39
|
+
|
|
40
|
+
# check for repeating characters
|
|
41
|
+
for i, s in enumerate(_axes):
|
|
42
|
+
if i != _axes.rfind(s):
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"Invalid axes {axes}. Cannot contain duplicate axes"
|
|
45
|
+
f" (got multiple {axes[i]})."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def check_czi_axes_validity(axes: str) -> bool:
|
|
50
|
+
"""
|
|
51
|
+
Check if the provided axes string is valid for CZI files.
|
|
52
|
+
|
|
53
|
+
CZI axes is always in the "SC(Z/T)YX" format, where Z or T are optional, and S and C
|
|
54
|
+
can be singleton dimensions, but must be provided.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
axes : str
|
|
59
|
+
The axes string to validate.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
bool
|
|
64
|
+
True if the axes string is valid, False otherwise.
|
|
65
|
+
"""
|
|
66
|
+
valid_axes = {"S", "C", "Z", "T", "Y", "X"}
|
|
67
|
+
axes_set = set(axes)
|
|
68
|
+
|
|
69
|
+
# check for invalid characters
|
|
70
|
+
if not axes_set.issubset(valid_axes):
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
# check for mandatory axes
|
|
74
|
+
if not ({"S", "C", "Y", "X"}.issubset(axes_set)):
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
# check for mutually exclusive axes
|
|
78
|
+
if "Z" in axes_set and "T" in axes_set:
|
|
79
|
+
return False
|
|
80
|
+
|
|
81
|
+
# check for correct order
|
|
82
|
+
order = "SCZYX" if "Z" in axes else "SCTYX"
|
|
83
|
+
last_index = -1
|
|
84
|
+
for axis in axes:
|
|
85
|
+
current_index = order.find(axis)
|
|
86
|
+
if current_index < last_index:
|
|
87
|
+
return False
|
|
88
|
+
last_index = current_index
|
|
89
|
+
|
|
90
|
+
return True
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Architecture model validators."""
|
|
2
|
+
|
|
3
|
+
from careamics.config.architectures import UNetConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def model_without_n2v2(model: UNetConfig) -> UNetConfig:
|
|
7
|
+
"""Validate that the Unet model does not have the n2v2 attribute.
|
|
8
|
+
|
|
9
|
+
Parameters
|
|
10
|
+
----------
|
|
11
|
+
model : UNetModel
|
|
12
|
+
Model to validate.
|
|
13
|
+
|
|
14
|
+
Returns
|
|
15
|
+
-------
|
|
16
|
+
UNetModel
|
|
17
|
+
The validated model.
|
|
18
|
+
|
|
19
|
+
Raises
|
|
20
|
+
------
|
|
21
|
+
ValueError
|
|
22
|
+
If the model has the `n2v2` attribute set to `True`.
|
|
23
|
+
"""
|
|
24
|
+
if model.n2v2:
|
|
25
|
+
raise ValueError(
|
|
26
|
+
"The algorithm does not support the `n2v2` attribute in the model. "
|
|
27
|
+
"Set it to `False`."
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
return model
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def model_without_final_activation(model: UNetConfig) -> UNetConfig:
|
|
34
|
+
"""Validate that the UNet model does not have the final_activation.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
model : UNetModel
|
|
39
|
+
Model to validate.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
UNetModel
|
|
44
|
+
The validated model.
|
|
45
|
+
|
|
46
|
+
Raises
|
|
47
|
+
------
|
|
48
|
+
ValueError
|
|
49
|
+
If the model has the final_activation attribute set.
|
|
50
|
+
"""
|
|
51
|
+
if model.final_activation != "None":
|
|
52
|
+
raise ValueError(
|
|
53
|
+
"The algorithm does not support a `final_activation` in the model. "
|
|
54
|
+
'Set it to `"None"`.'
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return model
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def model_matching_in_out_channels(model: UNetConfig) -> UNetConfig:
|
|
61
|
+
"""Validate that the UNet model has the same number of channel inputs and outputs.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
model : UNetModel
|
|
66
|
+
Model to validate.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
UNetModel
|
|
71
|
+
Validated model.
|
|
72
|
+
|
|
73
|
+
Raises
|
|
74
|
+
------
|
|
75
|
+
ValueError
|
|
76
|
+
If the model has different number of input and output channels.
|
|
77
|
+
"""
|
|
78
|
+
if model.num_classes != model.in_channels:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"The algorithm requires the same number of input and output channels. "
|
|
81
|
+
"Make sure that `in_channels` and `num_classes` are equal."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return model
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Validator functions.
|
|
3
|
+
|
|
4
|
+
These functions are used to validate dimensions and axes of inputs.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _value_ge_than_8_power_of_2(
|
|
11
|
+
value: int,
|
|
12
|
+
) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Validate that the value is greater or equal than 8 and a power of 2.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
value : int
|
|
19
|
+
Value to validate.
|
|
20
|
+
|
|
21
|
+
Raises
|
|
22
|
+
------
|
|
23
|
+
ValueError
|
|
24
|
+
If the value is smaller than 8.
|
|
25
|
+
ValueError
|
|
26
|
+
If the value is not a power of 2.
|
|
27
|
+
"""
|
|
28
|
+
if value < 8:
|
|
29
|
+
raise ValueError(f"Value must be greater than 8 (got {value}).")
|
|
30
|
+
|
|
31
|
+
if (value & (value - 1)) != 0:
|
|
32
|
+
raise ValueError(f"Value must be a power of 2 (got {value}).")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def patch_size_ge_than_8_power_of_2(
|
|
36
|
+
patch_list: Sequence[int] | None,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
patch_list : Sequence of int, or None
|
|
44
|
+
Patch size.
|
|
45
|
+
|
|
46
|
+
Raises
|
|
47
|
+
------
|
|
48
|
+
ValueError
|
|
49
|
+
If the patch size if smaller than 8.
|
|
50
|
+
ValueError
|
|
51
|
+
If the patch size is not a power of 2.
|
|
52
|
+
"""
|
|
53
|
+
if patch_list is not None:
|
|
54
|
+
for dim in patch_list:
|
|
55
|
+
_value_ge_than_8_power_of_2(dim)
|
careamics/conftest.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""File used to discover python modules and run doctest.
|
|
2
|
+
|
|
3
|
+
See https://sybil.readthedocs.io/en/latest/use.html#pytest
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
from pytest import TempPathFactory
|
|
10
|
+
from sybil import Sybil
|
|
11
|
+
from sybil.parsers.codeblock import PythonCodeBlockParser
|
|
12
|
+
from sybil.parsers.doctest import DocTestParser
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture(scope="module")
|
|
16
|
+
def my_path(tmpdir_factory: TempPathFactory) -> Path:
|
|
17
|
+
"""Fixture used in doctest to create a temporary directory.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
tmpdir_factory : TempPathFactory
|
|
22
|
+
Temporary path factory from pytest.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Path
|
|
27
|
+
Temporary directory path.
|
|
28
|
+
"""
|
|
29
|
+
return tmpdir_factory.mktemp("my_path")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
pytest_collect_file = Sybil(
|
|
33
|
+
parsers=[
|
|
34
|
+
DocTestParser(),
|
|
35
|
+
PythonCodeBlockParser(future_imports=["print_function"]),
|
|
36
|
+
],
|
|
37
|
+
pattern="*.py",
|
|
38
|
+
fixtures=["my_path"],
|
|
39
|
+
).pytest()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Dataset module."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"InMemoryDataset",
|
|
5
|
+
"InMemoryPredDataset",
|
|
6
|
+
"InMemoryTiledPredDataset",
|
|
7
|
+
"IterablePredDataset",
|
|
8
|
+
"IterableTiledPredDataset",
|
|
9
|
+
"PathIterableDataset",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
from .in_memory_dataset import InMemoryDataset
|
|
13
|
+
from .in_memory_pred_dataset import InMemoryPredDataset
|
|
14
|
+
from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
|
|
15
|
+
from .iterable_dataset import PathIterableDataset
|
|
16
|
+
from .iterable_pred_dataset import IterablePredDataset
|
|
17
|
+
from .iterable_tiled_pred_dataset import IterableTiledPredDataset
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Files and arrays utils used in the datasets."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"WelfordStatistics",
|
|
5
|
+
"compute_normalization_stats",
|
|
6
|
+
"get_files_size",
|
|
7
|
+
"iterate_over_files",
|
|
8
|
+
"list_files",
|
|
9
|
+
"reshape_array",
|
|
10
|
+
"validate_source_target_files",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from .dataset_utils import (
|
|
15
|
+
reshape_array,
|
|
16
|
+
)
|
|
17
|
+
from .file_utils import get_files_size, list_files, validate_source_target_files
|
|
18
|
+
from .iterate_over_files import iterate_over_files
|
|
19
|
+
from .running_stats import WelfordStatistics, compute_normalization_stats
|