careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,27 @@
1
+ """Data Pydantic configuration models."""
2
+
3
+ __all__ = [
4
+ "DataConfig",
5
+ "MaskFilterConfig",
6
+ "MaxFilterConfig",
7
+ "MeanSTDFilterConfig",
8
+ "NGDataConfig",
9
+ "RandomPatchingConfig",
10
+ "ShannonFilterConfig",
11
+ "TiledPatchingConfig",
12
+ "WholePatchingConfig",
13
+ ]
14
+
15
+ from .data_config import DataConfig
16
+ from .ng_data_config import NGDataConfig
17
+ from .patch_filter import (
18
+ MaskFilterConfig,
19
+ MaxFilterConfig,
20
+ MeanSTDFilterConfig,
21
+ ShannonFilterConfig,
22
+ )
23
+ from .patching_strategies import (
24
+ RandomPatchingConfig,
25
+ TiledPatchingConfig,
26
+ WholePatchingConfig,
27
+ )
@@ -0,0 +1,472 @@
1
+ """Data configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import sys
7
+ from collections.abc import Sequence
8
+ from pprint import pformat
9
+ from typing import Annotated, Any, Literal, Self, Union
10
+ from warnings import warn
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+ from pydantic import (
15
+ BaseModel,
16
+ ConfigDict,
17
+ Field,
18
+ PlainSerializer,
19
+ field_validator,
20
+ model_validator,
21
+ )
22
+
23
+ from ..transformations import XYFlipConfig, XYRandomRotate90Config
24
+ from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
25
+
26
+
27
+ def np_float_to_scientific_str(x: float) -> str:
28
+ """Return a string scientific representation of a float.
29
+
30
+ In particular, this method is used to serialize floats to strings, allowing
31
+ numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
32
+
33
+ Parameters
34
+ ----------
35
+ x : float
36
+ Input value.
37
+
38
+ Returns
39
+ -------
40
+ str
41
+ Scientific string representation of the input value.
42
+ """
43
+ return np.format_float_scientific(x, precision=7)
44
+
45
+
46
+ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
47
+ """Annotated float type, used to serialize floats to strings."""
48
+
49
+
50
+ class DataConfig(BaseModel):
51
+ """Data configuration.
52
+
53
+ If std is specified, mean must be specified as well. Note that setting the std first
54
+ and then the mean (if they were both `None` before) will raise a validation error.
55
+ Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
56
+ to be lists of floats, one for each channel. For supervised tasks, the mean and std
57
+ of the target could be different from the input data.
58
+
59
+ All supported transforms are defined in the SupportedTransform enum.
60
+
61
+ Examples
62
+ --------
63
+ Minimum example:
64
+
65
+ >>> data = DataConfig(
66
+ ... data_type="array", # defined in SupportedData
67
+ ... patch_size=[128, 128],
68
+ ... batch_size=4,
69
+ ... axes="YX"
70
+ ... )
71
+
72
+ To change the image_means and image_stds of the data:
73
+ >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
74
+
75
+ One can pass also a list of transformations, by keyword, using the
76
+ SupportedTransform value:
77
+ >>> from careamics.config.support import SupportedTransform
78
+ >>> data = DataConfig(
79
+ ... data_type="tiff",
80
+ ... patch_size=[128, 128],
81
+ ... batch_size=4,
82
+ ... axes="YX",
83
+ ... transforms=[
84
+ ... {
85
+ ... "name": "XYFlip",
86
+ ... }
87
+ ... ]
88
+ ... )
89
+ """
90
+
91
+ # Pydantic class configuration
92
+ model_config = ConfigDict(
93
+ validate_assignment=True,
94
+ )
95
+
96
+ # Dataset configuration
97
+ data_type: Literal["array", "tiff", "czi", "custom"]
98
+ """Type of input data, numpy.ndarray (array) or paths (tiff, czi, and custom), as
99
+ defined in SupportedData."""
100
+
101
+ axes: str
102
+ """Axes of the data, as defined in SupportedAxes."""
103
+
104
+ patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
105
+ """Patch size, as used during training."""
106
+
107
+ batch_size: int = Field(default=1, ge=1, validate_default=True)
108
+ """Batch size for training."""
109
+
110
+ # Optional fields
111
+ image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
112
+ """Means of the data across channels, used for normalization."""
113
+
114
+ image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
115
+ """Standard deviations of the data across channels, used for normalization."""
116
+
117
+ target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
118
+ """Means of the target data across channels, used for normalization."""
119
+
120
+ target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
121
+ """Standard deviations of the target data across channels, used for
122
+ normalization."""
123
+
124
+ transforms: Sequence[Union[XYFlipConfig, XYRandomRotate90Config]] = Field(
125
+ default=[
126
+ XYFlipConfig(),
127
+ XYRandomRotate90Config(),
128
+ ],
129
+ validate_default=True,
130
+ )
131
+ """List of transformations to apply to the data, available transforms are defined
132
+ in SupportedTransform."""
133
+
134
+ train_dataloader_params: dict[str, Any] = Field(
135
+ default={"shuffle": True}, validate_default=True
136
+ )
137
+ """Dictionary of PyTorch training dataloader parameters. The dataloader parameters,
138
+ should include the `shuffle` key, which is set to `True` by default. We strongly
139
+ recommend to keep it as `True` to ensure the best training results."""
140
+
141
+ val_dataloader_params: dict[str, Any] = Field(default={}, validate_default=True)
142
+ """Dictionary of PyTorch validation dataloader parameters."""
143
+
144
+ @field_validator("patch_size")
145
+ @classmethod
146
+ def all_elements_power_of_2_minimum_8(
147
+ cls, patch_list: Union[list[int]]
148
+ ) -> Union[list[int]]:
149
+ """
150
+ Validate patch size.
151
+
152
+ Patch size must be powers of 2 and minimum 8.
153
+
154
+ Parameters
155
+ ----------
156
+ patch_list : list of int
157
+ Patch size.
158
+
159
+ Returns
160
+ -------
161
+ list of int
162
+ Validated patch size.
163
+
164
+ Raises
165
+ ------
166
+ ValueError
167
+ If the patch size is smaller than 8.
168
+ ValueError
169
+ If the patch size is not a power of 2.
170
+ """
171
+ patch_size_ge_than_8_power_of_2(patch_list)
172
+
173
+ return patch_list
174
+
175
+ @field_validator("axes")
176
+ @classmethod
177
+ def axes_valid(cls, axes: str) -> str:
178
+ """
179
+ Validate axes.
180
+
181
+ Axes must:
182
+ - be a combination of 'STCZYX'
183
+ - not contain duplicates
184
+ - contain at least 2 contiguous axes: X and Y
185
+ - contain at most 4 axes
186
+ - not contain both S and T axes
187
+
188
+ Parameters
189
+ ----------
190
+ axes : str
191
+ Axes to validate.
192
+
193
+ Returns
194
+ -------
195
+ str
196
+ Validated axes.
197
+
198
+ Raises
199
+ ------
200
+ ValueError
201
+ If axes are not valid.
202
+ """
203
+ # Validate axes
204
+ check_axes_validity(axes)
205
+
206
+ return axes
207
+
208
+ @field_validator("train_dataloader_params", "val_dataloader_params", mode="before")
209
+ @classmethod
210
+ def set_default_pin_memory(
211
+ cls, dataloader_params: dict[str, Any]
212
+ ) -> dict[str, Any]:
213
+ """
214
+ Set default pin_memory for dataloader parameters if not provided.
215
+
216
+ - If 'pin_memory' is not set, it defaults to True if CUDA is available.
217
+
218
+ Parameters
219
+ ----------
220
+ dataloader_params : dict of {str: Any}
221
+ The dataloader parameters.
222
+
223
+ Returns
224
+ -------
225
+ dict of {str: Any}
226
+ The dataloader parameters with pin_memory default applied.
227
+ """
228
+ if "pin_memory" not in dataloader_params:
229
+ import torch
230
+
231
+ dataloader_params["pin_memory"] = torch.cuda.is_available()
232
+
233
+ return dataloader_params
234
+
235
+ @field_validator("train_dataloader_params", mode="before")
236
+ @classmethod
237
+ def set_default_train_workers(
238
+ cls, dataloader_params: dict[str, Any]
239
+ ) -> dict[str, Any]:
240
+ """
241
+ Set default num_workers for training dataloader if not provided.
242
+
243
+ - If 'num_workers' is not set, it defaults to the number of available CPU cores.
244
+
245
+ Parameters
246
+ ----------
247
+ dataloader_params : dict of {str: Any}
248
+ The training dataloader parameters.
249
+
250
+ Returns
251
+ -------
252
+ dict of {str: Any}
253
+ The dataloader parameters with num_workers default applied.
254
+ """
255
+ if "num_workers" not in dataloader_params:
256
+ # Use 0 workers during tests, otherwise use all available CPU cores
257
+ if "pytest" in sys.modules:
258
+ dataloader_params["num_workers"] = 0
259
+ else:
260
+ dataloader_params["num_workers"] = os.cpu_count()
261
+
262
+ return dataloader_params
263
+
264
+ @model_validator(mode="after")
265
+ def set_val_workers_to_match_train(self: Self) -> Self:
266
+ """
267
+ Set validation dataloader num_workers to match training dataloader.
268
+
269
+ If num_workers is not specified in val_dataloader_params, it will be set to the
270
+ same value as train_dataloader_params["num_workers"].
271
+
272
+ Returns
273
+ -------
274
+ Self
275
+ Validated data model with synchronized num_workers.
276
+ """
277
+ if "num_workers" not in self.val_dataloader_params:
278
+ self.val_dataloader_params["num_workers"] = self.train_dataloader_params[
279
+ "num_workers"
280
+ ]
281
+ return self
282
+
283
+ @field_validator("train_dataloader_params")
284
+ @classmethod
285
+ def shuffle_train_dataloader(
286
+ cls, train_dataloader_params: dict[str, Any]
287
+ ) -> dict[str, Any]:
288
+ """
289
+ Validate that "shuffle" is included in the training dataloader params.
290
+
291
+ A warning will be raised if `shuffle=False`.
292
+
293
+ Parameters
294
+ ----------
295
+ train_dataloader_params : dict of {str: Any}
296
+ The training dataloader parameters.
297
+
298
+ Returns
299
+ -------
300
+ dict of {str: Any}
301
+ The validated training dataloader parameters.
302
+
303
+ Raises
304
+ ------
305
+ ValueError
306
+ If "shuffle" is not included in the training dataloader params.
307
+ """
308
+ if "shuffle" not in train_dataloader_params:
309
+ raise ValueError(
310
+ "Value for 'shuffle' was not included in the `train_dataloader_params`."
311
+ )
312
+ elif ("shuffle" in train_dataloader_params) and (
313
+ not train_dataloader_params["shuffle"]
314
+ ):
315
+ warn(
316
+ "Dataloader parameters include `shuffle=False`, this will be passed to "
317
+ "the training dataloader and may lead to lower quality results.",
318
+ stacklevel=1,
319
+ )
320
+ return train_dataloader_params
321
+
322
+ @model_validator(mode="after")
323
+ def std_only_with_mean(self: Self) -> Self:
324
+ """
325
+ Check that mean and std are either both None, or both specified.
326
+
327
+ Returns
328
+ -------
329
+ Self
330
+ Validated data model.
331
+
332
+ Raises
333
+ ------
334
+ ValueError
335
+ If std is not None and mean is None.
336
+ """
337
+ # check that mean and std are either both None, or both specified
338
+ if (self.image_means and not self.image_stds) or (
339
+ self.image_stds and not self.image_means
340
+ ):
341
+ raise ValueError(
342
+ "Mean and std must be either both None, or both specified."
343
+ )
344
+
345
+ elif (self.image_means is not None and self.image_stds is not None) and (
346
+ len(self.image_means) != len(self.image_stds)
347
+ ):
348
+ raise ValueError("Mean and std must be specified for each input channel.")
349
+
350
+ if (self.target_means and not self.target_stds) or (
351
+ self.target_stds and not self.target_means
352
+ ):
353
+ raise ValueError(
354
+ "Mean and std must be either both None, or both specified "
355
+ )
356
+
357
+ elif self.target_means is not None and self.target_stds is not None:
358
+ if len(self.target_means) != len(self.target_stds):
359
+ raise ValueError(
360
+ "Mean and std must be either both None, or both specified for each "
361
+ "target channel."
362
+ )
363
+
364
+ return self
365
+
366
+ @model_validator(mode="after")
367
+ def validate_dimensions(self: Self) -> Self:
368
+ """
369
+ Validate 2D/3D dimensions between axes, patch size and transforms.
370
+
371
+ Returns
372
+ -------
373
+ Self
374
+ Validated data model.
375
+
376
+ Raises
377
+ ------
378
+ ValueError
379
+ If the transforms are not valid.
380
+ """
381
+ if "Z" in self.axes:
382
+ if len(self.patch_size) != 3:
383
+ raise ValueError(
384
+ f"Patch size must have 3 dimensions if the data is 3D "
385
+ f"({self.axes})."
386
+ )
387
+
388
+ else:
389
+ if len(self.patch_size) != 2:
390
+ raise ValueError(
391
+ f"Patch size must have 3 dimensions if the data is 3D "
392
+ f"({self.axes})."
393
+ )
394
+
395
+ return self
396
+
397
+ def __str__(self) -> str:
398
+ """
399
+ Pretty string reprensenting the configuration.
400
+
401
+ Returns
402
+ -------
403
+ str
404
+ Pretty string.
405
+ """
406
+ return pformat(self.model_dump())
407
+
408
+ def _update(self, **kwargs: Any) -> None:
409
+ """
410
+ Update multiple arguments at once.
411
+
412
+ Parameters
413
+ ----------
414
+ **kwargs : Any
415
+ Keyword arguments to update.
416
+ """
417
+ self.__dict__.update(kwargs)
418
+ self.__class__.model_validate(self.__dict__)
419
+
420
+ def set_means_and_stds(
421
+ self,
422
+ image_means: Union[NDArray, tuple, list, None],
423
+ image_stds: Union[NDArray, tuple, list, None],
424
+ target_means: Union[NDArray, tuple, list, None] | None = None,
425
+ target_stds: Union[NDArray, tuple, list, None] | None = None,
426
+ ) -> None:
427
+ """
428
+ Set mean and standard deviation of the data across channels.
429
+
430
+ This method should be used instead setting the fields directly, as it would
431
+ otherwise trigger a validation error.
432
+
433
+ Parameters
434
+ ----------
435
+ image_means : numpy.ndarray, tuple or list
436
+ Mean values for normalization.
437
+ image_stds : numpy.ndarray, tuple or list
438
+ Standard deviation values for normalization.
439
+ target_means : numpy.ndarray, tuple or list, optional
440
+ Target mean values for normalization, by default ().
441
+ target_stds : numpy.ndarray, tuple or list, optional
442
+ Target standard deviation values for normalization, by default ().
443
+ """
444
+ # make sure we pass a list
445
+ if image_means is not None:
446
+ image_means = list(image_means)
447
+ if image_stds is not None:
448
+ image_stds = list(image_stds)
449
+ if target_means is not None:
450
+ target_means = list(target_means)
451
+ if target_stds is not None:
452
+ target_stds = list(target_stds)
453
+
454
+ self._update(
455
+ image_means=image_means,
456
+ image_stds=image_stds,
457
+ target_means=target_means,
458
+ target_stds=target_stds,
459
+ )
460
+
461
+ def set_3D(self, axes: str, patch_size: list[int]) -> None:
462
+ """
463
+ Set 3D parameters.
464
+
465
+ Parameters
466
+ ----------
467
+ axes : str
468
+ Axes.
469
+ patch_size : list of int
470
+ Patch size.
471
+ """
472
+ self._update(axes=axes, patch_size=patch_size)