careamics 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +103 -36
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +0 -3
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/validator_utils.py +3 -3
  35. careamics/dataset/__init__.py +2 -2
  36. careamics/dataset/dataset_utils/__init__.py +3 -3
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  38. careamics/dataset/dataset_utils/file_utils.py +9 -9
  39. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  40. careamics/dataset/in_memory_dataset.py +11 -12
  41. careamics/dataset/iterable_dataset.py +4 -4
  42. careamics/dataset/iterable_pred_dataset.py +2 -1
  43. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  44. careamics/dataset/patching/random_patching.py +11 -10
  45. careamics/dataset/patching/sequential_patching.py +26 -26
  46. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  47. careamics/dataset/tiling/__init__.py +2 -2
  48. careamics/dataset/tiling/collate_tiles.py +3 -3
  49. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  50. careamics/dataset/tiling/tiled_patching.py +11 -10
  51. careamics/file_io/__init__.py +5 -5
  52. careamics/file_io/read/__init__.py +1 -1
  53. careamics/file_io/read/get_func.py +2 -2
  54. careamics/file_io/write/__init__.py +2 -2
  55. careamics/lightning/__init__.py +5 -5
  56. careamics/lightning/callbacks/__init__.py +1 -1
  57. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  58. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  60. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  61. careamics/lightning/lightning_module.py +11 -7
  62. careamics/lightning/train_data_module.py +26 -26
  63. careamics/losses/__init__.py +3 -3
  64. careamics/model_io/__init__.py +1 -1
  65. careamics/model_io/bioimage/__init__.py +1 -1
  66. careamics/model_io/bioimage/_readme_factory.py +1 -1
  67. careamics/model_io/bioimage/model_description.py +17 -17
  68. careamics/model_io/bmz_io.py +6 -17
  69. careamics/model_io/model_io_utils.py +9 -9
  70. careamics/models/layers.py +16 -16
  71. careamics/models/lvae/lvae.py +0 -3
  72. careamics/models/model_factory.py +2 -15
  73. careamics/models/unet.py +8 -8
  74. careamics/prediction_utils/__init__.py +1 -1
  75. careamics/prediction_utils/prediction_outputs.py +15 -15
  76. careamics/prediction_utils/stitch_prediction.py +6 -6
  77. careamics/transforms/__init__.py +5 -5
  78. careamics/transforms/compose.py +13 -13
  79. careamics/transforms/n2v_manipulate.py +3 -3
  80. careamics/transforms/pixel_manipulation.py +9 -9
  81. careamics/transforms/xy_random_rotate90.py +4 -4
  82. careamics/utils/__init__.py +5 -5
  83. careamics/utils/context.py +2 -1
  84. careamics/utils/logging.py +11 -10
  85. careamics/utils/torch_utils.py +7 -7
  86. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/METADATA +11 -11
  87. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/RECORD +90 -85
  88. careamics/config/architectures/custom_model.py +0 -162
  89. careamics/config/architectures/register_model.py +0 -103
  90. careamics/config/configuration_model.py +0 -603
  91. careamics/config/fcn_algorithm_model.py +0 -152
  92. careamics/config/references/__init__.py +0 -45
  93. careamics/config/references/algorithm_descriptions.py +0 -132
  94. careamics/config/references/references.py +0 -39
  95. careamics/config/transformations/transform_union.py +0 -20
  96. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
careamics/__init__.py CHANGED
@@ -7,7 +7,22 @@ try:
7
7
  except PackageNotFoundError:
8
8
  __version__ = "uninstalled"
9
9
 
10
- __all__ = ["CAREamist", "Configuration", "load_configuration", "save_configuration"]
10
+ __all__ = [
11
+ "CAREamist",
12
+ "Configuration",
13
+ "algorithm_factory",
14
+ "configuration_factory",
15
+ "data_factory",
16
+ "load_configuration",
17
+ "save_configuration",
18
+ ]
11
19
 
12
20
  from .careamist import CAREamist
13
- from .config import Configuration, load_configuration, save_configuration
21
+ from .config import (
22
+ Configuration,
23
+ algorithm_factory,
24
+ configuration_factory,
25
+ data_factory,
26
+ load_configuration,
27
+ save_configuration,
28
+ )
careamics/careamist.py CHANGED
@@ -13,7 +13,7 @@ from pytorch_lightning.callbacks import (
13
13
  )
