deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,647 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.utils.data import DataLoader
|
|
5
|
+
|
|
6
|
+
import deepinv as dinv
|
|
7
|
+
from deepinv.optim import DataFidelity
|
|
8
|
+
from deepinv.optim.data_fidelity import L2, IndicatorL2, L1
|
|
9
|
+
from deepinv.optim.prior import Prior, PnP, RED
|
|
10
|
+
from deepinv.optim.optimizers import optim_builder
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def custom_init_CP(y, physics):
|
|
14
|
+
x_init = physics.A_adjoint(y)
|
|
15
|
+
u_init = y
|
|
16
|
+
return {"est": (x_init, x_init, u_init)}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_data_fidelity_l2(device):
|
|
20
|
+
data_fidelity = L2()
|
|
21
|
+
|
|
22
|
+
# 1. Testing value of the loss for a simple case
|
|
23
|
+
# Define two points
|
|
24
|
+
x = torch.Tensor([[1], [4]]).unsqueeze(0).to(device)
|
|
25
|
+
y = torch.Tensor([[1], [1]]).unsqueeze(0).to(device)
|
|
26
|
+
|
|
27
|
+
# Create a measurement operator
|
|
28
|
+
A = torch.Tensor([[2, 0], [0, 0.5]]).to(device)
|
|
29
|
+
A_forward = lambda v: A @ v
|
|
30
|
+
A_adjoint = lambda v: A.transpose(0, 1) @ v
|
|
31
|
+
|
|
32
|
+
# Define the physics model associated to this operator
|
|
33
|
+
physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
|
|
34
|
+
assert torch.allclose(data_fidelity(x, y, physics), torch.Tensor([1.0]).to(device))
|
|
35
|
+
|
|
36
|
+
# Compute the gradient of f
|
|
37
|
+
grad_dA = data_fidelity.grad(
|
|
38
|
+
x, y, physics
|
|
39
|
+
) # print(grad_dA) gives [[[2.0000], [0.5000]]]
|
|
40
|
+
|
|
41
|
+
# Compute the proximity operator of f
|
|
42
|
+
prox_dA = data_fidelity.prox(
|
|
43
|
+
x, y, physics, gamma=1.0
|
|
44
|
+
) # print(prox_dA) gives [[[0.6000], [3.6000]]]
|
|
45
|
+
|
|
46
|
+
# 2. Testing trivial operations on f and not f\circ A
|
|
47
|
+
gamma = 1.0
|
|
48
|
+
assert torch.allclose(
|
|
49
|
+
data_fidelity.prox_d(x, y, gamma), (x + gamma * y) / (1 + gamma)
|
|
50
|
+
)
|
|
51
|
+
assert torch.allclose(data_fidelity.grad_d(x, y), x - y)
|
|
52
|
+
|
|
53
|
+
# 3. Testing the value of the proximity operator for a nonsymmetric linear operator
|
|
54
|
+
# Create a measurement operator
|
|
55
|
+
B = torch.Tensor([[2, 1], [-1, 0.5]]).to(device)
|
|
56
|
+
B_forward = lambda v: B @ v
|
|
57
|
+
B_adjoint = lambda v: B.transpose(0, 1) @ v
|
|
58
|
+
|
|
59
|
+
# Define the physics model associated to this operator
|
|
60
|
+
physics = dinv.physics.LinearPhysics(A=B_forward, A_adjoint=B_adjoint)
|
|
61
|
+
|
|
62
|
+
# Compute the proximity operator manually (closed form formula)
|
|
63
|
+
Id = torch.eye(2).to(device)
|
|
64
|
+
manual_prox = (Id + gamma * B.transpose(0, 1) @ B).inverse() @ (
|
|
65
|
+
x + gamma * B.transpose(0, 1) @ y
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Compute the deepinv proximity operator
|
|
69
|
+
deepinv_prox = data_fidelity.prox(x, y, physics, gamma=gamma)
|
|
70
|
+
|
|
71
|
+
assert torch.allclose(deepinv_prox, manual_prox)
|
|
72
|
+
|
|
73
|
+
# 4. Testing the gradient of the loss
|
|
74
|
+
grad_deepinv = data_fidelity.grad(x, y, physics)
|
|
75
|
+
grad_manual = B.transpose(0, 1) @ (B @ x - y)
|
|
76
|
+
|
|
77
|
+
assert torch.allclose(grad_deepinv, grad_manual)
|
|
78
|
+
|
|
79
|
+
# 5. Testing the torch autograd implementation of the gradient
|
|
80
|
+
def dummy_torch_l2(x, y):
|
|
81
|
+
return 0.5 * torch.norm((B @ (x - y)).flatten(), p=2, dim=-1) ** 2
|
|
82
|
+
|
|
83
|
+
torch_loss = DataFidelity(d=dummy_torch_l2)
|
|
84
|
+
torch_loss_grad = torch_loss.grad_d(x, y)
|
|
85
|
+
grad_manual = B.transpose(0, 1) @ (B @ (x - y))
|
|
86
|
+
assert torch.allclose(torch_loss_grad, grad_manual)
|
|
87
|
+
|
|
88
|
+
# 6. Testing the torch autograd implementation of the prox
|
|
89
|
+
|
|
90
|
+
torch_loss = DataFidelity(d=dummy_torch_l2)
|
|
91
|
+
torch_loss_prox = torch_loss.prox_d(
|
|
92
|
+
x, y, gamma=gamma, stepsize_inter=0.1, max_iter_inter=1000, tol_inter=1e-6
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
manual_prox = (Id + gamma * B.transpose(0, 1) @ B).inverse() @ (
|
|
96
|
+
x + gamma * B.transpose(0, 1) @ B @ y
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
assert torch.allclose(torch_loss_prox, manual_prox)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_data_fidelity_indicator(device):
|
|
103
|
+
# Define two points
|
|
104
|
+
x = torch.Tensor([[1], [4]]).unsqueeze(0).to(device)
|
|
105
|
+
y = torch.Tensor([[1], [1]]).unsqueeze(0).to(device)
|
|
106
|
+
|
|
107
|
+
# Redefine the data fidelity with a different radius
|
|
108
|
+
radius = 0.5
|
|
109
|
+
data_fidelity = IndicatorL2(radius=radius)
|
|
110
|
+
|
|
111
|
+
# Create a measurement operator
|
|
112
|
+
A = torch.Tensor([[2, 0], [0, 0.5]]).to(device)
|
|
113
|
+
A_forward = lambda v: A @ v
|
|
114
|
+
A_adjoint = lambda v: A.transpose(0, 1) @ v
|
|
115
|
+
|
|
116
|
+
# Define the physics model associated to this operator
|
|
117
|
+
physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
|
|
118
|
+
|
|
119
|
+
# Test values of the loss for points inside and outside the l2 ball
|
|
120
|
+
assert data_fidelity(x, y, physics) == 1e16
|
|
121
|
+
assert data_fidelity(x / 2, y, physics) == 0
|
|
122
|
+
assert data_fidelity.d(x, y, radius=1) == 1e16
|
|
123
|
+
assert data_fidelity.d(x, y, radius=3.1) == 0
|
|
124
|
+
|
|
125
|
+
# 2. Testing trivial operations on f (and not f \circ A)
|
|
126
|
+
x_proj = torch.Tensor([[[1.0], [1 + radius]]]).to(device)
|
|
127
|
+
assert torch.allclose(data_fidelity.prox_d(x, y), x_proj)
|
|
128
|
+
|
|
129
|
+
# 3. Testing the proximity operator of the f \circ A
|
|
130
|
+
data_fidelity = IndicatorL2(radius=0.5)
|
|
131
|
+
|
|
132
|
+
x = torch.Tensor([[1], [4]]).unsqueeze(0).to(device)
|
|
133
|
+
y = torch.Tensor([[1], [1]]).unsqueeze(0).to(device)
|
|
134
|
+
|
|
135
|
+
A = torch.Tensor([[2, 0], [0, 0.5]]).to(device)
|
|
136
|
+
A_forward = lambda v: A @ v
|
|
137
|
+
A_adjoint = lambda v: A.transpose(0, 1) @ v
|
|
138
|
+
physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
|
|
139
|
+
|
|
140
|
+
# Define the physics model associated to this operator
|
|
141
|
+
x_proj = torch.Tensor([[[0.5290], [2.9932]]]).to(device)
|
|
142
|
+
dfb_proj = data_fidelity.prox(x, y, physics, max_iter=1000, crit_conv=1e-12)
|
|
143
|
+
assert torch.allclose(x_proj, dfb_proj, atol=1e-4)
|
|
144
|
+
assert torch.norm(A_forward(dfb_proj) - y) <= radius + 1e-06
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def test_data_fidelity_l1(device):
|
|
148
|
+
# Define two points
|
|
149
|
+
x = torch.Tensor([[[1], [4], [-0.5]]]).to(device)
|
|
150
|
+
y = torch.Tensor([[[1], [1], [1]]]).to(device)
|
|
151
|
+
|
|
152
|
+
data_fidelity = L1()
|
|
153
|
+
assert torch.allclose(data_fidelity.d(x, y), (x - y).abs().sum())
|
|
154
|
+
|
|
155
|
+
A = torch.Tensor([[2, 0, 0], [0, -0.5, 0], [0, 0, 1]]).to(device)
|
|
156
|
+
A_forward = lambda v: A @ v
|
|
157
|
+
A_adjoint = lambda v: A.transpose(0, 1) @ v
|
|
158
|
+
|
|
159
|
+
# Define the physics model associated to this operator
|
|
160
|
+
physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
|
|
161
|
+
Ax = A_forward(x)
|
|
162
|
+
assert data_fidelity(x, y, physics) == (Ax - y).abs().sum()
|
|
163
|
+
|
|
164
|
+
# Check subdifferential
|
|
165
|
+
grad_manual = torch.sign(x - y)
|
|
166
|
+
assert torch.allclose(data_fidelity.grad_d(x, y), grad_manual)
|
|
167
|
+
|
|
168
|
+
# Check prox
|
|
169
|
+
threshold = 0.5
|
|
170
|
+
prox_manual = torch.Tensor([[[1.0], [3.5], [0.0]]]).to(device)
|
|
171
|
+
assert torch.allclose(data_fidelity.prox_d(x, y, threshold), prox_manual)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# we do not test CP (Chambolle-Pock) as we have a dedicated test (due to more specific optimality conditions)
|
|
175
|
+
@pytest.mark.parametrize("name_algo", ["PGD", "ADMM", "DRS", "HQS"])
|
|
176
|
+
def test_optim_algo(name_algo, imsize, dummy_dataset, device):
|
|
177
|
+
for g_first in [True, False]:
|
|
178
|
+
# Define two points
|
|
179
|
+
x = torch.tensor([[[10], [10]]], dtype=torch.float64)
|
|
180
|
+
|
|
181
|
+
# Create a measurement operator
|
|
182
|
+
B = torch.tensor([[2, 1], [-1, 0.5]], dtype=torch.float64)
|
|
183
|
+
B_forward = lambda v: B @ v
|
|
184
|
+
B_adjoint = lambda v: B.transpose(0, 1) @ v
|
|
185
|
+
|
|
186
|
+
# Define the physics model associated to this operator
|
|
187
|
+
physics = dinv.physics.LinearPhysics(A=B_forward, A_adjoint=B_adjoint)
|
|
188
|
+
y = physics(x)
|
|
189
|
+
|
|
190
|
+
data_fidelity = L2() # The data fidelity term
|
|
191
|
+
|
|
192
|
+
def prior_g(x, *args):
|
|
193
|
+
ths = 0.1
|
|
194
|
+
return ths * torch.norm(x.view(x.shape[0], -1), p=1, dim=-1)
|
|
195
|
+
|
|
196
|
+
prior = Prior(g=prior_g) # The prior term
|
|
197
|
+
|
|
198
|
+
if (
|
|
199
|
+
name_algo == "CP"
|
|
200
|
+
): # In the case of primal-dual, stepsizes need to be bounded as reg_param*stepsize < 1/physics.compute_norm(x, tol=1e-4).item()
|
|
201
|
+
stepsize = 0.9 / physics.compute_norm(x, tol=1e-4).item()
|
|
202
|
+
sigma = 1.0
|
|
203
|
+
else: # Note that not all other algos need such constraints on parameters, but we use these to check that the computations are correct
|
|
204
|
+
stepsize = 0.9 / physics.compute_norm(x, tol=1e-4).item()
|
|
205
|
+
sigma = None
|
|
206
|
+
|
|
207
|
+
lamb = 1.1
|
|
208
|
+
max_iter = 1000
|
|
209
|
+
params_algo = {"stepsize": stepsize, "lambda": lamb, "sigma": sigma}
|
|
210
|
+
|
|
211
|
+
optimalgo = optim_builder(
|
|
212
|
+
name_algo,
|
|
213
|
+
prior=prior,
|
|
214
|
+
data_fidelity=data_fidelity,
|
|
215
|
+
max_iter=max_iter,
|
|
216
|
+
crit_conv="residual",
|
|
217
|
+
thres_conv=1e-11,
|
|
218
|
+
verbose=True,
|
|
219
|
+
params_algo=params_algo,
|
|
220
|
+
early_stop=True,
|
|
221
|
+
g_first=g_first,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Run the optimization algorithm
|
|
225
|
+
x = optimalgo(y, physics)
|
|
226
|
+
|
|
227
|
+
assert optimalgo.has_converged
|
|
228
|
+
|
|
229
|
+
# Compute the subdifferential of the regularisation at the limit point of the algorithm.
|
|
230
|
+
|
|
231
|
+
if name_algo == "HQS":
|
|
232
|
+
# In this case, the algorithm does not converge to the minimum of :math:`\lambda f+g` but to that of
|
|
233
|
+
# :math:`\lambda M_{\lambda \tau f}+g` where :math:` M_{\lambda \tau f}` denotes the Moreau envelope of :math:`f` with parameter :math:`\lambda \tau`.
|
|
234
|
+
# Beware, these are not fetch automatically here but handwritten in the test.
|
|
235
|
+
# The optimality condition is then :math:`0 \in \lambda M_{\lambda \tau f}(x)+\partial g(x)`
|
|
236
|
+
if not g_first:
|
|
237
|
+
subdiff = prior.grad(x)
|
|
238
|
+
moreau_grad = (
|
|
239
|
+
x - data_fidelity.prox(x, y, physics, gamma=lamb * stepsize)
|
|
240
|
+
) / (
|
|
241
|
+
lamb * stepsize
|
|
242
|
+
) # Gradient of the moreau envelope
|
|
243
|
+
assert torch.allclose(
|
|
244
|
+
lamb * moreau_grad, -subdiff, atol=1e-8
|
|
245
|
+
) # Optimality condition
|
|
246
|
+
else:
|
|
247
|
+
subdiff = lamb * data_fidelity.grad(x, y, physics)
|
|
248
|
+
moreau_grad = (
|
|
249
|
+
x - prior.prox(x, gamma=stepsize)
|
|
250
|
+
) / stepsize # Gradient of the moreau envelope
|
|
251
|
+
assert torch.allclose(
|
|
252
|
+
moreau_grad, -subdiff, atol=1e-8
|
|
253
|
+
) # Optimality condition
|
|
254
|
+
else:
|
|
255
|
+
subdiff = prior.grad(x)
|
|
256
|
+
# In this case, the algorithm converges to the minimum of :math:`\lambda f+g`.
|
|
257
|
+
# The optimality condition is then :math:`0 \in \lambda \nabla f(x)+\partial g(x)`
|
|
258
|
+
grad_deepinv = data_fidelity.grad(x, y, physics)
|
|
259
|
+
assert torch.allclose(
|
|
260
|
+
lamb * grad_deepinv, -subdiff, atol=1e-8
|
|
261
|
+
) # Optimality condition
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def test_denoiser(imsize, dummy_dataset, device):
|
|
265
|
+
dataloader = DataLoader(
|
|
266
|
+
dummy_dataset, batch_size=1, shuffle=False, num_workers=0
|
|
267
|
+
) # 1. Generate a dummy dataset
|
|
268
|
+
test_sample = next(iter(dataloader))
|
|
269
|
+
|
|
270
|
+
physics = dinv.physics.Denoising() # 2. Set a physical experiment (here, denoising)
|
|
271
|
+
y = physics(test_sample).type(test_sample.dtype).to(device)
|
|
272
|
+
|
|
273
|
+
ths = 2.0
|
|
274
|
+
|
|
275
|
+
model = dinv.models.TGV(n_it_max=5000, verbose=True, crit=1e-4)
|
|
276
|
+
|
|
277
|
+
x = model(y, ths) # 3. Apply the model we want to test
|
|
278
|
+
|
|
279
|
+
# For debugging
|
|
280
|
+
# plot = False
|
|
281
|
+
# if plot:
|
|
282
|
+
# imgs = []
|
|
283
|
+
# imgs.append(torch2cpu(y[0, :, :, :].unsqueeze(0)))
|
|
284
|
+
# imgs.append(torch2cpu(x[0, :, :, :].unsqueeze(0)))
|
|
285
|
+
#
|
|
286
|
+
# titles = ["Input", "Output"]
|
|
287
|
+
# num_im = 2
|
|
288
|
+
# plot_debug(
|
|
289
|
+
# imgs, shape=(1, num_im), titles=titles, row_order=True, save_dir=None
|
|
290
|
+
# )
|
|
291
|
+
|
|
292
|
+
assert model.has_converged
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# GD not implemented for this one
|
|
296
|
+
@pytest.mark.parametrize("pnp_algo", ["PGD", "HQS", "DRS", "ADMM", "CP"])
|
|
297
|
+
def test_pnp_algo(pnp_algo, imsize, dummy_dataset, device):
|
|
298
|
+
pytest.importorskip("pytorch_wavelets")
|
|
299
|
+
|
|
300
|
+
# 1. Generate a dummy dataset
|
|
301
|
+
dataloader = DataLoader(dummy_dataset, batch_size=1, shuffle=False, num_workers=0)
|
|
302
|
+
test_sample = next(iter(dataloader)).to(device)
|
|
303
|
+
|
|
304
|
+
# 2. Set a physical experiment (here, deblurring)
|
|
305
|
+
physics = dinv.physics.Blur(
|
|
306
|
+
dinv.physics.blur.gaussian_blur(sigma=(2, 0.1), angle=45.0), device=device
|
|
307
|
+
)
|
|
308
|
+
y = physics(test_sample)
|
|
309
|
+
max_iter = 1000
|
|
310
|
+
# Note: results are better for sigma_denoiser=0.001, but it takes longer to run.
|
|
311
|
+
sigma_denoiser = torch.tensor([[0.1]])
|
|
312
|
+
stepsize = 1.0
|
|
313
|
+
lamb = 1.0
|
|
314
|
+
|
|
315
|
+
data_fidelity = L2()
|
|
316
|
+
|
|
317
|
+
# here the prior model is common for all iterations
|
|
318
|
+
prior = PnP(denoiser=dinv.models.WaveletPrior(wv="db8", level=3, device=device))
|
|
319
|
+
|
|
320
|
+
stepsize_dual = 1.0 if pnp_algo == "CP" else None
|
|
321
|
+
params_algo = {
|
|
322
|
+
"stepsize": stepsize,
|
|
323
|
+
"g_param": sigma_denoiser,
|
|
324
|
+
"lambda": lamb,
|
|
325
|
+
"stepsize_dual": stepsize_dual,
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
custom_init = custom_init_CP if pnp_algo == "CP" else None
|
|
329
|
+
|
|
330
|
+
pnp = optim_builder(
|
|
331
|
+
pnp_algo,
|
|
332
|
+
prior=prior,
|
|
333
|
+
data_fidelity=data_fidelity,
|
|
334
|
+
max_iter=max_iter,
|
|
335
|
+
thres_conv=1e-4,
|
|
336
|
+
verbose=True,
|
|
337
|
+
params_algo=params_algo,
|
|
338
|
+
early_stop=True,
|
|
339
|
+
custom_init=custom_init,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
x = pnp(y, physics)
|
|
343
|
+
|
|
344
|
+
# # For debugging # Remark: to get nice results, lower sigma_denoiser to 0.001
|
|
345
|
+
# plot = True
|
|
346
|
+
# if plot:
|
|
347
|
+
# imgs = []
|
|
348
|
+
# imgs.append(torch2cpu(y[0, :, :, :].unsqueeze(0)))
|
|
349
|
+
# imgs.append(torch2cpu(x[0, :, :, :].unsqueeze(0)))
|
|
350
|
+
# imgs.append(torch2cpu(test_sample[0, :, :, :].unsqueeze(0)))
|
|
351
|
+
#
|
|
352
|
+
# titles = ["Input", "Output", "Groundtruth"]
|
|
353
|
+
# num_im = 3
|
|
354
|
+
# plot_debug(
|
|
355
|
+
# imgs, shape=(1, num_im), titles=titles, row_order=True, save_dir=None
|
|
356
|
+
# )
|
|
357
|
+
|
|
358
|
+
assert pnp.has_converged
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
@pytest.mark.parametrize("pnp_algo", ["PGD", "HQS", "DRS", "ADMM", "CP"])
|
|
362
|
+
def test_priors_algo(pnp_algo, imsize, dummy_dataset, device):
|
|
363
|
+
# for prior_name in ['L1Prior', 'Tikhonov']:
|
|
364
|
+
for prior_name in ["L1Prior", "Tikhonov"]:
|
|
365
|
+
# 1. Generate a dummy dataset
|
|
366
|
+
dataloader = DataLoader(
|
|
367
|
+
dummy_dataset, batch_size=1, shuffle=False, num_workers=0
|
|
368
|
+
)
|
|
369
|
+
test_sample = next(iter(dataloader)).to(device)
|
|
370
|
+
|
|
371
|
+
# 2. Set a physical experiment (here, deblurring)
|
|
372
|
+
physics = dinv.physics.Blur(
|
|
373
|
+
dinv.physics.blur.gaussian_blur(sigma=(2, 0.1), angle=45.0), device=device
|
|
374
|
+
)
|
|
375
|
+
y = physics(test_sample)
|
|
376
|
+
max_iter = 1000
|
|
377
|
+
# Note: results are better for sigma_denoiser=0.001, but it takes longer to run.
|
|
378
|
+
# sigma_denoiser = torch.tensor([[0.1]])
|
|
379
|
+
sigma_denoiser = torch.tensor([[1.0]], device=device)
|
|
380
|
+
stepsize = 1.0
|
|
381
|
+
lamb = 1.0
|
|
382
|
+
|
|
383
|
+
data_fidelity = L2()
|
|
384
|
+
|
|
385
|
+
# here the prior model is common for all iterations
|
|
386
|
+
if prior_name == "L1Prior":
|
|
387
|
+
prior = dinv.optim.prior.L1Prior()
|
|
388
|
+
elif prior_name == "Tikhonov":
|
|
389
|
+
prior = dinv.optim.prior.Tikhonov()
|
|
390
|
+
|
|
391
|
+
stepsize_dual = 1.0 if pnp_algo == "CP" else None
|
|
392
|
+
params_algo = {
|
|
393
|
+
"stepsize": stepsize,
|
|
394
|
+
"g_param": sigma_denoiser,
|
|
395
|
+
"lambda": lamb,
|
|
396
|
+
"stepsize_dual": stepsize_dual,
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
custom_init = custom_init_CP if pnp_algo == "CP" else None
|
|
400
|
+
|
|
401
|
+
opt_algo = optim_builder(
|
|
402
|
+
pnp_algo,
|
|
403
|
+
prior=prior,
|
|
404
|
+
data_fidelity=data_fidelity,
|
|
405
|
+
max_iter=max_iter,
|
|
406
|
+
thres_conv=1e-4,
|
|
407
|
+
verbose=True,
|
|
408
|
+
params_algo=params_algo,
|
|
409
|
+
early_stop=True,
|
|
410
|
+
custom_init=custom_init,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
x = opt_algo(y, physics)
|
|
414
|
+
|
|
415
|
+
# # For debugging # Remark: to get nice results, lower sigma_denoiser to 0.001
|
|
416
|
+
# plot = True
|
|
417
|
+
# if plot:
|
|
418
|
+
# imgs = []
|
|
419
|
+
# imgs.append(torch2cpu(y[0, :, :, :].unsqueeze(0)))
|
|
420
|
+
# imgs.append(torch2cpu(x[0, :, :, :].unsqueeze(0)))
|
|
421
|
+
# imgs.append(torch2cpu(test_sample[0, :, :, :].unsqueeze(0)))
|
|
422
|
+
#
|
|
423
|
+
# titles = ["Input", "Output", "Groundtruth"]
|
|
424
|
+
# num_im = 3
|
|
425
|
+
# plot_debug(
|
|
426
|
+
# imgs, shape=(1, num_im), titles=titles, row_order=True, save_dir=None
|
|
427
|
+
# )
|
|
428
|
+
|
|
429
|
+
assert opt_algo.has_converged
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
@pytest.mark.parametrize("red_algo", ["GD", "PGD"])
|
|
433
|
+
def test_red_algo(red_algo, imsize, dummy_dataset, device):
|
|
434
|
+
# This test uses WaveletPrior, which requires pytorch_wavelets
|
|
435
|
+
# TODO: we could use a dummy trainable denoiser with a linear layer instead
|
|
436
|
+
pytest.importorskip("pytorch_wavelets")
|
|
437
|
+
|
|
438
|
+
# 1. Generate a dummy dataset
|
|
439
|
+
dataloader = DataLoader(dummy_dataset, batch_size=1, shuffle=False, num_workers=0)
|
|
440
|
+
test_sample = next(iter(dataloader)).to(device)
|
|
441
|
+
|
|
442
|
+
# 2. Set a physical experiment (here, deblurring)
|
|
443
|
+
physics = dinv.physics.Blur(
|
|
444
|
+
dinv.physics.blur.gaussian_blur(sigma=(2, 0.1), angle=45.0), device=device
|
|
445
|
+
)
|
|
446
|
+
y = physics(test_sample)
|
|
447
|
+
max_iter = 1000
|
|
448
|
+
sigma_denoiser = 1.0 # Note: results are better for sigma_denoiser=0.001, but it takes longer to run.
|
|
449
|
+
stepsize = 1.0
|
|
450
|
+
lamb = 1.0
|
|
451
|
+
|
|
452
|
+
data_fidelity = L2()
|
|
453
|
+
|
|
454
|
+
prior = RED(denoiser=dinv.models.WaveletPrior(wv="db8", level=3, device=device))
|
|
455
|
+
|
|
456
|
+
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
|
|
457
|
+
|
|
458
|
+
red = optim_builder(
|
|
459
|
+
red_algo,
|
|
460
|
+
prior=prior,
|
|
461
|
+
data_fidelity=data_fidelity,
|
|
462
|
+
max_iter=max_iter,
|
|
463
|
+
thres_conv=1e-4,
|
|
464
|
+
verbose=True,
|
|
465
|
+
params_algo=params_algo,
|
|
466
|
+
early_stop=True,
|
|
467
|
+
g_first=True,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
red(y, physics)
|
|
471
|
+
|
|
472
|
+
assert red.has_converged
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def test_CP_K(imsize, dummy_dataset, device):
|
|
476
|
+
r"""
|
|
477
|
+
This test checks that the CP algorithm converges to the solution of the following problem:
|
|
478
|
+
|
|
479
|
+
.. math::
|
|
480
|
+
|
|
481
|
+
\min_x \lambda a(x) + b(Kx)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
where :math:`a` and :math:`b` are functions and :math:`K` is a linear operator. In this setting, we test both for
|
|
485
|
+
:math:`a(x) = d(Ax-y)` and :math:`b(z) = g(z)`, and for :math:`a(x) = g(x)` and :math:`b(z) = f(z-y)`.
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
for g_first in [True, False]:
|
|
489
|
+
# Define two points
|
|
490
|
+
x = torch.tensor([[[10], [10]]], dtype=torch.float64).to(device)
|
|
491
|
+
|
|
492
|
+
# Create a measurement operator
|
|
493
|
+
Id_forward = lambda v: v
|
|
494
|
+
Id_adjoint = lambda v: v
|
|
495
|
+
|
|
496
|
+
# Define the physics model associated to this operator
|
|
497
|
+
physics = dinv.physics.LinearPhysics(A=Id_forward, A_adjoint=Id_adjoint)
|
|
498
|
+
y = physics(x)
|
|
499
|
+
|
|
500
|
+
data_fidelity = L2() # The data fidelity term
|
|
501
|
+
|
|
502
|
+
def prior_g(x, *args):
|
|
503
|
+
ths = 1.0
|
|
504
|
+
return ths * torch.norm(x.view(x.shape[0], -1), p=1, dim=-1)
|
|
505
|
+
|
|
506
|
+
prior = Prior(g=prior_g) # The prior term
|
|
507
|
+
|
|
508
|
+
# Define a linear operator
|
|
509
|
+
K = torch.tensor([[2, 1], [-1, 0.5]], dtype=torch.float64).to(device)
|
|
510
|
+
K_forward = lambda v: K @ v
|
|
511
|
+
K_adjoint = lambda v: K.transpose(0, 1) @ v
|
|
512
|
+
|
|
513
|
+
# stepsize = 0.9 / physics.compute_norm(x, tol=1e-4).item()
|
|
514
|
+
stepsize = 0.9 / torch.linalg.norm(K, ord=2).item() ** 2
|
|
515
|
+
reg_param = 1.0
|
|
516
|
+
stepsize_dual = 1.0
|
|
517
|
+
|
|
518
|
+
lamb = 1.5
|
|
519
|
+
max_iter = 1000
|
|
520
|
+
|
|
521
|
+
params_algo = {
|
|
522
|
+
"stepsize": stepsize,
|
|
523
|
+
"g_param": reg_param,
|
|
524
|
+
"lambda": lamb,
|
|
525
|
+
"stepsize_dual": stepsize_dual,
|
|
526
|
+
"K": K_forward,
|
|
527
|
+
"K_adjoint": K_adjoint,
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
optimalgo = optim_builder(
|
|
531
|
+
"CP",
|
|
532
|
+
prior=prior,
|
|
533
|
+
data_fidelity=data_fidelity,
|
|
534
|
+
max_iter=max_iter,
|
|
535
|
+
crit_conv="residual",
|
|
536
|
+
thres_conv=1e-11,
|
|
537
|
+
verbose=True,
|
|
538
|
+
params_algo=params_algo,
|
|
539
|
+
early_stop=True,
|
|
540
|
+
g_first=g_first,
|
|
541
|
+
custom_init=custom_init_CP,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
# Run the optimization algorithm
|
|
545
|
+
x = optimalgo(y, physics)
|
|
546
|
+
|
|
547
|
+
print("g_first: ", g_first)
|
|
548
|
+
assert optimalgo.has_converged
|
|
549
|
+
|
|
550
|
+
# Compute the subdifferential of the regularisation at the limit point of the algorithm.
|
|
551
|
+
if not g_first:
|
|
552
|
+
subdiff = prior.grad(x, 0)
|
|
553
|
+
|
|
554
|
+
grad_deepinv = K_adjoint(
|
|
555
|
+
data_fidelity.grad(K_forward(x), y, physics)
|
|
556
|
+
) # This test is only valid for differentiable data fidelity terms.
|
|
557
|
+
assert torch.allclose(
|
|
558
|
+
lamb * grad_deepinv, -subdiff, atol=1e-12
|
|
559
|
+
) # Optimality condition
|
|
560
|
+
|
|
561
|
+
else:
|
|
562
|
+
subdiff = K_adjoint(prior.grad(K_forward(x), 0))
|
|
563
|
+
|
|
564
|
+
grad_deepinv = data_fidelity.grad(x, y, physics)
|
|
565
|
+
assert torch.allclose(
|
|
566
|
+
lamb * grad_deepinv, -subdiff, atol=1e-12
|
|
567
|
+
) # Optimality condition
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def test_CP_datafidsplit(imsize, dummy_dataset, device):
|
|
571
|
+
r"""
|
|
572
|
+
This test checks that the CP algorithm converges to the solution of the following problem:
|
|
573
|
+
|
|
574
|
+
.. math::
|
|
575
|
+
|
|
576
|
+
\min_x \lambda d(Ax,y) + g(x)
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
where :math:`d` is a distance function and :math:`g` is a prior term.
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
g_first = False
|
|
583
|
+
# Define two points
|
|
584
|
+
x = torch.tensor([[[10], [10]]], dtype=torch.float64).to(device)
|
|
585
|
+
|
|
586
|
+
# Create a measurement operator
|
|
587
|
+
A = torch.tensor([[2, 1], [-1, 0.5]], dtype=torch.float64).to(device)
|
|
588
|
+
A_forward = lambda v: A @ v
|
|
589
|
+
A_adjoint = lambda v: A.transpose(0, 1) @ v
|
|
590
|
+
|
|
591
|
+
# Define the physics model associated to this operator
|
|
592
|
+
physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
|
|
593
|
+
y = physics(x)
|
|
594
|
+
|
|
595
|
+
data_fidelity = L2() # The data fidelity term
|
|
596
|
+
|
|
597
|
+
def prior_g(x, *args):
|
|
598
|
+
ths = 1.0
|
|
599
|
+
return ths * torch.norm(x.view(x.shape[0], -1), p=1, dim=-1)
|
|
600
|
+
|
|
601
|
+
prior = Prior(g=prior_g) # The prior term
|
|
602
|
+
|
|
603
|
+
# stepsize = 0.9 / physics.compute_norm(x, tol=1e-4).item()
|
|
604
|
+
stepsize = 0.9 / torch.linalg.norm(A, ord=2).item() ** 2
|
|
605
|
+
reg_param = 1.0
|
|
606
|
+
stepsize_dual = 1.0
|
|
607
|
+
|
|
608
|
+
lamb = 1.5
|
|
609
|
+
max_iter = 1000
|
|
610
|
+
|
|
611
|
+
params_algo = {
|
|
612
|
+
"stepsize": stepsize,
|
|
613
|
+
"g_param": reg_param,
|
|
614
|
+
"lambda": lamb,
|
|
615
|
+
"stepsize_dual": stepsize_dual,
|
|
616
|
+
"K": A_forward,
|
|
617
|
+
"K_adjoint": A_adjoint,
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
optimalgo = optim_builder(
|
|
621
|
+
"CP",
|
|
622
|
+
prior=prior,
|
|
623
|
+
data_fidelity=data_fidelity,
|
|
624
|
+
max_iter=max_iter,
|
|
625
|
+
crit_conv="residual",
|
|
626
|
+
thres_conv=1e-11,
|
|
627
|
+
verbose=True,
|
|
628
|
+
params_algo=params_algo,
|
|
629
|
+
early_stop=True,
|
|
630
|
+
g_first=g_first,
|
|
631
|
+
custom_init=custom_init_CP,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
# Run the optimization algorithm
|
|
635
|
+
x = optimalgo(y, physics)
|
|
636
|
+
|
|
637
|
+
assert optimalgo.has_converged
|
|
638
|
+
|
|
639
|
+
# Compute the subdifferential of the regularisation at the limit point of the algorithm.
|
|
640
|
+
subdiff = prior.grad(x, 0)
|
|
641
|
+
|
|
642
|
+
grad_deepinv = A_adjoint(
|
|
643
|
+
data_fidelity.grad_d(A_forward(x), y)
|
|
644
|
+
) # This test is only valid for differentiable data fidelity terms.
|
|
645
|
+
assert torch.allclose(
|
|
646
|
+
lamb * grad_deepinv, -subdiff, atol=1e-12
|
|
647
|
+
) # Optimality condition
|