careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (87) hide show
  1. careamics/careamist.py +39 -28
  2. careamics/cli/__init__.py +5 -0
  3. careamics/cli/conf.py +391 -0
  4. careamics/cli/main.py +134 -0
  5. careamics/config/__init__.py +7 -3
  6. careamics/config/architectures/__init__.py +2 -2
  7. careamics/config/architectures/architecture_model.py +1 -1
  8. careamics/config/architectures/custom_model.py +11 -8
  9. careamics/config/architectures/lvae_model.py +170 -0
  10. careamics/config/configuration_factory.py +481 -170
  11. careamics/config/configuration_model.py +6 -3
  12. careamics/config/data_model.py +31 -20
  13. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
  14. careamics/config/likelihood_model.py +60 -0
  15. careamics/config/nm_model.py +127 -0
  16. careamics/config/optimizer_models.py +3 -1
  17. careamics/config/support/supported_activations.py +1 -0
  18. careamics/config/support/supported_algorithms.py +17 -4
  19. careamics/config/support/supported_architectures.py +8 -11
  20. careamics/config/support/supported_losses.py +3 -1
  21. careamics/config/support/supported_optimizers.py +1 -1
  22. careamics/config/support/supported_transforms.py +1 -0
  23. careamics/config/training_model.py +35 -6
  24. careamics/config/transformations/__init__.py +4 -1
  25. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  26. careamics/config/transformations/transform_union.py +20 -0
  27. careamics/config/vae_algorithm_model.py +137 -0
  28. careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
  29. careamics/file_io/read/tiff.py +1 -1
  30. careamics/lightning/__init__.py +3 -2
  31. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  32. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  33. careamics/lightning/lightning_module.py +367 -9
  34. careamics/lightning/predict_data_module.py +2 -2
  35. careamics/lightning/train_data_module.py +4 -4
  36. careamics/losses/__init__.py +11 -1
  37. careamics/losses/fcn/__init__.py +1 -0
  38. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  39. careamics/losses/loss_factory.py +112 -6
  40. careamics/losses/lvae/__init__.py +1 -0
  41. careamics/losses/lvae/loss_utils.py +83 -0
  42. careamics/losses/lvae/losses.py +445 -0
  43. careamics/lvae_training/dataset/__init__.py +15 -0
  44. careamics/lvae_training/dataset/config.py +123 -0
  45. careamics/lvae_training/dataset/lc_dataset.py +267 -0
  46. careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
  47. careamics/lvae_training/dataset/multifile_dataset.py +334 -0
  48. careamics/lvae_training/dataset/types.py +43 -0
  49. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  50. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  51. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  52. careamics/lvae_training/dataset/utils/index_manager.py +232 -0
  53. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  54. careamics/lvae_training/eval_utils.py +109 -64
  55. careamics/lvae_training/get_config.py +1 -1
  56. careamics/lvae_training/train_lvae.py +6 -3
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +2 -2
  59. careamics/model_io/bmz_io.py +20 -7
  60. careamics/model_io/model_io_utils.py +16 -4
  61. careamics/models/__init__.py +1 -3
  62. careamics/models/activation.py +2 -0
  63. careamics/models/lvae/__init__.py +3 -0
  64. careamics/models/lvae/layers.py +21 -21
  65. careamics/models/lvae/likelihoods.py +190 -129
  66. careamics/models/lvae/lvae.py +60 -148
  67. careamics/models/lvae/noise_models.py +318 -186
  68. careamics/models/lvae/utils.py +2 -2
  69. careamics/models/model_factory.py +22 -7
  70. careamics/prediction_utils/lvae_prediction.py +158 -0
  71. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  72. careamics/prediction_utils/stitch_prediction.py +16 -2
  73. careamics/transforms/compose.py +90 -15
  74. careamics/transforms/n2v_manipulate.py +6 -2
  75. careamics/transforms/normalize.py +14 -3
  76. careamics/transforms/pixel_manipulation.py +1 -1
  77. careamics/transforms/xy_flip.py +16 -6
  78. careamics/transforms/xy_random_rotate90.py +16 -7
  79. careamics/utils/metrics.py +277 -24
  80. careamics/utils/serializers.py +60 -0
  81. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
  82. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
  83. careamics-0.0.4.dist-info/entry_points.txt +2 -0
  84. careamics/config/architectures/vae_model.py +0 -42
  85. careamics/lvae_training/data_utils.py +0 -618
  86. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
  87. {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,12 @@
1
1
  """CAREamics Lightning module."""
2
2
 
3
- from typing import Any, Optional, Union
3
+ from typing import Any, Callable, Optional, Union
4
4
 
5
+ import numpy as np
5
6
  import pytorch_lightning as L
6
7
  from torch import Tensor, nn
7
8
 
8
- from careamics.config import AlgorithmConfig
9
+ from careamics.config import FCNAlgorithmConfig, VAEAlgorithmConfig
9
10
  from careamics.config.support import (
10
11
  SupportedAlgorithm,
11
12
  SupportedArchitecture,
@@ -14,12 +15,26 @@ from careamics.config.support import (
14
15
  SupportedScheduler,
15
16
  )
16
17
  from careamics.losses import loss_factory
18
+ from careamics.losses.loss_factory import LVAELossParameters
19
+ from careamics.models.lvae.likelihoods import (
20
+ GaussianLikelihood,
21
+ NoiseModelLikelihood,
22
+ likelihood_factory,
23
+ )
24
+ from careamics.models.lvae.noise_models import (
25
+ GaussianMixtureNoiseModel,
26
+ MultiChannelNoiseModel,
27
+ noise_model_factory,
28
+ )
17
29
  from careamics.models.model_factory import model_factory
18
30
  from careamics.transforms import Denormalize, ImageRestorationTTA
31
+ from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr
19
32
  from careamics.utils.torch_utils import get_optimizer, get_scheduler
20
33
 
34
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
35
+
21
36
 
22
- class CAREamicsModule(L.LightningModule):
37
+ class FCNModule(L.LightningModule):
23
38
  """
24
39
  CAREamics Lightning module.
25
40
 
@@ -45,7 +60,7 @@ class CAREamicsModule(L.LightningModule):
45
60
  Learning rate scheduler name.
46
61
  """
47
62
 
48
- def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
63
+ def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None:
49
64
  """Lightning module for CAREamics.
50
65
 
51
66
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -59,7 +74,7 @@ class CAREamicsModule(L.LightningModule):
59
74
  super().__init__()
60
75
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
61
76
  if isinstance(algorithm_config, dict):
62
- algorithm_config = AlgorithmConfig(**algorithm_config)
77
+ algorithm_config = FCNAlgorithmConfig(**algorithm_config)
63
78
 
64
79
  # create model and loss function
65
80
  self.model: nn.Module = model_factory(algorithm_config.model)
@@ -203,6 +218,343 @@ class CAREamicsModule(L.LightningModule):
203
218
  }
204
219
 
205
220
 
221
+ class VAEModule(L.LightningModule):
222
+ """
223
+ CAREamics Lightning module.
224
+
225
+ This class encapsulates the a PyTorch model along with the training, validation,
226
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
227
+
228
+ Parameters
229
+ ----------
230
+ algorithm_config : Union[VAEAlgorithmConfig, dict]
231
+ Algorithm configuration.
232
+
233
+ Attributes
234
+ ----------
235
+ model : nn.Module
236
+ PyTorch model.
237
+ loss_func : nn.Module
238
+ Loss function.
239
+ optimizer_name : str
240
+ Optimizer name.
241
+ optimizer_params : dict
242
+ Optimizer parameters.
243
+ lr_scheduler_name : str
244
+ Learning rate scheduler name.
245
+ """
246
+
247
+ def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None:
248
+ """Lightning module for CAREamics.
249
+
250
+ This class encapsulates the a PyTorch model along with the training, validation,
251
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
252
+
253
+ Parameters
254
+ ----------
255
+ algorithm_config : Union[AlgorithmModel, dict]
256
+ Algorithm configuration.
257
+ """
258
+ super().__init__()
259
+ # if loading from a checkpoint, AlgorithmModel needs to be instantiated
260
+ self.algorithm_config = (
261
+ VAEAlgorithmConfig(**algorithm_config)
262
+ if isinstance(algorithm_config, dict)
263
+ else algorithm_config
264
+ )
265
+
266
+ # TODO: log algorithm config
267
+ # self.save_hyperparameters(self.algorithm_config.model_dump())
268
+
269
+ # create model and loss function
270
+ self.model: nn.Module = model_factory(self.algorithm_config.model)
271
+ self.noise_model: NoiseModel = noise_model_factory(
272
+ self.algorithm_config.noise_model
273
+ )
274
+ # TODO: here we can add some code to check whether the noise model is not None
275
+ # and `self.algorithm_config.noise_model_likelihood_model.noise_model` is,
276
+ # instead, None. In that case we could assign the noise model to the latter.
277
+ # This is particular useful when loading an algorithm config from file.
278
+ # Indeed, in that case the noise model in the nm likelihood is likely
279
+ # not available since excluded from serializaion.
280
+ self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
281
+ self.algorithm_config.noise_model_likelihood_model
282
+ )
283
+ self.gaussian_likelihood: GaussianLikelihood = likelihood_factory(
284
+ self.algorithm_config.gaussian_likelihood_model
285
+ )
286
+ self.loss_parameters = LVAELossParameters(
287
+ noise_model_likelihood=self.noise_model_likelihood,
288
+ gaussian_likelihood=self.gaussian_likelihood,
289
+ # TODO: musplit/denoisplit weights ?
290
+ ) # type: ignore
291
+ self.loss_func = loss_factory(self.algorithm_config.loss)
292
+
293
+ # save optimizer and lr_scheduler names and parameters
294
+ self.optimizer_name = self.algorithm_config.optimizer.name
295
+ self.optimizer_params = self.algorithm_config.optimizer.parameters
296
+ self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name
297
+ self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters
298
+
299
+ # initialize running PSNR
300
+ self.running_psnr = [
301
+ RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
302
+ ]
303
+
304
+ def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
305
+ """Forward pass.
306
+
307
+ Parameters
308
+ ----------
309
+ x : Tensor
310
+ Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
311
+ number of lateral inputs.
312
+
313
+ Returns
314
+ -------
315
+ tuple[Tensor, dict[str, Any]]
316
+ A tuple with the output tensor and additional data from the top-down pass.
317
+ """
318
+ return self.model(x) # TODO Different model can have more than one output
319
+
320
+ def training_step(
321
+ self, batch: tuple[Tensor, Tensor], batch_idx: Any
322
+ ) -> Optional[dict[str, Tensor]]:
323
+ """Training step.
324
+
325
+ Parameters
326
+ ----------
327
+ batch : tuple[Tensor, Tensor]
328
+ Input batch. It is a tuple with the input tensor and the target tensor.
329
+ The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
330
+ number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
331
+ where C is the number of target channels (e.g., 1 in HDN, >1 in
332
+ muSplit/denoiSplit).
333
+ batch_idx : Any
334
+ Batch index.
335
+
336
+ Returns
337
+ -------
338
+ Any
339
+ Loss value.
340
+ """
341
+ x, target = batch
342
+
343
+ # Forward pass
344
+ out = self.model(x)
345
+
346
+ # Update loss parameters
347
+ # TODO rethink loss parameters
348
+ self.loss_parameters.current_epoch = self.current_epoch
349
+
350
+ # Compute loss
351
+ loss = self.loss_func(out, target, self.loss_parameters) # TODO ugly ?
352
+
353
+ # Logging
354
+ # TODO: implement a separate logging method?
355
+ self.log_dict(loss, on_step=True, on_epoch=True)
356
+ # self.log("lr", self, on_epoch=True)
357
+ return loss
358
+
359
+ def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: Any) -> None:
360
+ """Validation step.
361
+
362
+ Parameters
363
+ ----------
364
+ batch : tuple[Tensor, Tensor]
365
+ Input batch. It is a tuple with the input tensor and the target tensor.
366
+ The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
367
+ number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
368
+ where C is the number of target channels (e.g., 1 in HDN, >1 in
369
+ muSplit/denoiSplit).
370
+ batch_idx : Any
371
+ Batch index.
372
+ """
373
+ x, target = batch
374
+
375
+ # Forward pass
376
+ out = self.model(x)
377
+
378
+ # Compute loss
379
+ loss = self.loss_func(out, target, self.loss_parameters)
380
+
381
+ # Logging
382
+ # Rename val_loss dict
383
+ loss = {"_".join(["val", k]): v for k, v in loss.items()}
384
+ self.log_dict(loss, on_epoch=True, prog_bar=True)
385
+ curr_psnr = self.compute_val_psnr(out, target)
386
+ for i, psnr in enumerate(curr_psnr):
387
+ self.log(f"val_psnr_ch{i+1}_batch", psnr, on_epoch=True)
388
+
389
+ def on_validation_epoch_end(self) -> None:
390
+ """Validation epoch end."""
391
+ psnr_ = self.reduce_running_psnr()
392
+ if psnr_ is not None:
393
+ self.log("val_psnr", psnr_, on_epoch=True, prog_bar=True)
394
+ else:
395
+ self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
396
+
397
+ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
398
+ """Prediction step.
399
+
400
+ Parameters
401
+ ----------
402
+ batch : Tensor
403
+ Input batch.
404
+ batch_idx : Any
405
+ Batch index.
406
+
407
+ Returns
408
+ -------
409
+ Any
410
+ Model output.
411
+ """
412
+ if self._trainer.datamodule.tiled:
413
+ x, *aux = batch
414
+ else:
415
+ x = batch
416
+ aux = []
417
+
418
+ # apply test-time augmentation if available
419
+ # TODO: probably wont work with batch size > 1
420
+ if self._trainer.datamodule.prediction_config.tta_transforms:
421
+ tta = ImageRestorationTTA()
422
+ augmented_batch = tta.forward(x) # list of augmented tensors
423
+ augmented_output = []
424
+ for augmented in augmented_batch:
425
+ augmented_pred = self.model(augmented)
426
+ augmented_output.append(augmented_pred)
427
+ output = tta.backward(augmented_output)
428
+ else:
429
+ output = self.model(x)
430
+
431
+ # Denormalize the output
432
+ denorm = Denormalize(
433
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
434
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
435
+ )
436
+ denormalized_output = denorm(patch=output.cpu().numpy())
437
+
438
+ if len(aux) > 0: # aux can be tiling information
439
+ return denormalized_output, *aux
440
+ else:
441
+ return denormalized_output
442
+
443
+ def configure_optimizers(self) -> Any:
444
+ """Configure optimizers and learning rate schedulers.
445
+
446
+ Returns
447
+ -------
448
+ Any
449
+ Optimizer and learning rate scheduler.
450
+ """
451
+ # instantiate optimizer
452
+ optimizer_func = get_optimizer(self.optimizer_name)
453
+ optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
454
+
455
+ # and scheduler
456
+ scheduler_func = get_scheduler(self.lr_scheduler_name)
457
+ scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
458
+
459
+ return {
460
+ "optimizer": optimizer,
461
+ "lr_scheduler": scheduler,
462
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
463
+ }
464
+
465
+ # TODO: find a way to move the following methods to a separate module
466
+ # TODO: this same operation is done in many other places, like in loss_func
467
+ # should we refactor LadderVAE so that it already outputs
468
+ # tuple(`mean`, `logvar`, `td_data`)?
469
+ def get_reconstructed_tensor(
470
+ self, model_outputs: tuple[Tensor, dict[str, Any]]
471
+ ) -> Tensor:
472
+ """Get the reconstructed tensor from the LVAE model outputs.
473
+
474
+ Parameters
475
+ ----------
476
+ model_outputs : tuple[Tensor, dict[str, Any]]
477
+ Model outputs. It is a tuple with a tensor representing the predicted mean
478
+ and (optionally) logvar, and the top-down data dictionary.
479
+
480
+ Returns
481
+ -------
482
+ Tensor
483
+ Reconstructed tensor, i.e., the predicted mean.
484
+ """
485
+ predictions, _ = model_outputs
486
+ if self.model.predict_logvar is None:
487
+ return predictions
488
+ elif self.model.predict_logvar == "pixelwise":
489
+ return predictions.chunk(2, dim=1)[0]
490
+
491
+ def compute_val_psnr(
492
+ self,
493
+ model_output: tuple[Tensor, dict[str, Any]],
494
+ target: Tensor,
495
+ psnr_func: Callable = scale_invariant_psnr,
496
+ ) -> list[float]:
497
+ """Compute the PSNR for the current validation batch.
498
+
499
+ Parameters
500
+ ----------
501
+ model_output : tuple[Tensor, dict[str, Any]]
502
+ Model output, a tuple with the predicted mean and (optionally) logvar,
503
+ and the top-down data dictionary.
504
+ target : Tensor
505
+ Target tensor.
506
+ psnr_func : Callable, optional
507
+ PSNR function to use, by default `scale_invariant_psnr`.
508
+
509
+ Returns
510
+ -------
511
+ list[float]
512
+ PSNR for each channel in the current batch.
513
+ """
514
+ out_channels = target.shape[1]
515
+
516
+ # get the reconstructed image
517
+ recons_img = self.get_reconstructed_tensor(model_output)
518
+
519
+ # update running psnr
520
+ for i in range(out_channels):
521
+ self.running_psnr[i].update(rec=recons_img[:, i], tar=target[:, i])
522
+
523
+ # compute psnr for each channel in the current batch
524
+ # TODO: this doesn't need do be a method of this class
525
+ # and hence can be moved to a separate module
526
+ return [
527
+ psnr_func(
528
+ gt=target[:, i].clone().detach().cpu().numpy(),
529
+ pred=recons_img[:, i].clone().detach().cpu().numpy(),
530
+ )
531
+ for i in range(out_channels)
532
+ ]
533
+
534
+ def reduce_running_psnr(self) -> Optional[float]:
535
+ """Reduce the running PSNR statistics and reset the running PSNR.
536
+
537
+ Returns
538
+ -------
539
+ Optional[float]
540
+ Running PSNR averaged over the different output channels.
541
+ """
542
+ psnr_arr = [] # type: ignore
543
+ for i in range(len(self.running_psnr)):
544
+ psnr = self.running_psnr[i].get()
545
+ if psnr is None:
546
+ psnr_arr = None # type: ignore
547
+ break
548
+ psnr_arr.append(psnr.cpu().numpy())
549
+ self.running_psnr[i].reset()
550
+ # TODO: this line forces it to be a method of this class
551
+ # alternative is returning also the reset `running_psnr`
552
+ if psnr_arr is not None:
553
+ psnr = np.mean(psnr_arr)
554
+ return psnr
555
+
556
+
557
+ # TODO: make this LVAE compatible (?)
206
558
  def create_careamics_module(
207
559
  algorithm: Union[SupportedAlgorithm, str],
208
560
  loss: Union[SupportedLoss, str],
@@ -212,8 +564,8 @@ def create_careamics_module(
212
564
  optimizer_parameters: Optional[dict] = None,
213
565
  lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
214
566
  lr_scheduler_parameters: Optional[dict] = None,
215
- ) -> CAREamicsModule:
216
- """Create a CAREamics Lithgning module.
567
+ ) -> Union[FCNModule, VAEModule]:
568
+ """Create a CAREamics Lightning module.
217
569
 
218
570
  This function exposes parameters used to create an AlgorithmModel instance,
219
571
  triggering parameters validation.
@@ -254,7 +606,7 @@ def create_careamics_module(
254
606
  optimizer_parameters = {}
255
607
  if model_parameters is None:
256
608
  model_parameters = {}
257
- algorithm_configuration = {
609
+ algorithm_configuration: dict[str, Any] = {
258
610
  "algorithm": algorithm,
259
611
  "loss": loss,
260
612
  "optimizer": {
@@ -273,4 +625,10 @@ def create_careamics_module(
273
625
  algorithm_configuration["model"] = model_configuration
274
626
 
275
627
  # call the parent init using an AlgorithmModel instance
276
- return CAREamicsModule(AlgorithmConfig(**algorithm_configuration))
628
+ algorithm_str = algorithm_configuration["algorithm"]
629
+ if algorithm_str in FCNAlgorithmConfig.get_compatible_algorithms():
630
+ return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
631
+ else:
632
+ raise NotImplementedError(
633
+ f"Model {algorithm_str} is not implemented or unknown."
634
+ )
@@ -240,7 +240,7 @@ def create_predict_datamodule(
240
240
  ) -> PredictDataModule:
241
241
  """Create a CAREamics prediction Lightning datamodule.
242
242
 
243
- This function is used to explicitely pass the parameters usually contained in an
243
+ This function is used to explicitly pass the parameters usually contained in an
244
244
  `inference_model` configuration.
245
245
 
246
246
  Since the lightning datamodule has no access to the model, make sure that the
@@ -268,7 +268,7 @@ def create_predict_datamodule(
268
268
  data_type : {"array", "tiff", "custom"}
269
269
  Data type, see `SupportedData` for available options.
270
270
  axes : str
271
- Axes of the data, choosen among SCZYX.
271
+ Axes of the data, chosen among SCZYX.
272
272
  image_means : list of float
273
273
  Mean values for normalization, only used if Normalization is defined.
274
274
  image_stds : list of float
@@ -9,8 +9,8 @@ from numpy.typing import NDArray
9
9
  from torch.utils.data import DataLoader
10
10
 
11
11
  from careamics.config import DataConfig
12
- from careamics.config.data_model import TRANSFORMS_UNION
13
12
  from careamics.config.support import SupportedData
13
+ from careamics.config.transformations import TransformModel
14
14
  from careamics.dataset.dataset_utils import (
15
15
  get_files_size,
16
16
  list_files,
@@ -472,7 +472,7 @@ def create_train_datamodule(
472
472
  axes: str,
473
473
  batch_size: int,
474
474
  val_data: Optional[Union[str, Path, NDArray]] = None,
475
- transforms: Optional[list[TRANSFORMS_UNION]] = None,
475
+ transforms: Optional[list[TransformModel]] = None,
476
476
  train_target_data: Optional[Union[str, Path, NDArray]] = None,
477
477
  val_target_data: Optional[Union[str, Path, NDArray]] = None,
478
478
  read_source_func: Optional[Callable] = None,
@@ -487,7 +487,7 @@ def create_train_datamodule(
487
487
  ) -> TrainDataModule:
488
488
  """Create a TrainDataModule.
489
489
 
490
- This function is used to explicitely pass the parameters usually contained in a
490
+ This function is used to explicitly pass the parameters usually contained in a
491
491
  `data_model` configuration to a TrainDataModule.
492
492
 
493
493
  Since the lightning datamodule has no access to the model, make sure that the
@@ -537,7 +537,7 @@ def create_train_datamodule(
537
537
  patch_size : list of int
538
538
  Patch size, 2D or 3D patch size.
539
539
  axes : str
540
- Axes of the data, choosen amongst SCZYX.
540
+ Axes of the data, chosen amongst SCZYX.
541
541
  batch_size : int
542
542
  Batch size.
543
543
  val_data : pathlib.Path or str or numpy.ndarray, optional
@@ -1,5 +1,15 @@
1
1
  """Losses module."""
2
2
 
3
- __all__ = ["loss_factory"]
3
+ __all__ = [
4
+ "loss_factory",
5
+ "mae_loss",
6
+ "mse_loss",
7
+ "n2v_loss",
8
+ "denoisplit_loss",
9
+ "musplit_loss",
10
+ "denoisplit_musplit_loss",
11
+ ]
4
12
 
13
+ from .fcn.losses import mae_loss, mse_loss, n2v_loss
5
14
  from .loss_factory import loss_factory
15
+ from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
@@ -0,0 +1 @@
1
+ """FCN losses."""
@@ -53,7 +53,7 @@ def n2v_loss(
53
53
  errors = (original_patches - manipulated_patches) ** 2
54
54
  # Average over pixels and batch
55
55
  loss = torch.sum(errors * masks) / torch.sum(masks)
56
- return loss
56
+ return loss # TODO change output to dict ?
57
57
 
58
58
 
59
59
  def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
@@ -4,14 +4,114 @@ Loss factory module.
4
4
  This module contains a factory function for creating loss functions.
5
5
  """
6
6
 
7
- from typing import Callable, Union
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
11
+
12
+ from torch import Tensor as tensor
8
13
 
9
14
  from ..config.support import SupportedLoss
10
- from .losses import mae_loss, mse_loss, n2v_loss
15
+ from .fcn.losses import mae_loss, mse_loss, n2v_loss
16
+ from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
17
+
18
+ if TYPE_CHECKING:
19
+ from careamics.models.lvae.likelihoods import (
20
+ GaussianLikelihood,
21
+ NoiseModelLikelihood,
22
+ )
23
+ from careamics.models.lvae.noise_models import (
24
+ GaussianMixtureNoiseModel,
25
+ MultiChannelNoiseModel,
26
+ )
27
+
28
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
29
+
30
+
31
+ @dataclass
32
+ class FCNLossParameters:
33
+ """Dataclass for FCN loss."""
34
+
35
+ # TODO check
36
+ prediction: tensor
37
+ targets: tensor
38
+ mask: tensor
39
+ current_epoch: int
40
+ loss_weight: float
41
+
42
+
43
+ @dataclass # TODO why not pydantic?
44
+ class LVAELossParameters:
45
+ """Dataclass for LVAE loss."""
46
+
47
+ # TODO: refactor in more modular blocks (otherwise it gets messy very easily)
48
+ # e.g., - weights, - kl_params, ...
49
+
50
+ noise_model_likelihood: Optional[NoiseModelLikelihood] = None
51
+ """Noise model likelihood instance."""
52
+ gaussian_likelihood: Optional[GaussianLikelihood] = None
53
+ """Gaussian likelihood instance."""
54
+ current_epoch: int = 0
55
+ """Current epoch in the training loop."""
56
+ reconstruction_weight: float = 1.0
57
+ """Weight for the reconstruction loss in the total net loss
58
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
59
+ musplit_weight: float = 0.1
60
+ """Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
61
+ denoisplit_weight: float = 0.9
62
+ """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
63
+ kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
64
+ """Type of KL divergence used as KL loss."""
65
+ kl_weight: float = 1.0
66
+ """Weight for the KL loss in the total net loss.
67
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
68
+ kl_annealing: bool = False
69
+ """Whether to apply KL loss annealing."""
70
+ kl_start: int = -1
71
+ """Epoch at which KL loss annealing starts."""
72
+ kl_annealtime: int = 10
73
+ """Number of epochs for which KL loss annealing is applied."""
74
+ non_stochastic: bool = False
75
+ """Whether to sample latents and compute KL."""
76
+
77
+
78
+ # TODO: really needed?
79
+ # like it is now, it is difficult to use, we need a way to specify the
80
+ # loss parameters in a more user-friendly way.
81
+ def loss_parameters_factory(
82
+ type: SupportedLoss,
83
+ ) -> Union[FCNLossParameters, LVAELossParameters]:
84
+ """Return loss parameters.
85
+
86
+ Parameters
87
+ ----------
88
+ type : SupportedLoss
89
+ Requested loss.
90
+
91
+ Returns
92
+ -------
93
+ Union[FCNLossParameters, LVAELossParameters]
94
+ Loss parameters.
95
+
96
+ Raises
97
+ ------
98
+ NotImplementedError
99
+ If the loss is unknown.
100
+ """
101
+ if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
102
+ return FCNLossParameters
103
+
104
+ elif type in [
105
+ SupportedLoss.MUSPLIT,
106
+ SupportedLoss.DENOISPLIT,
107
+ SupportedLoss.DENOISPLIT_MUSPLIT,
108
+ ]:
109
+ return LVAELossParameters # it returns the class, not an instance
110
+
111
+ else:
112
+ raise NotImplementedError(f"Loss {type} is not yet supported.")
11
113
 
12
114
 
13
- # TODO add tests
14
- # TODO add custom?
15
115
  def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
16
116
  """Return loss function.
17
117
 
@@ -42,8 +142,14 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
42
142
  elif loss == SupportedLoss.MSE:
43
143
  return mse_loss
44
144
 
45
- # elif loss_type == SupportedLoss.DICE:
46
- # return dice_loss
145
+ elif loss == SupportedLoss.MUSPLIT:
146
+ return musplit_loss
147
+
148
+ elif loss == SupportedLoss.DENOISPLIT:
149
+ return denoisplit_loss
150
+
151
+ elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
152
+ return denoisplit_musplit_loss
47
153
 
48
154
  else:
49
155
  raise NotImplementedError(f"Loss {loss} is not yet supported.")
@@ -0,0 +1 @@
1
+ """LVAE losses."""