careamics 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/config/algorithms/care_algorithm_model.py +12 -24
- careamics/config/algorithms/n2n_algorithm_model.py +13 -25
- careamics/config/algorithms/n2v_algorithm_model.py +13 -19
- careamics/config/configuration_factories.py +84 -23
- careamics/config/data/data_model.py +47 -2
- careamics/config/support/supported_algorithms.py +5 -1
- careamics/config/validators/__init__.py +12 -1
- careamics/config/validators/model_validators.py +84 -0
- careamics/lightning/callbacks/progress_bar_callback.py +1 -1
- careamics/lightning/train_data_module.py +10 -19
- careamics/lvae_training/calibration.py +64 -57
- careamics/lvae_training/dataset/lc_dataset.py +2 -1
- careamics/lvae_training/dataset/multich_dataset.py +2 -2
- careamics/lvae_training/dataset/types.py +1 -1
- careamics/lvae_training/eval_utils.py +128 -128
- careamics/models/lvae/likelihoods.py +2 -0
- careamics/models/lvae/lvae.py +13 -1
- careamics/models/lvae/noise_models.py +280 -217
- careamics/models/lvae/stochastic.py +1 -0
- careamics/utils/metrics.py +25 -0
- careamics/utils/plotting.py +78 -0
- {careamics-0.0.6.dist-info → careamics-0.0.8.dist-info}/METADATA +5 -3
- {careamics-0.0.6.dist-info → careamics-0.0.8.dist-info}/RECORD +26 -24
- {careamics-0.0.6.dist-info → careamics-0.0.8.dist-info}/WHEEL +0 -0
- {careamics-0.0.6.dist-info → careamics-0.0.8.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.6.dist-info → careamics-0.0.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Union
|
|
1
|
+
from typing import Union, Optional
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -34,9 +34,6 @@ class Calibration:
|
|
|
34
34
|
self._bins = num_bins
|
|
35
35
|
self._bin_boundaries = None
|
|
36
36
|
|
|
37
|
-
def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
|
|
38
|
-
return np.exp(logvar / 2)
|
|
39
|
-
|
|
40
37
|
def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
|
|
41
38
|
"""Compute the bin boundaries for `num_bins` bins and predicted std values."""
|
|
42
39
|
min_std = np.min(predict_std)
|
|
@@ -104,65 +101,75 @@ class Calibration:
|
|
|
104
101
|
)
|
|
105
102
|
rmse_stderr = np.sqrt(stderr) if stderr is not None else None
|
|
106
103
|
|
|
107
|
-
bin_var = np.mean(
|
|
104
|
+
bin_var = np.mean(std_ch[bin_mask] ** 2)
|
|
108
105
|
stats_dict[ch_idx]["rmse"].append(bin_error)
|
|
109
106
|
stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
|
|
110
107
|
stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
|
|
111
108
|
stats_dict[ch_idx]["bin_count"].append(bin_size)
|
|
109
|
+
self.stats_dict = stats_dict
|
|
112
110
|
return stats_dict
|
|
113
111
|
|
|
112
|
+
def get_calibrated_factor_for_stdev(
|
|
113
|
+
self,
|
|
114
|
+
pred: Optional[np.ndarray] = None,
|
|
115
|
+
pred_std: Optional[np.ndarray] = None,
|
|
116
|
+
target: Optional[np.ndarray] = None,
|
|
117
|
+
q_s: float = 0.00001,
|
|
118
|
+
q_e: float = 0.99999,
|
|
119
|
+
) -> dict[str, float]:
|
|
120
|
+
"""Calibrate the uncertainty by multiplying the predicted std with a scalar.
|
|
114
121
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
stats_dict : dict[int, dict[str, Union[np.ndarray, list]]]
|
|
125
|
+
Dictionary containing the stats for each channel.
|
|
126
|
+
q_s : float, optional
|
|
127
|
+
Start quantile, by default 0.00001.
|
|
128
|
+
q_e : float, optional
|
|
129
|
+
End quantile, by default 0.99999.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
dict[str, float]
|
|
134
|
+
Calibrated factor for each channel (slope + intercept).
|
|
135
|
+
"""
|
|
136
|
+
if not hasattr(self, "stats_dict"):
|
|
137
|
+
print("No stats found. Computing stats...")
|
|
138
|
+
if any(v is None for v in [pred, pred_std, target]):
|
|
139
|
+
raise ValueError("pred, pred_std, and target must be provided.")
|
|
140
|
+
self.stats_dict = self.compute_stats(
|
|
141
|
+
pred=pred, pred_std=pred_std, target=target
|
|
142
|
+
)
|
|
143
|
+
outputs = {}
|
|
144
|
+
for ch_idx in self.stats_dict.keys():
|
|
145
|
+
y = self.stats_dict[ch_idx]["rmse"]
|
|
146
|
+
x = self.stats_dict[ch_idx]["rmv"]
|
|
147
|
+
count = self.stats_dict[ch_idx]["bin_count"]
|
|
148
|
+
|
|
149
|
+
first_idx = get_first_index(count, q_s)
|
|
150
|
+
last_idx = get_last_index(count, q_e)
|
|
151
|
+
x = x[first_idx:-last_idx]
|
|
152
|
+
y = y[first_idx:-last_idx]
|
|
153
|
+
slope, intercept, *_ = stats.linregress(x, y)
|
|
154
|
+
output = {"scalar": slope, "offset": intercept}
|
|
155
|
+
outputs[ch_idx] = output
|
|
156
|
+
factors = self.get_factors_array(factors_dict=outputs)
|
|
157
|
+
return outputs, factors
|
|
158
|
+
|
|
159
|
+
def get_factors_array(self, factors_dict: list[dict]):
|
|
160
|
+
"""Get the calibration factors as a numpy array."""
|
|
161
|
+
calib_scalar = [factors_dict[i]["scalar"] for i in range(len(factors_dict))]
|
|
162
|
+
calib_scalar = np.array(calib_scalar).reshape(1, 1, 1, -1)
|
|
163
|
+
calib_offset = [
|
|
164
|
+
factors_dict[i].get("offset", 0.0) for i in range(len(factors_dict))
|
|
165
|
+
]
|
|
166
|
+
calib_offset = np.array(calib_offset).reshape(1, 1, 1, -1)
|
|
167
|
+
return {"scalar": calib_scalar, "offset": calib_offset}
|
|
161
168
|
|
|
162
169
|
|
|
163
170
|
def plot_calibration(ax, calibration_stats):
|
|
164
|
-
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.
|
|
165
|
-
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.
|
|
171
|
+
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.0001)
|
|
172
|
+
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.9999)
|
|
166
173
|
ax.plot(
|
|
167
174
|
calibration_stats[0]["rmv"][first_idx:-last_idx],
|
|
168
175
|
calibration_stats[0]["rmse"][first_idx:-last_idx],
|
|
@@ -170,15 +177,15 @@ def plot_calibration(ax, calibration_stats):
|
|
|
170
177
|
label=r"$\hat{C}_0$: Ch1",
|
|
171
178
|
)
|
|
172
179
|
|
|
173
|
-
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.
|
|
174
|
-
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.
|
|
180
|
+
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.0001)
|
|
181
|
+
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.9999)
|
|
175
182
|
ax.plot(
|
|
176
183
|
calibration_stats[1]["rmv"][first_idx:-last_idx],
|
|
177
184
|
calibration_stats[1]["rmse"][first_idx:-last_idx],
|
|
178
185
|
"o",
|
|
179
|
-
label=r"$\hat{C}_1
|
|
186
|
+
label=r"$\hat{C}_1$: Ch2",
|
|
180
187
|
)
|
|
181
|
-
|
|
188
|
+
# TODO add multichannel
|
|
182
189
|
ax.set_xlabel("RMV")
|
|
183
190
|
ax.set_ylabel("RMSE")
|
|
184
191
|
ax.legend()
|
|
@@ -97,7 +97,8 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
97
97
|
]
|
|
98
98
|
|
|
99
99
|
self.N = len(t_list)
|
|
100
|
-
self.
|
|
100
|
+
# TODO where tf is self._img_sz defined?
|
|
101
|
+
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
|
|
101
102
|
print(
|
|
102
103
|
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
103
104
|
)
|
|
@@ -359,8 +359,8 @@ class MultiChDloader:
|
|
|
359
359
|
self._noise_data = self._noise_data[
|
|
360
360
|
t_list, h_start:h_end, w_start:w_end, :
|
|
361
361
|
].copy()
|
|
362
|
-
|
|
363
|
-
self.set_img_sz(self._img_sz, self._grid_sz)
|
|
362
|
+
# TODO where tf is self._img_sz defined?
|
|
363
|
+
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
|
|
364
364
|
print(
|
|
365
365
|
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
366
366
|
)
|
|
@@ -7,23 +7,18 @@ It includes functions to:
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import os
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import Optional
|
|
11
11
|
|
|
12
12
|
import matplotlib
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
import numpy as np
|
|
15
|
-
from scipy import stats
|
|
16
15
|
import torch
|
|
17
|
-
from torch import nn
|
|
18
|
-
from torch.utils.data import Dataset
|
|
19
16
|
from matplotlib.gridspec import GridSpec
|
|
20
|
-
from torch.utils.data import DataLoader
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset, Subset
|
|
21
18
|
from tqdm import tqdm
|
|
22
19
|
|
|
23
20
|
from careamics.lightning import VAEModule
|
|
24
|
-
|
|
25
|
-
from careamics.models.lvae.utils import ModelType
|
|
26
|
-
from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR
|
|
21
|
+
from careamics.utils.metrics import scale_invariant_psnr
|
|
27
22
|
|
|
28
23
|
|
|
29
24
|
class TilingMode:
|
|
@@ -149,11 +144,10 @@ def plot_crops(
|
|
|
149
144
|
tar,
|
|
150
145
|
tar_hsnr,
|
|
151
146
|
recon_img_list,
|
|
152
|
-
calibration_stats,
|
|
147
|
+
calibration_stats=None,
|
|
153
148
|
num_samples=2,
|
|
154
149
|
baseline_preds=None,
|
|
155
150
|
):
|
|
156
|
-
""" """
|
|
157
151
|
if baseline_preds is None:
|
|
158
152
|
baseline_preds = []
|
|
159
153
|
if len(baseline_preds) > 0:
|
|
@@ -164,15 +158,13 @@ def plot_crops(
|
|
|
164
158
|
)
|
|
165
159
|
print("This happens when we want to predict the edges of the image.")
|
|
166
160
|
return
|
|
161
|
+
color_ch_list = ["goldenrod", "cyan"]
|
|
162
|
+
color_pred = "red"
|
|
163
|
+
insetplot_xmax_value = 10000
|
|
164
|
+
insetplot_xmin_value = -1000
|
|
165
|
+
inset_min_labelsize = 10
|
|
166
|
+
inset_rect = [0.05, 0.05, 0.4, 0.2]
|
|
167
167
|
|
|
168
|
-
# color_ch_list = ['goldenrod', 'cyan']
|
|
169
|
-
# color_pred = 'red'
|
|
170
|
-
# insetplot_xmax_value = 10000
|
|
171
|
-
# insetplot_xmin_value = -1000
|
|
172
|
-
# inset_min_labelsize = 10
|
|
173
|
-
# inset_rect = [0.05, 0.05, 0.4, 0.2]
|
|
174
|
-
|
|
175
|
-
# Set plot attributes
|
|
176
168
|
img_sz = 3
|
|
177
169
|
ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1)
|
|
178
170
|
grid_factor = 5
|
|
@@ -191,7 +183,6 @@ def plot_crops(
|
|
|
191
183
|
)
|
|
192
184
|
params = {"mathtext.default": "regular"}
|
|
193
185
|
plt.rcParams.update(params)
|
|
194
|
-
|
|
195
186
|
# plot baselines
|
|
196
187
|
for i in range(2, 2 + len(baseline_preds)):
|
|
197
188
|
for col_idx in range(baseline_preds[0].shape[0]):
|
|
@@ -471,52 +462,18 @@ def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val
|
|
|
471
462
|
plt.colorbar(img_err, ax=ax)
|
|
472
463
|
|
|
473
464
|
|
|
474
|
-
#
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):
|
|
478
|
-
"""
|
|
479
|
-
Given an index and a validation/test set, it returns the input, target and the reconstructed images for that index.
|
|
480
|
-
"""
|
|
481
|
-
print(f"Predicting for {idx}")
|
|
482
|
-
val_dset.set_img_sz(patch_size, 64)
|
|
483
|
-
|
|
484
|
-
with torch.no_grad():
|
|
485
|
-
# val_dset.enable_noise()
|
|
486
|
-
inp, tar = val_dset[idx]
|
|
487
|
-
# val_dset.disable_noise()
|
|
488
|
-
|
|
489
|
-
inp = torch.Tensor(inp[None])
|
|
490
|
-
tar = torch.Tensor(tar[None])
|
|
491
|
-
inp = inp.cuda()
|
|
492
|
-
x_normalized = model.normalize_input(inp)
|
|
493
|
-
tar = tar.cuda()
|
|
494
|
-
tar_normalized = model.normalize_target(tar)
|
|
495
|
-
|
|
496
|
-
recon_img_list = []
|
|
497
|
-
for _ in range(mmse_count):
|
|
498
|
-
recon_normalized, td_data = model(x_normalized)
|
|
499
|
-
rec_loss, imgs = model.get_reconstruction_loss(
|
|
500
|
-
recon_normalized,
|
|
501
|
-
x_normalized,
|
|
502
|
-
tar_normalized,
|
|
503
|
-
return_predicted_img=True,
|
|
504
|
-
)
|
|
505
|
-
imgs = model.unnormalize_target(imgs)
|
|
506
|
-
recon_img_list.append(imgs.cpu().numpy()[0])
|
|
507
|
-
|
|
508
|
-
recon_img_list = np.array(recon_img_list)
|
|
509
|
-
return inp, tar, recon_img_list
|
|
465
|
+
# -------------------------------------------------------------------------------------
|
|
510
466
|
|
|
511
467
|
|
|
512
|
-
def
|
|
468
|
+
def get_predictions(
|
|
513
469
|
model: VAEModule,
|
|
514
470
|
dset: Dataset,
|
|
515
471
|
batch_size: int,
|
|
516
|
-
|
|
472
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
473
|
+
grid_size: Optional[int] = None,
|
|
517
474
|
mmse_count: int = 1,
|
|
518
475
|
num_workers: int = 4,
|
|
519
|
-
) -> tuple[
|
|
476
|
+
) -> tuple[dict, dict, dict]:
|
|
520
477
|
"""Get patch-wise predictions from a model for the entire dataset.
|
|
521
478
|
|
|
522
479
|
Parameters
|
|
@@ -545,6 +502,57 @@ def get_dset_predictions(
|
|
|
545
502
|
- losses: Reconstruction losses for the predictions.
|
|
546
503
|
- psnr: PSNR values for the predictions.
|
|
547
504
|
"""
|
|
505
|
+
if hasattr(dset, "dsets"):
|
|
506
|
+
multifile_stitched_predictions = {}
|
|
507
|
+
multifile_stitched_stds = {}
|
|
508
|
+
for d in dset.dsets:
|
|
509
|
+
stitched_predictions, stitched_stds = get_single_file_mmse(
|
|
510
|
+
model=model,
|
|
511
|
+
dset=d,
|
|
512
|
+
batch_size=batch_size,
|
|
513
|
+
tile_size=tile_size,
|
|
514
|
+
grid_size=grid_size,
|
|
515
|
+
mmse_count=mmse_count,
|
|
516
|
+
num_workers=num_workers,
|
|
517
|
+
)
|
|
518
|
+
# get filename without extension and path
|
|
519
|
+
filename = str(d._fpath).split("/")[-1].split(".")[0]
|
|
520
|
+
multifile_stitched_predictions[filename] = stitched_predictions
|
|
521
|
+
multifile_stitched_stds[filename] = stitched_stds
|
|
522
|
+
return (
|
|
523
|
+
multifile_stitched_predictions,
|
|
524
|
+
multifile_stitched_stds,
|
|
525
|
+
)
|
|
526
|
+
else:
|
|
527
|
+
stitched_predictions, stitched_stds = get_single_file_mmse(
|
|
528
|
+
model=model,
|
|
529
|
+
dset=dset,
|
|
530
|
+
batch_size=batch_size,
|
|
531
|
+
tile_size=tile_size,
|
|
532
|
+
grid_size=grid_size,
|
|
533
|
+
mmse_count=mmse_count,
|
|
534
|
+
num_workers=num_workers,
|
|
535
|
+
)
|
|
536
|
+
# get filename without extension and path
|
|
537
|
+
filename = str(dset._fpath).split("/")[-1].split(".")[0]
|
|
538
|
+
return (
|
|
539
|
+
{filename: stitched_predictions},
|
|
540
|
+
{filename: stitched_stds},
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def get_single_file_predictions(
|
|
545
|
+
model: VAEModule,
|
|
546
|
+
dset: Dataset,
|
|
547
|
+
batch_size: int,
|
|
548
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
549
|
+
grid_size: Optional[int] = None,
|
|
550
|
+
num_workers: int = 4,
|
|
551
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
552
|
+
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
553
|
+
if tile_size and grid_size:
|
|
554
|
+
dset.set_img_sz(tile_size, grid_size)
|
|
555
|
+
|
|
548
556
|
dloader = DataLoader(
|
|
549
557
|
dset,
|
|
550
558
|
pin_memory=False,
|
|
@@ -552,43 +560,66 @@ def get_dset_predictions(
|
|
|
552
560
|
shuffle=False,
|
|
553
561
|
batch_size=batch_size,
|
|
554
562
|
)
|
|
563
|
+
model.eval()
|
|
564
|
+
model.cuda()
|
|
565
|
+
tiles = []
|
|
566
|
+
logvar_arr = []
|
|
567
|
+
with torch.no_grad():
|
|
568
|
+
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
569
|
+
inp, tar = batch
|
|
570
|
+
inp = inp.cuda()
|
|
571
|
+
tar = tar.cuda()
|
|
572
|
+
|
|
573
|
+
# get model output
|
|
574
|
+
rec, _ = model(inp)
|
|
575
|
+
|
|
576
|
+
# get reconstructed img
|
|
577
|
+
if model.model.predict_logvar is None:
|
|
578
|
+
rec_img = rec
|
|
579
|
+
logvar = torch.tensor([-1])
|
|
580
|
+
else:
|
|
581
|
+
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
582
|
+
logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
|
|
583
|
+
|
|
584
|
+
tiles.append(rec_img.cpu().numpy())
|
|
585
|
+
|
|
586
|
+
tile_samples = np.concatenate(tiles, axis=0)
|
|
587
|
+
return stitch_predictions_new(tile_samples, dset)
|
|
555
588
|
|
|
556
|
-
gauss_likelihood = model.gaussian_likelihood
|
|
557
|
-
nm_likelihood = model.noise_model_likelihood
|
|
558
589
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
590
|
+
def get_single_file_mmse(
|
|
591
|
+
model: VAEModule,
|
|
592
|
+
dset: Dataset,
|
|
593
|
+
batch_size: int,
|
|
594
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
595
|
+
grid_size: Optional[int] = None,
|
|
596
|
+
mmse_count: int = 1,
|
|
597
|
+
num_workers: int = 4,
|
|
598
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
599
|
+
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
600
|
+
dloader = DataLoader(
|
|
601
|
+
dset,
|
|
602
|
+
pin_memory=False,
|
|
603
|
+
num_workers=num_workers,
|
|
604
|
+
shuffle=False,
|
|
605
|
+
batch_size=batch_size,
|
|
606
|
+
)
|
|
607
|
+
if tile_size and grid_size:
|
|
608
|
+
dset.set_img_sz(tile_size, grid_size)
|
|
609
|
+
|
|
610
|
+
model.eval()
|
|
611
|
+
model.cuda()
|
|
612
|
+
tile_mmse = []
|
|
613
|
+
tile_stds = []
|
|
562
614
|
logvar_arr = []
|
|
563
|
-
num_channels = dset[0][1].shape[0]
|
|
564
|
-
patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)]
|
|
565
615
|
with torch.no_grad():
|
|
566
|
-
for batch in tqdm(dloader, desc="Predicting
|
|
616
|
+
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
567
617
|
inp, tar = batch
|
|
568
618
|
inp = inp.cuda()
|
|
569
619
|
tar = tar.cuda()
|
|
570
620
|
|
|
571
621
|
rec_img_list = []
|
|
572
|
-
for
|
|
573
|
-
|
|
574
|
-
# TODO: case of HDN left for future refactoring
|
|
575
|
-
# if model_type == ModelType.Denoiser:
|
|
576
|
-
# assert model.denoise_channel in [
|
|
577
|
-
# "Ch1",
|
|
578
|
-
# "Ch2",
|
|
579
|
-
# "input",
|
|
580
|
-
# ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
|
|
581
|
-
|
|
582
|
-
# x_normalized_new, tar_new = model.get_new_input_target(
|
|
583
|
-
# (inp, tar, *batch[2:])
|
|
584
|
-
# )
|
|
585
|
-
# rec, _ = model(x_normalized_new)
|
|
586
|
-
# rec_loss, imgs = model.get_reconstruction_loss(
|
|
587
|
-
# rec,
|
|
588
|
-
# tar,
|
|
589
|
-
# x_normalized_new,
|
|
590
|
-
# return_predicted_img=True,
|
|
591
|
-
# )
|
|
622
|
+
for _ in range(mmse_count):
|
|
592
623
|
|
|
593
624
|
# get model output
|
|
594
625
|
rec, _ = model(inp)
|
|
@@ -600,52 +631,21 @@ def get_dset_predictions(
|
|
|
600
631
|
else:
|
|
601
632
|
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
602
633
|
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
603
|
-
logvar_arr.append(logvar.cpu().numpy())
|
|
604
|
-
|
|
605
|
-
# compute reconstruction loss
|
|
606
|
-
# if loss_type == "musplit":
|
|
607
|
-
# rec_loss = get_reconstruction_loss(
|
|
608
|
-
# reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
|
|
609
|
-
# )
|
|
610
|
-
# elif loss_type == "denoisplit":
|
|
611
|
-
# rec_loss = get_reconstruction_loss(
|
|
612
|
-
# reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
|
|
613
|
-
# )
|
|
614
|
-
# elif loss_type == "denoisplit_musplit":
|
|
615
|
-
# rec_loss = reconstruction_loss_musplit_denoisplit(
|
|
616
|
-
# predictions=rec,
|
|
617
|
-
# targets=tar,
|
|
618
|
-
# gaussian_likelihood=gauss_likelihood,
|
|
619
|
-
# nm_likelihood=nm_likelihood,
|
|
620
|
-
# nm_weight=model.loss_parameters.denoisplit_weight,
|
|
621
|
-
# gaussian_weight=model.loss_parameters.musplit_weight,
|
|
622
|
-
# )
|
|
623
|
-
# rec_loss = {"loss": rec_loss} # hacky, but ok for now
|
|
624
|
-
|
|
625
|
-
# # store rec loss values for first pred
|
|
626
|
-
# if mmse_idx == 0:
|
|
627
|
-
# try:
|
|
628
|
-
# losses.append(rec_loss["loss"].cpu().numpy())
|
|
629
|
-
# except:
|
|
630
|
-
# losses.append(rec_loss["loss"])
|
|
631
|
-
|
|
632
|
-
# update running PSNR
|
|
633
|
-
# for i in range(num_channels):
|
|
634
|
-
# patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
|
|
634
|
+
logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
|
|
635
635
|
|
|
636
636
|
# aggregate results
|
|
637
637
|
samples = torch.cat(rec_img_list, dim=0)
|
|
638
638
|
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
639
|
+
std_imgs = torch.std(samples, dim=0) # std over MMSE dim
|
|
640
|
+
|
|
641
|
+
tile_mmse.append(mmse_imgs.cpu().numpy())
|
|
642
|
+
tile_stds.append(std_imgs.cpu().numpy())
|
|
643
|
+
|
|
644
|
+
tiles_arr = np.concatenate(tile_mmse, axis=0)
|
|
645
|
+
tile_stds = np.concatenate(tile_stds, axis=0)
|
|
646
|
+
stitched_predictions = stitch_predictions_new(tiles_arr, dset)
|
|
647
|
+
stitched_stds = stitch_predictions_new(tile_stds, dset)
|
|
648
|
+
return stitched_predictions, stitched_stds
|
|
649
649
|
|
|
650
650
|
|
|
651
651
|
# ------------------------------------------------------------------------------------------
|
|
@@ -324,6 +324,8 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
324
324
|
if self.data_mean.device != correct_device_tensor.device:
|
|
325
325
|
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
326
326
|
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
327
|
+
if correct_device_tensor.device != self.noiseModel.device:
|
|
328
|
+
self.noiseModel.to_device(correct_device_tensor.device)
|
|
327
329
|
|
|
328
330
|
def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
|
|
329
331
|
return x, None
|
careamics/models/lvae/lvae.py
CHANGED
|
@@ -6,7 +6,7 @@ and Artefact Removal, Prakash et al."
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from collections.abc import Iterable
|
|
9
|
-
from typing import Union
|
|
9
|
+
from typing import Optional, Union
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
@@ -834,3 +834,15 @@ class LadderVAE(nn.Module):
|
|
|
834
834
|
# TODO check if model_3D_depth is needed ?
|
|
835
835
|
top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
|
|
836
836
|
return top_layer_shape
|
|
837
|
+
|
|
838
|
+
def reset_for_inference(self, tile_size: Optional[tuple[int, int]] = None):
|
|
839
|
+
"""Should be called if we want to predict for a different input/output size."""
|
|
840
|
+
self.mode_pred = True
|
|
841
|
+
if tile_size is None:
|
|
842
|
+
tile_size = self.image_size
|
|
843
|
+
self.image_size = tile_size
|
|
844
|
+
for i in range(self.n_layers):
|
|
845
|
+
self.bottom_up_layers[i].output_expected_shape = (
|
|
846
|
+
ts // 2 ** (i + 1) for ts in tile_size
|
|
847
|
+
)
|
|
848
|
+
self.top_down_layers[i].latent_shape = tile_size
|