careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,238 @@
1
+ """Module containing functions to convert prediction outputs to desired form."""
2
+
3
+ from typing import Any, Literal, Union, overload
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from ..config.data.tile_information import TileInformation
9
+ from .stitch_prediction import stitch_prediction, stitch_prediction_vae
10
+
11
+
12
+ def convert_outputs(predictions: list[Any], tiled: bool) -> list[NDArray]:
13
+ """
14
+ Convert the Lightning trainer outputs to the desired form.
15
+
16
+ This method allows stitching back together tiled predictions.
17
+
18
+ Parameters
19
+ ----------
20
+ predictions : list
21
+ Predictions that are output from `Trainer.predict`.
22
+ tiled : bool
23
+ Whether the predictions are tiled.
24
+
25
+ Returns
26
+ -------
27
+ list of numpy.ndarray or numpy.ndarray
28
+ list of arrays with the axes SC(Z)YX. If there is only 1 output it will not
29
+ be in a list.
30
+ """
31
+ if len(predictions) == 0:
32
+ return predictions
33
+
34
+ # this layout is to stop mypy complaining
35
+ if tiled:
36
+ predictions_comb = combine_batches(predictions, tiled)
37
+ predictions_output = stitch_prediction(*predictions_comb)
38
+ else:
39
+ predictions_output = combine_batches(predictions, tiled)
40
+
41
+ return predictions_output
42
+
43
+
44
+ def convert_outputs_pn2v(
45
+ predictions: list[Any], tiled: bool
46
+ ) -> tuple[list[NDArray], list[NDArray]]:
47
+ """
48
+ Convert the Lightning trainer outputs to the desired form.
49
+
50
+ This method allows stitching back together tiled predictions.
51
+
52
+ Parameters
53
+ ----------
54
+ predictions : list
55
+ Predictions that are output from `Trainer.predict`. Length of list the total
56
+ number of tiles divided by the batch size. Each element consists of a tuple of
57
+ ((prediction, mse), tile_info_list). 1st dimension of each tensor is the bs.
58
+ Length of tile info list is the batch size.
59
+
60
+ tiled : bool
61
+ Whether the predictions are tiled.
62
+
63
+ Returns
64
+ -------
65
+ tuple[list[NDArray], list[NDArray]]
66
+ Tuple of (predictions, mmse) where each is a list of arrays with axes SC(Z)YX.
67
+ """
68
+ if len(predictions) == 0:
69
+ return [], []
70
+ # TODO test with multi_channel predictions
71
+ if tiled:
72
+ # Separate predictions and mmse, keeping tile info for each
73
+ pred_with_tiles = [
74
+ (pred, tile_info_list) for (pred, _), tile_info_list in predictions
75
+ ]
76
+ mse_with_tiles = [
77
+ (mse, tile_info_list) for (_, mse), tile_info_list in predictions
78
+ ]
79
+
80
+ # Process predictions
81
+ pred_comb = combine_batches(pred_with_tiles, tiled)
82
+ predictions_output = stitch_prediction(*pred_comb)
83
+
84
+ # Process mmse
85
+ mse_comb = combine_batches(mse_with_tiles, tiled)
86
+ mse_output = stitch_prediction(*mse_comb)
87
+
88
+ return predictions_output, mse_output
89
+ else:
90
+ # Separate predictions and mmse for non-tiled case
91
+ pred_only_tuple, mse_only_tuple = zip(*predictions, strict=False)
92
+ pred_only_list: list[NDArray] = list(pred_only_tuple)
93
+ mse_only_list: list[NDArray] = list(mse_only_tuple)
94
+
95
+ predictions_output = combine_batches(pred_only_list, tiled=False)
96
+ mse_output = combine_batches(mse_only_list, tiled=False)
97
+
98
+ return predictions_output, mse_output
99
+
100
+
101
+ def convert_outputs_microsplit(
102
+ predictions: list[tuple[NDArray, NDArray]], dataset
103
+ ) -> tuple[NDArray, NDArray]:
104
+ """
105
+ Convert microsplit Lightning trainer outputs using eval_utils stitching functions.
106
+
107
+ This function processes microsplit predictions that return
108
+ (tile_prediction, tile_std) tuples and stitches them back together using the same
109
+ logic as get_single_file_mmse.
110
+
111
+ Parameters
112
+ ----------
113
+ predictions : list of tuple[NDArray, NDArray]
114
+ Predictions from Lightning trainer for microsplit. Each element is a tuple of
115
+ (tile_prediction, tile_std) where both are numpy arrays from predict_step.
116
+ dataset : Dataset
117
+ The dataset object used for prediction, needed for stitching function selection
118
+ and stitching process.
119
+
120
+ Returns
121
+ -------
122
+ tuple[NDArray, NDArray]
123
+ A tuple of (stitched_predictions, stitched_stds) representing the full
124
+ stitched predictions and standard deviations.
125
+ """
126
+ if len(predictions) == 0:
127
+ raise ValueError("No predictions provided")
128
+
129
+ # Separate predictions and stds from the list of tuples
130
+ tile_predictions = [pred for pred, _ in predictions]
131
+ tile_stds = [std for _, std in predictions]
132
+
133
+ # Concatenate all tiles exactly like get_single_file_mmse
134
+ tiles_arr = np.concatenate(tile_predictions, axis=0)
135
+ tile_stds_arr = np.concatenate(tile_stds, axis=0)
136
+
137
+ # Apply stitching using stitch_predictions_new
138
+ stitched_predictions = stitch_prediction_vae(tiles_arr, dataset)
139
+ stitched_stds = stitch_prediction_vae(tile_stds_arr, dataset)
140
+
141
+ return stitched_predictions, stitched_stds
142
+
143
+
144
+ # for mypy
145
+ @overload
146
+ def combine_batches( # numpydoc ignore=GL08
147
+ predictions: list[Any], tiled: Literal[True]
148
+ ) -> tuple[list[NDArray], list[TileInformation]]: ...
149
+
150
+
151
+ # for mypy
152
+ @overload
153
+ def combine_batches( # numpydoc ignore=GL08
154
+ predictions: list[Any], tiled: Literal[False]
155
+ ) -> list[NDArray]: ...
156
+
157
+
158
+ # for mypy
159
+ @overload
160
+ def combine_batches( # numpydoc ignore=GL08
161
+ predictions: list[Any], tiled: Union[bool, Literal[True], Literal[False]]
162
+ ) -> Union[list[NDArray], tuple[list[NDArray], list[TileInformation]]]: ...
163
+
164
+
165
+ def combine_batches(
166
+ predictions: list[Any], tiled: bool
167
+ ) -> Union[list[NDArray], tuple[list[NDArray], list[TileInformation]]]:
168
+ """
169
+ If predictions are in batches, they will be combined.
170
+
171
+ # TODO improve description!
172
+
173
+ Parameters
174
+ ----------
175
+ predictions : list
176
+ Predictions that are output from `Trainer.predict`.
177
+ tiled : bool
178
+ Whether the predictions are tiled.
179
+
180
+ Returns
181
+ -------
182
+ (list of numpy.ndarray) or tuple of (list of numpy.ndarray, list of TileInformation)
183
+ Combined batches.
184
+ """
185
+ if tiled:
186
+ return _combine_tiled_batches(predictions)
187
+ else:
188
+ return _combine_array_batches(predictions)
189
+
190
+
191
+ def _combine_tiled_batches(
192
+ predictions: list[tuple[NDArray, list[TileInformation]]],
193
+ ) -> tuple[list[NDArray], list[TileInformation]]:
194
+ """
195
+ Combine batches from tiled output.
196
+
197
+ Parameters
198
+ ----------
199
+ predictions : list of (numpy.ndarray, list of TileInformation)
200
+ Predictions that are output from `Trainer.predict`. For tiled batches, this is
201
+ a list of tuples. The first element of the tuples is the prediction output of
202
+ tiles with dimension (B, C, (Z), Y, X), where B is batch size. The second
203
+ element of the tuples is a list of TileInformation objects of length B.
204
+
205
+ Returns
206
+ -------
207
+ tuple of (list of numpy.ndarray, list of TileInformation)
208
+ Combined batches.
209
+ """
210
+ # turn list of lists into single list
211
+ tile_infos = [
212
+ tile_info for *_, tile_info_list in predictions for tile_info in tile_info_list
213
+ ]
214
+ prediction_tiles: list[NDArray] = _combine_array_batches(
215
+ [preds for preds, *_ in predictions]
216
+ )
217
+
218
+ return prediction_tiles, tile_infos
219
+
220
+
221
+ def _combine_array_batches(predictions: list[NDArray]) -> list[NDArray]:
222
+ """
223
+ Combine batches of arrays.
224
+
225
+ Parameters
226
+ ----------
227
+ predictions : list
228
+ Prediction arrays that are output from `Trainer.predict`. A list of arrays that
229
+ have dimensions (B, C, (Z), Y, X), where B is batch size.
230
+
231
+ Returns
232
+ -------
233
+ list of numpy.ndarray
234
+ A list of arrays with dimensions (1, C, (Z), Y, X).
235
+ """
236
+ prediction_concat: NDArray = np.concatenate(predictions, axis=0)
237
+ prediction_split = np.split(prediction_concat, prediction_concat.shape[0], axis=0)
238
+ return prediction_split
@@ -0,0 +1,193 @@
1
+ """Prediction utility functions."""
2
+
3
+ import builtins
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+
9
+ from careamics.config.data.tile_information import TileInformation
10
+
11
+
12
+ class TilingMode:
13
+ """Enum for the tiling mode."""
14
+
15
+ TrimBoundary = 0
16
+ PadBoundary = 1
17
+ ShiftBoundary = 2
18
+
19
+
20
+ def stitch_prediction_vae(predictions, dset):
21
+ """Stitch predictions back together using dataset's index manager.
22
+
23
+ Parameters
24
+ ----------
25
+ predictions : numpy.ndarray
26
+ Array of predictions with shape (n_tiles, channels, height, width).
27
+ dset : Dataset
28
+ Dataset object with idx_manager containing tiling information.
29
+
30
+ Returns
31
+ -------
32
+ numpy.ndarray
33
+ Stitched predictions.
34
+ """
35
+ mng = dset.idx_manager
36
+
37
+ # if there are more channels, use all of them.
38
+ shape = list(dset.get_data_shape())
39
+ shape[-1] = max(shape[-1], predictions.shape[1])
40
+
41
+ output = np.zeros(shape, dtype=predictions.dtype)
42
+ # frame_shape = dset.get_data_shape()[:-1]
43
+ for dset_idx in range(predictions.shape[0]):
44
+ # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
45
+ # predictions.shape[-1])
46
+ # grid start, grid end
47
+ gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
48
+ ge = gs + mng.grid_shape
49
+
50
+ # patch start, patch end
51
+ ps = gs - mng.patch_offset()
52
+ pe = ps + mng.patch_shape
53
+
54
+ # valid grid start, valid grid end
55
+ vgs = np.array([max(0, x) for x in gs], dtype=int)
56
+ vge = np.array(
57
+ [min(x, y) for x, y in zip(ge, mng.data_shape, strict=False)], dtype=int
58
+ )
59
+
60
+ if mng.tiling_mode == TilingMode.ShiftBoundary:
61
+ for dim in range(len(vgs)):
62
+ if ps[dim] == 0:
63
+ vgs[dim] = 0
64
+ if pe[dim] == mng.data_shape[dim]:
65
+ vge[dim] = mng.data_shape[dim]
66
+
67
+ # relative start, relative end. This will be used on pred_tiled
68
+ rs = vgs - ps
69
+ re = rs + (vge - vgs)
70
+
71
+ for ch_idx in range(predictions.shape[1]):
72
+ if len(output.shape) == 4:
73
+ # channel dimension is the last one.
74
+ output[vgs[0] : vge[0], vgs[1] : vge[1], vgs[2] : vge[2], ch_idx] = (
75
+ predictions[dset_idx][ch_idx, rs[1] : re[1], rs[2] : re[2]]
76
+ )
77
+ elif len(output.shape) == 5:
78
+ # channel dimension is the last one.
79
+ assert vge[0] - vgs[0] == 1, "Only one frame is supported"
80
+ output[
81
+ vgs[0], vgs[1] : vge[1], vgs[2] : vge[2], vgs[3] : vge[3], ch_idx
82
+ ] = predictions[dset_idx][
83
+ ch_idx, rs[1] : re[1], rs[2] : re[2], rs[3] : re[3]
84
+ ]
85
+ else:
86
+ raise ValueError(f"Unsupported shape {output.shape}")
87
+
88
+ return output
89
+
90
+
91
+ # TODO: why not allow input and output of torch.tensor ?
92
+ def stitch_prediction(
93
+ tiles: list[np.ndarray],
94
+ tile_infos: list[TileInformation],
95
+ ) -> list[np.ndarray]:
96
+ """
97
+ Stitch tiles back together to form a full image(s).
98
+
99
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
100
+ singleton dimension.
101
+
102
+ Parameters
103
+ ----------
104
+ tiles : list of numpy.ndarray
105
+ Cropped tiles and their respective stitching coordinates. Can contain tiles
106
+ from multiple images.
107
+ tile_infos : list of TileInformation
108
+ List of information and coordinates obtained from
109
+ `dataset.tiled_patching.extract_tiles`.
110
+
111
+ Returns
112
+ -------
113
+ list of numpy.ndarray
114
+ Full image(s).
115
+ """
116
+ # Find where to split the lists so that only info from one image is contained.
117
+ # Do this by locating the last tiles of each image.
118
+ last_tiles = [tile_info.last_tile for tile_info in tile_infos]
119
+ last_tile_position = np.where(last_tiles)[0]
120
+ image_slices = [
121
+ slice(
122
+ None if i == 0 else last_tile_position[i - 1] + 1, last_tile_position[i] + 1
123
+ )
124
+ for i in range(len(last_tile_position))
125
+ ]
126
+ image_predictions = []
127
+ # slice the lists and apply stitch_prediction_single to each in turn.
128
+ for image_slice in image_slices:
129
+ image_predictions.append(
130
+ stitch_prediction_single(tiles[image_slice], tile_infos[image_slice])
131
+ )
132
+ return image_predictions
133
+
134
+
135
+ def stitch_prediction_single(
136
+ tiles: list[NDArray],
137
+ tile_infos: list[TileInformation],
138
+ ) -> NDArray:
139
+ """
140
+ Stitch tiles back together to form a full image.
141
+
142
+ Tiles are of dimensions SC(Z)YX, where C is the number of channels and can be a
143
+ singleton dimension.
144
+
145
+ Parameters
146
+ ----------
147
+ tiles : list of numpy.ndarray
148
+ Cropped tiles and their respective stitching coordinates.
149
+ tile_infos : list of TileInformation
150
+ List of information and coordinates obtained from
151
+ `dataset.tiled_patching.extract_tiles`.
152
+
153
+ Returns
154
+ -------
155
+ numpy.ndarray
156
+ Full image, with dimensions SC(Z)YX.
157
+ """
158
+ # TODO: this is hacky... need a better way to deal with when input channels and
159
+ # target channels do not match
160
+ if len(tile_infos[0].array_shape) == 4:
161
+ # 4 dimensions => 3 spatial dimensions so -4 is channel dimension
162
+ tile_channels = tiles[0].shape[-4]
163
+ elif len(tile_infos[0].array_shape) == 3:
164
+ # 3 dimensions => 2 spatial dimensions so -3 is channel dimension
165
+ tile_channels = tiles[0].shape[-3]
166
+ else:
167
+ # Note pretty sure this is unreachable because array shape is already
168
+ # validated by TileInformation
169
+ raise ValueError(
170
+ f"Unsupported number of output dimension {len(tile_infos[0].array_shape)}"
171
+ )
172
+ # retrieve whole array size, add S dim and use number of channels in tile
173
+ input_shape = (1, tile_channels, *tile_infos[0].array_shape[1:])
174
+ predicted_image = np.zeros(input_shape, dtype=np.float32)
175
+
176
+ for tile, tile_info in zip(tiles, tile_infos, strict=False):
177
+
178
+ # Compute coordinates for cropping predicted tile
179
+ crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
180
+ ...,
181
+ *[slice(c[0], c[1]) for c in tile_info.overlap_crop_coords],
182
+ )
183
+
184
+ # Crop predited tile according to overlap coordinates
185
+ cropped_tile = tile[crop_slices]
186
+
187
+ # Insert cropped tile into predicted image using stitch coordinates
188
+ image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
189
+
190
+ # TODO fix mypy error here, potentially due to numpy 2
191
+ predicted_image[image_slices] = cropped_tile.astype(np.float32) # type: ignore
192
+
193
+ return predicted_image
careamics/py.typed ADDED
@@ -0,0 +1,5 @@
1
+ You may remove this file if you don't intend to add types to your package
2
+
3
+ Details at:
4
+
5
+ https://mypy.readthedocs.io/en/stable/installed_packages.html#creating-pep-561-compatible-packages
@@ -0,0 +1,22 @@
1
+ """Transforms that are used to augment the data."""
2
+
3
+ __all__ = [
4
+ "Compose",
5
+ "Denormalize",
6
+ "ImageRestorationTTA",
7
+ "N2VManipulate",
8
+ "N2VManipulateTorch",
9
+ "Normalize",
10
+ "TrainDenormalize",
11
+ "XYFlip",
12
+ "XYRandomRotate90",
13
+ "get_all_transforms",
14
+ ]
15
+
16
+ from .compose import Compose, get_all_transforms
17
+ from .n2v_manipulate import N2VManipulate
18
+ from .n2v_manipulate_torch import N2VManipulateTorch
19
+ from .normalize import Denormalize, Normalize, TrainDenormalize
20
+ from .tta import ImageRestorationTTA
21
+ from .xy_flip import XYFlip
22
+ from .xy_random_rotate90 import XYRandomRotate90
@@ -0,0 +1,173 @@
1
+ """A class chaining transforms together."""
2
+
3
+ from typing import Union, cast
4
+
5
+ from numpy.typing import NDArray
6
+
7
+ from careamics.config.transformations import NORM_AND_SPATIAL_UNION
8
+
9
+ from .normalize import Normalize
10
+ from .transform import Transform
11
+ from .xy_flip import XYFlip
12
+ from .xy_random_rotate90 import XYRandomRotate90
13
+
14
+ ALL_TRANSFORMS = {
15
+ "Normalize": Normalize,
16
+ "XYFlip": XYFlip,
17
+ "XYRandomRotate90": XYRandomRotate90,
18
+ }
19
+
20
+
21
+ def get_all_transforms() -> dict[str, type]:
22
+ """Return all the transforms accepted by CAREamics.
23
+
24
+ Returns
25
+ -------
26
+ dict
27
+ A dictionary with all the transforms accepted by CAREamics, where the keys are
28
+ the transform names and the values are the transform classes.
29
+ """
30
+ return ALL_TRANSFORMS
31
+
32
+
33
+ class Compose:
34
+ """A class chaining transforms together.
35
+
36
+ Parameters
37
+ ----------
38
+ transform_list : list[TransformConfig]
39
+ A list of dictionaries where each dictionary contains the name of a
40
+ transform and its parameters.
41
+
42
+ Attributes
43
+ ----------
44
+ _callable_transforms : Callable
45
+ A callable that applies the transforms to the input data.
46
+ """
47
+
48
+ def __init__(self, transform_list: list[NORM_AND_SPATIAL_UNION]) -> None:
49
+ """Instantiate a Compose object.
50
+
51
+ Parameters
52
+ ----------
53
+ transform_list : list[NORM_AND_SPATIAL_UNION]
54
+ A list of dictionaries where each dictionary contains the name of a
55
+ transform and its parameters.
56
+ """
57
+ # retrieve all available transforms
58
+ # TODO: correctly type hint get_all_transforms function output
59
+ all_transforms: dict[str, type[Transform]] = get_all_transforms()
60
+
61
+ # instantiate all transforms
62
+ self.transforms: list[Transform] = [
63
+ all_transforms[t.name](**t.model_dump()) for t in transform_list
64
+ ]
65
+
66
+ def _chain_transforms(
67
+ self, patch: NDArray, target: NDArray | None
68
+ ) -> tuple[NDArray | None, ...]:
69
+ """Chain transforms on the input data.
70
+
71
+ Parameters
72
+ ----------
73
+ patch : np.ndarray
74
+ Input data.
75
+ target : Optional[np.ndarray]
76
+ Target data, by default None.
77
+
78
+ Returns
79
+ -------
80
+ tuple[np.ndarray, Optional[np.ndarray]]
81
+ The output of the transformations.
82
+ """
83
+ params: Union[tuple[NDArray, NDArray | None],] = (patch, target)
84
+
85
+ for t in self.transforms:
86
+ *params, _ = t(*params) # ignore additional_arrays dict
87
+
88
+ # avoid None values that create problems for collating
89
+ # TODO: removing None should be handled in dataset, not here
90
+ return tuple(p for p in params if p is not None)
91
+
92
+ def _chain_transforms_additional_arrays(
93
+ self,
94
+ patch: NDArray,
95
+ target: NDArray | None,
96
+ **additional_arrays: NDArray,
97
+ ) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
98
+ """Chain transforms on the input data, with additional arrays.
99
+
100
+ Parameters
101
+ ----------
102
+ patch : np.ndarray
103
+ Input data.
104
+ target : Optional[np.ndarray]
105
+ Target data, by default None.
106
+ **additional_arrays : NDArray
107
+ Additional arrays that will be transformed identically to `patch` and
108
+ `target`.
109
+
110
+ Returns
111
+ -------
112
+ tuple[np.ndarray, Optional[np.ndarray]]
113
+ The output of the transformations.
114
+ """
115
+ params = {"patch": patch, "target": target, **additional_arrays}
116
+
117
+ for t in self.transforms:
118
+ patch, target, additional_arrays = t(**params)
119
+ params = {"patch": patch, "target": target, **additional_arrays}
120
+
121
+ return patch, target, additional_arrays
122
+
123
+ def __call__(
124
+ self, patch: NDArray, target: NDArray | None = None
125
+ ) -> tuple[NDArray, ...]:
126
+ """Apply the transforms to the input data.
127
+
128
+ Parameters
129
+ ----------
130
+ patch : np.ndarray
131
+ The input data.
132
+ target : Optional[np.ndarray], optional
133
+ Target data, by default None.
134
+
135
+ Returns
136
+ -------
137
+ tuple[np.ndarray, ...]
138
+ The output of the transformations.
139
+ """
140
+ # TODO: solve casting Compose.__call__ ouput
141
+ return cast(tuple[NDArray, ...], self._chain_transforms(patch, target))
142
+
143
+ def transform_with_additional_arrays(
144
+ self,
145
+ patch: NDArray,
146
+ target: NDArray | None = None,
147
+ **additional_arrays: NDArray,
148
+ ) -> tuple[NDArray, NDArray | None, dict[str, NDArray]]:
149
+ """Apply the transforms to the input data, including additional arrays.
150
+
151
+ Parameters
152
+ ----------
153
+ patch : np.ndarray
154
+ The input data.
155
+ target : Optional[np.ndarray], optional
156
+ Target data, by default None.
157
+ **additional_arrays : NDArray
158
+ Additional arrays that will be transformed identically to `patch` and
159
+ `target`.
160
+
161
+ Returns
162
+ -------
163
+ NDArray
164
+ The transformed patch.
165
+ NDArray | None
166
+ The transformed target.
167
+ dict of {str, NDArray}
168
+ Transformed additional arrays. Keys correspond to the keyword argument
169
+ names.
170
+ """
171
+ return self._chain_transforms_additional_arrays(
172
+ patch, target, **additional_arrays
173
+ )