careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,294 @@
1
+ """Iterable dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ from collections.abc import Callable, Generator
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ from torch.utils.data import IterableDataset
11
+
12
+ from careamics.config import DataConfig
13
+ from careamics.config.transformations import NormalizeConfig
14
+ from careamics.file_io.read import read_tiff
15
+ from careamics.transforms import Compose
16
+
17
+ from ..utils.logging import get_logger
18
+ from .dataset_utils import iterate_over_files
19
+ from .dataset_utils.running_stats import WelfordStatistics
20
+ from .patching.patching import Stats
21
+ from .patching.random_patching import extract_patches_random
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class PathIterableDataset(IterableDataset):
27
+ """
28
+ Dataset allowing extracting patches w/o loading whole data into memory.
29
+
30
+ Parameters
31
+ ----------
32
+ data_config : DataConfig
33
+ Data configuration.
34
+ src_files : list of pathlib.Path
35
+ List of data files.
36
+ target_files : list of pathlib.Path, optional
37
+ Optional list of target files, by default None.
38
+ read_source_func : Callable, optional
39
+ Read source function for custom types, by default read_tiff.
40
+
41
+ Attributes
42
+ ----------
43
+ data_path : list of pathlib.Path
44
+ Path to the data, must be a directory.
45
+ axes : str
46
+ Description of axes in format STCZYX.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ data_config: DataConfig,
52
+ src_files: list[Path],
53
+ target_files: list[Path] | None = None,
54
+ read_source_func: Callable = read_tiff,
55
+ ) -> None:
56
+ """Constructors.
57
+
58
+ Parameters
59
+ ----------
60
+ data_config : GeneralDataConfig
61
+ Data configuration.
62
+ src_files : list[Path]
63
+ List of data files.
64
+ target_files : list[Path] or None, optional
65
+ Optional list of target files, by default None.
66
+ read_source_func : Callable, optional
67
+ Read source function for custom types, by default read_tiff.
68
+ """
69
+ self.data_config = data_config
70
+ self.data_files = src_files
71
+ self.target_files = target_files
72
+ self.read_source_func = read_source_func
73
+
74
+ # compute mean and std over the dataset
75
+ # only checking the image_mean because the DataConfig class ensures that
76
+ # if image_mean is provided, image_std is also provided
77
+ if not self.data_config.image_means:
78
+ self.image_stats, self.target_stats = self._calculate_mean_and_std()
79
+ logger.info(
80
+ f"Computed dataset mean: {self.image_stats.means},"
81
+ f"std: {self.image_stats.stds}"
82
+ )
83
+
84
+ # update the mean in the config
85
+ self.data_config.set_means_and_stds(
86
+ image_means=self.image_stats.means,
87
+ image_stds=self.image_stats.stds,
88
+ target_means=(
89
+ list(self.target_stats.means)
90
+ if self.target_stats.means is not None
91
+ else None
92
+ ),
93
+ target_stds=(
94
+ list(self.target_stats.stds)
95
+ if self.target_stats.stds is not None
96
+ else None
97
+ ),
98
+ )
99
+
100
+ else:
101
+ # if mean and std are provided in the config, use them
102
+ self.image_stats, self.target_stats = (
103
+ Stats(self.data_config.image_means, self.data_config.image_stds),
104
+ Stats(self.data_config.target_means, self.data_config.target_stds),
105
+ )
106
+
107
+ # create transform composed of normalization and other transforms
108
+ self.patch_transform = Compose(
109
+ transform_list=[
110
+ NormalizeConfig(
111
+ image_means=self.image_stats.means,
112
+ image_stds=self.image_stats.stds,
113
+ target_means=self.target_stats.means,
114
+ target_stds=self.target_stats.stds,
115
+ )
116
+ ]
117
+ + list(data_config.transforms)
118
+ )
119
+
120
+ def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
121
+ """
122
+ Calculate mean and std of the dataset.
123
+
124
+ Returns
125
+ -------
126
+ tuple of Stats and optional Stats
127
+ Data classes containing the image and target statistics.
128
+ """
129
+ num_samples = 0
130
+ image_stats = WelfordStatistics()
131
+ if self.target_files is not None:
132
+ target_stats = WelfordStatistics()
133
+
134
+ for sample, target in iterate_over_files(
135
+ self.data_config, self.data_files, self.target_files, self.read_source_func
136
+ ):
137
+ # update the image statistics
138
+ image_stats.update(sample, num_samples)
139
+
140
+ # update the target statistics if target is available
141
+ if target is not None:
142
+ target_stats.update(target, num_samples)
143
+
144
+ num_samples += 1
145
+
146
+ if num_samples == 0:
147
+ raise ValueError("No samples found in the dataset.")
148
+
149
+ # Average the means and stds per sample
150
+ image_means, image_stds = image_stats.finalize()
151
+
152
+ if target is not None:
153
+ target_means, target_stds = target_stats.finalize()
154
+
155
+ return (
156
+ Stats(image_means, image_stds),
157
+ Stats(np.array(target_means), np.array(target_stds)),
158
+ )
159
+ else:
160
+ return Stats(image_means, image_stds), Stats(None, None)
161
+
162
+ def __iter__(
163
+ self,
164
+ ) -> Generator[tuple[np.ndarray, ...], None, None]:
165
+ """
166
+ Iterate over data source and yield single patch.
167
+
168
+ Yields
169
+ ------
170
+ np.ndarray
171
+ Single patch.
172
+ """
173
+ assert (
174
+ self.image_stats.means is not None and self.image_stats.stds is not None
175
+ ), "Mean and std must be provided"
176
+
177
+ # iterate over files
178
+ for sample_input, sample_target in iterate_over_files(
179
+ self.data_config, self.data_files, self.target_files, self.read_source_func
180
+ ):
181
+ patches = extract_patches_random(
182
+ arr=sample_input,
183
+ patch_size=self.data_config.patch_size,
184
+ target=sample_target,
185
+ )
186
+
187
+ # iterate over patches
188
+ # patches are tuples of (patch, target) if target is available
189
+ # or (patch, None) only if no target is available
190
+ # patch is of dimensions (C)ZYX
191
+ for patch_data in patches:
192
+ yield self.patch_transform(
193
+ patch=patch_data[0],
194
+ target=patch_data[1],
195
+ )
196
+
197
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
198
+ """Return training data statistics.
199
+
200
+ Returns
201
+ -------
202
+ tuple of list of floats
203
+ Means and standard deviations across channels of the training data.
204
+ """
205
+ return self.image_stats.get_statistics()
206
+
207
+ def get_number_of_files(self) -> int:
208
+ """
209
+ Return the number of files in the dataset.
210
+
211
+ Returns
212
+ -------
213
+ int
214
+ Number of files in the dataset.
215
+ """
216
+ return len(self.data_files)
217
+
218
+ def split_dataset(
219
+ self,
220
+ percentage: float = 0.1,
221
+ minimum_number: int = 5,
222
+ ) -> PathIterableDataset:
223
+ """Split up dataset in two.
224
+
225
+ Splits the datest sing a percentage of the data (files) to extract, or the
226
+ minimum number of the percentage is less than the minimum number.
227
+
228
+ Parameters
229
+ ----------
230
+ percentage : float, optional
231
+ Percentage of files to split up, by default 0.1.
232
+ minimum_number : int, optional
233
+ Minimum number of files to split up, by default 5.
234
+
235
+ Returns
236
+ -------
237
+ IterableDataset
238
+ Dataset containing the split data.
239
+
240
+ Raises
241
+ ------
242
+ ValueError
243
+ If the percentage is smaller than 0 or larger than 1.
244
+ ValueError
245
+ If the minimum number is smaller than 1 or larger than the number of files.
246
+ """
247
+ if percentage < 0 or percentage > 1:
248
+ raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
249
+
250
+ if minimum_number < 1 or minimum_number > self.get_number_of_files():
251
+ raise ValueError(
252
+ f"Minimum number of files must be between 1 and "
253
+ f"{self.get_number_of_files()} (number of files), got "
254
+ f"{minimum_number}."
255
+ )
256
+
257
+ # compute number of files
258
+ total_files = self.get_number_of_files()
259
+ n_files = max(round(percentage * total_files), minimum_number)
260
+
261
+ # get random indices
262
+ indices = np.random.choice(total_files, n_files, replace=False)
263
+
264
+ # extract files
265
+ val_files = [self.data_files[i] for i in indices]
266
+
267
+ # remove patches from self.patch
268
+ data_files = []
269
+ for i, file in enumerate(self.data_files):
270
+ if i not in indices:
271
+ data_files.append(file)
272
+ self.data_files = data_files
273
+
274
+ # same for targets
275
+ if self.target_files is not None:
276
+ val_target_files = [self.target_files[i] for i in indices]
277
+
278
+ data_target_files = []
279
+ for i, file in enumerate(self.target_files):
280
+ if i not in indices:
281
+ data_target_files.append(file)
282
+ self.target_files = data_target_files
283
+
284
+ # clone the dataset
285
+ dataset = copy.deepcopy(self)
286
+
287
+ # reassign patches
288
+ dataset.data_files = val_files
289
+
290
+ # reassign targets
291
+ if self.target_files is not None:
292
+ dataset.target_files = val_target_files
293
+
294
+ return dataset
@@ -0,0 +1,121 @@
1
+ """Iterable prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable, Generator
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from numpy.typing import NDArray
10
+ from torch.utils.data import IterableDataset
11
+
12
+ from careamics.file_io.read import read_tiff
13
+ from careamics.transforms import Compose
14
+
15
+ from ..config import InferenceConfig
16
+ from ..config.transformations import NormalizeConfig
17
+ from .dataset_utils import iterate_over_files
18
+
19
+
20
+ class IterablePredDataset(IterableDataset):
21
+ """Simple iterable prediction dataset.
22
+
23
+ Parameters
24
+ ----------
25
+ prediction_config : InferenceConfig
26
+ Inference configuration.
27
+ src_files : List[Path]
28
+ List of data files.
29
+ read_source_func : Callable, optional
30
+ Read source function for custom types, by default read_tiff.
31
+ **kwargs : Any
32
+ Additional keyword arguments, unused.
33
+
34
+ Attributes
35
+ ----------
36
+ data_path : Union[str, Path]
37
+ Path to the data, must be a directory.
38
+ axes : str
39
+ Description of axes in format STCZYX.
40
+ mean : Optional[float], optional
41
+ Expected mean of the dataset, by default None.
42
+ std : Optional[float], optional
43
+ Expected standard deviation of the dataset, by default None.
44
+ patch_transform : Optional[Callable], optional
45
+ Patch transform callable, by default None.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ prediction_config: InferenceConfig,
51
+ src_files: list[Path],
52
+ read_source_func: Callable = read_tiff,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ """Constructor.
56
+
57
+ Parameters
58
+ ----------
59
+ prediction_config : InferenceConfig
60
+ Inference configuration.
61
+ src_files : list of pathlib.Path
62
+ List of data files.
63
+ read_source_func : Callable, optional
64
+ Read source function for custom types, by default read_tiff.
65
+ **kwargs : Any
66
+ Additional keyword arguments, unused.
67
+
68
+ Raises
69
+ ------
70
+ ValueError
71
+ If mean and std are not provided in the inference configuration.
72
+ """
73
+ self.prediction_config = prediction_config
74
+ self.data_files = src_files
75
+ self.axes = prediction_config.axes
76
+ self.read_source_func = read_source_func
77
+
78
+ # check mean and std and create normalize transform
79
+ if (
80
+ self.prediction_config.image_means is None
81
+ or self.prediction_config.image_stds is None
82
+ ):
83
+ raise ValueError("Mean and std must be provided for prediction.")
84
+ else:
85
+ self.image_means = self.prediction_config.image_means
86
+ self.image_stds = self.prediction_config.image_stds
87
+
88
+ # instantiate normalize transform
89
+ self.patch_transform = Compose(
90
+ transform_list=[
91
+ NormalizeConfig(
92
+ image_means=self.image_means,
93
+ image_stds=self.image_stds,
94
+ )
95
+ ],
96
+ )
97
+
98
+ def __iter__(
99
+ self,
100
+ ) -> Generator[tuple[NDArray, ...], None, None]:
101
+ """
102
+ Iterate over data source and yield single patch.
103
+
104
+ Yields
105
+ ------
106
+ (numpy.ndarray, numpy.ndarray or None)
107
+ Single patch.
108
+ """
109
+ assert (
110
+ self.image_means is not None and self.image_stds is not None
111
+ ), "Mean and std must be provided"
112
+
113
+ for sample, _ in iterate_over_files(
114
+ self.prediction_config,
115
+ self.data_files,
116
+ read_source_func=self.read_source_func,
117
+ ):
118
+ # sample has S dimension
119
+ for i in range(sample.shape[0]):
120
+
121
+ yield self.patch_transform(patch=sample[i])
@@ -0,0 +1,141 @@
1
+ """Iterable tiled prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable, Generator
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from numpy.typing import NDArray
10
+ from torch.utils.data import IterableDataset
11
+
12
+ from careamics.file_io.read import read_tiff
13
+ from careamics.transforms import Compose
14
+
15
+ from ..config import InferenceConfig
16
+ from ..config.data.tile_information import TileInformation
17
+ from ..config.transformations import NormalizeConfig
18
+ from .dataset_utils import iterate_over_files
19
+ from .tiling import extract_tiles
20
+
21
+
22
+ class IterableTiledPredDataset(IterableDataset):
23
+ """Tiled prediction dataset.
24
+
25
+ Parameters
26
+ ----------
27
+ prediction_config : InferenceConfig
28
+ Inference configuration.
29
+ src_files : list of pathlib.Path
30
+ List of data files.
31
+ read_source_func : Callable, optional
32
+ Read source function for custom types, by default read_tiff.
33
+ **kwargs : Any
34
+ Additional keyword arguments, unused.
35
+
36
+ Attributes
37
+ ----------
38
+ data_path : str or pathlib.Path
39
+ Path to the data, must be a directory.
40
+ axes : str
41
+ Description of axes in format STCZYX.
42
+ mean : float, optional
43
+ Expected mean of the dataset, by default None.
44
+ std : float, optional
45
+ Expected standard deviation of the dataset, by default None.
46
+ patch_transform : Callable, optional
47
+ Patch transform callable, by default None.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ prediction_config: InferenceConfig,
53
+ src_files: list[Path],
54
+ read_source_func: Callable = read_tiff,
55
+ **kwargs: Any,
56
+ ) -> None:
57
+ """Constructor.
58
+
59
+ Parameters
60
+ ----------
61
+ prediction_config : InferenceConfig
62
+ Inference configuration.
63
+ src_files : List[Path]
64
+ List of data files.
65
+ read_source_func : Callable, optional
66
+ Read source function for custom types, by default read_tiff.
67
+ **kwargs : Any
68
+ Additional keyword arguments, unused.
69
+
70
+ Raises
71
+ ------
72
+ ValueError
73
+ If mean and std are not provided in the inference configuration.
74
+ """
75
+ if (
76
+ prediction_config.tile_size is None
77
+ or prediction_config.tile_overlap is None
78
+ ):
79
+ raise ValueError(
80
+ "Tile size and overlap must be provided for tiled prediction."
81
+ )
82
+
83
+ self.prediction_config = prediction_config
84
+ self.data_files = src_files
85
+ self.axes = prediction_config.axes
86
+ self.tile_size = prediction_config.tile_size
87
+ self.tile_overlap = prediction_config.tile_overlap
88
+ self.read_source_func = read_source_func
89
+
90
+ # check mean and std and create normalize transform
91
+ if (
92
+ self.prediction_config.image_means is None
93
+ or self.prediction_config.image_stds is None
94
+ ):
95
+ raise ValueError("Mean and std must be provided for prediction.")
96
+ else:
97
+ self.image_means = self.prediction_config.image_means
98
+ self.image_stds = self.prediction_config.image_stds
99
+
100
+ # instantiate normalize transform
101
+ self.patch_transform = Compose(
102
+ transform_list=[
103
+ NormalizeConfig(
104
+ image_means=self.image_means,
105
+ image_stds=self.image_stds,
106
+ )
107
+ ],
108
+ )
109
+
110
+ def __iter__(
111
+ self,
112
+ ) -> Generator[tuple[tuple[NDArray, ...], TileInformation], None, None]:
113
+ """
114
+ Iterate over data source and yield single patch.
115
+
116
+ Yields
117
+ ------
118
+ Generator of (np.ndarray, np.ndarray or None) and TileInformation tuple
119
+ Generator of single tiles.
120
+ """
121
+ assert (
122
+ self.image_means is not None and self.image_stds is not None
123
+ ), "Mean and std must be provided"
124
+
125
+ for sample, _ in iterate_over_files(
126
+ self.prediction_config,
127
+ self.data_files,
128
+ read_source_func=self.read_source_func,
129
+ ):
130
+ # generate patches, return a generator of single tiles
131
+ patch_gen = extract_tiles(
132
+ arr=sample,
133
+ tile_size=self.tile_size,
134
+ overlaps=self.tile_overlap,
135
+ )
136
+
137
+ # apply transform to patches
138
+ for patch_array, tile_info in patch_gen:
139
+ transformed_patch = self.patch_transform(patch=patch_array)
140
+
141
+ yield transformed_patch, tile_info
@@ -0,0 +1 @@
1
+ """Patching and tiling functions."""