careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,149 @@
1
+ """Noise models config."""
2
+
3
+ from pathlib import Path
4
+ from typing import Annotated, Literal, Self, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from pydantic import (
9
+ BaseModel,
10
+ ConfigDict,
11
+ Field,
12
+ PlainSerializer,
13
+ PlainValidator,
14
+ model_validator,
15
+ )
16
+
17
+ from careamics.utils.serializers import _array_to_json, _to_numpy
18
+
19
+ # TODO: this is a temporary solution to serialize and deserialize array fields
20
+ # in pydantic models. Specifically, the aim is to enable saving and loading configs
21
+ # with such arrays to/from JSON files during, resp., training and evaluation.
22
+ Array = Annotated[
23
+ Union[np.ndarray, torch.Tensor],
24
+ PlainSerializer(_array_to_json, return_type=str),
25
+ PlainValidator(_to_numpy),
26
+ ]
27
+ """Annotated array type, used to serialize arrays or tensors to JSON strings
28
+ and deserialize them back to arrays."""
29
+
30
+
31
+ # TODO: add histogram-based noise model
32
+
33
+
34
+ class GaussianMixtureNMConfig(BaseModel):
35
+ """Gaussian mixture noise model."""
36
+
37
+ model_config = ConfigDict(
38
+ protected_namespaces=(),
39
+ validate_assignment=True,
40
+ arbitrary_types_allowed=True,
41
+ extra="allow",
42
+ )
43
+ # model type
44
+ model_type: Literal["GaussianMixtureNoiseModel"] = "GaussianMixtureNoiseModel"
45
+
46
+ path: Union[Path, str] | None = None
47
+ """Path to the directory where the trained noise model (*.npz) is saved in the
48
+ `train` method."""
49
+
50
+ # TODO remove and use as parameters to the NM functions?
51
+ signal: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
52
+ """Path to the file containing signal or respective numpy array."""
53
+
54
+ # TODO remove and use as parameters to the NM functions?
55
+ observation: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
56
+ """Path to the file containing observation or respective numpy array."""
57
+
58
+ weight: Array | None = None
59
+ """A [3*n_gaussian, n_coeff] sized array containing the values of the weights
60
+ describing the GMM noise model, with each row corresponding to one
61
+ parameter of each gaussian, namely [mean, standard deviation and weight].
62
+ Specifically, rows are organized as follows:
63
+ - first n_gaussian rows correspond to the means
64
+ - next n_gaussian rows correspond to the weights
65
+ - last n_gaussian rows correspond to the standard deviations
66
+ If `weight=None`, the weight array is initialized using the `min_signal`
67
+ and `max_signal` parameters."""
68
+
69
+ n_gaussian: int = Field(default=1, ge=1)
70
+ """Number of gaussians used for the GMM."""
71
+
72
+ n_coeff: int = Field(default=2, ge=2)
73
+ """Number of coefficients to describe the functional relationship between gaussian
74
+ parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
75
+ relationship and so on."""
76
+
77
+ min_signal: float = Field(default=0.0, ge=0.0)
78
+ """Minimum signal intensity expected in the image."""
79
+
80
+ max_signal: float = Field(default=1.0, ge=0.0)
81
+ """Maximum signal intensity expected in the image."""
82
+
83
+ min_sigma: float = Field(default=200.0, ge=0.0) # TODO took from nb in pn2v
84
+ """Minimum value of `standard deviation` allowed in the GMM.
85
+ All values of `standard deviation` below this are clamped to this value."""
86
+
87
+ tol: float = Field(default=1e-10)
88
+ """Tolerance used in the computation of the noise model likelihood."""
89
+
90
+ @model_validator(mode="after")
91
+ def validate_path(self: Self) -> Self:
92
+ """Validate that the path points to a valid .npz file if provided.
93
+
94
+ Returns
95
+ -------
96
+ Self
97
+ Returns itself.
98
+
99
+ Raises
100
+ ------
101
+ ValueError
102
+ If the path is provided but does not point to a valid .npz file.
103
+ """
104
+ if self.path is not None:
105
+ path = Path(self.path)
106
+ if not path.exists():
107
+ raise ValueError(f"Path {path} does not exist.")
108
+ if path.suffix != ".npz":
109
+ raise ValueError(f"Path {path} must point to a .npz file.")
110
+ if not path.is_file():
111
+ raise ValueError(f"Path {path} must point to a file.")
112
+ return self
113
+
114
+ # @model_validator(mode="after")
115
+ # def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
116
+ # """Validate paths provided in the config.
117
+
118
+ # Returns
119
+ # -------
120
+ # Self
121
+ # Returns itself.
122
+ # """
123
+ # if self.path and (self.signal is not None or self.observation is not None):
124
+ # raise ValueError(
125
+ # "Either only 'path' to pre-trained noise model should be"
126
+ # "provided or only signal and observation in form of paths"
127
+ # "or numpy arrays."
128
+ # )
129
+ # if not self.path and (self.signal is None or self.observation is None):
130
+ # raise ValueError(
131
+ # "Either only 'path' to pre-trained noise model should be"
132
+ # "provided or only signal and observation in form of paths"
133
+ # "or numpy arrays."
134
+ # )
135
+ # return self
136
+ # TODO revisit validation
137
+
138
+
139
+ # The noise model is given by a set of GMMs, one for each target
140
+ # e.g., 2 target channels, 2 noise models
141
+ class MultiChannelNMConfig(BaseModel):
142
+ """Noise Model config aggregating noise models for single output channels."""
143
+
144
+ # TODO: check that this model config is OK
145
+ model_config = ConfigDict(
146
+ validate_assignment=True, arbitrary_types_allowed=True, extra="allow"
147
+ )
148
+ noise_models: list[GaussianMixtureNMConfig]
149
+ """List of noise models, one for each target channel."""
@@ -0,0 +1,31 @@
1
+ """Supported configuration options.
2
+
3
+ Used throughout the code to ensure consistency. These should be kept in sync with the
4
+ corresponding configuration options in the Pydantic models.
5
+ """
6
+
7
+ __all__ = [
8
+ "SupportedActivation",
9
+ "SupportedAlgorithm",
10
+ "SupportedArchitecture",
11
+ "SupportedData",
12
+ "SupportedLogger",
13
+ "SupportedLoss",
14
+ "SupportedOptimizer",
15
+ "SupportedPixelManipulation",
16
+ "SupportedScheduler",
17
+ "SupportedStructAxis",
18
+ "SupportedTransform",
19
+ ]
20
+
21
+
22
+ from .supported_activations import SupportedActivation
23
+ from .supported_algorithms import SupportedAlgorithm
24
+ from .supported_architectures import SupportedArchitecture
25
+ from .supported_data import SupportedData
26
+ from .supported_loggers import SupportedLogger
27
+ from .supported_losses import SupportedLoss
28
+ from .supported_optimizers import SupportedOptimizer, SupportedScheduler
29
+ from .supported_pixel_manipulations import SupportedPixelManipulation
30
+ from .supported_struct_axis import SupportedStructAxis
31
+ from .supported_transforms import SupportedTransform
@@ -0,0 +1,27 @@
1
+ """Activations supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedActivation(str, BaseEnum):
7
+ """Supported activation functions.
8
+
9
+ - None, no activation will be used.
10
+ - Sigmoid
11
+ - Softmax
12
+ - Tanh
13
+ - ReLU
14
+ - LeakyReLU
15
+
16
+ All activations are defined in PyTorch.
17
+
18
+ See: https://pytorch.org/docs/stable/nn.html#loss-functions
19
+ """
20
+
21
+ NONE = "None"
22
+ SIGMOID = "Sigmoid"
23
+ SOFTMAX = "Softmax"
24
+ TANH = "Tanh"
25
+ RELU = "ReLU"
26
+ LEAKYRELU = "LeakyReLU"
27
+ ELU = "ELU"
@@ -0,0 +1,40 @@
1
+ """Algorithms supported by CAREamics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from careamics.utils import BaseEnum
6
+
7
+
8
+ class SupportedAlgorithm(str, BaseEnum):
9
+ """Algorithms available in CAREamics.
10
+
11
+ These definitions are the same as the keyword `name` of the algorithm
12
+ configurations.
13
+ """
14
+
15
+ N2V = "n2v"
16
+ """Noise2Void algorithm, a self-supervised approach based on blind denoising."""
17
+
18
+ CARE = "care"
19
+ """Content-aware image restoration, a supervised algorithm used for a variety
20
+ of tasks."""
21
+
22
+ N2N = "n2n"
23
+ """Noise2Noise algorithm, a self-supervised denoising scheme based on comparing
24
+ noisy images of the same sample."""
25
+
26
+ MUSPLIT = "musplit" # TODO remove
27
+ """An image splitting approach based on ladder VAE architectures."""
28
+
29
+ MICROSPLIT = "microsplit"
30
+ """A micro-level image splitting approach based on ladder VAE architectures."""
31
+
32
+ DENOISPLIT = "denoisplit"
33
+ """An image splitting and denoising approach based on ladder VAE architectures."""
34
+
35
+ HDN = "hdn"
36
+ """Hierarchical Denoising Network, an unsupervised denoising algorithm"""
37
+
38
+ PN2V = "pn2v"
39
+ """Probabilistic Noise2Void. A extension of Noise2Void is not restricted to Gaussian
40
+ noise models or Gaussian intensity predictions."""
@@ -0,0 +1,13 @@
1
+ """Architectures supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedArchitecture(str, BaseEnum):
7
+ """Supported architectures."""
8
+
9
+ UNET = "UNet"
10
+ """UNet architecture used with N2V, CARE and Noise2Noise."""
11
+
12
+ LVAE = "LVAE"
13
+ """Ladder Variational Autoencoder used for muSplit and denoiSplit."""
@@ -0,0 +1,122 @@
1
+ """Data supported by CAREamics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Union
6
+
7
+ from careamics.utils import BaseEnum
8
+
9
+
10
+ class SupportedData(str, BaseEnum):
11
+ """Supported data types.
12
+
13
+ Attributes
14
+ ----------
15
+ ARRAY : str
16
+ Array data.
17
+ TIFF : str
18
+ TIFF image data.
19
+ CZI : str
20
+ CZI image data.
21
+ ZARR : str
22
+ Zarr data.
23
+ CUSTOM : str
24
+ Custom data.
25
+ """
26
+
27
+ ARRAY = "array"
28
+ TIFF = "tiff"
29
+ CZI = "czi"
30
+ CUSTOM = "custom"
31
+ ZARR = "zarr"
32
+
33
+ # TODO remove?
34
+ @classmethod
35
+ def _missing_(cls, value: object) -> str:
36
+ """
37
+ Override default behaviour for missing values.
38
+
39
+ This method is called when `value` is not found in the enum values. It converts
40
+ `value` to lowercase, removes "." if it is the first character and tries to
41
+ match it with enum values.
42
+
43
+ Parameters
44
+ ----------
45
+ value : object
46
+ Value to be matched with enum values.
47
+
48
+ Returns
49
+ -------
50
+ str
51
+ Matched enum value.
52
+ """
53
+ if isinstance(value, str):
54
+ lower_value = value.lower()
55
+
56
+ if lower_value.startswith("."):
57
+ lower_value = lower_value[1:]
58
+
59
+ # attempt to match lowercase value with enum values
60
+ for member in cls:
61
+ if member.value == lower_value:
62
+ return member
63
+
64
+ # still missing
65
+ return super()._missing_(value)
66
+
67
+ @classmethod
68
+ def get_extension_pattern(cls, data_type: Union[str, SupportedData]) -> str:
69
+ """
70
+ Get Path.rglob and fnmatch compatible extension.
71
+
72
+ Parameters
73
+ ----------
74
+ data_type : SupportedData
75
+ Data type.
76
+
77
+ Returns
78
+ -------
79
+ str
80
+ Corresponding extension pattern.
81
+ """
82
+ if data_type == cls.ARRAY:
83
+ raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
84
+ elif data_type == cls.TIFF:
85
+ return "*.tif*"
86
+ elif data_type == cls.ZARR:
87
+ return "*.zarr"
88
+ elif data_type == cls.CZI:
89
+ return "*.czi"
90
+ elif data_type == cls.CUSTOM:
91
+ return "*.*"
92
+ else:
93
+ raise ValueError(f"Data type {data_type} is not supported.")
94
+
95
+ @classmethod
96
+ def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
97
+ """
98
+ Get file extension of corresponding data type.
99
+
100
+ Parameters
101
+ ----------
102
+ data_type : str or SupportedData
103
+ Data type.
104
+
105
+ Returns
106
+ -------
107
+ str
108
+ Corresponding extension.
109
+ """
110
+ if data_type == cls.ARRAY:
111
+ raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
112
+ elif data_type == cls.TIFF:
113
+ return ".tiff"
114
+ elif data_type == cls.CZI:
115
+ return ".czi"
116
+ elif data_type == cls.ZARR:
117
+ return ".zarr"
118
+ elif data_type == cls.CUSTOM:
119
+ # TODO: improve this message
120
+ raise NotImplementedError("Custom extensions have to be passed elsewhere.")
121
+ else:
122
+ raise ValueError(f"Data type {data_type} is not supported.")
@@ -0,0 +1,17 @@
1
+ """Coordinate and patch filters supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedPatchFilters(str, BaseEnum):
7
+ """Supported patch filters."""
8
+
9
+ MAX = "max"
10
+ MEANSTD = "mean_std"
11
+ SHANNON = "shannon"
12
+
13
+
14
+ class SupportedCoordinateFilters(str, BaseEnum):
15
+ """Supported coordinate filters."""
16
+
17
+ MASK = "mask"
@@ -0,0 +1,10 @@
1
+ """Logger supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedLogger(str, BaseEnum):
7
+ """Available loggers."""
8
+
9
+ WANDB = "wandb"
10
+ TENSORBOARD = "tensorboard"
@@ -0,0 +1,32 @@
1
+ """Losses supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ # TODO register loss with custom_loss decorator?
7
+ class SupportedLoss(str, BaseEnum):
8
+ """Supported losses.
9
+
10
+ Attributes
11
+ ----------
12
+ MSE : str
13
+ Mean Squared Error loss.
14
+ MAE : str
15
+ Mean Absolute Error loss.
16
+ N2V : str
17
+ Noise2Void loss.
18
+ """
19
+
20
+ MSE = "mse"
21
+ MAE = "mae"
22
+ N2V = "n2v"
23
+ PN2V = "pn2v"
24
+ HDN = "hdn"
25
+ MUSPLIT = "musplit"
26
+ MICROSPLIT = "microsplit"
27
+ DENOISPLIT = "denoisplit"
28
+ DENOISPLIT_MUSPLIT = (
29
+ "denoisplit_musplit" # TODO refac losses, leave only microsplit
30
+ )
31
+ # CE = "ce"
32
+ # DICE = "dice"
@@ -0,0 +1,57 @@
1
+ """Optimizers and schedulers supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedOptimizer(str, BaseEnum):
7
+ """Supported optimizers.
8
+
9
+ Attributes
10
+ ----------
11
+ Adam : str
12
+ Adam optimizer.
13
+ SGD : str
14
+ Stochastic Gradient Descent optimizer.
15
+ """
16
+
17
+ # ASGD = "ASGD"
18
+ # Adadelta = "Adadelta"
19
+ # Adagrad = "Adagrad"
20
+ ADAM = "Adam"
21
+ # AdamW = "AdamW"
22
+ ADAMAX = "Adamax"
23
+ # LBFGS = "LBFGS"
24
+ # NAdam = "NAdam"
25
+ # RAdam = "RAdam"
26
+ # RMSprop = "RMSprop"
27
+ # Rprop = "Rprop"
28
+ SGD = "SGD"
29
+ # SparseAdam = "SparseAdam"
30
+
31
+
32
+ class SupportedScheduler(str, BaseEnum):
33
+ """Supported schedulers.
34
+
35
+ Attributes
36
+ ----------
37
+ ReduceLROnPlateau : str
38
+ Reduce learning rate on plateau.
39
+ StepLR : str
40
+ Step learning rate.
41
+ """
42
+
43
+ # ChainedScheduler = "ChainedScheduler"
44
+ # ConstantLR = "ConstantLR"
45
+ # CosineAnnealingLR = "CosineAnnealingLR"
46
+ # CosineAnnealingWarmRestarts = "CosineAnnealingWarmRestarts"
47
+ # CyclicLR = "CyclicLR"
48
+ # ExponentialLR = "ExponentialLR"
49
+ # LambdaLR = "LambdaLR"
50
+ # LinearLR = "LinearLR"
51
+ # MultiStepLR = "MultiStepLR"
52
+ # MultiplicativeLR = "MultiplicativeLR"
53
+ # OneCycleLR = "OneCycleLR"
54
+ # PolynomialLR = "PolynomialLR"
55
+ REDUCE_LR_ON_PLATEAU = "ReduceLROnPlateau"
56
+ # SequentialLR = "SequentialLR"
57
+ STEP_LR = "StepLR"
@@ -0,0 +1,22 @@
1
+ """Patching strategies supported by Careamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedPatchingStrategy(str, BaseEnum):
7
+ """Patching strategies supported by Careamics."""
8
+
9
+ FIXED_RANDOM = "fixed_random"
10
+ """Fixed random patching strategy, used during training."""
11
+
12
+ RANDOM = "random"
13
+ """Random patching strategy, used during training."""
14
+
15
+ # SEQUENTIAL = "sequential"
16
+ # """Sequential patching strategy, used during training."""
17
+
18
+ TILED = "tiled"
19
+ """Tiled patching strategy, used during prediction."""
20
+
21
+ WHOLE = "whole"
22
+ """Whole image patching strategy, used during prediction."""
@@ -0,0 +1,15 @@
1
+ """Pixel manipulation methods supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedPixelManipulation(str, BaseEnum):
7
+ """Supported Noise2Void pixel manipulations.
8
+
9
+ - Uniform: Replace masked pixel value by a (uniformly) randomly selected neighbor
10
+ pixel value.
11
+ - Median: Replace masked pixel value by the mean of the neighborhood.
12
+ """
13
+
14
+ UNIFORM = "uniform"
15
+ MEDIAN = "median"
@@ -0,0 +1,21 @@
1
+ """StructN2V axes supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedStructAxis(str, BaseEnum):
7
+ """Supported structN2V mask axes.
8
+
9
+ Attributes
10
+ ----------
11
+ HORIZONTAL : str
12
+ Horizontal axis.
13
+ VERTICAL : str
14
+ Vertical axis.
15
+ NONE : str
16
+ No axis, the mask is not applied.
17
+ """
18
+
19
+ HORIZONTAL = "horizontal"
20
+ VERTICAL = "vertical"
21
+ NONE = "none"
@@ -0,0 +1,12 @@
1
+ """Transforms supported by CAREamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedTransform(str, BaseEnum):
7
+ """Transforms officially supported by CAREamics."""
8
+
9
+ XY_FLIP = "XYFlip"
10
+ XY_RANDOM_ROTATE90 = "XYRandomRotate90"
11
+ NORMALIZE = "Normalize"
12
+ N2V_MANIPULATE = "N2VManipulate"
@@ -0,0 +1,22 @@
1
+ """CAREamics transformation Pydantic models."""
2
+
3
+ __all__ = [
4
+ "NORM_AND_SPATIAL_UNION",
5
+ "SPATIAL_TRANSFORMS_UNION",
6
+ "N2VManipulateConfig",
7
+ "NormalizeConfig",
8
+ "TransformConfig",
9
+ "XYFlipConfig",
10
+ "XYRandomRotate90Config",
11
+ ]
12
+
13
+
14
+ from .n2v_manipulate_config import N2VManipulateConfig
15
+ from .normalize_config import NormalizeConfig
16
+ from .transform_config import TransformConfig
17
+ from .transform_unions import (
18
+ NORM_AND_SPATIAL_UNION,
19
+ SPATIAL_TRANSFORMS_UNION,
20
+ )
21
+ from .xy_flip_config import XYFlipConfig
22
+ from .xy_random_rotate90_config import XYRandomRotate90Config
@@ -0,0 +1,79 @@
1
+ """Pydantic model for the N2VManipulate transform."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import ConfigDict, Field, field_validator
6
+
7
+ from .transform_config import TransformConfig
8
+
9
+
10
+ # TODO should probably not be a TransformConfig anymore, no reason for it
11
+ # `name` is used as a discriminator field in the transforms
12
+ class N2VManipulateConfig(TransformConfig):
13
+ """
14
+ Pydantic model used to represent N2V manipulation.
15
+
16
+ Attributes
17
+ ----------
18
+ name : Literal["N2VManipulate"]
19
+ Name of the transformation.
20
+ roi_size : int
21
+ Size of the masking region, by default 11.
22
+ masked_pixel_percentage : float
23
+ Percentage of masked pixels, by default 0.2.
24
+ strategy : Literal["uniform", "median"]
25
+ Strategy pixel value replacement, by default "uniform".
26
+ struct_mask_axis : Literal["horizontal", "vertical", "none"]
27
+ Axis of the structN2V mask, by default "none".
28
+ struct_mask_span : int
29
+ Span of the structN2V mask, by default 5.
30
+ """
31
+
32
+ model_config = ConfigDict(
33
+ validate_assignment=True,
34
+ )
35
+
36
+ name: Literal["N2VManipulate"] = "N2VManipulate"
37
+
38
+ roi_size: int = Field(default=11, ge=3, le=21)
39
+ """Size of the region where the pixel manipulation is applied."""
40
+
41
+ masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=10.0)
42
+ """Percentage of masked pixels per image."""
43
+
44
+ remove_center: bool = Field(default=True) # TODO remove it
45
+ """Exclude center pixel from average calculation.""" # TODO rephrase this
46
+
47
+ strategy: Literal["uniform", "median"] = Field(default="uniform")
48
+ """Strategy for pixel value replacement."""
49
+
50
+ struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
51
+ """Orientation of the structN2V mask. Set to `\"non\"` to not apply StructN2V."""
52
+
53
+ struct_mask_span: int = Field(default=5, ge=3, le=15)
54
+ """Size of the structN2V mask."""
55
+
56
+ @field_validator("roi_size", "struct_mask_span")
57
+ @classmethod
58
+ def odd_value(cls, v: int) -> int:
59
+ """
60
+ Validate that the value is odd.
61
+
62
+ Parameters
63
+ ----------
64
+ v : int
65
+ Value to validate.
66
+
67
+ Returns
68
+ -------
69
+ int
70
+ The validated value.
71
+
72
+ Raises
73
+ ------
74
+ ValueError
75
+ If the value is even.
76
+ """
77
+ if v % 2 == 0:
78
+ raise ValueError("Size must be an odd number.")
79
+ return v