careamics 0.0.4.2__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (43) hide show
  1. careamics/careamist.py +235 -25
  2. careamics/cli/conf.py +19 -30
  3. careamics/cli/main.py +111 -10
  4. careamics/cli/utils.py +29 -0
  5. careamics/config/__init__.py +2 -0
  6. careamics/config/architectures/lvae_model.py +104 -21
  7. careamics/config/configuration_factory.py +49 -45
  8. careamics/config/configuration_model.py +2 -2
  9. careamics/config/likelihood_model.py +7 -6
  10. careamics/config/loss_model.py +56 -0
  11. careamics/config/nm_model.py +24 -24
  12. careamics/config/vae_algorithm_model.py +14 -13
  13. careamics/dataset/dataset_utils/running_stats.py +22 -23
  14. careamics/lightning/lightning_module.py +58 -27
  15. careamics/lightning/train_data_module.py +15 -1
  16. careamics/losses/loss_factory.py +1 -85
  17. careamics/losses/lvae/losses.py +223 -164
  18. careamics/lvae_training/calibration.py +184 -0
  19. careamics/lvae_training/dataset/config.py +2 -2
  20. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  21. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  22. careamics/lvae_training/dataset/types.py +15 -26
  23. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  24. careamics/lvae_training/eval_utils.py +125 -213
  25. careamics/model_io/bioimage/_readme_factory.py +25 -33
  26. careamics/model_io/bioimage/cover_factory.py +171 -0
  27. careamics/model_io/bioimage/model_description.py +39 -17
  28. careamics/model_io/bmz_io.py +36 -25
  29. careamics/models/layers.py +6 -4
  30. careamics/models/lvae/layers.py +348 -975
  31. careamics/models/lvae/likelihoods.py +10 -8
  32. careamics/models/lvae/lvae.py +214 -272
  33. careamics/models/lvae/noise_models.py +179 -112
  34. careamics/models/lvae/stochastic.py +393 -0
  35. careamics/models/lvae/utils.py +82 -73
  36. careamics/utils/lightning_utils.py +57 -0
  37. careamics/utils/serializers.py +2 -0
  38. careamics/utils/torch_utils.py +1 -1
  39. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
  40. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
  41. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
  42. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
  43. {careamics-0.0.4.2.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -15,9 +15,16 @@ class LVAEModel(ArchitectureModel):
15
15
  model_config = ConfigDict(validate_assignment=True, validate_default=True)
16
16
 
17
17
  architecture: Literal["LVAE"]
18
- input_shape: int = Field(default=64, ge=8, le=1024)
19
- multiscale_count: int = Field(default=5) # TODO clarify
20
- # 0 - off, len(z_dims) + 1 # TODO can/should be le to z_dims len + 1
18
+ input_shape: list[int] = Field(default=(64, 64), validate_default=True)
19
+ """Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D."""
20
+ encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
21
+ # TODO make this per hierarchy step ?
22
+ decoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
23
+ """Dimensions (2D or 3D) of the convolutional layers."""
24
+ multiscale_count: int = Field(default=1)
25
+ # TODO there should be a check for multiscale_count in dataset !!
26
+
27
+ # 1 - off, len(z_dims) + 1 # TODO Consider starting from 0
21
28
  z_dims: list = Field(default=[128, 128, 128, 128])
22
29
  output_channels: int = Field(default=1, ge=1)
23
30
  encoder_n_filters: int = Field(default=64, ge=8, le=1024)
@@ -31,10 +38,90 @@ class LVAEModel(ArchitectureModel):
31
38
  )
32
39
 
33
40
  predict_logvar: Literal[None, "pixelwise"] = None
41
+ analytical_kl: bool = Field(default=False)
34
42
 
