careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,303 @@
1
+ """In-memory dataset module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ from collections.abc import Callable
7
+ from pathlib import Path
8
+ from typing import Any, Union
9
+
10
+ import numpy as np
11
+ from torch.utils.data import Dataset
12
+
13
+ from careamics.config import DataConfig
14
+ from careamics.config.transformations import NormalizeConfig
15
+ from careamics.dataset.patching.patching import (
16
+ PatchedOutput,
17
+ Stats,
18
+ prepare_patches_supervised,
19
+ prepare_patches_supervised_array,
20
+ prepare_patches_unsupervised,
21
+ prepare_patches_unsupervised_array,
22
+ )
23
+ from careamics.file_io.read import read_tiff
24
+ from careamics.transforms import Compose
25
+ from careamics.utils.logging import get_logger
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ class InMemoryDataset(Dataset):
31
+ """Dataset storing data in memory and allowing generating patches from it.
32
+
33
+ Parameters
34
+ ----------
35
+ data_config : CAREamics DataConfig
36
+ (see careamics.config.data_model.DataConfig)
37
+ Data configuration.
38
+ inputs : numpy.ndarray or list[pathlib.Path]
39
+ Input data.
40
+ input_target : numpy.ndarray or list[pathlib.Path], optional
41
+ Target data, by default None.
42
+ read_source_func : Callable, optional
43
+ Read source function for custom types, by default read_tiff.
44
+ **kwargs : Any
45
+ Additional keyword arguments, unused.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ data_config: DataConfig,
51
+ inputs: Union[np.ndarray, list[Path]],
52
+ input_target: Union[np.ndarray, list[Path]] | None = None,
53
+ read_source_func: Callable = read_tiff,
54
+ **kwargs: Any,
55
+ ) -> None:
56
+ """
57
+ Constructor.
58
+
59
+ Parameters
60
+ ----------
61
+ data_config : GeneralDataConfig
62
+ Data configuration.
63
+ inputs : numpy.ndarray or list[pathlib.Path]
64
+ Input data.
65
+ input_target : numpy.ndarray or list[pathlib.Path], optional
66
+ Target data, by default None.
67
+ read_source_func : Callable, optional
68
+ Read source function for custom types, by default read_tiff.
69
+ **kwargs : Any
70
+ Additional keyword arguments, unused.
71
+ """
72
+ self.data_config = data_config
73
+ self.inputs = inputs
74
+ self.input_targets = input_target
75
+ self.axes = self.data_config.axes
76
+ self.patch_size = self.data_config.patch_size
77
+
78
+ # read function
79
+ self.read_source_func = read_source_func
80
+
81
+ # generate patches
82
+ supervised = self.input_targets is not None
83
+ patches_data = self._prepare_patches(supervised)
84
+
85
+ # unpack the dataclass
86
+ self.data = patches_data.patches
87
+ self.data_targets = patches_data.targets
88
+
89
+ # set image statistics
90
+ if self.data_config.image_means is None:
91
+ self.image_stats = patches_data.image_stats
92
+ logger.info(
93
+ f"Computed dataset mean: {self.image_stats.means}, "
94
+ f"std: {self.image_stats.stds}"
95
+ )
96
+ else:
97
+ self.image_stats = Stats(
98
+ self.data_config.image_means, self.data_config.image_stds
99
+ )
100
+
101
+ # set target statistics
102
+ if self.data_config.target_means is None:
103
+ self.target_stats = patches_data.target_stats
104
+ else:
105
+ self.target_stats = Stats(
106
+ self.data_config.target_means, self.data_config.target_stds
107
+ )
108
+
109
+ # update mean and std in configuration
110
+ # the object is mutable and should then be recorded in the CAREamist obj
111
+ self.data_config.set_means_and_stds(
112
+ image_means=self.image_stats.means,
113
+ image_stds=self.image_stats.stds,
114
+ target_means=self.target_stats.means,
115
+ target_stds=self.target_stats.stds,
116
+ )
117
+ # get transforms
118
+ self.patch_transform = Compose(
119
+ transform_list=[
120
+ NormalizeConfig(
121
+ image_means=self.image_stats.means,
122
+ image_stds=self.image_stats.stds,
123
+ target_means=self.target_stats.means,
124
+ target_stds=self.target_stats.stds,
125
+ )
126
+ ]
127
+ + list(self.data_config.transforms),
128
+ )
129
+
130
+ def _prepare_patches(self, supervised: bool) -> PatchedOutput:
131
+ """
132
+ Iterate over data source and create an array of patches.
133
+
134
+ Parameters
135
+ ----------
136
+ supervised : bool
137
+ Whether the dataset is supervised or not.
138
+
139
+ Returns
140
+ -------
141
+ numpy.ndarray
142
+ Array of patches.
143
+ """
144
+ if supervised:
145
+ if isinstance(self.inputs, np.ndarray) and isinstance(
146
+ self.input_targets, np.ndarray
147
+ ):
148
+ return prepare_patches_supervised_array(
149
+ self.inputs,
150
+ self.axes,
151
+ self.input_targets,
152
+ self.patch_size,
153
+ )
154
+ elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
155
+ return prepare_patches_supervised(
156
+ self.inputs,
157
+ self.input_targets,
158
+ self.axes,
159
+ self.patch_size,
160
+ self.read_source_func,
161
+ )
162
+ else:
163
+ raise ValueError(
164
+ f"Data and target must be of the same type, either both numpy "
165
+ f"arrays or both lists of paths, got {type(self.inputs)} (data) "
166
+ f"and {type(self.input_targets)} (target)."
167
+ )
168
+ else:
169
+ if isinstance(self.inputs, np.ndarray):
170
+ return prepare_patches_unsupervised_array(
171
+ self.inputs,
172
+ self.axes,
173
+ self.patch_size,
174
+ )
175
+ else:
176
+ return prepare_patches_unsupervised(
177
+ self.inputs,
178
+ self.axes,
179
+ self.patch_size,
180
+ self.read_source_func,
181
+ )
182
+
183
+ def __len__(self) -> int:
184
+ """
185
+ Return the length of the dataset.
186
+
187
+ Returns
188
+ -------
189
+ int
190
+ Length of the dataset.
191
+ """
192
+ return self.data.shape[0]
193
+
194
+ def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
195
+ """
196
+ Return the patch corresponding to the provided index.
197
+
198
+ Parameters
199
+ ----------
200
+ index : int
201
+ Index of the patch to return.
202
+
203
+ Returns
204
+ -------
205
+ tuple of numpy.ndarray
206
+ Patch.
207
+
208
+ Raises
209
+ ------
210
+ ValueError
211
+ If dataset mean and std are not set.
212
+ """
213
+ patch = self.data[index]
214
+
215
+ # if there is a target
216
+ if self.data_targets is not None:
217
+ # get target
218
+ target = self.data_targets[index]
219
+ return self.patch_transform(patch=patch, target=target)
220
+
221
+ return self.patch_transform(patch=patch)
222
+
223
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
224
+ """Return training data statistics.
225
+
226
+ This does not return the target data statistics, only those of the input.
227
+
228
+ Returns
229
+ -------
230
+ tuple of list of floats
231
+ Means and standard deviations across channels of the training data.
232
+ """
233
+ return self.image_stats.get_statistics()
234
+
235
+ def split_dataset(
236
+ self,
237
+ percentage: float = 0.1,
238
+ minimum_patches: int = 1,
239
+ ) -> InMemoryDataset:
240
+ """Split a new dataset away from the current one.
241
+
242
+ This method is used to extract random validation patches from the dataset.
243
+
244
+ Parameters
245
+ ----------
246
+ percentage : float, optional
247
+ Percentage of patches to extract, by default 0.1.
248
+ minimum_patches : int, optional
249
+ Minimum number of patches to extract, by default 5.
250
+
251
+ Returns
252
+ -------
253
+ CAREamics InMemoryDataset
254
+ New dataset with the extracted patches.
255
+
256
+ Raises
257
+ ------
258
+ ValueError
259
+ If `percentage` is not between 0 and 1.
260
+ ValueError
261
+ If `minimum_number` is not between 1 and the number of patches.
262
+ """
263
+ if percentage < 0 or percentage > 1:
264
+ raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
265
+
266
+ if minimum_patches < 1 or minimum_patches > len(self):
267
+ raise ValueError(
268
+ f"Minimum number of patches must be between 1 and "
269
+ f"{len(self)} (number of patches), got "
270
+ f"{minimum_patches}. Adjust the patch size or the minimum number of "
271
+ f"patches."
272
+ )
273
+
274
+ total_patches = len(self)
275
+
276
+ # number of patches to extract (either percentage rounded or minimum number)
277
+ n_patches = max(round(total_patches * percentage), minimum_patches)
278
+
279
+ # get random indices
280
+ indices = np.random.choice(total_patches, n_patches, replace=False)
281
+
282
+ # extract patches
283
+ val_patches = self.data[indices]
284
+
285
+ # remove patches from self.patch
286
+ self.data = np.delete(self.data, indices, axis=0)
287
+
288
+ # same for targets
289
+ if self.data_targets is not None:
290
+ val_targets = self.data_targets[indices]
291
+ self.data_targets = np.delete(self.data_targets, indices, axis=0)
292
+
293
+ # clone the dataset
294
+ dataset = copy.deepcopy(self)
295
+
296
+ # reassign patches
297
+ dataset.data = val_patches
298
+
299
+ # reassign targets
300
+ if self.data_targets is not None:
301
+ dataset.data_targets = val_targets
302
+
303
+ return dataset
@@ -0,0 +1,88 @@
1
+ """In-memory prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.transformations import NormalizeConfig
12
+ from .dataset_utils import reshape_array
13
+
14
+
15
+ class InMemoryPredDataset(Dataset):
16
+ """Simple prediction dataset returning images along the sample axis.
17
+
18
+ Parameters
19
+ ----------
20
+ prediction_config : InferenceConfig
21
+ Prediction configuration.
22
+ inputs : NDArray
23
+ Input data.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ prediction_config: InferenceConfig,
29
+ inputs: NDArray,
30
+ ) -> None:
31
+ """Constructor.
32
+
33
+ Parameters
34
+ ----------
35
+ prediction_config : InferenceConfig
36
+ Prediction configuration.
37
+ inputs : NDArray
38
+ Input data.
39
+
40
+ Raises
41
+ ------
42
+ ValueError
43
+ If data_path is not a directory.
44
+ """
45
+ self.pred_config = prediction_config
46
+ self.input_array = inputs
47
+ self.axes = self.pred_config.axes
48
+ self.image_means = self.pred_config.image_means
49
+ self.image_stds = self.pred_config.image_stds
50
+
51
+ # Reshape data
52
+ self.data = reshape_array(self.input_array, self.axes)
53
+
54
+ # get transforms
55
+ self.patch_transform = Compose(
56
+ transform_list=[
57
+ NormalizeConfig(
58
+ image_means=self.image_means, image_stds=self.image_stds
59
+ )
60
+ ],
61
+ )
62
+
63
+ def __len__(self) -> int:
64
+ """
65
+ Return the length of the dataset.
66
+
67
+ Returns
68
+ -------
69
+ int
70
+ Length of the dataset.
71
+ """
72
+ return len(self.data)
73
+
74
+ def __getitem__(self, index: int) -> tuple[NDArray, ...]:
75
+ """
76
+ Return the patch corresponding to the provided index.
77
+
78
+ Parameters
79
+ ----------
80
+ index : int
81
+ Index of the patch to return.
82
+
83
+ Returns
84
+ -------
85
+ tuple(numpy.ndarray, ...)
86
+ Transformed patch.
87
+ """
88
+ return self.patch_transform(patch=self.data[index])
@@ -0,0 +1,131 @@
1
+ """In-memory tiled prediction dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from numpy.typing import NDArray
6
+ from torch.utils.data import Dataset
7
+
8
+ from careamics.transforms import Compose
9
+
10
+ from ..config import InferenceConfig
11
+ from ..config.data.tile_information import TileInformation
12
+ from ..config.transformations import NormalizeConfig
13
+ from .dataset_utils import reshape_array
14
+ from .tiling import extract_tiles
15
+
16
+
17
+ class InMemoryTiledPredDataset(Dataset):
18
+ """Prediction dataset storing data in memory and returning tiles of each image.
19
+
20
+ Parameters
21
+ ----------
22
+ prediction_config : InferenceConfig
23
+ Prediction configuration.
24
+ inputs : NDArray
25
+ Input data.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ prediction_config: InferenceConfig,
31
+ inputs: NDArray,
32
+ ) -> None:
33
+ """Constructor.
34
+
35
+ Parameters
36
+ ----------
37
+ prediction_config : InferenceConfig
38
+ Prediction configuration.
39
+ inputs : NDArray
40
+ Input data.
41
+
42
+ Raises
43
+ ------
44
+ ValueError
45
+ If data_path is not a directory.
46
+ """
47
+ if (
48
+ prediction_config.tile_size is None
49
+ or prediction_config.tile_overlap is None
50
+ ):
51
+ raise ValueError(
52
+ "Tile size and overlap must be provided to use the tiled prediction "
53
+ "dataset."
54
+ )
55
+
56
+ self.pred_config = prediction_config
57
+ self.input_array = inputs
58
+ self.axes = self.pred_config.axes
59
+ self.tile_size = prediction_config.tile_size
60
+ self.tile_overlap = prediction_config.tile_overlap
61
+ self.image_means = self.pred_config.image_means
62
+ self.image_stds = self.pred_config.image_stds
63
+
64
+ # Generate patches
65
+ self.data = self._prepare_tiles()
66
+
67
+ # get transforms
68
+ self.patch_transform = Compose(
69
+ transform_list=[
70
+ NormalizeConfig(
71
+ image_means=self.image_means, image_stds=self.image_stds
72
+ )
73
+ ],
74
+ )
75
+
76
+ def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
77
+ """
78
+ Iterate over data source and create an array of patches.
79
+
80
+ Returns
81
+ -------
82
+ list of tuples of NDArray and TileInformation
83
+ List of tiles and tile information.
84
+ """
85
+ # reshape array
86
+ reshaped_sample = reshape_array(self.input_array, self.axes)
87
+
88
+ # generate patches, which returns a generator
89
+ patch_generator = extract_tiles(
90
+ arr=reshaped_sample,
91
+ tile_size=self.tile_size,
92
+ overlaps=self.tile_overlap,
93
+ )
94
+ patches_list = list(patch_generator)
95
+
96
+ if len(patches_list) == 0:
97
+ raise ValueError("No tiles generated, ")
98
+
99
+ return patches_list
100
+
101
+ def __len__(self) -> int:
102
+ """
103
+ Return the length of the dataset.
104
+
105
+ Returns
106
+ -------
107
+ int
108
+ Length of the dataset.
109
+ """
110
+ return len(self.data)
111
+
112
+ def __getitem__(self, index: int) -> tuple[tuple[NDArray, ...], TileInformation]:
113
+ """
114
+ Return the patch corresponding to the provided index.
115
+
116
+ Parameters
117
+ ----------
118
+ index : int
119
+ Index of the patch to return.
120
+
121
+ Returns
122
+ -------
123
+ tuple of NDArray and TileInformation
124
+ Transformed patch.
125
+ """
126
+ tile_array, tile_info = self.data[index]
127
+
128
+ # Apply transforms
129
+ transformed_tile = self.patch_transform(patch=tile_array)
130
+
131
+ return transformed_tile, tile_info