careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,393 @@
1
+ """A module for random patching strategies."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import numpy as np
6
+
7
+ from .patching_strategy_protocol import PatchSpecs
8
+
9
+
10
+ class RandomPatchingStrategy:
11
+ """
12
+ A patching strategy for sampling random patches, it implements the
13
+ `PatchingStrategy` `Protocol`.
14
+
15
+ The output of `get_patch_spec` will be random, i.e. if the same index is given
16
+ twice the two outputs can be different.
17
+
18
+ However the strategy still ensures that there will be a known number of patches for
19
+ each sample in each image stack. This is achieved through defining a set of bins
20
+ that map to each sample in each image stack. Whichever bin an `index` passed to
21
+ `get_patch_spec` falls into, determines the `"data_idx"` and `"sample_idx"` in
22
+ the returned `PatchSpecs`, but the `"coords"` will be random.
23
+
24
+ The number of patches in each sample is based on the number of patches that would
25
+ fit if they were sampled sequentially, non-overlapping, and covering the entire
26
+ array.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ data_shapes: Sequence[Sequence[int]],
32
+ patch_size: Sequence[int],
33
+ seed: int | None = None,
34
+ ):
35
+ """
36
+ A patching strategy for sampling random patches.
37
+
38
+ Parameters
39
+ ----------
40
+ data_shapes : sequence of (sequence of int)
41
+ The shapes of the underlying data. Each element is the dimension of the
42
+ axes SC(Z)YX.
43
+ patch_size : sequence of int
44
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
45
+ data respectively.
46
+ seed : int, optional
47
+ An optional seed to ensure the reproducibility of the random patches.
48
+ """
49
+ self.rng = np.random.default_rng(seed=seed)
50
+ self.patch_size = patch_size
51
+ self.data_shapes = data_shapes
52
+
53
+ # these bins will determine which image stack and sample a patch comes from
54
+ # the image_stack_cumulative_patches map a patch index to each image stack
55
+ # the sample_cumulative_patches map a patch index to each sample
56
+ # the image_stack_cumulative_samples map a sample index to each image stack
57
+ (
58
+ self.image_stack_cumulative_patches,
59
+ self.sample_cumulative_patches,
60
+ self.image_stack_cumulative_samples,
61
+ ) = self._calc_bins(self.data_shapes, self.patch_size)
62
+
63
+ @property
64
+ def n_patches(self) -> int:
65
+ """
66
+ The number of patches that this patching strategy will return.
67
+
68
+ It also determines the maximum index that can be given to `get_patch_spec`.
69
+ """
70
+ # last bin boundary will be total patches
71
+ return self.image_stack_cumulative_patches[-1]
72
+
73
+ def get_patch_spec(self, index: int) -> PatchSpecs:
74
+ """Return the patch specs for a given index.
75
+
76
+ Parameters
77
+ ----------
78
+ index : int
79
+ A patch index.
80
+
81
+ Returns
82
+ -------
83
+ PatchSpecs
84
+ A dictionary that specifies a single patch in a series of `ImageStacks`.
85
+ """
86
+ # TODO: break into smaller testable functions?
87
+ if index >= self.n_patches:
88
+ raise IndexError(
89
+ f"Index {index} out of bounds for RandomPatchingStrategy with number "
90
+ f"of patches {self.n_patches}"
91
+ )
92
+ # digitize returns the bin that `index` belongs to
93
+ data_index = np.digitize(index, bins=self.image_stack_cumulative_patches).item()
94
+ # maps to a particular sample within the whole series of image stacks
95
+ # (not just a single image stack)
96
+ total_samples_index = np.digitize(
97
+ index, bins=self.sample_cumulative_patches
98
+ ).item()
99
+
100
+ data_shape = self.data_shapes[data_index]
101
+ spatial_shape = data_shape[2:]
102
+
103
+ # calculate sample index relative to image stack:
104
+ # subtract the total number of samples in the previous image stacks
105
+ if data_index == 0:
106
+ n_previous_samples = 0
107
+ else:
108
+ n_previous_samples = self.image_stack_cumulative_samples[data_index - 1]
109
+ sample_index = total_samples_index - n_previous_samples
110
+ coords = _generate_random_coords(spatial_shape, self.patch_size, self.rng)
111
+ return {
112
+ "data_idx": data_index,
113
+ "sample_idx": sample_index,
114
+ "coords": coords,
115
+ "patch_size": self.patch_size,
116
+ }
117
+
118
+ # Note: this is used by the FileIterSampler
119
+ def get_patch_indices(self, data_idx: int) -> Sequence[int]:
120
+ """
121
+ Get the patch indices will return patches for a specific `image_stack`.
122
+
123
+ The `image_stack` corresponds to the given `data_idx`.
124
+
125
+ Parameters
126
+ ----------
127
+ data_idx : int
128
+ An index that corresponds to a given `image_stack`.
129
+
130
+ Returns
131
+ -------
132
+ sequence of int
133
+ A sequence of patch indices, that when used to index the `CAREamicsDataset
134
+ will return a patch that comes from the `image_stack` corresponding to the
135
+ given `data_idx`.
136
+ """
137
+ # return all the values in the corresponding bin
138
+ if data_idx == 0:
139
+ start = 0
140
+ else:
141
+ start = self.image_stack_cumulative_patches[data_idx - 1]
142
+
143
+ return np.arange(start, self.image_stack_cumulative_patches[data_idx]).tolist()
144
+
145
+ @staticmethod
146
+ def _calc_bins(
147
+ data_shapes: Sequence[Sequence[int]], patch_size: Sequence[int]
148
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
149
+ """Calculate bins used to map an index to an image_stack and a sample.
150
+
151
+ The number of patches in each sample is based on the number of patches that
152
+ would fit if they were sampled sequentially.
153
+
154
+ Parameters
155
+ ----------
156
+ data_shapes : sequence of (sequence of int)
157
+ The shapes of the underlying data. Each element is the dimension of the
158
+ axes SC(Z)YX.
159
+ patch_size : sequence of int
160
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
161
+ data respectively.
162
+
163
+ Returns
164
+ -------
165
+ image_stack_cumulative_patches: tuple of int
166
+ The bins that map a patch index to an image stack. E.g. if a patch index
167
+ falls below the first bin boundary it belongs to the first image stack, if
168
+ a patch index falls between the first bin boundary and the second bin
169
+ boundary it belongs to the second image stack, and so on.
170
+ sample_cumulative_patches: tuple of int
171
+ The bins that map a patch index to a sample. E.g. if a patch index
172
+ falls below the first bin boundary it belongs to the first sample, if
173
+ a patch index falls between the first bin boundary and the second bin
174
+ boundary it belongs to the second sample, and so on.
175
+ image_stack_cumulative_samples: tuple of int
176
+ The bins that map a sample index to an image stack. E.g. if a sample index
177
+ falls below the first bin boundary it belongs to the first image stack, if
178
+ a patch index falls between the first bin boundary and the second bin
179
+ boundary it belongs to the second image stack, and so on.
180
+ """
181
+ patches_per_image_stack: list[int] = []
182
+ patches_per_sample: list[int] = []
183
+ samples_per_image_stack: list[int] = []
184
+ for data_shape in data_shapes:
185
+ spatial_shape = data_shape[2:]
186
+ n_single_sample_patches = _calc_n_patches(spatial_shape, patch_size)
187
+ # multiply by number of samples in image_stack
188
+ patches_per_image_stack.append(n_single_sample_patches * data_shape[0])
189
+ # list of length `sample` filled with `n_single_sample_patches`
190
+ patches_per_sample.extend([n_single_sample_patches] * data_shape[0])
191
+ # number of samples in each image stack
192
+ samples_per_image_stack.append(data_shape[0])
193
+
194
+ # cumulative sum creates the bins
195
+ image_stack_cumulative_patches = np.cumsum(patches_per_image_stack)
196
+ sample_cumulative_patches = np.cumsum(patches_per_sample)
197
+ image_stack_cumulative_samples = np.cumsum(samples_per_image_stack)
198
+ return (
199
+ tuple(image_stack_cumulative_patches),
200
+ tuple(sample_cumulative_patches),
201
+ tuple(image_stack_cumulative_samples),
202
+ )
203
+
204
+
205
+ class FixedRandomPatchingStrategy:
206
+ """
207
+ A patching strategy for sampling random patches it implements the `PatchingStrategy`
208
+ `Protocol`.
209
+
210
+ The output of `get_patch_spec` will be deterministic, i.e. if the same index is
211
+ given twice the two outputs will be the same.
212
+
213
+ The number of patches in each sample is based on the number of patches that would
214
+ fit if they were sampled sequentially, non-overlapping, and covering the entire
215
+ array.
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ data_shapes: Sequence[Sequence[int]],
221
+ patch_size: Sequence[int],
222
+ seed: int | None = None,
223
+ ):
224
+ """A patching strategy for sampling random patches.
225
+
226
+ Parameters
227
+ ----------
228
+ data_shapes : sequence of (sequence of int)
229
+ The shapes of the underlying data. Each element is the dimension of the
230
+ axes SC(Z)YX.
231
+ patch_size : sequence of int
232
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
233
+ data respectively.
234
+ seed : int, optional
235
+ An optional seed to ensure the reproducibility of the random patches.
236
+ """
237
+ self.rng = np.random.default_rng(seed=seed)
238
+ self.patch_size = patch_size
239
+ self.data_shapes = data_shapes
240
+
241
+ # simply generate all the patches at initialisation, so they will be fixed
242
+ self.fixed_patch_specs: list[PatchSpecs] = []
243
+ for data_idx, data_shape in enumerate(self.data_shapes):
244
+ spatial_shape = data_shape[2:]
245
+ n_patches = _calc_n_patches(spatial_shape, self.patch_size)
246
+ for sample_idx in range(data_shape[0]):
247
+ for _ in range(n_patches):
248
+ random_coords = _generate_random_coords(
249
+ spatial_shape, self.patch_size, self.rng
250
+ )
251
+ patch_specs: PatchSpecs = {
252
+ "data_idx": data_idx,
253
+ "sample_idx": sample_idx,
254
+ "coords": random_coords,
255
+ "patch_size": self.patch_size,
256
+ }
257
+ self.fixed_patch_specs.append(patch_specs)
258
+
259
+ @property
260
+ def n_patches(self):
261
+ """
262
+ The number of patches that this patching strategy will return.
263
+
264
+ It also determines the maximum index that can be given to `get_patch_spec`.
265
+ """
266
+ return len(self.fixed_patch_specs)
267
+
268
+ def get_patch_spec(self, index: int) -> PatchSpecs:
269
+ """Return the patch specs for a given index.
270
+
271
+ Parameters
272
+ ----------
273
+ index : int
274
+ A patch index.
275
+
276
+ Returns
277
+ -------
278
+ PatchSpecs
279
+ A dictionary that specifies a single patch in a series of `ImageStacks`.
280
+ """
281
+ if index >= self.n_patches:
282
+ raise IndexError(
283
+ f"Index {index} out of bounds for FixedRandomPatchingStrategy with "
284
+ f"number of patches, {self.n_patches}"
285
+ )
286
+ # simply index the pre-generated patches to get the correct patch
287
+ return self.fixed_patch_specs[index]
288
+
289
+ # Note: this is used by the FileIterSampler
290
+ def get_patch_indices(self, data_idx: int) -> Sequence[int]:
291
+ """
292
+ Get the patch indices will return patches for a specific `image_stack`.
293
+
294
+ The `image_stack` corresponds to the given `data_idx`.
295
+
296
+ Parameters
297
+ ----------
298
+ data_idx : int
299
+ An index that corresponds to a given `image_stack`.
300
+
301
+ Returns
302
+ -------
303
+ sequence of int
304
+ A sequence of patch indices, that when used to index the `CAREamicsDataset
305
+ will return a patch that comes from the `image_stack` corresponding to the
306
+ given `data_idx`.
307
+ """
308
+ return [
309
+ i
310
+ for i, patch_spec in enumerate(self.fixed_patch_specs)
311
+ if patch_spec["data_idx"] == data_idx
312
+ ]
313
+
314
+
315
+ def _generate_random_coords(
316
+ spatial_shape: Sequence[int], patch_size: Sequence[int], rng: np.random.Generator
317
+ ) -> tuple[int, ...]:
318
+ """Generate random patch coordinates for a given `spatial_shape` and `patch_size`.
319
+
320
+ The coords are the top-left (and first z-slice for 3D data) of a patch. The
321
+ sequence will have length 2 or 3, for 2D and 3D data respectively.
322
+
323
+ Parameters
324
+ ----------
325
+ spatial_shape : sequence of int
326
+ The dimension of the axes (Z)YX, a sequence of length 2 or 3, for 2D and 3D
327
+ data respectively.
328
+ patch_size : sequence of int
329
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
330
+ data respectively.
331
+ rng : numpy.random.Generator
332
+ A numpy generator to ensure the reproducibility of the random patches.
333
+
334
+ Returns
335
+ -------
336
+ coords: tuple of int
337
+ The top-left (and first z-slice for 3D data) coords of a patch. The tuple will
338
+ have length 2 or 3, for 2D and 3D data respectively.
339
+
340
+ Raises
341
+ ------
342
+ ValueError
343
+ Raises if the number of spatial dimensions do not match the number of patch
344
+ dimensions.
345
+ """
346
+ if len(patch_size) != len(spatial_shape):
347
+ raise ValueError(
348
+ f"Number of patch dimension {len(patch_size)}, do not match the number of "
349
+ f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
350
+ f"and `spatial_shape={spatial_shape}`."
351
+ )
352
+ return tuple(
353
+ rng.integers(
354
+ np.zeros(len(patch_size), dtype=int),
355
+ np.clip(np.array(spatial_shape) - np.array(patch_size), 0, None),
356
+ endpoint=True,
357
+ dtype=int,
358
+ ).tolist()
359
+ )
360
+
361
+
362
+ def _calc_n_patches(spatial_shape: Sequence[int], patch_size: Sequence[int]) -> int:
363
+ """
364
+ Calculates the number of patches for a given `spatial_shape` and `patch_size`.
365
+
366
+ This is based on the number of patches that would fit if they were sampled
367
+ sequentially.
368
+
369
+ Parameters
370
+ ----------
371
+ spatial_shape : sequence of int
372
+ The dimension of the axes (Z)YX, a sequence of length 2 or 3, for 2D and 3D
373
+ data respectively.
374
+ patch_size : sequence of int
375
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D
376
+ data respectively.
377
+
378
+ Returns
379
+ -------
380
+ int
381
+ The number of patches.
382
+ """
383
+ if len(patch_size) != len(spatial_shape):
384
+ raise ValueError(
385
+ f"Number of patch dimension {len(patch_size)}, do not match the number of "
386
+ f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
387
+ f"and `spatial_shape={spatial_shape}`."
388
+ )
389
+ patches_per_dim = [
390
+ np.ceil(s / p) for s, p in zip(spatial_shape, patch_size, strict=False)
391
+ ]
392
+ total_patches = int(np.prod(patches_per_dim))
393
+ return total_patches
@@ -0,0 +1,99 @@
1
+ import itertools
2
+ from collections.abc import Sequence
3
+
4
+ import numpy as np
5
+ from typing_extensions import ParamSpec
6
+
7
+ from .patching_strategy_protocol import PatchSpecs
8
+
9
+ P = ParamSpec("P")
10
+
11
+
12
+ # TODO: this is an unfinished prototype based on current tiling implementation
13
+ # not guaranteed to work!
14
+ class SequentialPatchingStrategy:
15
+ # TODO: docs
16
+ def __init__(
17
+ self,
18
+ data_shapes: Sequence[Sequence[int]],
19
+ patch_size: Sequence[int],
20
+ overlaps: Sequence[int] | None = None,
21
+ ):
22
+ self.data_shapes = data_shapes
23
+ self.patch_size = patch_size
24
+ if overlaps is None:
25
+ overlaps = [0] * len(patch_size)
26
+ self.overlaps = np.asarray(overlaps)
27
+
28
+ self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
29
+
30
+ @property
31
+ def n_patches(self) -> int:
32
+ return len(self.patch_specs)
33
+
34
+ def get_patch_spec(self, index: int) -> PatchSpecs:
35
+ return self.patch_specs[index]
36
+
37
+ # Note: this is used by the FileIterSampler
38
+ def get_patch_indices(self, data_idx: int) -> Sequence[int]:
39
+ """
40
+ Get the patch indices will return patches for a specific `image_stack`.
41
+
42
+ The `image_stack` corresponds to the given `data_idx`.
43
+
44
+ Parameters
45
+ ----------
46
+ data_idx : int
47
+ An index that corresponds to a given `image_stack`.
48
+
49
+ Returns
50
+ -------
51
+ sequence of int
52
+ A sequence of patch indices, that when used to index the `CAREamicsDataset
53
+ will return a patch that comes from the `image_stack` corresponding to the
54
+ given `data_idx`.
55
+ """
56
+ return [
57
+ i
58
+ for i, patch_spec in enumerate(self.patch_specs)
59
+ if patch_spec["data_idx"] == data_idx
60
+ ]
61
+
62
+ def _compute_coords_1d(
63
+ self, patch_size: int, spatial_shape: int, overlap: int
64
+ ) -> list[tuple[int, int]]:
65
+ step = patch_size - overlap
66
+ crop_coords = []
67
+
68
+ current_pos = 0
69
+ while current_pos <= spatial_shape - patch_size:
70
+ crop_coords.append((current_pos, current_pos + patch_size))
71
+ current_pos += step
72
+
73
+ if crop_coords[-1][1] < spatial_shape:
74
+ crop_coords.append((spatial_shape - patch_size, spatial_shape))
75
+
76
+ return crop_coords
77
+
78
+ def _initialize_patch_specs(self) -> list[PatchSpecs]:
79
+ patch_specs: list[PatchSpecs] = []
80
+ for data_idx, data_shape in enumerate(self.data_shapes):
81
+
82
+ data_spatial_shape = data_shape[-len(self.patch_size) :]
83
+ coords_list = [
84
+ self._compute_coords_1d(
85
+ self.patch_size[i], data_spatial_shape[i], self.overlaps[i]
86
+ )
87
+ for i in range(len(self.patch_size))
88
+ ]
89
+ for sample_idx in range(data_shape[0]):
90
+ for crop_coord in itertools.product(*coords_list):
91
+ patch_specs.append(
92
+ PatchSpecs(
93
+ data_idx=data_idx,
94
+ sample_idx=sample_idx,
95
+ coords=tuple(coord[0] for coord in crop_coord),
96
+ patch_size=self.patch_size,
97
+ )
98
+ )
99
+ return patch_specs
@@ -0,0 +1,207 @@
1
+ """Module for the `TilingStrategy` class."""
2
+
3
+ import itertools
4
+ from collections.abc import Sequence
5
+ from math import prod
6
+
7
+ from .patching_strategy_protocol import TileSpecs
8
+
9
+
10
+ class TilingStrategy:
11
+ """
12
+ The tiling strategy should be used for prediction. The `get_patch_specs`
13
+ method returns `TileSpec` dictionaries that contains information on how to
14
+ stitch the tiles back together to create the full image.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ data_shapes: Sequence[Sequence[int]],
20
+ patch_size: Sequence[int],
21
+ overlaps: Sequence[int],
22
+ ):
23
+ """
24
+ The tiling strategy should be used for prediction. The `get_patch_specs`
25
+ method returns `TileSpec` dictionaries that contains information on how to
26
+ stitch the tiles back together to create the full image.
27
+
28
+ Parameters
29
+ ----------
30
+ data_shapes : sequence of (sequence of int)
31
+ The shapes of the underlying data. Each element is the dimension of the
32
+ axes SC(Z)YX.
33
+ patch_size : sequence of int
34
+ The size of the tile. The sequence will have length 2 or 3, for 2D and 3D
35
+ data respectively.
36
+ overlaps : sequence of int
37
+ How much a tile will overlap with adjacent tiles in each spatial dimension.
38
+ """
39
+ self.data_shapes = data_shapes
40
+ self.patch_size = patch_size
41
+ self.overlaps = overlaps
42
+ # patch_size and overlap should have same length validated in pydantic configs
43
+ self.tile_specs: list[TileSpecs] = self._generate_specs()
44
+
45
+ @property
46
+ def n_patches(self) -> int:
47
+ """
48
+ The number of patches that this patching strategy will return.
49
+
50
+ It also determines the maximum index that can be given to `get_patch_spec`.
51
+ """
52
+ return len(self.tile_specs)
53
+
54
+ def get_patch_spec(self, index: int) -> TileSpecs:
55
+ """Return the tile specs for a given index.
56
+
57
+ Parameters
58
+ ----------
59
+ index : int
60
+ A patch index.
61
+
62
+ Returns
63
+ -------
64
+ TileSpecs
65
+ A dictionary that specifies a single patch in a series of `ImageStacks`.
66
+ """
67
+ return self.tile_specs[index]
68
+
69
+ # Note: this is used by the FileIterSampler
70
+ def get_patch_indices(self, data_idx: int) -> Sequence[int]:
71
+ """
72
+ Get the patch indices will return patches for a specific `image_stack`.
73
+
74
+ The `image_stack` corresponds to the given `data_idx`.
75
+
76
+ Parameters
77
+ ----------
78
+ data_idx : int
79
+ An index that corresponds to a given `image_stack`.
80
+
81
+ Returns
82
+ -------
83
+ sequence of int
84
+ A sequence of patch indices, that when used to index the `CAREamicsDataset
85
+ will return a patch that comes from the `image_stack` corresponding to the
86
+ given `data_idx`.
87
+ """
88
+ return [
89
+ i
90
+ for i, patch_spec in enumerate(self.tile_specs)
91
+ if patch_spec["data_idx"] == data_idx
92
+ ]
93
+
94
+ def _generate_specs(self) -> list[TileSpecs]:
95
+ tile_specs: list[TileSpecs] = []
96
+ for data_idx, data_shape in enumerate(self.data_shapes):
97
+ spatial_shape = data_shape[2:]
98
+
99
+ # spec info for each axis
100
+ axis_specs: list[tuple[list[int], list[int], list[int], list[int]]] = [
101
+ self._compute_1d_coords(
102
+ axis_size, self.patch_size[axis_idx], self.overlaps[axis_idx]
103
+ )
104
+ for axis_idx, axis_size in enumerate(spatial_shape)
105
+ ]
106
+
107
+ # combine by using zip
108
+ all_coords, all_stitch_coords, all_crop_coords, all_crop_size = zip(
109
+ *axis_specs, strict=False
110
+ )
111
+
112
+ # number of tiles for this data_idx
113
+ n_tiles = prod(len(dim) for dim in all_coords) * data_shape[0]
114
+
115
+ # patches will be the same for each sample in a stack
116
+ for sample_idx in range(data_shape[0]):
117
+ # iterate through all combinations using itertools.product
118
+ for coords, stitch_coords, crop_coords, crop_size in zip(
119
+ itertools.product(*all_coords),
120
+ itertools.product(*all_stitch_coords),
121
+ itertools.product(*all_crop_coords),
122
+ itertools.product(*all_crop_size),
123
+ strict=False,
124
+ ):
125
+ tile_specs.append(
126
+ {
127
+ # PatchSpecs
128
+ "data_idx": data_idx,
129
+ "sample_idx": sample_idx,
130
+ "coords": coords,
131
+ "patch_size": self.patch_size,
132
+ # TileSpecs additional fields
133
+ "crop_coords": crop_coords,
134
+ "crop_size": crop_size,
135
+ "stitch_coords": stitch_coords,
136
+ "total_tiles": n_tiles,
137
+ }
138
+ )
139
+
140
+ return tile_specs
141
+
142
+ @staticmethod
143
+ def _compute_1d_coords(
144
+ axis_size: int, patch_size: int, overlap: int
145
+ ) -> tuple[list[int], list[int], list[int], list[int]]:
146
+ """
147
+ Computes the TileSpec information for a single axis.
148
+
149
+ Parameters
150
+ ----------
151
+ axis_size : int
152
+ The size of the axis.
153
+ patch_size : int
154
+ The tile size.
155
+ overlap : int
156
+ The tile overlap.
157
+
158
+ Returns
159
+ -------
160
+ coords: list of int
161
+ The top-left (and first z-slice for 3D data) of a tile, in coords relative
162
+ to the image.
163
+ stitch_coords: list of int
164
+ Where the tile will be stitched back into an image, taking into account
165
+ that the tile will be cropped, in coords relative to the image.
166
+ crop_coords: list of int
167
+ The top-left side of where the tile will be cropped, in coordinates relative
168
+ to the tile.
169
+ crop_size: list of int
170
+ The size of the cropped tile.
171
+ """
172
+ coords: list[int] = []
173
+ stitch_coords: list[int] = []
174
+ crop_coords: list[int] = []
175
+ crop_size: list[int] = []
176
+
177
+ step = patch_size - overlap
178
+ for i in range(0, max(1, axis_size - overlap), step):
179
+ if i == 0:
180
+ coords.append(i)
181
+ crop_coords.append(0)
182
+ stitch_coords.append(0)
183
+ if axis_size <= patch_size:
184
+ crop_size.append(axis_size)
185
+ else:
186
+ crop_size.append(patch_size - overlap // 2)
187
+ elif (0 < i) and (i + patch_size < axis_size):
188
+ coords.append(i)
189
+ crop_coords.append(overlap // 2)
190
+ stitch_coords.append(coords[-1] + crop_coords[-1])
191
+ crop_size.append(patch_size - overlap)
192
+ else:
193
+ previous_crop_size = crop_size[-1] if crop_size else 1
194
+ previous_stitch_coord = stitch_coords[-1] if stitch_coords else 0
195
+ previous_tile_end = previous_stitch_coord + previous_crop_size
196
+
197
+ coords.append(max(0, axis_size - patch_size))
198
+ stitch_coords.append(previous_tile_end)
199
+ crop_coords.append(stitch_coords[-1] - coords[-1])
200
+ crop_size.append(axis_size - stitch_coords[-1])
201
+
202
+ return (
203
+ coords,
204
+ stitch_coords,
205
+ crop_coords,
206
+ crop_size,
207
+ )