careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,529 @@
1
+ """Next-Generation CAREamics DataModule."""
2
+
3
+ import copy
4
+ from collections.abc import Callable, Sequence
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Union, overload
7
+
8
+ import numpy as np
9
+ import pytorch_lightning as L
10
+ from numpy.typing import NDArray
11
+ from torch.utils.data import DataLoader, Sampler
12
+ from torch.utils.data._utils.collate import default_collate
13
+
14
+ from careamics.config.data.ng_data_config import NGDataConfig
15
+ from careamics.config.support import SupportedData
16
+ from careamics.dataset_ng.factory import create_dataset
17
+ from careamics.dataset_ng.grouped_index_sampler import GroupedIndexSampler
18
+ from careamics.dataset_ng.image_stack_loader import ImageStackLoader
19
+ from careamics.lightning.dataset_ng.data_module_utils import initialize_data_pair
20
+ from careamics.utils import get_logger
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ ItemType = Union[Path, str, NDArray[Any]]
25
+ """Type of input items passed to the dataset."""
26
+
27
+ InputType = Union[ItemType, Sequence[ItemType], None]
28
+ """Type of input data passed to the dataset."""
29
+
30
+
31
+ class CareamicsDataModule(L.LightningDataModule):
32
+ """Data module for Careamics dataset.
33
+
34
+ Parameters
35
+ ----------
36
+ data_config : DataConfig
37
+ Pydantic model for CAREamics data configuration.
38
+ train_data : Optional[InputType]
39
+ Training data, can be a path to a folder, a list of paths, or a numpy array.
40
+ train_data_target : Optional[InputType]
41
+ Training data target, can be a path to a folder,
42
+ a list of paths, or a numpy array.
43
+ train_data_mask : InputType (when filtering is needed)
44
+ Training data mask, can be a path to a folder,
45
+ a list of paths, or a numpy array. Used for coordinate filtering.
46
+ Only required when using coordinate-based patch filtering.
47
+ val_data : Optional[InputType]
48
+ Validation data, can be a path to a folder,
49
+ a list of paths, or a numpy array.
50
+ val_data_target : Optional[InputType]
51
+ Validation data target, can be a path to a folder,
52
+ a list of paths, or a numpy array.
53
+ pred_data : Optional[InputType]
54
+ Prediction data, can be a path to a folder, a list of paths,
55
+ or a numpy array.
56
+ pred_data_target : Optional[InputType]
57
+ Prediction data target, can be a path to a folder,
58
+ a list of paths, or a numpy array.
59
+ read_source_func : Optional[Callable], default=None
60
+ Function to read the source data. Only used for `custom`
61
+ data type (see DataModel).
62
+ read_kwargs : Optional[dict[str, Any]]
63
+ The kwargs for the read source function.
64
+ image_stack_loader : Optional[ImageStackLoader]
65
+ The image stack loader.
66
+ image_stack_loader_kwargs : Optional[dict[str, Any]]
67
+ The image stack loader kwargs.
68
+ extension_filter : str, default=""
69
+ Filter for file extensions. Only used for `custom` data types
70
+ (see DataModel).
71
+ val_percentage : Optional[float]
72
+ Percentage of the training data to use for validation. Only
73
+ used if `val_data` is None.
74
+ val_minimum_split : int, default=5
75
+ Minimum number of patches or files to split from the training data for
76
+ validation. Only used if `val_data` is None.
77
+
78
+
79
+ Attributes
80
+ ----------
81
+ config : DataConfig
82
+ Pydantic model for CAREamics data configuration.
83
+ data_type : str
84
+ Type of data, one of SupportedData.
85
+ batch_size : int
86
+ Batch size for the dataloaders.
87
+ extension_filter : str
88
+ Filter for file extensions, by default "".
89
+ read_source_func : Optional[Callable], default=None
90
+ Function to read the source data.
91
+ read_kwargs : Optional[dict[str, Any]], default=None
92
+ The kwargs for the read source function.
93
+ val_percentage : Optional[float]
94
+ Percentage of the training data to use for validation.
95
+ val_minimum_split : int, default=5
96
+ Minimum number of patches or files to split from the training data for
97
+ validation.
98
+ train_data : Optional[Any]
99
+ Training data, can be a path to a folder, a list of paths, or a numpy array.
100
+ train_data_target : Optional[Any]
101
+ Training data target, can be a path to a folder, a list of paths, or a numpy
102
+ array.
103
+ train_data_mask : Optional[Any]
104
+ Training data mask, can be a path to a folder, a list of paths, or a numpy
105
+ array.
106
+ val_data : Optional[Any]
107
+ Validation data, can be a path to a folder, a list of paths, or a numpy array.
108
+ val_data_target : Optional[Any]
109
+ Validation data target, can be a path to a folder, a list of paths, or a numpy
110
+ array.
111
+ pred_data : Optional[Any]
112
+ Prediction data, can be a path to a folder, a list of paths, or a numpy array.
113
+ pred_data_target : Optional[Any]
114
+ Prediction data target, can be a path to a folder, a list of paths, or a numpy
115
+ array.
116
+
117
+ Raises
118
+ ------
119
+ ValueError
120
+ If at least one of train_data, val_data or pred_data is not provided.
121
+ ValueError
122
+ If input and target data types are not consistent.
123
+ """
124
+
125
+ # standard use (no mask)
126
+ @overload
127
+ def __init__(
128
+ self,
129
+ data_config: NGDataConfig,
130
+ *,
131
+ train_data: InputType | None = None,
132
+ train_data_target: InputType | None = None,
133
+ val_data: InputType | None = None,
134
+ val_data_target: InputType | None = None,
135
+ pred_data: InputType | None = None,
136
+ pred_data_target: InputType | None = None,
137
+ extension_filter: str = "",
138
+ val_percentage: float | None = None,
139
+ val_minimum_split: int = 5,
140
+ ) -> None: ...
141
+
142
+ # with training mask for filtering
143
+ @overload
144
+ def __init__(
145
+ self,
146
+ data_config: NGDataConfig,
147
+ *,
148
+ train_data: InputType | None = None,
149
+ train_data_target: InputType | None = None,
150
+ train_data_mask: InputType,
151
+ val_data: InputType | None = None,
152
+ val_data_target: InputType | None = None,
153
+ pred_data: InputType | None = None,
154
+ pred_data_target: InputType | None = None,
155
+ extension_filter: str = "",
156
+ val_percentage: float | None = None,
157
+ val_minimum_split: int = 5,
158
+ ) -> None: ...
159
+
160
+ # custom read function (no mask)
161
+ @overload
162
+ def __init__(
163
+ self,
164
+ data_config: NGDataConfig,
165
+ *,
166
+ train_data: InputType | None = None,
167
+ train_data_target: InputType | None = None,
168
+ val_data: InputType | None = None,
169
+ val_data_target: InputType | None = None,
170
+ pred_data: InputType | None = None,
171
+ pred_data_target: InputType | None = None,
172
+ read_source_func: Callable,
173
+ read_kwargs: dict[str, Any] | None = None,
174
+ extension_filter: str = "",
175
+ val_percentage: float | None = None,
176
+ val_minimum_split: int = 5,
177
+ ) -> None: ...
178
+
179
+ # custom read function with training mask
180
+ @overload
181
+ def __init__(
182
+ self,
183
+ data_config: NGDataConfig,
184
+ *,
185
+ train_data: InputType | None = None,
186
+ train_data_target: InputType | None = None,
187
+ train_data_mask: InputType,
188
+ val_data: InputType | None = None,
189
+ val_data_target: InputType | None = None,
190
+ pred_data: InputType | None = None,
191
+ pred_data_target: InputType | None = None,
192
+ read_source_func: Callable,
193
+ read_kwargs: dict[str, Any] | None = None,
194
+ extension_filter: str = "",
195
+ val_percentage: float | None = None,
196
+ val_minimum_split: int = 5,
197
+ ) -> None: ...
198
+
199
+ # image stack loader (no mask)
200
+ @overload
201
+ def __init__(
202
+ self,
203
+ data_config: NGDataConfig,
204
+ *,
205
+ train_data: Any | None = None,
206
+ train_data_target: Any | None = None,
207
+ val_data: Any | None = None,
208
+ val_data_target: Any | None = None,
209
+ pred_data: Any | None = None,
210
+ pred_data_target: Any | None = None,
211
+ image_stack_loader: ImageStackLoader,
212
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
213
+ extension_filter: str = "",
214
+ val_percentage: float | None = None,
215
+ val_minimum_split: int = 5,
216
+ ) -> None: ...
217
+
218
+ # image stack loader with training mask
219
+ @overload
220
+ def __init__(
221
+ self,
222
+ data_config: NGDataConfig,
223
+ *,
224
+ train_data: Any | None = None,
225
+ train_data_target: Any | None = None,
226
+ train_data_mask: Any,
227
+ val_data: Any | None = None,
228
+ val_data_target: Any | None = None,
229
+ pred_data: Any | None = None,
230
+ pred_data_target: Any | None = None,
231
+ image_stack_loader: ImageStackLoader,
232
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
233
+ extension_filter: str = "",
234
+ val_percentage: float | None = None,
235
+ val_minimum_split: int = 5,
236
+ ) -> None: ...
237
+
238
+ def __init__(
239
+ self,
240
+ data_config: NGDataConfig,
241
+ *,
242
+ train_data: Any | None = None,
243
+ train_data_target: Any | None = None,
244
+ train_data_mask: Any | None = None,
245
+ val_data: Any | None = None,
246
+ val_data_target: Any | None = None,
247
+ pred_data: Any | None = None,
248
+ pred_data_target: Any | None = None,
249
+ read_source_func: Callable | None = None,
250
+ read_kwargs: dict[str, Any] | None = None,
251
+ image_stack_loader: ImageStackLoader | None = None,
252
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
253
+ extension_filter: str = "",
254
+ val_percentage: float | None = None,
255
+ val_minimum_split: int = 5,
256
+ ) -> None:
257
+ """
258
+ Data module for Careamics dataset initialization.
259
+
260
+ Create a lightning datamodule that handles creating datasets for training,
261
+ validation, and prediction.
262
+
263
+ Parameters
264
+ ----------
265
+ data_config : NGDataConfig
266
+ Pydantic model for CAREamics data configuration.
267
+ train_data : Optional[InputType]
268
+ Training data, can be a path to a folder, a list of paths, or a numpy array.
269
+ train_data_target : Optional[InputType]
270
+ Training data target, can be a path to a folder,
271
+ a list of paths, or a numpy array.
272
+ train_data_mask : InputType (when filtering is needed)
273
+ Training data mask, can be a path to a folder,
274
+ a list of paths, or a numpy array. Used for coordinate filtering.
275
+ Only required when using coordinate-based patch filtering.
276
+ val_data : Optional[InputType]
277
+ Validation data, can be a path to a folder,
278
+ a list of paths, or a numpy array.
279
+ val_data_target : Optional[InputType]
280
+ Validation data target, can be a path to a folder,
281
+ a list of paths, or a numpy array.
282
+ pred_data : Optional[InputType]
283
+ Prediction data, can be a path to a folder, a list of paths,
284
+ or a numpy array.
285
+ pred_data_target : Optional[InputType]
286
+ Prediction data target, can be a path to a folder,
287
+ a list of paths, or a numpy array.
288
+ read_source_func : Optional[Callable]
289
+ Function to read the source data, by default None. Only used for `custom`
290
+ data type (see DataModel).
291
+ read_kwargs : Optional[dict[str, Any]]
292
+ The kwargs for the read source function.
293
+ image_stack_loader : Optional[ImageStackLoader]
294
+ The image stack loader.
295
+ image_stack_loader_kwargs : Optional[dict[str, Any]]
296
+ The image stack loader kwargs.
297
+ extension_filter : str
298
+ Filter for file extensions, by default "". Only used for `custom` data types
299
+ (see DataModel).
300
+ val_percentage : Optional[float]
301
+ Percentage of the training data to use for validation. Only
302
+ used if `val_data` is None.
303
+ val_minimum_split : int
304
+ Minimum number of patches or files to split from the training data for
305
+ validation, by default 5. Only used if `val_data` is None.
306
+ """
307
+ super().__init__()
308
+
309
+ if train_data is None and val_data is None and pred_data is None:
310
+ raise ValueError(
311
+ "At least one of train_data, val_data or pred_data must be provided."
312
+ )
313
+ elif train_data is None != val_data is None:
314
+ raise ValueError(
315
+ "If one of train_data or val_data is provided, both must be provided."
316
+ )
317
+
318
+ self.config: NGDataConfig = data_config
319
+ self.data_type: str = data_config.data_type
320
+ self.batch_size: int = data_config.batch_size
321
+
322
+ self.extension_filter: str = (
323
+ extension_filter # list_files pulls the correct ext
324
+ )
325
+ self.read_source_func = read_source_func
326
+ self.read_kwargs = read_kwargs
327
+ self.image_stack_loader = image_stack_loader
328
+ self.image_stack_loader_kwargs = image_stack_loader_kwargs
329
+
330
+ # TODO: implement the validation split logic
331
+ self.val_percentage = val_percentage
332
+ self.val_minimum_split = val_minimum_split
333
+ if self.val_percentage is not None:
334
+ raise NotImplementedError("Validation split is not implemented.")
335
+
336
+ custom_loader = self.image_stack_loader is not None
337
+ self.train_data, self.train_data_target = initialize_data_pair(
338
+ self.data_type,
339
+ train_data,
340
+ train_data_target,
341
+ extension_filter,
342
+ custom_loader,
343
+ )
344
+ self.train_data_mask, _ = initialize_data_pair(
345
+ self.data_type, train_data_mask, None, extension_filter, custom_loader
346
+ )
347
+
348
+ self.val_data, self.val_data_target = initialize_data_pair(
349
+ self.data_type, val_data, val_data_target, extension_filter, custom_loader
350
+ )
351
+
352
+ # The pred_data_target can be needed to count metrics on the prediction
353
+ self.pred_data, self.pred_data_target = initialize_data_pair(
354
+ self.data_type, pred_data, pred_data_target, extension_filter, custom_loader
355
+ )
356
+
357
+ def setup(self, stage: str) -> None:
358
+ """
359
+ Setup datasets.
360
+
361
+ Lightning hook that is called at the beginning of fit (train + validate),
362
+ validate, test, or predict. Creates the datasets for a given stage.
363
+
364
+ Parameters
365
+ ----------
366
+ stage : str
367
+ The stage to set up datasets for.
368
+ Is either 'fit', 'validate', 'test', or 'predict'.
369
+
370
+ Raises
371
+ ------
372
+ NotImplementedError
373
+ If stage is not one of "fit", "validate" or "predict".
374
+ """
375
+ if stage == "fit":
376
+ if self.config.mode != "training":
377
+ raise ValueError(
378
+ f"CAREamicsDataModule configured for {self.config.mode} cannot be "
379
+ f"used for training. Please create a new CareamicsDataModule with "
380
+ f"a configuration with mode='training'."
381
+ )
382
+
383
+ self.train_dataset = create_dataset(
384
+ config=self.config,
385
+ inputs=self.train_data,
386
+ targets=self.train_data_target,
387
+ masks=self.train_data_mask,
388
+ read_func=self.read_source_func,
389
+ read_kwargs=self.read_kwargs,
390
+ image_stack_loader=self.image_stack_loader,
391
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
392
+ )
393
+ # TODO: ugly, need to find a better solution
394
+ self.stats = self.train_dataset.input_stats
395
+ self.config.set_means_and_stds(
396
+ self.train_dataset.input_stats.means,
397
+ self.train_dataset.input_stats.stds,
398
+ self.train_dataset.target_stats.means,
399
+ self.train_dataset.target_stats.stds,
400
+ )
401
+
402
+ validation_config = self.config.convert_mode("validating")
403
+ self.val_dataset = create_dataset(
404
+ config=validation_config,
405
+ inputs=self.val_data,
406
+ targets=self.val_data_target,
407
+ read_func=self.read_source_func,
408
+ read_kwargs=self.read_kwargs,
409
+ image_stack_loader=self.image_stack_loader,
410
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
411
+ )
412
+ elif stage == "validate":
413
+ validation_config = self.config.convert_mode("validating")
414
+ self.val_dataset = create_dataset(
415
+ config=validation_config,
416
+ inputs=self.val_data,
417
+ targets=self.val_data_target,
418
+ read_func=self.read_source_func,
419
+ read_kwargs=self.read_kwargs,
420
+ image_stack_loader=self.image_stack_loader,
421
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
422
+ )
423
+ self.stats = self.val_dataset.input_stats
424
+ elif stage == "predict":
425
+ if self.config.mode == "validating":
426
+ raise ValueError(
427
+ "CAREamicsDataModule configured for validating cannot be used for "
428
+ "prediction. Please create a new CareamicsDataModule with a "
429
+ "configuration with mode='predicting'."
430
+ )
431
+
432
+ self.predict_dataset = create_dataset(
433
+ config=(
434
+ self.config.convert_mode("predicting")
435
+ if self.config.mode == "training"
436
+ else self.config
437
+ ),
438
+ inputs=self.pred_data,
439
+ targets=self.pred_data_target,
440
+ read_func=self.read_source_func,
441
+ read_kwargs=self.read_kwargs,
442
+ image_stack_loader=self.image_stack_loader,
443
+ image_stack_loader_kwargs=self.image_stack_loader_kwargs,
444
+ )
445
+ self.stats = self.predict_dataset.input_stats
446
+ else:
447
+ raise NotImplementedError(f"Stage {stage} not implemented")
448
+
449
+ def _sampler(self, dataset: Literal["train", "val", "predict"]) -> Sampler | None:
450
+ sampler: GroupedIndexSampler | None
451
+ rng = np.random.default_rng(self.config.seed)
452
+ if not self.config.in_memory and self.config.data_type == SupportedData.TIFF:
453
+ match dataset:
454
+ case "train":
455
+ ds = self.train_dataset
456
+ case "val":
457
+ ds = self.val_dataset
458
+ case "predict":
459
+ ds = self.predict_dataset
460
+ case _:
461
+ raise (
462
+ f"Unrecognized dataset '{dataset}', should be one of 'train', "
463
+ "'val' or 'predict'."
464
+ )
465
+ sampler = GroupedIndexSampler.from_dataset(ds, rng=rng)
466
+ else:
467
+ sampler = None
468
+ return sampler
469
+
470
+ def train_dataloader(self) -> DataLoader:
471
+ """
472
+ Create a dataloader for training.
473
+
474
+ Returns
475
+ -------
476
+ DataLoader
477
+ Training dataloader.
478
+ """
479
+ sampler = self._sampler("train")
480
+ dataloader_params = copy.deepcopy(self.config.train_dataloader_params)
481
+ # have to remove shuffle with sampler because of torch error:
482
+ # ValueError: sampler option is mutually exclusive with shuffle
483
+ # TODO: there might be other parameters mutually exclusive with sampler
484
+ if (sampler is not None) and ("shuffle" in dataloader_params):
485
+ del dataloader_params["shuffle"]
486
+ return DataLoader(
487
+ self.train_dataset,
488
+ batch_size=self.batch_size,
489
+ collate_fn=default_collate,
490
+ sampler=sampler,
491
+ **dataloader_params,
492
+ )
493
+
494
+ def val_dataloader(self) -> DataLoader:
495
+ """
496
+ Create a dataloader for validation.
497
+
498
+ Returns
499
+ -------
500
+ DataLoader
501
+ Validation dataloader.
502
+ """
503
+ sampler = self._sampler("val")
504
+ dataloader_params = copy.deepcopy(self.config.val_dataloader_params)
505
+ if (sampler is not None) and ("shuffle" in dataloader_params):
506
+ del dataloader_params["shuffle"]
507
+ return DataLoader(
508
+ self.val_dataset,
509
+ batch_size=self.batch_size,
510
+ collate_fn=default_collate,
511
+ sampler=sampler,
512
+ **dataloader_params,
513
+ )
514
+
515
+ def predict_dataloader(self) -> DataLoader:
516
+ """
517
+ Create a dataloader for prediction.
518
+
519
+ Returns
520
+ -------
521
+ DataLoader
522
+ Prediction dataloader.
523
+ """
524
+ return DataLoader(
525
+ self.predict_dataset,
526
+ batch_size=self.batch_size,
527
+ collate_fn=default_collate,
528
+ **self.config.pred_dataloader_params,
529
+ )