careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
careamics/cli/utils.py ADDED
@@ -0,0 +1,29 @@
1
+ """Utility functions for the CAREamics CLI."""
2
+
3
+ from typing import Optional
4
+
5
+
6
+ def handle_2D_3D_callback(
7
+ value: Optional[tuple[int, int, int]]
8
+ ) -> Optional[tuple[int, ...]]:
9
+ """
10
+ Callback for options that require 2D or 3D inputs.
11
+
12
+ In the case of 2D, the 3rd element should be set to -1.
13
+
14
+ Parameters
15
+ ----------
16
+ value : (int, int, int)
17
+ Tile size value.
18
+
19
+ Returns
20
+ -------
21
+ (int, int, int) | (int, int)
22
+ If the last element in `value` is -1 the tuple is reduced to the first two
23
+ values.
24
+ """
25
+ if value is None:
26
+ return value
27
+ if value[2] == -1:
28
+ return value[:2]
29
+ return value
@@ -1,39 +1,63 @@
1
- """Configuration module."""
1
+ """CAREamics Pydantic configuration models.
2
+
3
+ To maintain clarity at the module level, we follow the following naming conventions:
4
+ `*_model` is specific for sub-configurations (e.g. architecture, data, algorithm),
5
+ while `*_configuration` is reserved for the main configuration models, including the
6
+ `Configuration` base class and its algorithm-specific child classes.
7
+ """
2
8
 
3
9
  __all__ = [
4
- "FCNAlgorithmConfig",
5
- "VAEAlgorithmConfig",
6
- "DataConfig",
7
- "Configuration",
10
+ "CAREAlgorithm",
11
+ "CAREConfiguration",
8
12
  "CheckpointModel",
13
+ "Configuration",
14
+ "DataConfig",
15
+ "GaussianMixtureNMConfig",
16
+ "GeneralDataConfig",
9
17
  "InferenceConfig",
10
- "load_configuration",
11
- "save_configuration",
18
+ "LVAELossConfig",
19
+ "MultiChannelNMConfig",
20
+ "N2NAlgorithm",
21
+ "N2NConfiguration",
22
+ "N2VAlgorithm",
23
+ "N2VConfiguration",
24
+ "N2VDataConfig",
12
25
  "TrainingConfig",
13
- "create_n2v_configuration",
14
- "create_n2n_configuration",
26
+ "UNetBasedAlgorithm",
27
+ "VAEBasedAlgorithm",
28
+ "algorithm_factory",
29
+ "configuration_factory",
15
30
  "create_care_configuration",
16
- "register_model",
17
- "CustomModel",
18
- "clear_custom_models",
19
- "GaussianMixtureNMConfig",
20
- "MultiChannelNMConfig",
31
+ "create_n2n_configuration",
32
+ "create_n2v_configuration",
33
+ "data_factory",
34
+ "load_configuration",
35
+ "save_configuration",
21
36
  ]
22
- from .architectures import CustomModel, clear_custom_models, register_model
37
+
38
+ from .algorithms import (
39
+ CAREAlgorithm,
40
+ N2NAlgorithm,
41
+ N2VAlgorithm,
42
+ UNetBasedAlgorithm,
43
+ VAEBasedAlgorithm,
44
+ )
23
45
  from .callback_model import CheckpointModel
