careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,394 @@
1
+ """
2
+ Script containing modules for defining different likelihood functions (as nn.Module).
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+
14
+ from careamics.config.noise_model.likelihood_config import (
15
+ GaussianLikelihoodConfig,
16
+ NMLikelihoodConfig,
17
+ )
18
+
19
+ if TYPE_CHECKING:
20
+ from careamics.models.lvae.noise_models import (
21
+ GaussianMixtureNoiseModel,
22
+ MultiChannelNoiseModel,
23
+ )
24
+
25
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
26
+
27
+
28
+ def likelihood_factory(
29
+ config: Optional[Union[GaussianLikelihoodConfig, NMLikelihoodConfig]],
30
+ noise_model: Optional[NoiseModel] = None,
31
+ ):
32
+ """
33
+ Factory function for creating likelihood modules.
34
+
35
+ Parameters
36
+ ----------
37
+ config: Union[GaussianLikelihoodConfig, NMLikelihoodConfig]
38
+ The configuration object for the likelihood module.
39
+ noise_model: Optional[NoiseModel]
40
+ The noise model instance used to define the `NoiseModelLikelihood`.
41
+
42
+ Returns
43
+ -------
44
+ nn.Module
45
+ The likelihood module.
46
+ """
47
+ if config is None:
48
+ return None
49
+
50
+ if isinstance(config, GaussianLikelihoodConfig):
51
+ return GaussianLikelihood(
52
+ predict_logvar=config.predict_logvar,
53
+ logvar_lowerbound=config.logvar_lowerbound,
54
+ )
55
+ elif isinstance(config, NMLikelihoodConfig):
56
+ return NoiseModelLikelihood(
57
+ noise_model=noise_model,
58
+ )
59
+
60
+
61
+ # TODO: is it really worth to have this class? Or it just adds complexity? --> REFACTOR
62
+ class LikelihoodModule(nn.Module):
63
+ """
64
+ The base class for all likelihood modules.
65
+ It defines the fundamental structure and methods for specialized likelihood models.
66
+ """
67
+
68
+ def distr_params(self, x: Any) -> None:
69
+ return None
70
+
71
+ def set_params_to_same_device_as(self, correct_device_tensor: Any) -> None:
72
+ pass
73
+
74
+ @staticmethod
75
+ def logvar(params: Any) -> None:
76
+ return None
77
+
78
+ @staticmethod
79
+ def mean(params: Any) -> None:
80
+ return None
81
+
82
+ @staticmethod
83
+ def mode(params: Any) -> None:
84
+ return None
85
+
86
+ @staticmethod
87
+ def sample(params: Any) -> None:
88
+ return None
89
+
90
+ def log_likelihood(self, x: Any, params: Any) -> None:
91
+ return None
92
+
93
+ def get_mean_lv(
94
+ self, x: torch.Tensor
95
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ...
96
+
97
+ def forward(
98
+ self, input_: torch.Tensor, x: Union[torch.Tensor, None]
99
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
100
+ """
101
+ Parameters
102
+ ----------
103
+ input_: torch.Tensor
104
+ The output of the top-down pass (e.g., reconstructed image in HDN,
105
+ or the unmixed images in 'Split' models).
106
+ x: Union[torch.Tensor, None]
107
+ The target tensor. If None, the log-likelihood is not computed.
108
+ """
109
+ distr_params = self.distr_params(input_)
110
+ mean = self.mean(distr_params)
111
+ mode = self.mode(distr_params)
112
+ sample = self.sample(distr_params)
113
+ logvar = self.logvar(distr_params)
114
+
115
+ if x is None:
116
+ ll = None
117
+ else:
118
+ ll = self.log_likelihood(x, distr_params)
119
+
120
+ dct = {
121
+ "mean": mean,
122
+ "mode": mode,
123
+ "sample": sample,
124
+ "params": distr_params,
125
+ "logvar": logvar,
126
+ }
127
+
128
+ return ll, dct
129
+
130
+
131
+ class GaussianLikelihood(LikelihoodModule):
132
+ r"""A specialized `LikelihoodModule` for Gaussian likelihood.
133
+
134
+ Specifically, in the LVAE model, the likelihood is defined as:
135
+ p(x|z_1) = N(x|\mu_{p,1}, \sigma_{p,1}^2)
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ predict_logvar: Union[Literal["pixelwise"], None] = None,
141
+ logvar_lowerbound: Union[float, None] = None,
142
+ ):
143
+ """Constructor.
144
+
145
+ Parameters
146
+ ----------
147
+ predict_logvar: Union[Literal["pixelwise"], None], optional
148
+ If `pixelwise`, log-variance is computed for each pixel, else log-variance
149
+ is not computed. Default is `None`.
150
+ logvar_lowerbound: float, optional
151
+ The lowerbound value for log-variance. Default is `None`.
152
+ """
153
+ super().__init__()
154
+
155
+ self.predict_logvar = predict_logvar
156
+ self.logvar_lowerbound = logvar_lowerbound
157
+ assert self.predict_logvar in [None, "pixelwise"]
158
+
159
+ print(
160
+ f"[{self.__class__.__name__}] PredLVar:{self.predict_logvar} LowBLVar:{self.logvar_lowerbound}"
161
+ )
162
+
163
+ def get_mean_lv(
164
+ self, x: torch.Tensor
165
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
166
+ """
167
+ Given the output of the top-down pass, compute the mean and log-variance of the
168
+ Gaussian distribution defining the likelihood.
169
+
170
+ Parameters
171
+ ----------
172
+ x: torch.Tensor
173
+ The input tensor to the likelihood module, i.e., the output of the top-down
174
+ pass.
175
+
176
+ Returns
177
+ -------
178
+ tuple of (torch.tensor, optional torch.tensor)
179
+ The first element of the tuple is the mean, the second element is the
180
+ log-variance. If the attribute `predict_logvar` is `None` then the second
181
+ element will be `None`.
182
+ """
183
+ # if LadderVAE.predict_logvar is None, dim 1 of `x`` has no. of target channels
184
+ if self.predict_logvar is None:
185
+ return x, None
186
+
187
+ # Get pixel-wise mean and logvar
188
+ # if LadderVAE.predict_logvar is not None,
189
+ # dim 1 has double no. of target channels
190
+ mean, lv = x.chunk(2, dim=1)
191
+
192
+ # Optionally, clip log-var to a lower bound
193
+ if self.logvar_lowerbound is not None:
194
+ lv = torch.clip(lv, min=self.logvar_lowerbound)
195
+
196
+ return mean, lv
197
+
198
+ def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
199
+ """
200
+ Get parameters (mean, log-var) of the Gaussian distribution defined by the likelihood.
201
+
202
+ Parameters
203
+ ----------
204
+ x: torch.Tensor
205
+ The input tensor to the likelihood module, i.e., the output
206
+ the LVAE 'output_layer'. Shape is: (B, 2 * C, [Z], Y, X) in case
207
+ `predict_logvar` is not None, or (B, C, [Z], Y, X) otherwise.
208
+ """
209
+ mean, lv = self.get_mean_lv(x)
210
+ params = {
211
+ "mean": mean,
212
+ "logvar": lv,
213
+ }
214
+ return params
215
+
216
+ @staticmethod
217
+ def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
218
+ return params["mean"]
219
+
220
+ @staticmethod
221
+ def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
222
+ return params["mean"]
223
+
224
+ @staticmethod
225
+ def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
226
+ # p = Normal(params['mean'], (params['logvar'] / 2).exp())
227
+ # return p.rsample()
228
+ return params["mean"]
229
+
230
+ @staticmethod
231
+ def logvar(params: dict[str, torch.Tensor]) -> torch.Tensor:
232
+ return params["logvar"]
233
+
234
+ def log_likelihood(
235
+ self, x: torch.Tensor, params: dict[str, Union[torch.Tensor, None]]
236
+ ):
237
+ """Compute Gaussian log-likelihood
238
+
239
+ Parameters
240
+ ----------
241
+ x: torch.Tensor
242
+ The target tensor. Shape is (B, C, [Z], Y, X).
243
+ params: dict[str, Union[torch.Tensor, None]]
244
+ The tensors obtained by chunking the output of the top-down pass,
245
+ here used as parameters of the Gaussian distribution.
246
+
247
+ Returns
248
+ -------
249
+ torch.Tensor
250
+ The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
251
+ """
252
+ if self.predict_logvar is not None:
253
+ logprob = log_normal(x, params["mean"], params["logvar"])
254
+ else:
255
+ logprob = -0.5 * (params["mean"] - x) ** 2
256
+ return logprob
257
+
258
+
259
+ def log_normal(
260
+ x: torch.Tensor, mean: torch.Tensor, logvar: torch.Tensor
261
+ ) -> torch.Tensor:
262
+ """
263
+ Compute the log-probability at `x` of a Gaussian distribution
264
+ with parameters `(mean, exp(logvar))`.
265
+
266
+ NOTE: In the case of LVAE, the log-likeihood formula becomes:
267
+ \\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{p_\theta(x|z_1)}]=-\frac{1}{2}(\\mathbb{E}_{z_1\\sim{q_\\phi}}[\\log{2\\pi\\sigma_{p,0}^2(z_1)}] +\\mathbb{E}_{z_1\\sim{q_\\phi}}[\frac{(x-\\mu_{p,0}(z_1))^2}{\\sigma_{p,0}^2(z_1)}])
268
+
269
+ Parameters
270
+ ----------
271
+ x: torch.Tensor
272
+ The ground-truth tensor. Shape is (batch, channels, dim1, dim2).
273
+ mean: torch.Tensor
274
+ The inferred mean of distribution. Shape is (batch, channels, dim1, dim2).
275
+ logvar: torch.Tensor
276
+ The inferred log-variance of distribution. Shape has to be either scalar or broadcastable.
277
+ """
278
+ var = torch.exp(logvar)
279
+ log_prob = -0.5 * (
280
+ ((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
281
+ )
282
+ return log_prob
283
+
284
+
285
+ class NoiseModelLikelihood(LikelihoodModule):
286
+
287
+ def __init__(
288
+ self,
289
+ noise_model: NoiseModel,
290
+ ):
291
+ """Constructor.
292
+
293
+ Parameters
294
+ ----------
295
+ noiseModel: NoiseModel
296
+ The noise model instance used to compute the likelihood.
297
+ """
298
+ super().__init__()
299
+ self.data_mean = None
300
+ self.data_std = None
301
+ self.noiseModel = noise_model
302
+
303
+ def set_data_stats(
304
+ self,
305
+ data_mean: Union[np.ndarray, torch.Tensor],
306
+ data_std: Union[np.ndarray, torch.Tensor],
307
+ ) -> None:
308
+ """Set the data mean and std for denormalization.
309
+ # TODO check this !!
310
+ Parameters
311
+ ----------
312
+ data_mean : Union[np.ndarray, torch.Tensor]
313
+ Mean values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
314
+ data_std : Union[np.ndarray, torch.Tensor]
315
+ Standard deviation values for each channel. Will be reshaped to (1, C, 1, 1, 1) for broadcasting.
316
+ """
317
+ # Convert to tensor if needed
318
+ self.data_mean = torch.as_tensor(data_mean, dtype=torch.float32)
319
+ self.data_std = torch.as_tensor(data_std, dtype=torch.float32)
320
+
321
+ # TODO add extra dim for 3D ?
322
+
323
+ def _set_params_to_same_device_as(
324
+ self, correct_device_tensor: torch.Tensor
325
+ ) -> None:
326
+ """Set the parameters to the same device as the input tensor.
327
+
328
+ Parameters
329
+ ----------
330
+ correct_device_tensor: torch.Tensor
331
+ The tensor whose device is used to set the parameters.
332
+ """
333
+ if (
334
+ self.data_mean is not None
335
+ and self.data_mean.device != correct_device_tensor.device
336
+ ):
337
+ self.data_mean = self.data_mean.to(correct_device_tensor.device)
338
+ self.data_std = self.data_std.to(correct_device_tensor.device)
339
+ if correct_device_tensor.device != self.noiseModel.device:
340
+ self.noiseModel.to_device(correct_device_tensor.device)
341
+
342
+ def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
343
+ return x, None
344
+
345
+ def distr_params(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
346
+ mean, lv = self.get_mean_lv(x)
347
+ params = {
348
+ "mean": mean,
349
+ "logvar": lv,
350
+ }
351
+ return params
352
+
353
+ @staticmethod
354
+ def mean(params: dict[str, torch.Tensor]) -> torch.Tensor:
355
+ return params["mean"]
356
+
357
+ @staticmethod
358
+ def mode(params: dict[str, torch.Tensor]) -> torch.Tensor:
359
+ return params["mean"]
360
+
361
+ @staticmethod
362
+ def sample(params: dict[str, torch.Tensor]) -> torch.Tensor:
363
+ return params["mean"]
364
+
365
+ def log_likelihood(self, x: torch.Tensor, params: dict[str, torch.Tensor]):
366
+ """Compute the log-likelihood given the parameters `params` obtained
367
+ from the reconstruction tensor and the target tensor `x`.
368
+
369
+ Parameters
370
+ ----------
371
+ x: torch.Tensor
372
+ The target tensor. Shape is (B, C, [Z], Y, X).
373
+ params: dict[str, Union[torch.Tensor, None]]
374
+ The tensors obtained from output of the top-down pass.
375
+ Here, "mean" correspond to the whole output, while logvar is `None`.
376
+
377
+ Returns
378
+ -------
379
+ torch.Tensor
380
+ The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
381
+ """
382
+ if self.data_mean is None or self.data_std is None:
383
+ raise RuntimeError(
384
+ "NoiseModelLikelihood: data_mean and data_std must be set before"
385
+ "callinglog_likelihood."
386
+ )
387
+ self._set_params_to_same_device_as(x)
388
+ predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
389
+ x_denormalized = x * self.data_std + self.data_mean
390
+ likelihoods = self.noiseModel.likelihood(
391
+ x_denormalized, predicted_s_denormalized
392
+ )
393
+ logprob = torch.log(likelihoods)
394
+ return logprob