careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (92) hide show
  1. careamics/careamist.py +55 -61
  2. careamics/cli/conf.py +24 -9
  3. careamics/cli/main.py +8 -8
  4. careamics/cli/utils.py +2 -4
  5. careamics/config/__init__.py +8 -0
  6. careamics/config/algorithms/__init__.py +4 -0
  7. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  8. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  9. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  10. careamics/config/algorithms/vae_algorithm_model.py +53 -18
  11. careamics/config/architectures/lvae_model.py +12 -8
  12. careamics/config/callback_model.py +15 -11
  13. careamics/config/configuration.py +9 -8
  14. careamics/config/configuration_factories.py +892 -78
  15. careamics/config/data/data_model.py +7 -14
  16. careamics/config/data/ng_data_model.py +8 -15
  17. careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
  18. careamics/config/inference_model.py +6 -11
  19. careamics/config/likelihood_model.py +4 -4
  20. careamics/config/loss_model.py +6 -2
  21. careamics/config/nm_model.py +30 -7
  22. careamics/config/optimizer_models.py +1 -2
  23. careamics/config/support/supported_algorithms.py +5 -3
  24. careamics/config/support/supported_losses.py +5 -2
  25. careamics/config/training_model.py +8 -38
  26. careamics/config/transformations/normalize_model.py +3 -4
  27. careamics/config/transformations/xy_flip_model.py +2 -2
  28. careamics/config/transformations/xy_random_rotate90_model.py +2 -2
  29. careamics/config/validators/validator_utils.py +1 -2
  30. careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
  31. careamics/dataset/in_memory_dataset.py +2 -2
  32. careamics/dataset/iterable_dataset.py +1 -2
  33. careamics/dataset/patching/random_patching.py +6 -6
  34. careamics/dataset/patching/sequential_patching.py +4 -4
  35. careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
  36. careamics/dataset_ng/dataset.py +3 -3
  37. careamics/dataset_ng/factory.py +19 -19
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  39. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  40. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  41. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  42. careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
  43. careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
  44. careamics/file_io/read/__init__.py +0 -1
  45. careamics/lightning/__init__.py +16 -2
  46. careamics/lightning/callbacks/__init__.py +2 -0
  47. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  48. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
  49. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
  50. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
  51. careamics/lightning/dataset_ng/data_module.py +43 -43
  52. careamics/lightning/lightning_module.py +166 -68
  53. careamics/lightning/microsplit_data_module.py +631 -0
  54. careamics/lightning/predict_data_module.py +16 -9
  55. careamics/lightning/train_data_module.py +29 -18
  56. careamics/losses/__init__.py +7 -1
  57. careamics/losses/loss_factory.py +9 -1
  58. careamics/losses/lvae/losses.py +94 -9
  59. careamics/lvae_training/dataset/__init__.py +8 -8
  60. careamics/lvae_training/dataset/config.py +56 -44
  61. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  62. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  63. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  64. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  65. careamics/model_io/bioimage/model_description.py +12 -11
  66. careamics/model_io/bmz_io.py +12 -8
  67. careamics/models/layers.py +5 -5
  68. careamics/models/lvae/likelihoods.py +30 -14
  69. careamics/models/lvae/lvae.py +2 -2
  70. careamics/models/lvae/noise_models.py +20 -14
  71. careamics/prediction_utils/__init__.py +8 -2
  72. careamics/prediction_utils/lvae_prediction.py +5 -5
  73. careamics/prediction_utils/prediction_outputs.py +48 -3
  74. careamics/prediction_utils/stitch_prediction.py +71 -0
  75. careamics/transforms/compose.py +9 -9
  76. careamics/transforms/n2v_manipulate.py +3 -3
  77. careamics/transforms/n2v_manipulate_torch.py +4 -4
  78. careamics/transforms/normalize.py +4 -6
  79. careamics/transforms/pixel_manipulation.py +6 -8
  80. careamics/transforms/pixel_manipulation_torch.py +5 -7
  81. careamics/transforms/xy_flip.py +3 -5
  82. careamics/transforms/xy_random_rotate90.py +4 -6
  83. careamics/utils/logging.py +8 -8
  84. careamics/utils/metrics.py +2 -2
  85. careamics/utils/plotting.py +1 -3
  86. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
  87. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
  88. careamics/dataset/zarr_dataset.py +0 -151
  89. careamics/file_io/read/zarr.py +0 -60
  90. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  91. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  92. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Callable
