careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,738 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from numpy.typing import NDArray
10
+
11
+ if TYPE_CHECKING:
12
+ from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig
13
+
14
+ # TODO this module shouldn't be in lvae folder
15
+
16
+
17
+ def create_histogram(
18
+ bins: int, min_val: float, max_val: float, observation: NDArray, signal: NDArray
19
+ ) -> NDArray:
20
+ """
21
+ Creates a 2D histogram from 'observation' and 'signal'.
22
+
23
+ Parameters
24
+ ----------
25
+ bins : int
26
+ Number of bins in x and y.
27
+ min_val : float
28
+ Lower bound of the lowest bin in x and y.
29
+ max_val : float
30
+ Upper bound of the highest bin in x and y.
31
+ observation : np.ndarray
32
+ 3D numpy array (stack of 2D images).
33
+ Observation.shape[0] must be divisible by signal.shape[0].
34
+ Assumes that n subsequent images in observation belong to one image in 'signal'.
35
+ signal : np.ndarray
36
+ 3D numpy array (stack of 2D images).
37
+
38
+ Returns
39
+ -------
40
+ histogram : np.ndarray
41
+ A 3D array:
42
+ - histogram[0]: Normalized 2D counts.
43
+ - histogram[1]: Lower boundaries of bins along y.
44
+ - histogram[2]: Upper boundaries of bins along y.
45
+ The values for x can be obtained by transposing 'histogram[1]' and 'histogram[2]'.
46
+ """
47
+ histogram = np.zeros((3, bins, bins))
48
+
49
+ value_range = [min_val, max_val]
50
+
51
+ # Compute mapping factor between observation and signal samples
52
+ obs_to_signal_shape_factor = int(observation.shape[0] / signal.shape[0])
53
+
54
+ # Flatten arrays and align signal values
55
+ signal_indices = np.arange(observation.shape[0]) // obs_to_signal_shape_factor
56
+ signal_values = signal[signal_indices].ravel()
57
+ observation_values = observation.ravel()
58
+
59
+ count_histogram, signal_edges, _ = np.histogram2d(
60
+ signal_values, observation_values, bins=bins, range=[value_range, value_range]
61
+ )
62
+
63
+ # Normalize rows to obtain probabilities
64
+ row_sums = count_histogram.sum(axis=1, keepdims=True)
65
+ count_histogram /= np.clip(row_sums, a_min=1e-20, a_max=None)
66
+
67
+ histogram[0] = count_histogram
68
+ histogram[1] = signal_edges[:-1][..., np.newaxis]
69
+ histogram[2] = signal_edges[1:][..., np.newaxis]
70
+
71
+ return histogram
72
+
73
+
74
+ def noise_model_factory(
75
+ model_config: Optional[GaussianMixtureNMConfig],
76
+ ) -> Optional[GaussianMixtureNoiseModel]:
77
+ """Noise model factory for single-channel noise models.
78
+
79
+ Parameters
80
+ ----------
81
+ model_config : Optional[GaussianMixtureNMConfig]
82
+ Noise model configuration for a single Gaussian mixture noise model.
83
+
84
+ Returns
85
+ -------
86
+ Optional[GaussianMixtureNoiseModel]
87
+ A single noise model instance, or None if no config is provided.
88
+
89
+ Raises
90
+ ------
91
+ NotImplementedError
92
+ If the chosen noise model `model_type` is not implemented.
93
+ Currently only `GaussianMixtureNoiseModel` is implemented.
94
+ """
95
+ if model_config:
96
+ if model_config.path:
97
+ if model_config.model_type == "GaussianMixtureNoiseModel":
98
+ return GaussianMixtureNoiseModel(model_config)
99
+ else:
100
+ raise NotImplementedError(
101
+ f"Model {model_config.model_type} is not implemented"
102
+ )
103
+
104
+ # TODO this is outdated and likely should be removed !!
105
+ else: # TODO this means signal/obs are provided. Controlled in pydantic model
106
+ # TODO train a new model. Config should always be provided?
107
+ if model_config.model_type == "GaussianMixtureNoiseModel":
108
+ # TODO one model for each channel all make this choise inside the model?
109
+ # trained_nm = train_gm_noise_model(model_config)
110
+ # return trained_nm
111
+ raise NotImplementedError(
112
+ "GaussianMixtureNoiseModel model training is not implemented."
113
+ )
114
+ else:
115
+ raise NotImplementedError(
116
+ f"Model {model_config.model_type} is not implemented"
117
+ )
118
+ return None
119
+
120
+
121
+ def multichannel_noise_model_factory(
122
+ model_config: Optional[MultiChannelNMConfig],
123
+ ) -> Optional[MultiChannelNoiseModel]:
124
+ """Multi-channel noise model factory.
125
+
126
+ Parameters
127
+ ----------
128
+ model_config : Optional[MultiChannelNMConfig]
129
+ Noise model configuration, a `MultiChannelNMConfig` config that defines
130
+ noise models for the different output channels.
131
+
132
+ Returns
133
+ -------
134
+ Optional[MultiChannelNoiseModel]
135
+ A noise model instance.
136
+
137
+ Raises
138
+ ------
139
+ NotImplementedError
140
+ If the chosen noise model `model_type` is not implemented.
141
+ Currently only `GaussianMixtureNoiseModel` is implemented.
142
+ """
143
+ if model_config:
144
+ noise_models = []
145
+ for nm in model_config.noise_models:
146
+ if nm.path:
147
+ if nm.model_type == "GaussianMixtureNoiseModel":
148
+ noise_models.append(GaussianMixtureNoiseModel(nm))
149
+ else:
150
+ raise NotImplementedError(
151
+ f"Model {nm.model_type} is not implemented"
152
+ )
153
+
154
+ # TODO this is outdated and likely should be removed !!
155
+ else: # TODO this means signal/obs are provided. Controlled in pydantic model
156
+ # TODO train a new model. Config should always be provided?
157
+ if nm.model_type == "GaussianMixtureNoiseModel":
158
+ # TODO one model for each channel all make this choise inside the model?
159
+ # trained_nm = train_gm_noise_model(nm)
160
+ # noise_models.append(trained_nm)
161
+ raise NotImplementedError(
162
+ "GaussianMixtureNoiseModel model training is not implemented."
163
+ )
164
+ else:
165
+ raise NotImplementedError(
166
+ f"Model {nm.model_type} is not implemented"
167
+ )
168
+ return MultiChannelNoiseModel(noise_models)
169
+ return None
170
+
171
+
172
+ def train_gm_noise_model(
173
+ model_config: GaussianMixtureNMConfig,
174
+ signal: np.ndarray,
175
+ observation: np.ndarray,
176
+ ) -> GaussianMixtureNoiseModel:
177
+ """Train a Gaussian mixture noise model.
178
+
179
+ Parameters
180
+ ----------
181
+ model_config : GaussianMixtureNoiseModel
182
+ _description_
183
+
184
+ Returns
185
+ -------
186
+ _description_
187
+ """
188
+ # TODO where to put train params?
189
+ # TODO any training params ? Different channels ?
190
+ noise_model = GaussianMixtureNoiseModel(model_config)
191
+ # TODO revisit config unpacking
192
+ noise_model.fit(signal, observation)
193
+ return noise_model
194
+
195
+
196
+ class MultiChannelNoiseModel(nn.Module):
197
+ def __init__(self, nmodels: list[GaussianMixtureNoiseModel]):
198
+ """Constructor.
199
+
200
+ To handle noise models and the relative likelihood computation for multiple
201
+ output channels (e.g., muSplit, denoiseSplit).
202
+
203
+ This class:
204
+ - receives as input a variable number of noise models, one for each channel.
205
+ - computes the likelihood of observations given signals for each channel.
206
+ - returns the concatenation of these likelihoods.
207
+
208
+ Parameters
209
+ ----------
210
+ nmodels : list[GaussianMixtureNoiseModel]
211
+ List of noise models, one for each output channel.
212
+ """
213
+ super().__init__()
214
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
215
+
216
+ for i, nmodel in enumerate(nmodels): # TODO refactor this !!!
217
+ if nmodel is not None:
218
+ self.add_module(
219
+ f"nmodel_{i}", nmodel
220
+ ) # TODO: wouldn't be easier to use a list?
221
+
222
+ self._nm_cnt = 0
223
+ for nmodel in nmodels:
224
+ if nmodel is not None:
225
+ self._nm_cnt += 1
226
+
227
+ print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}")
228
+
229
+ def to_device(self, device: torch.device):
230
+ self.device = device
231
+ self.to(device)
232
+ for ch_idx in range(self._nm_cnt):
233
+ nmodel = getattr(self, f"nmodel_{ch_idx}")
234
+ nmodel.to_device(device)
235
+
236
+ def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor:
237
+ """Compute the likelihood of observations given signals for each channel.
238
+
239
+ Parameters
240
+ ----------
241
+ obs : torch.Tensor
242
+ Noisy observations, i.e., the target(s). Specifically, the input noisy
243
+ image for HDN, or the noisy unmixed images used for supervision
244
+ for denoiSplit. Shape: (B, C, [Z], Y, X), where C is the number of
245
+ unmixed channels.
246
+ signal : torch.Tensor
247
+ Underlying signals, i.e., the (clean) output of the model. Specifically, the
248
+ denoised image for HDN, or the unmixed images for denoiSplit.
249
+ Shape: (B, C, [Z], Y, X), where C is the number of unmixed channels.
250
+ """
251
+ # Case 1: obs and signal have a single channel (e.g., denoising)
252
+ if obs.shape[1] == 1:
253
+ assert signal.shape[1] == 1
254
+ return self.nmodel_0.likelihood(obs, signal)
255
+
256
+ # Case 2: obs and signal have multiple channels (e.g., denoiSplit)
257
+ assert obs.shape[1] == self._nm_cnt, (
258
+ "The number of channels in `obs` must match the number of noise models."
259
+ f" Got instead: obs={obs.shape[1]}, nm={self._nm_cnt}"
260
+ )
261
+ ll_list = []
262
+ for ch_idx in range(obs.shape[1]):
263
+ nmodel = getattr(self, f"nmodel_{ch_idx}")
264
+ ll_list.append(
265
+ nmodel.likelihood(
266
+ obs[:, ch_idx : ch_idx + 1], signal[:, ch_idx : ch_idx + 1]
267
+ ) # slicing to keep the channel dimension
268
+ )
269
+ return torch.cat(ll_list, dim=1)
270
+
271
+
272
+ class GaussianMixtureNoiseModel(nn.Module):
273
+ """Define a noise model parameterized as a mixture of gaussians.
274
+
275
+ If `config.path` is not provided a new object is initialized from scratch.
276
+ Otherwise, a model is loaded from `config.path`.
277
+
278
+ Parameters
279
+ ----------
280
+ config : GaussianMixtureNMConfig
281
+ A `pydantic` model that defines the configuration of the GMM noise model.
282
+
283
+ Attributes
284
+ ----------
285
+ min_signal : float
286
+ Minimum signal intensity expected in the image.
287
+ max_signal : float
288
+ Maximum signal intensity expected in the image.
289
+ path: Union[str, Path]
290
+ Path to the directory where the trained noise model (*.npz) is saved in the `train` method.
291
+ weight : torch.nn.Parameter
292
+ A [3*n_gaussian, n_coeff] sized array containing the values of the weights
293
+ describing the GMM noise model, with each row corresponding to one
294
+ parameter of each gaussian, namely [mean, standard deviation and weight].
295
+ Specifically, rows are organized as follows:
296
+ - first n_gaussian rows correspond to the means
297
+ - next n_gaussian rows correspond to the weights
298
+ - last n_gaussian rows correspond to the standard deviations
299
+ If `weight=None`, the weight array is initialized using the `min_signal`
300
+ and `max_signal` parameters.
301
+ n_gaussian: int
302
+ Number of gaussians in the mixture.
303
+ n_coeff: int
304
+ Number of coefficients to describe the functional relationship between gaussian
305
+ parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
306
+ relationship and so on.
307
+ device: device
308
+ GPU device.
309
+ min_sigma: float
310
+ All values of `standard deviation` below this are clamped to this value.
311
+ """
312
+
313
+ # TODO training a NM relies on getting a clean data(N2V e.g,)
314
+ def __init__(self, config: GaussianMixtureNMConfig) -> None:
315
+ super().__init__()
316
+ self.device = torch.device("cpu")
317
+
318
+ if config.path is not None:
319
+ params = np.load(config.path)
320
+ else:
321
+ params = config.model_dump(exclude_none=True)
322
+
323
+ min_sigma = torch.tensor(params["min_sigma"])
324
+ min_signal = torch.tensor(params["min_signal"])
325
+ max_signal = torch.tensor(params["max_signal"])
326
+ self.register_buffer("min_signal", min_signal)
327
+ self.register_buffer("max_signal", max_signal)
328
+ self.register_buffer("min_sigma", min_sigma)
329
+ self.register_buffer("tolerance", torch.tensor([1e-10]))
330
+
331
+ if "trained_weight" in params:
332
+ weight = torch.tensor(params["trained_weight"])
333
+ elif "weight" in params and params["weight"] is not None:
334
+ weight = torch.tensor(params["weight"])
335
+ else:
336
+ weight = self._initialize_weights(
337
+ params["n_gaussian"], params["n_coeff"], max_signal, min_signal
338
+ )
339
+
340
+ self.n_gaussian = weight.shape[0] // 3
341
+ self.n_coeff = weight.shape[1]
342
+
343
+ self.register_parameter("weight", nn.Parameter(weight))
344
+ self._set_model_mode(mode="prediction")
345
+
346
+ print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}")
347
+
348
+ def _initialize_weights(
349
+ self,
350
+ n_gaussian: int,
351
+ n_coeff: int,
352
+ max_signal: torch.Tensor,
353
+ min_signal: torch.Tensor,
354
+ ) -> torch.Tensor:
355
+ """Create random weight initialization."""
356
+ weight = torch.randn(n_gaussian * 3, n_coeff)
357
+ weight[n_gaussian : 2 * n_gaussian, 1] = torch.log(
358
+ max_signal - min_signal
359
+ ).float()
360
+ return weight
361
+
362
+ def to_device(self, device: torch.device):
363
+ self.device = device
364
+ self.to(device)
365
+
366
+ def _set_model_mode(self, mode: str) -> None:
367
+ """Move parameters to the device and set weights' requires_grad depending on the mode"""
368
+ if mode == "train":
369
+ self.weight.requires_grad = True
370
+ else:
371
+ self.weight.requires_grad = False
372
+
373
+ def polynomial_regressor(
374
+ self, weight_params: torch.Tensor, signals: torch.Tensor
375
+ ) -> torch.Tensor:
376
+ """Combines `weight_params` and signal `signals` to regress for the gaussian parameter values.
377
+
378
+ Parameters
379
+ ----------
380
+ weight_params : Tensor
381
+ Corresponds to specific rows of the `self.weight`
382
+
383
+ signals : Tensor
384
+ Signals
385
+
386
+ Returns
387
+ -------
388
+ value : Tensor
389
+ Corresponds to either of mean, standard deviation or weight, evaluated at `signals`
390
+ """
391
+ value = torch.zeros_like(signals)
392
+ device = (
393
+ value.device
394
+ ) # TODO the whole device handling in this class needs to be refactored
395
+ weight_params = weight_params.to(device)
396
+ self.min_signal = self.min_signal.to(device)
397
+ self.max_signal = self.max_signal.to(device)
398
+ for i in range(weight_params.shape[0]):
399
+ value += weight_params[i] * (
400
+ ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i
401
+ )
402
+ return value
403
+
404
+ def normal_density(
405
+ self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
406
+ ) -> torch.Tensor:
407
+ """
408
+ Evaluates the normal probability density at `x` given the mean `mean` and standard deviation `std`.
409
+
410
+ Parameters
411
+ ----------
412
+ x: torch.Tensor
413
+ The ground-truth tensor. Shape is (batch, 1, dim1, dim2).
414
+ mean: torch.Tensor
415
+ The inferred mean of distribution. Shape is (batch, 1, dim1, dim2).
416
+ std: torch.Tensor
417
+ The inferred standard deviation of distribution. Shape is (batch, 1, dim1, dim2).
418
+
419
+ Returns
420
+ -------
421
+ tmp: torch.Tensor
422
+ Normal probability density of `x` given `mean` and `std`
423
+ """
424
+ tmp = -((x - mean) ** 2)
425
+ tmp = tmp / (2.0 * std * std)
426
+ tmp = torch.exp(tmp)
427
+ tmp = tmp / torch.sqrt((2.0 * np.pi) * std * std)
428
+ return tmp
429
+
430
+ def likelihood(
431
+ self, observations: torch.Tensor, signals: torch.Tensor
432
+ ) -> torch.Tensor:
433
+ """
434
+ Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters.
435
+
436
+ Parameters
437
+ ----------
438
+ observations : Tensor
439
+ Noisy observations. Shape is (batch, 1, dim1, dim2).
440
+ signals : Tensor
441
+ Underlying signals. Shape is (batch, 1, dim1, dim2).
442
+
443
+ Returns
444
+ -------
445
+ value: torch.Tensor:
446
+ Likelihood of observations given the signals and the GMM noise model
447
+ """
448
+ observations = observations.float()
449
+ signals = signals.float()
450
+ gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals)
451
+ p = 0 # torch.zeros_like(observations)
452
+ for gaussian in range(self.n_gaussian):
453
+ # Ensure all tensors have compatible shapes
454
+ mean = gaussian_parameters[gaussian]
455
+ std = gaussian_parameters[self.n_gaussian + gaussian]
456
+ weight = gaussian_parameters[2 * self.n_gaussian + gaussian]
457
+
458
+ # Compute normal density
459
+ p += (
460
+ self.normal_density(
461
+ observations,
462
+ mean,
463
+ std,
464
+ )
465
+ * weight
466
+ )
467
+ return p + self.tolerance
468
+
469
+ def get_gaussian_parameters(self, signals: torch.Tensor) -> list[torch.Tensor]:
470
+ """
471
+ Returns the noise model for given signals
472
+
473
+ Parameters
474
+ ----------
475
+ signals : Tensor
476
+ Underlying signals
477
+
478
+ Returns
479
+ -------
480
+ noise_model: list of Tensor
481
+ Contains a list of `mu`, `sigma` and `alpha` for the `signals`
482
+ """
483
+ noise_model = []
484
+ mu = []
485
+ sigma = []
486
+ alpha = []
487
+ kernels = self.weight.shape[0] // 3
488
+ device = signals.device
489
+ self.min_signal = self.min_signal.to(device)
490
+ self.max_signal = self.max_signal.to(device)
491
+ self.min_sigma = self.min_sigma.to(device)
492
+ self.tolerance = self.tolerance.to(device)
493
+ for num in range(kernels):
494
+ mu.append(self.polynomial_regressor(self.weight[num, :], signals))
495
+ expval = torch.exp(self.weight[kernels + num, :])
496
+ sigma_temp = self.polynomial_regressor(expval, signals)
497
+ sigma_temp = torch.clamp(sigma_temp, min=self.min_sigma)
498
+ sigma.append(torch.sqrt(sigma_temp))
499
+
500
+ expval = torch.exp(
501
+ self.polynomial_regressor(self.weight[2 * kernels + num, :], signals)
502
+ + self.tolerance
503
+ )
504
+ alpha.append(expval)
505
+
506
+ sum_alpha = 0
507
+ for al in range(kernels):
508
+ sum_alpha = alpha[al] + sum_alpha
509
+
510
+ # sum of alpha is forced to be 1.
511
+ for ker in range(kernels):
512
+ alpha[ker] = alpha[ker] / sum_alpha
513
+
514
+ sum_means = 0
515
+ # sum_means is the alpha weighted average of the means
516
+ for ker in range(kernels):
517
+ sum_means = alpha[ker] * mu[ker] + sum_means
518
+
519
+ # subtracting the alpha weighted average of the means from the means
520
+ # ensures that the GMM has the inclination to have the mean=signals.
521
+ # its like a residual conection. I don't understand why we need to learn the mean?
522
+ for ker in range(kernels):
523
+ mu[ker] = mu[ker] - sum_means + signals
524
+
525
+ for i in range(kernels):
526
+ noise_model.append(mu[i])
527
+ for j in range(kernels):
528
+ noise_model.append(sigma[j])
529
+ for k in range(kernels):
530
+ noise_model.append(alpha[k])
531
+
532
+ return noise_model
533
+
534
+ @staticmethod
535
+ def _fast_shuffle(series: torch.Tensor, num: int) -> torch.Tensor:
536
+ """Shuffle the inputs randomly num times"""
537
+ length = series.shape[0]
538
+ for _ in range(num):
539
+ idx = torch.randperm(length)
540
+ series = series[idx, :]
541
+ return series
542
+
543
+ def get_signal_observation_pairs(
544
+ self,
545
+ signal: NDArray,
546
+ observation: NDArray,
547
+ lower_clip: float,
548
+ upper_clip: float,
549
+ ) -> torch.Tensor:
550
+ """Returns the Signal-Observation pixel intensities as a two-column array
551
+
552
+ Parameters
553
+ ----------
554
+ signal : numpy array
555
+ Clean Signal Data
556
+ observation: numpy array
557
+ Noisy observation Data
558
+ lower_clip: float
559
+ Lower percentile bound for clipping.
560
+ upper_clip: float
561
+ Upper percentile bound for clipping.
562
+
563
+ Returns
564
+ -------
565
+ noise_model: list of torch floats
566
+ Contains a list of `mu`, `sigma` and `alpha` for the `signals`
567
+ """
568
+ lb = np.percentile(signal, lower_clip)
569
+ ub = np.percentile(signal, upper_clip)
570
+ stepsize = observation[0].size
571
+ n_observations = observation.shape[0]
572
+ n_signals = signal.shape[0]
573
+ sig_obs_pairs = np.zeros((n_observations * stepsize, 2))
574
+
575
+ for i in range(n_observations):
576
+ j = i // (n_observations // n_signals)
577
+ sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel()
578
+ sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel()
579
+ sig_obs_pairs = sig_obs_pairs[
580
+ (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub)
581
+ ]
582
+ sig_obs_pairs = sig_obs_pairs.astype(np.float32)
583
+ sig_obs_pairs = torch.from_numpy(sig_obs_pairs)
584
+ return self._fast_shuffle(sig_obs_pairs, 2)
585
+
586
+ def fit(
587
+ self,
588
+ signal: NDArray,
589
+ observation: NDArray,
590
+ learning_rate: float = 1e-1,
591
+ batch_size: int = 250000,
592
+ n_epochs: int = 2000,
593
+ lower_clip: float = 0.0,
594
+ upper_clip: float = 100.0,
595
+ ) -> list[float]:
596
+ """Training to learn the noise model from signal - observation pairs.
597
+
598
+ Parameters
599
+ ----------
600
+ signal: numpy array
601
+ Clean Signal Data
602
+ observation: numpy array
603
+ Noisy Observation Data
604
+ learning_rate: float
605
+ Learning rate. Default = 1e-1.
606
+ batch_size: int
607
+ Nini-batch size. Default = 250000.
608
+ n_epochs: int
609
+ Number of epochs. Default = 2000.
610
+ lower_clip : int
611
+ Lower percentile for clipping. Default is 0.
612
+ upper_clip : int
613
+ Upper percentile for clipping. Default is 100.
614
+ """
615
+ self._set_model_mode(mode="train")
616
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
617
+ self.to_device(device)
618
+ optimizer = torch.optim.Adam([self.weight], lr=learning_rate)
619
+
620
+ sig_obs_pairs = self.get_signal_observation_pairs(
621
+ signal, observation, lower_clip, upper_clip
622
+ )
623
+
624
+ train_losses = []
625
+ counter = 0
626
+ for t in range(n_epochs):
627
+ if (counter + 1) * batch_size >= sig_obs_pairs.shape[0]:
628
+ counter = 0
629
+ sig_obs_pairs = self._fast_shuffle(sig_obs_pairs, 1)
630
+
631
+ batch_vectors = sig_obs_pairs[
632
+ counter * batch_size : (counter + 1) * batch_size, :
633
+ ]
634
+ observations = batch_vectors[:, 1].to(self.device)
635
+ signals = batch_vectors[:, 0].to(self.device)
636
+
637
+ p = self.likelihood(observations, signals)
638
+
639
+ joint_loss = torch.mean(-torch.log(p))
640
+ train_losses.append(joint_loss.item())
641
+
642
+ if self.weight.isnan().any() or self.weight.isinf().any():
643
+ print(
644
+ "NaN or Inf detected in the weights. Aborting training at epoch: ",
645
+ t,
646
+ )
647
+ break
648
+
649
+ if t % 100 == 0:
650
+ last_losses = train_losses[-100:]
651
+ print(t, np.mean(last_losses))
652
+
653
+ optimizer.zero_grad()
654
+ joint_loss.backward()
655
+ optimizer.step()
656
+ counter += 1
657
+
658
+ self._set_model_mode(mode="prediction")
659
+ self.to_device(torch.device("cpu"))
660
+ print("===================\n")
661
+ return train_losses
662
+
663
+ def sample_observation_from_signal(self, signal: NDArray) -> NDArray:
664
+ """
665
+ Sample an instance of observation based on an input signal using a
666
+ learned Gaussian Mixture Model. For each pixel in the input signal,
667
+ samples a corresponding noisy pixel.
668
+
669
+ Parameters
670
+ ----------
671
+ signal: numpy array
672
+ Clean 2D signal data.
673
+
674
+ Returns
675
+ -------
676
+ observation: numpy array
677
+ An instance of noisy observation data based on the input signal.
678
+ """
679
+ assert len(signal.shape) == 2, "Only 2D inputs are supported."
680
+
681
+ signal_tensor = torch.from_numpy(signal).to(torch.float32)
682
+ height, width = signal_tensor.shape
683
+
684
+ with torch.no_grad():
685
+ # Get gaussian parameters for each pixel
686
+ gaussian_params = self.get_gaussian_parameters(signal_tensor)
687
+ means = np.array(gaussian_params[: self.n_gaussian])
688
+ stds = np.array(gaussian_params[self.n_gaussian : self.n_gaussian * 2])
689
+ alphas = np.array(gaussian_params[self.n_gaussian * 2 :])
690
+
691
+ if self.n_gaussian == 1:
692
+ # Single gaussian case
693
+ observation = np.random.normal(
694
+ loc=means[0], scale=stds[0], size=(height, width)
695
+ )
696
+ else:
697
+ # Multiple gaussians: sample component for each pixel
698
+ uniform = np.random.rand(1, height, width)
699
+ # Compute cumulative probabilities for component selection
700
+ cumulative_alphas = np.cumsum(
701
+ alphas, axis=0
702
+ ) # Shape: (n_gaussian, height, width)
703
+ selected_component = np.argmax(
704
+ uniform < cumulative_alphas, axis=0, keepdims=True
705
+ )
706
+
707
+ # For every pixel, choose the corresponding gaussian
708
+ # and get the learned mu and sigma
709
+ selected_mus = np.take_along_axis(means, selected_component, axis=0)
710
+ selected_stds = np.take_along_axis(stds, selected_component, axis=0)
711
+ selected_mus = selected_mus.squeeze(0)
712
+ selected_stds = selected_stds.squeeze(0)
713
+
714
+ # Sample from the normal distribution with learned mu and sigma
715
+ observation = np.random.normal(
716
+ selected_mus, selected_stds, size=(height, width)
717
+ )
718
+ return observation
719
+
720
+ def save(self, path: str, name: str) -> None:
721
+ """Save the trained parameters on the noise model.
722
+
723
+ Parameters
724
+ ----------
725
+ path : str
726
+ Path to save the trained parameters.
727
+ name : str
728
+ File name to save the trained parameters.
729
+ """
730
+ os.makedirs(path, exist_ok=True)
731
+ np.savez(
732
+ os.path.join(path, name),
733
+ trained_weight=self.weight.numpy(),
734
+ min_signal=self.min_signal.numpy(),
735
+ max_signal=self.max_signal.numpy(),
736
+ min_sigma=self.min_sigma,
737
+ )
738
+ print("The trained parameters (" + name + ") is saved at location: " + path)