24
- from .configuration_factory import (
46
+ from .care_configuration import CAREConfiguration
47
+ from .configuration import Configuration
48
+ from .configuration_factories import (
49
+ algorithm_factory,
50
+ configuration_factory,
25
51
  create_care_configuration,
26
52
  create_n2n_configuration,
27
53
  create_n2v_configuration,
54
+ data_factory,
28
55
  )
29
- from .configuration_model import (
30
- Configuration,
31
- load_configuration,
32
- save_configuration,
33
- )
34
- from .data_model import DataConfig
35
- from .fcn_algorithm_model import FCNAlgorithmConfig
56
+ from .configuration_io import load_configuration, save_configuration
57
+ from .data import DataConfig, GeneralDataConfig, N2VDataConfig
36
58
  from .inference_model import InferenceConfig
59
+ from .loss_model import LVAELossConfig
60
+ from .n2n_configuration import N2NConfiguration
61
+ from .n2v_configuration import N2VConfiguration
37
62
  from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
38
63
  from .training_model import TrainingConfig
39
- from .vae_algorithm_model import VAEAlgorithmConfig
@@ -0,0 +1,15 @@
1
+ """Algorithm configurations."""
2
+
3
+ __all__ = [
4
+ "CAREAlgorithm",
5
+ "N2NAlgorithm",
6
+ "N2VAlgorithm",
7
+ "UNetBasedAlgorithm",
8
+ "VAEBasedAlgorithm",
9
+ ]
10
+
11
+ from .care_algorithm_model import CAREAlgorithm
12
+ from .n2n_algorithm_model import N2NAlgorithm
13
+ from .n2v_algorithm_model import N2VAlgorithm
14
+ from .unet_algorithm_model import UNetBasedAlgorithm
15
+ from .vae_algorithm_model import VAEBasedAlgorithm
@@ -0,0 +1,50 @@
1
+ """CARE algorithm configuration."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import field_validator
6
+
7
+ from careamics.config.architectures import UNetModel
8
+
9
+ from .unet_algorithm_model import UNetBasedAlgorithm
10
+
11
+
12
+ class CAREAlgorithm(UNetBasedAlgorithm):
13
+ """CARE algorithm configuration.
14
+
15
+ Attributes
16
+ ----------
17
+ algorithm : "care"
18
+ CARE Algorithm name.
19
+ loss : {"mae", "mse"}
20
+ CARE-compatible loss function.
21
+ """
22
+
23
+ algorithm: Literal["care"] = "care"
24
+ """CARE Algorithm name."""
25
+
26
+ loss: Literal["mae", "mse"] = "mae"
27
+ """CARE-compatible loss function."""
28
+
29
+ @classmethod
30
+ @field_validator("model")
31
+ def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
32
+ """Validate that the model does not have the n2v2 attribute.
33
+
34
+ Parameters
35
+ ----------
36
+ value : UNetModel
37
+ Model to validate.
38
+
39
+ Returns
40
+ -------
41
+ UNetModel
42
+ The validated model.
43
+ """
44
+ if value.n2v2:
45
+ raise ValueError(
46
+ "The N2N algorithm does not support the `n2v2` attribute. "
47
+ "Set it to `False`."
48
+ )
49
+
50
+ return value
@@ -0,0 +1,42 @@
1
+ """N2N Algorithm configuration."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import field_validator
6
+
7
+ from careamics.config.architectures import UNetModel
8
+
9
+ from .unet_algorithm_model import UNetBasedAlgorithm
10
+
11
+
12
+ class N2NAlgorithm(UNetBasedAlgorithm):
13
+ """N2N Algorithm configuration."""
14
+
15
+ algorithm: Literal["n2n"] = "n2n"
16
+ """N2N Algorithm name."""
17
+
18
+ loss: Literal["mae", "mse"] = "mae"
19
+ """N2N-compatible loss function."""
20
+
21
+ @classmethod
22
+ @field_validator("model")
23
+ def model_without_n2v2(cls, value: UNetModel) -> UNetModel:
24
+ """Validate that the model does not have the n2v2 attribute.
25
+
26
+ Parameters
27
+ ----------
28
+ value : UNetModel
29
+ Model to validate.
30
+
31
+ Returns
32
+ -------
33
+ UNetModel
34
+ The validated model.
35
+ """
36
+ if value.n2v2:
37
+ raise ValueError(
38
+ "The N2N algorithm does not support the `n2v2` attribute. "
39
+ "Set it to `False`."
40
+ )
41
+
42
+ return value
@@ -0,0 +1,35 @@
1
+ """"N2V Algorithm configuration."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import model_validator
6
+ from typing_extensions import Self
7
+
8
+ from .unet_algorithm_model import UNetBasedAlgorithm
9
+
10
+
11
+ class N2VAlgorithm(UNetBasedAlgorithm):
12
+ """N2V Algorithm configuration."""
13
+
14
+ algorithm: Literal["n2v"] = "n2v"
15
+ """N2V Algorithm name."""
16
+
17
+ loss: Literal["n2v"] = "n2v"
18
+ """N2V loss function."""
19
+
20
+ @model_validator(mode="after")
21
+ def algorithm_cross_validation(self: Self) -> Self:
22
+ """Validate the algorithm model for N2V.
23
+
24
+ Returns
25
+ -------
26
+ Self
27
+ The validated model.
28
+ """
29
+ if self.model.in_channels != self.model.num_classes:
30
+ raise ValueError(
31
+ "N2V requires the same number of input and output channels. Make "
32
+ "sure that `in_channels` and `num_classes` are the same."
33
+ )
34
+
35
+ return self
@@ -0,0 +1,88 @@
1
+ """UNet-based algorithm Pydantic model."""
2
+
3
+ from pprint import pformat
4
+ from typing import Literal
5
+
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+ from careamics.config.architectures import UNetModel
9
+ from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
10
+
11
+
12
+ class UNetBasedAlgorithm(BaseModel):
13
+ """General UNet-based algorithm configuration.
14
+
15
+ This Pydantic model validates the parameters governing the components of the
16
+ training algorithm: which algorithm, loss function, model architecture, optimizer,
17
+ and learning rate scheduler to use.
18
+
19
+ Currently, we only support N2V, CARE, and N2N algorithms. In order to train these
20
+ algorithms, use the corresponding configuration child classes (e.g.
21
+ `N2VAlgorithm`) to ensure coherent parameters (e.g. specific losses).
22
+
23
+
24
+ Attributes
25
+ ----------
26
+ algorithm : {"n2v", "care", "n2n"}
27
+ Algorithm to use.
28
+ loss : {"n2v", "mae", "mse"}
29
+ Loss function to use.
30
+ model : UNetModel
31
+ Model architecture to use.
32
+ optimizer : OptimizerModel, optional
33
+ Optimizer to use.
34
+ lr_scheduler : LrSchedulerModel, optional
35
+ Learning rate scheduler to use.
36
+
37
+ Raises
38
+ ------
39
+ ValueError
40
+ Algorithm parameter type validation errors.
41
+ ValueError
42
+ If the algorithm, loss and model are not compatible.
43
+ """
44
+
45
+ # Pydantic class configuration
46
+ model_config = ConfigDict(
47
+ protected_namespaces=(), # allows to use model_* as a field name
48
+ validate_assignment=True,
49
+ extra="allow",
50
+ )
51
+
52
+ # Mandatory fields
53
+ algorithm: Literal["n2v", "care", "n2n"]
54
+ """Algorithm name, as defined in SupportedAlgorithm."""
55
+
56
+ loss: Literal["n2v", "mae", "mse"]
57
+ """Loss function to use, as defined in SupportedLoss."""
58
+
59
+ model: UNetModel
60
+ """UNet model configuration."""
61
+
62
+ # Optional fields
63
+ optimizer: OptimizerModel = OptimizerModel()
64
+ """Optimizer to use, defined in SupportedOptimizer."""
65
+
66
+ lr_scheduler: LrSchedulerModel = LrSchedulerModel()
67
+ """Learning rate scheduler to use, defined in SupportedLrScheduler."""
68
+
69
+ def __str__(self) -> str:
70
+ """Pretty string representing the configuration.
71
+
72
+ Returns
73
+ -------
74
+ str
75
+ Pretty string.
76
+ """
77
+ return pformat(self.model_dump())
78
+
79
+ @classmethod
80
+ def get_compatible_algorithms(cls) -> list[str]:
81
+ """Get the list of compatible algorithms.
82
+
83
+ Returns
84
+ -------
85
+ list of str
86
+ List of compatible algorithms.
87
+ """
88
+ return ["n2v", "care", "n2n"]
@@ -1,23 +1,26 @@
1
- """Algorithm configuration."""
1
+ """VAE-based algorithm Pydantic model."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Literal, Optional, Union
6
+ from typing import Literal, Optional
7
7
 
8
- from pydantic import BaseModel, ConfigDict, Field, model_validator
8
+ from pydantic import BaseModel, ConfigDict, model_validator
9
9
  from typing_extensions import Self
10
10
 
11
+ from careamics.config.architectures import LVAEModel
12
+ from careamics.config.likelihood_model import (
13
+ GaussianLikelihoodConfig,
14
+ NMLikelihoodConfig,
15
+ )
16
+ from careamics.config.loss_model import LVAELossConfig
17
+ from careamics.config.nm_model import MultiChannelNMConfig
18
+ from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
11
19
  from careamics.config.support import SupportedAlgorithm, SupportedLoss
12
20
 
13
- from .architectures import CustomModel, LVAEModel
14
- from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
15
- from .nm_model import MultiChannelNMConfig
16
- from .optimizer_models import LrSchedulerModel, OptimizerModel
17
21
 
18
-
19
- class VAEAlgorithmConfig(BaseModel):
20
- """Algorithm configuration.
22
+ class VAEBasedAlgorithm(BaseModel):
23
+ """VAE-based algorithm configuration.
21
24
 
22
25
  # TODO
23
26
 
@@ -38,13 +41,13 @@ class VAEAlgorithmConfig(BaseModel):
38
41
  # TODO: Use supported Enum classes for typing?
39
42
  # - values can still be passed as strings and they will be cast to Enum
40
43
  algorithm: Literal["musplit", "denoisplit"]
41
- loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
42
- model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
43
44
 
44
- # TODO: these are configs, change naming of attrs
45
+ # NOTE: these are all configs (pydantic models)
46
+ loss: LVAELossConfig
47
+ model: LVAEModel
45
48
  noise_model: Optional[MultiChannelNMConfig] = None
46
- noise_model_likelihood_model: Optional[NMLikelihoodConfig] = None
47
- gaussian_likelihood_model: Optional[GaussianLikelihoodConfig] = None
49
+ noise_model_likelihood: Optional[NMLikelihoodConfig] = None
50
+ gaussian_likelihood: Optional[GaussianLikelihoodConfig] = None
48
51
 
49
52
  # Optional fields
50
53
  optimizer: OptimizerModel = OptimizerModel()
@@ -63,13 +66,13 @@ class VAEAlgorithmConfig(BaseModel):
63
66
  """
64
67
  # musplit
65
68
  if self.algorithm == SupportedAlgorithm.MUSPLIT:
66
- if self.loss != SupportedLoss.MUSPLIT:
69
+ if self.loss.loss_type != SupportedLoss.MUSPLIT:
67
70
  raise ValueError(
68
71
  f"Algorithm {self.algorithm} only supports loss `musplit`."
69
72
  )
70
73
 
71
74
  if self.algorithm == SupportedAlgorithm.DENOISPLIT:
72
- if self.loss not in [
75
+ if self.loss.loss_type not in [
73
76
  SupportedLoss.DENOISPLIT,
74
77
  SupportedLoss.DENOISPLIT_MUSPLIT,
75
78
  ]:
@@ -78,16 +81,17 @@ class VAEAlgorithmConfig(BaseModel):
78
81
  "or `denoisplit_musplit."
79
82
  )
80
83
  if (
81
- self.loss == SupportedLoss.DENOISPLIT
84
+ self.loss.loss_type == SupportedLoss.DENOISPLIT
82
85
  and self.model.predict_logvar is not None
83
86
  ):
84
87
  raise ValueError(
85
88
  "Algorithm `denoisplit` with loss `denoisplit` only supports "
86
89
  "`predict_logvar` as `None`."
87
90
  )
91
+
88
92
  if self.noise_model is None:
89
93
  raise ValueError("Algorithm `denoisplit` requires a noise model.")
90
- # TODO: what if algorithm is not musplit or denoisplit (HDN?)
94
+ # TODO: what if algorithm is not musplit or denoisplit
91
95
  return self
92
96
 
93
97
  @model_validator(mode="after")
@@ -115,14 +119,13 @@ class VAEAlgorithmConfig(BaseModel):
115
119
  Self
116
120
  The validated model.
117
121
  """
118
- if self.gaussian_likelihood_model is not None:
122
+ if self.gaussian_likelihood is not None:
119
123
  assert (
120
- self.model.predict_logvar
121
- == self.gaussian_likelihood_model.predict_logvar
124
+ self.model.predict_logvar == self.gaussian_likelihood.predict_logvar
122
125
  ), (
123
126
  f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
124
127
  "Gaussian likelihood model `predict_logvar` "
125
- f"({self.gaussian_likelihood_model.predict_logvar}).",
128
+ f"({self.gaussian_likelihood.predict_logvar}).",
126
129
  )
127
130
  return self
128
131
 
@@ -1,17 +1,7 @@
1
1
  """Deep-learning model configurations."""
2
2
 
3
- __all__ = [
4
- "ArchitectureModel",
5
- "CustomModel",
6
- "UNetModel",
7
- "LVAEModel",
8
- "clear_custom_models",
9
- "get_custom_model",
10
- "register_model",
11
- ]
3
+ __all__ = ["ArchitectureModel", "LVAEModel", "UNetModel"]
12
4
 
13
5
  from .architecture_model import ArchitectureModel
14
- from .custom_model import CustomModel
15
6
  from .lvae_model import LVAEModel
16
- from .register_model import clear_custom_models, get_custom_model, register_model
17
7
  from .unet_model import UNetModel
@@ -1,6 +1,6 @@
1
1
  """Base model for the various CAREamics architectures."""
2
2
 
3
- from typing import Any, Dict
3
+ from typing import Any
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
@@ -15,7 +15,7 @@ class ArchitectureModel(BaseModel):
15
15
  architecture: str
16
16
  """Name of the architecture."""
17
17
 
18
- def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
18
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
19
19
  """
20
20
  Dump the model as a dictionary, ignoring the architecture keyword.
21
21
 
@@ -26,7 +26,7 @@ class ArchitectureModel(BaseModel):
26
26
 
27
27
  Returns
28
28
  -------
29
- dict[str, Any]
29
+ {str: Any}
30
30
  Model as a dictionary.
31
31
  """
32
32
  model_dict = super().model_dump(**kwargs)
@@ -15,9 +15,21 @@ class LVAEModel(ArchitectureModel):
15
15
  model_config = ConfigDict(validate_assignment=True, validate_default=True)
16
16
 
17
17
  architecture: Literal["LVAE"]
18
- input_shape: int = Field(default=64, ge=8, le=1024)
19
- multiscale_count: int = Field(default=5) # TODO clarify
20
- # 0 - off, len(z_dims) + 1 # TODO can/should be le to z_dims len + 1
18
+ """Name of the architecture."""
19
+
20
+ input_shape: list[int] = Field(default=[64, 64], validate_default=True)
21
+ """Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D."""
22
+
23
+ encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
24
+
25
+ # TODO make this per hierarchy step ?
26
+ decoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
27
+ """Dimensions (2D or 3D) of the convolutional layers."""
28
+
29
+ multiscale_count: int = Field(default=1)
30
+ # TODO there should be a check for multiscale_count in dataset !!
31
+
32
+ # 1 - off, len(z_dims) + 1 # TODO Consider starting from 0
21
33
  z_dims: list = Field(default=[128, 128, 128, 128])
22
34
  output_channels: int = Field(default=1, ge=1)
23
35
  encoder_n_filters: int = Field(default=64, ge=8, le=1024)
@@ -31,10 +43,90 @@ class LVAEModel(ArchitectureModel):
31
43
  )
