careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (79) hide show
  1. careamics/careamist.py +11 -14
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +15 -63
  12. careamics/config/configuration_factories.py +853 -29
  13. careamics/config/data/data_model.py +50 -11
  14. careamics/config/data/ng_data_model.py +168 -4
  15. careamics/config/data/patch_filter/__init__.py +15 -0
  16. careamics/config/data/patch_filter/filter_model.py +16 -0
  17. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  18. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  19. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  20. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  21. careamics/config/inference_model.py +1 -2
  22. careamics/config/likelihood_model.py +2 -2
  23. careamics/config/loss_model.py +6 -2
  24. careamics/config/nm_model.py +26 -1
  25. careamics/config/optimizer_models.py +1 -2
  26. careamics/config/support/supported_algorithms.py +5 -3
  27. careamics/config/support/supported_filters.py +17 -0
  28. careamics/config/support/supported_losses.py +5 -2
  29. careamics/config/training_model.py +6 -36
  30. careamics/config/transformations/normalize_model.py +1 -2
  31. careamics/dataset_ng/dataset.py +57 -5
  32. careamics/dataset_ng/factory.py +101 -18
  33. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  34. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  35. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  36. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  37. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  38. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  39. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  40. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  41. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  42. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  43. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  44. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  45. careamics/file_io/read/__init__.py +0 -1
  46. careamics/lightning/__init__.py +16 -2
  47. careamics/lightning/callbacks/__init__.py +2 -0
  48. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  49. careamics/lightning/dataset_ng/data_module.py +79 -2
  50. careamics/lightning/lightning_module.py +162 -61
  51. careamics/lightning/microsplit_data_module.py +636 -0
  52. careamics/lightning/predict_data_module.py +8 -1
  53. careamics/lightning/train_data_module.py +19 -8
  54. careamics/losses/__init__.py +7 -1
  55. careamics/losses/loss_factory.py +9 -1
  56. careamics/losses/lvae/losses.py +85 -0
  57. careamics/lvae_training/dataset/__init__.py +8 -8
  58. careamics/lvae_training/dataset/config.py +56 -44
  59. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  60. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  61. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  62. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  63. careamics/lvae_training/eval_utils.py +46 -24
  64. careamics/model_io/bmz_io.py +9 -5
  65. careamics/models/lvae/likelihoods.py +31 -14
  66. careamics/models/lvae/lvae.py +2 -2
  67. careamics/models/lvae/noise_models.py +20 -14
  68. careamics/prediction_utils/__init__.py +8 -2
  69. careamics/prediction_utils/prediction_outputs.py +49 -3
  70. careamics/prediction_utils/stitch_prediction.py +83 -1
  71. careamics/transforms/xy_random_rotate90.py +1 -1
  72. careamics/utils/version.py +4 -4
  73. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
  74. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
  75. careamics/dataset/zarr_dataset.py +0 -151
  76. careamics/file_io/read/zarr.py +0 -60
  77. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  78. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  79. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -5,7 +5,7 @@ from typing import Any, Literal, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
8
- from torch import Tensor, nn
8
+ import torch
9
9
 
