careamics 0.0.4.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +235 -25
- careamics/cli/conf.py +19 -30
- careamics/cli/main.py +111 -10
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +2 -0
- careamics/config/architectures/lvae_model.py +104 -21
- careamics/config/configuration_factory.py +49 -45
- careamics/config/configuration_model.py +2 -2
- careamics/config/likelihood_model.py +7 -6
- careamics/config/loss_model.py +56 -0
- careamics/config/nm_model.py +24 -24
- careamics/config/vae_algorithm_model.py +14 -13
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/lightning/lightning_module.py +58 -27
- careamics/lightning/train_data_module.py +15 -1
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/bioimage/_readme_factory.py +25 -33
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +35 -22
- careamics/model_io/bmz_io.py +36 -25
- careamics/models/layers.py +6 -4
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -272
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +1 -1
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/METADATA +12 -9
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/RECORD +43 -37
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/WHEEL +1 -1
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.1.dist-info → careamics-0.0.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,9 +15,16 @@ class LVAEModel(ArchitectureModel):
|
|
|
15
15
|
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
16
16
|
|
|
17
17
|
architecture: Literal["LVAE"]
|
|
18
|
-
input_shape: int = Field(default=64,
|
|
19
|
-
|
|
20
|
-
|
|
18
|
+
input_shape: list[int] = Field(default=(64, 64), validate_default=True)
|
|
19
|
+
"""Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D."""
|
|
20
|
+
encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
|
|
21
|
+
# TODO make this per hierarchy step ?
|
|
22
|
+
decoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
|
|
23
|
+
"""Dimensions (2D or 3D) of the convolutional layers."""
|
|
24
|
+
multiscale_count: int = Field(default=1)
|
|
25
|
+
# TODO there should be a check for multiscale_count in dataset !!
|
|
26
|
+
|
|
27
|
+
# 1 - off, len(z_dims) + 1 # TODO Consider starting from 0
|
|
21
28
|
z_dims: list = Field(default=[128, 128, 128, 128])
|
|
22
29
|
output_channels: int = Field(default=1, ge=1)
|
|
23
30
|
encoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
@@ -31,10 +38,90 @@ class LVAEModel(ArchitectureModel):
|
|
|
31
38
|
)
|
|
32
39
|
|
|
33
40
|
predict_logvar: Literal[None, "pixelwise"] = None
|
|
41
|
+
analytical_kl: bool = Field(default=False)
|
|
34
42
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
43
|
+
@model_validator(mode="after")
|
|
44
|
+
def validate_conv_strides(self: Self) -> Self:
|
|
45
|
+
"""
|
|
46
|
+
Validate the convolutional strides.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
list
|
|
51
|
+
Validated strides.
|
|
52
|
+
|
|
53
|
+
Raises
|
|
54
|
+
------
|
|
55
|
+
ValueError
|
|
56
|
+
If the number of strides is not 2.
|
|
57
|
+
"""
|
|
58
|
+
if len(self.encoder_conv_strides) < 2 or len(self.encoder_conv_strides) > 3:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"Strides must be 2 or 3 (got {len(self.encoder_conv_strides)})."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if len(self.decoder_conv_strides) < 2 or len(self.decoder_conv_strides) > 3:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Strides must be 2 or 3 (got {len(self.decoder_conv_strides)})."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# adding 1 to encoder strides for the number of input channels
|
|
69
|
+
if len(self.input_shape) != len(self.encoder_conv_strides):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Input dimensions must be equal to the number of encoder conv strides"
|
|
72
|
+
f" (got {len(self.input_shape)} and {len(self.encoder_conv_strides)})."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if len(self.encoder_conv_strides) < len(self.decoder_conv_strides):
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Decoder can't be 3D when encoder is 2D (got"
|
|
78
|
+
f" {len(self.encoder_conv_strides)} and"
|
|
79
|
+
f"{len(self.decoder_conv_strides)})."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if any(s < 1 for s in self.encoder_conv_strides) or any(
|
|
83
|
+
s < 1 for s in self.decoder_conv_strides
|
|
84
|
+
):
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"All strides must be greater or equal to 1"
|
|
87
|
+
f"(got {self.encoder_conv_strides} and {self.decoder_conv_strides})."
|
|
88
|
+
)
|
|
89
|
+
# TODO: validate max stride size ?
|
|
90
|
+
return self
|
|
91
|
+
|
|
92
|
+
@field_validator("input_shape")
|
|
93
|
+
@classmethod
|
|
94
|
+
def validate_input_shape(cls, input_shape: list) -> list:
|
|
95
|
+
"""
|
|
96
|
+
Validate the input shape.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
input_shape : list
|
|
101
|
+
Shape of the input patch.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
list
|
|
106
|
+
Validated input shape.
|
|
107
|
+
|
|
108
|
+
Raises
|
|
109
|
+
------
|
|
110
|
+
ValueError
|
|
111
|
+
If the number of dimensions is not 3 or 4.
|
|
112
|
+
"""
|
|
113
|
+
if len(input_shape) < 2 or len(input_shape) > 3:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"Number of input dimensions must be 2 for 2D data 3 for 3D"
|
|
116
|
+
f"(got {len(input_shape)})."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if any(s < 1 for s in input_shape):
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Input shape must be greater than 1 in all dimensions"
|
|
122
|
+
f"(got {input_shape})."
|
|
123
|
+
)
|
|
124
|
+
return input_shape
|
|
38
125
|
|
|
39
126
|
@field_validator("encoder_n_filters")
|
|
40
127
|
@classmethod
|
|
@@ -124,27 +211,20 @@ class LVAEModel(ArchitectureModel):
|
|
|
124
211
|
return z_dims
|
|
125
212
|
|
|
126
213
|
@model_validator(mode="after")
|
|
127
|
-
def validate_multiscale_count(
|
|
214
|
+
def validate_multiscale_count(self: Self) -> Self:
|
|
128
215
|
"""
|
|
129
216
|
Validate the multiscale count.
|
|
130
217
|
|
|
131
|
-
Parameters
|
|
132
|
-
----------
|
|
133
|
-
self : Self
|
|
134
|
-
The model.
|
|
135
|
-
|
|
136
218
|
Returns
|
|
137
219
|
-------
|
|
138
220
|
Self
|
|
139
221
|
The validated model.
|
|
140
222
|
"""
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
# )
|
|
147
|
-
|
|
223
|
+
if self.multiscale_count < 1 or self.multiscale_count > len(self.z_dims) + 1:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"Multiscale count must be 1 for LC off or less or equal to the number"
|
|
226
|
+
f" of Z dims + 1 (got {self.multiscale_count} and {len(self.z_dims)})."
|
|
227
|
+
)
|
|
148
228
|
return self
|
|
149
229
|
|
|
150
230
|
def set_3D(self, is_3D: bool) -> None:
|
|
@@ -156,7 +236,10 @@ class LVAEModel(ArchitectureModel):
|
|
|
156
236
|
is_3D : bool
|
|
157
237
|
Whether the algorithm is 3D or not.
|
|
158
238
|
"""
|
|
159
|
-
|
|
239
|
+
if is_3D:
|
|
240
|
+
self.conv_dims = 3
|
|
241
|
+
else:
|
|
242
|
+
self.conv_dims = 2
|
|
160
243
|
|
|
161
244
|
def is_3D(self) -> bool:
|
|
162
245
|
"""
|
|
@@ -167,4 +250,4 @@ class LVAEModel(ArchitectureModel):
|
|
|
167
250
|
bool
|
|
168
251
|
Whether the model is 3D or not.
|
|
169
252
|
"""
|
|
170
|
-
|
|
253
|
+
return self.conv_dims == 3
|
|
@@ -234,8 +234,8 @@ def _create_supervised_configuration(
|
|
|
234
234
|
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
|
|
235
235
|
independent_channels: bool = True,
|
|
236
236
|
loss: Literal["mae", "mse"] = "mae",
|
|
237
|
-
n_channels_in: int =
|
|
238
|
-
n_channels_out: int =
|
|
237
|
+
n_channels_in: Optional[int] = None,
|
|
238
|
+
n_channels_out: Optional[int] = None,
|
|
239
239
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
240
240
|
model_params: Optional[dict] = None,
|
|
241
241
|
dataloader_params: Optional[dict] = None,
|
|
@@ -267,10 +267,10 @@ def _create_supervised_configuration(
|
|
|
267
267
|
Whether to train all channels independently, by default False.
|
|
268
268
|
loss : Literal["mae", "mse"], optional
|
|
269
269
|
Loss function to use, by default "mae".
|
|
270
|
-
n_channels_in : int,
|
|
271
|
-
Number of channels in
|
|
272
|
-
n_channels_out : int,
|
|
273
|
-
Number of channels out
|
|
270
|
+
n_channels_in : int or None, default=None
|
|
271
|
+
Number of channels in.
|
|
272
|
+
n_channels_out : int or None, default=None
|
|
273
|
+
Number of channels out.
|
|
274
274
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
275
275
|
Logger to use, by default "none".
|
|
276
276
|
model_params : dict, optional
|
|
@@ -282,19 +282,29 @@ def _create_supervised_configuration(
|
|
|
282
282
|
-------
|
|
283
283
|
Configuration
|
|
284
284
|
Configuration for training CARE or Noise2Noise.
|
|
285
|
+
|
|
286
|
+
Raises
|
|
287
|
+
------
|
|
288
|
+
ValueError
|
|
289
|
+
If the number of channels is not specified when using channels.
|
|
290
|
+
ValueError
|
|
291
|
+
If the number of channels is specified but "C" is not in the axes.
|
|
285
292
|
"""
|
|
286
293
|
# if there are channels, we need to specify their number
|
|
287
|
-
if "C" in axes and n_channels_in
|
|
288
|
-
raise ValueError(
|
|
289
|
-
|
|
290
|
-
f"(got {n_channels_in} channel)."
|
|
291
|
-
)
|
|
292
|
-
elif "C" not in axes and n_channels_in > 1:
|
|
294
|
+
if "C" in axes and n_channels_in is None:
|
|
295
|
+
raise ValueError("Number of channels in must be specified when using channels ")
|
|
296
|
+
elif "C" not in axes and (n_channels_in is not None and n_channels_in > 1):
|
|
293
297
|
raise ValueError(
|
|
294
298
|
f"C is not present in the axes, but number of channels is specified "
|
|
295
299
|
f"(got {n_channels_in} channels)."
|
|
296
300
|
)
|
|
297
301
|
|
|
302
|
+
if n_channels_in is None:
|
|
303
|
+
n_channels_in = 1
|
|
304
|
+
|
|
305
|
+
if n_channels_out is None:
|
|
306
|
+
n_channels_out = n_channels_in
|
|
307
|
+
|
|
298
308
|
# augmentations
|
|
299
309
|
transform_list = _list_augmentations(augmentations)
|
|
300
310
|
|
|
@@ -327,8 +337,8 @@ def create_care_configuration(
|
|
|
327
337
|
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
|
|
328
338
|
independent_channels: bool = True,
|
|
329
339
|
loss: Literal["mae", "mse"] = "mae",
|
|
330
|
-
n_channels_in: int =
|
|
331
|
-
n_channels_out: int =
|
|
340
|
+
n_channels_in: Optional[int] = None,
|
|
341
|
+
n_channels_out: Optional[int] = None,
|
|
332
342
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
333
343
|
model_params: Optional[dict] = None,
|
|
334
344
|
dataloader_params: Optional[dict] = None,
|
|
@@ -374,16 +384,16 @@ def create_care_configuration(
|
|
|
374
384
|
and XYRandomRotate90 (in XY) to the images.
|
|
375
385
|
independent_channels : bool, optional
|
|
376
386
|
Whether to train all channels independently, by default False.
|
|
377
|
-
loss : Literal["mae", "mse"],
|
|
378
|
-
Loss function to use
|
|
379
|
-
n_channels_in : int,
|
|
380
|
-
Number of channels in
|
|
381
|
-
n_channels_out : int,
|
|
382
|
-
Number of channels out
|
|
383
|
-
logger : Literal["wandb", "tensorboard", "none"],
|
|
384
|
-
Logger to use
|
|
385
|
-
model_params : dict,
|
|
386
|
-
UNetModel parameters
|
|
387
|
+
loss : Literal["mae", "mse"], default="mae"
|
|
388
|
+
Loss function to use.
|
|
389
|
+
n_channels_in : int or None, default=None
|
|
390
|
+
Number of channels in.
|
|
391
|
+
n_channels_out : int or None, default=None
|
|
392
|
+
Number of channels out.
|
|
393
|
+
logger : Literal["wandb", "tensorboard", "none"], default="none"
|
|
394
|
+
Logger to use.
|
|
395
|
+
model_params : dict, default=None
|
|
396
|
+
UNetModel parameters.
|
|
387
397
|
dataloader_params : dict, optional
|
|
388
398
|
Parameters for the dataloader, see PyTorch notes, by default None.
|
|
389
399
|
|
|
@@ -459,9 +469,6 @@ def create_care_configuration(
|
|
|
459
469
|
... n_channels_out=1 # if applicable
|
|
460
470
|
... )
|
|
461
471
|
"""
|
|
462
|
-
if n_channels_out == -1:
|
|
463
|
-
n_channels_out = n_channels_in
|
|
464
|
-
|
|
465
472
|
return _create_supervised_configuration(
|
|
466
473
|
algorithm="care",
|
|
467
474
|
experiment_name=experiment_name,
|
|
@@ -491,8 +498,8 @@ def create_n2n_configuration(
|
|
|
491
498
|
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
|
|
492
499
|
independent_channels: bool = True,
|
|
493
500
|
loss: Literal["mae", "mse"] = "mae",
|
|
494
|
-
n_channels_in: int =
|
|
495
|
-
n_channels_out: int =
|
|
501
|
+
n_channels_in: Optional[int] = None,
|
|
502
|
+
n_channels_out: Optional[int] = None,
|
|
496
503
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
497
504
|
model_params: Optional[dict] = None,
|
|
498
505
|
dataloader_params: Optional[dict] = None,
|
|
@@ -540,10 +547,10 @@ def create_n2n_configuration(
|
|
|
540
547
|
Whether to train all channels independently, by default False.
|
|
541
548
|
loss : Literal["mae", "mse"], optional
|
|
542
549
|
Loss function to use, by default "mae".
|
|
543
|
-
n_channels_in : int,
|
|
544
|
-
Number of channels in
|
|
545
|
-
n_channels_out : int,
|
|
546
|
-
Number of channels out
|
|
550
|
+
n_channels_in : int or None, default=None
|
|
551
|
+
Number of channels in.
|
|
552
|
+
n_channels_out : int or None, default=None
|
|
553
|
+
Number of channels out.
|
|
547
554
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
548
555
|
Logger to use, by default "none".
|
|
549
556
|
model_params : dict, optional
|
|
@@ -623,9 +630,6 @@ def create_n2n_configuration(
|
|
|
623
630
|
... n_channels_out=1 # if applicable
|
|
624
631
|
... )
|
|
625
632
|
"""
|
|
626
|
-
if n_channels_out == -1:
|
|
627
|
-
n_channels_out = n_channels_in
|
|
628
|
-
|
|
629
633
|
return _create_supervised_configuration(
|
|
630
634
|
algorithm="n2n",
|
|
631
635
|
experiment_name=experiment_name,
|
|
@@ -655,7 +659,7 @@ def create_n2v_configuration(
|
|
|
655
659
|
augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
|
|
656
660
|
independent_channels: bool = True,
|
|
657
661
|
use_n2v2: bool = False,
|
|
658
|
-
n_channels: int =
|
|
662
|
+
n_channels: Optional[int] = None,
|
|
659
663
|
roi_size: int = 11,
|
|
660
664
|
masked_pixel_percentage: float = 0.2,
|
|
661
665
|
struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
|
|
@@ -727,8 +731,8 @@ def create_n2v_configuration(
|
|
|
727
731
|
Whether to train all channels together, by default True.
|
|
728
732
|
use_n2v2 : bool, optional
|
|
729
733
|
Whether to use N2V2, by default False.
|
|
730
|
-
n_channels : int,
|
|
731
|
-
Number of channels (in and out)
|
|
734
|
+
n_channels : int or None, default=None
|
|
735
|
+
Number of channels (in and out).
|
|
732
736
|
roi_size : int, optional
|
|
733
737
|
N2V pixel manipulation area, by default 11.
|
|
734
738
|
masked_pixel_percentage : float, optional
|
|
@@ -837,17 +841,17 @@ def create_n2v_configuration(
|
|
|
837
841
|
... )
|
|
838
842
|
"""
|
|
839
843
|
# if there are channels, we need to specify their number
|
|
840
|
-
if "C" in axes and n_channels
|
|
841
|
-
raise ValueError(
|
|
842
|
-
|
|
843
|
-
f"(got {n_channels} channel)."
|
|
844
|
-
)
|
|
845
|
-
elif "C" not in axes and n_channels > 1:
|
|
844
|
+
if "C" in axes and n_channels is None:
|
|
845
|
+
raise ValueError("Number of channels must be specified when using channels.")
|
|
846
|
+
elif "C" not in axes and (n_channels is not None and n_channels > 1):
|
|
846
847
|
raise ValueError(
|
|
847
848
|
f"C is not present in the axes, but number of channels is specified "
|
|
848
849
|
f"(got {n_channels} channel)."
|
|
849
850
|
)
|
|
850
851
|
|
|
852
|
+
if n_channels is None:
|
|
853
|
+
n_channels = 1
|
|
854
|
+
|
|
851
855
|
# augmentations
|
|
852
856
|
transform_list = _list_augmentations(augmentations)
|
|
853
857
|
|
|
@@ -565,8 +565,8 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
|
565
565
|
config : Configuration
|
|
566
566
|
Configuration to save.
|
|
567
567
|
path : str or Path
|
|
568
|
-
Path to a existing folder in which to save the configuration or to
|
|
569
|
-
configuration file.
|
|
568
|
+
Path to a existing folder in which to save the configuration, or to a valid
|
|
569
|
+
configuration file path (uses a .yml or .yaml extension).
|
|
570
570
|
|
|
571
571
|
Returns
|
|
572
572
|
-------
|
|
@@ -4,7 +4,7 @@ from typing import Literal, Optional, Union
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
7
|
-
from pydantic import BaseModel, ConfigDict,
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator
|
|
8
8
|
from typing_extensions import Annotated
|
|
9
9
|
|
|
10
10
|
from careamics.models.lvae.noise_models import (
|
|
@@ -41,7 +41,12 @@ class GaussianLikelihoodConfig(BaseModel):
|
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class NMLikelihoodConfig(BaseModel):
|
|
44
|
-
"""Noise model likelihood configuration.
|
|
44
|
+
"""Noise model likelihood configuration.
|
|
45
|
+
|
|
46
|
+
NOTE: we need to define the data mean and std here because the noise model
|
|
47
|
+
is trained on not-normalized data. Hence, we need to unnormalize the model
|
|
48
|
+
output to compute the noise model likelihood.
|
|
49
|
+
"""
|
|
45
50
|
|
|
46
51
|
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
47
52
|
|
|
@@ -54,7 +59,3 @@ class NMLikelihoodConfig(BaseModel):
|
|
|
54
59
|
data_std: Tensor = torch.ones(1)
|
|
55
60
|
"""The standard deviation of the data, used to unnormalize data for noise
|
|
56
61
|
model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
57
|
-
|
|
58
|
-
# TODO: serialization/deserialization for this
|
|
59
|
-
noise_model: Optional[NoiseModel] = Field(default=None, exclude=True)
|
|
60
|
-
"""The noise model instance used to compute the likelihood."""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Configuration classes for LVAE losses."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class KLLossConfig(BaseModel):
|
|
9
|
+
"""KL loss configuration."""
|
|
10
|
+
|
|
11
|
+
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
12
|
+
|
|
13
|
+
loss_type: Literal["kl", "kl_restricted"] = "kl"
|
|
14
|
+
"""Type of KL divergence used as KL loss."""
|
|
15
|
+
rescaling: Literal["latent_dim", "image_dim"] = "latent_dim"
|
|
16
|
+
"""Rescaling of the KL loss."""
|
|
17
|
+
aggregation: Literal["sum", "mean"] = "mean"
|
|
18
|
+
"""Aggregation of the KL loss across different layers."""
|
|
19
|
+
free_bits_coeff: float = 0.0
|
|
20
|
+
"""Free bits coefficient for the KL loss."""
|
|
21
|
+
annealing: bool = False
|
|
22
|
+
"""Whether to apply KL loss annealing."""
|
|
23
|
+
start: int = -1
|
|
24
|
+
"""Epoch at which KL loss annealing starts."""
|
|
25
|
+
annealtime: int = 10
|
|
26
|
+
"""Number of epochs for which KL loss annealing is applied."""
|
|
27
|
+
current_epoch: int = 0
|
|
28
|
+
"""Current epoch in the training loop."""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LVAELossConfig(BaseModel):
|
|
32
|
+
"""LVAE loss configuration."""
|
|
33
|
+
|
|
34
|
+
model_config = ConfigDict(
|
|
35
|
+
validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
39
|
+
"""Type of loss to use for LVAE."""
|
|
40
|
+
|
|
41
|
+
reconstruction_weight: float = 1.0
|
|
42
|
+
"""Weight for the reconstruction loss in the total net loss
|
|
43
|
+
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
44
|
+
kl_weight: float = 1.0
|
|
45
|
+
"""Weight for the KL loss in the total net loss.
|
|
46
|
+
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
47
|
+
musplit_weight: float = 0.1
|
|
48
|
+
"""Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
|
|
49
|
+
denoisplit_weight: float = 0.9
|
|
50
|
+
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
51
|
+
kl_params: KLLossConfig = KLLossConfig()
|
|
52
|
+
"""KL loss configuration."""
|
|
53
|
+
|
|
54
|
+
# TODO: remove?
|
|
55
|
+
non_stochastic: bool = False
|
|
56
|
+
"""Whether to sample latents and compute KL."""
|
careamics/config/nm_model.py
CHANGED
|
@@ -11,9 +11,8 @@ from pydantic import (
|
|
|
11
11
|
Field,
|
|
12
12
|
PlainSerializer,
|
|
13
13
|
PlainValidator,
|
|
14
|
-
model_validator,
|
|
15
14
|
)
|
|
16
|
-
from typing_extensions import Annotated
|
|
15
|
+
from typing_extensions import Annotated
|
|
17
16
|
|
|
18
17
|
from careamics.utils.serializers import _array_to_json, _to_numpy
|
|
19
18
|
|
|
@@ -90,28 +89,29 @@ class GaussianMixtureNMConfig(BaseModel):
|
|
|
90
89
|
tol: float = Field(default=1e-10)
|
|
91
90
|
"""Tolerance used in the computation of the noise model likelihood."""
|
|
92
91
|
|
|
93
|
-
@model_validator(mode="after")
|
|
94
|
-
def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
92
|
+
# @model_validator(mode="after")
|
|
93
|
+
# def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
|
|
94
|
+
# """Validate paths provided in the config.
|
|
95
|
+
|
|
96
|
+
# Returns
|
|
97
|
+
# -------
|
|
98
|
+
# Self
|
|
99
|
+
# Returns itself.
|
|
100
|
+
# """
|
|
101
|
+
# if self.path and (self.signal is not None or self.observation is not None):
|
|
102
|
+
# raise ValueError(
|
|
103
|
+
# "Either only 'path' to pre-trained noise model should be"
|
|
104
|
+
# "provided or only signal and observation in form of paths"
|
|
105
|
+
# "or numpy arrays."
|
|
106
|
+
# )
|
|
107
|
+
# if not self.path and (self.signal is None or self.observation is None):
|
|
108
|
+
# raise ValueError(
|
|
109
|
+
# "Either only 'path' to pre-trained noise model should be"
|
|
110
|
+
# "provided or only signal and observation in form of paths"
|
|
111
|
+
# "or numpy arrays."
|
|
112
|
+
# )
|
|
113
|
+
# return self
|
|
114
|
+
# TODO revisit validation
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
# The noise model is given by a set of GMMs, one for each target
|
|
@@ -12,6 +12,7 @@ from careamics.config.support import SupportedAlgorithm, SupportedLoss
|
|
|
12
12
|
|
|
13
13
|
from .architectures import CustomModel, LVAEModel
|
|
14
14
|
from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig
|
|
15
|
+
from .loss_model import LVAELossConfig
|
|
15
16
|
from .nm_model import MultiChannelNMConfig
|
|
16
17
|
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
17
18
|
|
|
@@ -38,13 +39,13 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
38
39
|
# TODO: Use supported Enum classes for typing?
|
|
39
40
|
# - values can still be passed as strings and they will be cast to Enum
|
|
40
41
|
algorithm: Literal["musplit", "denoisplit"]
|
|
41
|
-
loss: Literal["musplit", "denoisplit", "denoisplit_musplit"]
|
|
42
|
-
model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
|
|
43
42
|
|
|
44
|
-
#
|
|
43
|
+
# NOTE: these are all configs (pydantic models)
|
|
44
|
+
loss: LVAELossConfig
|
|
45
|
+
model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture")
|
|
45
46
|
noise_model: Optional[MultiChannelNMConfig] = None
|
|
46
|
-
|
|
47
|
-
|
|
47
|
+
noise_model_likelihood: Optional[NMLikelihoodConfig] = None
|
|
48
|
+
gaussian_likelihood: Optional[GaussianLikelihoodConfig] = None
|
|
48
49
|
|
|
49
50
|
# Optional fields
|
|
50
51
|
optimizer: OptimizerModel = OptimizerModel()
|
|
@@ -63,13 +64,13 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
63
64
|
"""
|
|
64
65
|
# musplit
|
|
65
66
|
if self.algorithm == SupportedAlgorithm.MUSPLIT:
|
|
66
|
-
if self.loss != SupportedLoss.MUSPLIT:
|
|
67
|
+
if self.loss.loss_type != SupportedLoss.MUSPLIT:
|
|
67
68
|
raise ValueError(
|
|
68
69
|
f"Algorithm {self.algorithm} only supports loss `musplit`."
|
|
69
70
|
)
|
|
70
71
|
|
|
71
72
|
if self.algorithm == SupportedAlgorithm.DENOISPLIT:
|
|
72
|
-
if self.loss not in [
|
|
73
|
+
if self.loss.loss_type not in [
|
|
73
74
|
SupportedLoss.DENOISPLIT,
|
|
74
75
|
SupportedLoss.DENOISPLIT_MUSPLIT,
|
|
75
76
|
]:
|
|
@@ -78,16 +79,17 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
78
79
|
"or `denoisplit_musplit."
|
|
79
80
|
)
|
|
80
81
|
if (
|
|
81
|
-
self.loss == SupportedLoss.DENOISPLIT
|
|
82
|
+
self.loss.loss_type == SupportedLoss.DENOISPLIT
|
|
82
83
|
and self.model.predict_logvar is not None
|
|
83
84
|
):
|
|
84
85
|
raise ValueError(
|
|
85
86
|
"Algorithm `denoisplit` with loss `denoisplit` only supports "
|
|
86
87
|
"`predict_logvar` as `None`."
|
|
87
88
|
)
|
|
89
|
+
|
|
88
90
|
if self.noise_model is None:
|
|
89
91
|
raise ValueError("Algorithm `denoisplit` requires a noise model.")
|
|
90
|
-
# TODO: what if algorithm is not musplit or denoisplit
|
|
92
|
+
# TODO: what if algorithm is not musplit or denoisplit
|
|
91
93
|
return self
|
|
92
94
|
|
|
93
95
|
@model_validator(mode="after")
|
|
@@ -115,14 +117,13 @@ class VAEAlgorithmConfig(BaseModel):
|
|
|
115
117
|
Self
|
|
116
118
|
The validated model.
|
|
117
119
|
"""
|
|
118
|
-
if self.
|
|
120
|
+
if self.gaussian_likelihood is not None:
|
|
119
121
|
assert (
|
|
120
|
-
self.model.predict_logvar
|
|
121
|
-
== self.gaussian_likelihood_model.predict_logvar
|
|
122
|
+
self.model.predict_logvar == self.gaussian_likelihood.predict_logvar
|
|
122
123
|
), (
|
|
123
124
|
f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
|
|
124
125
|
"Gaussian likelihood model `predict_logvar` "
|
|
125
|
-
f"({self.
|
|
126
|
+
f"({self.gaussian_likelihood.predict_logvar}).",
|
|
126
127
|
)
|
|
127
128
|
return self
|
|
128
129
|
|