35
- analytical_kl: bool = Field(
36
- default=False,
37
- )
43
+ @model_validator(mode="after")
44
+ def validate_conv_strides(self: Self) -> Self:
45
+ """
46
+ Validate the convolutional strides.
47
+
48
+ Returns
49
+ -------
50
+ list
51
+ Validated strides.
52
+
53
+ Raises
54
+ ------
55
+ ValueError
56
+ If the number of strides is not 2.
57
+ """
58
+ if len(self.encoder_conv_strides) < 2 or len(self.encoder_conv_strides) > 3:
59
+ raise ValueError(
60
+ f"Strides must be 2 or 3 (got {len(self.encoder_conv_strides)})."
61
+ )
62
+
63
+ if len(self.decoder_conv_strides) < 2 or len(self.decoder_conv_strides) > 3:
64
+ raise ValueError(
65
+ f"Strides must be 2 or 3 (got {len(self.decoder_conv_strides)})."
66
+ )
67
+
68
+ # adding 1 to encoder strides for the number of input channels
69
+ if len(self.input_shape) != len(self.encoder_conv_strides):
70
+ raise ValueError(
71
+ f"Input dimensions must be equal to the number of encoder conv strides"
72
+ f" (got {len(self.input_shape)} and {len(self.encoder_conv_strides)})."
73
+ )
74
+
75
+ if len(self.encoder_conv_strides) < len(self.decoder_conv_strides):
76
+ raise ValueError(
77
+ f"Decoder can't be 3D when encoder is 2D (got"
78
+ f" {len(self.encoder_conv_strides)} and"
79
+ f"{len(self.decoder_conv_strides)})."
80
+ )
81
+
82
+ if any(s < 1 for s in self.encoder_conv_strides) or any(
83
+ s < 1 for s in self.decoder_conv_strides
84
+ ):
85
+ raise ValueError(
86
+ f"All strides must be greater or equal to 1"
87
+ f"(got {self.encoder_conv_strides} and {self.decoder_conv_strides})."
88
+ )
89
+ # TODO: validate max stride size ?
90
+ return self
91
+
92
+ @field_validator("input_shape")
93
+ @classmethod
94
+ def validate_input_shape(cls, input_shape: list) -> list:
95
+ """
96
+ Validate the input shape.
97
+
98
+ Parameters
99
+ ----------
100
+ input_shape : list
101
+ Shape of the input patch.
102
+
103
+ Returns
104
+ -------
105
+ list
106
+ Validated input shape.
107
+
108
+ Raises
109
+ ------
110
+ ValueError
111
+ If the number of dimensions is not 3 or 4.
112
+ """
113
+ if len(input_shape) < 2 or len(input_shape) > 3:
114
+ raise ValueError(
115
+ f"Number of input dimensions must be 2 for 2D data 3 for 3D"
116
+ f"(got {len(input_shape)})."
117
+ )
118
+
119
+ if any(s < 1 for s in input_shape):
120
+ raise ValueError(
121
+ f"Input shape must be greater than 1 in all dimensions"
122
+ f"(got {input_shape})."
123
+ )
124
+ return input_shape
38
125
 
39
126
  @field_validator("encoder_n_filters")
40
127
  @classmethod
@@ -124,27 +211,20 @@ class LVAEModel(ArchitectureModel):
124
211
  return z_dims
125
212
 
126
213
  @model_validator(mode="after")
127
- def validate_multiscale_count(cls, self: Self) -> Self:
214
+ def validate_multiscale_count(self: Self) -> Self:
128
215
  """
129
216
  Validate the multiscale count.
130
217
 
131
- Parameters
132
- ----------
133
- self : Self
134
- The model.
135
-
136
218
  Returns
137
219
  -------
138
220
  Self
139
221
  The validated model.
140
222
  """