4
4
  from pathlib import Path
5
- from typing import Any, Optional, Union, overload
5
+ from typing import Any, Union, overload
6
6
 
7
7
  import numpy as np
8
8
  import pytorch_lightning as L
@@ -124,14 +124,14 @@ class CareamicsDataModule(L.LightningDataModule):
124
124
  self,
125
125
  data_config: NGDataConfig,
126
126
  *,
127
- train_data: Optional[InputType] = None,
128
- train_data_target: Optional[InputType] = None,
129
- val_data: Optional[InputType] = None,
130
- val_data_target: Optional[InputType] = None,
131
- pred_data: Optional[InputType] = None,
132
- pred_data_target: Optional[InputType] = None,
127
+ train_data: InputType | None = None,
128
+ train_data_target: InputType | None = None,
129
+ val_data: InputType | None = None,
130
+ val_data_target: InputType | None = None,
131
+ pred_data: InputType | None = None,
132
+ pred_data_target: InputType | None = None,
133
133
  extension_filter: str = "",
134
- val_percentage: Optional[float] = None,
134
+ val_percentage: float | None = None,
135
135
  val_minimum_split: int = 5,
136
136
  use_in_memory: bool = True,
137
137
  ) -> None: ...
@@ -142,16 +142,16 @@ class CareamicsDataModule(L.LightningDataModule):
142
142
  self,
143
143
  data_config: NGDataConfig,
144
144
  *,
145
- train_data: Optional[InputType] = None,
146
- train_data_target: Optional[InputType] = None,
147
- val_data: Optional[InputType] = None,
148
- val_data_target: Optional[InputType] = None,
149
- pred_data: Optional[InputType] = None,
150
- pred_data_target: Optional[InputType] = None,
145
+ train_data: InputType | None = None,
146
+ train_data_target: InputType | None = None,
147
+ val_data: InputType | None = None,
148
+ val_data_target: InputType | None = None,
149
+ pred_data: InputType | None = None,
150
+ pred_data_target: InputType | None = None,
151
151
  read_source_func: Callable,
152
- read_kwargs: Optional[dict[str, Any]] = None,
152
+ read_kwargs: dict[str, Any] | None = None,
153
153
  extension_filter: str = "",
154
- val_percentage: Optional[float] = None,
154
+ val_percentage: float | None = None,
155
155
  val_minimum_split: int = 5,
156
156
  use_in_memory: bool = True,
157
157
  ) -> None: ...
@@ -161,16 +161,16 @@ class CareamicsDataModule(L.LightningDataModule):
161
161
  self,
162
162
  data_config: NGDataConfig,
163
163
  *,
164
- train_data: Optional[Any] = None,
165
- train_data_target: Optional[Any] = None,
166
- val_data: Optional[Any] = None,
167
- val_data_target: Optional[Any] = None,
168
- pred_data: Optional[Any] = None,
169
- pred_data_target: Optional[Any] = None,
164
+ train_data: Any | None = None,
165
+ train_data_target: Any | None = None,
166
+ val_data: Any | None = None,
167
+ val_data_target: Any | None = None,
168
+ pred_data: Any | None = None,
169
+ pred_data_target: Any | None = None,
170
170
  image_stack_loader: ImageStackLoader,
171
- image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
171
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
172
172
  extension_filter: str = "",
173
- val_percentage: Optional[float] = None,
173
+ val_percentage: float | None = None,
174
174
  val_minimum_split: int = 5,
175
175
  use_in_memory: bool = True,
176
176
  ) -> None: ...
@@ -179,18 +179,18 @@ class CareamicsDataModule(L.LightningDataModule):
179
179
  self,
180
180
  data_config: NGDataConfig,
181
181
  *,