10
10
  from careamics.config import (
11
11
  N2VAlgorithm,
@@ -71,7 +71,9 @@ class FCNModule(L.LightningModule):
71
71
  Learning rate scheduler name.
72
72
  """
73
73
 
74
- def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None:
74
+ def __init__(
75
+ self, algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm, dict]
76
+ ) -> None:
75
77
  """Lightning module for CAREamics.
76
78
 
77
79
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -98,7 +100,7 @@ class FCNModule(L.LightningModule):
98
100
  self.n2v_preprocess = None
99
101
 
100
102
  self.algorithm = algorithm_config.algorithm
101
- self.model: nn.Module = model_factory(algorithm_config.model)
103
+ self.model: torch.nn.Module = model_factory(algorithm_config.model)
102
104
  self.loss_func = loss_factory(algorithm_config.loss)
103
105
 
104
106
  # save optimizer and lr_scheduler names and parameters
@@ -122,12 +124,12 @@ class FCNModule(L.LightningModule):
122
124
  """
123
125
  return self.model(x)
124
126
 
125
- def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
127
+ def training_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
126
128
  """Training step.
127
129
 
128
130
  Parameters
129
131
  ----------
130
- batch : torch.Tensor
132
+ batch : torch.torch.Tensor
131
133
  Input batch.
132
134
  batch_idx : Any
133
135
  Batch index.
@@ -154,12 +156,12 @@ class FCNModule(L.LightningModule):
154
156
  self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
155
157
  return loss
156
158
 
157
- def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
159
+ def validation_step(self, batch: torch.Tensor, batch_idx: Any) -> None:
158
160
  """Validation step.
159
161
 
160
162
  Parameters
161
163
  ----------
162
- batch : torch.Tensor
164
+ batch : torch.torch.Tensor
163
165
  Input batch.
164
166
  batch_idx : Any
165
167
  Batch index.
@@ -184,12 +186,12 @@ class FCNModule(L.LightningModule):
184
186
  logger=True,
185
187
  )
186
188
 
187
- def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
189
+ def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
188
190
  """Prediction step.
189
191
 
190
192
  Parameters
191
193
  ----------
192
- batch : torch.Tensor
194
+ batch : torch.torch.torch.Tensor
193
195
  Input batch.
194
196
  batch_idx : Any
195
197
  Batch index.
@@ -330,17 +332,21 @@ class VAEModule(L.LightningModule):
330
332
  # self.save_hyperparameters(self.algorithm_config.model_dump())
331
333
 
332
334
  # create model
333
- self.model: nn.Module = model_factory(self.algorithm_config.model)
335
+ self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
334
336
 
337
+ # supervised_mode
338
+ self.supervised_mode = self.algorithm_config.is_supervised
335
339
  # create loss function
336
340
  self.noise_model: NoiseModel | None = noise_model_factory(
337
341
  self.algorithm_config.noise_model
338
342
  )
339
343
 
340
- self.noise_model_likelihood: NoiseModelLikelihood | None = likelihood_factory(
341
- config=self.algorithm_config.noise_model_likelihood,
342
- noise_model=self.noise_model,
343
- )
344
+ self.noise_model_likelihood: NoiseModelLikelihood | None = None
345
+ if self.algorithm_config.noise_model_likelihood is not None:
346
+ self.noise_model_likelihood = likelihood_factory(
347
+ config=self.algorithm_config.noise_model_likelihood,
348
+ noise_model=self.noise_model,
349
+ )
344
350
 
345
351
  self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
346
352
  self.algorithm_config.gaussian_likelihood
@@ -360,30 +366,43 @@ class VAEModule(L.LightningModule):
360
366
  RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
361
367
  ]
362
368
 
363
- def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
369
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
364
370
  """Forward pass.
365
371
 
366
372
  Parameters
367
373
  ----------
368
- x : Tensor
374
+ x : torch.Tensor
369
375
  Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
370
376
  number of lateral inputs.
371
377
 
372
378
  Returns
373
379
  -------
374
- tuple[Tensor, dict[str, Any]]
380
+ tuple[torch.Tensor, dict[str, Any]]
375
381
  A tuple with the output tensor and additional data from the top-down pass.
376
382
  """
377
383
  return self.model(x) # TODO Different model can have more than one output
378
384
 
385
+ def set_data_stats(self, data_mean, data_std):
386
+ """Set data mean and std for the noise model likelihood.
387
+
388
+ Parameters
389
+ ----------
390
+ data_mean : float
391
+ Mean of the data.
392
+ data_std : float
393
+ Standard deviation of the data.
394
+ """
395
+ if self.noise_model_likelihood is not None:
396
+ self.noise_model_likelihood.set_data_stats(data_mean, data_std)
397
+
379
398
  def training_step(
380
- self, batch: tuple[Tensor, Tensor], batch_idx: Any
381
- ) -> dict[str, Tensor] | None:
399
+ self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
400
+ ) -> dict[str, torch.Tensor] | None:
382
401
  """Training step.
383
402
 
384
403
  Parameters
385
404
  ----------
386
- batch : tuple[Tensor, Tensor]
405
+ batch : tuple[torch.Tensor, torch.Tensor]
387
406
  Input batch. It is a tuple with the input tensor and the target tensor.
388
407
  The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
389
408
  number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
@@ -397,15 +416,30 @@ class VAEModule(L.LightningModule):
397
416
  Any
398
417
  Loss value.
399
418
  """
400
- x, target = batch
419
+ x, *target = batch
401
420
 
402
421
  # Forward pass
403
422
  out = self.model(x)
423
+ if not self.supervised_mode:
424
+ target = x
425
+ else:
426
+ target = target[
427
+ 0
428
+ ] # hacky way to unpack. #TODO maybe should be fixed on the dataset level
404
429
 
405
430
  # Update loss parameters
406
431
  self.loss_parameters.kl_params.current_epoch = self.current_epoch
407
432
 
408
433
  # Compute loss
