careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,178 @@
1
+ """VAE-based algorithm Pydantic model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pprint import pformat
6
+ from typing import Literal, Self
7
+
8
+ from pydantic import BaseModel, ConfigDict, model_validator
9
+
10
+ from careamics.config.architectures import LVAEConfig
11
+ from careamics.config.lightning.optimizer_configs import (
12
+ LrSchedulerConfig,
13
+ OptimizerConfig,
14
+ )
15
+ from careamics.config.losses.loss_config import LVAELossConfig
16
+ from careamics.config.noise_model.likelihood_config import (
17
+ GaussianLikelihoodConfig,
18
+ NMLikelihoodConfig,
19
+ )
20
+ from careamics.config.noise_model.noise_model_config import MultiChannelNMConfig
21
+ from careamics.config.support import SupportedAlgorithm, SupportedLoss
22
+
23
+
24
+ class VAEBasedAlgorithm(BaseModel):
25
+ """VAE-based algorithm configuration.
26
+
27
+ # TODO
28
+
29
+ Examples
30
+ --------
31
+ # TODO add once finalized
32
+ """
33
+
34
+ # Pydantic class configuration
35
+ model_config = ConfigDict(
36
+ protected_namespaces=(), # allows to use model_* as a field name
37
+ validate_assignment=True,
38
+ extra="allow",
39
+ )
40
+
41
+ # Mandatory fields
42
+ # defined in SupportedAlgorithm
43
+ # TODO: Use supported Enum classes for typing?
44
+ # - values can still be passed as strings and they will be cast to Enum
45
+ algorithm: Literal["hdn", "microsplit"]
46
+
47
+ # NOTE: these are all configs (pydantic models)
48
+ loss: LVAELossConfig
49
+ model: LVAEConfig
50
+ noise_model: MultiChannelNMConfig | None = None
51
+ noise_model_likelihood: NMLikelihoodConfig | None = None
52
+ gaussian_likelihood: GaussianLikelihoodConfig | None = None # TODO change to str
53
+
54
+ mmse_count: int = 1
55
+ is_supervised: bool = False
56
+
57
+ # Optional fields
58
+ optimizer: OptimizerConfig = OptimizerConfig()
59
+ """Optimizer to use, defined in SupportedOptimizer."""
60
+
61
+ lr_scheduler: LrSchedulerConfig = LrSchedulerConfig()
62
+
63
+ @model_validator(mode="after")
64
+ def algorithm_cross_validation(self: Self) -> Self:
65
+ """Validate the algorithm model based on `algorithm`.
66
+
67
+ Returns
68
+ -------
69
+ Self
70
+ The validated model.
71
+ """
72
+ # hdn
73
+ # TODO move to designated configurations
74
+ if self.algorithm == SupportedAlgorithm.HDN:
75
+ if self.loss.loss_type != SupportedLoss.HDN:
76
+ raise ValueError(
77
+ f"Algorithm {self.algorithm} only supports loss `hdn`."
78
+ )
79
+ if self.model.multiscale_count > 1:
80
+ raise ValueError("Algorithm `hdn` does not support multiscale models.")
81
+ # musplit
82
+ if self.algorithm == SupportedAlgorithm.MICROSPLIT:
83
+ if self.loss.loss_type not in [
84
+ SupportedLoss.MUSPLIT,
85
+ SupportedLoss.DENOISPLIT,
86
+ SupportedLoss.DENOISPLIT_MUSPLIT,
87
+ ]: # TODO Update losses configs, make loss just microsplit
88
+ raise ValueError(
89
+ f"Algorithm {self.algorithm} only supports loss `microsplit`."
90
+ ) # TODO Update losses configs
91
+
92
+ if (
93
+ self.loss.loss_type == SupportedLoss.DENOISPLIT
94
+ and self.model.predict_logvar is not None
95
+ ):
96
+ raise ValueError(
97
+ "Algorithm `denoisplit` with loss `denoisplit` only supports "
98
+ "`predict_logvar` as `None`."
99
+ )
100
+ if (
101
+ self.loss.loss_type == SupportedLoss.DENOISPLIT
102
+ and self.noise_model is None
103
+ ):
104
+ raise ValueError("Algorithm `denoisplit` requires a noise model.")
105
+ # TODO: what if algorithm is not musplit or denoisplit
106
+ return self
107
+
108
+ @model_validator(mode="after")
109
+ def output_channels_validation(self: Self) -> Self:
110
+ """Validate the consistency between number of out channels and noise models.
111
+
112
+ Returns
113
+ -------
114
+ Self
115
+ The validated model.
116
+ """
117
+ if self.noise_model is not None:
118
+ assert self.model.output_channels == len(self.noise_model.noise_models), (
119
+ f"Number of output channels ({self.model.output_channels}) must match "
120
+ f"the number of noise models ({len(self.noise_model.noise_models)})."
121
+ )
122
+
123
+ if self.algorithm == SupportedAlgorithm.HDN:
124
+ assert self.model.output_channels == 1, (
125
+ f"Number of output channels ({self.model.output_channels}) must be 1 "
126
+ "for algorithm `hdn`."
127
+ )
128
+ return self
129
+
130
+ @model_validator(mode="after")
131
+ def predict_logvar_validation(self: Self) -> Self:
132
+ """Validate the consistency of `predict_logvar` throughout the model.
133
+
134
+ Returns
135
+ -------
136
+ Self
137
+ The validated model.
138
+ """
139
+ if self.gaussian_likelihood is not None:
140
+ assert (
141
+ self.model.predict_logvar == self.gaussian_likelihood.predict_logvar
142
+ ), (
143
+ f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
144
+ "Gaussian likelihood model `predict_logvar` "
145
+ f"({self.gaussian_likelihood.predict_logvar}).",
146
+ )
147
+ # if self.algorithm == SupportedAlgorithm.HDN:
148
+ # assert (
149
+ # self.model.predict_logvar is None
150
+ # ), "Model `predict_logvar` must be `None` for algorithm `hdn`."
151
+ # if self.gaussian_likelihood is not None:
152
+ # assert self.gaussian_likelihood.predict_logvar is None, (
153
+ # "Gaussian likelihood model `predict_logvar` must be `None` "
154
+ # "for algorithm `hdn`."
155
+ # )
156
+ # TODO check this
157
+ return self
158
+
159
+ def __str__(self) -> str:
160
+ """Pretty string representing the configuration.
161
+
162
+ Returns
163
+ -------
164
+ str
165
+ Pretty string.
166
+ """
167
+ return pformat(self.model_dump())
168
+
169
+ @classmethod
170
+ def get_compatible_algorithms(cls) -> list[str]:
171
+ """Get the list of compatible algorithms.
172
+
173
+ Returns
174
+ -------
175
+ list of str
176
+ List of compatible algorithms.
177
+ """
178
+ return ["hdn", "microsplit"]
@@ -0,0 +1,7 @@
1
+ """Deep-learning model configurations."""
2
+
3
+ __all__ = ["ArchitectureConfig", "LVAEConfig", "UNetConfig"]
4
+
5
+ from .architecture_config import ArchitectureConfig
6
+ from .lvae_config import LVAEConfig
7
+ from .unet_config import UNetConfig
@@ -0,0 +1,37 @@
1
+ """Base model for the various CAREamics architectures."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class ArchitectureConfig(BaseModel):
9
+ """
10
+ Base Pydantic model for all model architectures.
11
+
12
+ The `model_dump` method allows removing the `architecture` key from the model.
13
+ """
14
+
15
+ architecture: str
16
+ """Name of the architecture."""
17
+
18
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
19
+ """
20
+ Dump the model as a dictionary, ignoring the architecture keyword.
21
+
22
+ Parameters
23
+ ----------
24
+ **kwargs : Any
25
+ Additional keyword arguments from Pydantic BaseModel model_dump method.
26
+
27
+ Returns
28
+ -------
29
+ {str: Any}
30
+ Model as a dictionary.
31
+ """
32
+ model_dict = super().model_dump(**kwargs)
33
+
34
+ # remove the architecture key
35
+ model_dict.pop("architecture")
36
+
37
+ return model_dict
@@ -0,0 +1,262 @@
1
+ """LVAE Pydantic model."""
2
+
3
+ from typing import Literal, Self
4
+
5
+ from pydantic import ConfigDict, Field, field_validator, model_validator
6
+
7
+ from .architecture_config import ArchitectureConfig
8
+
9
+
10
+ # TODO: it is quite confusing to call this LVAEModel, as it is basically a config
11
+ class LVAEConfig(ArchitectureConfig):
12
+ """LVAE model."""
13
+
14
+ model_config = ConfigDict(validate_assignment=True, validate_default=True)
15
+
16
+ architecture: Literal["LVAE"]
17
+
18
+ input_shape: tuple[int, ...] = Field(default=(64, 64), validate_default=True)
19
+ """Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D."""
20
+ encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
21
+
22
+ # TODO make this per hierarchy step ?
23
+ decoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
24
+ """Dimensions (2D or 3D) of the convolutional layers."""
25
+
26
+ multiscale_count: int = Field(default=1)
27
+ # TODO there should be a check for multiscale_count in dataset !!
28
+
29
+ # 1 - off, len(z_dims) + 1 # TODO Consider starting from 0
30
+ z_dims: list = Field(default=[128, 128, 128, 128])
31
+ output_channels: int = Field(default=1, ge=1)
32
+ encoder_n_filters: int = Field(default=64, ge=8, le=1024)
33
+ decoder_n_filters: int = Field(default=64, ge=8, le=1024)
34
+ encoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
35
+ decoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
36
+ nonlinearity: Literal[
37
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
38
+ ] = Field(
39
+ default="ELU",
40
+ )
41
+
42
+ predict_logvar: Literal[None, "pixelwise"] = "pixelwise"
43
+ analytical_kl: bool = Field(default=False)
44
+
45
+ @model_validator(mode="after")
46
+ def validate_conv_strides(self: Self) -> Self:
47
+ """
48
+ Validate the convolutional strides.
49
+
50
+ Returns
51
+ -------
52
+ list
53
+ Validated strides.
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If the number of strides is not 2.
59
+ """
60
+ if len(self.encoder_conv_strides) < 2 or len(self.encoder_conv_strides) > 3:
61
+ raise ValueError(
62
+ f"Strides must be 2 or 3 (got {len(self.encoder_conv_strides)})."
63
+ )
64
+
65
+ if len(self.decoder_conv_strides) < 2 or len(self.decoder_conv_strides) > 3:
66
+ raise ValueError(
67
+ f"Strides must be 2 or 3 (got {len(self.decoder_conv_strides)})."
68
+ )
69
+
70
+ # adding 1 to encoder strides for the number of input channels
71
+ if len(self.input_shape) != len(self.encoder_conv_strides):
72
+ raise ValueError(
73
+ f"Input dimensions must be equal to the number of encoder conv strides"
74
+ f" (got {len(self.input_shape)} and {len(self.encoder_conv_strides)})."
75
+ )
76
+
77
+ if len(self.encoder_conv_strides) < len(self.decoder_conv_strides):
78
+ raise ValueError(
79
+ f"Decoder can't be 3D when encoder is 2D (got"
80
+ f" {len(self.encoder_conv_strides)} and"
81
+ f"{len(self.decoder_conv_strides)})."
82
+ )
83
+
84
+ if any(s < 1 for s in self.encoder_conv_strides) or any(
85
+ s < 1 for s in self.decoder_conv_strides
86
+ ):
87
+ raise ValueError(
88
+ f"All strides must be greater or equal to 1"
89
+ f"(got {self.encoder_conv_strides} and {self.decoder_conv_strides})."
90
+ )
91
+ # TODO: validate max stride size ?
92
+ return self
93
+
94
+ @field_validator("input_shape")
95
+ @classmethod
96
+ def validate_input_shape(cls, input_shape: list) -> list:
97
+ """
98
+ Validate the input shape.
99
+
100
+ Parameters
101
+ ----------
102
+ input_shape : list
103
+ Shape of the input patch.
104
+
105
+ Returns
106
+ -------
107
+ list
108
+ Validated input shape.
109
+
110
+ Raises
111
+ ------
112
+ ValueError
113
+ If the number of dimensions is not 3 or 4.
114
+ """
115
+ if len(input_shape) < 2 or len(input_shape) > 3:
116
+ raise ValueError(
117
+ f"Number of input dimensions must be 2 for 2D data 3 for 3D"
118
+ f"(got {len(input_shape)})."
119
+ )
120
+
121
+ if any(s < 1 for s in input_shape):
122
+ raise ValueError(
123
+ f"Input shape must be greater than 1 in all dimensions"
124
+ f"(got {input_shape})."
125
+ )
126
+
127
+ if any(s < 64 for s in input_shape[-2:]):
128
+ raise ValueError(
129
+ f"Input shape must be greater or equal to 64 in XY dimensions"
130
+ f"(got {input_shape})."
131
+ )
132
+
133
+ return input_shape
134
+
135
+ @field_validator("encoder_n_filters")
136
+ @classmethod
137
+ def validate_encoder_even(cls, encoder_n_filters: int) -> int:
138
+ """
139
+ Validate that num_channels_init is even.
140
+
141
+ Parameters
142
+ ----------
143
+ encoder_n_filters : int
144
+ Number of channels.
145
+
146
+ Returns
147
+ -------
148
+ int
149
+ Validated number of channels.
150
+
151
+ Raises
152
+ ------
153
+ ValueError
154
+ If the number of channels is odd.
155
+ """
156
+ # if odd
157
+ if encoder_n_filters % 2 != 0:
158
+ raise ValueError(
159
+ f"Number of channels for the bottom layer must be even"
160
+ f" (got {encoder_n_filters})."
161
+ )
162
+
163
+ return encoder_n_filters
164
+
165
+ @field_validator("decoder_n_filters")
166
+ @classmethod
167
+ def validate_decoder_even(cls, decoder_n_filters: int) -> int:
168
+ """
169
+ Validate that num_channels_init is even.
170
+
171
+ Parameters
172
+ ----------
173
+ decoder_n_filters : int
174
+ Number of channels.
175
+
176
+ Returns
177
+ -------
178
+ int
179
+ Validated number of channels.
180
+
181
+ Raises
182
+ ------
183
+ ValueError
184
+ If the number of channels is odd.
185
+ """
186
+ # if odd
187
+ if decoder_n_filters % 2 != 0:
188
+ raise ValueError(
189
+ f"Number of channels for the bottom layer must be even"
190
+ f" (got {decoder_n_filters})."
191
+ )
192
+
193
+ return decoder_n_filters
194
+
195
+ @field_validator("z_dims")
196
+ def validate_z_dims(cls, z_dims: tuple) -> tuple:
197
+ """
198
+ Validate the z_dims.
199
+
200
+ Parameters
201
+ ----------
202
+ z_dims : tuple
203
+ Tuple of z dimensions.
204
+
205
+ Returns
206
+ -------
207
+ tuple
208
+ Validated z dimensions.
209
+
210
+ Raises
211
+ ------
212
+ ValueError
213
+ If the number of z dimensions is not 4.
214
+ """
215
+ if len(z_dims) < 2:
216
+ raise ValueError(
217
+ f"Number of z dimensions must be at least 2 (got {len(z_dims)})."
218
+ )
219
+
220
+ return z_dims
221
+
222
+ @model_validator(mode="after")
223
+ def validate_multiscale_count(self: Self) -> Self:
224
+ """
225
+ Validate the multiscale count.
226
+
227
+ Returns
228
+ -------
229
+ Self
230
+ The validated model.
231
+ """
232
+ if self.multiscale_count < 1 or self.multiscale_count > len(self.z_dims) + 1:
233
+ raise ValueError(
234
+ f"Multiscale count must be 1 for LC off or less or equal to the number"
235
+ f" of Z dims + 1 (got {self.multiscale_count} and {len(self.z_dims)})."
236
+ )
237
+ return self
238
+
239
+ def set_3D(self, is_3D: bool) -> None:
240
+ """
241
+ Set 3D model by setting the `conv_dims` parameters.
242
+
243
+ Parameters
244
+ ----------
245
+ is_3D : bool
246
+ Whether the algorithm is 3D or not.
247
+ """
248
+ if is_3D:
249
+ self.conv_dims = 3
250
+ else:
251
+ self.conv_dims = 2
252
+
253
+ def is_3D(self) -> bool:
254
+ """
255
+ Return whether the model is 3D or not.
256
+
257
+ Returns
258
+ -------
259
+ bool
260
+ Whether the model is 3D or not.
261
+ """
262
+ return len(self.input_shape) == 3
@@ -0,0 +1,125 @@
1
+ """UNet Pydantic model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal
6
+
7
+ from pydantic import ConfigDict, Field, field_validator
8
+
9
+ from .architecture_config import ArchitectureConfig
10
+
11
+
12
+ # TODO tests activation <-> pydantic model, test the literals!
13
+ # TODO annotations for the json schema?
14
+ class UNetConfig(ArchitectureConfig):
15
+ """
16
+ Pydantic model for a N2V(2)-compatible UNet.
17
+
18
+ Attributes
19
+ ----------
20
+ depth : int
21
+ Depth of the model, between 1 and 10 (default 2).
22
+ num_channels_init : int
23
+ Number of filters of the first level of the network, should be even
24
+ and minimum 8 (default 96).
25
+ """
26
+
27
+ # pydantic model config
28
+ model_config = ConfigDict(validate_assignment=True)
29
+
30
+ # discriminator used for choosing the pydantic model in Model
31
+ architecture: Literal["UNet"]
32
+ """Name of the architecture."""
33
+
34
+ # parameters
35
+ # validate_defaults allow ignoring default values in the dump if they were not set
36
+ conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
37
+ """Dimensions (2D or 3D) of the convolutional layers."""
38
+
39
+ num_classes: int = Field(default=1, ge=1, validate_default=True)
40
+ """Number of classes or channels in the model output."""
41
+
42
+ in_channels: int = Field(default=1, ge=1, validate_default=True)
43
+ """Number of channels in the input to the model."""
44
+
45
+ depth: int = Field(default=2, ge=1, le=10, validate_default=True)
46
+ """Number of levels in the UNet."""
47
+
48
+ num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
49
+ """Number of convolutional filters in the first layer of the UNet."""
50
+
51
+ # TODO we are not using this, so why make it a choice?
52
+ final_activation: Literal[
53
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
54
+ ] = Field(default="None", validate_default=True)
55
+ """Final activation function."""
56
+
57
+ n2v2: bool = Field(default=False, validate_default=True)
58
+ """Whether to use N2V2 architecture modifications, with blur pool layers and fewer
59
+ skip connections.
60
+ """
61
+
62
+ independent_channels: bool = Field(default=True, validate_default=True)
63
+ """Whether information is processed independently in each channel, used to train
64
+ channels independently."""
65
+
66
+ use_batch_norm: bool = Field(default=True, validate_default=True)
67
+ """Whether to use batch normalization in the model."""
68
+
69
+ @field_validator("num_channels_init")
70
+ @classmethod
71
+ def validate_num_channels_init(cls, num_channels_init: int) -> int:
72
+ """
73
+ Validate that num_channels_init is even.
74
+
75
+ Parameters
76
+ ----------
77
+ num_channels_init : int
78
+ Number of channels.
79
+
80
+ Returns
81
+ -------
82
+ int
83
+ Validated number of channels.
84
+
85
+ Raises
86
+ ------
87
+ ValueError
88
+ If the number of channels is odd.
89
+ """
90
+ # if odd
91
+ if num_channels_init % 2 != 0:
92
+ raise ValueError(
93
+ f"Number of channels for the bottom layer must be even"
94
+ f" (got {num_channels_init})."
95
+ )
96
+
97
+ return num_channels_init
98
+
99
+ def set_3D(self, is_3D: bool) -> None:
100
+ """
101
+ Set 3D model by setting the `conv_dims` parameters.
102
+
103
+ Parameters
104
+ ----------
105
+ is_3D : bool
106
+ Whether the algorithm is 3D or not.
107
+ """
108
+ if is_3D:
109
+ self.conv_dims = 3
110
+ else:
111
+ self.conv_dims = 2
112
+
113
+ def is_3D(self) -> bool:
114
+ """
115
+ Return whether the model is 3D or not.
116
+
117
+ This method is used in the NG configuration validation to check that the model
118
+ dimensions match the data dimensions.
119
+
120
+ Returns
121
+ -------
122
+ bool
123
+ Whether the model is 3D or not.
124
+ """
125
+ return self.conv_dims == 3