182
- train_data: Optional[Any] = None,
183
- train_data_target: Optional[Any] = None,
184
- val_data: Optional[Any] = None,
185
- val_data_target: Optional[Any] = None,
186
- pred_data: Optional[Any] = None,
187
- pred_data_target: Optional[Any] = None,
188
- read_source_func: Optional[Callable] = None,
189
- read_kwargs: Optional[dict[str, Any]] = None,
190
- image_stack_loader: Optional[ImageStackLoader] = None,
191
- image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
182
+ train_data: Any | None = None,
183
+ train_data_target: Any | None = None,
184
+ val_data: Any | None = None,
185
+ val_data_target: Any | None = None,
186
+ pred_data: Any | None = None,
187
+ pred_data_target: Any | None = None,
188
+ read_source_func: Callable | None = None,
189
+ read_kwargs: dict[str, Any] | None = None,
190
+ image_stack_loader: ImageStackLoader | None = None,
191
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
192
192
  extension_filter: str = "",
193
- val_percentage: Optional[float] = None,
193
+ val_percentage: float | None = None,
194
194
  val_minimum_split: int = 5,
195
195
  use_in_memory: bool = True,
196
196
  ) -> None:
@@ -280,7 +280,7 @@ class CareamicsDataModule(L.LightningDataModule):
280
280
  def _validate_input_target_type_consistency(
281
281
  self,
282
282
  input_data: InputType,
283
- target_data: Optional[InputType],
283
+ target_data: InputType | None,
284
284
  ) -> None:
285
285
  """Validate if the input and target data types are consistent.
286
286
 
@@ -314,7 +314,7 @@ class CareamicsDataModule(L.LightningDataModule):
314
314
  self,
315
315
  input_data,
316
316
  target_data=None,
317
- ) -> tuple[list[Path], Optional[list[Path]]]:
317
+ ) -> tuple[list[Path], list[Path] | None]:
318
318
  """List files from input and target directories.
319
319
 
320
320
  Parameters
@@ -347,7 +347,7 @@ class CareamicsDataModule(L.LightningDataModule):
347
347
  self,
348
348
  input_data,
349
349
  target_data=None,
350
- ) -> tuple[list[Path], Optional[list[Path]]]:
350
+ ) -> tuple[list[Path], list[Path] | None]:
351
351
  """Create a list of file paths from the input and target data.
352
352
 
353
353
  Parameters
@@ -379,7 +379,7 @@ class CareamicsDataModule(L.LightningDataModule):
379
379
  def _validate_array_input(
380
380
  self,
381
381
  input_data: InputType,
382
- target_data: Optional[InputType],
382
+ target_data: InputType | None,
383
383
  ) -> tuple[Any, Any]:
384
384
  """Validate if the input data is a numpy array.
385
385
 
@@ -408,8 +408,8 @@ class CareamicsDataModule(L.LightningDataModule):
408
408
  )
409
409
 
410
410
  def _validate_path_input(
411
- self, input_data: InputType, target_data: Optional[InputType]
412
- ) -> tuple[list[Path], Optional[list[Path]]]:
411
+ self, input_data: InputType, target_data: InputType | None
412
+ ) -> tuple[list[Path], list[Path] | None]:
413
413
  """Validate if the input data is a path or a list of paths.
414
414
 
415
415
  Parameters
@@ -488,8 +488,8 @@ class CareamicsDataModule(L.LightningDataModule):
488
488
 
489
489
  def _initialize_data_pair(
490
490
  self,
491
- input_data: Optional[InputType],
492
- target_data: Optional[InputType],
491
+ input_data: InputType | None,
492
+ target_data: InputType | None,
493
493
  ) -> tuple[Any, Any]:
494
494
  """
495
495
  Initialize a pair of input and target data.
@@ -1,11 +1,11 @@
1
1
  """CAREamics Lightning module."""
2
2
 
3
3
  from collections.abc import Callable
4
- from typing import Any, Literal, Optional, Union
4
+ from typing import Any, Literal, Union
5
5
 
6
6
  import numpy as np
7
7
  import pytorch_lightning as L
8
- from torch import Tensor, nn
8
+ import torch
9
9
 
