careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,191 @@
1
+ from typing import Union, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ from scipy import stats
6
+
7
+
8
+ def get_last_index(bin_count, quantile):
9
+ cumsum = np.cumsum(bin_count)
10
+ normalized_cumsum = cumsum / cumsum[-1]
11
+ for i in range(1, len(normalized_cumsum)):
12
+ if normalized_cumsum[-i] < quantile:
13
+ return i - 1
14
+ return None
15
+
16
+
17
+ def get_first_index(bin_count, quantile):
18
+ cumsum = np.cumsum(bin_count)
19
+ normalized_cumsum = cumsum / cumsum[-1]
20
+ for i in range(len(normalized_cumsum)):
21
+ if normalized_cumsum[i] > quantile:
22
+ return i
23
+ return None
24
+
25
+
26
+ class Calibration:
27
+ """Calibrate the uncertainty computed over samples from LVAE model.
28
+
29
+ Calibration is done by learning a scalar that maps the pixel-wise standard
30
+ deviation of the the predicted samples into the actual prediction error.
31
+ """
32
+
33
+ def __init__(self, num_bins: int = 15):
34
+ self._bins = num_bins
35
+ self._bin_boundaries = None
36
+
37
+ def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
38
+ """Compute the bin boundaries for `num_bins` bins and predicted std values."""
39
+ min_std = np.min(predict_std)
40
+ max_std = np.max(predict_std)
41
+ return np.linspace(min_std, max_std, self._bins + 1)
42
+
43
+ def compute_stats(
44
+ self, pred: np.ndarray, pred_std: np.ndarray, target: np.ndarray
45
+ ) -> dict[int, dict[str, Union[np.ndarray, list]]]:
46
+ """
47
+ It computes the bin-wise RMSE and RMV for each channel of the predicted image.
48
+
49
+ Recall that:
50
+ - RMSE = np.sqrt((pred - target)**2 / num_pixels)
51
+ - RMV = np.sqrt(np.mean(pred_std**2))
52
+
53
+ ALGORITHM
54
+ - For each channel:
55
+ - Given the bin boundaries, assign pixels of `std_ch` array to a specific bin index.
56
+ - For each bin index:
57
+ - Compute the RMSE, RMV, and number of pixels for that bin.
58
+
59
+ NOTE: each channel of the predicted image/logvar has its own stats.
60
+
61
+ Parameters
62
+ ----------
63
+ pred: np.ndarray
64
+ Predicted patches, shape (n, h, w, c).
65
+ pred_std: np.ndarray
66
+ Std computed over the predicted patches, shape (n, h, w, c).
67
+ target: np.ndarray
68
+ Target GT image, shape (n, h, w, c).
69
+ """
70
+ self._bin_boundaries = {}
71
+ stats_dict = {}
72
+ for ch_idx in range(pred.shape[-1]):
73
+ stats_dict[ch_idx] = {
74
+ "bin_count": [],
75
+ "rmv": [],
76
+ "rmse": [],
77
+ "bin_boundaries": None,
78
+ "bin_matrix": [],
79
+ "rmse_err": [],
80
+ }
81
+ pred_ch = pred[..., ch_idx]
82
+ std_ch = pred_std[..., ch_idx]
83
+ target_ch = target[..., ch_idx]
84
+ boundaries = self.compute_bin_boundaries(std_ch)
85
+ stats_dict[ch_idx]["bin_boundaries"] = boundaries
86
+ bin_matrix = np.digitize(std_ch.reshape(-1), boundaries)
87
+ bin_matrix = bin_matrix.reshape(std_ch.shape)
88
+ stats_dict[ch_idx]["bin_matrix"] = bin_matrix
89
+ error = (pred_ch - target_ch) ** 2
90
+ for bin_idx in range(1, 1 + self._bins):
91
+ bin_mask = bin_matrix == bin_idx
92
+ bin_error = error[bin_mask]
93
+ bin_size = np.sum(bin_mask)
94
+ bin_error = (
95
+ np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
96
+ )
97
+ stderr = (
98
+ np.std(error[bin_mask]) / np.sqrt(bin_size)
99
+ if bin_size > 0
100
+ else None
101
+ )
102
+ rmse_stderr = np.sqrt(stderr) if stderr is not None else None
103
+
104
+ bin_var = np.mean(std_ch[bin_mask] ** 2)
105
+ stats_dict[ch_idx]["rmse"].append(bin_error)
106
+ stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
107
+ stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
108
+ stats_dict[ch_idx]["bin_count"].append(bin_size)
109
+ self.stats_dict = stats_dict
110
+ return stats_dict
111
+
112
+ def get_calibrated_factor_for_stdev(
113
+ self,
114
+ pred: Optional[np.ndarray] = None,
115
+ pred_std: Optional[np.ndarray] = None,
116
+ target: Optional[np.ndarray] = None,
117
+ q_s: float = 0.00001,
118
+ q_e: float = 0.99999,
119
+ ) -> dict[str, float]:
120
+ """Calibrate the uncertainty by multiplying the predicted std with a scalar.
121
+
122
+ Parameters
123
+ ----------
124
+ stats_dict : dict[int, dict[str, Union[np.ndarray, list]]]
125
+ Dictionary containing the stats for each channel.
126
+ q_s : float, optional
127
+ Start quantile, by default 0.00001.
128
+ q_e : float, optional
129
+ End quantile, by default 0.99999.
130
+
131
+ Returns
132
+ -------
133
+ dict[str, float]
134
+ Calibrated factor for each channel (slope + intercept).
135
+ """
136
+ if not hasattr(self, "stats_dict"):
137
+ print("No stats found. Computing stats...")
138
+ if any(v is None for v in [pred, pred_std, target]):
139
+ raise ValueError("pred, pred_std, and target must be provided.")
140
+ self.stats_dict = self.compute_stats(
141
+ pred=pred, pred_std=pred_std, target=target
142
+ )
143
+ outputs = {}
144
+ for ch_idx in self.stats_dict.keys():
145
+ y = self.stats_dict[ch_idx]["rmse"]
146
+ x = self.stats_dict[ch_idx]["rmv"]
147
+ count = self.stats_dict[ch_idx]["bin_count"]
148
+
149
+ first_idx = get_first_index(count, q_s)
150
+ last_idx = get_last_index(count, q_e)
151
+ x = x[first_idx:-last_idx]
152
+ y = y[first_idx:-last_idx]
153
+ slope, intercept, *_ = stats.linregress(x, y)
154
+ output = {"scalar": slope, "offset": intercept}
155
+ outputs[ch_idx] = output
156
+ factors = self.get_factors_array(factors_dict=outputs)
157
+ return outputs, factors
158
+
159
+ def get_factors_array(self, factors_dict: list[dict]):
160
+ """Get the calibration factors as a numpy array."""
161
+ calib_scalar = [factors_dict[i]["scalar"] for i in range(len(factors_dict))]
162
+ calib_scalar = np.array(calib_scalar).reshape(1, 1, 1, -1)
163
+ calib_offset = [
164
+ factors_dict[i].get("offset", 0.0) for i in range(len(factors_dict))
165
+ ]
166
+ calib_offset = np.array(calib_offset).reshape(1, 1, 1, -1)
167
+ return {"scalar": calib_scalar, "offset": calib_offset}
168
+
169
+
170
+ def plot_calibration(ax, calibration_stats):
171
+ first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.0001)
172
+ last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.9999)
173
+ ax.plot(
174
+ calibration_stats[0]["rmv"][first_idx:-last_idx],
175
+ calibration_stats[0]["rmse"][first_idx:-last_idx],
176
+ "o",
177
+ label=r"$\hat{C}_0$: Ch1",
178
+ )
179
+
180
+ first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.0001)
181
+ last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.9999)
182
+ ax.plot(
183
+ calibration_stats[1]["rmv"][first_idx:-last_idx],
184
+ calibration_stats[1]["rmse"][first_idx:-last_idx],
185
+ "o",
186
+ label=r"$\hat{C}_1$: Ch2",
187
+ )
188
+ # TODO add multichannel
189
+ ax.set_xlabel("RMV")
190
+ ax.set_ylabel("RMSE")
191
+ ax.legend()
@@ -0,0 +1,20 @@
1
+ from .config import MicroSplitDataConfig
2
+ from .lc_dataset import LCMultiChDloader
3
+ from .ms_dataset_ref import MultiChDloaderRef
4
+ from .multich_dataset import MultiChDloader
5
+ from .multicrop_dset import MultiCropDset
6
+ from .multifile_dataset import MultiFileDset
7
+ from .types import DataSplitType, DataType, TilingMode
8
+
9
+ __all__ = [
10
+ "DataSplitType",
11
+ "DataType",
12
+ "LCMultiChDloader",
13
+ "LCMultiChDloaderRef",
14
+ "MicroSplitDataConfig",
15
+ "MultiChDloader",
16
+ "MultiChDloaderRef",
17
+ "MultiCropDset",
18
+ "MultiFileDset",
19
+ "TilingMode",
20
+ ]
@@ -0,0 +1,135 @@
1
+ from typing import Any, Union
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+ from .types import DataSplitType, DataType, TilingMode
6
+
7
+
8
+ # TODO: check if any bool logic can be removed
9
+ class MicroSplitDataConfig(BaseModel):
10
+ model_config = ConfigDict(validate_assignment=True, extra="allow")
11
+
12
+ data_type: Union[DataType, str] | None # TODO remove or refactor!!
13
+ """Type of the dataset, should be one of DataType"""
14
+
15
+ depth3D: int | None = 1
16
+ """Number of slices in 3D. If data is 2D depth3D is equal to 1"""
17
+
18
+ datasplit_type: DataSplitType | None = None
19
+ """Whether to return training, validation or test split, should be one of
20
+ DataSplitType"""
21
+
22
+ num_channels: int | None = 2
23
+ """Number of channels in the input"""
24
+
25
+ # TODO: remove ch*_fname parameters, should be parsed automatically from a name list
26
+ ch1_fname: str | None = None
27
+ ch2_fname: str | None = None
28
+ ch_input_fname: str | None = None
29
+
30
+ input_is_sum: bool | None = False
31
+ """Whether the input is the sum or average of channels"""
32
+
33
+ input_idx: int | None = None
34
+ """Index of the channel where the input is stored in the data"""
35
+
36
+ target_idx_list: list[int] | None = None
37
+ """Indices of the channels where the targets are stored in the data"""
38
+
39
+ # TODO: where are there used?
40
+ start_alpha: Any | None = None
41
+ end_alpha: Any | None = None
42
+
43
+ image_size: tuple # TODO: revisit, new model_config uses tuple
44
+ """Size of one patch of data"""
45
+
46
+ grid_size: Union[int, tuple[int, int, int]] | None = None
47
+ """Frame is divided into square grids of this size. A patch centered on a grid
48
+ having size `image_size` is returned. Grid size not used in training,
49
+ used only during val / test, grid size controls the overlap of the patches"""
50
+
51
+ empty_patch_replacement_enabled: bool | None = False
52
+ """Whether to replace the content of one of the channels
53
+ with background with given probability"""
54
+ empty_patch_replacement_channel_idx: Any | None = None
55
+ empty_patch_replacement_probab: Any | None = None
56
+ empty_patch_max_val_threshold: Any | None = None
57
+
58
+ uncorrelated_channels: bool | None = False
59
+ """Replace the content in one of the channels with given probability to make
60
+ channel content 'uncorrelated'"""
61
+ uncorrelated_channel_probab: float | None = 0.5
62
+
63
+ poisson_noise_factor: float | None = -1
64
+ """The added poisson noise factor"""
65
+
66
+ synthetic_gaussian_scale: float | None = 0.1
67
+
68
+ # TODO: set to True in training code, recheck
69
+ input_has_dependant_noise: bool | None = False
70
+
71
+ # TODO: sometimes max_val differs between runs with fixed seeds with noise enabled
72
+ enable_gaussian_noise: bool | None = False
73
+ """Whether to enable gaussian noise"""
74
+
75
+ # TODO: is this parameter used?
76
+ allow_generation: bool = False
77
+
78
+ # TODO: both used in IndexSwitcher, insure correct passing
79
+ training_validtarget_fraction: Any = None
80
+ deterministic_grid: Any = None
81
+
82
+ # TODO: why is this not used?
83
+ enable_rotation_aug: bool | None = False
84
+
85
+ max_val: Union[float, tuple] | None = None
86
+ """Maximum data in the dataset. Is calculated for train split, and should be
87
+ externally set for val and test splits."""
88
+
89
+ overlapping_padding_kwargs: Any = None
90
+ """Parameters for np.pad method"""
91
+
92
+ # TODO: remove this parameter, controls debug print
93
+ print_vars: bool | None = False
94
+
95
+ # Hard-coded parameters (used to be in the config file)
96
+ normalized_input: bool = True
97
+ """If this is set to true, then one mean and stdev is used
98
+ for both channels. Otherwise, two different mean and stdev are used."""
99
+ use_one_mu_std: bool | None = True
100
+
101
+ # TODO: is this parameter used?
102
+ train_aug_rotate: bool | None = False
103
+ enable_random_cropping: bool | None = True
104
+
105
+ multiscale_lowres_count: int | None = None
106
+ """Number of LC scales"""
107
+
108
+ tiling_mode: TilingMode | None = TilingMode.ShiftBoundary
109
+
110
+ target_separate_normalization: bool | None = True
111
+
112
+ mode_3D: bool | None = False
113
+ """If training in 3D mode or not"""
114
+
115
+ trainig_datausage_fraction: float | None = 1.0
116
+
117
+ validtarget_random_fraction: float | None = None
118
+
119
+ validation_datausage_fraction: float | None = 1.0
120
+
121
+ random_flip_z_3D: bool | None = False
122
+
123
+ padding_kwargs: dict = {"mode": "reflect"} # TODO remove !!
124
+
125
+ def __init__(self, **data):
126
+ # Convert string data_type to enum if needed
127
+ if "data_type" in data and isinstance(data["data_type"], str):
128
+ try:
129
+ data["data_type"] = DataType[data["data_type"]]
130
+ except KeyError:
131
+ # Keep original value to let validation handle the error
132
+ pass
133
+ super().__init__(**data)
134
+
135
+ # TODO add validators !
@@ -0,0 +1,274 @@
1
+ """
2
+ A place for Datasets and Dataloaders.
3
+ """
4
+
5
+ import logging
6
+ import math
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Optional, Union
9
+
10
+ import numpy as np
11
+ from skimage.transform import resize
12
+
13
+ from .config import MicroSplitDataConfig
14
+ from .multich_dataset import MultiChDloader
15
+
16
+
17
+ class LCMultiChDloader(MultiChDloader):
18
+ """Multi-channel dataset loader for LC-style datasets."""
19
+
20
+ def __init__(
21
+ self,
22
+ data_config: MicroSplitDataConfig,
23
+ datapath: Union[str, Path],
24
+ load_data_fn: Optional[Callable] = None,
25
+ val_fraction: float = 0.1,
26
+ test_fraction: float = 0.1,
27
+ allow_generation: bool = False,
28
+ ):
29
+ self._padding_kwargs = (
30
+ data_config.padding_kwargs # mode=padding_mode, constant_values=constant_value
31
+ )
32
+ self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
33
+
34
+ super().__init__(
35
+ data_config,
36
+ datapath,
37
+ load_data_fn=load_data_fn,
38
+ val_fraction=val_fraction,
39
+ test_fraction=test_fraction,
40
+ )
41
+
42
+ if data_config.overlapping_padding_kwargs is not None:
43
+ assert (
44
+ self._padding_kwargs == data_config.overlapping_padding_kwargs
45
+ ), "During evaluation, overlapping_padding_kwargs should be same as padding_args. \
46
+ It should be so since we just use overlapping_padding_kwargs when it is not None"
47
+
48
+ else:
49
+ self._overlapping_padding_kwargs = data_config.padding_kwargs
50
+
51
+ self.multiscale_lowres_count = data_config.multiscale_lowres_count
52
+ assert self.multiscale_lowres_count is not None
53
+ self._scaled_data = [self._data]
54
+ self._scaled_noise_data = [self._noise_data]
55
+
56
+ assert (
57
+ isinstance(self.multiscale_lowres_count, int)
58
+ and self.multiscale_lowres_count >= 1
59
+ )
60
+ assert isinstance(self._padding_kwargs, dict)
61
+ assert "mode" in self._padding_kwargs
62
+
63
+ for _ in range(1, self.multiscale_lowres_count):
64
+ shape = self._scaled_data[-1].shape
65
+ assert len(shape) == 4
66
+ new_shape = (shape[0], shape[1] // 2, shape[2] // 2, shape[3])
67
+ ds_data = resize(
68
+ self._scaled_data[-1].astype(np.float32), new_shape
69
+ ).astype(self._scaled_data[-1].dtype)
70
+ # NOTE: These asserts are important. the resize method expects np.float32. otherwise, one gets weird results.
71
+ assert (
72
+ ds_data.max() / self._scaled_data[-1].max() < 5
73
+ ), "Downsampled image should not have very different values"
74
+ assert (
75
+ ds_data.max() / self._scaled_data[-1].max() > 0.2
76
+ ), "Downsampled image should not have very different values"
77
+
78
+ self._scaled_data.append(ds_data)
79
+ # do the same for noise
80
+ if self._noise_data is not None:
81
+ noise_data = resize(self._scaled_noise_data[-1], new_shape)
82
+ self._scaled_noise_data.append(noise_data)
83
+
84
+ def reduce_data(
85
+ self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
86
+ ):
87
+ assert t_list is not None
88
+ assert h_start is None
89
+ assert h_end is None
90
+ assert w_start is None
91
+ assert w_end is None
92
+
93
+ self._data = self._data[t_list].copy()
94
+ self._scaled_data = [
95
+ self._scaled_data[i][t_list].copy() for i in range(len(self._scaled_data))
96
+ ]
97
+
98
+ if self._noise_data is not None:
99
+ self._noise_data = self._noise_data[t_list].copy()
100
+ self._scaled_noise_data = [
101
+ self._scaled_noise_data[i][t_list].copy()
102
+ for i in range(len(self._scaled_noise_data))
103
+ ]
104
+
105
+ self.N = len(t_list)
106
+ # TODO where tf is self._img_sz defined?
107
+ self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
108
+ print(
109
+ f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
110
+ )
111
+
112
+ def _init_msg(self):
113
+ msg = super()._init_msg()
114
+ msg += f" Pad:{self._padding_kwargs}"
115
+ if self._uncorrelated_channels:
116
+ msg += f" UncorrChProbab:{self._uncorrelated_channel_probab}"
117
+ return msg
118
+
119
+ def _load_scaled_img(
120
+ self, scaled_index, index: Union[int, tuple[int, int]]
121
+ ) -> tuple[np.ndarray, np.ndarray]:
122
+ if isinstance(index, int):
123
+ idx = index
124
+ else:
125
+ idx, _ = index
126
+
127
+ # tidx = self.idx_manager.get_t(idx)
128
+ patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
129
+ nidx = patch_loc_list[0]
130
+
131
+ imgs = self._scaled_data[scaled_index][nidx]
132
+ imgs = tuple([imgs[None, ..., i] for i in range(imgs.shape[-1])])
133
+ if self._noise_data is not None:
134
+ noisedata = self._scaled_noise_data[scaled_index][nidx]
135
+ noise = tuple([noisedata[None, ..., i] for i in range(noisedata.shape[-1])])
136
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
137
+ imgs = tuple([img + noise[0] * factor for img in imgs])
138
+ return imgs
139
+
140
+ def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
141
+ """
142
+ Here, h_start, w_start could be negative. That simply means we need to pick the content from 0. So,
143
+ the cropped image will be smaller than self._img_sz * self._img_sz
144
+ """
145
+ max_len_vals = list(self.idx_manager.data_shape[1:-1])
146
+ max_len_vals[-2:] = img.shape[-2:]
147
+ return self._crop_img_with_padding(
148
+ img, patch_start_loc, max_len_vals=max_len_vals
149
+ )
150
+
151
+ def _get_img(self, index: int):
152
+ """
153
+ Returns the primary patch along with low resolution patches centered on the primary patch.
154
+ """
155
+ # Noise_tuples is populated when there is synthetic noise in training
156
+ # Should have similar type of noise with the noise model
157
+ # Starting with microsplit, dump the noise, use it instead as an augmentation if nessesary
158
+ img_tuples, noise_tuples = self._load_img(index)
159
+ assert self._img_sz is not None
160
+ h, w = img_tuples[0].shape[-2:]
161
+ if self._enable_random_cropping:
162
+ patch_start_loc = self._get_random_hw(h, w)
163
+ if self._5Ddata:
164
+ patch_start_loc = (
165
+ np.random.choice(img_tuples[0].shape[-3] - self._depth3D),
166
+ ) + patch_start_loc
167
+ else:
168
+ patch_start_loc = self._get_deterministic_loc(index)
169
+
170
+ # LC logic is located here, the function crops the image of the highest resolution
171
+ cropped_img_tuples = [
172
+ self._crop_flip_img(img, patch_start_loc, False, False)
173
+ for img in img_tuples
174
+ ]
175
+ cropped_noise_tuples = [
176
+ self._crop_flip_img(noise, patch_start_loc, False, False)
177
+ for noise in noise_tuples
178
+ ]
179
+ patch_start_loc = list(patch_start_loc)
180
+ h_start, w_start = patch_start_loc[-2], patch_start_loc[-1]
181
+ h_center = h_start + self._img_sz // 2
182
+ w_center = w_start + self._img_sz // 2
183
+ allres_versions = {
184
+ i: [cropped_img_tuples[i]] for i in range(len(cropped_img_tuples))
185
+ }
186
+ for scale_idx in range(1, self.multiscale_lowres_count):
187
+ # Returning the image of the lower resolution
188
+ scaled_img_tuples = self._load_scaled_img(scale_idx, index)
189
+
190
+ h_center = h_center // 2
191
+ w_center = w_center // 2
192
+
193
+ h_start = h_center - self._img_sz // 2
194
+ w_start = w_center - self._img_sz // 2
195
+ patch_start_loc[-2:] = [h_start, w_start]
196
+ scaled_cropped_img_tuples = [
197
+ self._crop_flip_img(img, patch_start_loc, False, False)
198
+ for img in scaled_img_tuples
199
+ ]
200
+ for ch_idx in range(len(img_tuples)):
201
+ allres_versions[ch_idx].append(scaled_cropped_img_tuples[ch_idx])
202
+
203
+ output_img_tuples = tuple(
204
+ [
205
+ np.concatenate(allres_versions[ch_idx])
206
+ for ch_idx in range(len(img_tuples))
207
+ ]
208
+ )
209
+ return output_img_tuples, cropped_noise_tuples
210
+
211
+ def __getitem__(self, index: Union[int, tuple[int, int]]):
212
+ img_tuples, noise_tuples = self._get_img(index)
213
+ if self._uncorrelated_channels:
214
+ assert (
215
+ self._input_idx is None
216
+ ), "Uncorrelated channels is not implemented when there is a separate input channel."
217
+ if np.random.rand() < self._uncorrelated_channel_probab:
218
+ img_tuples_new = [None] * len(img_tuples)
219
+ img_tuples_new[0] = img_tuples[0]
220
+ for i in range(1, len(img_tuples)):
221
+ new_index = np.random.randint(len(self))
222
+ img_tuples_tmp, _ = self._get_img(new_index)
223
+ img_tuples_new[i] = img_tuples_tmp[i]
224
+ img_tuples = img_tuples_new
225
+
226
+ if self._is_train:
227
+ if self._empty_patch_replacement_enabled:
228
+ if np.random.rand() < self._empty_patch_replacement_probab:
229
+ img_tuples = self.replace_with_empty_patch(img_tuples)
230
+
231
+ if self._enable_rotation:
232
+ img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
233
+
234
+ # add noise to input, if noise is present combine it with the image
235
+ # factor is for the compute input not to have too much noise because the average of two gaussians
236
+ if len(noise_tuples) > 0:
237
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
238
+ input_tuples = []
239
+ for x in img_tuples:
240
+ x = (
241
+ x.copy()
242
+ ) # to avoid changing the original image since it is later used for target
243
+ # NOTE: other LC levels already have noise added. So, we just need to add noise to the highest resolution.
244
+ x[0] = x[0] + noise_tuples[0] * factor
245
+ input_tuples.append(x)
246
+ else:
247
+ input_tuples = img_tuples
248
+
249
+ # Compute the input by sum / average the channels
250
+ # Alpha is an amount of weight which is applied to the channels when combining them
251
+ # How to sample alpha is still under research
252
+ inp, alpha = self._compute_input(input_tuples)
253
+ target_tuples = [img[:1] for img in img_tuples]
254
+ # add noise to target.
255
+ if len(noise_tuples) >= 1:
256
+ target_tuples = [
257
+ x + noise for x, noise in zip(target_tuples, noise_tuples[1:])
258
+ ]
259
+
260
+ target = self._compute_target(target_tuples, alpha)
261
+
262
+ norm_target = self.normalize_target(target)
263
+
264
+ output = [inp, norm_target]
265
+
266
+ if self._return_alpha:
267
+ output.append(alpha)
268
+
269
+ if isinstance(index, int):
270
+ return tuple(output)
271
+
272
+ _, grid_size = index
273
+ output.append(grid_size)
274
+ return tuple(output)