careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,140 @@
1
+ from collections.abc import Sequence
2
+ from pathlib import Path
3
+ from typing import Any, Self
4
+
5
+ import numpy as np
6
+ import tifffile
7
+ from numpy.typing import DTypeLike, NDArray
8
+
9
+ from careamics.dataset.dataset_utils import reshape_array
10
+ from careamics.file_io.read import ReadFunc, read_tiff
11
+
12
+ from .image_utils.image_stack_utils import channel_slice, pad_patch, reshape_array_shape
13
+
14
+
15
+ class FileImageStack:
16
+ """
17
+ An ImageStack implementation for data that is coming from a file.
18
+
19
+ The data will not be loaded until the `load` method is called. The `close` method
20
+ can be used to remove the internal reference to the data.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ source: Path,
26
+ axes: str,
27
+ data_shape: tuple[int, ...],
28
+ data_dtype: DTypeLike,
29
+ read_func: ReadFunc,
30
+ read_kwargs: dict[str, Any] | Any = None,
31
+ ):
32
+ self.source = source
33
+ self.axes = axes
34
+ self.data_shape = data_shape
35
+ self.data_dtype = data_dtype
36
+ self.read_func = read_func
37
+ self.read_kwargs = read_kwargs
38
+ self._data: NDArray | None = None
39
+
40
+ def extract_patch(
41
+ self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
42
+ ) -> NDArray:
43
+ return self.extract_channel_patch(sample_idx, None, coords, patch_size)
44
+
45
+ def extract_channel_patch(
46
+ self,
47
+ sample_idx: int,
48
+ channels: Sequence[int] | None, # `channels = None` to select all channels
49
+ coords: Sequence[int],
50
+ patch_size: Sequence[int],
51
+ ) -> NDArray:
52
+ if self._data is None:
53
+ raise ValueError(
54
+ "Cannot extract patch because data has not been loaded from "
55
+ f"'{self.source}', the `load` method must be called first."
56
+ )
57
+
58
+ if (coord_dims := len(coords)) != (patch_dims := len(patch_size)):
59
+ raise ValueError(
60
+ "Patch coordinates and patch size must have the same dimensions but "
61
+ f"found {coord_dims} and {patch_dims}."
62
+ )
63
+
64
+ # check that channels are within bounds
65
+ if channels is not None:
66
+ max_channel = self.data_shape[1] - 1 # channel is second dimension
67
+ for ch in channels:
68
+ if ch > max_channel:
69
+ raise ValueError(
70
+ f"Channel index {ch} is out of bounds for data with "
71
+ f"{self.data_shape[1]} channels. Check the provided `channels` "
72
+ f"parameter in the configuration for erroneous channel "
73
+ f"indices."
74
+ )
75
+
76
+ patch_data = self._data[
77
+ (
78
+ sample_idx, # type: ignore
79
+ # use channel slice so that channel dimension is kept
80
+ channel_slice(channels), # type: ignore
81
+ *[
82
+ slice(
83
+ np.clip(c, 0, self.data_shape[2 + i]),
84
+ np.clip(c + ps, 0, self.data_shape[2 + i]),
85
+ )
86
+ for i, (c, ps) in enumerate(zip(coords, patch_size, strict=False))
87
+ ], # type: ignore
88
+ ) # type: ignore
89
+ ]
90
+ patch = pad_patch(coords, patch_size, self.data_shape, patch_data)
91
+
92
+ return patch
93
+
94
+ def load(self):
95
+ """Load the data stored in a file."""
96
+ data = self.read_func(self.source)
97
+ self._data = reshape_array(data, self.axes)
98
+
99
+ # TODO: maybe this should be called something else
100
+ def close(self):
101
+ """Remove the internal reference to the data to clear up memory."""
102
+ # will get cleaned up by the garbage collector since there is no longer a ref
103
+ self._data = None
104
+
105
+ @property
106
+ def is_loaded(self):
107
+ return self._data is not None
108
+
109
+ @classmethod
110
+ def from_tiff(
111
+ cls,
112
+ path: Path,
113
+ axes: str,
114
+ ) -> Self:
115
+ """
116
+ Construct the `ImageStack` from a TIFF file.
117
+
118
+ Parameters
119
+ ----------
120
+ path : Path
121
+ Path to the TIFF file.
122
+ axes : str
123
+ The original axes of the data, must be a subset of STCZYX.
124
+
125
+ Returns
126
+ -------
127
+ Self
128
+ The `ImageStack` with the underlying data being from a TIFF file.
129
+ """
130
+ # TODO: think this is correct but need more examples to test
131
+ file = tifffile.TiffFile(path)
132
+ data_shape = reshape_array_shape(axes, file.series[0].shape)
133
+ dtype = file.series[0].dtype
134
+ return cls(
135
+ source=path,
136
+ axes=axes,
137
+ data_shape=data_shape,
138
+ data_dtype=dtype,
139
+ read_func=read_tiff,
140
+ )
@@ -0,0 +1,93 @@
1
+ from collections.abc import Sequence
2
+ from pathlib import Path
3
+ from typing import Literal, Protocol, TypeVar, Union
4
+
5
+ from numpy.typing import DTypeLike, NDArray
6
+
7
+
8
+ class ImageStack(Protocol):
9
+ """
10
+ An interface for extracting patches from an image stack.
11
+
12
+ Attributes
13
+ ----------
14
+ source: Path or "array"
15
+ Origin of the image data.
16
+ data_shape: Sequence[int]
17
+ The shape of the data, it is expected to be in the order (SC(Z)YX).
18
+ data_dtype: DTypeLike
19
+ The data type of the image data.
20
+ """
21
+
22
+ @property
23
+ def source(self) -> Union[str, Path, Literal["array"]]: ...
24
+
25
+ """Source of the image data."""
26
+
27
+ @property
28
+ def data_shape(self) -> Sequence[int]: ...
29
+
30
+ """Shape of the image data."""
31
+
32
+ @property
33
+ def data_dtype(self) -> DTypeLike: ...
34
+
35
+ """Data type of the image data."""
36
+
37
+ def extract_patch(
38
+ self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
39
+ ) -> NDArray:
40
+ """
41
+ Extract a patch for a given sample within the image stack.
42
+
43
+ Parameters
44
+ ----------
45
+ sample_idx: int
46
+ Sample index. The first dimension of the image data will be indexed at this
47
+ value.
48
+ coords: Sequence of int
49
+ The coordinates that define the start of a patch.
50
+ patch_size: Sequence of int
51
+ The size of the patch in each spatial dimension.
52
+
53
+ Returns
54
+ -------
55
+ numpy.ndarray
56
+ A patch of the image data from a particlular sample. It will have the
57
+ dimensions C(Z)YX.
58
+ """
59
+ ...
60
+
61
+ def extract_channel_patch(
62
+ self,
63
+ sample_idx: int,
64
+ channels: Sequence[int] | None,
65
+ coords: Sequence[int],
66
+ patch_size: Sequence[int],
67
+ ) -> NDArray:
68
+ """
69
+ Extract a patch of a single channel for a given sample within the image stack.
70
+
71
+ Parameters
72
+ ----------
73
+ sample_idx: int
74
+ Sample index. The first dimension of the image data will be indexed at this
75
+ value.
76
+ channels: Sequence[int] | None
77
+ Channel indices to extract. If `None` is given all channels will be
78
+ extracted.
79
+ coords: Sequence of int
80
+ The coordinates that define the start of a patch.
81
+ patch_size: Sequence of int
82
+ The size of the patch in each spatial dimension.
83
+
84
+ Returns
85
+ -------
86
+ numpy.ndarray
87
+ A patch of the image data from a particlular sample. It will have the
88
+ dimensions C(Z)YX.
89
+ """
90
+ ...
91
+
92
+
93
+ GenericImageStack = TypeVar("GenericImageStack", bound=ImageStack, covariant=True)
@@ -0,0 +1,6 @@
1
+ """Image stack utility functions."""
2
+
3
+ __all__ = ["channel_slice", "pad_patch", "reshape_array_shape"]
4
+
5
+
6
+ from .image_stack_utils import channel_slice, pad_patch, reshape_array_shape
@@ -0,0 +1,125 @@
1
+ from collections.abc import Sequence
2
+ from types import EllipsisType
3
+ from typing import TypeVar
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ T = TypeVar("T", bound=np.generic)
9
+
10
+
11
+ def channel_slice(
12
+ channels: Sequence[int] | None,
13
+ ) -> EllipsisType | Sequence[int]:
14
+ """Create a slice or sequence for indexing channels while preserving dimensions.
15
+
16
+ Parameters
17
+ ----------
18
+ channels : Sequence[int] | None
19
+ The channel indices to select, or None to select all channels.
20
+
21
+ Returns
22
+ -------
23
+ EllipsisType | Sequence[int]
24
+ An indexing object that can be used to index the channel dimension while
25
+ preserving it.
26
+ """
27
+ if channels is None:
28
+ return ...
29
+
30
+ if len(channels) == 0:
31
+ raise ValueError("Channel index sequence cannot be empty.")
32
+
33
+ return channels
34
+
35
+
36
+ # TODO: add tests
37
+ # TODO: move to dataset_utils, better name?
38
+ def reshape_array_shape(
39
+ original_axes: str, shape: Sequence[int], add_singleton: bool = True
40
+ ) -> tuple[int, ...]:
41
+ """Find resulting shape if reshaping array to SC(Z)YX.
42
+
43
+ If `T` is present in the original axes, its size is multiplied into `S`, as both
44
+ axes are multiplexed.
45
+
46
+ Setting `add_singleton` to `False` will only include axes that are present in
47
+ `original_axes` in the output shape.
48
+
49
+ Parameters
50
+ ----------
51
+ original_axes : str
52
+ The axes of the original array, e.g. "TCZYX", "SCYX", etc.
53
+ shape : Sequence[int]
54
+ The shape of the original array.
55
+ add_singleton : bool, default=True
56
+ Whether to add singleton dimensions for missing axes. When `False`, only axes
57
+ present in `original_axes` will be included in the output shape. When `True`,
58
+ missing mandatory axes (`S` and `C`) will be added as singleton dimensions.
59
+ """
60
+ target_axes = "SCZYX"
61
+ target_shape = []
62
+ for d in target_axes:
63
+ if d in original_axes:
64
+ idx = original_axes.index(d)
65
+ target_shape.append(shape[idx])
66
+ elif d != "Z":
67
+ if add_singleton:
68
+ target_shape.append(1)
69
+
70
+ if "T" in original_axes:
71
+ idx = original_axes.index("T")
72
+ if "S" in original_axes or add_singleton:
73
+ target_shape[0] = target_shape[0] * shape[idx]
74
+ else:
75
+ target_shape.insert(0, shape[idx])
76
+
77
+ return tuple(target_shape)
78
+
79
+
80
+ def pad_patch(
81
+ coords: Sequence[int],
82
+ patch_size: Sequence[int],
83
+ data_shape: Sequence[int],
84
+ patch_data: NDArray[T],
85
+ ) -> NDArray[T]:
86
+ """
87
+ Pad patch data with zeros where it is outside the bounds of it's source image.
88
+
89
+ This ensures the patch data is contained in an array with the expected patch size.
90
+
91
+ If `coords` are negative, the start of the patch will be padded with zeros up until
92
+ where the start of the image would be, and this is where the patch data starts.
93
+
94
+ If the `coords + patch_size` are greater than the bounds of the image then the
95
+ end of the patch will be filled with zeros.
96
+
97
+ Parameters
98
+ ----------
99
+ coords : Sequence[int]
100
+ The coordinates that describe where the patch starts in the spatial dimension of
101
+ the image
102
+ patch_size : Sequence[int]
103
+ The size of the patch in the spatial dimensions.
104
+ data_shape : Sequence[int]
105
+ The shape of the image the patch originates from, must be in the format SC(Z)YX.
106
+ patch_data : NDArray[T]
107
+ The patch data to be padded.
108
+
109
+ Returns
110
+ -------
111
+ NDArray[T]
112
+ The resulting padded patch.
113
+ """
114
+ coords_ = np.array(coords)
115
+ patch = np.zeros((patch_data.shape[0], *patch_size), dtype=patch_data.dtype)
116
+ # data start will be zero unless coords are negative
117
+ data_start = np.clip(coords_, 0, None) - coords_
118
+ data_end = data_start + np.array(patch_data.shape[1:])
119
+ patch[
120
+ (
121
+ slice(None, None, None), # channel slice
122
+ *tuple(slice(s, t) for s, t in zip(data_start, data_end, strict=False)),
123
+ )
124
+ ] = patch_data
125
+ return patch
@@ -0,0 +1,93 @@
1
+ from collections.abc import Sequence
2
+ from pathlib import Path
3
+ from typing import Any, Literal, Self, Union
4
+
5
+ import numpy as np
6
+ from numpy.typing import DTypeLike, NDArray
7
+
8
+ from careamics.dataset.dataset_utils import reshape_array
9
+ from careamics.file_io.read import ReadFunc, read_tiff
10
+
11
+ from .image_utils.image_stack_utils import channel_slice, pad_patch
12
+
13
+
14
+ class InMemoryImageStack:
15
+ """
16
+ A class for extracting patches from an image stack that has been loaded into memory.
17
+ """
18
+
19
+ def __init__(self, source: Union[Path, Literal["array"]], data: NDArray):
20
+ self.source: Union[str, Path, Literal["array"]] = source
21
+ # data expected to be in SC(Z)YX shape, reason to use from_array constructor
22
+ self._data: NDArray = data
23
+ self.data_shape: Sequence[int] = self._data.shape
24
+ self.data_dtype: DTypeLike = self._data.dtype
25
+
26
+ def extract_patch(
27
+ self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
28
+ ) -> NDArray:
29
+ return self.extract_channel_patch(sample_idx, None, coords, patch_size)
30
+
31
+ def extract_channel_patch(
32
+ self,
33
+ sample_idx: int,
34
+ channels: Sequence[int] | None, # `channels = None` to select all channels
35
+ coords: Sequence[int],
36
+ patch_size: Sequence[int],
37
+ ) -> NDArray:
38
+ if (coord_dims := len(coords)) != (patch_dims := len(patch_size)):
39
+ raise ValueError(
40
+ "Patch coordinates and patch size must have the same dimensions but "
41
+ f"found {coord_dims} ({coords}) and {patch_dims} ({patch_size})."
42
+ )
43
+
44
+ # check that channels are within bounds
45
+ if channels is not None:
46
+ max_channel = self.data_shape[1] - 1 # channel is second dimension
47
+ for ch in channels:
48
+ if ch > max_channel:
49
+ raise ValueError(
50
+ f"Channel index {ch} is out of bounds for data with "
51
+ f"{self.data_shape[1]} channels. Check the provided `channels` "
52
+ f"parameter in the configuration for erroneous channel "
53
+ f"indices."
54
+ )
55
+
56
+ # TODO: test for 2D or 3D?
57
+
58
+ patch_data = self._data[
59
+ (
60
+ sample_idx, # type: ignore
61
+ # use channel slice so that channel dimension is kept
62
+ channel_slice(channels), # type: ignore
63
+ *[
64
+ slice(
65
+ np.clip(c, 0, self.data_shape[2 + i]),
66
+ np.clip(c + ps, 0, self.data_shape[2 + i]),
67
+ )
68
+ for i, (c, ps) in enumerate(zip(coords, patch_size, strict=False))
69
+ ], # type: ignore
70
+ ) # type: ignore
71
+ ]
72
+ patch = pad_patch(coords, patch_size, self.data_shape, patch_data)
73
+
74
+ return patch
75
+
76
+ @classmethod
77
+ def from_array(cls, data: NDArray, axes: str) -> Self:
78
+ data = reshape_array(data, axes)
79
+ return cls(source="array", data=data)
80
+
81
+ @classmethod
82
+ def from_tiff(cls, path: Path, axes: str) -> Self:
83
+ data = read_tiff(path)
84
+ data = reshape_array(data, axes)
85
+ return cls(source=path, data=data)
86
+
87
+ @classmethod
88
+ def from_custom_file_type(
89
+ cls, path: Path, axes: str, read_func: ReadFunc, **read_kwargs: Any
90
+ ) -> Self:
91
+ data = read_func(path, **read_kwargs)
92
+ data = reshape_array(data, axes)
93
+ return cls(source=path, data=data)
@@ -0,0 +1,170 @@
1
+ from collections.abc import Sequence
2
+
3
+ import zarr
4
+ from numpy.typing import DTypeLike, NDArray
5
+
6
+ from careamics.dataset.dataset_utils import reshape_array
7
+
8
+ from .image_utils.image_stack_utils import channel_slice, pad_patch, reshape_array_shape
9
+
10
+
11
+ class ZarrImageStack:
12
+ """
13
+ A class for extracting patches from an image stack that is stored as a zarr array.
14
+ """
15
+
16
+ def __init__(self, group: zarr.Group, data_path: str, axes: str):
17
+ if not isinstance(group, zarr.Group):
18
+ raise TypeError(f"group must be a zarr.Group instance, got {type(group)}.")
19
+
20
+ self._group = group
21
+ self._store = str(group.store_path)
22
+ try:
23
+ self._array = group[data_path]
24
+ except KeyError as e:
25
+ raise ValueError(
26
+ f"Did not find array at '{data_path}' in store '{self._store}'."
27
+ ) from e
28
+
29
+ if not isinstance(self._array, zarr.Array):
30
+ raise TypeError(
31
+ f"data at path '{data_path}' must be a zarr.Array instance, "
32
+ f"got {type(self._array)}."
33
+ )
34
+
35
+ self._source = self._array.store_path
36
+
37
+ # TODO: validate axes
38
+ # - must contain XY
39
+ # - must be subset of STCZYX
40
+ self._original_axes = axes
41
+ self._original_data_shape: tuple[int, ...] = self._array.shape
42
+ self.data_shape = reshape_array_shape(axes, self._original_data_shape)
43
+ self._data_dtype = self._array.dtype
44
+ self._chunk_size = reshape_array_shape(
45
+ axes, self._array.chunks, add_singleton=False
46
+ )
47
+ self._shard_size = (
48
+ reshape_array_shape(axes, self._array.shards, add_singleton=False)
49
+ if self._array.shards is not None
50
+ else None
51
+ )
52
+
53
+ # Used to identify the source of the data and write to similar path during pred
54
+ @property
55
+ def source(self) -> str:
56
+ # e.g. file://data/bsd68.zarr/train/
57
+ return str(self._source)
58
+
59
+ @property
60
+ def chunks(self) -> Sequence[int]:
61
+ """Chunks size in the order of data_shape (SC(Z)YX)."""
62
+ return self._chunk_size
63
+
64
+ @property
65
+ def shards(self) -> Sequence[int] | None:
66
+ """Shard size in the order of data_shape (SC(Z)YX)."""
67
+ return self._shard_size
68
+
69
+ @property
70
+ def data_dtype(self) -> DTypeLike:
71
+ return self._data_dtype
72
+
73
+ def extract_patch(
74
+ self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
75
+ ) -> NDArray:
76
+ return self.extract_channel_patch(sample_idx, None, coords, patch_size)
77
+
78
+ def extract_channel_patch(
79
+ self,
80
+ sample_idx: int,
81
+ channels: Sequence[int] | None, # `channels = None` to select all channels,
82
+ coords: Sequence[int],
83
+ patch_size: Sequence[int],
84
+ ) -> NDArray:
85
+ # original axes assumed to be any subset of STCZYX (containing YX), in any order
86
+ # arguments must be transformed to index data in original axes order
87
+ # to do this: loop through original axes and append correct index/slice
88
+ # for each case: STCZYX
89
+ # Note: if any axis is not present in original_axes it is skipped.
90
+
91
+ # guard for no S and T in original axes
92
+ if ("S" not in self._original_axes) and ("T" not in self._original_axes):
93
+ if sample_idx not in [0, -1]:
94
+ raise IndexError(
95
+ f"Sample index {sample_idx} out of bounds for S axes with size "
96
+ f"{self.data_shape[0]}"
97
+ )
98
+
99
+ # check that channels are within bounds
100
+ if channels is not None:
101
+ max_channel = self.data_shape[1] - 1 # channel is second dimension
102
+ for ch in channels:
103
+ if ch > max_channel:
104
+ raise ValueError(
105
+ f"Channel index {ch} is out of bounds for data with "
106
+ f"{self.data_shape[1]} channels. Check the provided `channels` "
107
+ f"parameter in the configuration for erroneous channel "
108
+ f"indices."
109
+ )
110
+
111
+ patch_slice: list[int | slice] = []
112
+ for d in self._original_axes:
113
+ if d == "S":
114
+ patch_slice.append(self._get_S_index(sample_idx))
115
+ elif d == "T":
116
+ patch_slice.append(self._get_T_index(sample_idx))
117
+ elif d == "C":
118
+ patch_slice.append(channel_slice(channels)) # type: ignore
119
+ elif d == "Z":
120
+ patch_slice.append(slice(coords[0], coords[0] + patch_size[0]))
121
+ elif d == "Y":
122
+ y_idx = 0 if "Z" not in self._original_axes else 1
123
+ patch_slice.append(
124
+ slice(coords[y_idx], coords[y_idx] + patch_size[y_idx])
125
+ )
126
+ elif d == "X":
127
+ x_idx = 1 if "Z" not in self._original_axes else 2
128
+ patch_slice.append(
129
+ slice(coords[x_idx], coords[x_idx] + patch_size[x_idx])
130
+ )
131
+ else:
132
+ raise ValueError(f"Unrecognised axis '{d}', axes should be in STCZYX.")
133
+
134
+ patch_data: NDArray = self._array[tuple(patch_slice)] # type: ignore
135
+ patch_axes = self._original_axes.replace("S", "").replace("T", "")
136
+ patch_data = reshape_array(patch_data, patch_axes)[0] # remove first sample dim
137
+ patch = pad_patch(coords, patch_size, self.data_shape, patch_data)
138
+
139
+ return patch
140
+
141
+ def _get_T_index(self, sample_idx: int) -> int:
142
+ """Get T index given `sample_idx`."""
143
+ if "T" not in self._original_axes:
144
+ raise ValueError("No 'T' axis specified in original data axes.")
145
+ axis_idx = self._original_axes.index("T")
146
+ dim = self._original_data_shape[axis_idx]
147
+
148
+ # new S' = S*T
149
+ # T_idx = S_idx' // T_size
150
+ # S_idx = S_idx' % T_size
151
+ # - floor divide finds the row
152
+ # - modulus finds how far along the row i.e. the column
153
+ return sample_idx % dim
154
+
155
+ def _get_S_index(self, sample_idx: int) -> int:
156
+ """Get S index given `sample_idx`."""
157
+ if "S" not in self._original_axes:
158
+ raise ValueError("No 'S' axis specified in original data axes.")
159
+ if "T" in self._original_axes:
160
+ T_axis_idx = self._original_axes.index("T")
161
+ T_dim = self._original_data_shape[T_axis_idx]
162
+
163
+ # new S' = S*T
164
+ # T_idx = S_idx' // T_size
165
+ # S_idx = S_idx' % T_size
166
+ # - floor divide finds the row
167
+ # - modulus finds how far along the row i.e. the column
168
+ return sample_idx // T_dim
169
+ else:
170
+ return sample_idx
@@ -0,0 +1,19 @@
1
+ __all__ = [
2
+ "ImageStackLoader",
3
+ "load_arrays",
4
+ "load_custom_file",
5
+ "load_czis",
6
+ "load_iter_tiff",
7
+ "load_tiffs",
8
+ "load_zarrs",
9
+ ]
10
+
11
+ from .image_stack_loader_protocol import ImageStackLoader
12
+ from .image_stack_loaders import (
13
+ load_arrays,
14
+ load_custom_file,
15
+ load_czis,
16
+ load_iter_tiff,
17
+ load_tiffs,
18
+ load_zarrs,
19
+ )