careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,186 @@
1
+ """Optimizers and schedulers Pydantic models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal, Self
6
+
7
+ from pydantic import (
8
+ BaseModel,
9
+ ConfigDict,
10
+ Field,
11
+ ValidationInfo,
12
+ field_validator,
13
+ model_validator,
14
+ )
15
+ from torch import optim
16
+
17
+ from careamics.utils.torch_utils import filter_parameters
18
+
19
+ from ..support import SupportedOptimizer
20
+
21
+
22
+ class OptimizerConfig(BaseModel):
23
+ """Torch optimizer Pydantic model.
24
+
25
+ Only parameters supported by the corresponding torch optimizer will be taken
26
+ into account. For more details, check:
27
+ https://pytorch.org/docs/stable/optim.html#algorithms
28
+
29
+ Note that mandatory parameters (see the specific Optimizer signature in the
30
+ link above) must be provided. For example, SGD requires `lr`.
31
+
32
+ Attributes
33
+ ----------
34
+ name : {"Adam", "SGD"}
35
+ Name of the optimizer.
36
+ parameters : dict
37
+ Parameters of the optimizer (see torch documentation).
38
+ """
39
+
40
+ # Pydantic class configuration
41
+ model_config = ConfigDict(
42
+ validate_assignment=True,
43
+ )
44
+
45
+ # Mandatory field
46
+ name: Literal["Adam", "SGD", "Adamax"] = Field(
47
+ default="Adam", validate_default=True
48
+ )
49
+ """Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
50
+
51
+ # Optional parameters, empty dict default value to allow filtering dictionary
52
+ parameters: dict = Field(
53
+ default={},
54
+ validate_default=True,
55
+ )
56
+ """Parameters of the optimizer, see PyTorch documentation for more details."""
57
+
58
+ @field_validator("parameters")
59
+ @classmethod
60
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
61
+ """
62
+ Validate optimizer parameters.
63
+
64
+ This method filters out unknown parameters, given the optimizer name.
65
+
66
+ Parameters
67
+ ----------
68
+ user_params : dict
69
+ Parameters passed on to the torch optimizer.
70
+ values : ValidationInfo
71
+ Pydantic field validation info, used to get the optimizer name.
72
+
73
+ Returns
74
+ -------
75
+ dict
76
+ Filtered optimizer parameters.
77
+
78
+ Raises
79
+ ------
80
+ ValueError
81
+ If the optimizer name is not specified.
82
+ """
83
+ optimizer_name = values.data["name"]
84
+
85
+ # retrieve the corresponding optimizer class
86
+ optimizer_class = getattr(optim, optimizer_name)
87
+
88
+ # filter the user parameters according to the optimizer's signature
89
+ parameters = filter_parameters(optimizer_class, user_params)
90
+
91
+ return parameters
92
+
93
+ @model_validator(mode="after")
94
+ def sgd_lr_parameter(self) -> Self:
95
+ """
96
+ Check that SGD optimizer has the mandatory `lr` parameter specified.
97
+
98
+ This is specific for PyTorch < 2.2.
99
+
100
+ Returns
101
+ -------
102
+ Self
103
+ Validated optimizer.
104
+
105
+ Raises
106
+ ------
107
+ ValueError
108
+ If the optimizer is SGD and the lr parameter is not specified.
109
+ """
110
+ if self.name == SupportedOptimizer.SGD and "lr" not in self.parameters:
111
+ raise ValueError(
112
+ "SGD optimizer requires `lr` parameter, check that it has correctly "
113
+ "been specified in `parameters`."
114
+ )
115
+
116
+ return self
117
+
118
+
119
+ class LrSchedulerConfig(BaseModel):
120
+ """Torch learning rate scheduler Pydantic model.
121
+
122
+ Only parameters supported by the corresponding torch lr scheduler will be taken
123
+ into account. For more details, check:
124
+ https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
125
+
126
+ Note that mandatory parameters (see the specific LrScheduler signature in the
127
+ link above) must be provided. For example, StepLR requires `step_size`.
128
+
129
+ Attributes
130
+ ----------
131
+ name : {"ReduceLROnPlateau", "StepLR"}
132
+ Name of the learning rate scheduler.
133
+ parameters : dict
134
+ Parameters of the learning rate scheduler (see torch documentation).
135
+ """
136
+
137
+ # Pydantic class configuration
138
+ model_config = ConfigDict(
139
+ validate_assignment=True,
140
+ )
141
+
142
+ # Mandatory field
143
+ name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
144
+ """Name of the learning rate scheduler, supported schedulers are defined in
145
+ SupportedScheduler."""
146
+
147
+ # Optional parameters
148
+ parameters: dict = Field(default={}, validate_default=True)
149
+ """Parameters of the learning rate scheduler, see PyTorch documentation for more
150
+ details."""
151
+
152
+ @field_validator("parameters")
153
+ @classmethod
154
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
155
+ """Filter parameters based on the learning rate scheduler's signature.
156
+
157
+ Parameters
158
+ ----------
159
+ user_params : dict
160
+ User parameters.
161
+ values : ValidationInfo
162
+ Pydantic field validation info, used to get the scheduler name.
163
+
164
+ Returns
165
+ -------
166
+ dict
167
+ Filtered scheduler parameters.
168
+
169
+ Raises
170
+ ------
171
+ ValueError
172
+ If the scheduler is StepLR and the step_size parameter is not specified.
173
+ """
174
+ # retrieve the corresponding scheduler class
175
+ scheduler_class = getattr(optim.lr_scheduler, values.data["name"])
176
+
177
+ # filter the user parameters according to the scheduler's signature
178
+ parameters = filter_parameters(scheduler_class, user_params)
179
+
180
+ if values.data["name"] == "StepLR" and "step_size" not in parameters:
181
+ raise ValueError(
182
+ "StepLR scheduler requires `step_size` parameter, check that it has "
183
+ "correctly been specified in `parameters`."
184
+ )
185
+
186
+ return parameters
@@ -0,0 +1,70 @@
1
+ """Training configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pprint import pformat
6
+ from typing import Literal
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field
9
+
10
+ from .callbacks.callback_config import CheckpointConfig, EarlyStoppingConfig
11
+
12
+
13
+ class TrainingConfig(BaseModel):
14
+ """
15
+ Parameters related to the training.
16
+
17
+ Mandatory parameters are:
18
+ - num_epochs: number of epochs, greater than 0.
19
+ - batch_size: batch size, greater than 0.
20
+ - augmentation: whether to use data augmentation or not (True or False).
21
+
22
+ Attributes
23
+ ----------
24
+ num_epochs : int
25
+ Number of epochs, greater than 0.
26
+ """
27
+
28
+ # Pydantic class configuration
29
+ model_config = ConfigDict(
30
+ validate_assignment=True,
31
+ )
32
+
33
+ lightning_trainer_config: dict | None = None
34
+ """Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning
35
+ Trainer class"""
36
+
37
+ logger: Literal["wandb", "tensorboard"] | None = None
38
+ """Logger to use during training. If None, no logger will be used. Available
39
+ loggers are defined in SupportedLogger."""
40
+
41
+ # Only basic callbacks
42
+ checkpoint_callback: CheckpointConfig = CheckpointConfig()
43
+ """Checkpoint callback configuration, following PyTorch Lightning Checkpoint
44
+ callback."""
45
+
46
+ early_stopping_callback: EarlyStoppingConfig | None = Field(
47
+ default=None, validate_default=True
48
+ )
49
+ """Early stopping callback configuration, following PyTorch Lightning Checkpoint
50
+ callback."""
51
+
52
+ def __str__(self) -> str:
53
+ """Pretty string reprensenting the configuration.
54
+
55
+ Returns
56
+ -------
57
+ str
58
+ Pretty string.
59
+ """
60
+ return pformat(self.model_dump())
61
+
62
+ def has_logger(self) -> bool:
63
+ """Check if the logger is defined.
64
+
65
+ Returns
66
+ -------
67
+ bool
68
+ Whether the logger is defined or not.
69
+ """
70
+ return self.logger is not None
@@ -0,0 +1,8 @@
1
+ """Losses Pydantic configurations."""
2
+
3
+ __all__ = [
4
+ "KLLossConfig",
5
+ "LVAELossConfig",
6
+ ]
7
+
8
+ from .loss_config import KLLossConfig, LVAELossConfig
@@ -0,0 +1,60 @@
1
+ """Configuration classes for LVAE losses."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+
8
+ class KLLossConfig(BaseModel):
9
+ """KL loss configuration."""
10
+
11
+ model_config = ConfigDict(validate_assignment=True, validate_default=True)
12
+
13
+ loss_type: Literal["kl", "kl_restricted"] = "kl"
14
+ """Type of KL divergence used as KL loss."""
15
+ rescaling: Literal["latent_dim", "image_dim"] = "latent_dim"
16
+ """Rescaling of the KL loss."""
17
+ aggregation: Literal["sum", "mean"] = "mean"
18
+ """Aggregation of the KL loss across different layers."""
19
+ free_bits_coeff: float = 0.0
20
+ """Free bits coefficient for the KL loss."""
21
+ annealing: bool = False
22
+ """Whether to apply KL loss annealing."""
23
+ start: int = -1
24
+ """Epoch at which KL loss annealing starts."""
25
+ annealtime: int = 10
26
+ """Number of epochs for which KL loss annealing is applied."""
27
+ current_epoch: int = 0
28
+ """Current epoch in the training loop."""
29
+
30
+
31
+ class LVAELossConfig(BaseModel):
32
+ """LVAE loss configuration."""
33
+
34
+ model_config = ConfigDict(
35
+ validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
36
+ )
37
+
38
+ loss_type: Literal[
39
+ "hdn", "microsplit", "musplit", "denoisplit", "denoisplit_musplit"
40
+ ]
41
+ """Type of loss to use for LVAE."""
42
+
43
+ reconstruction_weight: float = 1.0
44
+ """Weight for the reconstruction loss in the total net loss
45
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
46
+ kl_weight: float = 1.0
47
+ """Weight for the KL loss in the total net loss.
48
+ (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
49
+ musplit_weight: float = 0.1
50
+ """Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
51
+ denoisplit_weight: float = 0.9
52
+ """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
53
+ kl_params: KLLossConfig = KLLossConfig()
54
+ """KL loss configuration."""
55
+ # TODO revisit weights for the losses
56
+ # TODO: remove?
57
+ non_stochastic: bool = False
58
+ """Whether to sample latents and compute KL."""
59
+
60
+ # TODO what are the correct parameters for HDN ?
@@ -0,0 +1,5 @@
1
+ """Definitions of configurations for CAREamics, compatible with the NG dataset."""
2
+
3
+ __all__ = ["N2VConfiguration"]
4
+
5
+ from .n2v_configuration import N2VConfiguration
@@ -0,0 +1,64 @@
1
+ """Configuration for N2V."""
2
+
3
+ from typing import Self
4
+
5
+ import numpy as np
6
+ from pydantic import model_validator
7
+
8
+ from careamics.config.algorithms import N2VAlgorithm
9
+ from careamics.config.data.patching_strategies import RandomPatchingConfig
10
+
11
+ from .ng_configuration import NGConfiguration
12
+
13
+
14
+ class N2VConfiguration(NGConfiguration):
15
+ """N2V-specific configuration."""
16
+
17
+ algorithm_config: N2VAlgorithm
18
+
19
+ @model_validator(mode="after")
20
+ def validate_n2v_mask_pixel_perc(self: Self) -> Self:
21
+ """
22
+ Validate that there will always be at least one blind-spot pixel in every patch.
23
+
24
+ The probability of creating a blind-spot pixel is a function of the chosen
25
+ masked pixel percentage and patch size.
26
+
27
+ Returns
28
+ -------
29
+ Self
30
+ Validated configuration.
31
+
32
+ Raises
33
+ ------
34
+ ValueError
35
+ If the probability of masking a pixel within a patch is less than 1 for the
36
+ chosen masked pixel percentage and patch size.
37
+ """
38
+ if self.data_config.mode == "training":
39
+ assert isinstance(self.data_config.patching, RandomPatchingConfig)
40
+
41
+ mask_pixel_perc = self.algorithm_config.n2v_config.masked_pixel_percentage
42
+ patch_size = self.data_config.patching.patch_size
43
+ expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
44
+
45
+ n_dims = 3 if self.algorithm_config.model.is_3D() else 2
46
+ patch_size_lower_bound = int(
47
+ np.ceil(expected_area_per_pixel ** (1 / n_dims))
48
+ )
49
+ required_patch_size = tuple(
50
+ 2 ** int(np.ceil(np.log2(patch_size_lower_bound)))
51
+ for _ in range(n_dims)
52
+ )
53
+ required_mask_pixel_perc = (1 / np.prod(patch_size)) * 100
54
+
55
+ if expected_area_per_pixel > np.prod(patch_size):
56
+ raise ValueError(
57
+ "The probability of creating a blind-spot pixel within a patch is "
58
+ f"below 1, for a patch size of {patch_size} with a masked pixel "
59
+ f"percentage of {mask_pixel_perc}%. Either increase the patch size "
60
+ f"to {required_patch_size} or increase the masked pixel percentage "
61
+ f"to at least {required_mask_pixel_perc}%."
62
+ )
63
+
64
+ return self
@@ -0,0 +1,256 @@
1
+ """CAREamics configuration compatible with the NG Dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from pprint import pformat
7
+ from typing import Any, Literal, Self, Union
8
+
9
+ from bioimageio.spec.generic.v0_3 import CiteEntry
10
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
11
+
12
+ from careamics.config.algorithms import (
13
+ CAREAlgorithm,
14
+ N2NAlgorithm,
15
+ N2VAlgorithm,
16
+ )
17
+ from careamics.config.data import NGDataConfig
18
+ from careamics.config.lightning.training_config import TrainingConfig
19
+
20
+ ALGORITHMS = Union[
21
+ CAREAlgorithm,
22
+ N2NAlgorithm,
23
+ N2VAlgorithm,
24
+ ]
25
+
26
+
27
+ class NGConfiguration(BaseModel):
28
+ """
29
+ CAREamics configuration.
30
+
31
+ The configuration defines all parameters used to build and train a CAREamics model.
32
+ These parameters are validated to ensure that they are compatible with each other.
33
+
34
+ It contains three sub-configurations:
35
+
36
+ - AlgorithmModel: configuration for the algorithm training, which includes the
37
+ architecture, loss function, optimizer, and other hyperparameters.
38
+ - DataModel: configuration for the dataloader, which includes the type of data,
39
+ transformations, mean/std and other parameters.
40
+ - TrainingModel: configuration for the training, which includes the number of
41
+ epochs or the callbacks.
42
+
43
+ Attributes
44
+ ----------
45
+ experiment_name : str
46
+ Name of the experiment, used when saving logs and checkpoints.
47
+ algorithm : AlgorithmModel
48
+ Algorithm configuration.
49
+ data : DataModel
50
+ Data configuration.
51
+ training : TrainingModel
52
+ Training configuration.
53
+
54
+ Methods
55
+ -------
56
+ set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None
57
+ Switch configuration between 2D and 3D.
58
+ model_dump(
59
+ exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict
60
+ ) -> Dict
61
+ Export configuration to a dictionary.
62
+
63
+ Raises
64
+ ------
65
+ ValueError
66
+ Configuration parameter type validation errors.
67
+ ValueError
68
+ If the experiment name contains invalid characters or is empty.
69
+ ValueError
70
+ If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm
71
+ with "Z" in data axes.
72
+ ValueError
73
+ Algorithm, data or training validation errors.
74
+ """
75
+
76
+ model_config = ConfigDict(
77
+ validate_assignment=True,
78
+ arbitrary_types_allowed=True,
79
+ )
80
+
81
+ # version
82
+ version: Literal["0.1.0"] = "0.1.0"
83
+ """CAREamics configuration version."""
84
+
85
+ # required parameters
86
+ experiment_name: str
87
+ """Name of the experiment, used to name logs and checkpoints."""
88
+
89
+ # Sub-configurations
90
+ algorithm_config: ALGORITHMS = Field(discriminator="algorithm")
91
+ """Algorithm configuration, holding all parameters required to configure the
92
+ model."""
93
+
94
+ data_config: NGDataConfig
95
+ """Data configuration, holding all parameters required to configure the training
96
+ data loader."""
97
+
98
+ training_config: TrainingConfig
99
+ """Training configuration, holding all parameters required to configure the
100
+ training process."""
101
+
102
+ @field_validator("experiment_name")
103
+ @classmethod
104
+ def no_symbol(cls, name: str) -> str:
105
+ """
106
+ Validate experiment name.
107
+
108
+ A valid experiment name is a non-empty string with only contains letters,
109
+ numbers, underscores, dashes and spaces.
110
+
111
+ Parameters
112
+ ----------
113
+ name : str
114
+ Name to validate.
115
+
116
+ Returns
117
+ -------
118
+ str
119
+ Validated name.
120
+
121
+ Raises
122
+ ------
123
+ ValueError
124
+ If the name is empty or contains invalid characters.
125
+ """
126
+ if len(name) == 0 or name.isspace():
127
+ raise ValueError("Experiment name is empty.")
128
+
129
+ # Validate using a regex that it contains only letters, numbers, underscores,
130
+ # dashes and spaces
131
+ if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
132
+ raise ValueError(
133
+ f"Experiment name contains invalid characters (got {name}). "
134
+ f"Only letters, numbers, underscores, dashes and spaces are allowed."
135
+ )
136
+
137
+ return name
138
+
139
+ @model_validator(mode="after")
140
+ def validate_3D(self: Self) -> Self:
141
+ """
142
+ Validate algorithm dimensions to match data dimensions.
143
+
144
+ Returns
145
+ -------
146
+ Self
147
+ Validated configuration.
148
+ """
149
+ if self.data_config.is_3D() != self.algorithm_config.model.is_3D():
150
+ raise ValueError(
151
+ f"Mismatch between data ({'3D' if self.data_config.is_3D() else '2D'}) "
152
+ f"and algorithm ("
153
+ f"{'3D' if self.algorithm_config.model.is_3D() else '2D'}). Data "
154
+ f"dimensionality is determined by the axes ({self.data_config.axes}), "
155
+ f"as well as patch size (if applicable) and data type (if data type "
156
+ f"is 'czi', which uses 3D when 'T' axis is specified)."
157
+ )
158
+
159
+ return self
160
+
161
+ def __str__(self) -> str:
162
+ """
163
+ Pretty string reprensenting the configuration.
164
+
165
+ Returns
166
+ -------
167
+ str
168
+ Pretty string.
169
+ """
170
+ return pformat(self.model_dump())
171
+
172
+ def get_algorithm_friendly_name(self) -> str:
173
+ """
174
+ Get the algorithm name.
175
+
176
+ Returns
177
+ -------
178
+ str
179
+ Algorithm name.
180
+ """
181
+ return self.algorithm_config.get_algorithm_friendly_name()
182
+
183
+ def get_algorithm_description(self) -> str:
184
+ """
185
+ Return a description of the algorithm.
186
+
187
+ This method is used to generate the README of the BioImage Model Zoo export.
188
+
189
+ Returns
190
+ -------
191
+ str
192
+ Description of the algorithm.
193
+ """
194
+ return self.algorithm_config.get_algorithm_description()
195
+
196
+ def get_algorithm_citations(self) -> list[CiteEntry]:
197
+ """
198
+ Return a list of citation entries of the current algorithm.
199
+
200
+ This is used to generate the model description for the BioImage Model Zoo.
201
+
202
+ Returns
203
+ -------
204
+ List[CiteEntry]
205
+ List of citation entries.
206
+ """
207
+ return self.algorithm_config.get_algorithm_citations()
208
+
209
+ def get_algorithm_references(self) -> str:
210
+ """
211
+ Get the algorithm references.
212
+
213
+ This is used to generate the README of the BioImage Model Zoo export.
214
+
215
+ Returns
216
+ -------
217
+ str
218
+ Algorithm references.
219
+ """
220
+ return self.algorithm_config.get_algorithm_references()
221
+
222
+ def get_algorithm_keywords(self) -> list[str]:
223
+ """
224
+ Get algorithm keywords.
225
+
226
+ Returns
227
+ -------
228
+ list[str]
229
+ List of keywords.
230
+ """
231
+ return self.algorithm_config.get_algorithm_keywords()
232
+
233
+ def model_dump(
234
+ self,
235
+ **kwargs: Any,
236
+ ) -> dict[str, Any]:
237
+ """
238
+ Override model_dump method in order to set default values.
239
+
240
+ As opposed to the parent model_dump method, this method sets exclude none by
241
+ default.
242
+
243
+ Parameters
244
+ ----------
245
+ **kwargs : Any
246
+ Additional arguments to pass to the parent model_dump method.
247
+
248
+ Returns
249
+ -------
250
+ dict
251
+ Dictionary containing the model parameters.
252
+ """
253
+ if "exclude_none" not in kwargs:
254
+ kwargs["exclude_none"] = True
255
+
256
+ return super().model_dump(**kwargs)
@@ -0,0 +1,9 @@
1
+ """Convenience functions to create coherent configurations for CAREamics."""
2
+
3
+ __all__ = [
4
+ "create_n2v_configuration",
5
+ "create_ng_data_configuration",
6
+ ]
7
+
8
+ from .data_factory import create_ng_data_configuration
9
+ from .n2v_factory import create_n2v_configuration