careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +239 -28
- careamics/cli/conf.py +19 -31
- careamics/cli/main.py +112 -12
- careamics/cli/utils.py +29 -0
- careamics/config/__init__.py +48 -24
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +109 -21
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +8 -8
- careamics/config/loss_model.py +56 -0
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +24 -25
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/dataset_utils/running_stats.py +22 -23
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +69 -34
- careamics/lightning/train_data_module.py +41 -27
- careamics/losses/__init__.py +3 -3
- careamics/losses/loss_factory.py +1 -85
- careamics/losses/lvae/losses.py +223 -164
- careamics/lvae_training/calibration.py +184 -0
- careamics/lvae_training/dataset/config.py +2 -2
- careamics/lvae_training/dataset/multich_dataset.py +11 -19
- careamics/lvae_training/dataset/multifile_dataset.py +3 -2
- careamics/lvae_training/dataset/types.py +15 -26
- careamics/lvae_training/dataset/utils/index_manager.py +4 -4
- careamics/lvae_training/eval_utils.py +125 -213
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +26 -34
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +56 -34
- careamics/model_io/bmz_io.py +42 -42
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +22 -20
- careamics/models/lvae/layers.py +348 -975
- careamics/models/lvae/likelihoods.py +10 -8
- careamics/models/lvae/lvae.py +214 -275
- careamics/models/lvae/noise_models.py +179 -112
- careamics/models/lvae/stochastic.py +393 -0
- careamics/models/lvae/utils.py +82 -73
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/lightning_utils.py +57 -0
- careamics/utils/logging.py +11 -10
- careamics/utils/serializers.py +2 -0
- careamics/utils/torch_utils.py +8 -8
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
- careamics-0.0.6.dist-info/RECORD +176 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- careamics-0.0.4.2.dist-info/RECORD +0 -165
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Module to get read functions."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Callable,
|
|
4
|
+
from typing import Callable, Protocol, Union
|
|
5
5
|
|
|
6
6
|
from numpy.typing import NDArray
|
|
7
7
|
|
|
@@ -30,7 +30,7 @@ class ReadFunc(Protocol):
|
|
|
30
30
|
"""
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
READ_FUNCS:
|
|
33
|
+
READ_FUNCS: dict[SupportedData, ReadFunc] = {
|
|
34
34
|
SupportedData.TIFF: read_tiff,
|
|
35
35
|
}
|
|
36
36
|
|
careamics/lightning/__init__.py
CHANGED
|
@@ -2,14 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
4
|
"FCNModule",
|
|
5
|
+
"HyperParametersCallback",
|
|
6
|
+
"PredictDataModule",
|
|
7
|
+
"ProgressBarCallback",
|
|
8
|
+
"TrainDataModule",
|
|
5
9
|
"VAEModule",
|
|
6
10
|
"create_careamics_module",
|
|
7
|
-
"TrainDataModule",
|
|
8
|
-
"create_train_datamodule",
|
|
9
|
-
"PredictDataModule",
|
|
10
11
|
"create_predict_datamodule",
|
|
11
|
-
"
|
|
12
|
-
"ProgressBarCallback",
|
|
12
|
+
"create_train_datamodule",
|
|
13
13
|
]
|
|
14
14
|
|
|
15
15
|
from .callbacks import HyperParametersCallback, ProgressBarCallback
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
"""A package for the `PredictionWriterCallback` class and utilities."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"CacheTiles",
|
|
4
5
|
"PredictionWriterCallback",
|
|
5
|
-
"create_write_strategy",
|
|
6
|
-
"WriteStrategy",
|
|
7
6
|
"WriteImage",
|
|
8
|
-
"
|
|
7
|
+
"WriteStrategy",
|
|
9
8
|
"WriteTilesZarr",
|
|
9
|
+
"create_write_strategy",
|
|
10
10
|
"select_write_extension",
|
|
11
11
|
"select_write_func",
|
|
12
12
|
]
|
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from collections.abc import Sequence
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Optional,
|
|
7
|
+
from typing import Any, Optional, Union
|
|
7
8
|
|
|
8
9
|
from pytorch_lightning import LightningModule, Trainer
|
|
9
10
|
from pytorch_lightning.callbacks import BasePredictionWriter
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Module containing different strategies for writing predictions."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Sequence
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Optional, Protocol,
|
|
5
|
+
from typing import Any, Optional, Protocol, Union
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
from numpy.typing import NDArray
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Progressbar callback."""
|
|
2
2
|
|
|
3
3
|
import sys
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
6
|
from pytorch_lightning import LightningModule, Trainer
|
|
7
7
|
from pytorch_lightning.callbacks import TQDMProgressBar
|
|
@@ -71,7 +71,7 @@ class ProgressBarCallback(TQDMProgressBar):
|
|
|
71
71
|
|
|
72
72
|
def get_metrics(
|
|
73
73
|
self, trainer: Trainer, pl_module: LightningModule
|
|
74
|
-
) ->
|
|
74
|
+
) -> dict[str, Union[int, str, float, dict[str, float]]]:
|
|
75
75
|
"""Override this to customize the metrics displayed in the progress bar.
|
|
76
76
|
|
|
77
77
|
Parameters
|
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
import pytorch_lightning as L
|
|
7
7
|
from torch import Tensor, nn
|
|
8
8
|
|
|
9
|
-
from careamics.config import
|
|
9
|
+
from careamics.config import UNetBasedAlgorithm, VAEBasedAlgorithm
|
|
10
10
|
from careamics.config.support import (
|
|
11
11
|
SupportedAlgorithm,
|
|
12
12
|
SupportedArchitecture,
|
|
@@ -14,8 +14,8 @@ from careamics.config.support import (
|
|
|
14
14
|
SupportedOptimizer,
|
|
15
15
|
SupportedScheduler,
|
|
16
16
|
)
|
|
17
|
+
from careamics.config.tile_information import TileInformation
|
|
17
18
|
from careamics.losses import loss_factory
|
|
18
|
-
from careamics.losses.loss_factory import LVAELossParameters
|
|
19
19
|
from careamics.models.lvae.likelihoods import (
|
|
20
20
|
GaussianLikelihood,
|
|
21
21
|
NoiseModelLikelihood,
|
|
@@ -34,6 +34,7 @@ from careamics.utils.torch_utils import get_optimizer, get_scheduler
|
|
|
34
34
|
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
+
# TODO rename to UNetModule
|
|
37
38
|
class FCNModule(L.LightningModule):
|
|
38
39
|
"""
|
|
39
40
|
CAREamics Lightning module.
|
|
@@ -60,7 +61,7 @@ class FCNModule(L.LightningModule):
|
|
|
60
61
|
Learning rate scheduler name.
|
|
61
62
|
"""
|
|
62
63
|
|
|
63
|
-
def __init__(self, algorithm_config: Union[
|
|
64
|
+
def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None:
|
|
64
65
|
"""Lightning module for CAREamics.
|
|
65
66
|
|
|
66
67
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
@@ -74,7 +75,9 @@ class FCNModule(L.LightningModule):
|
|
|
74
75
|
super().__init__()
|
|
75
76
|
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
76
77
|
if isinstance(algorithm_config, dict):
|
|
77
|
-
algorithm_config =
|
|
78
|
+
algorithm_config = UNetBasedAlgorithm(
|
|
79
|
+
**algorithm_config
|
|
80
|
+
) # TODO this needs to be updated using the algorithm-specific class
|
|
78
81
|
|
|
79
82
|
# create model and loss function
|
|
80
83
|
self.model: nn.Module = model_factory(algorithm_config.model)
|
|
@@ -164,7 +167,17 @@ class FCNModule(L.LightningModule):
|
|
|
164
167
|
Any
|
|
165
168
|
Model output.
|
|
166
169
|
"""
|
|
167
|
-
|
|
170
|
+
# TODO refactor when redoing datasets
|
|
171
|
+
# hacky way to determine if it is PredictDataModule, otherwise there is a
|
|
172
|
+
# circular import to solve with isinstance
|
|
173
|
+
from_prediction = hasattr(self._trainer.datamodule, "tiled")
|
|
174
|
+
is_tiled = (
|
|
175
|
+
len(batch) > 1
|
|
176
|
+
and isinstance(batch[1], list)
|
|
177
|
+
and isinstance(batch[1][0], TileInformation)
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if is_tiled:
|
|
168
181
|
x, *aux = batch
|
|
169
182
|
else:
|
|
170
183
|
x = batch
|
|
@@ -172,7 +185,10 @@ class FCNModule(L.LightningModule):
|
|
|
172
185
|
|
|
173
186
|
# apply test-time augmentation if available
|
|
174
187
|
# TODO: probably wont work with batch size > 1
|
|
175
|
-
if
|
|
188
|
+
if (
|
|
189
|
+
from_prediction
|
|
190
|
+
and self._trainer.datamodule.prediction_config.tta_transforms
|
|
191
|
+
):
|
|
176
192
|
tta = ImageRestorationTTA()
|
|
177
193
|
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
178
194
|
augmented_output = []
|
|
@@ -184,9 +200,18 @@ class FCNModule(L.LightningModule):
|
|
|
184
200
|
output = self.model(x)
|
|
185
201
|
|
|
186
202
|
# Denormalize the output
|
|
203
|
+
# TODO incompatible API between predict and train datasets
|
|
187
204
|
denorm = Denormalize(
|
|
188
|
-
image_means=
|
|
189
|
-
|
|
205
|
+
image_means=(
|
|
206
|
+
self._trainer.datamodule.predict_dataset.image_means
|
|
207
|
+
if from_prediction
|
|
208
|
+
else self._trainer.datamodule.train_dataset.image_stats.means
|
|
209
|
+
),
|
|
210
|
+
image_stds=(
|
|
211
|
+
self._trainer.datamodule.predict_dataset.image_stds
|
|
212
|
+
if from_prediction
|
|
213
|
+
else self._trainer.datamodule.train_dataset.image_stats.stds
|
|
214
|
+
),
|
|
190
215
|
)
|
|
191
216
|
denormalized_output = denorm(patch=output.cpu().numpy())
|
|
192
217
|
|
|
@@ -244,7 +269,7 @@ class VAEModule(L.LightningModule):
|
|
|
244
269
|
Learning rate scheduler name.
|
|
245
270
|
"""
|
|
246
271
|
|
|
247
|
-
def __init__(self, algorithm_config: Union[
|
|
272
|
+
def __init__(self, algorithm_config: Union[VAEBasedAlgorithm, dict]) -> None:
|
|
248
273
|
"""Lightning module for CAREamics.
|
|
249
274
|
|
|
250
275
|
This class encapsulates the a PyTorch model along with the training, validation,
|
|
@@ -258,7 +283,7 @@ class VAEModule(L.LightningModule):
|
|
|
258
283
|
super().__init__()
|
|
259
284
|
# if loading from a checkpoint, AlgorithmModel needs to be instantiated
|
|
260
285
|
self.algorithm_config = (
|
|
261
|
-
|
|
286
|
+
VAEBasedAlgorithm(**algorithm_config)
|
|
262
287
|
if isinstance(algorithm_config, dict)
|
|
263
288
|
else algorithm_config
|
|
264
289
|
)
|
|
@@ -266,29 +291,27 @@ class VAEModule(L.LightningModule):
|
|
|
266
291
|
# TODO: log algorithm config
|
|
267
292
|
# self.save_hyperparameters(self.algorithm_config.model_dump())
|
|
268
293
|
|
|
269
|
-
# create model
|
|
294
|
+
# create model
|
|
270
295
|
self.model: nn.Module = model_factory(self.algorithm_config.model)
|
|
271
|
-
|
|
296
|
+
|
|
297
|
+
# create loss function
|
|
298
|
+
self.noise_model: Optional[NoiseModel] = noise_model_factory(
|
|
272
299
|
self.algorithm_config.noise_model
|
|
273
300
|
)
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
|
|
281
|
-
self.algorithm_config.noise_model_likelihood_model
|
|
301
|
+
|
|
302
|
+
self.noise_model_likelihood: Optional[NoiseModelLikelihood] = (
|
|
303
|
+
likelihood_factory(
|
|
304
|
+
config=self.algorithm_config.noise_model_likelihood,
|
|
305
|
+
noise_model=self.noise_model,
|
|
306
|
+
)
|
|
282
307
|
)
|
|
283
|
-
|
|
284
|
-
|
|
308
|
+
|
|
309
|
+
self.gaussian_likelihood: Optional[GaussianLikelihood] = likelihood_factory(
|
|
310
|
+
self.algorithm_config.gaussian_likelihood
|
|
285
311
|
)
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
# TODO: musplit/denoisplit weights ?
|
|
290
|
-
) # type: ignore
|
|
291
|
-
self.loss_func = loss_factory(self.algorithm_config.loss)
|
|
312
|
+
|
|
313
|
+
self.loss_parameters = self.algorithm_config.loss
|
|
314
|
+
self.loss_func = loss_factory(self.algorithm_config.loss.loss_type)
|
|
292
315
|
|
|
293
316
|
# save optimizer and lr_scheduler names and parameters
|
|
294
317
|
self.optimizer_name = self.algorithm_config.optimizer.name
|
|
@@ -344,11 +367,16 @@ class VAEModule(L.LightningModule):
|
|
|
344
367
|
out = self.model(x)
|
|
345
368
|
|
|
346
369
|
# Update loss parameters
|
|
347
|
-
|
|
348
|
-
self.loss_parameters.current_epoch = self.current_epoch
|
|
370
|
+
self.loss_parameters.kl_params.current_epoch = self.current_epoch
|
|
349
371
|
|
|
350
372
|
# Compute loss
|
|
351
|
-
loss = self.loss_func(
|
|
373
|
+
loss = self.loss_func(
|
|
374
|
+
model_outputs=out,
|
|
375
|
+
targets=target,
|
|
376
|
+
config=self.loss_parameters,
|
|
377
|
+
gaussian_likelihood=self.gaussian_likelihood,
|
|
378
|
+
noise_model_likelihood=self.noise_model_likelihood,
|
|
379
|
+
)
|
|
352
380
|
|
|
353
381
|
# Logging
|
|
354
382
|
# TODO: implement a separate logging method?
|
|
@@ -376,7 +404,13 @@ class VAEModule(L.LightningModule):
|
|
|
376
404
|
out = self.model(x)
|
|
377
405
|
|
|
378
406
|
# Compute loss
|
|
379
|
-
loss = self.loss_func(
|
|
407
|
+
loss = self.loss_func(
|
|
408
|
+
model_outputs=out,
|
|
409
|
+
targets=target,
|
|
410
|
+
config=self.loss_parameters,
|
|
411
|
+
gaussian_likelihood=self.gaussian_likelihood,
|
|
412
|
+
noise_model_likelihood=self.noise_model_likelihood,
|
|
413
|
+
)
|
|
380
414
|
|
|
381
415
|
# Logging
|
|
382
416
|
# Rename val_loss dict
|
|
@@ -625,9 +659,10 @@ def create_careamics_module(
|
|
|
625
659
|
algorithm_configuration["model"] = model_configuration
|
|
626
660
|
|
|
627
661
|
# call the parent init using an AlgorithmModel instance
|
|
662
|
+
# TODO broken by new configutations!
|
|
628
663
|
algorithm_str = algorithm_configuration["algorithm"]
|
|
629
|
-
if algorithm_str in
|
|
630
|
-
return FCNModule(
|
|
664
|
+
if algorithm_str in UNetBasedAlgorithm.get_compatible_algorithms():
|
|
665
|
+
return FCNModule(UNetBasedAlgorithm(**algorithm_configuration))
|
|
631
666
|
else:
|
|
632
667
|
raise NotImplementedError(
|
|
633
668
|
f"Model {algorithm_str} is not implemented or unknown."
|
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import Any, Callable, Literal, Optional, Union
|
|
5
|
+
from warnings import warn
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytorch_lightning as L
|
|
8
9
|
from numpy.typing import NDArray
|
|
9
|
-
from torch.utils.data import DataLoader
|
|
10
|
+
from torch.utils.data import DataLoader, IterableDataset
|
|
10
11
|
|
|
11
|
-
from careamics.config import DataConfig
|
|
12
|
+
from careamics.config.data import DataConfig, GeneralDataConfig, N2VDataConfig
|
|
12
13
|
from careamics.config.support import SupportedData
|
|
13
14
|
from careamics.config.transformations import TransformModel
|
|
14
15
|
from careamics.dataset.dataset_utils import (
|
|
@@ -118,7 +119,7 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
118
119
|
|
|
119
120
|
def __init__(
|
|
120
121
|
self,
|
|
121
|
-
data_config:
|
|
122
|
+
data_config: GeneralDataConfig,
|
|
122
123
|
train_data: Union[Path, str, NDArray],
|
|
123
124
|
val_data: Optional[Union[Path, str, NDArray]] = None,
|
|
124
125
|
train_data_target: Optional[Union[Path, str, NDArray]] = None,
|
|
@@ -218,7 +219,7 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
218
219
|
)
|
|
219
220
|
|
|
220
221
|
# configuration
|
|
221
|
-
self.data_config:
|
|
222
|
+
self.data_config: GeneralDataConfig = data_config
|
|
222
223
|
self.data_type: str = data_config.data_type
|
|
223
224
|
self.batch_size: int = data_config.batch_size
|
|
224
225
|
self.use_in_memory: bool = use_in_memory
|
|
@@ -446,6 +447,19 @@ class TrainDataModule(L.LightningDataModule):
|
|
|
446
447
|
Any
|
|
447
448
|
Training dataloader.
|
|
448
449
|
"""
|
|
450
|
+
# check because iterable dataset cannot be shuffled
|
|
451
|
+
if not isinstance(self.train_dataset, IterableDataset):
|
|
452
|
+
if ("shuffle" in self.dataloader_params) and (
|
|
453
|
+
not self.dataloader_params["shuffle"]
|
|
454
|
+
):
|
|
455
|
+
warn(
|
|
456
|
+
"Dataloader parameters include `shuffle=False`, this will be "
|
|
457
|
+
"passed to the training dataloader and may result in bad results.",
|
|
458
|
+
stacklevel=1,
|
|
459
|
+
)
|
|
460
|
+
else:
|
|
461
|
+
self.dataloader_params["shuffle"] = True
|
|
462
|
+
|
|
449
463
|
return DataLoader(
|
|
450
464
|
self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
|
|
451
465
|
)
|
|
@@ -488,12 +502,23 @@ def create_train_datamodule(
|
|
|
488
502
|
"""Create a TrainDataModule.
|
|
489
503
|
|
|
490
504
|
This function is used to explicitly pass the parameters usually contained in a
|
|
491
|
-
`
|
|
505
|
+
`GenericDataConfig` to a TrainDataModule.
|
|
492
506
|
|
|
493
507
|
Since the lightning datamodule has no access to the model, make sure that the
|
|
494
508
|
parameters passed to the datamodule are consistent with the model's requirements and
|
|
495
509
|
are coherent.
|
|
496
510
|
|
|
511
|
+
By default, the train DataModule will be set for Noise2Void if no target data is
|
|
512
|
+
provided. That means that it will add a `N2VManipulateModel` transformation to the
|
|
513
|
+
list of augmentations. The default augmentations are XY flip, XY rotation, and N2V
|
|
514
|
+
pixel manipulation. If you pass a training target data, the default behaviour is to
|
|
515
|
+
train a supervised model. It will use the default XY flip and rotation
|
|
516
|
+
augmentations.
|
|
517
|
+
|
|
518
|
+
To use a different set of transformations, you can pass a list of transforms to
|
|
519
|
+
`transforms`. Note that if you intend to use Noise2Void, you should add
|
|
520
|
+
`N2VManipulateModel` as the last transform in the list of transformations.
|
|
521
|
+
|
|
497
522
|
The data module can be used with Path, str or numpy arrays. In the case of
|
|
498
523
|
numpy arrays, it loads and computes all the patches in memory. For Path and str
|
|
499
524
|
inputs, it calculates the total file size and estimate whether it can fit in
|
|
@@ -504,11 +529,6 @@ def create_train_datamodule(
|
|
|
504
529
|
To use array data, set `data_type` to `array` and pass a numpy array to
|
|
505
530
|
`train_data`.
|
|
506
531
|
|
|
507
|
-
In particular, N2V requires a specific transformation (N2V manipulates), which is
|
|
508
|
-
not compatible with supervised training. The default transformations applied to the
|
|
509
|
-
training patches are defined in `careamics.config.data_model`. To use different
|
|
510
|
-
transformations, pass a list of transforms. See examples for more details.
|
|
511
|
-
|
|
512
532
|
By default, CAREamics only supports types defined in
|
|
513
533
|
`careamics.config.support.SupportedData`. To read custom data types, you can set
|
|
514
534
|
`data_type` to `custom` and provide a function that returns a numpy array from a
|
|
@@ -613,12 +633,12 @@ def create_train_datamodule(
|
|
|
613
633
|
transforms:
|
|
614
634
|
>>> import numpy as np
|
|
615
635
|
>>> from careamics.lightning import create_train_datamodule
|
|
636
|
+
>>> from careamics.config.transformations import XYFlipModel, N2VManipulateModel
|
|
616
637
|
>>> from careamics.config.support import SupportedTransform
|
|
617
638
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
618
639
|
>>> my_transforms = [
|
|
619
|
-
...
|
|
620
|
-
...
|
|
621
|
-
... }
|
|
640
|
+
... XYFlipModel(flip_y=False),
|
|
641
|
+
... N2VManipulateModel()
|
|
622
642
|
... ]
|
|
623
643
|
>>> data_module = create_train_datamodule(
|
|
624
644
|
... train_data=my_array,
|
|
@@ -645,21 +665,15 @@ def create_train_datamodule(
|
|
|
645
665
|
if transforms is not None:
|
|
646
666
|
data_dict["transforms"] = transforms
|
|
647
667
|
|
|
648
|
-
#
|
|
649
|
-
|
|
668
|
+
# TODO not compatible with HDN, consider adding an argument for n2v/hdn
|
|
669
|
+
if train_target_data is None:
|
|
670
|
+
data_config: GeneralDataConfig = N2VDataConfig(**data_dict)
|
|
671
|
+
assert isinstance(data_config, N2VDataConfig)
|
|
650
672
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
data_config.set_N2V2(use_n2v2)
|
|
656
|
-
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
657
|
-
else:
|
|
658
|
-
raise ValueError(
|
|
659
|
-
"Cannot have both supervised training (target data) and "
|
|
660
|
-
"N2V manipulation in the transforms. Pass a list of transforms "
|
|
661
|
-
"that is compatible with your supervised training."
|
|
662
|
-
)
|
|
673
|
+
data_config.set_n2v2(use_n2v2)
|
|
674
|
+
data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
|
|
675
|
+
else:
|
|
676
|
+
data_config = DataConfig(**data_dict)
|
|
663
677
|
|
|
664
678
|
# sanity check on the dataloader parameters
|
|
665
679
|
if "batch_size" in dataloader_params:
|
careamics/losses/__init__.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Losses module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
+
"denoisplit_loss",
|
|
5
|
+
"denoisplit_musplit_loss",
|
|
4
6
|
"loss_factory",
|
|
5
7
|
"mae_loss",
|
|
6
8
|
"mse_loss",
|
|
7
|
-
"n2v_loss",
|
|
8
|
-
"denoisplit_loss",
|
|
9
9
|
"musplit_loss",
|
|
10
|
-
"
|
|
10
|
+
"n2v_loss",
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
careamics/losses/loss_factory.py
CHANGED
|
@@ -7,7 +7,7 @@ This module contains a factory function for creating loss functions.
|
|
|
7
7
|
from __future__ import annotations
|
|
8
8
|
|
|
9
9
|
from dataclasses import dataclass
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import Callable, Union
|
|
11
11
|
|
|
12
12
|
from torch import Tensor as tensor
|
|
13
13
|
|
|
@@ -15,18 +15,6 @@ from ..config.support import SupportedLoss
|
|
|
15
15
|
from .fcn.losses import mae_loss, mse_loss, n2v_loss
|
|
16
16
|
from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
|
|
17
17
|
|
|
18
|
-
if TYPE_CHECKING:
|
|
19
|
-
from careamics.models.lvae.likelihoods import (
|
|
20
|
-
GaussianLikelihood,
|
|
21
|
-
NoiseModelLikelihood,
|
|
22
|
-
)
|
|
23
|
-
from careamics.models.lvae.noise_models import (
|
|
24
|
-
GaussianMixtureNoiseModel,
|
|
25
|
-
MultiChannelNoiseModel,
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
29
|
-
|
|
30
18
|
|
|
31
19
|
@dataclass
|
|
32
20
|
class FCNLossParameters:
|
|
@@ -40,78 +28,6 @@ class FCNLossParameters:
|
|
|
40
28
|
loss_weight: float
|
|
41
29
|
|
|
42
30
|
|
|
43
|
-
@dataclass # TODO why not pydantic?
|
|
44
|
-
class LVAELossParameters:
|
|
45
|
-
"""Dataclass for LVAE loss."""
|
|
46
|
-
|
|
47
|
-
# TODO: refactor in more modular blocks (otherwise it gets messy very easily)
|
|
48
|
-
# e.g., - weights, - kl_params, ...
|
|
49
|
-
|
|
50
|
-
noise_model_likelihood: Optional[NoiseModelLikelihood] = None
|
|
51
|
-
"""Noise model likelihood instance."""
|
|
52
|
-
gaussian_likelihood: Optional[GaussianLikelihood] = None
|
|
53
|
-
"""Gaussian likelihood instance."""
|
|
54
|
-
current_epoch: int = 0
|
|
55
|
-
"""Current epoch in the training loop."""
|
|
56
|
-
reconstruction_weight: float = 1.0
|
|
57
|
-
"""Weight for the reconstruction loss in the total net loss
|
|
58
|
-
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
59
|
-
musplit_weight: float = 0.1
|
|
60
|
-
"""Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
|
|
61
|
-
denoisplit_weight: float = 0.9
|
|
62
|
-
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
|
|
63
|
-
kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
|
|
64
|
-
"""Type of KL divergence used as KL loss."""
|
|
65
|
-
kl_weight: float = 1.0
|
|
66
|
-
"""Weight for the KL loss in the total net loss.
|
|
67
|
-
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
|
|
68
|
-
kl_annealing: bool = False
|
|
69
|
-
"""Whether to apply KL loss annealing."""
|
|
70
|
-
kl_start: int = -1
|
|
71
|
-
"""Epoch at which KL loss annealing starts."""
|
|
72
|
-
kl_annealtime: int = 10
|
|
73
|
-
"""Number of epochs for which KL loss annealing is applied."""
|
|
74
|
-
non_stochastic: bool = False
|
|
75
|
-
"""Whether to sample latents and compute KL."""
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# TODO: really needed?
|
|
79
|
-
# like it is now, it is difficult to use, we need a way to specify the
|
|
80
|
-
# loss parameters in a more user-friendly way.
|
|
81
|
-
def loss_parameters_factory(
|
|
82
|
-
type: SupportedLoss,
|
|
83
|
-
) -> Union[FCNLossParameters, LVAELossParameters]:
|
|
84
|
-
"""Return loss parameters.
|
|
85
|
-
|
|
86
|
-
Parameters
|
|
87
|
-
----------
|
|
88
|
-
type : SupportedLoss
|
|
89
|
-
Requested loss.
|
|
90
|
-
|
|
91
|
-
Returns
|
|
92
|
-
-------
|
|
93
|
-
Union[FCNLossParameters, LVAELossParameters]
|
|
94
|
-
Loss parameters.
|
|
95
|
-
|
|
96
|
-
Raises
|
|
97
|
-
------
|
|
98
|
-
NotImplementedError
|
|
99
|
-
If the loss is unknown.
|
|
100
|
-
"""
|
|
101
|
-
if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
|
|
102
|
-
return FCNLossParameters
|
|
103
|
-
|
|
104
|
-
elif type in [
|
|
105
|
-
SupportedLoss.MUSPLIT,
|
|
106
|
-
SupportedLoss.DENOISPLIT,
|
|
107
|
-
SupportedLoss.DENOISPLIT_MUSPLIT,
|
|
108
|
-
]:
|
|
109
|
-
return LVAELossParameters # it returns the class, not an instance
|
|
110
|
-
|
|
111
|
-
else:
|
|
112
|
-
raise NotImplementedError(f"Loss {type} is not yet supported.")
|
|
113
|
-
|
|
114
|
-
|
|
115
31
|
def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
|
|
116
32
|
"""Return loss function.
|
|
117
33
|
|