careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -9,11 +9,11 @@ from typing import Literal, Union
9
9
 
10
10
  import yaml
11
11
  from bioimageio.spec.generic.v0_3 import CiteEntry
12
- from pydantic import BaseModel, ConfigDict, field_validator, model_validator
12
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
13
13
  from typing_extensions import Self
14
14
 
15
- from .algorithm_model import AlgorithmConfig
16
15
  from .data_model import DataConfig
16
+ from .fcn_algorithm_model import FCNAlgorithmConfig
17
17
  from .references import (
18
18
  CARE,
19
19
  CUSTOM,
@@ -39,6 +39,7 @@ from .training_model import TrainingConfig
39
39
  from .transformations.n2v_manipulate_model import (
40
40
  N2VManipulateModel,
41
41
  )
42
+ from .vae_algorithm_model import VAEAlgorithmConfig
42
43
 
43
44
 
44
45
  class Configuration(BaseModel):
@@ -155,7 +156,9 @@ class Configuration(BaseModel):
155
156
  """Name of the experiment, used to name logs and checkpoints."""
156
157
 
157
158
  # Sub-configurations
158
- algorithm_config: AlgorithmConfig
159
+ algorithm_config: Union[FCNAlgorithmConfig, VAEAlgorithmConfig] = Field(
160
+ discriminator="algorithm"
161
+ )
159
162
  """Algorithm configuration, holding all parameters required to configure the
160
163
  model."""
161
164
 
@@ -5,31 +5,44 @@ from __future__ import annotations
5
5
  from pprint import pformat
6
6
  from typing import Any, Literal, Optional, Union
7
7
 
8
+ import numpy as np
8
9
  from numpy.typing import NDArray
9
10
  from pydantic import (
10
11
  BaseModel,
11
12
  ConfigDict,
12
- Discriminator,
13
13
  Field,
14
+ PlainSerializer,
14
15
  field_validator,
15
16
  model_validator,
16
17
  )
17
18
  from typing_extensions import Annotated, Self
18
19
 
19
20
  from .support import SupportedTransform
20
- from .transformations.n2v_manipulate_model import N2VManipulateModel
21
- from .transformations.xy_flip_model import XYFlipModel
22
- from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
21
+ from .transformations import TRANSFORMS_UNION, N2VManipulateModel
23
22
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
24
23
 
25
- TRANSFORMS_UNION = Annotated[
26
- Union[
27
- XYFlipModel,
28
- XYRandomRotate90Model,
29
- N2VManipulateModel,
30
- ],
31
- Discriminator("name"), # used to tell the different transform models apart
32
- ]
24
+
25
+ def np_float_to_scientific_str(x: float) -> str:
26
+ """Return a string scientific representation of a float.
27
+
28
+ In particular, this method is used to serialize floats to strings, allowing
29
+ numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
30
+
31
+ Parameters
32
+ ----------
33
+ x : float
34
+ Input value.
35
+
36
+ Returns
37
+ -------
38
+ str
39
+ Scientific string representation of the input value.
40
+ """
41
+ return np.format_float_scientific(x, precision=7)
42
+
43
+
44
+ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
45
+ """Annotated float type, used to serialize floats to strings."""
33
46
 
34
47
 
35
48
  class DataConfig(BaseModel):
@@ -94,20 +107,20 @@ class DataConfig(BaseModel):
94
107
  """Batch size for training."""
95
108
 
96
109
  # Optional fields
97
- image_means: Optional[list[float]] = Field(
110
+ image_means: Optional[list[Float]] = Field(
98
111
  default=None, min_length=0, max_length=32
99
112
  )
100
113
  """Means of the data across channels, used for normalization."""
101
114
 
102
- image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
115
+ image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
103
116
  """Standard deviations of the data across channels, used for normalization."""
104
117
 
105
- target_means: Optional[list[float]] = Field(
118
+ target_means: Optional[list[Float]] = Field(
106
119
  default=None, min_length=0, max_length=32
107
120
  )
108
121
  """Means of the target data across channels, used for normalization."""
109
122
 
110
- target_stds: Optional[list[float]] = Field(
123
+ target_stds: Optional[list[Float]] = Field(
111
124
  default=None, min_length=0, max_length=32
112
125
  )
113
126
  """Standard deviations of the target data across channels, used for
@@ -265,9 +278,7 @@ class DataConfig(BaseModel):
265
278
  elif (self.image_means is not None and self.image_stds is not None) and (
266
279
  len(self.image_means) != len(self.image_stds)
267
280
  ):
268
- raise ValueError(
269
- "Mean and std must be specified for each " "input channel."
270
- )
281
+ raise ValueError("Mean and std must be specified for each input channel.")
271
282
 
272
283
  if (self.target_means and not self.target_stds) or (
273
284
  self.target_stds and not self.target_means
@@ -380,7 +391,7 @@ class DataConfig(BaseModel):
380
391
 
381
392
  Parameters
382
393
  ----------
383
- image_means : numpy.ndarray ,tuple or list
394
+ image_means : numpy.ndarray, tuple or list
384
395
  Mean values for normalization.
385
396
  image_stds : numpy.ndarray, tuple or list
386
397
  Standard deviation values for normalization.
@@ -1,6 +1,4 @@
1
- """Algorithm configuration."""
2
-
3
- from __future__ import annotations
1
+ """Module containing `FCNAlgorithmConfig` class."""
4
2
 
5
3
  from pprint import pformat
6
4
  from typing import Literal, Union
@@ -8,11 +6,11 @@ from typing import Literal, Union
8
6
  from pydantic import BaseModel, ConfigDict, Field, model_validator
9
7
  from typing_extensions import Self
10
8
 
11
- from .architectures import CustomModel, UNetModel, VAEModel
12
- from .optimizer_models import LrSchedulerModel, OptimizerModel
9
+ from careamics.config.architectures import CustomModel, UNetModel
10
+ from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
13
11
 
14
12
 
15
- class AlgorithmConfig(BaseModel):
13
+ class FCNAlgorithmConfig(BaseModel):
16
14
  """Algorithm configuration.
17
15
 
18
16
  This Pydantic model validates the parameters governing the components of the
@@ -26,11 +24,11 @@ class AlgorithmConfig(BaseModel):
26
24
 
27
25
  Attributes
28
26
  ----------
29
- algorithm : Literal["n2v", "custom"]
27
+ algorithm : {"n2v", "care", "n2n", "custom"}
30
28
  Algorithm to use.
31
- loss : Literal["n2v", "mae", "mse"]
29
+ loss : {"n2v", "mae", "mse"}
32
30
  Loss function to use.
33
- model : Union[UNetModel, VAEModel, CustomModel]
31
+ model : UNetModel or CustomModel
34
32
  Model architecture to use.
35
33
  optimizer : OptimizerModel, optional
36
34
  Optimizer to use.
@@ -47,7 +45,7 @@ class AlgorithmConfig(BaseModel):
47
45
  Examples
48
46
  --------
49
47
  Minimum example:
50
- >>> from careamics.config import AlgorithmConfig
48
+ >>> from careamics.config import FCNAlgorithmConfig
51
49
  >>> config_dict = {
52
50
  ... "algorithm": "n2v",
53
51
  ... "loss": "n2v",
@@ -55,58 +53,37 @@ class AlgorithmConfig(BaseModel):
55
53
  ... "architecture": "UNet",
56
54
  ... }
57
55
  ... }
58
- >>> config = AlgorithmConfig(**config_dict)
59
-
60
- Using a custom model:
61
- >>> from torch import nn, ones
62
- >>> from careamics.config import AlgorithmConfig, register_model
63
- ...
64
- >>> @register_model(name="linear_model")
65
- ... class LinearModel(nn.Module):
66
- ... def __init__(self, in_features, out_features, *args, **kwargs):
67
- ... super().__init__()
68
- ... self.in_features = in_features
69
- ... self.out_features = out_features
70
- ... self.weight = nn.Parameter(ones(in_features, out_features))
71
- ... self.bias = nn.Parameter(ones(out_features))
72
- ... def forward(self, input):
73
- ... return (input @ self.weight) + self.bias
74
- ...
75
- >>> config_dict = {
76
- ... "algorithm": "custom",
77
- ... "loss": "mse",
78
- ... "model": {
79
- ... "architecture": "Custom",
80
- ... "name": "linear_model",
81
- ... "in_features": 10,
82
- ... "out_features": 5,
83
- ... }
84
- ... }
85
- >>> config = AlgorithmConfig(**config_dict)
56
+ >>> config = FCNAlgorithmConfig(**config_dict)
86
57
  """
87
58
 
88
59
  # Pydantic class configuration
89
60
  model_config = ConfigDict(
90
61
  protected_namespaces=(), # allows to use model_* as a field name
91
62
  validate_assignment=True,
63
+ extra="allow",
92
64
  )
93
65
 
94
66
  # Mandatory fields
95
- algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm
96
- """Name of the algorithm, as defined in SupportedAlgorithm."""
67
+ algorithm: Literal["n2v", "care", "n2n", "custom"]
68
+ """Name of the algorithm, as defined in SupportedAlgorithm. Use `custom` for custom
69
+ model architecture."""
97
70
 
98
71
  loss: Literal["n2v", "mae", "mse"]
99
72
  """Loss function to use, as defined in SupportedLoss."""
100
73
 
101
- model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture")
102
- """Model architecture to use, defined in SupportedArchitecture."""
74
+ model: Union[UNetModel, CustomModel] = Field(discriminator="architecture")
75
+ """Model architecture to use, along with its parameters. Compatible architectures
76
+ are defined in SupportedArchitecture, and their Pydantic models in
77
+ `careamics.config.architectures`."""
78
+ # TODO supported architectures are now all the architectures but does not warn users
79
+ # of the compatibility with the algorithm
103
80
 
104
81
  # Optional fields
105
82
  optimizer: OptimizerModel = OptimizerModel()
106
83
  """Optimizer to use, defined in SupportedOptimizer."""
107
84
 
108
85
  lr_scheduler: LrSchedulerModel = LrSchedulerModel()
109
- """Learning rate scheduler to use, defined in SupportedScheduler."""
86
+ """Learning rate scheduler to use, defined in SupportedLrScheduler."""
110
87
 
111
88
  @model_validator(mode="after")
112
89
  def algorithm_cross_validation(self: Self) -> Self:
@@ -146,8 +123,10 @@ class AlgorithmConfig(BaseModel):
146
123
  if self.loss == "n2v":
147
124
  raise ValueError("Supervised algorithms do not support loss `n2v`.")
148
125
 
149
- if isinstance(self.model, VAEModel):
150
- raise ValueError("VAE are currently not implemented.")
126
+ if (self.algorithm == "custom") != (self.model.architecture == "custom"):
127
+ raise ValueError(
128
+ "Algorithm and model architecture must be both `custom` or not."
129
+ )
151
130
 
152
131
  return self
153
132
 
@@ -160,3 +139,14 @@ class AlgorithmConfig(BaseModel):
160
139
  Pretty string.
161
140
  """
162
141
  return pformat(self.model_dump())
142
+
143
+ @classmethod
144
+ def get_compatible_algorithms(cls) -> list[str]:
145
+ """Get the list of compatible algorithms.
146
+
147
+ Returns
148
+ -------
149
+ list of str
150
+ List of compatible algorithms.
151
+ """
152
+ return ["n2v", "care", "n2n"]
@@ -0,0 +1,60 @@
1
+ """Likelihood model."""
2
+
3
+ from typing import Literal, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, PlainValidator
8
+ from typing_extensions import Annotated
9
+
10
+ from careamics.models.lvae.noise_models import (
11
+ GaussianMixtureNoiseModel,
12
+ MultiChannelNoiseModel,
13
+ )
14
+ from careamics.utils.serializers import _array_to_json, _to_torch
15
+
16
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
17
+
18
+ # TODO: this is a temporary solution to serialize and deserialize tensor fields
19
+ # in pydantic models. Specifically, the aim is to enable saving and loading configs
20
+ # with such tensors to/from JSON files during, resp., training and evaluation.
21
+ Tensor = Annotated[
22
+ Union[np.ndarray, torch.Tensor],
23
+ PlainSerializer(_array_to_json, return_type=str),
24
+ PlainValidator(_to_torch),
25
+ ]
26
+ """Annotated tensor type, used to serialize arrays or tensors to JSON strings
27
+ and deserialize them back to tensors."""
28
+
29
+
30
+ class GaussianLikelihoodConfig(BaseModel):
31
+ """Gaussian likelihood configuration."""
32
+
33
+ model_config = ConfigDict(validate_assignment=True)
34
+
35
+ predict_logvar: Optional[Literal["pixelwise"]] = None
36
+ """If `pixelwise`, log-variance is computed for each pixel, else log-variance
37
+ is not computed."""
38
+
39
+ logvar_lowerbound: Union[float, None] = None
40
+ """The lowerbound value for log-variance."""
41
+
42
+
43
+ class NMLikelihoodConfig(BaseModel):
44
+ """Noise model likelihood configuration."""
45
+
46
+ model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
47
+
48
+ # TODO remove and use as parameters to the likelihood functions?
49
+ data_mean: Tensor = torch.zeros(1)
50
+ """The mean of the data, used to unnormalize data for noise model evaluation.
51
+ Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
52
+
53
+ # TODO remove and use as parameters to the likelihood functions?
54
+ data_std: Tensor = torch.ones(1)
55
+ """The standard deviation of the data, used to unnormalize data for noise
56
+ model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
57
+
58
+ # TODO: serialization/deserialization for this
59
+ noise_model: Optional[NoiseModel] = Field(default=None, exclude=True)
60
+ """The noise model instance used to compute the likelihood."""
@@ -0,0 +1,127 @@
1
+ """Noise models config."""
2
+
3
+ from pathlib import Path
4
+ from typing import Literal, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from pydantic import (
9
+ BaseModel,
10
+ ConfigDict,
11
+ Field,
12
+ PlainSerializer,
13
+ PlainValidator,
14
+ model_validator,
15
+ )
16
+ from typing_extensions import Annotated, Self
17
+
18
+ from careamics.utils.serializers import _array_to_json, _to_numpy
19
+
20
+ # TODO: this is a temporary solution to serialize and deserialize array fields
21
+ # in pydantic models. Specifically, the aim is to enable saving and loading configs
22
+ # with such arrays to/from JSON files during, resp., training and evaluation.
23
+ Array = Annotated[
24
+ Union[np.ndarray, torch.Tensor],
25
+ PlainSerializer(_array_to_json, return_type=str),
26
+ PlainValidator(_to_numpy),
27
+ ]
28
+ """Annotated array type, used to serialize arrays or tensors to JSON strings
29
+ and deserialize them back to arrays."""
30
+
31
+
32
+ # TODO: add histogram-based noise model
33
+
34
+
35
+ class GaussianMixtureNMConfig(BaseModel):
36
+ """Gaussian mixture noise model."""
37
+
38
+ model_config = ConfigDict(
39
+ protected_namespaces=(),
40
+ validate_assignment=True,
41
+ arbitrary_types_allowed=True,
42
+ extra="allow",
43
+ )
44
+ # model type
45
+ model_type: Literal["GaussianMixtureNoiseModel"]
46
+
47
+ path: Optional[Union[Path, str]] = None
48
+ """Path to the directory where the trained noise model (*.npz) is saved in the
49
+ `train` method."""
50
+
51
+ # TODO remove and use as parameters to the NM functions?
52
+ signal: Optional[Union[str, Path, np.ndarray]] = Field(default=None, exclude=True)
53
+ """Path to the file containing signal or respective numpy array."""
54
+
55
+ # TODO remove and use as parameters to the NM functions?
56
+ observation: Optional[Union[str, Path, np.ndarray]] = Field(
57
+ default=None, exclude=True
58
+ )
59
+ """Path to the file containing observation or respective numpy array."""
60
+
61
+ weight: Optional[Array] = None
62
+ """A [3*n_gaussian, n_coeff] sized array containing the values of the weights
63
+ describing the GMM noise model, with each row corresponding to one
64
+ parameter of each gaussian, namely [mean, standard deviation and weight].
65
+ Specifically, rows are organized as follows:
66
+ - first n_gaussian rows correspond to the means
67
+ - next n_gaussian rows correspond to the weights
68
+ - last n_gaussian rows correspond to the standard deviations
69
+ If `weight=None`, the weight array is initialized using the `min_signal`
70
+ and `max_signal` parameters."""
71
+
72
+ n_gaussian: int = Field(default=1, ge=1)
73
+ """Number of gaussians used for the GMM."""
74
+
75
+ n_coeff: int = Field(default=2, ge=2)
76
+ """Number of coefficients to describe the functional relationship between gaussian
77
+ parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
78
+ relationship and so on."""
79
+
80
+ min_signal: float = Field(default=0.0, ge=0.0)
81
+ """Minimum signal intensity expected in the image."""
82
+
83
+ max_signal: float = Field(default=1.0, ge=0.0)
84
+ """Maximum signal intensity expected in the image."""
85
+
86
+ min_sigma: float = Field(default=200.0, ge=0.0) # TODO took from nb in pn2v
87
+ """Minimum value of `standard deviation` allowed in the GMM.
88
+ All values of `standard deviation` below this are clamped to this value."""
89
+
90
+ tol: float = Field(default=1e-10)
91
+ """Tolerance used in the computation of the noise model likelihood."""
92
+
93
+ @model_validator(mode="after")
94
+ def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
95
+ """Validate paths provided in the config.
96
+
97
+ Returns
98
+ -------
99
+ Self
100
+ Returns itself.
101
+ """
102
+ if self.path and (self.signal is not None or self.observation is not None):
103
+ raise ValueError(
104
+ "Either only 'path' to pre-trained noise model should be"
105
+ "provided or only signal and observation in form of paths"
106
+ "or numpy arrays."
107
+ )
108
+ if not self.path and (self.signal is None or self.observation is None):
109
+ raise ValueError(
110
+ "Either only 'path' to pre-trained noise model should be"
111
+ "provided or only signal and observation in form of paths"
112
+ "or numpy arrays."
113
+ )
114
+ return self
115
+
116
+
117
+ # The noise model is given by a set of GMMs, one for each target
118
+ # e.g., 2 target channels, 2 noise models
119
+ class MultiChannelNMConfig(BaseModel):
120
+ """Noise Model config aggregating noise models for single output channels."""
121
+
122
+ # TODO: check that this model config is OK
123
+ model_config = ConfigDict(
124
+ validate_assignment=True, arbitrary_types_allowed=True, extra="allow"
125
+ )
126
+ noise_models: list[GaussianMixtureNMConfig]
127
+ """List of noise models, one for each target channel."""
@@ -44,7 +44,9 @@ class OptimizerModel(BaseModel):
44
44
  )
45
45
 
46
46
  # Mandatory field
47
- name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
47
+ name: Literal["Adam", "SGD", "Adamax"] = Field(
48
+ default="Adam", validate_default=True
49
+ )
48
50
  """Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
49
51
 
50
52
  # Optional parameters, empty dict default value to allow filtering dictionary
@@ -24,3 +24,4 @@ class SupportedActivation(str, BaseEnum):
24
24
  TANH = "Tanh"
25
25
  RELU = "ReLU"
26
26
  LEAKYRELU = "LeakyReLU"
27
+ ELU = "ELU"
@@ -6,15 +6,28 @@ from careamics.utils import BaseEnum
6
6
 
7
7
 
8
8
  class SupportedAlgorithm(str, BaseEnum):
9
- """Algorithms available in CAREamics.
10
-
11
- # TODO
12
- """
9
+ """Algorithms available in CAREamics."""
13
10
 
14
11
  N2V = "n2v"
12
+ """Noise2Void algorithm, a self-supervised approach based on blind denoising."""
13
+
15
14
  CARE = "care"
15
+ """Content-aware image restoration, a supervised algorithm used for a variety
16
+ of tasks."""
17
+
16
18
  N2N = "n2n"
19
+ """Noise2Noise algorithm, a self-supervised denoising scheme based on comparing
20
+ noisy images of the same sample."""
21
+
22
+ MUSPLIT = "musplit"
23
+ """An image splitting approach based on ladder VAE architectures."""
24
+
25
+ DENOISPLIT = "denoisplit"
26
+ """An image splitting and denoising approach based on ladder VAE architectures."""
27
+
17
28
  CUSTOM = "custom"
29
+ """Custom algorithm, used for cases where a custom architecture is provided."""
30
+
18
31
  # PN2V = "pn2v"
19
32
  # HDN = "hdn"
20
33
  # SEG = "segmentation"
@@ -4,17 +4,14 @@ from careamics.utils import BaseEnum
4
4
 
5
5
 
6
6
  class SupportedArchitecture(str, BaseEnum):
7
- """Supported architectures.
7
+ """Supported architectures."""
8
8
 
9
- # TODO add details, in particular where to find the API for the models
9
+ UNET = "UNet"
10
+ """UNet architecture used with N2V, CARE and Noise2Noise."""
10
11
 
11
- - UNet: classical UNet compatible with N2V2
12
- - VAE: variational Autoencoder
13
- - Custom: custom model registered with `@register_model` decorator
14
- """
12
+ LVAE = "LVAE"
13
+ """Ladder Variational Autoencoder used for muSplit and denoiSplit."""
15
14
 
16
- UNET = "UNet"
17
- VAE = "VAE"
18
- CUSTOM = (
19
- "Custom" # TODO all the others tags are small letters, except the architect
20
- )
15
+ CUSTOM = "custom"
16
+ """Keyword used for custom architectures provided by users and only compatible
17
+ with `FCNAlgorithmConfig` configuration."""
@@ -22,6 +22,8 @@ class SupportedLoss(str, BaseEnum):
22
22
  N2V = "n2v"
23
23
  # PN2V = "pn2v"
24
24
  # HDN = "hdn"
25
+ MUSPLIT = "musplit"
26
+ DENOISPLIT = "denoisplit"
27
+ DENOISPLIT_MUSPLIT = "denoisplit_musplit"
25
28
  # CE = "ce"
26
29
  # DICE = "dice"
27
- # CUSTOM = "custom" # TODO create mechanism for that
@@ -19,7 +19,7 @@ class SupportedOptimizer(str, BaseEnum):
19
19
  # Adagrad = "Adagrad"
20
20
  ADAM = "Adam"
21
21
  # AdamW = "AdamW"
22
- # Adamax = "Adamax"
22
+ ADAMAX = "Adamax"
23
23
  # LBFGS = "LBFGS"
24
24
  # NAdam = "NAdam"
25
25
  # RAdam = "RAdam"
@@ -9,3 +9,4 @@ class SupportedTransform(str, BaseEnum):
9
9
  XY_FLIP = "XYFlip"
10
10
  XY_RANDOM_ROTATE90 = "XYRandomRotate90"
11
11
  N2V_MANIPULATE = "N2VManipulate"
12
+ NORMALIZE = "Normalize"
@@ -3,13 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Literal, Optional
6
+ from typing import Literal, Optional, Union
7
7
 
8
- from pydantic import (
9
- BaseModel,
10
- ConfigDict,
11
- Field,
12
- )
8
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
13
9
 
14
10
  from .callback_model import CheckpointModel, EarlyStoppingModel
15
11
 
@@ -37,6 +33,20 @@ class TrainingConfig(BaseModel):
37
33
  num_epochs: int = Field(default=20, ge=1)
38
34
  """Number of epochs, greater than 0."""
39
35
 
36
+ precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
37
+ """Numerical precision"""
38
+ max_steps: int = Field(default=-1, ge=-1)
39
+ """Maximum number of steps to train for. -1 means no limit."""
40
+ check_val_every_n_epoch: int = Field(default=1, ge=1)
41
+ """Validation step frequency."""
42
+ enable_progress_bar: bool = Field(default=True)
43
+ """Whether to enable the progress bar."""
44
+ accumulate_grad_batches: int = Field(default=1, ge=1)
45
+ """Number of batches to accumulate gradients over before stepping the optimizer."""
46
+ gradient_clip_val: Optional[Union[int, float]] = None
47
+ """The value to which to clip the gradient"""
48
+ gradient_clip_algorithm: Literal["value", "norm"] = "norm"
49
+ """The algorithm to use for gradient clipping (see lightning `Trainer`)."""
40
50
  logger: Optional[Literal["wandb", "tensorboard"]] = None
41
51
  """Logger to use during training. If None, no logger will be used. Available
42
52
  loggers are defined in SupportedLogger."""
@@ -70,3 +80,22 @@ class TrainingConfig(BaseModel):
70
80
  Whether the logger is defined or not.
71
81
  """
72
82
  return self.logger is not None
83
+
84
+ @field_validator("max_steps")
85
+ @classmethod
86
+ def validate_max_steps(cls, max_steps: int) -> int:
87
+ """Validate the max_steps parameter.
88
+
89
+ Parameters
90
+ ----------
91
+ max_steps : int
92
+ Maximum number of steps to train for. -1 means no limit.
93
+
94
+ Returns
95
+ -------
96
+ int
97
+ Validated max_steps.
98
+ """
99
+ if max_steps == 0:
100
+ raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
101
+ return max_steps
@@ -5,11 +5,14 @@ __all__ = [
5
5
  "XYFlipModel",
6
6
  "NormalizeModel",
7
7
  "XYRandomRotate90Model",
8
- "XorYFlipModel",
8
+ "TransformModel",
9
+ "TRANSFORMS_UNION",
9
10
  ]
10
11
 
11
12
 
12
13
  from .n2v_manipulate_model import N2VManipulateModel
13
14
  from .normalize_model import NormalizeModel
15
+ from .transform_model import TransformModel
16
+ from .transform_union import TRANSFORMS_UNION
14
17
  from .xy_flip_model import XYFlipModel
15
18
  from .xy_random_rotate90_model import XYRandomRotate90Model
@@ -33,7 +33,7 @@ class N2VManipulateModel(TransformModel):
33
33
 
34
34
  name: Literal["N2VManipulate"] = "N2VManipulate"
35
35
  roi_size: int = Field(default=11, ge=3, le=21)
36
- masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0)
36
+ masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=10.0)
37
37
  strategy: Literal["uniform", "median"] = Field(default="uniform")
38
38
  struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
39
39
  struct_mask_span: int = Field(default=5, ge=3, le=15)