careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,120 @@
1
+ """Convenience function to create algorithm configurations."""
2
+
3
+ from typing import Annotated, Any, Literal, Union
4
+
5
+ from pydantic import Field, TypeAdapter
6
+
7
+ from careamics.config.algorithms import (
8
+ CAREAlgorithm,
9
+ N2NAlgorithm,
10
+ N2VAlgorithm,
11
+ # PN2VAlgorithm, # TODO not yet compatible with NG Dataset
12
+ )
13
+ from careamics.config.architectures import UNetConfig
14
+ from careamics.config.support.supported_architectures import SupportedArchitecture
15
+
16
+
17
+ # TODO rename so that it does not bear the same name as the module?
18
+ def algorithm_factory(
19
+ algorithm: dict[str, Any],
20
+ ) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
21
+ """
22
+ Create an algorithm model for training CAREamics.
23
+
24
+ Parameters
25
+ ----------
26
+ algorithm : dict
27
+ Algorithm dictionary.
28
+
29
+ Returns
30
+ -------
31
+ N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
32
+ Algorithm model for training CAREamics.
33
+ """
34
+ adapter: TypeAdapter = TypeAdapter(
35
+ Annotated[
36
+ Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm],
37
+ Field(discriminator="algorithm"),
38
+ ]
39
+ )
40
+ return adapter.validate_python(algorithm)
41
+
42
+
43
+ def create_algorithm_configuration(
44
+ dimensions: Literal[2, 3],
45
+ algorithm: Literal["n2v", "care", "n2n"],
46
+ loss: Literal["n2v", "mae", "mse"],
47
+ independent_channels: bool,
48
+ n_channels_in: int,
49
+ n_channels_out: int,
50
+ use_n2v2: bool = False,
51
+ model_params: dict | None = None,
52
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
53
+ optimizer_params: dict[str, Any] | None = None,
54
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
55
+ lr_scheduler_params: dict[str, Any] | None = None,
56
+ ) -> dict:
57
+ """
58
+ Create a dictionary with the parameters of the algorithm model.
59
+
60
+ Parameters
61
+ ----------
62
+ dimensions : {2, 3}
63
+ Dimension of the model, either 2D or 3D.
64
+ algorithm : {"n2v", "care", "n2n"}
65
+ Algorithm to use.
66
+ loss : {"n2v", "mae", "mse"}
67
+ Loss function to use.
68
+ independent_channels : bool
69
+ Whether to train all channels independently.
70
+ n_channels_in : int
71
+ Number of input channels.
72
+ n_channels_out : int
73
+ Number of output channels.
74
+ use_n2v2 : bool, default=false
75
+ Whether to use N2V2.
76
+ model_params : dict, default=None
77
+ UNetModel parameters.
78
+ optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
79
+ Optimizer to use.
80
+ optimizer_params : dict, default=None
81
+ Parameters for the optimizer, see PyTorch documentation for more details.
82
+ lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
83
+ Learning rate scheduler to use.
84
+ lr_scheduler_params : dict, default=None
85
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
86
+ details.
87
+
88
+
89
+ Returns
90
+ -------
91
+ dict
92
+ Algorithm model as dictionnary with the specified parameters.
93
+ """
94
+ # create dictionary to ensure priority of explicit parameters over model_params
95
+ # and prevent multiple same parameters being passed to UNetConfig
96
+ model_params = {} if model_params is None else model_params
97
+ model_params["n2v2"] = use_n2v2
98
+ model_params["conv_dims"] = dimensions
99
+ model_params["in_channels"] = n_channels_in
100
+ model_params["num_classes"] = n_channels_out
101
+ model_params["independent_channels"] = independent_channels
102
+
103
+ unet_model = UNetConfig(
104
+ architecture=SupportedArchitecture.UNET.value,
105
+ **model_params,
106
+ )
107
+
108
+ return {
109
+ "algorithm": algorithm,
110
+ "loss": loss,
111
+ "model": unet_model,
112
+ "optimizer": {
113
+ "name": optimizer,
114
+ "parameters": {} if optimizer_params is None else optimizer_params,
115
+ },
116
+ "lr_scheduler": {
117
+ "name": lr_scheduler,
118
+ "parameters": {} if lr_scheduler_params is None else lr_scheduler_params,
119
+ },
120
+ }
@@ -0,0 +1,154 @@
1
+ """Convenience functions to create NG data configurations."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Any, Literal
5
+
6
+ from careamics.config.data import NGDataConfig
7
+ from careamics.config.transformations import (
8
+ SPATIAL_TRANSFORMS_UNION,
9
+ XYFlipConfig,
10
+ XYRandomRotate90Config,
11
+ )
12
+
13
+
14
+ def list_spatial_augmentations(
15
+ augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
16
+ ) -> list[SPATIAL_TRANSFORMS_UNION]:
17
+ """
18
+ List the augmentations to apply.
19
+
20
+ Parameters
21
+ ----------
22
+ augmentations : list of transforms, optional
23
+ List of transforms to apply, either both or one of XYFlipConfig and
24
+ XYRandomRotate90Config.
25
+
26
+ Returns
27
+ -------
28
+ list of transforms
29
+ List of transforms to apply.
30
+
31
+ Raises
32
+ ------
33
+ ValueError
34
+ If the transforms are not XYFlipConfig or XYRandomRotate90Config.
35
+ ValueError
36
+ If there are duplicate transforms.
37
+ """
38
+ if augmentations is None:
39
+ transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
40
+ XYFlipConfig(),
41
+ XYRandomRotate90Config(),
42
+ ]
43
+ else:
44
+ # throw error if not all transforms are pydantic models
45
+ if not all(
46
+ isinstance(t, XYFlipConfig) or isinstance(t, XYRandomRotate90Config)
47
+ for t in augmentations
48
+ ):
49
+ raise ValueError(
50
+ "Accepted transforms are either XYFlipConfig or "
51
+ "XYRandomRotate90Config."
52
+ )
53
+
54
+ # check that there is no duplication
55
+ aug_types = [t.__class__ for t in augmentations]
56
+ if len(set(aug_types)) != len(aug_types):
57
+ raise ValueError("Duplicate transforms are not allowed.")
58
+
59
+ transform_list = augmentations
60
+
61
+ return transform_list
62
+
63
+
64
+ def create_ng_data_configuration(
65
+ data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
66
+ axes: str,
67
+ patch_size: Sequence[int],
68
+ batch_size: int,
69
+ augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
70
+ channels: Sequence[int] | None = None,
71
+ in_memory: bool | None = None,
72
+ train_dataloader_params: dict[str, Any] | None = None,
73
+ val_dataloader_params: dict[str, Any] | None = None,
74
+ pred_dataloader_params: dict[str, Any] | None = None,
75
+ seed: int | None = None,
76
+ ) -> NGDataConfig:
77
+ """
78
+ Create a training NGDatasetConfig.
79
+
80
+ Parameters
81
+ ----------
82
+ data_type : {"array", "tiff", "zarr", "czi", "custom"}
83
+ Type of the data.
84
+ axes : str
85
+ Axes of the data.
86
+ patch_size : list of int
87
+ Size of the patches along the spatial dimensions.
88
+ batch_size : int
89
+ Batch size.
90
+ augmentations : list of transforms
91
+ List of transforms to apply.
92
+ channels : Sequence of int, default=None
93
+ List of channels to use. If `None`, all channels are used.
94
+ in_memory : bool, default=None
95
+ Whether to load all data into memory. This is only supported for 'array',
96
+ 'tiff' and 'custom' data types. If `None`, defaults to `True` for 'array',
97
+ 'tiff' and `custom`, and `False` for 'zarr' and 'czi' data types. Must be `True`
98
+ for `array`.
99
+ augmentations : list of transforms or None, default=None
100
+ List of transforms to apply. If `None`, default augmentations are applied
101
+ (flip in X and Y, rotations by 90 degrees in the XY plane).
102
+ train_dataloader_params : dict
103
+ Parameters for the training dataloader, see PyTorch notes, by default None.
104
+ val_dataloader_params : dict
105
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
106
+ pred_dataloader_params : dict
107
+ Parameters for the test dataloader, see PyTorch notes, by default None.
108
+ seed : int, default=None
109
+ Random seed for reproducibility. If `None`, no seed is set.
110
+
111
+ Returns
112
+ -------
113
+ NGDataConfig
114
+ Next-Generation Data model with the specified parameters.
115
+ """
116
+ if augmentations is None:
117
+ augmentations = list_spatial_augmentations()
118
+
119
+ # data model
120
+ data: dict[str, Any] = {
121
+ "mode": "training",
122
+ "data_type": data_type,
123
+ "axes": axes,
124
+ "batch_size": batch_size,
125
+ "channels": channels,
126
+ "transforms": augmentations,
127
+ "seed": seed,
128
+ }
129
+
130
+ if in_memory is not None:
131
+ data["in_memory"] = in_memory
132
+
133
+ # don't override defaults set in DataConfig class
134
+ if train_dataloader_params is not None:
135
+ # the presence of `shuffle` key in the dataloader parameters is enforced
136
+ # by the NGDataConfig class
137
+ if "shuffle" not in train_dataloader_params:
138
+ train_dataloader_params["shuffle"] = True
139
+
140
+ data["train_dataloader_params"] = train_dataloader_params
141
+
142
+ if val_dataloader_params is not None:
143
+ data["val_dataloader_params"] = val_dataloader_params
144
+
145
+ if pred_dataloader_params is not None:
146
+ data["pred_dataloader_params"] = pred_dataloader_params
147
+
148
+ # add training patching
149
+ data["patching"] = {
150
+ "name": "random",
151
+ "patch_size": patch_size,
152
+ }
153
+
154
+ return NGDataConfig(**data)
@@ -0,0 +1,256 @@
1
+ """Convenience function to create N2V configurations."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Any, Literal
5
+
6
+ from careamics.config.ng_configs import N2VConfiguration
7
+ from careamics.config.support import (
8
+ SupportedPixelManipulation,
9
+ SupportedTransform,
10
+ )
11
+ from careamics.config.transformations import (
12
+ N2VManipulateConfig,
13
+ XYFlipConfig,
14
+ XYRandomRotate90Config,
15
+ )
16
+
17
+ from .algorithm_factory import create_algorithm_configuration
18
+ from .data_factory import create_ng_data_configuration, list_spatial_augmentations
19
+ from .training_factory import create_training_configuration, update_trainer_params
20
+
21
+
22
+ def create_n2v_configuration(
23
+ experiment_name: str,
24
+ data_type: Literal["array", "tiff", "zarr", "czi", "custom"],
25
+ axes: str,
26
+ patch_size: Sequence[int],
27
+ batch_size: int,
28
+ num_epochs: int = 100,
29
+ num_steps: int | None = None,
30
+ augmentations: list[XYFlipConfig | XYRandomRotate90Config] | None = None,
31
+ channels: Sequence[int] | None = None,
32
+ in_memory: bool | None = None,
33
+ independent_channels: bool = True,
34
+ use_n2v2: bool = False,
35
+ n_channels: int | None = None,
36
+ roi_size: int = 11,
37
+ masked_pixel_percentage: float = 0.2,
38
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
39
+ struct_n2v_span: int = 5,
40
+ trainer_params: dict | None = None,
41
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
42
+ model_params: dict | None = None,
43
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
44
+ optimizer_params: dict[str, Any] | None = None,
45
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
46
+ lr_scheduler_params: dict[str, Any] | None = None,
47
+ train_dataloader_params: dict[str, Any] | None = None,
48
+ val_dataloader_params: dict[str, Any] | None = None,
49
+ checkpoint_params: dict[str, Any] | None = None,
50
+ ) -> N2VConfiguration:
51
+ """
52
+ Create a configuration for training Noise2Void.
53
+
54
+ N2V uses a UNet model to denoise images in a self-supervised manner. To use its
55
+ variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
56
+ (structN2V) parameters, or set `use_n2v2` to True (N2V2).
57
+
58
+ N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
59
+ connections, thus removing checkboard artefacts. StructN2V is used when vertical
60
+ or horizontal correlations are present in the noise; it applies an additional mask
61
+ to the manipulated pixel neighbors.
62
+
63
+ If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
64
+ 2.
65
+
66
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
67
+ channels.
68
+
69
+ By default, all channels are trained independently. To train all channels together,
70
+ set `independent_channels` to False.
71
+
72
+ By default, the transformations applied are a random flip along X or Y, and a random
73
+ 90 degrees rotation in the XY plane. Normalization is always applied, as well as the
74
+ N2V manipulation.
75
+
76
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
77
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
78
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
79
+ disable the transforms, simply pass an empty list.
80
+
81
+ The `roi_size` parameter specifies the size of the area around each pixel that will
82
+ be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
83
+ pixels per patch will be manipulated.
84
+
85
+ The parameters of the UNet can be specified in the `model_params` (passed as a
86
+ parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
87
+ corresponding parameters passed in `model_params`.
88
+
89
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
90
+ will be applied to each manipulated pixel.
91
+
92
+ Parameters
93
+ ----------
94
+ experiment_name : str
95
+ Name of the experiment.
96
+ data_type : Literal["array", "tiff", "czi", "custom"]
97
+ Type of the data.
98
+ axes : str
99
+ Axes of the data (e.g. SYX).
100
+ patch_size : List[int]
101
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
102
+ batch_size : int
103
+ Batch size.
104
+ num_epochs : int, default=100
105
+ Number of epochs to train for. If provided, this will be added to
106
+ trainer_params.
107
+ num_steps : int, optional
108
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
109
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
110
+ documentation for more details.
111
+ augmentations : list of transforms, default=None
112
+ List of transforms to apply, either both or one of XYFlipConfig and
113
+ XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
114
+ and XYRandomRotate90 (in XY) to the images.
115
+ channels : Sequence of int, optional
116
+ List of channels to use. If `None`, all channels are used.
117
+ in_memory : bool, optional
118
+ Whether to load all data into memory. This is only supported for 'array',
119
+ 'tiff' and 'custom' data types. If `None`, defaults to `True` for 'array',
120
+ 'tiff' and `custom`, and `False` for 'zarr' and 'czi' data types. Must be `True`
121
+ for `array`.
122
+ independent_channels : bool, optional
123
+ Whether to train all channels together, by default True.
124
+ use_n2v2 : bool, optional
125
+ Whether to use N2V2, by default False.
126
+ n_channels : int or None, default=None
127
+ Number of channels (in and out). If `channels` is specified, then the number of
128
+ channels is inferred from its length.
129
+ roi_size : int, optional
130
+ N2V pixel manipulation area, by default 11.
131
+ masked_pixel_percentage : float, optional
132
+ Percentage of pixels masked in each patch, by default 0.2.
133
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
134
+ Axis along which to apply structN2V mask, by default "none".
135
+ struct_n2v_span : int, optional
136
+ Span of the structN2V mask, by default 5.
137
+ trainer_params : dict, optional
138
+ Parameters for the trainer, see the relevant documentation.
139
+ logger : Literal["wandb", "tensorboard", "none"], optional
140
+ Logger to use, by default "none".
141
+ model_params : dict, default=None
142
+ UNetModel parameters.
143
+ optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
144
+ Optimizer to use.
145
+ optimizer_params : dict, default=None
146
+ Parameters for the optimizer, see PyTorch documentation for more details.
147
+ lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
148
+ Learning rate scheduler to use.
149
+ lr_scheduler_params : dict, default=None
150
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
151
+ details.
152
+ train_dataloader_params : dict, optional
153
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
154
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
155
+ the `GeneralDataConfig`.
156
+ val_dataloader_params : dict, optional
157
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
158
+ If left as `None`, the empty dict `{}` will be used, this is set in the
159
+ `GeneralDataConfig`.
160
+ checkpoint_params : dict, default=None
161
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
162
+ (`ModelCheckpoint`) for the list of available parameters.
163
+
164
+ Returns
165
+ -------
166
+ N2VConfiguration
167
+ Configuration for training N2V.
168
+ """
169
+ # if there are channels, we need to specify their number
170
+ channels_present = "C" in axes
171
+
172
+ if channels_present and (n_channels is None and channels is None):
173
+ raise ValueError(
174
+ "`n_channels` or `channels` must be specified when using channels."
175
+ )
176
+ elif not channels_present and (n_channels is not None and n_channels > 1):
177
+ raise ValueError(
178
+ f"C is not present in the axes, but number of channels is specified "
179
+ f"(got {n_channels} channel)."
180
+ )
181
+
182
+ if n_channels is not None and channels is not None:
183
+ if n_channels != len(channels):
184
+ raise ValueError(
185
+ f"Number of channels ({n_channels}) does not match length of "
186
+ f"`channels` ({len(channels)}). Only specify `channels`."
187
+ )
188
+
189
+ if n_channels is None:
190
+ n_channels = 1 if channels is None else len(channels)
191
+
192
+ # augmentations
193
+ spatial_transforms = list_spatial_augmentations(augmentations)
194
+
195
+ # data
196
+ data_config = create_ng_data_configuration(
197
+ data_type=data_type,
198
+ axes=axes,
199
+ patch_size=patch_size,
200
+ batch_size=batch_size,
201
+ augmentations=spatial_transforms,
202
+ channels=channels,
203
+ in_memory=in_memory,
204
+ train_dataloader_params=train_dataloader_params,
205
+ val_dataloader_params=val_dataloader_params,
206
+ )
207
+
208
+ # algorithm
209
+ algorithm_params = create_algorithm_configuration(
210
+ dimensions=3 if data_config.is_3D() else 2,
211
+ algorithm="n2v",
212
+ loss="n2v",
213
+ independent_channels=independent_channels,
214
+ n_channels_in=n_channels,
215
+ n_channels_out=n_channels,
216
+ use_n2v2=use_n2v2,
217
+ model_params=model_params,
218
+ optimizer=optimizer,
219
+ optimizer_params=optimizer_params,
220
+ lr_scheduler=lr_scheduler,
221
+ lr_scheduler_params=lr_scheduler_params,
222
+ )
223
+
224
+ # create the N2VManipulate transform using the supplied parameters
225
+ n2v_transform = N2VManipulateConfig(
226
+ name=SupportedTransform.N2V_MANIPULATE.value,
227
+ strategy=(
228
+ SupportedPixelManipulation.MEDIAN.value
229
+ if use_n2v2
230
+ else SupportedPixelManipulation.UNIFORM.value
231
+ ),
232
+ roi_size=roi_size,
233
+ masked_pixel_percentage=masked_pixel_percentage,
234
+ struct_mask_axis=struct_n2v_axis,
235
+ struct_mask_span=struct_n2v_span,
236
+ )
237
+ algorithm_params["n2v_config"] = n2v_transform
238
+
239
+ # training
240
+ final_trainer_params = update_trainer_params(
241
+ trainer_params=trainer_params,
242
+ num_epochs=num_epochs,
243
+ num_steps=num_steps,
244
+ )
245
+ training_params = create_training_configuration(
246
+ trainer_params=final_trainer_params,
247
+ logger=logger,
248
+ checkpoint_params=checkpoint_params,
249
+ )
250
+
251
+ return N2VConfiguration(
252
+ experiment_name=experiment_name,
253
+ algorithm_config=algorithm_params,
254
+ data_config=data_config,
255
+ training_config=training_params,
256
+ )
@@ -0,0 +1,69 @@
1
+ """Convenience functions to create training configurations."""
2
+
3
+ from typing import Any, Literal
4
+
5
+ from careamics.config.lightning.training_config import TrainingConfig
6
+
7
+
8
+ def create_training_configuration(
9
+ trainer_params: dict,
10
+ logger: Literal["wandb", "tensorboard", "none"],
11
+ checkpoint_params: dict[str, Any] | None = None,
12
+ ) -> TrainingConfig:
13
+ """
14
+ Create a dictionary with the parameters of the training model.
15
+
16
+ Parameters
17
+ ----------
18
+ trainer_params : dict
19
+ Parameters for Lightning Trainer class, see PyTorch Lightning documentation.
20
+ logger : {"wandb", "tensorboard", "none"}
21
+ Logger to use.
22
+ checkpoint_params : dict, default=None
23
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
24
+ (`ModelCheckpoint`) for the list of available parameters.
25
+
26
+ Returns
27
+ -------
28
+ TrainingConfig
29
+ Training model with the specified parameters.
30
+ """
31
+ return TrainingConfig(
32
+ lightning_trainer_config=trainer_params,
33
+ logger=None if logger == "none" else logger,
34
+ checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
35
+ )
36
+
37
+
38
+ def update_trainer_params(
39
+ trainer_params: dict[str, Any] | None = None,
40
+ num_epochs: int | None = None,
41
+ num_steps: int | None = None,
42
+ ) -> dict[str, Any]:
43
+ """
44
+ Update trainer parameters with num_epochs and num_steps.
45
+
46
+ Parameters
47
+ ----------
48
+ trainer_params : dict, optional
49
+ Parameters for Lightning Trainer class, by default None.
50
+ num_epochs : int, optional
51
+ Number of epochs to train for. If provided, this will be added as max_epochs
52
+ to trainer_params, by default None.
53
+ num_steps : int, optional
54
+ Number of batches in 1 epoch. If provided, this will be added as
55
+ limit_train_batches to trainer_params, by default None.
56
+
57
+ Returns
58
+ -------
59
+ dict
60
+ Updated trainer parameters dictionary.
61
+ """
62
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
63
+
64
+ if num_epochs is not None:
65
+ final_trainer_params["max_epochs"] = num_epochs
66
+ if num_steps is not None:
67
+ final_trainer_params["limit_train_batches"] = num_steps
68
+
69
+ return final_trainer_params
@@ -0,0 +1,12 @@
1
+ """Noise models Pydantic configurations."""
2
+
3
+ __all__ = [
4
+ "GaussianLikelihoodConfig",
5
+ "GaussianMixtureNMConfig",
6
+ "MultiChannelNMConfig",
7
+ "NMLikelihoodConfig",
8
+ ]
9
+
10
+
11
+ from .likelihood_config import GaussianLikelihoodConfig, NMLikelihoodConfig
12
+ from .noise_model_config import GaussianMixtureNMConfig, MultiChannelNMConfig
@@ -0,0 +1,60 @@
1
+ """Likelihood model."""
2
+
3
+ from typing import Annotated, Literal, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator
8
+
9
+ from careamics.models.lvae.noise_models import (
10
+ GaussianMixtureNoiseModel,
11
+ MultiChannelNoiseModel,
12
+ )
13
+ from careamics.utils.serializers import _array_to_json, _to_torch
14
+
15
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
16
+
17
+ # TODO: this is a temporary solution to serialize and deserialize tensor fields
18
+ # in pydantic models. Specifically, the aim is to enable saving and loading configs
19
+ # with such tensors to/from JSON files during, resp., training and evaluation.
20
+ Tensor = Annotated[
21
+ Union[np.ndarray, torch.Tensor],
22
+ PlainSerializer(_array_to_json, return_type=str),
23
+ PlainValidator(_to_torch),
24
+ ]
25
+ """Annotated tensor type, used to serialize arrays or tensors to JSON strings
26
+ and deserialize them back to tensors."""
27
+
28
+
29
+ class GaussianLikelihoodConfig(BaseModel):
30
+ """Gaussian likelihood configuration."""
31
+
32
+ model_config = ConfigDict(validate_assignment=True)
33
+
34
+ predict_logvar: Literal["pixelwise"] | None = None
35
+ """If `pixelwise`, log-variance is computed for each pixel, else log-variance
36
+ is not computed."""
37
+
38
+ logvar_lowerbound: Union[float, None] = None
39
+ """The lowerbound value for log-variance."""
40
+
41
+
42
+ class NMLikelihoodConfig(BaseModel):
43
+ """Noise model likelihood configuration.
44
+
45
+ NOTE: we need to define the data mean and std here because the noise model
46
+ is trained on not-normalized data. Hence, we need to unnormalize the model
47
+ output to compute the noise model likelihood.
48
+ """
49
+
50
+ model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
51
+
52
+ # TODO remove and use as parameters to the likelihood functions?
53
+ data_mean: Tensor | None = None
54
+ """The mean of the data, used to unnormalize data for noise model evaluation.
55
+ Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
56
+
57
+ # TODO remove and use as parameters to the likelihood functions?
58
+ data_std: Tensor | None = None
59
+ """The standard deviation of the data, used to unnormalize data for noise
60
+ model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""