141
- # if self.multiscale_count != 0:
142
- # if self.multiscale_count != len(self.z_dims) - 1:
143
- # raise ValueError(
144
- # f"Multiscale count must be 0 or equal to the number of Z "
145
- # f"dims - 1 (got {self.multiscale_count} and {len(self.z_dims)})."
146
- # )
147
-
223
+ if self.multiscale_count < 1 or self.multiscale_count > len(self.z_dims) + 1:
224
+ raise ValueError(
225
+ f"Multiscale count must be 1 for LC off or less or equal to the number"
226
+ f" of Z dims + 1 (got {self.multiscale_count} and {len(self.z_dims)})."
227
+ )
148
228
  return self
149
229
 
150
230
  def set_3D(self, is_3D: bool) -> None:
@@ -156,7 +236,10 @@ class LVAEModel(ArchitectureModel):
156
236
  is_3D : bool
157
237
  Whether the algorithm is 3D or not.
158
238
  """
159
- raise NotImplementedError("VAE is not implemented yet.")
239
+ if is_3D:
240
+ self.conv_dims = 3
241
+ else:
242
+ self.conv_dims = 2
160
243
 
161
244
  def is_3D(self) -> bool:
162
245
  """
@@ -167,4 +250,4 @@ class LVAEModel(ArchitectureModel):
167
250
  bool
168
251
  Whether the model is 3D or not.
169
252
  """
170
- raise NotImplementedError("VAE is not implemented yet.")
253
+ return self.conv_dims == 3
@@ -234,8 +234,8 @@ def _create_supervised_configuration(
234
234
  augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
235
235
  independent_channels: bool = True,
236
236
  loss: Literal["mae", "mse"] = "mae",
237
- n_channels_in: int = 1,
238
- n_channels_out: int = 1,
237
+ n_channels_in: Optional[int] = None,
238
+ n_channels_out: Optional[int] = None,
239
239
  logger: Literal["wandb", "tensorboard", "none"] = "none",
240
240
  model_params: Optional[dict] = None,
241
241
  dataloader_params: Optional[dict] = None,
@@ -267,10 +267,10 @@ def _create_supervised_configuration(
267
267
  Whether to train all channels independently, by default False.
268
268
  loss : Literal["mae", "mse"], optional
269
269
  Loss function to use, by default "mae".
270
- n_channels_in : int, optional
271
- Number of channels in, by default 1.
272
- n_channels_out : int, optional
273
- Number of channels out, by default 1.
270
+ n_channels_in : int or None, default=None
271
+ Number of channels in.
272
+ n_channels_out : int or None, default=None
273
+ Number of channels out.
274
274
  logger : Literal["wandb", "tensorboard", "none"], optional
275
275
  Logger to use, by default "none".
276
276
  model_params : dict, optional
@@ -282,19 +282,29 @@ def _create_supervised_configuration(
282
282
  -------
283
283
  Configuration
284
284
  Configuration for training CARE or Noise2Noise.
285
+
286
+ Raises
287
+ ------
288
+ ValueError
289
+ If the number of channels is not specified when using channels.
290
+ ValueError
291
+ If the number of channels is specified but "C" is not in the axes.
285
292
  """
286
293
  # if there are channels, we need to specify their number
287
- if "C" in axes and n_channels_in == 1:
288
- raise ValueError(
289
- f"Number of channels in must be specified when using channels "
290
- f"(got {n_channels_in} channel)."
291
- )
292
- elif "C" not in axes and n_channels_in > 1:
294
+ if "C" in axes and n_channels_in is None:
295
+ raise ValueError("Number of channels in must be specified when using channels ")
296
+ elif "C" not in axes and (n_channels_in is not None and n_channels_in > 1):
293
297
  raise ValueError(
294
298
  f"C is not present in the axes, but number of channels is specified "
295
299
  f"(got {n_channels_in} channels)."
296
300
  )
297
301
 
302
+ if n_channels_in is None:
303
+ n_channels_in = 1
304
+
305
+ if n_channels_out is None:
306
+ n_channels_out = n_channels_in
307
+
298
308
  # augmentations
299
309
  transform_list = _list_augmentations(augmentations)
300
310
 
