careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
careamics/cli/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""Utility functions for the CAREamics CLI."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def handle_2D_3D_callback(
|
|
7
|
+
value: Optional[tuple[int, int, int]]
|
|
8
|
+
) -> Optional[tuple[int, ...]]:
|
|
9
|
+
"""
|
|
10
|
+
Callback for options that require 2D or 3D inputs.
|
|
11
|
+
|
|
12
|
+
In the case of 2D, the 3rd element should be set to -1.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
value : (int, int, int)
|
|
17
|
+
Tile size value.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
(int, int, int) | (int, int)
|
|
22
|
+
If the last element in `value` is -1 the tuple is reduced to the first two
|
|
23
|
+
values.
|
|
24
|
+
"""
|
|
25
|
+
if value is None:
|
|
26
|
+
return value
|
|
27
|
+
if value[2] == -1:
|
|
28
|
+
return value[:2]
|
|
29
|
+
return value
|
careamics/config/__init__.py
CHANGED
|
@@ -1,39 +1,63 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""CAREamics Pydantic configuration models.
|
|
2
|
+
|
|
3
|
+
To maintain clarity at the module level, we follow the following naming conventions:
|
|
4
|
+
`*_model` is specific for sub-configurations (e.g. architecture, data, algorithm),
|
|
5
|
+
while `*_configuration` is reserved for the main configuration models, including the
|
|
6
|
+
`Configuration` base class and its algorithm-specific child classes.
|
|
7
|
+
"""
|
|
2
8
|
|
|
3
9
|
__all__ = [
|
|
4
|
-
"
|
|
5
|
-
"
|
|
6
|
-
"DataConfig",
|
|
7
|
-
"Configuration",
|
|
10
|
+
"CAREAlgorithm",
|
|
11
|
+
"CAREConfiguration",
|
|
8
12
|
"CheckpointModel",
|
|
13
|
+
"Configuration",
|
|
14
|
+
"DataConfig",
|
|
15
|
+
"GaussianMixtureNMConfig",
|
|
16
|
+
"GeneralDataConfig",
|
|
9
17
|
"InferenceConfig",
|
|
10
|
-
"
|
|
11
|
-
"
|
|
18
|
+
"LVAELossConfig",
|
|
19
|
+
"MultiChannelNMConfig",
|
|
20
|
+
"N2NAlgorithm",
|
|
21
|
+
"N2NConfiguration",
|
|
22
|
+
"N2VAlgorithm",
|
|
23
|
+
"N2VConfiguration",
|
|
24
|
+
"N2VDataConfig",
|
|
12
25
|
"TrainingConfig",
|
|
13
|
-
"
|
|
14
|
-
"
|
|
26
|
+
"UNetBasedAlgorithm",
|
|
27
|
+
"VAEBasedAlgorithm",
|
|
28
|
+
"algorithm_factory",
|
|
29
|
+
"configuration_factory",
|
|
15
30
|
"create_care_configuration",
|
|
16
|
-
"
|
|
17
|
-
"
|
|
18
|
-
"
|
|
19
|
-
"
|
|
20
|
-
"
|
|
31
|
+
"create_n2n_configuration",
|
|
32
|
+
"create_n2v_configuration",
|
|
33
|
+
"data_factory",
|
|
34
|
+
"load_configuration",
|
|
35
|
+
"save_configuration",
|
|
21
36
|
]
|
|
22
|
-
|
|
37
|
+
|
|
38
|
+
from .algorithms import (
|
|
39
|
+
CAREAlgorithm,
|
|
40
|
+
N2NAlgorithm,
|
|
41
|
+
N2VAlgorithm,
|
|
42
|
+
UNetBasedAlgorithm,
|
|
43
|
+
VAEBasedAlgorithm,
|
|
44
|
+
)
|
|
23
45
|
from .callback_model import CheckpointModel
|
|
24
|
-
from .
|
|
46
|
+
from .care_configuration import CAREConfiguration
|
|
47
|
+
from .configuration import Configuration
|
|
48
|
+
from .configuration_factories import (
|
|
49
|
+
algorithm_factory,
|
|
50
|
+
configuration_factory,
|
|
25
51
|
create_care_configuration,
|
|
26
52
|
create_n2n_configuration,
|
|
27
53
|
create_n2v_configuration,
|
|
54
|
+
data_factory,
|
|
28
55
|
)
|
|
29
|
-
from .
|
|
30
|
-
|
|
31
|
-
load_configuration,
|
|
32
|
-
save_configuration,
|
|
33
|
-
)
|
|
34
|
-
from .data_model import DataConfig
|
|
35
|
-
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
56
|
+
from .configuration_io import load_configuration, save_configuration
|
|
57
|
+
from .data import DataConfig, GeneralDataConfig, N2VDataConfig
|
|
36
58
|
from .inference_model import InferenceConfig
|
|
59
|
+
from .loss_model import LVAELossConfig
|
|
60
|
+
from .n2n_configuration import N2NConfiguration
|
|
61
|
+
from .n2v_configuration import N2VConfiguration
|
|
37
62
|
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
38
63
|
from .training_model import TrainingConfig
|
|
39
|
-
from .vae_algorithm_model import VAEAlgorithmConfig
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Algorithm configurations."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CAREAlgorithm",
|
|
5
|
+
"N2NAlgorithm",
|
|
6
|
+
"N2VAlgorithm",
|
|
7
|
+
"UNetBasedAlgorithm",
|
|
8
|
+
"VAEBasedAlgorithm",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from .care_algorithm_model import CAREAlgorithm
|
|
12
|
+
from .n2n_algorithm_model import N2NAlgorithm
|
|
13
|
+
from .n2v_algorithm_model import N2VAlgorithm
|
|
14
|
+
from .unet_algorithm_model import UNetBasedAlgorithm
|
|
15
|
+
from .vae_algorithm_model import VAEBasedAlgorithm
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""CARE algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import field_validator
|
|
6
|
+
|
|
7
|
+
from careamics.config.architectures import UNetModel
|
|
8
|
+
|
|
9
|
+
from .unet_algorithm_model import UNetBasedAlgorithm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CAREAlgorithm(UNetBasedAlgorithm):
|
|
13
|
+
"""CARE algorithm configuration.
|
|
14
|
+
|
|
15
|
+
Attributes
|
|
16
|
+
----------
|
|
17
|
+
algorithm : "care"
|
|
18
|
+
CARE Algorithm name.
|
|
19
|
+
loss : {"mae", "mse"}
|
|
20
|
+
CARE-compatible loss function.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
algorithm: Literal["care"] = "care"
|
|
24
|
+
"""CARE Algorithm name."""
|
|
25
|
+
|
|
26
|
+
loss: Literal["mae", "mse"] = "mae"
|
|
27
|
+
"""CARE-compatible loss function."""
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
@field_validator("model")
|
|
31
|
+
def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
|
|
32
|
+
"""Validate that the model does not have the n2v2 attribute.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
value : UNetModel
|
|
37
|
+
Model to validate.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
UNetModel
|
|
42
|
+
The validated model.
|
|
43
|
+
"""
|
|
44
|
+
if value.n2v2:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
"The N2N algorithm does not support the `n2v2` attribute. "
|
|
47
|
+
"Set it to `False`."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return value
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""N2N Algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import field_validator
|
|
6
|
+
|
|
7
|
+
from careamics.config.architectures import UNetModel
|
|
8
|
+
|
|
9
|
+
from .unet_algorithm_model import UNetBasedAlgorithm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class N2NAlgorithm(UNetBasedAlgorithm):
|
|
13
|
+
"""N2N Algorithm configuration."""
|
|
14
|
+
|
|
15
|
+
algorithm: Literal["n2n"] = "n2n"
|
|
16
|
+
"""N2N Algorithm name."""
|
|
17
|
+
|
|
18
|
+
loss: Literal["mae", "mse"] = "mae"
|
|
19
|
+
"""N2N-compatible loss function."""
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
@field_validator("model")
|
|
23
|
+
def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
|
|
24
|
+
"""Validate that the model does not have the n2v2 attribute.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
value : UNetModel
|
|
29
|
+
Model to validate.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
UNetModel
|
|
34
|
+
The validated model.
|
|
35
|
+
"""
|
|
36
|
+
if value.n2v2:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"The N2N algorithm does not support the `n2v2` attribute. "
|
|
39
|
+
"Set it to `False`."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return value
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
""""N2V Algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import model_validator
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from .unet_algorithm_model import UNetBasedAlgorithm
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class N2VAlgorithm(UNetBasedAlgorithm):
|
|
12
|
+
"""N2V Algorithm configuration."""
|
|
13
|
+
|
|
14
|
+
algorithm: Literal["n2v"] = "n2v"
|
|
15
|
+
"""N2V Algorithm name."""
|
|
16
|
+
|
|
17
|
+
loss: Literal["n2v"] = "n2v"
|
|
18
|
+
"""N2V loss function."""
|
|
19
|
+
|
|
20
|
+
@model_validator(mode="after")
|
|
21
|
+
def algorithm_cross_validation(self: Self) -> Self:
|
|
22
|
+
"""Validate the algorithm model for N2V.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Self
|
|
27
|
+
The validated model.
|
|
28
|
+
"""
|
|
29
|
+
if self.model.in_channels != self.model.num_classes:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
"N2V requires the same number of input and output channels. Make "
|
|
32
|
+
"sure that `in_channels` and `num_classes` are the same."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
return self
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""UNet-based algorithm Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from pprint import pformat
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
|
|
8
|
+
from careamics.config.architectures import UNetModel
|
|
9
|
+
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class UNetBasedAlgorithm(BaseModel):
|
|
13
|
+
"""General UNet-based algorithm configuration.
|
|
14
|
+
|
|
15
|
+
This Pydantic model validates the parameters governing the components of the
|
|
16
|
+
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
17
|
+
and learning rate scheduler to use.
|
|
18
|
+
|
|
19
|
+
Currently, we only support N2V, CARE, and N2N algorithms. In order to train these
|
|
20
|
+
algorithms, use the corresponding configuration child classes (e.g.
|
|
21
|
+
`N2VAlgorithm`) to ensure coherent parameters (e.g. specific losses).
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
Attributes
|
|
25
|
+
----------
|
|
26
|
+
algorithm : {"n2v", "care", "n2n"}
|
|
27
|
+
Algorithm to use.
|
|
28
|
+
loss : {"n2v", "mae", "mse"}
|
|
29
|
+
Loss function to use.
|
|
30
|
+
model : UNetModel
|
|
31
|
+
Model architecture to use.
|
|
32
|
+
optimizer : OptimizerModel, optional
|
|
33
|
+
Optimizer to use.
|
|
34
|
+
lr_scheduler : LrSchedulerModel, optional
|
|
35
|
+
Learning rate scheduler to use.
|
|
36
|
+
|
|
37
|
+
Raises
|
|
38
|
+
------
|
|
39
|
+
ValueError
|
|
40
|
+
Algorithm parameter type validation errors.
|
|
41
|
+
ValueError
|
|
42
|
+
If the algorithm, loss and model are not compatible.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
# Pydantic class configuration
|
|
46
|
+
model_config = ConfigDict(
|
|
47
|
+
protected_namespaces=(), # allows to use model_* as a field name
|
|
48
|
+
validate_assignment=True,
|
|
49
|
+
extra="allow",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Mandatory fields
|
|
53
|
+
algorithm: Literal["n2v", "care", "n2n"]
|
|
54
|
+
"""Algorithm name, as defined in SupportedAlgorithm."""
|
|
55
|
+
|
|
56
|
+
loss: Literal["n2v", "mae", "mse"]
|
|
57
|
+
"""Loss function to use, as defined in SupportedLoss."""
|
|
58
|
+
|
|
59
|
+
model: UNetModel
|
|
60
|
+
"""UNet model configuration."""
|
|
61
|
+
|
|
62
|
+
# Optional fields
|
|
63
|
+
optimizer: OptimizerModel = OptimizerModel()
|
|
64
|
+
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
65
|
+
|
|
66
|
+
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
67
|
+
"""Learning rate scheduler to use, defined in SupportedLrScheduler."""
|
|
68
|
+
|
|
69
|
+
def __str__(self) -> str:
|
|
70
|
+
"""Pretty string representing the configuration.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
str
|
|
75
|
+
Pretty string.
|
|
76
|
+
"""
|
|
77
|
+
return pformat(self.model_dump())
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def get_compatible_algorithms(cls) -> list[str]:
|
|
81
|
+
"""Get the list of compatible algorithms.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
list of str
|
|
86
|
+
List of compatible algorithms.
|
|
87
|
+
"""
|
|
88
|
+
return ["n2v", "care", "n2n"]
|
|
@@ -1,23 +1,26 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""VAE-based algorithm Pydantic model."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Literal, Optional
|
|
6
|
+
from typing import Literal, Optional
|
|
7
7
|
|
|
8
|
-
from pydantic import BaseModel, ConfigDict,
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
|
9
9
|
from typing_extensions import Self
|
|
10
10
|
|
|
11
|
+
from careamics.config.architectures import LVAEModel
|
|
12
|
+
from careamics.config.likelihood_model import (
|
|
13
|
+
GaussianLikelihoodConfig,
|
|
14
|
+
NMLikelihoodConfig,
|
|
15
|
+
)
|
|
16
|
+
from careamics.config.loss_model import LVAELossConfig
|
|
17
|
+
from careamics.config.nm_model import MultiChannelNMConfig
|
|
18
|
+
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
|
|
11
19
|
from careamics.config.support import SupportedAlgorithm, SupportedLoss
|
|
12
20
|
|
|
13
|
-
from .architectures import CustomModel, LVAEModel
|
|
14
|
-
from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
|
|
15
|
-
from .nm_model import MultiChannelNMConfig
|
|
16
|
-
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
17
21
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
"""Algorithm configuration.
|
|
22
|
+
class VAEBasedAlgorithm(BaseModel):
|
|
23
|
+
"""VAE-based algorithm configuration.
|
|
21
24
|
|
|
22
25
|
# TODO
|
|
23
26
|
|
|
@@ -38,13 +41,13 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
38
41
|
# TODO: Use supported Enum classes for typing?
|
|
39
42
|
# - values can still be passed as strings and they will be cast to Enum
|
|
40
43
|
algorithm: Literal["musplit", "denoisplit"]
|
|
41
|
-
loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
42
|
-
model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
|
|
43
44
|
|
|
44
|
-
#
|
|
45
|
+
# NOTE: these are all configs (pydantic models)
|
|
46
|
+
loss: LVAELossConfig
|
|
47
|
+
model: LVAEModel
|
|
45
48
|
noise_model: Optional[MultiChannelNMConfig] = None
|
|
46
|
-
|
|
47
|
-
|
|
49
|
+
noise_model_likelihood: Optional[NMLikelihoodConfig] = None
|
|
50
|
+
gaussian_likelihood: Optional[GaussianLikelihoodConfig] = None
|
|
48
51
|
|
|
49
52
|
# Optional fields
|
|
50
53
|
optimizer: OptimizerModel = OptimizerModel()
|
|
@@ -63,13 +66,13 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
63
66
|
"""
|
|
64
67
|
# musplit
|
|
65
68
|
if self.algorithm == SupportedAlgorithm.MUSPLIT:
|
|
66
|
-
if self.loss != SupportedLoss.MUSPLIT:
|
|
69
|
+
if self.loss.loss_type != SupportedLoss.MUSPLIT:
|
|
67
70
|
raise ValueError(
|
|
68
71
|
f"Algorithm {self.algorithm} only supports loss `musplit`."
|
|
69
72
|
)
|
|
70
73
|
|
|
71
74
|
if self.algorithm == SupportedAlgorithm.DENOISPLIT:
|
|
72
|
-
if self.loss not in [
|
|
75
|
+
if self.loss.loss_type not in [
|
|
73
76
|
SupportedLoss.DENOISPLIT,
|
|
74
77
|
SupportedLoss.DENOISPLIT_MUSPLIT,
|
|
75
78
|
]:
|
|
@@ -78,16 +81,17 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
78
81
|
"or `denoisplit_musplit."
|
|
79
82
|
)
|
|
80
83
|
if (
|
|
81
|
-
self.loss == SupportedLoss.DENOISPLIT
|
|
84
|
+
self.loss.loss_type == SupportedLoss.DENOISPLIT
|
|
82
85
|
and self.model.predict_logvar is not None
|
|
83
86
|
):
|
|
84
87
|
raise ValueError(
|
|
85
88
|
"Algorithm `denoisplit` with loss `denoisplit` only supports "
|
|
86
89
|
"`predict_logvar` as `None`."
|
|
87
90
|
)
|
|
91
|
+
|
|
88
92
|
if self.noise_model is None:
|
|
89
93
|
raise ValueError("Algorithm `denoisplit` requires a noise model.")
|
|
90
|
-
# TODO: what if algorithm is not musplit or denoisplit
|
|
94
|
+
# TODO: what if algorithm is not musplit or denoisplit
|
|
91
95
|
return self
|
|
92
96
|
|
|
93
97
|
@model_validator(mode="after")
|
|
@@ -115,14 +119,13 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
115
119
|
Self
|
|
116
120
|
The validated model.
|
|
117
121
|
"""
|
|
118
|
-
if self.
|
|
122
|
+
if self.gaussian_likelihood is not None:
|
|
119
123
|
assert (
|
|
120
|
-
self.model.predict_logvar
|
|
121
|
-
== self.gaussian_likelihood_model.predict_logvar
|
|
124
|
+
self.model.predict_logvar == self.gaussian_likelihood.predict_logvar
|
|
122
125
|
), (
|
|
123
126
|
f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
|
|
124
127
|
"Gaussian likelihood model `predict_logvar` "
|
|
125
|
-
f"({self.
|
|
128
|
+
f"({self.gaussian_likelihood.predict_logvar}).",
|
|
126
129
|
)
|
|
127
130
|
return self
|
|
128
131
|
|
|
@@ -1,17 +1,7 @@
|
|
|
1
1
|
"""Deep-learning model configurations."""
|
|
2
2
|
|
|
3
|
-
__all__ = [
|
|
4
|
-
"ArchitectureModel",
|
|
5
|
-
"CustomModel",
|
|
6
|
-
"UNetModel",
|
|
7
|
-
"LVAEModel",
|
|
8
|
-
"clear_custom_models",
|
|
9
|
-
"get_custom_model",
|
|
10
|
-
"register_model",
|
|
11
|
-
]
|
|
3
|
+
__all__ = ["ArchitectureModel", "LVAEModel", "UNetModel"]
|
|
12
4
|
|
|
13
5
|
from .architecture_model import ArchitectureModel
|
|
14
|
-
from .custom_model import CustomModel
|
|
15
6
|
from .lvae_model import LVAEModel
|
|
16
|
-
from .register_model import clear_custom_models, get_custom_model, register_model
|
|
17
7
|
from .unet_model import UNetModel
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Base model for the various CAREamics architectures."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel
|
|
6
6
|
|
|
@@ -15,7 +15,7 @@ class ArchitectureModel(BaseModel):
|
|
|
15
15
|
architecture: str
|
|
16
16
|
"""Name of the architecture."""
|
|
17
17
|
|
|
18
|
-
def model_dump(self, **kwargs: Any) ->
|
|
18
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
19
19
|
"""
|
|
20
20
|
Dump the model as a dictionary, ignoring the architecture keyword.
|
|
21
21
|
|
|
@@ -26,7 +26,7 @@ class ArchitectureModel(BaseModel):
|
|
|
26
26
|
|
|
27
27
|
Returns
|
|
28
28
|
-------
|
|
29
|
-
|
|
29
|
+
{str: Any}
|
|
30
30
|
Model as a dictionary.
|
|
31
31
|
"""
|
|
32
32
|
model_dict = super().model_dump(**kwargs)
|
|
@@ -15,9 +15,21 @@ class LVAEModel(ArchitectureModel):
|
|
|
15
15
|
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
16
16
|
|
|
17
17
|
architecture: Literal["LVAE"]
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
18
|
+
"""Name of the architecture."""
|
|
19
|
+
|
|
20
|
+
input_shape: list[int] = Field(default=[64, 64], validate_default=True)
|
|
21
|
+
"""Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D."""
|
|
22
|
+
|
|
23
|
+
encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
|
|
24
|
+
|
|
25
|
+
# TODO make this per hierarchy step ?
|
|
26
|
+
decoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
|
|
27
|
+
"""Dimensions (2D or 3D) of the convolutional layers."""
|
|
28
|
+
|
|
29
|
+
multiscale_count: int = Field(default=1)
|
|
30
|
+
# TODO there should be a check for multiscale_count in dataset !!
|
|
31
|
+
|
|
32
|
+
# 1 - off, len(z_dims) + 1 # TODO Consider starting from 0
|
|
21
33
|
z_dims: list = Field(default=[128, 128, 128, 128])
|
|
22
34
|
output_channels: int = Field(default=1, ge=1)
|
|
23
35
|
encoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
@@ -31,10 +43,90 @@ class LVAEModel(ArchitectureModel):
|
|
|
31
43
|
)
|
|
32
44
|
|
|
33
45
|
predict_logvar: Literal[None, "pixelwise"] = None
|
|
46
|
+
analytical_kl: bool = Field(default=False)
|
|
34
47
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
48
|
+
@model_validator(mode="after")
|
|
49
|
+
def validate_conv_strides(self: Self) -> Self:
|
|
50
|
+
"""
|
|
51
|
+
Validate the convolutional strides.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
list
|
|
56
|
+
Validated strides.
|
|
57
|
+
|
|
58
|
+
Raises
|
|
59
|
+
------
|
|
60
|
+
ValueError
|
|
61
|
+
If the number of strides is not 2.
|
|
62
|
+
"""
|
|
63
|
+
if len(self.encoder_conv_strides) < 2 or len(self.encoder_conv_strides) > 3:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Strides must be 2 or 3 (got {len(self.encoder_conv_strides)})."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if len(self.decoder_conv_strides) < 2 or len(self.decoder_conv_strides) > 3:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Strides must be 2 or 3 (got {len(self.decoder_conv_strides)})."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# adding 1 to encoder strides for the number of input channels
|
|
74
|
+
if len(self.input_shape) != len(self.encoder_conv_strides):
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Input dimensions must be equal to the number of encoder conv strides"
|
|
77
|
+
f" (got {len(self.input_shape)} and {len(self.encoder_conv_strides)})."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if len(self.encoder_conv_strides) < len(self.decoder_conv_strides):
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"Decoder can't be 3D when encoder is 2D (got"
|
|
83
|
+
f" {len(self.encoder_conv_strides)} and"
|
|
84
|
+
f"{len(self.decoder_conv_strides)})."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if any(s < 1 for s in self.encoder_conv_strides) or any(
|
|
88
|
+
s < 1 for s in self.decoder_conv_strides
|
|
89
|
+
):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"All strides must be greater or equal to 1"
|
|
92
|
+
f"(got {self.encoder_conv_strides} and {self.decoder_conv_strides})."
|
|
93
|
+
)
|
|
94
|
+
# TODO: validate max stride size ?
|
|
95
|
+
return self
|
|
96
|
+
|
|
97
|
+
@field_validator("input_shape")
|
|
98
|
+
@classmethod
|
|
99
|
+
def validate_input_shape(cls, input_shape: list) -> list:
|
|
100
|
+
"""
|
|
101
|
+
Validate the input shape.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
input_shape : list
|
|
106
|
+
Shape of the input patch.
|
|
107
|
+
|
|
108
|
+
Returns
|
|
109
|
+
-------
|
|
110
|
+
list
|
|
111
|
+
Validated input shape.
|
|
112
|
+
|
|
113
|
+
Raises
|
|
114
|
+
------
|
|
115
|
+
ValueError
|
|
116
|
+
If the number of dimensions is not 3 or 4.
|
|
117
|
+
"""
|
|
118
|
+
if len(input_shape) < 2 or len(input_shape) > 3:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Number of input dimensions must be 2 for 2D data 3 for 3D"
|
|
121
|
+
f"(got {len(input_shape)})."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if any(s < 1 for s in input_shape):
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Input shape must be greater than 1 in all dimensions"
|
|
127
|
+
f"(got {input_shape})."
|
|
128
|
+
)
|
|
129
|
+
return input_shape
|
|
38
130
|
|
|
39
131
|
@field_validator("encoder_n_filters")
|
|
40
132
|
@classmethod
|
|
@@ -124,27 +216,20 @@ class LVAEModel(ArchitectureModel):
|
|
|
124
216
|
return z_dims
|
|
125
217
|
|
|
126
218
|
@model_validator(mode="after")
|
|
127
|
-
def validate_multiscale_count(
|
|
219
|
+
def validate_multiscale_count(self: Self) -> Self:
|
|
128
220
|
"""
|
|
129
221
|
Validate the multiscale count.
|
|
130
222
|
|
|
131
|
-
Parameters
|
|
132
|
-
----------
|
|
133
|
-
self : Self
|
|
134
|
-
The model.
|
|
135
|
-
|
|
136
223
|
Returns
|
|
137
224
|
-------
|
|
138
225
|
Self
|
|
139
226
|
The validated model.
|
|
140
227
|
"""
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
# )
|
|
147
|
-
|
|
228
|
+
if self.multiscale_count < 1 or self.multiscale_count > len(self.z_dims) + 1:
|
|
229
|
+
raise ValueError(
|
|
230
|
+
f"Multiscale count must be 1 for LC off or less or equal to the number"
|
|
231
|
+
f" of Z dims + 1 (got {self.multiscale_count} and {len(self.z_dims)})."
|
|
232
|
+
)
|
|
148
233
|
return self
|
|
149
234
|
|
|
150
235
|
def set_3D(self, is_3D: bool) -> None:
|
|
@@ -156,7 +241,10 @@ class LVAEModel(ArchitectureModel):
|
|
|
156
241
|
is_3D : bool
|
|
157
242
|
Whether the algorithm is 3D or not.
|
|
158
243
|
"""
|
|
159
|
-
|
|
244
|
+
if is_3D:
|
|
245
|
+
self.conv_dims = 3
|
|
246
|
+
else:
|
|
247
|
+
self.conv_dims = 2
|
|
160
248
|
|
|
161
249
|
def is_3D(self) -> bool:
|
|
162
250
|
"""
|
|
@@ -167,4 +255,4 @@ class LVAEModel(ArchitectureModel):
|
|
|
167
255
|
bool
|
|
168
256
|
Whether the model is 3D or not.
|
|
169
257
|
"""
|
|
170
|
-
|
|
258
|
+
return self.conv_dims == 3
|
|
@@ -48,6 +48,7 @@ class UNetModel(ArchitectureModel):
|
|
|
48
48
|
num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
|
|
49
49
|
"""Number of convolutional filters in the first layer of the UNet."""
|
|
50
50
|
|
|
51
|
+
# TODO we are not using this, so why make it a choice?
|
|
51
52
|
final_activation: Literal[
|
|
52
53
|
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
|
|
53
54
|
] = Field(default="None", validate_default=True)
|