32
44
 
33
45
  predict_logvar: Literal[None, "pixelwise"] = None
46
+ analytical_kl: bool = Field(default=False)
34
47
 
35
- analytical_kl: bool = Field(
36
- default=False,
37
- )
48
+ @model_validator(mode="after")
49
+ def validate_conv_strides(self: Self) -> Self:
50
+ """
51
+ Validate the convolutional strides.
52
+
53
+ Returns
54
+ -------
55
+ list
56
+ Validated strides.
57
+
58
+ Raises
59
+ ------
60
+ ValueError
61
+ If the number of strides is not 2.
62
+ """
63
+ if len(self.encoder_conv_strides) < 2 or len(self.encoder_conv_strides) > 3:
64
+ raise ValueError(
65
+ f"Strides must be 2 or 3 (got {len(self.encoder_conv_strides)})."
66
+ )
67
+
68
+ if len(self.decoder_conv_strides) < 2 or len(self.decoder_conv_strides) > 3:
69
+ raise ValueError(
70
+ f"Strides must be 2 or 3 (got {len(self.decoder_conv_strides)})."
71
+ )
72
+
73
+ # adding 1 to encoder strides for the number of input channels
74
+ if len(self.input_shape) != len(self.encoder_conv_strides):
75
+ raise ValueError(
76
+ f"Input dimensions must be equal to the number of encoder conv strides"
77
+ f" (got {len(self.input_shape)} and {len(self.encoder_conv_strides)})."
78
+ )
79
+
80
+ if len(self.encoder_conv_strides) < len(self.decoder_conv_strides):
81
+ raise ValueError(
82
+ f"Decoder can't be 3D when encoder is 2D (got"
83
+ f" {len(self.encoder_conv_strides)} and"
84
+ f"{len(self.decoder_conv_strides)})."
85
+ )
86
+
87
+ if any(s < 1 for s in self.encoder_conv_strides) or any(
88
+ s < 1 for s in self.decoder_conv_strides
89
+ ):
90
+ raise ValueError(
91
+ f"All strides must be greater or equal to 1"
92
+ f"(got {self.encoder_conv_strides} and {self.decoder_conv_strides})."
93
+ )
94
+ # TODO: validate max stride size ?
95
+ return self
96
+
97
+ @field_validator("input_shape")
98
+ @classmethod
99
+ def validate_input_shape(cls, input_shape: list) -> list:
100
+ """
101
+ Validate the input shape.
102
+
103
+ Parameters
104
+ ----------
105
+ input_shape : list
106
+ Shape of the input patch.
107
+
108
+ Returns
109
+ -------
110
+ list
111
+ Validated input shape.
112
+
113
+ Raises
114
+ ------
115
+ ValueError
116
+ If the number of dimensions is not 3 or 4.
117
+ """
118
+ if len(input_shape) < 2 or len(input_shape) > 3:
119
+ raise ValueError(
120
+ f"Number of input dimensions must be 2 for 2D data 3 for 3D"
121
+ f"(got {len(input_shape)})."
122
+ )
123
+
124
+ if any(s < 1 for s in input_shape):
125
+ raise ValueError(
126
+ f"Input shape must be greater than 1 in all dimensions"
127
+ f"(got {input_shape})."
128
+ )
129
+ return input_shape
38
130
 
