careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,341 @@
1
+ """Prediction Lightning data modules."""
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Any, Literal, Union
6
+
7
+ import numpy as np
8
+ import pytorch_lightning as L
9
+ from numpy.typing import NDArray
10
+ from torch.utils.data import DataLoader
11
+
12
+ from careamics.config import InferenceConfig
13
+ from careamics.config.support import SupportedData
14
+ from careamics.dataset import (
15
+ InMemoryPredDataset,
16
+ InMemoryTiledPredDataset,
17
+ IterablePredDataset,
18
+ IterableTiledPredDataset,
19
+ )
20
+ from careamics.dataset.dataset_utils import list_files
21
+ from careamics.dataset.tiling.collate_tiles import collate_tiles
22
+ from careamics.file_io.read import get_read_func
23
+ from careamics.utils import get_logger
24
+
25
+ PredictDatasetType = Union[
26
+ InMemoryPredDataset,
27
+ InMemoryTiledPredDataset,
28
+ IterablePredDataset,
29
+ IterableTiledPredDataset,
30
+ ]
31
+
32
+ logger = get_logger(__name__)
33
+
34
+
35
+ class PredictDataModule(L.LightningDataModule):
36
+ """
37
+ CAREamics Lightning prediction data module.
38
+
39
+ The data module can be used with Path, str or numpy arrays. The data can be either
40
+ a folder containing images or a single file.
41
+
42
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
43
+ and provide a function that returns a numpy array from a path as
44
+ `read_source_func` parameter. The function will receive a Path object and
45
+ an axies string as arguments, the axes being derived from the `data_config`.
46
+
47
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
48
+ "*.czi") to filter the files extension using `extension_filter`.
49
+
50
+ Parameters
51
+ ----------
52
+ pred_config : InferenceModel
53
+ Pydantic model for CAREamics prediction configuration.
54
+ pred_data : pathlib.Path or str or numpy.ndarray
55
+ Prediction data, can be a path to a folder, a file or a numpy array.
56
+ read_source_func : Callable, optional
57
+ Function to read custom types, by default None.
58
+ extension_filter : str, optional
59
+ Filter to filter file extensions for custom types, by default "".
60
+ dataloader_params : dict, optional
61
+ Dataloader parameters, by default {}.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ pred_config: InferenceConfig,
67
+ pred_data: Union[Path, str, NDArray],
68
+ read_source_func: Callable | None = None,
69
+ extension_filter: str = "",
70
+ dataloader_params: dict | None = None,
71
+ ) -> None:
72
+ """
73
+ Constructor.
74
+
75
+ The data module can be used with Path, str or numpy arrays. The data can be
76
+ either a folder containing images or a single file.
77
+
78
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
79
+ and provide a function that returns a numpy array from a path as
80
+ `read_source_func` parameter. The function will receive a Path object and
81
+ an axies string as arguments, the axes being derived from the `data_config`.
82
+
83
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
84
+ "*.czi") to filter the files extension using `extension_filter`.
85
+
86
+ Parameters
87
+ ----------
88
+ pred_config : InferenceModel
89
+ Pydantic model for CAREamics prediction configuration.
90
+ pred_data : pathlib.Path or str or numpy.ndarray
91
+ Prediction data, can be a path to a folder, a file or a numpy array.
92
+ read_source_func : Callable, optional
93
+ Function to read custom types, by default None.
94
+ extension_filter : str, optional
95
+ Filter to filter file extensions for custom types, by default "".
96
+ dataloader_params : dict, optional
97
+ Dataloader parameters, by default {}.
98
+
99
+ Raises
100
+ ------
101
+ ValueError
102
+ If the data type is `custom` and no `read_source_func` is provided.
103
+ ValueError
104
+ If the data type is `array` and the input is not a numpy array.
105
+ ValueError
106
+ If the data type is `tiff` and the input is neither a Path nor a str.
107
+ """
108
+ if dataloader_params is None:
109
+ dataloader_params = {}
110
+ if dataloader_params is None:
111
+ dataloader_params = {}
112
+ super().__init__()
113
+
114
+ # check that a read source function is provided for custom types
115
+ if pred_config.data_type == SupportedData.CUSTOM and read_source_func is None:
116
+ raise ValueError(
117
+ f"Data type {SupportedData.CUSTOM} is not allowed without "
118
+ f"specifying a `read_source_func` and an `extension_filer`."
119
+ )
120
+
121
+ # check correct input type
122
+ if (
123
+ isinstance(pred_data, np.ndarray)
124
+ and pred_config.data_type != SupportedData.ARRAY
125
+ ):
126
+ raise ValueError(
127
+ f"Received a numpy array as input, but the data type was set to "
128
+ f"{pred_config.data_type}. Set the data type "
129
+ f"to {SupportedData.ARRAY} to predict on numpy arrays."
130
+ )
131
+
132
+ # and that Path or str are passed, if tiff file type specified
133
+ elif (isinstance(pred_data, Path) or isinstance(pred_config, str)) and (
134
+ pred_config.data_type != SupportedData.TIFF
135
+ and pred_config.data_type != SupportedData.CUSTOM
136
+ ):
137
+ raise ValueError(
138
+ f"Received a path as input, but the data type was neither set to "
139
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
140
+ f" to {SupportedData.TIFF} or "
141
+ f"{SupportedData.CUSTOM} to predict on files."
142
+ )
143
+
144
+ # configuration data
145
+ self.prediction_config = pred_config
146
+ self.data_type = pred_config.data_type
147
+ self.batch_size = pred_config.batch_size
148
+ self.dataloader_params = dataloader_params
149
+
150
+ self.pred_data = pred_data
151
+ self.tile_size = pred_config.tile_size
152
+ self.tile_overlap = pred_config.tile_overlap
153
+
154
+ # check if it is tiled
155
+ self.tiled = self.tile_size is not None and self.tile_overlap is not None
156
+
157
+ # read source function
158
+ if pred_config.data_type == SupportedData.CUSTOM:
159
+ # mypy check
160
+ assert read_source_func is not None
161
+
162
+ self.read_source_func: Callable = read_source_func
163
+ elif pred_config.data_type != SupportedData.ARRAY:
164
+ self.read_source_func = get_read_func(pred_config.data_type)
165
+
166
+ self.extension_filter = extension_filter
167
+
168
+ def prepare_data(self) -> None:
169
+ """Hook used to prepare the data before calling `setup`."""
170
+ # if the data is a Path or a str
171
+ if not isinstance(self.pred_data, np.ndarray):
172
+ self.pred_files = list_files(
173
+ self.pred_data, self.data_type, self.extension_filter
174
+ )
175
+
176
+ def setup(self, stage: str | None = None) -> None:
177
+ """
178
+ Hook called at the beginning of predict.
179
+
180
+ Parameters
181
+ ----------
182
+ stage : Optional[str], optional
183
+ Stage, by default None.
184
+ """
185
+ # if numpy array
186
+ if self.data_type == SupportedData.ARRAY:
187
+ if self.tiled:
188
+ self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset(
189
+ prediction_config=self.prediction_config,
190
+ inputs=self.pred_data,
191
+ )
192
+ else:
193
+ self.predict_dataset = InMemoryPredDataset(
194
+ prediction_config=self.prediction_config,
195
+ inputs=self.pred_data,
196
+ )
197
+ else:
198
+ if self.tiled:
199
+ self.predict_dataset = IterableTiledPredDataset(
200
+ prediction_config=self.prediction_config,
201
+ src_files=self.pred_files,
202
+ read_source_func=self.read_source_func,
203
+ )
204
+ else:
205
+ self.predict_dataset = IterablePredDataset(
206
+ prediction_config=self.prediction_config,
207
+ src_files=self.pred_files,
208
+ read_source_func=self.read_source_func,
209
+ )
210
+
211
+ def predict_dataloader(self) -> DataLoader:
212
+ """
213
+ Create a dataloader for prediction.
214
+
215
+ Returns
216
+ -------
217
+ DataLoader
218
+ Prediction dataloader.
219
+ """
220
+ # For tiled predictions, we need to ensure tiles are processed in order
221
+ # to avoid stitching artifacts. Multi-worker processing can return batches
222
+ # out of order, so we disable it for tiled predictions.
223
+ dataloader_params = self.dataloader_params.copy()
224
+ if self.tiled:
225
+ dataloader_params["num_workers"] = 0
226
+
227
+ return DataLoader(
228
+ self.predict_dataset,
229
+ batch_size=self.batch_size,
230
+ collate_fn=collate_tiles if self.tiled else None,
231
+ **dataloader_params,
232
+ )
233
+
234
+
235
+ def create_predict_datamodule(
236
+ pred_data: Union[str, Path, NDArray],
237
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
238
+ axes: str,
239
+ image_means: list[float],
240
+ image_stds: list[float],
241
+ tile_size: tuple[int, ...] | None = None,
242
+ tile_overlap: tuple[int, ...] | None = None,
243
+ batch_size: int = 1,
244
+ tta_transforms: bool = True,
245
+ read_source_func: Callable | None = None,
246
+ extension_filter: str = "",
247
+ dataloader_params: dict | None = None,
248
+ ) -> PredictDataModule:
249
+ """Create a CAREamics prediction Lightning datamodule.
250
+
251
+ This function is used to explicitly pass the parameters usually contained in an
252
+ `inference_model` configuration.
253
+
254
+ Since the lightning datamodule has no access to the model, make sure that the
255
+ parameters passed to the datamodule are consistent with the model's requirements
256
+ and are coherent. This can be done by creating a `Configuration` object beforehand
257
+ and passing its parameters to the different Lightning modules.
258
+
259
+ The data module can be used with Path, str or numpy arrays. To use array data, set
260
+ `data_type` to `array` and pass a numpy array to `train_data`.
261
+
262
+ By default, CAREamics only supports types defined in
263
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
264
+ `data_type` to `custom` and provide a function that returns a numpy array from a
265
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression
266
+ (e.g. "*.jpeg") to filter the files extension using `extension_filter`.
267
+
268
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch
269
+ dataloaders, except for `batch_size`, which is set by the `batch_size`
270
+ parameter.
271
+
272
+ Parameters
273
+ ----------
274
+ pred_data : str or pathlib.Path or numpy.ndarray
275
+ Prediction data.
276
+ data_type : {"array", "tiff", "custom"}
277
+ Data type, see `SupportedData` for available options.
278
+ axes : str
279
+ Axes of the data, chosen among SCZYX.
280
+ image_means : list of float
281
+ Mean values for normalization, only used if Normalization is defined.
282
+ image_stds : list of float
283
+ Std values for normalization, only used if Normalization is defined.
284
+ tile_size : tuple of int, optional
285
+ Tile size, 2D or 3D tile size.
286
+ tile_overlap : tuple of int, optional
287
+ Tile overlap, 2D or 3D tile overlap.
288
+ batch_size : int
289
+ Batch size.
290
+ tta_transforms : bool, optional
291
+ Use test time augmentation, by default True.
292
+ read_source_func : Callable, optional
293
+ Function to read the source data, used if `data_type` is `custom`, by
294
+ default None.
295
+ extension_filter : str, optional
296
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
297
+ dataloader_params : dict, optional
298
+ Pytorch dataloader parameters, by default {}.
299
+
300
+ Returns
301
+ -------
302
+ PredictDataModule
303
+ CAREamics prediction datamodule.
304
+
305
+ Notes
306
+ -----
307
+ If you are using a UNet model and tiling, the tile size must be
308
+ divisible in every dimension by 2**d, where d is the depth of the model. This
309
+ avoids artefacts arising from the broken shift invariance induced by the
310
+ pooling layers of the UNet. If your image has less dimensions, as it may
311
+ happen in the Z dimension, consider padding your image.
312
+ """
313
+ if dataloader_params is None:
314
+ dataloader_params = {}
315
+
316
+ prediction_dict: dict[str, Any] = {
317
+ "data_type": data_type,
318
+ "tile_size": tile_size,
319
+ "tile_overlap": tile_overlap,
320
+ "axes": axes,
321
+ "image_means": image_means,
322
+ "image_stds": image_stds,
323
+ "tta_transforms": tta_transforms,
324
+ "batch_size": batch_size,
325
+ }
326
+
327
+ # validate configuration
328
+ prediction_config = InferenceConfig(**prediction_dict)
329
+
330
+ # sanity check on the dataloader parameters
331
+ if "batch_size" in dataloader_params:
332
+ # remove it
333
+ del dataloader_params["batch_size"]
334
+
335
+ return PredictDataModule(
336
+ pred_config=prediction_config,
337
+ pred_data=pred_data,
338
+ read_source_func=read_source_func,
339
+ extension_filter=extension_filter,
340
+ dataloader_params=dataloader_params,
341
+ )