careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,158 @@
1
+ """Module containing pytorch implementations for obtaining predictions from an LVAE."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from careamics.models.lvae import LadderVAE as LVAE
8
+ from careamics.models.lvae.likelihoods import LikelihoodModule
9
+
10
+ # TODO: convert these functions to lightning module `predict_step`
11
+ # -> mmse_count will have to be an instance attribute?
12
+
13
+
14
+ # This function is needed because the output of the datasets (input here) can include
15
+ # auxillary items, such as the TileInformation. This function allows for easier reuse
16
+ # between lvae_predict_single_sample and lvae_predict_mmse.
17
+ def lvae_predict_single_sample(
18
+ model: LVAE,
19
+ likelihood_obj: LikelihoodModule,
20
+ input: torch.Tensor,
21
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
22
+ """
23
+ Generate a single sample prediction from an LVAE model, for a given input.
24
+
25
+ Parameters
26
+ ----------
27
+ model : LVAE
28
+ Trained LVAE model.
29
+ likelihood_obj : LikelihoodModule
30
+ Instance of a likelihood class.
31
+ input : torch.tensor
32
+ Input to generate prediction for. Expected shape is (S, C, Y, X).
33
+
34
+ Returns
35
+ -------
36
+ tuple of (torch.tensor, optional torch.tensor)
37
+ The first element is the sample prediction, and the second element is the
38
+ log-variance. The log-variance will be None if `model.predict_logvar is None`.
39
+ """
40
+ model.eval() # Not in original predict code: effects batch_norm and dropout layers
41
+ with torch.no_grad():
42
+ output: torch.Tensor
43
+ output, _ = model(input) # 2nd item is top-down data dict
44
+
45
+ # presently, get_mean_lv just splits the output in 2 if predict_logvar=True,
46
+ # optionally clips the logvavr if logvar_lowerbound is not None
47
+ # TODO: consider refactoring to remove use of the likelihood object
48
+ sample_prediction, log_var = likelihood_obj.get_mean_lv(output)
49
+
50
+ # TODO: output denormalization using target stats that will be saved in data config
51
+ # -> Don't think we need this, saw it in a random bit of code somewhere.
52
+
53
+ return sample_prediction, log_var
54
+
55
+
56
+ def lvae_predict_tiled_batch(
57
+ model: LVAE,
58
+ likelihood_obj: LikelihoodModule,
59
+ input: tuple[Any],
60
+ ) -> tuple[tuple[Any], tuple[Any] | None]:
61
+ # TODO: fix docstring return types, ... too many output options
62
+ """
63
+ Generate a single sample prediction from an LVAE model, for a given input.
64
+
65
+ Parameters
66
+ ----------
67
+ model : LVAE
68
+ Trained LVAE model.
69
+ likelihood_obj : LikelihoodModule
70
+ Instance of a likelihood class.
71
+ input : torch.tensor | tuple of (torch.tensor, Any, ...)
72
+ Input to generate prediction for. This can include auxilary inputs such as
73
+ `TileInformation`, but the model input is always the first item of the tuple.
74
+ Expected shape of the model input is (S, C, Y, X).
75
+
76
+ Returns
77
+ -------
78
+ tuple of ((torch.tensor, Any, ...), optional tuple of (torch.tensor, Any, ...))
79
+ The first element is the sample prediction, and the second element is the
80
+ log-variance. The log-variance will be None if `model.predict_logvar is None`.
81
+ Any auxillary data included in the input will also be include with both the
82
+ sample prediction and the log-variance.
83
+ """
84
+ x: torch.Tensor
85
+ aux: list[Any]
86
+ x, *aux = input
87
+
88
+ sample_prediction, log_var = lvae_predict_single_sample(
89
+ model=model, likelihood_obj=likelihood_obj, input=x
90
+ )
91
+
92
+ log_var_output = (log_var, *aux) if log_var is not None else None
93
+ return (sample_prediction, *aux), log_var_output
94
+
95
+
96
+ def lvae_predict_mmse_tiled_batch(
97
+ model: LVAE,
98
+ likelihood_obj: LikelihoodModule,
99
+ input: tuple[Any],
100
+ mmse_count: int,
101
+ ) -> tuple[tuple[Any], tuple[Any], tuple[Any] | None]:
102
+ # TODO: fix docstring return types, ... hard to make readable
103
+ """
104
+ Generate the MMSE (minimum mean squared error) prediction, for a given input.
105
+
106
+ This is calculated from the mean of multiple single sample predictions.
107
+
108
+ Parameters
109
+ ----------
110
+ model : LVAE
111
+ Trained LVAE model.
112
+ likelihood_obj : LikelihoodModule
113
+ Instance of a likelihood class.
114
+ input : torch.tensor | tuple of (torch.tensor, Any, ...)
115
+ Input to generate prediction for. This can include auxilary inputs such as
116
+ `TileInformation`, but the model input is always the first item of the tuple.
117
+ Expected shape of the model input is (S, C, Y, X).
118
+ mmse_count : int
119
+ Number of samples to generate to calculate MMSE (minimum mean squared error).
120
+
121
+ Returns
122
+ -------
123
+ tuple of (tuple of (torch.Tensor[Any], Any, ...))
124
+ A tuple of 3 elements. The first element contains the MMSE prediction, the
125
+ second contains the standard deviation of the samples used to create the MMSE
126
+ prediction. Finally the last element contains the log-variance of the
127
+ likelihood, this will be `None` if `likelihood.predict_logvar` is `None`.
128
+ Any auxillary data included in the input will also be include with all of the
129
+ MMSE prediction, the standard deviation, and the log-variance.
130
+ """
131
+ if mmse_count <= 0:
132
+ raise ValueError("MMSE count must be greater than zero.")
133
+
134
+ x: torch.Tensor
135
+ aux: list[Any]
136
+ x, *aux = input
137
+
138
+ input_shape = x.shape
139
+ output_shape = (input_shape[0], model.target_ch, *input_shape[2:])
140
+ log_var: torch.Tensor | None = None
141
+ # pre-declare empty array to fill with individual sample predictions
142
+ sample_predictions = torch.zeros(size=(mmse_count, *output_shape))
143
+ for mmse_idx in range(mmse_count):
144
+ sample_prediction, lv = lvae_predict_single_sample(
145
+ model=model, likelihood_obj=likelihood_obj, input=x
146
+ )
147
+ # only keep the log variance of the first sample prediction
148
+ if mmse_idx == 0:
149
+ log_var = lv
150
+
151
+ # store sample predictions
152
+ sample_predictions[mmse_idx, ...] = sample_prediction
153
+
154
+ mmse_prediction = torch.mean(sample_predictions, dim=0)
155
+ mmse_prediction_std = torch.std(sample_predictions, dim=0)
156
+
157
+ log_var_output = (log_var, *aux) if log_var is not None else None
158
+ return (mmse_prediction, *aux), (mmse_prediction_std, *aux), log_var_output
@@ -0,0 +1,362 @@
1
+ """Module contiaing tiling manager class."""
2
+
3
+ # # TODO: remove this file, left as a reference for now.
4
+
5
+ # from typing import Any, Optional
6
+
7
+ # import numpy as np
8
+ # from numpy.typing import NDArray
9
+
10
+ # from careamics.config.tile_information import TileInformation
11
+ # from careamics.config.validators import check_axes_validity
12
+
13
+
14
+ # def calculate_padding(
15
+ # patch_start_location: NDArray,
16
+ # patch_size: NDArray,
17
+ # data_shape: NDArray,
18
+ # ) -> NDArray:
19
+ # patch_end_location = patch_start_location + patch_size
20
+
21
+ # pad_before = np.zeros_like(patch_start_location)
22
+ # start_out_of_bounds = patch_start_location < 0
23
+ # pad_before[start_out_of_bounds] = -patch_start_location[start_out_of_bounds]
24
+
25
+ # pad_after = np.zeros_like(patch_start_location)
26
+ # end_out_of_bounds = patch_end_location > data_shape
27
+ # pad_after[end_out_of_bounds] = (
28
+ # patch_end_location - data_shape
29
+ # )[end_out_of_bounds]
30
+
31
+ # return np.stack([pad_before, pad_after], axis=1)
32
+
33
+
34
+ # def extract_tile(
35
+ # img: np.ndarray,
36
+ # grid_start_loc: tuple[int, ...],
37
+ # patch_size: tuple[int, ...],
38
+ # overlap: tuple[int, ...],
39
+ # padding: bool,
40
+ # padding_kwargs: Optional[dict[str, Any]] = None,
41
+ # ) -> NDArray:
42
+ # if padding_kwargs is None:
43
+ # padding_kwargs = {}
44
+
45
+ # data_shape = img.shape
46
+ # patch_start_loc = np.array(grid_start_loc) - np.array(overlap) // 2
47
+ # crop_slices = tuple(
48
+ # slice(max(0, start), min(start + size, dim_shape))
49
+ # for start, size, dim_shape in zip(patch_start_loc, patch_size, data_shape)
50
+ # )
51
+ # crop = img[crop_slices]
52
+ # if padding:
53
+ # pad = calculate_padding(
54
+ # patch_start_location=patch_start_loc,
55
+ # patch_size=patch_size,
56
+ # data_shape=data_shape,
57
+ # )
58
+ # crop = np.pad(crop, pad, **padding_kwargs)
59
+
60
+ # return crop
61
+
62
+
63
+ # class TilingManager:
64
+
65
+ # def __init__(
66
+ # self,
67
+ # data_shape: tuple[int, ...],
68
+ # tile_size: tuple[int, ...],
69
+ # overlaps: tuple[int, ...],
70
+ # trim_boundary: tuple[int, ...],
71
+ # ):
72
+ # # --- validation
73
+ # if len(data_shape) != len(tile_size):
74
+ # raise ValueError(
75
+ # f"Data shape:{data_shape} and tile size:{tile_size} must have the "
76
+ # "same dimension"
77
+ # )
78
+ # if len(data_shape) != len(overlaps):
79
+ # raise ValueError(
80
+ # f"Data shape:{data_shape} and tile overlaps:{overlaps} must have the "
81
+ # "same dimension"
82
+ # )
83
+ # # overlaps = np.array(tile_size) - np.array(grid_shape)
84
+ # if (np.array(overlaps) < 0).any():
85
+ # raise ValueError(
86
+ # "Tile overlap must be positive or zero in all dimension."
87
+ # )
88
+ # if ((np.array(overlaps) % 2) != 0).any():
89
+ # # TODO: currently not required by CAREamics tiling,
90
+ # # -> because floor divide is used.
91
+ # raise ValueError("Tile overlaps must be even.")
92
+
93
+ # # initialize attributes
94
+ # self.data_shape = data_shape
95
+ # self.overlaps = overlaps
96
+ # self.grid_shape = tuple(np.array(tile_size) - np.array(overlaps))
97
+ # self.patch_shape = tile_size
98
+ # self.trim_boundary = trim_boundary
99
+
100
+ # def compute_tile_info(self, index: int, axes: str):
101
+
102
+ # # TODO: better axis validation, data should already be in the form SC(Z)YX
103
+
104
+ # # validate axes
105
+ # check_axes_validity(axes)
106
+ # # z will be -1 if not present
107
+ # spatial_axes = [axes.find("Z"), axes.find("Y"), axes.find("X")]
108
+
109
+ # # convert to numpy for convenience
110
+ # data_shape = np.array(self.data_shape)
111
+ # patch_shape = np.array(self.patch_shape)
112
+
113
+ # # --- calculate stitch coords
114
+ # stitch_coords_start = np.array(self.get_location_from_dataset_idx(index))
115
+ # stitch_coords_end = stitch_coords_start + np.array(self.grid_shape)
116
+
117
+ # # --- patch coords
118
+ # patch_coords_start = stitch_coords_start - np.array(self.overlaps) // 2
119
+ # patch_coords_end = patch_coords_start + patch_shape
120
+
121
+ # # --- replace out of bounds indices
122
+
123
+ # out_of_lower_bound = stitch_coords_start < 0
124
+ # out_of_upper_bound = stitch_coords_end > data_shape
125
+
126
+ # stitch_coords_start[out_of_lower_bound] = 0
127
+ # stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound]
128
+
129
+ # # --- calculate overlap crop coords
130
+ # overlap_crop_coords_start = stitch_coords_start - patch_coords_start
131
+ # overlap_crop_coords_end = overlap_crop_coords_start + (
132
+ # stitch_coords_end - stitch_coords_start
133
+ # )
134
+
135
+ # # --- combine start and end
136
+ # stitch_coords = tuple(
137
+ # (stitch_coords_start[axis], stitch_coords_end[axis])
138
+ # for axis in spatial_axes
139
+ # if axis != -1
140
+ # )
141
+ # overlap_crop_coords = tuple(
142
+ # (overlap_crop_coords_start[axis], overlap_crop_coords_end[axis])
143
+ # for axis in spatial_axes
144
+ # if axis != -1
145
+ # )
146
+
147
+ # channel_axis = axes.find("C")
148
+ # array_shape_processed = tuple(
149
+ # data_shape[axis] for axis in [channel_axis, *spatial_axes] if axis != -1
150
+ # )
151
+
152
+ # tile_info = TileInformation(
153
+ # array_shape=array_shape_processed,
154
+ # last_tile=index == self.total_grid_count() - 1,
155
+ # overlap_crop_coords=overlap_crop_coords,
156
+ # stitch_coords=stitch_coords,
157
+ # sample_id=0, # TODO: in iterable dataset this is also always 0 pretty sure
158
+ # )
159
+ # return tile_info
160
+
161
+ # def patch_offset(self):
162
+ # return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
163
+
164
+ # def get_individual_dim_grid_count(self, dim: int):
165
+ # """
166
+ # Returns the number of the grid in the specified dimension, ignoring all other
167
+ # dimensions.
168
+ # """
169
+ # assert dim < len(
170
+ # self.data_shape
171
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
172
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
173
+
174
+ # if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
175
+ # return self.data_shape[dim]
176
+ # elif self.trim_boundary is False:
177
+ # return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
178
+ # else:
179
+ # excess_size = self.patch_shape[dim] - self.grid_shape[dim]
180
+ # return int(
181
+ # np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
182
+ # )
183
+
184
+ # def total_grid_count(self):
185
+ # """
186
+ # Returns the total number of grids in the dataset.
187
+ # """
188
+ # return self.grid_count(0) * self.get_individual_dim_grid_count(0)
189
+
190
+ # def grid_count(self, dim: int):
191
+ # """
192
+ # Returns the total number of grids for one value in the specified dimension.
193
+ # """
194
+ # assert dim < len(
195
+ # self.data_shape
196
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
197
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
198
+ # if dim == len(self.data_shape) - 1:
199
+ # return 1
200
+
201
+ # return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
202
+
203
+ # def get_grid_index(self, dim: int, coordinate: int):
204
+ # """
205
+ # Returns the index of the grid in the specified dimension.
206
+ # """
207
+ # assert dim < len(
208
+ # self.data_shape
209
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
210
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
211
+ # assert (
212
+ # coordinate < self.data_shape[dim]
213
+ # ), (
214
+ # f"Coordinate {coordinate} is out of bounds for data "
215
+ # f"shape {self.data_shape}"
216
+ # )
217
+ # if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
218
+ # return coordinate
219
+ # elif self.trim_boundary is False:
220
+ # return np.floor(coordinate / self.grid_shape[dim])
221
+ # else:
222
+ # excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
223
+ # # can be <0 if coordinate is in [0,grid_shape[dim]]
224
+ # return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
225
+
226
+ # def dataset_idx_from_grid_idx(self, grid_idx: tuple):
227
+ # """
228
+ # Returns the index of the grid in the dataset.
229
+ # """
230
+ # assert len(grid_idx) == len(
231
+ # self.data_shape
232
+ # ), (
233
+ # f"Dimension indices {grid_idx} must have the same dimension as data "
234
+ # f"shape {self.data_shape}"
235
+ # )
236
+ # index = 0
237
+ # for dim in range(len(grid_idx)):
238
+ # index += grid_idx[dim] * self.grid_count(dim)
239
+ # return index
240
+
241
+ # def get_patch_location_from_dataset_idx(self, dataset_idx: int):
242
+ # """
243
+ # Returns the patch location of the grid in the dataset.
244
+ # """
245
+ # location = self.get_location_from_dataset_idx(dataset_idx)
246
+ # offset = self.patch_offset()
247
+ # return tuple(np.array(location) - np.array(offset))
248
+
249
+ # def get_dataset_idx_from_grid_location(self, location: tuple):
250
+ # assert len(location) == len(
251
+ # self.data_shape
252
+ # ), (
253
+ # f"Location {location} must have the same dimension as data shape "
254
+ # f"{self.data_shape}"
255
+ # )
256
+ # grid_idx = [
257
+ # self.get_grid_index(dim, location[dim]) for dim in range(len(location))
258
+ # ]
259
+ # return self.dataset_idx_from_grid_idx(tuple(grid_idx))
260
+
261
+ # def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
262
+ # """
263
+ # Returns the grid-start coordinate of the grid in the specified dimension.
264
+ # """
265
+ # assert dim < len(
266
+ # self.data_shape
267
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
268
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
269
+ # assert dim_index < self.get_individual_dim_grid_count(
270
+ # dim
271
+ # ), (
272
+ # f"Dimension index {dim_index} is out of bounds for data shape "
273
+ # f"{self.data_shape}"
274
+ # )
275
+
276
+ # if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
277
+ # return dim_index
278
+ # elif self.trim_boundary is False:
279
+ # return dim_index * self.grid_shape[dim]
280
+ # else:
281
+ # excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
282
+ # return dim_index * self.grid_shape[dim] + excess_size
283
+
284
+ # def get_location_from_dataset_idx(self, dataset_idx: int):
285
+ # grid_idx = []
286
+ # for dim in range(len(self.data_shape)):
287
+ # grid_idx.append(dataset_idx // self.grid_count(dim))
288
+ # dataset_idx = dataset_idx % self.grid_count(dim)
289
+ # location = [
290
+ # self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
291
+ # for dim in range(len(self.data_shape))
292
+ # ]
293
+ # return tuple(location)
294
+
295
+ # def on_boundary(self, dataset_idx: int, dim: int):
296
+ # """
297
+ # Returns True if the grid is on the boundary in the specified dimension.
298
+ # """
299
+ # assert dim < len(
300
+ # self.data_shape
301
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
302
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
303
+
304
+ # if dim > 0:
305
+ # dataset_idx = dataset_idx % self.grid_count(dim - 1)
306
+
307
+ # dim_index = dataset_idx // self.grid_count(dim)
308
+ # return (
309
+ # dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
310
+ # )
311
+
312
+ # def next_grid_along_dim(self, dataset_idx: int, dim: int):
313
+ # """
314
+ # Returns the index of the grid in the specified dimension in the specified "
315
+ # "direction.
316
+ # """
317
+ # assert dim < len(
318
+ # self.data_shape
319
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
320
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
321
+ # new_idx = dataset_idx + self.grid_count(dim)
322
+ # if new_idx >= self.total_grid_count():
323
+ # return None
324
+ # return new_idx
325
+
326
+ # def prev_grid_along_dim(self, dataset_idx: int, dim: int):
327
+ # """
328
+ # Returns the index of the grid in the specified dimension in the specified "
329
+ # "direction.
330
+ # """
331
+ # assert dim < len(
332
+ # self.data_shape
333
+ # ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
334
+ # assert dim >= 0, "Dimension must be greater than or equal to 0"
335
+ # new_idx = dataset_idx - self.grid_count(dim)
336
+ # if new_idx < 0:
337
+ # return None
338
+
339
+
340
+ # if __name__ == "__main__":
341
+ # data_shape = (1, 1, 103, 103, 2)
342
+ # grid_shape = (1, 1, 16, 16, 2)
343
+ # patch_shape = (1, 1, 32, 32, 2)
344
+ # overlap = tuple(np.array(patch_shape) - np.array(grid_shape))
345
+
346
+ # trim_boundary = False
347
+ # manager = TilingManager(
348
+ # data_shape=data_shape,
349
+ # tile_size=patch_shape,
350
+ # overlaps=overlap,
351
+ # trim_boundary=trim_boundary,
352
+ # )
353
+ # gc = manager.total_grid_count()
354
+ # print("Grid count", gc)
355
+ # for i in range(gc):
356
+ # loc = manager.get_location_from_dataset_idx(i)
357
+ # print(i, loc)
358
+ # inferred_i = manager.get_dataset_idx_from_grid_location(loc)
359
+ # assert i == inferred_i, f"Index mismatch: {i} != {inferred_i}"
360
+
361
+ # for i in range(5):
362
+ # print(manager.on_boundary(40, i))