careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,987 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script provides methods to evaluate the performance of the LVAE model.
|
|
3
|
+
It includes functions to:
|
|
4
|
+
- make predictions,
|
|
5
|
+
- quantify the performance of the model
|
|
6
|
+
- create plots to visualize the results.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
import matplotlib
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
from matplotlib.gridspec import GridSpec
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
|
|
20
|
+
from careamics.lightning import VAEModule
|
|
21
|
+
from careamics.lvae_training.dataset import MultiChDloaderRef
|
|
22
|
+
from careamics.utils.metrics import scale_invariant_psnr
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TilingMode:
|
|
26
|
+
"""
|
|
27
|
+
Enum for the tiling mode.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
TrimBoundary = 0
|
|
31
|
+
PadBoundary = 1
|
|
32
|
+
ShiftBoundary = 2
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ------------------------------------------------------------------------------------
|
|
36
|
+
# Function of plotting: TODO -> moved them to another file, plot_utils.py
|
|
37
|
+
def clean_ax(ax):
|
|
38
|
+
"""
|
|
39
|
+
Helper function to remove ticks from axes in plots.
|
|
40
|
+
"""
|
|
41
|
+
# 2D or 1D axes are of type np.ndarray
|
|
42
|
+
if isinstance(ax, np.ndarray):
|
|
43
|
+
for one_ax in ax:
|
|
44
|
+
clean_ax(one_ax)
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
ax.set_yticklabels([])
|
|
48
|
+
ax.set_xticklabels([])
|
|
49
|
+
ax.tick_params(left=False, right=False, top=False, bottom=False)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_eval_output_dir(
|
|
53
|
+
saveplotsdir: str, patch_size: int, mmse_count: int = 50
|
|
54
|
+
) -> str:
|
|
55
|
+
"""
|
|
56
|
+
Given the path to a root directory to save plots, patch size, and mmse count,
|
|
57
|
+
it returns the specific directory to save the plots.
|
|
58
|
+
"""
|
|
59
|
+
eval_out_dir = os.path.join(
|
|
60
|
+
saveplotsdir, f"eval_outputs/patch_{patch_size}_mmse_{mmse_count}"
|
|
61
|
+
)
|
|
62
|
+
os.makedirs(eval_out_dir, exist_ok=True)
|
|
63
|
+
print(eval_out_dir)
|
|
64
|
+
return eval_out_dir
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_psnr_str(tar_hsnr, pred, col_idx):
|
|
68
|
+
"""
|
|
69
|
+
Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
|
|
70
|
+
"""
|
|
71
|
+
psnr = scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item()
|
|
72
|
+
|
|
73
|
+
return f"{psnr:.1f}"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def add_psnr_str(ax_, psnr):
|
|
77
|
+
"""
|
|
78
|
+
Add psnr string to the axes
|
|
79
|
+
"""
|
|
80
|
+
textstr = f"PSNR\n{psnr}"
|
|
81
|
+
props = dict(boxstyle="round", facecolor="gray", alpha=0.5)
|
|
82
|
+
# place a text box in upper left in axes coords
|
|
83
|
+
ax_.text(
|
|
84
|
+
0.05,
|
|
85
|
+
0.95,
|
|
86
|
+
textstr,
|
|
87
|
+
transform=ax_.transAxes,
|
|
88
|
+
fontsize=11,
|
|
89
|
+
verticalalignment="top",
|
|
90
|
+
bbox=props,
|
|
91
|
+
color="white",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_last_index(bin_count, quantile):
|
|
96
|
+
cumsum = np.cumsum(bin_count)
|
|
97
|
+
normalized_cumsum = cumsum / cumsum[-1]
|
|
98
|
+
for i in range(1, len(normalized_cumsum)):
|
|
99
|
+
if normalized_cumsum[-i] < quantile:
|
|
100
|
+
return i - 1
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_first_index(bin_count, quantile):
|
|
105
|
+
cumsum = np.cumsum(bin_count)
|
|
106
|
+
normalized_cumsum = cumsum / cumsum[-1]
|
|
107
|
+
for i in range(len(normalized_cumsum)):
|
|
108
|
+
if normalized_cumsum[i] > quantile:
|
|
109
|
+
return i
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_device():
|
|
114
|
+
if torch.cuda.is_available():
|
|
115
|
+
return "cuda"
|
|
116
|
+
elif torch.backends.mps.is_available():
|
|
117
|
+
return "mps"
|
|
118
|
+
else:
|
|
119
|
+
return "cpu"
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def show_for_one(
|
|
123
|
+
idx,
|
|
124
|
+
val_dset,
|
|
125
|
+
highsnr_val_dset,
|
|
126
|
+
model,
|
|
127
|
+
calibration_stats,
|
|
128
|
+
mmse_count=5,
|
|
129
|
+
patch_size=256,
|
|
130
|
+
num_samples=2,
|
|
131
|
+
baseline_preds=None,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Given an index, it plots the input, target, reconstructed images and the difference
|
|
135
|
+
image.
|
|
136
|
+
Note the the difference image is computed with respect to a ground truth image,
|
|
137
|
+
obtained from the high SNR dataset.
|
|
138
|
+
"""
|
|
139
|
+
highsnr_val_dset.set_img_sz(patch_size, 64)
|
|
140
|
+
highsnr_val_dset.disable_noise()
|
|
141
|
+
_, tar_hsnr = highsnr_val_dset[idx]
|
|
142
|
+
inp, tar, recon_img_list = get_predictions(
|
|
143
|
+
idx, val_dset, model, mmse_count=mmse_count, patch_size=patch_size
|
|
144
|
+
)
|
|
145
|
+
plot_crops(
|
|
146
|
+
inp,
|
|
147
|
+
tar,
|
|
148
|
+
tar_hsnr,
|
|
149
|
+
recon_img_list,
|
|
150
|
+
calibration_stats,
|
|
151
|
+
num_samples=num_samples,
|
|
152
|
+
baseline_preds=baseline_preds,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def plot_crops(
|
|
157
|
+
inp,
|
|
158
|
+
tar,
|
|
159
|
+
tar_hsnr,
|
|
160
|
+
recon_img_list,
|
|
161
|
+
calibration_stats=None,
|
|
162
|
+
num_samples=2,
|
|
163
|
+
baseline_preds=None,
|
|
164
|
+
):
|
|
165
|
+
if baseline_preds is None:
|
|
166
|
+
baseline_preds = []
|
|
167
|
+
if len(baseline_preds) > 0:
|
|
168
|
+
for i in range(len(baseline_preds)):
|
|
169
|
+
if baseline_preds[i].shape != tar_hsnr.shape:
|
|
170
|
+
print(
|
|
171
|
+
f"Baseline prediction {i} shape {baseline_preds[i].shape} does not "
|
|
172
|
+
f"match target shape {tar_hsnr.shape}"
|
|
173
|
+
)
|
|
174
|
+
print("This happens when we want to predict the edges of the image.")
|
|
175
|
+
return
|
|
176
|
+
color_ch_list = ["goldenrod", "cyan"]
|
|
177
|
+
color_pred = "red"
|
|
178
|
+
insetplot_xmax_value = 10000
|
|
179
|
+
insetplot_xmin_value = -1000
|
|
180
|
+
inset_min_labelsize = 10
|
|
181
|
+
inset_rect = [0.05, 0.05, 0.4, 0.2]
|
|
182
|
+
|
|
183
|
+
img_sz = 3
|
|
184
|
+
ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1)
|
|
185
|
+
grid_factor = 5
|
|
186
|
+
grid_img_sz = img_sz * grid_factor
|
|
187
|
+
example_spacing = 1
|
|
188
|
+
c0_extra = 1
|
|
189
|
+
nimgs = 1
|
|
190
|
+
fig_w = ncols * img_sz + 2 * c0_extra / grid_factor
|
|
191
|
+
fig_h = int(img_sz * ncols + (example_spacing * (nimgs - 1)) / grid_factor)
|
|
192
|
+
fig = plt.figure(figsize=(fig_w, fig_h))
|
|
193
|
+
gs = GridSpec(
|
|
194
|
+
nrows=int(grid_factor * fig_h),
|
|
195
|
+
ncols=int(grid_factor * fig_w),
|
|
196
|
+
hspace=0.2,
|
|
197
|
+
wspace=0.2,
|
|
198
|
+
)
|
|
199
|
+
params = {"mathtext.default": "regular"}
|
|
200
|
+
plt.rcParams.update(params)
|
|
201
|
+
# plot baselines
|
|
202
|
+
for i in range(2, 2 + len(baseline_preds)):
|
|
203
|
+
for col_idx in range(baseline_preds[0].shape[0]):
|
|
204
|
+
ax_temp = fig.add_subplot(
|
|
205
|
+
gs[
|
|
206
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
207
|
+
i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra,
|
|
208
|
+
]
|
|
209
|
+
)
|
|
210
|
+
print(tar_hsnr.shape, baseline_preds[i - 2].shape)
|
|
211
|
+
psnr = get_psnr_str(tar_hsnr, baseline_preds[i - 2], col_idx)
|
|
212
|
+
ax_temp.imshow(baseline_preds[i - 2][col_idx], cmap="magma")
|
|
213
|
+
add_psnr_str(ax_temp, psnr)
|
|
214
|
+
clean_ax(ax_temp)
|
|
215
|
+
|
|
216
|
+
# plot samples
|
|
217
|
+
sample_start_idx = 2 + len(baseline_preds)
|
|
218
|
+
for i in range(sample_start_idx, ncols - 3):
|
|
219
|
+
for col_idx in range(recon_img_list.shape[1]):
|
|
220
|
+
ax_temp = fig.add_subplot(
|
|
221
|
+
gs[
|
|
222
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
223
|
+
i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra,
|
|
224
|
+
]
|
|
225
|
+
)
|
|
226
|
+
psnr = get_psnr_str(tar_hsnr, recon_img_list[i - sample_start_idx], col_idx)
|
|
227
|
+
ax_temp.imshow(recon_img_list[i - sample_start_idx][col_idx], cmap="magma")
|
|
228
|
+
add_psnr_str(ax_temp, psnr)
|
|
229
|
+
clean_ax(ax_temp)
|
|
230
|
+
# inset_ax = add_pixel_kde(ax_temp,
|
|
231
|
+
# inset_rect,
|
|
232
|
+
# [tar_hsnr[col_idx],
|
|
233
|
+
# recon_img_list[i - sample_start_idx][col_idx]],
|
|
234
|
+
# inset_min_labelsize,
|
|
235
|
+
# label_list=['', ''],
|
|
236
|
+
# color_list=[color_ch_list[col_idx], color_pred],
|
|
237
|
+
# plot_xmax_value=insetplot_xmax_value,
|
|
238
|
+
# plot_xmin_value=insetplot_xmin_value)
|
|
239
|
+
|
|
240
|
+
# inset_ax.set_xticks([])
|
|
241
|
+
# inset_ax.set_yticks([])
|
|
242
|
+
|
|
243
|
+
# difference image
|
|
244
|
+
if num_samples > 1:
|
|
245
|
+
for col_idx in range(recon_img_list.shape[1]):
|
|
246
|
+
ax_temp = fig.add_subplot(
|
|
247
|
+
gs[
|
|
248
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
249
|
+
(ncols - 3) * grid_img_sz
|
|
250
|
+
+ c0_extra : (ncols - 2) * grid_img_sz
|
|
251
|
+
+ c0_extra,
|
|
252
|
+
]
|
|
253
|
+
)
|
|
254
|
+
ax_temp.imshow(
|
|
255
|
+
recon_img_list[1][col_idx] - recon_img_list[0][col_idx], cmap="coolwarm"
|
|
256
|
+
)
|
|
257
|
+
clean_ax(ax_temp)
|
|
258
|
+
|
|
259
|
+
for col_idx in range(recon_img_list.shape[1]):
|
|
260
|
+
# print(recon_img_list.shape)
|
|
261
|
+
ax_temp = fig.add_subplot(
|
|
262
|
+
gs[
|
|
263
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
264
|
+
c0_extra
|
|
265
|
+
+ (ncols - 2) * grid_img_sz : (ncols - 1) * grid_img_sz
|
|
266
|
+
+ c0_extra,
|
|
267
|
+
]
|
|
268
|
+
)
|
|
269
|
+
psnr = get_psnr_str(tar_hsnr, recon_img_list.mean(axis=0), col_idx)
|
|
270
|
+
ax_temp.imshow(recon_img_list.mean(axis=0)[col_idx], cmap="magma")
|
|
271
|
+
add_psnr_str(ax_temp, psnr)
|
|
272
|
+
# inset_ax = add_pixel_kde(ax_temp,
|
|
273
|
+
# inset_rect,
|
|
274
|
+
# [tar_hsnr[col_idx],
|
|
275
|
+
# recon_img_list.mean(axis=0)[col_idx]],
|
|
276
|
+
# inset_min_labelsize,
|
|
277
|
+
# label_list=['', ''],
|
|
278
|
+
# color_list=[color_ch_list[col_idx], color_pred],
|
|
279
|
+
# plot_xmax_value=insetplot_xmax_value,
|
|
280
|
+
# plot_xmin_value=insetplot_xmin_value)
|
|
281
|
+
# inset_ax.set_xticks([])
|
|
282
|
+
# inset_ax.set_yticks([])
|
|
283
|
+
|
|
284
|
+
clean_ax(ax_temp)
|
|
285
|
+
|
|
286
|
+
ax_temp = fig.add_subplot(
|
|
287
|
+
gs[
|
|
288
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
289
|
+
(ncols - 1) * grid_img_sz
|
|
290
|
+
+ 2 * c0_extra : (ncols) * grid_img_sz
|
|
291
|
+
+ 2 * c0_extra,
|
|
292
|
+
]
|
|
293
|
+
)
|
|
294
|
+
ax_temp.imshow(tar_hsnr[col_idx], cmap="magma")
|
|
295
|
+
if col_idx == 0:
|
|
296
|
+
legend_ch1_ax = ax_temp
|
|
297
|
+
if col_idx == 1:
|
|
298
|
+
legend_ch2_ax = ax_temp
|
|
299
|
+
|
|
300
|
+
# inset_ax = add_pixel_kde(ax_temp,
|
|
301
|
+
# inset_rect,
|
|
302
|
+
# [tar_hsnr[col_idx],
|
|
303
|
+
# ],
|
|
304
|
+
# inset_min_labelsize,
|
|
305
|
+
# label_list=[''],
|
|
306
|
+
# color_list=[color_ch_list[col_idx]],
|
|
307
|
+
# plot_xmax_value=insetplot_xmax_value,
|
|
308
|
+
# plot_xmin_value=insetplot_xmin_value)
|
|
309
|
+
# inset_ax.set_xticks([])
|
|
310
|
+
# inset_ax.set_yticks([])
|
|
311
|
+
|
|
312
|
+
clean_ax(ax_temp)
|
|
313
|
+
|
|
314
|
+
ax_temp = fig.add_subplot(
|
|
315
|
+
gs[
|
|
316
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
317
|
+
grid_img_sz : 2 * grid_img_sz,
|
|
318
|
+
]
|
|
319
|
+
)
|
|
320
|
+
ax_temp.imshow(tar[0, col_idx].cpu().numpy(), cmap="magma")
|
|
321
|
+
# inset_ax = add_pixel_kde(ax_temp,
|
|
322
|
+
# inset_rect,
|
|
323
|
+
# [tar[0,col_idx].cpu().numpy(),
|
|
324
|
+
# ],
|
|
325
|
+
# inset_min_labelsize,
|
|
326
|
+
# label_list=[''],
|
|
327
|
+
# color_list=[color_ch_list[col_idx]],
|
|
328
|
+
# plot_kwargs_list=[{'linestyle':'--'}],
|
|
329
|
+
# plot_xmax_value=insetplot_xmax_value,
|
|
330
|
+
# plot_xmin_value=insetplot_xmin_value)
|
|
331
|
+
|
|
332
|
+
# inset_ax.set_xticks([])
|
|
333
|
+
# inset_ax.set_yticks([])
|
|
334
|
+
|
|
335
|
+
clean_ax(ax_temp)
|
|
336
|
+
|
|
337
|
+
ax_temp = fig.add_subplot(gs[0:grid_img_sz, 0:grid_img_sz])
|
|
338
|
+
ax_temp.imshow(inp[0, 0].cpu().numpy(), cmap="magma")
|
|
339
|
+
clean_ax(ax_temp)
|
|
340
|
+
|
|
341
|
+
# line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-',
|
|
342
|
+
# label='$C_1$')
|
|
343
|
+
# line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-',
|
|
344
|
+
# label='$C_2$')
|
|
345
|
+
# line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-',
|
|
346
|
+
# label='Pred')
|
|
347
|
+
# line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0],
|
|
348
|
+
# linestyle='--', label='$C^N_1$')
|
|
349
|
+
# line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1],
|
|
350
|
+
# linestyle='--', label='$C^N_2$')
|
|
351
|
+
# legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred],
|
|
352
|
+
# loc='upper right', frameon=False, labelcolor='white',
|
|
353
|
+
# prop={'size': 11})
|
|
354
|
+
# legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred],
|
|
355
|
+
# loc='upper right', frameon=False, labelcolor='white',
|
|
356
|
+
# prop={'size': 11})
|
|
357
|
+
|
|
358
|
+
if calibration_stats is not None:
|
|
359
|
+
smaller_offset = 4
|
|
360
|
+
ax_temp = fig.add_subplot(
|
|
361
|
+
gs[
|
|
362
|
+
grid_img_sz + 1 : 2 * grid_img_sz - smaller_offset + 1,
|
|
363
|
+
smaller_offset - 1 : grid_img_sz - 1,
|
|
364
|
+
]
|
|
365
|
+
)
|
|
366
|
+
plot_calibration(ax_temp, calibration_stats)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def plot_calibration(ax, calibration_stats):
|
|
370
|
+
"""
|
|
371
|
+
To plot calibration statistics (RMV vs RMSE).
|
|
372
|
+
"""
|
|
373
|
+
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
|
|
374
|
+
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
|
|
375
|
+
ax.plot(
|
|
376
|
+
calibration_stats[0]["rmv"][first_idx:-last_idx],
|
|
377
|
+
calibration_stats[0]["rmse"][first_idx:-last_idx],
|
|
378
|
+
"o",
|
|
379
|
+
label=r"$\hat{C}_0$",
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
|
|
383
|
+
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
|
|
384
|
+
ax.plot(
|
|
385
|
+
calibration_stats[1]["rmv"][first_idx:-last_idx],
|
|
386
|
+
calibration_stats[1]["rmse"][first_idx:-last_idx],
|
|
387
|
+
"o",
|
|
388
|
+
label=r"$\hat{C}_1$",
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
ax.set_xlabel("RMV")
|
|
392
|
+
ax.set_ylabel("RMSE")
|
|
393
|
+
ax.legend()
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name="shiftedcmap"):
|
|
397
|
+
"""
|
|
398
|
+
Adapted from
|
|
399
|
+
https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-
|
|
400
|
+
matplotlib
|
|
401
|
+
|
|
402
|
+
Function to offset the "center" of a colormap. Useful for
|
|
403
|
+
data with a negative min and positive max and you want the
|
|
404
|
+
middle of the colormap's dynamic range to be at zero.
|
|
405
|
+
|
|
406
|
+
Input
|
|
407
|
+
-----
|
|
408
|
+
cmap : The matplotlib colormap to be altered
|
|
409
|
+
start : Offset from lowest point in the colormap's range.
|
|
410
|
+
Defaults to 0.0 (no lower offset). Should be between
|
|
411
|
+
0.0 and `midpoint`.
|
|
412
|
+
midpoint : The new center of the colormap. Defaults to
|
|
413
|
+
0.5 (no shift). Should be between 0.0 and 1.0. In
|
|
414
|
+
general, this should be 1 - vmax / (vmax + abs(vmin))
|
|
415
|
+
For example if your data range from -15.0 to +5.0 and
|
|
416
|
+
you want the center of the colormap at 0.0, `midpoint`
|
|
417
|
+
should be set to 1 - 5/(5 + 15)) or 0.75
|
|
418
|
+
stop : Offset from highest point in the colormap's range.
|
|
419
|
+
Defaults to 1.0 (no upper offset). Should be between
|
|
420
|
+
`midpoint` and 1.0.
|
|
421
|
+
"""
|
|
422
|
+
cdict = {"red": [], "green": [], "blue": [], "alpha": []}
|
|
423
|
+
|
|
424
|
+
# regular index to compute the colors
|
|
425
|
+
reg_index = np.linspace(start, stop, 257)
|
|
426
|
+
mid_idx = len(reg_index) // 2
|
|
427
|
+
# shifted index to match the data
|
|
428
|
+
shift_index = np.hstack(
|
|
429
|
+
[
|
|
430
|
+
np.linspace(0.0, midpoint, 128, endpoint=False),
|
|
431
|
+
np.linspace(midpoint, 1.0, 129, endpoint=True),
|
|
432
|
+
]
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
for ri, si in zip(reg_index, shift_index):
|
|
436
|
+
r, g, b, a = cmap(ri)
|
|
437
|
+
a = np.abs(ri - reg_index[mid_idx]) / reg_index[mid_idx]
|
|
438
|
+
# print(a)
|
|
439
|
+
cdict["red"].append((si, r, r))
|
|
440
|
+
cdict["green"].append((si, g, g))
|
|
441
|
+
cdict["blue"].append((si, b, b))
|
|
442
|
+
cdict["alpha"].append((si, a, a))
|
|
443
|
+
|
|
444
|
+
newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
|
|
445
|
+
matplotlib.colormaps.register(cmap=newcmap, force=True)
|
|
446
|
+
|
|
447
|
+
return newcmap
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def get_fractional_change(target, prediction, max_val=None):
|
|
451
|
+
"""
|
|
452
|
+
Get relative difference between target and prediction.
|
|
453
|
+
"""
|
|
454
|
+
if max_val is None:
|
|
455
|
+
max_val = target.max()
|
|
456
|
+
return (target - prediction) / max_val
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def get_zero_centered_midval(error):
|
|
460
|
+
"""
|
|
461
|
+
When done this way, the midval ensures that the colorbar is centered at 0. (Don't
|
|
462
|
+
know how, but it works ;))
|
|
463
|
+
"""
|
|
464
|
+
vmax = error.max()
|
|
465
|
+
vmin = error.min()
|
|
466
|
+
midval = 1 - vmax / (vmax + abs(vmin))
|
|
467
|
+
return midval
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val=None):
|
|
471
|
+
"""
|
|
472
|
+
Plot the relative difference between target and prediction.
|
|
473
|
+
NOTE: The plot is overlapped to the prediction image (in gray scale).
|
|
474
|
+
NOTE: The colorbar is centered at 0.
|
|
475
|
+
"""
|
|
476
|
+
if ax is None:
|
|
477
|
+
_, ax = plt.subplots(figsize=(6, 6))
|
|
478
|
+
|
|
479
|
+
# Relative difference between target and prediction
|
|
480
|
+
rel_diff = get_fractional_change(target, prediction, max_val=max_val)
|
|
481
|
+
midval = get_zero_centered_midval(rel_diff)
|
|
482
|
+
shifted_cmap = shiftedColorMap(
|
|
483
|
+
cmap, start=0, midpoint=midval, stop=1.0, name="shiftedcmap"
|
|
484
|
+
)
|
|
485
|
+
ax.imshow(prediction, cmap="gray")
|
|
486
|
+
img_err = ax.imshow(rel_diff, cmap=shifted_cmap, alpha=1)
|
|
487
|
+
plt.colorbar(img_err, ax=ax)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
# -------------------------------------------------------------------------------------
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def get_predictions(
|
|
494
|
+
model: VAEModule,
|
|
495
|
+
dset: Dataset,
|
|
496
|
+
batch_size: int,
|
|
497
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
498
|
+
grid_size: Optional[int] = None,
|
|
499
|
+
mmse_count: int = 1,
|
|
500
|
+
num_workers: int = 4,
|
|
501
|
+
) -> tuple[dict, dict, dict]:
|
|
502
|
+
"""Get patch-wise predictions from a model for the entire dataset.
|
|
503
|
+
|
|
504
|
+
Parameters
|
|
505
|
+
----------
|
|
506
|
+
model : VAEModule
|
|
507
|
+
Lightning model used for prediction.
|
|
508
|
+
dset : Dataset
|
|
509
|
+
Dataset to predict on.
|
|
510
|
+
batch_size : int
|
|
511
|
+
Batch size to use for prediction.
|
|
512
|
+
loss_type :
|
|
513
|
+
Type of reconstruction loss used by the model, by default `None`.
|
|
514
|
+
mmse_count : int, optional
|
|
515
|
+
Number of samples to generate for each input and then to average over for
|
|
516
|
+
MMSE estimation, by default 1.
|
|
517
|
+
num_workers : int, optional
|
|
518
|
+
Number of workers to use for DataLoader, by default 4.
|
|
519
|
+
|
|
520
|
+
Returns
|
|
521
|
+
-------
|
|
522
|
+
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]
|
|
523
|
+
Tuple containing:
|
|
524
|
+
- predictions: Predicted images for the dataset.
|
|
525
|
+
- predictions_std: Standard deviation of the predicted images.
|
|
526
|
+
- logvar_arr: Log variance of the predicted images.
|
|
527
|
+
- losses: Reconstruction losses for the predictions.
|
|
528
|
+
- psnr: PSNR values for the predictions.
|
|
529
|
+
"""
|
|
530
|
+
if hasattr(dset, "dsets"):
|
|
531
|
+
multifile_stitched_predictions = {}
|
|
532
|
+
multifile_stitched_stds = {}
|
|
533
|
+
for d in dset.dsets:
|
|
534
|
+
stitched_predictions, stitched_stds = get_single_file_mmse(
|
|
535
|
+
model=model,
|
|
536
|
+
dset=d,
|
|
537
|
+
batch_size=batch_size,
|
|
538
|
+
tile_size=tile_size,
|
|
539
|
+
grid_size=grid_size,
|
|
540
|
+
mmse_count=mmse_count,
|
|
541
|
+
num_workers=num_workers,
|
|
542
|
+
)
|
|
543
|
+
# get filename without extension and path
|
|
544
|
+
filename = d._fpath.name
|
|
545
|
+
multifile_stitched_predictions[filename] = stitched_predictions
|
|
546
|
+
multifile_stitched_stds[filename] = stitched_stds
|
|
547
|
+
return (
|
|
548
|
+
multifile_stitched_predictions,
|
|
549
|
+
multifile_stitched_stds,
|
|
550
|
+
)
|
|
551
|
+
else:
|
|
552
|
+
stitched_predictions, stitched_stds = get_single_file_mmse(
|
|
553
|
+
model=model,
|
|
554
|
+
dset=dset,
|
|
555
|
+
batch_size=batch_size,
|
|
556
|
+
tile_size=tile_size,
|
|
557
|
+
grid_size=grid_size,
|
|
558
|
+
mmse_count=mmse_count,
|
|
559
|
+
num_workers=num_workers,
|
|
560
|
+
)
|
|
561
|
+
# TODO stitching still not working properly for weirdly shaped images
|
|
562
|
+
# get filename without extension and path
|
|
563
|
+
# TODO in the ref ds this is the name of a folder not file :(
|
|
564
|
+
filename = dset._fpath.name
|
|
565
|
+
return (
|
|
566
|
+
{filename: stitched_predictions},
|
|
567
|
+
{filename: stitched_stds},
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def get_single_file_predictions(
|
|
572
|
+
model: VAEModule,
|
|
573
|
+
dset: Dataset,
|
|
574
|
+
batch_size: int,
|
|
575
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
576
|
+
grid_size: Optional[int] = None,
|
|
577
|
+
num_workers: int = 4,
|
|
578
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
579
|
+
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
580
|
+
if tile_size and grid_size:
|
|
581
|
+
dset.set_img_sz(tile_size, grid_size)
|
|
582
|
+
|
|
583
|
+
device = get_device()
|
|
584
|
+
|
|
585
|
+
dloader = DataLoader(
|
|
586
|
+
dset,
|
|
587
|
+
pin_memory=False,
|
|
588
|
+
num_workers=num_workers,
|
|
589
|
+
shuffle=False,
|
|
590
|
+
batch_size=batch_size,
|
|
591
|
+
)
|
|
592
|
+
model.eval()
|
|
593
|
+
model.to(device)
|
|
594
|
+
tiles = []
|
|
595
|
+
logvar_arr = []
|
|
596
|
+
with torch.no_grad():
|
|
597
|
+
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
598
|
+
inp, tar = batch
|
|
599
|
+
inp = inp.to(device)
|
|
600
|
+
tar = tar.to(device)
|
|
601
|
+
|
|
602
|
+
# get model output
|
|
603
|
+
rec, _ = model(inp)
|
|
604
|
+
|
|
605
|
+
# get reconstructed img
|
|
606
|
+
if model.model.predict_logvar is None:
|
|
607
|
+
rec_img = rec
|
|
608
|
+
logvar = torch.tensor([-1])
|
|
609
|
+
else:
|
|
610
|
+
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
611
|
+
logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
|
|
612
|
+
|
|
613
|
+
tiles.append(rec_img.cpu().numpy())
|
|
614
|
+
|
|
615
|
+
tile_samples = np.concatenate(tiles, axis=0)
|
|
616
|
+
return stitch_predictions_new(tile_samples, dset)
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def get_single_file_mmse(
|
|
620
|
+
model: VAEModule,
|
|
621
|
+
dset: Dataset,
|
|
622
|
+
batch_size: int,
|
|
623
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
624
|
+
grid_size: Optional[int] = None,
|
|
625
|
+
mmse_count: int = 1,
|
|
626
|
+
num_workers: int = 4,
|
|
627
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
628
|
+
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
629
|
+
device = get_device()
|
|
630
|
+
|
|
631
|
+
dloader = DataLoader(
|
|
632
|
+
dset,
|
|
633
|
+
pin_memory=False,
|
|
634
|
+
num_workers=num_workers,
|
|
635
|
+
shuffle=False,
|
|
636
|
+
batch_size=batch_size,
|
|
637
|
+
)
|
|
638
|
+
if tile_size and grid_size:
|
|
639
|
+
dset.set_img_sz(tile_size, grid_size)
|
|
640
|
+
|
|
641
|
+
model.eval()
|
|
642
|
+
model.to(device)
|
|
643
|
+
tile_mmse = []
|
|
644
|
+
tile_stds = []
|
|
645
|
+
logvar_arr = []
|
|
646
|
+
with torch.no_grad():
|
|
647
|
+
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
648
|
+
inp, tar = batch
|
|
649
|
+
inp = inp.to(device)
|
|
650
|
+
tar = tar.to(device)
|
|
651
|
+
|
|
652
|
+
rec_img_list = []
|
|
653
|
+
for _ in range(mmse_count):
|
|
654
|
+
|
|
655
|
+
# get model output
|
|
656
|
+
rec, _ = model(inp)
|
|
657
|
+
|
|
658
|
+
# get reconstructed img
|
|
659
|
+
if model.model.predict_logvar is None:
|
|
660
|
+
rec_img = rec
|
|
661
|
+
logvar = torch.tensor([-1])
|
|
662
|
+
else:
|
|
663
|
+
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
664
|
+
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
665
|
+
logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
|
|
666
|
+
|
|
667
|
+
# aggregate results
|
|
668
|
+
samples = torch.cat(rec_img_list, dim=0)
|
|
669
|
+
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
670
|
+
std_imgs = torch.std(samples, dim=0) # std over MMSE dim
|
|
671
|
+
|
|
672
|
+
tile_mmse.append(mmse_imgs.cpu().numpy())
|
|
673
|
+
tile_stds.append(std_imgs.cpu().numpy())
|
|
674
|
+
|
|
675
|
+
tiles_arr = np.concatenate(tile_mmse, axis=0)
|
|
676
|
+
tile_stds = np.concatenate(tile_stds, axis=0)
|
|
677
|
+
# TODO temporary hack, because of the stupid jupyter!
|
|
678
|
+
# If a user reruns a cell with class definition, isinstance will return False
|
|
679
|
+
if str(MultiChDloaderRef).split(".")[-1] == str(dset.__class__).split(".")[-1]:
|
|
680
|
+
stitch_func = stitch_predictions_general
|
|
681
|
+
else:
|
|
682
|
+
stitch_func = stitch_predictions_new
|
|
683
|
+
stitched_predictions = stitch_func(tiles_arr, dset)
|
|
684
|
+
stitched_stds = stitch_func(tile_stds, dset)
|
|
685
|
+
return stitched_predictions, stitched_stds
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
# ---------------------------------------------------------------------------------
|
|
689
|
+
### Classes and Functions used to stitch predictions
|
|
690
|
+
class PatchLocation:
|
|
691
|
+
"""
|
|
692
|
+
Encapsulates t_idx and spatial location.
|
|
693
|
+
"""
|
|
694
|
+
|
|
695
|
+
def __init__(self, h_idx_range, w_idx_range, t_idx):
|
|
696
|
+
self.t = t_idx
|
|
697
|
+
self.h_start, self.h_end = h_idx_range
|
|
698
|
+
self.w_start, self.w_end = w_idx_range
|
|
699
|
+
|
|
700
|
+
def __str__(self):
|
|
701
|
+
msg = f"T:{self.t} [{self.h_start}-{self.h_end}) [{self.w_start}-{self.w_end}) "
|
|
702
|
+
return msg
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def _get_location(extra_padding, hwt, pred_h, pred_w):
|
|
706
|
+
h_start, w_start, t_idx = hwt
|
|
707
|
+
h_start -= extra_padding
|
|
708
|
+
h_end = h_start + pred_h
|
|
709
|
+
w_start -= extra_padding
|
|
710
|
+
w_end = w_start + pred_w
|
|
711
|
+
return PatchLocation((h_start, h_end), (w_start, w_end), t_idx)
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def get_location_from_idx(dset, dset_input_idx, pred_h, pred_w):
|
|
715
|
+
"""
|
|
716
|
+
For a given idx of the dataset, it returns where exactly in the dataset, does this
|
|
717
|
+
prediction lies. Note that this prediction also has padded pixels and so a subset of
|
|
718
|
+
it will be used in the final prediction. Which time frame, which spatial location
|
|
719
|
+
(h_start, h_end, w_start,w_end)
|
|
720
|
+
Args:
|
|
721
|
+
dset:
|
|
722
|
+
dset_input_idx:
|
|
723
|
+
pred_h:
|
|
724
|
+
pred_w:
|
|
725
|
+
|
|
726
|
+
Returns
|
|
727
|
+
-------
|
|
728
|
+
"""
|
|
729
|
+
extra_padding = dset.per_side_overlap_pixelcount()
|
|
730
|
+
htw = dset.get_idx_manager().hwt_from_idx(
|
|
731
|
+
dset_input_idx, grid_size=dset.get_grid_size()
|
|
732
|
+
)
|
|
733
|
+
return _get_location(extra_padding, htw, pred_h, pred_w)
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
def remove_pad(pred, loc, extra_padding, smoothening_pixelcount, frame_shape):
|
|
737
|
+
assert smoothening_pixelcount == 0
|
|
738
|
+
if extra_padding - smoothening_pixelcount > 0:
|
|
739
|
+
h_s = extra_padding - smoothening_pixelcount
|
|
740
|
+
|
|
741
|
+
# rows
|
|
742
|
+
h_N = frame_shape[0]
|
|
743
|
+
if loc.h_end > h_N:
|
|
744
|
+
assert loc.h_end - extra_padding + smoothening_pixelcount <= h_N
|
|
745
|
+
h_e = extra_padding - smoothening_pixelcount
|
|
746
|
+
|
|
747
|
+
w_s = extra_padding - smoothening_pixelcount
|
|
748
|
+
|
|
749
|
+
# columns
|
|
750
|
+
w_N = frame_shape[1]
|
|
751
|
+
if loc.w_end > w_N:
|
|
752
|
+
assert loc.w_end - extra_padding + smoothening_pixelcount <= w_N
|
|
753
|
+
|
|
754
|
+
w_e = extra_padding - smoothening_pixelcount
|
|
755
|
+
|
|
756
|
+
return pred[h_s:-h_e, w_s:-w_e]
|
|
757
|
+
|
|
758
|
+
return pred
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
def update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount):
|
|
762
|
+
extra_padding = extra_padding - smoothening_pixelcount
|
|
763
|
+
loc.h_start += extra_padding
|
|
764
|
+
loc.w_start += extra_padding
|
|
765
|
+
loc.h_end -= extra_padding
|
|
766
|
+
loc.w_end -= extra_padding
|
|
767
|
+
return loc
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
|
|
771
|
+
"""
|
|
772
|
+
Args:
|
|
773
|
+
smoothening_pixelcount: number of pixels which can be interpolated
|
|
774
|
+
"""
|
|
775
|
+
assert smoothening_pixelcount >= 0 and isinstance(smoothening_pixelcount, int)
|
|
776
|
+
extra_padding = dset.per_side_overlap_pixelcount()
|
|
777
|
+
# if there are more channels, use all of them.
|
|
778
|
+
shape = list(dset.get_data_shape())
|
|
779
|
+
shape[-1] = max(shape[-1], predictions.shape[1])
|
|
780
|
+
|
|
781
|
+
output = np.zeros(shape, dtype=predictions.dtype)
|
|
782
|
+
frame_shape = dset.get_data_shape()[1:3]
|
|
783
|
+
for dset_input_idx in range(predictions.shape[0]):
|
|
784
|
+
loc = get_location_from_idx(
|
|
785
|
+
dset, dset_input_idx, predictions.shape[-2], predictions.shape[-1]
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
mask = None
|
|
789
|
+
cropped_pred_list = []
|
|
790
|
+
for ch_idx in range(predictions.shape[1]):
|
|
791
|
+
# class i
|
|
792
|
+
cropped_pred_i = remove_pad(
|
|
793
|
+
predictions[dset_input_idx, ch_idx],
|
|
794
|
+
loc,
|
|
795
|
+
extra_padding,
|
|
796
|
+
smoothening_pixelcount,
|
|
797
|
+
frame_shape,
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
if mask is None:
|
|
801
|
+
# NOTE: don't need to compute it for every patch.
|
|
802
|
+
assert (
|
|
803
|
+
smoothening_pixelcount == 0
|
|
804
|
+
), "For smoothing,enable the get_smoothing_mask. It is disabled since I"
|
|
805
|
+
"don't use it and it needs modification to work with non-square images"
|
|
806
|
+
mask = 1
|
|
807
|
+
# mask = _get_smoothing_mask(cropped_pred_i.shape,
|
|
808
|
+
# smoothening_pixelcount, loc, frame_size)
|
|
809
|
+
|
|
810
|
+
cropped_pred_list.append(cropped_pred_i)
|
|
811
|
+
|
|
812
|
+
loc = update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount)
|
|
813
|
+
for ch_idx in range(predictions.shape[1]):
|
|
814
|
+
output[loc.t, loc.h_start : loc.h_end, loc.w_start : loc.w_end, ch_idx] += (
|
|
815
|
+
cropped_pred_list[ch_idx] * mask
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
return output
|
|
819
|
+
|
|
820
|
+
|
|
821
|
+
# from disentangle.analysis.stitch_prediction import *
|
|
822
|
+
def stitch_predictions_new(predictions, dset):
|
|
823
|
+
"""
|
|
824
|
+
Args:
|
|
825
|
+
smoothening_pixelcount: number of pixels which can be interpolated
|
|
826
|
+
"""
|
|
827
|
+
# Commented out since it is not used as of now
|
|
828
|
+
# if isinstance(dset, MultiFileDset):
|
|
829
|
+
# cum_count = 0
|
|
830
|
+
# output = []
|
|
831
|
+
# for dset in dset.dsets:
|
|
832
|
+
# cnt = dset.idx_manager.total_grid_count()
|
|
833
|
+
# output.append(
|
|
834
|
+
# stitch_predictions(predictions[cum_count:cum_count + cnt], dset))
|
|
835
|
+
# cum_count += cnt
|
|
836
|
+
# return output
|
|
837
|
+
|
|
838
|
+
# else:
|
|
839
|
+
mng = dset.idx_manager
|
|
840
|
+
|
|
841
|
+
# if there are more channels, use all of them.
|
|
842
|
+
shape = list(dset.get_data_shape())
|
|
843
|
+
shape[-1] = max(shape[-1], predictions.shape[1])
|
|
844
|
+
|
|
845
|
+
output = np.zeros(shape, dtype=predictions.dtype)
|
|
846
|
+
# frame_shape = dset.get_data_shape()[:-1]
|
|
847
|
+
for dset_idx in range(predictions.shape[0]):
|
|
848
|
+
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
|
|
849
|
+
# predictions.shape[-1])
|
|
850
|
+
# grid start, grid end
|
|
851
|
+
gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
|
|
852
|
+
ge = gs + mng.grid_shape
|
|
853
|
+
|
|
854
|
+
# patch start, patch end
|
|
855
|
+
ps = gs - mng.patch_offset()
|
|
856
|
+
pe = ps + mng.patch_shape
|
|
857
|
+
# print('PS')
|
|
858
|
+
# print(ps)
|
|
859
|
+
# print(pe)
|
|
860
|
+
|
|
861
|
+
# valid grid start, valid grid end
|
|
862
|
+
vgs = np.array([max(0, x) for x in gs], dtype=int)
|
|
863
|
+
vge = np.array([min(x, y) for x, y in zip(ge, mng.data_shape)], dtype=int)
|
|
864
|
+
# assert np.all(vgs == gs)
|
|
865
|
+
# assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest
|
|
866
|
+
# to dig why it's failing at this point !
|
|
867
|
+
# print('VGS')
|
|
868
|
+
# print(gs)
|
|
869
|
+
# print(ge)
|
|
870
|
+
|
|
871
|
+
if mng.tiling_mode == TilingMode.ShiftBoundary:
|
|
872
|
+
for dim in range(len(vgs)):
|
|
873
|
+
if ps[dim] == 0:
|
|
874
|
+
vgs[dim] = 0
|
|
875
|
+
if pe[dim] == mng.data_shape[dim]:
|
|
876
|
+
vge[dim] = mng.data_shape[dim]
|
|
877
|
+
|
|
878
|
+
# relative start, relative end. This will be used on pred_tiled
|
|
879
|
+
rs = vgs - ps
|
|
880
|
+
re = rs + (vge - vgs)
|
|
881
|
+
# print('RS')
|
|
882
|
+
# print(rs)
|
|
883
|
+
# print(re)
|
|
884
|
+
|
|
885
|
+
# print(output.shape)
|
|
886
|
+
# print(predictions.shape)
|
|
887
|
+
for ch_idx in range(predictions.shape[1]):
|
|
888
|
+
if len(output.shape) == 4:
|
|
889
|
+
# channel dimension is the last one.
|
|
890
|
+
output[vgs[0] : vge[0], vgs[1] : vge[1], vgs[2] : vge[2], ch_idx] = (
|
|
891
|
+
predictions[dset_idx][ch_idx, rs[1] : re[1], rs[2] : re[2]]
|
|
892
|
+
)
|
|
893
|
+
elif len(output.shape) == 5:
|
|
894
|
+
# channel dimension is the last one.
|
|
895
|
+
assert vge[0] - vgs[0] == 1, "Only one frame is supported"
|
|
896
|
+
output[
|
|
897
|
+
vgs[0], vgs[1] : vge[1], vgs[2] : vge[2], vgs[3] : vge[3], ch_idx
|
|
898
|
+
] = predictions[dset_idx][
|
|
899
|
+
ch_idx, rs[1] : re[1], rs[2] : re[2], rs[3] : re[3]
|
|
900
|
+
]
|
|
901
|
+
else:
|
|
902
|
+
raise ValueError(f"Unsupported shape {output.shape}")
|
|
903
|
+
|
|
904
|
+
return output
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def stitch_predictions_general(predictions, dset):
|
|
908
|
+
"""Stitching for the dataset with multiple files of different shape."""
|
|
909
|
+
mng = dset.idx_manager
|
|
910
|
+
|
|
911
|
+
# TODO assert all shapes are equal len
|
|
912
|
+
# adjust number of channels to match with prediction shape #TODO ugly, refac!
|
|
913
|
+
shapes = []
|
|
914
|
+
for shape in dset.get_data_shapes()[0]:
|
|
915
|
+
shapes.append((predictions.shape[1],) + shape[1:])
|
|
916
|
+
|
|
917
|
+
output = [np.zeros(shape, dtype=predictions.dtype) for shape in shapes]
|
|
918
|
+
# frame_shape = dset.get_data_shape()[:-1]
|
|
919
|
+
for patch_idx in range(predictions.shape[0]):
|
|
920
|
+
# grid start, grid end
|
|
921
|
+
# channel_idx is 0 because during prediction we're only use one channel.
|
|
922
|
+
# # TODO revisit this
|
|
923
|
+
# 0th dimension is sample index in the output list
|
|
924
|
+
grid_coords = np.array(
|
|
925
|
+
mng.get_location_from_patch_idx(channel_idx=0, patch_idx=patch_idx),
|
|
926
|
+
dtype=int,
|
|
927
|
+
)
|
|
928
|
+
sample_idx = grid_coords[0]
|
|
929
|
+
grid_start = grid_coords[1:]
|
|
930
|
+
# from here on, coordinates are relative to the sample(file in the list of
|
|
931
|
+
# inputs)
|
|
932
|
+
grid_end = grid_start + mng.grid_shape
|
|
933
|
+
|
|
934
|
+
# patch start, patch end
|
|
935
|
+
patch_start = grid_start - mng.patch_offset()
|
|
936
|
+
patch_end = patch_start + mng.patch_shape
|
|
937
|
+
|
|
938
|
+
# valid grid start, valid grid end
|
|
939
|
+
valid_grid_start = np.array([max(0, x) for x in grid_start], dtype=int)
|
|
940
|
+
valid_grid_end = np.array(
|
|
941
|
+
[min(x, y) for x, y in zip(grid_end, shapes[sample_idx])], dtype=int
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
if mng.tiling_mode == TilingMode.ShiftBoundary:
|
|
945
|
+
for dim in range(len(valid_grid_start)):
|
|
946
|
+
if patch_start[dim] == 0:
|
|
947
|
+
valid_grid_start[dim] = 0
|
|
948
|
+
if patch_end[dim] == mng.data_shape[dim]:
|
|
949
|
+
valid_grid_end[dim] = mng.data_shape[dim]
|
|
950
|
+
|
|
951
|
+
# relative start, relative end. This will be used on pred_tiled
|
|
952
|
+
relative_start = valid_grid_start - patch_start
|
|
953
|
+
relative_end = relative_start + (valid_grid_end - valid_grid_start)
|
|
954
|
+
|
|
955
|
+
for ch_idx in range(predictions.shape[1]):
|
|
956
|
+
if len(output[sample_idx].shape) == 3:
|
|
957
|
+
# starting from 1 because 0th dimension is channel relative to input
|
|
958
|
+
# channel dimension for stitched output is relative to model output
|
|
959
|
+
output[sample_idx][
|
|
960
|
+
ch_idx,
|
|
961
|
+
valid_grid_start[1] : valid_grid_end[1],
|
|
962
|
+
valid_grid_start[2] : valid_grid_end[2],
|
|
963
|
+
] = predictions[patch_idx][
|
|
964
|
+
ch_idx,
|
|
965
|
+
relative_start[1] : relative_end[1],
|
|
966
|
+
relative_start[2] : relative_end[2],
|
|
967
|
+
]
|
|
968
|
+
elif len(output[sample_idx].shape) == 4:
|
|
969
|
+
assert (
|
|
970
|
+
valid_grid_end[0] - valid_grid_start[0] == 1
|
|
971
|
+
), "Only one frame is supported"
|
|
972
|
+
output[
|
|
973
|
+
ch_idx,
|
|
974
|
+
valid_grid_start[0],
|
|
975
|
+
valid_grid_end[1] : valid_grid_end[1],
|
|
976
|
+
valid_grid_start[2] : valid_grid_end[2],
|
|
977
|
+
valid_grid_start[3] : valid_grid_end[3],
|
|
978
|
+
] = predictions[patch_idx][
|
|
979
|
+
ch_idx,
|
|
980
|
+
relative_start[1] : relative_end[1],
|
|
981
|
+
relative_start[2] : relative_end[2],
|
|
982
|
+
relative_start[3] : relative_end[3],
|
|
983
|
+
]
|
|
984
|
+
else:
|
|
985
|
+
raise ValueError(f"Unsupported shape {output.shape}")
|
|
986
|
+
|
|
987
|
+
return output
|