careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +4 -3
- careamics/cli/conf.py +1 -2
- careamics/cli/main.py +1 -2
- careamics/cli/utils.py +3 -3
- careamics/config/__init__.py +47 -25
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +38 -0
- careamics/config/algorithms/n2n_algorithm_model.py +30 -0
- careamics/config/algorithms/n2v_algorithm_model.py +29 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +6 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +91 -186
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +1 -2
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +1 -2
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +5 -4
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/__init__.py +12 -1
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +3 -3
- careamics/lightning/lightning_module.py +11 -7
- careamics/lightning/train_data_module.py +36 -45
- careamics/losses/__init__.py +3 -3
- careamics/lvae_training/calibration.py +64 -57
- careamics/lvae_training/dataset/lc_dataset.py +2 -1
- careamics/lvae_training/dataset/multich_dataset.py +2 -2
- careamics/lvae_training/dataset/types.py +1 -1
- careamics/lvae_training/eval_utils.py +123 -128
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +1 -1
- careamics/model_io/bioimage/model_description.py +17 -17
- careamics/model_io/bmz_io.py +6 -17
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +16 -16
- careamics/models/lvae/likelihoods.py +2 -0
- careamics/models/lvae/lvae.py +13 -4
- careamics/models/lvae/noise_models.py +280 -217
- careamics/models/lvae/stochastic.py +1 -0
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/logging.py +11 -10
- careamics/utils/metrics.py +25 -0
- careamics/utils/plotting.py +78 -0
- careamics/utils/torch_utils.py +7 -7
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
- careamics-0.0.7.dist-info/RECORD +178 -0
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.5.dist-info/RECORD +0 -171
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
careamics/__init__.py
CHANGED
|
@@ -7,7 +7,22 @@ try:
|
|
|
7
7
|
except PackageNotFoundError:
|
|
8
8
|
__version__ = "uninstalled"
|
|
9
9
|
|
|
10
|
-
__all__ = [
|
|
10
|
+
__all__ = [
|
|
11
|
+
"CAREamist",
|
|
12
|
+
"Configuration",
|
|
13
|
+
"algorithm_factory",
|
|
14
|
+
"configuration_factory",
|
|
15
|
+
"data_factory",
|
|
16
|
+
"load_configuration",
|
|
17
|
+
"save_configuration",
|
|
18
|
+
]
|
|
11
19
|
|
|
12
20
|
from .careamist import CAREamist
|
|
13
|
-
from .config import
|
|
21
|
+
from .config import (
|
|
22
|
+
Configuration,
|
|
23
|
+
algorithm_factory,
|
|
24
|
+
configuration_factory,
|
|
25
|
+
data_factory,
|
|
26
|
+
load_configuration,
|
|
27
|
+
save_configuration,
|
|
28
|
+
)
|
careamics/careamist.py
CHANGED
|
@@ -13,7 +13,7 @@ from pytorch_lightning.callbacks import (
|
|
|
13
13
|
)
|
|
14
14
|
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
|
|
15
15
|
|
|
16
|
-
from careamics.config import Configuration,
|
|
16
|
+
from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
|
|
17
17
|
from careamics.config.support import (
|
|
18
18
|
SupportedAlgorithm,
|
|
19
19
|
SupportedArchitecture,
|
|
@@ -137,7 +137,7 @@ class CAREamist:
|
|
|
137
137
|
self.cfg = source
|
|
138
138
|
|
|
139
139
|
# instantiate model
|
|
140
|
-
if isinstance(self.cfg.algorithm_config,
|
|
140
|
+
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
|
|
141
141
|
self.model = FCNModule(
|
|
142
142
|
algorithm_config=self.cfg.algorithm_config,
|
|
143
143
|
)
|
|
@@ -157,7 +157,8 @@ class CAREamist:
|
|
|
157
157
|
self.cfg = load_configuration(source)
|
|
158
158
|
|
|
159
159
|
# instantiate model
|
|
160
|
-
|
|
160
|
+
# TODO call model factory here
|
|
161
|
+
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
|
|
161
162
|
self.model = FCNModule(
|
|
162
163
|
algorithm_config=self.cfg.algorithm_config,
|
|
163
164
|
) # type: ignore
|
careamics/cli/conf.py
CHANGED
|
@@ -3,12 +3,11 @@
|
|
|
3
3
|
import sys
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Optional
|
|
6
|
+
from typing import Annotated, Optional
|
|
7
7
|
|
|
8
8
|
import click
|
|
9
9
|
import typer
|
|
10
10
|
import yaml
|
|
11
|
-
from typing_extensions import Annotated
|
|
12
11
|
|
|
13
12
|
from ..config import (
|
|
14
13
|
Configuration,
|
careamics/cli/main.py
CHANGED
|
@@ -7,11 +7,10 @@ its implementation is contained in the conf.py file.
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import Optional
|
|
10
|
+
from typing import Annotated, Optional
|
|
11
11
|
|
|
12
12
|
import click
|
|
13
13
|
import typer
|
|
14
|
-
from typing_extensions import Annotated
|
|
15
14
|
|
|
16
15
|
from ..careamist import CAREamist
|
|
17
16
|
from . import conf
|
careamics/cli/utils.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
"""Utility functions for the CAREamics CLI."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Optional
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def handle_2D_3D_callback(
|
|
7
|
-
value: Optional[
|
|
8
|
-
) -> Optional[
|
|
7
|
+
value: Optional[tuple[int, int, int]]
|
|
8
|
+
) -> Optional[tuple[int, ...]]:
|
|
9
9
|
"""
|
|
10
10
|
Callback for options that require 2D or 3D inputs.
|
|
11
11
|
|
careamics/config/__init__.py
CHANGED
|
@@ -1,41 +1,63 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""CAREamics Pydantic configuration models.
|
|
2
|
+
|
|
3
|
+
To maintain clarity at the module level, we follow the following naming conventions:
|
|
4
|
+
`*_model` is specific for sub-configurations (e.g. architecture, data, algorithm),
|
|
5
|
+
while `*_configuration` is reserved for the main configuration models, including the
|
|
6
|
+
`Configuration` base class and its algorithm-specific child classes.
|
|
7
|
+
"""
|
|
2
8
|
|
|
3
9
|
__all__ = [
|
|
4
|
-
"
|
|
5
|
-
"
|
|
6
|
-
"DataConfig",
|
|
7
|
-
"Configuration",
|
|
10
|
+
"CAREAlgorithm",
|
|
11
|
+
"CAREConfiguration",
|
|
8
12
|
"CheckpointModel",
|
|
13
|
+
"Configuration",
|
|
14
|
+
"DataConfig",
|
|
15
|
+
"GaussianMixtureNMConfig",
|
|
16
|
+
"GeneralDataConfig",
|
|
9
17
|
"InferenceConfig",
|
|
10
|
-
"
|
|
11
|
-
"
|
|
18
|
+
"LVAELossConfig",
|
|
19
|
+
"MultiChannelNMConfig",
|
|
20
|
+
"N2NAlgorithm",
|
|
21
|
+
"N2NConfiguration",
|
|
22
|
+
"N2VAlgorithm",
|
|
23
|
+
"N2VConfiguration",
|
|
24
|
+
"N2VDataConfig",
|
|
12
25
|
"TrainingConfig",
|
|
13
|
-
"
|
|
14
|
-
"
|
|
26
|
+
"UNetBasedAlgorithm",
|
|
27
|
+
"VAEBasedAlgorithm",
|
|
28
|
+
"algorithm_factory",
|
|
29
|
+
"configuration_factory",
|
|
15
30
|
"create_care_configuration",
|
|
16
|
-
"
|
|
17
|
-
"
|
|
18
|
-
"
|
|
19
|
-
"
|
|
20
|
-
"
|
|
21
|
-
"LVAELossConfig",
|
|
31
|
+
"create_n2n_configuration",
|
|
32
|
+
"create_n2v_configuration",
|
|
33
|
+
"data_factory",
|
|
34
|
+
"load_configuration",
|
|
35
|
+
"save_configuration",
|
|
22
36
|
]
|
|
23
|
-
|
|
37
|
+
|
|
38
|
+
from .algorithms import (
|
|
39
|
+
CAREAlgorithm,
|
|
40
|
+
N2NAlgorithm,
|
|
41
|
+
N2VAlgorithm,
|
|
42
|
+
UNetBasedAlgorithm,
|
|
43
|
+
VAEBasedAlgorithm,
|
|
44
|
+
)
|
|
24
45
|
from .callback_model import CheckpointModel
|
|
25
|
-
from .
|
|
46
|
+
from .care_configuration import CAREConfiguration
|
|
47
|
+
from .configuration import Configuration
|
|
48
|
+
from .configuration_factories import (
|
|
49
|
+
algorithm_factory,
|
|
50
|
+
configuration_factory,
|
|
26
51
|
create_care_configuration,
|
|
27
52
|
create_n2n_configuration,
|
|
28
53
|
create_n2v_configuration,
|
|
54
|
+
data_factory,
|
|
29
55
|
)
|
|
30
|
-
from .
|
|
31
|
-
|
|
32
|
-
load_configuration,
|
|
33
|
-
save_configuration,
|
|
34
|
-
)
|
|
35
|
-
from .data_model import DataConfig
|
|
36
|
-
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
56
|
+
from .configuration_io import load_configuration, save_configuration
|
|
57
|
+
from .data import DataConfig, GeneralDataConfig, N2VDataConfig
|
|
37
58
|
from .inference_model import InferenceConfig
|
|
38
59
|
from .loss_model import LVAELossConfig
|
|
60
|
+
from .n2n_configuration import N2NConfiguration
|
|
61
|
+
from .n2v_configuration import N2VConfiguration
|
|
39
62
|
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
40
63
|
from .training_model import TrainingConfig
|
|
41
|
-
from .vae_algorithm_model import VAEAlgorithmConfig
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Algorithm configurations."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CAREAlgorithm",
|
|
5
|
+
"N2NAlgorithm",
|
|
6
|
+
"N2VAlgorithm",
|
|
7
|
+
"UNetBasedAlgorithm",
|
|
8
|
+
"VAEBasedAlgorithm",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from .care_algorithm_model import CAREAlgorithm
|
|
12
|
+
from .n2n_algorithm_model import N2NAlgorithm
|
|
13
|
+
from .n2v_algorithm_model import N2VAlgorithm
|
|
14
|
+
from .unet_algorithm_model import UNetBasedAlgorithm
|
|
15
|
+
from .vae_algorithm_model import VAEBasedAlgorithm
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""CARE algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import AfterValidator
|
|
6
|
+
|
|
7
|
+
from careamics.config.architectures import UNetModel
|
|
8
|
+
from careamics.config.validators import (
|
|
9
|
+
model_without_final_activation,
|
|
10
|
+
model_without_n2v2,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .unet_algorithm_model import UNetBasedAlgorithm
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CAREAlgorithm(UNetBasedAlgorithm):
|
|
17
|
+
"""CARE algorithm configuration.
|
|
18
|
+
|
|
19
|
+
Attributes
|
|
20
|
+
----------
|
|
21
|
+
algorithm : "care"
|
|
22
|
+
CARE Algorithm name.
|
|
23
|
+
loss : {"mae", "mse"}
|
|
24
|
+
CARE-compatible loss function.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
algorithm: Literal["care"] = "care"
|
|
28
|
+
"""CARE Algorithm name."""
|
|
29
|
+
|
|
30
|
+
loss: Literal["mae", "mse"] = "mae"
|
|
31
|
+
"""CARE-compatible loss function."""
|
|
32
|
+
|
|
33
|
+
model: Annotated[
|
|
34
|
+
UNetModel,
|
|
35
|
+
AfterValidator(model_without_n2v2),
|
|
36
|
+
AfterValidator(model_without_final_activation),
|
|
37
|
+
]
|
|
38
|
+
"""UNet without a final activation function and without the `n2v2` modifications."""
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""N2N Algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import AfterValidator
|
|
6
|
+
|
|
7
|
+
from careamics.config.architectures import UNetModel
|
|
8
|
+
from careamics.config.validators import (
|
|
9
|
+
model_without_final_activation,
|
|
10
|
+
model_without_n2v2,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .unet_algorithm_model import UNetBasedAlgorithm
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class N2NAlgorithm(UNetBasedAlgorithm):
|
|
17
|
+
"""Noise2Noise Algorithm configuration."""
|
|
18
|
+
|
|
19
|
+
algorithm: Literal["n2n"] = "n2n"
|
|
20
|
+
"""N2N Algorithm name."""
|
|
21
|
+
|
|
22
|
+
loss: Literal["mae", "mse"] = "mae"
|
|
23
|
+
"""N2N-compatible loss function."""
|
|
24
|
+
|
|
25
|
+
model: Annotated[
|
|
26
|
+
UNetModel,
|
|
27
|
+
AfterValidator(model_without_n2v2),
|
|
28
|
+
AfterValidator(model_without_final_activation),
|
|
29
|
+
]
|
|
30
|
+
"""UNet without a final activation function and without the `n2v2` modifications."""
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
""""N2V Algorithm configuration."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import AfterValidator
|
|
6
|
+
|
|
7
|
+
from careamics.config.architectures import UNetModel
|
|
8
|
+
from careamics.config.validators import (
|
|
9
|
+
model_matching_in_out_channels,
|
|
10
|
+
model_without_final_activation,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .unet_algorithm_model import UNetBasedAlgorithm
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class N2VAlgorithm(UNetBasedAlgorithm):
|
|
17
|
+
"""N2V Algorithm configuration."""
|
|
18
|
+
|
|
19
|
+
algorithm: Literal["n2v"] = "n2v"
|
|
20
|
+
"""N2V Algorithm name."""
|
|
21
|
+
|
|
22
|
+
loss: Literal["n2v"] = "n2v"
|
|
23
|
+
"""N2V loss function."""
|
|
24
|
+
|
|
25
|
+
model: Annotated[
|
|
26
|
+
UNetModel,
|
|
27
|
+
AfterValidator(model_matching_in_out_channels),
|
|
28
|
+
AfterValidator(model_without_final_activation),
|
|
29
|
+
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""UNet-based algorithm Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from pprint import pformat
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
|
|
8
|
+
from careamics.config.architectures import UNetModel
|
|
9
|
+
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class UNetBasedAlgorithm(BaseModel):
|
|
13
|
+
"""General UNet-based algorithm configuration.
|
|
14
|
+
|
|
15
|
+
This Pydantic model validates the parameters governing the components of the
|
|
16
|
+
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
17
|
+
and learning rate scheduler to use.
|
|
18
|
+
|
|
19
|
+
Currently, we only support N2V, CARE, and N2N algorithms. In order to train these
|
|
20
|
+
algorithms, use the corresponding configuration child classes (e.g.
|
|
21
|
+
`N2VAlgorithm`) to ensure coherent parameters (e.g. specific losses).
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
Attributes
|
|
25
|
+
----------
|
|
26
|
+
algorithm : {"n2v", "care", "n2n"}
|
|
27
|
+
Algorithm to use.
|
|
28
|
+
loss : {"n2v", "mae", "mse"}
|
|
29
|
+
Loss function to use.
|
|
30
|
+
model : UNetModel
|
|
31
|
+
Model architecture to use.
|
|
32
|
+
optimizer : OptimizerModel, optional
|
|
33
|
+
Optimizer to use.
|
|
34
|
+
lr_scheduler : LrSchedulerModel, optional
|
|
35
|
+
Learning rate scheduler to use.
|
|
36
|
+
|
|
37
|
+
Raises
|
|
38
|
+
------
|
|
39
|
+
ValueError
|
|
40
|
+
Algorithm parameter type validation errors.
|
|
41
|
+
ValueError
|
|
42
|
+
If the algorithm, loss and model are not compatible.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
# Pydantic class configuration
|
|
46
|
+
model_config = ConfigDict(
|
|
47
|
+
protected_namespaces=(), # allows to use model_* as a field name
|
|
48
|
+
validate_assignment=True,
|
|
49
|
+
extra="allow",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Mandatory fields
|
|
53
|
+
algorithm: Literal["n2v", "care", "n2n"]
|
|
54
|
+
"""Algorithm name, as defined in SupportedAlgorithm."""
|
|
55
|
+
|
|
56
|
+
loss: Literal["n2v", "mae", "mse"]
|
|
57
|
+
"""Loss function to use, as defined in SupportedLoss."""
|
|
58
|
+
|
|
59
|
+
model: UNetModel
|
|
60
|
+
"""UNet model configuration."""
|
|
61
|
+
|
|
62
|
+
# Optional fields
|
|
63
|
+
optimizer: OptimizerModel = OptimizerModel()
|
|
64
|
+
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
65
|
+
|
|
66
|
+
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
67
|
+
"""Learning rate scheduler to use, defined in SupportedLrScheduler."""
|
|
68
|
+
|
|
69
|
+
def __str__(self) -> str:
|
|
70
|
+
"""Pretty string representing the configuration.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
str
|
|
75
|
+
Pretty string.
|
|
76
|
+
"""
|
|
77
|
+
return pformat(self.model_dump())
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def get_compatible_algorithms(cls) -> list[str]:
|
|
81
|
+
"""Get the list of compatible algorithms.
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
list of str
|
|
86
|
+
List of compatible algorithms.
|
|
87
|
+
"""
|
|
88
|
+
return ["n2v", "care", "n2n"]
|
|
@@ -1,24 +1,26 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""VAE-based algorithm Pydantic model."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Literal, Optional
|
|
6
|
+
from typing import Literal, Optional
|
|
7
7
|
|
|
8
|
-
from pydantic import BaseModel, ConfigDict,
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
|
9
9
|
from typing_extensions import Self
|
|
10
10
|
|
|
11
|
+
from careamics.config.architectures import LVAEModel
|
|
12
|
+
from careamics.config.likelihood_model import (
|
|
13
|
+
GaussianLikelihoodConfig,
|
|
14
|
+
NMLikelihoodConfig,
|
|
15
|
+
)
|
|
16
|
+
from careamics.config.loss_model import LVAELossConfig
|
|
17
|
+
from careamics.config.nm_model import MultiChannelNMConfig
|
|
18
|
+
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
|
|
11
19
|
from careamics.config.support import SupportedAlgorithm, SupportedLoss
|
|
12
20
|
|
|
13
|
-
from .architectures import CustomModel, LVAEModel
|
|
14
|
-
from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
|
|
15
|
-
from .loss_model import LVAELossConfig
|
|
16
|
-
from .nm_model import MultiChannelNMConfig
|
|
17
|
-
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
18
21
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
"""Algorithm configuration.
|
|
22
|
+
class VAEBasedAlgorithm(BaseModel):
|
|
23
|
+
"""VAE-based algorithm configuration.
|
|
22
24
|
|
|
23
25
|
# TODO
|
|
24
26
|
|
|
@@ -42,7 +44,7 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
42
44
|
|
|
43
45
|
# NOTE: these are all configs (pydantic models)
|
|
44
46
|
loss: LVAELossConfig
|
|
45
|
-
model:
|
|
47
|
+
model: LVAEModel
|
|
46
48
|
noise_model: Optional[MultiChannelNMConfig] = None
|
|
47
49
|
noise_model_likelihood: Optional[NMLikelihoodConfig] = None
|
|
48
50
|
gaussian_likelihood: Optional[GaussianLikelihoodConfig] = None
|
|
@@ -1,17 +1,7 @@
|
|
|
1
1
|
"""Deep-learning model configurations."""
|
|
2
2
|
|
|
3
|
-
__all__ = [
|
|
4
|
-
"ArchitectureModel",
|
|
5
|
-
"CustomModel",
|
|
6
|
-
"UNetModel",
|
|
7
|
-
"LVAEModel",
|
|
8
|
-
"clear_custom_models",
|
|
9
|
-
"get_custom_model",
|
|
10
|
-
"register_model",
|
|
11
|
-
]
|
|
3
|
+
__all__ = ["ArchitectureModel", "LVAEModel", "UNetModel"]
|
|
12
4
|
|
|
13
5
|
from .architecture_model import ArchitectureModel
|
|
14
|
-
from .custom_model import CustomModel
|
|
15
6
|
from .lvae_model import LVAEModel
|
|
16
|
-
from .register_model import clear_custom_models, get_custom_model, register_model
|
|
17
7
|
from .unet_model import UNetModel
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Base model for the various CAREamics architectures."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel
|
|
6
6
|
|
|
@@ -15,7 +15,7 @@ class ArchitectureModel(BaseModel):
|
|
|
15
15
|
architecture: str
|
|
16
16
|
"""Name of the architecture."""
|
|
17
17
|
|
|
18
|
-
def model_dump(self, **kwargs: Any) ->
|
|
18
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
19
19
|
"""
|
|
20
20
|
Dump the model as a dictionary, ignoring the architecture keyword.
|
|
21
21
|
|
|
@@ -26,7 +26,7 @@ class ArchitectureModel(BaseModel):
|
|
|
26
26
|
|
|
27
27
|
Returns
|
|
28
28
|
-------
|
|
29
|
-
|
|
29
|
+
{str: Any}
|
|
30
30
|
Model as a dictionary.
|
|
31
31
|
"""
|
|
32
32
|
model_dict = super().model_dump(**kwargs)
|
|
@@ -15,12 +15,17 @@ class LVAEModel(ArchitectureModel):
|
|
|
15
15
|
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
16
16
|
|
|
17
17
|
architecture: Literal["LVAE"]
|
|
18
|
-
|
|
18
|
+
"""Name of the architecture."""
|
|
19
|
+
|
|
20
|
+
input_shape: list[int] = Field(default=[64, 64], validate_default=True)
|
|
19
21
|
"""Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D."""
|
|
22
|
+
|
|
20
23
|
encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
|
|
24
|
+
|
|
21
25
|
# TODO make this per hierarchy step ?
|
|
22
26
|
decoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
|
|
23
27
|
"""Dimensions (2D or 3D) of the convolutional layers."""
|
|
28
|
+
|
|
24
29
|
multiscale_count: int = Field(default=1)
|
|
25
30
|
# TODO there should be a check for multiscale_count in dataset !!
|
|
26
31
|
|
|
@@ -48,6 +48,7 @@ class UNetModel(ArchitectureModel):
|
|
|
48
48
|
num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
|
|
49
49
|
"""Number of convolutional filters in the first layer of the UNet."""
|
|
50
50
|
|
|
51
|
+
# TODO we are not using this, so why make it a choice?
|
|
51
52
|
final_activation: Literal[
|
|
52
53
|
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
|
|
53
54
|
] = Field(default="None", validate_default=True)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""CARE Pydantic configuration."""
|
|
2
|
+
|
|
3
|
+
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
4
|
+
|
|
5
|
+
from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm
|
|
6
|
+
from careamics.config.configuration import Configuration
|
|
7
|
+
from careamics.config.data import DataConfig
|
|
8
|
+
|
|
9
|
+
CARE = "CARE"
|
|
10
|
+
|
|
11
|
+
CARE_DESCRIPTION = (
|
|
12
|
+
"Content-aware image restoration (CARE) is a deep-learning-based "
|
|
13
|
+
"algorithm that uses a U-Net architecture to restore images. CARE "
|
|
14
|
+
"is a supervised algorithm that requires pairs of noisy and "
|
|
15
|
+
"clean images to train the network. The algorithm learns to "
|
|
16
|
+
"predict the clean image from the noisy image. CARE is "
|
|
17
|
+
"particularly useful for denoising images acquired in low-light "
|
|
18
|
+
"conditions, such as fluorescence microscopy images."
|
|
19
|
+
)
|
|
20
|
+
CARE_REF = CiteEntry(
|
|
21
|
+
text='Weigert, Martin, et al. "Content-aware image restoration: pushing the '
|
|
22
|
+
'limits of fluorescence microscopy." Nature methods 15.12 (2018): 1090-1097.',
|
|
23
|
+
doi="10.1038/s41592-018-0216-7",
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CAREConfiguration(Configuration):
|
|
28
|
+
"""CARE configuration."""
|
|
29
|
+
|
|
30
|
+
algorithm_config: CAREAlgorithm
|
|
31
|
+
"""Algorithm configuration."""
|
|
32
|
+
|
|
33
|
+
data_config: DataConfig
|
|
34
|
+
"""Data configuration."""
|
|
35
|
+
|
|
36
|
+
def get_algorithm_friendly_name(self) -> str:
|
|
37
|
+
"""
|
|
38
|
+
Get the algorithm friendly name.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
str
|
|
43
|
+
Friendly name of the algorithm.
|
|
44
|
+
"""
|
|
45
|
+
return CARE
|
|
46
|
+
|
|
47
|
+
def get_algorithm_keywords(self) -> list[str]:
|
|
48
|
+
"""
|
|
49
|
+
Get algorithm keywords.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
list[str]
|
|
54
|
+
List of keywords.
|
|
55
|
+
"""
|
|
56
|
+
return [
|
|
57
|
+
"restoration",
|
|
58
|
+
"UNet",
|
|
59
|
+
"3D" if "Z" in self.data_config.axes else "2D",
|
|
60
|
+
"CAREamics",
|
|
61
|
+
"pytorch",
|
|
62
|
+
CARE,
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
def get_algorithm_references(self) -> str:
|
|
66
|
+
"""
|
|
67
|
+
Get the algorithm references.
|
|
68
|
+
|
|
69
|
+
This is used to generate the README of the BioImage Model Zoo export.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
str
|
|
74
|
+
Algorithm references.
|
|
75
|
+
"""
|
|
76
|
+
return CARE_REF.text + " doi: " + CARE_REF.doi
|
|
77
|
+
|
|
78
|
+
def get_algorithm_citations(self) -> list[CiteEntry]:
|
|
79
|
+
"""
|
|
80
|
+
Return a list of citation entries of the current algorithm.
|
|
81
|
+
|
|
82
|
+
This is used to generate the model description for the BioImage Model Zoo.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
List[CiteEntry]
|
|
87
|
+
List of citation entries.
|
|
88
|
+
"""
|
|
89
|
+
return [CARE_REF]
|
|
90
|
+
|
|
91
|
+
def get_algorithm_description(self) -> str:
|
|
92
|
+
"""
|
|
93
|
+
Get the algorithm description.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
str
|
|
98
|
+
Algorithm description.
|
|
99
|
+
"""
|
|
100
|
+
return CARE_DESCRIPTION
|