careamics 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (59) hide show
  1. careamics/careamist.py +6 -12
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +9 -8
  12. careamics/config/configuration_factories.py +843 -29
  13. careamics/config/data/data_model.py +1 -2
  14. careamics/config/data/ng_data_model.py +1 -2
  15. careamics/config/inference_model.py +1 -2
  16. careamics/config/likelihood_model.py +2 -2
  17. careamics/config/loss_model.py +6 -2
  18. careamics/config/nm_model.py +26 -1
  19. careamics/config/optimizer_models.py +1 -2
  20. careamics/config/support/supported_algorithms.py +5 -3
  21. careamics/config/support/supported_losses.py +5 -2
  22. careamics/config/training_model.py +6 -36
  23. careamics/config/transformations/normalize_model.py +1 -2
  24. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  25. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  26. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  27. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  28. careamics/file_io/read/__init__.py +0 -1
  29. careamics/lightning/__init__.py +16 -2
  30. careamics/lightning/callbacks/__init__.py +2 -0
  31. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  32. careamics/lightning/lightning_module.py +161 -61
  33. careamics/lightning/microsplit_data_module.py +631 -0
  34. careamics/lightning/predict_data_module.py +8 -1
  35. careamics/lightning/train_data_module.py +19 -8
  36. careamics/losses/__init__.py +7 -1
  37. careamics/losses/loss_factory.py +9 -1
  38. careamics/losses/lvae/losses.py +85 -0
  39. careamics/lvae_training/dataset/__init__.py +8 -8
  40. careamics/lvae_training/dataset/config.py +56 -44
  41. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  42. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  43. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  44. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  45. careamics/model_io/bmz_io.py +9 -5
  46. careamics/models/lvae/likelihoods.py +30 -14
  47. careamics/models/lvae/lvae.py +2 -2
  48. careamics/models/lvae/noise_models.py +20 -14
  49. careamics/prediction_utils/__init__.py +8 -2
  50. careamics/prediction_utils/prediction_outputs.py +48 -3
  51. careamics/prediction_utils/stitch_prediction.py +71 -0
  52. careamics/transforms/xy_random_rotate90.py +1 -1
  53. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -15
  54. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/RECORD +57 -55
  55. careamics/dataset/zarr_dataset.py +0 -151
  56. careamics/file_io/read/zarr.py +0 -60
  57. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  58. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  59. {careamics-0.0.15.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -5,7 +5,7 @@ from typing import Any, Literal, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
8
- from torch import Tensor, nn
8
+ import torch
9
9
 
10
10
  from careamics.config import (
11
11
  N2VAlgorithm,
@@ -71,7 +71,9 @@ class FCNModule(L.LightningModule):
71
71
  Learning rate scheduler name.
72
72
  """
73
73
 
74
- def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None:
74
+ def __init__(
75
+ self, algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm, dict]
76
+ ) -> None:
75
77
  """Lightning module for CAREamics.
76
78
 
77
79
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -98,7 +100,7 @@ class FCNModule(L.LightningModule):
98
100
  self.n2v_preprocess = None
99
101
 
100
102
  self.algorithm = algorithm_config.algorithm
101
- self.model: nn.Module = model_factory(algorithm_config.model)
103
+ self.model: torch.nn.Module = model_factory(algorithm_config.model)
102
104
  self.loss_func = loss_factory(algorithm_config.loss)
103
105
 
104
106
  # save optimizer and lr_scheduler names and parameters
@@ -122,12 +124,12 @@ class FCNModule(L.LightningModule):
122
124
  """
123
125
  return self.model(x)
124
126
 
125
- def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
127
+ def training_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
126
128
  """Training step.
127
129
 
128
130
  Parameters
129
131
  ----------
130
- batch : torch.Tensor
132
+ batch : torch.torch.Tensor
131
133
  Input batch.
132
134
  batch_idx : Any
133
135
  Batch index.
@@ -154,12 +156,12 @@ class FCNModule(L.LightningModule):
154
156
  self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
155
157
  return loss
156
158
 
157
- def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
159
+ def validation_step(self, batch: torch.Tensor, batch_idx: Any) -> None:
158
160
  """Validation step.
159
161
 
160
162
  Parameters
161
163
  ----------
162
- batch : torch.Tensor
164
+ batch : torch.torch.Tensor
163
165
  Input batch.
164
166
  batch_idx : Any
165
167
  Batch index.
@@ -184,12 +186,12 @@ class FCNModule(L.LightningModule):
184
186
  logger=True,
185
187
  )
186
188
 
187
- def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
189
+ def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
188
190
  """Prediction step.
189
191
 
190
192
  Parameters
191
193
  ----------
192
- batch : torch.Tensor
194
+ batch : torch.torch.torch.Tensor
193
195
  Input batch.
194
196
  batch_idx : Any
195
197
  Batch index.
@@ -330,17 +332,21 @@ class VAEModule(L.LightningModule):
330
332
  # self.save_hyperparameters(self.algorithm_config.model_dump())
331
333
 
332
334
  # create model
333
- self.model: nn.Module = model_factory(self.algorithm_config.model)
335
+ self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
334
336
 
337
+ # supervised_mode
338
+ self.supervised_mode = self.algorithm_config.is_supervised
335
339
  # create loss function
336
340
  self.noise_model: NoiseModel | None = noise_model_factory(
337
341
  self.algorithm_config.noise_model
338
342
  )
339
343
 
340
- self.noise_model_likelihood: NoiseModelLikelihood | None = likelihood_factory(
341
- config=self.algorithm_config.noise_model_likelihood,
342
- noise_model=self.noise_model,
343
- )
344
+ self.noise_model_likelihood: NoiseModelLikelihood | None = None
345
+ if self.algorithm_config.noise_model_likelihood is not None:
346
+ self.noise_model_likelihood = likelihood_factory(
347
+ config=self.algorithm_config.noise_model_likelihood,
348
+ noise_model=self.noise_model,
349
+ )
344
350
 
345
351
  self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
346
352
  self.algorithm_config.gaussian_likelihood
@@ -360,30 +366,43 @@ class VAEModule(L.LightningModule):
360
366
  RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
361
367
  ]
362
368
 
363
- def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
369
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
364
370
  """Forward pass.
365
371
 
366
372
  Parameters
367
373
  ----------
368
- x : Tensor
374
+ x : torch.Tensor
369
375
  Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
370
376
  number of lateral inputs.
371
377
 
372
378
  Returns
373
379
  -------
374
- tuple[Tensor, dict[str, Any]]
380
+ tuple[torch.Tensor, dict[str, Any]]
375
381
  A tuple with the output tensor and additional data from the top-down pass.
376
382
  """
377
383
  return self.model(x) # TODO Different model can have more than one output
378
384
 
385
+ def set_data_stats(self, data_mean, data_std):
386
+ """Set data mean and std for the noise model likelihood.
387
+
388
+ Parameters
389
+ ----------
390
+ data_mean : float
391
+ Mean of the data.
392
+ data_std : float
393
+ Standard deviation of the data.
394
+ """
395
+ if self.noise_model_likelihood is not None:
396
+ self.noise_model_likelihood.set_data_stats(data_mean, data_std)
397
+
379
398
  def training_step(
380
- self, batch: tuple[Tensor, Tensor], batch_idx: Any
381
- ) -> dict[str, Tensor] | None:
399
+ self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
400
+ ) -> dict[str, torch.Tensor] | None:
382
401
  """Training step.
383
402
 
384
403
  Parameters
385
404
  ----------
386
- batch : tuple[Tensor, Tensor]
405
+ batch : tuple[torch.Tensor, torch.Tensor]
387
406
  Input batch. It is a tuple with the input tensor and the target tensor.
388
407
  The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
389
408
  number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
@@ -397,15 +416,29 @@ class VAEModule(L.LightningModule):
397
416
  Any
398
417
  Loss value.
399
418
  """
