careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,987 @@
1
+ """
2
+ This script provides methods to evaluate the performance of the LVAE model.
3
+ It includes functions to:
4
+ - make predictions,
5
+ - quantify the performance of the model
6
+ - create plots to visualize the results.
7
+ """
8
+
9
+ import os
10
+ from typing import Optional
11
+
12
+ import matplotlib
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import torch
16
+ from matplotlib.gridspec import GridSpec
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from tqdm import tqdm
19
+
20
+ from careamics.lightning import VAEModule
21
+ from careamics.lvae_training.dataset import MultiChDloaderRef
22
+ from careamics.utils.metrics import scale_invariant_psnr
23
+
24
+
25
+ class TilingMode:
26
+ """
27
+ Enum for the tiling mode.
28
+ """
29
+
30
+ TrimBoundary = 0
31
+ PadBoundary = 1
32
+ ShiftBoundary = 2
33
+
34
+
35
+ # ------------------------------------------------------------------------------------
36
+ # Function of plotting: TODO -> moved them to another file, plot_utils.py
37
+ def clean_ax(ax):
38
+ """
39
+ Helper function to remove ticks from axes in plots.
40
+ """
41
+ # 2D or 1D axes are of type np.ndarray
42
+ if isinstance(ax, np.ndarray):
43
+ for one_ax in ax:
44
+ clean_ax(one_ax)
45
+ return
46
+
47
+ ax.set_yticklabels([])
48
+ ax.set_xticklabels([])
49
+ ax.tick_params(left=False, right=False, top=False, bottom=False)
50
+
51
+
52
+ def get_eval_output_dir(
53
+ saveplotsdir: str, patch_size: int, mmse_count: int = 50
54
+ ) -> str:
55
+ """
56
+ Given the path to a root directory to save plots, patch size, and mmse count,
57
+ it returns the specific directory to save the plots.
58
+ """
59
+ eval_out_dir = os.path.join(
60
+ saveplotsdir, f"eval_outputs/patch_{patch_size}_mmse_{mmse_count}"
61
+ )
62
+ os.makedirs(eval_out_dir, exist_ok=True)
63
+ print(eval_out_dir)
64
+ return eval_out_dir
65
+
66
+
67
+ def get_psnr_str(tar_hsnr, pred, col_idx):
68
+ """
69
+ Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
70
+ """
71
+ psnr = scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item()
72
+
73
+ return f"{psnr:.1f}"
74
+
75
+
76
+ def add_psnr_str(ax_, psnr):
77
+ """
78
+ Add psnr string to the axes
79
+ """
80
+ textstr = f"PSNR\n{psnr}"
81
+ props = dict(boxstyle="round", facecolor="gray", alpha=0.5)
82
+ # place a text box in upper left in axes coords
83
+ ax_.text(
84
+ 0.05,
85
+ 0.95,
86
+ textstr,
87
+ transform=ax_.transAxes,
88
+ fontsize=11,
89
+ verticalalignment="top",
90
+ bbox=props,
91
+ color="white",
92
+ )
93
+
94
+
95
+ def get_last_index(bin_count, quantile):
96
+ cumsum = np.cumsum(bin_count)
97
+ normalized_cumsum = cumsum / cumsum[-1]
98
+ for i in range(1, len(normalized_cumsum)):
99
+ if normalized_cumsum[-i] < quantile:
100
+ return i - 1
101
+ return None
102
+
103
+
104
+ def get_first_index(bin_count, quantile):
105
+ cumsum = np.cumsum(bin_count)
106
+ normalized_cumsum = cumsum / cumsum[-1]
107
+ for i in range(len(normalized_cumsum)):
108
+ if normalized_cumsum[i] > quantile:
109
+ return i
110
+ return None
111
+
112
+
113
+ def get_device():
114
+ if torch.cuda.is_available():
115
+ return "cuda"
116
+ elif torch.backends.mps.is_available():
117
+ return "mps"
118
+ else:
119
+ return "cpu"
120
+
121
+
122
+ def show_for_one(
123
+ idx,
124
+ val_dset,
125
+ highsnr_val_dset,
126
+ model,
127
+ calibration_stats,
128
+ mmse_count=5,
129
+ patch_size=256,
130
+ num_samples=2,
131
+ baseline_preds=None,
132
+ ):
133
+ """
134
+ Given an index, it plots the input, target, reconstructed images and the difference
135
+ image.
136
+ Note the the difference image is computed with respect to a ground truth image,
137
+ obtained from the high SNR dataset.
138
+ """
139
+ highsnr_val_dset.set_img_sz(patch_size, 64)
140
+ highsnr_val_dset.disable_noise()
141
+ _, tar_hsnr = highsnr_val_dset[idx]
142
+ inp, tar, recon_img_list = get_predictions(
143
+ idx, val_dset, model, mmse_count=mmse_count, patch_size=patch_size
144
+ )
145
+ plot_crops(
146
+ inp,
147
+ tar,
148
+ tar_hsnr,
149
+ recon_img_list,
150
+ calibration_stats,
151
+ num_samples=num_samples,
152
+ baseline_preds=baseline_preds,
153
+ )
154
+
155
+
156
+ def plot_crops(
157
+ inp,
158
+ tar,
159
+ tar_hsnr,
160
+ recon_img_list,
161
+ calibration_stats=None,
162
+ num_samples=2,
163
+ baseline_preds=None,
164
+ ):
165
+ if baseline_preds is None:
166
+ baseline_preds = []
167
+ if len(baseline_preds) > 0:
168
+ for i in range(len(baseline_preds)):
169
+ if baseline_preds[i].shape != tar_hsnr.shape:
170
+ print(
171
+ f"Baseline prediction {i} shape {baseline_preds[i].shape} does not "
172
+ f"match target shape {tar_hsnr.shape}"
173
+ )
174
+ print("This happens when we want to predict the edges of the image.")
175
+ return
176
+ color_ch_list = ["goldenrod", "cyan"]
177
+ color_pred = "red"
178
+ insetplot_xmax_value = 10000
179
+ insetplot_xmin_value = -1000
180
+ inset_min_labelsize = 10
181
+ inset_rect = [0.05, 0.05, 0.4, 0.2]
182
+
183
+ img_sz = 3
184
+ ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1)
185
+ grid_factor = 5
186
+ grid_img_sz = img_sz * grid_factor
187
+ example_spacing = 1
188
+ c0_extra = 1
189
+ nimgs = 1
190
+ fig_w = ncols * img_sz + 2 * c0_extra / grid_factor
191
+ fig_h = int(img_sz * ncols + (example_spacing * (nimgs - 1)) / grid_factor)
192
+ fig = plt.figure(figsize=(fig_w, fig_h))
193
+ gs = GridSpec(
194
+ nrows=int(grid_factor * fig_h),
195
+ ncols=int(grid_factor * fig_w),
196
+ hspace=0.2,
197
+ wspace=0.2,
198
+ )
199
+ params = {"mathtext.default": "regular"}
200
+ plt.rcParams.update(params)
201
+ # plot baselines
202
+ for i in range(2, 2 + len(baseline_preds)):
203
+ for col_idx in range(baseline_preds[0].shape[0]):
204
+ ax_temp = fig.add_subplot(
205
+ gs[
206
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
207
+ i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra,
208
+ ]
209
+ )
210
+ print(tar_hsnr.shape, baseline_preds[i - 2].shape)
211
+ psnr = get_psnr_str(tar_hsnr, baseline_preds[i - 2], col_idx)
212
+ ax_temp.imshow(baseline_preds[i - 2][col_idx], cmap="magma")
213
+ add_psnr_str(ax_temp, psnr)
214
+ clean_ax(ax_temp)
215
+
216
+ # plot samples
217
+ sample_start_idx = 2 + len(baseline_preds)
218
+ for i in range(sample_start_idx, ncols - 3):
219
+ for col_idx in range(recon_img_list.shape[1]):
220
+ ax_temp = fig.add_subplot(
221
+ gs[
222
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
223
+ i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra,
224
+ ]
225
+ )
226
+ psnr = get_psnr_str(tar_hsnr, recon_img_list[i - sample_start_idx], col_idx)
227
+ ax_temp.imshow(recon_img_list[i - sample_start_idx][col_idx], cmap="magma")
228
+ add_psnr_str(ax_temp, psnr)
229
+ clean_ax(ax_temp)
230
+ # inset_ax = add_pixel_kde(ax_temp,
231
+ # inset_rect,
232
+ # [tar_hsnr[col_idx],
233
+ # recon_img_list[i - sample_start_idx][col_idx]],
234
+ # inset_min_labelsize,
235
+ # label_list=['', ''],
236
+ # color_list=[color_ch_list[col_idx], color_pred],
237
+ # plot_xmax_value=insetplot_xmax_value,
238
+ # plot_xmin_value=insetplot_xmin_value)
239
+
240
+ # inset_ax.set_xticks([])
241
+ # inset_ax.set_yticks([])
242
+
243
+ # difference image
244
+ if num_samples > 1:
245
+ for col_idx in range(recon_img_list.shape[1]):
246
+ ax_temp = fig.add_subplot(
247
+ gs[
248
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
249
+ (ncols - 3) * grid_img_sz
250
+ + c0_extra : (ncols - 2) * grid_img_sz
251
+ + c0_extra,
252
+ ]
253
+ )
254
+ ax_temp.imshow(
255
+ recon_img_list[1][col_idx] - recon_img_list[0][col_idx], cmap="coolwarm"
256
+ )
257
+ clean_ax(ax_temp)
258
+
259
+ for col_idx in range(recon_img_list.shape[1]):
260
+ # print(recon_img_list.shape)
261
+ ax_temp = fig.add_subplot(
262
+ gs[
263
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
264
+ c0_extra
265
+ + (ncols - 2) * grid_img_sz : (ncols - 1) * grid_img_sz
266
+ + c0_extra,
267
+ ]
268
+ )
269
+ psnr = get_psnr_str(tar_hsnr, recon_img_list.mean(axis=0), col_idx)
270
+ ax_temp.imshow(recon_img_list.mean(axis=0)[col_idx], cmap="magma")
271
+ add_psnr_str(ax_temp, psnr)
272
+ # inset_ax = add_pixel_kde(ax_temp,
273
+ # inset_rect,
274
+ # [tar_hsnr[col_idx],
275
+ # recon_img_list.mean(axis=0)[col_idx]],
276
+ # inset_min_labelsize,
277
+ # label_list=['', ''],
278
+ # color_list=[color_ch_list[col_idx], color_pred],
279
+ # plot_xmax_value=insetplot_xmax_value,
280
+ # plot_xmin_value=insetplot_xmin_value)
281
+ # inset_ax.set_xticks([])
282
+ # inset_ax.set_yticks([])
283
+
284
+ clean_ax(ax_temp)
285
+
286
+ ax_temp = fig.add_subplot(
287
+ gs[
288
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
289
+ (ncols - 1) * grid_img_sz
290
+ + 2 * c0_extra : (ncols) * grid_img_sz
291
+ + 2 * c0_extra,
292
+ ]
293
+ )
294
+ ax_temp.imshow(tar_hsnr[col_idx], cmap="magma")
295
+ if col_idx == 0:
296
+ legend_ch1_ax = ax_temp
297
+ if col_idx == 1:
298
+ legend_ch2_ax = ax_temp
299
+
300
+ # inset_ax = add_pixel_kde(ax_temp,
301
+ # inset_rect,
302
+ # [tar_hsnr[col_idx],
303
+ # ],
304
+ # inset_min_labelsize,
305
+ # label_list=[''],
306
+ # color_list=[color_ch_list[col_idx]],
307
+ # plot_xmax_value=insetplot_xmax_value,
308
+ # plot_xmin_value=insetplot_xmin_value)
309
+ # inset_ax.set_xticks([])
310
+ # inset_ax.set_yticks([])
311
+
312
+ clean_ax(ax_temp)
313
+
314
+ ax_temp = fig.add_subplot(
315
+ gs[
316
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
317
+ grid_img_sz : 2 * grid_img_sz,
318
+ ]
319
+ )
320
+ ax_temp.imshow(tar[0, col_idx].cpu().numpy(), cmap="magma")
321
+ # inset_ax = add_pixel_kde(ax_temp,
322
+ # inset_rect,
323
+ # [tar[0,col_idx].cpu().numpy(),
324
+ # ],
325
+ # inset_min_labelsize,
326
+ # label_list=[''],
327
+ # color_list=[color_ch_list[col_idx]],
328
+ # plot_kwargs_list=[{'linestyle':'--'}],
329
+ # plot_xmax_value=insetplot_xmax_value,
330
+ # plot_xmin_value=insetplot_xmin_value)
331
+
332
+ # inset_ax.set_xticks([])
333
+ # inset_ax.set_yticks([])
334
+
335
+ clean_ax(ax_temp)
336
+
337
+ ax_temp = fig.add_subplot(gs[0:grid_img_sz, 0:grid_img_sz])
338
+ ax_temp.imshow(inp[0, 0].cpu().numpy(), cmap="magma")
339
+ clean_ax(ax_temp)
340
+
341
+ # line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-',
342
+ # label='$C_1$')
343
+ # line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-',
344
+ # label='$C_2$')
345
+ # line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-',
346
+ # label='Pred')
347
+ # line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0],
348
+ # linestyle='--', label='$C^N_1$')
349
+ # line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1],
350
+ # linestyle='--', label='$C^N_2$')
351
+ # legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred],
352
+ # loc='upper right', frameon=False, labelcolor='white',
353
+ # prop={'size': 11})
354
+ # legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred],
355
+ # loc='upper right', frameon=False, labelcolor='white',
356
+ # prop={'size': 11})
357
+
358
+ if calibration_stats is not None:
359
+ smaller_offset = 4
360
+ ax_temp = fig.add_subplot(
361
+ gs[
362
+ grid_img_sz + 1 : 2 * grid_img_sz - smaller_offset + 1,
363
+ smaller_offset - 1 : grid_img_sz - 1,
364
+ ]
365
+ )
366
+ plot_calibration(ax_temp, calibration_stats)
367
+
368
+
369
+ def plot_calibration(ax, calibration_stats):
370
+ """
371
+ To plot calibration statistics (RMV vs RMSE).
372
+ """
373
+ first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
374
+ last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
375
+ ax.plot(
376
+ calibration_stats[0]["rmv"][first_idx:-last_idx],
377
+ calibration_stats[0]["rmse"][first_idx:-last_idx],
378
+ "o",
379
+ label=r"$\hat{C}_0$",
380
+ )
381
+
382
+ first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
383
+ last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
384
+ ax.plot(
385
+ calibration_stats[1]["rmv"][first_idx:-last_idx],
386
+ calibration_stats[1]["rmse"][first_idx:-last_idx],
387
+ "o",
388
+ label=r"$\hat{C}_1$",
389
+ )
390
+
391
+ ax.set_xlabel("RMV")
392
+ ax.set_ylabel("RMSE")
393
+ ax.legend()
394
+
395
+
396
+ def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name="shiftedcmap"):
397
+ """
398
+ Adapted from
399
+ https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-
400
+ matplotlib
401
+
402
+ Function to offset the "center" of a colormap. Useful for
403
+ data with a negative min and positive max and you want the
404
+ middle of the colormap's dynamic range to be at zero.
405
+
406
+ Input
407
+ -----
408
+ cmap : The matplotlib colormap to be altered
409
+ start : Offset from lowest point in the colormap's range.
410
+ Defaults to 0.0 (no lower offset). Should be between
411
+ 0.0 and `midpoint`.
412
+ midpoint : The new center of the colormap. Defaults to
413
+ 0.5 (no shift). Should be between 0.0 and 1.0. In
414
+ general, this should be 1 - vmax / (vmax + abs(vmin))
415
+ For example if your data range from -15.0 to +5.0 and
416
+ you want the center of the colormap at 0.0, `midpoint`
417
+ should be set to 1 - 5/(5 + 15)) or 0.75
418
+ stop : Offset from highest point in the colormap's range.
419
+ Defaults to 1.0 (no upper offset). Should be between
420
+ `midpoint` and 1.0.
421
+ """
422
+ cdict = {"red": [], "green": [], "blue": [], "alpha": []}
423
+
424
+ # regular index to compute the colors
425
+ reg_index = np.linspace(start, stop, 257)
426
+ mid_idx = len(reg_index) // 2
427
+ # shifted index to match the data
428
+ shift_index = np.hstack(
429
+ [
430
+ np.linspace(0.0, midpoint, 128, endpoint=False),
431
+ np.linspace(midpoint, 1.0, 129, endpoint=True),
432
+ ]
433
+ )
434
+
435
+ for ri, si in zip(reg_index, shift_index):
436
+ r, g, b, a = cmap(ri)
437
+ a = np.abs(ri - reg_index[mid_idx]) / reg_index[mid_idx]
438
+ # print(a)
439
+ cdict["red"].append((si, r, r))
440
+ cdict["green"].append((si, g, g))
441
+ cdict["blue"].append((si, b, b))
442
+ cdict["alpha"].append((si, a, a))
443
+
444
+ newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
445
+ matplotlib.colormaps.register(cmap=newcmap, force=True)
446
+
447
+ return newcmap
448
+
449
+
450
+ def get_fractional_change(target, prediction, max_val=None):
451
+ """
452
+ Get relative difference between target and prediction.
453
+ """
454
+ if max_val is None:
455
+ max_val = target.max()
456
+ return (target - prediction) / max_val
457
+
458
+
459
+ def get_zero_centered_midval(error):
460
+ """
461
+ When done this way, the midval ensures that the colorbar is centered at 0. (Don't
462
+ know how, but it works ;))
463
+ """
464
+ vmax = error.max()
465
+ vmin = error.min()
466
+ midval = 1 - vmax / (vmax + abs(vmin))
467
+ return midval
468
+
469
+
470
+ def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val=None):
471
+ """
472
+ Plot the relative difference between target and prediction.
473
+ NOTE: The plot is overlapped to the prediction image (in gray scale).
474
+ NOTE: The colorbar is centered at 0.
475
+ """
476
+ if ax is None:
477
+ _, ax = plt.subplots(figsize=(6, 6))
478
+
479
+ # Relative difference between target and prediction
480
+ rel_diff = get_fractional_change(target, prediction, max_val=max_val)
481
+ midval = get_zero_centered_midval(rel_diff)
482
+ shifted_cmap = shiftedColorMap(
483
+ cmap, start=0, midpoint=midval, stop=1.0, name="shiftedcmap"
484
+ )
485
+ ax.imshow(prediction, cmap="gray")
486
+ img_err = ax.imshow(rel_diff, cmap=shifted_cmap, alpha=1)
487
+ plt.colorbar(img_err, ax=ax)
488
+
489
+
490
+ # -------------------------------------------------------------------------------------
491
+
492
+
493
+ def get_predictions(
494
+ model: VAEModule,
495
+ dset: Dataset,
496
+ batch_size: int,
497
+ tile_size: Optional[tuple[int, int]] = None,
498
+ grid_size: Optional[int] = None,
499
+ mmse_count: int = 1,
500
+ num_workers: int = 4,
501
+ ) -> tuple[dict, dict, dict]:
502
+ """Get patch-wise predictions from a model for the entire dataset.
503
+
504
+ Parameters
505
+ ----------
506
+ model : VAEModule
507
+ Lightning model used for prediction.
508
+ dset : Dataset
509
+ Dataset to predict on.
510
+ batch_size : int
511
+ Batch size to use for prediction.
512
+ loss_type :
513
+ Type of reconstruction loss used by the model, by default `None`.
514
+ mmse_count : int, optional
515
+ Number of samples to generate for each input and then to average over for
516
+ MMSE estimation, by default 1.
517
+ num_workers : int, optional
518
+ Number of workers to use for DataLoader, by default 4.
519
+
520
+ Returns
521
+ -------
522
+ tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]
523
+ Tuple containing:
524
+ - predictions: Predicted images for the dataset.
525
+ - predictions_std: Standard deviation of the predicted images.
526
+ - logvar_arr: Log variance of the predicted images.
527
+ - losses: Reconstruction losses for the predictions.
528
+ - psnr: PSNR values for the predictions.
529
+ """
530
+ if hasattr(dset, "dsets"):
531
+ multifile_stitched_predictions = {}
532
+ multifile_stitched_stds = {}
533
+ for d in dset.dsets:
534
+ stitched_predictions, stitched_stds = get_single_file_mmse(
535
+ model=model,
536
+ dset=d,
537
+ batch_size=batch_size,
538
+ tile_size=tile_size,
539
+ grid_size=grid_size,
540
+ mmse_count=mmse_count,
541
+ num_workers=num_workers,
542
+ )
543
+ # get filename without extension and path
544
+ filename = d._fpath.name
545
+ multifile_stitched_predictions[filename] = stitched_predictions
546
+ multifile_stitched_stds[filename] = stitched_stds
547
+ return (
548
+ multifile_stitched_predictions,
549
+ multifile_stitched_stds,
550
+ )
551
+ else:
552
+ stitched_predictions, stitched_stds = get_single_file_mmse(
553
+ model=model,
554
+ dset=dset,
555
+ batch_size=batch_size,
556
+ tile_size=tile_size,
557
+ grid_size=grid_size,
558
+ mmse_count=mmse_count,
559
+ num_workers=num_workers,
560
+ )
561
+ # TODO stitching still not working properly for weirdly shaped images
562
+ # get filename without extension and path
563
+ # TODO in the ref ds this is the name of a folder not file :(
564
+ filename = dset._fpath.name
565
+ return (
566
+ {filename: stitched_predictions},
567
+ {filename: stitched_stds},
568
+ )
569
+
570
+
571
+ def get_single_file_predictions(
572
+ model: VAEModule,
573
+ dset: Dataset,
574
+ batch_size: int,
575
+ tile_size: Optional[tuple[int, int]] = None,
576
+ grid_size: Optional[int] = None,
577
+ num_workers: int = 4,
578
+ ) -> tuple[np.ndarray, np.ndarray]:
579
+ """Get patch-wise predictions from a model for a single file dataset."""
580
+ if tile_size and grid_size:
581
+ dset.set_img_sz(tile_size, grid_size)
582
+
583
+ device = get_device()
584
+
585
+ dloader = DataLoader(
586
+ dset,
587
+ pin_memory=False,
588
+ num_workers=num_workers,
589
+ shuffle=False,
590
+ batch_size=batch_size,
591
+ )
592
+ model.eval()
593
+ model.to(device)
594
+ tiles = []
595
+ logvar_arr = []
596
+ with torch.no_grad():
597
+ for batch in tqdm(dloader, desc="Predicting tiles"):
598
+ inp, tar = batch
599
+ inp = inp.to(device)
600
+ tar = tar.to(device)
601
+
602
+ # get model output
603
+ rec, _ = model(inp)
604
+
605
+ # get reconstructed img
606
+ if model.model.predict_logvar is None:
607
+ rec_img = rec
608
+ logvar = torch.tensor([-1])
609
+ else:
610
+ rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
611
+ logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
612
+
613
+ tiles.append(rec_img.cpu().numpy())
614
+
615
+ tile_samples = np.concatenate(tiles, axis=0)
616
+ return stitch_predictions_new(tile_samples, dset)
617
+
618
+
619
+ def get_single_file_mmse(
620
+ model: VAEModule,
621
+ dset: Dataset,
622
+ batch_size: int,
623
+ tile_size: Optional[tuple[int, int]] = None,
624
+ grid_size: Optional[int] = None,
625
+ mmse_count: int = 1,
626
+ num_workers: int = 4,
627
+ ) -> tuple[np.ndarray, np.ndarray]:
628
+ """Get patch-wise predictions from a model for a single file dataset."""
629
+ device = get_device()
630
+
631
+ dloader = DataLoader(
632
+ dset,
633
+ pin_memory=False,
634
+ num_workers=num_workers,
635
+ shuffle=False,
636
+ batch_size=batch_size,
637
+ )
638
+ if tile_size and grid_size:
639
+ dset.set_img_sz(tile_size, grid_size)
640
+
641
+ model.eval()
642
+ model.to(device)
643
+ tile_mmse = []
644
+ tile_stds = []
645
+ logvar_arr = []
646
+ with torch.no_grad():
647
+ for batch in tqdm(dloader, desc="Predicting tiles"):
648
+ inp, tar = batch
649
+ inp = inp.to(device)
650
+ tar = tar.to(device)
651
+
652
+ rec_img_list = []
653
+ for _ in range(mmse_count):
654
+
655
+ # get model output
656
+ rec, _ = model(inp)
657
+
658
+ # get reconstructed img
659
+ if model.model.predict_logvar is None:
660
+ rec_img = rec
661
+ logvar = torch.tensor([-1])
662
+ else:
663
+ rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
664
+ rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
665
+ logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
666
+
667
+ # aggregate results
668
+ samples = torch.cat(rec_img_list, dim=0)
669
+ mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
670
+ std_imgs = torch.std(samples, dim=0) # std over MMSE dim
671
+
672
+ tile_mmse.append(mmse_imgs.cpu().numpy())
673
+ tile_stds.append(std_imgs.cpu().numpy())
674
+
675
+ tiles_arr = np.concatenate(tile_mmse, axis=0)
676
+ tile_stds = np.concatenate(tile_stds, axis=0)
677
+ # TODO temporary hack, because of the stupid jupyter!
678
+ # If a user reruns a cell with class definition, isinstance will return False
679
+ if str(MultiChDloaderRef).split(".")[-1] == str(dset.__class__).split(".")[-1]:
680
+ stitch_func = stitch_predictions_general
681
+ else:
682
+ stitch_func = stitch_predictions_new
683
+ stitched_predictions = stitch_func(tiles_arr, dset)
684
+ stitched_stds = stitch_func(tile_stds, dset)
685
+ return stitched_predictions, stitched_stds
686
+
687
+
688
+ # ---------------------------------------------------------------------------------
689
+ ### Classes and Functions used to stitch predictions
690
+ class PatchLocation:
691
+ """
692
+ Encapsulates t_idx and spatial location.
693
+ """
694
+
695
+ def __init__(self, h_idx_range, w_idx_range, t_idx):
696
+ self.t = t_idx
697
+ self.h_start, self.h_end = h_idx_range
698
+ self.w_start, self.w_end = w_idx_range
699
+
700
+ def __str__(self):
701
+ msg = f"T:{self.t} [{self.h_start}-{self.h_end}) [{self.w_start}-{self.w_end}) "
702
+ return msg
703
+
704
+
705
+ def _get_location(extra_padding, hwt, pred_h, pred_w):
706
+ h_start, w_start, t_idx = hwt
707
+ h_start -= extra_padding
708
+ h_end = h_start + pred_h
709
+ w_start -= extra_padding
710
+ w_end = w_start + pred_w
711
+ return PatchLocation((h_start, h_end), (w_start, w_end), t_idx)
712
+
713
+
714
+ def get_location_from_idx(dset, dset_input_idx, pred_h, pred_w):
715
+ """
716
+ For a given idx of the dataset, it returns where exactly in the dataset, does this
717
+ prediction lies. Note that this prediction also has padded pixels and so a subset of
718
+ it will be used in the final prediction. Which time frame, which spatial location
719
+ (h_start, h_end, w_start,w_end)
720
+ Args:
721
+ dset:
722
+ dset_input_idx:
723
+ pred_h:
724
+ pred_w:
725
+
726
+ Returns
727
+ -------
728
+ """
729
+ extra_padding = dset.per_side_overlap_pixelcount()
730
+ htw = dset.get_idx_manager().hwt_from_idx(
731
+ dset_input_idx, grid_size=dset.get_grid_size()
732
+ )
733
+ return _get_location(extra_padding, htw, pred_h, pred_w)
734
+
735
+
736
+ def remove_pad(pred, loc, extra_padding, smoothening_pixelcount, frame_shape):
737
+ assert smoothening_pixelcount == 0
738
+ if extra_padding - smoothening_pixelcount > 0:
739
+ h_s = extra_padding - smoothening_pixelcount
740
+
741
+ # rows
742
+ h_N = frame_shape[0]
743
+ if loc.h_end > h_N:
744
+ assert loc.h_end - extra_padding + smoothening_pixelcount <= h_N
745
+ h_e = extra_padding - smoothening_pixelcount
746
+
747
+ w_s = extra_padding - smoothening_pixelcount
748
+
749
+ # columns
750
+ w_N = frame_shape[1]
751
+ if loc.w_end > w_N:
752
+ assert loc.w_end - extra_padding + smoothening_pixelcount <= w_N
753
+
754
+ w_e = extra_padding - smoothening_pixelcount
755
+
756
+ return pred[h_s:-h_e, w_s:-w_e]
757
+
758
+ return pred
759
+
760
+
761
+ def update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount):
762
+ extra_padding = extra_padding - smoothening_pixelcount
763
+ loc.h_start += extra_padding
764
+ loc.w_start += extra_padding
765
+ loc.h_end -= extra_padding
766
+ loc.w_end -= extra_padding
767
+ return loc
768
+
769
+
770
+ def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
771
+ """
772
+ Args:
773
+ smoothening_pixelcount: number of pixels which can be interpolated
774
+ """
775
+ assert smoothening_pixelcount >= 0 and isinstance(smoothening_pixelcount, int)
776
+ extra_padding = dset.per_side_overlap_pixelcount()
777
+ # if there are more channels, use all of them.
778
+ shape = list(dset.get_data_shape())
779
+ shape[-1] = max(shape[-1], predictions.shape[1])
780
+
781
+ output = np.zeros(shape, dtype=predictions.dtype)
782
+ frame_shape = dset.get_data_shape()[1:3]
783
+ for dset_input_idx in range(predictions.shape[0]):
784
+ loc = get_location_from_idx(
785
+ dset, dset_input_idx, predictions.shape[-2], predictions.shape[-1]
786
+ )
787
+
788
+ mask = None
789
+ cropped_pred_list = []
790
+ for ch_idx in range(predictions.shape[1]):
791
+ # class i
792
+ cropped_pred_i = remove_pad(
793
+ predictions[dset_input_idx, ch_idx],
794
+ loc,
795
+ extra_padding,
796
+ smoothening_pixelcount,
797
+ frame_shape,
798
+ )
799
+
800
+ if mask is None:
801
+ # NOTE: don't need to compute it for every patch.
802
+ assert (
803
+ smoothening_pixelcount == 0
804
+ ), "For smoothing,enable the get_smoothing_mask. It is disabled since I"
805
+ "don't use it and it needs modification to work with non-square images"
806
+ mask = 1
807
+ # mask = _get_smoothing_mask(cropped_pred_i.shape,
808
+ # smoothening_pixelcount, loc, frame_size)
809
+
810
+ cropped_pred_list.append(cropped_pred_i)
811
+
812
+ loc = update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount)
813
+ for ch_idx in range(predictions.shape[1]):
814
+ output[loc.t, loc.h_start : loc.h_end, loc.w_start : loc.w_end, ch_idx] += (
815
+ cropped_pred_list[ch_idx] * mask
816
+ )
817
+
818
+ return output
819
+
820
+
821
+ # from disentangle.analysis.stitch_prediction import *
822
+ def stitch_predictions_new(predictions, dset):
823
+ """
824
+ Args:
825
+ smoothening_pixelcount: number of pixels which can be interpolated
826
+ """
827
+ # Commented out since it is not used as of now
828
+ # if isinstance(dset, MultiFileDset):
829
+ # cum_count = 0
830
+ # output = []
831
+ # for dset in dset.dsets:
832
+ # cnt = dset.idx_manager.total_grid_count()
833
+ # output.append(
834
+ # stitch_predictions(predictions[cum_count:cum_count + cnt], dset))
835
+ # cum_count += cnt
836
+ # return output
837
+
838
+ # else:
839
+ mng = dset.idx_manager
840
+
841
+ # if there are more channels, use all of them.
842
+ shape = list(dset.get_data_shape())
843
+ shape[-1] = max(shape[-1], predictions.shape[1])
844
+
845
+ output = np.zeros(shape, dtype=predictions.dtype)
846
+ # frame_shape = dset.get_data_shape()[:-1]
847
+ for dset_idx in range(predictions.shape[0]):
848
+ # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
849
+ # predictions.shape[-1])
850
+ # grid start, grid end
851
+ gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
852
+ ge = gs + mng.grid_shape
853
+
854
+ # patch start, patch end
855
+ ps = gs - mng.patch_offset()
856
+ pe = ps + mng.patch_shape
857
+ # print('PS')
858
+ # print(ps)
859
+ # print(pe)
860
+
861
+ # valid grid start, valid grid end
862
+ vgs = np.array([max(0, x) for x in gs], dtype=int)
863
+ vge = np.array([min(x, y) for x, y in zip(ge, mng.data_shape)], dtype=int)
864
+ # assert np.all(vgs == gs)
865
+ # assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest
866
+ # to dig why it's failing at this point !
867
+ # print('VGS')
868
+ # print(gs)
869
+ # print(ge)
870
+
871
+ if mng.tiling_mode == TilingMode.ShiftBoundary:
872
+ for dim in range(len(vgs)):
873
+ if ps[dim] == 0:
874
+ vgs[dim] = 0
875
+ if pe[dim] == mng.data_shape[dim]:
876
+ vge[dim] = mng.data_shape[dim]
877
+
878
+ # relative start, relative end. This will be used on pred_tiled
879
+ rs = vgs - ps
880
+ re = rs + (vge - vgs)
881
+ # print('RS')
882
+ # print(rs)
883
+ # print(re)
884
+
885
+ # print(output.shape)
886
+ # print(predictions.shape)
887
+ for ch_idx in range(predictions.shape[1]):
888
+ if len(output.shape) == 4:
889
+ # channel dimension is the last one.
890
+ output[vgs[0] : vge[0], vgs[1] : vge[1], vgs[2] : vge[2], ch_idx] = (
891
+ predictions[dset_idx][ch_idx, rs[1] : re[1], rs[2] : re[2]]
892
+ )
893
+ elif len(output.shape) == 5:
894
+ # channel dimension is the last one.
895
+ assert vge[0] - vgs[0] == 1, "Only one frame is supported"
896
+ output[
897
+ vgs[0], vgs[1] : vge[1], vgs[2] : vge[2], vgs[3] : vge[3], ch_idx
898
+ ] = predictions[dset_idx][
899
+ ch_idx, rs[1] : re[1], rs[2] : re[2], rs[3] : re[3]
900
+ ]
901
+ else:
902
+ raise ValueError(f"Unsupported shape {output.shape}")
903
+
904
+ return output
905
+
906
+
907
+ def stitch_predictions_general(predictions, dset):
908
+ """Stitching for the dataset with multiple files of different shape."""
909
+ mng = dset.idx_manager
910
+
911
+ # TODO assert all shapes are equal len
912
+ # adjust number of channels to match with prediction shape #TODO ugly, refac!
913
+ shapes = []
914
+ for shape in dset.get_data_shapes()[0]:
915
+ shapes.append((predictions.shape[1],) + shape[1:])
916
+
917
+ output = [np.zeros(shape, dtype=predictions.dtype) for shape in shapes]
918
+ # frame_shape = dset.get_data_shape()[:-1]
919
+ for patch_idx in range(predictions.shape[0]):
920
+ # grid start, grid end
921
+ # channel_idx is 0 because during prediction we're only use one channel.
922
+ # # TODO revisit this
923
+ # 0th dimension is sample index in the output list
924
+ grid_coords = np.array(
925
+ mng.get_location_from_patch_idx(channel_idx=0, patch_idx=patch_idx),
926
+ dtype=int,
927
+ )
928
+ sample_idx = grid_coords[0]
929
+ grid_start = grid_coords[1:]
930
+ # from here on, coordinates are relative to the sample(file in the list of
931
+ # inputs)
932
+ grid_end = grid_start + mng.grid_shape
933
+
934
+ # patch start, patch end
935
+ patch_start = grid_start - mng.patch_offset()
936
+ patch_end = patch_start + mng.patch_shape
937
+
938
+ # valid grid start, valid grid end
939
+ valid_grid_start = np.array([max(0, x) for x in grid_start], dtype=int)
940
+ valid_grid_end = np.array(
941
+ [min(x, y) for x, y in zip(grid_end, shapes[sample_idx])], dtype=int
942
+ )
943
+
944
+ if mng.tiling_mode == TilingMode.ShiftBoundary:
945
+ for dim in range(len(valid_grid_start)):
946
+ if patch_start[dim] == 0:
947
+ valid_grid_start[dim] = 0
948
+ if patch_end[dim] == mng.data_shape[dim]:
949
+ valid_grid_end[dim] = mng.data_shape[dim]
950
+
951
+ # relative start, relative end. This will be used on pred_tiled
952
+ relative_start = valid_grid_start - patch_start
953
+ relative_end = relative_start + (valid_grid_end - valid_grid_start)
954
+
955
+ for ch_idx in range(predictions.shape[1]):
956
+ if len(output[sample_idx].shape) == 3:
957
+ # starting from 1 because 0th dimension is channel relative to input
958
+ # channel dimension for stitched output is relative to model output
959
+ output[sample_idx][
960
+ ch_idx,
961
+ valid_grid_start[1] : valid_grid_end[1],
962
+ valid_grid_start[2] : valid_grid_end[2],
963
+ ] = predictions[patch_idx][
964
+ ch_idx,
965
+ relative_start[1] : relative_end[1],
966
+ relative_start[2] : relative_end[2],
967
+ ]
968
+ elif len(output[sample_idx].shape) == 4:
969
+ assert (
970
+ valid_grid_end[0] - valid_grid_start[0] == 1
971
+ ), "Only one frame is supported"
972
+ output[
973
+ ch_idx,
974
+ valid_grid_start[0],
975
+ valid_grid_end[1] : valid_grid_end[1],
976
+ valid_grid_start[2] : valid_grid_end[2],
977
+ valid_grid_start[3] : valid_grid_end[3],
978
+ ] = predictions[patch_idx][
979
+ ch_idx,
980
+ relative_start[1] : relative_end[1],
981
+ relative_start[2] : relative_end[2],
982
+ relative_start[3] : relative_end[3],
983
+ ]
984
+ else:
985
+ raise ValueError(f"Unsupported shape {output.shape}")
986
+
987
+ return output