careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,65 @@
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+
4
+
5
+ class EmptyPatchFetcher:
6
+ """
7
+ The idea is to fetch empty patches so that real content can be replaced with this.
8
+ """
9
+
10
+ def __init__(self, idx_manager, patch_size, data_frames, max_val_threshold=None):
11
+ self._frames = data_frames
12
+ self._idx_manager = idx_manager
13
+ self._max_val_threshold = max_val_threshold
14
+ self._idx_list = []
15
+ self._patch_size = patch_size
16
+ self._grid_size = 1
17
+ self.set_empty_idx()
18
+
19
+ print(f"[{self.__class__.__name__}] MaxVal:{self._max_val_threshold}")
20
+
21
+ def compute_max(self, window):
22
+ """
23
+ Rolling compute.
24
+ """
25
+ N, H, W = self._frames.shape
26
+ randnum = -954321
27
+ assert self._grid_size == 1
28
+ max_data = np.zeros((N, H - window, W - window)) * randnum
29
+
30
+ for h in tqdm(range(H - window)):
31
+ for w in range(W - window):
32
+ max_data[:, h, w] = self._frames[:, h : h + window, w : w + window].max(
33
+ axis=(1, 2)
34
+ )
35
+
36
+ assert (max_data != 954321).any()
37
+ return max_data
38
+
39
+ def set_empty_idx(self):
40
+ max_data = self.compute_max(self._patch_size)
41
+ empty_loc = np.where(
42
+ np.logical_and(max_data >= 0, max_data < self._max_val_threshold)
43
+ )
44
+ # print(max_data.shape, len(empty_loc))
45
+ self._idx_list = []
46
+ for idx in range(len(empty_loc[0])):
47
+ n_idx = empty_loc[0][idx]
48
+ h_start = empty_loc[1][idx]
49
+ w_start = empty_loc[2][idx]
50
+ # print(n_idx,h_start,w_start)
51
+ # channel_idx = self._idx_manager.get_location_from_dataset_idx(0)[-1]
52
+ loc = (n_idx, h_start, w_start, 0)
53
+ idx = self._idx_manager.get_dataset_idx_from_location(loc)
54
+ t, h, w, _ = self._idx_manager.get_location_from_dataset_idx(idx)
55
+ assert h == h_start, f"{h} != {h_start}"
56
+ assert w == w_start, f"{w} != {w_start}"
57
+ assert t == n_idx, f"{t} != {n_idx}"
58
+ self._idx_list.append(idx)
59
+
60
+ self._idx_list = np.array(self._idx_list)
61
+
62
+ assert len(self._idx_list) > 0
63
+
64
+ def sample(self):
65
+ return (np.random.choice(self._idx_list), self._grid_size)
@@ -0,0 +1,491 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+
5
+ from careamics.lvae_training.dataset.types import TilingMode
6
+
7
+
8
+ @dataclass
9
+ class GridIndexManager:
10
+ data_shape: tuple
11
+ grid_shape: tuple
12
+ patch_shape: tuple
13
+ tiling_mode: TilingMode
14
+
15
+ # Patch is centered on index in the grid, grid size not used in training,
16
+ # used only during val / test, grid size controls the overlap of the patches
17
+ # in training you only get random patches every time
18
+ # For borders - just cropped the data, so it perfectly divisible
19
+
20
+ def __post_init__(self):
21
+ assert len(self.data_shape) == len(
22
+ self.grid_shape
23
+ ), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
24
+ assert len(self.data_shape) == len(
25
+ self.patch_shape
26
+ ), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
27
+ innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
28
+ for dim, pad in enumerate(innerpad):
29
+ if pad < 0:
30
+ raise ValueError(
31
+ f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}"
32
+ )
33
+ if pad % 2 != 0:
34
+ raise ValueError(
35
+ f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}"
36
+ )
37
+
38
+ def patch_offset(self):
39
+ return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
40
+
41
+ def get_individual_dim_grid_count(self, dim: int):
42
+ """
43
+ Returns the number of the grid in the specified dimension, ignoring all other dimensions.
44
+ """
45
+ assert dim < len(
46
+ self.data_shape
47
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
48
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
49
+
50
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
51
+ return self.data_shape[dim]
52
+ elif self.tiling_mode == TilingMode.PadBoundary:
53
+ return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
54
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
55
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
56
+ return int(
57
+ np.ceil((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
58
+ )
59
+ else:
60
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
61
+ return int(
62
+ np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
63
+ )
64
+
65
+ def total_grid_count(self):
66
+ """
67
+ Returns the total number of grids in the dataset.
68
+ """
69
+ return self.grid_count(0) * self.get_individual_dim_grid_count(0)
70
+
71
+ def grid_count(self, dim: int):
72
+ """
73
+ Returns the total number of grids for one value in the specified dimension.
74
+ """
75
+ assert dim < len(
76
+ self.data_shape
77
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
78
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
79
+ if dim == len(self.data_shape) - 1:
80
+ return 1
81
+
82
+ return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
83
+
84
+ def get_grid_index(self, dim: int, coordinate: int):
85
+ """
86
+ Returns the index of the grid in the specified dimension.
87
+ """
88
+ assert dim < len(
89
+ self.data_shape
90
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
91
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
92
+ assert (
93
+ coordinate < self.data_shape[dim]
94
+ ), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
95
+
96
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
97
+ return coordinate
98
+ elif self.tiling_mode == TilingMode.PadBoundary: # self.trim_boundary is False:
99
+ return np.floor(coordinate / self.grid_shape[dim])
100
+ elif self.tiling_mode == TilingMode.TrimBoundary:
101
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
102
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
103
+ return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
104
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
105
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
106
+ if coordinate + self.grid_shape[dim] + excess_size == self.data_shape[dim]:
107
+ return self.get_individual_dim_grid_count(dim) - 1
108
+ else:
109
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
110
+ return max(
111
+ 0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
112
+ )
113
+
114
+ else:
115
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
116
+
117
+ def dataset_idx_from_grid_idx(self, grid_idx: tuple):
118
+ """
119
+ Returns the index of the grid in the dataset.
120
+ """
121
+ assert len(grid_idx) == len(
122
+ self.data_shape
123
+ ), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
124
+ index = 0
125
+ for dim in range(len(grid_idx)):
126
+ index += grid_idx[dim] * self.grid_count(dim)
127
+ return index
128
+
129
+ def get_patch_location_from_dataset_idx(self, dataset_idx: int):
130
+ """
131
+ Returns the patch location of the grid in the dataset.
132
+ """
133
+ grid_location = self.get_location_from_dataset_idx(dataset_idx)
134
+ offset = self.patch_offset()
135
+ return tuple(np.array(grid_location) - np.array(offset))
136
+
137
+ def get_dataset_idx_from_grid_location(self, location: tuple):
138
+ assert len(location) == len(
139
+ self.data_shape
140
+ ), f"Location {location} must have the same dimension as data shape {self.data_shape}"
141
+ grid_idx = [
142
+ self.get_grid_index(dim, location[dim]) for dim in range(len(location))
143
+ ]
144
+ return self.dataset_idx_from_grid_idx(tuple(grid_idx))
145
+
146
+ def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
147
+ """
148
+ Returns the grid-start coordinate of the grid in the specified dimension.
149
+ """
150
+ assert dim < len(
151
+ self.data_shape
152
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
153
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
154
+ # assert dim_index < self.get_individual_dim_grid_count(
155
+ # dim
156
+ # ), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
157
+ # TODO comented out this shit cuz I have no interest to dig why it's failing at this point !
158
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
159
+ return dim_index
160
+ elif self.tiling_mode == TilingMode.PadBoundary:
161
+ return dim_index * self.grid_shape[dim]
162
+ elif self.tiling_mode == TilingMode.TrimBoundary:
163
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
164
+ return dim_index * self.grid_shape[dim] + excess_size
165
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
166
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
167
+ if dim_index < self.get_individual_dim_grid_count(dim) - 1:
168
+ return dim_index * self.grid_shape[dim] + excess_size
169
+ else:
170
+ # on boundary. grid should be placed such that the patch covers the entire data.
171
+ return self.data_shape[dim] - self.grid_shape[dim] - excess_size
172
+ else:
173
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
174
+
175
+ def get_location_from_dataset_idx(self, dataset_idx: int):
176
+ """
177
+ Returns the start location of the grid in the dataset.
178
+ """
179
+ grid_idx = []
180
+ for dim in range(len(self.data_shape)):
181
+ grid_idx.append(dataset_idx // self.grid_count(dim))
182
+ dataset_idx = dataset_idx % self.grid_count(dim)
183
+ location = [
184
+ self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
185
+ for dim in range(len(self.data_shape))
186
+ ]
187
+ return tuple(location)
188
+
189
+ def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
190
+ """
191
+ Returns True if the grid is on the boundary in the specified dimension.
192
+ """
193
+ assert dim < len(
194
+ self.data_shape
195
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
196
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
197
+
198
+ if dim > 0:
199
+ dataset_idx = dataset_idx % self.grid_count(dim - 1)
200
+
201
+ dim_index = dataset_idx // self.grid_count(dim)
202
+ if only_end:
203
+ return dim_index == self.get_individual_dim_grid_count(dim) - 1
204
+
205
+ return (
206
+ dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
207
+ )
208
+
209
+ def next_grid_along_dim(self, dataset_idx: int, dim: int):
210
+ """
211
+ Returns the index of the grid in the specified dimension in the specified direction.
212
+ """
213
+ assert dim < len(
214
+ self.data_shape
215
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
216
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
217
+ new_idx = dataset_idx + self.grid_count(dim)
218
+ if new_idx >= self.total_grid_count():
219
+ return None
220
+ return new_idx
221
+
222
+ def prev_grid_along_dim(self, dataset_idx: int, dim: int):
223
+ """
224
+ Returns the index of the grid in the specified dimension in the specified direction.
225
+ """
226
+ assert dim < len(
227
+ self.data_shape
228
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
229
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
230
+ new_idx = dataset_idx - self.grid_count(dim)
231
+ if new_idx < 0:
232
+ return None
233
+
234
+
235
+ @dataclass
236
+ class GridIndexManagerRef:
237
+ data_shapes: tuple
238
+ grid_shape: tuple
239
+ patch_shape: tuple
240
+ tiling_mode: TilingMode
241
+
242
+ # This class is used to calculate and store information about patches, and calculate
243
+ # the total length of the dataset in patches.
244
+ # It introduces a concept of a grid, to which input images are split.
245
+ # The grid is defined by the grid_shape and patch_shape, with former controlling the
246
+ # overlap.
247
+ # In this reimplementation it can accept multiple channels with different lengths,
248
+ # and every image can have different shape.
249
+
250
+ def __post_init__(self):
251
+ if len(self.data_shapes) > 1:
252
+ assert {len(ds) for ds in self.data_shapes[0]}.pop() == {
253
+ len(ds) for ds in self.data_shapes[1]
254
+ }.pop(), "Data shape for all channels must be the same" # TODO better way to assert this
255
+ assert {len(ds) for ds in self.data_shapes[0]}.pop() == len(
256
+ self.grid_shape
257
+ ), "Data shape and grid size must have the same dimension"
258
+ assert {len(ds) for ds in self.data_shapes[0]}.pop() == len(
259
+ self.patch_shape
260
+ ), "Data shape and patch shape must have the same dimension"
261
+ innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
262
+ for dim, pad in enumerate(innerpad):
263
+ if pad < 0:
264
+ raise ValueError(
265
+ f"Patch shape must be greater than or equal to grid shape in dimension {dim}"
266
+ )
267
+ if pad % 2 != 0:
268
+ raise ValueError(
269
+ f"Patch shape must have even padding in dimension {dim}"
270
+ )
271
+ self.num_patches_per_channel = self.total_grid_count()[1]
272
+
273
+ def patch_offset(self):
274
+ return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
275
+
276
+ def get_individual_dim_grid_count(self, shape: tuple, dim: int):
277
+ """
278
+ Returns the number of the grid in the specified dimension, ignoring all other dimensions.
279
+ """
280
+ # assert that dim is less than the number of dimensions in data shape
281
+
282
+ # if dim > len()
283
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
284
+ return shape[dim]
285
+ elif self.tiling_mode == TilingMode.PadBoundary:
286
+ return int(np.ceil(shape[dim] / self.grid_shape[dim]))
287
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
288
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
289
+ return int(np.ceil((shape[dim] - excess_size) / self.grid_shape[dim]))
290
+ # if dim_index < self.get_individual_dim_grid_count(dim) - 1:
291
+ # return dim_index * self.grid_shape[dim] + excess_size
292
+ # on boundary. grid should be placed such that the patch covers the entire data.
293
+ # return self.data_shape[dim] - self.grid_shape[dim] - excess_size
294
+ else:
295
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
296
+ return int(np.floor((shape[dim] - excess_size) / self.grid_shape[dim]))
297
+
298
+ def total_grid_count(self):
299
+ """Returns the total number of patches in the dataset."""
300
+ len_per_channel = []
301
+ num_patches_per_sample = []
302
+ for channel_data in self.data_shapes:
303
+ num_patches = []
304
+ for file_shape in channel_data:
305
+ num_patches.append(np.prod(self.grid_count_per_sample(file_shape)))
306
+ len_per_channel.append(np.sum(num_patches))
307
+ num_patches_per_sample.append(num_patches)
308
+
309
+ return len_per_channel, num_patches_per_sample
310
+
311
+ def grid_count_per_sample(self, shape: tuple):
312
+ """Returns the total number of patches for one dimension."""
313
+ grid_count = []
314
+ for dim in range(len(shape)):
315
+ grid_count.append(self.get_individual_dim_grid_count(shape, dim))
316
+ return grid_count
317
+
318
+ def get_grid_index(self, shape, dim: int, coordinate: int):
319
+ """Returns the index of the patch in the specified dimension."""
320
+ assert dim < len(
321
+ shape
322
+ ), f"Dimension {dim} is out of bounds for data shape {shape}"
323
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
324
+ assert (
325
+ coordinate < shape[dim]
326
+ ), f"Coordinate {coordinate} is out of bounds for data shape {shape}"
327
+
328
+ if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
329
+ return coordinate
330
+ elif self.tiling_mode == TilingMode.PadBoundary: # self.trim_boundary is False:
331
+ return np.floor(coordinate / self.grid_shape[dim])
332
+ elif self.tiling_mode == TilingMode.TrimBoundary:
333
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
334
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
335
+ return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
336
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
337
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
338
+ if coordinate + self.grid_shape[dim] + excess_size == self.data_shapes[dim]:
339
+ return self.get_individual_dim_grid_count(shape, dim) - 1
340
+ else:
341
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
342
+ return max(
343
+ 0, np.floor((coordinate - excess_size) / self.grid_shape[dim])
344
+ )
345
+
346
+ else:
347
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
348
+
349
+ def patch_idx_from_grid_idx(self, shape: tuple, grid_idx: tuple):
350
+ """Returns the index of the patch in the dataset."""
351
+ assert len(grid_idx) == len(
352
+ shape
353
+ ), f"Dimension indices {grid_idx} must have the same dimension as data shape {shape}"
354
+ index = 0
355
+ for dim in range(len(grid_idx)):
356
+ index += grid_idx[dim] * self.grid_count(shape, dim)
357
+ return index
358
+
359
+ def get_patch_location_from_patch_idx(self, ch_idx: int, patch_idx: int):
360
+ """Returns the patch location of the grid in the dataset."""
361
+ grid_location = self.get_location_from_patch_idx(ch_idx, patch_idx)
362
+ offset = self.patch_offset()
363
+ return tuple(np.array(grid_location) - np.concatenate((np.array((0,)), offset)))
364
+
365
+ def get_patch_idx_from_grid_location(self, shape, location: tuple):
366
+ assert len(location) == len(
367
+ shape
368
+ ), f"Location {location} must have the same dimension as data shape {shape}"
369
+ grid_idx = [
370
+ self.get_grid_index(dim, location[dim]) for dim in range(len(location))
371
+ ]
372
+ return self.patch_idx_from_grid_idx(tuple(grid_idx))
373
+
374
+ def get_gridstart_location_from_dim_index(
375
+ self, shape: tuple, dim_idx: int, dim: int
376
+ ):
377
+ """Returns the grid-start coordinate of the grid in the specified dimension.
378
+
379
+ dim_idx: int
380
+ Index of the dimension in the data shape.
381
+ dim: int
382
+ Value of the dimension in the grid (relative to num patches in dimension).
383
+ """
384
+ if self.grid_shape[dim_idx] == 1 and self.patch_shape[dim_idx] == 1:
385
+ return dim_idx
386
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
387
+ excess_size = (self.patch_shape[dim_idx] - self.grid_shape[dim_idx]) // 2
388
+ if dim < self.get_individual_dim_grid_count(shape, dim_idx) - 1:
389
+ return dim * self.grid_shape[dim_idx] + excess_size
390
+ else:
391
+ # on boundary. grid should be placed such that the patch covers the entire data.
392
+ return shape[dim_idx] - self.grid_shape[dim_idx] - excess_size
393
+ else:
394
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
395
+
396
+ def get_location_from_patch_idx(self, channel_idx: int, patch_idx: int):
397
+ """
398
+ Returns the start location of the grid in the dataset. Per channel!.
399
+
400
+ Parameters
401
+ ----------
402
+ patch_idx : int
403
+ The index of the patch in a list of samples within a channel. Channels can
404
+ be different in length.
405
+ """
406
+ # TODO assert patch_idx <= num of patches in the channel
407
+ # create cumulative sum of the grid counts for each channel
408
+ cumulative_indices = np.cumsum(self.total_grid_count()[1][channel_idx])
409
+ # find the channel index
410
+ sample_idx = np.searchsorted(cumulative_indices, patch_idx, side="right")
411
+ sample_shape = self.data_shapes[channel_idx][sample_idx]
412
+ # TODO duplicated runs, revisit
413
+ # ingoring the channel dimension because we index it explicitly
414
+ grid_count = self.grid_count_per_sample(sample_shape)[1:]
415
+
416
+ grid_idx = []
417
+ for i in range(len(grid_count) - 1, -1, -1):
418
+ stride = np.prod(grid_count[:i]) if i > 0 else 1
419
+ grid_idx.insert(0, patch_idx // stride)
420
+ patch_idx %= stride
421
+ # TODO check for 3D !
422
+ # adding channel index
423
+ grid_idx = [channel_idx] + grid_idx
424
+ location = [
425
+ sample_idx,
426
+ ] + [
427
+ self.get_gridstart_location_from_dim_index(
428
+ shape=sample_shape, dim_idx=dim_idx, dim=grid_idx[dim_idx]
429
+ )
430
+ for dim_idx in range(len(grid_idx))
431
+ ]
432
+ return tuple(location)
433
+
434
+ def get_location_from_patch_idx_o(self, dataset_idx: int):
435
+ """
436
+ Returns the start location of the grid in the dataset.
437
+ """
438
+ grid_idx = []
439
+ for dim in range(len(self.data_shape)):
440
+ grid_idx.append(dataset_idx // self.grid_count(dim))
441
+ dataset_idx = dataset_idx % self.grid_count(dim)
442
+ location = [
443
+ self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
444
+ for dim in range(len(self.data_shape))
445
+ ]
446
+ return tuple(location)
447
+
448
+ def on_boundary(self, dataset_idx: int, dim: int, only_end: bool = False):
449
+ """
450
+ Returns True if the grid is on the boundary in the specified dimension.
451
+ """
452
+ assert dim < len(
453
+ self.data_shapes
454
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
455
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
456
+
457
+ if dim > 0:
458
+ dataset_idx = dataset_idx % self.grid_count(dim - 1)
459
+
460
+ dim_index = dataset_idx // self.grid_count(dim)
461
+ if only_end:
462
+ return dim_index == self.get_individual_dim_grid_count(dim) - 1
463
+
464
+ return (
465
+ dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
466
+ )
467
+
468
+ def next_grid_along_dim(self, dataset_idx: int, dim: int):
469
+ """
470
+ Returns the index of the grid in the specified dimension in the specified direction.
471
+ """
472
+ assert dim < len(
473
+ self.data_shapes
474
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
475
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
476
+ new_idx = dataset_idx + self.grid_count(dim)
477
+ if new_idx >= self.total_grid_count():
478
+ return None
479
+ return new_idx
480
+
481
+ def prev_grid_along_dim(self, dataset_idx: int, dim: int):
482
+ """
483
+ Returns the index of the grid in the specified dimension in the specified direction.
484
+ """
485
+ assert dim < len(
486
+ self.data_shapes
487
+ ), f"Dimension {dim} is out of bounds for data shape {self.data_shapes}"
488
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
489
+ new_idx = dataset_idx - self.grid_count(dim)
490
+ if new_idx < 0:
491
+ return None