400
- x, target = batch
419
+ x, *target = batch
401
420
 
402
421
  # Forward pass
403
422
  out = self.model(x)
423
+ if not self.supervised_mode:
424
+ target = x
425
+ else:
426
+ target = target[
427
+ 0
428
+ ] # hacky way to unpack. #TODO maybe should be fixed on the dataset level
404
429
 
405
430
  # Update loss parameters
406
431
  self.loss_parameters.kl_params.current_epoch = self.current_epoch
407
432
 
408
433
  # Compute loss
434
+ if self.noise_model_likelihood is not None:
435
+ if (
436
+ self.noise_model_likelihood.data_mean is None
437
+ or self.noise_model_likelihood.data_std is None
438
+ ):
439
+ raise RuntimeError(
440
+ "NoiseModelLikelihood: data_mean and data_std must be set before training."
441
+ )
409
442
  loss = self.loss_func(
410
443
  model_outputs=out,
411
444
  targets=target,
@@ -417,15 +450,26 @@ class VAEModule(L.LightningModule):
417
450
  # Logging
418
451
  # TODO: implement a separate logging method?
419
452
  self.log_dict(loss, on_step=True, on_epoch=True)
420
- # self.log("lr", self, on_epoch=True)
453
+
454
+ try:
455
+ optimizer = self.optimizers()
456
+ current_lr = optimizer.param_groups[0]["lr"]
457
+ self.log(
458
+ "learning_rate", current_lr, on_step=False, on_epoch=True, logger=True
459
+ )
460
+ except RuntimeError:
461
+ # This happens when the module is not attached to a trainer, e.g., in tests
462
+ pass
421
463
  return loss
422
464
 
423
- def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: Any) -> None:
465
+ def validation_step(
466
+ self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
467
+ ) -> None:
424
468
  """Validation step.
425
469
 
426
470
  Parameters
427
471
  ----------
428
- batch : tuple[Tensor, Tensor]
472
+ batch : tuple[torch.Tensor, torch.Tensor]
429
473
  Input batch. It is a tuple with the input tensor and the target tensor.
430
474
  The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
431
475
  number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
@@ -434,11 +478,16 @@ class VAEModule(L.LightningModule):
434
478
  batch_idx : Any
435
479
  Batch index.
436
480
  """
