careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,188 @@
1
+ """Filter patches based on Shannon entropy threshold."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import numpy as np
6
+ from skimage.measure import shannon_entropy
7
+ from tqdm import tqdm
8
+
9
+ from careamics.dataset_ng.image_stack_loader import load_arrays
10
+ from careamics.dataset_ng.patch_extractor import PatchExtractor
11
+ from careamics.dataset_ng.patch_filter.patch_filter_protocol import PatchFilterProtocol
12
+ from careamics.dataset_ng.patching_strategies import TilingStrategy
13
+
14
+
15
+ class ShannonPatchFilter(PatchFilterProtocol):
16
+ """
17
+ Filter patches based on Shannon entropy threshold.
18
+
19
+ Attributes
20
+ ----------
21
+ threshold : float
22
+ Threshold for the Shannon entropy of the patch.
23
+ p : float
24
+ Probability of applying the filter to a patch.
25
+ rng : np.random.Generator
26
+ Random number generator for stochastic filtering.
27
+ """
28
+
29
+ def __init__(
30
+ self, threshold: float, p: float = 1.0, seed: int | None = None
31
+ ) -> None:
32
+ """
33
+ Create a ShannonEntropyFilter.
34
+
35
+ This filter removes patches whose Shannon entropy is below a specified
36
+ threshold.
37
+
38
+ Parameters
39
+ ----------
40
+ threshold : float
41
+ Threshold for the Shannon entropy of the patch.
42
+ p : float, default=1
43
+ Probability of applying the filter to a patch. Must be between 0 and 1.
44
+ seed : int | None, default=None
45
+ Seed for the random number generator for reproducibility.
46
+
47
+ Raises
48
+ ------
49
+ ValueError
50
+ If threshold is negative.
51
+ ValueError
52
+ If p is not between 0 and 1.
53
+ """
54
+ if threshold < 0:
55
+ raise ValueError("Threshold must be non-negative.")
56
+ if not (0 <= p <= 1):
57
+ raise ValueError("Probability p must be between 0 and 1.")
58
+
59
+ self.threshold = threshold
60
+
61
+ self.p = p
62
+ self.rng = np.random.default_rng(seed)
63
+
64
+ def filter_out(self, patch: np.ndarray) -> bool:
65
+ """
66
+ Determine whether to filter out a patch based on its Shannon entropy.
67
+
68
+ Parameters
69
+ ----------
70
+ patch : numpy.NDArray
71
+ The patch to evaluate.
72
+
73
+ Returns
74
+ -------
75
+ bool
76
+ True if the patch should be filtered out, False otherwise.
77
+ """
78
+ if self.rng.uniform(0, 1) < self.p:
79
+ return shannon_entropy(patch) < self.threshold
80
+ return False
81
+
82
+ @staticmethod
83
+ def filter_map(
84
+ image: np.ndarray,
85
+ patch_size: Sequence[int],
86
+ ) -> np.ndarray:
87
+ """
88
+ Compute the Shannon entropy map of an image.
89
+
90
+ The entropy is computed over non-overlapping patches. This method can be used
91
+ to assess a useful threshold for the Shannon entropy filter.
92
+
93
+ Parameters
94
+ ----------
95
+ image : numpy.NDArray
96
+ The image for which to compute the entropy map, must be 2D or 3D.
97
+ patch_size : Sequence[int]
98
+ The size of the patches to compute the entropy over. Must be a sequence
99
+ of two integers.
100
+
101
+ Returns
102
+ -------
103
+ numpy.NDArray
104
+ The Shannon entropy map of the patch.
105
+
106
+ Raises
107
+ ------
108
+ ValueError
109
+ If the image is not 2D or 3D.
110
+
111
+ Example
112
+ -------
113
+ The `filter_map` method can be used to assess a useful threshold for the
114
+ Shannon entropy filter. Below is an example of how to compute and visualize
115
+ the Shannon entropy map of a random image and visualize thresholded versions
116
+ of the map.
117
+ >>> import numpy as np
118
+ >>> from matplotlib import pyplot as plt
119
+ >>> from careamics.dataset_ng.patch_filter import ShannonPatchFilter
120
+ >>> rng = np.random.default_rng(42)
121
+ >>> image = rng.binomial(20, 0.1, (256, 256)).astype(np.float32)
122
+ >>> image[64:192, 64:192] += rng.normal(50, 5, (128, 128))
123
+ >>> image[96:160, 96:160] = rng.poisson(image[96:160, 96:160])
124
+ >>> patch_size = (16, 16)
125
+ >>> entropy_map = ShannonPatchFilter.filter_map(image, patch_size)
126
+ >>> fig, ax = plt.subplots(1, 5, figsize=(20, 5)) # doctest: +SKIP
127
+ >>> for i, thresh in enumerate([2 + 1.5 * i for i in range(5)]):
128
+ ... ax[i].imshow(entropy_map >= thresh, cmap="gray") #doctest: +SKIP
129
+ ... ax[i].set_title(f"Threshold: {thresh}") #doctest: +SKIP
130
+ >>> plt.show() # doctest: +SKIP
131
+ """
132
+ if len(image.shape) < 2 or len(image.shape) > 3:
133
+ raise ValueError("Image must be 2D or 3D.")
134
+
135
+ axes = "YX" if len(patch_size) == 2 else "ZYX"
136
+
137
+ shannon_img = np.zeros_like(image, dtype=float)
138
+
139
+ image_stacks = load_arrays(source=[image], axes=axes)
140
+ extractor = PatchExtractor(image_stacks)
141
+ tiling = TilingStrategy(
142
+ data_shapes=[(1, 1, *image.shape)],
143
+ patch_size=patch_size,
144
+ overlaps=(0,) * len(patch_size), # no overlap
145
+ )
146
+
147
+ for idx in tqdm(range(tiling.n_patches), desc="Computing Shannon Entropy map"):
148
+ patch_spec = tiling.get_patch_spec(idx)
149
+ patch = extractor.extract_patch(
150
+ data_idx=0,
151
+ sample_idx=0,
152
+ coords=patch_spec["coords"],
153
+ patch_size=patch_size,
154
+ )
155
+
156
+ coordinates = tuple(
157
+ slice(patch_spec["coords"][i], patch_spec["coords"][i] + p)
158
+ for i, p in enumerate(patch_size)
159
+ )
160
+ shannon_img[coordinates] = shannon_entropy(patch)
161
+
162
+ return shannon_img
163
+
164
+ @staticmethod
165
+ def apply_filter(
166
+ filter_map: np.ndarray,
167
+ threshold: float,
168
+ ) -> np.ndarray:
169
+ """
170
+ Apply the Shannon entropy filter to a precomputed filter map.
171
+
172
+ The filter map is the output of the `filter_map` method.
173
+
174
+ Parameters
175
+ ----------
176
+ filter_map : numpy.NDArray
177
+ The precomputed Shannon entropy map of the image.
178
+ threshold : float
179
+ The Shannon entropy threshold for filtering.
180
+
181
+ Returns
182
+ -------
183
+ numpy.NDArray
184
+ A boolean array where True indicates that the patch should be kept
185
+ (not filtered out) and False indicates that the patch should be filtered
186
+ out.
187
+ """
188
+ return filter_map >= threshold
@@ -0,0 +1,26 @@
1
+ __all__ = [
2
+ "FixedRandomPatchingStrategy",
3
+ "PatchSpecs",
4
+ "PatchingStrategy",
5
+ "RandomPatchingStrategy",
6
+ "RegionSpecs",
7
+ "SequentialPatchingStrategy",
8
+ "TileSpecs",
9
+ "TilingStrategy",
10
+ "WholeSamplePatchingStrategy",
11
+ "create_patching_strategy",
12
+ "is_tile_specs",
13
+ ]
14
+
15
+ from .patching_strategy_factory import create_patching_strategy
16
+ from .patching_strategy_protocol import (
17
+ PatchingStrategy,
18
+ PatchSpecs,
19
+ RegionSpecs,
20
+ TileSpecs,
21
+ is_tile_specs,
22
+ )
23
+ from .random_patching import FixedRandomPatchingStrategy, RandomPatchingStrategy
24
+ from .sequential_patching import SequentialPatchingStrategy
25
+ from .tiling_strategy import TilingStrategy
26
+ from .whole_sample import WholeSamplePatchingStrategy
@@ -0,0 +1,50 @@
1
+ """Patching strategy factory."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from careamics.config.data.ng_data_config import PatchingConfig
6
+ from careamics.config.support.supported_patching_strategies import (
7
+ SupportedPatchingStrategy,
8
+ )
9
+
10
+ from .patching_strategy_protocol import PatchingStrategy
11
+ from .random_patching import FixedRandomPatchingStrategy, RandomPatchingStrategy
12
+ from .tiling_strategy import TilingStrategy
13
+ from .whole_sample import WholeSamplePatchingStrategy
14
+
15
+
16
+ def create_patching_strategy(
17
+ data_shapes: list[Sequence[int]], patching_config: PatchingConfig
18
+ ) -> PatchingStrategy:
19
+ """Factory function to create a patching strategy based on the provided config.
20
+
21
+ Parameters
22
+ ----------
23
+ data_shapes : list of Sequence of int
24
+ The shapes of the data stacks to be patched.
25
+ patching_config: PatchingConfig
26
+ The configuration for the desired patching strategy.
27
+
28
+ Returns
29
+ -------
30
+ PatchingStrategy
31
+ An instance of the specified patching strategy.
32
+ """
33
+ patch_class = None
34
+ match patching_config.name:
35
+ case SupportedPatchingStrategy.RANDOM:
36
+ patch_class = RandomPatchingStrategy
37
+ case SupportedPatchingStrategy.FIXED_RANDOM:
38
+ patch_class = FixedRandomPatchingStrategy
39
+ case SupportedPatchingStrategy.TILED:
40
+ patch_class = TilingStrategy
41
+ case SupportedPatchingStrategy.WHOLE:
42
+ patch_class = WholeSamplePatchingStrategy
43
+ case _:
44
+ raise ValueError(f"Unsupported patching strategy: {patching_config.name}")
45
+
46
+ # remove `name` to match the class signatures
47
+ # tiling requires `tile_size` instead of `patch_size`, hence the aliasing
48
+ return patch_class(
49
+ data_shapes=data_shapes, **patching_config.model_dump(exclude={"name"})
50
+ )
@@ -0,0 +1,161 @@
1
+ """A module to contain type definitions relating to patching strategies."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Protocol, TypedDict, TypeGuard, TypeVar
5
+
6
+ RegionSpecs = TypeVar("RegionSpecs", bound="PatchSpecs")
7
+
8
+
9
+ class PatchSpecs(TypedDict):
10
+ """A dictionary that specifies a single patch in a series of `ImageStacks`.
11
+
12
+ Attributes
13
+ ----------
14
+ data_idx: int
15
+ Determines which `ImageStack` a patch belongs to, within a series of
16
+ `ImageStack`s.
17
+ sample_idx: int
18
+ Determines which sample a patch belongs to, within an `ImageStack`.
19
+ coords: sequence of int
20
+ The top-left (and first z-slice for 3D data) of a patch. The sequence will have
21
+ length 2 or 3, for 2D and 3D data respectively.
22
+ patch_size: sequence of int
23
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D data
24
+ respectively.
25
+ """
26
+
27
+ data_idx: int
28
+ sample_idx: int
29
+ coords: Sequence[int]
30
+ patch_size: Sequence[int]
31
+
32
+
33
+ class TileSpecs(PatchSpecs):
34
+ """A dictionary that specifies a single patch in a series of `ImageStacks`.
35
+
36
+ Attributes
37
+ ----------
38
+ data_idx: int
39
+ Determines which `ImageStack` a patch belongs to, within a series of
40
+ `ImageStack`s.
41
+ sample_idx: int
42
+ Determines which sample a patch belongs to, within an `ImageStack`.
43
+ coords: sequence of int
44
+ The top-left (and first z-slice for 3D data) of a patch. The sequence will have
45
+ length 2 or 3, for 2D and 3D data respectively.
46
+ patch_size: sequence of int
47
+ The size of the patch. The sequence will have length 2 or 3, for 2D and 3D data
48
+ respectively.
49
+ crop_coords: sequence of int
50
+ The top-left side of where the tile will be cropped, in coordinates relative
51
+ to the tile.
52
+ crop_size: sequence of int
53
+ The size of the cropped tile.
54
+ stitch_coords: sequence of int
55
+ Where the tile will be stitched back into an image, taking into account
56
+ that the tile will be cropped, in coords relative to the image.
57
+ total_tiles: int
58
+ Number of tiles belonging to the same data.
59
+ """
60
+
61
+ crop_coords: Sequence[int]
62
+ crop_size: Sequence[int]
63
+ stitch_coords: Sequence[int]
64
+ total_tiles: int
65
+
66
+
67
+ def is_tile_specs(specs: PatchSpecs) -> TypeGuard[TileSpecs]:
68
+ """Determine whether a given PatchSpecs is a TileSpecs.
69
+
70
+ Used for type checking.
71
+
72
+ Parameters
73
+ ----------
74
+ specs : PatchSpecs
75
+ A patch specification.
76
+
77
+ Returns
78
+ -------
79
+ bool
80
+ Whether the given specs is a TileSpecs.
81
+ """
82
+ return (
83
+ ("crop_coords" in specs)
84
+ and ("crop_size" in specs)
85
+ and ("stitch_coords" in specs)
86
+ )
87
+
88
+
89
+ class PatchingStrategy(Protocol):
90
+ """
91
+ An interface for patching strategies.
92
+
93
+ Patching strategies are a component of the `CAREamicsDataset`; they determine
94
+ how patches are extracted from the underlying data.
95
+
96
+ Attributes
97
+ ----------
98
+ n_patches: int
99
+ The number of patches that the patching strategy will return.
100
+
101
+ Methods
102
+ -------
103
+ get_patch_spec(index: int) -> PatchSpecs
104
+ Get a patch specification for a given patch index.
105
+ """
106
+
107
+ @property
108
+ def n_patches(self) -> int:
109
+ """
110
+ The number of patches that the patching strategy will return.
111
+
112
+ It also determines the maximum index that can be given to `get_patch_spec`,
113
+ and the length of the `CAREamicsDataset`.
114
+
115
+ Returns
116
+ -------
117
+ int
118
+ Number of patches.
119
+ """
120
+ ...
121
+
122
+ def get_patch_spec(self, index: int) -> PatchSpecs:
123
+ """
124
+ Get a patch specification for a given patch index.
125
+
126
+ This method is intended to be called from within the
127
+ `CAREamicsDataset.__getitem__`. The index will be passed through from this
128
+ method.
129
+
130
+ Parameters
131
+ ----------
132
+ index : int
133
+ A patch index.
134
+
135
+ Returns
136
+ -------
137
+ PatchSpecs
138
+ A dictionary that specifies a single patch in a series of `ImageStacks`.
139
+ """
140
+ ...
141
+
142
+ # Note: this is used by the FileIterSampler
143
+ def get_patch_indices(self, data_idx: int) -> Sequence[int]:
144
+ """
145
+ Get the patch indices will return patches for a specific `image_stack`.
146
+
147
+ The `image_stack` corresponds to the given `data_idx`.
148
+
149
+ Parameters
150
+ ----------
151
+ data_idx : int
152
+ An index that corresponds to a given `image_stack`.
153
+
154
+ Returns
155
+ -------
156
+ sequence of int
157
+ A sequence of patch indices, that when used to index the `CAREamicsDataset
158
+ will return a patch that comes from the `image_stack` corresponding to the
159
+ given `data_idx`.
160
+ """
161
+ ...