careamics 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (279) hide show
  1. careamics/__init__.py +24 -0
  2. careamics/careamist.py +961 -0
  3. careamics/cli/__init__.py +5 -0
  4. careamics/cli/conf.py +394 -0
  5. careamics/cli/main.py +234 -0
  6. careamics/cli/utils.py +27 -0
  7. careamics/config/__init__.py +66 -0
  8. careamics/config/algorithms/__init__.py +21 -0
  9. careamics/config/algorithms/care_algorithm_config.py +122 -0
  10. careamics/config/algorithms/hdn_algorithm_config.py +103 -0
  11. careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
  12. careamics/config/algorithms/n2n_algorithm_config.py +115 -0
  13. careamics/config/algorithms/n2v_algorithm_config.py +296 -0
  14. careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
  15. careamics/config/algorithms/unet_algorithm_config.py +91 -0
  16. careamics/config/algorithms/vae_algorithm_config.py +178 -0
  17. careamics/config/architectures/__init__.py +7 -0
  18. careamics/config/architectures/architecture_config.py +37 -0
  19. careamics/config/architectures/lvae_config.py +262 -0
  20. careamics/config/architectures/unet_config.py +125 -0
  21. careamics/config/configuration.py +367 -0
  22. careamics/config/configuration_factories.py +2400 -0
  23. careamics/config/data/__init__.py +27 -0
  24. careamics/config/data/data_config.py +472 -0
  25. careamics/config/data/inference_config.py +237 -0
  26. careamics/config/data/ng_data_config.py +1038 -0
  27. careamics/config/data/patch_filter/__init__.py +15 -0
  28. careamics/config/data/patch_filter/filter_config.py +16 -0
  29. careamics/config/data/patch_filter/mask_filter_config.py +17 -0
  30. careamics/config/data/patch_filter/max_filter_config.py +15 -0
  31. careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
  32. careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
  33. careamics/config/data/patching_strategies/__init__.py +15 -0
  34. careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
  35. careamics/config/data/patching_strategies/_patched_config.py +56 -0
  36. careamics/config/data/patching_strategies/random_patching_config.py +45 -0
  37. careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
  38. careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
  39. careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
  40. careamics/config/data/tile_information.py +65 -0
  41. careamics/config/lightning/__init__.py +15 -0
  42. careamics/config/lightning/callbacks/__init__.py +8 -0
  43. careamics/config/lightning/callbacks/callback_config.py +116 -0
  44. careamics/config/lightning/optimizer_configs.py +186 -0
  45. careamics/config/lightning/training_config.py +70 -0
  46. careamics/config/losses/__init__.py +8 -0
  47. careamics/config/losses/loss_config.py +60 -0
  48. careamics/config/ng_configs/__init__.py +5 -0
  49. careamics/config/ng_configs/n2v_configuration.py +64 -0
  50. careamics/config/ng_configs/ng_configuration.py +256 -0
  51. careamics/config/ng_factories/__init__.py +9 -0
  52. careamics/config/ng_factories/algorithm_factory.py +120 -0
  53. careamics/config/ng_factories/data_factory.py +154 -0
  54. careamics/config/ng_factories/n2v_factory.py +256 -0
  55. careamics/config/ng_factories/training_factory.py +69 -0
  56. careamics/config/noise_model/__init__.py +12 -0
  57. careamics/config/noise_model/likelihood_config.py +60 -0
  58. careamics/config/noise_model/noise_model_config.py +149 -0
  59. careamics/config/support/__init__.py +31 -0
  60. careamics/config/support/supported_activations.py +27 -0
  61. careamics/config/support/supported_algorithms.py +40 -0
  62. careamics/config/support/supported_architectures.py +13 -0
  63. careamics/config/support/supported_data.py +122 -0
  64. careamics/config/support/supported_filters.py +17 -0
  65. careamics/config/support/supported_loggers.py +10 -0
  66. careamics/config/support/supported_losses.py +32 -0
  67. careamics/config/support/supported_optimizers.py +57 -0
  68. careamics/config/support/supported_patching_strategies.py +22 -0
  69. careamics/config/support/supported_pixel_manipulations.py +15 -0
  70. careamics/config/support/supported_struct_axis.py +21 -0
  71. careamics/config/support/supported_transforms.py +12 -0
  72. careamics/config/transformations/__init__.py +22 -0
  73. careamics/config/transformations/n2v_manipulate_config.py +79 -0
  74. careamics/config/transformations/normalize_config.py +59 -0
  75. careamics/config/transformations/transform_config.py +45 -0
  76. careamics/config/transformations/transform_unions.py +29 -0
  77. careamics/config/transformations/xy_flip_config.py +43 -0
  78. careamics/config/transformations/xy_random_rotate90_config.py +35 -0
  79. careamics/config/utils/__init__.py +8 -0
  80. careamics/config/utils/configuration_io.py +85 -0
  81. careamics/config/validators/__init__.py +18 -0
  82. careamics/config/validators/axes_validators.py +90 -0
  83. careamics/config/validators/model_validators.py +84 -0
  84. careamics/config/validators/patch_validators.py +55 -0
  85. careamics/conftest.py +39 -0
  86. careamics/dataset/__init__.py +17 -0
  87. careamics/dataset/dataset_utils/__init__.py +19 -0
  88. careamics/dataset/dataset_utils/dataset_utils.py +118 -0
  89. careamics/dataset/dataset_utils/file_utils.py +141 -0
  90. careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
  91. careamics/dataset/dataset_utils/running_stats.py +189 -0
  92. careamics/dataset/in_memory_dataset.py +303 -0
  93. careamics/dataset/in_memory_pred_dataset.py +88 -0
  94. careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
  95. careamics/dataset/iterable_dataset.py +294 -0
  96. careamics/dataset/iterable_pred_dataset.py +121 -0
  97. careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
  98. careamics/dataset/patching/__init__.py +1 -0
  99. careamics/dataset/patching/patching.py +300 -0
  100. careamics/dataset/patching/random_patching.py +110 -0
  101. careamics/dataset/patching/sequential_patching.py +212 -0
  102. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  103. careamics/dataset/tiling/__init__.py +10 -0
  104. careamics/dataset/tiling/collate_tiles.py +33 -0
  105. careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
  106. careamics/dataset/tiling/tiled_patching.py +166 -0
  107. careamics/dataset_ng/README.md +212 -0
  108. careamics/dataset_ng/__init__.py +0 -0
  109. careamics/dataset_ng/dataset.py +365 -0
  110. careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
  111. careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
  112. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  113. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
  114. careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
  115. careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
  116. careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
  117. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
  118. careamics/dataset_ng/factory.py +180 -0
  119. careamics/dataset_ng/grouped_index_sampler.py +73 -0
  120. careamics/dataset_ng/image_stack/__init__.py +14 -0
  121. careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
  122. careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
  123. careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
  124. careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
  125. careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
  126. careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
  127. careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
  128. careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
  129. careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
  130. careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
  131. careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
  132. careamics/dataset_ng/legacy_interoperability.py +175 -0
  133. careamics/dataset_ng/microsplit_input_synth.py +377 -0
  134. careamics/dataset_ng/patch_extractor/__init__.py +7 -0
  135. careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
  136. careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
  137. careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
  138. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  139. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  140. careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
  141. careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
  142. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  143. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  144. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  145. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  146. careamics/dataset_ng/patching_strategies/__init__.py +26 -0
  147. careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
  148. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
  149. careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
  150. careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
  151. careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
  152. careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
  153. careamics/file_io/__init__.py +15 -0
  154. careamics/file_io/read/__init__.py +11 -0
  155. careamics/file_io/read/get_func.py +57 -0
  156. careamics/file_io/read/tiff.py +58 -0
  157. careamics/file_io/write/__init__.py +15 -0
  158. careamics/file_io/write/get_func.py +63 -0
  159. careamics/file_io/write/tiff.py +40 -0
  160. careamics/lightning/__init__.py +32 -0
  161. careamics/lightning/callbacks/__init__.py +13 -0
  162. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  163. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  164. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  165. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  166. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
  167. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
  168. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  169. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  170. careamics/lightning/dataset_ng/__init__.py +1 -0
  171. careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
  172. careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
  173. careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
  174. careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
  175. careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
  176. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
  177. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
  178. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
  179. careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
  180. careamics/lightning/dataset_ng/data_module.py +529 -0
  181. careamics/lightning/dataset_ng/data_module_utils.py +395 -0
  182. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  183. careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
  184. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
  185. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
  186. careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
  187. careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
  188. careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
  189. careamics/lightning/lightning_module.py +914 -0
  190. careamics/lightning/microsplit_data_module.py +632 -0
  191. careamics/lightning/predict_data_module.py +341 -0
  192. careamics/lightning/train_data_module.py +666 -0
  193. careamics/losses/__init__.py +21 -0
  194. careamics/losses/fcn/__init__.py +1 -0
  195. careamics/losses/fcn/losses.py +125 -0
  196. careamics/losses/loss_factory.py +80 -0
  197. careamics/losses/lvae/__init__.py +1 -0
  198. careamics/losses/lvae/loss_utils.py +83 -0
  199. careamics/losses/lvae/losses.py +589 -0
  200. careamics/lvae_training/__init__.py +0 -0
  201. careamics/lvae_training/calibration.py +191 -0
  202. careamics/lvae_training/dataset/__init__.py +20 -0
  203. careamics/lvae_training/dataset/config.py +135 -0
  204. careamics/lvae_training/dataset/lc_dataset.py +274 -0
  205. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  206. careamics/lvae_training/dataset/multich_dataset.py +1121 -0
  207. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  208. careamics/lvae_training/dataset/multifile_dataset.py +335 -0
  209. careamics/lvae_training/dataset/types.py +32 -0
  210. careamics/lvae_training/dataset/utils/__init__.py +0 -0
  211. careamics/lvae_training/dataset/utils/data_utils.py +114 -0
  212. careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
  213. careamics/lvae_training/dataset/utils/index_manager.py +491 -0
  214. careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
  215. careamics/lvae_training/eval_utils.py +987 -0
  216. careamics/lvae_training/get_config.py +84 -0
  217. careamics/lvae_training/lightning_module.py +701 -0
  218. careamics/lvae_training/metrics.py +214 -0
  219. careamics/lvae_training/train_lvae.py +342 -0
  220. careamics/lvae_training/train_utils.py +121 -0
  221. careamics/model_io/__init__.py +7 -0
  222. careamics/model_io/bioimage/__init__.py +11 -0
  223. careamics/model_io/bioimage/_readme_factory.py +113 -0
  224. careamics/model_io/bioimage/bioimage_utils.py +56 -0
  225. careamics/model_io/bioimage/cover_factory.py +171 -0
  226. careamics/model_io/bioimage/model_description.py +341 -0
  227. careamics/model_io/bmz_io.py +251 -0
  228. careamics/model_io/model_io_utils.py +95 -0
  229. careamics/models/__init__.py +5 -0
  230. careamics/models/activation.py +40 -0
  231. careamics/models/layers.py +495 -0
  232. careamics/models/lvae/__init__.py +3 -0
  233. careamics/models/lvae/layers.py +1371 -0
  234. careamics/models/lvae/likelihoods.py +394 -0
  235. careamics/models/lvae/lvae.py +848 -0
  236. careamics/models/lvae/noise_models.py +738 -0
  237. careamics/models/lvae/stochastic.py +394 -0
  238. careamics/models/lvae/utils.py +404 -0
  239. careamics/models/model_factory.py +54 -0
  240. careamics/models/unet.py +449 -0
  241. careamics/nm_training_placeholder.py +203 -0
  242. careamics/prediction_utils/__init__.py +21 -0
  243. careamics/prediction_utils/lvae_prediction.py +158 -0
  244. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  245. careamics/prediction_utils/prediction_outputs.py +238 -0
  246. careamics/prediction_utils/stitch_prediction.py +193 -0
  247. careamics/py.typed +5 -0
  248. careamics/transforms/__init__.py +22 -0
  249. careamics/transforms/compose.py +173 -0
  250. careamics/transforms/n2v_manipulate.py +150 -0
  251. careamics/transforms/n2v_manipulate_torch.py +149 -0
  252. careamics/transforms/normalize.py +374 -0
  253. careamics/transforms/pixel_manipulation.py +406 -0
  254. careamics/transforms/pixel_manipulation_torch.py +388 -0
  255. careamics/transforms/struct_mask_parameters.py +20 -0
  256. careamics/transforms/transform.py +24 -0
  257. careamics/transforms/tta.py +88 -0
  258. careamics/transforms/xy_flip.py +131 -0
  259. careamics/transforms/xy_random_rotate90.py +108 -0
  260. careamics/utils/__init__.py +19 -0
  261. careamics/utils/autocorrelation.py +40 -0
  262. careamics/utils/base_enum.py +60 -0
  263. careamics/utils/context.py +67 -0
  264. careamics/utils/deprecation.py +63 -0
  265. careamics/utils/lightning_utils.py +71 -0
  266. careamics/utils/logging.py +323 -0
  267. careamics/utils/metrics.py +394 -0
  268. careamics/utils/path_utils.py +26 -0
  269. careamics/utils/plotting.py +76 -0
  270. careamics/utils/ram.py +15 -0
  271. careamics/utils/receptive_field.py +108 -0
  272. careamics/utils/serializers.py +62 -0
  273. careamics/utils/torch_utils.py +150 -0
  274. careamics/utils/version.py +38 -0
  275. careamics-0.0.19.dist-info/METADATA +80 -0
  276. careamics-0.0.19.dist-info/RECORD +279 -0
  277. careamics-0.0.19.dist-info/WHEEL +4 -0
  278. careamics-0.0.19.dist-info/entry_points.txt +2 -0
  279. careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
