careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (64) hide show
  1. careamics/careamist.py +14 -11
  2. careamics/config/__init__.py +7 -3
  3. careamics/config/architectures/__init__.py +2 -2
  4. careamics/config/architectures/architecture_model.py +1 -1
  5. careamics/config/architectures/custom_model.py +11 -8
  6. careamics/config/architectures/lvae_model.py +174 -0
  7. careamics/config/configuration_factory.py +11 -3
  8. careamics/config/configuration_model.py +7 -3
  9. careamics/config/data_model.py +33 -8
  10. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
  11. careamics/config/likelihood_model.py +43 -0
  12. careamics/config/nm_model.py +101 -0
  13. careamics/config/support/supported_activations.py +1 -0
  14. careamics/config/support/supported_algorithms.py +17 -4
  15. careamics/config/support/supported_architectures.py +8 -11
  16. careamics/config/support/supported_losses.py +3 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  18. careamics/config/vae_algorithm_model.py +171 -0
  19. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  20. careamics/file_io/read/tiff.py +1 -1
  21. careamics/lightning/__init__.py +3 -2
  22. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  23. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  24. careamics/lightning/lightning_module.py +365 -9
  25. careamics/lightning/predict_data_module.py +2 -2
  26. careamics/lightning/train_data_module.py +2 -2
  27. careamics/losses/__init__.py +11 -1
  28. careamics/losses/fcn/__init__.py +1 -0
  29. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  30. careamics/losses/loss_factory.py +112 -6
  31. careamics/losses/lvae/__init__.py +1 -0
  32. careamics/losses/lvae/loss_utils.py +83 -0
  33. careamics/losses/lvae/losses.py +445 -0
  34. careamics/lvae_training/dataset/__init__.py +0 -0
  35. careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
  36. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  37. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  38. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  39. careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
  40. careamics/lvae_training/get_config.py +1 -1
  41. careamics/lvae_training/train_lvae.py +6 -3
  42. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  43. careamics/model_io/bioimage/model_description.py +2 -2
  44. careamics/model_io/bmz_io.py +19 -6
  45. careamics/model_io/model_io_utils.py +16 -4
  46. careamics/models/__init__.py +1 -3
  47. careamics/models/activation.py +2 -0
  48. careamics/models/lvae/__init__.py +3 -0
  49. careamics/models/lvae/layers.py +21 -21
  50. careamics/models/lvae/likelihoods.py +180 -128
  51. careamics/models/lvae/lvae.py +52 -136
  52. careamics/models/lvae/noise_models.py +318 -186
  53. careamics/models/lvae/utils.py +2 -2
  54. careamics/models/model_factory.py +22 -7
  55. careamics/prediction_utils/lvae_prediction.py +158 -0
  56. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  57. careamics/prediction_utils/stitch_prediction.py +16 -2
  58. careamics/transforms/pixel_manipulation.py +1 -1
  59. careamics/utils/metrics.py +74 -1
  60. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
  61. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
  62. careamics/config/architectures/vae_model.py +0 -42
  63. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
  64. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py CHANGED
@@ -13,10 +13,7 @@ from pytorch_lightning.callbacks import (
13
13
  )
14
14
  from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
15
15
 
