careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,300 @@
1
+ """Patching functions."""
2
+
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ from ...utils.logging import get_logger
12
+ from ..dataset_utils import reshape_array
13
+ from ..dataset_utils.running_stats import compute_normalization_stats
14
+ from .sequential_patching import extract_patches_sequential
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class Stats:
21
+ """Dataclass to store statistics."""
22
+
23
+ means: Union[NDArray, tuple, list, None]
24
+ """Mean of the data across channels."""
25
+
26
+ stds: Union[NDArray, tuple, list, None]
27
+ """Standard deviation of the data across channels."""
28
+
29
+ def get_statistics(self) -> tuple[list[float], list[float]]:
30
+ """Return the means and standard deviations.
31
+
32
+ Returns
33
+ -------
34
+ tuple of two lists of floats
35
+ Means and standard deviations.
36
+ """
37
+ if self.means is None or self.stds is None:
38
+ return [], []
39
+
40
+ return list(self.means), list(self.stds)
41
+
42
+
43
+ @dataclass
44
+ class PatchedOutput:
45
+ """Dataclass to store patches and statistics."""
46
+
47
+ patches: Union[NDArray]
48
+ """Image patches."""
49
+
50
+ targets: Union[NDArray, None]
51
+ """Target patches."""
52
+
53
+ image_stats: Stats
54
+ """Statistics of the image patches."""
55
+
56
+ target_stats: Stats
57
+ """Statistics of the target patches."""
58
+
59
+
60
+ # called by in memory dataset
61
+ def prepare_patches_supervised(
62
+ train_files: list[Path],
63
+ target_files: list[Path],
64
+ axes: str,
65
+ patch_size: Union[list[int], tuple[int, ...]],
66
+ read_source_func: Callable,
67
+ ) -> PatchedOutput:
68
+ """
69
+ Iterate over data source and create an array of patches and corresponding targets.
70
+
71
+ The lists of Paths should be pre-sorted.
72
+
73
+ Parameters
74
+ ----------
75
+ train_files : list of pathlib.Path
76
+ List of paths to training data.
77
+ target_files : list of pathlib.Path
78
+ List of paths to target data.
79
+ axes : str
80
+ Axes of the data.
81
+ patch_size : list or tuple of int
82
+ Size of the patches.
83
+ read_source_func : Callable
84
+ Function to read the data.
85
+
86
+ Returns
87
+ -------
88
+ np.ndarray
89
+ Array of patches.
90
+ """
91
+ means, stds, num_samples = 0, 0, 0
92
+ all_patches, all_targets = [], []
93
+ for train_filename, target_filename in zip(train_files, target_files, strict=False):
94
+ try:
95
+ sample: np.ndarray = read_source_func(train_filename, axes)
96
+ target: np.ndarray = read_source_func(target_filename, axes)
97
+ means += sample.mean()
98
+ stds += sample.std()
99
+ num_samples += 1
100
+
101
+ # reshape array
102
+ sample = reshape_array(sample, axes)
103
+ target = reshape_array(target, axes)
104
+
105
+ # generate patches, return a generator
106
+ patches, targets = extract_patches_sequential(
107
+ sample, patch_size=patch_size, target=target
108
+ )
109
+
110
+ # convert generator to list and add to all_patches
111
+ all_patches.append(patches)
112
+
113
+ # ensure targets are not None (type checking)
114
+ if targets is not None:
115
+ all_targets.append(targets)
116
+ else:
117
+ raise ValueError(f"No target found for {target_filename}.")
118
+
119
+ except Exception as e:
120
+ # emit warning and continue
121
+ logger.error(f"Failed to read {train_filename} or {target_filename}: {e}")
122
+
123
+ # raise error if no valid samples found
124
+ if num_samples == 0 or len(all_patches) == 0:
125
+ raise ValueError(
126
+ f"No valid samples found in the input data: {train_files} and "
127
+ f"{target_files}."
128
+ )
129
+
130
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
131
+ target_means, target_stds = compute_normalization_stats(np.concatenate(all_targets))
132
+
133
+ patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
134
+ target_array: np.ndarray = np.concatenate(all_targets, axis=0)
135
+ logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
136
+
137
+ return PatchedOutput(
138
+ patch_array,
139
+ target_array,
140
+ Stats(image_means, image_stds),
141
+ Stats(target_means, target_stds),
142
+ )
143
+
144
+
145
+ # called by in_memory_dataset
146
+ def prepare_patches_unsupervised(
147
+ train_files: list[Path],
148
+ axes: str,
149
+ patch_size: Union[list[int], tuple[int]],
150
+ read_source_func: Callable,
151
+ ) -> PatchedOutput:
152
+ """Iterate over data source and create an array of patches.
153
+
154
+ This method returns the mean and standard deviation of the image.
155
+
156
+ Parameters
157
+ ----------
158
+ train_files : list of pathlib.Path
159
+ List of paths to training data.
160
+ axes : str
161
+ Axes of the data.
162
+ patch_size : list or tuple of int
163
+ Size of the patches.
164
+ read_source_func : Callable
165
+ Function to read the data.
166
+
167
+ Returns
168
+ -------
169
+ PatchedOutput
170
+ Dataclass holding patches and their statistics.
171
+ """
172
+ means, stds, num_samples = 0, 0, 0
173
+ all_patches = []
174
+ for filename in train_files:
175
+ try:
176
+ sample: np.ndarray = read_source_func(filename, axes)
177
+ means += sample.mean()
178
+ stds += sample.std()
179
+ num_samples += 1
180
+
181
+ # reshape array
182
+ sample = reshape_array(sample, axes)
183
+
184
+ # generate patches, return a generator
185
+ patches, _ = extract_patches_sequential(sample, patch_size=patch_size)
186
+
187
+ # convert generator to list and add to all_patches
188
+ all_patches.append(patches)
189
+ except Exception as e:
190
+ # emit warning and continue
191
+ logger.error(f"Failed to read {filename}: {e}")
192
+
193
+ # raise error if no valid samples found
194
+ if num_samples == 0:
195
+ raise ValueError(f"No valid samples found in the input data: {train_files}.")
196
+
197
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
198
+
199
+ patch_array: np.ndarray = np.concatenate(all_patches)
200
+ logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
201
+
202
+ return PatchedOutput(
203
+ patch_array, None, Stats(image_means, image_stds), Stats((), ())
204
+ )
205
+
206
+
207
+ # called on arrays by in memory dataset
208
+ def prepare_patches_supervised_array(
209
+ data: NDArray,
210
+ axes: str,
211
+ data_target: NDArray,
212
+ patch_size: Union[list[int], tuple[int]],
213
+ ) -> PatchedOutput:
214
+ """Iterate over data source and create an array of patches.
215
+
216
+ This method expects an array of shape SC(Z)YX, where S and C can be singleton
217
+ dimensions.
218
+
219
+ Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
220
+
221
+ Parameters
222
+ ----------
223
+ data : numpy.ndarray
224
+ Input data array.
225
+ axes : str
226
+ Axes of the data.
227
+ data_target : numpy.ndarray
228
+ Target data array.
229
+ patch_size : list or tuple of int
230
+ Size of the patches.
231
+
232
+ Returns
233
+ -------
234
+ PatchedOutput
235
+ Dataclass holding the source and target patches, with their statistics.
236
+ """
237
+ # reshape array
238
+ reshaped_sample = reshape_array(data, axes)
239
+ reshaped_target = reshape_array(data_target, axes)
240
+
241
+ # compute statistics
242
+ image_means, image_stds = compute_normalization_stats(reshaped_sample)
243
+ target_means, target_stds = compute_normalization_stats(reshaped_target)
244
+
245
+ # generate patches, return a generator
246
+ patches, patch_targets = extract_patches_sequential(
247
+ reshaped_sample, patch_size=patch_size, target=reshaped_target
248
+ )
249
+
250
+ if patch_targets is None:
251
+ raise ValueError("No target extracted.")
252
+
253
+ logger.info(f"Extracted {patches.shape[0]} patches from input array.")
254
+
255
+ return PatchedOutput(
256
+ patches,
257
+ patch_targets,
258
+ Stats(image_means, image_stds),
259
+ Stats(target_means, target_stds),
260
+ )
261
+
262
+
263
+ # called by in memory dataset
264
+ def prepare_patches_unsupervised_array(
265
+ data: NDArray,
266
+ axes: str,
267
+ patch_size: Union[list[int], tuple[int]],
268
+ ) -> PatchedOutput:
269
+ """
270
+ Iterate over data source and create an array of patches.
271
+
272
+ This method expects an array of shape SC(Z)YX, where S and C can be singleton
273
+ dimensions.
274
+
275
+ Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
276
+
277
+ Parameters
278
+ ----------
279
+ data : numpy.ndarray
280
+ Input data array.
281
+ axes : str
282
+ Axes of the data.
283
+ patch_size : list or tuple of int
284
+ Size of the patches.
285
+
286
+ Returns
287
+ -------
288
+ PatchedOutput
289
+ Dataclass holding the patches and their statistics.
290
+ """
291
+ # reshape array
292
+ reshaped_sample = reshape_array(data, axes)
293
+
294
+ # calculate mean and std
295
+ means, stds = compute_normalization_stats(reshaped_sample)
296
+
297
+ # generate patches, return a generator
298
+ patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
299
+
300
+ return PatchedOutput(patches, None, Stats(means, stds), Stats((), ()))
@@ -0,0 +1,110 @@
1
+ """Random patching utilities."""
2
+
3
+ from collections.abc import Generator
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+
8
+ from .validate_patch_dimension import validate_patch_dimensions
9
+
10
+
11
+ # TOOD split in testable functions
12
+ def extract_patches_random(
13
+ arr: np.ndarray,
14
+ patch_size: Union[list[int], tuple[int, ...]],
15
+ target: np.ndarray | None = None,
16
+ seed: int | None = None,
17
+ ) -> Generator[tuple[np.ndarray, np.ndarray | None], None, None]:
18
+ """
19
+ Generate patches from an array in a random manner.
20
+
21
+ The method calculates how many patches the image can be divided into and then
22
+ extracts an equal number of random patches.
23
+
24
+ It returns a generator that yields the following:
25
+
26
+ - patch: np.ndarray, dimension C(Z)YX.
27
+ - target_patch: np.ndarray, dimension C(Z)YX, if the target is present, None
28
+ otherwise.
29
+
30
+ Parameters
31
+ ----------
32
+ arr : np.ndarray
33
+ Input image array.
34
+ patch_size : tuple of int
35
+ Patch sizes in each dimension.
36
+ target : Optional[np.ndarray], optional
37
+ Target array, by default None.
38
+ seed : int or None, default=None
39
+ Random seed.
40
+
41
+ Yields
42
+ ------
43
+ Generator[np.ndarray, None, None]
44
+ Generator of patches.
45
+ """
46
+ rng = np.random.default_rng(seed=seed)
47
+
48
+ is_3d_patch = len(patch_size) == 3
49
+
50
+ # patches sanity check
51
+ validate_patch_dimensions(arr, patch_size, is_3d_patch)
52
+
53
+ # Update patch size to encompass S and C dimensions
54
+ patch_size = [1, arr.shape[1], *patch_size]
55
+
56
+ # iterate over the number of samples (S or T)
57
+ for sample_idx in range(arr.shape[0]):
58
+ # get sample array
59
+ sample: np.ndarray = arr[sample_idx, ...]
60
+
61
+ # same for target
62
+ if target is not None:
63
+ target_sample: np.ndarray = target[sample_idx, ...]
64
+
65
+ # calculate the number of patches
66
+ n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
67
+
68
+ # iterate over the number of patches
69
+ for _ in range(n_patches):
70
+ # get crop coordinates
71
+ crop_coords = [
72
+ rng.integers(0, sample.shape[i] - patch_size[1:][i], endpoint=True)
73
+ for i in range(len(patch_size[1:]))
74
+ ]
75
+
76
+ # extract patch
77
+ patch = (
78
+ sample[
79
+ (
80
+ ..., # type: ignore
81
+ *[ # type: ignore
82
+ slice(c, c + patch_size[1:][i])
83
+ for i, c in enumerate(crop_coords)
84
+ ],
85
+ )
86
+ ]
87
+ .copy()
88
+ .astype(np.float32)
89
+ )
90
+
91
+ # same for target
92
+ if target is not None:
93
+ target_patch = (
94
+ target_sample[
95
+ (
96
+ ..., # type: ignore
97
+ *[ # type: ignore
98
+ slice(c, c + patch_size[1:][i])
99
+ for i, c in enumerate(crop_coords)
100
+ ],
101
+ )
102
+ ]
103
+ .copy()
104
+ .astype(np.float32)
105
+ )
106
+ # return patch and target patch
107
+ yield patch, target_patch
108
+ else:
109
+ # return patch
110
+ yield patch, None
@@ -0,0 +1,212 @@
1
+ """Sequential patching functions."""
2
+
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ from skimage.util import view_as_windows
7
+
8
+ from .validate_patch_dimension import validate_patch_dimensions
9
+
10
+
11
+ def _compute_number_of_patches(
12
+ arr_shape: tuple[int, ...], patch_sizes: Union[list[int], tuple[int, ...]]
13
+ ) -> tuple[int, ...]:
14
+ """
15
+ Compute the number of patches that fit in each dimension.
16
+
17
+ Parameters
18
+ ----------
19
+ arr_shape : tuple[int, ...]
20
+ Shape of the input array.
21
+ patch_sizes : Union[list[int], tuple[int, ...]
22
+ Shape of the patches.
23
+
24
+ Returns
25
+ -------
26
+ tuple[int, ...]
27
+ Number of patches in each dimension.
28
+ """
29
+ if len(arr_shape) != len(patch_sizes):
30
+ raise ValueError(
31
+ f"Array shape {arr_shape} and patch size {patch_sizes} should have the "
32
+ f"same dimension, including singleton dimension for S and equal dimension "
33
+ f"for C."
34
+ )
35
+
36
+ try:
37
+ n_patches = [
38
+ np.ceil(arr_shape[i] / patch_sizes[i]).astype(int)
39
+ for i in range(len(patch_sizes))
40
+ ]
41
+ except IndexError as e:
42
+ raise ValueError(
43
+ f"Patch size {patch_sizes} is not compatible with array shape {arr_shape}"
44
+ ) from e
45
+
46
+ return tuple(n_patches)
47
+
48
+
49
+ def _compute_overlap(
50
+ arr_shape: tuple[int, ...], patch_sizes: Union[list[int], tuple[int, ...]]
51
+ ) -> tuple[int, ...]:
52
+ """
53
+ Compute the overlap between patches in each dimension.
54
+
55
+ If the array dimensions are divisible by the patch sizes, then the overlap is
56
+ 0. Otherwise, it is the result of the division rounded to the upper value.
57
+
58
+ Parameters
59
+ ----------
60
+ arr_shape : tuple[int, ...]
61
+ Input array shape.
62
+ patch_sizes : Union[list[int], tuple[int, ...]]
63
+ Size of the patches.
64
+
65
+ Returns
66
+ -------
67
+ tuple[int, ...]
68
+ Overlap between patches in each dimension.
69
+ """
70
+ n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
71
+
72
+ overlap = [
73
+ np.ceil(
74
+ np.clip(n_patches[i] * patch_sizes[i] - arr_shape[i], 0, None)
75
+ / max(1, (n_patches[i] - 1))
76
+ ).astype(int)
77
+ for i in range(len(patch_sizes))
78
+ ]
79
+ return tuple(overlap)
80
+
81
+
82
+ def _compute_patch_steps(
83
+ patch_sizes: Union[list[int], tuple[int, ...]], overlaps: tuple[int, ...]
84
+ ) -> tuple[int, ...]:
85
+ """
86
+ Compute steps between patches.
87
+
88
+ Parameters
89
+ ----------
90
+ patch_sizes : tuple[int]
91
+ Size of the patches.
92
+ overlaps : tuple[int]
93
+ Overlap between patches.
94
+
95
+ Returns
96
+ -------
97
+ tuple[int]
98
+ Steps between patches.
99
+ """
100
+ steps = [
101
+ min(patch_sizes[i] - overlaps[i], patch_sizes[i])
102
+ for i in range(len(patch_sizes))
103
+ ]
104
+ return tuple(steps)
105
+
106
+
107
+ # TODO why stack the target here and not on a different dimension before this function?
108
+ def _compute_patch_views(
109
+ arr: np.ndarray,
110
+ window_shape: list[int],
111
+ step: tuple[int, ...],
112
+ output_shape: list[int],
113
+ target: np.ndarray | None = None,
114
+ ) -> np.ndarray:
115
+ """
116
+ Compute views of an array corresponding to patches.
117
+
118
+ Parameters
119
+ ----------
120
+ arr : np.ndarray
121
+ Array from which the views are extracted.
122
+ window_shape : tuple[int]
123
+ Shape of the views.
124
+ step : tuple[int]
125
+ Steps between views.
126
+ output_shape : tuple[int]
127
+ Shape of the output array.
128
+ target : Optional[np.ndarray], optional
129
+ Target array, by default None.
130
+
131
+ Returns
132
+ -------
133
+ np.ndarray
134
+ Array with views dimension.
135
+ """
136
+ rng = np.random.default_rng()
137
+
138
+ if target is not None:
139
+ arr = np.stack([arr, target], axis=0)
140
+ window_shape = [arr.shape[0], *window_shape]
141
+ step = (arr.shape[0], *step)
142
+ output_shape = [-1, arr.shape[0], arr.shape[2], *output_shape[2:]]
143
+
144
+ patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
145
+ *output_shape
146
+ )
147
+ rng.shuffle(patches, axis=0)
148
+ return patches
149
+
150
+
151
+ def extract_patches_sequential(
152
+ arr: np.ndarray,
153
+ patch_size: Union[list[int], tuple[int, ...]],
154
+ target: np.ndarray | None = None,
155
+ ) -> tuple[np.ndarray, np.ndarray | None]:
156
+ """
157
+ Generate patches from an array in a sequential manner.
158
+
159
+ Array dimensions should be SC(Z)YX, where S and C can be singleton dimensions. The
160
+ patches are generated sequentially and cover the whole array.
161
+
162
+ Parameters
163
+ ----------
164
+ arr : np.ndarray
165
+ Input image array.
166
+ patch_size : tuple[int]
167
+ Patch sizes in each dimension.
168
+ target : Optional[np.ndarray], optional
169
+ Target array, by default None.
170
+
171
+ Returns
172
+ -------
173
+ tuple[np.ndarray, Optional[np.ndarray]]
174
+ Patches.
175
+ """
176
+ is_3d_patch = len(patch_size) == 3
177
+
178
+ # Patches sanity check
179
+ validate_patch_dimensions(arr, patch_size, is_3d_patch)
180
+
181
+ # Update patch size to encompass S and C dimensions
182
+ patch_size = [1, arr.shape[1], *patch_size]
183
+
184
+ # Compute overlap
185
+ overlaps = _compute_overlap(arr_shape=arr.shape, patch_sizes=patch_size)
186
+
187
+ # Create view window and overlaps
188
+ window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
189
+
190
+ output_shape = [
191
+ -1,
192
+ ] + patch_size[1:]
193
+
194
+ # Generate a view of the input array containing pre-calculated number of patches
195
+ # in each dimension with overlap.
196
+ # Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches, C, Y, X)
197
+ patches = _compute_patch_views(
198
+ arr,
199
+ window_shape=patch_size,
200
+ step=window_steps,
201
+ output_shape=output_shape,
202
+ target=target,
203
+ )
204
+
205
+ if target is not None:
206
+ # target was concatenated to patches in _compute_reshaped_view
207
+ return (
208
+ patches[:, 0, ...],
209
+ patches[:, 1, ...],
210
+ ) # TODO in _compute_reshaped_view?
211
+ else:
212
+ return patches, None
@@ -0,0 +1,64 @@
1
+ """Patch validation functions."""
2
+
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+
7
+
8
+ def validate_patch_dimensions(
9
+ arr: np.ndarray,
10
+ patch_size: Union[list[int], tuple[int, ...]],
11
+ is_3d_patch: bool,
12
+ ) -> None:
13
+ """
14
+ Check patch size and array compatibility.
15
+
16
+ This method validates the patch sizes with respect to the array dimensions:
17
+
18
+ - Patch must have two dimensions fewer than the array (S and C).
19
+ - Patch sizes are smaller than the corresponding array dimensions.
20
+
21
+ If one of these conditions is not met, a ValueError is raised.
22
+
23
+ This method should be called after inputs have been resized.
24
+
25
+ Parameters
26
+ ----------
27
+ arr : np.ndarray
28
+ Input array.
29
+ patch_size : Union[list[int], tuple[int, ...]]
30
+ Size of the patches along each dimension of the array, except the first.
31
+ is_3d_patch : bool
32
+ Whether the patch is 3D or not.
33
+
34
+ Raises
35
+ ------
36
+ ValueError
37
+ If the patch size is not consistent with the array shape (one more array
38
+ dimension).
39
+ ValueError
40
+ If the patch size in Z is larger than the array dimension.
41
+ ValueError
42
+ If either of the patch sizes in X or Y is larger than the corresponding array
43
+ dimension.
44
+ """
45
+ if len(patch_size) != len(arr.shape[2:]):
46
+ raise ValueError(
47
+ f"There must be a patch size for each spatial dimensions "
48
+ f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
49
+ )
50
+
51
+ # Sanity checks on patch sizes versus array dimension
52
+ if is_3d_patch and patch_size[0] > arr.shape[-3]:
53
+ raise ValueError(
54
+ f"Z patch size is inconsistent with image shape "
55
+ f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
56
+ f"order."
57
+ )
58
+
59
+ if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
60
+ raise ValueError(
61
+ f"At least one of YX patch dimensions is larger than the corresponding "
62
+ f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
63
+ f"Check the axes order."
64
+ )
@@ -0,0 +1,10 @@
1
+ """Tiling functions."""
2
+
3
+ __all__ = [
4
+ "collate_tiles",
5
+ "extract_tiles",
6
+ "stitch_prediction",
7
+ ]
8
+
9
+ from .collate_tiles import collate_tiles
10
+ from .tiled_patching import extract_tiles