careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,404 @@
1
+ """
2
+ Script for utility functions needed by the LVAE model.
3
+ """
4
+
5
+ from typing import Literal, Sequence
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision.transforms.functional as F
11
+ from torch.distributions.normal import Normal
12
+
13
+
14
+ def torch_nanmean(inp):
15
+ return torch.mean(inp[~inp.isnan()])
16
+
17
+
18
+ def power_of_2(self, x):
19
+ assert isinstance(x, int)
20
+ if x == 1:
21
+ return True
22
+ if x == 0:
23
+ # happens with validation
24
+ return False
25
+ if x % 2 == 1:
26
+ return False
27
+ return self.power_of_2(x // 2)
28
+
29
+
30
+ class Enum:
31
+ @classmethod
32
+ def name(cls, enum_type):
33
+ for key, value in cls.__dict__.items():
34
+ if enum_type == value:
35
+ return key
36
+
37
+ @classmethod
38
+ def contains(cls, enum_type):
39
+ for key, value in cls.__dict__.items():
40
+ if enum_type == value:
41
+ return True
42
+ return False
43
+
44
+ @classmethod
45
+ def from_name(cls, enum_type_str):
46
+ for key, value in cls.__dict__.items():
47
+ if key == enum_type_str:
48
+ return value
49
+ assert f"{cls.__name__}:{enum_type_str} doesnot exist."
50
+
51
+
52
+ class LossType(Enum):
53
+ Elbo = 0
54
+ ElboWithCritic = 1
55
+ ElboMixedReconstruction = 2
56
+ MSE = 3
57
+ ElboWithNbrConsistency = 4
58
+ ElboSemiSupMixedReconstruction = 5
59
+ ElboCL = 6
60
+ ElboRestrictedReconstruction = 7
61
+ DenoiSplitMuSplit = 8
62
+
63
+
64
+ class ModelType(Enum):
65
+ LadderVae = 3
66
+ LadderVaeTwinDecoder = 4
67
+ LadderVAECritic = 5
68
+ # Separate vampprior: two optimizers
69
+ LadderVaeSepVampprior = 6
70
+ # one encoder for mixed input, two for separate inputs.
71
+ LadderVaeSepEncoder = 7
72
+ LadderVAEMultiTarget = 8
73
+ LadderVaeSepEncoderSingleOptim = 9
74
+ UNet = 10
75
+ BraveNet = 11
76
+ LadderVaeStitch = 12
77
+ LadderVaeSemiSupervised = 13
78
+ LadderVaeStitch2Stage = 14 # Note that previously trained models will have issue.
79
+ # since earlier, LadderVaeStitch2Stage = 13, LadderVaeSemiSupervised = 14
80
+ LadderVaeMixedRecons = 15
81
+ LadderVaeCL = 16
82
+ LadderVaeTwoDataSet = (
83
+ 17 # on one subdset, apply disentanglement, on other apply reconstruction
84
+ )
85
+ LadderVaeTwoDatasetMultiBranch = 18
86
+ LadderVaeTwoDatasetMultiOptim = 19
87
+ LVaeDeepEncoderIntensityAug = 20
88
+ AutoRegresiveLadderVAE = 21
89
+ LadderVAEInterleavedOptimization = 22
90
+ Denoiser = 23
91
+ DenoiserSplitter = 24
92
+ SplitterDenoiser = 25
93
+ LadderVAERestrictedReconstruction = 26
94
+ LadderVAETwoDataSetRestRecon = 27
95
+ LadderVAETwoDataSetFinetuning = 28
96
+
97
+
98
+ def _pad_crop_img(
99
+ x: torch.Tensor, size: Sequence[int], mode: Literal["crop", "pad"]
100
+ ) -> torch.Tensor:
101
+ """Pads or crops a tensor.
102
+
103
+ Pads or crops a tensor of shape (B, C, [Z], Y, X) to new shape.
104
+
105
+ Parameters:
106
+ -----------
107
+ x: torch.Tensor
108
+ Input image of shape (B, C, [Z], Y, X)
109
+ size: Sequence[int]
110
+ Desired size ([Z*], Y*, X*)
111
+ mode: Literal["crop", "pad"]
112
+ Mode, either 'pad' or 'crop'
113
+
114
+ Returns:
115
+ --------
116
+ torch.Tensor:
117
+ The padded or cropped tensor
118
+ """
119
+ # TODO: Support cropping/padding on selected dimensions
120
+ assert (x.dim() == 4 and len(size) == 2) or (x.dim() == 5 and len(size) == 3)
121
+
122
+ size = tuple(size)
123
+ x_size = x.size()[2:]
124
+
125
+ if mode == "pad":
126
+ cond = any(x_size[i] > size[i] for i in range(len(size)))
127
+ elif mode == "crop":
128
+ cond = any(x_size[i] < size[i] for i in range(len(size)))
129
+
130
+ if cond:
131
+ raise ValueError(f"Trying to {mode} from size {x_size} to size {size}")
132
+
133
+ diffs = [abs(x - s) for x, s in zip(x_size, size)]
134
+ d1 = [d // 2 for d in diffs]
135
+ d2 = [d - (d // 2) for d in diffs]
136
+
137
+ if mode == "pad":
138
+ if x.dim() == 4:
139
+ padding = [d1[1], d2[1], d1[0], d2[0], 0, 0, 0, 0]
140
+ elif x.dim() == 5:
141
+ padding = [d1[2], d2[2], d1[1], d2[1], d1[0], d2[0], 0, 0, 0, 0]
142
+ return nn.functional.pad(x, padding)
143
+ elif mode == "crop":
144
+ if x.dim() == 4:
145
+ return x[:, :, d1[0] : (x_size[0] - d2[0]), d1[1] : (x_size[1] - d2[1])]
146
+ elif x.dim() == 5:
147
+ return x[
148
+ :,
149
+ :,
150
+ d1[0] : (x_size[0] - d2[0]),
151
+ d1[1] : (x_size[1] - d2[1]),
152
+ d1[2] : (x_size[2] - d2[2]),
153
+ ]
154
+
155
+
156
+ def pad_img_tensor(x: torch.Tensor, size: Sequence[int]) -> torch.Tensor:
157
+ """Pads a tensor
158
+
159
+ Pads a tensor of shape (B, C, [Z], Y, X) to desired spatial dimensions.
160
+
161
+ Parameters:
162
+ -----------
163
+ x (torch.Tensor): Input image of shape (B, C, [Z], Y, X)
164
+ size (list or tuple): Desired size ([Z*], Y*, X*)
165
+
166
+ Returns:
167
+ --------
168
+ The padded tensor
169
+ """
170
+ return _pad_crop_img(x, size, "pad")
171
+
172
+
173
+ def crop_img_tensor(x, size) -> torch.Tensor:
174
+ """Crops a tensor.
175
+ Crops a tensor of shape (batch, channels, h, w) to a desired height and width
176
+ given by a tuple.
177
+ Args:
178
+ x (torch.Tensor): Input image
179
+ size (list or tuple): Desired size (height, width)
180
+
181
+ Returns
182
+ -------
183
+ The cropped tensor
184
+ """
185
+ return _pad_crop_img(x, size, "crop")
186
+
187
+
188
+ class StableExponential:
189
+ """
190
+ Class that redefines the definition of exp() to increase numerical stability.
191
+ Naturally, also the definition of log() must change accordingly.
192
+ However, it is worth noting that the two operations remain one the inverse of the other,
193
+ meaning that x = log(exp(x)) and x = exp(log(x)) are always true.
194
+
195
+ Definition:
196
+ exp(x) = {
197
+ exp(x) if x<=0
198
+ x+1 if x>0
199
+ }
200
+
201
+ log(x) = {
202
+ x if x<=0
203
+ log(1+x) if x>0
204
+ }
205
+
206
+ NOTE 1:
207
+ Within the class everything is done on the tensor given as input to the constructor.
208
+ Therefore, when exp() is called, self._tensor.exp() is computed.
209
+ When log() is called, torch.log(self._tensor.exp()) is computed instead.
210
+
211
+ NOTE 2:
212
+ Given the output from exp(), torch.log() or the log() method of the class give identical results.
213
+ """
214
+
215
+ def __init__(self, tensor):
216
+ self._raw_tensor = tensor
217
+ posneg_dic = self.posneg_separation(self._raw_tensor)
218
+ self.pos_f, self.neg_f = posneg_dic["filter"]
219
+ self.pos_data, self.neg_data = posneg_dic["value"]
220
+
221
+ def posneg_separation(self, tensor):
222
+ pos = tensor > 0
223
+ pos_tensor = torch.clip(tensor, min=0)
224
+
225
+ neg = tensor <= 0
226
+ neg_tensor = torch.clip(tensor, max=0)
227
+
228
+ return {"filter": [pos, neg], "value": [pos_tensor, neg_tensor]}
229
+
230
+ def exp(self):
231
+ return torch.exp(self.neg_data) * self.neg_f + (1 + self.pos_data) * self.pos_f
232
+
233
+ def log(self):
234
+ return self.neg_data * self.neg_f + torch.log(1 + self.pos_data) * self.pos_f
235
+
236
+
237
+ class StableLogVar:
238
+ """
239
+ Class that provides a numerically stable implementation of Log-Variance.
240
+ Specifically, it uses the exp() and log() formulas defined in `StableExponential` class.
241
+ """
242
+
243
+ def __init__(
244
+ self, logvar: torch.Tensor, enable_stable: bool = True, var_eps: float = 1e-6
245
+ ):
246
+ """
247
+ Constructor.
248
+
249
+ Parameters
250
+ ----------
251
+ logvar: torch.Tensor
252
+ The input (true) logvar vector, to be converted in the Stable version.
253
+ enable_stable: bool, optional
254
+ Whether to compute the stable version of log-variance. Default is `True`.
255
+ var_eps: float, optional
256
+ The minimum value attainable by the variance. Default is `1e-6`.
257
+ """
258
+ self._lv = logvar
259
+ self._enable_stable = enable_stable
260
+ self._eps = var_eps
261
+
262
+ def get(self) -> torch.Tensor:
263
+ if self._enable_stable is False:
264
+ return self._lv
265
+
266
+ return torch.log(self.get_var())
267
+
268
+ def get_var(self) -> torch.Tensor:
269
+ """
270
+ Get Variance from Log-Variance.
271
+ """
272
+ if self._enable_stable is False:
273
+ return torch.exp(self._lv)
274
+ return StableExponential(self._lv).exp() + self._eps
275
+
276
+ def get_std(self) -> torch.Tensor:
277
+ return torch.sqrt(self.get_var())
278
+
279
+ @property
280
+ def is_3D(self) -> bool:
281
+ """Check if the _lv tensor is 3D.
282
+
283
+ Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).
284
+ """
285
+ return self._lv.dim() == 5
286
+
287
+ def centercrop_to_size(self, size: Sequence[int]) -> None:
288
+ """
289
+ Centercrop the log-variance tensor to the desired size.
290
+
291
+ Parameters
292
+ ----------
293
+ size: torch.Tensor
294
+ The desired size of the log-variance tensor.
295
+ """
296
+ assert not self.is_3D, "Centercrop is implemented only for 2D tensors."
297
+
298
+ if self._lv.shape[-1] == size:
299
+ return
300
+
301
+ diff = self._lv.shape[-1] - size
302
+ assert diff > 0 and diff % 2 == 0
303
+ self._lv = F.center_crop(self._lv, (size, size))
304
+
305
+
306
+ class StableMean:
307
+
308
+ def __init__(self, mean):
309
+ self._mean = mean
310
+
311
+ def get(self) -> torch.Tensor:
312
+ return self._mean
313
+
314
+ @property
315
+ def is_3D(self) -> bool:
316
+ """Check if the _mean tensor is 3D.
317
+
318
+ Recall that, in this framework, tensors have shape (B, C, [Z], Y, X).
319
+ """
320
+ return self._mean.dim() == 5
321
+
322
+ def centercrop_to_size(self, size: Sequence[int]) -> None:
323
+ """Centercrop the mean tensor to the desired size.
324
+
325
+ Implemented only in the case of 2D tensors.
326
+
327
+ Parameters
328
+ ----------
329
+ size: torch.Tensor
330
+ The desired size of the log-variance tensor.
331
+ """
332
+ assert not self.is_3D, "Centercrop is implemented only for 2D tensors."
333
+
334
+ if self._mean.shape[-1] == size:
335
+ return
336
+
337
+ diff = self._mean.shape[-1] - size
338
+ assert diff > 0 and diff % 2 == 0
339
+ self._mean = F.center_crop(self._mean, (size, size))
340
+
341
+
342
+ def allow_numpy(func):
343
+ """
344
+ All optional arguments are passed as is. positional arguments are checked. if they are numpy array,
345
+ they are converted to torch Tensor.
346
+ """
347
+
348
+ def numpy_wrapper(*args, **kwargs):
349
+ new_args = []
350
+ for arg in args:
351
+ if isinstance(arg, np.ndarray):
352
+ arg = torch.Tensor(arg)
353
+ new_args.append(arg)
354
+ new_args = tuple(new_args)
355
+
356
+ output = func(*new_args, **kwargs)
357
+ return output
358
+
359
+ return numpy_wrapper
360
+
361
+
362
+ class Interpolate(nn.Module):
363
+ """Wrapper for torch.nn.functional.interpolate."""
364
+
365
+ def __init__(self, size=None, scale=None, mode="bilinear", align_corners=False):
366
+ super().__init__()
367
+ assert (size is None) == (scale is not None)
368
+ self.size = size
369
+ self.scale = scale
370
+ self.mode = mode
371
+ self.align_corners = align_corners
372
+
373
+ def forward(self, x):
374
+ out = F.interpolate(
375
+ x,
376
+ size=self.size,
377
+ scale_factor=self.scale,
378
+ mode=self.mode,
379
+ align_corners=self.align_corners,
380
+ )
381
+ return out
382
+
383
+
384
+ def kl_normal_mc(z, p_mulv, q_mulv):
385
+ """
386
+ One-sample estimation of element-wise KL between two diagonal
387
+ multivariate normal distributions. Any number of dimensions,
388
+ broadcasting supported (be careful).
389
+ :param z:
390
+ :param p_mulv:
391
+ :param q_mulv:
392
+ :return:
393
+ """
394
+ assert isinstance(p_mulv, tuple)
395
+ assert isinstance(q_mulv, tuple)
396
+ p_mu, p_lv = p_mulv
397
+ q_mu, q_lv = q_mulv
398
+
399
+ p_std = p_lv.get_std()
400
+ q_std = q_lv.get_std()
401
+
402
+ p_distrib = Normal(p_mu.get(), p_std)
403
+ q_distrib = Normal(q_mu.get(), q_std)
404
+ return q_distrib.log_prob(z) - p_distrib.log_prob(z)
@@ -0,0 +1,54 @@
1
+ """Model creation factory functions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Union
6
+
7
+ import torch
8
+
9
+ from careamics.config.support import SupportedArchitecture
10
+ from careamics.models.lvae import LadderVAE as LVAE
11
+ from careamics.models.unet import UNet
12
+ from careamics.utils import get_logger
13
+
14
+ if TYPE_CHECKING:
15
+ from careamics.config.architectures import (
16
+ LVAEConfig,
17
+ UNetConfig,
18
+ )
19
+
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ def model_factory(
25
+ model_configuration: Union[UNetConfig, LVAEConfig],
26
+ ) -> torch.nn.Module:
27
+ """
28
+ Deep learning model factory.
29
+
30
+ Supported models are defined in careamics.config.SupportedArchitecture.
31
+
32
+ Parameters
33
+ ----------
34
+ model_configuration : Union[UNetModel, VAEModel]
35
+ Model configuration.
36
+
37
+ Returns
38
+ -------
39
+ torch.nn.Module
40
+ Model class.
41
+
42
+ Raises
43
+ ------
44
+ NotImplementedError
45
+ If the requested architecture is not implemented.
46
+ """
47
+ if model_configuration.architecture == SupportedArchitecture.UNET:
48
+ return UNet(**model_configuration.model_dump())
49
+ elif model_configuration.architecture == SupportedArchitecture.LVAE:
50
+ return LVAE(**model_configuration.model_dump())
51
+ else:
52
+ raise NotImplementedError(
53
+ f"Model {model_configuration.architecture} is not implemented or unknown."
54
+ )