10
10
  from careamics.config import (
11
11
  N2VAlgorithm,
@@ -71,7 +71,9 @@ class FCNModule(L.LightningModule):
71
71
  Learning rate scheduler name.
72
72
  """
73
73
 
74
- def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None:
74
+ def __init__(
75
+ self, algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm, dict]
76
+ ) -> None:
75
77
  """Lightning module for CAREamics.
76
78
 
77
79
  This class encapsulates the a PyTorch model along with the training, validation,
@@ -90,7 +92,7 @@ class FCNModule(L.LightningModule):
90
92
  # create preprocessing, model and loss function
91
93
  if isinstance(algorithm_config, N2VAlgorithm):
92
94
  self.use_n2v = True
93
- self.n2v_preprocess: Optional[N2VManipulateTorch] = N2VManipulateTorch(
95
+ self.n2v_preprocess: N2VManipulateTorch | None = N2VManipulateTorch(
94
96
  n2v_manipulate_config=algorithm_config.n2v_config
95
97
  )
96
98
  else:
@@ -98,7 +100,7 @@ class FCNModule(L.LightningModule):
98
100
  self.n2v_preprocess = None
99
101
 
100
102
  self.algorithm = algorithm_config.algorithm
101
- self.model: nn.Module = model_factory(algorithm_config.model)
103
+ self.model: torch.nn.Module = model_factory(algorithm_config.model)
102
104
  self.loss_func = loss_factory(algorithm_config.loss)
103
105
 
104
106
  # save optimizer and lr_scheduler names and parameters
@@ -122,12 +124,12 @@ class FCNModule(L.LightningModule):
122
124
  """
123
125
  return self.model(x)
124
126
 
125
- def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
127
+ def training_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
126
128
  """Training step.
127
129
 
128
130
  Parameters
129
131
  ----------
130
- batch : torch.Tensor
132
+ batch : torch.torch.Tensor
131
133
  Input batch.
132
134
  batch_idx : Any
133
135
  Batch index.
@@ -154,12 +156,12 @@ class FCNModule(L.LightningModule):
154
156
  self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
155
157
  return loss
156
158
 
157
- def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
159
+ def validation_step(self, batch: torch.Tensor, batch_idx: Any) -> None:
158
160
  """Validation step.
159
161
 
160
162
  Parameters
161
163
  ----------
162
- batch : torch.Tensor
164
+ batch : torch.torch.Tensor
163
165
  Input batch.
164
166
  batch_idx : Any
165
167
  Batch index.
@@ -184,12 +186,12 @@ class FCNModule(L.LightningModule):
184
186
  logger=True,
185
187
  )
186
188
 
187
- def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
189
+ def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
188
190
  """Prediction step.
189
191
 
190
192
  Parameters
191
193
  ----------
192
- batch : torch.Tensor
194
+ batch : torch.torch.torch.Tensor
193
195
  Input batch.
194
196
  batch_idx : Any
195
197
  Batch index.
@@ -330,21 +332,23 @@ class VAEModule(L.LightningModule):
330
332
  # self.save_hyperparameters(self.algorithm_config.model_dump())
331
333
 
332
334
  # create model
333
- self.model: nn.Module = model_factory(self.algorithm_config.model)
335
+ self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
334
336
 
337
+ # supervised_mode
338
+ self.supervised_mode = self.algorithm_config.is_supervised
335
339
  # create loss function
336
- self.noise_model: Optional[NoiseModel] = noise_model_factory(
340
+ self.noise_model: NoiseModel | None = noise_model_factory(
337
341
  self.algorithm_config.noise_model
338
342
  )
339
343
 
340
- self.noise_model_likelihood: Optional[NoiseModelLikelihood] = (
341
- likelihood_factory(
344
+ self.noise_model_likelihood: NoiseModelLikelihood | None = None
345
+ if self.algorithm_config.noise_model_likelihood is not None:
346
+ self.noise_model_likelihood = likelihood_factory(
342
347
  config=self.algorithm_config.noise_model_likelihood,
343
348
  noise_model=self.noise_model,
344
349
  )
345
- )
346
350
 
347
- self.gaussian_likelihood: Optional[GaussianLikelihood] = likelihood_factory(
351
+ self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
348
352
  self.algorithm_config.gaussian_likelihood
349
353
  )
350
354
 
@@ -362,30 +366,43 @@ class VAEModule(L.LightningModule):
362
366
  RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
363
367
  ]
364
368
 
365
- def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
369
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
366
370
  """Forward pass.
367
371
 
368
372
  Parameters
369
373
  ----------
370
- x : Tensor
374
+ x : torch.Tensor
371
375
  Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
372
376
  number of lateral inputs.
373
377
 
374
378
  Returns
375
379
  -------
376
- tuple[Tensor, dict[str, Any]]
380
+ tuple[torch.Tensor, dict[str, Any]]
377
381
  A tuple with the output tensor and additional data from the top-down pass.
378
382
  """
379
383
  return self.model(x) # TODO Different model can have more than one output
380
384
 
385
+ def set_data_stats(self, data_mean, data_std):
386
+ """Set data mean and std for the noise model likelihood.
387
+
388
+ Parameters
389
+ ----------
390
+ data_mean : float
391
+ Mean of the data.
392
+ data_std : float
393
+ Standard deviation of the data.
394
+ """
395
+ if self.noise_model_likelihood is not None:
396
+ self.noise_model_likelihood.set_data_stats(data_mean, data_std)
397
+
381
398
  def training_step(
382
- self, batch: tuple[Tensor, Tensor], batch_idx: Any
383
- ) -> Optional[dict[str, Tensor]]:
399
+ self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
400
+ ) -> dict[str, torch.Tensor] | None:
384
401
  """Training step.
385
402
 
386
403
  Parameters
387
404
  ----------
388
- batch : tuple[Tensor, Tensor]
405
+ batch : tuple[torch.Tensor, torch.Tensor]
389
406
  Input batch. It is a tuple with the input tensor and the target tensor.
390
407
  The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
391
408
  number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
@@ -399,15 +416,29 @@ class VAEModule(L.LightningModule):
399
416
  Any
400
417
  Loss value.
401
418
  """
402
- x, target = batch
419
+ x, *target = batch
403
420
 
404
421
  # Forward pass
405
422
  out = self.model(x)
423
+ if not self.supervised_mode:
424
+ target = x
425
+ else:
426
+ target = target[
427
+ 0
428
+ ] # hacky way to unpack. #TODO maybe should be fixed on the dataset level
406
429
 
407
430
  # Update loss parameters
408
431
  self.loss_parameters.kl_params.current_epoch = self.current_epoch
409
432
 
410
433
  # Compute loss
434
+ if self.noise_model_likelihood is not None:
435
+ if (
436
+ self.noise_model_likelihood.data_mean is None
437
+ or self.noise_model_likelihood.data_std is None
438
+ ):
439
+ raise RuntimeError(
440
+ "NoiseModelLikelihood: data_mean and data_std must be set before training."
441
+ )
411
442
  loss = self.loss_func(
412
443
  model_outputs=out,
413
444
  targets=target,
@@ -419,15 +450,26 @@ class VAEModule(L.LightningModule):
419
450
  # Logging
420
451
  # TODO: implement a separate logging method?
421
452
  self.log_dict(loss, on_step=True, on_epoch=True)
422
- # self.log("lr", self, on_epoch=True)
453
+
454
+ try:
455
+ optimizer = self.optimizers()
456
+ current_lr = optimizer.param_groups[0]["lr"]
457
+ self.log(
458
+ "learning_rate", current_lr, on_step=False, on_epoch=True, logger=True
459
+ )
460
+ except RuntimeError:
461
+ # This happens when the module is not attached to a trainer, e.g., in tests
462
+ pass
423
463
  return loss
424
464
 
425
- def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: Any) -> None:
465
+ def validation_step(
466
+ self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
467
+ ) -> None:
426
468
  """Validation step.
427
469
 
428
470
  Parameters
429
471
  ----------
430
- batch : tuple[Tensor, Tensor]
472
+ batch : tuple[torch.Tensor, torch.Tensor]
431
473
  Input batch. It is a tuple with the input tensor and the target tensor.
432
474
  The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
433
475
  number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
@@ -436,11 +478,16 @@ class VAEModule(L.LightningModule):
436
478
  batch_idx : Any
437
479
  Batch index.
438
480
  """
439
- x, target = batch
481
+ x, *target = batch
440
482
 
441
483
  # Forward pass
442
484
  out = self.model(x)
443
-
485
+ if not self.supervised_mode:
486
+ target = x
487
+ else:
488
+ target = target[
489
+ 0
490
+ ] # hacky way to unpack. #TODO maybe should be fixed on the datasel level
444
491
  # Compute loss
445
492
  loss = self.loss_func(
446
493
  model_outputs=out,
@@ -466,12 +513,12 @@ class VAEModule(L.LightningModule):
466
513
  else:
467
514
  self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
468
515
 
469
- def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
516
+ def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
470
517
  """Prediction step.
471
518
 
472
519
  Parameters
473
520
  ----------
474
- batch : Tensor
521
+ batch : torch.Tensor
475
522
  Input batch.
476
523
  batch_idx : Any
477
524
  Batch index.
@@ -481,36 +528,86 @@ class VAEModule(L.LightningModule):
481
528
  Any
482
529
  Model output.
483
530
  """
484
- if self._trainer.datamodule.tiled:
531
+ if self.algorithm_config.algorithm == "microsplit":
485
532
  x, *aux = batch
486
- else:
487
- x = batch
488
- aux = []
533
+ # Reset model for inference with spatial dimensions only (H, W)
534
+ self.model.reset_for_inference(x.shape[-2:])
489
535
 
490
- # apply test-time augmentation if available
491
- # TODO: probably wont work with batch size > 1
492
- if self._trainer.datamodule.prediction_config.tta_transforms:
493
- tta = ImageRestorationTTA()
494
- augmented_batch = tta.forward(x) # list of augmented tensors
495
- augmented_output = []
496
- for augmented in augmented_batch:
497
- augmented_pred = self.model(augmented)
498
- augmented_output.append(augmented_pred)
499
- output = tta.backward(augmented_output)
500
- else:
501
- output = self.model(x)
536
+ rec_img_list = []
537
+ for _ in range(self.algorithm_config.mmse_count):
538
+ # get model output
539
+ rec, _ = self.model(x)
502
540
 
503
- # Denormalize the output
504
- denorm = Denormalize(
505
- image_means=self._trainer.datamodule.predict_dataset.image_means,
506
- image_stds=self._trainer.datamodule.predict_dataset.image_stds,
507
- )
508
- denormalized_output = denorm(patch=output.cpu().numpy())
541
+ # get reconstructed img
542
+ if self.model.predict_logvar is None:
543
+ rec_img = rec
544
+ logvar = torch.tensor([-1])
545
+ else:
546
+ rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
547
+ rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
548
+
549
+ # aggregate results
550
+ samples = torch.cat(rec_img_list, dim=0)
551
+ mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
552
+ std_imgs = torch.std(samples, dim=0) # std over MMSE dim
553
+
554
+ tile_prediction = mmse_imgs.cpu().numpy()
555
+ tile_std = std_imgs.cpu().numpy()
556
+
557
+ return tile_prediction, tile_std
509
558
 
510
- if len(aux) > 0: # aux can be tiling information
511
- return denormalized_output, *aux
512
559
  else:
513
- return denormalized_output
560
+ # Regular prediction logic
561
+ if self._trainer.datamodule.tiled:
562
+ # TODO tile_size should match model input size
563
+ x, *aux = batch
564
+ x = (
565
+ x[0] if isinstance(x, list | tuple) else x
566
+ ) # TODO ugly, so far i don't know why x might be a list
567
+ self.model.reset_for_inference(x.shape) # TODO should it be here ?
568
+ else:
569
+ x = batch[0] if isinstance(batch, list | tuple) else batch
570
+ aux = []
571
+ self.model.reset_for_inference(x.shape)
572
+
573
+ mmse_list = []
574
+ for _ in range(self.algorithm_config.mmse_count):
575
+ # apply test-time augmentation if available
576
+ if self._trainer.datamodule.prediction_config.tta_transforms:
577
+ tta = ImageRestorationTTA()
578
+ augmented_batch = tta.forward(x) # list of augmented tensors
579
+ augmented_output = []
580
+ for augmented in augmented_batch:
581
+ augmented_pred = self.model(augmented)
582
+ augmented_output.append(augmented_pred)
583
+ output = tta.backward(augmented_output)
584
+ else:
585
+ output = self.model(x)
586
+
587
+ # taking the 1st element of the output, 2nd is std if
588
+ # predict_logvar=="pixelwise"
589
+ output = (
590
+ output[0]
591
+ if self.model.predict_logvar is None
592
+ else output[0][:, 0:1, ...]
593
+ )
594
+ mmse_list.append(output)
595
+
596
+ mmse = torch.stack(mmse_list).mean(0)
597
+ std = torch.stack(mmse_list).std(0) # TODO why?
598
+ # TODO better way to unpack if pred logvar
599
+ # Denormalize the output
600
+ denorm = Denormalize(
601
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
602
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
603
+ )
604
+
605
+ denormalized_output = denorm(patch=mmse.cpu().numpy())
606
+
607
+ if len(aux) > 0: # aux can be tiling information
608
+ return denormalized_output, std, *aux
609
+ else:
610
+ return denormalized_output, std
514
611
 
515
612
  def configure_optimizers(self) -> Any:
516
613
  """Configure optimizers and learning rate schedulers.
@@ -539,19 +636,19 @@ class VAEModule(L.LightningModule):
539
636
  # should we refactor LadderVAE so that it already outputs
540
637
  # tuple(`mean`, `logvar`, `td_data`)?
541
638
  def get_reconstructed_tensor(
542
- self, model_outputs: tuple[Tensor, dict[str, Any]]
543
- ) -> Tensor:
639
+ self, model_outputs: tuple[torch.Tensor, dict[str, Any]]
640
+ ) -> torch.Tensor:
544
641
  """Get the reconstructed tensor from the LVAE model outputs.
545
642
 
546
643
  Parameters
547
644
  ----------
548
- model_outputs : tuple[Tensor, dict[str, Any]]
645
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
549
646
  Model outputs. It is a tuple with a tensor representing the predicted mean
550
647
  and (optionally) logvar, and the top-down data dictionary.
551
648
 
552
649
  Returns
553
650
  -------
554
- Tensor
651
+ torch.Tensor
555
652
  Reconstructed tensor, i.e., the predicted mean.
556
653
  """
557
654
  predictions, _ = model_outputs
@@ -562,18 +659,18 @@ class VAEModule(L.LightningModule):
562
659
 
563
660
  def compute_val_psnr(
564
661
  self,
565
- model_output: tuple[Tensor, dict[str, Any]],
566
- target: Tensor,
662
+ model_output: tuple[torch.Tensor, dict[str, Any]],
663
+ target: torch.Tensor,
567
664
  psnr_func: Callable = scale_invariant_psnr,
568
665
  ) -> list[float]:
569
666
  """Compute the PSNR for the current validation batch.
570
667
 
571
668
  Parameters
572
669
  ----------
573
- model_output : tuple[Tensor, dict[str, Any]]
670
+ model_output : tuple[torch.Tensor, dict[str, Any]]
574
671
  Model output, a tuple with the predicted mean and (optionally) logvar,
575
672
  and the top-down data dictionary.
576
- target : Tensor
673
+ target : torch.Tensor
577
674
  Target tensor.
578
675
  psnr_func : Callable, optional
579
676
  PSNR function to use, by default `scale_invariant_psnr`.
@@ -583,6 +680,7 @@ class VAEModule(L.LightningModule):
583
680
  list[float]
584
681
  PSNR for each channel in the current batch.
585
682
  """
683
+ # TODO check this! Related to is_supervised which is also wacky
586
684
  out_channels = target.shape[1]
587
685
 
588
686
  # get the reconstructed image
@@ -603,7 +701,7 @@ class VAEModule(L.LightningModule):
603
701
  for i in range(out_channels)
604
702
  ]
605
703
 
606
- def reduce_running_psnr(self) -> Optional[float]:
704
+ def reduce_running_psnr(self) -> float | None:
607
705
  """Reduce the running PSNR statistics and reset the running PSNR.
608
706
 
609
707
  Returns
@@ -634,11 +732,11 @@ def create_careamics_module(
634
732
  use_n2v2: bool = False,
635
733
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
636
734
  struct_n2v_span: int = 5,
637
- model_parameters: Optional[dict] = None,
735
+ model_parameters: dict | None = None,
638
736
  optimizer: Union[SupportedOptimizer, str] = "Adam",
639
- optimizer_parameters: Optional[dict] = None,
737
+ optimizer_parameters: dict | None = None,
640
738
  lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
641
- lr_scheduler_parameters: Optional[dict] = None,
739
+ lr_scheduler_parameters: dict | None = None,
642
740
  ) -> Union[FCNModule, VAEModule]:
643
741
  """Create a CAREamics Lightning module.
644
742