careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,589 @@
1
+ """Methods for Loss Computation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any, Literal, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from careamics.losses.lvae.loss_utils import free_bits_kl, get_kl_weight
11
+ from careamics.models.lvae.likelihoods import (
12
+ GaussianLikelihood,
13
+ LikelihoodModule,
14
+ NoiseModelLikelihood,
15
+ )
16
+
17
+ if TYPE_CHECKING:
18
+ from careamics.config import LVAELossConfig
19
+
20
+ Likelihood = Union[LikelihoodModule, GaussianLikelihood, NoiseModelLikelihood]
21
+
22
+
23
+ def get_reconstruction_loss(
24
+ reconstruction: torch.Tensor,
25
+ target: torch.Tensor,
26
+ likelihood_obj: Likelihood,
27
+ ) -> dict[str, torch.Tensor]:
28
+ """Compute the reconstruction loss (negative log-likelihood).
29
+
30
+ Parameters
31
+ ----------
32
+ reconstruction: torch.Tensor
33
+ The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), where C is the
34
+ number of output channels (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
35
+ target: torch.Tensor
36
+ The target image used to compute the reconstruction loss. Shape is
37
+ (B, C, [Z], Y, X), where C is the number of output channels
38
+ (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
39
+ likelihood_obj: Likelihood
40
+ The likelihood object used to compute the reconstruction loss.
41
+
42
+ Returns
43
+ -------
44
+ torch.Tensor
45
+ The recontruction loss (negative log-likelihood).
46
+ """
47
+ # Compute Log likelihood
48
+ ll, _ = likelihood_obj(reconstruction, target) # shape: (B, C, [Z], Y, X)
49
+ return -1 * ll.mean()
50
+
51
+
52
+ def _reconstruction_loss_musplit_denoisplit(
53
+ predictions: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
54
+ targets: torch.Tensor,
55
+ nm_likelihood: NoiseModelLikelihood,
56
+ gaussian_likelihood: GaussianLikelihood,
57
+ nm_weight: float,
58
+ gaussian_weight: float,
59
+ ) -> torch.Tensor:
60
+ """Compute the reconstruction loss for muSplit-denoiSplit loss.
61
+
62
+ The resulting loss is a weighted mean of the noise model likelihood and the
63
+ Gaussian likelihood.
64
+
65
+ Parameters
66
+ ----------
67
+ predictions : torch.Tensor
68
+ The output of the LVAE decoder. Shape is (B, C, [Z], Y, X), or
69
+ (B, 2*C, [Z], Y, X), where C is the number of output channels,
70
+ and the factor of 2 is for the case of predicted log-variance.
71
+ targets : torch.Tensor
72
+ The target image used to compute the reconstruction loss. Shape is
73
+ (B, C, [Z], Y, X), where C is the number of output channels
74
+ (e.g., 1 in HDN, >1 in muSplit/denoiSplit).
75
+ nm_likelihood : NoiseModelLikelihood
76
+ A `NoiseModelLikelihood` object used to compute the noise model likelihood.
77
+ gaussian_likelihood : GaussianLikelihood
78
+ A `GaussianLikelihood` object used to compute the Gaussian likelihood.
79
+ nm_weight : float
80
+ The weight for the noise model likelihood.
81
+ gaussian_weight : float
82
+ The weight for the Gaussian likelihood.
83
+
84
+ Returns
85
+ -------
86
+ recons_loss : torch.Tensor
87
+ The reconstruction loss. Shape is (1, ).
88
+ """
89
+ if predictions.shape[1] == 2 * targets.shape[1]:
90
+ # predictions contain both mean and log-variance
91
+ pred_mean, _ = predictions.chunk(2, dim=1)
92
+ # TODO if this condition does not hold, everything breaks later!
93
+ else:
94
+ pred_mean = predictions
95
+
96
+ recons_loss_nm = get_reconstruction_loss(
97
+ reconstruction=pred_mean, target=targets, likelihood_obj=nm_likelihood
98
+ )
99
+
100
+ recons_loss_gm = get_reconstruction_loss(
101
+ reconstruction=predictions,
102
+ target=targets,
103
+ likelihood_obj=gaussian_likelihood,
104
+ )
105
+
106
+ recons_loss = nm_weight * recons_loss_nm + gaussian_weight * recons_loss_gm
107
+ return recons_loss
108
+
109
+
110
+ def get_kl_divergence_loss(
111
+ kl_type: Literal["kl", "kl_restricted"],
112
+ topdown_data: dict[str, torch.Tensor],
113
+ rescaling: Literal["latent_dim", "image_dim"],
114
+ aggregation: Literal["mean", "sum"],
115
+ free_bits_coeff: float,
116
+ img_shape: tuple[int] | None = None,
117
+ ) -> torch.Tensor:
118
+ """Compute the KL divergence loss.
119
+
120
+ NOTE: Description of `rescaling` methods:
121
+ - If "latent_dim", the KL-loss values are rescaled w.r.t. the latent space
122
+ dimensions (spatial + number of channels, i.e., (C, [Z], Y, X)). In this way they
123
+ have the same magnitude across layers.
124
+ - If "image_dim", the KL-loss values are rescaled w.r.t. the input image spatial
125
+ dimensions. In this way, the lower layers have a larger KL-loss value compared to
126
+ the higher layers, since the latent space and hence the KL tensor has more entries.
127
+ Specifically, at hierarchy `i`, the total KL loss is larger by a factor (128/i**2).
128
+
129
+ NOTE: the type of `aggregation` determines the magnitude of the KL-loss. Clearly,
130
+ "sum" aggregation results in a larger KL-loss value compared to "mean" by a factor
131
+ of `n_layers`.
132
+
133
+ NOTE: recall that sample-wise KL is obtained by summing over all dimensions,
134
+ including Z. Also recall that in current 3D implementation of LVAE, no downsampling
135
+ is done on Z. Therefore, to avoid emphasizing KL loss too much, we divide it
136
+ by the Z dimension of input image in every case.
137
+
138
+ Parameters
139
+ ----------
140
+ kl_type : Literal["kl", "kl_restricted"]
141
+ The type of KL divergence loss to compute.
142
+ topdown_data : dict[str, torch.Tensor]
143
+ A dictionary containing information computed for each layer during the top-down
144
+ pass. The dictionary must include the following keys:
145
+ - "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
146
+ - "z": The sampled latents for each layer. Shape of each tensor is
147
+ (B, layers, `z_dims[i]`, H, W).
148
+ rescaling : Literal["latent_dim", "image_dim"]
149
+ The rescaling method used for the KL-loss values. If "latent_dim", the KL-loss
150
+ values are rescaled w.r.t. the latent space dimensions (spatial + number of
151
+ channels, i.e., (C, [Z], Y, X)). If "image_dim", the KL-loss values are
152
+ rescaled w.r.t. the input image spatial dimensions.
153
+ aggregation : Literal["mean", "sum"]
154
+ The aggregation method used to combine the KL-loss values across layers. If
155
+ "mean", the KL-loss values are averaged across layers. If "sum", the KL-loss
156
+ values are summed across layers.
157
+ free_bits_coeff : float
158
+ The free bits coefficient used for the KL-loss computation.
159
+ img_shape : Optional[tuple[int]]
160
+ The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
161
+
162
+ Returns
163
+ -------
164
+ kl_loss : torch.Tensor
165
+ The KL divergence loss. Shape is (1, ).
166
+ """
167
+ kl = torch.cat(
168
+ [kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_type]],
169
+ dim=1,
170
+ ) # shape: (B, n_layers)
171
+
172
+ # Apply free bits (& batch average)
173
+ kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)
174
+
175
+ # In 3D case, rescale by Z dim
176
+ # TODO If we have downsampling in Z dimension, then this needs to change.
177
+ if len(img_shape) == 3:
178
+ kl = kl / img_shape[0]
179
+
180
+ # Rescaling
181
+ if rescaling == "latent_dim":
182
+ for i in range(len(kl)):
183
+ latent_dim = topdown_data["z"][i].shape[1:]
184
+ norm_factor = np.prod(latent_dim)
185
+ kl[i] = kl[i] / norm_factor
186
+ elif rescaling == "image_dim":
187
+ kl = kl / np.prod(img_shape[-2:])
188
+
189
+ # Aggregation
190
+ if aggregation == "mean":
191
+ kl = kl.mean() # shape: (1,)
192
+ elif aggregation == "sum":
193
+ kl = kl.sum() # shape: (1,)
194
+
195
+ return kl
196
+
197
+
198
+ def _get_kl_divergence_loss_musplit(
199
+ topdown_data: dict[str, torch.Tensor],
200
+ img_shape: tuple[int],
201
+ kl_type: Literal["kl", "kl_restricted"],
202
+ ) -> torch.Tensor:
203
+ """Compute the KL divergence loss for muSplit.
204
+
205
+ Parameters
206
+ ----------
207
+ topdown_data : dict[str, torch.Tensor]
208
+ A dictionary containing information computed for each layer during the top-down
209
+ pass. The dictionary must include the following keys:
210
+ - "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
211
+ - "z": The sampled latents for each layer. Shape of each tensor is
212
+ (B, layers, `z_dims[i]`, H, W).
213
+ img_shape : tuple[int]
214
+ The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
215
+ kl_type : Literal["kl", "kl_restricted"]
216
+ The type of KL divergence loss to compute.
217
+
218
+ Returns
219
+ -------
220
+ kl_loss : torch.Tensor
221
+ The KL divergence loss for the muSplit case. Shape is (1, ).
222
+ """
223
+ return get_kl_divergence_loss(
224
+ kl_type="kl", # TODO: hardcoded, deal in future PR
225
+ topdown_data=topdown_data,
226
+ rescaling="latent_dim",
227
+ aggregation="mean",
228
+ free_bits_coeff=0.0,
229
+ img_shape=img_shape,
230
+ )
231
+
232
+
233
+ def _get_kl_divergence_loss_denoisplit(
234
+ topdown_data: dict[str, torch.Tensor],
235
+ img_shape: tuple[int],
236
+ kl_type: Literal["kl", "kl_restricted"],
237
+ ) -> torch.Tensor:
238
+ """Compute the KL divergence loss for denoiSplit.
239
+
240
+ Parameters
241
+ ----------
242
+ topdown_data : dict[str, torch.Tensor]
243
+ A dictionary containing information computed for each layer during the top-down
244
+ pass. The dictionary must include the following keys:
245
+ - "kl": The KL-loss values for each layer. Shape of each tensor is (B,).
246
+ - "z": The sampled latents for each layer. Shape of each tensor is
247
+ (B, layers, `z_dims[i]`, H, W).
248
+ img_shape : tuple[int]
249
+ The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
250
+ kl_type : Literal["kl", "kl_restricted"]
251
+ The type of KL divergence loss to compute.
252
+
253
+ Returns
254
+ -------
255
+ kl_loss : torch.Tensor
256
+ The KL divergence loss for the denoiSplit case. Shape is (1, ).
257
+ """
258
+ return get_kl_divergence_loss(
259
+ kl_type=kl_type,
260
+ topdown_data=topdown_data,
261
+ rescaling="image_dim",
262
+ aggregation="sum",
263
+ free_bits_coeff=1.0,
264
+ img_shape=img_shape,
265
+ )
266
+
267
+
268
+ # TODO: @melisande-c suggested to refactor this as a class (see PR #208)
269
+ # - loss computation happens by calling the `__call__` method
270
+ # - `__init__` method initializes the loss parameters now contained in
271
+ # the `LVAELossParameters` class
272
+ # NOTE: same for the other loss functions
273
+
274
+
275
+ def hdn_loss(
276
+ model_outputs: tuple[torch.Tensor, dict[str, Any]],
277
+ targets: torch.Tensor,
278
+ config: LVAELossConfig,
279
+ gaussian_likelihood: GaussianLikelihood | None,
280
+ noise_model_likelihood: NoiseModelLikelihood | None,
281
+ ) -> dict[str, torch.Tensor] | None:
282
+ """Loss function for HDN.
283
+
284
+ Parameters
285
+ ----------
286
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
287
+ Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
288
+ and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
289
+ targets : torch.Tensor
290
+ The target image used to compute the reconstruction loss. In this case we use
291
+ the input patch itself as target. Shape is (B, `target_ch`, [Z], Y, X).
292
+ config : LVAELossConfig
293
+ The config for loss function containing all loss hyperparameters.
294
+ gaussian_likelihood : GaussianLikelihood
295
+ The Gaussian likelihood object.
296
+ noise_model_likelihood : NoiseModelLikelihood
297
+ The noise model likelihood object.
298
+
299
+ Returns
300
+ -------
301
+ output : Optional[dict[str, torch.Tensor]]
302
+ A dictionary containing the overall loss `["loss"]`, the reconstruction loss
303
+ `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
304
+ """
305
+ if gaussian_likelihood is not None:
306
+ likelihood = gaussian_likelihood
307
+ elif noise_model_likelihood is not None:
308
+ likelihood = noise_model_likelihood
309
+ else:
310
+ raise ValueError("Invalid likelihood object.")
311
+ # TODO refactor loss signature
312
+ predictions, td_data = model_outputs
313
+
314
+ # Reconstruction loss computation
315
+ recons_loss = config.reconstruction_weight * get_reconstruction_loss(
316
+ reconstruction=predictions,
317
+ target=targets,
318
+ likelihood_obj=likelihood,
319
+ )
320
+ if torch.isnan(recons_loss).any():
321
+ recons_loss = 0.0
322
+
323
+ # KL loss computation
324
+ kl_weight = get_kl_weight(
325
+ config.kl_params.annealing,
326
+ config.kl_params.start,
327
+ config.kl_params.annealtime,
328
+ config.kl_weight,
329
+ config.kl_params.current_epoch,
330
+ )
331
+ kl_loss = (
332
+ _get_kl_divergence_loss_denoisplit(
333
+ topdown_data=td_data,
334
+ img_shape=targets.shape[2:],
335
+ kl_type=config.kl_params.loss_type,
336
+ )
337
+ * kl_weight
338
+ )
339
+
340
+ net_loss = recons_loss + kl_loss # TODO add check that losses coefs sum to 1
341
+ output = {
342
+ "loss": net_loss,
343
+ "reconstruction_loss": (
344
+ recons_loss.detach()
345
+ if isinstance(recons_loss, torch.Tensor)
346
+ else recons_loss
347
+ ),
348
+ "kl_loss": kl_loss.detach(),
349
+ }
350
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
351
+ if torch.isnan(net_loss).any():
352
+ return None
353
+
354
+ return output
355
+
356
+
357
+ def musplit_loss(
358
+ model_outputs: tuple[torch.Tensor, dict[str, Any]],
359
+ targets: torch.Tensor,
360
+ config: LVAELossConfig,
361
+ gaussian_likelihood: GaussianLikelihood | None,
362
+ noise_model_likelihood: NoiseModelLikelihood | None = None, # TODO: ugly
363
+ ) -> dict[str, torch.Tensor] | None:
364
+ """Loss function for muSplit.
365
+
366
+ Parameters
367
+ ----------
368
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
369
+ Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
370
+ and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
371
+ targets : torch.Tensor
372
+ The target image used to compute the reconstruction loss. Shape is
373
+ (B, `target_ch`, [Z], Y, X).
374
+ config : LVAELossConfig
375
+ The config for loss function (e.g., KL hyperparameters, likelihood module,
376
+ noise model, etc.).
377
+ gaussian_likelihood : GaussianLikelihood
378
+ The Gaussian likelihood object.
379
+ noise_model_likelihood : Optional[NoiseModelLikelihood]
380
+ The noise model likelihood object. Not used here.
381
+
382
+ Returns
383
+ -------
384
+ output : Optional[dict[str, torch.Tensor]]
385
+ A dictionary containing the overall loss `["loss"]`, the reconstruction loss
386
+ `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
387
+ """
388
+ assert gaussian_likelihood is not None
389
+
390
+ predictions, td_data = model_outputs
391
+
392
+ # Reconstruction loss computation
393
+ recons_loss = config.reconstruction_weight * get_reconstruction_loss(
394
+ reconstruction=predictions,
395
+ target=targets,
396
+ likelihood_obj=gaussian_likelihood,
397
+ )
398
+ if torch.isnan(recons_loss).any():
399
+ recons_loss = 0.0
400
+
401
+ # KL loss computation
402
+ kl_weight = get_kl_weight(
403
+ config.kl_params.annealing,
404
+ config.kl_params.start,
405
+ config.kl_params.annealtime,
406
+ config.kl_weight,
407
+ config.kl_params.current_epoch,
408
+ )
409
+ kl_loss = (
410
+ _get_kl_divergence_loss_musplit(
411
+ topdown_data=td_data,
412
+ img_shape=targets.shape[2:],
413
+ kl_type=config.kl_params.loss_type,
414
+ )
415
+ * kl_weight
416
+ )
417
+
418
+ net_loss = recons_loss + kl_loss
419
+ output = {
420
+ "loss": net_loss,
421
+ "reconstruction_loss": (
422
+ recons_loss.detach()
423
+ if isinstance(recons_loss, torch.Tensor)
424
+ else recons_loss
425
+ ),
426
+ "kl_loss": kl_loss.detach(),
427
+ }
428
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
429
+ if torch.isnan(net_loss).any():
430
+ return None
431
+
432
+ return output
433
+
434
+
435
+ def denoisplit_loss(
436
+ model_outputs: tuple[torch.Tensor, dict[str, Any]],
437
+ targets: torch.Tensor,
438
+ config: LVAELossConfig,
439
+ gaussian_likelihood: GaussianLikelihood | None = None,
440
+ noise_model_likelihood: NoiseModelLikelihood | None = None,
441
+ ) -> dict[str, torch.Tensor] | None:
442
+ """Loss function for DenoiSplit.
443
+
444
+ Parameters
445
+ ----------
446
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
447
+ Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
448
+ and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
449
+ targets : torch.Tensor
450
+ The target image used to compute the reconstruction loss. Shape is
451
+ (B, `target_ch`, [Z], Y, X).
452
+ config : LVAELossConfig
453
+ The config for loss function containing all loss hyperparameters.
454
+ gaussian_likelihood : GaussianLikelihood
455
+ The Gaussian likelihood object.
456
+ noise_model_likelihood : NoiseModelLikelihood
457
+ The noise model likelihood object.
458
+
459
+ Returns
460
+ -------
461
+ output : Optional[dict[str, torch.Tensor]]
462
+ A dictionary containing the overall loss `["loss"]`, the reconstruction loss
463
+ `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
464
+ """
465
+ assert noise_model_likelihood is not None
466
+
467
+ predictions, td_data = model_outputs
468
+
469
+ # Reconstruction loss computation
470
+ recons_loss = config.reconstruction_weight * get_reconstruction_loss(
471
+ reconstruction=predictions,
472
+ target=targets,
473
+ likelihood_obj=noise_model_likelihood,
474
+ )
475
+ if torch.isnan(recons_loss).any():
476
+ recons_loss = 0.0
477
+
478
+ # KL loss computation
479
+ kl_weight = get_kl_weight(
480
+ config.kl_params.annealing,
481
+ config.kl_params.start,
482
+ config.kl_params.annealtime,
483
+ config.kl_weight,
484
+ config.kl_params.current_epoch,
485
+ )
486
+ kl_loss = (
487
+ _get_kl_divergence_loss_denoisplit(
488
+ topdown_data=td_data,
489
+ img_shape=targets.shape[2:],
490
+ kl_type=config.kl_params.loss_type,
491
+ )
492
+ * kl_weight
493
+ )
494
+
495
+ net_loss = recons_loss + kl_loss
496
+ output = {
497
+ "loss": net_loss,
498
+ "reconstruction_loss": (
499
+ recons_loss.detach()
500
+ if isinstance(recons_loss, torch.Tensor)
501
+ else recons_loss
502
+ ),
503
+ "kl_loss": kl_loss.detach(),
504
+ }
505
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
506
+ if torch.isnan(net_loss).any():
507
+ return None
508
+
509
+ return output
510
+
511
+
512
+ def denoisplit_musplit_loss(
513
+ model_outputs: tuple[torch.Tensor, dict[str, Any]],
514
+ targets: torch.Tensor,
515
+ config: LVAELossConfig,
516
+ gaussian_likelihood: GaussianLikelihood,
517
+ noise_model_likelihood: NoiseModelLikelihood,
518
+ ) -> dict[str, torch.Tensor] | None:
519
+ """Loss function for DenoiSplit.
520
+
521
+ Parameters
522
+ ----------
523
+ model_outputs : tuple[torch.Tensor, dict[str, Any]]
524
+ Tuple containing the model predictions (shape is (B, `target_ch`, [Z], Y, X))
525
+ and the top-down layer data (e.g., sampled latents, KL-loss values, etc.).
526
+ targets : torch.Tensor
527
+ The target image used to compute the reconstruction loss. Shape is
528
+ (B, `target_ch`, [Z], Y, X).
529
+ config : LVAELossConfig
530
+ The config for loss function containing all loss hyperparameters.
531
+ gaussian_likelihood : GaussianLikelihood
532
+ The Gaussian likelihood object.
533
+ noise_model_likelihood : NoiseModelLikelihood
534
+ The noise model likelihood object.
535
+
536
+ Returns
537
+ -------
538
+ output : Optional[dict[str, torch.Tensor]]
539
+ A dictionary containing the overall loss `["loss"]`, the reconstruction loss
540
+ `["reconstruction_loss"]`, and the KL divergence loss `["kl_loss"]`.
541
+ """
542
+ predictions, td_data = model_outputs
543
+
544
+ # Reconstruction loss computation
545
+ recons_loss = _reconstruction_loss_musplit_denoisplit(
546
+ predictions=predictions,
547
+ targets=targets,
548
+ nm_likelihood=noise_model_likelihood,
549
+ gaussian_likelihood=gaussian_likelihood,
550
+ nm_weight=config.denoisplit_weight,
551
+ gaussian_weight=config.musplit_weight,
552
+ )
553
+ if torch.isnan(recons_loss).any():
554
+ recons_loss = 0.0
555
+
556
+ # KL loss computation
557
+ # NOTE: 'kl' key stands for the 'kl_samplewise' key in the TopDownLayer class.
558
+ # The different naming comes from `top_down_pass()` method in the LadderVAE.
559
+ denoisplit_kl = _get_kl_divergence_loss_denoisplit(
560
+ topdown_data=td_data,
561
+ img_shape=targets.shape[2:],
562
+ kl_type=config.kl_params.loss_type,
563
+ )
564
+ musplit_kl = _get_kl_divergence_loss_musplit(
565
+ topdown_data=td_data,
566
+ img_shape=targets.shape[2:],
567
+ kl_type=config.kl_params.loss_type,
568
+ )
569
+ kl_loss = (
570
+ config.denoisplit_weight * denoisplit_kl + config.musplit_weight * musplit_kl
571
+ )
572
+ # TODO `kl_weight` is hardcoded (???)
573
+ kl_loss = config.kl_weight * kl_loss
574
+
575
+ net_loss = recons_loss + kl_loss
576
+ output = {
577
+ "loss": net_loss,
578
+ "reconstruction_loss": (
579
+ recons_loss.detach()
580
+ if isinstance(recons_loss, torch.Tensor)
581
+ else recons_loss
582
+ ),
583
+ "kl_loss": kl_loss.detach(),
584
+ }
585
+ # https://github.com/openai/vdvae/blob/main/train.py#L26
586
+ if torch.isnan(net_loss).any():
587
+ return None
588
+
589
+ return output
File without changes