deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,529 @@
1
+ import torchvision.utils
2
+ from deepinv.utils import (
3
+ save_model,
4
+ AverageMeter,
5
+ get_timestamp,
6
+ cal_psnr,
7
+ )
8
+ from deepinv.utils import plot, plot_curves, wandb_plot_curves, rescale_img, zeros_like
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import torch
12
+ import wandb
13
+ from pathlib import Path
14
+
15
+
16
+ def train(
17
+ model,
18
+ train_dataloader,
19
+ epochs,
20
+ losses,
21
+ eval_dataloader=None,
22
+ physics=None,
23
+ optimizer=None,
24
+ grad_clip=None,
25
+ scheduler=None,
26
+ device="cpu",
27
+ ckp_interval=1,
28
+ eval_interval=1,
29
+ save_path=".",
30
+ verbose=False,
31
+ unsupervised=False,
32
+ plot_images=False,
33
+ plot_metrics=False,
34
+ wandb_vis=False,
35
+ wandb_setup={},
36
+ online_measurements=False,
37
+ plot_measurements=True,
38
+ check_grad=False,
39
+ ckpt_pretrained=None,
40
+ fact_losses=None,
41
+ freq_plot=1,
42
+ ):
43
+ r"""
44
+ Trains a reconstruction network.
45
+
46
+
47
+ .. note::
48
+
49
+ The losses can be chosen from :ref:`the libraries' training losses <loss>`, or can be a custom loss function,
50
+ as long as it takes as input ``(x, x_net, y, physics, model)`` and returns a scalar, where ``x`` is the ground
51
+ reconstruction, ``x_net`` is the network reconstruction :math:`\inversef{y, A}`,
52
+ ``y`` is the measurement vector, ``physics`` is the forward operator
53
+ and ``model`` is the reconstruction network. Note that not all inpus need to be used by the loss,
54
+ e.g., self-supervised losses will not make use of ``x``.
55
+
56
+
57
+ :param torch.nn.Module, deepinv.models.ArtifactRemoval model: Reconstruction network, which can be PnP, unrolled, artifact removal
58
+ or any other custom reconstruction network.
59
+ :param torch.utils.data.DataLoader train_dataloader: Train dataloader.
60
+ :param int epochs: Number of training epochs.
61
+ :param torch.nn.Module, list of torch.nn.Module losses: Loss or list of losses used for training the model.
62
+ :param torch.utils.data.DataLoader eval_dataloader: Evaluation dataloader.
63
+ :param deepinv.physics.Physics, list[deepinv.physics.Physics] physics: Forward operator(s)
64
+ used by the reconstruction network at train time.
65
+ :param torch.nn.optim optimizer: Torch optimizer for training the network.
66
+ :param float grad_clip: Gradient clipping value for the optimizer. If None, no gradient clipping is performed.
67
+ :param torch.nn.optim scheduler: Torch scheduler for changing the learning rate across iterations.
68
+ :param torch.device device: gpu or cpu.
69
+ :param int ckp_interval: The model is saved every ``ckp_interval`` epochs.
70
+ :param int eval_interval: Number of epochs between each evaluation of the model on the evaluation set.
71
+ :param str save_path: Directory in which to save the trained model.
72
+ :param bool verbose: Output training progress information in the console.
73
+ :param bool unsupervised: Train an unsupervised network, i.e., uses only measurement vectors y for training.
74
+ :param bool plot_images: Plots reconstructions every ``ckp_interval`` epochs.
75
+ :param bool wandb_vis: Use Weights & Biases visualization, see https://wandb.ai/ for more details.
76
+ :param dict wandb_setup: Dictionary with the setup for wandb, see https://docs.wandb.ai/quickstart for more details.
77
+ :param bool online_measurements: Generate the measurements in an online manner at each iteration by calling
78
+ ``physics(x)``. This results in a wider range of measurements if the physics' parameters, such as
79
+ parameters of the forward operator or noise realizations, can change between each sample; these are updated
80
+ with the ``physics.reset()`` method. If ``online_measurements=False``, the measurements are loaded from the training dataset
81
+ :param bool plot_measurements: Plot the measurements y. default=True.
82
+ :param bool check_grad: Check the gradient norm at each iteration.
83
+ :param str ckpt_pretrained: path of the pretrained checkpoint. If None, no pretrained checkpoint is loaded.
84
+ :param list fact_losses: List of factors to multiply the losses. If None, all losses are multiplied by 1.
85
+ :param int freq_plot: Frequency of plotting images to wandb. If 1, plots at each epoch.
86
+ :returns: Trained model.
87
+ """
88
+ save_path = Path(save_path)
89
+
90
+ # wandb initialiation
91
+ if wandb_vis:
92
+ if wandb.run is None:
93
+ wandb.init(**wandb_setup)
94
+
95
+ # set the different metrics
96
+ meters = []
97
+ total_loss = AverageMeter("loss", ":.2e")
98
+ meters.append(total_loss)
99
+ if not isinstance(losses, list) or isinstance(losses, tuple):
100
+ losses = [losses]
101
+ if fact_losses is None:
102
+ fact_losses = [1] * len(losses)
103
+ losses_verbose = [AverageMeter("Loss_" + l.name, ":.2e") for l in losses]
104
+ for loss in losses_verbose:
105
+ meters.append(loss)
106
+ train_psnr = AverageMeter("Train_psnr_model", ":.2f")
107
+ meters.append(train_psnr)
108
+ if eval_dataloader:
109
+ eval_psnr = AverageMeter("Eval_psnr_model", ":.2f")
110
+ meters.append(eval_psnr)
111
+ if check_grad:
112
+ check_grad_val = AverageMeter("Gradient norm", ":.2e")
113
+ meters.append(check_grad_val)
114
+
115
+ save_path = f"{save_path}/{get_timestamp()}"
116
+
117
+ # count the overall training parameters
118
+ params = sum(p.numel() for p in model.parameters() if p.requires_grad)
119
+ print(f"The model has {params} trainable parameters")
120
+
121
+ # make physics and data_loaders of list type
122
+ if type(physics) is not list:
123
+ physics = [physics]
124
+ if type(train_dataloader) is not list:
125
+ train_dataloader = [train_dataloader]
126
+ if eval_dataloader and type(eval_dataloader) is not list:
127
+ eval_dataloader = [eval_dataloader]
128
+
129
+ G = len(train_dataloader)
130
+
131
+ loss_history = []
132
+
133
+ log_dict = {}
134
+
135
+ epoch_start = 0
136
+ if ckpt_pretrained is not None:
137
+ checkpoint = torch.load(ckpt_pretrained)
138
+ model.load_state_dict(checkpoint["state_dict"])
139
+ optimizer.load_state_dict(checkpoint["optimizer"])
140
+ epoch_start = checkpoint["epoch"]
141
+
142
+ for epoch in range(epoch_start, epochs):
143
+ ### Evaluation
144
+
145
+ if wandb_vis:
146
+ wandb_log_dict_epoch = {"epoch": epoch}
147
+
148
+ # perform evaluation every eval_interval epoch
149
+ perform_eval = (
150
+ (not unsupervised)
151
+ and eval_dataloader
152
+ and ((epoch + 1) % eval_interval == 0 or epoch + 1 == epochs)
153
+ )
154
+
155
+ if perform_eval:
156
+ test_psnr, _, _, _ = test(
157
+ model,
158
+ eval_dataloader,
159
+ physics,
160
+ device,
161
+ verbose=False,
162
+ plot_images=plot_images,
163
+ plot_metrics=plot_metrics,
164
+ wandb_vis=wandb_vis,
165
+ wandb_setup=wandb_setup,
166
+ step=epoch,
167
+ online_measurements=online_measurements,
168
+ )
169
+ eval_psnr.update(test_psnr)
170
+ log_dict["eval_psnr"] = test_psnr
171
+ if wandb_vis:
172
+ wandb_log_dict_epoch["eval_psnr"] = test_psnr
173
+
174
+ # wandb logging
175
+ if wandb_vis:
176
+ last_lr = None if scheduler is None else scheduler.get_last_lr()[0]
177
+ wandb_log_dict_epoch["learning rate"] = last_lr
178
+
179
+ wandb.log(wandb_log_dict_epoch)
180
+
181
+ ### Training
182
+
183
+ model.train()
184
+
185
+ for meter in meters:
186
+ meter.reset() # reset the metric at each epoch
187
+
188
+ iterators = [iter(loader) for loader in train_dataloader]
189
+ batches = len(train_dataloader[G - 1])
190
+
191
+ for i in (progress_bar := tqdm(range(batches), disable=not verbose)):
192
+ progress_bar.set_description(f"Epoch {epoch + 1}")
193
+
194
+ if wandb_vis:
195
+ wandb_log_dict_iter = {}
196
+
197
+ # random permulation of the dataloaders
198
+ G_perm = np.random.permutation(G)
199
+
200
+ for g in G_perm: # for each dataloader
201
+ if online_measurements: # the measurements y are created on-the-fly
202
+ x, _ = next(
203
+ iterators[g]
204
+ ) # In this case the dataloader outputs also a class label
205
+ x = x.to(device)
206
+ physics_cur = physics[g]
207
+
208
+ if isinstance(physics_cur, torch.nn.DataParallel):
209
+ physics_cur.module.noise_model.__init__()
210
+ else:
211
+ physics_cur.reset()
212
+
213
+ y = physics_cur(x)
214
+
215
+ else: # the measurements y were pre-computed
216
+ if unsupervised:
217
+ y = next(iterators[g])
218
+ x = None
219
+ else:
220
+ x, y = next(iterators[g])
221
+ if type(x) is list or type(x) is tuple:
222
+ x = [s.to(device) for s in x]
223
+ else:
224
+ x = x.to(device)
225
+
226
+ physics_cur = physics[g]
227
+
228
+ y = y.to(device)
229
+
230
+ optimizer.zero_grad()
231
+
232
+ # run the forward model
233
+ x_net = model(y, physics_cur)
234
+
235
+ # compute the losses
236
+ loss_total = 0
237
+ for k, l in enumerate(losses):
238
+ loss = l(x=x, x_net=x_net, y=y, physics=physics[g], model=model)
239
+ loss_total += fact_losses[k] * loss
240
+ losses_verbose[k].update(loss.item())
241
+ if len(losses) > 1:
242
+ log_dict["loss_" + l.name] = losses_verbose[k].avg
243
+ if wandb_vis:
244
+ wandb_log_dict_iter["loss_" + l.name] = loss.item()
245
+ if wandb_vis:
246
+ wandb_log_dict_iter["training loss"] = loss_total.item()
247
+ total_loss.update(loss_total.item())
248
+ log_dict["total_loss"] = total_loss.avg
249
+
250
+ # backward the total loss
251
+ loss_total.backward()
252
+
253
+ # gradient clipping
254
+ if grad_clip is not None:
255
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
256
+
257
+ if check_grad:
258
+ # from https://discuss.pytorch.org/t/check-the-norm-of-gradients/27961/7
259
+ grads = [
260
+ param.grad.detach().flatten()
261
+ for param in model.parameters()
262
+ if param.grad is not None
263
+ ]
264
+ norm_grads = torch.cat(grads).norm()
265
+ wandb_log_dict_iter["gradient norm"] = norm_grads.item()
266
+ check_grad_val.update(norm_grads.item())
267
+
268
+ # optimize step
269
+ optimizer.step()
270
+
271
+ # training psnr and logging
272
+ if not unsupervised:
273
+ with torch.no_grad():
274
+ psnr = cal_psnr(x_net, x)
275
+ train_psnr.update(psnr)
276
+ if wandb_vis:
277
+ wandb_log_dict_iter["train_psnr"] = psnr
278
+ wandb.log(wandb_log_dict_iter)
279
+ log_dict["train_psnr"] = train_psnr.avg
280
+
281
+ progress_bar.set_postfix(log_dict)
282
+
283
+ # wandb plotting of training images
284
+ if wandb_vis:
285
+ # log average training metrics
286
+ log_dict_post_epoch = {}
287
+ log_dict_post_epoch["mean training loss"] = total_loss.avg
288
+ log_dict_post_epoch["mean training psnr"] = train_psnr.avg
289
+ if check_grad:
290
+ log_dict_post_epoch["mean gradient norm"] = check_grad_val.avg
291
+
292
+ with torch.no_grad():
293
+ if plot_measurements and y.shape != x.shape:
294
+ y_reshaped = torch.nn.functional.interpolate(y, size=x.shape[2])
295
+ if hasattr(physics_cur, "A_adjoint"):
296
+ imgs = [y_reshaped, physics_cur.A_adjoint(y), x_net, x]
297
+ caption = (
298
+ "From top to bottom: input, backprojection, output, target"
299
+ )
300
+ else:
301
+ imgs = [y_reshaped, x_net, x]
302
+ caption = "From top to bottom: input, output, target"
303
+ else:
304
+ if hasattr(physics_cur, "A_adjoint"):
305
+ if isinstance(physics_cur, torch.nn.DataParallel):
306
+ back = physics_cur.module.A_adjoint(y)
307
+ else:
308
+ back = physics_cur.A_adjoint(y)
309
+ imgs = [back, x_net, x]
310
+ caption = "From top to bottom: backprojection, output, target"
311
+ else:
312
+ imgs = [x_net, x]
313
+ caption = "From top to bottom: output, target"
314
+
315
+ vis_array = torch.cat(imgs, dim=0)
316
+ for i in range(len(vis_array)):
317
+ vis_array[i] = rescale_img(vis_array[i], rescale_mode="min_max")
318
+ grid_image = torchvision.utils.make_grid(vis_array, nrow=y.shape[0])
319
+ if epoch % freq_plot == 0:
320
+ images = wandb.Image(
321
+ grid_image,
322
+ caption=caption,
323
+ )
324
+ log_dict_post_epoch["Training samples"] = images
325
+
326
+ if wandb_vis:
327
+ wandb.log(log_dict_post_epoch)
328
+
329
+ loss_history.append(total_loss.avg)
330
+
331
+ if scheduler:
332
+ scheduler.step()
333
+
334
+ # Saving the model
335
+ save_model(
336
+ epoch,
337
+ model,
338
+ optimizer,
339
+ ckp_interval,
340
+ epochs,
341
+ loss_history,
342
+ str(save_path),
343
+ eval_psnr=eval_psnr if perform_eval else None,
344
+ )
345
+
346
+ if wandb_vis:
347
+ wandb.save("model.h5")
348
+
349
+ return model
350
+
351
+
352
+ def test(
353
+ model,
354
+ test_dataloader,
355
+ physics,
356
+ device="cpu",
357
+ plot_images=False,
358
+ save_folder="results",
359
+ plot_metrics=False,
360
+ verbose=True,
361
+ plot_only_first_batch=True,
362
+ wandb_vis=False,
363
+ wandb_setup={},
364
+ step=0,
365
+ online_measurements=False,
366
+ plot_measurements=True,
367
+ **kwargs,
368
+ ):
369
+ r"""
370
+ Tests a reconstruction network.
371
+
372
+ This function computes the PSNR of the reconstruction network on the test set,
373
+ and optionally plots the reconstructions as well as the metrics computed along the iterations.
374
+ Note that by default only the first batch is plotted.
375
+
376
+ :param torch.nn.Module, deepinv.models.ArtifactRemoval model: Reconstruction network, which can be PnP, unrolled, artifact removal
377
+ or any other custom reconstruction network.
378
+ :param torch.utils.data.DataLoader test_dataloader: Test data loader, which should provide a tuple of (x, y) pairs.
379
+ See :ref:`datasets <datasets>` for more details.
380
+ :param deepinv.physics.Physics, list[deepinv.physics.Physics] physics: Forward operator(s)
381
+ used by the reconstruction network at test time.
382
+ :param torch.device device: gpu or cpu.
383
+ :param bool plot_images: Plot the ground-truth and estimated images.
384
+ :param str save_folder: Directory in which to save plotted reconstructions.
385
+ :param bool plot_metrics: plot the metrics to be plotted w.r.t iteration.
386
+ :param bool verbose: Output training progress information in the console.
387
+ :param bool plot_only_first_batch: Plot only the first batch of the test set.
388
+ :param bool wandb_vis: Use Weights & Biases visualization, see https://wandb.ai/ for more details.
389
+ :param dict wandb_setup: Dictionary with the setup for wandb, see https://docs.wandb.ai/quickstart for more details.
390
+ :param int step: Step number for wandb visualization.
391
+ :param bool online_measurements: Generate the measurements in an online manner at each iteration by calling
392
+ ``physics(x)``.
393
+ :param bool plot_measurements: Plot the measurements y. default=True.
394
+ :returns: A tuple of floats (test_psnr, test_std_psnr, linear_std_psnr, linear_std_psnr) with the PSNR of the
395
+ reconstruction network and a simple linear inverse on the test set.
396
+ """
397
+ save_folder = Path(save_folder)
398
+
399
+ psnr_init = []
400
+ psnr_net = []
401
+
402
+ model.eval()
403
+
404
+ if type(physics) is not list:
405
+ physics = [physics]
406
+
407
+ if type(test_dataloader) is not list:
408
+ test_dataloader = [test_dataloader]
409
+
410
+ G = len(test_dataloader)
411
+
412
+ show_operators = 5
413
+
414
+ if wandb_vis:
415
+ if wandb.run is None:
416
+ wandb.init(**wandb_setup)
417
+ psnr_data = []
418
+
419
+ for g in range(G):
420
+ dataloader = test_dataloader[g]
421
+ if verbose:
422
+ print(f"Processing data of operator {g+1} out of {G}")
423
+ for i, batch in enumerate(tqdm(dataloader, disable=not verbose)):
424
+ with torch.no_grad():
425
+ if online_measurements:
426
+ (
427
+ x,
428
+ _,
429
+ ) = batch # In this case the dataloader outputs also a class label
430
+ x = x.to(device)
431
+ physics_cur = physics[g]
432
+ if isinstance(physics_cur, torch.nn.DataParallel):
433
+ physics_cur.module.noise_model.__init__()
434
+ else:
435
+ physics_cur.reset()
436
+ y = physics_cur(x)
437
+ else:
438
+ x, y = batch
439
+ if type(x) is list or type(x) is tuple:
440
+ x = [s.to(device) for s in x]
441
+ else:
442
+ x = x.to(device)
443
+ physics_cur = physics[g]
444
+
445
+ y = y.to(device)
446
+
447
+ if plot_metrics:
448
+ x1, metrics = model(y, physics_cur, x_gt=x, compute_metrics=True)
449
+ else:
450
+ x1 = model(y, physics[g])
451
+
452
+ if hasattr(physics_cur, "A_adjoint"):
453
+ if isinstance(physics_cur, torch.nn.DataParallel):
454
+ x_init = physics_cur.module.A_adjoint(y)
455
+ else:
456
+ x_init = physics_cur.A_adjoint(y)
457
+ elif hasattr(physics_cur, "A_dagger"):
458
+ if isinstance(physics_cur, torch.nn.DataParallel):
459
+ x_init = physics_cur.module.A_dagger(y)
460
+ else:
461
+ x_init = physics_cur.A_dagger(y)
462
+ else:
463
+ x_init = zeros_like(x)
464
+
465
+ cur_psnr_init = cal_psnr(x_init, x)
466
+ cur_psnr = cal_psnr(x1, x)
467
+ psnr_init.append(cur_psnr_init)
468
+ psnr_net.append(cur_psnr)
469
+
470
+ if wandb_vis:
471
+ psnr_data.append([g, i, cur_psnr_init, cur_psnr])
472
+
473
+ if plot_images:
474
+ save_folder_im = (
475
+ (save_folder / ("G" + str(g))) if G > 1 else save_folder
476
+ ) / "images"
477
+ save_folder_im.mkdir(parents=True, exist_ok=True)
478
+ else:
479
+ save_folder_im = None
480
+ if plot_metrics:
481
+ save_folder_curve = (
482
+ (save_folder / ("G" + str(g))) if G > 1 else save_folder
483
+ ) / "curves"
484
+ save_folder_curve.mkdir(parents=True, exist_ok=True)
485
+
486
+ if plot_images or wandb_vis:
487
+ if g < show_operators:
488
+ if not plot_only_first_batch or (
489
+ plot_only_first_batch and i == 0
490
+ ):
491
+ if plot_measurements and len(y.shape) == 4:
492
+ imgs = [y, x_init, x1, x]
493
+ name_imgs = ["Input", "No learning", "Recons.", "GT"]
494
+ else:
495
+ imgs = [x_init, x1, x]
496
+ name_imgs = ["No learning", "Recons.", "GT"]
497
+ fig = plot(
498
+ imgs,
499
+ titles=name_imgs,
500
+ save_dir=save_folder_im if plot_images else None,
501
+ show=plot_images,
502
+ return_fig=True,
503
+ )
504
+ if wandb_vis:
505
+ wandb.log(
506
+ {
507
+ f"Test images batch_{i} (G={g}) ": wandb.Image(
508
+ fig
509
+ )
510
+ }
511
+ )
512
+
513
+ if plot_metrics:
514
+ plot_curves(metrics, save_dir=save_folder_curve, show=True)
515
+ if wandb_vis:
516
+ wandb_plot_curves(metrics, batch_idx=i, step=step)
517
+
518
+ test_psnr = np.mean(psnr_net)
519
+ test_std_psnr = np.std(psnr_net)
520
+ linear_psnr = np.mean(psnr_init)
521
+ linear_std_psnr = np.std(psnr_init)
522
+ if verbose:
523
+ print(
524
+ f"Test PSNR: No learning rec.: {linear_psnr:.2f}+-{linear_std_psnr:.2f} dB | Model: {test_psnr:.2f}+-{test_std_psnr:.2f} dB. "
525
+ )
526
+ if wandb_vis:
527
+ wandb.log({"Test PSNR": test_psnr}, step=step)
528
+
529
+ return test_psnr, test_std_psnr, linear_psnr, linear_std_psnr
@@ -0,0 +1,2 @@
1
+ from .rotate import Rotate
2
+ from .shift import Shift
@@ -0,0 +1,41 @@
1
+ import torch
2
+ from torchvision.transforms.functional import rotate
3
+ import numpy as np
4
+
5
+
6
+ class Rotate(torch.nn.Module):
7
+ r"""
8
+ 2D Rotations.
9
+
10
+ Generates n_transf randomly rotated versions of 2D images with zero padding.
11
+
12
+ :param n_trans: number of rotated versions generated per input image.
13
+ :param degrees: images are rotated in the range of angles (-degrees, degrees)
14
+ """
15
+
16
+ def __init__(self, n_trans=1, degrees=360):
17
+ super(Rotate, self).__init__()
18
+ self.n_trans, self.group_size = n_trans, degrees
19
+
20
+ def forward(self, data):
21
+ if self.group_size == 360:
22
+ theta = np.arange(0, 360)[1:][torch.randperm(359)]
23
+ theta = theta[: self.n_trans]
24
+ else:
25
+ theta = np.arange(0, 360, int(360 / (self.group_size + 1)))[1:]
26
+ theta = theta[torch.randperm(self.group_size)][: self.n_trans]
27
+ return torch.cat([rotate(data, float(_theta)) for _theta in theta])
28
+
29
+
30
+ # if __name__ == "__main__":
31
+ # device = "cuda:0"
32
+ #
33
+ # x = torch.zeros(1, 1, 64, 64, device=device)
34
+ # x[:, :, 16:48, 16:48] = 1
35
+ #
36
+ # t = Rotate(4)
37
+ # y = t(x)
38
+ #
39
+ # from deepinv.utils import plot
40
+ #
41
+ # plot([x, y[0, :, :, :].unsqueeze(0), y[1, :, :, :].unsqueeze(0)])
@@ -0,0 +1,26 @@
1
+ import torch
2
+
3
+
4
+ class Shift(torch.nn.Module):
5
+ r"""
6
+ Fast integer 2D translations.
7
+
8
+ Generates n_transf randomly shifted versions of 2D images with circular padding.
9
+
10
+ :param n_trans: number of shifted versions generated per input image.
11
+ """
12
+
13
+ def __init__(self, n_trans=1):
14
+ super(Shift, self).__init__()
15
+ self.n_trans = n_trans
16
+
17
+ def forward(self, data):
18
+ H, W = data.shape[-2:]
19
+ assert self.n_trans <= H - 1 and self.n_trans <= W - 1
20
+ x = torch.arange(-H, H)[torch.randperm(2 * H)][: self.n_trans]
21
+ y = torch.arange(-W, W)[torch.randperm(2 * W)][: self.n_trans]
22
+
23
+ out = torch.cat(
24
+ [torch.roll(data, [sx, sy], [-2, -1]) for sx, sy in zip(x, y)], dim=0
25
+ )
26
+ return out
@@ -0,0 +1,2 @@
1
+ from .unfolded import unfolded_builder, BaseUnfold
2
+ from .deep_equilibrium import DEQ_builder, BaseDEQ