437
- x, target = batch
481
+ x, *target = batch
438
482
 
439
483
  # Forward pass
440
484
  out = self.model(x)
441
-
485
+ if not self.supervised_mode:
486
+ target = x
487
+ else:
488
+ target = target[
489
+ 0
490
+ ] # hacky way to unpack. #TODO maybe should be fixed on the datasel level
442
491
  # Compute loss
443
492
  loss = self.loss_func(
444
493
  model_outputs=out,
@@ -464,12 +513,12 @@ class VAEModule(L.LightningModule):
464
513
  else:
465
514
  self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
466
515
 
467
- def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
516
+ def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
468
517
  """Prediction step.
469
518
 
470
519
  Parameters
471
520
  ----------
472
- batch : Tensor
521
+ batch : torch.Tensor
473
522
  Input batch.
474
523
  batch_idx : Any
475
524
  Batch index.
@@ -479,36 +528,86 @@ class VAEModule(L.LightningModule):
479
528
  Any
480
529
  Model output.
481
530
  """
482
- if self._trainer.datamodule.tiled:
531
+ if self.algorithm_config.algorithm == "microsplit":
483
532
  x, *aux = batch
484
- else:
485
- x = batch
486
- aux = []
533
+ # Reset model for inference with spatial dimensions only (H, W)
534
+ self.model.reset_for_inference(x.shape[-2:])
487
535
 
488
- # apply test-time augmentation if available
489
- # TODO: probably wont work with batch size > 1
490
- if self._trainer.datamodule.prediction_config.tta_transforms:
491
- tta = ImageRestorationTTA()
492
- augmented_batch = tta.forward(x) # list of augmented tensors
493
- augmented_output = []
494
- for augmented in augmented_batch:
495
- augmented_pred = self.model(augmented)
496
- augmented_output.append(augmented_pred)
497
- output = tta.backward(augmented_output)
498
- else:
499
- output = self.model(x)
536
+ rec_img_list = []
537
+ for _ in range(self.algorithm_config.mmse_count):
538
+ # get model output
539
+ rec, _ = self.model(x)
500
540
 
501
- # Denormalize the output
502
- denorm = Denormalize(
503
- image_means=self._trainer.datamodule.predict_dataset.image_means,
504
- image_stds=self._trainer.datamodule.predict_dataset.image_stds,
505
- )
506
- denormalized_output = denorm(patch=output.cpu().numpy())
541
+ # get reconstructed img
542
+ if self.model.predict_logvar is None:
543
+ rec_img = rec
544
+ logvar = torch.tensor([-1])
545
+ else:
546
+ rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
547
+ rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
548
+
549
+ # aggregate results
550
+ samples = torch.cat(rec_img_list, dim=0)
551
+ mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
552
+ std_imgs = torch.std(samples, dim=0) # std over MMSE dim
553
+
554
+ tile_prediction = mmse_imgs.cpu().numpy()
555
+ tile_std = std_imgs.cpu().numpy()
556
+
557
+ return tile_prediction, tile_std
507
558
 
508
- if len(aux) > 0: # aux can be tiling information
509
- return denormalized_output, *aux
510
559
  else:
511
- return denormalized_output
560
+ # Regular prediction logic
561
+ if self._trainer.datamodule.tiled:
562
+ # TODO tile_size should match model input size
563
+ x, *aux = batch
564
+ x = (
565
+ x[0] if isinstance(x, list | tuple) else x
566
+ ) # TODO ugly, so far i don't know why x might be a list
567
+ self.model.reset_for_inference(x.shape) # TODO should it be here ?
568
+ else:
569
+ x = batch[0] if isinstance(batch, list | tuple) else batch
570
+ aux = []
571
+ self.model.reset_for_inference(x.shape)
572
+
573
+ mmse_list = []
574
+ for _ in range(self.algorithm_config.mmse_count):
575
+ # apply test-time augmentation if available
576
+ if self._trainer.datamodule.prediction_config.tta_transforms:
577
+ tta = ImageRestorationTTA()
578
+ augmented_batch = tta.forward(x) # list of augmented tensors
579
+ augmented_output = []
580
+ for augmented in augmented_batch:
581
+ augmented_pred = self.model(augmented)
582
+ augmented_output.append(augmented_pred)
583
+ output = tta.backward(augmented_output)
584
+ else:
585
+ output = self.model(x)
586
+
587
+ # taking the 1st element of the output, 2nd is std if
588
+ # predict_logvar=="pixelwise"
589
+ output = (
590
+ output[0]
591
+ if self.model.predict_logvar is None
592
+ else output[0][:, 0:1, ...]
593
+ )
594
+ mmse_list.append(output)
595
+
596
+ mmse = torch.stack(mmse_list).mean(0)
597
+ std = torch.stack(mmse_list).std(0) # TODO why?
598
+ # TODO better way to unpack if pred logvar
599
+ # Denormalize the output
600
+ denorm = Denormalize(
601
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
602
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
603
+ )
604
+
605
+ denormalized_output = denorm(patch=mmse.cpu().numpy())
606
+
607
+ if len(aux) > 0: # aux can be tiling information
608
+ return denormalized_output, std, *aux
609
+ else:
610
+ return denormalized_output, std
512
611
 
513
612
  def configure_optimizers(self) -> Any:
514
613
  """Configure optimizers and learning rate schedulers.
@@ -537,19 +636,19 @@ class VAEModule(L.LightningModule):
537
636
  # should we refactor LadderVAE so that it already outputs
538
637
  # tuple(`mean`, `logvar`, `td_data`)?
539
638
  def get_reconstructed_tensor(
540
- self, model_outputs: tuple[Tensor, dict[str, Any]]
541
- ) -> Tensor:
639
+ self, model_outputs: tuple[torch.Tensor, dict[str, Any]]
640
+ ) -> torch.Tensor:
542
641
  """Get the reconstructed tensor from the LVAE model outputs.
543
642
 
544
643
  Parameters
545
644
  ----------
546
- model_outputs : tuple[Tensor, dict[str, Any]]
645
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
547
646
  Model outputs. It is a tuple with a tensor representing the predicted mean
548
647
  and (optionally) logvar, and the top-down data dictionary.
549
648
 
550
649
  Returns
551
650
  -------
552
- Tensor
651
+ torch.Tensor
553
652
  Reconstructed tensor, i.e., the predicted mean.
554
653
  """
555
654
  predictions, _ = model_outputs
@@ -560,18 +659,18 @@ class VAEModule(L.LightningModule):
560
659
 
561
660
  def compute_val_psnr(
562
661
  self,
563
- model_output: tuple[Tensor, dict[str, Any]],
564
- target: Tensor,
662
+ model_output: tuple[torch.Tensor, dict[str, Any]],
663
+ target: torch.Tensor,
565
664
  psnr_func: Callable = scale_invariant_psnr,
566
665
  ) -> list[float]:
567
666
  """Compute the PSNR for the current validation batch.
568
667
 
569
668
  Parameters
570
669
  ----------
571
- model_output : tuple[Tensor, dict[str, Any]]
670
+ model_output : tuple[torch.Tensor, dict[str, Any]]
572
671
  Model output, a tuple with the predicted mean and (optionally) logvar,
573
672
  and the top-down data dictionary.
574
- target : Tensor
673
+ target : torch.Tensor
575
674
  Target tensor.
576
675
  psnr_func : Callable, optional
577
676
  PSNR function to use, by default `scale_invariant_psnr`.
@@ -581,6 +680,7 @@ class VAEModule(L.LightningModule):
581
680
  list[float]
582
681
  PSNR for each channel in the current batch.
583
682
  """
683
+ # TODO check this! Related to is_supervised which is also wacky
584
684
  out_channels = target.shape[1]
585
685
 
586
686
  # get the reconstructed image