careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,1121 @@
1
+ """
2
+ A place for Datasets and Dataloaders.
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Optional, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+
12
+ from .utils.empty_patch_fetcher import EmptyPatchFetcher
13
+ from .utils.index_manager import GridIndexManager
14
+ from .utils.index_switcher import IndexSwitcher
15
+ from .config import MicroSplitDataConfig
16
+ from .types import DataSplitType, TilingMode
17
+
18
+
19
+ class MultiChDloader(Dataset):
20
+ """Multi-channel dataset loader."""
21
+
22
+ def __init__(
23
+ self,
24
+ data_config: MicroSplitDataConfig,
25
+ datapath: Union[str, Path],
26
+ load_data_fn: Optional[Callable] = None,
27
+ val_fraction: float = 0.1,
28
+ test_fraction: float = 0.1,
29
+ allow_generation: bool = False,
30
+ ):
31
+ """ """
32
+ self._data_type = data_config.data_type
33
+ self._fpath = datapath
34
+ self._data = self._noise_data = None
35
+ self.Z = 1
36
+ self._5Ddata = False
37
+ self._tiling_mode = data_config.tiling_mode
38
+ # by default, if the noise is present, add it to the input and target.
39
+ self._disable_noise = False # to add synthetic noise
40
+ self._poisson_noise_factor = None
41
+ self._train_index_switcher = None
42
+ self._depth3D = data_config.depth3D
43
+ self._mode_3D = data_config.mode_3D
44
+ # NOTE: Input is the sum of the different channels. It is not the average of the different channels.
45
+ self._input_is_sum = data_config.input_is_sum
46
+ self._num_channels = data_config.num_channels
47
+ self._input_idx = data_config.input_idx
48
+ self._tar_idx_list = data_config.target_idx_list
49
+
50
+ if data_config.datasplit_type == DataSplitType.Train:
51
+ self._datausage_fraction = data_config.trainig_datausage_fraction
52
+ # assert self._datausage_fraction == 1.0, 'Not supported. Use validtarget_random_fraction and training_validtarget_fraction to get the same effect'
53
+ self._validtarget_rand_fract = data_config.validtarget_random_fraction
54
+ # self._validtarget_random_fraction_final = data_config.get('validtarget_random_fraction_final', None)
55
+ # self._validtarget_random_fraction_stepepoch = data_config.get('validtarget_random_fraction_stepepoch', None)
56
+ # self._idx_count = 0
57
+ elif data_config.datasplit_type == DataSplitType.Val:
58
+ self._datausage_fraction = data_config.validation_datausage_fraction
59
+ else:
60
+ self._datausage_fraction = 1.0
61
+
62
+ self.load_data(
63
+ data_config,
64
+ data_config.datasplit_type,
65
+ load_data_fn=load_data_fn,
66
+ val_fraction=val_fraction,
67
+ test_fraction=test_fraction,
68
+ allow_generation=data_config.allow_generation,
69
+ )
70
+ self._normalized_input = data_config.normalized_input
71
+ self._quantile = 1.0
72
+ self._channelwise_quantile = False
73
+ self._background_quantile = 0.0
74
+ self._clip_background_noise_to_zero = False
75
+ self._skip_normalization_using_mean = False
76
+ self._empty_patch_replacement_enabled = False
77
+
78
+ self._background_values = None
79
+
80
+ self._overlapping_padding_kwargs = data_config.overlapping_padding_kwargs
81
+ if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
82
+ if (
83
+ self._overlapping_padding_kwargs is None
84
+ or data_config.multiscale_lowres_count is not None
85
+ ):
86
+ # raise warning
87
+ print("Padding is not used with this alignement style")
88
+ else:
89
+ assert (
90
+ self._overlapping_padding_kwargs is not None
91
+ ), "When not trimming boudnary, padding is needed."
92
+
93
+ self._is_train = data_config.datasplit_type == DataSplitType.Train
94
+
95
+ # input = alpha * ch1 + (1-alpha)*ch2.
96
+ # alpha is sampled randomly between these two extremes
97
+ self._start_alpha_arr = self._end_alpha_arr = self._return_alpha = None
98
+
99
+ self._img_sz = self._grid_sz = self._repeat_factor = self.idx_manager = None
100
+
101
+ # changed set_img_sz because "grid_size" in data_config returns false
102
+ try:
103
+ grid_size = data_config.grid_size
104
+ except AttributeError:
105
+ grid_size = data_config.image_size
106
+
107
+ if self._is_train:
108
+ self._start_alpha_arr = data_config.start_alpha
109
+ self._end_alpha_arr = data_config.end_alpha
110
+
111
+ self.set_img_sz(data_config.image_size, grid_size)
112
+
113
+ if self._validtarget_rand_fract is not None:
114
+ self._train_index_switcher = IndexSwitcher(
115
+ self.idx_manager, data_config, self._img_sz
116
+ )
117
+
118
+ else:
119
+ self.set_img_sz(data_config.image_size, grid_size)
120
+
121
+ self._return_alpha = False
122
+ self._return_index = False
123
+
124
+ self._empty_patch_replacement_enabled = (
125
+ data_config.empty_patch_replacement_enabled and self._is_train
126
+ )
127
+ if self._empty_patch_replacement_enabled:
128
+ self._empty_patch_replacement_channel_idx = (
129
+ data_config.empty_patch_replacement_channel_idx
130
+ )
131
+ self._empty_patch_replacement_probab = (
132
+ data_config.empty_patch_replacement_probab
133
+ )
134
+ data_frames = self._data[..., self._empty_patch_replacement_channel_idx]
135
+ # NOTE: This is on the raw data. So, it must be called before removing the background.
136
+ self._empty_patch_fetcher = EmptyPatchFetcher(
137
+ self.idx_manager,
138
+ self._img_sz,
139
+ data_frames,
140
+ max_val_threshold=data_config.empty_patch_max_val_threshold,
141
+ )
142
+
143
+ self.rm_bkground_set_max_val_and_upperclip_data(
144
+ data_config.max_val, data_config.datasplit_type
145
+ )
146
+
147
+ # For overlapping dloader, image_size and repeat_factors are not related. hence a different function.
148
+
149
+ self._mean = None
150
+ self._std = None
151
+ self._use_one_mu_std = data_config.use_one_mu_std
152
+
153
+ self._target_separate_normalization = data_config.target_separate_normalization
154
+
155
+ self._enable_rotation = data_config.enable_rotation_aug
156
+ flipz_3D = data_config.random_flip_z_3D
157
+ self._flipz_3D = flipz_3D and self._enable_rotation
158
+
159
+ self._enable_random_cropping = data_config.enable_random_cropping
160
+ self._uncorrelated_channels = (
161
+ data_config.uncorrelated_channels and self._is_train
162
+ )
163
+ self._uncorrelated_channel_probab = data_config.uncorrelated_channel_probab
164
+ assert self._is_train or self._uncorrelated_channels is False
165
+ assert (
166
+ self._enable_random_cropping is True or self._uncorrelated_channels is False
167
+ )
168
+ # Randomly rotate [-90,90]
169
+
170
+ self._rotation_transform = None
171
+ if self._enable_rotation:
172
+ # TODO: fix this import
173
+ import albumentations as A
174
+
175
+ self._rotation_transform = A.Compose([A.Flip(), A.RandomRotate90()])
176
+
177
+ # TODO: remove print log messages
178
+ # if print_vars:
179
+ # msg = self._init_msg()
180
+ # print(msg)
181
+
182
+ def disable_noise(self):
183
+ assert (
184
+ self._poisson_noise_factor is None
185
+ ), "This is not supported. Poisson noise is added to the data itself and so the noise cannot be disabled."
186
+ self._disable_noise = True
187
+
188
+ def enable_noise(self):
189
+ self._disable_noise = False
190
+
191
+ def get_data_shape(self):
192
+ return self._data.shape
193
+
194
+ def load_data(
195
+ self,
196
+ data_config,
197
+ datasplit_type,
198
+ load_data_fn: Callable,
199
+ val_fraction=None,
200
+ test_fraction=None,
201
+ allow_generation=None,
202
+ ):
203
+ self._data = load_data_fn(
204
+ data_config,
205
+ self._fpath,
206
+ datasplit_type,
207
+ val_fraction=val_fraction,
208
+ test_fraction=test_fraction,
209
+ allow_generation=allow_generation,
210
+ )
211
+ self._loaded_data_preprocessing(data_config)
212
+
213
+ def _loaded_data_preprocessing(self, data_config):
214
+ old_shape = self._data.shape
215
+ if self._datausage_fraction < 1.0:
216
+ framepixelcount = np.prod(self._data.shape[1:3])
217
+ pixelcount = int(
218
+ len(self._data) * framepixelcount * self._datausage_fraction
219
+ )
220
+ frame_count = int(np.ceil(pixelcount / framepixelcount))
221
+ last_frame_reduced_size, _ = IndexSwitcher.get_reduced_frame_size(
222
+ self._data.shape[:3], self._datausage_fraction
223
+ )
224
+ self._data = self._data[:frame_count].copy()
225
+ if frame_count == 1:
226
+ self._data = self._data[
227
+ :, :last_frame_reduced_size, :last_frame_reduced_size
228
+ ].copy()
229
+ print(
230
+ f"[{self.__class__.__name__}] New data shape: {self._data.shape} Old: {old_shape}"
231
+ )
232
+
233
+ msg = ""
234
+ if data_config.poisson_noise_factor > 0:
235
+ self._poisson_noise_factor = data_config.poisson_noise_factor
236
+ msg += f"Adding Poisson noise with factor {self._poisson_noise_factor}.\t"
237
+ self._data = np.random.poisson(self._data / self._poisson_noise_factor)
238
+
239
+ if data_config.enable_gaussian_noise:
240
+ synthetic_scale = data_config.synthetic_gaussian_scale
241
+ msg += f"Adding Gaussian noise with scale {synthetic_scale}"
242
+ # 0 => noise for input. 1: => noise for all targets.
243
+ shape = self._data.shape
244
+ self._noise_data = np.random.normal(
245
+ 0, synthetic_scale, (*shape[:-1], shape[-1] + 1)
246
+ )
247
+ if data_config.input_has_dependant_noise:
248
+ msg += ". Moreover, input has dependent noise"
249
+ self._noise_data[..., 0] = np.mean(self._noise_data[..., 1:], axis=-1)
250
+ print(msg)
251
+
252
+ if len(self._data.shape) == 5:
253
+ if self._mode_3D:
254
+ self._5Ddata = True
255
+ else:
256
+ assert self._depth3D == 1, "Depth3D must be 1 for 2D training"
257
+ self._data = self._data.reshape(-1, *self._data.shape[2:])
258
+
259
+ if self._5Ddata:
260
+ self.Z = self._data.shape[1]
261
+
262
+ if self._depth3D > 1:
263
+ assert self._5Ddata, "Data must be 5D:NxZxHxWxC for 3D data"
264
+
265
+ assert (
266
+ self._data.shape[-1] == self._num_channels
267
+ ), "Number of channels in data and config do not match."
268
+
269
+ def save_background(self, channel_idx, frame_idx, background_value):
270
+ self._background_values[frame_idx, channel_idx] = background_value
271
+
272
+ def get_background(self, channel_idx, frame_idx):
273
+ return self._background_values[frame_idx, channel_idx]
274
+
275
+ def remove_background(self):
276
+
277
+ self._background_values = np.zeros((self._data.shape[0], self._data.shape[-1]))
278
+
279
+ if self._background_quantile == 0.0:
280
+ assert (
281
+ self._clip_background_noise_to_zero is False
282
+ ), "This operation currently happens later in this function."
283
+ return
284
+
285
+ if self._data.dtype in [np.uint16]:
286
+ # unsigned integer creates havoc
287
+ self._data = self._data.astype(np.int32)
288
+
289
+ for ch in range(self._data.shape[-1]):
290
+ for idx in range(self._data.shape[0]):
291
+ qval = np.quantile(self._data[idx, ..., ch], self._background_quantile)
292
+ assert (
293
+ np.abs(qval) > 20
294
+ ), "We are truncating the qval to an integer which will only make sense if it is large enough"
295
+ # NOTE: Here, there can be an issue if you work with normalized data
296
+ qval = int(qval)
297
+ self.save_background(ch, idx, qval)
298
+ self._data[idx, ..., ch] -= qval
299
+
300
+ if self._clip_background_noise_to_zero:
301
+ self._data[self._data < 0] = 0
302
+
303
+ def rm_bkground_set_max_val_and_upperclip_data(self, max_val, datasplit_type):
304
+ self.remove_background()
305
+ self.set_max_val(max_val, datasplit_type)
306
+ self.upperclip_data()
307
+
308
+ def upperclip_data(self):
309
+ if isinstance(self.max_val, list):
310
+ chN = self._data.shape[-1]
311
+ assert chN == len(self.max_val)
312
+ for ch in range(chN):
313
+ ch_data = self._data[..., ch]
314
+ ch_q = self.max_val[ch]
315
+ ch_data[ch_data > ch_q] = ch_q
316
+ self._data[..., ch] = ch_data
317
+ else:
318
+ self._data[self._data > self.max_val] = self.max_val
319
+
320
+ def compute_max_val(self):
321
+ if self._channelwise_quantile:
322
+ max_val_arr = [
323
+ np.quantile(self._data[..., i], self._quantile)
324
+ for i in range(self._data.shape[-1])
325
+ ]
326
+ return max_val_arr
327
+ else:
328
+ return np.quantile(self._data, self._quantile)
329
+
330
+ def set_max_val(self, max_val, datasplit_type):
331
+
332
+ if max_val is None:
333
+ assert datasplit_type == DataSplitType.Train
334
+ self.max_val = self.compute_max_val()
335
+ else:
336
+ assert max_val is not None
337
+ self.max_val = max_val
338
+
339
+ def get_max_val(self):
340
+ return self.max_val
341
+
342
+ def get_img_sz(self):
343
+ return self._img_sz
344
+
345
+ def get_num_frames(self):
346
+ return self._data.shape[0]
347
+
348
+ def reduce_data(
349
+ self,
350
+ t_list=None,
351
+ z_start=None,
352
+ z_end=None,
353
+ h_start=None,
354
+ h_end=None,
355
+ w_start=None,
356
+ w_end=None,
357
+ ):
358
+ if self._5Ddata:
359
+ if t_list is None:
360
+ t_list = list(range(self._data.shape[0]))
361
+ if z_start is None:
362
+ z_start = 0
363
+ if z_end is None:
364
+ z_end = self._data.shape[1]
365
+ if h_start is None:
366
+ h_start = 0
367
+ if h_end is None:
368
+ h_end = self._data.shape[2]
369
+ if w_start is None:
370
+ w_start = 0
371
+ if w_end is None:
372
+ w_end = self._data.shape[3]
373
+ self._data = self._data[
374
+ t_list, z_start:z_end, h_start:h_end, w_start:w_end, :
375
+ ].copy()
376
+ if self._noise_data is not None:
377
+ self._noise_data = self._noise_data[
378
+ t_list, z_start:z_end, h_start:h_end, w_start:w_end, :
379
+ ].copy()
380
+ else:
381
+ if t_list is None:
382
+ t_list = list(range(self._data.shape[0]))
383
+ if h_start is None:
384
+ h_start = 0
385
+ if h_end is None:
386
+ h_end = self._data.shape[1]
387
+ if w_start is None:
388
+ w_start = 0
389
+ if w_end is None:
390
+ w_end = self._data.shape[2]
391
+
392
+ self._data = self._data[t_list, h_start:h_end, w_start:w_end, :].copy()
393
+ if self._noise_data is not None:
394
+ self._noise_data = self._noise_data[
395
+ t_list, h_start:h_end, w_start:w_end, :
396
+ ].copy()
397
+ # TODO where tf is self._img_sz defined?
398
+ self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
399
+ print(
400
+ f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
401
+ )
402
+
403
+ def get_idx_manager_shapes(
404
+ self, patch_size: int, grid_size: Union[int, tuple[int, int, int]]
405
+ ):
406
+ numC = self._data.shape[-1]
407
+ if self._5Ddata:
408
+ patch_shape = (1, self._depth3D, patch_size, patch_size, numC)
409
+ if isinstance(grid_size, int):
410
+ grid_shape = (1, 1, grid_size, grid_size, numC)
411
+ else:
412
+ assert len(grid_size) == 3
413
+ assert all(
414
+ [g <= p for g, p in zip(grid_size, patch_shape[1:-1])]
415
+ ), f"Grid size {grid_size} must be less than patch size {patch_shape[1:-1]}"
416
+ grid_shape = (1, grid_size[0], grid_size[1], grid_size[2], numC)
417
+ else:
418
+ assert isinstance(grid_size, int)
419
+ grid_shape = (1, grid_size, grid_size, numC)
420
+ patch_shape = (1, patch_size, patch_size, numC)
421
+
422
+ return patch_shape, grid_shape
423
+
424
+ def set_img_sz(self, image_size, grid_size: Union[int, tuple[int, int, int]]):
425
+ """
426
+ If one wants to change the image size on the go, then this can be used.
427
+ Args:
428
+ image_size: size of one patch
429
+ grid_size: frame is divided into square grids of this size. A patch centered on a grid having size `image_size` is returned.
430
+ """
431
+ # hacky way to deal with image shape from new conf
432
+ self._img_sz = image_size[-1] # TODO revisit!
433
+ self._grid_sz = grid_size
434
+ shape = self._data.shape
435
+
436
+ patch_shape, grid_shape = self.get_idx_manager_shapes(
437
+ self._img_sz, self._grid_sz
438
+ )
439
+ self.idx_manager = GridIndexManager(
440
+ shape, grid_shape, patch_shape, self._tiling_mode
441
+ )
442
+ # self.set_repeat_factor()
443
+
444
+ def __len__(self):
445
+ # Vera: N is the number of frames in Z stack
446
+ # Repeat factor is n_rows * n_cols
447
+ return self.idx_manager.total_grid_count()
448
+
449
+ def set_repeat_factor(self):
450
+ if self._grid_sz > 1:
451
+ self._repeat_factor = self.idx_manager.grid_rows(
452
+ self._grid_sz
453
+ ) * self.idx_manager.grid_cols(self._grid_sz)
454
+ else:
455
+ self._repeat_factor = self.idx_manager.grid_rows(
456
+ self._img_sz
457
+ ) * self.idx_manager.grid_cols(self._img_sz)
458
+
459
+ def _init_msg(
460
+ self,
461
+ ):
462
+ msg = (
463
+ f"[{self.__class__.__name__}] Train:{int(self._is_train)} Sz:{self._img_sz}"
464
+ )
465
+ dim_sizes = [
466
+ self.idx_manager.get_individual_dim_grid_count(dim)
467
+ for dim in range(len(self._data.shape))
468
+ ]
469
+ dim_sizes = ",".join([str(x) for x in dim_sizes])
470
+ msg += f" N:{self.N} NumPatchPerN:{self._repeat_factor}"
471
+ msg += f"{self.idx_manager.total_grid_count()} DimSz:({dim_sizes})"
472
+ msg += f" TrimB:{self._tiling_mode}"
473
+ # msg += f' NormInp:{self._normalized_input}'
474
+ # msg += f' SingleNorm:{self._use_one_mu_std}'
475
+ msg += f" Rot:{self._enable_rotation}"
476
+ if self._flipz_3D:
477
+ msg += f" FlipZ:{self._flipz_3D}"
478
+
479
+ msg += f" RandCrop:{self._enable_random_cropping}"
480
+ msg += f" Channel:{self._num_channels}"
481
+ # msg += f' Q:{self._quantile}'
482
+ if self._input_is_sum:
483
+ msg += f" SummedInput:{self._input_is_sum}"
484
+
485
+ if self._empty_patch_replacement_enabled:
486
+ msg += f" ReplaceWithRandSample:{self._empty_patch_replacement_enabled}"
487
+ if self._uncorrelated_channels:
488
+ msg += f" Uncorr:{self._uncorrelated_channels}"
489
+ if self._empty_patch_replacement_enabled:
490
+ msg += f"-{self._empty_patch_replacement_channel_idx}-{self._empty_patch_replacement_probab}"
491
+ if self._background_quantile > 0.0:
492
+ msg += f" BckQ:{self._background_quantile}"
493
+
494
+ if self._start_alpha_arr is not None:
495
+ msg += f" Alpha:[{self._start_alpha_arr},{self._end_alpha_arr}]"
496
+ return msg
497
+
498
+ def _crop_imgs(self, index, *img_tuples: np.ndarray):
499
+ h, w = img_tuples[0].shape[-2:]
500
+ if self._img_sz is None:
501
+ return (
502
+ *img_tuples,
503
+ {"h": [0, h], "w": [0, w], "hflip": False, "wflip": False},
504
+ )
505
+
506
+ if self._enable_random_cropping:
507
+ patch_start_loc = self._get_random_hw(h, w)
508
+ if self._5Ddata:
509
+ patch_start_loc = (
510
+ np.random.choice(1 + img_tuples[0].shape[-3] - self._depth3D),
511
+ ) + patch_start_loc
512
+ else:
513
+ patch_start_loc = self._get_deterministic_loc(index)
514
+
515
+ cropped_imgs = []
516
+ for img in img_tuples:
517
+ img = self._crop_flip_img(img, patch_start_loc, False, False)
518
+ cropped_imgs.append(img)
519
+
520
+ return (
521
+ *tuple(cropped_imgs),
522
+ {
523
+ "hflip": False,
524
+ "wflip": False,
525
+ },
526
+ )
527
+
528
+ def _crop_img(self, img: np.ndarray, patch_start_loc: tuple):
529
+ if self._tiling_mode in [TilingMode.TrimBoundary, TilingMode.ShiftBoundary]:
530
+ # In training, this is used.
531
+ # NOTE: It is my opinion that if I just use self._crop_img_with_padding, it will work perfectly fine.
532
+ # The only benefit this if else loop provides is that it makes it easier to see what happens during training.
533
+ patch_end_loc = (
534
+ np.array(patch_start_loc, dtype=np.int32)
535
+ + self.idx_manager.patch_shape[1:-1]
536
+ )
537
+ if self._5Ddata:
538
+ z_start, h_start, w_start = patch_start_loc
539
+ z_end, h_end, w_end = patch_end_loc
540
+ new_img = img[..., z_start:z_end, h_start:h_end, w_start:w_end]
541
+ else:
542
+ h_start, w_start = patch_start_loc
543
+ h_end, w_end = patch_end_loc
544
+ new_img = img[..., h_start:h_end, w_start:w_end]
545
+
546
+ return new_img
547
+ else:
548
+ # During evaluation, this is used. In this situation, we can have negative h_start, w_start. Or h_start +self._img_sz can be larger than frame
549
+ # In these situations, we need some sort of padding. This is not needed in the LeftTop alignement.
550
+ return self._crop_img_with_padding(img, patch_start_loc)
551
+
552
+ def get_begin_end_padding(self, start_pos, end_pos, max_len):
553
+ """
554
+ The effect is that the image with size self._grid_sz is in the center of the patch with sufficient
555
+ padding on all four sides so that the final patch size is self._img_sz.
556
+ """
557
+ pad_start = 0
558
+ pad_end = 0
559
+ if start_pos < 0:
560
+ pad_start = -1 * start_pos
561
+
562
+ pad_end = max(0, end_pos - max_len)
563
+
564
+ return pad_start, pad_end
565
+
566
+ def _crop_img_with_padding(
567
+ self, img: np.ndarray, patch_start_loc, max_len_vals=None
568
+ ):
569
+ if max_len_vals is None:
570
+ max_len_vals = self.idx_manager.data_shape[1:-1]
571
+ patch_end_loc = np.array(patch_start_loc, dtype=int) + np.array(
572
+ self.idx_manager.patch_shape[1:-1], dtype=int
573
+ )
574
+ boundary_crossed = []
575
+ valid_slice = []
576
+ padding = [[0, 0]]
577
+ for start_idx, end_idx, max_len in zip(
578
+ patch_start_loc, patch_end_loc, max_len_vals
579
+ ):
580
+ boundary_crossed.append(end_idx > max_len or start_idx < 0)
581
+ valid_slice.append((max(0, start_idx), min(max_len, end_idx)))
582
+ pad = [0, 0]
583
+ if boundary_crossed[-1]:
584
+ pad = self.get_begin_end_padding(start_idx, end_idx, max_len)
585
+ padding.append(pad)
586
+ # max() is needed since h_start could be negative.
587
+ if self._5Ddata:
588
+ new_img = img[
589
+ ...,
590
+ valid_slice[0][0] : valid_slice[0][1],
591
+ valid_slice[1][0] : valid_slice[1][1],
592
+ valid_slice[2][0] : valid_slice[2][1],
593
+ ]
594
+ else:
595
+ new_img = img[
596
+ ...,
597
+ valid_slice[0][0] : valid_slice[0][1],
598
+ valid_slice[1][0] : valid_slice[1][1],
599
+ ]
600
+
601
+ # print(np.array(padding).shape, img.shape, new_img.shape)
602
+ # print(padding)
603
+ if not np.all(padding == 0):
604
+ new_img = np.pad(new_img, padding, **self._overlapping_padding_kwargs)
605
+
606
+ return new_img
607
+
608
+ def _crop_flip_img(
609
+ self, img: np.ndarray, patch_start_loc: tuple, h_flip: bool, w_flip: bool
610
+ ):
611
+ new_img = self._crop_img(img, patch_start_loc)
612
+ if h_flip:
613
+ new_img = new_img[..., ::-1, :]
614
+ if w_flip:
615
+ new_img = new_img[..., :, ::-1]
616
+
617
+ return new_img.astype(np.float32)
618
+
619
+ def _load_img(
620
+ self, index: Union[int, tuple[int, int]]
621
+ ) -> tuple[np.ndarray, np.ndarray]:
622
+ """
623
+ Returns the channels and also the respective noise channels.
624
+ """
625
+ if isinstance(index, int) or isinstance(index, np.int64):
626
+ idx = index
627
+ else:
628
+ idx = index[0]
629
+
630
+ patch_loc_list = self.idx_manager.get_patch_location_from_dataset_idx(idx)
631
+ imgs = self._data[patch_loc_list[0]]
632
+ # if self._5Ddata:
633
+ # assert self._noise_data is None, 'Noise is not supported for 5D data'
634
+ # n_loc, z_loc = patch_loc_list[:2]
635
+ # z_loc_interval = range(z_loc, z_loc + self._depth3D)
636
+ # imgs = self._data[n_loc, z_loc_interval]
637
+ # else:
638
+ # imgs = self._data[patch_loc_list[0]]
639
+
640
+ loaded_imgs = [imgs[None, ..., i] for i in range(imgs.shape[-1])]
641
+ noise = []
642
+ if self._noise_data is not None and not self._disable_noise:
643
+ noise = [
644
+ self._noise_data[patch_loc_list[0]][None, ..., i]
645
+ for i in range(self._noise_data.shape[-1])
646
+ ]
647
+ return tuple(loaded_imgs), tuple(noise)
648
+
649
+ def get_mean_std(self):
650
+ return self._mean, self._std
651
+
652
+ def set_mean_std(self, mean_val, std_val):
653
+ self._mean = mean_val
654
+ self._std = std_val
655
+
656
+ def normalize_img(self, *img_tuples):
657
+ mean, std = self.get_mean_std()
658
+ mean = mean["target"]
659
+ std = std["target"]
660
+ mean = mean.squeeze()
661
+ std = std.squeeze()
662
+ normalized_imgs = []
663
+ for i, img in enumerate(img_tuples):
664
+ img = (img - mean[i]) / std[i]
665
+ normalized_imgs.append(img)
666
+ return tuple(normalized_imgs)
667
+
668
+ def normalize_input(self, x):
669
+ mean_dict, std_dict = self.get_mean_std()
670
+ mean_ = mean_dict["input"].mean()
671
+ std_ = std_dict["input"].mean()
672
+ return (x - mean_) / std_
673
+
674
+ def normalize_target(self, target):
675
+ mean_dict, std_dict = self.get_mean_std()
676
+ mean_ = mean_dict["target"].squeeze(0)
677
+ std_ = std_dict["target"].squeeze(0)
678
+ return (target - mean_) / std_
679
+
680
+ def get_grid_size(self):
681
+ return self._grid_sz
682
+
683
+ def get_idx_manager(self):
684
+ return self.idx_manager
685
+
686
+ def per_side_overlap_pixelcount(self):
687
+ return (self._img_sz - self._grid_sz) // 2
688
+
689
+ # def on_boundary(self, cur_loc, frame_size):
690
+ # return cur_loc + self._img_sz > frame_size or cur_loc < 0
691
+
692
+ def _get_deterministic_loc(self, index: int):
693
+ """
694
+ It returns the top-left corner of the patch corresponding to index.
695
+ """
696
+ loc_list = self.idx_manager.get_patch_location_from_dataset_idx(index)
697
+ # last dim is channel. we need to take the third and the second last element.
698
+ return loc_list[1:-1]
699
+
700
+ def compute_individual_mean_std(self):
701
+ # numpy 1.19.2 has issues in computing for large arrays. https://github.com/numpy/numpy/issues/8869
702
+ # mean = np.mean(self._data, axis=(0, 1, 2))
703
+ # std = np.std(self._data, axis=(0, 1, 2))
704
+ mean_arr = []
705
+ std_arr = []
706
+ for ch_idx in range(self._data.shape[-1]):
707
+ mean_ = (
708
+ 0.0
709
+ if self._skip_normalization_using_mean
710
+ else self._data[..., ch_idx].mean()
711
+ )
712
+ if self._noise_data is not None:
713
+ std_ = (
714
+ self._data[..., ch_idx] + self._noise_data[..., ch_idx + 1]
715
+ ).std()
716
+ else:
717
+ std_ = self._data[..., ch_idx].std()
718
+
719
+ mean_arr.append(mean_)
720
+ std_arr.append(std_)
721
+
722
+ mean = np.array(mean_arr)
723
+ std = np.array(std_arr)
724
+ if (
725
+ self._5Ddata
726
+ ): # NOTE: IDEALLY this should be only when the model expects 3D data.
727
+ return mean[None, :, None, None, None], std[None, :, None, None, None]
728
+
729
+ return mean[None, :, None, None], std[None, :, None, None]
730
+
731
+ def compute_mean_std(self, allow_for_validation_data=False):
732
+ """
733
+ Note that we must compute this only for training data.
734
+ """
735
+ assert (
736
+ self._is_train is True or allow_for_validation_data
737
+ ), "This is just allowed for training data"
738
+ assert self._use_one_mu_std is True, "This is the only supported case"
739
+
740
+ if self._input_idx is not None:
741
+ assert (
742
+ self._tar_idx_list is not None
743
+ ), "tar_idx_list must be set if input_idx is set."
744
+ assert self._noise_data is None, "This is not supported with noise"
745
+ assert (
746
+ self._target_separate_normalization is True
747
+ ), "This is not supported with target_separate_normalization=False"
748
+
749
+ mean, std = self.compute_individual_mean_std()
750
+ mean_dict = {
751
+ "input": mean[:, self._input_idx : self._input_idx + 1],
752
+ "target": mean[:, self._tar_idx_list],
753
+ }
754
+ std_dict = {
755
+ "input": std[:, self._input_idx : self._input_idx + 1],
756
+ "target": std[:, self._tar_idx_list],
757
+ }
758
+ return mean_dict, std_dict
759
+
760
+ if self._input_is_sum:
761
+ assert self._noise_data is None, "This is not supported with noise"
762
+ mean = [
763
+ np.mean(self._data[..., k : k + 1], keepdims=True)
764
+ for k in range(self._num_channels)
765
+ ]
766
+ mean = np.sum(mean, keepdims=True)[0]
767
+ std = np.linalg.norm(
768
+ [
769
+ np.std(self._data[..., k : k + 1], keepdims=True)
770
+ for k in range(self._num_channels)
771
+ ],
772
+ keepdims=True,
773
+ )[0]
774
+ else:
775
+ mean = np.mean(self._data, keepdims=True).reshape(1, 1, 1, 1)
776
+ if self._noise_data is not None:
777
+ std = np.std(
778
+ self._data + self._noise_data[..., 1:], keepdims=True
779
+ ).reshape(1, 1, 1, 1)
780
+ else:
781
+ std = np.std(self._data, keepdims=True).reshape(1, 1, 1, 1)
782
+
783
+ mean = np.repeat(mean, self._num_channels, axis=1)
784
+ std = np.repeat(std, self._num_channels, axis=1)
785
+
786
+ if self._skip_normalization_using_mean:
787
+ mean = np.zeros_like(mean)
788
+
789
+ if self._5Ddata:
790
+ mean = mean[:, :, None]
791
+ std = std[:, :, None]
792
+
793
+ mean_dict = {"input": mean} # , 'target':mean}
794
+ std_dict = {"input": std} # , 'target':std}
795
+
796
+ if self._target_separate_normalization:
797
+ mean, std = self.compute_individual_mean_std()
798
+
799
+ mean_dict["target"] = mean
800
+ std_dict["target"] = std
801
+ return mean_dict, std_dict
802
+
803
+ def _get_random_hw(self, h: int, w: int):
804
+ """
805
+ Random starting position for the crop for the img with index `index`.
806
+ """
807
+ if h != self._img_sz:
808
+ h_start = np.random.choice(h - self._img_sz)
809
+ w_start = np.random.choice(w - self._img_sz)
810
+ else:
811
+ h_start = 0
812
+ w_start = 0
813
+ return h_start, w_start
814
+
815
+ def _get_img(self, index: Union[int, tuple[int, int]]):
816
+ """
817
+ Loads an image.
818
+ Crops the image such that cropped image has content.
819
+ """
820
+ img_tuples, noise_tuples = self._load_img(index)
821
+ cropped_img_tuples = self._crop_imgs(index, *img_tuples, *noise_tuples)[:-1]
822
+ cropped_noise_tuples = cropped_img_tuples[len(img_tuples) :]
823
+ cropped_img_tuples = cropped_img_tuples[: len(img_tuples)]
824
+ return cropped_img_tuples, cropped_noise_tuples
825
+
826
+ def replace_with_empty_patch(self, img_tuples):
827
+ """
828
+ Replaces the content of one of the channels with background
829
+ """
830
+ empty_index = self._empty_patch_fetcher.sample()
831
+ empty_img_tuples, empty_img_noise_tuples = self._get_img(empty_index)
832
+ assert (
833
+ len(empty_img_noise_tuples) == 0
834
+ ), "Noise is not supported with empty patch replacement"
835
+ final_img_tuples = []
836
+ for tuple_idx in range(len(img_tuples)):
837
+ if tuple_idx == self._empty_patch_replacement_channel_idx:
838
+ final_img_tuples.append(empty_img_tuples[tuple_idx])
839
+ else:
840
+ final_img_tuples.append(img_tuples[tuple_idx])
841
+ return tuple(final_img_tuples)
842
+
843
+ def get_mean_std_for_input(self):
844
+ mean, std = self.get_mean_std()
845
+ return mean["input"], std["input"]
846
+
847
+ def _compute_target(self, img_tuples, alpha):
848
+ if self._tar_idx_list is not None and isinstance(self._tar_idx_list, int):
849
+ target = img_tuples[self._tar_idx_list]
850
+ else:
851
+ if self._tar_idx_list is not None:
852
+ assert isinstance(self._tar_idx_list, list) or isinstance(
853
+ self._tar_idx_list, tuple
854
+ )
855
+ img_tuples = [img_tuples[i] for i in self._tar_idx_list]
856
+
857
+ target = np.concatenate(img_tuples, axis=0)
858
+ return target
859
+
860
+ def _compute_input_with_alpha(self, img_tuples, alpha_list):
861
+ # assert self._normalized_input is True, "normalization should happen here"
862
+ if self._input_idx is not None:
863
+ inp = img_tuples[self._input_idx]
864
+ else:
865
+ inp = 0
866
+ for alpha, img in zip(alpha_list, img_tuples):
867
+ inp += img * alpha
868
+
869
+ if self._normalized_input is False:
870
+ return inp.astype(np.float32)
871
+
872
+ mean, std = self.get_mean_std_for_input()
873
+ mean = mean.squeeze()
874
+ std = std.squeeze()
875
+ if mean.size == 1:
876
+ mean = mean.reshape(
877
+ 1,
878
+ )
879
+ std = std.reshape(
880
+ 1,
881
+ )
882
+
883
+ for i in range(len(mean)):
884
+ assert mean[0] == mean[i]
885
+ assert std[0] == std[i]
886
+
887
+ inp = (inp - mean[0]) / std[0]
888
+ return inp.astype(np.float32)
889
+
890
+ def _sample_alpha(self):
891
+ alpha_arr = []
892
+ for i in range(self._num_channels):
893
+ alpha_pos = np.random.rand()
894
+ alpha = self._start_alpha_arr[i] + alpha_pos * (
895
+ self._end_alpha_arr[i] - self._start_alpha_arr[i]
896
+ )
897
+ alpha_arr.append(alpha)
898
+ return alpha_arr
899
+
900
+ def _compute_input(self, img_tuples):
901
+ alpha = [1 / len(img_tuples) for _ in range(len(img_tuples))]
902
+ if self._start_alpha_arr is not None:
903
+ alpha = self._sample_alpha()
904
+
905
+ inp = self._compute_input_with_alpha(img_tuples, alpha)
906
+ if self._input_is_sum:
907
+ inp = len(img_tuples) * inp
908
+ return inp, alpha
909
+
910
+ def _get_index_from_valid_target_logic(self, index):
911
+ if self._validtarget_rand_fract is not None:
912
+ if np.random.rand() < self._validtarget_rand_fract:
913
+ index = self._train_index_switcher.get_valid_target_index()
914
+ else:
915
+ index = self._train_index_switcher.get_invalid_target_index()
916
+ return index
917
+
918
+ def _rotate2D(self, img_tuples, noise_tuples):
919
+ img_kwargs = {}
920
+ for i, img in enumerate(img_tuples):
921
+ for k in range(len(img)):
922
+ img_kwargs[f"img{i}_{k}"] = img[k]
923
+
924
+ noise_kwargs = {}
925
+ for i, nimg in enumerate(noise_tuples):
926
+ for k in range(len(nimg)):
927
+ noise_kwargs[f"noise{i}_{k}"] = nimg[k]
928
+
929
+ keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
930
+ self._rotation_transform.add_targets({k: "image" for k in keys})
931
+ rot_dic = self._rotation_transform(
932
+ image=img_tuples[0][0], **img_kwargs, **noise_kwargs
933
+ )
934
+
935
+ rotated_img_tuples = []
936
+ for i, img in enumerate(img_tuples):
937
+ if len(img) == 1:
938
+ rotated_img_tuples.append(rot_dic[f"img{i}_0"][None])
939
+ else:
940
+ rotated_img_tuples.append(
941
+ np.concatenate(
942
+ [rot_dic[f"img{i}_{k}"][None] for k in range(len(img))], axis=0
943
+ )
944
+ )
945
+
946
+ rotated_noise_tuples = []
947
+ for i, nimg in enumerate(noise_tuples):
948
+ if len(nimg) == 1:
949
+ rotated_noise_tuples.append(rot_dic[f"noise{i}_0"][None])
950
+ else:
951
+ rotated_noise_tuples.append(
952
+ np.concatenate(
953
+ [rot_dic[f"noise{i}_{k}"][None] for k in range(len(nimg))],
954
+ axis=0,
955
+ )
956
+ )
957
+
958
+ return rotated_img_tuples, rotated_noise_tuples
959
+
960
+ def _rotate(self, img_tuples, noise_tuples):
961
+
962
+ if self._5Ddata:
963
+ return self._rotate3D(img_tuples, noise_tuples)
964
+ else:
965
+ return self._rotate2D(img_tuples, noise_tuples)
966
+
967
+ def _rotate3D(self, img_tuples, noise_tuples):
968
+ img_kwargs = {}
969
+ # random flip in z direction
970
+ flip_z = self._flipz_3D and np.random.rand() < 0.5
971
+ for i, img in enumerate(img_tuples):
972
+ for j in range(self._depth3D):
973
+ for k in range(len(img)):
974
+ if flip_z:
975
+ z_idx = self._depth3D - 1 - j
976
+ else:
977
+ z_idx = j
978
+ img_kwargs[f"img{i}_{z_idx}_{k}"] = img[k, j]
979
+
980
+ noise_kwargs = {}
981
+ for i, nimg in enumerate(noise_tuples):
982
+ for j in range(self._depth3D):
983
+ for k in range(len(nimg)):
984
+ if flip_z:
985
+ z_idx = self._depth3D - 1 - j
986
+ else:
987
+ z_idx = j
988
+ noise_kwargs[f"noise{i}_{z_idx}_{k}"] = nimg[k, j]
989
+
990
+ keys = list(img_kwargs.keys()) + list(noise_kwargs.keys())
991
+ self._rotation_transform.add_targets({k: "image" for k in keys})
992
+ rot_dic = self._rotation_transform(
993
+ image=img_tuples[0][0][0], **img_kwargs, **noise_kwargs
994
+ )
995
+ rotated_img_tuples = []
996
+ for i, img in enumerate(img_tuples):
997
+ if len(img) == 1:
998
+ rotated_img_tuples.append(
999
+ np.concatenate(
1000
+ [
1001
+ rot_dic[f"img{i}_{j}_0"][None, None]
1002
+ for j in range(self._depth3D)
1003
+ ],
1004
+ axis=1,
1005
+ )
1006
+ )
1007
+ else:
1008
+ temp_arr = []
1009
+ for k in range(len(img)):
1010
+ temp_arr.append(
1011
+ np.concatenate(
1012
+ [
1013
+ rot_dic[f"img{i}_{j}_{k}"][None, None]
1014
+ for j in range(self._depth3D)
1015
+ ],
1016
+ axis=1,
1017
+ )
1018
+ )
1019
+ rotated_img_tuples.append(np.concatenate(temp_arr, axis=0))
1020
+
1021
+ rotated_noise_tuples = []
1022
+ for i, nimg in enumerate(noise_tuples):
1023
+ if len(nimg) == 1:
1024
+ rotated_noise_tuples.append(
1025
+ np.concatenate(
1026
+ [
1027
+ rot_dic[f"noise{i}_{j}_0"][None, None]
1028
+ for j in range(self._depth3D)
1029
+ ],
1030
+ axis=1,
1031
+ )
1032
+ )
1033
+ else:
1034
+ temp_arr = []
1035
+ for k in range(len(nimg)):
1036
+ temp_arr.append(
1037
+ np.concatenate(
1038
+ [
1039
+ rot_dic[f"noise{i}_{j}_{k}"][None, None]
1040
+ for j in range(self._depth3D)
1041
+ ],
1042
+ axis=1,
1043
+ )
1044
+ )
1045
+ rotated_noise_tuples.append(np.concatenate(temp_arr, axis=0))
1046
+
1047
+ return rotated_img_tuples, rotated_noise_tuples
1048
+
1049
+ def get_uncorrelated_img_tuples(self, index):
1050
+ """
1051
+ Content of channels like actin and nuclei is "correlated" in its
1052
+ respective location, this function allows to pick channels' content
1053
+ from different patches of the image to make it "uncorrelated".
1054
+ """
1055
+ img_tuples, noise_tuples = self._get_img(index)
1056
+ assert len(noise_tuples) == 0
1057
+ img_tuples = [img_tuples[0]]
1058
+ for ch_idx in range(1, len(img_tuples)):
1059
+ new_index = np.random.randint(len(self))
1060
+ other_img_tuples, _ = self._get_img(new_index)
1061
+ img_tuples.append(other_img_tuples[ch_idx])
1062
+ return img_tuples, noise_tuples
1063
+
1064
+ def __getitem__(
1065
+ self, index: Union[int, tuple[int, int]]
1066
+ ) -> tuple[np.ndarray, np.ndarray]:
1067
+ # Vera: input can be both real microscopic image and two separate channels that are summed in the code
1068
+
1069
+ if self._train_index_switcher is not None:
1070
+ index = self._get_index_from_valid_target_logic(index)
1071
+
1072
+ if (
1073
+ self._uncorrelated_channels
1074
+ and np.random.rand() < self._uncorrelated_channel_probab
1075
+ ):
1076
+ img_tuples, noise_tuples = self.get_uncorrelated_img_tuples(index)
1077
+ else:
1078
+ img_tuples, noise_tuples = self._get_img(index)
1079
+
1080
+ assert (
1081
+ self._empty_patch_replacement_enabled != True
1082
+ ), "This is not supported with noise"
1083
+
1084
+ # Replace the content of one of the channels
1085
+ # with background with given probability
1086
+ if self._empty_patch_replacement_enabled:
1087
+ if np.random.rand() < self._empty_patch_replacement_probab:
1088
+ img_tuples = self.replace_with_empty_patch(img_tuples)
1089
+
1090
+ # Noise tuples are not needed for the paper
1091
+ # the image tuples are noisy by default
1092
+ # TODO: remove noise tuples completely?
1093
+ if self._enable_rotation:
1094
+ img_tuples, noise_tuples = self._rotate(img_tuples, noise_tuples)
1095
+
1096
+ # Add noise tuples with image tuples to create the input
1097
+ if len(noise_tuples) > 0:
1098
+ factor = np.sqrt(2) if self._input_is_sum else 1.0
1099
+ input_tuples = [x + noise_tuples[0] * factor for x in img_tuples]
1100
+ else:
1101
+ input_tuples = img_tuples
1102
+
1103
+ # Weight the individual channels, typically alpha is fixed
1104
+ inp, alpha = self._compute_input(input_tuples)
1105
+
1106
+ # Add noise tuples to the image tuples to create the target
1107
+ if len(noise_tuples) >= 1:
1108
+ img_tuples = [x + noise for x, noise in zip(img_tuples, noise_tuples[1:])]
1109
+
1110
+ target = self._compute_target(img_tuples, alpha)
1111
+ norm_target = self.normalize_target(target)
1112
+
1113
+ output = [inp, norm_target]
1114
+
1115
+ if self._return_alpha:
1116
+ output.append(alpha)
1117
+
1118
+ if self._return_index:
1119
+ output.append(index)
1120
+
1121
+ return tuple(output)