careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,905 @@
1
+ """
2
+ This script provides methods to evaluate the performance of the LVAE model.
3
+ It includes functions to:
4
+ - make predictions,
5
+ - quantify the performance of the model
6
+ - create plots to visualize the results.
7
+ """
8
+
9
+ import math
10
+ import os
11
+ from typing import Dict, List, Literal, Union
12
+
13
+ import matplotlib
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ import torch
17
+ from matplotlib.gridspec import GridSpec
18
+ from torch.utils.data import DataLoader
19
+ from tqdm import tqdm
20
+
21
+ from careamics.models.lvae.utils import ModelType
22
+
23
+ from .metrics import RangeInvariantPsnr, RunningPSNR
24
+
25
+
26
+ # ------------------------------------------------------------------------------------------------
27
+ # Function of plotting: TODO -> moved them to another file, plot_utils.py
28
+ def clean_ax(ax):
29
+ """
30
+ Helper function to remove ticks from axes in plots.
31
+ """
32
+ # 2D or 1D axes are of type np.ndarray
33
+ if isinstance(ax, np.ndarray):
34
+ for one_ax in ax:
35
+ clean_ax(one_ax)
36
+ return
37
+
38
+ ax.set_yticklabels([])
39
+ ax.set_xticklabels([])
40
+ ax.tick_params(left=False, right=False, top=False, bottom=False)
41
+
42
+
43
+ def get_plots_output_dir(
44
+ saveplotsdir: str, patch_size: int, mmse_count: int = 50
45
+ ) -> str:
46
+ """
47
+ Given the path to a root directory to save plots, patch size, and mmse count,
48
+ it returns the specific directory to save the plots.
49
+ """
50
+ plotsrootdir = os.path.join(
51
+ saveplotsdir, f"plots/patch_{patch_size}_mmse_{mmse_count}"
52
+ )
53
+ os.makedirs(plotsrootdir, exist_ok=True)
54
+ print(plotsrootdir)
55
+ return plotsrootdir
56
+
57
+
58
+ def get_psnr_str(tar_hsnr, pred, col_idx):
59
+ """
60
+ Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
61
+ """
62
+ return (
63
+ f"{RangeInvariantPsnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"
64
+ )
65
+
66
+
67
+ def add_psnr_str(ax_, psnr):
68
+ """
69
+ Add psnr string to the axes
70
+ """
71
+ textstr = f"PSNR\n{psnr}"
72
+ props = dict(boxstyle="round", facecolor="gray", alpha=0.5)
73
+ # place a text box in upper left in axes coords
74
+ ax_.text(
75
+ 0.05,
76
+ 0.95,
77
+ textstr,
78
+ transform=ax_.transAxes,
79
+ fontsize=11,
80
+ verticalalignment="top",
81
+ bbox=props,
82
+ color="white",
83
+ )
84
+
85
+
86
+ def get_last_index(bin_count, quantile):
87
+ cumsum = np.cumsum(bin_count)
88
+ normalized_cumsum = cumsum / cumsum[-1]
89
+ for i in range(1, len(normalized_cumsum)):
90
+ if normalized_cumsum[-i] < quantile:
91
+ return i - 1
92
+ return None
93
+
94
+
95
+ def get_first_index(bin_count, quantile):
96
+ cumsum = np.cumsum(bin_count)
97
+ normalized_cumsum = cumsum / cumsum[-1]
98
+ for i in range(len(normalized_cumsum)):
99
+ if normalized_cumsum[i] > quantile:
100
+ return i
101
+ return None
102
+
103
+
104
+ def show_for_one(
105
+ idx,
106
+ val_dset,
107
+ highsnr_val_dset,
108
+ model,
109
+ calibration_stats,
110
+ mmse_count=5,
111
+ patch_size=256,
112
+ num_samples=2,
113
+ baseline_preds=None,
114
+ ):
115
+ """
116
+ Given an index, it plots the input, target, reconstructed images and the difference image.
117
+ Note the the difference image is computed with respect to a ground truth image, obtained from the high SNR dataset.
118
+ """
119
+ highsnr_val_dset.set_img_sz(patch_size, 64)
120
+ highsnr_val_dset.disable_noise()
121
+ _, tar_hsnr = highsnr_val_dset[idx]
122
+ inp, tar, recon_img_list = get_predictions(
123
+ idx, val_dset, model, mmse_count=mmse_count, patch_size=patch_size
124
+ )
125
+ plot_crops(
126
+ inp,
127
+ tar,
128
+ tar_hsnr,
129
+ recon_img_list,
130
+ calibration_stats,
131
+ num_samples=num_samples,
132
+ baseline_preds=baseline_preds,
133
+ )
134
+
135
+
136
+ def plot_crops(
137
+ inp,
138
+ tar,
139
+ tar_hsnr,
140
+ recon_img_list,
141
+ calibration_stats,
142
+ num_samples=2,
143
+ baseline_preds=None,
144
+ ):
145
+ """ """
146
+ if baseline_preds is None:
147
+ baseline_preds = []
148
+ if len(baseline_preds) > 0:
149
+ for i in range(len(baseline_preds)):
150
+ if baseline_preds[i].shape != tar_hsnr.shape:
151
+ print(
152
+ f"Baseline prediction {i} shape {baseline_preds[i].shape} does not match target shape {tar_hsnr.shape}"
153
+ )
154
+ print("This happens when we want to predict the edges of the image.")
155
+ return
156
+
157
+ # color_ch_list = ['goldenrod', 'cyan']
158
+ # color_pred = 'red'
159
+ # insetplot_xmax_value = 10000
160
+ # insetplot_xmin_value = -1000
161
+ # inset_min_labelsize = 10
162
+ # inset_rect = [0.05, 0.05, 0.4, 0.2]
163
+
164
+ # Set plot attributes
165
+ img_sz = 3
166
+ ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1)
167
+ grid_factor = 5
168
+ grid_img_sz = img_sz * grid_factor
169
+ example_spacing = 1
170
+ c0_extra = 1
171
+ nimgs = 1
172
+ fig_w = ncols * img_sz + 2 * c0_extra / grid_factor
173
+ fig_h = int(img_sz * ncols + (example_spacing * (nimgs - 1)) / grid_factor)
174
+ fig = plt.figure(figsize=(fig_w, fig_h))
175
+ gs = GridSpec(
176
+ nrows=int(grid_factor * fig_h),
177
+ ncols=int(grid_factor * fig_w),
178
+ hspace=0.2,
179
+ wspace=0.2,
180
+ )
181
+ params = {"mathtext.default": "regular"}
182
+ plt.rcParams.update(params)
183
+
184
+ # plot baselines
185
+ for i in range(2, 2 + len(baseline_preds)):
186
+ for col_idx in range(baseline_preds[0].shape[0]):
187
+ ax_temp = fig.add_subplot(
188
+ gs[
189
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
190
+ i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra,
191
+ ]
192
+ )
193
+ print(tar_hsnr.shape, baseline_preds[i - 2].shape)
194
+ psnr = get_psnr_str(tar_hsnr, baseline_preds[i - 2], col_idx)
195
+ ax_temp.imshow(baseline_preds[i - 2][col_idx], cmap="magma")
196
+ add_psnr_str(ax_temp, psnr)
197
+ clean_ax(ax_temp)
198
+
199
+ # plot samples
200
+ sample_start_idx = 2 + len(baseline_preds)
201
+ for i in range(sample_start_idx, ncols - 3):
202
+ for col_idx in range(recon_img_list.shape[1]):
203
+ ax_temp = fig.add_subplot(
204
+ gs[
205
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
206
+ i * grid_img_sz + c0_extra : (i + 1) * grid_img_sz + c0_extra,
207
+ ]
208
+ )
209
+ psnr = get_psnr_str(tar_hsnr, recon_img_list[i - sample_start_idx], col_idx)
210
+ ax_temp.imshow(recon_img_list[i - sample_start_idx][col_idx], cmap="magma")
211
+ add_psnr_str(ax_temp, psnr)
212
+ clean_ax(ax_temp)
213
+ # inset_ax = add_pixel_kde(ax_temp,
214
+ # inset_rect,
215
+ # [tar_hsnr[col_idx],
216
+ # recon_img_list[i - sample_start_idx][col_idx]],
217
+ # inset_min_labelsize,
218
+ # label_list=['', ''],
219
+ # color_list=[color_ch_list[col_idx], color_pred],
220
+ # plot_xmax_value=insetplot_xmax_value,
221
+ # plot_xmin_value=insetplot_xmin_value)
222
+
223
+ # inset_ax.set_xticks([])
224
+ # inset_ax.set_yticks([])
225
+
226
+ # difference image
227
+ if num_samples > 1:
228
+ for col_idx in range(recon_img_list.shape[1]):
229
+ ax_temp = fig.add_subplot(
230
+ gs[
231
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
232
+ (ncols - 3) * grid_img_sz
233
+ + c0_extra : (ncols - 2) * grid_img_sz
234
+ + c0_extra,
235
+ ]
236
+ )
237
+ ax_temp.imshow(
238
+ recon_img_list[1][col_idx] - recon_img_list[0][col_idx], cmap="coolwarm"
239
+ )
240
+ clean_ax(ax_temp)
241
+
242
+ for col_idx in range(recon_img_list.shape[1]):
243
+ # print(recon_img_list.shape)
244
+ ax_temp = fig.add_subplot(
245
+ gs[
246
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
247
+ c0_extra
248
+ + (ncols - 2) * grid_img_sz : (ncols - 1) * grid_img_sz
249
+ + c0_extra,
250
+ ]
251
+ )
252
+ psnr = get_psnr_str(tar_hsnr, recon_img_list.mean(axis=0), col_idx)
253
+ ax_temp.imshow(recon_img_list.mean(axis=0)[col_idx], cmap="magma")
254
+ add_psnr_str(ax_temp, psnr)
255
+ # inset_ax = add_pixel_kde(ax_temp,
256
+ # inset_rect,
257
+ # [tar_hsnr[col_idx],
258
+ # recon_img_list.mean(axis=0)[col_idx]],
259
+ # inset_min_labelsize,
260
+ # label_list=['', ''],
261
+ # color_list=[color_ch_list[col_idx], color_pred],
262
+ # plot_xmax_value=insetplot_xmax_value,
263
+ # plot_xmin_value=insetplot_xmin_value)
264
+ # inset_ax.set_xticks([])
265
+ # inset_ax.set_yticks([])
266
+
267
+ clean_ax(ax_temp)
268
+
269
+ ax_temp = fig.add_subplot(
270
+ gs[
271
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
272
+ (ncols - 1) * grid_img_sz
273
+ + 2 * c0_extra : (ncols) * grid_img_sz
274
+ + 2 * c0_extra,
275
+ ]
276
+ )
277
+ ax_temp.imshow(tar_hsnr[col_idx], cmap="magma")
278
+ if col_idx == 0:
279
+ legend_ch1_ax = ax_temp
280
+ if col_idx == 1:
281
+ legend_ch2_ax = ax_temp
282
+
283
+ # inset_ax = add_pixel_kde(ax_temp,
284
+ # inset_rect,
285
+ # [tar_hsnr[col_idx],
286
+ # ],
287
+ # inset_min_labelsize,
288
+ # label_list=[''],
289
+ # color_list=[color_ch_list[col_idx]],
290
+ # plot_xmax_value=insetplot_xmax_value,
291
+ # plot_xmin_value=insetplot_xmin_value)
292
+ # inset_ax.set_xticks([])
293
+ # inset_ax.set_yticks([])
294
+
295
+ clean_ax(ax_temp)
296
+
297
+ ax_temp = fig.add_subplot(
298
+ gs[
299
+ col_idx * grid_img_sz : grid_img_sz * (col_idx + 1),
300
+ grid_img_sz : 2 * grid_img_sz,
301
+ ]
302
+ )
303
+ ax_temp.imshow(tar[0, col_idx].cpu().numpy(), cmap="magma")
304
+ # inset_ax = add_pixel_kde(ax_temp,
305
+ # inset_rect,
306
+ # [tar[0,col_idx].cpu().numpy(),
307
+ # ],
308
+ # inset_min_labelsize,
309
+ # label_list=[''],
310
+ # color_list=[color_ch_list[col_idx]],
311
+ # plot_kwargs_list=[{'linestyle':'--'}],
312
+ # plot_xmax_value=insetplot_xmax_value,
313
+ # plot_xmin_value=insetplot_xmin_value)
314
+
315
+ # inset_ax.set_xticks([])
316
+ # inset_ax.set_yticks([])
317
+
318
+ clean_ax(ax_temp)
319
+
320
+ ax_temp = fig.add_subplot(gs[0:grid_img_sz, 0:grid_img_sz])
321
+ ax_temp.imshow(inp[0, 0].cpu().numpy(), cmap="magma")
322
+ clean_ax(ax_temp)
323
+
324
+ # line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-', label='$C_1$')
325
+ # line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-', label='$C_2$')
326
+ # line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-', label='Pred')
327
+ # line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='--', label='$C^N_1$')
328
+ # line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='--', label='$C^N_2$')
329
+ # legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred], loc='upper right', frameon=False, labelcolor='white',
330
+ # prop={'size': 11})
331
+ # legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred], loc='upper right', frameon=False, labelcolor='white',
332
+ # prop={'size': 11})
333
+
334
+ if calibration_stats is not None:
335
+ smaller_offset = 4
336
+ ax_temp = fig.add_subplot(
337
+ gs[
338
+ grid_img_sz + 1 : 2 * grid_img_sz - smaller_offset + 1,
339
+ smaller_offset - 1 : grid_img_sz - 1,
340
+ ]
341
+ )
342
+ plot_calibration(ax_temp, calibration_stats)
343
+
344
+
345
+ def plot_calibration(ax, calibration_stats):
346
+ """
347
+ To plot calibration statistics (RMV vs RMSE).
348
+ """
349
+ first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
350
+ last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
351
+ ax.plot(
352
+ calibration_stats[0]["rmv"][first_idx:-last_idx],
353
+ calibration_stats[0]["rmse"][first_idx:-last_idx],
354
+ "o",
355
+ label=r"$\hat{C}_0$",
356
+ )
357
+
358
+ first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
359
+ last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
360
+ ax.plot(
361
+ calibration_stats[1]["rmv"][first_idx:-last_idx],
362
+ calibration_stats[1]["rmse"][first_idx:-last_idx],
363
+ "o",
364
+ label=r"$\hat{C}_1$",
365
+ )
366
+
367
+ ax.set_xlabel("RMV")
368
+ ax.set_ylabel("RMSE")
369
+ ax.legend()
370
+
371
+
372
+ def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name="shiftedcmap"):
373
+ """
374
+ Adapted from https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-matplotlib
375
+
376
+ Function to offset the "center" of a colormap. Useful for
377
+ data with a negative min and positive max and you want the
378
+ middle of the colormap's dynamic range to be at zero.
379
+
380
+ Input
381
+ -----
382
+ cmap : The matplotlib colormap to be altered
383
+ start : Offset from lowest point in the colormap's range.
384
+ Defaults to 0.0 (no lower offset). Should be between
385
+ 0.0 and `midpoint`.
386
+ midpoint : The new center of the colormap. Defaults to
387
+ 0.5 (no shift). Should be between 0.0 and 1.0. In
388
+ general, this should be 1 - vmax / (vmax + abs(vmin))
389
+ For example if your data range from -15.0 to +5.0 and
390
+ you want the center of the colormap at 0.0, `midpoint`
391
+ should be set to 1 - 5/(5 + 15)) or 0.75
392
+ stop : Offset from highest point in the colormap's range.
393
+ Defaults to 1.0 (no upper offset). Should be between
394
+ `midpoint` and 1.0.
395
+ """
396
+ cdict = {"red": [], "green": [], "blue": [], "alpha": []}
397
+
398
+ # regular index to compute the colors
399
+ reg_index = np.linspace(start, stop, 257)
400
+ mid_idx = len(reg_index) // 2
401
+ # shifted index to match the data
402
+ shift_index = np.hstack(
403
+ [
404
+ np.linspace(0.0, midpoint, 128, endpoint=False),
405
+ np.linspace(midpoint, 1.0, 129, endpoint=True),
406
+ ]
407
+ )
408
+
409
+ for ri, si in zip(reg_index, shift_index):
410
+ r, g, b, a = cmap(ri)
411
+ a = np.abs(ri - reg_index[mid_idx]) / reg_index[mid_idx]
412
+ # print(a)
413
+ cdict["red"].append((si, r, r))
414
+ cdict["green"].append((si, g, g))
415
+ cdict["blue"].append((si, b, b))
416
+ cdict["alpha"].append((si, a, a))
417
+
418
+ newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
419
+ matplotlib.colormaps.register(cmap=newcmap, force=True)
420
+
421
+ return newcmap
422
+
423
+
424
+ def get_fractional_change(target, prediction, max_val=None):
425
+ """
426
+ Get relative difference between target and prediction.
427
+ """
428
+ if max_val is None:
429
+ max_val = target.max()
430
+ return (target - prediction) / max_val
431
+
432
+
433
+ def get_zero_centered_midval(error):
434
+ """
435
+ When done this way, the midval ensures that the colorbar is centered at 0. (Don't know how, but it works ;))
436
+ """
437
+ vmax = error.max()
438
+ vmin = error.min()
439
+ midval = 1 - vmax / (vmax + abs(vmin))
440
+ return midval
441
+
442
+
443
+ def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val=None):
444
+ """
445
+ Plot the relative difference between target and prediction.
446
+ NOTE: The plot is overlapped to the prediction image (in gray scale).
447
+ NOTE: The colorbar is centered at 0.
448
+ """
449
+ if ax is None:
450
+ _, ax = plt.subplots(figsize=(6, 6))
451
+
452
+ # Relative difference between target and prediction
453
+ rel_diff = get_fractional_change(target, prediction, max_val=max_val)
454
+ midval = get_zero_centered_midval(rel_diff)
455
+ shifted_cmap = shiftedColorMap(
456
+ cmap, start=0, midpoint=midval, stop=1.0, name="shiftedcmap"
457
+ )
458
+ ax.imshow(prediction, cmap="gray")
459
+ img_err = ax.imshow(rel_diff, cmap=shifted_cmap, alpha=1)
460
+ plt.colorbar(img_err, ax=ax)
461
+
462
+
463
+ # ------------------------------------------------------------------------------------------------
464
+
465
+
466
+ def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):
467
+ """
468
+ Given an index and a validation/test set, it returns the input, target and the reconstructed images for that index.
469
+ """
470
+ print(f"Predicting for {idx}")
471
+ val_dset.set_img_sz(patch_size, 64)
472
+
473
+ with torch.no_grad():
474
+ # val_dset.enable_noise()
475
+ inp, tar = val_dset[idx]
476
+ # val_dset.disable_noise()
477
+
478
+ inp = torch.Tensor(inp[None])
479
+ tar = torch.Tensor(tar[None])
480
+ inp = inp.cuda()
481
+ x_normalized = model.normalize_input(inp)
482
+ tar = tar.cuda()
483
+ tar_normalized = model.normalize_target(tar)
484
+
485
+ recon_img_list = []
486
+ for _ in range(mmse_count):
487
+ recon_normalized, td_data = model(x_normalized)
488
+ rec_loss, imgs = model.get_reconstruction_loss(
489
+ recon_normalized,
490
+ x_normalized,
491
+ tar_normalized,
492
+ return_predicted_img=True,
493
+ )
494
+ imgs = model.unnormalize_target(imgs)
495
+ recon_img_list.append(imgs.cpu().numpy()[0])
496
+
497
+ recon_img_list = np.array(recon_img_list)
498
+ return inp, tar, recon_img_list
499
+
500
+
501
+ def get_dset_predictions(
502
+ model,
503
+ dset,
504
+ batch_size: int,
505
+ model_type: ModelType = None,
506
+ mmse_count: int = 1,
507
+ num_workers: int = 4,
508
+ ):
509
+ """
510
+ Get predictions from a model for the entire dataset.
511
+
512
+ Parameters
513
+ ----------
514
+ mmse_count : int
515
+ Number of samples to generate for each input and then to average over for MMSE estimation.
516
+ """
517
+ dloader = DataLoader(
518
+ dset,
519
+ pin_memory=False,
520
+ num_workers=num_workers,
521
+ shuffle=False,
522
+ batch_size=batch_size,
523
+ )
524
+ likelihood = model.model.likelihood
525
+ predictions = []
526
+ predictions_std = []
527
+ losses = []
528
+ logvar_arr = []
529
+ patch_psnr_channels = [RunningPSNR() for _ in range(dset[0][1].shape[0])]
530
+ with torch.no_grad():
531
+ for batch in tqdm(dloader):
532
+ inp, tar = batch[:2]
533
+ inp = inp.cuda()
534
+ tar = tar.cuda()
535
+
536
+ recon_img_list = []
537
+ for mmse_idx in range(mmse_count):
538
+ if model_type == ModelType.Denoiser:
539
+ assert model.denoise_channel in [
540
+ "Ch1",
541
+ "Ch2",
542
+ "input",
543
+ ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
544
+
545
+ x_normalized_new, tar_new = model.get_new_input_target(
546
+ (inp, tar, *batch[2:])
547
+ )
548
+ tar_normalized = model.normalize_target(tar_new)
549
+ recon_normalized, _ = model(x_normalized_new)
550
+ rec_loss, imgs = model.get_reconstruction_loss(
551
+ recon_normalized,
552
+ tar_normalized,
553
+ x_normalized_new,
554
+ return_predicted_img=True,
555
+ )
556
+ else:
557
+ x_normalized = model.normalize_input(inp)
558
+ tar_normalized = model.normalize_target(tar)
559
+ recon_normalized, _ = model(x_normalized)
560
+ rec_loss, imgs = model.get_reconstruction_loss(
561
+ recon_normalized, tar_normalized, inp, return_predicted_img=True
562
+ )
563
+
564
+ if mmse_idx == 0:
565
+ q_dic = (
566
+ likelihood.distr_params(recon_normalized)
567
+ if likelihood is not None
568
+ else {"logvar": None}
569
+ )
570
+ if q_dic["logvar"] is not None:
571
+ logvar_arr.append(q_dic["logvar"].cpu().numpy())
572
+ else:
573
+ logvar_arr.append(np.array([-1]))
574
+
575
+ try:
576
+ losses.append(rec_loss["loss"].cpu().numpy())
577
+ except:
578
+ losses.append(rec_loss["loss"])
579
+
580
+ for i in range(imgs.shape[1]):
581
+ patch_psnr_channels[i].update(imgs[:, i], tar_normalized[:, i])
582
+
583
+ recon_img_list.append(imgs.cpu()[None])
584
+
585
+ samples = torch.cat(recon_img_list, dim=0)
586
+ mmse_imgs = torch.mean(samples, dim=0)
587
+ mmse_std = torch.std(samples, dim=0)
588
+ predictions.append(mmse_imgs.cpu().numpy())
589
+ predictions_std.append(mmse_std.cpu().numpy())
590
+
591
+ psnr = [x.get() for x in patch_psnr_channels]
592
+ return (
593
+ np.concatenate(predictions, axis=0),
594
+ np.array(losses),
595
+ np.concatenate(logvar_arr),
596
+ psnr,
597
+ np.concatenate(predictions_std, axis=0),
598
+ )
599
+
600
+
601
+ # ------------------------------------------------------------------------------------------
602
+ ### Classes and Functions used to stitch predictions
603
+ class PatchLocation:
604
+ """
605
+ Encapsulates t_idx and spatial location.
606
+ """
607
+
608
+ def __init__(self, h_idx_range, w_idx_range, t_idx):
609
+ self.t = t_idx
610
+ self.h_start, self.h_end = h_idx_range
611
+ self.w_start, self.w_end = w_idx_range
612
+
613
+ def __str__(self):
614
+ msg = f"T:{self.t} [{self.h_start}-{self.h_end}) [{self.w_start}-{self.w_end}) "
615
+ return msg
616
+
617
+
618
+ def _get_location(extra_padding, hwt, pred_h, pred_w):
619
+ h_start, w_start, t_idx = hwt
620
+ h_start -= extra_padding
621
+ h_end = h_start + pred_h
622
+ w_start -= extra_padding
623
+ w_end = w_start + pred_w
624
+ return PatchLocation((h_start, h_end), (w_start, w_end), t_idx)
625
+
626
+
627
+ def get_location_from_idx(dset, dset_input_idx, pred_h, pred_w):
628
+ """
629
+ For a given idx of the dataset, it returns where exactly in the dataset, does this prediction lies.
630
+ Note that this prediction also has padded pixels and so a subset of it will be used in the final prediction.
631
+ Which time frame, which spatial location (h_start, h_end, w_start,w_end)
632
+ Args:
633
+ dset:
634
+ dset_input_idx:
635
+ pred_h:
636
+ pred_w:
637
+
638
+ Returns
639
+ -------
640
+ """
641
+ extra_padding = dset.per_side_overlap_pixelcount()
642
+ htw = dset.get_idx_manager().hwt_from_idx(
643
+ dset_input_idx, grid_size=dset.get_grid_size()
644
+ )
645
+ return _get_location(extra_padding, htw, pred_h, pred_w)
646
+
647
+
648
+ def remove_pad(pred, loc, extra_padding, smoothening_pixelcount, frame_shape):
649
+ assert smoothening_pixelcount == 0
650
+ if extra_padding - smoothening_pixelcount > 0:
651
+ h_s = extra_padding - smoothening_pixelcount
652
+
653
+ # rows
654
+ h_N = frame_shape[0]
655
+ if loc.h_end > h_N:
656
+ assert loc.h_end - extra_padding + smoothening_pixelcount <= h_N
657
+ h_e = extra_padding - smoothening_pixelcount
658
+
659
+ w_s = extra_padding - smoothening_pixelcount
660
+
661
+ # columns
662
+ w_N = frame_shape[1]
663
+ if loc.w_end > w_N:
664
+ assert loc.w_end - extra_padding + smoothening_pixelcount <= w_N
665
+
666
+ w_e = extra_padding - smoothening_pixelcount
667
+
668
+ return pred[h_s:-h_e, w_s:-w_e]
669
+
670
+ return pred
671
+
672
+
673
+ def update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount):
674
+ extra_padding = extra_padding - smoothening_pixelcount
675
+ loc.h_start += extra_padding
676
+ loc.w_start += extra_padding
677
+ loc.h_end -= extra_padding
678
+ loc.w_end -= extra_padding
679
+ return loc
680
+
681
+
682
+ def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
683
+ """
684
+ Args:
685
+ smoothening_pixelcount: number of pixels which can be interpolated
686
+ """
687
+ assert smoothening_pixelcount >= 0 and isinstance(smoothening_pixelcount, int)
688
+ extra_padding = dset.per_side_overlap_pixelcount()
689
+ # if there are more channels, use all of them.
690
+ shape = list(dset.get_data_shape())
691
+ shape[-1] = max(shape[-1], predictions.shape[1])
692
+
693
+ output = np.zeros(shape, dtype=predictions.dtype)
694
+ frame_shape = dset.get_data_shape()[1:3]
695
+ for dset_input_idx in range(predictions.shape[0]):
696
+ loc = get_location_from_idx(
697
+ dset, dset_input_idx, predictions.shape[-2], predictions.shape[-1]
698
+ )
699
+
700
+ mask = None
701
+ cropped_pred_list = []
702
+ for ch_idx in range(predictions.shape[1]):
703
+ # class i
704
+ cropped_pred_i = remove_pad(
705
+ predictions[dset_input_idx, ch_idx],
706
+ loc,
707
+ extra_padding,
708
+ smoothening_pixelcount,
709
+ frame_shape,
710
+ )
711
+
712
+ if mask is None:
713
+ # NOTE: don't need to compute it for every patch.
714
+ assert (
715
+ smoothening_pixelcount == 0
716
+ ), "For smoothing,enable the get_smoothing_mask. It is disabled since I don't use it and it needs modification to work with non-square images"
717
+ mask = 1
718
+ # mask = _get_smoothing_mask(cropped_pred_i.shape, smoothening_pixelcount, loc, frame_size)
719
+
720
+ cropped_pred_list.append(cropped_pred_i)
721
+
722
+ loc = update_loc_for_final_insertion(loc, extra_padding, smoothening_pixelcount)
723
+ for ch_idx in range(predictions.shape[1]):
724
+ output[loc.t, loc.h_start : loc.h_end, loc.w_start : loc.w_end, ch_idx] += (
725
+ cropped_pred_list[ch_idx] * mask
726
+ )
727
+
728
+ return output
729
+
730
+
731
+ # ------------------------------------------------------------------------------------------
732
+
733
+
734
+ # ------------------------------------------------------------------------------------------
735
+ ### Classes and Functions used for Calibration
736
+ class Calibration:
737
+
738
+ def __init__(
739
+ self, num_bins: int = 15, mode: Literal["pixelwise", "patchwise"] = "pixelwise"
740
+ ):
741
+ self._bins = num_bins
742
+ self._bin_boundaries = None
743
+ self._mode = mode
744
+ assert mode in ["pixelwise", "patchwise"]
745
+ self._boundary_mode = "uniform"
746
+ assert self._boundary_mode in ["quantile", "uniform"]
747
+ # self._bin_boundaries = {}
748
+
749
+ def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
750
+ return np.exp(logvar / 2)
751
+
752
+ def compute_bin_boundaries(self, predict_logvar: np.ndarray) -> np.ndarray:
753
+ """
754
+ Compute the bin boundaries for `num_bins` bins and the given logvar values.
755
+ """
756
+ if self._boundary_mode == "quantile":
757
+ boundaries = np.quantile(
758
+ self.logvar_to_std(predict_logvar), np.linspace(0, 1, self._bins + 1)
759
+ )
760
+ return boundaries
761
+ else:
762
+ min_logvar = np.min(predict_logvar)
763
+ max_logvar = np.max(predict_logvar)
764
+ min_std = self.logvar_to_std(min_logvar)
765
+ max_std = self.logvar_to_std(max_logvar)
766
+ return np.linspace(min_std, max_std, self._bins + 1)
767
+
768
+ def compute_stats(
769
+ self, pred: np.ndarray, pred_logvar: np.ndarray, target: np.ndarray
770
+ ) -> Dict[int, Dict[str, Union[np.ndarray, List]]]:
771
+ """
772
+ It computes the bin-wise RMSE and RMV for each channel of the predicted image.
773
+
774
+ Recall that:
775
+ - RMSE = np.sqrt((pred - target)**2 / num_pixels)
776
+ - RMV = np.sqrt(np.mean(pred_std**2))
777
+
778
+ ALGORITHM
779
+ - For each channel:
780
+ - Given the bin boundaries, assign pixels of `std_ch` array to a specific bin index.
781
+ - For each bin index:
782
+ - Compute the RMSE, RMV, and number of pixels for that bin.
783
+
784
+ NOTE: each channel of the predicted image/logvar has its own stats.
785
+
786
+ Args:
787
+ pred: np.ndarray, shape (n, h, w, c)
788
+ pred_logvar: np.ndarray, shape (n, h, w, c)
789
+ target: np.ndarray, shape (n, h, w, c)
790
+ """
791
+ self._bin_boundaries = {}
792
+ stats = {}
793
+ for ch_idx in range(pred.shape[-1]):
794
+ stats[ch_idx] = {
795
+ "bin_count": [],
796
+ "rmv": [],
797
+ "rmse": [],
798
+ "bin_boundaries": None,
799
+ "bin_matrix": [],
800
+ }
801
+ pred_ch = pred[..., ch_idx]
802
+ logvar_ch = pred_logvar[..., ch_idx]
803
+ std_ch = self.logvar_to_std(logvar_ch)
804
+ target_ch = target[..., ch_idx]
805
+ if self._mode == "pixelwise":
806
+ boundaries = self.compute_bin_boundaries(logvar_ch)
807
+ stats[ch_idx]["bin_boundaries"] = boundaries
808
+ bin_matrix = np.digitize(std_ch.reshape(-1), boundaries)
809
+ bin_matrix = bin_matrix.reshape(std_ch.shape)
810
+ stats[ch_idx]["bin_matrix"] = bin_matrix
811
+ error = (pred_ch - target_ch) ** 2
812
+ for bin_idx in range(self._bins):
813
+ bin_mask = bin_matrix == bin_idx
814
+ bin_error = error[bin_mask]
815
+ bin_size = np.sum(bin_mask)
816
+ bin_error = (
817
+ np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
818
+ ) # RMSE
819
+ bin_var = np.sqrt(np.mean(std_ch[bin_mask] ** 2)) # RMV
820
+ stats[ch_idx]["rmse"].append(bin_error)
821
+ stats[ch_idx]["rmv"].append(bin_var)
822
+ stats[ch_idx]["bin_count"].append(bin_size)
823
+ else:
824
+ raise NotImplementedError("Patchwise mode is not implemented yet.")
825
+ return stats
826
+
827
+
828
+ def nll(x, mean, logvar):
829
+ """
830
+ Log of the probability density of the values x under the Normal
831
+ distribution with parameters mean and logvar.
832
+
833
+ :param x: tensor of points, with shape (batch, channels, dim1, dim2)
834
+ :param mean: tensor with mean of distribution, shape
835
+ (batch, channels, dim1, dim2)
836
+ :param logvar: tensor with log-variance of distribution, shape has to be
837
+ either scalar or broadcastable
838
+ """
839
+ var = torch.exp(logvar)
840
+ log_prob = -0.5 * (
841
+ ((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
842
+ )
843
+ nll = -log_prob
844
+ return nll
845
+
846
+
847
+ def get_calibrated_factor_for_stdev(
848
+ pred: Union[np.ndarray, torch.Tensor],
849
+ pred_logvar: Union[np.ndarray, torch.Tensor],
850
+ target: Union[np.ndarray, torch.Tensor],
851
+ batch_size: int = 32,
852
+ epochs: int = 500,
853
+ lr: float = 0.01,
854
+ ):
855
+ """
856
+ Here, we calibrate the uncertainty by multiplying the predicted std (mmse estimate or predicted logvar) with a scalar.
857
+ We return the calibrated scalar. This needs to be multiplied with the std.
858
+
859
+ NOTE: Why is the input logvar and not std? because the model typically predicts logvar and not std.
860
+ """
861
+ # create a learnable scalar
862
+ scalar = torch.nn.Parameter(torch.tensor(2.0))
863
+ optimizer = torch.optim.Adam([scalar], lr=lr)
864
+
865
+ bar = tqdm(range(epochs))
866
+ for _ in bar:
867
+ optimizer.zero_grad()
868
+ # Select a random batch of predictions
869
+ mask = np.random.randint(0, pred.shape[0], batch_size)
870
+ pred_batch = torch.Tensor(pred[mask]).cuda()
871
+ pred_logvar_batch = torch.Tensor(pred_logvar[mask]).cuda()
872
+ target_batch = torch.Tensor(target[mask]).cuda()
873
+
874
+ loss = torch.mean(
875
+ nll(target_batch, pred_batch, pred_logvar_batch + torch.log(scalar))
876
+ )
877
+ loss.backward()
878
+ optimizer.step()
879
+ bar.set_description(f"nll: {loss.item()} scalar: {scalar.item()}")
880
+
881
+ return np.sqrt(scalar.item())
882
+
883
+
884
+ def plot_calibration(ax, calibration_stats):
885
+ first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
886
+ last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
887
+ ax.plot(
888
+ calibration_stats[0]["rmv"][first_idx:-last_idx],
889
+ calibration_stats[0]["rmse"][first_idx:-last_idx],
890
+ "o",
891
+ label=r"$\hat{C}_0$: Ch1",
892
+ )
893
+
894
+ first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
895
+ last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
896
+ ax.plot(
897
+ calibration_stats[1]["rmv"][first_idx:-last_idx],
898
+ calibration_stats[1]["rmse"][first_idx:-last_idx],
899
+ "o",
900
+ label=r"$\hat{C}_1: : Ch2$",
901
+ )
902
+
903
+ ax.set_xlabel("RMV")
904
+ ax.set_ylabel("RMSE")
905
+ ax.legend()