deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import deepinv as dinv
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Linear forward operators to test (make sure they appear in find_operator as well)
|
|
9
|
+
OPERATORS = [
|
|
10
|
+
"CS",
|
|
11
|
+
"fastCS",
|
|
12
|
+
"inpainting",
|
|
13
|
+
"denoising",
|
|
14
|
+
"deblur_fft",
|
|
15
|
+
"deblur",
|
|
16
|
+
"singlepixel",
|
|
17
|
+
"fast_singlepixel",
|
|
18
|
+
"super_resolution",
|
|
19
|
+
"MRI",
|
|
20
|
+
"pansharpen",
|
|
21
|
+
]
|
|
22
|
+
NONLINEAR_OPERATORS = ["haze", "blind_deblur", "lidar"]
|
|
23
|
+
|
|
24
|
+
NOISES = [
|
|
25
|
+
"Gaussian",
|
|
26
|
+
"Poisson",
|
|
27
|
+
"PoissonGaussian",
|
|
28
|
+
"UniformGaussian",
|
|
29
|
+
"Uniform",
|
|
30
|
+
"Neighbor2Neighbor",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def find_operator(name, device):
|
|
35
|
+
r"""
|
|
36
|
+
Chooses operator
|
|
37
|
+
|
|
38
|
+
:param name: operator name
|
|
39
|
+
:param device: (torch.device) cpu or cuda
|
|
40
|
+
:return: (deepinv.physics.Physics) forward operator.
|
|
41
|
+
"""
|
|
42
|
+
img_size = (3, 16, 8)
|
|
43
|
+
norm = 1
|
|
44
|
+
if name == "CS":
|
|
45
|
+
m = 30
|
|
46
|
+
p = dinv.physics.CompressedSensing(m=m, img_shape=img_size, device=device)
|
|
47
|
+
norm = (
|
|
48
|
+
1 + np.sqrt(np.prod(img_size) / m)
|
|
49
|
+
) ** 2 - 0.75 # Marcenko-Pastur law, second term is a small n correction
|
|
50
|
+
elif name == "fastCS":
|
|
51
|
+
p = dinv.physics.CompressedSensing(
|
|
52
|
+
m=20, fast=True, channelwise=True, img_shape=img_size, device=device
|
|
53
|
+
)
|
|
54
|
+
elif name == "inpainting":
|
|
55
|
+
p = dinv.physics.Inpainting(tensor_size=img_size, mask=0.5, device=device)
|
|
56
|
+
elif name == "MRI":
|
|
57
|
+
img_size = (2, 16, 8)
|
|
58
|
+
p = dinv.physics.MRI(mask=torch.ones(img_size[-2], img_size[-1]), device=device)
|
|
59
|
+
elif name == "Tomography":
|
|
60
|
+
img_size = (1, 16, 16)
|
|
61
|
+
p = dinv.physics.Tomography(
|
|
62
|
+
img_width=img_size[-1], angles=img_size[-1], device=device
|
|
63
|
+
)
|
|
64
|
+
elif name == "denoising":
|
|
65
|
+
p = dinv.physics.Denoising(dinv.physics.GaussianNoise(0.1))
|
|
66
|
+
elif name == "pansharpen":
|
|
67
|
+
img_size = (3, 30, 32)
|
|
68
|
+
p = dinv.physics.Pansharpen(img_size=img_size, device=device)
|
|
69
|
+
norm = 0.4
|
|
70
|
+
elif name == "fast_singlepixel":
|
|
71
|
+
p = dinv.physics.SinglePixelCamera(
|
|
72
|
+
m=20, fast=True, img_shape=img_size, device=device
|
|
73
|
+
)
|
|
74
|
+
elif name == "singlepixel":
|
|
75
|
+
m = 20
|
|
76
|
+
p = dinv.physics.SinglePixelCamera(
|
|
77
|
+
m=m, fast=False, img_shape=img_size, device=device
|
|
78
|
+
)
|
|
79
|
+
norm = (
|
|
80
|
+
1 + np.sqrt(np.prod(img_size) / m)
|
|
81
|
+
) ** 2 - 3.7 # Marcenko-Pastur law, second term is a small n correction
|
|
82
|
+
elif name == "deblur":
|
|
83
|
+
img_size = (3, 17, 19)
|
|
84
|
+
p = dinv.physics.Blur(
|
|
85
|
+
dinv.physics.blur.gaussian_blur(sigma=(2, 0.1), angle=45.0), device=device
|
|
86
|
+
)
|
|
87
|
+
elif name == "deblur_fft":
|
|
88
|
+
img_size = (3, 17, 19)
|
|
89
|
+
p = dinv.physics.BlurFFT(
|
|
90
|
+
img_size=img_size,
|
|
91
|
+
filter=dinv.physics.blur.gaussian_blur(sigma=(0.1, 0.5), angle=45.0),
|
|
92
|
+
device=device,
|
|
93
|
+
)
|
|
94
|
+
elif name == "super_resolution":
|
|
95
|
+
img_size = (1, 32, 32)
|
|
96
|
+
factor = 2
|
|
97
|
+
norm = 1 / factor**2
|
|
98
|
+
p = dinv.physics.Downsampling(img_size=img_size, factor=factor, device=device)
|
|
99
|
+
else:
|
|
100
|
+
raise Exception("The inverse problem chosen doesn't exist")
|
|
101
|
+
return p, img_size, norm
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def find_nonlinear_operator(name, device):
|
|
105
|
+
r"""
|
|
106
|
+
Chooses operator
|
|
107
|
+
|
|
108
|
+
:param name: operator name
|
|
109
|
+
:param device: (torch.device) cpu or cuda
|
|
110
|
+
:return: (deepinv.physics.Physics) forward operator.
|
|
111
|
+
"""
|
|
112
|
+
if name == "blind_deblur":
|
|
113
|
+
x = dinv.utils.TensorList(
|
|
114
|
+
[
|
|
115
|
+
torch.randn(1, 3, 16, 16, device=device),
|
|
116
|
+
torch.randn(1, 1, 3, 3, device=device),
|
|
117
|
+
]
|
|
118
|
+
)
|
|
119
|
+
p = dinv.physics.BlindBlur(kernel_size=3)
|
|
120
|
+
elif name == "haze":
|
|
121
|
+
x = dinv.utils.TensorList(
|
|
122
|
+
[
|
|
123
|
+
torch.randn(1, 1, 16, 16, device=device),
|
|
124
|
+
torch.randn(1, 1, 16, 16, device=device),
|
|
125
|
+
torch.randn(1, device=device),
|
|
126
|
+
]
|
|
127
|
+
)
|
|
128
|
+
p = dinv.physics.Haze()
|
|
129
|
+
elif name == "lidar":
|
|
130
|
+
x = torch.rand(1, 3, 16, 16, device=device)
|
|
131
|
+
p = dinv.physics.SinglePhotonLidar(device=device)
|
|
132
|
+
else:
|
|
133
|
+
raise Exception("The inverse problem chosen doesn't exist")
|
|
134
|
+
return p, x
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@pytest.mark.parametrize("name", OPERATORS)
|
|
138
|
+
def test_operators_adjointness(name, device):
|
|
139
|
+
r"""
|
|
140
|
+
Tests if a linear forward operator has a well defined adjoint.
|
|
141
|
+
Warning: Only test linear operators, non-linear ones will fail the test.
|
|
142
|
+
|
|
143
|
+
:param name: operator name (see find_operator)
|
|
144
|
+
:param imsize: image size tuple in (C, H, W)
|
|
145
|
+
:param device: (torch.device) cpu or cuda:x
|
|
146
|
+
:return: asserts adjointness
|
|
147
|
+
"""
|
|
148
|
+
physics, imsize, _ = find_operator(name, device)
|
|
149
|
+
x = torch.randn(imsize, device=device).unsqueeze(0)
|
|
150
|
+
error = physics.adjointness_test(x).abs()
|
|
151
|
+
assert error < 1e-3
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@pytest.mark.parametrize("name", OPERATORS)
|
|
155
|
+
def test_operators_norm(name, device):
|
|
156
|
+
r"""
|
|
157
|
+
Tests if a linear physics operator has a norm close to 1.
|
|
158
|
+
Warning: Only test linear operators, non-linear ones will fail the test.
|
|
159
|
+
|
|
160
|
+
:param name: operator name (see find_operator)
|
|
161
|
+
:param imsize: (tuple) image size tuple in (C, H, W)
|
|
162
|
+
:param device: (torch.device) cpu or cuda:x
|
|
163
|
+
:return: asserts norm is in (.8,1.2)
|
|
164
|
+
"""
|
|
165
|
+
if name == "singlepixel" or name == "CS":
|
|
166
|
+
device = torch.device("cpu")
|
|
167
|
+
|
|
168
|
+
torch.manual_seed(0)
|
|
169
|
+
physics, imsize, norm_ref = find_operator(name, device)
|
|
170
|
+
x = torch.randn(imsize, device=device).unsqueeze(0)
|
|
171
|
+
norm = physics.compute_norm(x)
|
|
172
|
+
assert torch.abs(norm - norm_ref) < 0.2
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@pytest.mark.parametrize("name", NONLINEAR_OPERATORS)
|
|
176
|
+
def test_nonlinear_operators(name, device):
|
|
177
|
+
r"""
|
|
178
|
+
Tests if a linear physics operator has a norm close to 1.
|
|
179
|
+
Warning: Only test linear operators, non-linear ones will fail the test.
|
|
180
|
+
|
|
181
|
+
:param name: operator name (see find_operator)
|
|
182
|
+
:param device: (torch.device) cpu or cuda:x
|
|
183
|
+
:return: asserts correct shapes
|
|
184
|
+
"""
|
|
185
|
+
physics, x = find_nonlinear_operator(name, device)
|
|
186
|
+
y = physics(x)
|
|
187
|
+
xhat = physics.A_dagger(y)
|
|
188
|
+
assert x.shape == xhat.shape
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@pytest.mark.parametrize("name", OPERATORS)
|
|
192
|
+
def test_pseudo_inverse(name, device):
|
|
193
|
+
r"""
|
|
194
|
+
Tests if a linear physics operator has a well defined pseudoinverse.
|
|
195
|
+
Warning: Only test linear operators, non-linear ones will fail the test.
|
|
196
|
+
|
|
197
|
+
:param name: operator name (see find_operator)
|
|
198
|
+
:param imsize: (tuple) image size tuple in (C, H, W)
|
|
199
|
+
:param device: (torch.device) cpu or cuda:x
|
|
200
|
+
:return: asserts error is less than 1e-3
|
|
201
|
+
"""
|
|
202
|
+
physics, imsize, _ = find_operator(name, device)
|
|
203
|
+
x = torch.randn(imsize, device=device).unsqueeze(0)
|
|
204
|
+
|
|
205
|
+
r = physics.A_adjoint(physics.A(x))
|
|
206
|
+
y = physics.A(r)
|
|
207
|
+
error = (physics.A_dagger(y) - r).flatten().mean().abs()
|
|
208
|
+
assert error < 0.01
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def test_MRI(device):
|
|
212
|
+
r"""
|
|
213
|
+
Test MRI function
|
|
214
|
+
|
|
215
|
+
:param name: operator name (see find_operator)
|
|
216
|
+
:param imsize: (tuple) image size tuple in (C, H, W)
|
|
217
|
+
:param device: (torch.device) cpu or cuda:x
|
|
218
|
+
:return: asserts error is less than 1e-3
|
|
219
|
+
"""
|
|
220
|
+
physics = dinv.physics.MRI(mask=None, device=device, acceleration_factor=4)
|
|
221
|
+
x = torch.randn((2, 320, 320), device=device).unsqueeze(0)
|
|
222
|
+
x2 = physics.A_adjoint(physics.A(x))
|
|
223
|
+
assert x2.shape == x.shape
|
|
224
|
+
|
|
225
|
+
physics = dinv.physics.MRI(mask=None, device=device, acceleration_factor=8, seed=0)
|
|
226
|
+
y1 = physics.A(x)
|
|
227
|
+
physics.reset()
|
|
228
|
+
y2 = physics.A(x)
|
|
229
|
+
if y1.shape == y2.shape:
|
|
230
|
+
error = (y1.abs() - y2.abs()).flatten().mean().abs()
|
|
231
|
+
assert error > 0.0
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def choose_noise(noise_type):
|
|
235
|
+
gain = 0.1
|
|
236
|
+
sigma = 0.1
|
|
237
|
+
if noise_type == "PoissonGaussian":
|
|
238
|
+
noise_model = dinv.physics.PoissonGaussianNoise(sigma=sigma, gain=gain)
|
|
239
|
+
elif noise_type == "Gaussian":
|
|
240
|
+
noise_model = dinv.physics.GaussianNoise(sigma)
|
|
241
|
+
elif noise_type == "UniformGaussian":
|
|
242
|
+
noise_model = dinv.physics.UniformGaussianNoise(
|
|
243
|
+
sigma=sigma
|
|
244
|
+
) # This is equivalent to GaussianNoise when sigma is fixed
|
|
245
|
+
elif noise_type == "Uniform":
|
|
246
|
+
noise_model = dinv.physics.UniformNoise(a=gain)
|
|
247
|
+
elif noise_type == "Poisson":
|
|
248
|
+
noise_model = dinv.physics.PoissonNoise(gain)
|
|
249
|
+
elif noise_type == "Neighbor2Neighbor":
|
|
250
|
+
noise_model = dinv.physics.PoissonNoise(gain)
|
|
251
|
+
else:
|
|
252
|
+
raise Exception("Noise model not found")
|
|
253
|
+
|
|
254
|
+
return noise_model
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@pytest.mark.parametrize("noise_type", NOISES)
|
|
258
|
+
def test_noise(device, noise_type):
|
|
259
|
+
r"""
|
|
260
|
+
Tests noise models.
|
|
261
|
+
"""
|
|
262
|
+
physics = dinv.physics.DecomposablePhysics()
|
|
263
|
+
physics.noise_model = choose_noise(noise_type)
|
|
264
|
+
x = torch.ones((1, 12, 7), device=device).unsqueeze(0)
|
|
265
|
+
|
|
266
|
+
y1 = physics(
|
|
267
|
+
x
|
|
268
|
+
) # Note: this works but not physics.A(x) because only the noise is reset (A does not encapsulate noise)
|
|
269
|
+
assert y1.shape == x.shape
|
|
270
|
+
|
|
271
|
+
if noise_type == "UniformGaussian":
|
|
272
|
+
physics.reset()
|
|
273
|
+
y2 = physics(x)
|
|
274
|
+
error = (y1 - y2).flatten().abs().sum()
|
|
275
|
+
assert error > 0.0
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def test_reset_noise(device):
|
|
279
|
+
r"""
|
|
280
|
+
Tests that the reset function works.
|
|
281
|
+
|
|
282
|
+
:param device: (torch.device) cpu or cuda:x
|
|
283
|
+
:return: asserts error is > 0
|
|
284
|
+
"""
|
|
285
|
+
physics = dinv.physics.DecomposablePhysics()
|
|
286
|
+
physics.noise_model = dinv.physics.UniformGaussianNoise(
|
|
287
|
+
sigma=None
|
|
288
|
+
) # Should be 20/255 (to check)
|
|
289
|
+
x = torch.ones((1, 12, 7), device=device).unsqueeze(0)
|
|
290
|
+
|
|
291
|
+
y1 = physics(
|
|
292
|
+
x
|
|
293
|
+
) # Note: this works but not physics.A(x) because only the noise is reset (A does not encapsulate noise)
|
|
294
|
+
physics.reset()
|
|
295
|
+
y2 = physics(x)
|
|
296
|
+
error = (y1 - y2).flatten().abs().sum()
|
|
297
|
+
assert error > 0.0
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def test_tomography(device):
|
|
301
|
+
r"""
|
|
302
|
+
Tests tomography operator which does not have a numerically precise adjoint.
|
|
303
|
+
|
|
304
|
+
:param device: (torch.device) cpu or cuda:x
|
|
305
|
+
"""
|
|
306
|
+
for circle in [True, False]:
|
|
307
|
+
imsize = (1, 16, 16)
|
|
308
|
+
physics = dinv.physics.Tomography(
|
|
309
|
+
img_width=imsize[-1], angles=imsize[-1], device=device, circle=circle
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
x = torch.randn(imsize, device=device).unsqueeze(0)
|
|
313
|
+
r = physics.A_adjoint(physics.A(x))
|
|
314
|
+
y = physics.A(r)
|
|
315
|
+
error = (physics.A_dagger(y) - r).flatten().mean().abs()
|
|
316
|
+
assert error < 0.2
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch.nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import deepinv as dinv
|
|
6
|
+
from deepinv.optim.data_fidelity import L2
|
|
7
|
+
from deepinv.sampling import ULA, SKRock, DiffPIR, DPS
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
SAMPLING_ALGOS = ["DDRM", "ULA", "SKRock"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def choose_algo(algo, likelihood, thresh_conv, sigma, sigma_prior):
|
|
14
|
+
if algo == "ULA":
|
|
15
|
+
out = ULA(
|
|
16
|
+
GaussianScore(sigma_prior),
|
|
17
|
+
likelihood,
|
|
18
|
+
max_iter=500,
|
|
19
|
+
thinning=1,
|
|
20
|
+
step_size=0.01 / (1 / sigma**2 + 1 / sigma_prior**2),
|
|
21
|
+
clip=(-100, 100),
|
|
22
|
+
thresh_conv=thresh_conv,
|
|
23
|
+
sigma=1,
|
|
24
|
+
verbose=True,
|
|
25
|
+
)
|
|
26
|
+
elif algo == "SKRock":
|
|
27
|
+
out = SKRock(
|
|
28
|
+
GaussianScore(sigma_prior),
|
|
29
|
+
likelihood,
|
|
30
|
+
max_iter=500,
|
|
31
|
+
inner_iter=5,
|
|
32
|
+
step_size=1 / (1 / sigma**2 + 1 / sigma_prior**2),
|
|
33
|
+
clip=(-100, 100),
|
|
34
|
+
thresh_conv=thresh_conv,
|
|
35
|
+
sigma=1,
|
|
36
|
+
verbose=True,
|
|
37
|
+
)
|
|
38
|
+
elif algo == "DDRM":
|
|
39
|
+
diff = dinv.sampling.DDRM(
|
|
40
|
+
denoiser=GaussianDenoiser(sigma_prior),
|
|
41
|
+
eta=1,
|
|
42
|
+
sigmas=np.linspace(1, 0, 100),
|
|
43
|
+
)
|
|
44
|
+
out = dinv.sampling.DiffusionSampler(diff, clip=(-100, 100), max_iter=500)
|
|
45
|
+
else:
|
|
46
|
+
raise Exception("The sampling algorithm doesnt exist")
|
|
47
|
+
|
|
48
|
+
return out
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class GaussianScore(torch.nn.Module):
|
|
52
|
+
def __init__(self, sigma_prior):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.sigma_prior2 = sigma_prior**2
|
|
55
|
+
|
|
56
|
+
def forward(self, x, sigma):
|
|
57
|
+
return x / self.sigma_prior2
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class GaussianDenoiser(torch.nn.Module):
|
|
61
|
+
def __init__(self, sigma_prior):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.sigma_prior2 = sigma_prior**2
|
|
64
|
+
|
|
65
|
+
def forward(self, x, sigma):
|
|
66
|
+
return x / (1 + sigma**2 / self.sigma_prior2)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.mark.parametrize("algo", SAMPLING_ALGOS)
|
|
70
|
+
def test_sampling_algo(algo, imsize, device):
|
|
71
|
+
test_sample = torch.ones((1, *imsize))
|
|
72
|
+
|
|
73
|
+
sigma = 1
|
|
74
|
+
sigma_prior = 1
|
|
75
|
+
physics = dinv.physics.Denoising()
|
|
76
|
+
physics.noise_model = dinv.physics.GaussianNoise(sigma)
|
|
77
|
+
y = physics(test_sample)
|
|
78
|
+
|
|
79
|
+
convergence_crit = 0.1 # for fast tests
|
|
80
|
+
likelihood = L2(sigma=sigma)
|
|
81
|
+
f = choose_algo(
|
|
82
|
+
algo,
|
|
83
|
+
likelihood,
|
|
84
|
+
thresh_conv=convergence_crit,
|
|
85
|
+
sigma=sigma,
|
|
86
|
+
sigma_prior=sigma_prior,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
xmean, xvar = f(y, physics, seed=0)
|
|
90
|
+
|
|
91
|
+
tol = 5 # can be lowered?
|
|
92
|
+
sigma2 = sigma**2
|
|
93
|
+
sigma_prior2 = sigma_prior**2
|
|
94
|
+
|
|
95
|
+
# the posterior of a gaussian likelihood with a gaussian prior is gaussian
|
|
96
|
+
post_var = (sigma2 * sigma_prior2) / (sigma2 + sigma_prior2)
|
|
97
|
+
post_mean = y / (1 + sigma2 / sigma_prior2)
|
|
98
|
+
|
|
99
|
+
mean_ok = (
|
|
100
|
+
torch.sum((xmean - post_mean).abs() / post_mean < tol)
|
|
101
|
+
> np.prod(xmean.shape) / 2
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
var_ok = (
|
|
105
|
+
torch.sum((xvar - post_var).abs() / post_var < tol) > np.prod(xvar.shape) / 2
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
assert f.mean_has_converged() and f.var_has_converged() and mean_ok and var_ok
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def test_diffpir(device):
|
|
112
|
+
from deepinv.models import DiffUNet
|
|
113
|
+
|
|
114
|
+
x = torch.ones((1, 3, 32, 32)).to(device)
|
|
115
|
+
|
|
116
|
+
sigma = 12.75 / 255.0 # noise level
|
|
117
|
+
|
|
118
|
+
physics = dinv.physics.BlurFFT(
|
|
119
|
+
img_size=(3, x.shape[-2], x.shape[-1]),
|
|
120
|
+
filter=torch.ones((1, 1, 5, 5), device=device) / 25,
|
|
121
|
+
device=device,
|
|
122
|
+
noise_model=dinv.physics.GaussianNoise(sigma=sigma),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
y = physics(x)
|
|
126
|
+
|
|
127
|
+
model = DiffUNet().to(device)
|
|
128
|
+
likelihood = L2()
|
|
129
|
+
|
|
130
|
+
algorithm = DiffPIR(model, likelihood, max_iter=5, verbose=False, device=device)
|
|
131
|
+
|
|
132
|
+
out = algorithm(y, physics)
|
|
133
|
+
assert out.shape == x.shape
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def test_dps(device):
|
|
137
|
+
from deepinv.models import DiffUNet
|
|
138
|
+
|
|
139
|
+
x = torch.ones((1, 3, 32, 32)).to(device)
|
|
140
|
+
|
|
141
|
+
sigma = 12.75 / 255.0 # noise level
|
|
142
|
+
|
|
143
|
+
physics = dinv.physics.BlurFFT(
|
|
144
|
+
img_size=(3, x.shape[-2], x.shape[-1]),
|
|
145
|
+
filter=torch.ones((1, 1, 5, 5), device=device) / 25,
|
|
146
|
+
device=device,
|
|
147
|
+
noise_model=dinv.physics.GaussianNoise(sigma=sigma),
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
y = physics(x)
|
|
151
|
+
|
|
152
|
+
model = DiffUNet().to(device)
|
|
153
|
+
likelihood = L2()
|
|
154
|
+
|
|
155
|
+
algorithm = DPS(model, likelihood, max_iter=5, verbose=False, device=device)
|
|
156
|
+
|
|
157
|
+
out = algorithm(y, physics)
|
|
158
|
+
assert out.shape == x.shape
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
import deepinv as dinv
|
|
5
|
+
from deepinv.optim.prior import PnP
|
|
6
|
+
from deepinv.optim.data_fidelity import L2
|
|
7
|
+
from deepinv.unfolded import unfolded_builder, DEQ_builder
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
OPTIM_ALGO = ["PGD", "HQS"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.mark.parametrize("unfolded_algo", OPTIM_ALGO)
|
|
14
|
+
def test_unfolded(unfolded_algo, imsize, dummy_dataset, device):
|
|
15
|
+
pytest.importorskip("pytorch_wavelets")
|
|
16
|
+
|
|
17
|
+
# Select the data fidelity term
|
|
18
|
+
data_fidelity = L2()
|
|
19
|
+
|
|
20
|
+
# Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis.
|
|
21
|
+
# If the prior is initialized with a list of length max_iter,
|
|
22
|
+
# then a distinct weight is trained for each PGD iteration.
|
|
23
|
+
# For fixed trained model prior across iterations, initialize with a single model.
|
|
24
|
+
max_iter = 30 if torch.cuda.is_available() else 20 # Number of unrolled iterations
|
|
25
|
+
level = 3
|
|
26
|
+
prior = [
|
|
27
|
+
PnP(denoiser=dinv.models.WaveletPrior(wv="db8", level=level, device=device))
|
|
28
|
+
for i in range(max_iter)
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
# Unrolled optimization algorithm parameters
|
|
32
|
+
lamb = [
|
|
33
|
+
1.0
|
|
34
|
+
] * max_iter # initialization of the regularization parameter. A distinct lamb is trained for each iteration.
|
|
35
|
+
stepsize = [
|
|
36
|
+
1.0
|
|
37
|
+
] * max_iter # initialization of the stepsizes. A distinct stepsize is trained for each iteration.
|
|
38
|
+
|
|
39
|
+
sigma_denoiser_init = 0.01
|
|
40
|
+
sigma_denoiser = [sigma_denoiser_init * torch.ones(level, 3)] * max_iter
|
|
41
|
+
# sigma_denoiser = [torch.Tensor([sigma_denoiser_init])]*max_iter
|
|
42
|
+
params_algo = { # wrap all the restoration parameters in a 'params_algo' dictionary
|
|
43
|
+
"stepsize": stepsize,
|
|
44
|
+
"g_param": sigma_denoiser,
|
|
45
|
+
"lambda": lamb,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
trainable_params = [
|
|
49
|
+
"g_param",
|
|
50
|
+
"stepsize",
|
|
51
|
+
] # define which parameters from 'params_algo' are trainable
|
|
52
|
+
|
|
53
|
+
# Define the unfolded trainable model.
|
|
54
|
+
model = unfolded_builder(
|
|
55
|
+
unfolded_algo,
|
|
56
|
+
params_algo=params_algo,
|
|
57
|
+
trainable_params=trainable_params,
|
|
58
|
+
data_fidelity=data_fidelity,
|
|
59
|
+
max_iter=max_iter,
|
|
60
|
+
prior=prior,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
for idx, (name, param) in enumerate(model.named_parameters()):
|
|
64
|
+
assert param.requires_grad
|
|
65
|
+
assert (trainable_params[0] in name) or (trainable_params[1] in name)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@pytest.mark.parametrize("unfolded_algo", OPTIM_ALGO)
|
|
69
|
+
def test_DEQ(unfolded_algo, imsize, dummy_dataset, device):
|
|
70
|
+
pytest.importorskip("pytorch_wavelets")
|
|
71
|
+
torch.set_grad_enabled(
|
|
72
|
+
True
|
|
73
|
+
) # Disabled somewhere in previous test files, necessary for this test to pass
|
|
74
|
+
|
|
75
|
+
# Select the data fidelity term
|
|
76
|
+
data_fidelity = L2()
|
|
77
|
+
|
|
78
|
+
# Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis.
|
|
79
|
+
# If the prior is initialized with a list of length max_iter,
|
|
80
|
+
# then a distinct weight is trained for each PGD iteration.
|
|
81
|
+
# For fixed trained model prior across iterations, initialize with a single model.
|
|
82
|
+
max_iter = 30 if torch.cuda.is_available() else 20 # Number of unrolled iterations
|
|
83
|
+
level = 3
|
|
84
|
+
prior = [
|
|
85
|
+
PnP(denoiser=dinv.models.WaveletPrior(wv="db8", level=level, device=device))
|
|
86
|
+
for i in range(max_iter)
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
# Unrolled optimization algorithm parameters
|
|
90
|
+
lamb = [
|
|
91
|
+
1.0
|
|
92
|
+
] * max_iter # initialization of the regularization parameter. A distinct lamb is trained for each iteration.
|
|
93
|
+
stepsize = [
|
|
94
|
+
1.0
|
|
95
|
+
] * max_iter # initialization of the stepsizes. A distinct stepsize is trained for each iteration.
|
|
96
|
+
|
|
97
|
+
sigma_denoiser_init = 0.01
|
|
98
|
+
sigma_denoiser = [sigma_denoiser_init * torch.ones(level, 3)] * max_iter
|
|
99
|
+
# sigma_denoiser = [torch.Tensor([sigma_denoiser_init])]*max_iter
|
|
100
|
+
params_algo = { # wrap all the restoration parameters in a 'params_algo' dictionary
|
|
101
|
+
"stepsize": stepsize,
|
|
102
|
+
"g_param": sigma_denoiser,
|
|
103
|
+
"lambda": lamb,
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
trainable_params = [
|
|
107
|
+
"g_param",
|
|
108
|
+
"stepsize",
|
|
109
|
+
] # define which parameters from 'params_algo' are trainable
|
|
110
|
+
|
|
111
|
+
# Define the unfolded trainable model.
|
|
112
|
+
for and_acc in [False, True]:
|
|
113
|
+
# DRS, ADMM and CP algorithms are not real fixed-point algorithms on the primal variable
|
|
114
|
+
|
|
115
|
+
model = DEQ_builder(
|
|
116
|
+
unfolded_algo,
|
|
117
|
+
params_algo=params_algo,
|
|
118
|
+
trainable_params=trainable_params,
|
|
119
|
+
data_fidelity=data_fidelity,
|
|
120
|
+
max_iter=max_iter,
|
|
121
|
+
prior=prior,
|
|
122
|
+
anderson_acceleration=and_acc,
|
|
123
|
+
anderson_acceleration_backward=and_acc,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
for idx, (name, param) in enumerate(model.named_parameters()):
|
|
127
|
+
assert param.requires_grad
|
|
128
|
+
assert (trainable_params[0] in name) or (trainable_params[1] in name)
|
|
129
|
+
|
|
130
|
+
# batch_size, n_channels, img_size_w, img_size_h = 5, imsize
|
|
131
|
+
batch_size = 5
|
|
132
|
+
n_channels, img_size_w, img_size_h = imsize
|
|
133
|
+
noise_level = 0.01
|
|
134
|
+
|
|
135
|
+
torch.manual_seed(0)
|
|
136
|
+
test_sample = torch.randn(batch_size, n_channels, img_size_w, img_size_h).to(
|
|
137
|
+
device
|
|
138
|
+
)
|
|
139
|
+
groundtruth_sample = torch.randn(
|
|
140
|
+
batch_size, n_channels, img_size_w, img_size_h
|
|
141
|
+
).to(device)
|
|
142
|
+
|
|
143
|
+
physics = dinv.physics.BlurFFT(
|
|
144
|
+
img_size=(n_channels, img_size_w, img_size_h),
|
|
145
|
+
filter=dinv.physics.blur.gaussian_blur(),
|
|
146
|
+
device=device,
|
|
147
|
+
noise_model=dinv.physics.GaussianNoise(sigma=noise_level),
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
y = physics(test_sample).type(test_sample.dtype).to(device)
|
|
151
|
+
|
|
152
|
+
out = model(y, physics=physics)
|
|
153
|
+
|
|
154
|
+
assert out.shape == test_sample.shape
|
|
155
|
+
|
|
156
|
+
loss_fn = dinv.loss.SupLoss(metric=dinv.metric.mse())
|
|
157
|
+
loss = loss_fn(groundtruth_sample, out)
|
|
158
|
+
loss.backward()
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import deepinv
|
|
2
|
+
import torch
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.fixture
|
|
7
|
+
def tensorlist():
|
|
8
|
+
x = torch.ones((1, 1, 2, 2))
|
|
9
|
+
y = torch.ones((1, 1, 2, 2))
|
|
10
|
+
x = deepinv.utils.TensorList([x, x])
|
|
11
|
+
y = deepinv.utils.TensorList([y, y])
|
|
12
|
+
return x, y
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_tensordict_sum(tensorlist):
|
|
16
|
+
x, y = tensorlist
|
|
17
|
+
z = torch.ones((1, 1, 2, 2)) * 2
|
|
18
|
+
z1 = deepinv.utils.TensorList([z, z])
|
|
19
|
+
z = x + y
|
|
20
|
+
assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_tensordict_mul(tensorlist):
|
|
24
|
+
x, y = tensorlist
|
|
25
|
+
z = torch.ones((1, 1, 2, 2))
|
|
26
|
+
z1 = deepinv.utils.TensorList([z, z])
|
|
27
|
+
z = x * y
|
|
28
|
+
assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def test_tensordict_div(tensorlist):
|
|
32
|
+
x, y = tensorlist
|
|
33
|
+
z = torch.ones((1, 1, 2, 2))
|
|
34
|
+
z1 = deepinv.utils.TensorList([z, z])
|
|
35
|
+
z = x / y
|
|
36
|
+
assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_tensordict_sub(tensorlist):
|
|
40
|
+
x, y = tensorlist
|
|
41
|
+
z = torch.zeros((1, 1, 2, 2))
|
|
42
|
+
z1 = deepinv.utils.TensorList([z, z])
|
|
43
|
+
z = x - y
|
|
44
|
+
assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_tensordict_neg(tensorlist):
|
|
48
|
+
x, y = tensorlist
|
|
49
|
+
z = -torch.ones((1, 1, 2, 2))
|
|
50
|
+
z1 = deepinv.utils.TensorList([z, z])
|
|
51
|
+
z = -x
|
|
52
|
+
assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def test_tensordict_append(tensorlist):
|
|
56
|
+
x, y = tensorlist
|
|
57
|
+
z = torch.ones((1, 1, 2, 2))
|
|
58
|
+
z1 = deepinv.utils.TensorList([z, z, z, z])
|
|
59
|
+
z = x.append(y)
|
|
60
|
+
assert (z1[0] == z[0]).all() and (z1[-1] == z[-1]).all()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_plot():
|
|
64
|
+
x = torch.ones((1, 1, 2, 2))
|
|
65
|
+
imgs = [x, x]
|
|
66
|
+
deepinv.utils.plot(imgs, titles=["a", "b"])
|
|
67
|
+
deepinv.utils.plot(x, titles="a")
|
|
68
|
+
deepinv.utils.plot(imgs)
|