careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""CAREamics transformation Pydantic models."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"N2VManipulateModel",
|
|
5
|
+
"XYFlipModel",
|
|
6
|
+
"NormalizeModel",
|
|
7
|
+
"XYRandomRotate90Model",
|
|
8
|
+
"XorYFlipModel",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
from .n2v_manipulate_model import N2VManipulateModel
|
|
13
|
+
from .normalize_model import NormalizeModel
|
|
14
|
+
from .xy_flip_model import XYFlipModel
|
|
15
|
+
from .xy_random_rotate90_model import XYRandomRotate90Model
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Pydantic model for the N2VManipulate transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, field_validator
|
|
6
|
+
|
|
7
|
+
from .transform_model import TransformModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class N2VManipulateModel(TransformModel):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent N2V manipulation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["N2VManipulate"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
roi_size : int
|
|
19
|
+
Size of the masking region, by default 11.
|
|
20
|
+
masked_pixel_percentage : float
|
|
21
|
+
Percentage of masked pixels, by default 0.2.
|
|
22
|
+
strategy : Literal["uniform", "median"]
|
|
23
|
+
Strategy pixel value replacement, by default "uniform".
|
|
24
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"]
|
|
25
|
+
Axis of the structN2V mask, by default "none".
|
|
26
|
+
struct_mask_span : int
|
|
27
|
+
Span of the structN2V mask, by default 5.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
model_config = ConfigDict(
|
|
31
|
+
validate_assignment=True,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
name: Literal["N2VManipulate"] = "N2VManipulate"
|
|
35
|
+
roi_size: int = Field(default=11, ge=3, le=21)
|
|
36
|
+
masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=10.0)
|
|
37
|
+
strategy: Literal["uniform", "median"] = Field(default="uniform")
|
|
38
|
+
struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
|
|
39
|
+
struct_mask_span: int = Field(default=5, ge=3, le=15)
|
|
40
|
+
|
|
41
|
+
@field_validator("roi_size", "struct_mask_span")
|
|
42
|
+
@classmethod
|
|
43
|
+
def odd_value(cls, v: int) -> int:
|
|
44
|
+
"""
|
|
45
|
+
Validate that the value is odd.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
v : int
|
|
50
|
+
Value to validate.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
int
|
|
55
|
+
The validated value.
|
|
56
|
+
|
|
57
|
+
Raises
|
|
58
|
+
------
|
|
59
|
+
ValueError
|
|
60
|
+
If the value is even.
|
|
61
|
+
"""
|
|
62
|
+
if v % 2 == 0:
|
|
63
|
+
raise ValueError("Size must be an odd number.")
|
|
64
|
+
return v
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Pydantic model for the Normalize transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from .transform_model import TransformModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NormalizeModel(TransformModel):
|
|
12
|
+
"""
|
|
13
|
+
Pydantic model used to represent Normalize transformation.
|
|
14
|
+
|
|
15
|
+
The Normalize transform is a zero mean and unit variance transformation.
|
|
16
|
+
|
|
17
|
+
Attributes
|
|
18
|
+
----------
|
|
19
|
+
name : Literal["Normalize"]
|
|
20
|
+
Name of the transformation.
|
|
21
|
+
mean : float
|
|
22
|
+
Mean value for normalization.
|
|
23
|
+
std : float
|
|
24
|
+
Standard deviation value for normalization.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
model_config = ConfigDict(
|
|
28
|
+
validate_assignment=True,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
name: Literal["Normalize"] = "Normalize"
|
|
32
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
33
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
34
|
+
target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
35
|
+
target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
36
|
+
|
|
37
|
+
@model_validator(mode="after")
|
|
38
|
+
def validate_means_stds(self: Self) -> Self:
|
|
39
|
+
"""Validate that the means and stds have the same length.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
Self
|
|
44
|
+
The instance of the model.
|
|
45
|
+
"""
|
|
46
|
+
if len(self.image_means) != len(self.image_stds):
|
|
47
|
+
raise ValueError("The number of image means and stds must be the same.")
|
|
48
|
+
|
|
49
|
+
if (self.target_means is None) != (self.target_stds is None):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Both target means and stds must be provided together, or bot None."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if self.target_means is not None and self.target_stds is not None:
|
|
55
|
+
if len(self.target_means) != len(self.target_stds):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"The number of target means and stds must be the same."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return self
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Parent model for the transforms."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TransformModel(BaseModel):
|
|
9
|
+
"""
|
|
10
|
+
Pydantic model used to represent a transformation.
|
|
11
|
+
|
|
12
|
+
The `model_dump` method is overwritten to exclude the name field.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : str
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
model_config = ConfigDict(
|
|
21
|
+
extra="forbid", # throw errors if the parameters are not properly passed
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
name: str
|
|
25
|
+
|
|
26
|
+
def model_dump(self, **kwargs) -> Dict[str, Any]:
|
|
27
|
+
"""
|
|
28
|
+
Return the model as a dictionary.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
**kwargs
|
|
33
|
+
Pydantic BaseMode model_dump method keyword arguments.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
Dict[str, Any]
|
|
38
|
+
Dictionary representation of the model.
|
|
39
|
+
"""
|
|
40
|
+
model_dict = super().model_dump(**kwargs)
|
|
41
|
+
|
|
42
|
+
# remove the name field
|
|
43
|
+
model_dict.pop("name")
|
|
44
|
+
|
|
45
|
+
return model_dict
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Pydantic model for the XYFlip transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from .transform_model import TransformModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYFlipModel(TransformModel):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent XYFlip transformation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["XYFlip"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
p : float
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
20
|
+
seed : Optional[int]
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(
|
|
25
|
+
validate_assignment=True,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
name: Literal["XYFlip"] = "XYFlip"
|
|
29
|
+
flip_x: bool = Field(
|
|
30
|
+
True,
|
|
31
|
+
description="Whether to flip along the X axis.",
|
|
32
|
+
)
|
|
33
|
+
flip_y: bool = Field(
|
|
34
|
+
True,
|
|
35
|
+
description="Whether to flip along the Y axis.",
|
|
36
|
+
)
|
|
37
|
+
p: float = Field(
|
|
38
|
+
0.5,
|
|
39
|
+
description="Probability of applying the transform.",
|
|
40
|
+
ge=0,
|
|
41
|
+
le=1,
|
|
42
|
+
)
|
|
43
|
+
seed: Optional[int] = None
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Pydantic model for the XYRandomRotate90 transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from .transform_model import TransformModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYRandomRotate90Model(TransformModel):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent the XY random 90 degree rotation transformation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["XYRandomRotate90"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
p : float
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
20
|
+
seed : Optional[int]
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(
|
|
25
|
+
validate_assignment=True,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
name: Literal["XYRandomRotate90"] = "XYRandomRotate90"
|
|
29
|
+
p: float = Field(
|
|
30
|
+
0.5,
|
|
31
|
+
description="Probability of applying the transform.",
|
|
32
|
+
ge=0,
|
|
33
|
+
le=1,
|
|
34
|
+
)
|
|
35
|
+
seed: Optional[int] = None
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pprint import pformat
|
|
6
|
+
from typing import Literal, Optional, Union
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from careamics.config.support import SupportedAlgorithm, SupportedLoss
|
|
12
|
+
|
|
13
|
+
from .architectures import CustomModel, LVAEModel
|
|
14
|
+
from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
|
|
15
|
+
from .nm_model import MultiChannelNMConfig
|
|
16
|
+
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class VAEAlgorithmConfig(BaseModel):
|
|
20
|
+
"""Algorithm configuration.
|
|
21
|
+
|
|
22
|
+
This Pydantic model validates the parameters governing the components of the
|
|
23
|
+
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
24
|
+
and learning rate scheduler to use.
|
|
25
|
+
|
|
26
|
+
Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
|
|
27
|
+
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
28
|
+
allows you to register your own architecture and select it using its name as
|
|
29
|
+
`name` in the custom pydantic model.
|
|
30
|
+
|
|
31
|
+
Attributes
|
|
32
|
+
----------
|
|
33
|
+
algorithm : algorithm: Literal["musplit", "denoisplit", "custom"]
|
|
34
|
+
Algorithm to use.
|
|
35
|
+
loss : Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
36
|
+
Loss function to use.
|
|
37
|
+
model : Union[LVAEModel, CustomModel]
|
|
38
|
+
Model architecture to use.
|
|
39
|
+
noise_model: Optional[MultiChannelNmModel]
|
|
40
|
+
Noise model to use.
|
|
41
|
+
noise_model_likelihood_model: Optional[NMLikelihoodModel]
|
|
42
|
+
Noise model likelihood model to use.
|
|
43
|
+
gaussian_likelihood_model: Optional[GaussianLikelihoodModel]
|
|
44
|
+
Gaussian likelihood model to use.
|
|
45
|
+
optimizer : OptimizerModel, optional
|
|
46
|
+
Optimizer to use.
|
|
47
|
+
lr_scheduler : LrSchedulerModel, optional
|
|
48
|
+
Learning rate scheduler to use.
|
|
49
|
+
|
|
50
|
+
Raises
|
|
51
|
+
------
|
|
52
|
+
ValueError
|
|
53
|
+
Algorithm parameter type validation errors.
|
|
54
|
+
ValueError
|
|
55
|
+
If the algorithm, loss and model are not compatible.
|
|
56
|
+
|
|
57
|
+
Examples
|
|
58
|
+
--------
|
|
59
|
+
# TODO add once finalized
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
# Pydantic class configuration
|
|
63
|
+
model_config = ConfigDict(
|
|
64
|
+
protected_namespaces=(), # allows to use model_* as a field name
|
|
65
|
+
validate_assignment=True,
|
|
66
|
+
extra="allow",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Mandatory fields
|
|
70
|
+
# defined in SupportedAlgorithm
|
|
71
|
+
# TODO: Use supported Enum classes for typing?
|
|
72
|
+
# - values can still be passed as strings and they will be cast to Enum
|
|
73
|
+
algorithm_type: Literal["vae"]
|
|
74
|
+
algorithm: Literal["musplit", "denoisplit", "custom"]
|
|
75
|
+
loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
76
|
+
model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
|
|
77
|
+
|
|
78
|
+
# TODO: these are configs, change naming of attrs
|
|
79
|
+
noise_model: Optional[MultiChannelNMConfig] = None
|
|
80
|
+
noise_model_likelihood_model: Optional[NMLikelihoodConfig] = None
|
|
81
|
+
gaussian_likelihood_model: Optional[GaussianLikelihoodConfig] = None
|
|
82
|
+
|
|
83
|
+
# Optional fields
|
|
84
|
+
optimizer: OptimizerModel = OptimizerModel()
|
|
85
|
+
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
86
|
+
|
|
87
|
+
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
88
|
+
|
|
89
|
+
@model_validator(mode="after")
|
|
90
|
+
def algorithm_cross_validation(self: Self) -> Self:
|
|
91
|
+
"""Validate the algorithm model based on `algorithm`.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Self
|
|
96
|
+
The validated model.
|
|
97
|
+
"""
|
|
98
|
+
# musplit
|
|
99
|
+
if self.algorithm == SupportedAlgorithm.MUSPLIT:
|
|
100
|
+
if self.loss != SupportedLoss.MUSPLIT:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Algorithm {self.algorithm} only supports loss `musplit`."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if self.algorithm == SupportedAlgorithm.DENOISPLIT:
|
|
106
|
+
if self.loss not in [
|
|
107
|
+
SupportedLoss.DENOISPLIT,
|
|
108
|
+
SupportedLoss.DENOISPLIT_MUSPLIT,
|
|
109
|
+
]:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Algorithm {self.algorithm} only supports loss `denoisplit` "
|
|
112
|
+
"or `denoisplit_musplit."
|
|
113
|
+
)
|
|
114
|
+
if (
|
|
115
|
+
self.loss == SupportedLoss.DENOISPLIT
|
|
116
|
+
and self.model.predict_logvar is not None
|
|
117
|
+
):
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"Algorithm `denoisplit` with loss `denoisplit` only supports "
|
|
120
|
+
"`predict_logvar` as `None`."
|
|
121
|
+
)
|
|
122
|
+
if self.noise_model is None:
|
|
123
|
+
raise ValueError("Algorithm `denoisplit` requires a noise model.")
|
|
124
|
+
# TODO: what if algorithm is not musplit or denoisplit (HDN?)
|
|
125
|
+
return self
|
|
126
|
+
|
|
127
|
+
@model_validator(mode="after")
|
|
128
|
+
def output_channels_validation(self: Self) -> Self:
|
|
129
|
+
"""Validate the consistency between number of out channels and noise models.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
Self
|
|
134
|
+
The validated model.
|
|
135
|
+
"""
|
|
136
|
+
if self.noise_model is not None:
|
|
137
|
+
assert self.model.output_channels == len(self.noise_model.noise_models), (
|
|
138
|
+
f"Number of output channels ({self.model.output_channels}) must match "
|
|
139
|
+
f"the number of noise models ({len(self.noise_model.noise_models)})."
|
|
140
|
+
)
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
@model_validator(mode="after")
|
|
144
|
+
def predict_logvar_validation(self: Self) -> Self:
|
|
145
|
+
"""Validate the consistency of `predict_logvar` throughout the model.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
Self
|
|
150
|
+
The validated model.
|
|
151
|
+
"""
|
|
152
|
+
if self.gaussian_likelihood_model is not None:
|
|
153
|
+
assert (
|
|
154
|
+
self.model.predict_logvar
|
|
155
|
+
== self.gaussian_likelihood_model.predict_logvar
|
|
156
|
+
), (
|
|
157
|
+
f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
|
|
158
|
+
"Gaussian likelihood model `predict_logvar` "
|
|
159
|
+
f"({self.gaussian_likelihood_model.predict_logvar}).",
|
|
160
|
+
)
|
|
161
|
+
return self
|
|
162
|
+
|
|
163
|
+
def __str__(self) -> str:
|
|
164
|
+
"""Pretty string representing the configuration.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
str
|
|
169
|
+
Pretty string.
|
|
170
|
+
"""
|
|
171
|
+
return pformat(self.model_dump())
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Validator functions.
|
|
3
|
+
|
|
4
|
+
These functions are used to validate dimensions and axes of inputs.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
_AXES = "STCZYX"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def check_axes_validity(axes: str) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Sanity check on axes.
|
|
15
|
+
|
|
16
|
+
The constraints on the axes are the following:
|
|
17
|
+
- must be a combination of 'STCZYX'
|
|
18
|
+
- must not contain duplicates
|
|
19
|
+
- must contain at least 2 contiguous axes: X and Y
|
|
20
|
+
- must contain at most 4 axes
|
|
21
|
+
- cannot contain both S and T axes
|
|
22
|
+
|
|
23
|
+
Axes do not need to be in the order 'STCZYX', as this depends on the user data.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
axes : str
|
|
28
|
+
Axes to validate.
|
|
29
|
+
"""
|
|
30
|
+
_axes = axes.upper()
|
|
31
|
+
|
|
32
|
+
# Minimum is 2 (XY) and maximum is 4 (TZYX)
|
|
33
|
+
if len(_axes) < 2 or len(_axes) > 6:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if "YX" not in _axes and "XY" not in _axes:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"Invalid axes {axes}. Must contain at least X and Y axes consecutively."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# all characters must be in REF_AXES = 'STCZYX'
|
|
44
|
+
if not all(s in _AXES for s in _axes):
|
|
45
|
+
raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.")
|
|
46
|
+
|
|
47
|
+
# check for repeating characters
|
|
48
|
+
for i, s in enumerate(_axes):
|
|
49
|
+
if i != _axes.rfind(s):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Invalid axes {axes}. Cannot contain duplicate axes"
|
|
52
|
+
f" (got multiple {axes[i]})."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def value_ge_than_8_power_of_2(
|
|
57
|
+
value: int,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Validate that the value is greater or equal than 8 and a power of 2.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
value : int
|
|
65
|
+
Value to validate.
|
|
66
|
+
|
|
67
|
+
Raises
|
|
68
|
+
------
|
|
69
|
+
ValueError
|
|
70
|
+
If the value is smaller than 8.
|
|
71
|
+
ValueError
|
|
72
|
+
If the value is not a power of 2.
|
|
73
|
+
"""
|
|
74
|
+
if value < 8:
|
|
75
|
+
raise ValueError(f"Value must be greater than 8 (got {value}).")
|
|
76
|
+
|
|
77
|
+
if (value & (value - 1)) != 0:
|
|
78
|
+
raise ValueError(f"Value must be a power of 2 (got {value}).")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def patch_size_ge_than_8_power_of_2(
|
|
82
|
+
patch_list: Optional[Union[List[int], Union[Tuple[int, ...]]]],
|
|
83
|
+
) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
patch_list : Optional[Union[List[int]]]
|
|
90
|
+
Patch size.
|
|
91
|
+
|
|
92
|
+
Raises
|
|
93
|
+
------
|
|
94
|
+
ValueError
|
|
95
|
+
If the patch size if smaller than 8.
|
|
96
|
+
ValueError
|
|
97
|
+
If the patch size is not a power of 2.
|
|
98
|
+
"""
|
|
99
|
+
if patch_list is not None:
|
|
100
|
+
for dim in patch_list:
|
|
101
|
+
value_ge_than_8_power_of_2(dim)
|
careamics/conftest.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""File used to discover python modules and run doctest.
|
|
2
|
+
|
|
3
|
+
See https://sybil.readthedocs.io/en/latest/use.html#pytest
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
from pytest import TempPathFactory
|
|
10
|
+
from sybil import Sybil
|
|
11
|
+
from sybil.parsers.codeblock import PythonCodeBlockParser
|
|
12
|
+
from sybil.parsers.doctest import DocTestParser
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture(scope="module")
|
|
16
|
+
def my_path(tmpdir_factory: TempPathFactory) -> Path:
|
|
17
|
+
"""Fixture used in doctest to create a temporary directory.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
tmpdir_factory : TempPathFactory
|
|
22
|
+
Temporary path factory from pytest.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Path
|
|
27
|
+
Temporary directory path.
|
|
28
|
+
"""
|
|
29
|
+
return tmpdir_factory.mktemp("my_path")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
pytest_collect_file = Sybil(
|
|
33
|
+
parsers=[
|
|
34
|
+
DocTestParser(),
|
|
35
|
+
PythonCodeBlockParser(future_imports=["print_function"]),
|
|
36
|
+
],
|
|
37
|
+
pattern="*.py",
|
|
38
|
+
fixtures=["my_path"],
|
|
39
|
+
).pytest()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Dataset module."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"InMemoryDataset",
|
|
5
|
+
"InMemoryPredDataset",
|
|
6
|
+
"InMemoryTiledPredDataset",
|
|
7
|
+
"PathIterableDataset",
|
|
8
|
+
"IterableTiledPredDataset",
|
|
9
|
+
"IterablePredDataset",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
from .in_memory_dataset import InMemoryDataset
|
|
13
|
+
from .in_memory_pred_dataset import InMemoryPredDataset
|
|
14
|
+
from .in_memory_tiled_pred_dataset import InMemoryTiledPredDataset
|
|
15
|
+
from .iterable_dataset import PathIterableDataset
|
|
16
|
+
from .iterable_pred_dataset import IterablePredDataset
|
|
17
|
+
from .iterable_tiled_pred_dataset import IterableTiledPredDataset
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Files and arrays utils used in the datasets."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"reshape_array",
|
|
5
|
+
"compute_normalization_stats",
|
|
6
|
+
"get_files_size",
|
|
7
|
+
"list_files",
|
|
8
|
+
"validate_source_target_files",
|
|
9
|
+
"iterate_over_files",
|
|
10
|
+
"WelfordStatistics",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from .dataset_utils import (
|
|
15
|
+
reshape_array,
|
|
16
|
+
)
|
|
17
|
+
from .file_utils import get_files_size, list_files, validate_source_target_files
|
|
18
|
+
from .iterate_over_files import iterate_over_files
|
|
19
|
+
from .running_stats import WelfordStatistics, compute_normalization_stats
|