careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,914 @@
1
+ """CAREamics Lightning module."""
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any, Literal, Union
5
+
6
+ import numpy as np
7
+ import pytorch_lightning as L
8
+ import torch
9
+
10
+ from careamics.config import (
11
+ N2VAlgorithm,
12
+ PN2VAlgorithm,
13
+ UNetBasedAlgorithm,
14
+ VAEBasedAlgorithm,
15
+ algorithm_factory,
16
+ )
17
+ from careamics.config.data.tile_information import TileInformation
18
+ from careamics.config.support import (
19
+ SupportedAlgorithm,
20
+ SupportedArchitecture,
21
+ SupportedLoss,
22
+ SupportedOptimizer,
23
+ SupportedScheduler,
24
+ )
25
+ from careamics.losses import loss_factory
26
+ from careamics.models.lvae.likelihoods import (
27
+ GaussianLikelihood,
28
+ NoiseModelLikelihood,
29
+ likelihood_factory,
30
+ )
31
+ from careamics.models.lvae.noise_models import (
32
+ GaussianMixtureNoiseModel,
33
+ MultiChannelNoiseModel,
34
+ multichannel_noise_model_factory,
35
+ noise_model_factory,
36
+ )
37
+ from careamics.models.model_factory import model_factory
38
+ from careamics.transforms import (
39
+ Denormalize,
40
+ ImageRestorationTTA,
41
+ N2VManipulateTorch,
42
+ TrainDenormalize,
43
+ )
44
+ from careamics.utils.metrics import RunningPSNR, scale_invariant_psnr
45
+ from careamics.utils.torch_utils import get_optimizer, get_scheduler
46
+
47
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
48
+
49
+
50
+ # TODO rename to UNetModule
51
+ class FCNModule(L.LightningModule):
52
+ """
53
+ CAREamics Lightning module.
54
+
55
+ This class encapsulates the PyTorch model along with the training, validation,
56
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
57
+
58
+ Parameters
59
+ ----------
60
+ algorithm_config : AlgorithmModel or dict
61
+ Algorithm configuration.
62
+
63
+ Attributes
64
+ ----------
65
+ model : torch.nn.Module
66
+ PyTorch model.
67
+ loss_func : torch.nn.Module
68
+ Loss function.
69
+ optimizer_name : str
70
+ Optimizer name.
71
+ optimizer_params : dict
72
+ Optimizer parameters.
73
+ lr_scheduler_name : str
74
+ Learning rate scheduler name.
75
+ """
76
+
77
+ def __init__(
78
+ self, algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm, dict]
79
+ ) -> None:
80
+ """Lightning module for CAREamics.
81
+
82
+ This class encapsulates the a PyTorch model along with the training, validation,
83
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
84
+
85
+ Parameters
86
+ ----------
87
+ algorithm_config : AlgorithmModel or dict
88
+ Algorithm configuration.
89
+ """
90
+ super().__init__()
91
+
92
+ if isinstance(algorithm_config, dict):
93
+ algorithm_config = algorithm_factory(algorithm_config)
94
+
95
+ self.algorithm_config = algorithm_config
96
+ # create preprocessing, model and loss function
97
+ if isinstance(self.algorithm_config, N2VAlgorithm | PN2VAlgorithm):
98
+ self.use_n2v = True
99
+ self.n2v_preprocess: N2VManipulateTorch | None = N2VManipulateTorch(
100
+ n2v_manipulate_config=self.algorithm_config.n2v_config
101
+ )
102
+ else:
103
+ self.use_n2v = False
104
+ self.n2v_preprocess = None
105
+
106
+ self.algorithm = self.algorithm_config.algorithm
107
+ self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
108
+ self.noise_model: NoiseModel | None = noise_model_factory(
109
+ self.algorithm_config.noise_model
110
+ if isinstance(self.algorithm_config, PN2VAlgorithm)
111
+ else None
112
+ )
113
+
114
+ # Create loss function, pre-configure with noise model for PN2V
115
+ loss_func = loss_factory(self.algorithm_config.loss)
116
+ if (
117
+ isinstance(self.algorithm_config, PN2VAlgorithm)
118
+ and self.noise_model is not None
119
+ ):
120
+ # For PN2V, reorder arguments and pass noise model
121
+ self.loss_func = lambda *args: loss_func(
122
+ args[0], args[1], args[2], self.noise_model
123
+ )
124
+ else:
125
+ self.loss_func = loss_func
126
+
127
+ # save optimizer and lr_scheduler names and parameters
128
+ self.optimizer_name = self.algorithm_config.optimizer.name
129
+ self.optimizer_params = self.algorithm_config.optimizer.parameters
130
+ self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name
131
+ self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters
132
+
133
+ def forward(self, x: Any) -> Any:
134
+ """Forward pass.
135
+
136
+ Parameters
137
+ ----------
138
+ x : Any
139
+ Input tensor.
140
+
141
+ Returns
142
+ -------
143
+ Any
144
+ Output tensor.
145
+ """
146
+ return self.model(x)
147
+
148
+ def _train_denormalize(self, out: torch.Tensor) -> torch.Tensor:
149
+ """Denormalize output using training dataset statistics.
150
+
151
+ Parameters
152
+ ----------
153
+ out : torch.Tensor
154
+ Output tensor to denormalize.
155
+
156
+ Returns
157
+ -------
158
+ torch.Tensor
159
+ Denormalized tensor.
160
+ """
161
+ denorm = TrainDenormalize(
162
+ image_means=(self._trainer.datamodule.train_dataset.image_stats.means),
163
+ image_stds=(self._trainer.datamodule.train_dataset.image_stats.stds),
164
+ )
165
+ return denorm(patch=out)
166
+
167
+ def _predict_denormalize(
168
+ self, out: torch.Tensor, from_prediction: bool
169
+ ) -> torch.Tensor:
170
+ """Denormalize output for prediction.
171
+
172
+ Parameters
173
+ ----------
174
+ out : torch.Tensor
175
+ Output tensor to denormalize.
176
+ from_prediction : bool
177
+ Whether using prediction or training dataset stats.
178
+
179
+ Returns
180
+ -------
181
+ torch.Tensor
182
+ Denormalized tensor.
183
+ """
184
+ denorm = Denormalize(
185
+ image_means=(
186
+ self._trainer.datamodule.predict_dataset.image_means
187
+ if from_prediction
188
+ else self._trainer.datamodule.train_dataset.image_stats.means
189
+ ),
190
+ image_stds=(
191
+ self._trainer.datamodule.predict_dataset.image_stds
192
+ if from_prediction
193
+ else self._trainer.datamodule.train_dataset.image_stats.stds
194
+ ),
195
+ )
196
+ return denorm(patch=out.cpu().numpy())
197
+
198
+ def training_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
199
+ """Training step.
200
+
201
+ Parameters
202
+ ----------
203
+ batch : torch.torch.Tensor
204
+ Input batch.
205
+ batch_idx : Any
206
+ Batch index.
207
+
208
+ Returns
209
+ -------
210
+ Any
211
+ Loss value.
212
+ """
213
+ x, *targets = batch
214
+ if self.use_n2v and self.n2v_preprocess is not None:
215
+ x_preprocessed, *aux = self.n2v_preprocess(x)
216
+ else:
217
+ x_preprocessed = x
218
+ aux = []
219
+
220
+ out = self.model(x_preprocessed)
221
+
222
+ # PN2V needs denormalized output and targets for loss computation
223
+ if isinstance(self.algorithm_config, PN2VAlgorithm):
224
+ out = self._train_denormalize(out)
225
+ aux = [self._train_denormalize(aux[0]), aux[1]]
226
+ # TODO hacky and ugly
227
+ loss = self.loss_func(out, *aux, *targets)
228
+ self.log(
229
+ "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
230
+ )
231
+ optimizer = self.optimizers()
232
+ current_lr = optimizer.param_groups[0]["lr"]
233
+ self.log("learning_rate", current_lr, on_step=False, on_epoch=True, logger=True)
234
+ return loss
235
+
236
+ def validation_step(self, batch: torch.Tensor, batch_idx: Any) -> None:
237
+ """Validation step.
238
+
239
+ Parameters
240
+ ----------
241
+ batch : torch.torch.Tensor
242
+ Input batch.
243
+ batch_idx : Any
244
+ Batch index.
245
+ """
246
+ x, *targets = batch
247
+ if self.use_n2v and self.n2v_preprocess is not None:
248
+ x_preprocessed, *aux = self.n2v_preprocess(x)
249
+ else:
250
+ x_preprocessed = x
251
+ aux = []
252
+
253
+ out = self.model(x_preprocessed)
254
+
255
+ # PN2V needs denormalized output and targets for loss computation
256
+ if isinstance(self.algorithm_config, PN2VAlgorithm):
257
+ out = torch.tensor(self._train_denormalize(out))
258
+ aux = [self._train_denormalize(aux[0]), aux[1]]
259
+ # TODO hacky and ugly
260
+ val_loss = self.loss_func(out, *aux, *targets)
261
+
262
+ # log validation loss
263
+ self.log(
264
+ "val_loss",
265
+ val_loss,
266
+ on_step=False,
267
+ on_epoch=True,
268
+ prog_bar=True,
269
+ logger=True,
270
+ )
271
+
272
+ def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
273
+ """Prediction step.
274
+
275
+ Parameters
276
+ ----------
277
+ batch : torch.torch.torch.Tensor
278
+ Input batch.
279
+ batch_idx : Any
280
+ Batch index.
281
+
282
+ Returns
283
+ -------
284
+ Any
285
+ Model output.
286
+ """
287
+ # TODO refactor when redoing datasets
288
+ # hacky way to determine if it is PredictDataModule, otherwise there is a
289
+ # circular import to solve with isinstance
290
+ from_prediction = hasattr(self._trainer.datamodule, "tiled")
291
+ is_tiled = (
292
+ len(batch) > 1
293
+ and isinstance(batch[1], list)
294
+ and isinstance(batch[1][0], TileInformation)
295
+ )
296
+
297
+ # TODO add explanations for what is happening here
298
+ if is_tiled:
299
+ x, *aux = batch
300
+ if type(x) in [list, tuple]:
301
+ x = x[0]
302
+ else:
303
+ if type(batch) in [list, tuple]:
304
+ x = batch[0] # TODO change, ugly way to deal with n2v refac
305
+ else:
306
+ x = batch
307
+ aux = []
308
+
309
+ # apply test-time augmentation if available
310
+ # TODO: probably wont work with batch size > 1
311
+ if (
312
+ from_prediction
313
+ and self._trainer.datamodule.prediction_config.tta_transforms
314
+ ):
315
+ tta = ImageRestorationTTA()
316
+ augmented_batch = tta.forward(x) # list of augmented tensors
317
+ augmented_output = []
318
+ for augmented in augmented_batch:
319
+ augmented_pred = self.model(augmented)
320
+ augmented_output.append(augmented_pred)
321
+ output = tta.backward(augmented_output)
322
+ else:
323
+ output = self.model(x)
324
+
325
+ # Denormalize the output
326
+ # TODO incompatible API between predict and train datasets
327
+
328
+ denormalized_input = self._predict_denormalize(
329
+ x, from_prediction=from_prediction
330
+ )
331
+ denormalized_output = self._predict_denormalize(
332
+ output, from_prediction=from_prediction
333
+ )
334
+
335
+ # Calculate MSE estimate
336
+ if isinstance(self.algorithm_config, PN2VAlgorithm):
337
+ assert self.noise_model is not None, "Noise model required for PN2V"
338
+ likelihoods = self.noise_model.likelihood(
339
+ torch.tensor(denormalized_input), torch.tensor(denormalized_output)
340
+ )
341
+ mse_estimate = torch.sum(
342
+ likelihoods * denormalized_output, dim=1, keepdim=True
343
+ )
344
+ mse_estimate /= torch.sum(likelihoods, dim=1, keepdim=True)
345
+
346
+ if isinstance(self.algorithm_config, PN2VAlgorithm):
347
+ denormalized_output = np.mean(denormalized_output, axis=1, keepdims=True)
348
+ denormalized_output = (denormalized_output, mse_estimate)
349
+ # TODO: might be ugly but otherwise we need to change the output signature
350
+ if len(aux) > 0: # aux can be tiling information
351
+ return denormalized_output, *aux
352
+ else:
353
+ return denormalized_output
354
+
355
+ def configure_optimizers(self) -> Any:
356
+ """Configure optimizers and learning rate schedulers.
357
+
358
+ Returns
359
+ -------
360
+ Any
361
+ Optimizer and learning rate scheduler.
362
+ """
363
+ # instantiate optimizer
364
+ optimizer_func = get_optimizer(self.optimizer_name)
365
+ optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
366
+
367
+ # and scheduler
368
+ scheduler_func = get_scheduler(self.lr_scheduler_name)
369
+ scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
370
+
371
+ return {
372
+ "optimizer": optimizer,
373
+ "lr_scheduler": scheduler,
374
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
375
+ }
376
+
377
+
378
+ class VAEModule(L.LightningModule):
379
+ """
380
+ CAREamics Lightning module.
381
+
382
+ This class encapsulates the a PyTorch model along with the training, validation,
383
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
384
+
385
+ Parameters
386
+ ----------
387
+ algorithm_config : Union[VAEAlgorithmConfig, dict]
388
+ Algorithm configuration.
389
+
390
+ Attributes
391
+ ----------
392
+ model : nn.Module
393
+ PyTorch model.
394
+ loss_func : nn.Module
395
+ Loss function.
396
+ optimizer_name : str
397
+ Optimizer name.
398
+ optimizer_params : dict
399
+ Optimizer parameters.
400
+ lr_scheduler_name : str
401
+ Learning rate scheduler name.
402
+ """
403
+
404
+ def __init__(self, algorithm_config: Union[VAEBasedAlgorithm, dict]) -> None:
405
+ """Lightning module for CAREamics.
406
+
407
+ This class encapsulates the a PyTorch model along with the training, validation,
408
+ and testing logic. It is configured using an `AlgorithmModel` Pydantic class.
409
+
410
+ Parameters
411
+ ----------
412
+ algorithm_config : Union[AlgorithmModel, dict]
413
+ Algorithm configuration.
414
+ """
415
+ super().__init__()
416
+ # if loading from a checkpoint, AlgorithmModel needs to be instantiated
417
+ self.algorithm_config = (
418
+ VAEBasedAlgorithm(**algorithm_config)
419
+ if isinstance(algorithm_config, dict)
420
+ else algorithm_config
421
+ )
422
+
423
+ # TODO: log algorithm config
424
+ # self.save_hyperparameters(self.algorithm_config.model_dump())
425
+
426
+ # create model
427
+ self.model: torch.nn.Module = model_factory(self.algorithm_config.model)
428
+
429
+ # supervised_mode
430
+ self.supervised_mode = self.algorithm_config.is_supervised
431
+ # create noise model (VAE algorithms always use multichannel nm factory)
432
+ self.noise_model: NoiseModel | None = multichannel_noise_model_factory(
433
+ self.algorithm_config.noise_model
434
+ )
435
+
436
+ self.noise_model_likelihood: NoiseModelLikelihood | None = None
437
+ if self.algorithm_config.noise_model_likelihood is not None:
438
+ self.noise_model_likelihood = likelihood_factory(
439
+ config=self.algorithm_config.noise_model_likelihood,
440
+ noise_model=self.noise_model,
441
+ )
442
+
443
+ self.gaussian_likelihood: GaussianLikelihood | None = likelihood_factory(
444
+ self.algorithm_config.gaussian_likelihood
445
+ )
446
+
447
+ self.loss_parameters = self.algorithm_config.loss
448
+ self.loss_func = loss_factory(self.algorithm_config.loss.loss_type)
449
+
450
+ # save optimizer and lr_scheduler names and parameters
451
+ self.optimizer_name = self.algorithm_config.optimizer.name
452
+ self.optimizer_params = self.algorithm_config.optimizer.parameters
453
+ self.lr_scheduler_name = self.algorithm_config.lr_scheduler.name
454
+ self.lr_scheduler_params = self.algorithm_config.lr_scheduler.parameters
455
+
456
+ # initialize running PSNR
457
+ self.running_psnr = [
458
+ RunningPSNR() for _ in range(self.algorithm_config.model.output_channels)
459
+ ]
460
+
461
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
462
+ """Forward pass.
463
+
464
+ Parameters
465
+ ----------
466
+ x : torch.Tensor
467
+ Input tensor of shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
468
+ number of lateral inputs.
469
+
470
+ Returns
471
+ -------
472
+ tuple[torch.Tensor, dict[str, Any]]
473
+ A tuple with the output tensor and additional data from the top-down pass.
474
+ """
475
+ return self.model(x) # TODO Different model can have more than one output
476
+
477
+ def set_data_stats(self, data_mean, data_std):
478
+ """Set data mean and std for the noise model likelihood.
479
+
480
+ Parameters
481
+ ----------
482
+ data_mean : float
483
+ Mean of the data.
484
+ data_std : float
485
+ Standard deviation of the data.
486
+ """
487
+ if self.noise_model_likelihood is not None:
488
+ self.noise_model_likelihood.set_data_stats(data_mean, data_std)
489
+
490
+ def training_step(
491
+ self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
492
+ ) -> dict[str, torch.Tensor] | None:
493
+ """Training step.
494
+
495
+ Parameters
496
+ ----------
497
+ batch : tuple[torch.Tensor, torch.Tensor]
498
+ Input batch. It is a tuple with the input tensor and the target tensor.
499
+ The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
500
+ number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
501
+ where C is the number of target channels (e.g., 1 in HDN, >1 in
502
+ muSplit/denoiSplit).
503
+ batch_idx : Any
504
+ Batch index.
505
+
506
+ Returns
507
+ -------
508
+ Any
509
+ Loss value.
510
+ """
511
+ x, *target = batch
512
+
513
+ # Forward pass
514
+ out = self.model(x)
515
+ if not self.supervised_mode:
516
+ target = x
517
+ else:
518
+ target = target[
519
+ 0
520
+ ] # hacky way to unpack. #TODO maybe should be fixed on the dataset level
521
+
522
+ # Update loss parameters
523
+ self.loss_parameters.kl_params.current_epoch = self.current_epoch
524
+
525
+ # Compute loss
526
+ if self.noise_model_likelihood is not None:
527
+ if (
528
+ self.noise_model_likelihood.data_mean is None
529
+ or self.noise_model_likelihood.data_std is None
530
+ ):
531
+ raise RuntimeError(
532
+ "NoiseModelLikelihood: mean and std must be set before training."
533
+ )
534
+ loss = self.loss_func(
535
+ model_outputs=out,
536
+ targets=target,
537
+ config=self.loss_parameters,
538
+ gaussian_likelihood=self.gaussian_likelihood,
539
+ noise_model_likelihood=self.noise_model_likelihood,
540
+ )
541
+
542
+ # Logging
543
+ # TODO: implement a separate logging method?
544
+ self.log_dict(loss, on_step=True, on_epoch=True)
545
+
546
+ try:
547
+ optimizer = self.optimizers()
548
+ current_lr = optimizer.param_groups[0]["lr"]
549
+ self.log(
550
+ "learning_rate", current_lr, on_step=False, on_epoch=True, logger=True
551
+ )
552
+ except RuntimeError:
553
+ # This happens when the module is not attached to a trainer, e.g., in tests
554
+ pass
555
+ return loss
556
+
557
+ def validation_step(
558
+ self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: Any
559
+ ) -> None:
560
+ """Validation step.
561
+
562
+ Parameters
563
+ ----------
564
+ batch : tuple[torch.Tensor, torch.Tensor]
565
+ Input batch. It is a tuple with the input tensor and the target tensor.
566
+ The input tensor has shape (B, (1 + n_LC), [Z], Y, X), where n_LC is the
567
+ number of lateral inputs. The target tensor has shape (B, C, [Z], Y, X),
568
+ where C is the number of target channels (e.g., 1 in HDN, >1 in
569
+ muSplit/denoiSplit).
570
+ batch_idx : Any
571
+ Batch index.
572
+ """
573
+ x, *target = batch
574
+
575
+ # Forward pass
576
+ out = self.model(x)
577
+ if not self.supervised_mode:
578
+ target = x
579
+ else:
580
+ target = target[
581
+ 0
582
+ ] # hacky way to unpack. #TODO maybe should be fixed on the datasel level
583
+ # Compute loss
584
+ loss = self.loss_func(
585
+ model_outputs=out,
586
+ targets=target,
587
+ config=self.loss_parameters,
588
+ gaussian_likelihood=self.gaussian_likelihood,
589
+ noise_model_likelihood=self.noise_model_likelihood,
590
+ )
591
+
592
+ # Logging
593
+ # Rename val_loss dict
594
+ loss = {"_".join(["val", k]): v for k, v in loss.items()}
595
+ self.log_dict(loss, on_epoch=True, prog_bar=True)
596
+ curr_psnr = self.compute_val_psnr(out, target)
597
+ for i, psnr in enumerate(curr_psnr):
598
+ self.log(f"val_psnr_ch{i+1}_batch", psnr, on_epoch=True)
599
+
600
+ def on_validation_epoch_end(self) -> None:
601
+ """Validation epoch end."""
602
+ psnr_ = self.reduce_running_psnr()
603
+ if psnr_ is not None:
604
+ self.log("val_psnr", psnr_, on_epoch=True, prog_bar=True)
605
+ else:
606
+ self.log("val_psnr", 0.0, on_epoch=True, prog_bar=True)
607
+
608
+ def predict_step(self, batch: torch.Tensor, batch_idx: Any) -> Any:
609
+ """Prediction step.
610
+
611
+ Parameters
612
+ ----------
613
+ batch : torch.Tensor
614
+ Input batch.
615
+ batch_idx : Any
616
+ Batch index.
617
+
618
+ Returns
619
+ -------
620
+ Any
621
+ Model output.
622
+ """
623
+ if self.algorithm_config.algorithm == "microsplit":
624
+ x, *aux = batch
625
+ # Reset model for inference with spatial dimensions only (H, W)
626
+ self.model.reset_for_inference(x.shape[-2:])
627
+
628
+ rec_img_list = []
629
+ for _ in range(self.algorithm_config.mmse_count):
630
+ # get model output
631
+ rec, _ = self.model(x)
632
+
633
+ # get reconstructed img
634
+ if self.model.predict_logvar is None:
635
+ rec_img = rec
636
+ _logvar = torch.tensor([-1])
637
+ else:
638
+ rec_img, _logvar = torch.chunk(rec, chunks=2, dim=1)
639
+ rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
640
+
641
+ # aggregate results
642
+ samples = torch.cat(rec_img_list, dim=0)
643
+ mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
644
+ std_imgs = torch.std(samples, dim=0) # std over MMSE dim
645
+
646
+ tile_prediction = mmse_imgs.cpu().numpy()
647
+ tile_std = std_imgs.cpu().numpy()
648
+
649
+ return tile_prediction, tile_std
650
+
651
+ else:
652
+ # Regular prediction logic
653
+ if self._trainer.datamodule.tiled:
654
+ # TODO tile_size should match model input size
655
+ x, *aux = batch
656
+ x = (
657
+ x[0] if isinstance(x, list | tuple) else x
658
+ ) # TODO ugly, so far i don't know why x might be a list
659
+ self.model.reset_for_inference(x.shape) # TODO should it be here ?
660
+ else:
661
+ x = batch[0] if isinstance(batch, list | tuple) else batch
662
+ aux = []
663
+ self.model.reset_for_inference(x.shape)
664
+
665
+ mmse_list = []
666
+ for _ in range(self.algorithm_config.mmse_count):
667
+ # apply test-time augmentation if available
668
+ if self._trainer.datamodule.prediction_config.tta_transforms:
669
+ tta = ImageRestorationTTA()
670
+ augmented_batch = tta.forward(x) # list of augmented tensors
671
+ augmented_output = []
672
+ for augmented in augmented_batch:
673
+ augmented_pred = self.model(augmented)
674
+ augmented_output.append(augmented_pred)
675
+ output = tta.backward(augmented_output)
676
+ else:
677
+ output = self.model(x)
678
+
679
+ # taking the 1st element of the output, 2nd is std if
680
+ # predict_logvar=="pixelwise"
681
+ output = (
682
+ output[0]
683
+ if self.model.predict_logvar is None
684
+ else output[0][:, 0:1, ...]
685
+ )
686
+ mmse_list.append(output)
687
+
688
+ mmse = torch.stack(mmse_list).mean(0)
689
+ std = torch.stack(mmse_list).std(0) # TODO why?
690
+ # TODO better way to unpack if pred logvar
691
+ # Denormalize the output
692
+ denorm = Denormalize(
693
+ image_means=self._trainer.datamodule.predict_dataset.image_means,
694
+ image_stds=self._trainer.datamodule.predict_dataset.image_stds,
695
+ )
696
+
697
+ denormalized_output = denorm(patch=mmse.cpu().numpy())
698
+
699
+ if len(aux) > 0: # aux can be tiling information
700
+ return denormalized_output, std, *aux
701
+ else:
702
+ return denormalized_output, std
703
+
704
+ def configure_optimizers(self) -> Any:
705
+ """Configure optimizers and learning rate schedulers.
706
+
707
+ Returns
708
+ -------
709
+ Any
710
+ Optimizer and learning rate scheduler.
711
+ """
712
+ # instantiate optimizer
713
+ optimizer_func = get_optimizer(self.optimizer_name)
714
+ optimizer = optimizer_func(self.model.parameters(), **self.optimizer_params)
715
+
716
+ # and scheduler
717
+ scheduler_func = get_scheduler(self.lr_scheduler_name)
718
+ scheduler = scheduler_func(optimizer, **self.lr_scheduler_params)
719
+
720
+ return {
721
+ "optimizer": optimizer,
722
+ "lr_scheduler": scheduler,
723
+ "monitor": "val_loss", # otherwise triggers MisconfigurationException
724
+ }
725
+
726
+ # TODO: find a way to move the following methods to a separate module
727
+ # TODO: this same operation is done in many other places, like in loss_func
728
+ # should we refactor LadderVAE so that it already outputs
729
+ # tuple(`mean`, `logvar`, `td_data`)?
730
+ def get_reconstructed_tensor(
731
+ self, model_outputs: tuple[torch.Tensor, dict[str, Any]]
732
+ ) -> torch.Tensor:
733
+ """Get the reconstructed tensor from the LVAE model outputs.
734
+
735
+ Parameters
736
+ ----------
737
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
738
+ Model outputs. It is a tuple with a tensor representing the predicted mean
739
+ and (optionally) logvar, and the top-down data dictionary.
740
+
741
+ Returns
742
+ -------
743
+ torch.Tensor
744
+ Reconstructed tensor, i.e., the predicted mean.
745
+ """
746
+ predictions, _ = model_outputs
747
+ if self.model.predict_logvar is None:
748
+ return predictions
749
+ elif self.model.predict_logvar == "pixelwise":
750
+ return predictions.chunk(2, dim=1)[0]
751
+
752
+ def compute_val_psnr(
753
+ self,
754
+ model_output: tuple[torch.Tensor, dict[str, Any]],
755
+ target: torch.Tensor,
756
+ psnr_func: Callable = scale_invariant_psnr,
757
+ ) -> list[float]:
758
+ """Compute the PSNR for the current validation batch.
759
+
760
+ Parameters
761
+ ----------
762
+ model_output : tuple[torch.Tensor, dict[str, Any]]
763
+ Model output, a tuple with the predicted mean and (optionally) logvar,
764
+ and the top-down data dictionary.
765
+ target : torch.Tensor
766
+ Target tensor.
767
+ psnr_func : Callable, optional
768
+ PSNR function to use, by default `scale_invariant_psnr`.
769
+
770
+ Returns
771
+ -------
772
+ list[float]
773
+ PSNR for each channel in the current batch.
774
+ """
775
+ # TODO check this! Related to is_supervised which is also wacky
776
+ out_channels = target.shape[1]
777
+
778
+ # get the reconstructed image
779
+ recons_img = self.get_reconstructed_tensor(model_output)
780
+
781
+ # update running psnr
782
+ for i in range(out_channels):
783
+ self.running_psnr[i].update(rec=recons_img[:, i], tar=target[:, i])
784
+
785
+ # compute psnr for each channel in the current batch
786
+ # TODO: this doesn't need do be a method of this class
787
+ # and hence can be moved to a separate module
788
+ return [
789
+ psnr_func(
790
+ gt=target[:, i].clone().detach().cpu().numpy(),
791
+ pred=recons_img[:, i].clone().detach().cpu().numpy(),
792
+ )
793
+ for i in range(out_channels)
794
+ ]
795
+
796
+ def reduce_running_psnr(self) -> float | None:
797
+ """Reduce the running PSNR statistics and reset the running PSNR.
798
+
799
+ Returns
800
+ -------
801
+ Optional[float]
802
+ Running PSNR averaged over the different output channels.
803
+ """
804
+ psnr_arr = [] # type: ignore
805
+ for i in range(len(self.running_psnr)):
806
+ psnr = self.running_psnr[i].get()
807
+ if psnr is None:
808
+ psnr_arr = None # type: ignore
809
+ break
810
+ psnr_arr.append(psnr.cpu().numpy())
811
+ self.running_psnr[i].reset()
812
+ # TODO: this line forces it to be a method of this class
813
+ # alternative is returning also the reset `running_psnr`
814
+ if psnr_arr is not None:
815
+ psnr = np.mean(psnr_arr)
816
+ return psnr
817
+
818
+
819
+ # TODO: make this LVAE compatible (?)
820
+ def create_careamics_module(
821
+ algorithm: Union[SupportedAlgorithm, str],
822
+ loss: Union[SupportedLoss, str],
823
+ architecture: Union[SupportedArchitecture, str],
824
+ use_n2v2: bool = False,
825
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
826
+ struct_n2v_span: int = 5,
827
+ model_parameters: dict | None = None,
828
+ optimizer: Union[SupportedOptimizer, str] = "Adam",
829
+ optimizer_parameters: dict | None = None,
830
+ lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau",
831
+ lr_scheduler_parameters: dict | None = None,
832
+ ) -> Union[FCNModule, VAEModule]:
833
+ """Create a CAREamics Lightning module.
834
+
835
+ This function exposes parameters used to create an AlgorithmModel instance,
836
+ triggering parameters validation.
837
+
838
+ Parameters
839
+ ----------
840
+ algorithm : SupportedAlgorithm or str
841
+ Algorithm to use for training (see SupportedAlgorithm).
842
+ loss : SupportedLoss or str
843
+ Loss function to use for training (see SupportedLoss).
844
+ architecture : SupportedArchitecture or str
845
+ Model architecture to use for training (see SupportedArchitecture).
846
+ use_n2v2 : bool, default=False
847
+ Whether to use N2V2 or Noise2Void.
848
+ struct_n2v_axis : "horizontal", "vertical", or "none", default="none"
849
+ Axis of the StructN2V mask.
850
+ struct_n2v_span : int, default=5
851
+ Span of the StructN2V mask.
852
+ model_parameters : dict, optional
853
+ Model parameters to use for training, by default {}. Model parameters are
854
+ defined in the relevant `torch.nn.Module` class, or Pyddantic model (see
855
+ `careamics.config.architectures`).
856
+ optimizer : SupportedOptimizer or str, optional
857
+ Optimizer to use for training, by default "Adam" (see SupportedOptimizer).
858
+ optimizer_parameters : dict, optional
859
+ Optimizer parameters to use for training, as defined in `torch.optim`, by
860
+ default {}.
861
+ lr_scheduler : SupportedScheduler or str, optional
862
+ Learning rate scheduler to use for training, by default "ReduceLROnPlateau"
863
+ (see SupportedScheduler).
864
+ lr_scheduler_parameters : dict, optional
865
+ Learning rate scheduler parameters to use for training, as defined in
866
+ `torch.optim`, by default {}.
867
+
868
+ Returns
869
+ -------
870
+ CAREamicsModule
871
+ CAREamics Lightning module.
872
+ """
873
+ # TODO should use the same functions are in configuration_factory.py
874
+ # create an AlgorithmModel compatible dictionary
875
+ if lr_scheduler_parameters is None:
876
+ lr_scheduler_parameters = {}
877
+ if optimizer_parameters is None:
878
+ optimizer_parameters = {}
879
+ if model_parameters is None:
880
+ model_parameters = {}
881
+ algorithm_dict: dict[str, Any] = {
882
+ "algorithm": algorithm,
883
+ "loss": loss,
884
+ "optimizer": {
885
+ "name": optimizer,
886
+ "parameters": optimizer_parameters,
887
+ },
888
+ "lr_scheduler": {
889
+ "name": lr_scheduler,
890
+ "parameters": lr_scheduler_parameters,
891
+ },
892
+ }
893
+
894
+ model_dict = {"architecture": architecture}
895
+ model_dict.update(model_parameters)
896
+
897
+ # add model parameters to algorithm configuration
898
+ algorithm_dict["model"] = model_dict
899
+
900
+ which_algo = algorithm_dict["algorithm"]
901
+ if which_algo in UNetBasedAlgorithm.get_compatible_algorithms():
902
+ algorithm_cfg = algorithm_factory(algorithm_dict)
903
+
904
+ # if use N2V
905
+ if isinstance(algorithm_cfg, N2VAlgorithm | PN2VAlgorithm):
906
+ algorithm_cfg.n2v_config.struct_mask_axis = struct_n2v_axis
907
+ algorithm_cfg.n2v_config.struct_mask_span = struct_n2v_span
908
+ algorithm_cfg.set_n2v2(use_n2v2)
909
+
910
+ return FCNModule(algorithm_cfg)
911
+ else:
912
+ raise NotImplementedError(
913
+ f"Algorithm {which_algo} is not implemented or unknown."
914
+ )