39
131
  @field_validator("encoder_n_filters")
40
132
  @classmethod
@@ -124,27 +216,20 @@ class LVAEModel(ArchitectureModel):
124
216
  return z_dims
125
217
 
126
218
  @model_validator(mode="after")
127
- def validate_multiscale_count(cls, self: Self) -> Self:
219
+ def validate_multiscale_count(self: Self) -> Self:
128
220
  """
129
221
  Validate the multiscale count.
130
222
 
131
- Parameters
132
- ----------
133
- self : Self
134
- The model.
135
-
136
223
  Returns
137
224
  -------
138
225
  Self
139
226
  The validated model.
140
227
  """
141
- # if self.multiscale_count != 0:
142
- # if self.multiscale_count != len(self.z_dims) - 1:
143
- # raise ValueError(
144
- # f"Multiscale count must be 0 or equal to the number of Z "
145
- # f"dims - 1 (got {self.multiscale_count} and {len(self.z_dims)})."
146
- # )
147
-
228
+ if self.multiscale_count < 1 or self.multiscale_count > len(self.z_dims) + 1:
229
+ raise ValueError(
230
+ f"Multiscale count must be 1 for LC off or less or equal to the number"
231
+ f" of Z dims + 1 (got {self.multiscale_count} and {len(self.z_dims)})."
232
+ )
148
233
  return self
149
234
 
150
235
  def set_3D(self, is_3D: bool) -> None:
@@ -156,7 +241,10 @@ class LVAEModel(ArchitectureModel):
156
241
  is_3D : bool
157
242
  Whether the algorithm is 3D or not.
158
243
  """
159
- raise NotImplementedError("VAE is not implemented yet.")
244
+ if is_3D:
245
+ self.conv_dims = 3
246
+ else:
247
+ self.conv_dims = 2
160
248
 
161
249
  def is_3D(self) -> bool:
162
250
  """
@@ -167,4 +255,4 @@ class LVAEModel(ArchitectureModel):
167
255
  bool
168
256
  Whether the model is 3D or not.
169
257
  """
170
- raise NotImplementedError("VAE is not implemented yet.")
258
+ return self.conv_dims == 3
@@ -48,6 +48,7 @@ class UNetModel(ArchitectureModel):
48
48
  num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
49
49
  """Number of convolutional filters in the first layer of the UNet."""
50
50
 
51
+ # TODO we are not using this, so why make it a choice?
51
52
  final_activation: Literal[
52
53
  "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
53
54
  ] = Field(default="None", validate_default=True)