careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,666 @@
1
+ """Training and validation Lightning data modules."""
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Any, Literal, Union
6
+
7
+ import numpy as np
8
+ import pytorch_lightning as L
9
+ from numpy.typing import NDArray
10
+ from torch.utils.data import DataLoader, IterableDataset
11
+
12
+ from careamics.config.data import DataConfig
13
+ from careamics.config.support import SupportedData
14
+ from careamics.config.transformations import TransformConfig
15
+ from careamics.dataset.dataset_utils import (
16
+ get_files_size,
17
+ list_files,
18
+ validate_source_target_files,
19
+ )
20
+ from careamics.dataset.in_memory_dataset import (
21
+ InMemoryDataset,
22
+ )
23
+ from careamics.dataset.iterable_dataset import (
24
+ PathIterableDataset,
25
+ )
26
+ from careamics.file_io.read import get_read_func
27
+ from careamics.utils import get_logger, get_ram_size
28
+
29
+ DatasetType = Union[InMemoryDataset, PathIterableDataset]
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ class TrainDataModule(L.LightningDataModule):
35
+ """
36
+ CAREamics Ligthning training and validation data module.
37
+
38
+ The data module can be used with Path, str or numpy arrays. In the case of
39
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
40
+ inputs, it calculates the total file size and estimate whether it can fit in
41
+ memory. If it does not, it iterates through the files. This behaviour can be
42
+ deactivated by setting `use_in_memory` to False, in which case it will
43
+ always use the iterating dataset to train on a Path or str.
44
+
45
+ The data can be either a folder containing images or a single file.
46
+
47
+ Validation can be omitted, in which case the validation data is extracted from
48
+ the training data. The percentage of the training data to use for validation,
49
+ as well as the minimum number of patches or files to split from the training
50
+ data can be set using `val_percentage` and `val_minimum_split`, respectively.
51
+
52
+ To read custom data types, you can set `data_type` to `custom` in `data_config`
53
+ and provide a function that returns a numpy array from a path as
54
+ `read_source_func` parameter. The function will receive a Path object and
55
+ an axies string as arguments, the axes being derived from the `data_config`.
56
+
57
+ You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
58
+ "*.czi") to filter the files extension using `extension_filter`.
59
+
60
+ Parameters
61
+ ----------
62
+ data_config : DataModel
63
+ Pydantic model for CAREamics data configuration.
64
+ train_data : pathlib.Path or str or numpy.ndarray
65
+ Training data, can be a path to a folder, a file or a numpy array.
66
+ val_data : pathlib.Path or str or numpy.ndarray, optional
67
+ Validation data, can be a path to a folder, a file or a numpy array, by
68
+ default None.
69
+ train_data_target : pathlib.Path or str or numpy.ndarray, optional
70
+ Training target data, can be a path to a folder, a file or a numpy array, by
71
+ default None.
72
+ val_data_target : pathlib.Path or str or numpy.ndarray, optional
73
+ Validation target data, can be a path to a folder, a file or a numpy array,
74
+ by default None.
75
+ read_source_func : Callable, optional
76
+ Function to read the source data, by default None. Only used for `custom`
77
+ data type (see DataModel).
78
+ extension_filter : str, optional
79
+ Filter for file extensions, by default "". Only used for `custom` data types
80
+ (see DataModel).
81
+ val_percentage : float, optional
82
+ Percentage of the training data to use for validation, by default 0.1. Only
83
+ used if `val_data` is None.
84
+ val_minimum_split : int, optional
85
+ Minimum number of patches or files to split from the training data for
86
+ validation, by default 5. Only used if `val_data` is None.
87
+ use_in_memory : bool, optional
88
+ Use in memory dataset if possible, by default True.
89
+
90
+ Attributes
91
+ ----------
92
+ data_config : DataModel
93
+ CAREamics data configuration.
94
+ data_type : SupportedData
95
+ Expected data type, one of "tiff", "array" or "custom".
96
+ batch_size : int
97
+ Batch size.
98
+ use_in_memory : bool
99
+ Whether to use in memory dataset if possible.
100
+ train_data : pathlib.Path or numpy.ndarray
101
+ Training data.
102
+ val_data : pathlib.Path or numpy.ndarray
103
+ Validation data.
104
+ train_data_target : pathlib.Path or numpy.ndarray
105
+ Training target data.
106
+ val_data_target : pathlib.Path or numpy.ndarray
107
+ Validation target data.
108
+ val_percentage : float
109
+ Percentage of the training data to use for validation, if no validation data is
110
+ provided.
111
+ val_minimum_split : int
112
+ Minimum number of patches or files to split from the training data for
113
+ validation, if no validation data is provided.
114
+ read_source_func : Optional[Callable]
115
+ Function to read the source data, used if `data_type` is `custom`.
116
+ extension_filter : str
117
+ Filter for file extensions, used if `data_type` is `custom`.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ data_config: DataConfig,
123
+ train_data: Union[Path, str, NDArray],
124
+ val_data: Union[Path, str, NDArray] | None = None,
125
+ train_data_target: Union[Path, str, NDArray] | None = None,
126
+ val_data_target: Union[Path, str, NDArray] | None = None,
127
+ read_source_func: Callable | None = None,
128
+ extension_filter: str = "",
129
+ val_percentage: float = 0.1,
130
+ val_minimum_split: int = 5,
131
+ use_in_memory: bool = True,
132
+ ) -> None:
133
+ """
134
+ Constructor.
135
+
136
+ Parameters
137
+ ----------
138
+ data_config : DataModel
139
+ Pydantic model for CAREamics data configuration.
140
+ train_data : pathlib.Path or str or numpy.ndarray
141
+ Training data, can be a path to a folder, a file or a numpy array.
142
+ val_data : pathlib.Path or str or numpy.ndarray, optional
143
+ Validation data, can be a path to a folder, a file or a numpy array, by
144
+ default None.
145
+ train_data_target : pathlib.Path or str or numpy.ndarray, optional
146
+ Training target data, can be a path to a folder, a file or a numpy array, by
147
+ default None.
148
+ val_data_target : pathlib.Path or str or numpy.ndarray, optional
149
+ Validation target data, can be a path to a folder, a file or a numpy array,
150
+ by default None.
151
+ read_source_func : Callable, optional
152
+ Function to read the source data, by default None. Only used for `custom`
153
+ data type (see DataModel).
154
+ extension_filter : str, optional
155
+ Filter for file extensions, by default "". Only used for `custom` data types
156
+ (see DataModel).
157
+ val_percentage : float, optional
158
+ Percentage of the training data to use for validation, by default 0.1. Only
159
+ used if `val_data` is None.
160
+ val_minimum_split : int, optional
161
+ Minimum number of patches or files to split from the training data for
162
+ validation, by default 5. Only used if `val_data` is None.
163
+ use_in_memory : bool, optional
164
+ Use in memory dataset if possible, by default True.
165
+
166
+ Raises
167
+ ------
168
+ NotImplementedError
169
+ Raised if target data is provided.
170
+ ValueError
171
+ If the input types are mixed (e.g. Path and numpy.ndarray).
172
+ ValueError
173
+ If the data type is `custom` and no `read_source_func` is provided.
174
+ ValueError
175
+ If the data type is `array` and the input is not a numpy array.
176
+ ValueError
177
+ If the data type is `tiff` and the input is neither a Path nor a str.
178
+ """
179
+ super().__init__()
180
+
181
+ # check input types coherence (no mixed types)
182
+ inputs = [train_data, val_data, train_data_target, val_data_target]
183
+ types_set = {type(i) for i in inputs}
184
+ if len(types_set) > 2: # None + expected type
185
+ raise ValueError(
186
+ f"Inputs for `train_data`, `val_data`, `train_data_target` and "
187
+ f"`val_data_target` must be of the same type or None. Got "
188
+ f"{types_set}."
189
+ )
190
+
191
+ # check that a read source function is provided for custom types
192
+ if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
193
+ raise ValueError(
194
+ f"Data type {SupportedData.CUSTOM} is not allowed without "
195
+ f"specifying a `read_source_func` and an `extension_filer`."
196
+ )
197
+
198
+ # check correct input type
199
+ if (
200
+ isinstance(train_data, np.ndarray)
201
+ and data_config.data_type != SupportedData.ARRAY
202
+ ):
203
+ raise ValueError(
204
+ f"Received a numpy array as input, but the data type was set to "
205
+ f"{data_config.data_type}. Set the data type in the configuration "
206
+ f"to {SupportedData.ARRAY} to train on numpy arrays."
207
+ )
208
+
209
+ # and that Path or str are passed, if tiff file type specified
210
+ elif (isinstance(train_data, Path) or isinstance(train_data, str)) and (
211
+ data_config.data_type != SupportedData.TIFF
212
+ and data_config.data_type != SupportedData.CUSTOM
213
+ ):
214
+ raise ValueError(
215
+ f"Received a path as input, but the data type was neither set to "
216
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
217
+ f"in the configuration to {SupportedData.TIFF} or "
218
+ f"{SupportedData.CUSTOM} to train on files."
219
+ )
220
+
221
+ # configuration
222
+ self.data_config: DataConfig = data_config
223
+ self.data_type: str = data_config.data_type
224
+ self.batch_size: int = data_config.batch_size
225
+ self.use_in_memory: bool = use_in_memory
226
+
227
+ # data: make data Path or np.ndarray, use type annotations for mypy
228
+ self.train_data: Union[Path, NDArray] = (
229
+ Path(train_data) if isinstance(train_data, str) else train_data
230
+ )
231
+
232
+ self.val_data: Union[Path, NDArray] = (
233
+ Path(val_data) if isinstance(val_data, str) else val_data
234
+ )
235
+
236
+ self.train_data_target: Union[Path, NDArray] = (
237
+ Path(train_data_target)
238
+ if isinstance(train_data_target, str)
239
+ else train_data_target
240
+ )
241
+
242
+ self.val_data_target: Union[Path, NDArray] = (
243
+ Path(val_data_target)
244
+ if isinstance(val_data_target, str)
245
+ else val_data_target
246
+ )
247
+
248
+ # validation split
249
+ self.val_percentage = val_percentage
250
+ self.val_minimum_split = val_minimum_split
251
+
252
+ # read source function corresponding to the requested type
253
+ if data_config.data_type == SupportedData.CUSTOM.value:
254
+ # mypy check
255
+ assert read_source_func is not None
256
+
257
+ self.read_source_func: Callable = read_source_func
258
+
259
+ elif data_config.data_type != SupportedData.ARRAY:
260
+ self.read_source_func = get_read_func(data_config.data_type)
261
+
262
+ self.extension_filter: str = extension_filter
263
+
264
+ def prepare_data(self) -> None:
265
+ """
266
+ Hook used to prepare the data before calling `setup`.
267
+
268
+ Here, we only need to examine the data if it was provided as a str or a Path.
269
+
270
+ TODO: from lightning doc:
271
+ prepare_data is called from the main process. It is not recommended to assign
272
+ state here (e.g. self.x = y) since it is called on a single process and if you
273
+ assign states here then they won't be available for other processes.
274
+
275
+ https://lightning.ai/docs/pytorch/stable/data/datamodule.html
276
+ """
277
+ # if the data is a Path or a str
278
+ if (
279
+ not isinstance(self.train_data, np.ndarray)
280
+ and not isinstance(self.val_data, np.ndarray)
281
+ and not isinstance(self.train_data_target, np.ndarray)
282
+ and not isinstance(self.val_data_target, np.ndarray)
283
+ ):
284
+ # list training files
285
+ self.train_files = list_files(
286
+ self.train_data, self.data_type, self.extension_filter
287
+ )
288
+ self.train_files_size = get_files_size(self.train_files)
289
+
290
+ # list validation files
291
+ if self.val_data is not None:
292
+ self.val_files = list_files(
293
+ self.val_data, self.data_type, self.extension_filter
294
+ )
295
+
296
+ # same for target data
297
+ if self.train_data_target is not None:
298
+ self.train_target_files: list[Path] = list_files(
299
+ self.train_data_target, self.data_type, self.extension_filter
300
+ )
301
+
302
+ # verify that they match the training data
303
+ validate_source_target_files(self.train_files, self.train_target_files)
304
+
305
+ if self.val_data_target is not None:
306
+ self.val_target_files = list_files(
307
+ self.val_data_target, self.data_type, self.extension_filter
308
+ )
309
+
310
+ # verify that they match the validation data
311
+ validate_source_target_files(self.val_files, self.val_target_files)
312
+
313
+ def setup(self, *args: Any, **kwargs: Any) -> None:
314
+ """Hook called at the beginning of fit, validate, or predict.
315
+
316
+ Parameters
317
+ ----------
318
+ *args : Any
319
+ Unused.
320
+ **kwargs : Any
321
+ Unused.
322
+ """
323
+ # if numpy array
324
+ if self.data_type == SupportedData.ARRAY:
325
+ # mypy checks
326
+ assert isinstance(self.train_data, np.ndarray)
327
+ if self.train_data_target is not None:
328
+ assert isinstance(self.train_data_target, np.ndarray)
329
+
330
+ # train dataset
331
+ self.train_dataset: DatasetType = InMemoryDataset(
332
+ data_config=self.data_config,
333
+ inputs=self.train_data,
334
+ input_target=self.train_data_target,
335
+ )
336
+
337
+ # validation dataset
338
+ if self.val_data is not None:
339
+ # mypy checks
340
+ assert isinstance(self.val_data, np.ndarray)
341
+ if self.val_data_target is not None:
342
+ assert isinstance(self.val_data_target, np.ndarray)
343
+
344
+ # create its own dataset
345
+ self.val_dataset: DatasetType = InMemoryDataset(
346
+ data_config=self.data_config,
347
+ inputs=self.val_data,
348
+ input_target=self.val_data_target,
349
+ )
350
+ else:
351
+ # extract validation from the training patches
352
+ self.val_dataset = self.train_dataset.split_dataset(
353
+ percentage=self.val_percentage,
354
+ minimum_patches=self.val_minimum_split,
355
+ )
356
+
357
+ # else we read files
358
+ else:
359
+ # Heuristics, if the file size is smaller than 80% of the RAM,
360
+ # we run the training in memory, otherwise we switch to iterable dataset
361
+ # The switch is deactivated if use_in_memory is False
362
+ if self.use_in_memory and self.train_files_size < get_ram_size() * 0.8:
363
+ # train dataset
364
+ self.train_dataset = InMemoryDataset(
365
+ data_config=self.data_config,
366
+ inputs=self.train_files,
367
+ input_target=(
368
+ self.train_target_files if self.train_data_target else None
369
+ ),
370
+ read_source_func=self.read_source_func,
371
+ )
372
+
373
+ # validation dataset
374
+ if self.val_data is not None:
375
+ self.val_dataset = InMemoryDataset(
376
+ data_config=self.data_config,
377
+ inputs=self.val_files,
378
+ input_target=(
379
+ self.val_target_files if self.val_data_target else None
380
+ ),
381
+ read_source_func=self.read_source_func,
382
+ )
383
+ else:
384
+ # split dataset
385
+ self.val_dataset = self.train_dataset.split_dataset(
386
+ percentage=self.val_percentage,
387
+ minimum_patches=self.val_minimum_split,
388
+ )
389
+
390
+ # else if the data is too large, load file by file during training
391
+ else:
392
+ # create training dataset
393
+ self.train_dataset = PathIterableDataset(
394
+ data_config=self.data_config,
395
+ src_files=self.train_files,
396
+ target_files=(
397
+ self.train_target_files if self.train_data_target else None
398
+ ),
399
+ read_source_func=self.read_source_func,
400
+ )
401
+
402
+ # create validation dataset
403
+ if self.val_data is not None:
404
+ # create its own dataset
405
+ self.val_dataset = PathIterableDataset(
406
+ data_config=self.data_config,
407
+ src_files=self.val_files,
408
+ target_files=(
409
+ self.val_target_files if self.val_data_target else None
410
+ ),
411
+ read_source_func=self.read_source_func,
412
+ )
413
+ elif len(self.train_files) <= self.val_minimum_split:
414
+ raise ValueError(
415
+ f"Not enough files to split a minimum of "
416
+ f"{self.val_minimum_split} files, got {len(self.train_files)} "
417
+ f"files."
418
+ )
419
+ else:
420
+ # extract validation from the training patches
421
+ self.val_dataset = self.train_dataset.split_dataset(
422
+ percentage=self.val_percentage,
423
+ minimum_number=self.val_minimum_split,
424
+ )
425
+
426
+ def get_data_statistics(self) -> tuple[list[float], list[float]]:
427
+ """Return training data statistics.
428
+
429
+ Returns
430
+ -------
431
+ tuple of list
432
+ Means and standard deviations across channels of the training data.
433
+ """
434
+ return self.train_dataset.get_data_statistics()
435
+
436
+ def train_dataloader(self) -> Any:
437
+ """
438
+ Create a dataloader for training.
439
+
440
+ Returns
441
+ -------
442
+ Any
443
+ Training dataloader.
444
+ """
445
+ train_dataloader_params = self.data_config.train_dataloader_params.copy()
446
+
447
+ # NOTE: When next-gen datasets are completed this can be removed
448
+ # iterable dataset cannot be shuffled
449
+ if isinstance(self.train_dataset, IterableDataset):
450
+ del train_dataloader_params["shuffle"]
451
+
452
+ return DataLoader(
453
+ self.train_dataset,
454
+ batch_size=self.batch_size,
455
+ **train_dataloader_params,
456
+ )
457
+
458
+ def val_dataloader(self) -> Any:
459
+ """
460
+ Create a dataloader for validation.
461
+
462
+ Returns
463
+ -------
464
+ Any
465
+ Validation dataloader.
466
+ """
467
+ return DataLoader(
468
+ self.val_dataset,
469
+ batch_size=self.batch_size,
470
+ **self.data_config.val_dataloader_params,
471
+ )
472
+
473
+
474
+ def create_train_datamodule(
475
+ train_data: Union[str, Path, NDArray],
476
+ data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
477
+ patch_size: list[int],
478
+ axes: str,
479
+ batch_size: int,
480
+ val_data: Union[str, Path, NDArray] | None = None,
481
+ transforms: list[TransformConfig] | None = None,
482
+ train_target_data: Union[str, Path, NDArray] | None = None,
483
+ val_target_data: Union[str, Path, NDArray] | None = None,
484
+ read_source_func: Callable | None = None,
485
+ extension_filter: str = "",
486
+ val_percentage: float = 0.1,
487
+ val_minimum_patches: int = 5,
488
+ train_dataloader_params: dict | None = None,
489
+ val_dataloader_params: dict | None = None,
490
+ use_in_memory: bool = True,
491
+ ) -> TrainDataModule:
492
+ """Create a TrainDataModule.
493
+
494
+ This function is used to explicitly pass the parameters usually contained in a
495
+ `GenericDataConfig` to a TrainDataModule.
496
+
497
+ Since the lightning datamodule has no access to the model, make sure that the
498
+ parameters passed to the datamodule are consistent with the model's requirements and
499
+ are coherent.
500
+
501
+ The default augmentations are XY flip and XY rotation. To use a different set of
502
+ transformations, you can pass a list of transforms to `transforms`.
503
+
504
+ The data module can be used with Path, str or numpy arrays. In the case of
505
+ numpy arrays, it loads and computes all the patches in memory. For Path and str
506
+ inputs, it calculates the total file size and estimate whether it can fit in
507
+ memory. If it does not, it iterates through the files. This behaviour can be
508
+ deactivated by setting `use_in_memory` to False, in which case it will
509
+ always use the iterating dataset to train on a Path or str.
510
+
511
+ To use array data, set `data_type` to `array` and pass a numpy array to
512
+ `train_data`.
513
+
514
+ By default, CAREamics only supports types defined in
515
+ `careamics.config.support.SupportedData`. To read custom data types, you can set
516
+ `data_type` to `custom` and provide a function that returns a numpy array from a
517
+ path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression (e.g.
518
+ "*.jpeg") to filter the files extension using `extension_filter`.
519
+
520
+ In the absence of validation data, the validation data is extracted from the
521
+ training data. The percentage of the training data to use for validation, as well as
522
+ the minimum number of patches to split from the training data for validation can be
523
+ set using `val_percentage` and `val_minimum_patches`, respectively.
524
+
525
+ In `dataloader_params`, you can pass any parameter accepted by PyTorch dataloaders,
526
+ except for `batch_size`, which is set by the `batch_size` parameter.
527
+
528
+ Parameters
529
+ ----------
530
+ train_data : pathlib.Path or str or numpy.ndarray
531
+ Training data.
532
+ data_type : {"array", "tiff", "custom"}
533
+ Data type, see `SupportedData` for available options.
534
+ patch_size : list of int
535
+ Patch size, 2D or 3D patch size.
536
+ axes : str
537
+ Axes of the data, chosen amongst SCZYX.
538
+ batch_size : int
539
+ Batch size.
540
+ val_data : pathlib.Path or str or numpy.ndarray, optional
541
+ Validation data, by default None.
542
+ transforms : list of Transforms, optional
543
+ List of transforms to apply to training patches. If None, default transforms
544
+ are applied.
545
+ train_target_data : pathlib.Path or str or numpy.ndarray, optional
546
+ Training target data, by default None.
547
+ val_target_data : pathlib.Path or str or numpy.ndarray, optional
548
+ Validation target data, by default None.
549
+ read_source_func : Callable, optional
550
+ Function to read the source data, used if `data_type` is `custom`, by
551
+ default None.
552
+ extension_filter : str, optional
553
+ Filter for file extensions, used if `data_type` is `custom`, by default "".
554
+ val_percentage : float, optional
555
+ Percentage of the training data to use for validation if no validation data
556
+ is given, by default 0.1.
557
+ val_minimum_patches : int, optional
558
+ Minimum number of patches to split from the training data for validation if
559
+ no validation data is given, by default 5.
560
+ train_dataloader_params : dict, optional
561
+ Pytorch dataloader parameters for the training data, by default {}.
562
+ val_dataloader_params : dict, optional
563
+ Pytorch dataloader parameters for the validation data, by default {}.
564
+ use_in_memory : bool, optional
565
+ Use in memory dataset if possible, by default True.
566
+
567
+ Returns
568
+ -------
569
+ TrainDataModule
570
+ CAREamics training Lightning data module.
571
+
572
+ Examples
573
+ --------
574
+ Create a TrainingDataModule with default transforms with a numpy array:
575
+ >>> import numpy as np
576
+ >>> from careamics.lightning import create_train_datamodule
577
+ >>> my_array = np.arange(256).reshape(16, 16)
578
+ >>> data_module = create_train_datamodule(
579
+ ... train_data=my_array,
580
+ ... data_type="array",
581
+ ... patch_size=(8, 8),
582
+ ... axes='YX',
583
+ ... batch_size=2,
584
+ ... )
585
+
586
+ For custom data types (those not supported by CAREamics), then one can pass a read
587
+ function and a filter for the files extension:
588
+ >>> import numpy as np
589
+ >>> from careamics.lightning import create_train_datamodule
590
+ >>>
591
+ >>> def read_npy(path):
592
+ ... return np.load(path)
593
+ >>>
594
+ >>> data_module = create_train_datamodule(
595
+ ... train_data="path/to/data",
596
+ ... data_type="custom",
597
+ ... patch_size=(8, 8),
598
+ ... axes='YX',
599
+ ... batch_size=2,
600
+ ... read_source_func=read_npy,
601
+ ... extension_filter="*.npy",
602
+ ... )
603
+
604
+ If you want to use a different set of transformations, you can pass a list of
605
+ transforms:
606
+ >>> import numpy as np
607
+ >>> from careamics.lightning import create_train_datamodule
608
+ >>> from careamics.config.transformations import XYFlipConfig
609
+ >>> from careamics.config.support import SupportedTransform
610
+ >>> my_array = np.arange(256).reshape(16, 16)
611
+ >>> my_transforms = [
612
+ ... XYFlipConfig(flip_y=False),
613
+ ... ]
614
+ >>> data_module = create_train_datamodule(
615
+ ... train_data=my_array,
616
+ ... data_type="array",
617
+ ... patch_size=(8, 8),
618
+ ... axes='YX',
619
+ ... batch_size=2,
620
+ ... transforms=my_transforms,
621
+ ... )
622
+ """
623
+ if train_dataloader_params is None:
624
+ train_dataloader_params = {"shuffle": True}
625
+
626
+ if val_dataloader_params is None:
627
+ val_dataloader_params = {"shuffle": False}
628
+
629
+ data_dict: dict[str, Any] = {
630
+ "mode": "train",
631
+ "data_type": data_type,
632
+ "patch_size": patch_size,
633
+ "axes": axes,
634
+ "batch_size": batch_size,
635
+ "train_dataloader_params": train_dataloader_params,
636
+ "val_dataloader_params": val_dataloader_params,
637
+ }
638
+
639
+ # if transforms are passed (otherwise it will use the default ones)
640
+ if transforms is not None:
641
+ data_dict["transforms"] = transforms
642
+
643
+ # instantiate data configuration
644
+ data_config = DataConfig(**data_dict)
645
+
646
+ # sanity check on the dataloader parameters
647
+ if "batch_size" in train_dataloader_params:
648
+ # remove it
649
+ del train_dataloader_params["batch_size"]
650
+
651
+ if "batch_size" in val_dataloader_params:
652
+ # remove it
653
+ del val_dataloader_params["batch_size"]
654
+
655
+ return TrainDataModule(
656
+ data_config=data_config,
657
+ train_data=train_data,
658
+ val_data=val_data,
659
+ train_data_target=train_target_data,
660
+ val_data_target=val_target_data,
661
+ read_source_func=read_source_func,
662
+ extension_filter=extension_filter,
663
+ val_percentage=val_percentage,
664
+ val_minimum_split=val_minimum_patches,
665
+ use_in_memory=use_in_memory,
666
+ )
@@ -0,0 +1,21 @@
1
+ """Losses module."""
2
+
3
+ __all__ = [
4
+ "denoisplit_loss",
5
+ "denoisplit_musplit_loss",
6
+ "hdn_loss",
7
+ "loss_factory",
8
+ "mae_loss",
9
+ "mse_loss",
10
+ "musplit_loss",
11
+ "n2v_loss",
12
+ ]
13
+
14
+ from .fcn.losses import mae_loss, mse_loss, n2v_loss
15
+ from .loss_factory import loss_factory
16
+ from .lvae.losses import (
17
+ denoisplit_loss,
18
+ denoisplit_musplit_loss,
19
+ hdn_loss,
20
+ musplit_loss,
21
+ )
@@ -0,0 +1 @@
1
+ """FCN losses."""