deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,529 @@
|
|
|
1
|
+
import torchvision.utils
|
|
2
|
+
from deepinv.utils import (
|
|
3
|
+
save_model,
|
|
4
|
+
AverageMeter,
|
|
5
|
+
get_timestamp,
|
|
6
|
+
cal_psnr,
|
|
7
|
+
)
|
|
8
|
+
from deepinv.utils import plot, plot_curves, wandb_plot_curves, rescale_img, zeros_like
|
|
9
|
+
import numpy as np
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import torch
|
|
12
|
+
import wandb
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def train(
|
|
17
|
+
model,
|
|
18
|
+
train_dataloader,
|
|
19
|
+
epochs,
|
|
20
|
+
losses,
|
|
21
|
+
eval_dataloader=None,
|
|
22
|
+
physics=None,
|
|
23
|
+
optimizer=None,
|
|
24
|
+
grad_clip=None,
|
|
25
|
+
scheduler=None,
|
|
26
|
+
device="cpu",
|
|
27
|
+
ckp_interval=1,
|
|
28
|
+
eval_interval=1,
|
|
29
|
+
save_path=".",
|
|
30
|
+
verbose=False,
|
|
31
|
+
unsupervised=False,
|
|
32
|
+
plot_images=False,
|
|
33
|
+
plot_metrics=False,
|
|
34
|
+
wandb_vis=False,
|
|
35
|
+
wandb_setup={},
|
|
36
|
+
online_measurements=False,
|
|
37
|
+
plot_measurements=True,
|
|
38
|
+
check_grad=False,
|
|
39
|
+
ckpt_pretrained=None,
|
|
40
|
+
fact_losses=None,
|
|
41
|
+
freq_plot=1,
|
|
42
|
+
):
|
|
43
|
+
r"""
|
|
44
|
+
Trains a reconstruction network.
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
.. note::
|
|
48
|
+
|
|
49
|
+
The losses can be chosen from :ref:`the libraries' training losses <loss>`, or can be a custom loss function,
|
|
50
|
+
as long as it takes as input ``(x, x_net, y, physics, model)`` and returns a scalar, where ``x`` is the ground
|
|
51
|
+
reconstruction, ``x_net`` is the network reconstruction :math:`\inversef{y, A}`,
|
|
52
|
+
``y`` is the measurement vector, ``physics`` is the forward operator
|
|
53
|
+
and ``model`` is the reconstruction network. Note that not all inpus need to be used by the loss,
|
|
54
|
+
e.g., self-supervised losses will not make use of ``x``.
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
:param torch.nn.Module, deepinv.models.ArtifactRemoval model: Reconstruction network, which can be PnP, unrolled, artifact removal
|
|
58
|
+
or any other custom reconstruction network.
|
|
59
|
+
:param torch.utils.data.DataLoader train_dataloader: Train dataloader.
|
|
60
|
+
:param int epochs: Number of training epochs.
|
|
61
|
+
:param torch.nn.Module, list of torch.nn.Module losses: Loss or list of losses used for training the model.
|
|
62
|
+
:param torch.utils.data.DataLoader eval_dataloader: Evaluation dataloader.
|
|
63
|
+
:param deepinv.physics.Physics, list[deepinv.physics.Physics] physics: Forward operator(s)
|
|
64
|
+
used by the reconstruction network at train time.
|
|
65
|
+
:param torch.nn.optim optimizer: Torch optimizer for training the network.
|
|
66
|
+
:param float grad_clip: Gradient clipping value for the optimizer. If None, no gradient clipping is performed.
|
|
67
|
+
:param torch.nn.optim scheduler: Torch scheduler for changing the learning rate across iterations.
|
|
68
|
+
:param torch.device device: gpu or cpu.
|
|
69
|
+
:param int ckp_interval: The model is saved every ``ckp_interval`` epochs.
|
|
70
|
+
:param int eval_interval: Number of epochs between each evaluation of the model on the evaluation set.
|
|
71
|
+
:param str save_path: Directory in which to save the trained model.
|
|
72
|
+
:param bool verbose: Output training progress information in the console.
|
|
73
|
+
:param bool unsupervised: Train an unsupervised network, i.e., uses only measurement vectors y for training.
|
|
74
|
+
:param bool plot_images: Plots reconstructions every ``ckp_interval`` epochs.
|
|
75
|
+
:param bool wandb_vis: Use Weights & Biases visualization, see https://wandb.ai/ for more details.
|
|
76
|
+
:param dict wandb_setup: Dictionary with the setup for wandb, see https://docs.wandb.ai/quickstart for more details.
|
|
77
|
+
:param bool online_measurements: Generate the measurements in an online manner at each iteration by calling
|
|
78
|
+
``physics(x)``. This results in a wider range of measurements if the physics' parameters, such as
|
|
79
|
+
parameters of the forward operator or noise realizations, can change between each sample; these are updated
|
|
80
|
+
with the ``physics.reset()`` method. If ``online_measurements=False``, the measurements are loaded from the training dataset
|
|
81
|
+
:param bool plot_measurements: Plot the measurements y. default=True.
|
|
82
|
+
:param bool check_grad: Check the gradient norm at each iteration.
|
|
83
|
+
:param str ckpt_pretrained: path of the pretrained checkpoint. If None, no pretrained checkpoint is loaded.
|
|
84
|
+
:param list fact_losses: List of factors to multiply the losses. If None, all losses are multiplied by 1.
|
|
85
|
+
:param int freq_plot: Frequency of plotting images to wandb. If 1, plots at each epoch.
|
|
86
|
+
:returns: Trained model.
|
|
87
|
+
"""
|
|
88
|
+
save_path = Path(save_path)
|
|
89
|
+
|
|
90
|
+
# wandb initialiation
|
|
91
|
+
if wandb_vis:
|
|
92
|
+
if wandb.run is None:
|
|
93
|
+
wandb.init(**wandb_setup)
|
|
94
|
+
|
|
95
|
+
# set the different metrics
|
|
96
|
+
meters = []
|
|
97
|
+
total_loss = AverageMeter("loss", ":.2e")
|
|
98
|
+
meters.append(total_loss)
|
|
99
|
+
if not isinstance(losses, list) or isinstance(losses, tuple):
|
|
100
|
+
losses = [losses]
|
|
101
|
+
if fact_losses is None:
|
|
102
|
+
fact_losses = [1] * len(losses)
|
|
103
|
+
losses_verbose = [AverageMeter("Loss_" + l.name, ":.2e") for l in losses]
|
|
104
|
+
for loss in losses_verbose:
|
|
105
|
+
meters.append(loss)
|
|
106
|
+
train_psnr = AverageMeter("Train_psnr_model", ":.2f")
|
|
107
|
+
meters.append(train_psnr)
|
|
108
|
+
if eval_dataloader:
|
|
109
|
+
eval_psnr = AverageMeter("Eval_psnr_model", ":.2f")
|
|
110
|
+
meters.append(eval_psnr)
|
|
111
|
+
if check_grad:
|
|
112
|
+
check_grad_val = AverageMeter("Gradient norm", ":.2e")
|
|
113
|
+
meters.append(check_grad_val)
|
|
114
|
+
|
|
115
|
+
save_path = f"{save_path}/{get_timestamp()}"
|
|
116
|
+
|
|
117
|
+
# count the overall training parameters
|
|
118
|
+
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
119
|
+
print(f"The model has {params} trainable parameters")
|
|
120
|
+
|
|
121
|
+
# make physics and data_loaders of list type
|
|
122
|
+
if type(physics) is not list:
|
|
123
|
+
physics = [physics]
|
|
124
|
+
if type(train_dataloader) is not list:
|
|
125
|
+
train_dataloader = [train_dataloader]
|
|
126
|
+
if eval_dataloader and type(eval_dataloader) is not list:
|
|
127
|
+
eval_dataloader = [eval_dataloader]
|
|
128
|
+
|
|
129
|
+
G = len(train_dataloader)
|
|
130
|
+
|
|
131
|
+
loss_history = []
|
|
132
|
+
|
|
133
|
+
log_dict = {}
|
|
134
|
+
|
|
135
|
+
epoch_start = 0
|
|
136
|
+
if ckpt_pretrained is not None:
|
|
137
|
+
checkpoint = torch.load(ckpt_pretrained)
|
|
138
|
+
model.load_state_dict(checkpoint["state_dict"])
|
|
139
|
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
140
|
+
epoch_start = checkpoint["epoch"]
|
|
141
|
+
|
|
142
|
+
for epoch in range(epoch_start, epochs):
|
|
143
|
+
### Evaluation
|
|
144
|
+
|
|
145
|
+
if wandb_vis:
|
|
146
|
+
wandb_log_dict_epoch = {"epoch": epoch}
|
|
147
|
+
|
|
148
|
+
# perform evaluation every eval_interval epoch
|
|
149
|
+
perform_eval = (
|
|
150
|
+
(not unsupervised)
|
|
151
|
+
and eval_dataloader
|
|
152
|
+
and ((epoch + 1) % eval_interval == 0 or epoch + 1 == epochs)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if perform_eval:
|
|
156
|
+
test_psnr, _, _, _ = test(
|
|
157
|
+
model,
|
|
158
|
+
eval_dataloader,
|
|
159
|
+
physics,
|
|
160
|
+
device,
|
|
161
|
+
verbose=False,
|
|
162
|
+
plot_images=plot_images,
|
|
163
|
+
plot_metrics=plot_metrics,
|
|
164
|
+
wandb_vis=wandb_vis,
|
|
165
|
+
wandb_setup=wandb_setup,
|
|
166
|
+
step=epoch,
|
|
167
|
+
online_measurements=online_measurements,
|
|
168
|
+
)
|
|
169
|
+
eval_psnr.update(test_psnr)
|
|
170
|
+
log_dict["eval_psnr"] = test_psnr
|
|
171
|
+
if wandb_vis:
|
|
172
|
+
wandb_log_dict_epoch["eval_psnr"] = test_psnr
|
|
173
|
+
|
|
174
|
+
# wandb logging
|
|
175
|
+
if wandb_vis:
|
|
176
|
+
last_lr = None if scheduler is None else scheduler.get_last_lr()[0]
|
|
177
|
+
wandb_log_dict_epoch["learning rate"] = last_lr
|
|
178
|
+
|
|
179
|
+
wandb.log(wandb_log_dict_epoch)
|
|
180
|
+
|
|
181
|
+
### Training
|
|
182
|
+
|
|
183
|
+
model.train()
|
|
184
|
+
|
|
185
|
+
for meter in meters:
|
|
186
|
+
meter.reset() # reset the metric at each epoch
|
|
187
|
+
|
|
188
|
+
iterators = [iter(loader) for loader in train_dataloader]
|
|
189
|
+
batches = len(train_dataloader[G - 1])
|
|
190
|
+
|
|
191
|
+
for i in (progress_bar := tqdm(range(batches), disable=not verbose)):
|
|
192
|
+
progress_bar.set_description(f"Epoch {epoch + 1}")
|
|
193
|
+
|
|
194
|
+
if wandb_vis:
|
|
195
|
+
wandb_log_dict_iter = {}
|
|
196
|
+
|
|
197
|
+
# random permulation of the dataloaders
|
|
198
|
+
G_perm = np.random.permutation(G)
|
|
199
|
+
|
|
200
|
+
for g in G_perm: # for each dataloader
|
|
201
|
+
if online_measurements: # the measurements y are created on-the-fly
|
|
202
|
+
x, _ = next(
|
|
203
|
+
iterators[g]
|
|
204
|
+
) # In this case the dataloader outputs also a class label
|
|
205
|
+
x = x.to(device)
|
|
206
|
+
physics_cur = physics[g]
|
|
207
|
+
|
|
208
|
+
if isinstance(physics_cur, torch.nn.DataParallel):
|
|
209
|
+
physics_cur.module.noise_model.__init__()
|
|
210
|
+
else:
|
|
211
|
+
physics_cur.reset()
|
|
212
|
+
|
|
213
|
+
y = physics_cur(x)
|
|
214
|
+
|
|
215
|
+
else: # the measurements y were pre-computed
|
|
216
|
+
if unsupervised:
|
|
217
|
+
y = next(iterators[g])
|
|
218
|
+
x = None
|
|
219
|
+
else:
|
|
220
|
+
x, y = next(iterators[g])
|
|
221
|
+
if type(x) is list or type(x) is tuple:
|
|
222
|
+
x = [s.to(device) for s in x]
|
|
223
|
+
else:
|
|
224
|
+
x = x.to(device)
|
|
225
|
+
|
|
226
|
+
physics_cur = physics[g]
|
|
227
|
+
|
|
228
|
+
y = y.to(device)
|
|
229
|
+
|
|
230
|
+
optimizer.zero_grad()
|
|
231
|
+
|
|
232
|
+
# run the forward model
|
|
233
|
+
x_net = model(y, physics_cur)
|
|
234
|
+
|
|
235
|
+
# compute the losses
|
|
236
|
+
loss_total = 0
|
|
237
|
+
for k, l in enumerate(losses):
|
|
238
|
+
loss = l(x=x, x_net=x_net, y=y, physics=physics[g], model=model)
|
|
239
|
+
loss_total += fact_losses[k] * loss
|
|
240
|
+
losses_verbose[k].update(loss.item())
|
|
241
|
+
if len(losses) > 1:
|
|
242
|
+
log_dict["loss_" + l.name] = losses_verbose[k].avg
|
|
243
|
+
if wandb_vis:
|
|
244
|
+
wandb_log_dict_iter["loss_" + l.name] = loss.item()
|
|
245
|
+
if wandb_vis:
|
|
246
|
+
wandb_log_dict_iter["training loss"] = loss_total.item()
|
|
247
|
+
total_loss.update(loss_total.item())
|
|
248
|
+
log_dict["total_loss"] = total_loss.avg
|
|
249
|
+
|
|
250
|
+
# backward the total loss
|
|
251
|
+
loss_total.backward()
|
|
252
|
+
|
|
253
|
+
# gradient clipping
|
|
254
|
+
if grad_clip is not None:
|
|
255
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
|
256
|
+
|
|
257
|
+
if check_grad:
|
|
258
|
+
# from https://discuss.pytorch.org/t/check-the-norm-of-gradients/27961/7
|
|
259
|
+
grads = [
|
|
260
|
+
param.grad.detach().flatten()
|
|
261
|
+
for param in model.parameters()
|
|
262
|
+
if param.grad is not None
|
|
263
|
+
]
|
|
264
|
+
norm_grads = torch.cat(grads).norm()
|
|
265
|
+
wandb_log_dict_iter["gradient norm"] = norm_grads.item()
|
|
266
|
+
check_grad_val.update(norm_grads.item())
|
|
267
|
+
|
|
268
|
+
# optimize step
|
|
269
|
+
optimizer.step()
|
|
270
|
+
|
|
271
|
+
# training psnr and logging
|
|
272
|
+
if not unsupervised:
|
|
273
|
+
with torch.no_grad():
|
|
274
|
+
psnr = cal_psnr(x_net, x)
|
|
275
|
+
train_psnr.update(psnr)
|
|
276
|
+
if wandb_vis:
|
|
277
|
+
wandb_log_dict_iter["train_psnr"] = psnr
|
|
278
|
+
wandb.log(wandb_log_dict_iter)
|
|
279
|
+
log_dict["train_psnr"] = train_psnr.avg
|
|
280
|
+
|
|
281
|
+
progress_bar.set_postfix(log_dict)
|
|
282
|
+
|
|
283
|
+
# wandb plotting of training images
|
|
284
|
+
if wandb_vis:
|
|
285
|
+
# log average training metrics
|
|
286
|
+
log_dict_post_epoch = {}
|
|
287
|
+
log_dict_post_epoch["mean training loss"] = total_loss.avg
|
|
288
|
+
log_dict_post_epoch["mean training psnr"] = train_psnr.avg
|
|
289
|
+
if check_grad:
|
|
290
|
+
log_dict_post_epoch["mean gradient norm"] = check_grad_val.avg
|
|
291
|
+
|
|
292
|
+
with torch.no_grad():
|
|
293
|
+
if plot_measurements and y.shape != x.shape:
|
|
294
|
+
y_reshaped = torch.nn.functional.interpolate(y, size=x.shape[2])
|
|
295
|
+
if hasattr(physics_cur, "A_adjoint"):
|
|
296
|
+
imgs = [y_reshaped, physics_cur.A_adjoint(y), x_net, x]
|
|
297
|
+
caption = (
|
|
298
|
+
"From top to bottom: input, backprojection, output, target"
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
imgs = [y_reshaped, x_net, x]
|
|
302
|
+
caption = "From top to bottom: input, output, target"
|
|
303
|
+
else:
|
|
304
|
+
if hasattr(physics_cur, "A_adjoint"):
|
|
305
|
+
if isinstance(physics_cur, torch.nn.DataParallel):
|
|
306
|
+
back = physics_cur.module.A_adjoint(y)
|
|
307
|
+
else:
|
|
308
|
+
back = physics_cur.A_adjoint(y)
|
|
309
|
+
imgs = [back, x_net, x]
|
|
310
|
+
caption = "From top to bottom: backprojection, output, target"
|
|
311
|
+
else:
|
|
312
|
+
imgs = [x_net, x]
|
|
313
|
+
caption = "From top to bottom: output, target"
|
|
314
|
+
|
|
315
|
+
vis_array = torch.cat(imgs, dim=0)
|
|
316
|
+
for i in range(len(vis_array)):
|
|
317
|
+
vis_array[i] = rescale_img(vis_array[i], rescale_mode="min_max")
|
|
318
|
+
grid_image = torchvision.utils.make_grid(vis_array, nrow=y.shape[0])
|
|
319
|
+
if epoch % freq_plot == 0:
|
|
320
|
+
images = wandb.Image(
|
|
321
|
+
grid_image,
|
|
322
|
+
caption=caption,
|
|
323
|
+
)
|
|
324
|
+
log_dict_post_epoch["Training samples"] = images
|
|
325
|
+
|
|
326
|
+
if wandb_vis:
|
|
327
|
+
wandb.log(log_dict_post_epoch)
|
|
328
|
+
|
|
329
|
+
loss_history.append(total_loss.avg)
|
|
330
|
+
|
|
331
|
+
if scheduler:
|
|
332
|
+
scheduler.step()
|
|
333
|
+
|
|
334
|
+
# Saving the model
|
|
335
|
+
save_model(
|
|
336
|
+
epoch,
|
|
337
|
+
model,
|
|
338
|
+
optimizer,
|
|
339
|
+
ckp_interval,
|
|
340
|
+
epochs,
|
|
341
|
+
loss_history,
|
|
342
|
+
str(save_path),
|
|
343
|
+
eval_psnr=eval_psnr if perform_eval else None,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
if wandb_vis:
|
|
347
|
+
wandb.save("model.h5")
|
|
348
|
+
|
|
349
|
+
return model
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def test(
|
|
353
|
+
model,
|
|
354
|
+
test_dataloader,
|
|
355
|
+
physics,
|
|
356
|
+
device="cpu",
|
|
357
|
+
plot_images=False,
|
|
358
|
+
save_folder="results",
|
|
359
|
+
plot_metrics=False,
|
|
360
|
+
verbose=True,
|
|
361
|
+
plot_only_first_batch=True,
|
|
362
|
+
wandb_vis=False,
|
|
363
|
+
wandb_setup={},
|
|
364
|
+
step=0,
|
|
365
|
+
online_measurements=False,
|
|
366
|
+
plot_measurements=True,
|
|
367
|
+
**kwargs,
|
|
368
|
+
):
|
|
369
|
+
r"""
|
|
370
|
+
Tests a reconstruction network.
|
|
371
|
+
|
|
372
|
+
This function computes the PSNR of the reconstruction network on the test set,
|
|
373
|
+
and optionally plots the reconstructions as well as the metrics computed along the iterations.
|
|
374
|
+
Note that by default only the first batch is plotted.
|
|
375
|
+
|
|
376
|
+
:param torch.nn.Module, deepinv.models.ArtifactRemoval model: Reconstruction network, which can be PnP, unrolled, artifact removal
|
|
377
|
+
or any other custom reconstruction network.
|
|
378
|
+
:param torch.utils.data.DataLoader test_dataloader: Test data loader, which should provide a tuple of (x, y) pairs.
|
|
379
|
+
See :ref:`datasets <datasets>` for more details.
|
|
380
|
+
:param deepinv.physics.Physics, list[deepinv.physics.Physics] physics: Forward operator(s)
|
|
381
|
+
used by the reconstruction network at test time.
|
|
382
|
+
:param torch.device device: gpu or cpu.
|
|
383
|
+
:param bool plot_images: Plot the ground-truth and estimated images.
|
|
384
|
+
:param str save_folder: Directory in which to save plotted reconstructions.
|
|
385
|
+
:param bool plot_metrics: plot the metrics to be plotted w.r.t iteration.
|
|
386
|
+
:param bool verbose: Output training progress information in the console.
|
|
387
|
+
:param bool plot_only_first_batch: Plot only the first batch of the test set.
|
|
388
|
+
:param bool wandb_vis: Use Weights & Biases visualization, see https://wandb.ai/ for more details.
|
|
389
|
+
:param dict wandb_setup: Dictionary with the setup for wandb, see https://docs.wandb.ai/quickstart for more details.
|
|
390
|
+
:param int step: Step number for wandb visualization.
|
|
391
|
+
:param bool online_measurements: Generate the measurements in an online manner at each iteration by calling
|
|
392
|
+
``physics(x)``.
|
|
393
|
+
:param bool plot_measurements: Plot the measurements y. default=True.
|
|
394
|
+
:returns: A tuple of floats (test_psnr, test_std_psnr, linear_std_psnr, linear_std_psnr) with the PSNR of the
|
|
395
|
+
reconstruction network and a simple linear inverse on the test set.
|
|
396
|
+
"""
|
|
397
|
+
save_folder = Path(save_folder)
|
|
398
|
+
|
|
399
|
+
psnr_init = []
|
|
400
|
+
psnr_net = []
|
|
401
|
+
|
|
402
|
+
model.eval()
|
|
403
|
+
|
|
404
|
+
if type(physics) is not list:
|
|
405
|
+
physics = [physics]
|
|
406
|
+
|
|
407
|
+
if type(test_dataloader) is not list:
|
|
408
|
+
test_dataloader = [test_dataloader]
|
|
409
|
+
|
|
410
|
+
G = len(test_dataloader)
|
|
411
|
+
|
|
412
|
+
show_operators = 5
|
|
413
|
+
|
|
414
|
+
if wandb_vis:
|
|
415
|
+
if wandb.run is None:
|
|
416
|
+
wandb.init(**wandb_setup)
|
|
417
|
+
psnr_data = []
|
|
418
|
+
|
|
419
|
+
for g in range(G):
|
|
420
|
+
dataloader = test_dataloader[g]
|
|
421
|
+
if verbose:
|
|
422
|
+
print(f"Processing data of operator {g+1} out of {G}")
|
|
423
|
+
for i, batch in enumerate(tqdm(dataloader, disable=not verbose)):
|
|
424
|
+
with torch.no_grad():
|
|
425
|
+
if online_measurements:
|
|
426
|
+
(
|
|
427
|
+
x,
|
|
428
|
+
_,
|
|
429
|
+
) = batch # In this case the dataloader outputs also a class label
|
|
430
|
+
x = x.to(device)
|
|
431
|
+
physics_cur = physics[g]
|
|
432
|
+
if isinstance(physics_cur, torch.nn.DataParallel):
|
|
433
|
+
physics_cur.module.noise_model.__init__()
|
|
434
|
+
else:
|
|
435
|
+
physics_cur.reset()
|
|
436
|
+
y = physics_cur(x)
|
|
437
|
+
else:
|
|
438
|
+
x, y = batch
|
|
439
|
+
if type(x) is list or type(x) is tuple:
|
|
440
|
+
x = [s.to(device) for s in x]
|
|
441
|
+
else:
|
|
442
|
+
x = x.to(device)
|
|
443
|
+
physics_cur = physics[g]
|
|
444
|
+
|
|
445
|
+
y = y.to(device)
|
|
446
|
+
|
|
447
|
+
if plot_metrics:
|
|
448
|
+
x1, metrics = model(y, physics_cur, x_gt=x, compute_metrics=True)
|
|
449
|
+
else:
|
|
450
|
+
x1 = model(y, physics[g])
|
|
451
|
+
|
|
452
|
+
if hasattr(physics_cur, "A_adjoint"):
|
|
453
|
+
if isinstance(physics_cur, torch.nn.DataParallel):
|
|
454
|
+
x_init = physics_cur.module.A_adjoint(y)
|
|
455
|
+
else:
|
|
456
|
+
x_init = physics_cur.A_adjoint(y)
|
|
457
|
+
elif hasattr(physics_cur, "A_dagger"):
|
|
458
|
+
if isinstance(physics_cur, torch.nn.DataParallel):
|
|
459
|
+
x_init = physics_cur.module.A_dagger(y)
|
|
460
|
+
else:
|
|
461
|
+
x_init = physics_cur.A_dagger(y)
|
|
462
|
+
else:
|
|
463
|
+
x_init = zeros_like(x)
|
|
464
|
+
|
|
465
|
+
cur_psnr_init = cal_psnr(x_init, x)
|
|
466
|
+
cur_psnr = cal_psnr(x1, x)
|
|
467
|
+
psnr_init.append(cur_psnr_init)
|
|
468
|
+
psnr_net.append(cur_psnr)
|
|
469
|
+
|
|
470
|
+
if wandb_vis:
|
|
471
|
+
psnr_data.append([g, i, cur_psnr_init, cur_psnr])
|
|
472
|
+
|
|
473
|
+
if plot_images:
|
|
474
|
+
save_folder_im = (
|
|
475
|
+
(save_folder / ("G" + str(g))) if G > 1 else save_folder
|
|
476
|
+
) / "images"
|
|
477
|
+
save_folder_im.mkdir(parents=True, exist_ok=True)
|
|
478
|
+
else:
|
|
479
|
+
save_folder_im = None
|
|
480
|
+
if plot_metrics:
|
|
481
|
+
save_folder_curve = (
|
|
482
|
+
(save_folder / ("G" + str(g))) if G > 1 else save_folder
|
|
483
|
+
) / "curves"
|
|
484
|
+
save_folder_curve.mkdir(parents=True, exist_ok=True)
|
|
485
|
+
|
|
486
|
+
if plot_images or wandb_vis:
|
|
487
|
+
if g < show_operators:
|
|
488
|
+
if not plot_only_first_batch or (
|
|
489
|
+
plot_only_first_batch and i == 0
|
|
490
|
+
):
|
|
491
|
+
if plot_measurements and len(y.shape) == 4:
|
|
492
|
+
imgs = [y, x_init, x1, x]
|
|
493
|
+
name_imgs = ["Input", "No learning", "Recons.", "GT"]
|
|
494
|
+
else:
|
|
495
|
+
imgs = [x_init, x1, x]
|
|
496
|
+
name_imgs = ["No learning", "Recons.", "GT"]
|
|
497
|
+
fig = plot(
|
|
498
|
+
imgs,
|
|
499
|
+
titles=name_imgs,
|
|
500
|
+
save_dir=save_folder_im if plot_images else None,
|
|
501
|
+
show=plot_images,
|
|
502
|
+
return_fig=True,
|
|
503
|
+
)
|
|
504
|
+
if wandb_vis:
|
|
505
|
+
wandb.log(
|
|
506
|
+
{
|
|
507
|
+
f"Test images batch_{i} (G={g}) ": wandb.Image(
|
|
508
|
+
fig
|
|
509
|
+
)
|
|
510
|
+
}
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
if plot_metrics:
|
|
514
|
+
plot_curves(metrics, save_dir=save_folder_curve, show=True)
|
|
515
|
+
if wandb_vis:
|
|
516
|
+
wandb_plot_curves(metrics, batch_idx=i, step=step)
|
|
517
|
+
|
|
518
|
+
test_psnr = np.mean(psnr_net)
|
|
519
|
+
test_std_psnr = np.std(psnr_net)
|
|
520
|
+
linear_psnr = np.mean(psnr_init)
|
|
521
|
+
linear_std_psnr = np.std(psnr_init)
|
|
522
|
+
if verbose:
|
|
523
|
+
print(
|
|
524
|
+
f"Test PSNR: No learning rec.: {linear_psnr:.2f}+-{linear_std_psnr:.2f} dB | Model: {test_psnr:.2f}+-{test_std_psnr:.2f} dB. "
|
|
525
|
+
)
|
|
526
|
+
if wandb_vis:
|
|
527
|
+
wandb.log({"Test PSNR": test_psnr}, step=step)
|
|
528
|
+
|
|
529
|
+
return test_psnr, test_std_psnr, linear_psnr, linear_std_psnr
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torchvision.transforms.functional import rotate
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Rotate(torch.nn.Module):
|
|
7
|
+
r"""
|
|
8
|
+
2D Rotations.
|
|
9
|
+
|
|
10
|
+
Generates n_transf randomly rotated versions of 2D images with zero padding.
|
|
11
|
+
|
|
12
|
+
:param n_trans: number of rotated versions generated per input image.
|
|
13
|
+
:param degrees: images are rotated in the range of angles (-degrees, degrees)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, n_trans=1, degrees=360):
|
|
17
|
+
super(Rotate, self).__init__()
|
|
18
|
+
self.n_trans, self.group_size = n_trans, degrees
|
|
19
|
+
|
|
20
|
+
def forward(self, data):
|
|
21
|
+
if self.group_size == 360:
|
|
22
|
+
theta = np.arange(0, 360)[1:][torch.randperm(359)]
|
|
23
|
+
theta = theta[: self.n_trans]
|
|
24
|
+
else:
|
|
25
|
+
theta = np.arange(0, 360, int(360 / (self.group_size + 1)))[1:]
|
|
26
|
+
theta = theta[torch.randperm(self.group_size)][: self.n_trans]
|
|
27
|
+
return torch.cat([rotate(data, float(_theta)) for _theta in theta])
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# if __name__ == "__main__":
|
|
31
|
+
# device = "cuda:0"
|
|
32
|
+
#
|
|
33
|
+
# x = torch.zeros(1, 1, 64, 64, device=device)
|
|
34
|
+
# x[:, :, 16:48, 16:48] = 1
|
|
35
|
+
#
|
|
36
|
+
# t = Rotate(4)
|
|
37
|
+
# y = t(x)
|
|
38
|
+
#
|
|
39
|
+
# from deepinv.utils import plot
|
|
40
|
+
#
|
|
41
|
+
# plot([x, y[0, :, :, :].unsqueeze(0), y[1, :, :, :].unsqueeze(0)])
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Shift(torch.nn.Module):
|
|
5
|
+
r"""
|
|
6
|
+
Fast integer 2D translations.
|
|
7
|
+
|
|
8
|
+
Generates n_transf randomly shifted versions of 2D images with circular padding.
|
|
9
|
+
|
|
10
|
+
:param n_trans: number of shifted versions generated per input image.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, n_trans=1):
|
|
14
|
+
super(Shift, self).__init__()
|
|
15
|
+
self.n_trans = n_trans
|
|
16
|
+
|
|
17
|
+
def forward(self, data):
|
|
18
|
+
H, W = data.shape[-2:]
|
|
19
|
+
assert self.n_trans <= H - 1 and self.n_trans <= W - 1
|
|
20
|
+
x = torch.arange(-H, H)[torch.randperm(2 * H)][: self.n_trans]
|
|
21
|
+
y = torch.arange(-W, W)[torch.randperm(2 * W)][: self.n_trans]
|
|
22
|
+
|
|
23
|
+
out = torch.cat(
|
|
24
|
+
[torch.roll(data, [sx, sy], [-2, -1]) for sx, sy in zip(x, y)], dim=0
|
|
25
|
+
)
|
|
26
|
+
return out
|