careamics 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/config/algorithms/care_algorithm_model.py +12 -24
- careamics/config/algorithms/n2n_algorithm_model.py +13 -25
- careamics/config/algorithms/n2v_algorithm_model.py +13 -19
- careamics/config/configuration_factories.py +84 -23
- careamics/config/data/data_model.py +47 -2
- careamics/config/support/supported_algorithms.py +5 -1
- careamics/config/validators/__init__.py +12 -1
- careamics/config/validators/model_validators.py +84 -0
- careamics/lightning/callbacks/progress_bar_callback.py +1 -1
- careamics/lightning/train_data_module.py +10 -19
- careamics/lvae_training/calibration.py +64 -57
- careamics/lvae_training/dataset/lc_dataset.py +2 -1
- careamics/lvae_training/dataset/multich_dataset.py +2 -2
- careamics/lvae_training/dataset/types.py +1 -1
- careamics/lvae_training/eval_utils.py +123 -128
- careamics/models/lvae/likelihoods.py +2 -0
- careamics/models/lvae/lvae.py +13 -1
- careamics/models/lvae/noise_models.py +280 -217
- careamics/models/lvae/stochastic.py +1 -0
- careamics/utils/metrics.py +25 -0
- careamics/utils/plotting.py +78 -0
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/METADATA +5 -3
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/RECORD +26 -24
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.6.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Union
|
|
1
|
+
from typing import Union, Optional
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -34,9 +34,6 @@ class Calibration:
|
|
|
34
34
|
self._bins = num_bins
|
|
35
35
|
self._bin_boundaries = None
|
|
36
36
|
|
|
37
|
-
def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
|
|
38
|
-
return np.exp(logvar / 2)
|
|
39
|
-
|
|
40
37
|
def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
|
|
41
38
|
"""Compute the bin boundaries for `num_bins` bins and predicted std values."""
|
|
42
39
|
min_std = np.min(predict_std)
|
|
@@ -104,65 +101,75 @@ class Calibration:
|
|
|
104
101
|
)
|
|
105
102
|
rmse_stderr = np.sqrt(stderr) if stderr is not None else None
|
|
106
103
|
|
|
107
|
-
bin_var = np.mean(
|
|
104
|
+
bin_var = np.mean(std_ch[bin_mask] ** 2)
|
|
108
105
|
stats_dict[ch_idx]["rmse"].append(bin_error)
|
|
109
106
|
stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
|
|
110
107
|
stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
|
|
111
108
|
stats_dict[ch_idx]["bin_count"].append(bin_size)
|
|
109
|
+
self.stats_dict = stats_dict
|
|
112
110
|
return stats_dict
|
|
113
111
|
|
|
112
|
+
def get_calibrated_factor_for_stdev(
|
|
113
|
+
self,
|
|
114
|
+
pred: Optional[np.ndarray] = None,
|
|
115
|
+
pred_std: Optional[np.ndarray] = None,
|
|
116
|
+
target: Optional[np.ndarray] = None,
|
|
117
|
+
q_s: float = 0.00001,
|
|
118
|
+
q_e: float = 0.99999,
|
|
119
|
+
) -> dict[str, float]:
|
|
120
|
+
"""Calibrate the uncertainty by multiplying the predicted std with a scalar.
|
|
114
121
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
stats_dict : dict[int, dict[str, Union[np.ndarray, list]]]
|
|
125
|
+
Dictionary containing the stats for each channel.
|
|
126
|
+
q_s : float, optional
|
|
127
|
+
Start quantile, by default 0.00001.
|
|
128
|
+
q_e : float, optional
|
|
129
|
+
End quantile, by default 0.99999.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
dict[str, float]
|
|
134
|
+
Calibrated factor for each channel (slope + intercept).
|
|
135
|
+
"""
|
|
136
|
+
if not hasattr(self, "stats_dict"):
|
|
137
|
+
print("No stats found. Computing stats...")
|
|
138
|
+
if any(v is None for v in [pred, pred_std, target]):
|
|
139
|
+
raise ValueError("pred, pred_std, and target must be provided.")
|
|
140
|
+
self.stats_dict = self.compute_stats(
|
|
141
|
+
pred=pred, pred_std=pred_std, target=target
|
|
142
|
+
)
|
|
143
|
+
outputs = {}
|
|
144
|
+
for ch_idx in self.stats_dict.keys():
|
|
145
|
+
y = self.stats_dict[ch_idx]["rmse"]
|
|
146
|
+
x = self.stats_dict[ch_idx]["rmv"]
|
|
147
|
+
count = self.stats_dict[ch_idx]["bin_count"]
|
|
148
|
+
|
|
149
|
+
first_idx = get_first_index(count, q_s)
|
|
150
|
+
last_idx = get_last_index(count, q_e)
|
|
151
|
+
x = x[first_idx:-last_idx]
|
|
152
|
+
y = y[first_idx:-last_idx]
|
|
153
|
+
slope, intercept, *_ = stats.linregress(x, y)
|
|
154
|
+
output = {"scalar": slope, "offset": intercept}
|
|
155
|
+
outputs[ch_idx] = output
|
|
156
|
+
factors = self.get_factors_array(factors_dict=outputs)
|
|
157
|
+
return outputs, factors
|
|
158
|
+
|
|
159
|
+
def get_factors_array(self, factors_dict: list[dict]):
|
|
160
|
+
"""Get the calibration factors as a numpy array."""
|
|
161
|
+
calib_scalar = [factors_dict[i]["scalar"] for i in range(len(factors_dict))]
|
|
162
|
+
calib_scalar = np.array(calib_scalar).reshape(1, 1, 1, -1)
|
|
163
|
+
calib_offset = [
|
|
164
|
+
factors_dict[i].get("offset", 0.0) for i in range(len(factors_dict))
|
|
165
|
+
]
|
|
166
|
+
calib_offset = np.array(calib_offset).reshape(1, 1, 1, -1)
|
|
167
|
+
return {"scalar": calib_scalar, "offset": calib_offset}
|
|
161
168
|
|
|
162
169
|
|
|
163
170
|
def plot_calibration(ax, calibration_stats):
|
|
164
|
-
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.
|
|
165
|
-
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.
|
|
171
|
+
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.0001)
|
|
172
|
+
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.9999)
|
|
166
173
|
ax.plot(
|
|
167
174
|
calibration_stats[0]["rmv"][first_idx:-last_idx],
|
|
168
175
|
calibration_stats[0]["rmse"][first_idx:-last_idx],
|
|
@@ -170,15 +177,15 @@ def plot_calibration(ax, calibration_stats):
|
|
|
170
177
|
label=r"$\hat{C}_0$: Ch1",
|
|
171
178
|
)
|
|
172
179
|
|
|
173
|
-
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.
|
|
174
|
-
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.
|
|
180
|
+
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.0001)
|
|
181
|
+
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.9999)
|
|
175
182
|
ax.plot(
|
|
176
183
|
calibration_stats[1]["rmv"][first_idx:-last_idx],
|
|
177
184
|
calibration_stats[1]["rmse"][first_idx:-last_idx],
|
|
178
185
|
"o",
|
|
179
|
-
label=r"$\hat{C}_1
|
|
186
|
+
label=r"$\hat{C}_1$: Ch2",
|
|
180
187
|
)
|
|
181
|
-
|
|
188
|
+
# TODO add multichannel
|
|
182
189
|
ax.set_xlabel("RMV")
|
|
183
190
|
ax.set_ylabel("RMSE")
|
|
184
191
|
ax.legend()
|
|
@@ -97,7 +97,8 @@ class LCMultiChDloader(MultiChDloader):
|
|
|
97
97
|
]
|
|
98
98
|
|
|
99
99
|
self.N = len(t_list)
|
|
100
|
-
self.
|
|
100
|
+
# TODO where tf is self._img_sz defined?
|
|
101
|
+
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
|
|
101
102
|
print(
|
|
102
103
|
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
103
104
|
)
|
|
@@ -359,8 +359,8 @@ class MultiChDloader:
|
|
|
359
359
|
self._noise_data = self._noise_data[
|
|
360
360
|
t_list, h_start:h_end, w_start:w_end, :
|
|
361
361
|
].copy()
|
|
362
|
-
|
|
363
|
-
self.set_img_sz(self._img_sz, self._grid_sz)
|
|
362
|
+
# TODO where tf is self._img_sz defined?
|
|
363
|
+
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
|
|
364
364
|
print(
|
|
365
365
|
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
|
|
366
366
|
)
|
|
@@ -7,23 +7,18 @@ It includes functions to:
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import os
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import Optional
|
|
11
11
|
|
|
12
12
|
import matplotlib
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
14
|
import numpy as np
|
|
15
|
-
from scipy import stats
|
|
16
15
|
import torch
|
|
17
|
-
from torch import nn
|
|
18
|
-
from torch.utils.data import Dataset
|
|
19
16
|
from matplotlib.gridspec import GridSpec
|
|
20
|
-
from torch.utils.data import DataLoader
|
|
17
|
+
from torch.utils.data import DataLoader, Dataset, Subset
|
|
21
18
|
from tqdm import tqdm
|
|
22
19
|
|
|
23
20
|
from careamics.lightning import VAEModule
|
|
24
|
-
|
|
25
|
-
from careamics.models.lvae.utils import ModelType
|
|
26
|
-
from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR
|
|
21
|
+
from careamics.utils.metrics import scale_invariant_psnr
|
|
27
22
|
|
|
28
23
|
|
|
29
24
|
class TilingMode:
|
|
@@ -149,11 +144,10 @@ def plot_crops(
|
|
|
149
144
|
tar,
|
|
150
145
|
tar_hsnr,
|
|
151
146
|
recon_img_list,
|
|
152
|
-
calibration_stats,
|
|
147
|
+
calibration_stats=None,
|
|
153
148
|
num_samples=2,
|
|
154
149
|
baseline_preds=None,
|
|
155
150
|
):
|
|
156
|
-
""" """
|
|
157
151
|
if baseline_preds is None:
|
|
158
152
|
baseline_preds = []
|
|
159
153
|
if len(baseline_preds) > 0:
|
|
@@ -164,15 +158,13 @@ def plot_crops(
|
|
|
164
158
|
)
|
|
165
159
|
print("This happens when we want to predict the edges of the image.")
|
|
166
160
|
return
|
|
161
|
+
color_ch_list = ["goldenrod", "cyan"]
|
|
162
|
+
color_pred = "red"
|
|
163
|
+
insetplot_xmax_value = 10000
|
|
164
|
+
insetplot_xmin_value = -1000
|
|
165
|
+
inset_min_labelsize = 10
|
|
166
|
+
inset_rect = [0.05, 0.05, 0.4, 0.2]
|
|
167
167
|
|
|
168
|
-
# color_ch_list = ['goldenrod', 'cyan']
|
|
169
|
-
# color_pred = 'red'
|
|
170
|
-
# insetplot_xmax_value = 10000
|
|
171
|
-
# insetplot_xmin_value = -1000
|
|
172
|
-
# inset_min_labelsize = 10
|
|
173
|
-
# inset_rect = [0.05, 0.05, 0.4, 0.2]
|
|
174
|
-
|
|
175
|
-
# Set plot attributes
|
|
176
168
|
img_sz = 3
|
|
177
169
|
ncols = num_samples + len(baseline_preds) + 1 + 1 + 1 + 1 + 1 * (num_samples > 1)
|
|
178
170
|
grid_factor = 5
|
|
@@ -191,7 +183,6 @@ def plot_crops(
|
|
|
191
183
|
)
|
|
192
184
|
params = {"mathtext.default": "regular"}
|
|
193
185
|
plt.rcParams.update(params)
|
|
194
|
-
|
|
195
186
|
# plot baselines
|
|
196
187
|
for i in range(2, 2 + len(baseline_preds)):
|
|
197
188
|
for col_idx in range(baseline_preds[0].shape[0]):
|
|
@@ -471,52 +462,17 @@ def plot_error(target, prediction, cmap=matplotlib.cm.coolwarm, ax=None, max_val
|
|
|
471
462
|
plt.colorbar(img_err, ax=ax)
|
|
472
463
|
|
|
473
464
|
|
|
474
|
-
#
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):
|
|
478
|
-
"""
|
|
479
|
-
Given an index and a validation/test set, it returns the input, target and the reconstructed images for that index.
|
|
480
|
-
"""
|
|
481
|
-
print(f"Predicting for {idx}")
|
|
482
|
-
val_dset.set_img_sz(patch_size, 64)
|
|
483
|
-
|
|
484
|
-
with torch.no_grad():
|
|
485
|
-
# val_dset.enable_noise()
|
|
486
|
-
inp, tar = val_dset[idx]
|
|
487
|
-
# val_dset.disable_noise()
|
|
488
|
-
|
|
489
|
-
inp = torch.Tensor(inp[None])
|
|
490
|
-
tar = torch.Tensor(tar[None])
|
|
491
|
-
inp = inp.cuda()
|
|
492
|
-
x_normalized = model.normalize_input(inp)
|
|
493
|
-
tar = tar.cuda()
|
|
494
|
-
tar_normalized = model.normalize_target(tar)
|
|
495
|
-
|
|
496
|
-
recon_img_list = []
|
|
497
|
-
for _ in range(mmse_count):
|
|
498
|
-
recon_normalized, td_data = model(x_normalized)
|
|
499
|
-
rec_loss, imgs = model.get_reconstruction_loss(
|
|
500
|
-
recon_normalized,
|
|
501
|
-
x_normalized,
|
|
502
|
-
tar_normalized,
|
|
503
|
-
return_predicted_img=True,
|
|
504
|
-
)
|
|
505
|
-
imgs = model.unnormalize_target(imgs)
|
|
506
|
-
recon_img_list.append(imgs.cpu().numpy()[0])
|
|
507
|
-
|
|
508
|
-
recon_img_list = np.array(recon_img_list)
|
|
509
|
-
return inp, tar, recon_img_list
|
|
465
|
+
# -------------------------------------------------------------------------------------
|
|
510
466
|
|
|
511
467
|
|
|
512
|
-
def
|
|
468
|
+
def get_predictions(
|
|
513
469
|
model: VAEModule,
|
|
514
470
|
dset: Dataset,
|
|
515
471
|
batch_size: int,
|
|
516
|
-
|
|
472
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
517
473
|
mmse_count: int = 1,
|
|
518
474
|
num_workers: int = 4,
|
|
519
|
-
) -> tuple[
|
|
475
|
+
) -> tuple[dict, dict, dict]:
|
|
520
476
|
"""Get patch-wise predictions from a model for the entire dataset.
|
|
521
477
|
|
|
522
478
|
Parameters
|
|
@@ -545,6 +501,55 @@ def get_dset_predictions(
|
|
|
545
501
|
- losses: Reconstruction losses for the predictions.
|
|
546
502
|
- psnr: PSNR values for the predictions.
|
|
547
503
|
"""
|
|
504
|
+
if hasattr(dset, "dsets"):
|
|
505
|
+
multifile_stitched_predictions = {}
|
|
506
|
+
multifile_stitched_stds = {}
|
|
507
|
+
for d in dset.dsets:
|
|
508
|
+
stitched_predictions, stitched_stds = get_single_file_mmse(
|
|
509
|
+
model=model,
|
|
510
|
+
dset=d,
|
|
511
|
+
batch_size=batch_size,
|
|
512
|
+
tile_size=tile_size,
|
|
513
|
+
mmse_count=mmse_count,
|
|
514
|
+
num_workers=num_workers,
|
|
515
|
+
)
|
|
516
|
+
# get filename without extension and path
|
|
517
|
+
filename = str(d._fpath).split("/")[-1].split(".")[0]
|
|
518
|
+
multifile_stitched_predictions[filename] = stitched_predictions
|
|
519
|
+
multifile_stitched_stds[filename] = stitched_stds
|
|
520
|
+
return (
|
|
521
|
+
multifile_stitched_predictions,
|
|
522
|
+
multifile_stitched_stds,
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
stitched_predictions, stitched_stds = get_single_file_mmse(
|
|
526
|
+
model=model,
|
|
527
|
+
dset=dset,
|
|
528
|
+
batch_size=batch_size,
|
|
529
|
+
tile_size=tile_size,
|
|
530
|
+
mmse_count=mmse_count,
|
|
531
|
+
num_workers=num_workers,
|
|
532
|
+
)
|
|
533
|
+
# get filename without extension and path
|
|
534
|
+
filename = str(dset._fpath).split("/")[-1].split(".")[0]
|
|
535
|
+
return (
|
|
536
|
+
{filename: stitched_predictions},
|
|
537
|
+
{filename: stitched_stds},
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def get_single_file_predictions(
|
|
542
|
+
model: VAEModule,
|
|
543
|
+
dset: Dataset,
|
|
544
|
+
batch_size: int,
|
|
545
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
546
|
+
grid_size: Optional[int] = None,
|
|
547
|
+
num_workers: int = 4,
|
|
548
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
549
|
+
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
550
|
+
if tile_size and grid_size:
|
|
551
|
+
dset.set_img_sz(tile_size, grid_size)
|
|
552
|
+
|
|
548
553
|
dloader = DataLoader(
|
|
549
554
|
dset,
|
|
550
555
|
pin_memory=False,
|
|
@@ -552,43 +557,64 @@ def get_dset_predictions(
|
|
|
552
557
|
shuffle=False,
|
|
553
558
|
batch_size=batch_size,
|
|
554
559
|
)
|
|
560
|
+
model.eval()
|
|
561
|
+
model.cuda()
|
|
562
|
+
tiles = []
|
|
563
|
+
logvar_arr = []
|
|
564
|
+
with torch.no_grad():
|
|
565
|
+
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
566
|
+
inp, tar = batch
|
|
567
|
+
inp = inp.cuda()
|
|
568
|
+
tar = tar.cuda()
|
|
569
|
+
|
|
570
|
+
# get model output
|
|
571
|
+
rec, _ = model(inp)
|
|
555
572
|
|
|
556
|
-
|
|
557
|
-
|
|
573
|
+
# get reconstructed img
|
|
574
|
+
if model.model.predict_logvar is None:
|
|
575
|
+
rec_img = rec
|
|
576
|
+
logvar = torch.tensor([-1])
|
|
577
|
+
else:
|
|
578
|
+
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
579
|
+
logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
|
|
580
|
+
|
|
581
|
+
tiles.append(rec_img.cpu().numpy())
|
|
582
|
+
|
|
583
|
+
tile_samples = np.concatenate(tiles, axis=0)
|
|
584
|
+
return stitch_predictions_new(tile_samples, dset)
|
|
558
585
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
586
|
+
|
|
587
|
+
def get_single_file_mmse(
|
|
588
|
+
model: VAEModule,
|
|
589
|
+
dset: Dataset,
|
|
590
|
+
batch_size: int,
|
|
591
|
+
tile_size: Optional[tuple[int, int]] = None,
|
|
592
|
+
mmse_count: int = 1,
|
|
593
|
+
num_workers: int = 4,
|
|
594
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
595
|
+
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
596
|
+
dloader = DataLoader(
|
|
597
|
+
dset,
|
|
598
|
+
pin_memory=False,
|
|
599
|
+
num_workers=num_workers,
|
|
600
|
+
shuffle=False,
|
|
601
|
+
batch_size=batch_size,
|
|
602
|
+
)
|
|
603
|
+
if tile_size:
|
|
604
|
+
dset.set_img_sz(tile_size, tile_size[-1] // 2)
|
|
605
|
+
model.eval()
|
|
606
|
+
model.cuda()
|
|
607
|
+
tile_mmse = []
|
|
608
|
+
tile_stds = []
|
|
562
609
|
logvar_arr = []
|
|
563
|
-
num_channels = dset[0][1].shape[0]
|
|
564
|
-
patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)]
|
|
565
610
|
with torch.no_grad():
|
|
566
|
-
for batch in tqdm(dloader, desc="Predicting
|
|
611
|
+
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
567
612
|
inp, tar = batch
|
|
568
613
|
inp = inp.cuda()
|
|
569
614
|
tar = tar.cuda()
|
|
570
615
|
|
|
571
616
|
rec_img_list = []
|
|
572
|
-
for
|
|
573
|
-
|
|
574
|
-
# TODO: case of HDN left for future refactoring
|
|
575
|
-
# if model_type == ModelType.Denoiser:
|
|
576
|
-
# assert model.denoise_channel in [
|
|
577
|
-
# "Ch1",
|
|
578
|
-
# "Ch2",
|
|
579
|
-
# "input",
|
|
580
|
-
# ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'
|
|
581
|
-
|
|
582
|
-
# x_normalized_new, tar_new = model.get_new_input_target(
|
|
583
|
-
# (inp, tar, *batch[2:])
|
|
584
|
-
# )
|
|
585
|
-
# rec, _ = model(x_normalized_new)
|
|
586
|
-
# rec_loss, imgs = model.get_reconstruction_loss(
|
|
587
|
-
# rec,
|
|
588
|
-
# tar,
|
|
589
|
-
# x_normalized_new,
|
|
590
|
-
# return_predicted_img=True,
|
|
591
|
-
# )
|
|
617
|
+
for _ in range(mmse_count):
|
|
592
618
|
|
|
593
619
|
# get model output
|
|
594
620
|
rec, _ = model(inp)
|
|
@@ -600,52 +626,21 @@ def get_dset_predictions(
|
|
|
600
626
|
else:
|
|
601
627
|
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
602
628
|
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
603
|
-
logvar_arr.append(logvar.cpu().numpy())
|
|
604
|
-
|
|
605
|
-
# compute reconstruction loss
|
|
606
|
-
# if loss_type == "musplit":
|
|
607
|
-
# rec_loss = get_reconstruction_loss(
|
|
608
|
-
# reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
|
|
609
|
-
# )
|
|
610
|
-
# elif loss_type == "denoisplit":
|
|
611
|
-
# rec_loss = get_reconstruction_loss(
|
|
612
|
-
# reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
|
|
613
|
-
# )
|
|
614
|
-
# elif loss_type == "denoisplit_musplit":
|
|
615
|
-
# rec_loss = reconstruction_loss_musplit_denoisplit(
|
|
616
|
-
# predictions=rec,
|
|
617
|
-
# targets=tar,
|
|
618
|
-
# gaussian_likelihood=gauss_likelihood,
|
|
619
|
-
# nm_likelihood=nm_likelihood,
|
|
620
|
-
# nm_weight=model.loss_parameters.denoisplit_weight,
|
|
621
|
-
# gaussian_weight=model.loss_parameters.musplit_weight,
|
|
622
|
-
# )
|
|
623
|
-
# rec_loss = {"loss": rec_loss} # hacky, but ok for now
|
|
624
|
-
|
|
625
|
-
# # store rec loss values for first pred
|
|
626
|
-
# if mmse_idx == 0:
|
|
627
|
-
# try:
|
|
628
|
-
# losses.append(rec_loss["loss"].cpu().numpy())
|
|
629
|
-
# except:
|
|
630
|
-
# losses.append(rec_loss["loss"])
|
|
631
|
-
|
|
632
|
-
# update running PSNR
|
|
633
|
-
# for i in range(num_channels):
|
|
634
|
-
# patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])
|
|
629
|
+
logvar_arr.append(logvar.cpu().numpy()) # Why do we need this ?
|
|
635
630
|
|
|
636
631
|
# aggregate results
|
|
637
632
|
samples = torch.cat(rec_img_list, dim=0)
|
|
638
633
|
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
634
|
+
std_imgs = torch.std(samples, dim=0) # std over MMSE dim
|
|
635
|
+
|
|
636
|
+
tile_mmse.append(mmse_imgs.cpu().numpy())
|
|
637
|
+
tile_stds.append(std_imgs.cpu().numpy())
|
|
638
|
+
|
|
639
|
+
tiles_arr = np.concatenate(tile_mmse, axis=0)
|
|
640
|
+
tile_stds = np.concatenate(tile_stds, axis=0)
|
|
641
|
+
stitched_predictions = stitch_predictions_new(tiles_arr, dset)
|
|
642
|
+
stitched_stds = stitch_predictions_new(tile_stds, dset)
|
|
643
|
+
return stitched_predictions, stitched_stds
|
|
649
644
|
|
|
650
645
|
|
|
651
646
|
# ------------------------------------------------------------------------------------------
|
|
@@ -324,6 +324,8 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
324
324
|
if self.data_mean.device != correct_device_tensor.device:
|
|
325
325
|
self.data_mean = self.data_mean.to(correct_device_tensor.device)
|
|
326
326
|
self.data_std = self.data_std.to(correct_device_tensor.device)
|
|
327
|
+
if correct_device_tensor.device != self.noiseModel.device:
|
|
328
|
+
self.noiseModel.to_device(correct_device_tensor.device)
|
|
327
329
|
|
|
328
330
|
def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
|
|
329
331
|
return x, None
|
careamics/models/lvae/lvae.py
CHANGED
|
@@ -6,7 +6,7 @@ and Artefact Removal, Prakash et al."
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
from collections.abc import Iterable
|
|
9
|
-
from typing import Union
|
|
9
|
+
from typing import Optional, Union
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
@@ -834,3 +834,15 @@ class LadderVAE(nn.Module):
|
|
|
834
834
|
# TODO check if model_3D_depth is needed ?
|
|
835
835
|
top_layer_shape = (n_imgs, mu_logvar, self._model_3D_depth, h, w)
|
|
836
836
|
return top_layer_shape
|
|
837
|
+
|
|
838
|
+
def reset_for_inference(self, tile_size: Optional[tuple[int, int]] = None):
|
|
839
|
+
"""Should be called if we want to predict for a different input/output size."""
|
|
840
|
+
self.mode_pred = True
|
|
841
|
+
if tile_size is None:
|
|
842
|
+
tile_size = self.image_size
|
|
843
|
+
self.image_size = tile_size
|
|
844
|
+
for i in range(self.n_layers):
|
|
845
|
+
self.bottom_up_layers[i].output_expected_shape = (
|
|
846
|
+
ts // 2 ** (i + 1) for ts in tile_size
|
|
847
|
+
)
|
|
848
|
+
self.top_down_layers[i].latent_shape = tile_size
|