careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,2400 @@
1
+ """Convenience functions to create configurations for training and inference."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Annotated, Any, Literal, Union
5
+
6
+ from pydantic import Field, TypeAdapter
7
+
8
+ from careamics.config.algorithms import (
9
+ CAREAlgorithm,
10
+ MicroSplitAlgorithm,
11
+ N2NAlgorithm,
12
+ N2VAlgorithm,
13
+ PN2VAlgorithm,
14
+ )
15
+ from careamics.config.architectures import LVAEConfig, UNetConfig
16
+ from careamics.config.data import DataConfig
17
+ from careamics.config.lightning.training_config import TrainingConfig
18
+ from careamics.config.losses.loss_config import LVAELossConfig
19
+ from careamics.config.noise_model.likelihood_config import (
20
+ GaussianLikelihoodConfig,
21
+ NMLikelihoodConfig,
22
+ )
23
+ from careamics.config.noise_model.noise_model_config import (
24
+ GaussianMixtureNMConfig,
25
+ MultiChannelNMConfig,
26
+ )
27
+ from careamics.config.support import (
28
+ SupportedArchitecture,
29
+ SupportedPixelManipulation,
30
+ SupportedTransform,
31
+ )
32
+ from careamics.config.transformations import (
33
+ SPATIAL_TRANSFORMS_UNION,
34
+ N2VManipulateConfig,
35
+ XYFlipConfig,
36
+ XYRandomRotate90Config,
37
+ )
38
+ from careamics.lvae_training.dataset.config import MicroSplitDataConfig
39
+
40
+ from .configuration import Configuration
41
+
42
+
43
+ def algorithm_factory(
44
+ algorithm: dict[str, Any],
45
+ ) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm, PN2VAlgorithm]:
46
+ """
47
+ Create an algorithm model for training CAREamics.
48
+
49
+ Parameters
50
+ ----------
51
+ algorithm : dict
52
+ Algorithm dictionary.
53
+
54
+ Returns
55
+ -------
56
+ N2VAlgorithm or N2NAlgorithm or CAREAlgorithm or PN2VAlgorithm
57
+ Algorithm model for training CAREamics.
58
+ """
59
+ adapter: TypeAdapter = TypeAdapter(
60
+ Annotated[
61
+ Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm, PN2VAlgorithm],
62
+ Field(discriminator="algorithm"),
63
+ ]
64
+ )
65
+ return adapter.validate_python(algorithm)
66
+
67
+
68
+ def _list_spatial_augmentations(
69
+ augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
70
+ ) -> list[SPATIAL_TRANSFORMS_UNION]:
71
+ """
72
+ List the augmentations to apply.
73
+
74
+ Parameters
75
+ ----------
76
+ augmentations : list of transforms, optional
77
+ List of transforms to apply, either both or one of XYFlipConfig and
78
+ XYRandomRotate90Config.
79
+
80
+ Returns
81
+ -------
82
+ list of transforms
83
+ List of transforms to apply.
84
+
85
+ Raises
86
+ ------
87
+ ValueError
88
+ If the transforms are not XYFlipConfig or XYRandomRotate90Config.
89
+ ValueError
90
+ If there are duplicate transforms.
91
+ """
92
+ if augmentations is None:
93
+ transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
94
+ XYFlipConfig(),
95
+ XYRandomRotate90Config(),
96
+ ]
97
+ else:
98
+ # throw error if not all transforms are pydantic models
99
+ if not all(
100
+ isinstance(t, XYFlipConfig) or isinstance(t, XYRandomRotate90Config)
101
+ for t in augmentations
102
+ ):
103
+ raise ValueError(
104
+ "Accepted transforms are either XYFlipConfig or "
105
+ "XYRandomRotate90Config."
106
+ )
107
+
108
+ # check that there is no duplication
109
+ aug_types = [t.__class__ for t in augmentations]
110
+ if len(set(aug_types)) != len(aug_types):
111
+ raise ValueError("Duplicate transforms are not allowed.")
112
+
113
+ transform_list = augmentations
114
+
115
+ return transform_list
116
+
117
+
118
+ def _create_unet_configuration(
119
+ axes: str,
120
+ n_channels_in: int,
121
+ n_channels_out: int,
122
+ independent_channels: bool,
123
+ use_n2v2: bool,
124
+ model_params: dict[str, Any] | None = None,
125
+ ) -> UNetConfig:
126
+ """
127
+ Create a dictionary with the parameters of the UNet model.
128
+
129
+ Parameters
130
+ ----------
131
+ axes : str
132
+ Axes of the data.
133
+ n_channels_in : int
134
+ Number of input channels.
135
+ n_channels_out : int
136
+ Number of output channels.
137
+ independent_channels : bool
138
+ Whether to train all channels independently.
139
+ use_n2v2 : bool
140
+ Whether to use N2V2.
141
+ model_params : dict
142
+ UNetModel parameters.
143
+
144
+ Returns
145
+ -------
146
+ UNetModel
147
+ UNet model with the specified parameters.
148
+ """
149
+ if model_params is None:
150
+ model_params = {}
151
+
152
+ model_params["n2v2"] = use_n2v2
153
+ model_params["conv_dims"] = 3 if "Z" in axes else 2
154
+ model_params["in_channels"] = n_channels_in
155
+ model_params["num_classes"] = n_channels_out
156
+ model_params["independent_channels"] = independent_channels
157
+
158
+ return UNetConfig(
159
+ architecture=SupportedArchitecture.UNET.value,
160
+ **model_params,
161
+ )
162
+
163
+
164
+ def _create_algorithm_configuration(
165
+ axes: str,
166
+ algorithm: Literal["n2v", "care", "n2n", "pn2v"],
167
+ loss: Literal["n2v", "mae", "mse", "pn2v"],
168
+ independent_channels: bool,
169
+ n_channels_in: int,
170
+ n_channels_out: int,
171
+ use_n2v2: bool = False,
172
+ model_params: dict | None = None,
173
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
174
+ optimizer_params: dict[str, Any] | None = None,
175
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
176
+ lr_scheduler_params: dict[str, Any] | None = None,
177
+ ) -> dict:
178
+ """
179
+ Create a dictionary with the parameters of the algorithm model.
180
+
181
+ Parameters
182
+ ----------
183
+ axes : str
184
+ Axes of the data.
185
+ algorithm : {"n2v", "care", "n2n", "pn2v"}
186
+ Algorithm to use.
187
+ loss : {"n2v", "mae", "mse", "pn2v"}
188
+ Loss function to use.
189
+ independent_channels : bool
190
+ Whether to train all channels independently.
191
+ n_channels_in : int
192
+ Number of input channels.
193
+ n_channels_out : int
194
+ Number of output channels.
195
+ use_n2v2 : bool, default=false
196
+ Whether to use N2V2.
197
+ model_params : dict, default=None
198
+ UNetModel parameters.
199
+ optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
200
+ Optimizer to use.
201
+ optimizer_params : dict, default=None
202
+ Parameters for the optimizer, see PyTorch documentation for more details.
203
+ lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
204
+ Learning rate scheduler to use.
205
+ lr_scheduler_params : dict, default=None
206
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
207
+ details.
208
+
209
+
210
+ Returns
211
+ -------
212
+ dict
213
+ Algorithm model as dictionnary with the specified parameters.
214
+ """
215
+ # model
216
+ unet_model = _create_unet_configuration(
217
+ axes=axes,
218
+ n_channels_in=n_channels_in,
219
+ n_channels_out=n_channels_out,
220
+ independent_channels=independent_channels,
221
+ use_n2v2=use_n2v2,
222
+ model_params=model_params,
223
+ )
224
+
225
+ return {
226
+ "algorithm": algorithm,
227
+ "loss": loss,
228
+ "model": unet_model,
229
+ "optimizer": {
230
+ "name": optimizer,
231
+ "parameters": {} if optimizer_params is None else optimizer_params,
232
+ },
233
+ "lr_scheduler": {
234
+ "name": lr_scheduler,
235
+ "parameters": {} if lr_scheduler_params is None else lr_scheduler_params,
236
+ },
237
+ }
238
+
239
+
240
+ def _create_data_configuration(
241
+ data_type: Literal["array", "tiff", "czi", "custom"],
242
+ axes: str,
243
+ patch_size: Sequence[int],
244
+ batch_size: int,
245
+ augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
246
+ train_dataloader_params: dict[str, Any] | None = None,
247
+ val_dataloader_params: dict[str, Any] | None = None,
248
+ ) -> DataConfig:
249
+ """
250
+ Create a dictionary with the parameters of the data model.
251
+
252
+ Parameters
253
+ ----------
254
+ data_type : {"array", "tiff", "czi", "custom"}
255
+ Type of the data.
256
+ axes : str
257
+ Axes of the data.
258
+ patch_size : list of int
259
+ Size of the patches along the spatial dimensions.
260
+ batch_size : int
261
+ Batch size.
262
+ augmentations : list of transforms
263
+ List of transforms to apply.
264
+ train_dataloader_params : dict
265
+ Parameters for the training dataloader, see PyTorch notes, by default None.
266
+ val_dataloader_params : dict
267
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
268
+
269
+ Returns
270
+ -------
271
+ DataConfig
272
+ Data model with the specified parameters.
273
+ """
274
+ # data model
275
+ data = {
276
+ "data_type": data_type,
277
+ "axes": axes,
278
+ "patch_size": patch_size,
279
+ "batch_size": batch_size,
280
+ "transforms": augmentations,
281
+ }
282
+ # Don't override defaults set in DataConfig class
283
+ if train_dataloader_params is not None:
284
+ # DataConfig enforces the presence of `shuffle` key in the dataloader parameters
285
+ if "shuffle" not in train_dataloader_params:
286
+ train_dataloader_params["shuffle"] = True
287
+
288
+ data["train_dataloader_params"] = train_dataloader_params
289
+
290
+ if val_dataloader_params is not None:
291
+ data["val_dataloader_params"] = val_dataloader_params
292
+
293
+ return DataConfig(**data)
294
+
295
+
296
+ def _create_microsplit_data_configuration(
297
+ data_type: Literal["array", "tiff", "custom"],
298
+ axes: str,
299
+ patch_size: Sequence[int],
300
+ grid_size: int,
301
+ multiscale_count: int,
302
+ batch_size: int,
303
+ augmentations: Union[list[SPATIAL_TRANSFORMS_UNION]],
304
+ train_dataloader_params: dict[str, Any] | None = None,
305
+ val_dataloader_params: dict[str, Any] | None = None,
306
+ ) -> DataConfig:
307
+ """
308
+ Create a dictionary with the parameters of the data model.
309
+
310
+ Parameters
311
+ ----------
312
+ data_type : {"array", "tiff", "czi", "custom"}
313
+ Type of the data.
314
+ axes : str
315
+ Axes of the data.
316
+ patch_size : list of int
317
+ Size of the patches along the spatial dimensions.
318
+ grid_size : int
319
+ Size of the grid for multiscale data configuration.
320
+ multiscale_count : int
321
+ Number of multiscale levels.
322
+ batch_size : int
323
+ Batch size.
324
+ augmentations : list of transforms
325
+ List of transforms to apply.
326
+ train_dataloader_params : dict
327
+ Parameters for the training dataloader, see PyTorch notes, by default None.
328
+ val_dataloader_params : dict
329
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
330
+
331
+ Returns
332
+ -------
333
+ DataConfig
334
+ Data model with the specified parameters.
335
+ """
336
+ # data model
337
+ data = {
338
+ "data_type": data_type,
339
+ "axes": axes,
340
+ "image_size": patch_size,
341
+ "grid_size": grid_size,
342
+ "multiscale_lowres_count": multiscale_count,
343
+ "batch_size": batch_size,
344
+ "transforms": augmentations,
345
+ }
346
+ # Don't override defaults set in DataConfig class
347
+ if train_dataloader_params is not None:
348
+ # DataConfig enforces the presence of `shuffle` key in the dataloader parameters
349
+ if "shuffle" not in train_dataloader_params:
350
+ train_dataloader_params["shuffle"] = True
351
+
352
+ data["train_dataloader_params"] = train_dataloader_params
353
+
354
+ if val_dataloader_params is not None:
355
+ data["val_dataloader_params"] = val_dataloader_params
356
+
357
+ return MicroSplitDataConfig(**data)
358
+
359
+
360
+ def _create_training_configuration(
361
+ trainer_params: dict,
362
+ logger: Literal["wandb", "tensorboard", "none"],
363
+ checkpoint_params: dict[str, Any] | None = None,
364
+ ) -> TrainingConfig:
365
+ """
366
+ Create a dictionary with the parameters of the training model.
367
+
368
+ Parameters
369
+ ----------
370
+ trainer_params : dict
371
+ Parameters for Lightning Trainer class, see PyTorch Lightning documentation.
372
+ logger : {"wandb", "tensorboard", "none"}
373
+ Logger to use.
374
+ checkpoint_params : dict, default=None
375
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
376
+ (`ModelCheckpoint`) for the list of available parameters.
377
+
378
+ Returns
379
+ -------
380
+ TrainingConfig
381
+ Training model with the specified parameters.
382
+ """
383
+ return TrainingConfig(
384
+ lightning_trainer_config=trainer_params,
385
+ logger=None if logger == "none" else logger,
386
+ checkpoint_callback={} if checkpoint_params is None else checkpoint_params,
387
+ )
388
+
389
+
390
+ def update_trainer_params(
391
+ trainer_params: dict[str, Any] | None = None,
392
+ num_epochs: int | None = None,
393
+ num_steps: int | None = None,
394
+ ) -> dict[str, Any]:
395
+ """
396
+ Update trainer parameters with num_epochs and num_steps.
397
+
398
+ Parameters
399
+ ----------
400
+ trainer_params : dict, optional
401
+ Parameters for Lightning Trainer class, by default None.
402
+ num_epochs : int, optional
403
+ Number of epochs to train for. If provided, this will be added as max_epochs
404
+ to trainer_params, by default None.
405
+ num_steps : int, optional
406
+ Number of batches in 1 epoch. If provided, this will be added as
407
+ limit_train_batches to trainer_params, by default None.
408
+
409
+ Returns
410
+ -------
411
+ dict
412
+ Updated trainer parameters dictionary.
413
+ """
414
+ final_trainer_params = {} if trainer_params is None else trainer_params.copy()
415
+
416
+ if num_epochs is not None:
417
+ final_trainer_params["max_epochs"] = num_epochs
418
+ if num_steps is not None:
419
+ final_trainer_params["limit_train_batches"] = num_steps
420
+
421
+ return final_trainer_params
422
+
423
+
424
+ # TODO reconsider naming once we officially support LVAE approaches
425
+ def _create_supervised_config_dict(
426
+ algorithm: Literal["care", "n2n"],
427
+ experiment_name: str,
428
+ data_type: Literal["array", "tiff", "czi", "custom"],
429
+ axes: str,
430
+ patch_size: Sequence[int],
431
+ batch_size: int,
432
+ trainer_params: dict | None = None,
433
+ augmentations: list[SPATIAL_TRANSFORMS_UNION] | None = None,
434
+ independent_channels: bool = True,
435
+ loss: Literal["mae", "mse"] = "mae",
436
+ n_channels_in: int | None = None,
437
+ n_channels_out: int | None = None,
438
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
439
+ model_params: dict | None = None,
440
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
441
+ optimizer_params: dict[str, Any] | None = None,
442
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
443
+ lr_scheduler_params: dict[str, Any] | None = None,
444
+ train_dataloader_params: dict[str, Any] | None = None,
445
+ val_dataloader_params: dict[str, Any] | None = None,
446
+ checkpoint_params: dict[str, Any] | None = None,
447
+ num_epochs: int | None = None,
448
+ num_steps: int | None = None,
449
+ ) -> dict:
450
+ """
451
+ Create a configuration for training CARE or Noise2Noise.
452
+
453
+ Parameters
454
+ ----------
455
+ algorithm : Literal["care", "n2n"]
456
+ Algorithm to use.
457
+ experiment_name : str
458
+ Name of the experiment.
459
+ data_type : Literal["array", "tiff", "czi", "custom"]
460
+ Type of the data.
461
+ axes : str
462
+ Axes of the data (e.g. SYX).
463
+ patch_size : List[int]
464
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
465
+ batch_size : int
466
+ Batch size.
467
+ trainer_params : dict
468
+ Parameters for the training configuration.
469
+ augmentations : list of transforms, default=None
470
+ List of transforms to apply, either both or one of XYFlipConfig and
471
+ XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
472
+ and XYRandomRotate90 (in XY) to the images.
473
+ independent_channels : bool, optional
474
+ Whether to train all channels independently, by default False.
475
+ loss : Literal["mae", "mse"], optional
476
+ Loss function to use, by default "mae".
477
+ n_channels_in : int or None, default=None
478
+ Number of channels in.
479
+ n_channels_out : int or None, default=None
480
+ Number of channels out.
481
+ logger : Literal["wandb", "tensorboard", "none"], optional
482
+ Logger to use, by default "none".
483
+ model_params : dict, default=None
484
+ UNetModel parameters.
485
+ optimizer : {"Adam", "Adamax", "SGD"}, default="Adam"
486
+ Optimizer to use.
487
+ optimizer_params : dict, default=None
488
+ Parameters for the optimizer, see PyTorch documentation for more details.
489
+ lr_scheduler : {"ReduceLROnPlateau", "StepLR"}, default="ReduceLROnPlateau"
490
+ Learning rate scheduler to use.
491
+ lr_scheduler_params : dict, default=None
492
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
493
+ details.
494
+ train_dataloader_params : dict
495
+ Parameters for the training dataloader, see PyTorch notes, by default None.
496
+ val_dataloader_params : dict
497
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
498
+ checkpoint_params : dict, default=None
499
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
500
+ (`ModelCheckpoint`) for the list of available parameters.
501
+ num_epochs : int or None, default=None
502
+ Number of epochs to train for. If provided, this will be added to
503
+ trainer_params.
504
+ num_steps : int or None, default=None
505
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
506
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
507
+ documentation for more details.
508
+
509
+ Returns
510
+ -------
511
+ Configuration
512
+ Configuration for training CARE or Noise2Noise.
513
+
514
+ Raises
515
+ ------
516
+ ValueError
517
+ If the number of channels is not specified when using channels.
518
+ ValueError
519
+ If the number of channels is specified but "C" is not in the axes.
520
+ """
521
+ # if there are channels, we need to specify their number
522
+ if "C" in axes and n_channels_in is None:
523
+ raise ValueError("Number of channels in must be specified when using channels ")
524
+ elif "C" not in axes and (n_channels_in is not None and n_channels_in > 1):
525
+ raise ValueError(
526
+ f"C is not present in the axes, but number of channels is specified "
527
+ f"(got {n_channels_in} channels)."
528
+ )
529
+
530
+ if n_channels_in is None:
531
+ n_channels_in = 1
532
+
533
+ if n_channels_out is None:
534
+ n_channels_out = n_channels_in
535
+
536
+ # augmentations
537
+ spatial_transform_list = _list_spatial_augmentations(augmentations)
538
+
539
+ # algorithm
540
+ algorithm_params = _create_algorithm_configuration(
541
+ axes=axes,
542
+ algorithm=algorithm,
543
+ loss=loss,
544
+ independent_channels=independent_channels,
545
+ n_channels_in=n_channels_in,
546
+ n_channels_out=n_channels_out,
547
+ model_params=model_params,
548
+ optimizer=optimizer,
549
+ optimizer_params=optimizer_params,
550
+ lr_scheduler=lr_scheduler,
551
+ lr_scheduler_params=lr_scheduler_params,
552
+ )
553
+
554
+ # data
555
+ data_params = _create_data_configuration(
556
+ data_type=data_type,
557
+ axes=axes,
558
+ patch_size=patch_size,
559
+ batch_size=batch_size,
560
+ augmentations=spatial_transform_list,
561
+ train_dataloader_params=train_dataloader_params,
562
+ val_dataloader_params=val_dataloader_params,
563
+ )
564
+
565
+ # training
566
+ final_trainer_params = update_trainer_params(
567
+ trainer_params=trainer_params,
568
+ num_epochs=num_epochs,
569
+ num_steps=num_steps,
570
+ )
571
+ training_params = _create_training_configuration(
572
+ trainer_params=final_trainer_params,
573
+ logger=logger,
574
+ checkpoint_params=checkpoint_params,
575
+ )
576
+
577
+ return {
578
+ "experiment_name": experiment_name,
579
+ "algorithm_config": algorithm_params,
580
+ "data_config": data_params,
581
+ "training_config": training_params,
582
+ }
583
+
584
+
585
+ def create_care_configuration(
586
+ experiment_name: str,
587
+ data_type: Literal["array", "tiff", "czi", "custom"],
588
+ axes: str,
589
+ patch_size: Sequence[int],
590
+ batch_size: int,
591
+ num_epochs: int = 100,
592
+ num_steps: int | None = None,
593
+ augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
594
+ independent_channels: bool = True,
595
+ loss: Literal["mae", "mse"] = "mae",
596
+ n_channels_in: int | None = None,
597
+ n_channels_out: int | None = None,
598
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
599
+ trainer_params: dict | None = None,
600
+ model_params: dict | None = None,
601
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
602
+ optimizer_params: dict[str, Any] | None = None,
603
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
604
+ lr_scheduler_params: dict[str, Any] | None = None,
605
+ train_dataloader_params: dict[str, Any] | None = None,
606
+ val_dataloader_params: dict[str, Any] | None = None,
607
+ checkpoint_params: dict[str, Any] | None = None,
608
+ ) -> Configuration:
609
+ """
610
+ Create a configuration for training CARE.
611
+
612
+ If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
613
+ 2.
614
+
615
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
616
+ channels. Likewise, if you set the number of channels, then "C" must be present in
617
+ `axes`.
618
+
619
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
620
+ not specified, it will be assumed to be equal to `n_channels_in`.
621
+
622
+ By default, all channels are trained together. To train all channels independently,
623
+ set `independent_channels` to True.
624
+
625
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
626
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
627
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
628
+ disable the transforms, simply pass an empty list.
629
+
630
+ Parameters
631
+ ----------
632
+ experiment_name : str
633
+ Name of the experiment.
634
+ data_type : Literal["array", "tiff", "czi", "custom"]
635
+ Type of the data.
636
+ axes : str
637
+ Axes of the data (e.g. SYX).
638
+ patch_size : List[int]
639
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
640
+ batch_size : int
641
+ Batch size.
642
+ num_epochs : int, default=100
643
+ Number of epochs to train for. If provided, this will be added to
644
+ trainer_params.
645
+ num_steps : int, optional
646
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
647
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
648
+ documentation for more details.
649
+ augmentations : list of transforms, default=None
650
+ List of transforms to apply, either both or one of XYFlipConfig and
651
+ XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
652
+ and XYRandomRotate90 (in XY) to the images.
653
+ independent_channels : bool, optional
654
+ Whether to train all channels independently, by default False.
655
+ loss : Literal["mae", "mse"], default="mae"
656
+ Loss function to use.
657
+ n_channels_in : int or None, default=None
658
+ Number of channels in.
659
+ n_channels_out : int or None, default=None
660
+ Number of channels out.
661
+ logger : Literal["wandb", "tensorboard", "none"], default="none"
662
+ Logger to use.
663
+ trainer_params : dict, optional
664
+ Parameters for the trainer class, see PyTorch Lightning documentation.
665
+ model_params : dict, default=None
666
+ UNetModel parameters.
667
+ optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
668
+ Optimizer to use.
669
+ optimizer_params : dict, default=None
670
+ Parameters for the optimizer, see PyTorch documentation for more details.
671
+ lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
672
+ Learning rate scheduler to use.
673
+ lr_scheduler_params : dict, default=None
674
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
675
+ details.
676
+ train_dataloader_params : dict, optional
677
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
678
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
679
+ the `GeneralDataConfig`.
680
+ val_dataloader_params : dict, optional
681
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
682
+ If left as `None`, the empty dict `{}` will be used, this is set in the
683
+ `GeneralDataConfig`.
684
+ checkpoint_params : dict, default=None
685
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
686
+ (`ModelCheckpoint`) for the list of available parameters.
687
+
688
+ Returns
689
+ -------
690
+ Configuration
691
+ Configuration for training CARE.
692
+
693
+ Examples
694
+ --------
695
+ Minimum example:
696
+ >>> config = create_care_configuration(
697
+ ... experiment_name="care_experiment",
698
+ ... data_type="array",
699
+ ... axes="YX",
700
+ ... patch_size=[64, 64],
701
+ ... batch_size=32,
702
+ ... num_epochs=100
703
+ ... )
704
+
705
+ You can also limit the number of batches per epoch:
706
+ >>> config = create_care_configuration(
707
+ ... experiment_name="care_experiment",
708
+ ... data_type="array",
709
+ ... axes="YX",
710
+ ... patch_size=[64, 64],
711
+ ... batch_size=32,
712
+ ... num_steps=100 # limit to 100 batches per epoch
713
+ ... )
714
+
715
+ To disable transforms, simply set `augmentations` to an empty list:
716
+ >>> config = create_care_configuration(
717
+ ... experiment_name="care_experiment",
718
+ ... data_type="array",
719
+ ... axes="YX",
720
+ ... patch_size=[64, 64],
721
+ ... batch_size=32,
722
+ ... num_epochs=100,
723
+ ... augmentations=[]
724
+ ... )
725
+
726
+ A list of transforms can be passed to the `augmentations` parameter to replace the
727
+ default augmentations:
728
+ >>> from careamics.config.transformations import XYFlipConfig
729
+ >>> config = create_care_configuration(
730
+ ... experiment_name="care_experiment",
731
+ ... data_type="array",
732
+ ... axes="YX",
733
+ ... patch_size=[64, 64],
734
+ ... batch_size=32,
735
+ ... num_epochs=100,
736
+ ... augmentations=[
737
+ ... # No rotation and only Y flipping
738
+ ... XYFlipConfig(flip_x = False, flip_y = True)
739
+ ... ]
740
+ ... )
741
+
742
+ If you are training multiple channels they will be trained independently by default,
743
+ you simply need to specify the number of channels input (and optionally, the number
744
+ of channels output):
745
+ >>> config = create_care_configuration(
746
+ ... experiment_name="care_experiment",
747
+ ... data_type="array",
748
+ ... axes="YXC", # channels must be in the axes
749
+ ... patch_size=[64, 64],
750
+ ... batch_size=32,
751
+ ... num_epochs=100,
752
+ ... n_channels_in=3, # number of input channels
753
+ ... n_channels_out=1 # if applicable
754
+ ... )
755
+
756
+ If instead you want to train multiple channels together, you need to turn off the
757
+ `independent_channels` parameter:
758
+ >>> config = create_care_configuration(
759
+ ... experiment_name="care_experiment",
760
+ ... data_type="array",
761
+ ... axes="YXC", # channels must be in the axes
762
+ ... patch_size=[64, 64],
763
+ ... batch_size=32,
764
+ ... num_epochs=100,
765
+ ... independent_channels=False,
766
+ ... n_channels_in=3,
767
+ ... n_channels_out=1 # if applicable
768
+ ... )
769
+
770
+ If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
771
+ `axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
772
+ for 3-D data but spatial context along the Z dimension will then not be taken into
773
+ account.
774
+ >>> config_2d = create_care_configuration(
775
+ ... experiment_name="care_experiment",
776
+ ... data_type="czi",
777
+ ... axes="SCYX",
778
+ ... patch_size=[64, 64],
779
+ ... batch_size=32,
780
+ ... num_epochs=100,
781
+ ... n_channels_in=1,
782
+ ... )
783
+ >>> config_3d = create_care_configuration(
784
+ ... experiment_name="care_experiment",
785
+ ... data_type="czi",
786
+ ... axes="SCZYX",
787
+ ... patch_size=[16, 64, 64],
788
+ ... batch_size=16,
789
+ ... num_epochs=100,
790
+ ... n_channels_in=1,
791
+ ... )
792
+ """
793
+ return Configuration(
794
+ **_create_supervised_config_dict(
795
+ algorithm="care",
796
+ experiment_name=experiment_name,
797
+ data_type=data_type,
798
+ axes=axes,
799
+ patch_size=patch_size,
800
+ batch_size=batch_size,
801
+ augmentations=augmentations,
802
+ independent_channels=independent_channels,
803
+ loss=loss,
804
+ n_channels_in=n_channels_in,
805
+ n_channels_out=n_channels_out,
806
+ logger=logger,
807
+ trainer_params=trainer_params,
808
+ model_params=model_params,
809
+ optimizer=optimizer,
810
+ optimizer_params=optimizer_params,
811
+ lr_scheduler=lr_scheduler,
812
+ lr_scheduler_params=lr_scheduler_params,
813
+ train_dataloader_params=train_dataloader_params,
814
+ val_dataloader_params=val_dataloader_params,
815
+ checkpoint_params=checkpoint_params,
816
+ num_epochs=num_epochs,
817
+ num_steps=num_steps,
818
+ )
819
+ )
820
+
821
+
822
+ def create_n2n_configuration(
823
+ experiment_name: str,
824
+ data_type: Literal["array", "tiff", "czi", "custom"],
825
+ axes: str,
826
+ patch_size: Sequence[int],
827
+ batch_size: int,
828
+ num_epochs: int = 100,
829
+ num_steps: int | None = None,
830
+ augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
831
+ independent_channels: bool = True,
832
+ loss: Literal["mae", "mse"] = "mae",
833
+ n_channels_in: int | None = None,
834
+ n_channels_out: int | None = None,
835
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
836
+ trainer_params: dict | None = None,
837
+ model_params: dict | None = None,
838
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
839
+ optimizer_params: dict[str, Any] | None = None,
840
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
841
+ lr_scheduler_params: dict[str, Any] | None = None,
842
+ train_dataloader_params: dict[str, Any] | None = None,
843
+ val_dataloader_params: dict[str, Any] | None = None,
844
+ checkpoint_params: dict[str, Any] | None = None,
845
+ ) -> Configuration:
846
+ """
847
+ Create a configuration for training Noise2Noise.
848
+
849
+ If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
850
+ 2.
851
+
852
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
853
+ channels. Likewise, if you set the number of channels, then "C" must be present in
854
+ `axes`.
855
+
856
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
857
+ not specified, it will be assumed to be equal to `n_channels_in`.
858
+
859
+ By default, all channels are trained together. To train all channels independently,
860
+ set `independent_channels` to True.
861
+
862
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
863
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
864
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
865
+ disable the transforms, simply pass an empty list.
866
+
867
+ Parameters
868
+ ----------
869
+ experiment_name : str
870
+ Name of the experiment.
871
+ data_type : Literal["array", "tiff", "czi", "custom"]
872
+ Type of the data.
873
+ axes : str
874
+ Axes of the data (e.g. SYX).
875
+ patch_size : List[int]
876
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
877
+ batch_size : int
878
+ Batch size.
879
+ num_epochs : int, default=100
880
+ Number of epochs to train for. If provided, this will be added to
881
+ trainer_params.
882
+ num_steps : int, optional
883
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
884
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
885
+ documentation for more details.
886
+ augmentations : list of transforms, default=None
887
+ List of transforms to apply, either both or one of XYFlipConfig and
888
+ XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
889
+ and XYRandomRotate90 (in XY) to the images.
890
+ independent_channels : bool, optional
891
+ Whether to train all channels independently, by default False.
892
+ loss : Literal["mae", "mse"], optional
893
+ Loss function to use, by default "mae".
894
+ n_channels_in : int or None, default=None
895
+ Number of channels in.
896
+ n_channels_out : int or None, default=None
897
+ Number of channels out.
898
+ logger : Literal["wandb", "tensorboard", "none"], optional
899
+ Logger to use, by default "none".
900
+ trainer_params : dict, optional
901
+ Parameters for the trainer class, see PyTorch Lightning documentation.
902
+ model_params : dict, default=None
903
+ UNetModel parameters.
904
+ optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
905
+ Optimizer to use.
906
+ optimizer_params : dict, default=None
907
+ Parameters for the optimizer, see PyTorch documentation for more details.
908
+ lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
909
+ Learning rate scheduler to use.
910
+ lr_scheduler_params : dict, default=None
911
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
912
+ details.
913
+ train_dataloader_params : dict, optional
914
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
915
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
916
+ the `GeneralDataConfig`.
917
+ val_dataloader_params : dict, optional
918
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
919
+ If left as `None`, the empty dict `{}` will be used, this is set in the
920
+ `GeneralDataConfig`.
921
+ checkpoint_params : dict, default=None
922
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
923
+ (`ModelCheckpoint`) for the list of available parameters.
924
+
925
+ Returns
926
+ -------
927
+ Configuration
928
+ Configuration for training Noise2Noise.
929
+
930
+ Examples
931
+ --------
932
+ Minimum example:
933
+ >>> config = create_n2n_configuration(
934
+ ... experiment_name="n2n_experiment",
935
+ ... data_type="array",
936
+ ... axes="YX",
937
+ ... patch_size=[64, 64],
938
+ ... batch_size=32,
939
+ ... num_epochs=100
940
+ ... )
941
+
942
+ You can also limit the number of batches per epoch:
943
+ >>> config = create_n2n_configuration(
944
+ ... experiment_name="n2n_experiment",
945
+ ... data_type="array",
946
+ ... axes="YX",
947
+ ... patch_size=[64, 64],
948
+ ... batch_size=32,
949
+ ... num_steps=100 # limit to 100 batches per epoch
950
+ ... )
951
+
952
+ To disable transforms, simply set `augmentations` to an empty list:
953
+ >>> config = create_n2n_configuration(
954
+ ... experiment_name="n2n_experiment",
955
+ ... data_type="array",
956
+ ... axes="YX",
957
+ ... patch_size=[64, 64],
958
+ ... batch_size=32,
959
+ ... num_epochs=100,
960
+ ... augmentations=[]
961
+ ... )
962
+
963
+ A list of transforms can be passed to the `augmentations` parameter:
964
+ >>> from careamics.config.transformations import XYFlipConfig
965
+ >>> config = create_n2n_configuration(
966
+ ... experiment_name="n2n_experiment",
967
+ ... data_type="array",
968
+ ... axes="YX",
969
+ ... patch_size=[64, 64],
970
+ ... batch_size=32,
971
+ ... num_epochs=100,
972
+ ... augmentations=[
973
+ ... # No rotation and only Y flipping
974
+ ... XYFlipConfig(flip_x = False, flip_y = True)
975
+ ... ]
976
+ ... )
977
+
978
+ If you are training multiple channels they will be trained independently by default,
979
+ you simply need to specify the number of channels input (and optionally, the number
980
+ of channels output):
981
+ >>> config = create_n2n_configuration(
982
+ ... experiment_name="n2n_experiment",
983
+ ... data_type="array",
984
+ ... axes="YXC", # channels must be in the axes
985
+ ... patch_size=[64, 64],
986
+ ... batch_size=32,
987
+ ... num_epochs=100,
988
+ ... n_channels_in=3, # number of input channels
989
+ ... n_channels_out=1 # if applicable
990
+ ... )
991
+
992
+ If instead you want to train multiple channels together, you need to turn off the
993
+ `independent_channels` parameter:
994
+ >>> config = create_n2n_configuration(
995
+ ... experiment_name="n2n_experiment",
996
+ ... data_type="array",
997
+ ... axes="YXC", # channels must be in the axes
998
+ ... patch_size=[64, 64],
999
+ ... batch_size=32,
1000
+ ... num_epochs=100,
1001
+ ... independent_channels=False,
1002
+ ... n_channels_in=3,
1003
+ ... n_channels_out=1 # if applicable
1004
+ ... )
1005
+
1006
+ If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
1007
+ `axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
1008
+ for 3-D data but spatial context along the Z dimension will then not be taken into
1009
+ account.
1010
+ >>> config_2d = create_n2n_configuration(
1011
+ ... experiment_name="n2n_experiment",
1012
+ ... data_type="czi",
1013
+ ... axes="SCYX",
1014
+ ... patch_size=[64, 64],
1015
+ ... batch_size=32,
1016
+ ... num_epochs=100,
1017
+ ... n_channels_in=1,
1018
+ ... )
1019
+ >>> config_3d = create_n2n_configuration(
1020
+ ... experiment_name="n2n_experiment",
1021
+ ... data_type="czi",
1022
+ ... axes="SCZYX",
1023
+ ... patch_size=[16, 64, 64],
1024
+ ... batch_size=16,
1025
+ ... num_epochs=100,
1026
+ ... n_channels_in=1,
1027
+ ... )
1028
+ """
1029
+ return Configuration(
1030
+ **_create_supervised_config_dict(
1031
+ algorithm="n2n",
1032
+ experiment_name=experiment_name,
1033
+ data_type=data_type,
1034
+ axes=axes,
1035
+ patch_size=patch_size,
1036
+ batch_size=batch_size,
1037
+ trainer_params=trainer_params,
1038
+ augmentations=augmentations,
1039
+ independent_channels=independent_channels,
1040
+ loss=loss,
1041
+ n_channels_in=n_channels_in,
1042
+ n_channels_out=n_channels_out,
1043
+ logger=logger,
1044
+ model_params=model_params,
1045
+ optimizer=optimizer,
1046
+ optimizer_params=optimizer_params,
1047
+ lr_scheduler=lr_scheduler,
1048
+ lr_scheduler_params=lr_scheduler_params,
1049
+ train_dataloader_params=train_dataloader_params,
1050
+ val_dataloader_params=val_dataloader_params,
1051
+ checkpoint_params=checkpoint_params,
1052
+ num_epochs=num_epochs,
1053
+ num_steps=num_steps,
1054
+ )
1055
+ )
1056
+
1057
+
1058
+ def create_n2v_configuration(
1059
+ experiment_name: str,
1060
+ data_type: Literal["array", "tiff", "czi", "custom"],
1061
+ axes: str,
1062
+ patch_size: Sequence[int],
1063
+ batch_size: int,
1064
+ num_epochs: int = 100,
1065
+ num_steps: int | None = None,
1066
+ augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
1067
+ independent_channels: bool = True,
1068
+ use_n2v2: bool = False,
1069
+ n_channels: int | None = None,
1070
+ roi_size: int = 11,
1071
+ masked_pixel_percentage: float = 0.2,
1072
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
1073
+ struct_n2v_span: int = 5,
1074
+ trainer_params: dict | None = None,
1075
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
1076
+ model_params: dict | None = None,
1077
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
1078
+ optimizer_params: dict[str, Any] | None = None,
1079
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
1080
+ lr_scheduler_params: dict[str, Any] | None = None,
1081
+ train_dataloader_params: dict[str, Any] | None = None,
1082
+ val_dataloader_params: dict[str, Any] | None = None,
1083
+ checkpoint_params: dict[str, Any] | None = None,
1084
+ ) -> Configuration:
1085
+ """
1086
+ Create a configuration for training Noise2Void.
1087
+
1088
+ N2V uses a UNet model to denoise images in a self-supervised manner. To use its
1089
+ variants structN2V and N2V2, set the `struct_n2v_axis` and `struct_n2v_span`
1090
+ (structN2V) parameters, or set `use_n2v2` to True (N2V2).
1091
+
1092
+ N2V2 modifies the UNet architecture by adding blur pool layers and removes the skip
1093
+ connections, thus removing checkboard artefacts. StructN2V is used when vertical
1094
+ or horizontal correlations are present in the noise; it applies an additional mask
1095
+ to the manipulated pixel neighbors.
1096
+
1097
+ If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
1098
+ 2.
1099
+
1100
+ If "C" is present in `axes`, then you need to set `n_channels` to the number of
1101
+ channels.
1102
+
1103
+ By default, all channels are trained independently. To train all channels together,
1104
+ set `independent_channels` to False.
1105
+
1106
+ By default, the transformations applied are a random flip along X or Y, and a random
1107
+ 90 degrees rotation in the XY plane. Normalization is always applied, as well as the
1108
+ N2V manipulation.
1109
+
1110
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
1111
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
1112
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
1113
+ disable the transforms, simply pass an empty list.
1114
+
1115
+ The `roi_size` parameter specifies the size of the area around each pixel that will
1116
+ be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
1117
+ pixels per patch will be manipulated.
1118
+
1119
+ The parameters of the UNet can be specified in the `model_params` (passed as a
1120
+ parameter-value dictionary). Note that `use_n2v2` and 'n_channels' override the
1121
+ corresponding parameters passed in `model_params`.
1122
+
1123
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
1124
+ will be applied to each manipulated pixel.
1125
+
1126
+ Parameters
1127
+ ----------
1128
+ experiment_name : str
1129
+ Name of the experiment.
1130
+ data_type : Literal["array", "tiff", "czi", "custom"]
1131
+ Type of the data.
1132
+ axes : str
1133
+ Axes of the data (e.g. SYX).
1134
+ patch_size : List[int]
1135
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
1136
+ batch_size : int
1137
+ Batch size.
1138
+ num_epochs : int, default=100
1139
+ Number of epochs to train for. If provided, this will be added to
1140
+ trainer_params.
1141
+ num_steps : int, optional
1142
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
1143
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
1144
+ documentation for more details.
1145
+ augmentations : list of transforms, default=None
1146
+ List of transforms to apply, either both or one of XYFlipConfig and
1147
+ XYRandomRotate90Config. By default, it applies both XYFlip (on X and Y)
1148
+ and XYRandomRotate90 (in XY) to the images.
1149
+ independent_channels : bool, optional
1150
+ Whether to train all channels together, by default True.
1151
+ use_n2v2 : bool, optional
1152
+ Whether to use N2V2, by default False.
1153
+ n_channels : int or None, default=None
1154
+ Number of channels (in and out).
1155
+ roi_size : int, optional
1156
+ N2V pixel manipulation area, by default 11.
1157
+ masked_pixel_percentage : float, optional
1158
+ Percentage of pixels masked in each patch, by default 0.2.
1159
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
1160
+ Axis along which to apply structN2V mask, by default "none".
1161
+ struct_n2v_span : int, optional
1162
+ Span of the structN2V mask, by default 5.
1163
+ trainer_params : dict, optional
1164
+ Parameters for the trainer, see the relevant documentation.
1165
+ logger : Literal["wandb", "tensorboard", "none"], optional
1166
+ Logger to use, by default "none".
1167
+ model_params : dict, default=None
1168
+ UNetModel parameters.
1169
+ optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
1170
+ Optimizer to use.
1171
+ optimizer_params : dict, default=None
1172
+ Parameters for the optimizer, see PyTorch documentation for more details.
1173
+ lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
1174
+ Learning rate scheduler to use.
1175
+ lr_scheduler_params : dict, default=None
1176
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
1177
+ details.
1178
+ train_dataloader_params : dict, optional
1179
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
1180
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
1181
+ the `GeneralDataConfig`.
1182
+ val_dataloader_params : dict, optional
1183
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
1184
+ If left as `None`, the empty dict `{}` will be used, this is set in the
1185
+ `GeneralDataConfig`.
1186
+ checkpoint_params : dict, default=None
1187
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
1188
+ (`ModelCheckpoint`) for the list of available parameters.
1189
+
1190
+ Returns
1191
+ -------
1192
+ Configuration
1193
+ Configuration for training N2V.
1194
+
1195
+ Examples
1196
+ --------
1197
+ Minimum example:
1198
+ >>> config = create_n2v_configuration(
1199
+ ... experiment_name="n2v_experiment",
1200
+ ... data_type="array",
1201
+ ... axes="YX",
1202
+ ... patch_size=[64, 64],
1203
+ ... batch_size=32,
1204
+ ... num_epochs=100
1205
+ ... )
1206
+
1207
+ You can also limit the number of batches per epoch:
1208
+ >>> config = create_n2v_configuration(
1209
+ ... experiment_name="n2v_experiment",
1210
+ ... data_type="array",
1211
+ ... axes="YX",
1212
+ ... patch_size=[64, 64],
1213
+ ... batch_size=32,
1214
+ ... num_steps=100 # limit to 100 batches per epoch
1215
+ ... )
1216
+
1217
+ To disable transforms, simply set `augmentations` to an empty list:
1218
+ >>> config = create_n2v_configuration(
1219
+ ... experiment_name="n2v_experiment",
1220
+ ... data_type="array",
1221
+ ... axes="YX",
1222
+ ... patch_size=[64, 64],
1223
+ ... batch_size=32,
1224
+ ... num_epochs=100,
1225
+ ... augmentations=[]
1226
+ ... )
1227
+
1228
+ A list of transforms can be passed to the `augmentations` parameter:
1229
+ >>> from careamics.config.transformations import XYFlipConfig
1230
+ >>> config = create_n2v_configuration(
1231
+ ... experiment_name="n2v_experiment",
1232
+ ... data_type="array",
1233
+ ... axes="YX",
1234
+ ... patch_size=[64, 64],
1235
+ ... batch_size=32,
1236
+ ... num_epochs=100,
1237
+ ... augmentations=[
1238
+ ... # No rotation and only Y flipping
1239
+ ... XYFlipConfig(flip_x = False, flip_y = True)
1240
+ ... ]
1241
+ ... )
1242
+
1243
+ To use N2V2, simply pass the `use_n2v2` parameter:
1244
+ >>> config = create_n2v_configuration(
1245
+ ... experiment_name="n2v2_experiment",
1246
+ ... data_type="tiff",
1247
+ ... axes="YX",
1248
+ ... patch_size=[64, 64],
1249
+ ... batch_size=32,
1250
+ ... num_epochs=100,
1251
+ ... use_n2v2=True
1252
+ ... )
1253
+
1254
+ For structN2V, there are two parameters to set, `struct_n2v_axis` and
1255
+ `struct_n2v_span`:
1256
+ >>> config = create_n2v_configuration(
1257
+ ... experiment_name="structn2v_experiment",
1258
+ ... data_type="tiff",
1259
+ ... axes="YX",
1260
+ ... patch_size=[64, 64],
1261
+ ... batch_size=32,
1262
+ ... num_epochs=100,
1263
+ ... struct_n2v_axis="horizontal",
1264
+ ... struct_n2v_span=7
1265
+ ... )
1266
+
1267
+ If you are training multiple channels they will be trained independently by default,
1268
+ you simply need to specify the number of channels:
1269
+ >>> config = create_n2v_configuration(
1270
+ ... experiment_name="n2v_experiment",
1271
+ ... data_type="array",
1272
+ ... axes="YXC",
1273
+ ... patch_size=[64, 64],
1274
+ ... batch_size=32,
1275
+ ... num_epochs=100,
1276
+ ... n_channels=3
1277
+ ... )
1278
+
1279
+ If instead you want to train multiple channels together, you need to turn off the
1280
+ `independent_channels` parameter:
1281
+ >>> config = create_n2v_configuration(
1282
+ ... experiment_name="n2v_experiment",
1283
+ ... data_type="array",
1284
+ ... axes="YXC",
1285
+ ... patch_size=[64, 64],
1286
+ ... batch_size=32,
1287
+ ... num_epochs=100,
1288
+ ... independent_channels=False,
1289
+ ... n_channels=3
1290
+ ... )
1291
+
1292
+ If you would like to train on CZI files, use `"czi"` as `data_type` and `"SCYX"` as
1293
+ `axes` for 2-D or `"SCZYX"` for 3-D denoising. Note that `"SCYX"` can also be used
1294
+ for 3-D data but spatial context along the Z dimension will then not be taken into
1295
+ account.
1296
+ >>> config_2d = create_n2v_configuration(
1297
+ ... experiment_name="n2v_experiment",
1298
+ ... data_type="czi",
1299
+ ... axes="SCYX",
1300
+ ... patch_size=[64, 64],
1301
+ ... batch_size=32,
1302
+ ... num_epochs=100,
1303
+ ... n_channels=1,
1304
+ ... )
1305
+ >>> config_3d = create_n2v_configuration(
1306
+ ... experiment_name="n2v_experiment",
1307
+ ... data_type="czi",
1308
+ ... axes="SCZYX",
1309
+ ... patch_size=[16, 64, 64],
1310
+ ... batch_size=16,
1311
+ ... num_epochs=100,
1312
+ ... n_channels=1,
1313
+ ... )
1314
+ """
1315
+ # if there are channels, we need to specify their number
1316
+ if "C" in axes and n_channels is None:
1317
+ raise ValueError("Number of channels must be specified when using channels.")
1318
+ elif "C" not in axes and (n_channels is not None and n_channels > 1):
1319
+ raise ValueError(
1320
+ f"C is not present in the axes, but number of channels is specified "
1321
+ f"(got {n_channels} channel)."
1322
+ )
1323
+
1324
+ if n_channels is None:
1325
+ n_channels = 1
1326
+
1327
+ # augmentations
1328
+ spatial_transforms = _list_spatial_augmentations(augmentations)
1329
+
1330
+ # create the N2VManipulate transform using the supplied parameters
1331
+ n2v_transform = N2VManipulateConfig(
1332
+ name=SupportedTransform.N2V_MANIPULATE.value,
1333
+ strategy=(
1334
+ SupportedPixelManipulation.MEDIAN.value
1335
+ if use_n2v2
1336
+ else SupportedPixelManipulation.UNIFORM.value
1337
+ ),
1338
+ roi_size=roi_size,
1339
+ masked_pixel_percentage=masked_pixel_percentage,
1340
+ struct_mask_axis=struct_n2v_axis,
1341
+ struct_mask_span=struct_n2v_span,
1342
+ )
1343
+
1344
+ # algorithm
1345
+ algorithm_params = _create_algorithm_configuration(
1346
+ axes=axes,
1347
+ algorithm="n2v",
1348
+ loss="n2v",
1349
+ independent_channels=independent_channels,
1350
+ n_channels_in=n_channels,
1351
+ n_channels_out=n_channels,
1352
+ use_n2v2=use_n2v2,
1353
+ model_params=model_params,
1354
+ optimizer=optimizer,
1355
+ optimizer_params=optimizer_params,
1356
+ lr_scheduler=lr_scheduler,
1357
+ lr_scheduler_params=lr_scheduler_params,
1358
+ )
1359
+ algorithm_params["n2v_config"] = n2v_transform
1360
+
1361
+ # data
1362
+ data_params = _create_data_configuration(
1363
+ data_type=data_type,
1364
+ axes=axes,
1365
+ patch_size=patch_size,
1366
+ batch_size=batch_size,
1367
+ augmentations=spatial_transforms,
1368
+ train_dataloader_params=train_dataloader_params,
1369
+ val_dataloader_params=val_dataloader_params,
1370
+ )
1371
+
1372
+ # training
1373
+ final_trainer_params = update_trainer_params(
1374
+ trainer_params=trainer_params,
1375
+ num_epochs=num_epochs,
1376
+ num_steps=num_steps,
1377
+ )
1378
+ training_params = _create_training_configuration(
1379
+ trainer_params=final_trainer_params,
1380
+ logger=logger,
1381
+ checkpoint_params=checkpoint_params,
1382
+ )
1383
+
1384
+ return Configuration(
1385
+ experiment_name=experiment_name,
1386
+ algorithm_config=algorithm_params,
1387
+ data_config=data_params,
1388
+ training_config=training_params,
1389
+ )
1390
+
1391
+
1392
+ def _create_vae_configuration(
1393
+ input_shape: Sequence[int],
1394
+ encoder_conv_strides: tuple[int, ...],
1395
+ decoder_conv_strides: tuple[int, ...],
1396
+ multiscale_count: int,
1397
+ z_dims: tuple[int, ...],
1398
+ output_channels: int,
1399
+ encoder_n_filters: int,
1400
+ decoder_n_filters: int,
1401
+ encoder_dropout: float,
1402
+ decoder_dropout: float,
1403
+ nonlinearity: Literal[
1404
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1405
+ ],
1406
+ predict_logvar: Literal[None, "pixelwise"],
1407
+ analytical_kl: bool,
1408
+ ) -> LVAEConfig:
1409
+ """Create a dictionary with the parameters of the vae based algorithm model.
1410
+
1411
+ Parameters
1412
+ ----------
1413
+ input_shape : tuple[int, ...]
1414
+ Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D.
1415
+ encoder_conv_strides : tuple[int, ...]
1416
+ Strides of the encoder convolutional layers, length also defines 2D or 3D.
1417
+ decoder_conv_strides : tuple[int, ...]
1418
+ Strides of the decoder convolutional layers, length also defines 2D or 3D.
1419
+ multiscale_count : int
1420
+ Number of lateral context layers, specific to MicroSplit.
1421
+ z_dims : tuple[int, ...]
1422
+ Number of hierarchies in the LVAE model.
1423
+ output_channels : int
1424
+ Number of output channels.
1425
+ encoder_n_filters : int
1426
+ Number of filters in the convolutional layers of the encoder.
1427
+ decoder_n_filters : int
1428
+ Number of filters in the convolutional layers of the decoder.
1429
+ encoder_dropout : float
1430
+ Dropout rate for the encoder.
1431
+ decoder_dropout : float
1432
+ Dropout rate for the decoder.
1433
+ nonlinearity : Literal
1434
+ Type of nonlinearity function to use.
1435
+ predict_logvar : Literal # TODO needs review
1436
+ _description_.
1437
+ analytical_kl : bool # TODO needs clarification
1438
+ _description_.
1439
+
1440
+ Returns
1441
+ -------
1442
+ LVAEModel
1443
+ LVAE model with the specified parameters.
1444
+ """
1445
+ return LVAEConfig(
1446
+ architecture=SupportedArchitecture.LVAE.value,
1447
+ input_shape=input_shape,
1448
+ encoder_conv_strides=encoder_conv_strides,
1449
+ decoder_conv_strides=decoder_conv_strides,
1450
+ multiscale_count=multiscale_count,
1451
+ z_dims=z_dims,
1452
+ output_channels=output_channels,
1453
+ encoder_n_filters=encoder_n_filters,
1454
+ decoder_n_filters=decoder_n_filters,
1455
+ encoder_dropout=encoder_dropout,
1456
+ decoder_dropout=decoder_dropout,
1457
+ nonlinearity=nonlinearity,
1458
+ predict_logvar=predict_logvar,
1459
+ analytical_kl=analytical_kl,
1460
+ )
1461
+
1462
+
1463
+ def _create_vae_based_algorithm(
1464
+ algorithm: Literal["hdn", "microsplit"],
1465
+ loss: LVAELossConfig,
1466
+ input_shape: Sequence[int],
1467
+ encoder_conv_strides: tuple[int, ...],
1468
+ decoder_conv_strides: tuple[int, ...],
1469
+ multiscale_count: int,
1470
+ z_dims: tuple[int, ...],
1471
+ output_channels: int,
1472
+ encoder_n_filters: int,
1473
+ decoder_n_filters: int,
1474
+ encoder_dropout: float,
1475
+ decoder_dropout: float,
1476
+ nonlinearity: Literal[
1477
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1478
+ ],
1479
+ predict_logvar: Literal[None, "pixelwise"],
1480
+ analytical_kl: bool,
1481
+ gaussian_likelihood: GaussianLikelihoodConfig | None = None,
1482
+ nm_likelihood: NMLikelihoodConfig | None = None,
1483
+ ) -> dict:
1484
+ """
1485
+ Create a dictionary with the parameters of the VAE-based algorithm model.
1486
+
1487
+ Parameters
1488
+ ----------
1489
+ algorithm : Literal["hdn"]
1490
+ The algorithm type.
1491
+ loss : Literal["hdn"]
1492
+ The loss function type.
1493
+ input_shape : tuple[int, ...]
1494
+ The shape of the input data.
1495
+ encoder_conv_strides : list[int]
1496
+ The strides of the encoder convolutional layers.
1497
+ decoder_conv_strides : list[int]
1498
+ The strides of the decoder convolutional layers.
1499
+ multiscale_count : int
1500
+ The number of multiscale layers.
1501
+ z_dims : list[int]
1502
+ The dimensions of the latent space.
1503
+ output_channels : int
1504
+ The number of output channels.
1505
+ encoder_n_filters : int
1506
+ The number of filters in the encoder.
1507
+ decoder_n_filters : int
1508
+ The number of filters in the decoder.
1509
+ encoder_dropout : float
1510
+ The dropout rate for the encoder.
1511
+ decoder_dropout : float
1512
+ The dropout rate for the decoder.
1513
+ nonlinearity : Literal
1514
+ The nonlinearity function to use.
1515
+ predict_logvar : Literal[None, "pixelwise"]
1516
+ The type of log variance prediction.
1517
+ analytical_kl : bool
1518
+ Whether to use analytical KL divergence.
1519
+ gaussian_likelihood : Optional[GaussianLikelihoodConfig], optional
1520
+ The Gaussian likelihood model, by default None.
1521
+ nm_likelihood : Optional[NMLikelihoodConfig], optional
1522
+ The noise model likelihood model, by default None.
1523
+
1524
+ Returns
1525
+ -------
1526
+ dict
1527
+ A dictionary with the parameters of the VAE-based algorithm model.
1528
+ """
1529
+ network_model = _create_vae_configuration(
1530
+ input_shape=input_shape,
1531
+ encoder_conv_strides=encoder_conv_strides,
1532
+ decoder_conv_strides=decoder_conv_strides,
1533
+ multiscale_count=multiscale_count,
1534
+ z_dims=z_dims,
1535
+ output_channels=output_channels,
1536
+ encoder_n_filters=encoder_n_filters,
1537
+ decoder_n_filters=decoder_n_filters,
1538
+ encoder_dropout=encoder_dropout,
1539
+ decoder_dropout=decoder_dropout,
1540
+ nonlinearity=nonlinearity,
1541
+ predict_logvar=predict_logvar,
1542
+ analytical_kl=analytical_kl,
1543
+ )
1544
+ assert gaussian_likelihood or nm_likelihood, "Likelihood model must be specified"
1545
+ return {
1546
+ "algorithm": algorithm,
1547
+ "loss": loss,
1548
+ "model": network_model,
1549
+ "gaussian_likelihood": gaussian_likelihood,
1550
+ "noise_model_likelihood": nm_likelihood,
1551
+ }
1552
+
1553
+
1554
+ def get_likelihood_config(
1555
+ loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
1556
+ # TODO remove different microsplit loss types, refac
1557
+ predict_logvar: Literal["pixelwise"] | None = None,
1558
+ logvar_lowerbound: float | None = -5.0,
1559
+ nm_paths: list[str] | None = None,
1560
+ data_stats: tuple[float, float] | None = None,
1561
+ ) -> tuple[
1562
+ GaussianLikelihoodConfig | None,
1563
+ MultiChannelNMConfig | None,
1564
+ NMLikelihoodConfig | None,
1565
+ ]:
1566
+ """Get the likelihood configuration for split models.
1567
+
1568
+ Returns a tuple containing the following optional entries:
1569
+ - GaussianLikelihoodConfig: Gaussian likelihood configuration for musplit losses
1570
+ - MultiChannelNMConfig: Multi-channel noise model configuration for denoisplit
1571
+ losses
1572
+ - NMLikelihoodConfig: Noise model likelihood configuration for denoisplit losses
1573
+
1574
+ Parameters
1575
+ ----------
1576
+ loss_type : Literal["musplit", "denoisplit", "denoisplit_musplit"]
1577
+ The type of loss function to use.
1578
+ predict_logvar : Literal["pixelwise"] | None, optional
1579
+ Type of log variance prediction, by default None.
1580
+ Required when loss_type is "musplit" or "denoisplit_musplit".
1581
+ logvar_lowerbound : float | None, optional
1582
+ Lower bound for the log variance, by default -5.0.
1583
+ Used when loss_type is "musplit" or "denoisplit_musplit".
1584
+ nm_paths : list[str] | None, optional
1585
+ Paths to the noise model files, by default None.
1586
+ Required when loss_type is "denoisplit" or "denoisplit_musplit".
1587
+ data_stats : tuple[float, float] | None, optional
1588
+ Data statistics (mean, std), by default None.
1589
+ Required when loss_type is "denoisplit" or "denoisplit_musplit".
1590
+
1591
+ Returns
1592
+ -------
1593
+ gaussian_lik_config : GaussianLikelihoodConfig | None
1594
+ Gaussian likelihood configuration for musplit losses, or None.
1595
+ nm_config : MultiChannelNMConfig | None
1596
+ Multi-channel noise model configuration for denoisplit losses, or None.
1597
+ nm_lik_config : NMLikelihoodConfig | None
1598
+ Noise model likelihood configuration for denoisplit losses, or None.
1599
+
1600
+ Raises
1601
+ ------
1602
+ ValueError
1603
+ If required parameters are missing for the specified loss_type.
1604
+ """
1605
+ # gaussian likelihood
1606
+ if loss_type in ["musplit", "denoisplit_musplit"]:
1607
+ # if predict_logvar is None:
1608
+ # raise ValueError(f"predict_logvar is required for '{loss_type}'")
1609
+ # TODO validators should be in pydantic models
1610
+ gaussian_lik_config = GaussianLikelihoodConfig(
1611
+ predict_logvar=predict_logvar,
1612
+ logvar_lowerbound=logvar_lowerbound,
1613
+ )
1614
+ else:
1615
+ gaussian_lik_config = None
1616
+
1617
+ # noise model likelihood
1618
+ if loss_type in ["denoisplit", "denoisplit_musplit"]:
1619
+ # if nm_paths is None:
1620
+ # raise ValueError(f"nm_paths is required for loss_type '{loss_type}'")
1621
+ # if data_stats is None:
1622
+ # raise ValueError(f"data_stats is required for loss_type '{loss_type}'")
1623
+ # TODO validators should be in pydantic models
1624
+ gmm_list = []
1625
+ if nm_paths is not None:
1626
+ for NM_path in nm_paths:
1627
+ gmm_list.append(
1628
+ GaussianMixtureNMConfig(
1629
+ model_type="GaussianMixtureNoiseModel",
1630
+ path=NM_path,
1631
+ )
1632
+ )
1633
+ noise_model_config = MultiChannelNMConfig(noise_models=gmm_list)
1634
+ nm_lik_config = NMLikelihoodConfig() # TODO this config isn't needed probably
1635
+ else:
1636
+ noise_model_config = None
1637
+ nm_lik_config = None
1638
+
1639
+ return gaussian_lik_config, noise_model_config, nm_lik_config
1640
+
1641
+
1642
+ # TODO wrap parameters into model, loss etc
1643
+ # TODO refac likelihood configs to make it 1. Can it be done ?
1644
+ def create_hdn_configuration(
1645
+ experiment_name: str,
1646
+ data_type: Literal["array", "tiff", "custom"],
1647
+ axes: str,
1648
+ patch_size: Sequence[int],
1649
+ batch_size: int,
1650
+ num_epochs: int = 100,
1651
+ num_steps: int | None = None,
1652
+ encoder_conv_strides: tuple[int, ...] = (2, 2),
1653
+ decoder_conv_strides: tuple[int, ...] = (2, 2),
1654
+ multiscale_count: int = 1,
1655
+ z_dims: tuple[int, ...] = (128, 128),
1656
+ output_channels: int = 1,
1657
+ encoder_n_filters: int = 32,
1658
+ decoder_n_filters: int = 32,
1659
+ encoder_dropout: float = 0.0,
1660
+ decoder_dropout: float = 0.0,
1661
+ nonlinearity: Literal[
1662
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1663
+ ] = "ReLU",
1664
+ analytical_kl: bool = False,
1665
+ predict_logvar: Literal["pixelwise"] | None = None,
1666
+ logvar_lowerbound: Union[float, None] = None,
1667
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
1668
+ trainer_params: dict | None = None,
1669
+ augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
1670
+ train_dataloader_params: dict[str, Any] | None = None,
1671
+ val_dataloader_params: dict[str, Any] | None = None,
1672
+ ) -> Configuration:
1673
+ """
1674
+ Create a configuration for training HDN.
1675
+
1676
+ If "Z" is present in `axes`, then `patch_size` must be a list of length 3, otherwise
1677
+ 2.
1678
+
1679
+ If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
1680
+ channels. Likewise, if you set the number of channels, then "C" must be present in
1681
+ `axes`.
1682
+
1683
+ To set the number of output channels, use the `n_channels_out` parameter. If it is
1684
+ not specified, it will be assumed to be equal to `n_channels_in`.
1685
+
1686
+ By default, all channels are trained independently. To train all channels together,
1687
+ set `independent_channels` to False.
1688
+
1689
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
1690
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
1691
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
1692
+ disable the transforms, simply pass an empty list.
1693
+
1694
+ # TODO revisit the necessity of model_params
1695
+
1696
+ Parameters
1697
+ ----------
1698
+ experiment_name : str
1699
+ Name of the experiment.
1700
+ data_type : Literal["array", "tiff", "custom"]
1701
+ Type of the data.
1702
+ axes : str
1703
+ Axes of the data (e.g. SYX).
1704
+ patch_size : List[int]
1705
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
1706
+ batch_size : int
1707
+ Batch size.
1708
+ num_epochs : int, default=100
1709
+ Number of epochs to train for. If provided, this will be added to
1710
+ trainer_params.
1711
+ num_steps : int, optional
1712
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
1713
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
1714
+ documentation for more details.
1715
+ encoder_conv_strides : tuple[int, ...], optional
1716
+ Strides for the encoder convolutional layers, by default (2, 2).
1717
+ decoder_conv_strides : tuple[int, ...], optional
1718
+ Strides for the decoder convolutional layers, by default (2, 2).
1719
+ multiscale_count : int, optional
1720
+ Number of scales in the multiscale architecture, by default 1.
1721
+ z_dims : tuple[int, ...], optional
1722
+ Dimensions of the latent space, by default (128, 128).
1723
+ output_channels : int, optional
1724
+ Number of output channels, by default 1.
1725
+ encoder_n_filters : int, optional
1726
+ Number of filters in the encoder, by default 32.
1727
+ decoder_n_filters : int, optional
1728
+ Number of filters in the decoder, by default 32.
1729
+ encoder_dropout : float, optional
1730
+ Dropout rate for the encoder, by default 0.0.
1731
+ decoder_dropout : float, optional
1732
+ Dropout rate for the decoder, by default 0.0.
1733
+ nonlinearity : Literal, optional
1734
+ Nonlinearity function to use, by default "ReLU".
1735
+ analytical_kl : bool, optional
1736
+ Whether to use analytical KL divergence, by default False.
1737
+ predict_logvar : Literal[None, "pixelwise"], optional
1738
+ Type of log variance prediction, by default None.
1739
+ logvar_lowerbound : Union[float, None], optional
1740
+ Lower bound for the log variance, by default None.
1741
+ logger : Literal["wandb", "tensorboard", "none"], optional
1742
+ Logger to use for training, by default "none".
1743
+ trainer_params : dict, optional
1744
+ Parameters for the trainer class, see PyTorch Lightning documentation.
1745
+ augmentations : list[XYFlipConfig | XYRandomRotate90Config] | None, optional
1746
+ List of augmentations to apply, by default None.
1747
+ train_dataloader_params : Optional[dict[str, Any]], optional
1748
+ Parameters for the training dataloader, by default None.
1749
+ val_dataloader_params : Optional[dict[str, Any]], optional
1750
+ Parameters for the validation dataloader, by default None.
1751
+
1752
+ Returns
1753
+ -------
1754
+ Configuration
1755
+ The configuration object for training HDN.
1756
+
1757
+ Examples
1758
+ --------
1759
+ Minimum example:
1760
+ >>> config = create_hdn_configuration(
1761
+ ... experiment_name="hdn_experiment",
1762
+ ... data_type="array",
1763
+ ... axes="YX",
1764
+ ... patch_size=[64, 64],
1765
+ ... batch_size=32,
1766
+ ... num_epochs=100
1767
+ ... )
1768
+
1769
+ You can also limit the number of batches per epoch:
1770
+ >>> config = create_hdn_configuration(
1771
+ ... experiment_name="hdn_experiment",
1772
+ ... data_type="array",
1773
+ ... axes="YX",
1774
+ ... patch_size=[64, 64],
1775
+ ... batch_size=32,
1776
+ ... num_steps=100 # limit to 100 batches per epoch
1777
+ ... )
1778
+ """
1779
+ transform_list = _list_spatial_augmentations(augmentations)
1780
+
1781
+ loss_config = LVAELossConfig(
1782
+ loss_type="hdn", denoisplit_weight=1, musplit_weight=0
1783
+ ) # TODO what are the correct defaults for HDN?
1784
+
1785
+ gaussian_likelihood = GaussianLikelihoodConfig(
1786
+ predict_logvar=predict_logvar, logvar_lowerbound=logvar_lowerbound
1787
+ )
1788
+
1789
+ # algorithm & model
1790
+ algorithm_params = _create_vae_based_algorithm(
1791
+ algorithm="hdn",
1792
+ loss=loss_config,
1793
+ input_shape=patch_size,
1794
+ encoder_conv_strides=encoder_conv_strides,
1795
+ decoder_conv_strides=decoder_conv_strides,
1796
+ multiscale_count=multiscale_count,
1797
+ z_dims=z_dims,
1798
+ output_channels=output_channels,
1799
+ encoder_n_filters=encoder_n_filters,
1800
+ decoder_n_filters=decoder_n_filters,
1801
+ encoder_dropout=encoder_dropout,
1802
+ decoder_dropout=decoder_dropout,
1803
+ nonlinearity=nonlinearity,
1804
+ predict_logvar=predict_logvar,
1805
+ analytical_kl=analytical_kl,
1806
+ gaussian_likelihood=gaussian_likelihood,
1807
+ nm_likelihood=None,
1808
+ )
1809
+
1810
+ # data
1811
+ data_params = _create_data_configuration(
1812
+ data_type=data_type,
1813
+ axes=axes,
1814
+ patch_size=patch_size,
1815
+ batch_size=batch_size,
1816
+ augmentations=transform_list,
1817
+ train_dataloader_params=train_dataloader_params,
1818
+ val_dataloader_params=val_dataloader_params,
1819
+ )
1820
+
1821
+ # training
1822
+ final_trainer_params = update_trainer_params(
1823
+ trainer_params=trainer_params,
1824
+ num_epochs=num_epochs,
1825
+ num_steps=num_steps,
1826
+ )
1827
+ training_params = _create_training_configuration(
1828
+ trainer_params=final_trainer_params,
1829
+ logger=logger,
1830
+ )
1831
+
1832
+ return Configuration(
1833
+ experiment_name=experiment_name,
1834
+ algorithm_config=algorithm_params,
1835
+ data_config=data_params,
1836
+ training_config=training_params,
1837
+ )
1838
+
1839
+
1840
+ def create_microsplit_configuration(
1841
+ experiment_name: str,
1842
+ data_type: Literal["array", "tiff", "custom"],
1843
+ axes: str,
1844
+ patch_size: Sequence[int],
1845
+ batch_size: int,
1846
+ num_epochs: int = 100,
1847
+ num_steps: int | None = None,
1848
+ encoder_conv_strides: tuple[int, ...] = (2, 2),
1849
+ decoder_conv_strides: tuple[int, ...] = (2, 2),
1850
+ multiscale_count: int = 3,
1851
+ grid_size: int = 32, # TODO most likely can be derived from patch size
1852
+ z_dims: tuple[int, ...] = (128, 128),
1853
+ output_channels: int = 1,
1854
+ encoder_n_filters: int = 32,
1855
+ decoder_n_filters: int = 32,
1856
+ encoder_dropout: float = 0.0,
1857
+ decoder_dropout: float = 0.0,
1858
+ nonlinearity: Literal[
1859
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
1860
+ ] = "ReLU", # TODO do we need all these?
1861
+ analytical_kl: bool = False,
1862
+ predict_logvar: Literal["pixelwise"] = "pixelwise",
1863
+ logvar_lowerbound: Union[float, None] = None,
1864
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
1865
+ trainer_params: dict | None = None,
1866
+ augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
1867
+ nm_paths: list[str] | None = None,
1868
+ data_stats: tuple[float, float] | None = None,
1869
+ train_dataloader_params: dict[str, Any] | None = None,
1870
+ val_dataloader_params: dict[str, Any] | None = None,
1871
+ ) -> Configuration:
1872
+ """
1873
+ Create a configuration for training MicroSplit.
1874
+
1875
+ Parameters
1876
+ ----------
1877
+ experiment_name : str
1878
+ Name of the experiment.
1879
+ data_type : Literal["array", "tiff", "custom"]
1880
+ Type of the data.
1881
+ axes : str
1882
+ Axes of the data (e.g. SYX).
1883
+ patch_size : Sequence[int]
1884
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
1885
+ batch_size : int
1886
+ Batch size.
1887
+ num_epochs : int, default=100
1888
+ Number of epochs to train for. If provided, this will be added to
1889
+ trainer_params.
1890
+ num_steps : int, optional
1891
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
1892
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
1893
+ documentation for more details.
1894
+ encoder_conv_strides : tuple[int, ...], optional
1895
+ Strides for the encoder convolutional layers, by default (2, 2).
1896
+ decoder_conv_strides : tuple[int, ...], optional
1897
+ Strides for the decoder convolutional layers, by default (2, 2).
1898
+ multiscale_count : int, optional
1899
+ Number of multiscale levels, by default 3.
1900
+ grid_size : int, optional
1901
+ Size of the grid for multiscale training, by default 32.
1902
+ z_dims : tuple[int, ...], optional
1903
+ List of latent dims for each hierarchy level in the LVAE, default (128, 128).
1904
+ output_channels : int, optional
1905
+ Number of output channels for the model, by default 1.
1906
+ encoder_n_filters : int, optional
1907
+ Number of filters in the encoder, by default 32.
1908
+ decoder_n_filters : int, optional
1909
+ Number of filters in the decoder, by default 32.
1910
+ encoder_dropout : float, optional
1911
+ Dropout rate for the encoder, by default 0.0.
1912
+ decoder_dropout : float, optional
1913
+ Dropout rate for the decoder, by default 0.0.
1914
+ nonlinearity : Literal, optional
1915
+ Nonlinearity to use in the model, by default "ReLU".
1916
+ analytical_kl : bool, optional
1917
+ Whether to use analytical KL divergence, by default False.
1918
+ predict_logvar : Literal["pixelwise"] | None, optional
1919
+ Type of log-variance prediction, by default None.
1920
+ logvar_lowerbound : Union[float, None], optional
1921
+ Lower bound for the log variance, by default None.
1922
+ logger : Literal["wandb", "tensorboard", "none"], optional
1923
+ Logger to use for training, by default "none".
1924
+ trainer_params : dict, optional
1925
+ Parameters for the trainer class, see PyTorch Lightning documentation.
1926
+ augmentations : list[Union[XYFlipConfig, XYRandomRotate90Config]] | None, optional
1927
+ List of augmentations to apply, by default None.
1928
+ nm_paths : list[str] | None, optional
1929
+ Paths to the noise model files, by default None.
1930
+ data_stats : tuple[float, float] | None, optional
1931
+ Data statistics (mean, std), by default None.
1932
+ train_dataloader_params : dict[str, Any] | None, optional
1933
+ Parameters for the training dataloader, by default None.
1934
+ val_dataloader_params : dict[str, Any] | None, optional
1935
+ Parameters for the validation dataloader, by default None.
1936
+
1937
+ Returns
1938
+ -------
1939
+ Configuration
1940
+ A configuration object for the microsplit algorithm.
1941
+
1942
+ Examples
1943
+ --------
1944
+ Minimum example:
1945
+ # >>> config = create_microsplit_configuration(
1946
+ # ... experiment_name="microsplit_experiment",
1947
+ # ... data_type="array",
1948
+ # ... axes="YX",
1949
+ # ... patch_size=[64, 64],
1950
+ # ... batch_size=32,
1951
+ # ... num_epochs=100
1952
+
1953
+ # ... )
1954
+
1955
+ # You can also limit the number of batches per epoch:
1956
+ # >>> config = create_microsplit_configuration(
1957
+ # ... experiment_name="microsplit_experiment",
1958
+ # ... data_type="array",
1959
+ # ... axes="YX",
1960
+ # ... patch_size=[64, 64],
1961
+ # ... batch_size=32,
1962
+ # ... num_steps=100 # limit to 100 batches per epoch
1963
+ # ... )
1964
+ """
1965
+ transform_list = _list_spatial_augmentations(augmentations)
1966
+
1967
+ loss_config = LVAELossConfig(
1968
+ loss_type="denoisplit_musplit", denoisplit_weight=0.9, musplit_weight=0.1
1969
+ ) # TODO losses need to be refactored! just for example. Add validator if sum to 1
1970
+
1971
+ # Create likelihood configurations
1972
+ gaussian_likelihood_config, noise_model_config, nm_likelihood_config = (
1973
+ get_likelihood_config(
1974
+ loss_type="denoisplit_musplit",
1975
+ predict_logvar=predict_logvar,
1976
+ logvar_lowerbound=logvar_lowerbound,
1977
+ nm_paths=nm_paths,
1978
+ data_stats=data_stats,
1979
+ )
1980
+ )
1981
+
1982
+ # Create the LVAE model
1983
+ network_model = _create_vae_configuration(
1984
+ input_shape=patch_size,
1985
+ encoder_conv_strides=encoder_conv_strides,
1986
+ decoder_conv_strides=decoder_conv_strides,
1987
+ multiscale_count=multiscale_count,
1988
+ z_dims=z_dims,
1989
+ output_channels=output_channels,
1990
+ encoder_n_filters=encoder_n_filters,
1991
+ decoder_n_filters=decoder_n_filters,
1992
+ encoder_dropout=encoder_dropout,
1993
+ decoder_dropout=decoder_dropout,
1994
+ nonlinearity=nonlinearity,
1995
+ predict_logvar=predict_logvar,
1996
+ analytical_kl=analytical_kl,
1997
+ )
1998
+
1999
+ # Create the MicroSplit algorithm configuration
2000
+ algorithm_params = {
2001
+ "algorithm": "microsplit",
2002
+ "loss": loss_config,
2003
+ "model": network_model,
2004
+ "gaussian_likelihood": gaussian_likelihood_config,
2005
+ "noise_model": noise_model_config,
2006
+ "noise_model_likelihood": nm_likelihood_config,
2007
+ }
2008
+
2009
+ # Convert to MicroSplitAlgorithm instance
2010
+ algorithm_config = MicroSplitAlgorithm(**algorithm_params)
2011
+
2012
+ # data
2013
+ data_params = _create_microsplit_data_configuration(
2014
+ data_type=data_type,
2015
+ axes=axes,
2016
+ patch_size=patch_size,
2017
+ grid_size=grid_size,
2018
+ multiscale_count=multiscale_count,
2019
+ batch_size=batch_size,
2020
+ augmentations=transform_list,
2021
+ train_dataloader_params=train_dataloader_params,
2022
+ val_dataloader_params=val_dataloader_params,
2023
+ )
2024
+
2025
+ # training
2026
+ final_trainer_params = update_trainer_params(
2027
+ trainer_params=trainer_params,
2028
+ num_epochs=num_epochs,
2029
+ num_steps=num_steps,
2030
+ )
2031
+ training_params = _create_training_configuration(
2032
+ trainer_params=final_trainer_params,
2033
+ logger=logger,
2034
+ )
2035
+
2036
+ return Configuration(
2037
+ experiment_name=experiment_name,
2038
+ algorithm_config=algorithm_config,
2039
+ data_config=data_params,
2040
+ training_config=training_params,
2041
+ )
2042
+
2043
+
2044
+ def create_pn2v_configuration(
2045
+ experiment_name: str,
2046
+ data_type: Literal["array", "tiff", "czi", "custom"],
2047
+ axes: str,
2048
+ patch_size: Sequence[int],
2049
+ batch_size: int,
2050
+ nm_path: str,
2051
+ num_epochs: int = 100,
2052
+ num_steps: int | None = None,
2053
+ augmentations: list[Union[XYFlipConfig, XYRandomRotate90Config]] | None = None,
2054
+ independent_channels: bool = True,
2055
+ use_n2v2: bool = False,
2056
+ num_in_channels: int = 1,
2057
+ num_out_channels: int = 100,
2058
+ roi_size: int = 11,
2059
+ masked_pixel_percentage: float = 0.2,
2060
+ struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
2061
+ struct_n2v_span: int = 5,
2062
+ trainer_params: dict | None = None,
2063
+ logger: Literal["wandb", "tensorboard", "none"] = "none",
2064
+ model_params: dict | None = None,
2065
+ optimizer: Literal["Adam", "Adamax", "SGD"] = "Adam",
2066
+ optimizer_params: dict[str, Any] | None = None,
2067
+ lr_scheduler: Literal["ReduceLROnPlateau", "StepLR"] = "ReduceLROnPlateau",
2068
+ lr_scheduler_params: dict[str, Any] | None = None,
2069
+ train_dataloader_params: dict[str, Any] | None = None,
2070
+ val_dataloader_params: dict[str, Any] | None = None,
2071
+ checkpoint_params: dict[str, Any] | None = None,
2072
+ ) -> Configuration:
2073
+ """
2074
+ Create a configuration for training Probabilistic Noise2Void (PN2V).
2075
+
2076
+ PN2V extends N2V by incorporating a probabilistic noise model to estimate the
2077
+ posterior distibution of each pixel more precisely.
2078
+
2079
+ If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
2080
+ 2.
2081
+
2082
+ If "C" is present in `axes`, then you need to set `num_in_channels` to the number of
2083
+ channels.
2084
+
2085
+ By default, all channels are trained independently. To train all channels together,
2086
+ set `independent_channels` to False. When training independently, each input channel
2087
+ will have `num_out_channels` outputs (default 400). When training together, all
2088
+ input channels will share `num_out_channels` outputs.
2089
+
2090
+ By default, the transformations applied are a random flip along X or Y, and a random
2091
+ 90 degrees rotation in the XY plane. Normalization is always applied, as well as the
2092
+ N2V manipulation.
2093
+
2094
+ By setting `augmentations` to `None`, the default transformations (flip in X and Y,
2095
+ rotations by 90 degrees in the XY plane) are applied. Rather than the default
2096
+ transforms, a list of transforms can be passed to the `augmentations` parameter. To
2097
+ disable the transforms, simply pass an empty list.
2098
+
2099
+ The `roi_size` parameter specifies the size of the area around each pixel that will
2100
+ be manipulated by N2V. The `masked_pixel_percentage` parameter specifies how many
2101
+ pixels per patch will be manipulated.
2102
+
2103
+ The parameters of the UNet can be specified in the `model_params` (passed as a
2104
+ parameter-value dictionary). Note that `use_n2v2`, `num_in_channels`, and
2105
+ `num_out_channels` override the corresponding parameters passed in `model_params`.
2106
+
2107
+ If you pass "horizontal" or "vertical" to `struct_n2v_axis`, then structN2V mask
2108
+ will be applied to each manipulated pixel.
2109
+
2110
+ Parameters
2111
+ ----------
2112
+ experiment_name : str
2113
+ Name of the experiment.
2114
+ data_type : Literal["array", "tiff", "czi", "custom"]
2115
+ Type of the data.
2116
+ axes : str
2117
+ Axes of the data (e.g. SYX).
2118
+ patch_size : List[int]
2119
+ Size of the patches along the spatial dimensions (e.g. [64, 64]).
2120
+ batch_size : int
2121
+ Batch size.
2122
+ nm_path : str
2123
+ Path to the noise model file.
2124
+ num_epochs : int, default=100
2125
+ Number of epochs to train for. If provided, this will be added to
2126
+ trainer_params.
2127
+ num_steps : int, optional
2128
+ Number of batches in 1 epoch. If provided, this will be added to trainer_params.
2129
+ Translates to `limit_train_batches` in PyTorch Lightning Trainer. See relevant
2130
+ documentation for more details.
2131
+ augmentations : list of transforms, default=None
2132
+ List of transforms to apply, either both or one of XYFlipModel and
2133
+ XYRandomRotate90Model. By default, it applies both XYFlip (on X and Y)
2134
+ and XYRandomRotate90 (in XY) to the images.
2135
+ independent_channels : bool, optional
2136
+ Whether to train all channels independently, by default True. If True, each
2137
+ input channel will correspond to num_out_channels output channels (e.g., 3
2138
+ input channels with num_out_channels=400 results in 1200 total output
2139
+ channels).
2140
+ use_n2v2 : bool, optional
2141
+ Whether to use N2V2, by default False.
2142
+ num_in_channels : int, default=1
2143
+ Number of input channels.
2144
+ num_out_channels : int, default=400
2145
+ Number of output channels per input channel when independent_channels is True,
2146
+ or total number of output channels when independent_channels is False.
2147
+ roi_size : int, optional
2148
+ N2V pixel manipulation area, by default 11.
2149
+ masked_pixel_percentage : float, optional
2150
+ Percentage of pixels masked in each patch, by default 0.2.
2151
+ struct_n2v_axis : Literal["horizontal", "vertical", "none"], optional
2152
+ Axis along which to apply structN2V mask, by default "none".
2153
+ struct_n2v_span : int, optional
2154
+ Span of the structN2V mask, by default 5.
2155
+ trainer_params : dict, optional
2156
+ Parameters for the trainer, see the relevant documentation.
2157
+ logger : Literal["wandb", "tensorboard", "none"], optional
2158
+ Logger to use, by default "none".
2159
+ model_params : dict, default=None
2160
+ UNetModel parameters.
2161
+ optimizer : Literal["Adam", "Adamax", "SGD"], default="Adam"
2162
+ Optimizer to use.
2163
+ optimizer_params : dict, default=None
2164
+ Parameters for the optimizer, see PyTorch documentation for more details.
2165
+ lr_scheduler : Literal["ReduceLROnPlateau", "StepLR"], default="ReduceLROnPlateau"
2166
+ Learning rate scheduler to use.
2167
+ lr_scheduler_params : dict, default=None
2168
+ Parameters for the learning rate scheduler, see PyTorch documentation for more
2169
+ details.
2170
+ train_dataloader_params : dict, optional
2171
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
2172
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
2173
+ the `GeneralDataConfig`.
2174
+ val_dataloader_params : dict, optional
2175
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
2176
+ If left as `None`, the empty dict `{}` will be used, this is set in the
2177
+ `GeneralDataConfig`.
2178
+ checkpoint_params : dict, default=None
2179
+ Parameters for the checkpoint callback, see PyTorch Lightning documentation
2180
+ (`ModelCheckpoint`) for the list of available parameters.
2181
+
2182
+ Returns
2183
+ -------
2184
+ Configuration
2185
+ Configuration for training PN2V.
2186
+
2187
+ Examples
2188
+ --------
2189
+ Minimum example:
2190
+ # >>> config = create_pn2v_configuration(
2191
+ # ... experiment_name="pn2v_experiment",
2192
+ # ... data_type="array",
2193
+ # ... axes="YX",
2194
+ # ... patch_size=[64, 64],
2195
+ # ... batch_size=32,
2196
+ # ... nm_path="path/to/noise_model.npz",
2197
+ # ... num_epochs=100
2198
+ # ... )
2199
+
2200
+ # You can also limit the number of batches per epoch:
2201
+ # >>> config = create_pn2v_configuration(
2202
+ # ... experiment_name="pn2v_experiment",
2203
+ # ... data_type="array",
2204
+ # ... axes="YX",
2205
+ # ... patch_size=[64, 64],
2206
+ # ... batch_size=32,
2207
+ # ... nm_path="path/to/noise_model.npz",
2208
+ # ... num_steps=100 # limit to 100 batches per epoch
2209
+ # ... )
2210
+
2211
+ # To disable transforms, simply set `augmentations` to an empty list:
2212
+ # >>> config = create_pn2v_configuration(
2213
+ # ... experiment_name="pn2v_experiment",
2214
+ # ... data_type="array",
2215
+ # ... axes="YX",
2216
+ # ... patch_size=[64, 64],
2217
+ # ... batch_size=32,
2218
+ # ... nm_path="path/to/noise_model.npz",
2219
+ # ... num_epochs=100,
2220
+ # ... augmentations=[]
2221
+ # ... )
2222
+
2223
+ # A list of transforms can be passed to the `augmentations` parameter:
2224
+ # >>> from careamics.config.transformations import XYFlipModel
2225
+ # >>> config = create_pn2v_configuration(
2226
+ # ... experiment_name="pn2v_experiment",
2227
+ # ... data_type="array",
2228
+ # ... axes="YX",
2229
+ # ... patch_size=[64, 64],
2230
+ # ... batch_size=32,
2231
+ # ... nm_path="path/to/noise_model.npz",
2232
+ # ... num_epochs=100,
2233
+ # ... augmentations=[
2234
+ # ... # No rotation and only Y flipping
2235
+ # ... XYFlipModel(flip_x = False, flip_y = True)
2236
+ # ... ]
2237
+ # ... )
2238
+
2239
+ # To use N2V2, simply pass the `use_n2v2` parameter:
2240
+ # >>> config = create_pn2v_configuration(
2241
+ # ... experiment_name="pn2v2_experiment",
2242
+ # ... data_type="tiff",
2243
+ # ... axes="YX",
2244
+ # ... patch_size=[64, 64],
2245
+ # ... batch_size=32,
2246
+ # ... nm_path="path/to/noise_model.npz",
2247
+ # ... num_epochs=100,
2248
+ # ... use_n2v2=True
2249
+ # ... )
2250
+
2251
+ # For structN2V, there are two parameters to set, `struct_n2v_axis` and
2252
+ # `struct_n2v_span`:
2253
+ # >>> config = create_pn2v_configuration(
2254
+ # ... experiment_name="structpn2v_experiment",
2255
+ # ... data_type="tiff",
2256
+ # ... axes="YX",
2257
+ # ... patch_size=[64, 64],
2258
+ # ... batch_size=32,
2259
+ # ... nm_path="path/to/noise_model.npz",
2260
+ # ... num_epochs=100,
2261
+ # ... struct_n2v_axis="horizontal",
2262
+ # ... struct_n2v_span=7
2263
+ # ... )
2264
+
2265
+ # If you are training multiple channels they will be trained independently by
2266
+ # default, you simply need to specify the number of input channels. Each input
2267
+ # channel will correspond to num_out_channels outputs (1200 total for 3
2268
+ # channels with default num_out_channels=400):
2269
+ # >>> config = create_pn2v_configuration(
2270
+ # ... experiment_name="pn2v_experiment",
2271
+ # ... data_type="array",
2272
+ # ... axes="YXC",
2273
+ # ... patch_size=[64, 64],
2274
+ # ... batch_size=32,
2275
+ # ... nm_path="path/to/noise_model.npz",
2276
+ # ... num_epochs=100,
2277
+ # ... num_in_channels=3
2278
+ # ... )
2279
+
2280
+ # If instead you want to train multiple channels together, you need to turn
2281
+ # off the `independent_channels` parameter (resulting in 400 total output
2282
+ # channels regardless of the number of input channels):
2283
+ # >>> config = create_pn2v_configuration(
2284
+ # ... experiment_name="pn2v_experiment",
2285
+ # ... data_type="array",
2286
+ # ... axes="YXC",
2287
+ # ... patch_size=[64, 64],
2288
+ # ... batch_size=32,
2289
+ # ... nm_path="path/to/noise_model.npz",
2290
+ # ... num_epochs=100,
2291
+ # ... independent_channels=False,
2292
+ # ... num_in_channels=3
2293
+ # ... )
2294
+
2295
+ # >>> config_2d = create_pn2v_configuration(
2296
+ # ... experiment_name="pn2v_experiment",
2297
+ # ... data_type="czi",
2298
+ # ... axes="SCYX",
2299
+ # ... patch_size=[64, 64],
2300
+ # ... batch_size=32,
2301
+ # ... nm_path="path/to/noise_model.npz",
2302
+ # ... num_epochs=100,
2303
+ # ... num_in_channels=1,
2304
+ # ... )
2305
+ # >>> config_3d = create_pn2v_configuration(
2306
+ # ... experiment_name="pn2v_experiment",
2307
+ # ... data_type="czi",
2308
+ # ... axes="SCZYX",
2309
+ # ... patch_size=[16, 64, 64],
2310
+ # ... batch_size=16,
2311
+ # ... nm_path="path/to/noise_model.npz",
2312
+ # ... num_epochs=100,
2313
+ # ... num_in_channels=1,
2314
+ # ... )
2315
+ """
2316
+ # Validate channel configuration
2317
+ if "C" in axes and num_in_channels < 1:
2318
+ raise ValueError("num_in_channels must be at least 1 when using channels.")
2319
+ elif "C" not in axes and num_in_channels > 1:
2320
+ raise ValueError(
2321
+ f"C is not present in the axes, but num_in_channels is specified "
2322
+ f"(got {num_in_channels} channels)."
2323
+ )
2324
+
2325
+ # Calculate total output channels based on independent_channels setting
2326
+ if independent_channels:
2327
+ total_out_channels = num_in_channels * num_out_channels
2328
+ else:
2329
+ total_out_channels = num_out_channels
2330
+
2331
+ # augmentations
2332
+ spatial_transforms = _list_spatial_augmentations(augmentations)
2333
+
2334
+ # create the N2VManipulate transform using the supplied parameters
2335
+ n2v_transform = N2VManipulateConfig(
2336
+ name=SupportedTransform.N2V_MANIPULATE.value,
2337
+ strategy=(
2338
+ SupportedPixelManipulation.MEDIAN.value
2339
+ if use_n2v2
2340
+ else SupportedPixelManipulation.UNIFORM.value
2341
+ ),
2342
+ roi_size=roi_size,
2343
+ masked_pixel_percentage=masked_pixel_percentage,
2344
+ struct_mask_axis=struct_n2v_axis,
2345
+ struct_mask_span=struct_n2v_span,
2346
+ )
2347
+
2348
+ # Create noise model configuration
2349
+ noise_model_config = GaussianMixtureNMConfig(path=nm_path)
2350
+
2351
+ # algorithm
2352
+ algorithm_params = _create_algorithm_configuration(
2353
+ axes=axes,
2354
+ algorithm="pn2v",
2355
+ loss="pn2v",
2356
+ independent_channels=independent_channels,
2357
+ n_channels_in=num_in_channels,
2358
+ n_channels_out=total_out_channels,
2359
+ use_n2v2=use_n2v2,
2360
+ model_params=model_params,
2361
+ optimizer=optimizer,
2362
+ optimizer_params=optimizer_params,
2363
+ lr_scheduler=lr_scheduler,
2364
+ lr_scheduler_params=lr_scheduler_params,
2365
+ )
2366
+ algorithm_params["n2v_config"] = n2v_transform
2367
+ algorithm_params["noise_model"] = noise_model_config
2368
+
2369
+ # Convert to PN2VAlgorithm instance
2370
+ algorithm_config = PN2VAlgorithm(**algorithm_params)
2371
+
2372
+ # data
2373
+ data_params = _create_data_configuration(
2374
+ data_type=data_type,
2375
+ axes=axes,
2376
+ patch_size=patch_size,
2377
+ batch_size=batch_size,
2378
+ augmentations=spatial_transforms,
2379
+ train_dataloader_params=train_dataloader_params,
2380
+ val_dataloader_params=val_dataloader_params,
2381
+ )
2382
+
2383
+ # training
2384
+ final_trainer_params = update_trainer_params(
2385
+ trainer_params=trainer_params,
2386
+ num_epochs=num_epochs,
2387
+ num_steps=num_steps,
2388
+ )
2389
+ training_params = _create_training_configuration(
2390
+ trainer_params=final_trainer_params,
2391
+ logger=logger,
2392
+ checkpoint_params=checkpoint_params,
2393
+ )
2394
+
2395
+ return Configuration(
2396
+ experiment_name=experiment_name,
2397
+ algorithm_config=algorithm_config,
2398
+ data_config=data_params,
2399
+ training_config=training_params,
2400
+ )