careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,125 @@
1
+ """
2
+ Loss submodule.
3
+
4
+ This submodule contains the various losses used in CAREamics.
5
+ """
6
+
7
+ import torch
8
+ from torch.nn import L1Loss, MSELoss
9
+
10
+ from careamics.models.lvae.noise_models import GaussianMixtureNoiseModel
11
+
12
+
13
+ def mse_loss(source: torch.Tensor, target: torch.Tensor, *args) -> torch.Tensor:
14
+ """
15
+ Mean squared error loss.
16
+
17
+ Parameters
18
+ ----------
19
+ source : torch.Tensor
20
+ Source patches.
21
+ target : torch.Tensor
22
+ Target patches.
23
+ *args : Any
24
+ Additional arguments.
25
+
26
+ Returns
27
+ -------
28
+ torch.Tensor
29
+ Loss value.
30
+ """
31
+ loss = MSELoss()
32
+ return loss(source, target)
33
+
34
+
35
+ def n2v_loss(
36
+ manipulated_batch: torch.Tensor,
37
+ original_batch: torch.Tensor,
38
+ masks: torch.Tensor,
39
+ *args,
40
+ ) -> torch.Tensor:
41
+ """
42
+ N2V Loss function described in A Krull et al 2018.
43
+
44
+ Parameters
45
+ ----------
46
+ manipulated_batch : torch.Tensor
47
+ Batch after manipulation function applied.
48
+ original_batch : torch.Tensor
49
+ Original images.
50
+ masks : torch.Tensor
51
+ Coordinates of changed pixels.
52
+ *args : Any
53
+ Additional arguments.
54
+
55
+ Returns
56
+ -------
57
+ torch.Tensor
58
+ Loss value.
59
+ """
60
+ errors = (original_batch - manipulated_batch) ** 2
61
+ # Average over pixels and batch
62
+ loss = torch.sum(errors * masks) / torch.sum(masks)
63
+ return loss # TODO change output to dict ?
64
+
65
+
66
+ def pn2v_loss(
67
+ samples: torch.Tensor,
68
+ labels: torch.Tensor,
69
+ masks: torch.Tensor,
70
+ noise_model: GaussianMixtureNoiseModel,
71
+ ) -> torch.Tensor:
72
+ """
73
+ Probabilistic N2V loss function described in A Krull et al., CVF (2019).
74
+
75
+ Parameters
76
+ ----------
77
+ samples : torch.Tensor # TODO this naming is confusing
78
+ Predicted pixel values from the network.
79
+ labels : torch.Tensor
80
+ Original pixel values.
81
+ masks : torch.Tensor
82
+ Coordinates of manipulated pixels.
83
+ noise_model : GaussianMixtureNoiseModel
84
+ Noise model for computing likelihood.
85
+
86
+ Returns
87
+ -------
88
+ torch.Tensor
89
+ Loss value.
90
+ """
91
+ likelihoods = noise_model.likelihood(labels, samples)
92
+ likelihoods_avg = torch.log(torch.mean(likelihoods, dim=1, keepdim=True))
93
+
94
+ # Average over pixels and batch
95
+ loss = -torch.sum(likelihoods_avg * masks) / torch.sum(masks)
96
+ return loss
97
+
98
+
99
+ def mae_loss(samples: torch.Tensor, labels: torch.Tensor, *args) -> torch.Tensor:
100
+ """
101
+ N2N Loss function described in to J Lehtinen et al 2018.
102
+
103
+ Parameters
104
+ ----------
105
+ samples : torch.Tensor
106
+ Raw patches.
107
+ labels : torch.Tensor
108
+ Different subset of noisy patches.
109
+ *args : Any
110
+ Additional arguments.
111
+
112
+ Returns
113
+ -------
114
+ torch.Tensor
115
+ Loss value.
116
+ """
117
+ loss = L1Loss()
118
+ return loss(samples, labels)
119
+
120
+
121
+ # def dice_loss(
122
+ # samples: torch.Tensor, labels: torch.Tensor, mode: str = "multiclass"
123
+ # ) -> torch.Tensor:
124
+ # """Dice loss function."""
125
+ # return DiceLoss(mode=mode)(samples, labels.long())
@@ -0,0 +1,80 @@
1
+ """
2
+ Loss factory module.
3
+
4
+ This module contains a factory function for creating loss functions.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections.abc import Callable
10
+ from dataclasses import dataclass
11
+ from typing import Union
12
+
13
+ from torch import Tensor as tensor
14
+
15
+ from ..config.support import SupportedLoss
16
+ from .fcn.losses import mae_loss, mse_loss, n2v_loss, pn2v_loss
17
+ from .lvae.losses import (
18
+ denoisplit_loss,
19
+ denoisplit_musplit_loss,
20
+ hdn_loss,
21
+ musplit_loss,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class FCNLossParameters:
27
+ """Dataclass for FCN loss."""
28
+
29
+ # TODO check
30
+ prediction: tensor
31
+ targets: tensor
32
+ mask: tensor
33
+ current_epoch: int
34
+ loss_weight: float
35
+
36
+
37
+ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
38
+ """Return loss function.
39
+
40
+ Parameters
41
+ ----------
42
+ loss : Union[SupportedLoss, str]
43
+ Requested loss.
44
+
45
+ Returns
46
+ -------
47
+ Callable
48
+ Loss function.
49
+
50
+ Raises
51
+ ------
52
+ NotImplementedError
53
+ If the loss is unknown.
54
+ """
55
+ if loss == SupportedLoss.N2V:
56
+ return n2v_loss
57
+
58
+ elif loss == SupportedLoss.PN2V:
59
+ return pn2v_loss
60
+
61
+ elif loss == SupportedLoss.MAE:
62
+ return mae_loss
63
+
64
+ elif loss == SupportedLoss.MSE:
65
+ return mse_loss
66
+
67
+ elif loss == SupportedLoss.HDN:
68
+ return hdn_loss
69
+
70
+ elif loss == SupportedLoss.MUSPLIT:
71
+ return musplit_loss
72
+
73
+ elif loss == SupportedLoss.DENOISPLIT:
74
+ return denoisplit_loss
75
+
76
+ elif loss == SupportedLoss.DENOISPLIT_MUSPLIT:
77
+ return denoisplit_musplit_loss
78
+
79
+ else:
80
+ raise NotImplementedError(f"Loss {loss} is not yet supported.")
@@ -0,0 +1 @@
1
+ """LVAE losses."""
@@ -0,0 +1,83 @@
1
+ import torch
2
+
3
+
4
+ def free_bits_kl(
5
+ kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
6
+ ) -> torch.Tensor:
7
+ """Compute free-bits version of KL divergence.
8
+
9
+ This function ensures that the KL doesn't go to zero for any latent dimension.
10
+ Hence, it contributes to use latent variables more efficiently, leading to
11
+ better representation learning.
12
+
13
+ NOTE:
14
+ Takes in the KL with shape (batch size, layers), returns the KL with
15
+ free bits (for optimization) with shape (layers,), which is the average
16
+ free-bits KL per layer in the current batch.
17
+ If batch_average is False (default), the free bits are per layer and
18
+ per batch element. Otherwise, the free bits are still per layer, but
19
+ are assigned on average to the whole batch. In both cases, the batch
20
+ average is returned, so it's simply a matter of doing mean(clamp(KL))
21
+ or clamp(mean(KL)).
22
+
23
+ Parameters
24
+ ----------
25
+ kl : torch.Tensor
26
+ The KL divergence tensor with shape (batch size, layers).
27
+ free_bits : float
28
+ The free bits value. Set to 0.0 to disable free bits.
29
+ batch_average : bool
30
+ Whether to average over the batch before clamping to `free_bits`.
31
+ eps : float
32
+ A small value to avoid numerical instability.
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ The free-bits version of the KL divergence with shape (layers,).
38
+ """
39
+ assert kl.dim() == 2
40
+ if free_bits < eps:
41
+ return kl.mean(0)
42
+ if batch_average:
43
+ return kl.mean(0).clamp(min=free_bits)
44
+ return kl.clamp(min=free_bits).mean(0)
45
+
46
+
47
+ def get_kl_weight(
48
+ kl_annealing: bool,
49
+ kl_start: int,
50
+ kl_annealtime: int,
51
+ kl_weight: float,
52
+ current_epoch: int,
53
+ ) -> float:
54
+ """Compute the weight of the KL loss in case of annealing.
55
+
56
+ Parameters
57
+ ----------
58
+ kl_annealing : bool
59
+ Whether to use KL annealing.
60
+ kl_start : int
61
+ The epoch at which to start
62
+ kl_annealtime : int
63
+ The number of epochs for which annealing is applied.
64
+ kl_weight : float
65
+ The weight for the KL loss. If `None`, the weight is computed
66
+ using annealing, else it is set to a default of 1.
67
+ current_epoch : int
68
+ The current epoch.
69
+ """
70
+ if kl_annealing:
71
+ # calculate relative weight
72
+ kl_weight = (current_epoch - kl_start) * (1.0 / kl_annealtime)
73
+ # clamp to [0,1]
74
+ kl_weight = min(max(0.0, kl_weight), 1.0)
75
+
76
+ # if the final weight is given, then apply that weight on top of it
77
+ if kl_weight is not None:
78
+ kl_weight = kl_weight * kl_weight
79
+ elif kl_weight is not None:
80
+ return kl_weight
81
+ else:
82
+ kl_weight = 1.0
83
+ return kl_weight