careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Algorithms supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from careamics.utils import BaseEnum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SupportedAlgorithm(str, BaseEnum):
|
|
9
|
+
"""Algorithms available in CAREamics.
|
|
10
|
+
|
|
11
|
+
# TODO
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
N2V = "n2v"
|
|
15
|
+
CARE = "care"
|
|
16
|
+
N2N = "n2n"
|
|
17
|
+
CUSTOM = "custom"
|
|
18
|
+
# PN2V = "pn2v"
|
|
19
|
+
# HDN = "hdn"
|
|
20
|
+
# SEG = "segmentation"
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Architectures supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportedArchitecture(str, BaseEnum):
|
|
7
|
+
"""Supported architectures.
|
|
8
|
+
|
|
9
|
+
# TODO add details, in particular where to find the API for the models
|
|
10
|
+
|
|
11
|
+
- UNet: classical UNet compatible with N2V2
|
|
12
|
+
- VAE: variational Autoencoder
|
|
13
|
+
- Custom: custom model registered with `@register_model` decorator
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
UNET = "UNet"
|
|
17
|
+
VAE = "VAE"
|
|
18
|
+
CUSTOM = (
|
|
19
|
+
"Custom" # TODO all the others tags are small letters, except the architect
|
|
20
|
+
)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Data supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
from careamics.utils import BaseEnum
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SupportedData(str, BaseEnum):
|
|
11
|
+
"""Supported data types.
|
|
12
|
+
|
|
13
|
+
Attributes
|
|
14
|
+
----------
|
|
15
|
+
ARRAY : str
|
|
16
|
+
Array data.
|
|
17
|
+
TIFF : str
|
|
18
|
+
TIFF image data.
|
|
19
|
+
CUSTOM : str
|
|
20
|
+
Custom data.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
ARRAY = "array"
|
|
24
|
+
TIFF = "tiff"
|
|
25
|
+
CUSTOM = "custom"
|
|
26
|
+
# ZARR = "zarr"
|
|
27
|
+
|
|
28
|
+
# TODO remove?
|
|
29
|
+
@classmethod
|
|
30
|
+
def _missing_(cls, value: object) -> str:
|
|
31
|
+
"""
|
|
32
|
+
Override default behaviour for missing values.
|
|
33
|
+
|
|
34
|
+
This method is called when `value` is not found in the enum values. It converts
|
|
35
|
+
`value` to lowercase, removes "." if it is the first character and tries to
|
|
36
|
+
match it with enum values.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
value : object
|
|
41
|
+
Value to be matched with enum values.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
str
|
|
46
|
+
Matched enum value.
|
|
47
|
+
"""
|
|
48
|
+
if isinstance(value, str):
|
|
49
|
+
lower_value = value.lower()
|
|
50
|
+
|
|
51
|
+
if lower_value.startswith("."):
|
|
52
|
+
lower_value = lower_value[1:]
|
|
53
|
+
|
|
54
|
+
# attempt to match lowercase value with enum values
|
|
55
|
+
for member in cls:
|
|
56
|
+
if member.value == lower_value:
|
|
57
|
+
return member
|
|
58
|
+
|
|
59
|
+
# still missing
|
|
60
|
+
return super()._missing_(value)
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def get_extension_pattern(cls, data_type: Union[str, SupportedData]) -> str:
|
|
64
|
+
"""
|
|
65
|
+
Get Path.rglob and fnmatch compatible extension.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
data_type : SupportedData
|
|
70
|
+
Data type.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
str
|
|
75
|
+
Corresponding extension pattern.
|
|
76
|
+
"""
|
|
77
|
+
if data_type == cls.ARRAY:
|
|
78
|
+
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
|
|
79
|
+
elif data_type == cls.TIFF:
|
|
80
|
+
return "*.tif*"
|
|
81
|
+
elif data_type == cls.CUSTOM:
|
|
82
|
+
return "*.*"
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError(f"Data type {data_type} is not supported.")
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
|
|
88
|
+
"""
|
|
89
|
+
Get file extension of corresponding data type.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
data_type : str or SupportedData
|
|
94
|
+
Data type.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
str
|
|
99
|
+
Corresponding extension.
|
|
100
|
+
"""
|
|
101
|
+
if data_type == cls.ARRAY:
|
|
102
|
+
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
|
|
103
|
+
elif data_type == cls.TIFF:
|
|
104
|
+
return ".tiff"
|
|
105
|
+
elif data_type == cls.CUSTOM:
|
|
106
|
+
# TODO: improve this message
|
|
107
|
+
raise NotImplementedError("Custom extensions have to be passed elsewhere.")
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(f"Data type {data_type} is not supported.")
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Losses supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# TODO register loss with custom_loss decorator?
|
|
7
|
+
class SupportedLoss(str, BaseEnum):
|
|
8
|
+
"""Supported losses.
|
|
9
|
+
|
|
10
|
+
Attributes
|
|
11
|
+
----------
|
|
12
|
+
MSE : str
|
|
13
|
+
Mean Squared Error loss.
|
|
14
|
+
MAE : str
|
|
15
|
+
Mean Absolute Error loss.
|
|
16
|
+
N2V : str
|
|
17
|
+
Noise2Void loss.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
MSE = "mse"
|
|
21
|
+
MAE = "mae"
|
|
22
|
+
N2V = "n2v"
|
|
23
|
+
# PN2V = "pn2v"
|
|
24
|
+
# HDN = "hdn"
|
|
25
|
+
# CE = "ce"
|
|
26
|
+
# DICE = "dice"
|
|
27
|
+
# CUSTOM = "custom" # TODO create mechanism for that
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Optimizers and schedulers supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportedOptimizer(str, BaseEnum):
|
|
7
|
+
"""Supported optimizers.
|
|
8
|
+
|
|
9
|
+
Attributes
|
|
10
|
+
----------
|
|
11
|
+
Adam : str
|
|
12
|
+
Adam optimizer.
|
|
13
|
+
SGD : str
|
|
14
|
+
Stochastic Gradient Descent optimizer.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
# ASGD = "ASGD"
|
|
18
|
+
# Adadelta = "Adadelta"
|
|
19
|
+
# Adagrad = "Adagrad"
|
|
20
|
+
ADAM = "Adam"
|
|
21
|
+
# AdamW = "AdamW"
|
|
22
|
+
# Adamax = "Adamax"
|
|
23
|
+
# LBFGS = "LBFGS"
|
|
24
|
+
# NAdam = "NAdam"
|
|
25
|
+
# RAdam = "RAdam"
|
|
26
|
+
# RMSprop = "RMSprop"
|
|
27
|
+
# Rprop = "Rprop"
|
|
28
|
+
SGD = "SGD"
|
|
29
|
+
# SparseAdam = "SparseAdam"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SupportedScheduler(str, BaseEnum):
|
|
33
|
+
"""Supported schedulers.
|
|
34
|
+
|
|
35
|
+
Attributes
|
|
36
|
+
----------
|
|
37
|
+
ReduceLROnPlateau : str
|
|
38
|
+
Reduce learning rate on plateau.
|
|
39
|
+
StepLR : str
|
|
40
|
+
Step learning rate.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
# ChainedScheduler = "ChainedScheduler"
|
|
44
|
+
# ConstantLR = "ConstantLR"
|
|
45
|
+
# CosineAnnealingLR = "CosineAnnealingLR"
|
|
46
|
+
# CosineAnnealingWarmRestarts = "CosineAnnealingWarmRestarts"
|
|
47
|
+
# CyclicLR = "CyclicLR"
|
|
48
|
+
# ExponentialLR = "ExponentialLR"
|
|
49
|
+
# LambdaLR = "LambdaLR"
|
|
50
|
+
# LinearLR = "LinearLR"
|
|
51
|
+
# MultiStepLR = "MultiStepLR"
|
|
52
|
+
# MultiplicativeLR = "MultiplicativeLR"
|
|
53
|
+
# OneCycleLR = "OneCycleLR"
|
|
54
|
+
# PolynomialLR = "PolynomialLR"
|
|
55
|
+
REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau"
|
|
56
|
+
# SequentialLR = "SequentialLR"
|
|
57
|
+
STEP_LR = "StepLR"
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pixel manipulation methods supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportedPixelManipulation(str, BaseEnum):
|
|
7
|
+
"""Supported Noise2Void pixel manipulations.
|
|
8
|
+
|
|
9
|
+
- Uniform: Replace masked pixel value by a (uniformly) randomly selected neighbor
|
|
10
|
+
pixel value.
|
|
11
|
+
- Median: Replace masked pixel value by the mean of the neighborhood.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
UNIFORM = "uniform"
|
|
15
|
+
MEDIAN = "median"
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""StructN2V axes supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportedStructAxis(str, BaseEnum):
|
|
7
|
+
"""Supported structN2V mask axes.
|
|
8
|
+
|
|
9
|
+
Attributes
|
|
10
|
+
----------
|
|
11
|
+
HORIZONTAL : str
|
|
12
|
+
Horizontal axis.
|
|
13
|
+
VERTICAL : str
|
|
14
|
+
Vertical axis.
|
|
15
|
+
NONE : str
|
|
16
|
+
No axis, the mask is not applied.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
HORIZONTAL = "horizontal"
|
|
20
|
+
VERTICAL = "vertical"
|
|
21
|
+
NONE = "none"
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Transforms supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportedTransform(str, BaseEnum):
|
|
7
|
+
"""Transforms officially supported by CAREamics."""
|
|
8
|
+
|
|
9
|
+
XY_FLIP = "XYFlip"
|
|
10
|
+
XY_RANDOM_ROTATE90 = "XYRandomRotate90"
|
|
11
|
+
N2V_MANIPULATE = "N2VManipulate"
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Pydantic model representing the metadata of a prediction tile."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Annotated
|
|
6
|
+
|
|
7
|
+
from annotated_types import Len
|
|
8
|
+
from pydantic import BaseModel, ConfigDict
|
|
9
|
+
|
|
10
|
+
DimTuple = Annotated[tuple, Len(min_length=3, max_length=4)]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TileInformation(BaseModel):
|
|
14
|
+
"""
|
|
15
|
+
Pydantic model containing tile information.
|
|
16
|
+
|
|
17
|
+
This model is used to represent the information required to stitch back a tile into
|
|
18
|
+
a larger image. It is used throughout the prediction pipeline of CAREamics.
|
|
19
|
+
|
|
20
|
+
Array shape should be C(Z)YX, where Z is an optional dimensions.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
model_config = ConfigDict(validate_default=True)
|
|
24
|
+
|
|
25
|
+
array_shape: DimTuple # TODO: find a way to add custom error message?
|
|
26
|
+
"""Shape of the original (untiled) array."""
|
|
27
|
+
|
|
28
|
+
last_tile: bool = False
|
|
29
|
+
"""Whether this tile is the last one of the array."""
|
|
30
|
+
|
|
31
|
+
overlap_crop_coords: tuple[tuple[int, ...], ...]
|
|
32
|
+
"""Inner coordinates of the tile where to crop the prediction in order to stitch
|
|
33
|
+
it back into the original image."""
|
|
34
|
+
|
|
35
|
+
stitch_coords: tuple[tuple[int, ...], ...]
|
|
36
|
+
"""Coordinates in the original image where to stitch the cropped tile back."""
|
|
37
|
+
|
|
38
|
+
sample_id: int
|
|
39
|
+
"""Sample ID of the tile."""
|
|
40
|
+
|
|
41
|
+
# TODO: Test that ZYX axes are not singleton ?
|
|
42
|
+
|
|
43
|
+
def __eq__(self, other_tile: object):
|
|
44
|
+
"""Check if two tile information objects are equal.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
other_tile : object
|
|
49
|
+
Tile information object to compare with.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
bool
|
|
54
|
+
Whether the two tile information objects are equal.
|
|
55
|
+
"""
|
|
56
|
+
if not isinstance(other_tile, TileInformation):
|
|
57
|
+
return NotImplemented
|
|
58
|
+
|
|
59
|
+
return (
|
|
60
|
+
self.array_shape == other_tile.array_shape
|
|
61
|
+
and self.last_tile == other_tile.last_tile
|
|
62
|
+
and self.overlap_crop_coords == other_tile.overlap_crop_coords
|
|
63
|
+
and self.stitch_coords == other_tile.stitch_coords
|
|
64
|
+
and self.sample_id == other_tile.sample_id
|
|
65
|
+
)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Training configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pprint import pformat
|
|
6
|
+
from typing import Literal, Optional
|
|
7
|
+
|
|
8
|
+
from pydantic import (
|
|
9
|
+
BaseModel,
|
|
10
|
+
ConfigDict,
|
|
11
|
+
Field,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from .callback_model import CheckpointModel, EarlyStoppingModel
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TrainingConfig(BaseModel):
|
|
18
|
+
"""
|
|
19
|
+
Parameters related to the training.
|
|
20
|
+
|
|
21
|
+
Mandatory parameters are:
|
|
22
|
+
- num_epochs: number of epochs, greater than 0.
|
|
23
|
+
- batch_size: batch size, greater than 0.
|
|
24
|
+
- augmentation: whether to use data augmentation or not (True or False).
|
|
25
|
+
|
|
26
|
+
Attributes
|
|
27
|
+
----------
|
|
28
|
+
num_epochs : int
|
|
29
|
+
Number of epochs, greater than 0.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
# Pydantic class configuration
|
|
33
|
+
model_config = ConfigDict(
|
|
34
|
+
validate_assignment=True,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
num_epochs: int = Field(default=20, ge=1)
|
|
38
|
+
"""Number of epochs, greater than 0."""
|
|
39
|
+
|
|
40
|
+
logger: Optional[Literal["wandb", "tensorboard"]] = None
|
|
41
|
+
"""Logger to use during training. If None, no logger will be used. Available
|
|
42
|
+
loggers are defined in SupportedLogger."""
|
|
43
|
+
|
|
44
|
+
checkpoint_callback: CheckpointModel = CheckpointModel()
|
|
45
|
+
"""Checkpoint callback configuration, following PyTorch Lightning Checkpoint
|
|
46
|
+
callback."""
|
|
47
|
+
|
|
48
|
+
early_stopping_callback: Optional[EarlyStoppingModel] = Field(
|
|
49
|
+
default=None, validate_default=True
|
|
50
|
+
)
|
|
51
|
+
"""Early stopping callback configuration, following PyTorch Lightning Checkpoint
|
|
52
|
+
callback."""
|
|
53
|
+
|
|
54
|
+
def __str__(self) -> str:
|
|
55
|
+
"""Pretty string reprensenting the configuration.
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
str
|
|
60
|
+
Pretty string.
|
|
61
|
+
"""
|
|
62
|
+
return pformat(self.model_dump())
|
|
63
|
+
|
|
64
|
+
def has_logger(self) -> bool:
|
|
65
|
+
"""Check if the logger is defined.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
bool
|
|
70
|
+
Whether the logger is defined or not.
|
|
71
|
+
"""
|
|
72
|
+
return self.logger is not None
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""CAREamics transformation Pydantic models."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"N2VManipulateModel",
|
|
5
|
+
"XYFlipModel",
|
|
6
|
+
"NormalizeModel",
|
|
7
|
+
"XYRandomRotate90Model",
|
|
8
|
+
"XorYFlipModel",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
from .n2v_manipulate_model import N2VManipulateModel
|
|
13
|
+
from .normalize_model import NormalizeModel
|
|
14
|
+
from .xy_flip_model import XYFlipModel
|
|
15
|
+
from .xy_random_rotate90_model import XYRandomRotate90Model
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Pydantic model for the N2VManipulate transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, field_validator
|
|
6
|
+
|
|
7
|
+
from .transform_model import TransformModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class N2VManipulateModel(TransformModel):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent N2V manipulation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["N2VManipulate"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
roi_size : int
|
|
19
|
+
Size of the masking region, by default 11.
|
|
20
|
+
masked_pixel_percentage : float
|
|
21
|
+
Percentage of masked pixels, by default 0.2.
|
|
22
|
+
strategy : Literal["uniform", "median"]
|
|
23
|
+
Strategy pixel value replacement, by default "uniform".
|
|
24
|
+
struct_mask_axis : Literal["horizontal", "vertical", "none"]
|
|
25
|
+
Axis of the structN2V mask, by default "none".
|
|
26
|
+
struct_mask_span : int
|
|
27
|
+
Span of the structN2V mask, by default 5.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
model_config = ConfigDict(
|
|
31
|
+
validate_assignment=True,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
name: Literal["N2VManipulate"] = "N2VManipulate"
|
|
35
|
+
roi_size: int = Field(default=11, ge=3, le=21)
|
|
36
|
+
masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0)
|
|
37
|
+
strategy: Literal["uniform", "median"] = Field(default="uniform")
|
|
38
|
+
struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
|
|
39
|
+
struct_mask_span: int = Field(default=5, ge=3, le=15)
|
|
40
|
+
|
|
41
|
+
@field_validator("roi_size", "struct_mask_span")
|
|
42
|
+
@classmethod
|
|
43
|
+
def odd_value(cls, v: int) -> int:
|
|
44
|
+
"""
|
|
45
|
+
Validate that the value is odd.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
v : int
|
|
50
|
+
Value to validate.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
int
|
|
55
|
+
The validated value.
|
|
56
|
+
|
|
57
|
+
Raises
|
|
58
|
+
------
|
|
59
|
+
ValueError
|
|
60
|
+
If the value is even.
|
|
61
|
+
"""
|
|
62
|
+
if v % 2 == 0:
|
|
63
|
+
raise ValueError("Size must be an odd number.")
|
|
64
|
+
return v
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Pydantic model for the Normalize transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, model_validator
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from .transform_model import TransformModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NormalizeModel(TransformModel):
|
|
12
|
+
"""
|
|
13
|
+
Pydantic model used to represent Normalize transformation.
|
|
14
|
+
|
|
15
|
+
The Normalize transform is a zero mean and unit variance transformation.
|
|
16
|
+
|
|
17
|
+
Attributes
|
|
18
|
+
----------
|
|
19
|
+
name : Literal["Normalize"]
|
|
20
|
+
Name of the transformation.
|
|
21
|
+
mean : float
|
|
22
|
+
Mean value for normalization.
|
|
23
|
+
std : float
|
|
24
|
+
Standard deviation value for normalization.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
model_config = ConfigDict(
|
|
28
|
+
validate_assignment=True,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
name: Literal["Normalize"] = "Normalize"
|
|
32
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
33
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
34
|
+
target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
35
|
+
target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
|
|
36
|
+
|
|
37
|
+
@model_validator(mode="after")
|
|
38
|
+
def validate_means_stds(self: Self) -> Self:
|
|
39
|
+
"""Validate that the means and stds have the same length.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
Self
|
|
44
|
+
The instance of the model.
|
|
45
|
+
"""
|
|
46
|
+
if len(self.image_means) != len(self.image_stds):
|
|
47
|
+
raise ValueError("The number of image means and stds must be the same.")
|
|
48
|
+
|
|
49
|
+
if (self.target_means is None) != (self.target_stds is None):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Both target means and stds must be provided together, or bot None."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if self.target_means is not None and self.target_stds is not None:
|
|
55
|
+
if len(self.target_means) != len(self.target_stds):
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"The number of target means and stds must be the same."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return self
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Parent model for the transforms."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TransformModel(BaseModel):
|
|
9
|
+
"""
|
|
10
|
+
Pydantic model used to represent a transformation.
|
|
11
|
+
|
|
12
|
+
The `model_dump` method is overwritten to exclude the name field.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : str
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
model_config = ConfigDict(
|
|
21
|
+
extra="forbid", # throw errors if the parameters are not properly passed
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
name: str
|
|
25
|
+
|
|
26
|
+
def model_dump(self, **kwargs) -> Dict[str, Any]:
|
|
27
|
+
"""
|
|
28
|
+
Return the model as a dictionary.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
**kwargs
|
|
33
|
+
Pydantic BaseMode model_dump method keyword arguments.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
Dict[str, Any]
|
|
38
|
+
Dictionary representation of the model.
|
|
39
|
+
"""
|
|
40
|
+
model_dict = super().model_dump(**kwargs)
|
|
41
|
+
|
|
42
|
+
# remove the name field
|
|
43
|
+
model_dict.pop("name")
|
|
44
|
+
|
|
45
|
+
return model_dict
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Pydantic model for the XYFlip transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
from .transform_model import TransformModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class XYFlipModel(TransformModel):
|
|
11
|
+
"""
|
|
12
|
+
Pydantic model used to represent XYFlip transformation.
|
|
13
|
+
|
|
14
|
+
Attributes
|
|
15
|
+
----------
|
|
16
|
+
name : Literal["XYFlip"]
|
|
17
|
+
Name of the transformation.
|
|
18
|
+
p : float
|
|
19
|
+
Probability of applying the transform, by default 0.5.
|
|
20
|
+
seed : Optional[int]
|
|
21
|
+
Seed for the random number generator, by default None.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(
|
|
25
|
+
validate_assignment=True,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
name: Literal["XYFlip"] = "XYFlip"
|
|
29
|
+
flip_x: bool = Field(
|
|
30
|
+
True,
|
|
31
|
+
description="Whether to flip along the X axis.",
|
|
32
|
+
)
|
|
33
|
+
flip_y: bool = Field(
|
|
34
|
+
True,
|
|
35
|
+
description="Whether to flip along the Y axis.",
|
|
36
|
+
)
|
|
37
|
+
p: float = Field(
|
|
38
|
+
0.5,
|
|
39
|
+
description="Probability of applying the transform.",
|
|
40
|
+
ge=0,
|
|
41
|
+
le=1,
|
|
42
|
+
)
|
|
43
|
+
seed: Optional[int] = None
|