@@ -0,0 +1,323 @@
1
+ """
2
+ Logging submodule.
3
+
4
+ The methods are responsible for the in-console logger.
5
+ """
6
+
7
+ import logging
8
+ import sys
9
+ import time
10
+ from collections.abc import Generator
11
+ from pathlib import Path
12
+ from typing import Any, Union
13
+
14
+ LOGGERS: dict = {}
15
+
16
+
17
+ def get_logger(
18
+ name: str,
19
+ log_level: int = logging.INFO,
20
+ log_path: Union[str, Path] | None = None,
21
+ ) -> logging.Logger:
22
+ """
23
+ Create a python logger instance with configured handlers.
24
+
25
+ Parameters
26
+ ----------
27
+ name : str
28
+ Name of the logger.
29
+ log_level : int, optional
30
+ Log level (info, error etc.), by default logging.INFO.
31
+ log_path : Optional[Union[str, Path]], optional
32
+ Path in which to save the log, by default None.
33
+
34
+ Returns
35
+ -------
36
+ logging.Logger
37
+ Logger.
38
+ """
39
+ logger = logging.getLogger(name)
40
+ logger.propagate = False
41
+
42
+ if name in LOGGERS:
43
+ return logger
44
+
45
+ for logger_name in LOGGERS:
46
+ if name.startswith(logger_name):
47
+ return logger
48
+
49
+ logger.propagate = False
50
+
51
+ if log_path:
52
+ handlers = [
53
+ logging.StreamHandler(),
54
+ logging.FileHandler(log_path),
55
+ ]
56
+ else:
57
+ handlers = [logging.StreamHandler()]
58
+
59
+ formatter = logging.Formatter("%(message)s")
60
+
61
+ for handler in handlers:
62
+ handler.setFormatter(formatter) # type: ignore
63
+ handler.setLevel(log_level) # type: ignore
64
+ logger.addHandler(handler) # type: ignore
65
+
66
+ logger.setLevel(log_level)
67
+ LOGGERS[name] = True
68
+
69
+ logger.propagate = False
70
+
71
+ return logger
72
+
73
+
74
+ class ProgressBar:
75
+ """
76
+ Keras style progress bar.
77
+
78
+ Adapted from https://github.com/yueyericardo/pkbar.
79
+
80
+ Parameters
81
+ ----------
82
+ max_value : Optional[int], optional
83
+ Maximum progress bar value, by default None.
84
+ epoch : Optional[int], optional
85
+ Zero-indexed current epoch, by default None.
86
+ num_epochs : Optional[int], optional
87
+ Total number of epochs, by default None.
88
+ stateful_metrics : Optional[list], optional
89
+ Iterable of string names of metrics that should *not* be averaged over time.
90
+ Metrics in this list will be displayed as-is. All others will be averaged by
91
+ the progress bar before display, by default None.
92
+ always_stateful : bool, optional
93
+ Whether to set all metrics to be stateful, by default False.
94
+ mode : str, optional
95
+ Mode, one of "train", "val", or "predict", by default "train".
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ max_value: int | None = None,
101
+ epoch: int | None = None,
102
+ num_epochs: int | None = None,
103
+ stateful_metrics: list | None = None,
104
+ always_stateful: bool = False,
105
+ mode: str = "train",
106
+ ) -> None:
107
+ """
108
+ Constructor.
109
+
110
+ Parameters
111
+ ----------
112
+ max_value : Optional[int], optional
113
+ Maximum progress bar value, by default None.
114
+ epoch : Optional[int], optional
115
+ Zero-indexed current epoch, by default None.
116
+ num_epochs : Optional[int], optional
117
+ Total number of epochs, by default None.
118
+ stateful_metrics : Optional[list], optional
119
+ Iterable of string names of metrics that should *not* be averaged over time.
120
+ Metrics in this list will be displayed as-is. All others will be averaged by
121
+ the progress bar before display, by default None.
122
+ always_stateful : bool, optional
123
+ Whether to set all metrics to be stateful, by default False.
124
+ mode : str, optional
125
+ Mode, one of "train", "val", or "predict", by default "train".
126
+ """
127
+ self.max_value = max_value
128
+ # Width of the progress bar
129
+ self.width = 30
130
+ self.always_stateful = always_stateful
131
+
132
+ if (epoch is not None) and (num_epochs is not None):
133
+ print(f"Epoch: {epoch + 1}/{num_epochs}")
134
+
135
+ if stateful_metrics:
136
+ self.stateful_metrics = set(stateful_metrics)
137
+ else:
138
+ self.stateful_metrics = set()
139
+
140
+ self._dynamic_display = (
141
+ (hasattr(sys.stdout, "isatty") and sys.stdout.isatty())
142
+ or "ipykernel" in sys.modules
143
+ or "posix" in sys.modules
144
+ )
145
+ self._total_width = 0
146
+ self._seen_so_far = 0
147
+ # We use a dict + list to avoid garbage collection
148
+ # issues found in OrderedDict
149
+ self._values: dict[Any, Any] = {}
150
+ self._values_order: list[Any] = []
151
+ self._start = time.time()
152
+ self._last_update = 0.0
153
+ self.spin = self.spinning_cursor() if self.max_value is None else None
154
+ if mode == "train" and self.max_value is None:
155
+ self.message = "Estimating dataset size"
156
+ elif mode == "val":
157
+ self.message = "Validating"
158
+ elif mode == "predict":
159
+ self.message = "Denoising"
160
+
161
+ def update(
162
+ self, current_step: int, batch_size: int = 1, values: list | None = None
163
+ ) -> None:
164
+ """
165
+ Update the progress bar.
166
+
167
+ Parameters
168
+ ----------
169
+ current_step : int
170
+ Index of the current step.
171
+ batch_size : int, optional
172
+ Batch size, by default 1.
173
+ values : Optional[list], optional
174
+ Updated metrics values, by default None.
175
+ """
176
+ values = values or []
177
+ for k, v in values:
178
+ # if torch tensor, convert it to numpy
179
+ if str(type(v)) == "<class 'torch.Tensor'>":
180
+ v = v.detach().cpu().numpy()
181
+
182
+ if k not in self._values_order:
183
+ self._values_order.append(k)
184
+ if k not in self.stateful_metrics and not self.always_stateful:
185
+ if k not in self._values:
186
+ self._values[k] = [
187
+ v * (current_step - self._seen_so_far),
188
+ current_step - self._seen_so_far,
189
+ ]
190
+ else:
191
+ self._values[k][0] += v * (current_step - self._seen_so_far)
192
+ self._values[k][1] += current_step - self._seen_so_far
193
+ else:
194
+ # Stateful metrics output a numeric value. This representation
195
+ # means "take an average from a single value" but keeps the
196
+ # numeric formatting.
197
+ self._values[k] = [v, 1]
198
+
199
+ self._seen_so_far = current_step
200
+
201
+ now = time.time()
202
+ info = f" - {(now - self._start):.0f}s"
203
+
204
+ prev_total_width = self._total_width
205
+ if self._dynamic_display:
206
+ sys.stdout.write("\b" * prev_total_width)
207
+ sys.stdout.write("\r")
208
+ else:
209
+ sys.stdout.write("\n")
210
+
211
+ if self.max_value is not None:
212
+ bar = f"{current_step}/{self.max_value} ["
213
+ progress = float(current_step) / self.max_value
214
+ progress_width = int(self.width * progress)
215
+ if progress_width > 0:
216
+ bar += "=" * (progress_width - 1)
217
+ if current_step < self.max_value:
218
+ bar += ">"
219
+ else:
220
+ bar += "="
221
+ bar += "." * (self.width - progress_width)
222
+ bar += "]"
223
+ else:
224
+ bar = (
225
+ f"{self.message} {next(self.spin)}, tile " # type: ignore
226
+ f"No. {current_step * batch_size}"
227
+ )
228
+
229
+ self._total_width = len(bar)
230
+ sys.stdout.write(bar)
231
+
232
+ if current_step > 0:
233
+ time_per_unit = (now - self._start) / current_step
234
+ else:
235
+ time_per_unit = 0
236
+
237
+ if time_per_unit >= 1 or time_per_unit == 0:
238
+ info += f" {time_per_unit:.0f}s/step"
239
+ elif time_per_unit >= 1e-3:
240
+ info += f" {time_per_unit * 1e3:.0f}ms/step"
241
+ else:
242
+ info += f" {time_per_unit * 1e6:.0f}us/step"
243
+
244
+ for k in self._values_order:
245
+ info += f" - {k}:"
246
+ if isinstance(self._values[k], list):
247
+ avg = self._values[k][0] / max(1, self._values[k][1])
248
+ if abs(avg) > 1e-3:
249
+ info += f" {avg:.4f}"
250
+ else:
251
+ info += f" {avg:.4e}"
252
+ else:
253
+ info += f" {self._values[k]}s"
254
+
255
+ self._total_width += len(info)
256
+ if prev_total_width > self._total_width:
257
+ info += " " * (prev_total_width - self._total_width)
258
+
259
+ if self.max_value is not None and current_step >= self.max_value:
260
+ info += "\n"
261
+
262
+ sys.stdout.write(info)
263
+ sys.stdout.flush()
264
+
265
+ self._last_update = now
266
+
267
+ def add(self, n: int, values: list | None = None) -> None:
268
+ """
269
+ Update the progress bar by n steps.
270
+
271
+ Parameters
272
+ ----------
273
+ n : int
274
+ Number of steps to increase the progress bar with.
275
+ values : Optional[list], optional
276
+ Updated metrics values, by default None.
277
+ """
278
+ self.update(self._seen_so_far + n, 1, values=values)
279
+
280
+ def spinning_cursor(self) -> Generator:
281
+ """
282
+ Generate a spinning cursor animation.
283
+
284
+ Taken from https://github.com/manrajgrover/py-spinners/tree/master.
285
+
286
+ Returns
287
+ -------
288
+ Generator
289
+ Generator of animation frames.
290
+ """
291
+ while True:
292
+ yield from [
293
+ "▓ ----- ▒",
294
+ "▓ ----- ▒",
295
+ "▓ ----- ▒",
296
+ "▓ ->--- ▒",
297
+ "▓ ->--- ▒",
298
+ "▓ ->--- ▒",
299
+ "▓ -->-- ▒",
300
+ "▓ -->-- ▒",
301
+ "▓ -->-- ▒",
302
+ "▓ --->- ▒",
303
+ "▓ --->- ▒",
304
+ "▓ --->- ▒",
305
+ "▓ ----> ▒",
306
+ "▓ ----> ▒",
307
+ "▓ ----> ▒",
308
+ "▒ ----- ░",
309
+ "▒ ----- ░",
310
+ "▒ ----- ░",
311
+ "▒ ->--- ░",
312
+ "▒ ->--- ░",
313
+ "▒ ->--- ░",
314
+ "▒ -->-- ░",
315
+ "▒ -->-- ░",
316
+ "▒ -->-- ░",
317
+ "▒ --->- ░",
318
+ "▒ --->- ░",
319
+ "▒ --->- ░",
320
+ "▒ ----> ░",
321
+ "▒ ----> ░",
322
+ "▒ ----> ░",
323
+ ]
@@ -0,0 +1,394 @@
1
+ """
2
+ Metrics submodule.
3
+
4
+ This module contains various metrics and a metrics tracking class.
5
+ """
6
+
7
+ from collections.abc import Callable
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from skimage.metrics import peak_signal_noise_ratio, structural_similarity
13
+ from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
14
+
15
+ # TODO: does this add additional dependency?
16
+
17
+
18
+ # TODO revisit metric for notebook
19
+ def avg_range_invariant_psnr(
20
+ pred: np.ndarray,
21
+ target: np.ndarray,
22
+ ) -> float:
23
+ """Compute the average range-invariant PSNR.
24
+
25
+ Parameters
26
+ ----------
27
+ pred : np.ndarray
28
+ Predicted images.
29
+ target : np.ndarray
30
+ Target images.
31
+
32
+ Returns
33
+ -------
34
+ float
35
+ Average range-invariant PSNR value.
36
+ """
37
+ psnr_arr = []
38
+ for i in range(pred.shape[0]):
39
+ psnr_arr.append(scale_invariant_psnr(pred[i], target[i]))
40
+ return np.mean(psnr_arr)
41
+
42
+
43
+ def psnr(gt: np.ndarray, pred: np.ndarray, data_range: float) -> float:
44
+ """
45
+ Peak Signal to Noise Ratio.
46
+
47
+ This method calls skimage.metrics.peak_signal_noise_ratio. See:
48
+ https://scikit-image.org/docs/dev/api/skimage.metrics.html.
49
+
50
+ NOTE: to avoid unwanted behaviors (e.g., data_range inferred from array dtype),
51
+ the data_range parameter is mandatory.
52
+
53
+ Parameters
54
+ ----------
55
+ gt : np.ndarray
56
+ Ground truth array.
57
+ pred : np.ndarray
58
+ Predicted array.
59
+ data_range : float
60
+ The images pixel range.
61
+
62
+ Returns
63
+ -------
64
+ float
65
+ PSNR value.
66
+ """
67
+ return peak_signal_noise_ratio(gt, pred, data_range=data_range)
68
+
69
+
70
+ def _zero_mean(x: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
71
+ """
72
+ Zero the mean of an array.
73
+
74
+ Parameters
75
+ ----------
76
+ x : numpy.ndarray or torch.Tensor
77
+ Input array.
78
+
79
+ Returns
80
+ -------
81
+ numpy.ndarray or torch.Tensor
82
+ Zero-mean array.
83
+ """
84
+ return x - x.mean()
85
+
86
+
87
+ def _fix_range(
88
+ gt: Union[np.ndarray, torch.Tensor], x: Union[np.ndarray, torch.Tensor]
89
+ ) -> Union[np.ndarray, torch.Tensor]:
90
+ """
91
+ Adjust the range of an array based on a reference ground-truth array.
92
+
93
+ Parameters
94
+ ----------
95
+ gt : Union[np.ndarray, torch.Tensor]
96
+ Ground truth array.
97
+ x : Union[np.ndarray, torch.Tensor]
98
+ Input array.
99
+
100
+ Returns
101
+ -------
102
+ Union[np.ndarray, torch.Tensor]
103
+ Range-adjusted array.
104
+ """
105
+ a = (gt * x).sum() / (x * x).sum()
106
+ return x * a
107
+
108
+
109
+ def _fix(
110
+ gt: Union[np.ndarray, torch.Tensor], x: Union[np.ndarray, torch.Tensor]
111
+ ) -> Union[np.ndarray, torch.Tensor]:
112
+ """
113
+ Zero mean a groud truth array and adjust the range of the array.
114
+
115
+ Parameters
116
+ ----------
117
+ gt : Union[np.ndarray, torch.Tensor]
118
+ Ground truth image.
119
+ x : Union[np.ndarray, torch.Tensor]
120
+ Input array.
121
+
122
+ Returns
123
+ -------
124
+ Union[np.ndarray, torch.Tensor]
125
+ Zero-mean and range-adjusted array.
126
+ """
127
+ gt_ = _zero_mean(gt)
128
+ return _fix_range(gt_, _zero_mean(x))
129
+
130
+
131
+ def scale_invariant_psnr(
132
+ gt: np.ndarray, pred: np.ndarray
133
+ ) -> Union[float, torch.tensor]:
134
+ """
135
+ Scale invariant PSNR.
136
+
137
+ Parameters
138
+ ----------
139
+ gt : np.ndarray
140
+ Ground truth image.
141
+ pred : np.ndarray
142
+ Predicted image.
143
+
144
+ Returns
145
+ -------
146
+ Union[float, torch.tensor]
147
+ Scale invariant PSNR value.
148
+ """
149
+ range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt)
150
+ gt_ = _zero_mean(gt) / np.std(gt)
151
+ return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter)
152
+
153
+
154
+ class RunningPSNR:
155
+ """Compute the running PSNR during validation step in training.
156
+
157
+ This class allows to compute the PSNR on the entire validation set
158
+ one batch at the time.
159
+
160
+ Attributes
161
+ ----------
162
+ N : int
163
+ Number of elements seen so far during the epoch.
164
+ mse_sum : float
165
+ Running sum of the MSE over the N elements seen so far.
166
+ max : float
167
+ Running max value of the N target images seen so far.
168
+ min : float
169
+ Running min value of the N target images seen so far.
170
+ """
171
+
172
+ def __init__(self):
173
+ """Constructor."""
174
+ self.N = None
175
+ self.mse_sum = None
176
+ self.max = self.min = None
177
+ self.reset()
178
+
179
+ def reset(self):
180
+ """Reset the running PSNR computation.
181
+
182
+ Usually called at the end of each epoch.
183
+ """
184
+ self.mse_sum = 0
185
+ self.N = 0
186
+ self.max = self.min = None
187
+
188
+ def update(self, rec: torch.Tensor, tar: torch.Tensor) -> None:
189
+ """Update the running PSNR statistics given a new batch.
190
+
191
+ Parameters
192
+ ----------
193
+ rec : torch.Tensor
194
+ Reconstructed batch.
195
+ tar : torch.Tensor
196
+ Target batch.
197
+ """
198
+ ins_max = torch.max(tar).item()
199
+ ins_min = torch.min(tar).item()
200
+ if self.max is None:
201
+ assert self.min is None
202
+ self.max = ins_max
203
+ self.min = ins_min
204
+ else:
205
+ self.max = max(self.max, ins_max)
206
+ self.min = min(self.min, ins_min)
207
+
208
+ mse = (rec - tar) ** 2
209
+ elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1)
210
+ self.mse_sum += torch.nansum(elementwise_mse)
211
+ self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse))
212
+
213
+ def get(self) -> torch.Tensor | None:
214
+ """Get the actual PSNR value given the running statistics.
215
+
216
+ Returns
217
+ -------
218
+ Optional[torch.Tensor]
219
+ PSNR value.
220
+ """
221
+ if self.N == 0 or self.N is None:
222
+ return None
223
+ rmse = torch.sqrt(self.mse_sum / self.N)
224
+ return 20 * torch.log10((self.max - self.min) / rmse)
225
+
226
+
227
+ def _range_invariant_multiscale_ssim(
228
+ gt_: Union[np.ndarray, torch.Tensor], pred_: Union[np.ndarray, torch.Tensor]
229
+ ) -> float:
230
+ """Compute range invariant multiscale SSIM for a single channel.
231
+
232
+ The advantage of this metric in comparison to commonly used SSIM is that
233
+ it is invariant to scalar multiplications in the prediction.
234
+ # TODO: Add reference to the paper.
235
+
236
+ NOTE: images fed to this function should have channels dimension as the last one.
237
+
238
+ Parameters
239
+ ----------
240
+ gt_ : Union[np.ndarray, torch.Tensor]
241
+ Ground truth image with shape (N, H, W).
242
+ pred_ : Union[np.ndarray, torch.Tensor]
243
+ Predicted image with shape (N, H, W).
244
+
245
+ Returns
246
+ -------
247
+ float
248
+ Range invariant multiscale SSIM value.
249
+ """
250
+ shape = gt_.shape
251
+ gt_ = torch.Tensor(gt_.reshape((shape[0], -1)))
252
+ pred_ = torch.Tensor(pred_.reshape((shape[0], -1)))
253
+ gt_ = _zero_mean(gt_)
254
+ pred_ = _zero_mean(pred_)
255
+ pred_ = _fix(gt_, pred_)
256
+ pred_ = pred_.reshape(shape)
257
+ gt_ = gt_.reshape(shape)
258
+
259
+ ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
260
+ data_range=gt_.max() - gt_.min()
261
+ )
262
+ return ms_ssim(torch.Tensor(pred_[:, None]), torch.Tensor(gt_[:, None])).item()
263
+
264
+
265
+ def multiscale_ssim(
266
+ gt_: Union[np.ndarray, torch.Tensor],
267
+ pred_: Union[np.ndarray, torch.Tensor],
268
+ range_invariant: bool = True,
269
+ ) -> list[Union[float, None]]:
270
+ """Compute channel-wise multiscale SSIM for each channel.
271
+
272
+ It allows to use either standard multiscale SSIM or its range-invariant version.
273
+
274
+ NOTE: images fed to this function should have channels dimension as the last one.
275
+ # TODO: do we want to allow this behavior? or we want the usual (N, C, H, W)?
276
+
277
+ Parameters
278
+ ----------
279
+ gt_ : Union[np.ndarray, torch.Tensor]
280
+ Ground truth image with shape (N, H, W, C).
281
+ pred_ : Union[np.ndarray, torch.Tensor]
282
+ Predicted image with shape (N, H, W, C).
283
+ range_invariant : bool
284
+ Whether to use standard or range invariant multiscale SSIM.
285
+
286
+ Returns
287
+ -------
288
+ list[float]
289
+ List of SSIM values for each channel.
290
+ """
291
+ ms_ssim_values = {}
292
+ for ch_idx in range(gt_.shape[-1]):
293
+ tar_tmp = gt_[..., ch_idx]
294
+ pred_tmp = pred_[..., ch_idx]
295
+ if range_invariant:
296
+ ms_ssim_values[ch_idx] = _range_invariant_multiscale_ssim(
297
+ gt_=tar_tmp, pred_=pred_tmp
298
+ )
299
+ else:
300
+ ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
301
+ data_range=tar_tmp.max() - tar_tmp.min()
302
+ )
303
+ ms_ssim_values[ch_idx] = ms_ssim(
304
+ torch.Tensor(pred_tmp[:, None]), torch.Tensor(tar_tmp[:, None])
305
+ ).item()
306
+
307
+ return [ms_ssim_values[i] for i in range(gt_.shape[-1])] # type: ignore
308
+
309
+
310
+ def _avg_psnr(target: np.ndarray, prediction: np.ndarray, psnr_fn: Callable) -> float:
311
+ """Compute the average PSNR over a batch of images.
312
+
313
+ Parameters
314
+ ----------
315
+ target : np.ndarray
316
+ Array of ground truth images, shape is (N, C, H, W).
317
+ prediction : np.ndarray
318
+ Array of predicted images, shape is (N, C, H, W).
319
+ psnr_fn : Callable
320
+ PSNR function to use.
321
+
322
+ Returns
323
+ -------
324
+ float
325
+ Average PSNR value over the batch.
326
+ """
327
+ return np.mean(
328
+ [
329
+ psnr_fn(target[i : i + 1], prediction[i : i + 1]).item()
330
+ for i in range(len(prediction))
331
+ ]
332
+ )
333
+
334
+
335
+ def avg_range_inv_psnr(target: np.ndarray, prediction: np.ndarray) -> float:
336
+ """Compute the average range-invariant PSNR over a batch of images.
337
+
338
+ Parameters
339
+ ----------
340
+ target : np.ndarray
341
+ Array of ground truth images, shape is (N, C, H, W).
342
+ prediction : np.ndarray
343
+ Array of predicted images, shape is (N, C, H, W).
344
+
345
+ Returns
346
+ -------
347
+ float
348
+ Average range-invariant PSNR value over the batch.
349
+ """
350
+ return _avg_psnr(target, prediction, scale_invariant_psnr)
351
+
352
+
353
+ def avg_psnr(target: np.ndarray, prediction: np.ndarray) -> float:
354
+ """Compute the average PSNR over a batch of images.
355
+
356
+ Parameters
357
+ ----------
358
+ target : np.ndarray
359
+ Array of ground truth images, shape is (N, C, H, W).
360
+ prediction : np.ndarray
361
+ Array of predicted images, shape is (N, C, H, W).
362
+
363
+ Returns
364
+ -------
365
+ float
366
+ Average PSNR value over the batch.
367
+ """
368
+ return _avg_psnr(target, prediction, psnr)
369
+
370
+
371
+ def avg_ssim(
372
+ target: Union[np.ndarray, torch.Tensor], prediction: Union[np.ndarray, torch.Tensor]
373
+ ) -> tuple[float, float]:
374
+ """Compute the average Structural Similarity (SSIM) over a batch of images.
375
+
376
+ Parameters
377
+ ----------
378
+ target : np.ndarray
379
+ Array of ground truth images, shape is (N, C, H, W).
380
+ prediction : np.ndarray
381
+ Array of predicted images, shape is (N, C, H, W).
382
+
383
+ Returns
384
+ -------
385
+ tuple[float, float]
386
+ Mean and standard deviation of SSIM values over the batch.
387
+ """
388
+ ssim = [
389
+ structural_similarity(
390
+ target[i], prediction[i], data_range=(target[i].max() - target[i].min())
391
+ )
392
+ for i in range(len(target))
393
+ ]
394
+ return np.mean(ssim), np.std(ssim)