careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
1
1
  """Module to get read functions."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Callable, Dict, Protocol, Union
4
+ from typing import Callable, Protocol, Union
5
5
 
6
6
  from numpy.typing import NDArray
7
7
 
@@ -30,7 +30,7 @@ class ReadFunc(Protocol):
30
30
  """
31
31
 
32
32
 
33
- READ_FUNCS: Dict[SupportedData, ReadFunc] = {
33
+ READ_FUNCS: dict[SupportedData, ReadFunc] = {
34
34
  SupportedData.TIFF: read_tiff,
35
35
  }
36
36
 
@@ -1,10 +1,10 @@
1
1
  """Functions relating to writing image files of different formats."""
2
2
 
3
3
  __all__ = [
4
+ "SupportedWriteType",
5
+ "WriteFunc",
4
6
  "get_write_func",
5
7
  "write_tiff",
6
- "WriteFunc",
7
- "SupportedWriteType",
8
8
  ]
9
9
 
10
10
  from .get_func import (
@@ -2,14 +2,14 @@
2
2
 
3
3
  __all__ = [
4
4
  "FCNModule",
5
+ "HyperParametersCallback",
6
+ "PredictDataModule",
7
+ "ProgressBarCallback",
8
+ "TrainDataModule",
5
9
  "VAEModule",
6
10
  "create_careamics_module",
7
- "TrainDataModule",
8
- "create_train_datamodule",
9
- "PredictDataModule",
10
11
  "create_predict_datamodule",
11
- "HyperParametersCallback",
12
- "ProgressBarCallback",
12
+ "create_train_datamodule",
13
13
  ]
14
14
 
15
15
  from .callbacks import HyperParametersCallback, ProgressBarCallback
@@ -2,8 +2,8 @@
2
2
 
3
3
  __all__ = [
4
4
  "HyperParametersCallback",
5
- "ProgressBarCallback",
6
5
  "PredictionWriterCallback",
6
+ "ProgressBarCallback",
7
7
  ]
8
8
 
9
9
  from .hyperparameters_callback import HyperParametersCallback
@@ -1,12 +1,12 @@
1
1
  """A package for the `PredictionWriterCallback` class and utilities."""
2
2
 
3
3
  __all__ = [
4
+ "CacheTiles",
4
5
  "PredictionWriterCallback",
5
- "create_write_strategy",
6
- "WriteStrategy",
7
6
  "WriteImage",
8
- "CacheTiles",
7
+ "WriteStrategy",
9
8
  "WriteTilesZarr",
9
+ "create_write_strategy",
10
10
  "select_write_extension",
11
11
  "select_write_func",
12
12
  ]
@@ -2,8 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Optional, Sequence, Union
7
+ from typing import Any, Optional, Union
7
8
 
8
9
  from pytorch_lightning import LightningModule, Trainer
9
10
  from pytorch_lightning.callbacks import BasePredictionWriter
@@ -1,7 +1,8 @@
1
1
  """Module containing different strategies for writing predictions."""
2
2
 
3
+ from collections.abc import Sequence
3
4
  from pathlib import Path
4
- from typing import Any, Optional, Protocol, Sequence, Union
5
+ from typing import Any, Optional, Protocol, Union
5
6
 
6
7
  import numpy as np
7
8
  from numpy.typing import NDArray
@@ -1,7 +1,7 @@
1
1
  """Progressbar callback."""
2
2
 
3
3
  import sys
4
- from typing import Dict, Union
4
+ from typing import Union
5
5
 
6
6
  from pytorch_lightning import LightningModule, Trainer
7
7
  from pytorch_lightning.callbacks import TQDMProgressBar
@@ -71,7 +71,7 @@ class ProgressBarCallback(TQDMProgressBar):
71
71
 
72
72
  def get_metrics(
73
73
  self, trainer: Trainer, pl_module: LightningModule
74
- ) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
74
+ ) -> dict[str, Union[int, str, float, dict[str, float]]]:
75
75
  """Override this to customize the metrics displayed in the progress bar.
76
76
 
77
77
  Parameters
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import pytorch_lightning as L
7
7
  from torch import Tensor, nn
8
8
 
9
- from careamics.config import FCNAlgorithmConfig, VAEAlgorithmConfig
9
+ from careamics.config import UNetBasedAlgorithm, VAEBasedAlgorithm
10
10
  from careamics.config.support import (
11
11
  SupportedAlgorithm,
12
12
  SupportedArchitecture,
@@ -14,8 +14,8 @@ from careamics.config.support import (
14
14
  SupportedOptimizer,
15
15
  SupportedScheduler,
16
16
  )
17
+ from careamics.config.tile_information import TileInformation
17
18
  from careamics.losses import loss_factory
18
- from careamics.losses.loss_factory import LVAELossParameters
19
19
  from careamics.models.lvae.likelihoods import (
20
20
  GaussianLikelihood,
21
21
  NoiseModelLikelihood,
@@ -34,6 +34,7 @@ from careamics.utils.torch_utils import get_optimizer, get_scheduler
34
34
  NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
35
35
 
36
36
 
37
+ # TODO rename to UNetModule
37
38
  class FCNModule(L.LightningModule):
38
39
  """
39
40
  CAREamics Lightning module.
@@ -60,7 +61,7 @@ class FCNModule(L.LightningModule):
60
61
  Learning rate scheduler name.
61
62
  """
62
63
 
63
- def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None:
64
+ def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None:
64
65
  """Lightning module for CAREamics.
65
66
 
66
67
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -74,7 +75,9 @@ class FCNModule(L.LightningModule):
74
75
  super().__init__()
75
76
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
76
77
  if isinstance(algorithm_config, dict):
77
- algorithm_config = FCNAlgorithmConfig(**algorithm_config)
78
+ algorithm_config = UNetBasedAlgorithm(
79
+ **algorithm_config
80
+ ) # TODO this needs to be updated using the algorithm-specific class
78
81
 
79
82
  # create model and loss function
80
83
  self.model: nn.Module = model_factory(algorithm_config.model)
@@ -164,7 +167,17 @@ class FCNModule(L.LightningModule):
164
167
  Any
165
168
  Model output.
166
169
  """
167
- if self._trainer.datamodule.tiled:
170
+ # TODO refactor when redoing datasets
171
+ # hacky way to determine if it is PredictDataModule, otherwise there is a
172
+ # circular import to solve with isinstance
173
+ from_prediction = hasattr(self._trainer.datamodule, "tiled")
174
+ is_tiled = (
175
+ len(batch) > 1
176
+ and isinstance(batch[1], list)
177
+ and isinstance(batch[1][0], TileInformation)
178
+ )
179
+
180
+ if is_tiled:
168
181
  x, *aux = batch
169
182
  else:
170
183
  x = batch
@@ -172,7 +185,10 @@ class FCNModule(L.LightningModule):
172
185
 
173
186
  # apply test-time augmentation if available
174
187
  # TODO: probably wont work with batch size > 1
175
- if self._trainer.datamodule.prediction_config.tta_transforms:
188
+ if (
189
+ from_prediction
190
+ and self._trainer.datamodule.prediction_config.tta_transforms
191
+ ):
176
192
  tta = ImageRestorationTTA()
177
193
  augmented_batch = tta.forward(x) # list of augmented tensors
178
194
  augmented_output = []
@@ -184,9 +200,18 @@ class FCNModule(L.LightningModule):
184
200
  output = self.model(x)
185
201
 
186
202
  # Denormalize the output
203
+ # TODO incompatible API between predict and train datasets
187
204
  denorm = Denormalize(
188
- image_means=self._trainer.datamodule.predict_dataset.image_means,
189
- image_stds=self._trainer.datamodule.predict_dataset.image_stds,
205
+ image_means=(
206
+ self._trainer.datamodule.predict_dataset.image_means
207
+ if from_prediction
208
+ else self._trainer.datamodule.train_dataset.image_stats.means
209
+ ),
210
+ image_stds=(
211
+ self._trainer.datamodule.predict_dataset.image_stds
212
+ if from_prediction
213
+ else self._trainer.datamodule.train_dataset.image_stats.stds
214
+ ),
190
215
  )
191
216
  denormalized_output = denorm(patch=output.cpu().numpy())
192
217
 
@@ -244,7 +269,7 @@ class VAEModule(L.LightningModule):
244
269
  Learning rate scheduler name.
245
270
  """
246
271
 
247
- def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None:
272
+ def __init__(self, algorithm_config: Union[VAEBasedAlgorithm, dict]) -> None:
248
273
  """Lightning module for CAREamics.
249
274
 
250
275
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -258,7 +283,7 @@ class VAEModule(L.LightningModule):
258
283
  super().__init__()
259
284
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
260
285
  self.algorithm_config = (
261
- VAEAlgorithmConfig(**algorithm_config)
286
+ VAEBasedAlgorithm(**algorithm_config)
262
287
  if isinstance(algorithm_config, dict)
263
288
  else algorithm_config
264
289
  )
@@ -266,29 +291,27 @@ class VAEModule(L.LightningModule):
266
291
  # TODO: log algorithm config
267
292
  # self.save_hyperparameters(self.algorithm_config.model_dump())
268
293
 
269
- # create model and loss function
294
+ # create model
270
295
  self.model: nn.Module = model_factory(self.algorithm_config.model)
271
- self.noise_model: NoiseModel = noise_model_factory(
296
+
297
+ # create loss function
298
+ self.noise_model: Optional[NoiseModel] = noise_model_factory(
272
299
  self.algorithm_config.noise_model
273
300
  )
274
- # TODO: here we can add some code to check whether the noise model is not None
275
- # and `self.algorithm_config.noise_model_likelihood_model.noise_model` is,
276
- # instead, None. In that case we could assign the noise model to the latter.
277
- # This is particular useful when loading an algorithm config from file.
278
- # Indeed, in that case the noise model in the nm likelihood is likely
279
- # not available since excluded from serializaion.
280
- self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
281
- self.algorithm_config.noise_model_likelihood_model
301
+
302
+ self.noise_model_likelihood: Optional[NoiseModelLikelihood] = (
303
+ likelihood_factory(
304
+ config=self.algorithm_config.noise_model_likelihood,
305
+ noise_model=self.noise_model,
306
+ )
282
307
  )
283
- self.gaussian_likelihood: GaussianLikelihood = likelihood_factory(
284
- self.algorithm_config.gaussian_likelihood_model
308
+
309
+ self.gaussian_likelihood: Optional[GaussianLikelihood] = likelihood_factory(
310
+ self.algorithm_config.gaussian_likelihood
285
311
  )
286
- self.loss_parameters = LVAELossParameters(
287
- noise_model_likelihood=self.noise_model_likelihood,
288
- gaussian_likelihood=self.gaussian_likelihood,
289
- # TODO: musplit/denoisplit weights ?
290
- ) # type: ignore
291
- self.loss_func = loss_factory(self.algorithm_config.loss)
312
+
313
+ self.loss_parameters = self.algorithm_config.loss
314
+ self.loss_func = loss_factory(self.algorithm_config.loss.loss_type)
292
315
 
293
316
  # save optimizer and lr_scheduler names and parameters
294
317
  self.optimizer_name = self.algorithm_config.optimizer.name
@@ -344,11 +367,16 @@ class VAEModule(L.LightningModule):
344
367
  out = self.model(x)
345
368
 
346
369
  # Update loss parameters
347
- # TODO rethink loss parameters
348
- self.loss_parameters.current_epoch = self.current_epoch
370
+ self.loss_parameters.kl_params.current_epoch = self.current_epoch
349
371
 
350
372
  # Compute loss
351
- loss = self.loss_func(out, target, self.loss_parameters) # TODO ugly ?
373
+ loss = self.loss_func(
374
+ model_outputs=out,
375
+ targets=target,
376
+ config=self.loss_parameters,
377
+ gaussian_likelihood=self.gaussian_likelihood,
378
+ noise_model_likelihood=self.noise_model_likelihood,
379
+ )
352
380
 
353
381
  # Logging
354
382
  # TODO: implement a separate logging method?
@@ -376,7 +404,13 @@ class VAEModule(L.LightningModule):
376
404
  out = self.model(x)
377
405
 
378
406
  # Compute loss
379
- loss = self.loss_func(out, target, self.loss_parameters)
407
+ loss = self.loss_func(
408
+ model_outputs=out,
409
+ targets=target,
410
+ config=self.loss_parameters,
411
+ gaussian_likelihood=self.gaussian_likelihood,
412
+ noise_model_likelihood=self.noise_model_likelihood,
413
+ )
380
414
 
381
415
  # Logging
382
416
  # Rename val_loss dict
@@ -625,9 +659,10 @@ def create_careamics_module(
625
659
  algorithm_configuration["model"] = model_configuration
626
660
 
627
661
  # call the parent init using an AlgorithmModel instance
662
+ # TODO broken by new configutations!
628
663
  algorithm_str = algorithm_configuration["algorithm"]
629
- if algorithm_str in FCNAlgorithmConfig.get_compatible_algorithms():
630
- return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
664
+ if algorithm_str in UNetBasedAlgorithm.get_compatible_algorithms():
665
+ return FCNModule(UNetBasedAlgorithm(**algorithm_configuration))
631
666
  else:
632
667
  raise NotImplementedError(
633
668
  f"Model {algorithm_str} is not implemented or unknown."
@@ -2,13 +2,14 @@
2
2
 
3
3
  from pathlib import Path
4
4
  from typing import Any, Callable, Literal, Optional, Union
5
+ from warnings import warn
5
6
 
6
7
  import numpy as np
7
8
  import pytorch_lightning as L
8
9
  from numpy.typing import NDArray
9
- from torch.utils.data import DataLoader
10
+ from torch.utils.data import DataLoader, IterableDataset
10
11
 
11
- from careamics.config import DataConfig
12
+ from careamics.config.data import DataConfig, GeneralDataConfig, N2VDataConfig
12
13
  from careamics.config.support import SupportedData
13
14
  from careamics.config.transformations import TransformModel
14
15
  from careamics.dataset.dataset_utils import (
@@ -118,7 +119,7 @@ class TrainDataModule(L.LightningDataModule):
118
119
 
119
120
  def __init__(
120
121
  self,
121
- data_config: DataConfig,
122
+ data_config: GeneralDataConfig,
122
123
  train_data: Union[Path, str, NDArray],
123
124
  val_data: Optional[Union[Path, str, NDArray]] = None,
124
125
  train_data_target: Optional[Union[Path, str, NDArray]] = None,
@@ -218,7 +219,7 @@ class TrainDataModule(L.LightningDataModule):
218
219
  )
219
220
 
220
221
  # configuration
221
- self.data_config: DataConfig = data_config
222
+ self.data_config: GeneralDataConfig = data_config
222
223
  self.data_type: str = data_config.data_type
223
224
  self.batch_size: int = data_config.batch_size
224
225
  self.use_in_memory: bool = use_in_memory
@@ -446,6 +447,19 @@ class TrainDataModule(L.LightningDataModule):
446
447
  Any
447
448
  Training dataloader.
448
449
  """
450
+ # check because iterable dataset cannot be shuffled
451
+ if not isinstance(self.train_dataset, IterableDataset):
452
+ if ("shuffle" in self.dataloader_params) and (
453
+ not self.dataloader_params["shuffle"]
454
+ ):
455
+ warn(
456
+ "Dataloader parameters include `shuffle=False`, this will be "
457
+ "passed to the training dataloader and may result in bad results.",
458
+ stacklevel=1,
459
+ )
460
+ else:
461
+ self.dataloader_params["shuffle"] = True
462
+
449
463
  return DataLoader(
450
464
  self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
451
465
  )
@@ -488,12 +502,23 @@ def create_train_datamodule(
488
502
  """Create a TrainDataModule.
489
503
 
490
504
  This function is used to explicitly pass the parameters usually contained in a
491
- `data_model` configuration to a TrainDataModule.
505
+ `GenericDataConfig` to a TrainDataModule.
492
506
 
493
507
  Since the lightning datamodule has no access to the model, make sure that the
494
508
  parameters passed to the datamodule are consistent with the model's requirements and
495
509
  are coherent.
496
510
 
511
+ By default, the train DataModule will be set for Noise2Void if no target data is
512
+ provided. That means that it will add a `N2VManipulateModel` transformation to the
513
+ list of augmentations. The default augmentations are XY flip, XY rotation, and N2V
514
+ pixel manipulation. If you pass a training target data, the default behaviour is to
515
+ train a supervised model. It will use the default XY flip and rotation
516
+ augmentations.
517
+
518
+ To use a different set of transformations, you can pass a list of transforms to
519
+ `transforms`. Note that if you intend to use Noise2Void, you should add
520
+ `N2VManipulateModel` as the last transform in the list of transformations.
521
+
497
522
  The data module can be used with Path, str or numpy arrays. In the case of
498
523
  numpy arrays, it loads and computes all the patches in memory. For Path and str
499
524
  inputs, it calculates the total file size and estimate whether it can fit in
@@ -504,11 +529,6 @@ def create_train_datamodule(
504
529
  To use array data, set `data_type` to `array` and pass a numpy array to
505
530
  `train_data`.
506
531
 
507
- In particular, N2V requires a specific transformation (N2V manipulates), which is
508
- not compatible with supervised training. The default transformations applied to the
509
- training patches are defined in `careamics.config.data_model`. To use different
510
- transformations, pass a list of transforms. See examples for more details.
511
-
512
532
  By default, CAREamics only supports types defined in
513
533
  `careamics.config.support.SupportedData`. To read custom data types, you can set
514
534
  `data_type` to `custom` and provide a function that returns a numpy array from a
@@ -613,12 +633,12 @@ def create_train_datamodule(
613
633
  transforms:
614
634
  >>> import numpy as np
615
635
  >>> from careamics.lightning import create_train_datamodule
636
+ >>> from careamics.config.transformations import XYFlipModel, N2VManipulateModel
616
637
  >>> from careamics.config.support import SupportedTransform
617
638
  >>> my_array = np.arange(256).reshape(16, 16)
618
639
  >>> my_transforms = [
619
- ... {
620
- ... "name": SupportedTransform.XY_FLIP.value,
621
- ... }
640
+ ... XYFlipModel(flip_y=False),
641
+ ... N2VManipulateModel()
622
642
  ... ]
623
643
  >>> data_module = create_train_datamodule(
624
644
  ... train_data=my_array,
@@ -645,21 +665,15 @@ def create_train_datamodule(
645
665
  if transforms is not None:
646
666
  data_dict["transforms"] = transforms
647
667
 
648
- # validate configuration
649
- data_config = DataConfig(**data_dict)
668
+ # TODO not compatible with HDN, consider adding an argument for n2v/hdn
669
+ if train_target_data is None:
670
+ data_config: GeneralDataConfig = N2VDataConfig(**data_dict)
671
+ assert isinstance(data_config, N2VDataConfig)
650
672
 
651
- # N2V specific checks, N2V, structN2V, and transforms
652
- if data_config.has_n2v_manipulate():
653
- # there is not target, n2v2 and structN2V can be changed
654
- if train_target_data is None:
655
- data_config.set_N2V2(use_n2v2)
656
- data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
657
- else:
658
- raise ValueError(
659
- "Cannot have both supervised training (target data) and "
660
- "N2V manipulation in the transforms. Pass a list of transforms "
661
- "that is compatible with your supervised training."
662
- )
673
+ data_config.set_n2v2(use_n2v2)
674
+ data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span)
675
+ else:
676
+ data_config = DataConfig(**data_dict)
663
677
 
664
678
  # sanity check on the dataloader parameters
665
679
  if "batch_size" in dataloader_params:
@@ -1,13 +1,13 @@
1
1
  """Losses module."""
2
2
 
3
3
  __all__ = [
4
+ "denoisplit_loss",
5
+ "denoisplit_musplit_loss",
4
6
  "loss_factory",
5
7
  "mae_loss",
6
8
  "mse_loss",
7
- "n2v_loss",
8
- "denoisplit_loss",
9
9
  "musplit_loss",
10
- "denoisplit_musplit_loss",
10
+ "n2v_loss",
11
11
  ]
12
12
 
13
13
  from .fcn.losses import mae_loss, mse_loss, n2v_loss
@@ -7,7 +7,7 @@ This module contains a factory function for creating loss functions.
7
7
  from __future__ import annotations
8
8
 
9
9
  from dataclasses import dataclass
10
- from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
10
+ from typing import Callable, Union
11
11
 
12
12
  from torch import Tensor as tensor
13
13
 
@@ -15,18 +15,6 @@ from ..config.support import SupportedLoss
15
15
  from .fcn.losses import mae_loss, mse_loss, n2v_loss
16
16
  from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
17
17
 
18
- if TYPE_CHECKING:
19
- from careamics.models.lvae.likelihoods import (
20
- GaussianLikelihood,
21
- NoiseModelLikelihood,
22
- )
23
- from careamics.models.lvae.noise_models import (
24
- GaussianMixtureNoiseModel,
25
- MultiChannelNoiseModel,
26
- )
27
-
28
- NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
29
-
30
18
 
31
19
  @dataclass
32
20
  class FCNLossParameters:
@@ -40,78 +28,6 @@ class FCNLossParameters:
40
28
  loss_weight: float
41
29
 
42
30
 
43
- @dataclass # TODO why not pydantic?
44
- class LVAELossParameters:
45
- """Dataclass for LVAE loss."""
46
-
47
- # TODO: refactor in more modular blocks (otherwise it gets messy very easily)
48
- # e.g., - weights, - kl_params, ...
49
-
50
- noise_model_likelihood: Optional[NoiseModelLikelihood] = None
51
- """Noise model likelihood instance."""
52
- gaussian_likelihood: Optional[GaussianLikelihood] = None
53
- """Gaussian likelihood instance."""
54
- current_epoch: int = 0
55
- """Current epoch in the training loop."""
56
- reconstruction_weight: float = 1.0
57
- """Weight for the reconstruction loss in the total net loss
58
- (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
59
- musplit_weight: float = 0.1
60
- """Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
61
- denoisplit_weight: float = 0.9
62
- """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
63
- kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
64
- """Type of KL divergence used as KL loss."""
65
- kl_weight: float = 1.0
66
- """Weight for the KL loss in the total net loss.
67
- (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
68
- kl_annealing: bool = False
69
- """Whether to apply KL loss annealing."""
70
- kl_start: int = -1
71
- """Epoch at which KL loss annealing starts."""
72
- kl_annealtime: int = 10
73
- """Number of epochs for which KL loss annealing is applied."""
74
- non_stochastic: bool = False
75
- """Whether to sample latents and compute KL."""
76
-
77
-
78
- # TODO: really needed?
79
- # like it is now, it is difficult to use, we need a way to specify the
80
- # loss parameters in a more user-friendly way.
81
- def loss_parameters_factory(
82
- type: SupportedLoss,
83
- ) -> Union[FCNLossParameters, LVAELossParameters]:
84
- """Return loss parameters.
85
-
86
- Parameters
87
- ----------
88
- type : SupportedLoss
89
- Requested loss.
90
-
91
- Returns
92
- -------
93
- Union[FCNLossParameters, LVAELossParameters]
94
- Loss parameters.
95
-
96
- Raises
97
- ------
98
- NotImplementedError
99
- If the loss is unknown.
100
- """
101
- if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
102
- return FCNLossParameters
103
-
104
- elif type in [
105
- SupportedLoss.MUSPLIT,
106
- SupportedLoss.DENOISPLIT,
107
- SupportedLoss.DENOISPLIT_MUSPLIT,
108
- ]:
109
- return LVAELossParameters # it returns the class, not an instance
110
-
111
- else:
112
- raise NotImplementedError(f"Loss {type} is not yet supported.")
113
-
114
-
115
31
  def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
116
32
  """Return loss function.
117
33