careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,632 @@
1
+ """CAREamics Lightning module."""
2
+
3
+ from typing import Any, Callable, Literal, Optional, Union
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as L
7
+ from torch import Tensor, nn
8
+
9
+ from careamics.config import FCNAlgorithmConfig, VAEAlgorithmConfig
10
+ from careamics.config.support import (
11
+ SupportedAlgorithm,
12
+ SupportedArchitecture,
13
+ SupportedLoss,
14
+ SupportedOptimizer,
15
+ SupportedScheduler,
16
+ )
17
+ from careamics.losses import loss_factory
18
+ from careamics.losses.loss_factory import LVAELossParameters
19
+ from careamics.models.lvae.likelihoods import (
20
+ GaussianLikelihood,
21
+ NoiseModelLikelihood,
22
+ likelihood_factory,
23
+ )
24
+ from careamics.models.lvae.noise_models import (
25
+ GaussianMixtureNoiseModel,
26
+ MultiChannelNoiseModel,
27
+ noise_model_factory,
28
+ )
29
+ from careamics.models.model_factory import model_factory
30
+ from careamics.transforms import Denormalize, ImageRestorationTTA
31
+ from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr
32
+ from careamics.utils.torch_utils import get_optimizer, get_scheduler
33
+
34
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
35
+
36
+
37
+ class FCNModule(L.LightningModule):
38
+ """
39
+ CAREamics Lightning module.
40
+
41
+ This class encapsulates the PyTorch model along with the training, validation,
42
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
43
+
44
+ Parameters
45
+ ----------
46
+ algorithm_config : AlgorithmModel or dict
47
+ Algorithm configuration.
48
+
49
+ Attributes
50
+ ----------
51
+ model : torch.nn.Module
52
+ PyTorch model.
53
+ loss_func : torch.nn.Module
54
+ Loss function.
55
+ optimizer_name : str
56
+ Optimizer name.
57
+ optimizer_params : dict
58
+ Optimizer parameters.
59
+ lr_scheduler_name : str
60
+ Learning rate scheduler name.
61
+ """
62
+
63
+ def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None:
64
+ """Lightning module for CAREamics.
65
+
66
+ This class encapsulates the a PyTorch model along with the training, validation,
67
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
68
+
69
+ Parameters
70
+ ----------
71
+ algorithm_config : AlgorithmModel or dict
72
+ Algorithm configuration.
73
+ """
74
+ super().__init__()
75
+ # if loading from a checkpoint, AlgorithmModel needs to be instantiated
76
+ if isinstance(algorithm_config, dict):
77
+ algorithm_config = FCNAlgorithmConfig(**algorithm_config)
78
+
79
+ # create model and loss function
80
+ self.model: nn.Module = model_factory(algorithm_config.model)
81
+ self.loss_func = loss_factory(algorithm_config.loss)
82
+
83
+ # save optimizer and lr_scheduler names and parameters
84
+ self.optimizer_name = algorithm_config.optimizer.name
85
+ self.optimizer_params = algorithm_config.optimizer.parameters
86
+ self.lr_scheduler_name = algorithm_config.lr_scheduler.name
87
+ self.lr_scheduler_params = algorithm_config.lr_scheduler.parameters
88
+
89
+ def forward(self, x: Any) -> Any:
90
+ """Forward pass.
91
+
92
+ Parameters
93
+ ----------
94
+ x : Any
95
+ Input tensor.
96
+
97
+ Returns
98
+ -------
99
+ Any
100
+ Output tensor.
101
+ """
102
+ return self.model(x)
103
+
104
+ def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
105
+ """Training step.
106
+
107
+ Parameters
108
+ ----------
109
+ batch : torch.Tensor
110
+ Input batch.
111
+ batch_idx : Any
112
+ Batch index.
113
+
114
+ Returns
115
+ -------
116
+ Any
117
+ Loss value.
118
+ """
119
+ # TODO can N2V be simplified by returning mask*original_patch
120
+ x, *aux = batch
121
+ out = self.model(x)
122
+ loss = self.loss_func(out, *aux)
123
+ self.log(
124
+ "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
125
+ )
126
+ return loss
127
+
128
+ def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
129
+ """Validation step.
130
+
131
+ Parameters
132
+ ----------
133
+ batch : torch.Tensor
134
+ Input batch.
135
+ batch_idx : Any
136
+ Batch index.
137
+ """
138
+ x, *aux = batch
139
+ out = self.model(x)
140
+ val_loss = self.loss_func(out, *aux)
141
+
142
+ # log validation loss
143
+ self.log(
144
+ "val_loss",
145
+ val_loss,
146
+ on_step=False,
147
+ on_epoch=True,
148
+ prog_bar=True,
149
+ logger=True,
150
+ )
151
+
152
+ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
153
+ """Prediction step.
154
+
155
+ Parameters
156
+ ----------
157
+ batch : torch.Tensor
158
+ Input batch.
159
+ batch_idx : Any
160
+ Batch index.
161
+
162
+ Returns
163
+ -------
164
+ Any
165
+ Model output.
166
+ """
167
+ if self._trainer.datamodule.tiled:
168
+ x, *aux = batch
169
+ else:
170
+ x = batch
171
+ aux = []
172
+
173
+ # apply test-time augmentation if available
174
+ # TODO: probably wont work with batch size > 1
175
+ if self._trainer.datamodule.prediction_config.tta_transforms:
176
+ tta = ImageRestorationTTA()
177
+ augmented_batch = tta.forward(x) # list of augmented tensors
178
+ augmented_output = []
179
+ for augmented in augmented_batch:
180
+ augmented_pred = self.model(augmented)
181
+ augmented_output.append(augmented_pred)
182
+ output = tta.backward(augmented_output)
183
+ else:
184
+ output = self.model(x)
185
+
186
+ # Denormalize the output
187
+ denorm = Denormalize(
188
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
189
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
190
+ )
191
+ denormalized_output = denorm(patch=output.cpu().numpy())
192
+
193
+ if len(aux) > 0: # aux can be tiling information
194
+ return denormalized_output, *aux
195
+ else:
196
+ return denormalized_output
197
+
198
+ def configure_optimizers(self) -> Any:
199
+ """Configure optimizers and learning rate schedulers.
200
+
201
+ Returns
202
+ -------
203
+ Any
204
+ Optimizer and learning rate scheduler.
205
+ """
206
+ # instantiate optimizer
207
+ optimizer_func = get_optimizer(self.optimizer_name)
208
+ optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
209
+
210
+ # and scheduler
211
+ scheduler_func = get_scheduler(self.lr_scheduler_name)
212
+ scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
213
+
214
+ return {
215
+ "optimizer": optimizer,
216
+ "lr_scheduler": scheduler,
217
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
218
+ }
219
+
220
+
221
+ class VAEModule(L.LightningModule):
222
+ """
223
+ CAREamics Lightning module.
224
+
225
+ This class encapsulates the a PyTorch model along with the training, validation,
226
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
227
+
228
+ Parameters
229
+ ----------
230
+ algorithm_config : Union[VAEAlgorithmConfig, dict]
231
+ Algorithm configuration.
232
+
233
+ Attributes
234
+ ----------
235
+ model : nn.Module
236
+ PyTorch model.
237
+ loss_func : nn.Module
238
+ Loss function.
239
+ optimizer_name : str
240
+ Optimizer name.
241
+ optimizer_params : dict
242
+ Optimizer parameters.
243
+ lr_scheduler_name : str
244
+ Learning rate scheduler name.
245
+ """
246
+
247
+ def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None:
248
+ """Lightning module for CAREamics.
249
+
250
+ This class encapsulates the a PyTorch model along with the training, validation,
251
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
252
+
253
+ Parameters
254
+ ----------
255
+ algorithm_config : Union[AlgorithmModel, dict]
256
+ Algorithm configuration.
257
+ """
258
+ super().__init__()
259
+ # if loading from a checkpoint, AlgorithmModel needs to be instantiated
260
+ self.algorithm_config = (
261
+ VAEAlgorithmConfig(**algorithm_config)
262
+ if isinstance(algorithm_config, dict)
263
+ else algorithm_config
264
+ )
265
+
266
+ # TODO: log algorithm config
267
+ # self.save_hyperparameters(self.algorithm_config.model_dump())
268
+
269
+ # create model and loss function
270
+ self.model: nn.Module = model_factory(self.algorithm_config.model)
271
+ self.noise_model: NoiseModel = noise_model_factory(
272
+ self.algorithm_config.noise_model
273
+ )
274
+ self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
275
+ self.algorithm_config.noise_model_likelihood_model
276
+ )
277
+ self.gaussian_likelihood: GaussianLikelihood = likelihood_factory(
278
+ self.algorithm_config.gaussian_likelihood_model
279
+ )
280
+ self.loss_parameters = LVAELossParameters(
281
+ noise_model_likelihood=self.noise_model_likelihood,
282
+ gaussian_likelihood=self.gaussian_likelihood,
283
+ # TODO: musplit/denoisplit weights ?
284
+ ) # type: ignore
285
+ self.loss_func = loss_factory(self.algorithm_config.loss)
286
+
287
+ # save optimizer and lr_scheduler names and parameters
288
+ self.optimizer_name = self.algorithm_config.optimizer.name
289
+ self.optimizer_params = self.algorithm_config.optimizer.parameters
290
+ self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name
291
+ self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters
292
+
293
+ # initialize running PSNR
294
+ self.running_psnr = [
295
+ RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
296
+ ]
297
+
298
+ def forward(self, x: Tensor) -> tuple[Tensor, dict[str, Any]]:
299
+ """Forward pass.
300
+
301
+ Parameters
302
+ ----------
303
+ x : Tensor
304
+ Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
305
+ number of lateral inputs.
306
+
307
+ Returns
308
+ -------
309
+ tuple[Tensor, dict[str, Any]]
310
+ A tuple with the output tensor and additional data from the top-down pass.
311
+ """
312
+ return self.model(x) # TODO Different model can have more than one output
313
+
314
+ def training_step(
315
+ self, batch: tuple[Tensor, Tensor], batch_idx: Any
316
+ ) -> Optional[dict[str, Tensor]]:
317
+ """Training step.
318
+
319
+ Parameters
320
+ ----------
321
+ batch : tuple[Tensor, Tensor]
322
+ Input batch. It is a tuple with the input tensor and the target tensor.
323
+ The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
324
+ number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
325
+ where C is the number of target channels (e.g., 1 in HDN, >1 in
326
+ muSplit/denoiSplit).
327
+ batch_idx : Any
328
+ Batch index.
329
+
330
+ Returns
331
+ -------
332
+ Any
333
+ Loss value.
334
+ """
335
+ x, target = batch
336
+
337
+ # Forward pass
338
+ out = self.model(x)
339
+
340
+ # Update loss parameters
341
+ # TODO rethink loss parameters
342
+ self.loss_parameters.current_epoch = self.current_epoch
343
+
344
+ # Compute loss
345
+ loss = self.loss_func(out, target, self.loss_parameters) # TODO ugly ?
346
+
347
+ # Logging
348
+ # TODO: implement a separate logging method?
349
+ self.log_dict(loss, on_step=True, on_epoch=True)
350
+ # self.log("lr", self, on_epoch=True)
351
+ return loss
352
+
353
+ def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: Any) -> None:
354
+ """Validation step.
355
+
356
+ Parameters
357
+ ----------
358
+ batch : tuple[Tensor, Tensor]
359
+ Input batch. It is a tuple with the input tensor and the target tensor.
360
+ The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
361
+ number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
362
+ where C is the number of target channels (e.g., 1 in HDN, >1 in
363
+ muSplit/denoiSplit).
364
+ batch_idx : Any
365
+ Batch index.
366
+ """
367
+ x, target = batch
368
+
369
+ # Forward pass
370
+ out = self.model(x)
371
+
372
+ # Compute loss
373
+ loss = self.loss_func(out, target, self.loss_parameters)
374
+
375
+ # Logging
376
+ # Rename val_loss dict
377
+ loss = {"_".join(["val", k]): v for k, v in loss.items()}
378
+ self.log_dict(loss, on_epoch=True, prog_bar=True)
379
+ curr_psnr = self.compute_val_psnr(out, target)
380
+ for i, psnr in enumerate(curr_psnr):
381
+ self.log(f"val_psnr_ch{i+1}_batch", psnr, on_epoch=True)
382
+
383
+ def on_validation_epoch_end(self) -> None:
384
+ """Validation epoch end."""
385
+ psnr_ = self.reduce_running_psnr()
386
+ if psnr_ is not None:
387
+ self.log("val_psnr", psnr_, on_epoch=True, prog_bar=True)
388
+ else:
389
+ self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
390
+
391
+ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
392
+ """Prediction step.
393
+
394
+ Parameters
395
+ ----------
396
+ batch : Tensor
397
+ Input batch.
398
+ batch_idx : Any
399
+ Batch index.
400
+
401
+ Returns
402
+ -------
403
+ Any
404
+ Model output.
405
+ """
406
+ if self._trainer.datamodule.tiled:
407
+ x, *aux = batch
408
+ else:
409
+ x = batch
410
+ aux = []
411
+
412
+ # apply test-time augmentation if available
413
+ # TODO: probably wont work with batch size > 1
414
+ if self._trainer.datamodule.prediction_config.tta_transforms:
415
+ tta = ImageRestorationTTA()
416
+ augmented_batch = tta.forward(x) # list of augmented tensors
417
+ augmented_output = []
418
+ for augmented in augmented_batch:
419
+ augmented_pred = self.model(augmented)
420
+ augmented_output.append(augmented_pred)
421
+ output = tta.backward(augmented_output)
422
+ else:
423
+ output = self.model(x)
424
+
425
+ # Denormalize the output
426
+ denorm = Denormalize(
427
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
428
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
429
+ )
430
+ denormalized_output = denorm(patch=output.cpu().numpy())
431
+
432
+ if len(aux) > 0: # aux can be tiling information
433
+ return denormalized_output, *aux
434
+ else:
435
+ return denormalized_output
436
+
437
+ def configure_optimizers(self) -> Any:
438
+ """Configure optimizers and learning rate schedulers.
439
+
440
+ Returns
441
+ -------
442
+ Any
443
+ Optimizer and learning rate scheduler.
444
+ """
445
+ # instantiate optimizer
446
+ optimizer_func = get_optimizer(self.optimizer_name)
447
+ optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
448
+
449
+ # and scheduler
450
+ scheduler_func = get_scheduler(self.lr_scheduler_name)
451
+ scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
452
+
453
+ return {
454
+ "optimizer": optimizer,
455
+ "lr_scheduler": scheduler,
456
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
457
+ }
458
+
459
+ # TODO: find a way to move the following methods to a separate module
460
+ # TODO: this same operation is done in many other places, like in loss_func
461
+ # should we refactor LadderVAE so that it already outputs
462
+ # tuple(`mean`, `logvar`, `td_data`)?
463
+ def get_reconstructed_tensor(
464
+ self, model_outputs: tuple[Tensor, dict[str, Any]]
465
+ ) -> Tensor:
466
+ """Get the reconstructed tensor from the LVAE model outputs.
467
+
468
+ Parameters
469
+ ----------
470
+ model_outputs : tuple[Tensor, dict[str, Any]]
471
+ Model outputs. It is a tuple with a tensor representing the predicted mean
472
+ and (optionally) logvar, and the top-down data dictionary.
473
+
474
+ Returns
475
+ -------
476
+ Tensor
477
+ Reconstructed tensor, i.e., the predicted mean.
478
+ """
479
+ predictions, _ = model_outputs
480
+ if self.model.predict_logvar is None:
481
+ return predictions
482
+ elif self.model.predict_logvar == "pixelwise":
483
+ return predictions.chunk(2, dim=1)[0]
484
+
485
+ def compute_val_psnr(
486
+ self,
487
+ model_output: tuple[Tensor, dict[str, Any]],
488
+ target: Tensor,
489
+ psnr_func: Callable = scale_invariant_psnr,
490
+ ) -> list[float]:
491
+ """Compute the PSNR for the current validation batch.
492
+
493
+ Parameters
494
+ ----------
495
+ model_output : tuple[Tensor, dict[str, Any]]
496
+ Model output, a tuple with the predicted mean and (optionally) logvar,
497
+ and the top-down data dictionary.
498
+ target : Tensor
499
+ Target tensor.
500
+ psnr_func : Callable, optional
501
+ PSNR function to use, by default `scale_invariant_psnr`.
502
+
503
+ Returns
504
+ -------
505
+ list[float]
506
+ PSNR for each channel in the current batch.
507
+ """
508
+ out_channels = target.shape[1]
509
+
510
+ # get the reconstructed image
511
+ recons_img = self.get_reconstructed_tensor(model_output)
512
+
513
+ # update running psnr
514
+ for i in range(out_channels):
515
+ self.running_psnr[i].update(rec=recons_img[:, i], tar=target[:, i])
516
+
517
+ # compute psnr for each channel in the current batch
518
+ # TODO: this doesn't need do be a method of this class
519
+ # and hence can be moved to a separate module
520
+ return [
521
+ psnr_func(
522
+ gt=target[:, i].clone().detach().cpu().numpy(),
523
+ pred=recons_img[:, i].clone().detach().cpu().numpy(),
524
+ )
525
+ for i in range(out_channels)
526
+ ]
527
+
528
+ def reduce_running_psnr(self) -> Optional[float]:
529
+ """Reduce the running PSNR statistics and reset the running PSNR.
530
+
531
+ Returns
532
+ -------
533
+ Optional[float]
534
+ Running PSNR averaged over the different output channels.
535
+ """
536
+ psnr_arr = [] # type: ignore
537
+ for i in range(len(self.running_psnr)):
538
+ psnr = self.running_psnr[i].get()
539
+ if psnr is None:
540
+ psnr_arr = None # type: ignore
541
+ break
542
+ psnr_arr.append(psnr.cpu().numpy())
543
+ self.running_psnr[i].reset()
544
+ # TODO: this line forces it to be a method of this class
545
+ # alternative is returning also the reset `running_psnr`
546
+ if psnr_arr is not None:
547
+ psnr = np.mean(psnr_arr)
548
+ return psnr
549
+
550
+
551
+ # TODO: make this LVAE compatible (?)
552
+ def create_careamics_module(
553
+ algorithm_type: Literal["fcn"],
554
+ algorithm: Union[SupportedAlgorithm, str],
555
+ loss: Union[SupportedLoss, str],
556
+ architecture: Union[SupportedArchitecture, str],
557
+ model_parameters: Optional[dict] = None,
558
+ optimizer: Union[SupportedOptimizer, str] = "Adam",
559
+ optimizer_parameters: Optional[dict] = None,
560
+ lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
561
+ lr_scheduler_parameters: Optional[dict] = None,
562
+ ) -> Union[FCNModule, VAEModule]:
563
+ """Create a CAREamics Lightning module.
564
+
565
+ This function exposes parameters used to create an AlgorithmModel instance,
566
+ triggering parameters validation.
567
+
568
+ Parameters
569
+ ----------
570
+ algorithm_type : Literal["fcn"]
571
+ Algorithm type to use for training.
572
+ algorithm : SupportedAlgorithm or str
573
+ Algorithm to use for training (see SupportedAlgorithm).
574
+ loss : SupportedLoss or str
575
+ Loss function to use for training (see SupportedLoss).
576
+ architecture : SupportedArchitecture or str
577
+ Model architecture to use for training (see SupportedArchitecture).
578
+ model_parameters : dict, optional
579
+ Model parameters to use for training, by default {}. Model parameters are
580
+ defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
581
+ `careamics.config.architectures`).
582
+ optimizer : SupportedOptimizer or str, optional
583
+ Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
584
+ optimizer_parameters : dict, optional
585
+ Optimizer parameters to use for training, as defined in `torch.optim`, by
586
+ default {}.
587
+ lr_scheduler : SupportedScheduler or str, optional
588
+ Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
589
+ (see SupportedScheduler).
590
+ lr_scheduler_parameters : dict, optional
591
+ Learning rate scheduler parameters to use for training, as defined in
592
+ `torch.optim`, by default {}.
593
+
594
+ Returns
595
+ -------
596
+ CAREamicsModule
597
+ CAREamics Lightning module.
598
+ """
599
+ # create a AlgorithmModel compatible dictionary
600
+ if lr_scheduler_parameters is None:
601
+ lr_scheduler_parameters = {}
602
+ if optimizer_parameters is None:
603
+ optimizer_parameters = {}
604
+ if model_parameters is None:
605
+ model_parameters = {}
606
+ algorithm_configuration: dict[str, Any] = {
607
+ "algorithm_type": algorithm_type,
608
+ "algorithm": algorithm,
609
+ "loss": loss,
610
+ "optimizer": {
611
+ "name": optimizer,
612
+ "parameters": optimizer_parameters,
613
+ },
614
+ "lr_scheduler": {
615
+ "name": lr_scheduler,
616
+ "parameters": lr_scheduler_parameters,
617
+ },
618
+ }
619
+ model_configuration = {"architecture": architecture}
620
+ model_configuration.update(model_parameters)
621
+
622
+ # add model parameters to algorithm configuration
623
+ algorithm_configuration["model"] = model_configuration
624
+
625
+ # call the parent init using an AlgorithmModel instance
626
+ if algorithm_configuration["algorithm_type"] == "fcn":
627
+ return FCNModule(FCNAlgorithmConfig(**algorithm_configuration))
628
+ else:
629
+ raise NotImplementedError(
630
+ f"Model {algorithm_configuration['model']['architecture']} is not"
631
+ f"implemented or unknown."
632
+ )