careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (64) hide show
  1. careamics/careamist.py +14 -11
  2. careamics/config/__init__.py +7 -3
  3. careamics/config/architectures/__init__.py +2 -2
  4. careamics/config/architectures/architecture_model.py +1 -1
  5. careamics/config/architectures/custom_model.py +11 -8
  6. careamics/config/architectures/lvae_model.py +174 -0
  7. careamics/config/configuration_factory.py +11 -3
  8. careamics/config/configuration_model.py +7 -3
  9. careamics/config/data_model.py +33 -8
  10. careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
  11. careamics/config/likelihood_model.py +43 -0
  12. careamics/config/nm_model.py +101 -0
  13. careamics/config/support/supported_activations.py +1 -0
  14. careamics/config/support/supported_algorithms.py +17 -4
  15. careamics/config/support/supported_architectures.py +8 -11
  16. careamics/config/support/supported_losses.py +3 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +1 -1
  18. careamics/config/vae_algorithm_model.py +171 -0
  19. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  20. careamics/file_io/read/tiff.py +1 -1
  21. careamics/lightning/__init__.py +3 -2
  22. careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  23. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  24. careamics/lightning/lightning_module.py +365 -9
  25. careamics/lightning/predict_data_module.py +2 -2
  26. careamics/lightning/train_data_module.py +2 -2
  27. careamics/losses/__init__.py +11 -1
  28. careamics/losses/fcn/__init__.py +1 -0
  29. careamics/losses/{losses.py → fcn/losses.py} +1 -1
  30. careamics/losses/loss_factory.py +112 -6
  31. careamics/losses/lvae/__init__.py +1 -0
  32. careamics/losses/lvae/loss_utils.py +83 -0
  33. careamics/losses/lvae/losses.py +445 -0
  34. careamics/lvae_training/dataset/__init__.py +0 -0
  35. careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
  36. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  37. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  38. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  39. careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
  40. careamics/lvae_training/get_config.py +1 -1
  41. careamics/lvae_training/train_lvae.py +6 -3
  42. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  43. careamics/model_io/bioimage/model_description.py +2 -2
  44. careamics/model_io/bmz_io.py +19 -6
  45. careamics/model_io/model_io_utils.py +16 -4
  46. careamics/models/__init__.py +1 -3
  47. careamics/models/activation.py +2 -0
  48. careamics/models/lvae/__init__.py +3 -0
  49. careamics/models/lvae/layers.py +21 -21
  50. careamics/models/lvae/likelihoods.py +180 -128
  51. careamics/models/lvae/lvae.py +52 -136
  52. careamics/models/lvae/noise_models.py +318 -186
  53. careamics/models/lvae/utils.py +2 -2
  54. careamics/models/model_factory.py +22 -7
  55. careamics/prediction_utils/lvae_prediction.py +158 -0
  56. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  57. careamics/prediction_utils/stitch_prediction.py +16 -2
  58. careamics/transforms/pixel_manipulation.py +1 -1
  59. careamics/utils/metrics.py +74 -1
  60. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
  61. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
  62. careamics/config/architectures/vae_model.py +0 -42
  63. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
  64. {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
- """Funtions to read tiff images."""
1
+ """Functions to read tiff images."""
2
2
 
3
3
  import logging
4
4
  from fnmatch import fnmatch
@@ -1,7 +1,8 @@
1
1
  """CAREamics PyTorch Lightning modules."""
2
2
 
3
3
  __all__ = [
4
- "CAREamicsModule",
4
+ "FCNModule",
5
+ "VAEModule",
5
6
  "create_careamics_module",
6
7
  "TrainDataModule",
7
8
  "create_train_datamodule",
@@ -12,6 +13,6 @@ __all__ = [
12
13
  ]
13
14
 
14
15
  from .callbacks import HyperParametersCallback, ProgressBarCallback
15
- from .lightning_module import CAREamicsModule, create_careamics_module
16
+ from .lightning_module import FCNModule, VAEModule, create_careamics_module
16
17
  from .predict_data_module import PredictDataModule, create_predict_datamodule
17
18
  from .train_data_module import TrainDataModule, create_train_datamodule
@@ -10,7 +10,7 @@ class HyperParametersCallback(Callback):
10
10
  """
11
11
  Callback allowing saving CAREamics configuration as hyperparameters in the model.
12
12
 
13
- This allows saving the configuration as dictionnary in the checkpoints, and
13
+ This allows saving the configuration as dictionary in the checkpoints, and
14
14
  loading it subsequently in a CAREamist instance.
15
15
 
16
16
  Parameters
@@ -1,4 +1,4 @@
1
- """Module containing convienience function to create `WriteStrategy`."""
1
+ """Module containing convenience function to create `WriteStrategy`."""
2
2
 
3
3
  from typing import Any, Optional
4
4
 
@@ -1,11 +1,12 @@
1
1
  """CAREamics Lightning module."""
2
2
 
3
- from typing import Any, Optional, Union
3
+ from typing import Any, Callable, Literal, Optional, Union
4
4
 
5
+ import numpy as np
5
6
  import pytorch_lightning as L
6
7
  from torch import Tensor, nn
7
8
 
8
- from careamics.config import AlgorithmConfig
9
+ from careamics.config import FCNAlgorithmConfig, VAEAlgorithmConfig
9
10
  from careamics.config.support import (
10
11
  SupportedAlgorithm,
11
12
  SupportedArchitecture,
@@ -14,12 +15,26 @@ from careamics.config.support import (
14
15
  SupportedScheduler,
15
16
  )
16
17
  from careamics.losses import loss_factory
18
+ from careamics.losses.loss_factory import LVAELossParameters
19
+ from careamics.models.lvae.likelihoods import (
20
+ GaussianLikelihood,
21
+ NoiseModelLikelihood,
22
+ likelihood_factory,
23
+ )
24
+ from careamics.models.lvae.noise_models import (
25
+ GaussianMixtureNoiseModel,
26
+ MultiChannelNoiseModel,
27
+ noise_model_factory,
28
+ )
17
29
  from careamics.models.model_factory import model_factory
18
30
  from careamics.transforms import Denormalize, ImageRestorationTTA
31
+ from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr
19
32
  from careamics.utils.torch_utils import get_optimizer, get_scheduler
20
33
 
34
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
35
+
21
36
 
22
- class CAREamicsModule(L.LightningModule):
37
+ class FCNModule(L.LightningModule):
23
38
  """
24
39
  CAREamics Lightning module.
25
40
 
@@ -45,7 +60,7 @@ class CAREamicsModule(L.LightningModule):
45
60
  Learning rate scheduler name.
46
61
  """
47
62
 
48
- def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
63
+ def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None:
49
64
  """Lightning module for CAREamics.
50
65
 
51
66
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -59,7 +74,7 @@ class CAREamicsModule(L.LightningModule):
59
74
  super().__init__()
60
75
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
61
76
  if isinstance(algorithm_config, dict):
62
- algorithm_config = AlgorithmConfig(**algorithm_config)
77
+ algorithm_config = FCNAlgorithmConfig(**algorithm_config)
63
78
 
64
79
  # create model and loss function
65
80
  self.model: nn.Module = model_factory(algorithm_config.model)
@@ -203,7 +218,339 @@ class CAREamicsModule(L.LightningModule):
203
218
  }
204
219
 
205
220
 
221
+ class VAEModule(L.LightningModule):
222
+ """
223
+ CAREamics Lightning module.
224
+
225
+ This class encapsulates the a PyTorch model along with the training, validation,
226
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
227
+
228
+ Parameters
229
+ ----------
230
+ algorithm_config : Union[VAEAlgorithmConfig, dict]
231
+ Algorithm configuration.
232
+
233
+ Attributes
234
+ ----------
235
+ model : nn.Module
236
+ PyTorch model.
237
+ loss_func : nn.Module
238
+ Loss function.
239
+ optimizer_name : str
240
+ Optimizer name.
241
+ optimizer_params : dict
242
+ Optimizer parameters.
243
+ lr_scheduler_name : str
244
+ Learning rate scheduler name.
245
+ """
246
+
247
+ def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None:
248
+ """Lightning module for CAREamics.
249
+
250
+ This class encapsulates the a PyTorch model along with the training, validation,
251
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
252
+
253
+ Parameters
254
+ ----------
255
+ algorithm_config : Union[AlgorithmModel, dict]
256
+ Algorithm configuration.
257
+ """
258
+ super().__init__()
259
+ # if loading from a checkpoint, AlgorithmModel needs to be instantiated
260
+ self.algorithm_config = (
261
+ VAEAlgorithmConfig(**algorithm_config)
262
+ if isinstance(algorithm_config, dict)
263
+ else algorithm_config
264
+ )
265
+
266
+ # TODO: log algorithm config
267
+ # self.save_hyperparameters(self.algorithm_config.model_dump())
268
+
269
+ # create model and loss function
270
+ self.model: nn.Module = model_factory(self.algorithm_config.model)
271
+ self.noise_model: NoiseModel = noise_model_factory(
272
+ self.algorithm_config.noise_model
273
+ )
274
+ self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
275
+ self.algorithm_config.noise_model_likelihood_model
276
+ )
277
+ self.gaussian_likelihood: GaussianLikelihood = likelihood_factory(
278
+ self.algorithm_config.gaussian_likelihood_model
279
+ )
280
+ self.loss_parameters = LVAELossParameters(
281
+ noise_model_likelihood=self.noise_model_likelihood,
282
+ gaussian_likelihood=self.gaussian_likelihood,
283
+ # TODO: musplit/denoisplit weights ?
284
+ ) # type: ignore
285
+ self.loss_func = loss_factory(self.algorithm_config.loss)
286
+
287
+ # save optimizer and lr_scheduler names and parameters
288
+ self.optimizer_name = self.algorithm_config.optimizer.name
289
+ self.optimizer_params = self.algorithm_config.optimizer.parameters
290
+ self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name
291
+ self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters
292
+
293
+ # initialize running PSNR
294
+ self.running_psnr = [
295
+ RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
296
+ ]
297
+
298
+ def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
299
+ """Forward pass.
300
+
301
+ Parameters
302
+ ----------
303
+ x : Tensor
304
+ Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
305
+ number of lateral inputs.
306
+
307
+ Returns
308
+ -------
309
+ tuple[Tensor, dict[str, Any]]
310
+ A tuple with the output tensor and additional data from the top-down pass.
311
+ """
312
+ return self.model(x) # TODO Different model can have more than one output
313
+
314
+ def training_step(
315
+ self, batch: tuple[Tensor, Tensor], batch_idx: Any
316
+ ) -> Optional[dict[str, Tensor]]:
317
+ """Training step.
318
+
319
+ Parameters
320
+ ----------
321
+ batch : tuple[Tensor, Tensor]
322
+ Input batch. It is a tuple with the input tensor and the target tensor.
323
+ The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
324
+ number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
325
+ where C is the number of target channels (e.g., 1 in HDN, >1 in
326
+ muSplit/denoiSplit).
327
+ batch_idx : Any
328
+ Batch index.
329
+
330
+ Returns
331
+ -------
332
+ Any
333
+ Loss value.
334
+ """
335
+ x, target = batch
336
+
337
+ # Forward pass
338
+ out = self.model(x)
339
+
340
+ # Update loss parameters
341
+ # TODO rethink loss parameters
342
+ self.loss_parameters.current_epoch = self.current_epoch
343
+
344
+ # Compute loss
345
+ loss = self.loss_func(out, target, self.loss_parameters) # TODO ugly ?
346
+
347
+ # Logging
348
+ # TODO: implement a separate logging method?
349
+ self.log_dict(loss, on_step=True, on_epoch=True)
350
+ # self.log("lr", self, on_epoch=True)
351
+ return loss
352
+
353
+ def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: Any) -> None:
354
+ """Validation step.
355
+
356
+ Parameters
357
+ ----------
358
+ batch : tuple[Tensor, Tensor]
359
+ Input batch. It is a tuple with the input tensor and the target tensor.
360
+ The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
361
+ number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
362
+ where C is the number of target channels (e.g., 1 in HDN, >1 in
363
+ muSplit/denoiSplit).
364
+ batch_idx : Any
365
+ Batch index.
366
+ """
367
+ x, target = batch
368
+
369
+ # Forward pass
370
+ out = self.model(x)
371
+
372
+ # Compute loss
373
+ loss = self.loss_func(out, target, self.loss_parameters)
374
+
375
+ # Logging
376
+ # Rename val_loss dict
377
+ loss = {"_".join(["val", k]): v for k, v in loss.items()}
378
+ self.log_dict(loss, on_epoch=True, prog_bar=True)
379
+ curr_psnr = self.compute_val_psnr(out, target)
380
+ for i, psnr in enumerate(curr_psnr):
381
+ self.log(f"val_psnr_ch{i+1}_batch", psnr, on_epoch=True)
382
+
383
+ def on_validation_epoch_end(self) -> None:
384
+ """Validation epoch end."""
385
+ psnr_ = self.reduce_running_psnr()
386
+ if psnr_ is not None:
387
+ self.log("val_psnr", psnr_, on_epoch=True, prog_bar=True)
388
+ else:
389
+ self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
390
+
391
+ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
392
+ """Prediction step.
393
+
394
+ Parameters
395
+ ----------
396
+ batch : Tensor
397
+ Input batch.
398
+ batch_idx : Any
399
+ Batch index.
400
+
401
+ Returns
402
+ -------
403
+ Any
404
+ Model output.
405
+ """
406
+ if self._trainer.datamodule.tiled:
407
+ x, *aux = batch
408
+ else:
409
+ x = batch
410
+ aux = []
411
+
412
+ # apply test-time augmentation if available
413
+ # TODO: probably wont work with batch size > 1
414
+ if self._trainer.datamodule.prediction_config.tta_transforms:
415
+ tta = ImageRestorationTTA()
416
+ augmented_batch = tta.forward(x) # list of augmented tensors
417
+ augmented_output = []
418
+ for augmented in augmented_batch:
419
+ augmented_pred = self.model(augmented)
420
+ augmented_output.append(augmented_pred)
421
+ output = tta.backward(augmented_output)
422
+ else:
423
+ output = self.model(x)
424
+
425
+ # Denormalize the output
426
+ denorm = Denormalize(
427
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
428
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
429
+ )
430
+ denormalized_output = denorm(patch=output.cpu().numpy())
431
+
432
+ if len(aux) > 0: # aux can be tiling information
433
+ return denormalized_output, *aux
434
+ else:
435
+ return denormalized_output
436
+
437
+ def configure_optimizers(self) -> Any:
438
+ """Configure optimizers and learning rate schedulers.
439
+
440
+ Returns
441
+ -------
442
+ Any
443
+ Optimizer and learning rate scheduler.
444
+ """
445
+ # instantiate optimizer
446
+ optimizer_func = get_optimizer(self.optimizer_name)
447
+ optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
448
+
449
+ # and scheduler
450
+ scheduler_func = get_scheduler(self.lr_scheduler_name)
451
+ scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
452
+
453
+ return {
454
+ "optimizer": optimizer,
455
+ "lr_scheduler": scheduler,
456
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
457
+ }
458
+
459
+ # TODO: find a way to move the following methods to a separate module
460
+ # TODO: this same operation is done in many other places, like in loss_func
461
+ # should we refactor LadderVAE so that it already outputs
462
+ # tuple(`mean`, `logvar`, `td_data`)?
463
+ def get_reconstructed_tensor(
464
+ self, model_outputs: tuple[Tensor, dict[str, Any]]
465
+ ) -> Tensor:
466
+ """Get the reconstructed tensor from the LVAE model outputs.
467
+
468
+ Parameters
469
+ ----------
470
+ model_outputs : tuple[Tensor, dict[str, Any]]
471
+ Model outputs. It is a tuple with a tensor representing the predicted mean
472
+ and (optionally) logvar, and the top-down data dictionary.
473
+
474
+ Returns
475
+ -------
476
+ Tensor
477
+ Reconstructed tensor, i.e., the predicted mean.
478
+ """
479
+ predictions, _ = model_outputs
480
+ if self.model.predict_logvar is None:
481
+ return predictions
482
+ elif self.model.predict_logvar == "pixelwise":
483
+ return predictions.chunk(2, dim=1)[0]
484
+
485
+ def compute_val_psnr(
486
+ self,
487
+ model_output: tuple[Tensor, dict[str, Any]],
488
+ target: Tensor,
489
+ psnr_func: Callable = scale_invariant_psnr,
490
+ ) -> list[float]:
491
+ """Compute the PSNR for the current validation batch.
492
+
493
+ Parameters
494
+ ----------
495
+ model_output : tuple[Tensor, dict[str, Any]]
496
+ Model output, a tuple with the predicted mean and (optionally) logvar,
497
+ and the top-down data dictionary.
498
+ target : Tensor
499
+ Target tensor.
500
+ psnr_func : Callable, optional
501
+ PSNR function to use, by default `scale_invariant_psnr`.
502
+
503
+ Returns
504
+ -------
505
+ list[float]
506
+ PSNR for each channel in the current batch.
507
+ """
508
+ out_channels = target.shape[1]
509
+
510
+ # get the reconstructed image
511
+ recons_img = self.get_reconstructed_tensor(model_output)
512
+
513
+ # update running psnr
514
+ for i in range(out_channels):
515
+ self.running_psnr[i].update(rec=recons_img[:, i], tar=target[:, i])
516
+
517
+ # compute psnr for each channel in the current batch
518
+ # TODO: this doesn't need do be a method of this class
519
+ # and hence can be moved to a separate module
520
+ return [
521
+ psnr_func(
522
+ gt=target[:, i].clone().detach().cpu().numpy(),
523
+ pred=recons_img[:, i].clone().detach().cpu().numpy(),
524
+ )
525
+ for i in range(out_channels)
526
+ ]
527
+
528
+ def reduce_running_psnr(self) -> Optional[float]:
529
+ """Reduce the running PSNR statistics and reset the running PSNR.
530
+
531
+ Returns
532
+ -------
533
+ Optional[float]
534
+ Running PSNR averaged over the different output channels.
535
+ """
536
+ psnr_arr = [] # type: ignore
537
+ for i in range(len(self.running_psnr)):
538
+ psnr = self.running_psnr[i].get()
539
+ if psnr is None:
540
+ psnr_arr = None # type: ignore
541
+ break
542
+ psnr_arr.append(psnr.cpu().numpy())
543
+ self.running_psnr[i].reset()
544
+ # TODO: this line forces it to be a method of this class
545
+ # alternative is returning also the reset `running_psnr`
546
+ if psnr_arr is not None:
547
+ psnr = np.mean(psnr_arr)
548
+ return psnr
549
+
550
+
551
+ # TODO: make this LVAE compatible (?)
206
552
  def create_careamics_module(
553
+ algorithm_type: Literal["fcn"],
207
554
  algorithm: Union[SupportedAlgorithm, str],
208
555
  loss: Union[SupportedLoss, str],
209
556
  architecture: Union[SupportedArchitecture, str],
@@ -212,14 +559,16 @@ def create_careamics_module(
212
559
  optimizer_parameters: Optional[dict] = None,
213
560
  lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
214
561
  lr_scheduler_parameters: Optional[dict] = None,
215
- ) -> CAREamicsModule:
216
- """Create a CAREamics Lithgning module.
562
+ ) -> Union[FCNModule, VAEModule]:
563
+ """Create a CAREamics Lightning module.
217
564
 
218
565
  This function exposes parameters used to create an AlgorithmModel instance,
219
566
  triggering parameters validation.
220
567
 
221
568
  Parameters
222
569
  ----------
570
+ algorithm_type : Literal["fcn"]
571
+ Algorithm type to use for training.
223
572
  algorithm : SupportedAlgorithm or str
224
573
  Algorithm to use for training (see SupportedAlgorithm).
225
574
  loss : SupportedLoss or str
@@ -254,7 +603,8 @@ def create_careamics_module(
254
603
  optimizer_parameters = {}
255
604
  if model_parameters is None:
256
605
  model_parameters = {}
257
- algorithm_configuration = {
606
+ algorithm_configuration: dict[str, Any] = {
607
+ "algorithm_type": algorithm_type,
258
608
  "algorithm": algorithm,
259
609
  "loss": loss,
260
610
  "optimizer": {
@@ -273,4 +623,10 @@ def create_careamics_module(
273
623
  algorithm_configuration["model"] = model_configuration
274
624
 
275
625
  # call the parent init using an AlgorithmModel instance
276
- return CAREamicsModule(AlgorithmConfig(**algorithm_configuration))
626
+ if algorithm_configuration["algorithm_type"] == "fcn":
627
+ return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
628
+ else:
629
+ raise NotImplementedError(
630
+ f"Model {algorithm_configuration['model']['architecture']} is not"
631
+ f"implemented or unknown."
632
+ )
@@ -240,7 +240,7 @@ def create_predict_datamodule(
240
240
  ) -> PredictDataModule:
241
241
  """Create a CAREamics prediction Lightning datamodule.
242
242
 
243
- This function is used to explicitely pass the parameters usually contained in an
243
+ This function is used to explicitly pass the parameters usually contained in an
244
244
  `inference_model` configuration.
245
245
 
246
246
  Since the lightning datamodule has no access to the model, make sure that the
@@ -268,7 +268,7 @@ def create_predict_datamodule(
268
268
  data_type : {"array", "tiff", "custom"}
269
269
  Data type, see `SupportedData` for available options.
270
270
  axes : str
271
- Axes of the data, choosen among SCZYX.
271
+ Axes of the data, chosen among SCZYX.
272
272
  image_means : list of float
273
273
  Mean values for normalization, only used if Normalization is defined.
274
274
  image_stds : list of float
@@ -487,7 +487,7 @@ def create_train_datamodule(
487
487
  ) -> TrainDataModule:
488
488
  """Create a TrainDataModule.
489
489
 
490
- This function is used to explicitely pass the parameters usually contained in a
490
+ This function is used to explicitly pass the parameters usually contained in a
491
491
  `data_model` configuration to a TrainDataModule.
492
492
 
493
493
  Since the lightning datamodule has no access to the model, make sure that the
@@ -537,7 +537,7 @@ def create_train_datamodule(
537
537
  patch_size : list of int
538
538
  Patch size, 2D or 3D patch size.
539
539
  axes : str
540
- Axes of the data, choosen amongst SCZYX.
540
+ Axes of the data, chosen amongst SCZYX.
541
541
  batch_size : int
542
542
  Batch size.
543
543
  val_data : pathlib.Path or str or numpy.ndarray, optional
@@ -1,5 +1,15 @@
1
1
  """Losses module."""
2
2
 
3
- __all__ = ["loss_factory"]
3
+ __all__ = [
4
+ "loss_factory",
5
+ "mae_loss",
6
+ "mse_loss",
7
+ "n2v_loss",
8
+ "denoisplit_loss",
9
+ "musplit_loss",
10
+ "denoisplit_musplit_loss",
11
+ ]
4
12
 
13
+ from .fcn.losses import mae_loss, mse_loss, n2v_loss
5
14
  from .loss_factory import loss_factory
15
+ from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
@@ -0,0 +1 @@
1
+ """FCN losses."""
@@ -53,7 +53,7 @@ def n2v_loss(
53
53
  errors = (original_patches - manipulated_patches) ** 2
54
54
  # Average over pixels and batch
55
55
  loss = torch.sum(errors * masks) / torch.sum(masks)
56
- return loss
56
+ return loss # TODO change output to dict ?
57
57
 
58
58
 
59
59
  def mae_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
@@ -4,14 +4,114 @@ Loss factory module.
4
4
  This module contains a factory function for creating loss functions.
5
5
  """
6
6
 
7
- from typing import Callable, Union
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
11
+
12
+ from torch import Tensor as tensor
8
13
 
9
14
  from ..config.support import SupportedLoss
10
- from .losses import mae_loss, mse_loss, n2v_loss
15
+ from .fcn.losses import mae_loss, mse_loss, n2v_loss
16
+ from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss
17
+
18
+ if TYPE_CHECKING:
19
+ from careamics.models.lvae.likelihoods import (
20
+ GaussianLikelihood,
21
+ NoiseModelLikelihood,
22
+ )
23
+ from careamics.models.lvae.noise_models import (
24
+ GaussianMixtureNoiseModel,
25
+ MultiChannelNoiseModel,
26
+ )
27
+
28
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
29
+
30
+
31
+ @dataclass
32
+ class FCNLossParameters:
33
+ """Dataclass for FCN loss."""
34
+
35
+ # TODO check
36
+ prediction: tensor
37
+ targets: tensor
38
+ mask: tensor
39
+ current_epoch: int
40
+ loss_weight: float
41
+
42
+
43
+ @dataclass # TODO why not pydantic?
44
+ class LVAELossParameters:
45
+ """Dataclass for LVAE loss."""
46
+
47
+ # TODO: refactor in more modular blocks (otherwise it gets messy very easily)
48
+ # e.g., - weights, - kl_params, ...
49
+
50
+ noise_model_likelihood: Optional[NoiseModelLikelihood] = None
51
+ """Noise model likelihood instance."""
52
+ gaussian_likelihood: Optional[GaussianLikelihood] = None
53
+ """Gaussian likelihood instance."""
54
+ current_epoch: int = 0
55
+ """Current epoch in the training loop."""
56
+ reconstruction_weight: float = 1.0
57
+ """Weight for the reconstruction loss in the total net loss
58
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
59
+ musplit_weight: float = 0.0
60
+ """Weight for the muSplit loss (used in the muSplit-deonoiSplit loss)."""
61
+ denoisplit_weight: float = 1.0
62
+ """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
63
+ kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
64
+ """Type of KL divergence used as KL loss."""
65
+ kl_weight: float = 1.0
66
+ """Weight for the KL loss in the total net loss.
67
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
68
+ kl_annealing: bool = False
69
+ """Whether to apply KL loss annealing."""
70
+ kl_start: int = -1
71
+ """Epoch at which KL loss annealing starts."""
72
+ kl_annealtime: int = 10
73
+ """Number of epochs for which KL loss annealing is applied."""
74
+ non_stochastic: bool = False
75
+ """Whether to sample latents and compute KL."""
76
+
77
+
78
+ # TODO: really needed?
79
+ # like it is now, it is difficult to use, we need a way to specify the
80
+ # loss parameters in a more user-friendly way.
81
+ def loss_parameters_factory(
82
+ type: SupportedLoss,
83
+ ) -> Union[FCNLossParameters, LVAELossParameters]:
84
+ """Return loss parameters.
85
+
86
+ Parameters
87
+ ----------
88
+ type : SupportedLoss
89
+ Requested loss.
90
+
91
+ Returns
92
+ -------
93
+ Union[FCNLossParameters, LVAELossParameters]
94
+ Loss parameters.
95
+
96
+ Raises
97
+ ------
98
+ NotImplementedError
99
+ If the loss is unknown.
100
+ """
101
+ if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
102
+ return FCNLossParameters
103
+
104
+ elif type in [
105
+ SupportedLoss.MUSPLIT,
106
+ SupportedLoss.DENOISPLIT,
107
+ SupportedLoss.DENOISPLIT_MUSPLIT,
108
+ ]:
109
+ return LVAELossParameters # it returns the class, not an instance
110
+
111
+ else:
112
+ raise NotImplementedError(f"Loss {type} is not yet supported.")
11
113
 
12
114
 
13
- # TODO add tests
14
- # TODO add custom?
15
115
  def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
16
116
  """Return loss function.
17
117
 
@@ -42,8 +142,14 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
42
142
  elif loss == SupportedLoss.MSE:
43
143
  return mse_loss
44
144
 
45
- # elif loss_type == SupportedLoss.DICE:
46
- # return dice_loss
145
+ elif loss == SupportedLoss.MUSPLIT:
146
+ return musplit_loss
147
+
148
+ elif loss == SupportedLoss.DENOISPLIT:
149
+ return denoisplit_loss
150
+
151
+ elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
152
+ return denoisplit_musplit_loss
47
153
 
48
154
  else:
49
155
  raise NotImplementedError(f"Loss {loss} is not yet supported.")
@@ -0,0 +1 @@
1
+ """LVAE losses."""