careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Logging submodule.
|
|
3
|
+
|
|
4
|
+
The methods are responsible for the in-console logger.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import sys
|
|
9
|
+
import time
|
|
10
|
+
from collections.abc import Generator
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Union
|
|
13
|
+
|
|
14
|
+
LOGGERS: dict = {}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_logger(
|
|
18
|
+
name: str,
|
|
19
|
+
log_level: int = logging.INFO,
|
|
20
|
+
log_path: Union[str, Path] | None = None,
|
|
21
|
+
) -> logging.Logger:
|
|
22
|
+
"""
|
|
23
|
+
Create a python logger instance with configured handlers.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
name : str
|
|
28
|
+
Name of the logger.
|
|
29
|
+
log_level : int, optional
|
|
30
|
+
Log level (info, error etc.), by default logging.INFO.
|
|
31
|
+
log_path : Optional[Union[str, Path]], optional
|
|
32
|
+
Path in which to save the log, by default None.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
logging.Logger
|
|
37
|
+
Logger.
|
|
38
|
+
"""
|
|
39
|
+
logger = logging.getLogger(name)
|
|
40
|
+
logger.propagate = False
|
|
41
|
+
|
|
42
|
+
if name in LOGGERS:
|
|
43
|
+
return logger
|
|
44
|
+
|
|
45
|
+
for logger_name in LOGGERS:
|
|
46
|
+
if name.startswith(logger_name):
|
|
47
|
+
return logger
|
|
48
|
+
|
|
49
|
+
logger.propagate = False
|
|
50
|
+
|
|
51
|
+
if log_path:
|
|
52
|
+
handlers = [
|
|
53
|
+
logging.StreamHandler(),
|
|
54
|
+
logging.FileHandler(log_path),
|
|
55
|
+
]
|
|
56
|
+
else:
|
|
57
|
+
handlers = [logging.StreamHandler()]
|
|
58
|
+
|
|
59
|
+
formatter = logging.Formatter("%(message)s")
|
|
60
|
+
|
|
61
|
+
for handler in handlers:
|
|
62
|
+
handler.setFormatter(formatter) # type: ignore
|
|
63
|
+
handler.setLevel(log_level) # type: ignore
|
|
64
|
+
logger.addHandler(handler) # type: ignore
|
|
65
|
+
|
|
66
|
+
logger.setLevel(log_level)
|
|
67
|
+
LOGGERS[name] = True
|
|
68
|
+
|
|
69
|
+
logger.propagate = False
|
|
70
|
+
|
|
71
|
+
return logger
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ProgressBar:
|
|
75
|
+
"""
|
|
76
|
+
Keras style progress bar.
|
|
77
|
+
|
|
78
|
+
Adapted from https://github.com/yueyericardo/pkbar.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
max_value : Optional[int], optional
|
|
83
|
+
Maximum progress bar value, by default None.
|
|
84
|
+
epoch : Optional[int], optional
|
|
85
|
+
Zero-indexed current epoch, by default None.
|
|
86
|
+
num_epochs : Optional[int], optional
|
|
87
|
+
Total number of epochs, by default None.
|
|
88
|
+
stateful_metrics : Optional[list], optional
|
|
89
|
+
Iterable of string names of metrics that should *not* be averaged over time.
|
|
90
|
+
Metrics in this list will be displayed as-is. All others will be averaged by
|
|
91
|
+
the progress bar before display, by default None.
|
|
92
|
+
always_stateful : bool, optional
|
|
93
|
+
Whether to set all metrics to be stateful, by default False.
|
|
94
|
+
mode : str, optional
|
|
95
|
+
Mode, one of "train", "val", or "predict", by default "train".
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
max_value: int | None = None,
|
|
101
|
+
epoch: int | None = None,
|
|
102
|
+
num_epochs: int | None = None,
|
|
103
|
+
stateful_metrics: list | None = None,
|
|
104
|
+
always_stateful: bool = False,
|
|
105
|
+
mode: str = "train",
|
|
106
|
+
) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Constructor.
|
|
109
|
+
|
|
110
|
+
Parameters
|
|
111
|
+
----------
|
|
112
|
+
max_value : Optional[int], optional
|
|
113
|
+
Maximum progress bar value, by default None.
|
|
114
|
+
epoch : Optional[int], optional
|
|
115
|
+
Zero-indexed current epoch, by default None.
|
|
116
|
+
num_epochs : Optional[int], optional
|
|
117
|
+
Total number of epochs, by default None.
|
|
118
|
+
stateful_metrics : Optional[list], optional
|
|
119
|
+
Iterable of string names of metrics that should *not* be averaged over time.
|
|
120
|
+
Metrics in this list will be displayed as-is. All others will be averaged by
|
|
121
|
+
the progress bar before display, by default None.
|
|
122
|
+
always_stateful : bool, optional
|
|
123
|
+
Whether to set all metrics to be stateful, by default False.
|
|
124
|
+
mode : str, optional
|
|
125
|
+
Mode, one of "train", "val", or "predict", by default "train".
|
|
126
|
+
"""
|
|
127
|
+
self.max_value = max_value
|
|
128
|
+
# Width of the progress bar
|
|
129
|
+
self.width = 30
|
|
130
|
+
self.always_stateful = always_stateful
|
|
131
|
+
|
|
132
|
+
if (epoch is not None) and (num_epochs is not None):
|
|
133
|
+
print(f"Epoch: {epoch + 1}/{num_epochs}")
|
|
134
|
+
|
|
135
|
+
if stateful_metrics:
|
|
136
|
+
self.stateful_metrics = set(stateful_metrics)
|
|
137
|
+
else:
|
|
138
|
+
self.stateful_metrics = set()
|
|
139
|
+
|
|
140
|
+
self._dynamic_display = (
|
|
141
|
+
(hasattr(sys.stdout, "isatty") and sys.stdout.isatty())
|
|
142
|
+
or "ipykernel" in sys.modules
|
|
143
|
+
or "posix" in sys.modules
|
|
144
|
+
)
|
|
145
|
+
self._total_width = 0
|
|
146
|
+
self._seen_so_far = 0
|
|
147
|
+
# We use a dict + list to avoid garbage collection
|
|
148
|
+
# issues found in OrderedDict
|
|
149
|
+
self._values: dict[Any, Any] = {}
|
|
150
|
+
self._values_order: list[Any] = []
|
|
151
|
+
self._start = time.time()
|
|
152
|
+
self._last_update = 0.0
|
|
153
|
+
self.spin = self.spinning_cursor() if self.max_value is None else None
|
|
154
|
+
if mode == "train" and self.max_value is None:
|
|
155
|
+
self.message = "Estimating dataset size"
|
|
156
|
+
elif mode == "val":
|
|
157
|
+
self.message = "Validating"
|
|
158
|
+
elif mode == "predict":
|
|
159
|
+
self.message = "Denoising"
|
|
160
|
+
|
|
161
|
+
def update(
|
|
162
|
+
self, current_step: int, batch_size: int = 1, values: list | None = None
|
|
163
|
+
) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Update the progress bar.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
current_step : int
|
|
170
|
+
Index of the current step.
|
|
171
|
+
batch_size : int, optional
|
|
172
|
+
Batch size, by default 1.
|
|
173
|
+
values : Optional[list], optional
|
|
174
|
+
Updated metrics values, by default None.
|
|
175
|
+
"""
|
|
176
|
+
values = values or []
|
|
177
|
+
for k, v in values:
|
|
178
|
+
# if torch tensor, convert it to numpy
|
|
179
|
+
if str(type(v)) == "<class 'torch.Tensor'>":
|
|
180
|
+
v = v.detach().cpu().numpy()
|
|
181
|
+
|
|
182
|
+
if k not in self._values_order:
|
|
183
|
+
self._values_order.append(k)
|
|
184
|
+
if k not in self.stateful_metrics and not self.always_stateful:
|
|
185
|
+
if k not in self._values:
|
|
186
|
+
self._values[k] = [
|
|
187
|
+
v * (current_step - self._seen_so_far),
|
|
188
|
+
current_step - self._seen_so_far,
|
|
189
|
+
]
|
|
190
|
+
else:
|
|
191
|
+
self._values[k][0] += v * (current_step - self._seen_so_far)
|
|
192
|
+
self._values[k][1] += current_step - self._seen_so_far
|
|
193
|
+
else:
|
|
194
|
+
# Stateful metrics output a numeric value. This representation
|
|
195
|
+
# means "take an average from a single value" but keeps the
|
|
196
|
+
# numeric formatting.
|
|
197
|
+
self._values[k] = [v, 1]
|
|
198
|
+
|
|
199
|
+
self._seen_so_far = current_step
|
|
200
|
+
|
|
201
|
+
now = time.time()
|
|
202
|
+
info = f" - {(now - self._start):.0f}s"
|
|
203
|
+
|
|
204
|
+
prev_total_width = self._total_width
|
|
205
|
+
if self._dynamic_display:
|
|
206
|
+
sys.stdout.write("\b" * prev_total_width)
|
|
207
|
+
sys.stdout.write("\r")
|
|
208
|
+
else:
|
|
209
|
+
sys.stdout.write("\n")
|
|
210
|
+
|
|
211
|
+
if self.max_value is not None:
|
|
212
|
+
bar = f"{current_step}/{self.max_value} ["
|
|
213
|
+
progress = float(current_step) / self.max_value
|
|
214
|
+
progress_width = int(self.width * progress)
|
|
215
|
+
if progress_width > 0:
|
|
216
|
+
bar += "=" * (progress_width - 1)
|
|
217
|
+
if current_step < self.max_value:
|
|
218
|
+
bar += ">"
|
|
219
|
+
else:
|
|
220
|
+
bar += "="
|
|
221
|
+
bar += "." * (self.width - progress_width)
|
|
222
|
+
bar += "]"
|
|
223
|
+
else:
|
|
224
|
+
bar = (
|
|
225
|
+
f"{self.message} {next(self.spin)}, tile " # type: ignore
|
|
226
|
+
f"No. {current_step * batch_size}"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
self._total_width = len(bar)
|
|
230
|
+
sys.stdout.write(bar)
|
|
231
|
+
|
|
232
|
+
if current_step > 0:
|
|
233
|
+
time_per_unit = (now - self._start) / current_step
|
|
234
|
+
else:
|
|
235
|
+
time_per_unit = 0
|
|
236
|
+
|
|
237
|
+
if time_per_unit >= 1 or time_per_unit == 0:
|
|
238
|
+
info += f" {time_per_unit:.0f}s/step"
|
|
239
|
+
elif time_per_unit >= 1e-3:
|
|
240
|
+
info += f" {time_per_unit * 1e3:.0f}ms/step"
|
|
241
|
+
else:
|
|
242
|
+
info += f" {time_per_unit * 1e6:.0f}us/step"
|
|
243
|
+
|
|
244
|
+
for k in self._values_order:
|
|
245
|
+
info += f" - {k}:"
|
|
246
|
+
if isinstance(self._values[k], list):
|
|
247
|
+
avg = self._values[k][0] / max(1, self._values[k][1])
|
|
248
|
+
if abs(avg) > 1e-3:
|
|
249
|
+
info += f" {avg:.4f}"
|
|
250
|
+
else:
|
|
251
|
+
info += f" {avg:.4e}"
|
|
252
|
+
else:
|
|
253
|
+
info += f" {self._values[k]}s"
|
|
254
|
+
|
|
255
|
+
self._total_width += len(info)
|
|
256
|
+
if prev_total_width > self._total_width:
|
|
257
|
+
info += " " * (prev_total_width - self._total_width)
|
|
258
|
+
|
|
259
|
+
if self.max_value is not None and current_step >= self.max_value:
|
|
260
|
+
info += "\n"
|
|
261
|
+
|
|
262
|
+
sys.stdout.write(info)
|
|
263
|
+
sys.stdout.flush()
|
|
264
|
+
|
|
265
|
+
self._last_update = now
|
|
266
|
+
|
|
267
|
+
def add(self, n: int, values: list | None = None) -> None:
|
|
268
|
+
"""
|
|
269
|
+
Update the progress bar by n steps.
|
|
270
|
+
|
|
271
|
+
Parameters
|
|
272
|
+
----------
|
|
273
|
+
n : int
|
|
274
|
+
Number of steps to increase the progress bar with.
|
|
275
|
+
values : Optional[list], optional
|
|
276
|
+
Updated metrics values, by default None.
|
|
277
|
+
"""
|
|
278
|
+
self.update(self._seen_so_far + n, 1, values=values)
|
|
279
|
+
|
|
280
|
+
def spinning_cursor(self) -> Generator:
|
|
281
|
+
"""
|
|
282
|
+
Generate a spinning cursor animation.
|
|
283
|
+
|
|
284
|
+
Taken from https://github.com/manrajgrover/py-spinners/tree/master.
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
Generator
|
|
289
|
+
Generator of animation frames.
|
|
290
|
+
"""
|
|
291
|
+
while True:
|
|
292
|
+
yield from [
|
|
293
|
+
"▓ ----- ▒",
|
|
294
|
+
"▓ ----- ▒",
|
|
295
|
+
"▓ ----- ▒",
|
|
296
|
+
"▓ ->--- ▒",
|
|
297
|
+
"▓ ->--- ▒",
|
|
298
|
+
"▓ ->--- ▒",
|
|
299
|
+
"▓ -->-- ▒",
|
|
300
|
+
"▓ -->-- ▒",
|
|
301
|
+
"▓ -->-- ▒",
|
|
302
|
+
"▓ --->- ▒",
|
|
303
|
+
"▓ --->- ▒",
|
|
304
|
+
"▓ --->- ▒",
|
|
305
|
+
"▓ ----> ▒",
|
|
306
|
+
"▓ ----> ▒",
|
|
307
|
+
"▓ ----> ▒",
|
|
308
|
+
"▒ ----- ░",
|
|
309
|
+
"▒ ----- ░",
|
|
310
|
+
"▒ ----- ░",
|
|
311
|
+
"▒ ->--- ░",
|
|
312
|
+
"▒ ->--- ░",
|
|
313
|
+
"▒ ->--- ░",
|
|
314
|
+
"▒ -->-- ░",
|
|
315
|
+
"▒ -->-- ░",
|
|
316
|
+
"▒ -->-- ░",
|
|
317
|
+
"▒ --->- ░",
|
|
318
|
+
"▒ --->- ░",
|
|
319
|
+
"▒ --->- ░",
|
|
320
|
+
"▒ ----> ░",
|
|
321
|
+
"▒ ----> ░",
|
|
322
|
+
"▒ ----> ░",
|
|
323
|
+
]
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Metrics submodule.
|
|
3
|
+
|
|
4
|
+
This module contains various metrics and a metrics tracking class.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
|
|
13
|
+
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
|
|
14
|
+
|
|
15
|
+
# TODO: does this add additional dependency?
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# TODO revisit metric for notebook
|
|
19
|
+
def avg_range_invariant_psnr(
|
|
20
|
+
pred: np.ndarray,
|
|
21
|
+
target: np.ndarray,
|
|
22
|
+
) -> float:
|
|
23
|
+
"""Compute the average range-invariant PSNR.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
pred : np.ndarray
|
|
28
|
+
Predicted images.
|
|
29
|
+
target : np.ndarray
|
|
30
|
+
Target images.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
float
|
|
35
|
+
Average range-invariant PSNR value.
|
|
36
|
+
"""
|
|
37
|
+
psnr_arr = []
|
|
38
|
+
for i in range(pred.shape[0]):
|
|
39
|
+
psnr_arr.append(scale_invariant_psnr(pred[i], target[i]))
|
|
40
|
+
return np.mean(psnr_arr)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def psnr(gt: np.ndarray, pred: np.ndarray, data_range: float) -> float:
|
|
44
|
+
"""
|
|
45
|
+
Peak Signal to Noise Ratio.
|
|
46
|
+
|
|
47
|
+
This method calls skimage.metrics.peak_signal_noise_ratio. See:
|
|
48
|
+
https://scikit-image.org/docs/dev/api/skimage.metrics.html.
|
|
49
|
+
|
|
50
|
+
NOTE: to avoid unwanted behaviors (e.g., data_range inferred from array dtype),
|
|
51
|
+
the data_range parameter is mandatory.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
gt : np.ndarray
|
|
56
|
+
Ground truth array.
|
|
57
|
+
pred : np.ndarray
|
|
58
|
+
Predicted array.
|
|
59
|
+
data_range : float
|
|
60
|
+
The images pixel range.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
float
|
|
65
|
+
PSNR value.
|
|
66
|
+
"""
|
|
67
|
+
return peak_signal_noise_ratio(gt, pred, data_range=data_range)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _zero_mean(x: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
|
71
|
+
"""
|
|
72
|
+
Zero the mean of an array.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
x : numpy.ndarray or torch.Tensor
|
|
77
|
+
Input array.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
numpy.ndarray or torch.Tensor
|
|
82
|
+
Zero-mean array.
|
|
83
|
+
"""
|
|
84
|
+
return x - x.mean()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _fix_range(
|
|
88
|
+
gt: Union[np.ndarray, torch.Tensor], x: Union[np.ndarray, torch.Tensor]
|
|
89
|
+
) -> Union[np.ndarray, torch.Tensor]:
|
|
90
|
+
"""
|
|
91
|
+
Adjust the range of an array based on a reference ground-truth array.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
gt : Union[np.ndarray, torch.Tensor]
|
|
96
|
+
Ground truth array.
|
|
97
|
+
x : Union[np.ndarray, torch.Tensor]
|
|
98
|
+
Input array.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
Union[np.ndarray, torch.Tensor]
|
|
103
|
+
Range-adjusted array.
|
|
104
|
+
"""
|
|
105
|
+
a = (gt * x).sum() / (x * x).sum()
|
|
106
|
+
return x * a
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _fix(
|
|
110
|
+
gt: Union[np.ndarray, torch.Tensor], x: Union[np.ndarray, torch.Tensor]
|
|
111
|
+
) -> Union[np.ndarray, torch.Tensor]:
|
|
112
|
+
"""
|
|
113
|
+
Zero mean a groud truth array and adjust the range of the array.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
gt : Union[np.ndarray, torch.Tensor]
|
|
118
|
+
Ground truth image.
|
|
119
|
+
x : Union[np.ndarray, torch.Tensor]
|
|
120
|
+
Input array.
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
Union[np.ndarray, torch.Tensor]
|
|
125
|
+
Zero-mean and range-adjusted array.
|
|
126
|
+
"""
|
|
127
|
+
gt_ = _zero_mean(gt)
|
|
128
|
+
return _fix_range(gt_, _zero_mean(x))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def scale_invariant_psnr(
|
|
132
|
+
gt: np.ndarray, pred: np.ndarray
|
|
133
|
+
) -> Union[float, torch.tensor]:
|
|
134
|
+
"""
|
|
135
|
+
Scale invariant PSNR.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
gt : np.ndarray
|
|
140
|
+
Ground truth image.
|
|
141
|
+
pred : np.ndarray
|
|
142
|
+
Predicted image.
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
Union[float, torch.tensor]
|
|
147
|
+
Scale invariant PSNR value.
|
|
148
|
+
"""
|
|
149
|
+
range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt)
|
|
150
|
+
gt_ = _zero_mean(gt) / np.std(gt)
|
|
151
|
+
return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class RunningPSNR:
|
|
155
|
+
"""Compute the running PSNR during validation step in training.
|
|
156
|
+
|
|
157
|
+
This class allows to compute the PSNR on the entire validation set
|
|
158
|
+
one batch at the time.
|
|
159
|
+
|
|
160
|
+
Attributes
|
|
161
|
+
----------
|
|
162
|
+
N : int
|
|
163
|
+
Number of elements seen so far during the epoch.
|
|
164
|
+
mse_sum : float
|
|
165
|
+
Running sum of the MSE over the N elements seen so far.
|
|
166
|
+
max : float
|
|
167
|
+
Running max value of the N target images seen so far.
|
|
168
|
+
min : float
|
|
169
|
+
Running min value of the N target images seen so far.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self):
|
|
173
|
+
"""Constructor."""
|
|
174
|
+
self.N = None
|
|
175
|
+
self.mse_sum = None
|
|
176
|
+
self.max = self.min = None
|
|
177
|
+
self.reset()
|
|
178
|
+
|
|
179
|
+
def reset(self):
|
|
180
|
+
"""Reset the running PSNR computation.
|
|
181
|
+
|
|
182
|
+
Usually called at the end of each epoch.
|
|
183
|
+
"""
|
|
184
|
+
self.mse_sum = 0
|
|
185
|
+
self.N = 0
|
|
186
|
+
self.max = self.min = None
|
|
187
|
+
|
|
188
|
+
def update(self, rec: torch.Tensor, tar: torch.Tensor) -> None:
|
|
189
|
+
"""Update the running PSNR statistics given a new batch.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
rec : torch.Tensor
|
|
194
|
+
Reconstructed batch.
|
|
195
|
+
tar : torch.Tensor
|
|
196
|
+
Target batch.
|
|
197
|
+
"""
|
|
198
|
+
ins_max = torch.max(tar).item()
|
|
199
|
+
ins_min = torch.min(tar).item()
|
|
200
|
+
if self.max is None:
|
|
201
|
+
assert self.min is None
|
|
202
|
+
self.max = ins_max
|
|
203
|
+
self.min = ins_min
|
|
204
|
+
else:
|
|
205
|
+
self.max = max(self.max, ins_max)
|
|
206
|
+
self.min = min(self.min, ins_min)
|
|
207
|
+
|
|
208
|
+
mse = (rec - tar) ** 2
|
|
209
|
+
elementwise_mse = torch.mean(mse.view(len(mse), -1), dim=1)
|
|
210
|
+
self.mse_sum += torch.nansum(elementwise_mse)
|
|
211
|
+
self.N += len(elementwise_mse) - torch.sum(torch.isnan(elementwise_mse))
|
|
212
|
+
|
|
213
|
+
def get(self) -> torch.Tensor | None:
|
|
214
|
+
"""Get the actual PSNR value given the running statistics.
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
Optional[torch.Tensor]
|
|
219
|
+
PSNR value.
|
|
220
|
+
"""
|
|
221
|
+
if self.N == 0 or self.N is None:
|
|
222
|
+
return None
|
|
223
|
+
rmse = torch.sqrt(self.mse_sum / self.N)
|
|
224
|
+
return 20 * torch.log10((self.max - self.min) / rmse)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _range_invariant_multiscale_ssim(
|
|
228
|
+
gt_: Union[np.ndarray, torch.Tensor], pred_: Union[np.ndarray, torch.Tensor]
|
|
229
|
+
) -> float:
|
|
230
|
+
"""Compute range invariant multiscale SSIM for a single channel.
|
|
231
|
+
|
|
232
|
+
The advantage of this metric in comparison to commonly used SSIM is that
|
|
233
|
+
it is invariant to scalar multiplications in the prediction.
|
|
234
|
+
# TODO: Add reference to the paper.
|
|
235
|
+
|
|
236
|
+
NOTE: images fed to this function should have channels dimension as the last one.
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
gt_ : Union[np.ndarray, torch.Tensor]
|
|
241
|
+
Ground truth image with shape (N, H, W).
|
|
242
|
+
pred_ : Union[np.ndarray, torch.Tensor]
|
|
243
|
+
Predicted image with shape (N, H, W).
|
|
244
|
+
|
|
245
|
+
Returns
|
|
246
|
+
-------
|
|
247
|
+
float
|
|
248
|
+
Range invariant multiscale SSIM value.
|
|
249
|
+
"""
|
|
250
|
+
shape = gt_.shape
|
|
251
|
+
gt_ = torch.Tensor(gt_.reshape((shape[0], -1)))
|
|
252
|
+
pred_ = torch.Tensor(pred_.reshape((shape[0], -1)))
|
|
253
|
+
gt_ = _zero_mean(gt_)
|
|
254
|
+
pred_ = _zero_mean(pred_)
|
|
255
|
+
pred_ = _fix(gt_, pred_)
|
|
256
|
+
pred_ = pred_.reshape(shape)
|
|
257
|
+
gt_ = gt_.reshape(shape)
|
|
258
|
+
|
|
259
|
+
ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
|
|
260
|
+
data_range=gt_.max() - gt_.min()
|
|
261
|
+
)
|
|
262
|
+
return ms_ssim(torch.Tensor(pred_[:, None]), torch.Tensor(gt_[:, None])).item()
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def multiscale_ssim(
|
|
266
|
+
gt_: Union[np.ndarray, torch.Tensor],
|
|
267
|
+
pred_: Union[np.ndarray, torch.Tensor],
|
|
268
|
+
range_invariant: bool = True,
|
|
269
|
+
) -> list[Union[float, None]]:
|
|
270
|
+
"""Compute channel-wise multiscale SSIM for each channel.
|
|
271
|
+
|
|
272
|
+
It allows to use either standard multiscale SSIM or its range-invariant version.
|
|
273
|
+
|
|
274
|
+
NOTE: images fed to this function should have channels dimension as the last one.
|
|
275
|
+
# TODO: do we want to allow this behavior? or we want the usual (N, C, H, W)?
|
|
276
|
+
|
|
277
|
+
Parameters
|
|
278
|
+
----------
|
|
279
|
+
gt_ : Union[np.ndarray, torch.Tensor]
|
|
280
|
+
Ground truth image with shape (N, H, W, C).
|
|
281
|
+
pred_ : Union[np.ndarray, torch.Tensor]
|
|
282
|
+
Predicted image with shape (N, H, W, C).
|
|
283
|
+
range_invariant : bool
|
|
284
|
+
Whether to use standard or range invariant multiscale SSIM.
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
list[float]
|
|
289
|
+
List of SSIM values for each channel.
|
|
290
|
+
"""
|
|
291
|
+
ms_ssim_values = {}
|
|
292
|
+
for ch_idx in range(gt_.shape[-1]):
|
|
293
|
+
tar_tmp = gt_[..., ch_idx]
|
|
294
|
+
pred_tmp = pred_[..., ch_idx]
|
|
295
|
+
if range_invariant:
|
|
296
|
+
ms_ssim_values[ch_idx] = _range_invariant_multiscale_ssim(
|
|
297
|
+
gt_=tar_tmp, pred_=pred_tmp
|
|
298
|
+
)
|
|
299
|
+
else:
|
|
300
|
+
ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(
|
|
301
|
+
data_range=tar_tmp.max() - tar_tmp.min()
|
|
302
|
+
)
|
|
303
|
+
ms_ssim_values[ch_idx] = ms_ssim(
|
|
304
|
+
torch.Tensor(pred_tmp[:, None]), torch.Tensor(tar_tmp[:, None])
|
|
305
|
+
).item()
|
|
306
|
+
|
|
307
|
+
return [ms_ssim_values[i] for i in range(gt_.shape[-1])] # type: ignore
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _avg_psnr(target: np.ndarray, prediction: np.ndarray, psnr_fn: Callable) -> float:
|
|
311
|
+
"""Compute the average PSNR over a batch of images.
|
|
312
|
+
|
|
313
|
+
Parameters
|
|
314
|
+
----------
|
|
315
|
+
target : np.ndarray
|
|
316
|
+
Array of ground truth images, shape is (N, C, H, W).
|
|
317
|
+
prediction : np.ndarray
|
|
318
|
+
Array of predicted images, shape is (N, C, H, W).
|
|
319
|
+
psnr_fn : Callable
|
|
320
|
+
PSNR function to use.
|
|
321
|
+
|
|
322
|
+
Returns
|
|
323
|
+
-------
|
|
324
|
+
float
|
|
325
|
+
Average PSNR value over the batch.
|
|
326
|
+
"""
|
|
327
|
+
return np.mean(
|
|
328
|
+
[
|
|
329
|
+
psnr_fn(target[i : i + 1], prediction[i : i + 1]).item()
|
|
330
|
+
for i in range(len(prediction))
|
|
331
|
+
]
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def avg_range_inv_psnr(target: np.ndarray, prediction: np.ndarray) -> float:
|
|
336
|
+
"""Compute the average range-invariant PSNR over a batch of images.
|
|
337
|
+
|
|
338
|
+
Parameters
|
|
339
|
+
----------
|
|
340
|
+
target : np.ndarray
|
|
341
|
+
Array of ground truth images, shape is (N, C, H, W).
|
|
342
|
+
prediction : np.ndarray
|
|
343
|
+
Array of predicted images, shape is (N, C, H, W).
|
|
344
|
+
|
|
345
|
+
Returns
|
|
346
|
+
-------
|
|
347
|
+
float
|
|
348
|
+
Average range-invariant PSNR value over the batch.
|
|
349
|
+
"""
|
|
350
|
+
return _avg_psnr(target, prediction, scale_invariant_psnr)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def avg_psnr(target: np.ndarray, prediction: np.ndarray) -> float:
|
|
354
|
+
"""Compute the average PSNR over a batch of images.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
target : np.ndarray
|
|
359
|
+
Array of ground truth images, shape is (N, C, H, W).
|
|
360
|
+
prediction : np.ndarray
|
|
361
|
+
Array of predicted images, shape is (N, C, H, W).
|
|
362
|
+
|
|
363
|
+
Returns
|
|
364
|
+
-------
|
|
365
|
+
float
|
|
366
|
+
Average PSNR value over the batch.
|
|
367
|
+
"""
|
|
368
|
+
return _avg_psnr(target, prediction, psnr)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def avg_ssim(
|
|
372
|
+
target: Union[np.ndarray, torch.Tensor], prediction: Union[np.ndarray, torch.Tensor]
|
|
373
|
+
) -> tuple[float, float]:
|
|
374
|
+
"""Compute the average Structural Similarity (SSIM) over a batch of images.
|
|
375
|
+
|
|
376
|
+
Parameters
|
|
377
|
+
----------
|
|
378
|
+
target : np.ndarray
|
|
379
|
+
Array of ground truth images, shape is (N, C, H, W).
|
|
380
|
+
prediction : np.ndarray
|
|
381
|
+
Array of predicted images, shape is (N, C, H, W).
|
|
382
|
+
|
|
383
|
+
Returns
|
|
384
|
+
-------
|
|
385
|
+
tuple[float, float]
|
|
386
|
+
Mean and standard deviation of SSIM values over the batch.
|
|
387
|
+
"""
|
|
388
|
+
ssim = [
|
|
389
|
+
structural_similarity(
|
|
390
|
+
target[i], prediction[i], data_range=(target[i].max() - target[i].min())
|
|
391
|
+
)
|
|
392
|
+
for i in range(len(target))
|
|
393
|
+
]
|
|
394
|
+
return np.mean(ssim), np.std(ssim)
|