careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,118 @@
1
+ """Dataset utilities."""
2
+
3
+ import numpy as np
4
+
5
+ from careamics.utils.logging import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ def get_axes_order(axes_in, ref_axes="STCZYX"):
11
+ """
12
+ Get the order of axes based on a reference axes string.
13
+
14
+ Parameters
15
+ ----------
16
+ axes_in : str
17
+ Input axes string.
18
+ ref_axes : str
19
+ Reference axes string.
20
+
21
+ Returns
22
+ -------
23
+ list[int]
24
+ Indices of axes in the reference axes order.
25
+ """
26
+ indices = [axes_in.find(k) for k in ref_axes]
27
+ # remove all non-existing axes (index == -1)
28
+ new_indices = list(filter(lambda k: k != -1, indices))
29
+ return new_indices
30
+
31
+
32
+ def _get_shape_order(
33
+ shape_in: tuple[int, ...], axes_in: str, ref_axes: str = "STCZYX"
34
+ ) -> tuple[tuple[int, ...], str, list[int]]:
35
+ """
36
+ Compute a new shape for the array based on the reference axes.
37
+
38
+ Parameters
39
+ ----------
40
+ shape_in : tuple[int, ...]
41
+ Input shape.
42
+ axes_in : str
43
+ Input axes.
44
+ ref_axes : str
45
+ Reference axes.
46
+
47
+ Returns
48
+ -------
49
+ tuple[tuple[int, ...], str, list[int]]
50
+ New shape, new axes, indices of axes in the new axes order.
51
+ """
52
+ new_indices = get_axes_order(axes_in, ref_axes)
53
+
54
+ # find axes order and get new shape
55
+ new_axes = [axes_in[ind] for ind in new_indices]
56
+ new_shape = tuple([shape_in[ind] for ind in new_indices])
57
+
58
+ return new_shape, "".join(new_axes), new_indices
59
+
60
+
61
+ def reshape_array(x: np.ndarray, axes: str) -> np.ndarray:
62
+ """Reshape the data to (S, C, (Z), Y, X) by moving axes.
63
+
64
+ If the data has both S and T axes, the two axes will be merged. A singleton
65
+ dimension is added if there are no C axis.
66
+
67
+ Parameters
68
+ ----------
69
+ x : np.ndarray
70
+ Input array.
71
+ axes : str
72
+ Description of current axes in format `STCZYX`.
73
+
74
+ Returns
75
+ -------
76
+ np.ndarray
77
+ Reshaped array with shape (S, C, (Z), Y, X).
78
+ """
79
+ _x = x
80
+ _axes = axes
81
+
82
+ # sanity checks
83
+ if len(_axes) != len(_x.shape):
84
+ raise ValueError(
85
+ f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes "
86
+ f"correct?"
87
+ )
88
+
89
+ # get new x shape
90
+ new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes)
91
+
92
+ # if S is not in the list of axes, then add a singleton S
93
+ if "S" not in new_axes:
94
+ new_axes = "S" + new_axes
95
+ _x = _x[np.newaxis, ...]
96
+ new_x_shape = (1,) + new_x_shape
97
+
98
+ # need to change the array of indices
99
+ indices = [0] + [1 + i for i in indices]
100
+
101
+ # reshape by moving axes
102
+ destination = list(range(len(indices)))
103
+ _x = np.moveaxis(_x, indices, destination)
104
+
105
+ # remove T if necessary
106
+ if "T" in new_axes:
107
+ new_x_shape = (-1,) + new_x_shape[2:] # remove T and S
108
+ new_axes = new_axes.replace("T", "")
109
+
110
+ # reshape S and T together
111
+ _x = _x.reshape(new_x_shape)
112
+
113
+ # add channel
114
+ if "C" not in new_axes:
115
+ # Add channel axis after S
116
+ _x = np.expand_dims(_x, new_axes.index("S") + 1)
117
+
118
+ return _x
@@ -0,0 +1,141 @@
1
+ """File utilities."""
2
+
3
+ from fnmatch import fnmatch
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+
9
+ from careamics.config.support import SupportedData
10
+ from careamics.utils.logging import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def get_files_size(files: list[Path]) -> float:
16
+ """Get files size in MB.
17
+
18
+ Parameters
19
+ ----------
20
+ files : list of pathlib.Path
21
+ List of files.
22
+
23
+ Returns
24
+ -------
25
+ float
26
+ Total size of the files in MB.
27
+ """
28
+ return np.sum([f.stat().st_size / 1024**2 for f in files])
29
+
30
+
31
+ def list_files(
32
+ data_path: Union[str, Path],
33
+ data_type: Union[str, SupportedData],
34
+ extension_filter: str = "",
35
+ ) -> list[Path]:
36
+ """List recursively files in `data_path` and return a sorted list.
37
+
38
+ If `data_path` is a file, its name is validated against the `data_type` using
39
+ `fnmatch`, and the method returns `data_path` itself.
40
+
41
+ By default, if `data_type` is equal to `custom`, all files will be listed. To
42
+ further filter the files, use `extension_filter`.
43
+
44
+ `extension_filter` must be compatible with `fnmatch` and `Path.rglob`, e.g. "*.npy"
45
+ or "*.czi".
46
+
47
+ Parameters
48
+ ----------
49
+ data_path : Union[str, Path]
50
+ Path to the folder containing the data.
51
+ data_type : Union[str, SupportedData]
52
+ One of the supported data type (e.g. tif, custom).
53
+ extension_filter : str, optional
54
+ Extension filter, by default "".
55
+
56
+ Returns
57
+ -------
58
+ list[Path]
59
+ list of pathlib.Path objects.
60
+
61
+ Raises
62
+ ------
63
+ FileNotFoundError
64
+ If the data path does not exist.
65
+ ValueError
66
+ If the data path is empty or no files with the extension were found.
67
+ ValueError
68
+ If the file does not match the requested extension.
69
+ """
70
+ # convert to Path
71
+ data_path = Path(data_path)
72
+
73
+ # raise error if does not exists
74
+ if not data_path.exists():
75
+ raise FileNotFoundError(f"Data path {data_path} does not exist.")
76
+
77
+ # get extension compatible with fnmatch and rglob search
78
+ extension = SupportedData.get_extension_pattern(data_type)
79
+
80
+ if data_type == SupportedData.CUSTOM and extension_filter != "":
81
+ extension = extension_filter
82
+
83
+ # search recurively
84
+ if data_path.is_dir() and data_path.suffix != ".zarr":
85
+ # search recursively the path for files with the extension
86
+ files = sorted(data_path.rglob(extension))
87
+ else:
88
+ # raise error if it has the wrong extension
89
+ if not fnmatch(str(data_path.absolute()), extension):
90
+ raise ValueError(
91
+ f"File {data_path} does not match the requested extension "
92
+ f'"{extension}".'
93
+ )
94
+
95
+ # save in list
96
+ files = [data_path]
97
+
98
+ # raise error if no files were found
99
+ if len(files) == 0:
100
+ raise ValueError(
101
+ f'Data path {data_path} is empty or files with extension "{extension}" '
102
+ f"were not found."
103
+ )
104
+
105
+ return files
106
+
107
+
108
+ def validate_source_target_files(src_files: list[Path], tar_files: list[Path]) -> None:
109
+ """
110
+ Validate source and target path lists.
111
+
112
+ The two lists should have the same number of files, and the filenames should match.
113
+
114
+ Parameters
115
+ ----------
116
+ src_files : list of pathlib.Path
117
+ List of source files.
118
+ tar_files : list of pathlib.Path
119
+ List of target files.
120
+
121
+ Raises
122
+ ------
123
+ ValueError
124
+ If the number of files in source and target folders is not the same.
125
+ ValueError
126
+ If some filenames in Train and target folders are not the same.
127
+ """
128
+ # check equal length
129
+ if len(src_files) != len(tar_files):
130
+ raise ValueError(
131
+ f"The number of source files ({len(src_files)}) is not equal to the number "
132
+ f"of target files ({len(tar_files)})."
133
+ )
134
+
135
+ # check identical names
136
+ src_names = {f.name for f in src_files}
137
+ tar_names = {f.name for f in tar_files}
138
+ difference = src_names.symmetric_difference(tar_names)
139
+
140
+ if len(difference) > 0:
141
+ raise ValueError(f"Source and target files have different names: {difference}.")
@@ -0,0 +1,84 @@
1
+ """Function to iterate over files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable, Generator
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+ from numpy.typing import NDArray
10
+ from torch.utils.data import get_worker_info
11
+
12
+ from careamics.config import DataConfig, InferenceConfig
13
+ from careamics.file_io.read import read_tiff
14
+ from careamics.utils.logging import get_logger
15
+
16
+ from .dataset_utils import reshape_array
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ def iterate_over_files(
22
+ data_config: Union[DataConfig, InferenceConfig],
23
+ data_files: list[Path],
24
+ target_files: list[Path] | None = None,
25
+ read_source_func: Callable = read_tiff,
26
+ ) -> Generator[tuple[NDArray, NDArray | None], None, None]:
27
+ """Iterate over data source and yield whole reshaped images.
28
+
29
+ Parameters
30
+ ----------
31
+ data_config : CAREamics DataConfig or InferenceConfig
32
+ Configuration.
33
+ data_files : list of pathlib.Path
34
+ List of data files.
35
+ target_files : list of pathlib.Path, optional
36
+ List of target files, by default None.
37
+ read_source_func : Callable, optional
38
+ Function to read the source, by default read_tiff.
39
+
40
+ Yields
41
+ ------
42
+ NDArray
43
+ Image.
44
+ """
45
+ # When num_workers > 0, each worker process will have a different copy of the
46
+ # dataset object
47
+ # Configuring each copy independently to avoid having duplicate data returned
48
+ # from the workers
49
+ worker_info = get_worker_info()
50
+ worker_id = worker_info.id if worker_info is not None else 0
51
+ num_workers = worker_info.num_workers if worker_info is not None else 1
52
+
53
+ # iterate over the files
54
+ for i, filename in enumerate(data_files):
55
+ # retrieve file corresponding to the worker id
56
+ if i % num_workers == worker_id:
57
+ try:
58
+ # read data
59
+ sample = read_source_func(filename, data_config.axes)
60
+
61
+ # reshape array
62
+ reshaped_sample = reshape_array(sample, data_config.axes)
63
+
64
+ # read target, if available
65
+ if target_files is not None:
66
+ if filename.name != target_files[i].name:
67
+ raise ValueError(
68
+ f"File {filename} does not match target file "
69
+ f"{target_files[i]}. Have you passed sorted "
70
+ f"arrays?"
71
+ )
72
+
73
+ # read target
74
+ target = read_source_func(target_files[i], data_config.axes)
75
+
76
+ # reshape target
77
+ reshaped_target = reshape_array(target, data_config.axes)
78
+
79
+ yield reshaped_sample, reshaped_target
80
+ else:
81
+ yield reshaped_sample, None
82
+
83
+ except Exception as e:
84
+ logger.error(f"Error reading file {filename}: {e}")
@@ -0,0 +1,189 @@
1
+ """Computing data statistics."""
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+
7
+ def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
8
+ """
9
+ Compute mean and standard deviation of an array.
10
+
11
+ Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
12
+ computed per channel.
13
+
14
+ Parameters
15
+ ----------
16
+ image : NDArray
17
+ Input array.
18
+
19
+ Returns
20
+ -------
21
+ tuple of (list of floats, list of floats)
22
+ Lists of mean and standard deviation values per channel.
23
+ """
24
+ # Define the lists for storing mean and std values
25
+ means, stds = [], []
26
+ # Iterate over the channels dimension and compute mean and std
27
+ for ax in range(image.shape[1]):
28
+ means.append(image[:, ax, ...].mean())
29
+ stds.append(image[:, ax, ...].std())
30
+ return np.stack(means), np.stack(stds)
31
+
32
+
33
+ def update_iterative_stats(
34
+ count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
35
+ ) -> tuple[NDArray, NDArray, NDArray]:
36
+ """Update the mean and variance of an array iteratively.
37
+
38
+ Parameters
39
+ ----------
40
+ count : NDArray
41
+ Number of elements in the array. Shape: (C,).
42
+ mean : NDArray
43
+ Mean of the array. Shape: (C,).
44
+ m2 : NDArray
45
+ Variance of the array. Shape: (C,).
46
+ new_values : NDArray
47
+ New values to add to the mean and variance. Shape: (C, 1, 1, Z, Y, X).
48
+
49
+ Returns
50
+ -------
51
+ tuple[NDArray, NDArray, NDArray]
52
+ Updated count, mean, and variance.
53
+ """
54
+ num_channels = len(new_values)
55
+
56
+ # --- update channel-wise counts ---
57
+ count += np.ones_like(count) * np.prod(new_values.shape[1:])
58
+
59
+ # --- update channel-wise mean ---
60
+ # compute (new_values - old_mean) -> shape: (C, Z*Y*X)
61
+ delta = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
62
+ mean += np.sum(delta / count.reshape(num_channels, 1), axis=1)
63
+
64
+ # --- update channel-wise SoS ---
65
+ # compute (new_values - new_mean) -> shape: (C, Z*Y*X)
66
+ delta2 = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
67
+ m2 += np.sum(delta * delta2, axis=1)
68
+
69
+ return count, mean, m2
70
+
71
+
72
+ def finalize_iterative_stats(
73
+ count: NDArray, mean: NDArray, m2: NDArray
74
+ ) -> tuple[NDArray, NDArray]:
75
+ """Finalize the mean and variance computation.
76
+
77
+ Parameters
78
+ ----------
79
+ count : NDArray
80
+ Number of elements in the array. Shape: (C,).
81
+ mean : NDArray
82
+ Mean of the array. Shape: (C,).
83
+ m2 : NDArray
84
+ Variance of the array. Shape: (C,).
85
+
86
+ Returns
87
+ -------
88
+ tuple[NDArray, NDArray]
89
+ Final channel-wise mean and standard deviation.
90
+ """
91
+ std = np.sqrt(m2 / count)
92
+ if any(c < 2 for c in count):
93
+ return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
94
+ else:
95
+ return mean, std
96
+
97
+
98
+ class WelfordStatistics:
99
+ """Compute Welford statistics iteratively.
100
+
101
+ The Welford algorithm is used to compute the mean and variance of an array
102
+ iteratively. Based on the implementation from:
103
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
104
+ """
105
+
106
+ def update(self, array: NDArray, sample_idx: int) -> None:
107
+ """Update the Welford statistics.
108
+
109
+ Parameters
110
+ ----------
111
+ array : NDArray
112
+ Input array.
113
+ sample_idx : int
114
+ Current sample number.
115
+ """
116
+ self.sample_idx = sample_idx
117
+ sample_channels = np.array(np.split(array, array.shape[1], axis=1))
118
+
119
+ # Initialize the statistics
120
+ if self.sample_idx == 0:
121
+ # Compute the mean and standard deviation
122
+ self.mean, _ = compute_normalization_stats(array)
123
+ # Initialize the count and m2 with zero-valued arrays of shape (C,)
124
+ self.count, self.mean, self.m2 = update_iterative_stats(
125
+ count=np.zeros(array.shape[1]),
126
+ mean=self.mean,
127
+ m2=np.zeros(array.shape[1]),
128
+ new_values=sample_channels,
129
+ )
130
+ else:
131
+ # Update the statistics
132
+ self.count, self.mean, self.m2 = update_iterative_stats(
133
+ count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
134
+ )
135
+
136
+ self.sample_idx += 1
137
+
138
+ def finalize(self) -> tuple[NDArray, NDArray]:
139
+ """Finalize the Welford statistics.
140
+
141
+ Returns
142
+ -------
143
+ tuple or numpy arrays
144
+ Final mean and standard deviation.
145
+ """
146
+ return finalize_iterative_stats(self.count, self.mean, self.m2)
147
+
148
+
149
+ # from multiprocessing import Value
150
+ # from typing import tuple
151
+
152
+ # import numpy as np
153
+
154
+
155
+ # class RunningStats:
156
+ # """Calculates running mean and std."""
157
+
158
+ # def __init__(self) -> None:
159
+ # self.reset()
160
+
161
+ # def reset(self) -> None:
162
+ # """Reset the running stats."""
163
+ # self.avg_mean = Value("d", 0)
164
+ # self.avg_std = Value("d", 0)
165
+ # self.m2 = Value("d", 0)
166
+ # self.count = Value("i", 0)
167
+
168
+ # def init(self, mean: float, std: float) -> None:
169
+ # """Initialize running stats."""
170
+ # with self.avg_mean.get_lock():
171
+ # self.avg_mean.value += mean
172
+ # with self.avg_std.get_lock():
173
+ # self.avg_std.value = std
174
+
175
+ # def compute_std(self) -> tuple[float, float]:
176
+ # """Compute std."""
177
+ # if self.count.value >= 2:
178
+ # self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
179
+
180
+ # def update(self, value: float) -> None:
181
+ # """Update running stats."""
182
+ # with self.count.get_lock():
183
+ # self.count.value += 1
184
+ # delta = value - self.avg_mean.value
185
+ # with self.avg_mean.get_lock():
186
+ # self.avg_mean.value += delta / self.count.value
187
+ # delta2 = value - self.avg_mean.value
188
+ # with self.m2.get_lock():
189
+ # self.m2.value += delta * delta2