careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,395 @@
1
+ """Utility functions for file and paths solver."""
2
+
3
+ from collections.abc import Sequence
4
+ from pathlib import Path
5
+ from typing import Any, Literal
6
+
7
+ from numpy import ndarray
8
+ from numpy.typing import NDArray
9
+
10
+ from careamics.config.support import SupportedData
11
+ from careamics.dataset.dataset_utils import list_files, validate_source_target_files
12
+ from careamics.dataset_ng.image_stack_loader.zarr_utils import is_valid_uri
13
+
14
+ ItemType = Path | str | NDArray[Any]
15
+ """Type of input items passed to the dataset."""
16
+
17
+ InputType = ItemType | Sequence[ItemType] | None
18
+ """Type of input data passed to the dataset."""
19
+
20
+
21
+ def list_files_in_directory(
22
+ data_type: Literal["tiff", "zarr", "czi", "custom"],
23
+ input_data,
24
+ target_data=None,
25
+ extension_filter: str = "",
26
+ ) -> tuple[list[Path], list[Path] | None]:
27
+ """List files from input and target directories.
28
+
29
+ Parameters
30
+ ----------
31
+ data_type : Literal["tiff", "zarr", "czi", "custom"]
32
+ The type of data to validate.
33
+ input_data : InputType
34
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
35
+ target_data : Optional[InputType]
36
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
37
+ array.
38
+ extension_filter : str, default=""
39
+ File extension filter to apply when listing files.
40
+
41
+ Returns
42
+ -------
43
+ list[Path]
44
+ A list of file paths for input data.
45
+ list[Path] | None
46
+ A list of file paths for target data, or None if target_data is None.
47
+ """
48
+ input_data = Path(input_data)
49
+
50
+ # list_files will return a list with a single element if the path is a file with
51
+ # the correct extension
52
+ input_files = list_files(input_data, data_type, extension_filter)
53
+ if target_data is None:
54
+ return input_files, None
55
+ else:
56
+ target_data = Path(target_data)
57
+ target_files = list_files(target_data, data_type, extension_filter)
58
+ validate_source_target_files(input_files, target_files)
59
+ return input_files, target_files
60
+
61
+
62
+ def convert_paths_to_pathlib(
63
+ input_data: Sequence[str | Path],
64
+ target_data: Sequence[str | Path] | None = None,
65
+ ) -> tuple[list[Path], list[Path] | None]:
66
+ """Create a list of file paths from the input and target data.
67
+
68
+ Parameters
69
+ ----------
70
+ input_data : Sequence[str | Path]
71
+ Input data, can be a path to a folder, or a list of paths.
72
+ target_data : Sequence[str | Path] | None
73
+ Target data, can be None, a path to a folder, or a list of paths.
74
+
75
+ Returns
76
+ -------
77
+ list[Path]
78
+ A list of file paths for input data.
79
+ list[Path] | None
80
+ A list of file paths for target data, or None if target_data is None.
81
+ """
82
+ input_files = [Path(item) if isinstance(item, str) else item for item in input_data]
83
+ if target_data is None:
84
+ return input_files, None
85
+ else:
86
+ target_files = [
87
+ Path(item) if isinstance(item, str) else item for item in target_data
88
+ ]
89
+ validate_source_target_files(input_files, target_files)
90
+ return input_files, target_files
91
+
92
+
93
+ def validate_input_target_type_consistency(
94
+ input_data: InputType,
95
+ target_data: InputType | None,
96
+ ) -> None:
97
+ """Validate if the input and target data types are consistent.
98
+
99
+ Parameters
100
+ ----------
101
+ input_data : InputType
102
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
103
+ target_data : Optional[InputType]
104
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
105
+ array.
106
+
107
+ Raises
108
+ ------
109
+ ValueError
110
+ If the input and target data types are not consistent.
111
+ """
112
+ if input_data is not None and target_data is not None:
113
+ if not isinstance(input_data, type(target_data)):
114
+ raise ValueError(
115
+ f"Inputs for input and target must be of the same type or None. "
116
+ f"Got {type(input_data)} and {type(target_data)}."
117
+ )
118
+ if isinstance(input_data, list) and isinstance(target_data, list):
119
+ if len(input_data) != len(target_data):
120
+ raise ValueError(
121
+ f"Inputs and targets must have the same length. "
122
+ f"Got {len(input_data)} and {len(target_data)}."
123
+ )
124
+ if not isinstance(input_data[0], type(target_data[0])):
125
+ raise ValueError(
126
+ f"Inputs and targets must have the same type. "
127
+ f"Got {type(input_data[0])} and {type(target_data[0])}."
128
+ )
129
+
130
+
131
+ def validate_array_input(
132
+ input_data: NDArray | list[NDArray],
133
+ target_data: NDArray | list[NDArray] | None,
134
+ ) -> tuple[list[NDArray], list[NDArray] | None]:
135
+ """Validate if the input data is a numpy array.
136
+
137
+ Parameters
138
+ ----------
139
+ input_data : InputType
140
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
141
+ target_data : Optional[InputType]
142
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
143
+ array.
144
+
145
+ Returns
146
+ -------
147
+ list[numpy.ndarray]
148
+ Validated input data.
149
+ list[numpy.ndarray] | None
150
+ Validated target data, None if the target data is None.
151
+
152
+ Raises
153
+ ------
154
+ ValueError
155
+ If the input data is not a numpy array or a list of numpy arrays.
156
+ """
157
+ if isinstance(input_data, ndarray):
158
+ input_list = [input_data]
159
+
160
+ if target_data is not None and not isinstance(target_data, ndarray):
161
+ raise ValueError(
162
+ f"Wrong target type. Expected numpy.ndarray, got {type(target_data)}. "
163
+ f"Check the data_type parameter or your inputs."
164
+ )
165
+ target_list = [target_data] if target_data is not None else None
166
+ return input_list, target_list
167
+ elif isinstance(input_data, list):
168
+ # TODO warn if wrong types inside list
169
+ input_list = [array for array in input_data if isinstance(array, ndarray)]
170
+
171
+ if target_data is None:
172
+ target_list = None
173
+ else:
174
+ assert isinstance(target_data, list)
175
+ target_list = [array for array in target_data if isinstance(array, ndarray)]
176
+ return input_list, target_list
177
+ else:
178
+ raise ValueError(
179
+ f"Wrong input type. Expected numpy.ndarray or list of numpy.ndarray, got "
180
+ f"{type(input_data)}. Check the data_type parameter or your inputs."
181
+ )
182
+
183
+
184
+ def validate_path_input(
185
+ data_type: Literal["tiff", "zarr", "czi", "custom"],
186
+ input_data: str | Path | list[str | Path],
187
+ target_data: str | Path | list[str | Path] | None,
188
+ extension_filter: str = "",
189
+ ) -> tuple[list[Path], list[Path] | None]:
190
+ """Validate if the input data is a path or a list of paths.
191
+
192
+ Parameters
193
+ ----------
194
+ data_type : Literal["tiff", "zarr", "czi", "custom"]
195
+ The type of data to validate.
196
+ input_data : str | Path | list[str | Path]
197
+ Input data, can be a path to a folder, a list of paths, or a numpy array.
198
+ target_data : str | Path | list[str | Path] | None
199
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
200
+ array.
201
+ extension_filter : str, default=""
202
+ File extension filter to apply when listing files.
203
+
204
+ Returns
205
+ -------
206
+ list[Path]
207
+ A list of file paths for input data.
208
+ list[Path] | None
209
+ A list of file paths for target data, or None if target_data is None.
210
+
211
+ Raises
212
+ ------
213
+ ValueError
214
+ If the input data is not a path or a list of paths.
215
+ """
216
+ if isinstance(input_data, (str, Path)):
217
+ input_list, target_list = list_files_in_directory(
218
+ data_type, input_data, target_data, extension_filter
219
+ )
220
+ return input_list, target_list
221
+ elif isinstance(input_data, list):
222
+ # TODO warn if wrong types inside list
223
+ input_list = [
224
+ Path(item)
225
+ for item in input_data
226
+ if isinstance(item, (str, Path)) and Path(item).exists()
227
+ ]
228
+
229
+ target_list = None
230
+ if target_data is not None:
231
+ assert isinstance(target_data, list)
232
+ target_list = [
233
+ Path(item)
234
+ for item in target_data
235
+ if isinstance(item, (str, Path)) and Path(item).exists()
236
+ ] # consistency with input is enforced by convert_paths_to_pathlib
237
+
238
+ return convert_paths_to_pathlib(input_list, target_list)
239
+ else:
240
+ raise ValueError(
241
+ f"Wrong input type, expected str or Path or list[str | Path], got "
242
+ f"{type(input_data)}. Check the data_type parameter or your inputs."
243
+ )
244
+
245
+
246
+ def validate_zarr_input(
247
+ input_data: str | Path | list[str | Path],
248
+ target_data: str | Path | list[str | Path] | None,
249
+ ) -> tuple[list[str] | list[Path], list[str] | list[Path] | None]:
250
+ """Validate if the input data corresponds a zarr input.
251
+
252
+ Parameters
253
+ ----------
254
+ input_data : str | Path | list[str | Path]
255
+ Input data, can be a path to a folder, to zarr file, a URI pointing to a zarr
256
+ dataset, or a list.
257
+ target_data : str | Path | list[str | Path] | None
258
+ Target data, can be None.
259
+
260
+ Returns
261
+ -------
262
+ list[str] or list[Path]
263
+ A list of zarr URIs or path for input data.
264
+ list[str] or list[Path] | None
265
+ A list of zarr URIs or paths for target data, or None if target_data is None.
266
+
267
+ Raises
268
+ ------
269
+ ValueError
270
+ If the input and target data types are not consistent.
271
+ ValueError
272
+ If the input data is not a zarr URI or path, or a list of zarr URIs or paths.
273
+ """
274
+ # validate_input_target_type_consistency is called beforehand, ensuring the types
275
+ # of input and target are the same
276
+ if isinstance(input_data, (str, Path)):
277
+ if Path(input_data).exists():
278
+ # either a path to a folder or a zarr file
279
+ # path to a folder will trigger collection of all zarr files in that folder
280
+ assert target_data is None or isinstance(target_data, (str, Path))
281
+ if target_data is not None and not Path(target_data).exists():
282
+ raise ValueError(
283
+ f"Target provided as path, but does not exist: {target_data}."
284
+ )
285
+
286
+ return validate_path_input("zarr", input_data, target_data)
287
+ elif isinstance(input_data, str) and is_valid_uri(input_data):
288
+ input_list = [input_data]
289
+
290
+ assert target_data is None or isinstance(target_data, str)
291
+ if target_data is not None and not is_valid_uri(target_data):
292
+ raise ValueError(
293
+ f"Wrong target type for zarr data. Expected a zarr URI, got "
294
+ f"{type(target_data)}."
295
+ )
296
+ target_list = [target_data] if target_data is not None else None
297
+ return input_list, target_list
298
+ else:
299
+ raise ValueError(
300
+ f"Wrong input type for zarr data. Expected a file URI or a path to a "
301
+ f" file, got {input_data}. Path may not exist."
302
+ )
303
+ elif isinstance(input_data, list):
304
+ # use first element as determinant of type
305
+ if isinstance(input_data[0], (str, Path)):
306
+ if Path(input_data[0]).exists():
307
+ return validate_path_input("zarr", input_data, target_data)
308
+ else:
309
+ final_input_list = [
310
+ str(item) for item in input_data if is_valid_uri(item)
311
+ ]
312
+ if target_data is not None:
313
+ assert isinstance(target_data, list)
314
+ final_target_list = [
315
+ str(item) for item in target_data if is_valid_uri(item)
316
+ ]
317
+ else:
318
+ final_target_list = None
319
+ return final_input_list, final_target_list
320
+ else:
321
+ raise ValueError(
322
+ f"Wrong input type for zarr data. Expected a list of file URIs or "
323
+ f" paths to files, got {type(input_data[0])}."
324
+ )
325
+ else:
326
+ raise ValueError(
327
+ f"Wrong input type for zarr data. Expected a file URI, a path to a file, "
328
+ f" or a list of those, got {type(input_data)}."
329
+ )
330
+
331
+
332
+ def initialize_data_pair(
333
+ data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
334
+ input_data: InputType,
335
+ target_data: InputType | None = None,
336
+ extension_filter: str = "",
337
+ custom_loader: bool = False,
338
+ ) -> tuple[InputType | list[InputType], InputType | list[InputType] | None]:
339
+ """
340
+ Initialize a pair of input and target data.
341
+
342
+ Parameters
343
+ ----------
344
+ data_type : Literal["array", "tiff", "zarr", "czi", "custom"]
345
+ The type of data to initialize.
346
+ input_data : InputType
347
+ Input data, can be None, a path to a folder, a list of paths, or a numpy
348
+ array.
349
+ target_data : InputType | None
350
+ Target data, can be None, a path to a folder, a list of paths, or a numpy
351
+ array.
352
+ extension_filter : str, default=""
353
+ File extension filter to apply when listing files.
354
+ custom_loader : bool, default=False
355
+ Whether a custom image stack loader is used.
356
+
357
+ Returns
358
+ -------
359
+ list[numpy.ndarray] | list[pathlib.Path]
360
+ Initialized input data. For file paths, returns a list of Path objects. For
361
+ numpy arrays, returns the arrays directly.
362
+ list[numpy.ndarray] | list[pathlib.Path] | None
363
+ Initialized target data. For file paths, returns a list of Path objects. For
364
+ numpy arrays, returns the arrays directly. Returns None if target_data is None.
365
+ """
366
+ if input_data is None:
367
+ return None, None
368
+
369
+ validate_input_target_type_consistency(input_data, target_data)
370
+
371
+ if data_type == SupportedData.ARRAY:
372
+ return validate_array_input(input_data, target_data)
373
+ elif data_type in (SupportedData.TIFF, SupportedData.CZI):
374
+ assert data_type != SupportedData.ARRAY.value # for mypy
375
+
376
+ if isinstance(input_data, (str, Path)):
377
+ assert target_data is None or isinstance(target_data, (str, Path))
378
+
379
+ return validate_path_input(data_type, input_data, target_data)
380
+ elif isinstance(input_data, list):
381
+ assert target_data is None or isinstance(target_data, list)
382
+
383
+ return validate_path_input(data_type, input_data, target_data)
384
+ else:
385
+ raise ValueError(
386
+ f"Unsupported input type for {data_type}: {type(input_data)}"
387
+ )
388
+ elif data_type == SupportedData.ZARR:
389
+ return validate_zarr_input(input_data, target_data)
390
+ elif data_type == SupportedData.CUSTOM:
391
+ if custom_loader:
392
+ return input_data, target_data
393
+ return validate_path_input(data_type, input_data, target_data, extension_filter)
394
+ else:
395
+ raise NotImplementedError(f"Unsupported data type: {data_type}")
@@ -0,0 +1,9 @@
1
+ """CAREamics PyTorch Lightning modules."""
2
+
3
+ from .care_module import CAREModule
4
+ from .n2v_module import N2VModule
5
+
6
+ __all__ = [
7
+ "CAREModule",
8
+ "N2VModule",
9
+ ]
@@ -0,0 +1,97 @@
1
+ """CARE Lightning DataModule."""
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any, Union
5
+
6
+ from careamics.config.algorithms.care_algorithm_config import CAREAlgorithm
7
+ from careamics.config.algorithms.n2n_algorithm_config import N2NAlgorithm
8
+ from careamics.config.support import SupportedLoss
9
+ from careamics.dataset_ng.dataset import ImageRegionData
10
+ from careamics.losses import mae_loss, mse_loss
11
+ from careamics.utils.logging import get_logger
12
+
13
+ from .unet_module import UnetModule
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class CAREModule(UnetModule):
19
+ """CAREamics PyTorch Lightning module for CARE algorithm.
20
+
21
+ Parameters
22
+ ----------
23
+ algorithm_config : CAREAlgorithm or dict
24
+ Configuration for the CARE algorithm, either as a CAREAlgorithm instance or a
25
+ dictionary.
26
+ """
27
+
28
+ def __init__(self, algorithm_config: Union[CAREAlgorithm, dict]) -> None:
29
+ """Instantiate CARE DataModule.
30
+
31
+ Parameters
32
+ ----------
33
+ algorithm_config : CAREAlgorithm or dict
34
+ Configuration for the CARE algorithm, either as a CAREAlgorithm instance or
35
+ a dictionary.
36
+ """
37
+ super().__init__(algorithm_config)
38
+ assert isinstance(
39
+ algorithm_config, CAREAlgorithm | N2NAlgorithm
40
+ ), "algorithm_config must be a CAREAlgorithm or a N2NAlgorithm"
41
+ loss = algorithm_config.loss
42
+ if loss == SupportedLoss.MAE:
43
+ self.loss_func: Callable = mae_loss
44
+ elif loss == SupportedLoss.MSE:
45
+ self.loss_func = mse_loss
46
+ else:
47
+ raise ValueError(f"Unsupported loss for Care: {loss}")
48
+
49
+ def training_step(
50
+ self,
51
+ batch: tuple[ImageRegionData, ImageRegionData],
52
+ batch_idx: Any,
53
+ ) -> Any:
54
+ """Training step for CARE module.
55
+
56
+ Parameters
57
+ ----------
58
+ batch : (ImageRegionData, ImageRegionData)
59
+ A tuple containing the input data and the target data.
60
+ batch_idx : Any
61
+ The index of the current batch in the training loop.
62
+
63
+ Returns
64
+ -------
65
+ Any
66
+ The loss value computed for the current batch.
67
+ """
68
+ # TODO: add validation to determine if target is initialized
69
+ x, target = batch[0], batch[1]
70
+
71
+ prediction = self.model(x.data)
72
+ loss = self.loss_func(prediction, target.data)
73
+
74
+ self._log_training_stats(loss, batch_size=x.data.shape[0])
75
+
76
+ return loss
77
+
78
+ def validation_step(
79
+ self,
80
+ batch: tuple[ImageRegionData, ImageRegionData],
81
+ batch_idx: Any,
82
+ ) -> None:
83
+ """Validation step for CARE module.
84
+
85
+ Parameters
86
+ ----------
87
+ batch : (ImageRegionData, ImageRegionData)
88
+ A tuple containing the input data and the target data.
89
+ batch_idx : Any
90
+ The index of the current batch in the training loop.
91
+ """
92
+ x, target = batch[0], batch[1]
93
+
94
+ prediction = self.model(x.data)
95
+ val_loss = self.loss_func(prediction, target.data)
96
+ self.metrics(prediction, target.data)
97
+ self._log_validation_stats(val_loss, batch_size=x.data.shape[0])
@@ -0,0 +1,106 @@
1
+ """Noise2Void Lightning DataModule."""
2
+
3
+ from typing import Any, Union
4
+
5
+ from careamics.config import (
6
+ N2VAlgorithm,
7
+ )
8
+ from careamics.dataset_ng.dataset import ImageRegionData
9
+ from careamics.losses import n2v_loss
10
+ from careamics.transforms import N2VManipulateTorch
11
+ from careamics.utils.logging import get_logger
12
+
13
+ from .unet_module import UnetModule
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class N2VModule(UnetModule):
19
+ """CAREamics PyTorch Lightning module for N2V algorithm.
20
+
21
+ Parameters
22
+ ----------
23
+ algorithm_config : N2VAlgorithm or dict
24
+ Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
25
+ dictionary.
26
+ """
27
+
28
+ def __init__(self, algorithm_config: Union[N2VAlgorithm, dict]) -> None:
29
+ """Instantiate N2V DataModule.
30
+
31
+ Parameters
32
+ ----------
33
+ algorithm_config : N2VAlgorithm or dict
34
+ Configuration for the N2V algorithm, either as an N2VAlgorithm instance or a
35
+ dictionary.
36
+ """
37
+ super().__init__(algorithm_config)
38
+
39
+ assert isinstance(
40
+ algorithm_config, N2VAlgorithm
41
+ ), "algorithm_config must be a N2VAlgorithm"
42
+
43
+ self.n2v_manipulate = N2VManipulateTorch(
44
+ n2v_manipulate_config=algorithm_config.n2v_config
45
+ )
46
+ self.loss_func = n2v_loss
47
+
48
+ def _load_best_checkpoint(self) -> None:
49
+ """Load the best checkpoint for N2V model."""
50
+ logger.warning(
51
+ "Loading best checkpoint for N2V model. Note that for N2V, "
52
+ "the checkpoint with the best validation metrics may not necessarily "
53
+ "have the best denoising performance."
54
+ )
55
+ super()._load_best_checkpoint()
56
+
57
+ def training_step(
58
+ self,
59
+ batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
60
+ batch_idx: Any,
61
+ ) -> Any:
62
+ """Training step for N2V model.
63
+
64
+ Parameters
65
+ ----------
66
+ batch : ImageRegionData or (ImageRegionData, ImageRegionData)
67
+ A tuple containing the input data and the target data.
68
+ batch_idx : Any
69
+ The index of the current batch in the training loop.
70
+
71
+ Returns
72
+ -------
73
+ Any
74
+ The loss value for the current training step.
75
+ """
76
+ x = batch[0]
77
+ x_masked, x_original, mask = self.n2v_manipulate(x.data)
78
+ prediction = self.model(x_masked)
79
+ loss = self.loss_func(prediction, x_original, mask)
80
+
81
+ self._log_training_stats(loss, batch_size=x.data.shape[0])
82
+
83
+ return loss
84
+
85
+ def validation_step(
86
+ self,
87
+ batch: Union[tuple[ImageRegionData], tuple[ImageRegionData, ImageRegionData]],
88
+ batch_idx: Any,
89
+ ) -> None:
90
+ """Validation step for N2V model.
91
+
92
+ Parameters
93
+ ----------
94
+ batch : ImageRegionData or (ImageRegionData, ImageRegionData)
95
+ A tuple containing the input data and the target data.
96
+ batch_idx : Any
97
+ The index of the current batch in the training loop.
98
+ """
99
+ x = batch[0]
100
+
101
+ x_masked, x_original, mask = self.n2v_manipulate(x.data)
102
+ prediction = self.model(x_masked)
103
+
104
+ val_loss = self.loss_func(prediction, x_original, mask)
105
+ self.metrics(prediction, x_original)
106
+ self._log_validation_stats(val_loss, batch_size=x.data.shape[0])