careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,180 @@
1
+ from functools import partial
2
+ from typing import Any
3
+
4
+ from typing_extensions import ParamSpec
5
+
6
+ from careamics.config.data.ng_data_config import NGDataConfig
7
+ from careamics.config.support import SupportedData
8
+ from careamics.file_io.read import ReadFunc
9
+
10
+ from .dataset import CareamicsDataset
11
+ from .image_stack import (
12
+ GenericImageStack,
13
+ ImageStack,
14
+ )
15
+ from .image_stack_loader import (
16
+ ImageStackLoader,
17
+ load_arrays,
18
+ load_custom_file,
19
+ load_czis,
20
+ load_iter_tiff,
21
+ load_tiffs,
22
+ load_zarrs,
23
+ )
24
+ from .patch_extractor import LimitFilesPatchExtractor, PatchExtractor
25
+
26
+ P = ParamSpec("P")
27
+
28
+
29
+ # convenience function but should use `create_dataloader` function instead
30
+ # For lazy loading custom batch sampler also needs to be set.
31
+ def create_dataset(
32
+ config: NGDataConfig,
33
+ inputs: Any,
34
+ targets: Any,
35
+ masks: Any = None,
36
+ read_func: ReadFunc | None = None,
37
+ read_kwargs: dict[str, Any] | None = None,
38
+ image_stack_loader: ImageStackLoader | None = None,
39
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
40
+ ) -> CareamicsDataset[ImageStack]:
41
+ """
42
+ Convenience function to create the CAREamicsDataset.
43
+
44
+ Parameters
45
+ ----------
46
+ config : DataConfig or InferenceConfig
47
+ The data configuration.
48
+ inputs : Any
49
+ The input sources to the dataset.
50
+ targets : Any, optional
51
+ The target sources to the dataset.
52
+ masks : Any, optional
53
+ The mask sources used to filter patches.
54
+ read_func : ReadFunc, optional
55
+ A function that can that can be used to load custom data. This argument is
56
+ ignored unless the `data_type` in the `config` is "custom".
57
+ read_kwargs : dict of {str, Any}, optional
58
+ Additional key-word arguments to pass to the `read_func`.
59
+ image_stack_loader : ImageStackLoader, optional
60
+ A function for custom image stack loading. This argument is ignored unless the
61
+ `data_type` in the `config` is "custom".
62
+ image_stack_loader_kwargs : {str, Any}, optional
63
+ Additional key-word arguments to pass to the `image_stack_loader`.
64
+ """
65
+ image_stack_loader = select_image_stack_loader(
66
+ data_type=SupportedData(config.data_type),
67
+ in_memory=config.in_memory,
68
+ read_func=read_func,
69
+ read_kwargs=read_kwargs,
70
+ image_stack_loader=image_stack_loader,
71
+ image_stack_loader_kwargs=image_stack_loader_kwargs,
72
+ )
73
+ patch_extractor_type = select_patch_extractor_type(
74
+ data_type=SupportedData(config.data_type), in_memory=config.in_memory
75
+ )
76
+ input_extractor = init_patch_extractor(
77
+ patch_extractor_type, image_stack_loader, inputs, config.axes
78
+ )
79
+ if targets is not None:
80
+ target_extractor = init_patch_extractor(
81
+ patch_extractor_type, image_stack_loader, targets, config.axes
82
+ )
83
+ else:
84
+ target_extractor = None
85
+ if masks is not None:
86
+ mask_extractor = init_patch_extractor(
87
+ patch_extractor_type, image_stack_loader, masks, config.axes
88
+ )
89
+ else:
90
+ mask_extractor = None
91
+ return CareamicsDataset(
92
+ data_config=config,
93
+ input_extractor=input_extractor,
94
+ target_extractor=target_extractor,
95
+ mask_extractor=mask_extractor,
96
+ )
97
+
98
+
99
+ def init_patch_extractor(
100
+ patch_extractor: type[PatchExtractor],
101
+ image_stack_loader: ImageStackLoader[..., GenericImageStack],
102
+ source: Any,
103
+ axes: str,
104
+ ) -> PatchExtractor[GenericImageStack]:
105
+ image_stacks = image_stack_loader(source, axes)
106
+ return patch_extractor(image_stacks)
107
+
108
+
109
+ def select_patch_extractor_type(
110
+ data_type: SupportedData,
111
+ in_memory: bool,
112
+ ) -> type[PatchExtractor]:
113
+ """Select the appropriate PatchExtractor type based on data type and memory mode.
114
+
115
+ If `in_memory` is True, or `data_type` is ZARR or CZI, the standard
116
+ `PatchExtractor` is selected, otherwise the `LimitFilesPatchExtractor` will be used.
117
+
118
+ Parameters
119
+ ----------
120
+ data_type : SupportedData
121
+ The type of data being handled.
122
+ in_memory : bool
123
+ Indicates whether data is to be loaded into memory.
124
+
125
+ Returns
126
+ -------
127
+ type[PatchExtractor]
128
+ The selected PatchExtractor type.
129
+ """
130
+ if not in_memory and data_type in (SupportedData.TIFF, SupportedData.CUSTOM):
131
+ return LimitFilesPatchExtractor
132
+ else:
133
+ return PatchExtractor
134
+
135
+
136
+ def select_image_stack_loader(
137
+ data_type: SupportedData,
138
+ in_memory: bool,
139
+ read_func: ReadFunc | None = None,
140
+ read_kwargs: dict[str, Any] | None = None,
141
+ image_stack_loader: ImageStackLoader | None = None,
142
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
143
+ ) -> ImageStackLoader:
144
+ match data_type:
145
+ case SupportedData.ARRAY:
146
+ return load_arrays
147
+ case SupportedData.TIFF:
148
+ if in_memory:
149
+ return load_tiffs
150
+ else:
151
+ return load_iter_tiff
152
+ case SupportedData.CUSTOM:
153
+ if (read_func is not None) and (image_stack_loader is None):
154
+ read_kwargs = {} if read_kwargs is None else read_kwargs
155
+ return partial(
156
+ load_custom_file, read_func=read_func, read_kwargs=read_kwargs
157
+ )
158
+ elif (read_func is None) and (image_stack_loader is not None):
159
+ image_stack_loader_kwargs = (
160
+ {}
161
+ if image_stack_loader_kwargs is None
162
+ else image_stack_loader_kwargs
163
+ )
164
+ return partial(image_stack_loader, **image_stack_loader_kwargs)
165
+ else:
166
+ raise ValueError(
167
+ "Found `data_type='custom'` **one** of `read_func` or "
168
+ "`image_stack_loader` must be provided."
169
+ )
170
+ case SupportedData.ZARR:
171
+ # TODO: in_memory or not
172
+ return load_zarrs
173
+ case SupportedData.CZI:
174
+ # TODO: in_memory or not
175
+ return load_czis
176
+ case _:
177
+ raise NotImplementedError(
178
+ f"Selecting an image stack for data type '{data_type}' has not been "
179
+ "implemented yet."
180
+ )
@@ -0,0 +1,73 @@
1
+ """Module for the `GroupedIndexSampler`."""
2
+
3
+ from collections.abc import Iterator, Sequence
4
+ from typing import Self
5
+
6
+ import numpy as np
7
+ from numpy.random import Generator, default_rng
8
+ from torch.utils.data import Sampler
9
+
10
+ from careamics.dataset_ng.dataset import CareamicsDataset
11
+
12
+
13
+ class GroupedIndexSampler(Sampler):
14
+ """
15
+ A PyTorch Sampler iterates through groups of indices.
16
+
17
+ The order of the groups will be shuffled and the order of the indices within the
18
+ groups will be shuffled.
19
+
20
+ This sampler is useful for iterative file loading — one file should be loaded at a
21
+ time so indices belonging to the same file should be grouped, but the order of the
22
+ files and the order of the indices should be shuffled.
23
+ """
24
+
25
+ def __init__(self, grouped_indices: Sequence[Sequence[int]], rng: Generator | None):
26
+ """
27
+ Parameters
28
+ ----------
29
+ grouped_indices : Sequence of (Sequence of int)
30
+ The indices that should be iterated through in groups.
31
+ """
32
+ super().__init__()
33
+ if rng is None:
34
+ self.rng = default_rng()
35
+ else:
36
+ self.rng = rng
37
+ # TODO: validate indices are unique across groups
38
+ self.grouped_indices = grouped_indices
39
+
40
+ @classmethod
41
+ def from_dataset(
42
+ cls, dataset: CareamicsDataset, rng: Generator | None = None
43
+ ) -> Self:
44
+ """
45
+ Create the sampler from a CareamicsDataset.
46
+
47
+ The grouped indices will be retrieved from the dataset's patching strategy.
48
+
49
+ Parameters
50
+ ----------
51
+ dataset: CareamicsDataset
52
+ An instance of the CareamicsDataset to create the sampler for.
53
+ rng: numpy.random.Generator, optional
54
+ Numpy random number generator that can be used to seed the sampler.
55
+ """
56
+ n_data_samples = len(dataset.input_extractor.shapes)
57
+ grouped_indices: list[Sequence[int]] = [
58
+ dataset.patching_strategy.get_patch_indices(i)
59
+ for i in range(n_data_samples)
60
+ ]
61
+ return cls(grouped_indices=grouped_indices, rng=rng)
62
+
63
+ def __iter__(self) -> Iterator[int]:
64
+
65
+ # shuffle the groups and the sub groups but keep indices in a group adjacent
66
+ group_order = np.arange(len(self.grouped_indices))
67
+ self.rng.shuffle(group_order)
68
+ for group_idx in group_order:
69
+ group = self.grouped_indices[group_idx.item()]
70
+ index_order = np.arange(len(group))
71
+ self.rng.shuffle(index_order)
72
+ for idx in index_order:
73
+ yield group[idx.item()]
@@ -0,0 +1,14 @@
1
+ __all__ = [
2
+ "CziImageStack",
3
+ "FileImageStack",
4
+ "GenericImageStack",
5
+ "ImageStack",
6
+ "InMemoryImageStack",
7
+ "ZarrImageStack",
8
+ ]
9
+
10
+ from .czi_image_stack import CziImageStack
11
+ from .file_image_stack import FileImageStack
12
+ from .image_stack_protocol import GenericImageStack, ImageStack
13
+ from .in_memory_image_stack import InMemoryImageStack
14
+ from .zarr_image_stack import ZarrImageStack
@@ -0,0 +1,396 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from collections.abc import Iterable, Sequence
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Literal
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ try:
12
+ from pylibCZIrw.czi import CziReader, Rectangle, open_czi
13
+
14
+ pyczi_available = True
15
+ except ImportError:
16
+ pyczi_available = False
17
+
18
+ if TYPE_CHECKING:
19
+ try:
20
+ from pylibCZIrw.czi import CziReader, Rectangle, open_czi
21
+ except ImportError:
22
+ CziReader = Rectangle = open_czi = None # type: ignore
23
+
24
+
25
+ class CziImageStack:
26
+ """
27
+ A class for extracting patches from an image stack that is stored as a CZI file.
28
+
29
+ Parameters
30
+ ----------
31
+ data_path : str or Path
32
+ Path to the CZI file.
33
+
34
+ scene : int, optional
35
+ Index of the scene to extract.
36
+
37
+ A single CZI file can contain multiple "scenes", which are stored alongside each
38
+ other at different coordinates in the image plane, often separated by empty
39
+ space. Specifying this argument will read only the single scene with that index
40
+ from the file. Think of it as cropping the CZI file to the region where that
41
+ scene is located.
42
+
43
+ If no scene index is specified, the entire image will be read. In case it
44
+ contains multiple scenes, they will all be present in the resulting image.
45
+ This is usually not desirable due to the empty space between them.
46
+ In general, only omit this argument or set it to `None` if you know that
47
+ your CZI file does not contain any scenes.
48
+
49
+ The static function :py:meth:`get_bounding_rectangles` can be used to find out
50
+ how many scenes a given file contains and what their bounding rectangles are.
51
+
52
+ The scene can also be provided as part of `data_path` by appending an `"@"`
53
+ followed by the scene index to the filename.
54
+
55
+ depth_axis : {"none", "Z", "T"}, default: "none"
56
+ Which axis to use as depth-axis for providing 3-D patches.
57
+
58
+ - `"none"`: Only provide 2-D patches. If a Z or T dimension is present in the
59
+ data, they will be combined into the sample dimension `S`.
60
+ - `"Z"`: Use the Z-axis as depth-axis. If a T axis is present as well, it will
61
+ be merged into the sample dimensions `S`.
62
+ - `"T"`: Use the T-axis as depth-axis. If a Z axis is present as well, it will
63
+ be merged into the sample dimensions `S`.
64
+
65
+ Attributes
66
+ ----------
67
+ source : Path
68
+ Path to the CZI file, including the scene index if specified.
69
+ data_path : Path
70
+ Path to the CZI file without scene index.
71
+ scene : int or None
72
+ Index of the scene to extract, or None if not specified.
73
+ data_shape : Sequence[int]
74
+ The shape of the data in the order `(SC(Z)YX)`.
75
+ axes : str
76
+ The axes in the CZI file corresponding to the dimensions in `data_shape`.
77
+ The following values can occur:
78
+
79
+ - "SCZYX" for 3-D volumes if `depth_axis` is `"Z"`.
80
+ - "SCTYX" for time-series if `depth_axis` is `"T"`.
81
+ - "SCYX" if `depth_axis` is `"none"`.
82
+
83
+ The axis `S` (sample) is the only one not mapping one-to-one to an axis in the
84
+ CZI file but combines all remaining axes present in the file into one.
85
+
86
+ Examples
87
+ --------
88
+ Create an image stack for the first scene in a CZI file:
89
+ >>> stack = CziImageStack("path/to/file.czi", scene=0) # doctest: +SKIP
90
+
91
+ Alternatively, the scene index can also be provided as part of the filename.
92
+ This is mainly intended for re-creating an image stack from the `source` property:
93
+ >>> stack = CziImageStack("path/to/file.czi@0") # doctest: +SKIP
94
+ >>> stack2 = CziImageStack(stack.source) # doctest: +SKIP
95
+
96
+ If the CZI file contains a third dimension (Z or T) and you want to perform 3-D
97
+ denoising, you need to explicitly set `depth_axis` to `"Z"` or `"T"`:
98
+ >>> stack_2d = CziImageStack("path/to/file.czi", scene=0) # doctest: +SKIP
99
+ >>> stack_2d.axes, stack_2d.data_shape # doctest: +SKIP
100
+ ('SCYX', [40, 1, 512, 512])
101
+ >>> stack_3d = CziImageStack( # doctest: +SKIP
102
+ ... "path/to/file.czi", scene=0, depth_axis="Z"
103
+ ... )
104
+ >>> stack_3d.axes, stack_3d.data_shape # doctest: +SKIP
105
+ ('SCZYX', [4, 1, 10, 512, 512])
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ data_path: str | Path,
111
+ scene: int | None = None,
112
+ depth_axis: Literal["none", "Z", "T"] = "none",
113
+ ) -> None:
114
+ if not pyczi_available:
115
+ raise ImportError(
116
+ "The CZI image stack requires the `pylibCZIrw` package to be installed."
117
+ " Please install it with `pip install careamics[czi]`."
118
+ )
119
+
120
+ _data_path = Path(data_path)
121
+
122
+ # Check for scene encoded in filename.
123
+ # Normally, file path and scene should be provided as separate arguments but
124
+ # we would also like to support using the `source` property to re-create the
125
+ # CZI image stack. In this case, the scene index is encoded in the file path.
126
+ scene_matches = re.match(r"^(.*)@(\d+)$", _data_path.name)
127
+ if scene_matches:
128
+ if scene is not None:
129
+ raise ValueError(
130
+ f"Scene index is specified in the filename ({_data_path.name}) and "
131
+ f"as an argument ({scene}). Please specify only one."
132
+ )
133
+ _data_path = _data_path.parent / scene_matches.group(1)
134
+ scene = int(scene_matches.group(2))
135
+
136
+ # Set variables
137
+ self.data_path = _data_path
138
+ self.scene = scene
139
+ self._depth_axis = depth_axis
140
+
141
+ # Open CZI file
142
+ self._czi = CziReader(str(self.data_path))
143
+
144
+ # Determine metadata
145
+ self.axes, self.data_shape, self._bounding_rectangle, self._sample_axes = (
146
+ self._get_shape()
147
+ )
148
+ self.data_dtype = np.float32
149
+
150
+ def __del__(self):
151
+ if hasattr(self, "_czi"):
152
+ # Close CZI file
153
+ self._czi.close()
154
+
155
+ def __getstate__(self) -> dict[str, Any]:
156
+ # Remove CziReader object from state to avoid pickling issues
157
+ state = self.__dict__.copy()
158
+ del state["_czi"]
159
+ return state
160
+
161
+ def __setstate__(self, state: dict[str, Any]) -> None:
162
+ # Reopen CZI file after unpickling
163
+ self.__dict__.update(state)
164
+ self._czi = CziReader(str(self.data_path))
165
+
166
+ # TODO: we append the scene index to the file name
167
+ # - not sure if this is a good approach
168
+ @property
169
+ def source(self) -> Path:
170
+ filename = self.data_path.name
171
+ if self.scene is not None:
172
+ filename = f"{filename}@{self.scene}"
173
+ return self.data_path.parent / filename
174
+
175
+ def extract_patch(
176
+ self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
177
+ ) -> NDArray:
178
+ return self.extract_channel_patch(sample_idx, None, coords, patch_size)
179
+
180
+ def extract_channel_patch(
181
+ self,
182
+ sample_idx: int,
183
+ channels: Sequence[int] | None, # `channels = None` to select all channels
184
+ coords: Sequence[int],
185
+ patch_size: Sequence[int],
186
+ ) -> NDArray:
187
+ # check that channels are within bounds
188
+ if channels is not None:
189
+ max_channel = self.data_shape[1] - 1 # channel is second dimension
190
+ for ch in channels:
191
+ if ch > max_channel:
192
+ raise ValueError(
193
+ f"Channel index {ch} is out of bounds for data with "
194
+ f"{self.data_shape[1]} channels. Check the provided `channels` "
195
+ f"parameter in the configuration for erroneous channel "
196
+ f"indices."
197
+ )
198
+
199
+ # Determine 3rd dimension (T, Z or none)
200
+ if len(coords) == 3:
201
+ if len(self.axes) != 5:
202
+ raise ValueError(
203
+ f"Requested a 3D patch from a 2D image stack with axes {self.axes}."
204
+ )
205
+ third_dim = self.axes[2]
206
+ third_dim_offset, third_dim_size = coords[0], patch_size[0]
207
+ else:
208
+ if len(self.axes) != 4:
209
+ raise ValueError(
210
+ f"Requested a 2D patch from a 3D image stack with axes {self.axes}."
211
+ )
212
+ third_dim = None
213
+ third_dim_offset, third_dim_size = 0, 1
214
+
215
+ # Set up ROI to extract from each plane as (x, y, w, h)
216
+ roi = (
217
+ self._bounding_rectangle.x + coords[-1],
218
+ self._bounding_rectangle.y + coords[-2],
219
+ patch_size[-1],
220
+ patch_size[-2],
221
+ )
222
+
223
+ # Create output array of shape (C, Z, Y, X)
224
+ n_channels = self.data_shape[1] if channels is None else len(channels)
225
+ patch = np.empty(
226
+ (n_channels, third_dim_size, *patch_size[-2:]), dtype=np.float32
227
+ )
228
+
229
+ # Set up plane to index `sample_idx`
230
+ sample_shape = list(self._sample_axes.values())
231
+ sample_indices = np.unravel_index(sample_idx, sample_shape)
232
+ plane = {
233
+ dimension: int(index)
234
+ for dimension, index in zip(
235
+ self._sample_axes.keys(), sample_indices, strict=False
236
+ )
237
+ }
238
+
239
+ # Read XY planes sequentially
240
+ channel_iter: Iterable
241
+ if channels is None:
242
+ channel_iter = range(self.data_shape[1]) # iter over number of requested C
243
+ else:
244
+ channel_iter = list(channels)
245
+
246
+ # for each channel
247
+ for patch_channel, data_channel in enumerate(channel_iter):
248
+ # pull plane with the given channel and 3rd dim index
249
+ for third_dim_index in range(third_dim_size):
250
+ plane["C"] = data_channel
251
+ if third_dim is not None:
252
+ plane[third_dim] = third_dim_offset + third_dim_index
253
+
254
+ # read plane
255
+ extracted_roi = self._czi.read(roi=roi, plane=plane, scene=self.scene)
256
+ if extracted_roi.ndim == 3:
257
+ if extracted_roi.shape[-1] > 1:
258
+ raise ValueError(
259
+ "CZI files with RGB channels are currently not supported."
260
+ )
261
+
262
+ # remove channel dimension
263
+ extracted_roi = extracted_roi.squeeze(-1)
264
+
265
+ # add requested channel into the patch
266
+ patch[patch_channel, third_dim_index] = extracted_roi
267
+
268
+ # Remove dummy 3rd dimension for 2-D data
269
+ if third_dim is None:
270
+ patch = patch.squeeze(1)
271
+
272
+ return patch
273
+
274
+ def _get_shape(self) -> tuple[str, list[int], Rectangle, dict[str, int]]:
275
+ """Determines the shape of the selected scene.
276
+
277
+ Returns
278
+ -------
279
+ axes : str
280
+ String specifying the axis order. Examples:
281
+
282
+ - "SCZYX" for 3-D volumes if `depth_axis` is `"Z"`.
283
+ - "SCTYX" for time-series if `depth_axis` is `"T"`.
284
+ - "SCYX" if `depth_axis` is `"none"`.
285
+
286
+ The axis `S` is the sample dimension and combines all remaining axes
287
+ present in the data.
288
+
289
+ shape : list[int]
290
+ The size of each axis, in the order listed in `axes`.
291
+
292
+ bounding_rectangle : Rectangle
293
+ The bounding rectangle of the scene in pixels. The rectangle is
294
+ defined by its top-left corner (x, y) and its width and height (w, h).
295
+
296
+ sample_axes : dict[str, int]
297
+ A dictionary with information about the remaining axes used for the
298
+ sample dimension.
299
+ The keys are the axis names (e.g., "T", "Z") and the values are their
300
+ respective sizes.
301
+ """
302
+ # Get CZI dimensions
303
+ total_bbox = self._czi.total_bounding_box_no_pyramid
304
+ if self.scene is None:
305
+ bounding_rectangle = self._czi.total_bounding_rectangle_no_pyramid
306
+ else:
307
+ bounding_rectangle = self._czi.scenes_bounding_rectangle_no_pyramid[
308
+ self.scene
309
+ ]
310
+
311
+ # Determine if T and Z axis are present
312
+ # Note: An axis of size 1 is as good as no axis since we cannot use it for 3-D
313
+ # denoising.
314
+ has_time = "T" in total_bbox and (total_bbox["T"][1] - total_bbox["T"][0]) > 1
315
+ has_depth = "Z" in total_bbox and (total_bbox["Z"][1] - total_bbox["Z"][0]) > 1
316
+
317
+ # Determine axis order depending on `depth_axis`
318
+ if self._depth_axis == "Z":
319
+ axes = "SCZYX"
320
+ if not has_depth:
321
+ raise RuntimeError(
322
+ f"The CZI file {self.data_path} does not contain a Z axis to use "
323
+ 'for 3-D denoising. Consider setting `axes="YX"` or '
324
+ '`depth_axis="none"` to perform 2-D denoising instead.'
325
+ )
326
+ elif self._depth_axis == "T":
327
+ axes = "SCTYX"
328
+ if not has_time:
329
+ raise RuntimeError(
330
+ f"The CZI file {self.data_path} does not contain a T axis to use "
331
+ 'for 3-D denoising. Consider setting `axes="YX"` or '
332
+ '`depth_axis="none"` to perform 2-D denoising instead.'
333
+ )
334
+ else:
335
+ axes = "SCYX"
336
+
337
+ # Calculcate size of sample dimension S, combining all axes not used elsewhere.
338
+ # This could, for example, be a time axis. If we only perform 2-D denoising, a
339
+ # potentially present Z axis would also be used as sample dimension. If both,
340
+ # T and Z, are present, both need to be combined into the sample dimension.
341
+ # The same needs to be done to any other potentially present axis in the CZI
342
+ # file which is not a spatial or channel axis.
343
+ # The following code calculates the size of the combined sample axis.
344
+ sample_axes = {}
345
+ sample_size = 1
346
+ for dimension, (start, end) in total_bbox.items():
347
+ if dimension not in axes:
348
+ sample_axes[dimension] = end - start
349
+ sample_size *= end - start
350
+
351
+ # Determine data shape
352
+ shape = []
353
+ for dimension in axes:
354
+ if dimension == "S":
355
+ shape.append(sample_size)
356
+ elif dimension == "Y":
357
+ shape.append(bounding_rectangle.h)
358
+ elif dimension == "X":
359
+ shape.append(bounding_rectangle.w)
360
+ elif dimension in total_bbox:
361
+ shape.append(total_bbox[dimension][1] - total_bbox[dimension][0])
362
+ else:
363
+ shape.append(1)
364
+
365
+ return axes, shape, bounding_rectangle, sample_axes
366
+
367
+ @classmethod
368
+ def get_bounding_rectangles(
369
+ cls, czi: Path | str | CziReader
370
+ ) -> dict[int | None, Rectangle]:
371
+ """Gets the bounding rectangles of all scenes in a CZI file.
372
+
373
+ Parameters
374
+ ----------
375
+ czi : Path or str or pyczi.CziReader
376
+ Path to the CZI file or an already opened file as CziReader object.
377
+
378
+ Returns
379
+ -------
380
+ dict[int | None, Rectangle]
381
+ A dictionary mapping scene indices to their bounding rectangles in the
382
+ format `(x, y, w, h)`.
383
+ If no scenes are present in the CZI file, the returned dictionary will
384
+ have only one entry with key `None`, whose bounding rectangle covers the
385
+ entire image.
386
+ """
387
+ if not isinstance(czi, CziReader):
388
+ with open_czi(str(czi)) as czi_reader:
389
+ return cls.get_bounding_rectangles(czi_reader)
390
+
391
+ scenes_bounding_rectangle = czi.scenes_bounding_rectangle_no_pyramid
392
+ if len(scenes_bounding_rectangle) >= 1:
393
+ # Ensure keys are int | None for type compatibility
394
+ return {int(k): v for k, v in scenes_bounding_rectangle.items()}
395
+ else:
396
+ return {None: czi.total_bounding_rectangle_no_pyramid}