careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,33 @@
1
+ """Collate function for tiling."""
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from careamics.config.data.tile_information import TileInformation
9
+
10
+
11
+ def collate_tiles(batch: list[tuple[np.ndarray, TileInformation]]) -> Any:
12
+ """
13
+ Collate tiles received from CAREamics prediction dataloader.
14
+
15
+ CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
16
+ case of non-tiled data, this function will return the arrays. In case of tiled data,
17
+ it will return the arrays, the last tile flag, the overlap crop coordinates and the
18
+ stitch coordinates.
19
+
20
+ Parameters
21
+ ----------
22
+ batch : list[tuple[np.ndarray, TileInformation], ...]
23
+ Batch of tiles.
24
+
25
+ Returns
26
+ -------
27
+ Any
28
+ Collated batch.
29
+ """
30
+ new_batch = [tile for tile, _ in batch]
31
+ tiles_batch = [tile_info for _, tile_info in batch]
32
+
33
+ return default_collate(new_batch), tiles_batch
@@ -0,0 +1,375 @@
1
+ """Functions to reimplement the tiling in the Disentangle repository."""
2
+
3
+ import builtins
4
+ import itertools
5
+ from collections.abc import Generator
6
+ from typing import Any, Union
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ from careamics.config.data.tile_information import TileInformation
12
+ from careamics.lvae_training.dataset.utils.index_manager import GridIndexManager
13
+
14
+
15
+ def extract_tiles(
16
+ arr: NDArray,
17
+ tile_size: NDArray[np.int_],
18
+ overlaps: NDArray[np.int_],
19
+ padding_kwargs: dict[str, Any] | None = None,
20
+ ) -> Generator[tuple[NDArray, TileInformation], None, None]:
21
+ """Generate tiles from the input array with specified overlap.
22
+
23
+ The tiles cover the whole array; which will be additionally padded, to ensure that
24
+ the section of the tile that contributes to the final image comes from the center
25
+ of the tile.
26
+
27
+ The method returns a generator that yields tuples of array and tile information,
28
+ the latter includes whether the tile is the last one, the coordinates of the
29
+ overlap crop, and the coordinates of the stitched tile.
30
+
31
+ Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
32
+ where C can be a singleton.
33
+
34
+ Parameters
35
+ ----------
36
+ arr : np.ndarray
37
+ Array of shape (S, C, (Z), Y, X).
38
+ tile_size : 1D numpy.ndarray of tuple
39
+ Tile sizes in each dimension, of length 2 or 3.
40
+ overlaps : 1D numpy.ndarray of tuple
41
+ Overlap values in each dimension, of length 2 or 3.
42
+ padding_kwargs : dict, optional
43
+ The arguments of `np.pad` after the first two arguments, `array` and
44
+ `pad_width`. If not specified the default will be `{"mode": "reflect"}`. See
45
+ `numpy.pad` docs:
46
+ https://numpy.org/doc/stable/reference/generated/numpy.pad.html.
47
+
48
+ Yields
49
+ ------
50
+ Generator[Tuple[np.ndarray, TileInformation], None, None]
51
+ Tile generator, yields the tile and additional information.
52
+ """
53
+ if padding_kwargs is None:
54
+ padding_kwargs = {"mode": "reflect"}
55
+
56
+ # Iterate over num samples (S)
57
+ for sample_idx in range(arr.shape[0]):
58
+ sample = arr[sample_idx, ...]
59
+ data_shape = np.array(sample.shape)
60
+
61
+ # add padding to ensure evenly spaced & overlapping tiles.
62
+ spatial_padding = compute_padding(data_shape, tile_size, overlaps)
63
+ padding = ((0, 0), *spatial_padding)
64
+ sample = np.pad(sample, padding, **padding_kwargs)
65
+
66
+ # The number of tiles in each dimension, should be of length 2 or 3
67
+ tile_grid_shape = compute_tile_grid_shape(data_shape, tile_size, overlaps)
68
+ # itertools.product is equivalent of nested loops
69
+
70
+ stitch_size = tile_size - overlaps
71
+ for tile_grid_indices in itertools.product(
72
+ *[range(n) for n in tile_grid_shape]
73
+ ):
74
+
75
+ # calculate crop coordinates
76
+ crop_coords_start = np.array(tile_grid_indices) * stitch_size
77
+ crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
78
+ ...,
79
+ *[
80
+ slice(coords, coords + extent)
81
+ for coords, extent in zip(
82
+ crop_coords_start, tile_size, strict=False
83
+ )
84
+ ],
85
+ )
86
+ tile = sample[crop_slices]
87
+
88
+ tile_info = compute_tile_info(
89
+ np.array(tile_grid_indices),
90
+ np.array(data_shape),
91
+ np.array(tile_size),
92
+ np.array(overlaps),
93
+ sample_idx,
94
+ )
95
+ # TODO: kinda weird this is a generator,
96
+ # -> doesn't really save memory ? Don't think there are any places the
97
+ # tiles are not exracted all at the same time.
98
+ # Although I guess it would make sense for a zarr tile extractor.
99
+ yield tile, tile_info
100
+
101
+
102
+ def compute_tile_info_legacy(
103
+ grid_index_manager: GridIndexManager, index: int
104
+ ) -> TileInformation:
105
+ """
106
+ Compute the tile information for a tile at a given dataset index.
107
+
108
+ Parameters
109
+ ----------
110
+ grid_index_manager : GridIndexManager
111
+ The grid index manager that keeps track of tile locations.
112
+ index : int
113
+ The dataset index.
114
+
115
+ Returns
116
+ -------
117
+ TileInformation
118
+ Information that describes how to crop and stitch a tile to create a full image.
119
+
120
+ Raises
121
+ ------
122
+ ValueError
123
+ If `grid_index_manager.data_shape` does not have 4 or 5 dimensions.
124
+ """
125
+ data_shape = np.array(grid_index_manager.data_shape)
126
+ if len(data_shape) == 5:
127
+ n_spatial_dims = 3
128
+ elif len(data_shape) == 4:
129
+ n_spatial_dims = 2
130
+ else:
131
+ raise ValueError("Data shape must have 4 or 5 dimensions, equating to SC(Z)YX.")
132
+
133
+ stitch_coords_start = np.array(
134
+ grid_index_manager.get_location_from_dataset_idx(index)
135
+ )
136
+ stitch_coords_end = stitch_coords_start + np.array(grid_index_manager.grid_shape)
137
+
138
+ tile_coords_start = stitch_coords_start - grid_index_manager.patch_offset()
139
+
140
+ # --- replace out of bounds indices
141
+ out_of_lower_bound = stitch_coords_start < 0
142
+ out_of_upper_bound = stitch_coords_end > data_shape
143
+ stitch_coords_start[out_of_lower_bound] = 0
144
+ stitch_coords_end[out_of_upper_bound] = data_shape[out_of_upper_bound]
145
+
146
+ # TODO: TilingMode not in current version
147
+ # if grid_index_manager.tiling_mode == TilingMode.ShiftBoundary:
148
+ # for dim in range(len(stitch_coords_start)):
149
+ # if tile_coords_start[dim] == 0:
150
+ # stitch_coords_start[dim] = 0
151
+ # if tile_coords_end[dim] == grid_index_manager.data_shape[dim]:
152
+ # tile_coords_end [dim]= grid_index_manager.data_shape[dim]
153
+
154
+ # --- calculate overlap crop coords
155
+ overlap_crop_coords_start = stitch_coords_start - tile_coords_start
156
+ overlap_crop_coords_end = overlap_crop_coords_start + (
157
+ stitch_coords_end - stitch_coords_start
158
+ )
159
+
160
+ last_tile = index == grid_index_manager.total_grid_count() - 1
161
+
162
+ # --- combine start and end
163
+ stitch_coords = tuple(
164
+ (start, end)
165
+ for start, end in zip(stitch_coords_start, stitch_coords_end, strict=False)
166
+ )
167
+ overlap_crop_coords = tuple(
168
+ (start, end)
169
+ for start, end in zip(
170
+ overlap_crop_coords_start, overlap_crop_coords_end, strict=False
171
+ )
172
+ )
173
+
174
+ tile_info = TileInformation(
175
+ array_shape=data_shape[1:], # remove S dim
176
+ last_tile=last_tile,
177
+ overlap_crop_coords=overlap_crop_coords[-n_spatial_dims:],
178
+ stitch_coords=stitch_coords[-n_spatial_dims:],
179
+ sample_id=0,
180
+ )
181
+ return tile_info
182
+
183
+
184
+ def compute_tile_info(
185
+ tile_grid_indices: NDArray[np.int_],
186
+ data_shape: NDArray[np.int_],
187
+ tile_size: NDArray[np.int_],
188
+ overlaps: NDArray[np.int_],
189
+ sample_id: int = 0,
190
+ ) -> TileInformation:
191
+ """
192
+ Compute the tile information for a tile with the coordinates `tile_grid_indices`.
193
+
194
+ Parameters
195
+ ----------
196
+ tile_grid_indices : 1D np.array of int
197
+ The coordinates of the tile within the tile grid, ((Z), Y, X), i.e. for 2D
198
+ tiling the coordinates for the second tile in the first row of tiles would be
199
+ (0, 1).
200
+ data_shape : 1D np.array of int
201
+ The shape of the data, should be (C, (Z), Y, X) where Z is optional.
202
+ tile_size : 1D np.array of int
203
+ Tile sizes in each dimension, of length 2 or 3.
204
+ overlaps : 1D np.array of int
205
+ Overlap values in each dimension, of length 2 or 3.
206
+ sample_id : int, default=0
207
+ An ID to identify which sample a tile belongs to.
208
+
209
+ Returns
210
+ -------
211
+ TileInformation
212
+ Information that describes how to crop and stitch a tile to create a full image.
213
+ """
214
+ spatial_dims_shape = data_shape[-len(tile_size) :]
215
+
216
+ # The extent of the tile which will make up part of the stitched image.
217
+ stitch_size = tile_size - overlaps
218
+ stitch_coords_start = tile_grid_indices * stitch_size
219
+ stitch_coords_end = stitch_coords_start + stitch_size
220
+
221
+ tile_coords_start = stitch_coords_start - overlaps // 2
222
+
223
+ # --- replace out of bounds indices
224
+ out_of_lower_bound = stitch_coords_start < 0
225
+ out_of_upper_bound = stitch_coords_end > spatial_dims_shape
226
+ stitch_coords_start[out_of_lower_bound] = 0
227
+ stitch_coords_end[out_of_upper_bound] = spatial_dims_shape[out_of_upper_bound]
228
+
229
+ # --- calculate overlap crop coords
230
+ overlap_crop_coords_start = stitch_coords_start - tile_coords_start
231
+ overlap_crop_coords_end = overlap_crop_coords_start + (
232
+ stitch_coords_end - stitch_coords_start
233
+ )
234
+
235
+ # --- combine start and end
236
+ stitch_coords = tuple(
237
+ (start, end)
238
+ for start, end in zip(stitch_coords_start, stitch_coords_end, strict=False)
239
+ )
240
+ overlap_crop_coords = tuple(
241
+ (start, end)
242
+ for start, end in zip(
243
+ overlap_crop_coords_start, overlap_crop_coords_end, strict=False
244
+ )
245
+ )
246
+
247
+ # --- Check if last tile
248
+ tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
249
+ last_tile = (tile_grid_indices == (tile_grid_shape - 1)).all()
250
+
251
+ tile_info = TileInformation(
252
+ array_shape=data_shape,
253
+ last_tile=last_tile,
254
+ overlap_crop_coords=overlap_crop_coords,
255
+ stitch_coords=stitch_coords,
256
+ sample_id=sample_id,
257
+ )
258
+ return tile_info
259
+
260
+
261
+ def compute_padding(
262
+ data_shape: NDArray[np.int_],
263
+ tile_size: NDArray[np.int_],
264
+ overlaps: NDArray[np.int_],
265
+ ) -> tuple[tuple[int, int], ...]:
266
+ """
267
+ Calculate padding to ensure stitched data comes from the center of a tile.
268
+
269
+ Padding is added to an array with shape `data_shape` so that when tiles are
270
+ stitched together, the data used always comes from the center of a tile, even for
271
+ tiles at the boundaries of the array.
272
+
273
+ Parameters
274
+ ----------
275
+ data_shape : 1D numpy.array of int
276
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
277
+ tile_size : 1D numpy.array of int
278
+ The tile size in each dimension, ((Z), Y, X).
279
+ overlaps : 1D numpy.array of int
280
+ The tile overlap in each dimension, ((Z), Y, X).
281
+
282
+ Returns
283
+ -------
284
+ tuple of (int, int)
285
+ A tuple specifying the padding to add in each dimension, each element is a two
286
+ element tuple specifying the padding to add before and after the data. This
287
+ can be used as the `pad_width` argument to `numpy.pad`.
288
+ """
289
+ tile_grid_shape = np.array(compute_tile_grid_shape(data_shape, tile_size, overlaps))
290
+ covered_shape = (tile_size - overlaps) * tile_grid_shape + overlaps
291
+
292
+ pad_before = overlaps // 2
293
+ pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before
294
+
295
+ return tuple(
296
+ (before, after) for before, after in zip(pad_before, pad_after, strict=False)
297
+ )
298
+
299
+
300
+ def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int:
301
+ """Calculate the number of tiles in a specific dimension.
302
+
303
+ Parameters
304
+ ----------
305
+ axis_size : int
306
+ The length of the data for in a specific dimension.
307
+ tile_size : int
308
+ The length of the tiles in a specific dimension.
309
+ overlap : int
310
+ The tile overlap in a specific dimension.
311
+
312
+ Returns
313
+ -------
314
+ int
315
+ The number of tiles that fit in one dimension given the arguments.
316
+ """
317
+ return int(np.ceil(axis_size / (tile_size - overlap)))
318
+
319
+
320
+ def total_n_tiles(
321
+ data_shape: tuple[int, ...], tile_size: tuple[int, ...], overlaps: tuple[int, ...]
322
+ ) -> int:
323
+ """Calculate The total number of tiles over all dimensions.
324
+
325
+ Parameters
326
+ ----------
327
+ data_shape : 1D numpy.array of int
328
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
329
+ tile_size : 1D numpy.array of int
330
+ The tile size in each dimension, ((Z), Y, X).
331
+ overlaps : 1D numpy.array of int
332
+ The tile overlap in each dimension, ((Z), Y, X).
333
+
334
+
335
+ Returns
336
+ -------
337
+ int
338
+ The total number of tiles over all dimensions.
339
+ """
340
+ result = 1
341
+ # assume spatial dimension are the last dimensions so iterate backwards
342
+ for i in range(-1, -len(tile_size) - 1, -1):
343
+ result = result * n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
344
+
345
+ return result
346
+
347
+
348
+ def compute_tile_grid_shape(
349
+ data_shape: NDArray[np.int_],
350
+ tile_size: NDArray[np.int_],
351
+ overlaps: NDArray[np.int_],
352
+ ) -> tuple[int, ...]:
353
+ """Calculate the number of tiles in each dimension.
354
+
355
+ This can be thought of as a grid of tiles.
356
+
357
+ Parameters
358
+ ----------
359
+ data_shape : 1D numpy.array of int
360
+ The shape of the data to be tiled and stitched together, (S, C, (Z), Y, X).
361
+ tile_size : 1D numpy.array of int
362
+ The tile size in each dimension, ((Z), Y, X).
363
+ overlaps : 1D numpy.array of int
364
+ The tile overlap in each dimension, ((Z), Y, X).
365
+
366
+ Returns
367
+ -------
368
+ tuple of int
369
+ The number of tiles in each direction, ((Z, Y, X)).
370
+ """
371
+ shape = [0 for _ in range(len(tile_size))]
372
+ # assume spatial dimension are the last dimensions so iterate backwards
373
+ for i in range(-1, -len(tile_size) - 1, -1):
374
+ shape[i] = n_tiles_1d(data_shape[i], tile_size[i], overlaps[i])
375
+ return tuple(shape)
@@ -0,0 +1,166 @@
1
+ """Tiled patching utilities."""
2
+
3
+ import itertools
4
+ from collections.abc import Generator
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+
9
+ from careamics.config.data.tile_information import TileInformation
10
+
11
+
12
+ def _compute_crop_and_stitch_coords_1d(
13
+ axis_size: int, tile_size: int, overlap: int
14
+ ) -> tuple[list[tuple[int, int]], list[tuple[int, int]], list[tuple[int, int]]]:
15
+ """
16
+ Compute the coordinates of each tile along an axis, given the overlap.
17
+
18
+ Parameters
19
+ ----------
20
+ axis_size : int
21
+ Length of the axis.
22
+ tile_size : int
23
+ Size of the tile for the given axis.
24
+ overlap : int
25
+ Size of the overlap for the given axis.
26
+
27
+ Returns
28
+ -------
29
+ tuple[tuple[int, ...], ...]
30
+ tuple of all coordinates for given axis.
31
+ """
32
+ # Compute the step between tiles
33
+ step = tile_size - overlap
34
+ crop_coords = []
35
+ stitch_coords = []
36
+ overlap_crop_coords = []
37
+
38
+ # Iterate over the axis with step
39
+ for i in range(0, max(1, axis_size - overlap), step):
40
+ # Check if the tile fits within the axis
41
+ if i + tile_size <= axis_size:
42
+ # Add the coordinates to crop one tile
43
+ crop_coords.append((i, i + tile_size))
44
+
45
+ # Add the pixel coordinates of the cropped tile in the original image space
46
+ stitch_coords.append(
47
+ (
48
+ i + overlap // 2 if i > 0 else 0,
49
+ (
50
+ i + tile_size - overlap // 2
51
+ if crop_coords[-1][1] < axis_size
52
+ else axis_size
53
+ ),
54
+ )
55
+ )
56
+
57
+ # Add the coordinates to crop the overlap from the prediction.
58
+ overlap_crop_coords.append(
59
+ (
60
+ overlap // 2 if i > 0 else 0,
61
+ (
62
+ tile_size - overlap // 2
63
+ if crop_coords[-1][1] < axis_size
64
+ else tile_size
65
+ ),
66
+ )
67
+ )
68
+
69
+ # If the tile does not fit within the axis, perform the abovementioned
70
+ # operations starting from the end of the axis
71
+ else:
72
+ # if (axis_size - tile_size, axis_size) not in crop_coords:
73
+ crop_coords.append((max(0, axis_size - tile_size), axis_size))
74
+ last_tile_end_coord = stitch_coords[-1][1] if stitch_coords else 1
75
+ stitch_coords.append((last_tile_end_coord, axis_size))
76
+ overlap_crop_coords.append(
77
+ (tile_size - (axis_size - last_tile_end_coord), tile_size)
78
+ )
79
+ break
80
+ return crop_coords, stitch_coords, overlap_crop_coords
81
+
82
+
83
+ def extract_tiles(
84
+ arr: np.ndarray,
85
+ tile_size: Union[list[int], tuple[int, ...]],
86
+ overlaps: Union[list[int], tuple[int, ...]],
87
+ ) -> Generator[tuple[np.ndarray, TileInformation], None, None]:
88
+ """Generate tiles from the input array with specified overlap.
89
+
90
+ The tiles cover the whole array. The method returns a generator that yields
91
+ tuples of array and tile information, the latter includes whether
92
+ the tile is the last one, the coordinates of the overlap crop, and the coordinates
93
+ of the stitched tile.
94
+
95
+ Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
96
+ where C can be a singleton.
97
+
98
+ Parameters
99
+ ----------
100
+ arr : np.ndarray
101
+ Array of shape (S, C, (Z), Y, X).
102
+ tile_size : Union[list[int], tuple[int]]
103
+ Tile sizes in each dimension, of length 2 or 3.
104
+ overlaps : Union[list[int], tuple[int]]
105
+ Overlap values in each dimension, of length 2 or 3.
106
+
107
+ Yields
108
+ ------
109
+ Generator[tuple[np.ndarray, TileInformation], None, None]
110
+ Tile generator, yields the tile and additional information.
111
+ """
112
+ # Iterate over num samples (S)
113
+ for sample_idx in range(arr.shape[0]):
114
+ sample: np.ndarray = arr[sample_idx, ...]
115
+
116
+ # Create a list of coordinates for cropping and stitching all axes.
117
+ # [crop coordinates, stitching coordinates, overlap crop coordinates]
118
+ # For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
119
+ # will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)])
120
+ crop_and_stitch_coords_list = [
121
+ _compute_crop_and_stitch_coords_1d(
122
+ sample.shape[i + 1], tile_size[i], overlaps[i]
123
+ )
124
+ for i in range(len(tile_size))
125
+ ]
126
+
127
+ # Rearrange crop coordinates from a list of coordinate pairs per axis to a list
128
+ # grouped by type.
129
+ all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
130
+ *crop_and_stitch_coords_list, strict=False
131
+ )
132
+
133
+ # Maximum tile index
134
+ max_tile_idx = np.prod([len(axis) for axis in all_crop_coords]) - 1
135
+
136
+ # Iterate over generated coordinate pairs:
137
+ for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
138
+ zip(
139
+ itertools.product(*all_crop_coords),
140
+ itertools.product(*all_stitch_coords),
141
+ itertools.product(*all_overlap_crop_coords),
142
+ strict=False,
143
+ )
144
+ ):
145
+ # Extract tile from the sample
146
+ tile: np.ndarray = sample[
147
+ (..., *[slice(c[0], c[1]) for c in list(crop_coords)]) # type: ignore
148
+ ]
149
+
150
+ # Check if we are at the end of the sample by computing the length of the
151
+ # array that contains all the tiles
152
+ if tile_idx == max_tile_idx:
153
+ last_tile = True
154
+ else:
155
+ last_tile = False
156
+
157
+ # create tile information
158
+ tile_info = TileInformation(
159
+ array_shape=sample.shape,
160
+ last_tile=last_tile,
161
+ overlap_crop_coords=overlap_crop_coords,
162
+ stitch_coords=stitch_coords,
163
+ sample_id=sample_idx,
164
+ )
165
+
166
+ yield tile, tile_info