434
+ if self.noise_model_likelihood is not None:
435
+ if (
436
+ self.noise_model_likelihood.data_mean is None
437
+ or self.noise_model_likelihood.data_std is None
438
+ ):
439
+ raise RuntimeError(
440
+ "NoiseModelLikelihood: data_mean and data_std must be set before"
441
+ "training."
442
+ )
409
443
  loss = self.loss_func(
410
444
  model_outputs=out,
411
445
  targets=target,
@@ -417,15 +451,26 @@ class VAEModule(L.LightningModule):
417
451
  # Logging
418
452
  # TODO: implement a separate logging method?
419
453
  self.log_dict(loss, on_step=True, on_epoch=True)
420
- # self.log("lr", self, on_epoch=True)
454
+
455
+ try:
456
+ optimizer = self.optimizers()
457
+ current_lr = optimizer.param_groups[0]["lr"]
458
+ self.log(
459
+ "learning_rate", current_lr, on_step=False, on_epoch=True, logger=True
460
+ )
461
+ except RuntimeError:
462
+ # This happens when the module is not attached to a trainer, e.g., in tests
463
+ pass
421
464
  return loss
422
465
 
423
- def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: Any) -> None:
466
+ def validation_step(
467
+ self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
468
+ ) -> None:
424
469
  """Validation step.
425
470
 
426
471
  Parameters
427
472
  ----------
428
- batch : tuple[Tensor, Tensor]
473
+ batch : tuple[torch.Tensor, torch.Tensor]
429
474
  Input batch. It is a tuple with the input tensor and the target tensor.
430
475
  The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
431
476
  number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
@@ -434,11 +479,16 @@ class VAEModule(L.LightningModule):
434
479
  batch_idx : Any
435
480
  Batch index.
436
481
  """
437
- x, target = batch
482
+ x, *target = batch
438
483
 
439
484
  # Forward pass
440
485
  out = self.model(x)
441
-
486
+ if not self.supervised_mode:
487
+ target = x
488
+ else:
489
+ target = target[
490
+ 0
491
+ ] # hacky way to unpack. #TODO maybe should be fixed on the datasel level
442
492
  # Compute loss
