deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,647 @@
1
+ import pytest
2
+
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+ import deepinv as dinv
7
+ from deepinv.optim import DataFidelity
8
+ from deepinv.optim.data_fidelity import L2, IndicatorL2, L1
9
+ from deepinv.optim.prior import Prior, PnP, RED
10
+ from deepinv.optim.optimizers import optim_builder
11
+
12
+
13
+ def custom_init_CP(y, physics):
14
+ x_init = physics.A_adjoint(y)
15
+ u_init = y
16
+ return {"est": (x_init, x_init, u_init)}
17
+
18
+
19
+ def test_data_fidelity_l2(device):
20
+ data_fidelity = L2()
21
+
22
+ # 1. Testing value of the loss for a simple case
23
+ # Define two points
24
+ x = torch.Tensor([[1], [4]]).unsqueeze(0).to(device)
25
+ y = torch.Tensor([[1], [1]]).unsqueeze(0).to(device)
26
+
27
+ # Create a measurement operator
28
+ A = torch.Tensor([[2, 0], [0, 0.5]]).to(device)
29
+ A_forward = lambda v: A @ v
30
+ A_adjoint = lambda v: A.transpose(0, 1) @ v
31
+
32
+ # Define the physics model associated to this operator
33
+ physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
34
+ assert torch.allclose(data_fidelity(x, y, physics), torch.Tensor([1.0]).to(device))
35
+
36
+ # Compute the gradient of f
37
+ grad_dA = data_fidelity.grad(
38
+ x, y, physics
39
+ ) # print(grad_dA) gives [[[2.0000], [0.5000]]]
40
+
41
+ # Compute the proximity operator of f
42
+ prox_dA = data_fidelity.prox(
43
+ x, y, physics, gamma=1.0
44
+ ) # print(prox_dA) gives [[[0.6000], [3.6000]]]
45
+
46
+ # 2. Testing trivial operations on f and not f\circ A
47
+ gamma = 1.0
48
+ assert torch.allclose(
49
+ data_fidelity.prox_d(x, y, gamma), (x + gamma * y) / (1 + gamma)
50
+ )
51
+ assert torch.allclose(data_fidelity.grad_d(x, y), x - y)
52
+
53
+ # 3. Testing the value of the proximity operator for a nonsymmetric linear operator
54
+ # Create a measurement operator
55
+ B = torch.Tensor([[2, 1], [-1, 0.5]]).to(device)
56
+ B_forward = lambda v: B @ v
57
+ B_adjoint = lambda v: B.transpose(0, 1) @ v
58
+
59
+ # Define the physics model associated to this operator
60
+ physics = dinv.physics.LinearPhysics(A=B_forward, A_adjoint=B_adjoint)
61
+
62
+ # Compute the proximity operator manually (closed form formula)
63
+ Id = torch.eye(2).to(device)
64
+ manual_prox = (Id + gamma * B.transpose(0, 1) @ B).inverse() @ (
65
+ x + gamma * B.transpose(0, 1) @ y
66
+ )
67
+
68
+ # Compute the deepinv proximity operator
69
+ deepinv_prox = data_fidelity.prox(x, y, physics, gamma=gamma)
70
+
71
+ assert torch.allclose(deepinv_prox, manual_prox)
72
+
73
+ # 4. Testing the gradient of the loss
74
+ grad_deepinv = data_fidelity.grad(x, y, physics)
75
+ grad_manual = B.transpose(0, 1) @ (B @ x - y)
76
+
77
+ assert torch.allclose(grad_deepinv, grad_manual)
78
+
79
+ # 5. Testing the torch autograd implementation of the gradient
80
+ def dummy_torch_l2(x, y):
81
+ return 0.5 * torch.norm((B @ (x - y)).flatten(), p=2, dim=-1) ** 2
82
+
83
+ torch_loss = DataFidelity(d=dummy_torch_l2)
84
+ torch_loss_grad = torch_loss.grad_d(x, y)
85
+ grad_manual = B.transpose(0, 1) @ (B @ (x - y))
86
+ assert torch.allclose(torch_loss_grad, grad_manual)
87
+
88
+ # 6. Testing the torch autograd implementation of the prox
89
+
90
+ torch_loss = DataFidelity(d=dummy_torch_l2)
91
+ torch_loss_prox = torch_loss.prox_d(
92
+ x, y, gamma=gamma, stepsize_inter=0.1, max_iter_inter=1000, tol_inter=1e-6
93
+ )
94
+
95
+ manual_prox = (Id + gamma * B.transpose(0, 1) @ B).inverse() @ (
96
+ x + gamma * B.transpose(0, 1) @ B @ y
97
+ )
98
+
99
+ assert torch.allclose(torch_loss_prox, manual_prox)
100
+
101
+
102
+ def test_data_fidelity_indicator(device):
103
+ # Define two points
104
+ x = torch.Tensor([[1], [4]]).unsqueeze(0).to(device)
105
+ y = torch.Tensor([[1], [1]]).unsqueeze(0).to(device)
106
+
107
+ # Redefine the data fidelity with a different radius
108
+ radius = 0.5
109
+ data_fidelity = IndicatorL2(radius=radius)
110
+
111
+ # Create a measurement operator
112
+ A = torch.Tensor([[2, 0], [0, 0.5]]).to(device)
113
+ A_forward = lambda v: A @ v
114
+ A_adjoint = lambda v: A.transpose(0, 1) @ v
115
+
116
+ # Define the physics model associated to this operator
117
+ physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
118
+
119
+ # Test values of the loss for points inside and outside the l2 ball
120
+ assert data_fidelity(x, y, physics) == 1e16
121
+ assert data_fidelity(x / 2, y, physics) == 0
122
+ assert data_fidelity.d(x, y, radius=1) == 1e16
123
+ assert data_fidelity.d(x, y, radius=3.1) == 0
124
+
125
+ # 2. Testing trivial operations on f (and not f \circ A)
126
+ x_proj = torch.Tensor([[[1.0], [1 + radius]]]).to(device)
127
+ assert torch.allclose(data_fidelity.prox_d(x, y), x_proj)
128
+
129
+ # 3. Testing the proximity operator of the f \circ A
130
+ data_fidelity = IndicatorL2(radius=0.5)
131
+
132
+ x = torch.Tensor([[1], [4]]).unsqueeze(0).to(device)
133
+ y = torch.Tensor([[1], [1]]).unsqueeze(0).to(device)
134
+
135
+ A = torch.Tensor([[2, 0], [0, 0.5]]).to(device)
136
+ A_forward = lambda v: A @ v
137
+ A_adjoint = lambda v: A.transpose(0, 1) @ v
138
+ physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
139
+
140
+ # Define the physics model associated to this operator
141
+ x_proj = torch.Tensor([[[0.5290], [2.9932]]]).to(device)
142
+ dfb_proj = data_fidelity.prox(x, y, physics, max_iter=1000, crit_conv=1e-12)
143
+ assert torch.allclose(x_proj, dfb_proj, atol=1e-4)
144
+ assert torch.norm(A_forward(dfb_proj) - y) <= radius + 1e-06
145
+
146
+
147
+ def test_data_fidelity_l1(device):
148
+ # Define two points
149
+ x = torch.Tensor([[[1], [4], [-0.5]]]).to(device)
150
+ y = torch.Tensor([[[1], [1], [1]]]).to(device)
151
+
152
+ data_fidelity = L1()
153
+ assert torch.allclose(data_fidelity.d(x, y), (x - y).abs().sum())
154
+
155
+ A = torch.Tensor([[2, 0, 0], [0, -0.5, 0], [0, 0, 1]]).to(device)
156
+ A_forward = lambda v: A @ v
157
+ A_adjoint = lambda v: A.transpose(0, 1) @ v
158
+
159
+ # Define the physics model associated to this operator
160
+ physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
161
+ Ax = A_forward(x)
162
+ assert data_fidelity(x, y, physics) == (Ax - y).abs().sum()
163
+
164
+ # Check subdifferential
165
+ grad_manual = torch.sign(x - y)
166
+ assert torch.allclose(data_fidelity.grad_d(x, y), grad_manual)
167
+
168
+ # Check prox
169
+ threshold = 0.5
170
+ prox_manual = torch.Tensor([[[1.0], [3.5], [0.0]]]).to(device)
171
+ assert torch.allclose(data_fidelity.prox_d(x, y, threshold), prox_manual)
172
+
173
+
174
+ # we do not test CP (Chambolle-Pock) as we have a dedicated test (due to more specific optimality conditions)
175
+ @pytest.mark.parametrize("name_algo", ["PGD", "ADMM", "DRS", "HQS"])
176
+ def test_optim_algo(name_algo, imsize, dummy_dataset, device):
177
+ for g_first in [True, False]:
178
+ # Define two points
179
+ x = torch.tensor([[[10], [10]]], dtype=torch.float64)
180
+
181
+ # Create a measurement operator
182
+ B = torch.tensor([[2, 1], [-1, 0.5]], dtype=torch.float64)
183
+ B_forward = lambda v: B @ v
184
+ B_adjoint = lambda v: B.transpose(0, 1) @ v
185
+
186
+ # Define the physics model associated to this operator
187
+ physics = dinv.physics.LinearPhysics(A=B_forward, A_adjoint=B_adjoint)
188
+ y = physics(x)
189
+
190
+ data_fidelity = L2() # The data fidelity term
191
+
192
+ def prior_g(x, *args):
193
+ ths = 0.1
194
+ return ths * torch.norm(x.view(x.shape[0], -1), p=1, dim=-1)
195
+
196
+ prior = Prior(g=prior_g) # The prior term
197
+
198
+ if (
199
+ name_algo == "CP"
200
+ ): # In the case of primal-dual, stepsizes need to be bounded as reg_param*stepsize < 1/physics.compute_norm(x, tol=1e-4).item()
201
+ stepsize = 0.9 / physics.compute_norm(x, tol=1e-4).item()
202
+ sigma = 1.0
203
+ else: # Note that not all other algos need such constraints on parameters, but we use these to check that the computations are correct
204
+ stepsize = 0.9 / physics.compute_norm(x, tol=1e-4).item()
205
+ sigma = None
206
+
207
+ lamb = 1.1
208
+ max_iter = 1000
209
+ params_algo = {"stepsize": stepsize, "lambda": lamb, "sigma": sigma}
210
+
211
+ optimalgo = optim_builder(
212
+ name_algo,
213
+ prior=prior,
214
+ data_fidelity=data_fidelity,
215
+ max_iter=max_iter,
216
+ crit_conv="residual",
217
+ thres_conv=1e-11,
218
+ verbose=True,
219
+ params_algo=params_algo,
220
+ early_stop=True,
221
+ g_first=g_first,
222
+ )
223
+
224
+ # Run the optimization algorithm
225
+ x = optimalgo(y, physics)
226
+
227
+ assert optimalgo.has_converged
228
+
229
+ # Compute the subdifferential of the regularisation at the limit point of the algorithm.
230
+
231
+ if name_algo == "HQS":
232
+ # In this case, the algorithm does not converge to the minimum of :math:`\lambda f+g` but to that of
233
+ # :math:`\lambda M_{\lambda \tau f}+g` where :math:` M_{\lambda \tau f}` denotes the Moreau envelope of :math:`f` with parameter :math:`\lambda \tau`.
234
+ # Beware, these are not fetch automatically here but handwritten in the test.
235
+ # The optimality condition is then :math:`0 \in \lambda M_{\lambda \tau f}(x)+\partial g(x)`
236
+ if not g_first:
237
+ subdiff = prior.grad(x)
238
+ moreau_grad = (
239
+ x - data_fidelity.prox(x, y, physics, gamma=lamb * stepsize)
240
+ ) / (
241
+ lamb * stepsize
242
+ ) # Gradient of the moreau envelope
243
+ assert torch.allclose(
244
+ lamb * moreau_grad, -subdiff, atol=1e-8
245
+ ) # Optimality condition
246
+ else:
247
+ subdiff = lamb * data_fidelity.grad(x, y, physics)
248
+ moreau_grad = (
249
+ x - prior.prox(x, gamma=stepsize)
250
+ ) / stepsize # Gradient of the moreau envelope
251
+ assert torch.allclose(
252
+ moreau_grad, -subdiff, atol=1e-8
253
+ ) # Optimality condition
254
+ else:
255
+ subdiff = prior.grad(x)
256
+ # In this case, the algorithm converges to the minimum of :math:`\lambda f+g`.
257
+ # The optimality condition is then :math:`0 \in \lambda \nabla f(x)+\partial g(x)`
258
+ grad_deepinv = data_fidelity.grad(x, y, physics)
259
+ assert torch.allclose(
260
+ lamb * grad_deepinv, -subdiff, atol=1e-8
261
+ ) # Optimality condition
262
+
263
+
264
+ def test_denoiser(imsize, dummy_dataset, device):
265
+ dataloader = DataLoader(
266
+ dummy_dataset, batch_size=1, shuffle=False, num_workers=0
267
+ ) # 1. Generate a dummy dataset
268
+ test_sample = next(iter(dataloader))
269
+
270
+ physics = dinv.physics.Denoising() # 2. Set a physical experiment (here, denoising)
271
+ y = physics(test_sample).type(test_sample.dtype).to(device)
272
+
273
+ ths = 2.0
274
+
275
+ model = dinv.models.TGV(n_it_max=5000, verbose=True, crit=1e-4)
276
+
277
+ x = model(y, ths) # 3. Apply the model we want to test
278
+
279
+ # For debugging
280
+ # plot = False
281
+ # if plot:
282
+ # imgs = []
283
+ # imgs.append(torch2cpu(y[0, :, :, :].unsqueeze(0)))
284
+ # imgs.append(torch2cpu(x[0, :, :, :].unsqueeze(0)))
285
+ #
286
+ # titles = ["Input", "Output"]
287
+ # num_im = 2
288
+ # plot_debug(
289
+ # imgs, shape=(1, num_im), titles=titles, row_order=True, save_dir=None
290
+ # )
291
+
292
+ assert model.has_converged
293
+
294
+
295
+ # GD not implemented for this one
296
+ @pytest.mark.parametrize("pnp_algo", ["PGD", "HQS", "DRS", "ADMM", "CP"])
297
+ def test_pnp_algo(pnp_algo, imsize, dummy_dataset, device):
298
+ pytest.importorskip("pytorch_wavelets")
299
+
300
+ # 1. Generate a dummy dataset
301
+ dataloader = DataLoader(dummy_dataset, batch_size=1, shuffle=False, num_workers=0)
302
+ test_sample = next(iter(dataloader)).to(device)
303
+
304
+ # 2. Set a physical experiment (here, deblurring)
305
+ physics = dinv.physics.Blur(
306
+ dinv.physics.blur.gaussian_blur(sigma=(2, 0.1), angle=45.0), device=device
307
+ )
308
+ y = physics(test_sample)
309
+ max_iter = 1000
310
+ # Note: results are better for sigma_denoiser=0.001, but it takes longer to run.
311
+ sigma_denoiser = torch.tensor([[0.1]])
312
+ stepsize = 1.0
313
+ lamb = 1.0
314
+
315
+ data_fidelity = L2()
316
+
317
+ # here the prior model is common for all iterations
318
+ prior = PnP(denoiser=dinv.models.WaveletPrior(wv="db8", level=3, device=device))
319
+
320
+ stepsize_dual = 1.0 if pnp_algo == "CP" else None
321
+ params_algo = {
322
+ "stepsize": stepsize,
323
+ "g_param": sigma_denoiser,
324
+ "lambda": lamb,
325
+ "stepsize_dual": stepsize_dual,
326
+ }
327
+
328
+ custom_init = custom_init_CP if pnp_algo == "CP" else None
329
+
330
+ pnp = optim_builder(
331
+ pnp_algo,
332
+ prior=prior,
333
+ data_fidelity=data_fidelity,
334
+ max_iter=max_iter,
335
+ thres_conv=1e-4,
336
+ verbose=True,
337
+ params_algo=params_algo,
338
+ early_stop=True,
339
+ custom_init=custom_init,
340
+ )
341
+
342
+ x = pnp(y, physics)
343
+
344
+ # # For debugging # Remark: to get nice results, lower sigma_denoiser to 0.001
345
+ # plot = True
346
+ # if plot:
347
+ # imgs = []
348
+ # imgs.append(torch2cpu(y[0, :, :, :].unsqueeze(0)))
349
+ # imgs.append(torch2cpu(x[0, :, :, :].unsqueeze(0)))
350
+ # imgs.append(torch2cpu(test_sample[0, :, :, :].unsqueeze(0)))
351
+ #
352
+ # titles = ["Input", "Output", "Groundtruth"]
353
+ # num_im = 3
354
+ # plot_debug(
355
+ # imgs, shape=(1, num_im), titles=titles, row_order=True, save_dir=None
356
+ # )
357
+
358
+ assert pnp.has_converged
359
+
360
+
361
+ @pytest.mark.parametrize("pnp_algo", ["PGD", "HQS", "DRS", "ADMM", "CP"])
362
+ def test_priors_algo(pnp_algo, imsize, dummy_dataset, device):
363
+ # for prior_name in ['L1Prior', 'Tikhonov']:
364
+ for prior_name in ["L1Prior", "Tikhonov"]:
365
+ # 1. Generate a dummy dataset
366
+ dataloader = DataLoader(
367
+ dummy_dataset, batch_size=1, shuffle=False, num_workers=0
368
+ )
369
+ test_sample = next(iter(dataloader)).to(device)
370
+
371
+ # 2. Set a physical experiment (here, deblurring)
372
+ physics = dinv.physics.Blur(
373
+ dinv.physics.blur.gaussian_blur(sigma=(2, 0.1), angle=45.0), device=device
374
+ )
375
+ y = physics(test_sample)
376
+ max_iter = 1000
377
+ # Note: results are better for sigma_denoiser=0.001, but it takes longer to run.
378
+ # sigma_denoiser = torch.tensor([[0.1]])
379
+ sigma_denoiser = torch.tensor([[1.0]], device=device)
380
+ stepsize = 1.0
381
+ lamb = 1.0
382
+
383
+ data_fidelity = L2()
384
+
385
+ # here the prior model is common for all iterations
386
+ if prior_name == "L1Prior":
387
+ prior = dinv.optim.prior.L1Prior()
388
+ elif prior_name == "Tikhonov":
389
+ prior = dinv.optim.prior.Tikhonov()
390
+
391
+ stepsize_dual = 1.0 if pnp_algo == "CP" else None
392
+ params_algo = {
393
+ "stepsize": stepsize,
394
+ "g_param": sigma_denoiser,
395
+ "lambda": lamb,
396
+ "stepsize_dual": stepsize_dual,
397
+ }
398
+
399
+ custom_init = custom_init_CP if pnp_algo == "CP" else None
400
+
401
+ opt_algo = optim_builder(
402
+ pnp_algo,
403
+ prior=prior,
404
+ data_fidelity=data_fidelity,
405
+ max_iter=max_iter,
406
+ thres_conv=1e-4,
407
+ verbose=True,
408
+ params_algo=params_algo,
409
+ early_stop=True,
410
+ custom_init=custom_init,
411
+ )
412
+
413
+ x = opt_algo(y, physics)
414
+
415
+ # # For debugging # Remark: to get nice results, lower sigma_denoiser to 0.001
416
+ # plot = True
417
+ # if plot:
418
+ # imgs = []
419
+ # imgs.append(torch2cpu(y[0, :, :, :].unsqueeze(0)))
420
+ # imgs.append(torch2cpu(x[0, :, :, :].unsqueeze(0)))
421
+ # imgs.append(torch2cpu(test_sample[0, :, :, :].unsqueeze(0)))
422
+ #
423
+ # titles = ["Input", "Output", "Groundtruth"]
424
+ # num_im = 3
425
+ # plot_debug(
426
+ # imgs, shape=(1, num_im), titles=titles, row_order=True, save_dir=None
427
+ # )
428
+
429
+ assert opt_algo.has_converged
430
+
431
+
432
+ @pytest.mark.parametrize("red_algo", ["GD", "PGD"])
433
+ def test_red_algo(red_algo, imsize, dummy_dataset, device):
434
+ # This test uses WaveletPrior, which requires pytorch_wavelets
435
+ # TODO: we could use a dummy trainable denoiser with a linear layer instead
436
+ pytest.importorskip("pytorch_wavelets")
437
+
438
+ # 1. Generate a dummy dataset
439
+ dataloader = DataLoader(dummy_dataset, batch_size=1, shuffle=False, num_workers=0)
440
+ test_sample = next(iter(dataloader)).to(device)
441
+
442
+ # 2. Set a physical experiment (here, deblurring)
443
+ physics = dinv.physics.Blur(
444
+ dinv.physics.blur.gaussian_blur(sigma=(2, 0.1), angle=45.0), device=device
445
+ )
446
+ y = physics(test_sample)
447
+ max_iter = 1000
448
+ sigma_denoiser = 1.0 # Note: results are better for sigma_denoiser=0.001, but it takes longer to run.
449
+ stepsize = 1.0
450
+ lamb = 1.0
451
+
452
+ data_fidelity = L2()
453
+
454
+ prior = RED(denoiser=dinv.models.WaveletPrior(wv="db8", level=3, device=device))
455
+
456
+ params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
457
+
458
+ red = optim_builder(
459
+ red_algo,
460
+ prior=prior,
461
+ data_fidelity=data_fidelity,
462
+ max_iter=max_iter,
463
+ thres_conv=1e-4,
464
+ verbose=True,
465
+ params_algo=params_algo,
466
+ early_stop=True,
467
+ g_first=True,
468
+ )
469
+
470
+ red(y, physics)
471
+
472
+ assert red.has_converged
473
+
474
+
475
+ def test_CP_K(imsize, dummy_dataset, device):
476
+ r"""
477
+ This test checks that the CP algorithm converges to the solution of the following problem:
478
+
479
+ .. math::
480
+
481
+ \min_x \lambda a(x) + b(Kx)
482
+
483
+
484
+ where :math:`a` and :math:`b` are functions and :math:`K` is a linear operator. In this setting, we test both for
485
+ :math:`a(x) = d(Ax-y)` and :math:`b(z) = g(z)`, and for :math:`a(x) = g(x)` and :math:`b(z) = f(z-y)`.
486
+ """
487
+
488
+ for g_first in [True, False]:
489
+ # Define two points
490
+ x = torch.tensor([[[10], [10]]], dtype=torch.float64).to(device)
491
+
492
+ # Create a measurement operator
493
+ Id_forward = lambda v: v
494
+ Id_adjoint = lambda v: v
495
+
496
+ # Define the physics model associated to this operator
497
+ physics = dinv.physics.LinearPhysics(A=Id_forward, A_adjoint=Id_adjoint)
498
+ y = physics(x)
499
+
500
+ data_fidelity = L2() # The data fidelity term
501
+
502
+ def prior_g(x, *args):
503
+ ths = 1.0
504
+ return ths * torch.norm(x.view(x.shape[0], -1), p=1, dim=-1)
505
+
506
+ prior = Prior(g=prior_g) # The prior term
507
+
508
+ # Define a linear operator
509
+ K = torch.tensor([[2, 1], [-1, 0.5]], dtype=torch.float64).to(device)
510
+ K_forward = lambda v: K @ v
511
+ K_adjoint = lambda v: K.transpose(0, 1) @ v
512
+
513
+ # stepsize = 0.9 / physics.compute_norm(x, tol=1e-4).item()
514
+ stepsize = 0.9 / torch.linalg.norm(K, ord=2).item() ** 2
515
+ reg_param = 1.0
516
+ stepsize_dual = 1.0
517
+
518
+ lamb = 1.5
519
+ max_iter = 1000
520
+
521
+ params_algo = {
522
+ "stepsize": stepsize,
523
+ "g_param": reg_param,
524
+ "lambda": lamb,
525
+ "stepsize_dual": stepsize_dual,
526
+ "K": K_forward,
527
+ "K_adjoint": K_adjoint,
528
+ }
529
+
530
+ optimalgo = optim_builder(
531
+ "CP",
532
+ prior=prior,
533
+ data_fidelity=data_fidelity,
534
+ max_iter=max_iter,
535
+ crit_conv="residual",
536
+ thres_conv=1e-11,
537
+ verbose=True,
538
+ params_algo=params_algo,
539
+ early_stop=True,
540
+ g_first=g_first,
541
+ custom_init=custom_init_CP,
542
+ )
543
+
544
+ # Run the optimization algorithm
545
+ x = optimalgo(y, physics)
546
+
547
+ print("g_first: ", g_first)
548
+ assert optimalgo.has_converged
549
+
550
+ # Compute the subdifferential of the regularisation at the limit point of the algorithm.
551
+ if not g_first:
552
+ subdiff = prior.grad(x, 0)
553
+
554
+ grad_deepinv = K_adjoint(
555
+ data_fidelity.grad(K_forward(x), y, physics)
556
+ ) # This test is only valid for differentiable data fidelity terms.
557
+ assert torch.allclose(
558
+ lamb * grad_deepinv, -subdiff, atol=1e-12
559
+ ) # Optimality condition
560
+
561
+ else:
562
+ subdiff = K_adjoint(prior.grad(K_forward(x), 0))
563
+
564
+ grad_deepinv = data_fidelity.grad(x, y, physics)
565
+ assert torch.allclose(
566
+ lamb * grad_deepinv, -subdiff, atol=1e-12
567
+ ) # Optimality condition
568
+
569
+
570
+ def test_CP_datafidsplit(imsize, dummy_dataset, device):
571
+ r"""
572
+ This test checks that the CP algorithm converges to the solution of the following problem:
573
+
574
+ .. math::
575
+
576
+ \min_x \lambda d(Ax,y) + g(x)
577
+
578
+
579
+ where :math:`d` is a distance function and :math:`g` is a prior term.
580
+ """
581
+
582
+ g_first = False
583
+ # Define two points
584
+ x = torch.tensor([[[10], [10]]], dtype=torch.float64).to(device)
585
+
586
+ # Create a measurement operator
587
+ A = torch.tensor([[2, 1], [-1, 0.5]], dtype=torch.float64).to(device)
588
+ A_forward = lambda v: A @ v
589
+ A_adjoint = lambda v: A.transpose(0, 1) @ v
590
+
591
+ # Define the physics model associated to this operator
592
+ physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
593
+ y = physics(x)
594
+
595
+ data_fidelity = L2() # The data fidelity term
596
+
597
+ def prior_g(x, *args):
598
+ ths = 1.0
599
+ return ths * torch.norm(x.view(x.shape[0], -1), p=1, dim=-1)
600
+
601
+ prior = Prior(g=prior_g) # The prior term
602
+
603
+ # stepsize = 0.9 / physics.compute_norm(x, tol=1e-4).item()
604
+ stepsize = 0.9 / torch.linalg.norm(A, ord=2).item() ** 2
605
+ reg_param = 1.0
606
+ stepsize_dual = 1.0
607
+
608
+ lamb = 1.5
609
+ max_iter = 1000
610
+
611
+ params_algo = {
612
+ "stepsize": stepsize,
613
+ "g_param": reg_param,
614
+ "lambda": lamb,
615
+ "stepsize_dual": stepsize_dual,
616
+ "K": A_forward,
617
+ "K_adjoint": A_adjoint,
618
+ }
619
+
620
+ optimalgo = optim_builder(
621
+ "CP",
622
+ prior=prior,
623
+ data_fidelity=data_fidelity,
624
+ max_iter=max_iter,
625
+ crit_conv="residual",
626
+ thres_conv=1e-11,
627
+ verbose=True,
628
+ params_algo=params_algo,
629
+ early_stop=True,
630
+ g_first=g_first,
631
+ custom_init=custom_init_CP,
632
+ )
633
+
634
+ # Run the optimization algorithm
635
+ x = optimalgo(y, physics)
636
+
637
+ assert optimalgo.has_converged
638
+
639
+ # Compute the subdifferential of the regularisation at the limit point of the algorithm.
640
+ subdiff = prior.grad(x, 0)
641
+
642
+ grad_deepinv = A_adjoint(
643
+ data_fidelity.grad_d(A_forward(x), y)
644
+ ) # This test is only valid for differentiable data fidelity terms.
645
+ assert torch.allclose(
646
+ lamb * grad_deepinv, -subdiff, atol=1e-12
647
+ ) # Optimality condition