@@ -327,8 +337,8 @@ def create_care_configuration(
327
337
  augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
328
338
  independent_channels: bool = True,
329
339
  loss: Literal["mae", "mse"] = "mae",
330
- n_channels_in: int = 1,
331
- n_channels_out: int = -1,
340
+ n_channels_in: Optional[int] = None,
341
+ n_channels_out: Optional[int] = None,
332
342
  logger: Literal["wandb", "tensorboard", "none"] = "none",
333
343
  model_params: Optional[dict] = None,
334
344
  dataloader_params: Optional[dict] = None,
@@ -374,16 +384,16 @@ def create_care_configuration(
374
384
  and XYRandomRotate90 (in XY) to the images.
375
385
  independent_channels : bool, optional
376
386
  Whether to train all channels independently, by default False.
377
- loss : Literal["mae", "mse"], optional
378
- Loss function to use, by default "mae".
379
- n_channels_in : int, optional
380
- Number of channels in, by default 1.
381
- n_channels_out : int, optional
382
- Number of channels out, by default -1.
383
- logger : Literal["wandb", "tensorboard", "none"], optional
384
- Logger to use, by default "none".
385
- model_params : dict, optional
386
- UNetModel parameters, by default None.
387
+ loss : Literal["mae", "mse"], default="mae"
388
+ Loss function to use.
389
+ n_channels_in : int or None, default=None
390
+ Number of channels in.
391
+ n_channels_out : int or None, default=None
392
+ Number of channels out.
393
+ logger : Literal["wandb", "tensorboard", "none"], default="none"
394
+ Logger to use.
395
+ model_params : dict, default=None
396
+ UNetModel parameters.
387
397
  dataloader_params : dict, optional
388
398
  Parameters for the dataloader, see PyTorch notes, by default None.
389
399
 
@@ -459,9 +469,6 @@ def create_care_configuration(
459
469
  ... n_channels_out=1 # if applicable
460
470
  ... )
461
471
  """
462
- if n_channels_out == -1:
463
- n_channels_out = n_channels_in
464
-
465
472
  return _create_supervised_configuration(
466
473
  algorithm="care",
467
474
  experiment_name=experiment_name,
@@ -491,8 +498,8 @@ def create_n2n_configuration(
491
498
  augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
492
499
  independent_channels: bool = True,
493
500
  loss: Literal["mae", "mse"] = "mae",
494
- n_channels_in: int = 1,
495
- n_channels_out: int = -1,
501
+ n_channels_in: Optional[int] = None,
502
+ n_channels_out: Optional[int] = None,
496
503
  logger: Literal["wandb", "tensorboard", "none"] = "none",
497
504
  model_params: Optional[dict] = None,
498
505
  dataloader_params: Optional[dict] = None,
@@ -540,10 +547,10 @@ def create_n2n_configuration(
540
547
  Whether to train all channels independently, by default False.
541
548
  loss : Literal["mae", "mse"], optional
542
549
  Loss function to use, by default "mae".
543
- n_channels_in : int, optional
544
- Number of channels in, by default 1.
545
- n_channels_out : int, optional
546
- Number of channels out, by default -1.
550
+ n_channels_in : int or None, default=None
551
+ Number of channels in.
552
+ n_channels_out : int or None, default=None
553
+ Number of channels out.
547
554
  logger : Literal["wandb", "tensorboard", "none"], optional
548
555
  Logger to use, by default "none".
549
556
  model_params : dict, optional
@@ -623,9 +630,6 @@ def create_n2n_configuration(
623
630
  ... n_channels_out=1 # if applicable
624
631
  ... )
625
632
  """
626
- if n_channels_out == -1:
627
- n_channels_out = n_channels_in
628
-
629
633
  return _create_supervised_configuration(
630
634
  algorithm="n2n",
631
635
  experiment_name=experiment_name,
@@ -655,7 +659,7 @@ def create_n2v_configuration(
655
659
  augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
656
660
  independent_channels: bool = True,
657
661
  use_n2v2: bool = False,
658
- n_channels: int = 1,
662
+ n_channels: Optional[int] = None,
659
663
  roi_size: int = 11,
660
664
  masked_pixel_percentage: float = 0.2,
661
665
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
@@ -727,8 +731,8 @@ def create_n2v_configuration(
727
731
  Whether to train all channels together, by default True.
728
732
  use_n2v2 : bool, optional
729
733
  Whether to use N2V2, by default False.
730
- n_channels : int, optional
731
- Number of channels (in and out), by default 1.
734
+ n_channels : int or None, default=None
735
+ Number of channels (in and out).
732
736
  roi_size : int, optional
733
737
  N2V pixel manipulation area, by default 11.
734
738
  masked_pixel_percentage : float, optional
@@ -837,17 +841,17 @@ def create_n2v_configuration(
837
841
  ... )
838
842
  """
839
843
  # if there are channels, we need to specify their number
840
- if "C" in axes and n_channels == 1:
841
- raise ValueError(
842
- f"Number of channels must be specified when using channels "
843
- f"(got {n_channels} channel)."
844
- )
845
- elif "C" not in axes and n_channels > 1:
844
+ if "C" in axes and n_channels is None:
845
+ raise ValueError("Number of channels must be specified when using channels.")
846
+ elif "C" not in axes and (n_channels is not None and n_channels > 1):
846
847
  raise ValueError(
847
848
  f"C is not present in the axes, but number of channels is specified "
848
849
  f"(got {n_channels} channel)."
849
850
  )
850
851
 
852
+ if n_channels is None:
853
+ n_channels = 1
854
+
851
855
  # augmentations
852
856
  transform_list = _list_augmentations(augmentations)
853
857
 
@@ -565,8 +565,8 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
565
565
  config : Configuration
566
566
  Configuration to save.
567
567
  path : str or Path
568
- Path to a existing folder in which to save the configuration or to an existing
569
- configuration file.
568
+ Path to a existing folder in which to save the configuration, or to a valid
569
+ configuration file path (uses a .yml or .yaml extension).
570
570
 
571
571
  Returns
572
572
  -------
@@ -4,7 +4,7 @@ from typing import Literal, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
- from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, PlainValidator
7
+ from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator
8
8
  from typing_extensions import Annotated
9
9
 
10
10
  from careamics.models.lvae.noise_models import (
@@ -41,7 +41,12 @@ class GaussianLikelihoodConfig(BaseModel):
41
41
 
42
42
 
43
43
  class NMLikelihoodConfig(BaseModel):
44
- """Noise model likelihood configuration."""
44
+ """Noise model likelihood configuration.
45
+
46
+ NOTE: we need to define the data mean and std here because the noise model
47
+ is trained on not-normalized data. Hence, we need to unnormalize the model
48
+ output to compute the noise model likelihood.
49
+ """
45
50
 
46
51
  model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
47
52
 
@@ -54,7 +59,3 @@ class NMLikelihoodConfig(BaseModel):
54
59
  data_std: Tensor = torch.ones(1)
55
60
  """The standard deviation of the data, used to unnormalize data for noise
56
61
  model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
57
-
58
- # TODO: serialization/deserialization for this
59
- noise_model: Optional[NoiseModel] = Field(default=None, exclude=True)
60
- """The noise model instance used to compute the likelihood."""
@@ -0,0 +1,56 @@
1
+ """Configuration classes for LVAE losses."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+
8
+ class KLLossConfig(BaseModel):
9
+ """KL loss configuration."""
10
+
11
+ model_config = ConfigDict(validate_assignment=True, validate_default=True)
12
+
13
+ loss_type: Literal["kl", "kl_restricted"] = "kl"
14
+ """Type of KL divergence used as KL loss."""
15
+ rescaling: Literal["latent_dim", "image_dim"] = "latent_dim"
16
+ """Rescaling of the KL loss."""
17
+ aggregation: Literal["sum", "mean"] = "mean"
18
+ """Aggregation of the KL loss across different layers."""
19
+ free_bits_coeff: float = 0.0
20
+ """Free bits coefficient for the KL loss."""
21
+ annealing: bool = False
22
+ """Whether to apply KL loss annealing."""
23
+ start: int = -1
24
+ """Epoch at which KL loss annealing starts."""
25
+ annealtime: int = 10
26
+ """Number of epochs for which KL loss annealing is applied."""
27
+ current_epoch: int = 0
28
+ """Current epoch in the training loop."""
29
+
30
+
31
+ class LVAELossConfig(BaseModel):
32
+ """LVAE loss configuration."""
33
+
34
+ model_config = ConfigDict(
35
+ validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
36
+ )
37
+
38
+ loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"]
39
+ """Type of loss to use for LVAE."""
40
+
41
+ reconstruction_weight: float = 1.0
42
+ """Weight for the reconstruction loss in the total net loss
43
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
44
+ kl_weight: float = 1.0
45
+ """Weight for the KL loss in the total net loss.
46
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
47
+ musplit_weight: float = 0.1
48
+ """Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
49
+ denoisplit_weight: float = 0.9
50
+ """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
51
+ kl_params: KLLossConfig = KLLossConfig()
52
+ """KL loss configuration."""
53
+
54
+ # TODO: remove?
55
+ non_stochastic: bool = False
56
+ """Whether to sample latents and compute KL."""
@@ -11,9 +11,8 @@ from pydantic import (
11
11
  Field,
12
12
  PlainSerializer,
13
13
  PlainValidator,
14
- model_validator,
15
14
  )
16
- from typing_extensions import Annotated, Self
15
+ from typing_extensions import Annotated
17
16
 
18
17
  from careamics.utils.serializers import _array_to_json, _to_numpy
19
18
 
@@ -90,28 +89,29 @@ class GaussianMixtureNMConfig(BaseModel):
90
89
  tol: float = Field(default=1e-10)
91
90
  """Tolerance used in the computation of the noise model likelihood."""
92
91
 
93
- @model_validator(mode="after")
94
- def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
95
- """Validate paths provided in the config.
96
-
97
- Returns
98
- -------
99
- Self
100
- Returns itself.
101
- """
102
- if self.path and (self.signal is not None or self.observation is not None):
103
- raise ValueError(
104
- "Either only 'path' to pre-trained noise model should be"
105
- "provided or only signal and observation in form of paths"
106
- "or numpy arrays."
107
- )
108
- if not self.path and (self.signal is None or self.observation is None):
109
- raise ValueError(
110
- "Either only 'path' to pre-trained noise model should be"
111
- "provided or only signal and observation in form of paths"
112
- "or numpy arrays."
113
- )
114
- return self
92
+ # @model_validator(mode="after")
93
+ # def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
94
+ # """Validate paths provided in the config.
95
+
96
+ # Returns
97
+ # -------
98
+ # Self
99
+ # Returns itself.
100
+ # """
101
+ # if self.path and (self.signal is not None or self.observation is not None):
102
+ # raise ValueError(
103
+ # "Either only 'path' to pre-trained noise model should be"
104
+ # "provided or only signal and observation in form of paths"
105
+ # "or numpy arrays."
106
+ # )
107
+ # if not self.path and (self.signal is None or self.observation is None):
108
+ # raise ValueError(
109
+ # "Either only 'path' to pre-trained noise model should be"
110
+ # "provided or only signal and observation in form of paths"
111
+ # "or numpy arrays."
112
+ # )
113
+ # return self
114
+ # TODO revisit validation
115
115
 
116
116
 
117
117
  # The noise model is given by a set of GMMs, one for each target
@@ -12,6 +12,7 @@ from careamics.config.support import SupportedAlgorithm, SupportedLoss
12
12
 
13
13
  from .architectures import CustomModel, LVAEModel
14
14
  from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
15
+ from .loss_model import LVAELossConfig
15
16
  from .nm_model import MultiChannelNMConfig
16
17
  from .optimizer_models import LrSchedulerModel, OptimizerModel
17
18
 
@@ -38,13 +39,13 @@ class VAEAlgorithmConfig(BaseModel):
38
39
  # TODO: Use supported Enum classes for typing?
39
40
  # - values can still be passed as strings and they will be cast to Enum
40
41
  algorithm: Literal["musplit", "denoisplit"]
41
- loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
42
- model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
43
42
 
44
- # TODO: these are configs, change naming of attrs
43
+ # NOTE: these are all configs (pydantic models)
44
+ loss: LVAELossConfig
45
+ model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
45
46
  noise_model: Optional[MultiChannelNMConfig] = None
46
- noise_model_likelihood_model: Optional[NMLikelihoodConfig] = None
47
- gaussian_likelihood_model: Optional[GaussianLikelihoodConfig] = None
47
+ noise_model_likelihood: Optional[NMLikelihoodConfig] = None
48
+ gaussian_likelihood: Optional[GaussianLikelihoodConfig] = None
48
49
 
49
50
  # Optional fields
50
51
  optimizer: OptimizerModel = OptimizerModel()
@@ -63,13 +64,13 @@ class VAEAlgorithmConfig(BaseModel):
63
64
  """
64
65
  # musplit
65
66
  if self.algorithm == SupportedAlgorithm.MUSPLIT:
66
- if self.loss != SupportedLoss.MUSPLIT:
67
+ if self.loss.loss_type != SupportedLoss.MUSPLIT:
67
68
  raise ValueError(
68
69
  f"Algorithm {self.algorithm} only supports loss `musplit`."
69
70
  )
70
71
 
71
72
  if self.algorithm == SupportedAlgorithm.DENOISPLIT:
72
- if self.loss not in [
73
+ if self.loss.loss_type not in [
73
74
  SupportedLoss.DENOISPLIT,
74
75
  SupportedLoss.DENOISPLIT_MUSPLIT,
75
76
  ]:
@@ -78,16 +79,17 @@ class VAEAlgorithmConfig(BaseModel):
78
79
  "or `denoisplit_musplit."
79
80
  )
80
81
  if (
81
- self.loss == SupportedLoss.DENOISPLIT
82
+ self.loss.loss_type == SupportedLoss.DENOISPLIT
82
83
  and self.model.predict_logvar is not None
83
84
  ):
84
85
  raise ValueError(
85
86
  "Algorithm `denoisplit` with loss `denoisplit` only supports "
86
87
  "`predict_logvar` as `None`."
87
88
  )
89
+
88
90
  if self.noise_model is None:
89
91
  raise ValueError("Algorithm `denoisplit` requires a noise model.")
90
- # TODO: what if algorithm is not musplit or denoisplit (HDN?)
92
+ # TODO: what if algorithm is not musplit or denoisplit
91
93
  return self
92
94
 
93
95
  @model_validator(mode="after")
@@ -115,14 +117,13 @@ class VAEAlgorithmConfig(BaseModel):
115
117
  Self
116
118
  The validated model.
117
119
  """
118
- if self.gaussian_likelihood_model is not None:
120
+ if self.gaussian_likelihood is not None:
119
121
  assert (
120
- self.model.predict_logvar
121
- == self.gaussian_likelihood_model.predict_logvar
122
+ self.model.predict_logvar == self.gaussian_likelihood.predict_logvar
122
123
  ), (
123
124
  f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
124
125
  "Gaussian likelihood model `predict_logvar` "
125
- f"({self.gaussian_likelihood_model.predict_logvar}).",
126
+ f"({self.gaussian_likelihood.predict_logvar}).",
126
127
  )
127
128
  return self
128
129