443
493
  loss = self.loss_func(
444
494
  model_outputs=out,
@@ -464,12 +514,12 @@ class VAEModule(L.LightningModule):
464
514
  else:
465
515
  self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
466
516
 
467
- def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
517
+ def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
468
518
  """Prediction step.
469
519
 
470
520
  Parameters
471
521
  ----------
472
- batch : Tensor
522
+ batch : torch.Tensor
473
523
  Input batch.
474
524
  batch_idx : Any
475
525
  Batch index.
@@ -479,36 +529,86 @@ class VAEModule(L.LightningModule):
479
529
  Any
480
530
  Model output.
481
531
  """
482
- if self._trainer.datamodule.tiled:
532
+ if self.algorithm_config.algorithm == "microsplit":
483
533
  x, *aux = batch
484
- else:
485
- x = batch
486
- aux = []
534
+ # Reset model for inference with spatial dimensions only (H, W)
535
+ self.model.reset_for_inference(x.shape[-2:])
487
536
 
488
- # apply test-time augmentation if available
489
- # TODO: probably wont work with batch size > 1
490
- if self._trainer.datamodule.prediction_config.tta_transforms:
491
- tta = ImageRestorationTTA()
492
- augmented_batch = tta.forward(x) # list of augmented tensors
493
- augmented_output = []
494
- for augmented in augmented_batch:
495
- augmented_pred = self.model(augmented)
496
- augmented_output.append(augmented_pred)
497
- output = tta.backward(augmented_output)
498
- else:
499
- output = self.model(x)
537
+ rec_img_list = []
538
+ for _ in range(self.algorithm_config.mmse_count):
539
+ # get model output
540
+ rec, _ = self.model(x)
500
541
 
501
- # Denormalize the output
502
- denorm = Denormalize(
503
- image_means=self._trainer.datamodule.predict_dataset.image_means,
504
- image_stds=self._trainer.datamodule.predict_dataset.image_stds,
505
- )
506
- denormalized_output = denorm(patch=output.cpu().numpy())
542
+ # get reconstructed img
543
+ if self.model.predict_logvar is None:
544
+ rec_img = rec
545
+ _logvar = torch.tensor([-1])
546
+ else:
547
+ rec_img, _logvar = torch.chunk(rec, chunks=2, dim=1)
548
+ rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
549
+
550
+ # aggregate results
551
+ samples = torch.cat(rec_img_list, dim=0)
552
+ mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
553
+ std_imgs = torch.std(samples, dim=0) # std over MMSE dim
554
+
555
+ tile_prediction = mmse_imgs.cpu().numpy()
556
+ tile_std = std_imgs.cpu().numpy()
557
+
558
+ return tile_prediction, tile_std
507
559
 
508
- if len(aux) > 0: # aux can be tiling information
509
- return denormalized_output, *aux
510
560
  else:
511
- return denormalized_output
561
+ # Regular prediction logic
562
+ if self._trainer.datamodule.tiled:
563
+ # TODO tile_size should match model input size
564
+ x, *aux = batch
565
+ x = (
566
+ x[0] if isinstance(x, list | tuple) else x
567
+ ) # TODO ugly, so far i don't know why x might be a list
568
+ self.model.reset_for_inference(x.shape) # TODO should it be here ?
569
+ else:
570
+ x = batch[0] if isinstance(batch, list | tuple) else batch
571
+ aux = []
572
+ self.model.reset_for_inference(x.shape)
573
+
574
+ mmse_list = []
575
+ for _ in range(self.algorithm_config.mmse_count):
576
+ # apply test-time augmentation if available
577
+ if self._trainer.datamodule.prediction_config.tta_transforms:
578
+ tta = ImageRestorationTTA()
579
+ augmented_batch = tta.forward(x) # list of augmented tensors
580
+ augmented_output = []
581
+ for augmented in augmented_batch:
582
+ augmented_pred = self.model(augmented)
583
+ augmented_output.append(augmented_pred)
584
+ output = tta.backward(augmented_output)
585
+ else:
586
+ output = self.model(x)
587
+
588
+ # taking the 1st element of the output, 2nd is std if
589
+ # predict_logvar=="pixelwise"
590
+ output = (
591
+ output[0]
592
+ if self.model.predict_logvar is None
593
+ else output[0][:, 0:1, ...]
594
+ )
595
+ mmse_list.append(output)
596
+
597
+ mmse = torch.stack(mmse_list).mean(0)
598
+ std = torch.stack(mmse_list).std(0) # TODO why?
599
+ # TODO better way to unpack if pred logvar
600
+ # Denormalize the output
601
+ denorm = Denormalize(
602
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
603
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
604
+ )
605
+
606
+ denormalized_output = denorm(patch=mmse.cpu().numpy())
607
+
608
+ if len(aux) > 0: # aux can be tiling information
609
+ return denormalized_output, std, *aux
610
+ else:
611
+ return denormalized_output, std
512
612
 
513
613
  def configure_optimizers(self) -> Any:
514
614
  """Configure optimizers and learning rate schedulers.
@@ -537,19 +637,19 @@ class VAEModule(L.LightningModule):
537
637
  # should we refactor LadderVAE so that it already outputs
538
638
  # tuple(`mean`, `logvar`, `td_data`)?
539
639
  def get_reconstructed_tensor(
540
- self, model_outputs: tuple[Tensor, dict[str, Any]]
541
- ) -> Tensor:
640
+ self, model_outputs: tuple[torch.Tensor, dict[str, Any]]
641
+ ) -> torch.Tensor:
542
642
  """Get the reconstructed tensor from the LVAE model outputs.
543
643
 
544
644
  Parameters
545
645
  ----------
546
- model_outputs : tuple[Tensor, dict[str, Any]]
646
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
547
647
  Model outputs. It is a tuple with a tensor representing the predicted mean
548
648
  and (optionally) logvar, and the top-down data dictionary.
549
649
 
550
650
  Returns
551
651
  -------
552
- Tensor
652
+ torch.Tensor
553
653
  Reconstructed tensor, i.e., the predicted mean.
554
654
  """
555
655
  predictions, _ = model_outputs
@@ -560,18 +660,18 @@ class VAEModule(L.LightningModule):
560
660
 
561
661
  def compute_val_psnr(
562
662
  self,
563
- model_output: tuple[Tensor, dict[str, Any]],
564
- target: Tensor,
663
+ model_output: tuple[torch.Tensor, dict[str, Any]],
664
+ target: torch.Tensor,
565
665
  psnr_func: Callable = scale_invariant_psnr,
566
666
  ) -> list[float]:
567
667
  """Compute the PSNR for the current validation batch.
568
668
 
569
669
  Parameters
570
670
  ----------
571
- model_output : tuple[Tensor, dict[str, Any]]
671
+ model_output : tuple[torch.Tensor, dict[str, Any]]
572
672
  Model output, a tuple with the predicted mean and (optionally) logvar,
573
673
  and the top-down data dictionary.
574
- target : Tensor
674
+ target : torch.Tensor
575
675
  Target tensor.
576
676
  psnr_func : Callable, optional
577
677
  PSNR function to use, by default `scale_invariant_psnr`.
@@ -581,6 +681,7 @@ class VAEModule(L.LightningModule):
581
681
  list[float]
582
682
  PSNR for each channel in the current batch.
583
683
  """
684
+ # TODO check this! Related to is_supervised which is also wacky
584
685
  out_channels = target.shape[1]
585
686
 
586
687
  # get the reconstructed image