16
- from careamics.config import (
17
- Configuration,
18
- load_configuration,
19
- )
16
+ from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
20
17
  from careamics.config.support import (
21
18
  SupportedAlgorithm,
22
19
  SupportedArchitecture,
@@ -25,7 +22,7 @@ from careamics.config.support import (
25
22
  )
26
23
  from careamics.dataset.dataset_utils import reshape_array
27
24
  from careamics.lightning import (
28
- CAREamicsModule,
25
+ FCNModule,
29
26
  HyperParametersCallback,
30
27
  PredictDataModule,
31
28
  ProgressBarCallback,
@@ -148,9 +145,12 @@ class CAREamist:
148
145
  self.cfg = source
149
146
 
150
147
  # instantiate model
151
- self.model = CAREamicsModule(
152
- algorithm_config=self.cfg.algorithm_config,
153
- )
148
+ if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
149
+ self.model = FCNModule(
150
+ algorithm_config=self.cfg.algorithm_config,
151
+ )
152
+ else:
153
+ raise NotImplementedError("Architecture not supported.")
154
154
 
155
155
  # path to configuration file or model
156
156
  else:
@@ -164,9 +164,12 @@ class CAREamist:
164
164
  self.cfg = load_configuration(source)
165
165
 
166
166
  # instantiate model
167
- self.model = CAREamicsModule(
168
- algorithm_config=self.cfg.algorithm_config,
169
- )
167
+ if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
168
+ self.model = FCNModule(
169
+ algorithm_config=self.cfg.algorithm_config,
170
+ ) # type: ignore
171
+ else:
172
+ raise NotImplementedError("Architecture not supported.")
170
173
 
171
174
  # attempt loading a pre-trained model
172
175
  else:
@@ -1,7 +1,8 @@
1
1
  """Configuration module."""
2
2
 
3
3
  __all__ = [
4
- "AlgorithmConfig",
4
+ "FCNAlgorithmConfig",
5
+ "VAEAlgorithmConfig",
5
6
  "DataConfig",
6
7
  "Configuration",
7
8
  "CheckpointModel",
@@ -15,9 +16,9 @@ __all__ = [
15
16
  "register_model",
16
17
  "CustomModel",
17
18
  "clear_custom_models",
19
+ "GaussianMixtureNMConfig",
20
+ "MultiChannelNMConfig",
18
21
  ]
19
-
20
- from .algorithm_model import AlgorithmConfig
21
22
  from .architectures import CustomModel, clear_custom_models, register_model
22
23
  from .callback_model import CheckpointModel
23
24
  from .configuration_factory import (
@@ -31,5 +32,8 @@ from .configuration_model import (
31
32
  save_configuration,
32
33
  )
33
34
  from .data_model import DataConfig
35
+ from .fcn_algorithm_model import FCNAlgorithmConfig
34
36
  from .inference_model import InferenceConfig
37
+ from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
35
38
  from .training_model import TrainingConfig
39
+ from .vae_algorithm_model import VAEAlgorithmConfig
@@ -4,7 +4,7 @@ __all__ = [
4
4
  "ArchitectureModel",
5
5
  "CustomModel",
6
6
  "UNetModel",
7
- "VAEModel",
7
+ "LVAEModel",
8
8
  "clear_custom_models",
9
9
  "get_custom_model",
10
10
  "register_model",
@@ -12,6 +12,6 @@ __all__ = [
12
12
 
13
13
  from .architecture_model import ArchitectureModel
14
14
  from .custom_model import CustomModel
15
+ from .lvae_model import LVAEModel
15
16
  from .register_model import clear_custom_models, get_custom_model, register_model
16
17
  from .unet_model import UNetModel
17
- from .vae_model import VAEModel
@@ -27,7 +27,7 @@ class ArchitectureModel(BaseModel):
27
27
  Returns
28
28
  -------
29
29
  dict[str, Any]
30
- Model as a dictionnary.
30
+ Model as a dictionary.
31
31
  """
32
32
  model_dict = super().model_dump(**kwargs)
33
33
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import inspect
5
6
  from pprint import pformat
6
7
  from typing import Any, Literal
7
8
 
@@ -23,12 +24,13 @@ class CustomModel(ArchitectureModel):
23
24
 
24
25
  Attributes
25
26
  ----------
26
- architecture : Literal["Custom"]
27
- Discriminator for the custom model, must be set to "Custom".
27
+ architecture : Literal["custom"]
28
+ Discriminator for the custom model, must be set to "custom".
28
29
  name : str
29
30
  Name of the custom model.
30
31
  parameters : CustomParametersModel
31
- Parameters of the custom model.
32
+ All parameters, required for the initialization of the torch module have to be
33
+ passed here.
32
34
 
33
35
  Raises
34
36
  ------
@@ -57,7 +59,7 @@ class CustomModel(ArchitectureModel):
57
59
  ...
58
60
  >>> # Create a configuration
59
61
  >>> config_dict = {
60
- ... "architecture": "Custom",
62
+ ... "architecture": "custom",
61
63
  ... "name": "my_linear",
62
64
  ... "in_features": 10,
63
65
  ... "out_features": 5,
@@ -71,10 +73,9 @@ class CustomModel(ArchitectureModel):
71
73
  )
72
74
 
73
75
  # discriminator used for choosing the pydantic model in Model
74
- architecture: Literal["Custom"]
76
+ architecture: Literal["custom"]
75
77
  """Name of the architecture."""
76
78
 
77
- # name of the custom model
78
79
  name: str
79
80
  """Name of the custom model."""
80
81
 
@@ -120,10 +121,12 @@ class CustomModel(ArchitectureModel):
120
121
  get_custom_model(self.name)(**self.model_dump())
121
122
  except Exception as e:
122
123
  raise ValueError(
123
- f"error while passing parameters to the model {e}. Verify that all "
124
+ f"while passing parameters to the model {e}. Verify that all "
124
125
  f"mandatory parameters are provided, and that either the {e} accepts "
125
126
  f"*args and **kwargs in its __init__() method, or that no additional"
126
- f"parameter is provided."
127
+ f"parameter is provided. Trace: "
128
+ f"filename: {inspect.trace()[-1].filename}, function: "
129
+ f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
127
130
  ) from None
128
131
 
129
132
  return self
@@ -0,0 +1,174 @@
1
+ """LVAE Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import ConfigDict, Field, field_validator, model_validator
6
+ from typing_extensions import Self
7
+
8
+ from .architecture_model import ArchitectureModel
9
+
10
+
11
+ # TODO: it is quite confusing to call this LVAEModel, as it is basically a config
12
+ class LVAEModel(ArchitectureModel):
13
+ """LVAE model."""
14
+
15
+ model_config = ConfigDict(validate_assignment=True, validate_default=True)
16
+
17
+ architecture: Literal["LVAE"]
18
+ input_shape: int = Field(default=64, ge=8, le=1024)
19
+ multiscale_count: int = Field(default=5) # TODO clarify
20
+ # 0 - off, len(z_dims) + 1 # TODO can/should be le to z_dims len + 1
21
+ z_dims: list = Field(default=[128, 128, 128, 128])
22
+ output_channels: int = Field(default=1, ge=1)
23
+ encoder_n_filters: int = Field(default=64, ge=8, le=1024)
24
+ decoder_n_filters: int = Field(default=64, ge=8, le=1024)
25
+ encoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
26
+ decoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
27
+ nonlinearity: Literal[
28
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
29
+ ] = Field(
30
+ default="ELU",
31
+ )
32
+
33
+ predict_logvar: Literal[None, "pixelwise"] = None
34
+
35
+ # TODO this parameter is exessive -> Remove & refactor
36
+ enable_noise_model: bool = Field(
37
+ default=True,
38
+ )
39
+ analytical_kl: bool = Field(
40
+ default=False,
41
+ )
42
+
43
+ @field_validator("encoder_n_filters")
44
+ @classmethod
45
+ def validate_encoder_even(cls, encoder_n_filters: int) -> int:
46
+ """
47
+ Validate that num_channels_init is even.
48
+
49
+ Parameters
50
+ ----------
51
+ encoder_n_filters : int
52
+ Number of channels.
53
+
54
+ Returns
55
+ -------
56
+ int
57
+ Validated number of channels.
58
+
59
+ Raises
60
+ ------
61
+ ValueError
62
+ If the number of channels is odd.
63
+ """
64
+ # if odd
65
+ if encoder_n_filters % 2 != 0:
66
+ raise ValueError(
67
+ f"Number of channels for the bottom layer must be even"
68
+ f" (got {encoder_n_filters})."
69
+ )
70
+
71
+ return encoder_n_filters
72
+
73
+ @field_validator("decoder_n_filters")
74
+ @classmethod
75
+ def validate_decoder_even(cls, decoder_n_filters: int) -> int:
76
+ """
77
+ Validate that num_channels_init is even.
78
+
79
+ Parameters
80
+ ----------
81
+ decoder_n_filters : int
82
+ Number of channels.
83
+
84
+ Returns
85
+ -------
86
+ int
87
+ Validated number of channels.
88
+
89
+ Raises
90
+ ------
91
+ ValueError
92
+ If the number of channels is odd.
93
+ """
94
+ # if odd
95
+ if decoder_n_filters % 2 != 0:
96
+ raise ValueError(
97
+ f"Number of channels for the bottom layer must be even"
98
+ f" (got {decoder_n_filters})."
99
+ )
100
+
101
+ return decoder_n_filters
102
+
103
+ @field_validator("z_dims")
104
+ def validate_z_dims(cls, z_dims: tuple) -> tuple:
105
+ """
106
+ Validate the z_dims.
107
+
108
+ Parameters
109
+ ----------
110
+ z_dims : tuple
111
+ Tuple of z dimensions.
112
+
113
+ Returns
114
+ -------
115
+ tuple
116
+ Validated z dimensions.
117
+
118
+ Raises
119
+ ------
120
+ ValueError
121
+ If the number of z dimensions is not 4.
122
+ """
123
+ if len(z_dims) < 2:
124
+ raise ValueError(
125
+ f"Number of z dimensions must be at least 2 (got {len(z_dims)})."
126
+ )
127
+
128
+ return z_dims
129
+
130
+ @model_validator(mode="after")
131
+ def validate_multiscale_count(cls, self: Self) -> Self:
132
+ """
133
+ Validate the multiscale count.
134
+
135
+ Parameters
136
+ ----------
137
+ self : Self
138
+ The model.
139
+
140
+ Returns
141
+ -------
142
+ Self
143
+ The validated model.
144
+ """
145
+ # if self.multiscale_count != 0:
146
+ # if self.multiscale_count != len(self.z_dims) - 1:
147
+ # raise ValueError(
148
+ # f"Multiscale count must be 0 or equal to the number of Z "
149
+ # f"dims - 1 (got {self.multiscale_count} and {len(self.z_dims)})."
150
+ # )
151
+
152
+ return self
153
+
154
+ def set_3D(self, is_3D: bool) -> None:
155
+ """
156
+ Set 3D model by setting the `conv_dims` parameters.
157
+
158
+ Parameters
159
+ ----------
160
+ is_3D : bool
161
+ Whether the algorithm is 3D or not.
162
+ """
163
+ raise NotImplementedError("VAE is not implemented yet.")
164
+
165
+ def is_3D(self) -> bool:
166
+ """
167
+ Return whether the model is 3D or not.
168
+
169
+ Returns
170
+ -------
171
+ bool
172
+ Whether the model is 3D or not.
173
+ """
174
+ raise NotImplementedError("VAE is not implemented yet.")
@@ -2,10 +2,10 @@
2
2
 
3
3
  from typing import Any, Dict, List, Literal, Optional
4
4
 
5
- from .algorithm_model import AlgorithmConfig
6
5
  from .architectures import UNetModel
7
6
  from .configuration_model import Configuration
8
7
  from .data_model import DataConfig
8
+ from .fcn_algorithm_model import FCNAlgorithmConfig
9
9
  from .support import (
10
10
  SupportedAlgorithm,
11
11
  SupportedArchitecture,
@@ -16,7 +16,9 @@ from .support import (
16
16
  from .training_model import TrainingConfig
17
17
 
18
18
 
19
+ # TODO rename ?
19
20
  def _create_supervised_configuration(
21
+ algorithm_type: Literal["fcn"],
20
22
  algorithm: Literal["care", "n2n"],
21
23
  experiment_name: str,
22
24
  data_type: Literal["array", "tiff", "custom"],
@@ -37,6 +39,8 @@ def _create_supervised_configuration(
37
39
 
38
40
  Parameters
39
41
  ----------
42
+ algorithm_type : Literal["fcn"]
43
+ Type of the algorithm.
40
44
  algorithm : Literal["care", "n2n"]
41
45
  Algorithm to use.
42
46
  experiment_name : str
@@ -97,7 +101,8 @@ def _create_supervised_configuration(
97
101
  )
98
102
 
99
103
  # algorithm model
100
- algorithm = AlgorithmConfig(
104
+ algorithm = FCNAlgorithmConfig(
105
+ algorithm_type=algorithm_type,
101
106
  algorithm=algorithm,
102
107
  loss=loss,
103
108
  model=unet_model,
@@ -215,6 +220,7 @@ def create_care_configuration(
215
220
  n_channels_out = n_channels_in
216
221
 
217
222
  return _create_supervised_configuration(
223
+ algorithm_type="fcn",
218
224
  algorithm="care",
219
225
  experiment_name=experiment_name,
220
226
  data_type=data_type,
@@ -304,6 +310,7 @@ def create_n2n_configuration(
304
310
  n_channels_out = n_channels_in
305
311
 
306
312
  return _create_supervised_configuration(
313
+ algorithm_type="fcn",
307
314
  algorithm="n2n",
308
315
  experiment_name=experiment_name,
309
316
  data_type=data_type,
@@ -514,7 +521,8 @@ def create_n2v_configuration(
514
521
  )
515
522
 
516
523
  # algorithm model
517
- algorithm = AlgorithmConfig(
524
+ algorithm = FCNAlgorithmConfig(
525
+ algorithm_type="fcn",
518
526
  algorithm=SupportedAlgorithm.N2V.value,
519
527
  loss=SupportedLoss.N2V.value,
520
528
  model=unet_model,
@@ -9,11 +9,11 @@ from typing import Literal, Union
9
9
 
10
10
  import yaml
11
11
  from bioimageio.spec.generic.v0_3 import CiteEntry
12
- from pydantic import BaseModel, ConfigDict, field_validator, model_validator
12
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
13
13
  from typing_extensions import Self
14
14
 
15
- from .algorithm_model import AlgorithmConfig
16
15
  from .data_model import DataConfig
16
+ from .fcn_algorithm_model import FCNAlgorithmConfig
17
17
  from .references import (
18
18
  CARE,
19
19
  CUSTOM,
@@ -39,6 +39,7 @@ from .training_model import TrainingConfig
39
39
  from .transformations.n2v_manipulate_model import (
40
40
  N2VManipulateModel,
41
41
  )
42
+ from .vae_algorithm_model import VAEAlgorithmConfig
42
43
 
43
44
 
44
45
  class Configuration(BaseModel):
@@ -123,6 +124,7 @@ class Configuration(BaseModel):
123
124
  >>> config_dict = {
124
125
  ... "experiment_name": "N2V_experiment",
125
126
  ... "algorithm_config": {
127
+ ... "algorithm_type": "fcn",
126
128
  ... "algorithm": "n2v",
127
129
  ... "loss": "n2v",
128
130
  ... "model": {
@@ -155,7 +157,9 @@ class Configuration(BaseModel):
155
157
  """Name of the experiment, used to name logs and checkpoints."""
156
158
 
157
159
  # Sub-configurations
158
- algorithm_config: AlgorithmConfig
160
+ algorithm_config: Union[FCNAlgorithmConfig, VAEAlgorithmConfig] = Field(
161
+ discriminator="algorithm_type"
162
+ )
159
163
  """Algorithm configuration, holding all parameters required to configure the
160
164
  model."""
161
165
 
@@ -5,12 +5,14 @@ from __future__ import annotations
5
5
  from pprint import pformat
6
6
  from typing import Any, Literal, Optional, Union
7
7
 
8
+ import numpy as np
8
9
  from numpy.typing import NDArray
9
10
  from pydantic import (
10
11
  BaseModel,
11
12
  ConfigDict,
12
13
  Discriminator,
13
14
  Field,
15
+ PlainSerializer,
14
16
  field_validator,
15
17
  model_validator,
16
18
  )
@@ -22,6 +24,30 @@ from .transformations.xy_flip_model import XYFlipModel
22
24
  from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
23
25
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
24
26
 
27
+
28
+ def np_float_to_scientific_str(x: float) -> str:
29
+ """Return a string scientific representation of a float.
30
+
31
+ In particular, this method is used to serialize floats to strings, allowing
32
+ numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
33
+
34
+ Parameters
35
+ ----------
36
+ x : float
37
+ Input value.
38
+
39
+ Returns
40
+ -------
41
+ str
42
+ Scientific string representation of the input value.
43
+ """
44
+ return np.format_float_scientific(x, precision=7)
45
+
46
+
47
+ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
48
+ """Annotated float type, used to serialize floats to strings."""
49
+
50
+
25
51
  TRANSFORMS_UNION = Annotated[
26
52
  Union[
27
53
  XYFlipModel,
@@ -30,6 +56,7 @@ TRANSFORMS_UNION = Annotated[
30
56
  ],
31
57
  Discriminator("name"), # used to tell the different transform models apart
32
58
  ]
59
+ """Available transforms in CAREamics."""
33
60
 
34
61
 
35
62
  class DataConfig(BaseModel):
@@ -94,20 +121,20 @@ class DataConfig(BaseModel):
94
121
  """Batch size for training."""
95
122
 
96
123
  # Optional fields
97
- image_means: Optional[list[float]] = Field(
124
+ image_means: Optional[list[Float]] = Field(
98
125
  default=None, min_length=0, max_length=32
99
126
  )
100
127
  """Means of the data across channels, used for normalization."""
101
128
 
102
- image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
129
+ image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
103
130
  """Standard deviations of the data across channels, used for normalization."""
104
131
 
105
- target_means: Optional[list[float]] = Field(
132
+ target_means: Optional[list[Float]] = Field(
106
133
  default=None, min_length=0, max_length=32
107
134
  )
108
135
  """Means of the target data across channels, used for normalization."""
109
136
 
110
- target_stds: Optional[list[float]] = Field(
137
+ target_stds: Optional[list[Float]] = Field(
111
138
  default=None, min_length=0, max_length=32
112
139
  )
113
140
  """Standard deviations of the target data across channels, used for
@@ -265,9 +292,7 @@ class DataConfig(BaseModel):
265
292
  elif (self.image_means is not None and self.image_stds is not None) and (
266
293
  len(self.image_means) != len(self.image_stds)
267
294
  ):
268
- raise ValueError(
269
- "Mean and std must be specified for each " "input channel."
270
- )
295
+ raise ValueError("Mean and std must be specified for each input channel.")
271
296
 
272
297
  if (self.target_means and not self.target_stds) or (
273
298
  self.target_stds and not self.target_means
@@ -380,7 +405,7 @@ class DataConfig(BaseModel):
380
405
 
381
406
  Parameters
382
407
  ----------
383
- image_means : numpy.ndarray ,tuple or list
408
+ image_means : numpy.ndarray, tuple or list
384
409
  Mean values for normalization.
385
410
  image_stds : numpy.ndarray, tuple or list
386
411
  Standard deviation values for normalization.
@@ -1,6 +1,4 @@
1
- """Algorithm configuration."""
2
-
3
- from __future__ import annotations
1
+ """Module containing `FCNAlgorithmConfig` class."""
4
2
 
5
3
  from pprint import pformat
6
4
  from typing import Literal, Union
@@ -8,11 +6,11 @@ from typing import Literal, Union
8
6
  from pydantic import BaseModel, ConfigDict, Field, model_validator
9
7
  from typing_extensions import Self
10
8
 
11
- from .architectures import CustomModel, UNetModel, VAEModel
12
- from .optimizer_models import LrSchedulerModel, OptimizerModel
9
+ from careamics.config.architectures import CustomModel, UNetModel
10
+ from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
13
11
 
14
12
 
15
- class AlgorithmConfig(BaseModel):
13
+ class FCNAlgorithmConfig(BaseModel):
16
14
  """Algorithm configuration.
17
15
 
18
16
  This Pydantic model validates the parameters governing the components of the
@@ -30,7 +28,7 @@ class AlgorithmConfig(BaseModel):
30
28
  Algorithm to use.
31
29
  loss : Literal["n2v", "mae", "mse"]
32
30
  Loss function to use.
33
- model : Union[UNetModel, VAEModel, CustomModel]
31
+ model : Union[UNetModel, LVAEModel, CustomModel]
34
32
  Model architecture to use.
35
33
  optimizer : OptimizerModel, optional
36
34
  Optimizer to use.
@@ -47,66 +45,51 @@ class AlgorithmConfig(BaseModel):
47
45
  Examples
48
46
  --------
49
47
  Minimum example:
50
- >>> from careamics.config import AlgorithmConfig
48
+ >>> from careamics.config import FCNAlgorithmConfig
51
49
  >>> config_dict = {
52
50
  ... "algorithm": "n2v",
51
+ ... "algorithm_type": "fcn",
53
52
  ... "loss": "n2v",
54
53
  ... "model": {
55
54
  ... "architecture": "UNet",
56
55
  ... }
57
56
  ... }
58
- >>> config = AlgorithmConfig(**config_dict)
59
-
60
- Using a custom model:
61
- >>> from torch import nn, ones
62
- >>> from careamics.config import AlgorithmConfig, register_model
63
- ...
64
- >>> @register_model(name="linear_model")
65
- ... class LinearModel(nn.Module):
66
- ... def __init__(self, in_features, out_features, *args, **kwargs):
67
- ... super().__init__()
68
- ... self.in_features = in_features
69
- ... self.out_features = out_features
70
- ... self.weight = nn.Parameter(ones(in_features, out_features))
71
- ... self.bias = nn.Parameter(ones(out_features))
72
- ... def forward(self, input):
73
- ... return (input @ self.weight) + self.bias
74
- ...
75
- >>> config_dict = {
76
- ... "algorithm": "custom",
77
- ... "loss": "mse",
78
- ... "model": {
79
- ... "architecture": "Custom",
80
- ... "name": "linear_model",
81
- ... "in_features": 10,
82
- ... "out_features": 5,
83
- ... }
84
- ... }
85
- >>> config = AlgorithmConfig(**config_dict)
57
+ >>> config = FCNAlgorithmConfig(**config_dict)
86
58
  """
87
59
 
88
60
  # Pydantic class configuration
89
61
  model_config = ConfigDict(
90
62
  protected_namespaces=(), # allows to use model_* as a field name
91
63
  validate_assignment=True,
64
+ extra="allow",
92
65
  )
93
66
 
94
67
  # Mandatory fields
95
- algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm
96
- """Name of the algorithm, as defined in SupportedAlgorithm."""
68
+ # defined in SupportedAlgorithm
69
+ algorithm_type: Literal["fcn"]
70
+ """Algorithm type must be `fcn` (fully convolutional network) to differentiate this
71
+ configuration from LVAE."""
72
+
73
+ algorithm: Literal["n2v", "care", "n2n", "custom"]
74
+ """Name of the algorithm, as defined in SupportedAlgorithm. Use `custom` for custom
75
+ model architecture."""
97
76
 
98
77
  loss: Literal["n2v", "mae", "mse"]
99
78
  """Loss function to use, as defined in SupportedLoss."""
100
79
 
101
- model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture")
102
- """Model architecture to use, defined in SupportedArchitecture."""
80
+ model: Union[UNetModel, CustomModel] = Field(discriminator="architecture")
81
+ """Model architecture to use, along with its parameters. Compatible architectures
82
+ are defined in SupportedArchitecture, and their Pydantic models in
83
+ `careamics.config.architectures`."""
84
+ # TODO supported architectures are now all the architectures but does not warn users
85
+ # of the compatibility with the algorithm
103
86
 
104
87
  # Optional fields
105
88
  optimizer: OptimizerModel = OptimizerModel()
106
89
  """Optimizer to use, defined in SupportedOptimizer."""
107
90
 
108
91
  lr_scheduler: LrSchedulerModel = LrSchedulerModel()
109
- """Learning rate scheduler to use, defined in SupportedScheduler."""
92
+ """Learning rate scheduler to use, defined in SupportedLrScheduler."""
110
93
 
111
94
  @model_validator(mode="after")
112
95
  def algorithm_cross_validation(self: Self) -> Self:
@@ -146,8 +129,10 @@ class AlgorithmConfig(BaseModel):
146
129
  if self.loss == "n2v":
147
130
  raise ValueError("Supervised algorithms do not support loss `n2v`.")
148
131
 
149
- if isinstance(self.model, VAEModel):
150
- raise ValueError("VAE are currently not implemented.")
132
+ if (self.algorithm == "custom") != (self.model.architecture == "custom"):
133
+ raise ValueError(
134
+ "Algorithm and model architecture must be both `custom` or not."
135
+ )
151
136
 
152
137
  return self
153
138