14
14
  from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
15
15
 
16
- from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
16
+ from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration
17
17
  from careamics.config.support import (
18
18
  SupportedAlgorithm,
19
19
  SupportedArchitecture,
@@ -137,7 +137,7 @@ class CAREamist:
137
137
  self.cfg = source
138
138
 
139
139
  # instantiate model
140
- if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
140
+ if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
141
141
  self.model = FCNModule(
142
142
  algorithm_config=self.cfg.algorithm_config,
143
143
  )
@@ -157,7 +157,8 @@ class CAREamist:
157
157
  self.cfg = load_configuration(source)
158
158
 
159
159
  # instantiate model
160
- if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
160
+ # TODO call model factory here
161
+ if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
161
162
  self.model = FCNModule(
162
163
  algorithm_config=self.cfg.algorithm_config,
163
164
  ) # type: ignore
careamics/cli/conf.py CHANGED
@@ -3,12 +3,11 @@
3
3
  import sys
4
4
  from dataclasses import dataclass
5
5
  from pathlib import Path
6
- from typing import Optional
6
+ from typing import Annotated, Optional
7
7
 
8
8
  import click
9
9
  import typer
10
10
  import yaml
11
- from typing_extensions import Annotated
12
11
 
13
12
  from ..config import (
14
13
  Configuration,
careamics/cli/main.py CHANGED
@@ -7,11 +7,10 @@ its implementation is contained in the conf.py file.
7
7
  """
8
8
 
9
9
  from pathlib import Path
10
- from typing import Optional
10
+ from typing import Annotated, Optional
11
11
 
12
12
  import click
13
13
  import typer
14
- from typing_extensions import Annotated
15
14
 
16
15
  from ..careamist import CAREamist
17
16
  from . import conf
careamics/cli/utils.py CHANGED
@@ -1,11 +1,11 @@
1
1
  """Utility functions for the CAREamics CLI."""
2
2
 
3
- from typing import Optional, Tuple
3
+ from typing import Optional
4
4
 
5
5
 
6
6
  def handle_2D_3D_callback(
7
- value: Optional[Tuple[int, int, int]]
8
- ) -> Optional[Tuple[int, ...]]:
7
+ value: Optional[tuple[int, int, int]]
8
+ ) -> Optional[tuple[int, ...]]:
9
9
  """
10
10
  Callback for options that require 2D or 3D inputs.
11
11
 
@@ -1,41 +1,63 @@
1
- """Configuration module."""
1
+ """CAREamics Pydantic configuration models.
2
+
3
+ To maintain clarity at the module level, we follow the following naming conventions:
4
+ `*_model` is specific for sub-configurations (e.g. architecture, data, algorithm),
5
+ while `*_configuration` is reserved for the main configuration models, including the
6
+ `Configuration` base class and its algorithm-specific child classes.
7
+ """
2
8
 
3
9
  __all__ = [
4
- "FCNAlgorithmConfig",
5
- "VAEAlgorithmConfig",
6
- "DataConfig",
7
- "Configuration",
10
+ "CAREAlgorithm",
11
+ "CAREConfiguration",
8
12
  "CheckpointModel",
13
+ "Configuration",
14
+ "DataConfig",
15
+ "GaussianMixtureNMConfig",
16
+ "GeneralDataConfig",
9
17
  "InferenceConfig",
10
- "load_configuration",
11
- "save_configuration",
18
+ "LVAELossConfig",
19
+ "MultiChannelNMConfig",
20
+ "N2NAlgorithm",
21
+ "N2NConfiguration",
22
+ "N2VAlgorithm",
23
+ "N2VConfiguration",
24
+ "N2VDataConfig",
12
25
  "TrainingConfig",
13
- "create_n2v_configuration",
14
- "create_n2n_configuration",
26
+ "UNetBasedAlgorithm",
27
+ "VAEBasedAlgorithm",
28
+ "algorithm_factory",
29
+ "configuration_factory",
15
30
  "create_care_configuration",
16
- "register_model",
17
- "CustomModel",
18
- "clear_custom_models",
19
- "GaussianMixtureNMConfig",
20
- "MultiChannelNMConfig",
21
- "LVAELossConfig",
31
+ "create_n2n_configuration",
32
+ "create_n2v_configuration",
33
+ "data_factory",
34
+ "load_configuration",
35
+ "save_configuration",
22
36
  ]
23
- from .architectures import CustomModel, clear_custom_models, register_model
37
+
38
+ from .algorithms import (
39
+ CAREAlgorithm,
40
+ N2NAlgorithm,
41
+ N2VAlgorithm,
42
+ UNetBasedAlgorithm,
43
+ VAEBasedAlgorithm,
44
+ )
24
45
  from .callback_model import CheckpointModel
25
- from .configuration_factory import (
46
+ from .care_configuration import CAREConfiguration
47
+ from .configuration import Configuration
48
+ from .configuration_factories import (
49
+ algorithm_factory,
50
+ configuration_factory,
26
51
  create_care_configuration,
27
52
  create_n2n_configuration,
28
53
  create_n2v_configuration,
54
+ data_factory,
29
55
  )
30
- from .configuration_model import (
31
- Configuration,
32
- load_configuration,
33
- save_configuration,
34
- )
35
- from .data_model import DataConfig
36
- from .fcn_algorithm_model import FCNAlgorithmConfig
56
+ from .configuration_io import load_configuration, save_configuration
57
+ from .data import DataConfig, GeneralDataConfig, N2VDataConfig
37
58
  from .inference_model import InferenceConfig
38
59
  from .loss_model import LVAELossConfig
60
+ from .n2n_configuration import N2NConfiguration
61
+ from .n2v_configuration import N2VConfiguration
39
62
  from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
40
63
  from .training_model import TrainingConfig
41
- from .vae_algorithm_model import VAEAlgorithmConfig
@@ -0,0 +1,15 @@
1
+ """Algorithm configurations."""
2
+
3
+ __all__ = [
4
+ "CAREAlgorithm",
5
+ "N2NAlgorithm",
6
+ "N2VAlgorithm",
7
+ "UNetBasedAlgorithm",
8
+ "VAEBasedAlgorithm",
9
+ ]
10
+
11
+ from .care_algorithm_model import CAREAlgorithm
12
+ from .n2n_algorithm_model import N2NAlgorithm
13
+ from .n2v_algorithm_model import N2VAlgorithm
14
+ from .unet_algorithm_model import UNetBasedAlgorithm
15
+ from .vae_algorithm_model import VAEBasedAlgorithm
@@ -0,0 +1,50 @@
1
+ """CARE algorithm configuration."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import field_validator
6
+
7
+ from careamics.config.architectures import UNetModel
8
+
9
+ from .unet_algorithm_model import UNetBasedAlgorithm
10
+
11
+
12
+ class CAREAlgorithm(UNetBasedAlgorithm):
13
+ """CARE algorithm configuration.
14
+
15
+ Attributes
16
+ ----------
17
+ algorithm : "care"
18
+ CARE Algorithm name.
19
+ loss : {"mae", "mse"}
20
+ CARE-compatible loss function.
21
+ """
22
+
23
+ algorithm: Literal["care"] = "care"
24
+ """CARE Algorithm name."""
25
+
26
+ loss: Literal["mae", "mse"] = "mae"
27
+ """CARE-compatible loss function."""
28
+
29
+ @classmethod
30
+ @field_validator("model")
31
+ def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
32
+ """Validate that the model does not have the n2v2 attribute.
33
+
34
+ Parameters
35
+ ----------
36
+ value : UNetModel
37
+ Model to validate.
38
+
39
+ Returns
40
+ -------
41
+ UNetModel
42
+ The validated model.
43
+ """
44
+ if value.n2v2:
45
+ raise ValueError(
46
+ "The N2N algorithm does not support the `n2v2` attribute. "
47
+ "Set it to `False`."
48
+ )
49
+
50
+ return value
@@ -0,0 +1,42 @@
1
+ """N2N Algorithm configuration."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import field_validator
6
+
7
+ from careamics.config.architectures import UNetModel
8
+
9
+ from .unet_algorithm_model import UNetBasedAlgorithm
10
+
11
+
12
+ class N2NAlgorithm(UNetBasedAlgorithm):
13
+ """N2N Algorithm configuration."""
14
+
15
+ algorithm: Literal["n2n"] = "n2n"
16
+ """N2N Algorithm name."""
17
+
18
+ loss: Literal["mae", "mse"] = "mae"
19
+ """N2N-compatible loss function."""
20
+
21
+ @classmethod
22
+ @field_validator("model")
23
+ def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
24
+ """Validate that the model does not have the n2v2 attribute.
25
+
26
+ Parameters
27
+ ----------
28
+ value : UNetModel
29
+ Model to validate.
30
+
31
+ Returns
32
+ -------
33
+ UNetModel
34
+ The validated model.
35
+ """
36
+ if value.n2v2:
37
+ raise ValueError(
38
+ "The N2N algorithm does not support the `n2v2` attribute. "
39
+ "Set it to `False`."
40
+ )
41
+
42
+ return value
@@ -0,0 +1,35 @@
1
+ """"N2V Algorithm configuration."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import model_validator
6
+ from typing_extensions import Self
7
+
8
+ from .unet_algorithm_model import UNetBasedAlgorithm
9
+
10
+
11
+ class N2VAlgorithm(UNetBasedAlgorithm):
12
+ """N2V Algorithm configuration."""
13
+
14
+ algorithm: Literal["n2v"] = "n2v"
15
+ """N2V Algorithm name."""
16
+
17
+ loss: Literal["n2v"] = "n2v"
18
+ """N2V loss function."""
19
+
20
+ @model_validator(mode="after")
21
+ def algorithm_cross_validation(self: Self) -> Self:
22
+ """Validate the algorithm model for N2V.
23
+
24
+ Returns
25
+ -------
26
+ Self
27
+ The validated model.
28
+ """
29
+ if self.model.in_channels != self.model.num_classes:
30
+ raise ValueError(
31
+ "N2V requires the same number of input and output channels. Make "
32
+ "sure that `in_channels` and `num_classes` are the same."
33
+ )
34
+
35
+ return self
@@ -0,0 +1,88 @@
1
+ """UNet-based algorithm Pydantic model."""
2
+
3
+ from pprint import pformat
4
+ from typing import Literal
5
+
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+ from careamics.config.architectures import UNetModel
9
+ from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
10
+
11
+
12
+ class UNetBasedAlgorithm(BaseModel):
13
+ """General UNet-based algorithm configuration.
14
+
15
+ This Pydantic model validates the parameters governing the components of the
16
+ training algorithm: which algorithm, loss function, model architecture, optimizer,
17
+ and learning rate scheduler to use.
18
+
19
+ Currently, we only support N2V, CARE, and N2N algorithms. In order to train these
20
+ algorithms, use the corresponding configuration child classes (e.g.
21
+ `N2VAlgorithm`) to ensure coherent parameters (e.g. specific losses).
22
+
23
+
24
+ Attributes
25
+ ----------
26
+ algorithm : {"n2v", "care", "n2n"}
27
+ Algorithm to use.
28
+ loss : {"n2v", "mae", "mse"}
29
+ Loss function to use.
30
+ model : UNetModel
31
+ Model architecture to use.
32
+ optimizer : OptimizerModel, optional
33
+ Optimizer to use.
34
+ lr_scheduler : LrSchedulerModel, optional
35
+ Learning rate scheduler to use.
36
+
37
+ Raises
38
+ ------
39
+ ValueError
40
+ Algorithm parameter type validation errors.
41
+ ValueError
42
+ If the algorithm, loss and model are not compatible.
43
+ """
44
+
45
+ # Pydantic class configuration
46
+ model_config = ConfigDict(
47
+ protected_namespaces=(), # allows to use model_* as a field name
48
+ validate_assignment=True,
49
+ extra="allow",
50
+ )
51
+
52
+ # Mandatory fields
53
+ algorithm: Literal["n2v", "care", "n2n"]
54
+ """Algorithm name, as defined in SupportedAlgorithm."""
55
+
56
+ loss: Literal["n2v", "mae", "mse"]
57
+ """Loss function to use, as defined in SupportedLoss."""
58
+
59
+ model: UNetModel
60
+ """UNet model configuration."""
61
+
62
+ # Optional fields
63
+ optimizer: OptimizerModel = OptimizerModel()
64
+ """Optimizer to use, defined in SupportedOptimizer."""
65
+
66
+ lr_scheduler: LrSchedulerModel = LrSchedulerModel()
67
+ """Learning rate scheduler to use, defined in SupportedLrScheduler."""
68
+
69
+ def __str__(self) -> str:
70
+ """Pretty string representing the configuration.
71
+
72
+ Returns
73
+ -------
74
+ str
75
+ Pretty string.
76
+ """
77
+ return pformat(self.model_dump())
78
+
79
+ @classmethod
80
+ def get_compatible_algorithms(cls) -> list[str]:
81
+ """Get the list of compatible algorithms.
82
+
83
+ Returns
84
+ -------
85
+ list of str
86
+ List of compatible algorithms.
87
+ """
88
+ return ["n2v", "care", "n2n"]
@@ -1,24 +1,26 @@
1
- """Algorithm configuration."""
1
+ """VAE-based algorithm Pydantic model."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Literal, Optional, Union
6
+ from typing import Literal, Optional
7
7
 
8
- from pydantic import BaseModel, ConfigDict, Field, model_validator
8
+ from pydantic import BaseModel, ConfigDict, model_validator
9
9
  from typing_extensions import Self
10
10
 
11
+ from careamics.config.architectures import LVAEModel
12
+ from careamics.config.likelihood_model import (
13
+ GaussianLikelihoodConfig,
14
+ NMLikelihoodConfig,
15
+ )
16
+ from careamics.config.loss_model import LVAELossConfig
17
+ from careamics.config.nm_model import MultiChannelNMConfig
18
+ from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
11
19
  from careamics.config.support import SupportedAlgorithm, SupportedLoss
12
20
 
13
- from .architectures import CustomModel, LVAEModel
14
- from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
15
- from .loss_model import LVAELossConfig
16
- from .nm_model import MultiChannelNMConfig
17
- from .optimizer_models import LrSchedulerModel, OptimizerModel
18
21
 
19
-
20
- class VAEAlgorithmConfig(BaseModel):
21
- """Algorithm configuration.
22
+ class VAEBasedAlgorithm(BaseModel):
23
+ """VAE-based algorithm configuration.
22
24
 
23
25
  # TODO
24
26
 
@@ -42,7 +44,7 @@ class VAEAlgorithmConfig(BaseModel):
42
44
 
43
45
  # NOTE: these are all configs (pydantic models)
44
46
  loss: LVAELossConfig
45
- model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
47
+ model: LVAEModel
46
48
  noise_model: Optional[MultiChannelNMConfig] = None
47
49
  noise_model_likelihood: Optional[NMLikelihoodConfig] = None
48
50
  gaussian_likelihood: Optional[GaussianLikelihoodConfig] = None
@@ -1,17 +1,7 @@
1
1
  """Deep-learning model configurations."""
2
2
 
3
- __all__ = [
4
- "ArchitectureModel",
5
- "CustomModel",
6
- "UNetModel",
7
- "LVAEModel",
8
- "clear_custom_models",
9
- "get_custom_model",
10
- "register_model",
11
- ]
3
+ __all__ = ["ArchitectureModel", "LVAEModel", "UNetModel"]
12
4
 
13
5
  from .architecture_model import ArchitectureModel
14
- from .custom_model import CustomModel
15
6
  from .lvae_model import LVAEModel
16
- from .register_model import clear_custom_models, get_custom_model, register_model
17
7
  from .unet_model import UNetModel
@@ -1,6 +1,6 @@
1
1
  """Base model for the various CAREamics architectures."""
2
2
 
3
- from typing import Any, Dict
3
+ from typing import Any
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
@@ -15,7 +15,7 @@ class ArchitectureModel(BaseModel):
15
15
  architecture: str
16
16
  """Name of the architecture."""
17
17
 
18
- def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
18
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
19
19
  """
20
20
  Dump the model as a dictionary, ignoring the architecture keyword.
21
21
 
@@ -26,7 +26,7 @@ class ArchitectureModel(BaseModel):
26
26
 
27
27
  Returns
28
28
  -------
29
- dict[str, Any]
29
+ {str: Any}
30
30
  Model as a dictionary.
31
31
  """
32
32
  model_dict = super().model_dump(**kwargs)
@@ -15,12 +15,17 @@ class LVAEModel(ArchitectureModel):
15
15
  model_config = ConfigDict(validate_assignment=True, validate_default=True)
16
16
 
17
17
  architecture: Literal["LVAE"]
18
- input_shape: list[int] = Field(default=(64, 64), validate_default=True)
18
+ """Name of the architecture."""
19
+
20
+ input_shape: list[int] = Field(default=[64, 64], validate_default=True)
19
21
  """Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D."""
22
+
20
23
  encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
24
+
21
25
  # TODO make this per hierarchy step ?
22
26
  decoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
23
27
  """Dimensions (2D or 3D) of the convolutional layers."""
28
+
24
29
  multiscale_count: int = Field(default=1)
25
30
  # TODO there should be a check for multiscale_count in dataset !!
26
31
 
@@ -48,6 +48,7 @@ class UNetModel(ArchitectureModel):
48
48
  num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
49
49
  """Number of convolutional filters in the first layer of the UNet."""
50
50
 
51
+ # TODO we are not using this, so why make it a choice?
51
52
  final_activation: Literal[
52
53
  "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
53
54
  ] = Field(default="None", validate_default=True)
@@ -0,0 +1,100 @@
1
+ """CARE Pydantic configuration."""
2
+
3
+ from bioimageio.spec.generic.v0_3 import CiteEntry
4
+
5
+ from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm
6
+ from careamics.config.configuration import Configuration
7
+ from careamics.config.data import DataConfig
8
+
9
+ CARE = "CARE"
10
+
11
+ CARE_DESCRIPTION = (
12
+ "Content-aware image restoration (CARE) is a deep-learning-based "
13
+ "algorithm that uses a U-Net architecture to restore images. CARE "
14
+ "is a supervised algorithm that requires pairs of noisy and "
15
+ "clean images to train the network. The algorithm learns to "
16
+ "predict the clean image from the noisy image. CARE is "
17
+ "particularly useful for denoising images acquired in low-light "
18
+ "conditions, such as fluorescence microscopy images."
19
+ )
20
+ CARE_REF = CiteEntry(
21
+ text='Weigert, Martin, et al. "Content-aware image restoration: pushing the '
22
+ 'limits of fluorescence microscopy." Nature methods 15.12 (2018): 1090-1097.',
23
+ doi="10.1038/s41592-018-0216-7",
24
+ )
25
+
26
+
27
+ class CAREConfiguration(Configuration):
28
+ """CARE configuration."""
29
+
30
+ algorithm_config: CAREAlgorithm
31
+ """Algorithm configuration."""
32
+
33
+ data_config: DataConfig
34
+ """Data configuration."""
35
+
36
+ def get_algorithm_friendly_name(self) -> str:
37
+ """
38
+ Get the algorithm friendly name.
39
+
40
+ Returns
41
+ -------
42
+ str
43
+ Friendly name of the algorithm.
44
+ """
45
+ return CARE
46
+
47
+ def get_algorithm_keywords(self) -> list[str]:
48
+ """
49
+ Get algorithm keywords.
50
+
51
+ Returns
52
+ -------
53
+ list[str]
54
+ List of keywords.
55
+ """
56
+ return [
57
+ "restoration",
58
+ "UNet",
59
+ "3D" if "Z" in self.data_config.axes else "2D",
60
+ "CAREamics",
61
+ "pytorch",
62
+ CARE,
63
+ ]
64
+
65
+ def get_algorithm_references(self) -> str:
66
+ """
67
+ Get the algorithm references.
68
+
69
+ This is used to generate the README of the BioImage Model Zoo export.
70
+
71
+ Returns
72
+ -------
73
+ str
74
+ Algorithm references.
75
+ """
76
+ return CARE_REF.text + " doi: " + CARE_REF.doi
77
+
78
+ def get_algorithm_citations(self) -> list[CiteEntry]:
79
+ """
80
+ Return a list of citation entries of the current algorithm.
81
+
82
+ This is used to generate the model description for the BioImage Model Zoo.
83
+
84
+ Returns
85
+ -------
86
+ List[CiteEntry]
87
+ List of citation entries.
88
+ """
89
+ return [CARE_REF]
90
+
91
+ def get_algorithm_description(self) -> str:
92
+ """
93
+ Get the algorithm description.
94
+
95
+ Returns
96
+ -------
97
+ str
98
+ Algorithm description.
99
+ """
100
+ return CARE_DESCRIPTION