careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,196 @@
1
+ """
2
+ Here, we have multiple folders, each containing images of a single channel.
3
+ """
4
+
5
+ from collections import defaultdict
6
+ from functools import cache
7
+
8
+ import numpy as np
9
+
10
+ from .types import DataSplitType
11
+
12
+
13
+ def l2(x):
14
+ return np.sqrt(np.mean(np.array(x) ** 2))
15
+
16
+
17
+ class MultiCropDset:
18
+ def __init__(
19
+ self,
20
+ data_config,
21
+ fpath: str,
22
+ load_data_fn=None,
23
+ val_fraction=None,
24
+ test_fraction=None,
25
+ ):
26
+
27
+ assert (
28
+ data_config.input_is_sum == True
29
+ ), "This dataset is designed for sum of images"
30
+
31
+ self._img_sz = data_config.image_size
32
+ self._enable_rotation = data_config.enable_rotation_aug
33
+
34
+ self._background_values = data_config.background_values
35
+ self._data = load_data_fn(
36
+ data_config, fpath, data_config.datasplit_type, val_fraction, test_fraction
37
+ )
38
+
39
+ # remove upper quantiles, crucial for removing puncta
40
+ self.max_val = data_config.max_val
41
+ if self.max_val is not None:
42
+ for ch_idx, data in enumerate(self._data):
43
+ if self.max_val[ch_idx] is not None:
44
+ for idx in range(len(data)):
45
+ data[idx][data[idx] > self.max_val[ch_idx]] = self.max_val[
46
+ ch_idx
47
+ ]
48
+
49
+ # remove background values
50
+ if self._background_values is not None:
51
+ final_data_arr = []
52
+ for ch_idx, data in enumerate(self._data):
53
+ data_float = [x.astype(np.float32) for x in data]
54
+ final_data_arr.append(
55
+ [x - self._background_values[ch_idx] for x in data_float]
56
+ )
57
+ self._data = final_data_arr
58
+
59
+ print(
60
+ f"{self.__class__.__name__} N:{len(self)} Rot:{self._enable_rotation} Ch:{len(self._data)} MaxVal:{self.max_val} Bg:{self._background_values}"
61
+ )
62
+
63
+ def get_max_val(self):
64
+ return self.max_val
65
+
66
+ def compute_mean_std(self):
67
+ mean_tar_dict = defaultdict(list)
68
+ std_tar_dict = defaultdict(list)
69
+ mean_inp = []
70
+ std_inp = []
71
+ for _ in range(30000):
72
+ crops = []
73
+ for ch_idx in range(len(self._data)):
74
+ crop = self.sample_crop(ch_idx)
75
+ mean_tar_dict[ch_idx].append(np.mean(crop))
76
+ std_tar_dict[ch_idx].append(np.std(crop))
77
+ crops.append(crop)
78
+
79
+ inp = 0
80
+ for img in crops:
81
+ inp += img
82
+
83
+ mean_inp.append(np.mean(inp))
84
+ std_inp.append(np.std(inp))
85
+
86
+ output_mean = defaultdict(list)
87
+ output_std = defaultdict(list)
88
+ NC = len(self._data)
89
+ for ch_idx in range(NC):
90
+ output_mean["target"].append(np.mean(mean_tar_dict[ch_idx]))
91
+ output_std["target"].append(l2(std_tar_dict[ch_idx]))
92
+
93
+ output_mean["target"] = np.array(output_mean["target"]).reshape(NC, 1, 1)
94
+ output_std["target"] = np.array(output_std["target"]).reshape(NC, 1, 1)
95
+
96
+ output_mean["input"] = np.array([np.mean(mean_inp)]).reshape(1, 1, 1)
97
+ output_std["input"] = np.array([l2(std_inp)]).reshape(1, 1, 1)
98
+ return dict(output_mean), dict(output_std)
99
+
100
+ def set_mean_std(self, mean_dict, std_dict):
101
+ self._data_mean = mean_dict
102
+ self._data_std = std_dict
103
+
104
+ def get_mean_std(self):
105
+ return self._data_mean, self._data_std
106
+
107
+ def get_num_frames(self):
108
+ return len(self._data)
109
+
110
+ @cache
111
+ def crop_probablities(self, ch_idx):
112
+ sizes = np.array([np.prod(x.shape) for x in self._data[ch_idx]])
113
+ return sizes / sizes.sum()
114
+
115
+ def sample_crop(self, ch_idx):
116
+ idx = None
117
+ count = 0
118
+ while idx is None:
119
+ count += 1
120
+ idx = np.random.choice(
121
+ len(self._data[ch_idx]), p=self.crop_probablities(ch_idx)
122
+ )
123
+ data = self._data[ch_idx][idx]
124
+ if data.shape[0] >= self._img_sz[0] and data.shape[1] >= self._img_sz[1]:
125
+ h = np.random.randint(0, data.shape[0] - self._img_sz[0])
126
+ w = np.random.randint(0, data.shape[1] - self._img_sz[1])
127
+ return data[h : h + self._img_sz[0], w : w + self._img_sz[1]]
128
+ elif count > 100:
129
+ raise ValueError("Cannot find a valid crop")
130
+ else:
131
+ idx = None
132
+
133
+ return None
134
+
135
+ def len_per_channel(self, ch_idx):
136
+ return np.sum([np.prod(x.shape) for x in self._data[ch_idx]]) / np.prod(
137
+ self._img_sz
138
+ )
139
+
140
+ def imgs_for_patch(self):
141
+ return [self.sample_crop(ch_idx) for ch_idx in range(len(self._data))]
142
+
143
+ def __len__(self):
144
+ len_per_channel = [
145
+ self.len_per_channel(ch_idx) for ch_idx in range(len(self._data))
146
+ ]
147
+ return int(np.max(len_per_channel))
148
+
149
+ def _rotate(self, img_tuples):
150
+ return self._rotate2D(img_tuples)
151
+
152
+ def _rotate2D(self, img_tuples):
153
+ img_kwargs = {}
154
+ for i, img in enumerate(img_tuples):
155
+ for k in range(len(img)):
156
+ img_kwargs[f"img{i}_{k}"] = img[k]
157
+
158
+ keys = list(img_kwargs.keys())
159
+ self._rotation_transform.add_targets({k: "image" for k in keys})
160
+ rot_dic = self._rotation_transform(image=img_tuples[0][0], **img_kwargs)
161
+
162
+ rotated_img_tuples = []
163
+ for i, img in enumerate(img_tuples):
164
+ if len(img) == 1:
165
+ rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
166
+ else:
167
+ rotated_img_tuples.append(
168
+ np.concatenate(
169
+ [rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
170
+ )
171
+ )
172
+
173
+ return rotated_img_tuples
174
+
175
+ def _compute_input(self, imgs):
176
+ inp = 0
177
+ for img in imgs:
178
+ inp += img
179
+
180
+ inp = inp[None]
181
+ inp = (inp - self._data_mean["input"]) / (self._data_std["input"])
182
+ return inp
183
+
184
+ def _compute_target(self, imgs):
185
+ imgs = np.stack(imgs)
186
+ target = (imgs - self._data_mean["target"]) / (self._data_std["target"])
187
+ return target
188
+
189
+ def __getitem__(self, idx):
190
+ imgs = self.imgs_for_patch()
191
+ if self._enable_rotation:
192
+ imgs = self._rotate(imgs)
193
+
194
+ inp = self._compute_input(imgs)
195
+ target = self._compute_target(imgs)
196
+ return inp, target
@@ -0,0 +1,335 @@
1
+ from collections.abc import Sequence
2
+ from typing import Callable, Union
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ from .config import MicroSplitDataConfig
8
+ from .lc_dataset import LCMultiChDloader
9
+ from .multich_dataset import MultiChDloader
10
+ from .types import DataSplitType
11
+
12
+
13
+ class TwoChannelData(Sequence):
14
+ """
15
+ each element in data_arr should be a N*H*W array
16
+ """
17
+
18
+ def __init__(self, data_arr1, data_arr2, paths_data1=None, paths_data2=None):
19
+ assert len(data_arr1) == len(data_arr2)
20
+ self.paths1 = paths_data1
21
+ self.paths2 = paths_data2
22
+
23
+ self._data = []
24
+ for i in range(len(data_arr1)):
25
+ assert data_arr1[i].shape == data_arr2[i].shape
26
+ assert (
27
+ len(data_arr1[i].shape) == 3
28
+ ), f"Each element in data arrays should be a N*H*W, but {data_arr1[i].shape}"
29
+ self._data.append(
30
+ np.concatenate(
31
+ [data_arr1[i][..., None], data_arr2[i][..., None]], axis=-1
32
+ )
33
+ )
34
+
35
+ def __len__(self):
36
+ n = 0
37
+ for x in self._data:
38
+ n += x.shape[0]
39
+ return n
40
+
41
+ def __getitem__(self, idx):
42
+ n = 0
43
+ for dataidx, x in enumerate(self._data):
44
+ if idx < n + x.shape[0]:
45
+ if self.paths1 is None:
46
+ return x[idx - n], None
47
+ else:
48
+ return x[idx - n], (self.paths1[dataidx], self.paths2[dataidx])
49
+ n += x.shape[0]
50
+ raise IndexError("Index out of range")
51
+
52
+
53
+ class MultiChannelData(Sequence):
54
+ """
55
+ each element in data_arr should be a N*H*W array
56
+ """
57
+
58
+ def __init__(self, data_arr, paths=None):
59
+ self.paths = paths
60
+
61
+ self._data = data_arr
62
+
63
+ def __len__(self):
64
+ n = 0
65
+ for x in self._data:
66
+ n += x.shape[0]
67
+ return n
68
+
69
+ def __getitem__(self, idx):
70
+ n = 0
71
+ for dataidx, x in enumerate(self._data):
72
+ if idx < n + x.shape[0]:
73
+ if self.paths is None:
74
+ return x[idx - n], None
75
+ else:
76
+ return x[idx - n], (self.paths[dataidx])
77
+ n += x.shape[0]
78
+ raise IndexError("Index out of range")
79
+
80
+
81
+ class SingleFileLCDset(LCMultiChDloader):
82
+ def __init__(
83
+ self,
84
+ preloaded_data: NDArray,
85
+ data_config: MicroSplitDataConfig,
86
+ fpath: str,
87
+ load_data_fn: Callable,
88
+ val_fraction=None,
89
+ test_fraction=None,
90
+ ):
91
+ self._preloaded_data = preloaded_data
92
+ super().__init__(
93
+ data_config,
94
+ fpath,
95
+ load_data_fn=load_data_fn,
96
+ val_fraction=val_fraction,
97
+ test_fraction=test_fraction,
98
+ )
99
+
100
+ @property
101
+ def data_path(self):
102
+ return self._fpath
103
+
104
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
105
+ pass
106
+
107
+ def load_data(
108
+ self,
109
+ data_config: MicroSplitDataConfig,
110
+ datasplit_type: DataSplitType,
111
+ load_data_fn: Callable,
112
+ val_fraction=None,
113
+ test_fraction=None,
114
+ allow_generation=None,
115
+ ):
116
+ self._data = self._preloaded_data
117
+ assert "channel_1" not in data_config or isinstance(data_config.channel_1, str)
118
+ assert "channel_2" not in data_config or isinstance(data_config.channel_2, str)
119
+ assert "channel_3" not in data_config or isinstance(data_config.channel_3, str)
120
+ self._loaded_data_preprocessing(data_config)
121
+
122
+
123
+ class SingleFileDset(MultiChDloader):
124
+ def __init__(
125
+ self,
126
+ preloaded_data: NDArray,
127
+ data_config: MicroSplitDataConfig,
128
+ fpath: str,
129
+ load_data_fn: Callable,
130
+ val_fraction=None,
131
+ test_fraction=None,
132
+ ):
133
+ self._preloaded_data = preloaded_data
134
+ super().__init__(
135
+ data_config,
136
+ fpath,
137
+ load_data_fn=load_data_fn,
138
+ val_fraction=val_fraction,
139
+ test_fraction=test_fraction,
140
+ )
141
+
142
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
143
+ pass
144
+
145
+ @property
146
+ def data_path(self):
147
+ return self._fpath
148
+
149
+ def load_data(
150
+ self,
151
+ data_config: MicroSplitDataConfig,
152
+ datasplit_type: DataSplitType,
153
+ load_data_fn: Callable[..., NDArray],
154
+ val_fraction=None,
155
+ test_fraction=None,
156
+ allow_generation=None,
157
+ ):
158
+ self._data = self._preloaded_data
159
+ assert (
160
+ "channel_1" not in data_config
161
+ ), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
162
+ assert (
163
+ "channel_2" not in data_config
164
+ ), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
165
+ assert (
166
+ "channel_3" not in data_config
167
+ ), "Outdated config file. Please remove channel_1, channel_2, channel_3 from the config file."
168
+ self._loaded_data_preprocessing(data_config)
169
+
170
+
171
+ class MultiFileDset:
172
+ """
173
+ Here, we handle dataset having multiple files. Each file can have a different spatial dimension and number of frames (Z stack).
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ data_config: MicroSplitDataConfig,
179
+ fpath: str,
180
+ load_data_fn: Callable[..., Union[TwoChannelData, MultiChannelData]],
181
+ val_fraction=None,
182
+ test_fraction=None,
183
+ ):
184
+ self._fpath = fpath
185
+ data: Union[TwoChannelData, MultiChannelData] = load_data_fn(
186
+ data_config,
187
+ self._fpath,
188
+ data_config.datasplit_type,
189
+ val_fraction=val_fraction,
190
+ test_fraction=test_fraction,
191
+ )
192
+ self.dsets = []
193
+
194
+ for i in range(len(data)):
195
+ prefetched_data, fpath_tuple = data[i]
196
+ if (
197
+ data_config.multiscale_lowres_count is not None
198
+ and data_config.multiscale_lowres_count > 1
199
+ ):
200
+
201
+ self.dsets.append(
202
+ SingleFileLCDset(
203
+ prefetched_data[None],
204
+ data_config,
205
+ fpath_tuple,
206
+ load_data_fn,
207
+ val_fraction=val_fraction,
208
+ test_fraction=test_fraction,
209
+ )
210
+ )
211
+
212
+ else:
213
+ self.dsets.append(
214
+ SingleFileDset(
215
+ prefetched_data[None],
216
+ data_config,
217
+ fpath_tuple,
218
+ load_data_fn,
219
+ val_fraction=val_fraction,
220
+ test_fraction=test_fraction,
221
+ )
222
+ )
223
+
224
+ self.rm_bkground_set_max_val_and_upperclip_data(
225
+ data_config.max_val, data_config.datasplit_type
226
+ )
227
+ count = 0
228
+ avg_height = 0
229
+ avg_width = 0
230
+ for dset in self.dsets:
231
+ shape = dset.get_data_shape()
232
+ avg_height += shape[1]
233
+ avg_width += shape[2]
234
+ count += shape[0]
235
+
236
+ avg_height = int(avg_height / len(self.dsets))
237
+ avg_width = int(avg_width / len(self.dsets))
238
+ print(
239
+ f"{self.__class__.__name__} avg height: {avg_height}, avg width: {avg_width}, count: {count}"
240
+ )
241
+
242
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
243
+ self.set_max_val(max_val, datasplit_type)
244
+ self.upperclip_data()
245
+
246
+ def set_mean_std(self, mean_val, std_val):
247
+ for dset in self.dsets:
248
+ dset.set_mean_std(mean_val, std_val)
249
+
250
+ def get_mean_std(self):
251
+ return self.dsets[0].get_mean_std()
252
+
253
+ def compute_max_val(self):
254
+ max_val_arr = []
255
+ for dset in self.dsets:
256
+ max_val_arr.append(dset.compute_max_val())
257
+ return np.max(max_val_arr)
258
+
259
+ def set_max_val(self, max_val, datasplit_type):
260
+ if datasplit_type == DataSplitType.Train:
261
+ assert max_val is None
262
+ max_val = self.compute_max_val()
263
+ for dset in self.dsets:
264
+ dset.set_max_val(max_val, datasplit_type)
265
+
266
+ def upperclip_data(self):
267
+ for dset in self.dsets:
268
+ dset.upperclip_data()
269
+
270
+ def get_max_val(self):
271
+ return self.dsets[0].get_max_val()
272
+
273
+ def get_img_sz(self):
274
+ return self.dsets[0].get_img_sz()
275
+
276
+ def set_img_sz(self, image_size, grid_size):
277
+ for dset in self.dsets:
278
+ dset.set_img_sz(image_size, grid_size)
279
+
280
+ def compute_mean_std(self):
281
+ cur_mean = {"target": 0, "input": 0}
282
+ cur_std = {"target": 0, "input": 0}
283
+ for dset in self.dsets:
284
+ mean, std = dset.compute_mean_std()
285
+ cur_mean["target"] += mean["target"]
286
+ cur_mean["input"] += mean["input"]
287
+
288
+ cur_std["target"] += std["target"]
289
+ cur_std["input"] += std["input"]
290
+
291
+ cur_mean["target"] /= len(self.dsets)
292
+ cur_mean["input"] /= len(self.dsets)
293
+ cur_std["target"] /= len(self.dsets)
294
+ cur_std["input"] /= len(self.dsets)
295
+ return cur_mean, cur_std
296
+
297
+ def compute_individual_mean_std(self):
298
+ cum_mean = 0
299
+ cum_std = 0
300
+ for dset in self.dsets:
301
+ mean, std = dset.compute_individual_mean_std()
302
+ cum_mean += mean
303
+ cum_std += std
304
+ return cum_mean / len(self.dsets), cum_std / len(self.dsets)
305
+
306
+ def get_num_frames(self):
307
+ return len(self.dsets)
308
+
309
+ def reduce_data(
310
+ self, t_list=None, h_start=None, h_end=None, w_start=None, w_end=None
311
+ ):
312
+ assert h_start is None
313
+ assert h_end is None
314
+ assert w_start is None
315
+ assert w_end is None
316
+ self.dsets = [self.dsets[t] for t in t_list]
317
+ print(
318
+ f"[{self.__class__.__name__}] Data reduced. New data count: {len(self.dsets)}"
319
+ )
320
+
321
+ def __len__(self):
322
+ out = 0
323
+ for dset in self.dsets:
324
+ out += len(dset)
325
+ return out
326
+
327
+ def __getitem__(self, idx):
328
+ cum_len = 0
329
+ for dset in self.dsets:
330
+ cum_len += len(dset)
331
+ if idx < cum_len:
332
+ rel_idx = idx - (cum_len - len(dset))
333
+ return dset[rel_idx]
334
+
335
+ raise IndexError("Index out of range")
@@ -0,0 +1,32 @@
1
+ from enum import Enum
2
+
3
+
4
+ class DataType(Enum):
5
+ HTH24Data = 0
6
+ HTLIF24Data = 1
7
+ PaviaP24Data = 2
8
+ TavernaSox2GolgiV2 = 3
9
+ Dao3ChannelWithInput = 4
10
+ ExpMicroscopyV1 = 5
11
+ ExpMicroscopyV2 = 6
12
+ Dao3Channel = 7
13
+ TavernaSox2Golgi = 8
14
+ HTIba1Ki67 = 9
15
+ OptiMEM100_014 = 10
16
+ SeparateTiffData = 11
17
+ BioSR_MRC = 12
18
+ HTH23BData = 13 # puncta, in case we have differently sized crops for each channel.
19
+ Care3D = 14
20
+
21
+
22
+ class DataSplitType(Enum):
23
+ All = 0
24
+ Train = 1
25
+ Val = 2
26
+ Test = 3
27
+
28
+
29
+ class TilingMode(Enum):
30
+ TrimBoundary = 0
31
+ PadBoundary = 1
32
+ ShiftBoundary = 2
File without changes
@@ -0,0 +1,114 @@
1
+ """
2
+ Utility functions needed by dataloader & co.
3
+ """
4
+
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from skimage.io import imread, imsave
9
+
10
+
11
+ def load_tiff(path):
12
+ """
13
+ Returns a 4d numpy array: num_imgs*h*w*num_channels
14
+ """
15
+ data = imread(path, plugin="tifffile")
16
+ return data
17
+
18
+
19
+ def save_tiff(path, data):
20
+ imsave(path, data, plugin="tifffile")
21
+
22
+
23
+ def load_tiffs(paths):
24
+ data = [load_tiff(path) for path in paths]
25
+ return np.concatenate(data, axis=0)
26
+
27
+
28
+ def split_in_half(s, e):
29
+ n = e - s
30
+ s1 = list(np.arange(n // 2))
31
+ s2 = list(np.arange(n // 2, n))
32
+ return [x + s for x in s1], [x + s for x in s2]
33
+
34
+
35
+ def adjust_for_imbalance_in_fraction_value(
36
+ val: List[int],
37
+ test: List[int],
38
+ val_fraction: float,
39
+ test_fraction: float,
40
+ total_size: int,
41
+ ):
42
+ """
43
+ here, val and test are divided almost equally. Here, we need to take into account their respective fractions
44
+ and pick elements rendomly from one array and put in the other array.
45
+ """
46
+ if val_fraction == 0:
47
+ test += val
48
+ val = []
49
+ elif test_fraction == 0:
50
+ val += test
51
+ test = []
52
+ else:
53
+ diff_fraction = test_fraction - val_fraction
54
+ if diff_fraction > 0:
55
+ imb_count = int(diff_fraction * total_size / 2)
56
+ val = list(np.random.RandomState(seed=955).permutation(val))
57
+ test += val[:imb_count]
58
+ val = val[imb_count:]
59
+ elif diff_fraction < 0:
60
+ imb_count = int(-1 * diff_fraction * total_size / 2)
61
+ test = list(np.random.RandomState(seed=955).permutation(test))
62
+ val += test[:imb_count]
63
+ test = test[imb_count:]
64
+ return val, test
65
+
66
+
67
+ def get_datasplit_tuples(
68
+ val_fraction: float,
69
+ test_fraction: float,
70
+ total_size: int,
71
+ starting_test: bool = False,
72
+ ):
73
+ if starting_test:
74
+ # test => val => train
75
+ test = list(range(0, int(total_size * test_fraction)))
76
+ val = list(range(test[-1] + 1, test[-1] + 1 + int(total_size * val_fraction)))
77
+ train = list(range(val[-1] + 1, total_size))
78
+ else:
79
+ # {test,val}=> train
80
+ test_val_size = int((val_fraction + test_fraction) * total_size)
81
+ train = list(range(test_val_size, total_size))
82
+
83
+ if test_val_size == 0:
84
+ test = []
85
+ val = []
86
+ return train, val, test
87
+
88
+ # Split the test and validation in chunks.
89
+ chunksize = max(1, min(3, test_val_size // 2))
90
+
91
+ nchunks = test_val_size // chunksize
92
+
93
+ test = []
94
+ val = []
95
+ s = 0
96
+ for i in range(nchunks):
97
+ if i % 2 == 0:
98
+ val += list(np.arange(s, s + chunksize))
99
+ else:
100
+ test += list(np.arange(s, s + chunksize))
101
+ s += chunksize
102
+
103
+ if i % 2 == 0:
104
+ test += list(np.arange(s, test_val_size))
105
+ else:
106
+ p1, p2 = split_in_half(s, test_val_size)
107
+ test += p1
108
+ val += p2
109
+
110
+ val, test = adjust_for_imbalance_in_fraction_value(
111
+ val, test, val_fraction, test_fraction, total_size
112
+ )
113
+
114
+ return train, val, test