careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,905 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script provides methods to evaluate the performance of the LVAE model.
|
|
3
|
+
It includes functions to:
|
|
4
|
+
- make predictions,
|
|
5
|
+
- quantify the performance of the model
|
|
6
|
+
- create plots to visualize the results.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
import os
|
|
11
|
+
from typing import Dict, List, Literal, Union
|
|
12
|
+
|
|
13
|
+
import matplotlib
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
from matplotlib.gridspec import GridSpec
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
from careamics.models.lvae.utils import ModelType
|
|
22
|
+
|
|
23
|
+
from .metrics import RangeInvariantPsnr, RunningPSNR
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# ------------------------------------------------------------------------------------------------
|
|
27
|
+
# Function of plotting: TODO -> moved them to another file, plot_utils.py
|
|
28
|
+
def clean_ax(ax):
|
|
29
|
+
"""
|
|
30
|
+
Helper function to remove ticks from axes in plots.
|
|
31
|
+
"""
|
|
32
|
+
# 2D or 1D axes are of type np.ndarray
|
|
33
|
+
if isinstance(ax, np.ndarray):
|
|
34
|
+
for one_ax in ax:
|
|
35
|
+
clean_ax(one_ax)
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
ax.set_yticklabels([])
|
|
39
|
+
ax.set_xticklabels([])
|
|
40
|
+
ax.tick_params(left=False, right=False, top=False, bottom=False)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_plots_output_dir(
|
|
44
|
+
saveplotsdir: str, patch_size: int, mmse_count: int = 50
|
|
45
|
+
) -> str:
|
|
46
|
+
"""
|
|
47
|
+
Given the path to a root directory to save plots, patch size, and mmse count,
|
|
48
|
+
it returns the specific directory to save the plots.
|
|
49
|
+
"""
|
|
50
|
+
plotsrootdir = os.path.join(
|
|
51
|
+
saveplotsdir, f"plots/patch_{patch_size}_mmse_{mmse_count}"
|
|
52
|
+
)
|
|
53
|
+
os.makedirs(plotsrootdir, exist_ok=True)
|
|
54
|
+
print(plotsrootdir)
|
|
55
|
+
return plotsrootdir
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_psnr_str(tar_hsnr, pred, col_idx):
|
|
59
|
+
"""
|
|
60
|
+
Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
|
|
61
|
+
"""
|
|
62
|
+
return (
|
|
63
|
+
f"{RangeInvariantPsnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def add_psnr_str(ax_, psnr):
|
|
68
|
+
"""
|
|
69
|
+
Add psnr string to the axes
|
|
70
|
+
"""
|
|
71
|
+
textstr = f"PSNR\n{psnr}"
|
|
72
|
+
props = dict(boxstyle="round", facecolor="gray", alpha=0.5)
|
|
73
|
+
# place a text box in upper left in axes coords
|
|
74
|
+
ax_.text(
|
|
75
|
+
0.05,
|
|
76
|
+
0.95,
|
|
77
|
+
textstr,
|
|
78
|
+
transform=ax_.transAxes,
|
|
79
|
+
fontsize=11,
|
|
80
|
+
verticalalignment="top",
|
|
81
|
+
bbox=props,
|
|
82
|
+
color="white",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_last_index(bin_count, quantile):
|
|
87
|
+
cumsum = np.cumsum(bin_count)
|
|
88
|
+
normalized_cumsum = cumsum / cumsum[-1]
|
|
89
|
+
for i in range(1, len(normalized_cumsum)):
|
|
90
|
+
if normalized_cumsum[-i] < quantile:
|
|
91
|
+
return i - 1
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_first_index(bin_count, quantile):
|
|
96
|
+
cumsum = np.cumsum(bin_count)
|
|
97
|
+
normalized_cumsum = cumsum / cumsum[-1]
|
|
98
|
+
for i in range(len(normalized_cumsum)):
|
|
99
|
+
if normalized_cumsum[i] > quantile:
|
|
100
|
+
return i
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def show_for_one(
|
|
105
|
+
idx,
|
|
106
|
+
val_dset,
|
|
107
|
+
highsnr_val_dset,
|
|
108
|
+
model,
|
|
109
|
+
calibration_stats,
|
|
110
|
+
mmse_count=5,
|
|
111
|
+
patch_size=256,
|
|
112
|
+
num_samples=2,
|
|
113
|
+
baseline_preds=None,
|
|
114
|
+
):
|
|
115
|
+
"""
|
|
116
|
+
Given an index, it plots the input, target, reconstructed images and the difference image.
|
|
117
|
+
Note the the difference image is computed with respect to a ground truth image, obtained from the high SNR dataset.
|
|
118
|
+
"""
|
|
119
|
+
highsnr_val_dset.set_img_sz(patch_size, 64)
|
|
120
|
+
highsnr_val_dset.disable_noise()
|
|
121
|
+
_, tar_hsnr = highsnr_val_dset[idx]
|
|
122
|
+
inp, tar, recon_img_list = get_predictions(
|
|
123
|
+
idx, val_dset, model, mmse_count=mmse_count, patch_size=patch_size
|
|
124
|
+
)
|
|
125
|
+
plot_crops(
|
|
126
|
+
inp,
|
|
127
|
+
tar,
|
|
128
|
+
tar_hsnr,
|
|
129
|
+
recon_img_list,
|
|
130
|
+
calibration_stats,
|
|
131
|
+
num_samples=num_samples,
|
|
132
|
+
baseline_preds=baseline_preds,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def plot_crops(
|
|
137
|
+
inp,
|
|
138
|
+
tar,
|
|
139
|
+
tar_hsnr,
|
|
140
|
+
recon_img_list,
|
|
141
|
+
calibration_stats,
|
|
142
|
+
num_samples=2,
|
|
143
|
+
baseline_preds=None,
|
|
144
|
+
):
|
|
145
|
+
""" """
|
|
146
|
+
if baseline_preds is None:
|
|
147
|
+
baseline_preds = []
|
|
148
|
+
if len(baseline_preds) > 0:
|
|
149
|
+
for i in range(len(baseline_preds)):
|
|
150
|
+
if baseline_preds[i].shape != tar_hsnr.shape:
|
|
151
|
+
print(
|
|
152
|
+
f"Baseline prediction {i} shape {baseline_preds[i].shape} does not match target shape {tar_hsnr.shape}"
|
|
153
|
+
)
|
|
154
|
+
print("This happens when we want to predict the edges of the image.")
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
# color_ch_list = ['goldenrod', 'cyan']
|
|
158
|
+
# color_pred = 'red'
|
|
159
|
+
# insetplot_xmax_value = 10000
|
|
160
|
+
# insetplot_xmin_value = -1000
|
|
161
|
+
# inset_min_labelsize = 10
|
|
162
|
+
# inset_rect = [0.05, 0.05, 0.4, 0.2]
|
|
163
|
+
|
|
164
|
+
# Set plot attributes
|
|
165
|
+
img_sz = 3
|
|
166
|
+
ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1)
|
|
167
|
+
grid_factor = 5
|
|
168
|
+
grid_img_sz = img_sz * grid_factor
|
|
169
|
+
example_spacing = 1
|
|
170
|
+
c0_extra = 1
|
|
171
|
+
nimgs = 1
|
|
172
|
+
fig_w = ncols * img_sz + 2 * c0_extra / grid_factor
|
|
173
|
+
fig_h = int(img_sz * ncols + (example_spacing * (nimgs - 1)) / grid_factor)
|
|
174
|
+
fig = plt.figure(figsize=(fig_w, fig_h))
|
|
175
|
+
gs = GridSpec(
|
|
176
|
+
nrows=int(grid_factor * fig_h),
|
|
177
|
+
ncols=int(grid_factor * fig_w),
|
|
178
|
+
hspace=0.2,
|
|
179
|
+
wspace=0.2,
|
|
180
|
+
)
|
|
181
|
+
params = {"mathtext.default": "regular"}
|
|
182
|
+
plt.rcParams.update(params)
|
|
183
|
+
|
|
184
|
+
# plot baselines
|
|
185
|
+
for i in range(2, 2 + len(baseline_preds)):
|
|
186
|
+
for col_idx in range(baseline_preds[0].shape[0]):
|
|
187
|
+
ax_temp = fig.add_subplot(
|
|
188
|
+
gs[
|
|
189
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
190
|
+
i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra,
|
|
191
|
+
]
|
|
192
|
+
)
|
|
193
|
+
print(tar_hsnr.shape, baseline_preds[i - 2].shape)
|
|
194
|
+
psnr = get_psnr_str(tar_hsnr, baseline_preds[i - 2], col_idx)
|
|
195
|
+
ax_temp.imshow(baseline_preds[i - 2][col_idx], cmap="magma")
|
|
196
|
+
add_psnr_str(ax_temp, psnr)
|
|
197
|
+
clean_ax(ax_temp)
|
|
198
|
+
|
|
199
|
+
# plot samples
|
|
200
|
+
sample_start_idx = 2 + len(baseline_preds)
|
|
201
|
+
for i in range(sample_start_idx, ncols - 3):
|
|
202
|
+
for col_idx in range(recon_img_list.shape[1]):
|
|
203
|
+
ax_temp = fig.add_subplot(
|
|
204
|
+
gs[
|
|
205
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
206
|
+
i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra,
|
|
207
|
+
]
|
|
208
|
+
)
|
|
209
|
+
psnr = get_psnr_str(tar_hsnr, recon_img_list[i - sample_start_idx], col_idx)
|
|
210
|
+
ax_temp.imshow(recon_img_list[i - sample_start_idx][col_idx], cmap="magma")
|
|
211
|
+
add_psnr_str(ax_temp, psnr)
|
|
212
|
+
clean_ax(ax_temp)
|
|
213
|
+
# inset_ax = add_pixel_kde(ax_temp,
|
|
214
|
+
# inset_rect,
|
|
215
|
+
# [tar_hsnr[col_idx],
|
|
216
|
+
# recon_img_list[i - sample_start_idx][col_idx]],
|
|
217
|
+
# inset_min_labelsize,
|
|
218
|
+
# label_list=['', ''],
|
|
219
|
+
# color_list=[color_ch_list[col_idx], color_pred],
|
|
220
|
+
# plot_xmax_value=insetplot_xmax_value,
|
|
221
|
+
# plot_xmin_value=insetplot_xmin_value)
|
|
222
|
+
|
|
223
|
+
# inset_ax.set_xticks([])
|
|
224
|
+
# inset_ax.set_yticks([])
|
|
225
|
+
|
|
226
|
+
# difference image
|
|
227
|
+
if num_samples > 1:
|
|
228
|
+
for col_idx in range(recon_img_list.shape[1]):
|
|
229
|
+
ax_temp = fig.add_subplot(
|
|
230
|
+
gs[
|
|
231
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
232
|
+
(ncols - 3) * grid_img_sz
|
|
233
|
+
+ c0_extra : (ncols - 2) * grid_img_sz
|
|
234
|
+
+ c0_extra,
|
|
235
|
+
]
|
|
236
|
+
)
|
|
237
|
+
ax_temp.imshow(
|
|
238
|
+
recon_img_list[1][col_idx] - recon_img_list[0][col_idx], cmap="coolwarm"
|
|
239
|
+
)
|
|
240
|
+
clean_ax(ax_temp)
|
|
241
|
+
|
|
242
|
+
for col_idx in range(recon_img_list.shape[1]):
|
|
243
|
+
# print(recon_img_list.shape)
|
|
244
|
+
ax_temp = fig.add_subplot(
|
|
245
|
+
gs[
|
|
246
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
247
|
+
c0_extra
|
|
248
|
+
+ (ncols - 2) * grid_img_sz : (ncols - 1) * grid_img_sz
|
|
249
|
+
+ c0_extra,
|
|
250
|
+
]
|
|
251
|
+
)
|
|
252
|
+
psnr = get_psnr_str(tar_hsnr, recon_img_list.mean(axis=0), col_idx)
|
|
253
|
+
ax_temp.imshow(recon_img_list.mean(axis=0)[col_idx], cmap="magma")
|
|
254
|
+
add_psnr_str(ax_temp, psnr)
|
|
255
|
+
# inset_ax = add_pixel_kde(ax_temp,
|
|
256
|
+
# inset_rect,
|
|
257
|
+
# [tar_hsnr[col_idx],
|
|
258
|
+
# recon_img_list.mean(axis=0)[col_idx]],
|
|
259
|
+
# inset_min_labelsize,
|
|
260
|
+
# label_list=['', ''],
|
|
261
|
+
# color_list=[color_ch_list[col_idx], color_pred],
|
|
262
|
+
# plot_xmax_value=insetplot_xmax_value,
|
|
263
|
+
# plot_xmin_value=insetplot_xmin_value)
|
|
264
|
+
# inset_ax.set_xticks([])
|
|
265
|
+
# inset_ax.set_yticks([])
|
|
266
|
+
|
|
267
|
+
clean_ax(ax_temp)
|
|
268
|
+
|
|
269
|
+
ax_temp = fig.add_subplot(
|
|
270
|
+
gs[
|
|
271
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
272
|
+
(ncols - 1) * grid_img_sz
|
|
273
|
+
+ 2 * c0_extra : (ncols) * grid_img_sz
|
|
274
|
+
+ 2 * c0_extra,
|
|
275
|
+
]
|
|
276
|
+
)
|
|
277
|
+
ax_temp.imshow(tar_hsnr[col_idx], cmap="magma")
|
|
278
|
+
if col_idx == 0:
|
|
279
|
+
legend_ch1_ax = ax_temp
|
|
280
|
+
if col_idx == 1:
|
|
281
|
+
legend_ch2_ax = ax_temp
|
|
282
|
+
|
|
283
|
+
# inset_ax = add_pixel_kde(ax_temp,
|
|
284
|
+
# inset_rect,
|
|
285
|
+
# [tar_hsnr[col_idx],
|
|
286
|
+
# ],
|
|
287
|
+
# inset_min_labelsize,
|
|
288
|
+
# label_list=[''],
|
|
289
|
+
# color_list=[color_ch_list[col_idx]],
|
|
290
|
+
# plot_xmax_value=insetplot_xmax_value,
|
|
291
|
+
# plot_xmin_value=insetplot_xmin_value)
|
|
292
|
+
# inset_ax.set_xticks([])
|
|
293
|
+
# inset_ax.set_yticks([])
|
|
294
|
+
|
|
295
|
+
clean_ax(ax_temp)
|
|
296
|
+
|
|
297
|
+
ax_temp = fig.add_subplot(
|
|
298
|
+
gs[
|
|
299
|
+
col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
|
|
300
|
+
grid_img_sz : 2 * grid_img_sz,
|
|
301
|
+
]
|
|
302
|
+
)
|
|
303
|
+
ax_temp.imshow(tar[0, col_idx].cpu().numpy(), cmap="magma")
|
|
304
|
+
# inset_ax = add_pixel_kde(ax_temp,
|
|
305
|
+
# inset_rect,
|
|
306
|
+
# [tar[0,col_idx].cpu().numpy(),
|
|
307
|
+
# ],
|
|
308
|
+
# inset_min_labelsize,
|
|
309
|
+
# label_list=[''],
|
|
310
|
+
# color_list=[color_ch_list[col_idx]],
|
|
311
|
+
# plot_kwargs_list=[{'linestyle':'--'}],
|
|
312
|
+
# plot_xmax_value=insetplot_xmax_value,
|
|
313
|
+
# plot_xmin_value=insetplot_xmin_value)
|
|
314
|
+
|
|
315
|
+
# inset_ax.set_xticks([])
|
|
316
|
+
# inset_ax.set_yticks([])
|
|
317
|
+
|
|
318
|
+
clean_ax(ax_temp)
|
|
319
|
+
|
|
320
|
+
ax_temp = fig.add_subplot(gs[0:grid_img_sz, 0:grid_img_sz])
|
|
321
|
+
ax_temp.imshow(inp[0, 0].cpu().numpy(), cmap="magma")
|
|
322
|
+
clean_ax(ax_temp)
|
|
323
|
+
|
|
324
|
+
# line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-', label='$C_1$')
|
|
325
|
+
# line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-', label='$C_2$')
|
|
326
|
+
# line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-', label='Pred')
|
|
327
|
+
# line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='--', label='$C^N_1$')
|
|
328
|
+
# line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='--', label='$C^N_2$')
|
|
329
|
+
# legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred], loc='upper right', frameon=False, labelcolor='white',
|
|
330
|
+
# prop={'size': 11})
|
|
331
|
+
# legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred], loc='upper right', frameon=False, labelcolor='white',
|
|
332
|
+
# prop={'size': 11})
|
|
333
|
+
|
|
334
|
+
if calibration_stats is not None:
|
|
335
|
+
smaller_offset = 4
|
|
336
|
+
ax_temp = fig.add_subplot(
|
|
337
|
+
gs[
|
|
338
|
+
grid_img_sz + 1 : 2 * grid_img_sz - smaller_offset + 1,
|
|
339
|
+
smaller_offset - 1 : grid_img_sz - 1,
|
|
340
|
+
]
|
|
341
|
+
)
|
|
342
|
+
plot_calibration(ax_temp, calibration_stats)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def plot_calibration(ax, calibration_stats):
|
|
346
|
+
"""
|
|
347
|
+
To plot calibration statistics (RMV vs RMSE).
|
|
348
|
+
"""
|
|
349
|
+
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
|
|
350
|
+
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
|
|
351
|
+
ax.plot(
|
|
352
|
+
calibration_stats[0]["rmv"][first_idx:-last_idx],
|
|
353
|
+
calibration_stats[0]["rmse"][first_idx:-last_idx],
|
|
354
|
+
"o",
|
|
355
|
+
label=r"$\hat{C}_0$",
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
|
|
359
|
+
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
|
|
360
|
+
ax.plot(
|
|
361
|
+
calibration_stats[1]["rmv"][first_idx:-last_idx],
|
|
362
|
+
calibration_stats[1]["rmse"][first_idx:-last_idx],
|
|
363
|
+
"o",
|
|
364
|
+
label=r"$\hat{C}_1$",
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
ax.set_xlabel("RMV")
|
|
368
|
+
ax.set_ylabel("RMSE")
|
|
369
|
+
ax.legend()
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name="shiftedcmap"):
|
|
373
|
+
"""
|
|
374
|
+
Adapted from https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-matplotlib
|
|
375
|
+
|
|
376
|
+
Function to offset the "center" of a colormap. Useful for
|
|
377
|
+
data with a negative min and positive max and you want the
|
|
378
|
+
middle of the colormap's dynamic range to be at zero.
|
|
379
|
+
|
|
380
|
+
Input
|
|
381
|
+
-----
|
|
382
|
+
cmap : The matplotlib colormap to be altered
|
|
383
|
+
start : Offset from lowest point in the colormap's range.
|
|
384
|
+
Defaults to 0.0 (no lower offset). Should be between
|
|
385
|
+
0.0 and `midpoint`.
|
|
386
|
+
midpoint : The new center of the colormap. Defaults to
|
|
387
|
+
0.5 (no shift). Should be between 0.0 and 1.0. In
|
|
388
|
+
general, this should be 1 - vmax / (vmax + abs(vmin))
|
|
389
|
+
For example if your data range from -15.0 to +5.0 and
|
|
390
|
+
you want the center of the colormap at 0.0, `midpoint`
|
|
391
|
+
should be set to 1 - 5/(5 + 15)) or 0.75
|
|
392
|
+
stop : Offset from highest point in the colormap's range.
|
|
393
|
+
Defaults to 1.0 (no upper offset). Should be between
|
|
394
|
+
`midpoint` and 1.0.
|
|
395
|
+
"""
|
|
396
|
+
cdict = {"red": [], "green": [], "blue": [], "alpha": []}
|
|
397
|
+
|
|
398
|
+
# regular index to compute the colors
|
|
399
|
+
reg_index = np.linspace(start, stop, 257)
|
|
400
|
+
mid_idx = len(reg_index) // 2
|
|
401
|
+
# shifted index to match the data
|
|
402
|
+
shift_index = np.hstack(
|
|
403
|
+
[
|
|
404
|
+
np.linspace(0.0, midpoint, 128, endpoint=False),
|
|
405
|
+
np.linspace(midpoint, 1.0, 129, endpoint=True),
|
|
406
|
+
]
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
for ri, si in zip(reg_index, shift_index):
|
|
410
|
+
r, g, b, a = cmap(ri)
|
|
411
|
+
a = np.abs(ri - reg_index[mid_idx]) / reg_index[mid_idx]
|
|
412
|
+
# print(a)
|
|
413
|
+
cdict["red"].append((si, r, r))
|
|
414
|
+
cdict["green"].append((si, g, g))
|
|
415
|
+
cdict["blue"].append((si, b, b))
|
|
416
|
+
cdict["alpha"].append((si, a, a))
|
|
417
|
+
|
|
418
|
+
newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
|
|
419
|
+
matplotlib.colormaps.register(cmap=newcmap, force=True)
|
|
420
|
+
|
|
421
|
+
return newcmap
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def get_fractional_change(target, prediction, max_val=None):
|
|
425
|
+
"""
|
|
426
|
+
Get relative difference between target and prediction.
|
|
427
|
+
"""
|
|
428
|
+
if max_val is None:
|
|
429
|
+
max_val = target.max()
|
|
430
|
+
return (target - prediction) / max_val
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def get_zero_centered_midval(error):
|
|
434
|
+
"""
|
|
435
|
+
When done this way, the midval ensures that the colorbar is centered at 0. (Don't know how, but it works ;))
|
|
436
|
+
"""
|
|
437
|
+
vmax = error.max()
|
|
438
|
+
vmin = error.min()
|
|
439
|
+
midval = 1 - vmax / (vmax + abs(vmin))
|
|
440
|
+
return midval
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val=None):
|
|
444
|
+
"""
|
|
445
|
+
Plot the relative difference between target and prediction.
|
|
446
|
+
NOTE: The plot is overlapped to the prediction image (in gray scale).
|
|
447
|
+
NOTE: The colorbar is centered at 0.
|
|
448
|
+
"""
|
|
449
|
+
if ax is None:
|
|
450
|
+
_, ax = plt.subplots(figsize=(6, 6))
|
|
451
|
+
|
|
452
|
+
# Relative difference between target and prediction
|
|
453
|
+
rel_diff = get_fractional_change(target, prediction, max_val=max_val)
|
|
454
|
+
midval = get_zero_centered_midval(rel_diff)
|
|
455
|
+
shifted_cmap = shiftedColorMap(
|
|
456
|
+
cmap, start=0, midpoint=midval, stop=1.0, name="shiftedcmap"
|
|
457
|
+
)
|
|
458
|
+
ax.imshow(prediction, cmap="gray")
|
|
459
|
+
img_err = ax.imshow(rel_diff, cmap=shifted_cmap, alpha=1)
|
|
460
|
+
plt.colorbar(img_err, ax=ax)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
# ------------------------------------------------------------------------------------------------
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):
|
|
467
|
+
"""
|
|
468
|
+
Given an index and a validation/test set, it returns the input, target and the reconstructed images for that index.
|
|
469
|
+
"""
|
|
470
|
+
print(f"Predicting for {idx}")
|
|
471
|
+
val_dset.set_img_sz(patch_size, 64)
|
|
472
|
+
|
|
473
|
+
with torch.no_grad():
|
|
474
|
+
# val_dset.enable_noise()
|
|
475
|
+
inp, tar = val_dset[idx]
|
|
476
|
+
# val_dset.disable_noise()
|
|
477
|
+
|
|
478
|
+
inp = torch.Tensor(inp[None])
|
|
479
|
+
tar = torch.Tensor(tar[None])
|
|
480
|
+
inp = inp.cuda()
|
|
481
|
+
x_normalized = model.normalize_input(inp)
|
|
482
|
+
tar = tar.cuda()
|
|
483
|
+
tar_normalized = model.normalize_target(tar)
|
|
484
|
+
|
|
485
|
+
recon_img_list = []
|
|
486
|
+
for _ in range(mmse_count):
|
|
487
|
+
recon_normalized, td_data = model(x_normalized)
|
|
488
|
+
rec_loss, imgs = model.get_reconstruction_loss(
|
|
489
|
+
recon_normalized,
|
|
490
|
+
x_normalized,
|
|
491
|
+
tar_normalized,
|
|
492
|
+
return_predicted_img=True,
|
|
493
|
+
)
|
|
494
|
+
imgs = model.unnormalize_target(imgs)
|
|
495
|
+
recon_img_list.append(imgs.cpu().numpy()[0])
|
|
496
|
+
|
|
497
|
+
recon_img_list = np.array(recon_img_list)
|
|
498
|
+
return inp, tar, recon_img_list
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def get_dset_predictions(
|
|
502
|
+
model,
|
|
503
|
+
dset,
|
|
504
|
+
batch_size: int,
|
|
505
|
+
model_type: ModelType = None,
|
|
506
|
+
mmse_count: int = 1,
|
|
507
|
+
num_workers: int = 4,
|
|
508
|
+
):
|
|
509
|
+
"""
|
|
510
|
+
Get predictions from a model for the entire dataset.
|
|
511
|
+
|
|
512
|
+
Parameters
|
|
513
|
+
----------
|
|
514
|
+
mmse_count : int
|
|
515
|
+
Number of samples to generate for each input and then to average over for MMSE estimation.
|
|
516
|
+
"""
|
|
517
|
+
dloader = DataLoader(
|
|
518
|
+
dset,
|
|
519
|
+
pin_memory=False,
|
|
520
|
+
num_workers=num_workers,
|
|
521
|
+
shuffle=False,
|
|
522
|
+
batch_size=batch_size,
|
|
523
|
+
)
|
|
524
|
+
likelihood = model.model.likelihood
|
|
525
|
+
predictions = []
|
|
526
|
+
predictions_std = []
|
|
527
|
+
losses = []
|
|
528
|
+
logvar_arr = []
|
|
529
|
+
patch_psnr_channels = [RunningPSNR() for _ in range(dset[0][1].shape[0])]
|
|
530
|
+
with torch.no_grad():
|
|
531
|
+
for batch in tqdm(dloader):
|
|
532
|
+
inp, tar = batch[:2]
|
|
533
|
+
inp = inp.cuda()
|
|
534
|
+
tar = tar.cuda()
|
|
535
|
+
|
|
536
|
+
recon_img_list = []
|
|
537
|
+
for mmse_idx in range(mmse_count):
|
|
538
|
+
if model_type == ModelType.Denoiser:
|
|
539
|
+
assert model.denoise_channel in [
|
|
540
|
+
"Ch1",
|
|
541
|
+
"Ch2",
|
|
542
|
+
"input",
|
|
543
|
+
], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
|
|
544
|
+
|
|
545
|
+
x_normalized_new, tar_new = model.get_new_input_target(
|
|
546
|
+
(inp, tar, *batch[2:])
|
|
547
|
+
)
|
|
548
|
+
tar_normalized = model.normalize_target(tar_new)
|
|
549
|
+
recon_normalized, _ = model(x_normalized_new)
|
|
550
|
+
rec_loss, imgs = model.get_reconstruction_loss(
|
|
551
|
+
recon_normalized,
|
|
552
|
+
tar_normalized,
|
|
553
|
+
x_normalized_new,
|
|
554
|
+
return_predicted_img=True,
|
|
555
|
+
)
|
|
556
|
+
else:
|
|
557
|
+
x_normalized = model.normalize_input(inp)
|
|
558
|
+
tar_normalized = model.normalize_target(tar)
|
|
559
|
+
recon_normalized, _ = model(x_normalized)
|
|
560
|
+
rec_loss, imgs = model.get_reconstruction_loss(
|
|
561
|
+
recon_normalized, tar_normalized, inp, return_predicted_img=True
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
if mmse_idx == 0:
|
|
565
|
+
q_dic = (
|
|
566
|
+
likelihood.distr_params(recon_normalized)
|
|
567
|
+
if likelihood is not None
|
|
568
|
+
else {"logvar": None}
|
|
569
|
+
)
|
|
570
|
+
if q_dic["logvar"] is not None:
|
|
571
|
+
logvar_arr.append(q_dic["logvar"].cpu().numpy())
|
|
572
|
+
else:
|
|
573
|
+
logvar_arr.append(np.array([-1]))
|
|
574
|
+
|
|
575
|
+
try:
|
|
576
|
+
losses.append(rec_loss["loss"].cpu().numpy())
|
|
577
|
+
except:
|
|
578
|
+
losses.append(rec_loss["loss"])
|
|
579
|
+
|
|
580
|
+
for i in range(imgs.shape[1]):
|
|
581
|
+
patch_psnr_channels[i].update(imgs[:, i], tar_normalized[:, i])
|
|
582
|
+
|
|
583
|
+
recon_img_list.append(imgs.cpu()[None])
|
|
584
|
+
|
|
585
|
+
samples = torch.cat(recon_img_list, dim=0)
|
|
586
|
+
mmse_imgs = torch.mean(samples, dim=0)
|
|
587
|
+
mmse_std = torch.std(samples, dim=0)
|
|
588
|
+
predictions.append(mmse_imgs.cpu().numpy())
|
|
589
|
+
predictions_std.append(mmse_std.cpu().numpy())
|
|
590
|
+
|
|
591
|
+
psnr = [x.get() for x in patch_psnr_channels]
|
|
592
|
+
return (
|
|
593
|
+
np.concatenate(predictions, axis=0),
|
|
594
|
+
np.array(losses),
|
|
595
|
+
np.concatenate(logvar_arr),
|
|
596
|
+
psnr,
|
|
597
|
+
np.concatenate(predictions_std, axis=0),
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
# ------------------------------------------------------------------------------------------
|
|
602
|
+
### Classes and Functions used to stitch predictions
|
|
603
|
+
class PatchLocation:
|
|
604
|
+
"""
|
|
605
|
+
Encapsulates t_idx and spatial location.
|
|
606
|
+
"""
|
|
607
|
+
|
|
608
|
+
def __init__(self, h_idx_range, w_idx_range, t_idx):
|
|
609
|
+
self.t = t_idx
|
|
610
|
+
self.h_start, self.h_end = h_idx_range
|
|
611
|
+
self.w_start, self.w_end = w_idx_range
|
|
612
|
+
|
|
613
|
+
def __str__(self):
|
|
614
|
+
msg = f"T:{self.t} [{self.h_start}-{self.h_end}) [{self.w_start}-{self.w_end}) "
|
|
615
|
+
return msg
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def _get_location(extra_padding, hwt, pred_h, pred_w):
|
|
619
|
+
h_start, w_start, t_idx = hwt
|
|
620
|
+
h_start -= extra_padding
|
|
621
|
+
h_end = h_start + pred_h
|
|
622
|
+
w_start -= extra_padding
|
|
623
|
+
w_end = w_start + pred_w
|
|
624
|
+
return PatchLocation((h_start, h_end), (w_start, w_end), t_idx)
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def get_location_from_idx(dset, dset_input_idx, pred_h, pred_w):
|
|
628
|
+
"""
|
|
629
|
+
For a given idx of the dataset, it returns where exactly in the dataset, does this prediction lies.
|
|
630
|
+
Note that this prediction also has padded pixels and so a subset of it will be used in the final prediction.
|
|
631
|
+
Which time frame, which spatial location (h_start, h_end, w_start,w_end)
|
|
632
|
+
Args:
|
|
633
|
+
dset:
|
|
634
|
+
dset_input_idx:
|
|
635
|
+
pred_h:
|
|
636
|
+
pred_w:
|
|
637
|
+
|
|
638
|
+
Returns
|
|
639
|
+
-------
|
|
640
|
+
"""
|
|
641
|
+
extra_padding = dset.per_side_overlap_pixelcount()
|
|
642
|
+
htw = dset.get_idx_manager().hwt_from_idx(
|
|
643
|
+
dset_input_idx, grid_size=dset.get_grid_size()
|
|
644
|
+
)
|
|
645
|
+
return _get_location(extra_padding, htw, pred_h, pred_w)
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def remove_pad(pred, loc, extra_padding, smoothening_pixelcount, frame_shape):
|
|
649
|
+
assert smoothening_pixelcount == 0
|
|
650
|
+
if extra_padding - smoothening_pixelcount > 0:
|
|
651
|
+
h_s = extra_padding - smoothening_pixelcount
|
|
652
|
+
|
|
653
|
+
# rows
|
|
654
|
+
h_N = frame_shape[0]
|
|
655
|
+
if loc.h_end > h_N:
|
|
656
|
+
assert loc.h_end - extra_padding + smoothening_pixelcount <= h_N
|
|
657
|
+
h_e = extra_padding - smoothening_pixelcount
|
|
658
|
+
|
|
659
|
+
w_s = extra_padding - smoothening_pixelcount
|
|
660
|
+
|
|
661
|
+
# columns
|
|
662
|
+
w_N = frame_shape[1]
|
|
663
|
+
if loc.w_end > w_N:
|
|
664
|
+
assert loc.w_end - extra_padding + smoothening_pixelcount <= w_N
|
|
665
|
+
|
|
666
|
+
w_e = extra_padding - smoothening_pixelcount
|
|
667
|
+
|
|
668
|
+
return pred[h_s:-h_e, w_s:-w_e]
|
|
669
|
+
|
|
670
|
+
return pred
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount):
|
|
674
|
+
extra_padding = extra_padding - smoothening_pixelcount
|
|
675
|
+
loc.h_start += extra_padding
|
|
676
|
+
loc.w_start += extra_padding
|
|
677
|
+
loc.h_end -= extra_padding
|
|
678
|
+
loc.w_end -= extra_padding
|
|
679
|
+
return loc
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
|
|
683
|
+
"""
|
|
684
|
+
Args:
|
|
685
|
+
smoothening_pixelcount: number of pixels which can be interpolated
|
|
686
|
+
"""
|
|
687
|
+
assert smoothening_pixelcount >= 0 and isinstance(smoothening_pixelcount, int)
|
|
688
|
+
extra_padding = dset.per_side_overlap_pixelcount()
|
|
689
|
+
# if there are more channels, use all of them.
|
|
690
|
+
shape = list(dset.get_data_shape())
|
|
691
|
+
shape[-1] = max(shape[-1], predictions.shape[1])
|
|
692
|
+
|
|
693
|
+
output = np.zeros(shape, dtype=predictions.dtype)
|
|
694
|
+
frame_shape = dset.get_data_shape()[1:3]
|
|
695
|
+
for dset_input_idx in range(predictions.shape[0]):
|
|
696
|
+
loc = get_location_from_idx(
|
|
697
|
+
dset, dset_input_idx, predictions.shape[-2], predictions.shape[-1]
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
mask = None
|
|
701
|
+
cropped_pred_list = []
|
|
702
|
+
for ch_idx in range(predictions.shape[1]):
|
|
703
|
+
# class i
|
|
704
|
+
cropped_pred_i = remove_pad(
|
|
705
|
+
predictions[dset_input_idx, ch_idx],
|
|
706
|
+
loc,
|
|
707
|
+
extra_padding,
|
|
708
|
+
smoothening_pixelcount,
|
|
709
|
+
frame_shape,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
if mask is None:
|
|
713
|
+
# NOTE: don't need to compute it for every patch.
|
|
714
|
+
assert (
|
|
715
|
+
smoothening_pixelcount == 0
|
|
716
|
+
), "For smoothing,enable the get_smoothing_mask. It is disabled since I don't use it and it needs modification to work with non-square images"
|
|
717
|
+
mask = 1
|
|
718
|
+
# mask = _get_smoothing_mask(cropped_pred_i.shape, smoothening_pixelcount, loc, frame_size)
|
|
719
|
+
|
|
720
|
+
cropped_pred_list.append(cropped_pred_i)
|
|
721
|
+
|
|
722
|
+
loc = update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount)
|
|
723
|
+
for ch_idx in range(predictions.shape[1]):
|
|
724
|
+
output[loc.t, loc.h_start : loc.h_end, loc.w_start : loc.w_end, ch_idx] += (
|
|
725
|
+
cropped_pred_list[ch_idx] * mask
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
return output
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
# ------------------------------------------------------------------------------------------
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
# ------------------------------------------------------------------------------------------
|
|
735
|
+
### Classes and Functions used for Calibration
|
|
736
|
+
class Calibration:
|
|
737
|
+
|
|
738
|
+
def __init__(
|
|
739
|
+
self, num_bins: int = 15, mode: Literal["pixelwise", "patchwise"] = "pixelwise"
|
|
740
|
+
):
|
|
741
|
+
self._bins = num_bins
|
|
742
|
+
self._bin_boundaries = None
|
|
743
|
+
self._mode = mode
|
|
744
|
+
assert mode in ["pixelwise", "patchwise"]
|
|
745
|
+
self._boundary_mode = "uniform"
|
|
746
|
+
assert self._boundary_mode in ["quantile", "uniform"]
|
|
747
|
+
# self._bin_boundaries = {}
|
|
748
|
+
|
|
749
|
+
def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
|
|
750
|
+
return np.exp(logvar / 2)
|
|
751
|
+
|
|
752
|
+
def compute_bin_boundaries(self, predict_logvar: np.ndarray) -> np.ndarray:
|
|
753
|
+
"""
|
|
754
|
+
Compute the bin boundaries for `num_bins` bins and the given logvar values.
|
|
755
|
+
"""
|
|
756
|
+
if self._boundary_mode == "quantile":
|
|
757
|
+
boundaries = np.quantile(
|
|
758
|
+
self.logvar_to_std(predict_logvar), np.linspace(0, 1, self._bins + 1)
|
|
759
|
+
)
|
|
760
|
+
return boundaries
|
|
761
|
+
else:
|
|
762
|
+
min_logvar = np.min(predict_logvar)
|
|
763
|
+
max_logvar = np.max(predict_logvar)
|
|
764
|
+
min_std = self.logvar_to_std(min_logvar)
|
|
765
|
+
max_std = self.logvar_to_std(max_logvar)
|
|
766
|
+
return np.linspace(min_std, max_std, self._bins + 1)
|
|
767
|
+
|
|
768
|
+
def compute_stats(
|
|
769
|
+
self, pred: np.ndarray, pred_logvar: np.ndarray, target: np.ndarray
|
|
770
|
+
) -> Dict[int, Dict[str, Union[np.ndarray, List]]]:
|
|
771
|
+
"""
|
|
772
|
+
It computes the bin-wise RMSE and RMV for each channel of the predicted image.
|
|
773
|
+
|
|
774
|
+
Recall that:
|
|
775
|
+
- RMSE = np.sqrt((pred - target)**2 / num_pixels)
|
|
776
|
+
- RMV = np.sqrt(np.mean(pred_std**2))
|
|
777
|
+
|
|
778
|
+
ALGORITHM
|
|
779
|
+
- For each channel:
|
|
780
|
+
- Given the bin boundaries, assign pixels of `std_ch` array to a specific bin index.
|
|
781
|
+
- For each bin index:
|
|
782
|
+
- Compute the RMSE, RMV, and number of pixels for that bin.
|
|
783
|
+
|
|
784
|
+
NOTE: each channel of the predicted image/logvar has its own stats.
|
|
785
|
+
|
|
786
|
+
Args:
|
|
787
|
+
pred: np.ndarray, shape (n, h, w, c)
|
|
788
|
+
pred_logvar: np.ndarray, shape (n, h, w, c)
|
|
789
|
+
target: np.ndarray, shape (n, h, w, c)
|
|
790
|
+
"""
|
|
791
|
+
self._bin_boundaries = {}
|
|
792
|
+
stats = {}
|
|
793
|
+
for ch_idx in range(pred.shape[-1]):
|
|
794
|
+
stats[ch_idx] = {
|
|
795
|
+
"bin_count": [],
|
|
796
|
+
"rmv": [],
|
|
797
|
+
"rmse": [],
|
|
798
|
+
"bin_boundaries": None,
|
|
799
|
+
"bin_matrix": [],
|
|
800
|
+
}
|
|
801
|
+
pred_ch = pred[..., ch_idx]
|
|
802
|
+
logvar_ch = pred_logvar[..., ch_idx]
|
|
803
|
+
std_ch = self.logvar_to_std(logvar_ch)
|
|
804
|
+
target_ch = target[..., ch_idx]
|
|
805
|
+
if self._mode == "pixelwise":
|
|
806
|
+
boundaries = self.compute_bin_boundaries(logvar_ch)
|
|
807
|
+
stats[ch_idx]["bin_boundaries"] = boundaries
|
|
808
|
+
bin_matrix = np.digitize(std_ch.reshape(-1), boundaries)
|
|
809
|
+
bin_matrix = bin_matrix.reshape(std_ch.shape)
|
|
810
|
+
stats[ch_idx]["bin_matrix"] = bin_matrix
|
|
811
|
+
error = (pred_ch - target_ch) ** 2
|
|
812
|
+
for bin_idx in range(self._bins):
|
|
813
|
+
bin_mask = bin_matrix == bin_idx
|
|
814
|
+
bin_error = error[bin_mask]
|
|
815
|
+
bin_size = np.sum(bin_mask)
|
|
816
|
+
bin_error = (
|
|
817
|
+
np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
|
|
818
|
+
) # RMSE
|
|
819
|
+
bin_var = np.sqrt(np.mean(std_ch[bin_mask] ** 2)) # RMV
|
|
820
|
+
stats[ch_idx]["rmse"].append(bin_error)
|
|
821
|
+
stats[ch_idx]["rmv"].append(bin_var)
|
|
822
|
+
stats[ch_idx]["bin_count"].append(bin_size)
|
|
823
|
+
else:
|
|
824
|
+
raise NotImplementedError("Patchwise mode is not implemented yet.")
|
|
825
|
+
return stats
|
|
826
|
+
|
|
827
|
+
|
|
828
|
+
def nll(x, mean, logvar):
|
|
829
|
+
"""
|
|
830
|
+
Log of the probability density of the values x under the Normal
|
|
831
|
+
distribution with parameters mean and logvar.
|
|
832
|
+
|
|
833
|
+
:param x: tensor of points, with shape (batch, channels, dim1, dim2)
|
|
834
|
+
:param mean: tensor with mean of distribution, shape
|
|
835
|
+
(batch, channels, dim1, dim2)
|
|
836
|
+
:param logvar: tensor with log-variance of distribution, shape has to be
|
|
837
|
+
either scalar or broadcastable
|
|
838
|
+
"""
|
|
839
|
+
var = torch.exp(logvar)
|
|
840
|
+
log_prob = -0.5 * (
|
|
841
|
+
((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
|
|
842
|
+
)
|
|
843
|
+
nll = -log_prob
|
|
844
|
+
return nll
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
def get_calibrated_factor_for_stdev(
|
|
848
|
+
pred: Union[np.ndarray, torch.Tensor],
|
|
849
|
+
pred_logvar: Union[np.ndarray, torch.Tensor],
|
|
850
|
+
target: Union[np.ndarray, torch.Tensor],
|
|
851
|
+
batch_size: int = 32,
|
|
852
|
+
epochs: int = 500,
|
|
853
|
+
lr: float = 0.01,
|
|
854
|
+
):
|
|
855
|
+
"""
|
|
856
|
+
Here, we calibrate the uncertainty by multiplying the predicted std (mmse estimate or predicted logvar) with a scalar.
|
|
857
|
+
We return the calibrated scalar. This needs to be multiplied with the std.
|
|
858
|
+
|
|
859
|
+
NOTE: Why is the input logvar and not std? because the model typically predicts logvar and not std.
|
|
860
|
+
"""
|
|
861
|
+
# create a learnable scalar
|
|
862
|
+
scalar = torch.nn.Parameter(torch.tensor(2.0))
|
|
863
|
+
optimizer = torch.optim.Adam([scalar], lr=lr)
|
|
864
|
+
|
|
865
|
+
bar = tqdm(range(epochs))
|
|
866
|
+
for _ in bar:
|
|
867
|
+
optimizer.zero_grad()
|
|
868
|
+
# Select a random batch of predictions
|
|
869
|
+
mask = np.random.randint(0, pred.shape[0], batch_size)
|
|
870
|
+
pred_batch = torch.Tensor(pred[mask]).cuda()
|
|
871
|
+
pred_logvar_batch = torch.Tensor(pred_logvar[mask]).cuda()
|
|
872
|
+
target_batch = torch.Tensor(target[mask]).cuda()
|
|
873
|
+
|
|
874
|
+
loss = torch.mean(
|
|
875
|
+
nll(target_batch, pred_batch, pred_logvar_batch + torch.log(scalar))
|
|
876
|
+
)
|
|
877
|
+
loss.backward()
|
|
878
|
+
optimizer.step()
|
|
879
|
+
bar.set_description(f"nll: {loss.item()} scalar: {scalar.item()}")
|
|
880
|
+
|
|
881
|
+
return np.sqrt(scalar.item())
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
def plot_calibration(ax, calibration_stats):
|
|
885
|
+
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
|
|
886
|
+
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
|
|
887
|
+
ax.plot(
|
|
888
|
+
calibration_stats[0]["rmv"][first_idx:-last_idx],
|
|
889
|
+
calibration_stats[0]["rmse"][first_idx:-last_idx],
|
|
890
|
+
"o",
|
|
891
|
+
label=r"$\hat{C}_0$: Ch1",
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
|
|
895
|
+
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
|
|
896
|
+
ax.plot(
|
|
897
|
+
calibration_stats[1]["rmv"][first_idx:-last_idx],
|
|
898
|
+
calibration_stats[1]["rmse"][first_idx:-last_idx],
|
|
899
|
+
"o",
|
|
900
|
+
label=r"$\hat{C}_1: : Ch2$",
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
ax.set_xlabel("RMV")
|
|
904
|
+
ax.set_ylabel("RMSE")
|
|
905
|
+
ax.legend()
|