deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,316 @@
1
+ import pytest
2
+ import torch
3
+ import numpy as np
4
+
5
+ import deepinv as dinv
6
+
7
+
8
+ # Linear forward operators to test (make sure they appear in find_operator as well)
9
+ OPERATORS = [
10
+ "CS",
11
+ "fastCS",
12
+ "inpainting",
13
+ "denoising",
14
+ "deblur_fft",
15
+ "deblur",
16
+ "singlepixel",
17
+ "fast_singlepixel",
18
+ "super_resolution",
19
+ "MRI",
20
+ "pansharpen",
21
+ ]
22
+ NONLINEAR_OPERATORS = ["haze", "blind_deblur", "lidar"]
23
+
24
+ NOISES = [
25
+ "Gaussian",
26
+ "Poisson",
27
+ "PoissonGaussian",
28
+ "UniformGaussian",
29
+ "Uniform",
30
+ "Neighbor2Neighbor",
31
+ ]
32
+
33
+
34
+ def find_operator(name, device):
35
+ r"""
36
+ Chooses operator
37
+
38
+ :param name: operator name
39
+ :param device: (torch.device) cpu or cuda
40
+ :return: (deepinv.physics.Physics) forward operator.
41
+ """
42
+ img_size = (3, 16, 8)
43
+ norm = 1
44
+ if name == "CS":
45
+ m = 30
46
+ p = dinv.physics.CompressedSensing(m=m, img_shape=img_size, device=device)
47
+ norm = (
48
+ 1 + np.sqrt(np.prod(img_size) / m)
49
+ ) ** 2 - 0.75 # Marcenko-Pastur law, second term is a small n correction
50
+ elif name == "fastCS":
51
+ p = dinv.physics.CompressedSensing(
52
+ m=20, fast=True, channelwise=True, img_shape=img_size, device=device
53
+ )
54
+ elif name == "inpainting":
55
+ p = dinv.physics.Inpainting(tensor_size=img_size, mask=0.5, device=device)
56
+ elif name == "MRI":
57
+ img_size = (2, 16, 8)
58
+ p = dinv.physics.MRI(mask=torch.ones(img_size[-2], img_size[-1]), device=device)
59
+ elif name == "Tomography":
60
+ img_size = (1, 16, 16)
61
+ p = dinv.physics.Tomography(
62
+ img_width=img_size[-1], angles=img_size[-1], device=device
63
+ )
64
+ elif name == "denoising":
65
+ p = dinv.physics.Denoising(dinv.physics.GaussianNoise(0.1))
66
+ elif name == "pansharpen":
67
+ img_size = (3, 30, 32)
68
+ p = dinv.physics.Pansharpen(img_size=img_size, device=device)
69
+ norm = 0.4
70
+ elif name == "fast_singlepixel":
71
+ p = dinv.physics.SinglePixelCamera(
72
+ m=20, fast=True, img_shape=img_size, device=device
73
+ )
74
+ elif name == "singlepixel":
75
+ m = 20
76
+ p = dinv.physics.SinglePixelCamera(
77
+ m=m, fast=False, img_shape=img_size, device=device
78
+ )
79
+ norm = (
80
+ 1 + np.sqrt(np.prod(img_size) / m)
81
+ ) ** 2 - 3.7 # Marcenko-Pastur law, second term is a small n correction
82
+ elif name == "deblur":
83
+ img_size = (3, 17, 19)
84
+ p = dinv.physics.Blur(
85
+ dinv.physics.blur.gaussian_blur(sigma=(2, 0.1), angle=45.0), device=device
86
+ )
87
+ elif name == "deblur_fft":
88
+ img_size = (3, 17, 19)
89
+ p = dinv.physics.BlurFFT(
90
+ img_size=img_size,
91
+ filter=dinv.physics.blur.gaussian_blur(sigma=(0.1, 0.5), angle=45.0),
92
+ device=device,
93
+ )
94
+ elif name == "super_resolution":
95
+ img_size = (1, 32, 32)
96
+ factor = 2
97
+ norm = 1 / factor**2
98
+ p = dinv.physics.Downsampling(img_size=img_size, factor=factor, device=device)
99
+ else:
100
+ raise Exception("The inverse problem chosen doesn't exist")
101
+ return p, img_size, norm
102
+
103
+
104
+ def find_nonlinear_operator(name, device):
105
+ r"""
106
+ Chooses operator
107
+
108
+ :param name: operator name
109
+ :param device: (torch.device) cpu or cuda
110
+ :return: (deepinv.physics.Physics) forward operator.
111
+ """
112
+ if name == "blind_deblur":
113
+ x = dinv.utils.TensorList(
114
+ [
115
+ torch.randn(1, 3, 16, 16, device=device),
116
+ torch.randn(1, 1, 3, 3, device=device),
117
+ ]
118
+ )
119
+ p = dinv.physics.BlindBlur(kernel_size=3)
120
+ elif name == "haze":
121
+ x = dinv.utils.TensorList(
122
+ [
123
+ torch.randn(1, 1, 16, 16, device=device),
124
+ torch.randn(1, 1, 16, 16, device=device),
125
+ torch.randn(1, device=device),
126
+ ]
127
+ )
128
+ p = dinv.physics.Haze()
129
+ elif name == "lidar":
130
+ x = torch.rand(1, 3, 16, 16, device=device)
131
+ p = dinv.physics.SinglePhotonLidar(device=device)
132
+ else:
133
+ raise Exception("The inverse problem chosen doesn't exist")
134
+ return p, x
135
+
136
+
137
+ @pytest.mark.parametrize("name", OPERATORS)
138
+ def test_operators_adjointness(name, device):
139
+ r"""
140
+ Tests if a linear forward operator has a well defined adjoint.
141
+ Warning: Only test linear operators, non-linear ones will fail the test.
142
+
143
+ :param name: operator name (see find_operator)
144
+ :param imsize: image size tuple in (C, H, W)
145
+ :param device: (torch.device) cpu or cuda:x
146
+ :return: asserts adjointness
147
+ """
148
+ physics, imsize, _ = find_operator(name, device)
149
+ x = torch.randn(imsize, device=device).unsqueeze(0)
150
+ error = physics.adjointness_test(x).abs()
151
+ assert error < 1e-3
152
+
153
+
154
+ @pytest.mark.parametrize("name", OPERATORS)
155
+ def test_operators_norm(name, device):
156
+ r"""
157
+ Tests if a linear physics operator has a norm close to 1.
158
+ Warning: Only test linear operators, non-linear ones will fail the test.
159
+
160
+ :param name: operator name (see find_operator)
161
+ :param imsize: (tuple) image size tuple in (C, H, W)
162
+ :param device: (torch.device) cpu or cuda:x
163
+ :return: asserts norm is in (.8,1.2)
164
+ """
165
+ if name == "singlepixel" or name == "CS":
166
+ device = torch.device("cpu")
167
+
168
+ torch.manual_seed(0)
169
+ physics, imsize, norm_ref = find_operator(name, device)
170
+ x = torch.randn(imsize, device=device).unsqueeze(0)
171
+ norm = physics.compute_norm(x)
172
+ assert torch.abs(norm - norm_ref) < 0.2
173
+
174
+
175
+ @pytest.mark.parametrize("name", NONLINEAR_OPERATORS)
176
+ def test_nonlinear_operators(name, device):
177
+ r"""
178
+ Tests if a linear physics operator has a norm close to 1.
179
+ Warning: Only test linear operators, non-linear ones will fail the test.
180
+
181
+ :param name: operator name (see find_operator)
182
+ :param device: (torch.device) cpu or cuda:x
183
+ :return: asserts correct shapes
184
+ """
185
+ physics, x = find_nonlinear_operator(name, device)
186
+ y = physics(x)
187
+ xhat = physics.A_dagger(y)
188
+ assert x.shape == xhat.shape
189
+
190
+
191
+ @pytest.mark.parametrize("name", OPERATORS)
192
+ def test_pseudo_inverse(name, device):
193
+ r"""
194
+ Tests if a linear physics operator has a well defined pseudoinverse.
195
+ Warning: Only test linear operators, non-linear ones will fail the test.
196
+
197
+ :param name: operator name (see find_operator)
198
+ :param imsize: (tuple) image size tuple in (C, H, W)
199
+ :param device: (torch.device) cpu or cuda:x
200
+ :return: asserts error is less than 1e-3
201
+ """
202
+ physics, imsize, _ = find_operator(name, device)
203
+ x = torch.randn(imsize, device=device).unsqueeze(0)
204
+
205
+ r = physics.A_adjoint(physics.A(x))
206
+ y = physics.A(r)
207
+ error = (physics.A_dagger(y) - r).flatten().mean().abs()
208
+ assert error < 0.01
209
+
210
+
211
+ def test_MRI(device):
212
+ r"""
213
+ Test MRI function
214
+
215
+ :param name: operator name (see find_operator)
216
+ :param imsize: (tuple) image size tuple in (C, H, W)
217
+ :param device: (torch.device) cpu or cuda:x
218
+ :return: asserts error is less than 1e-3
219
+ """
220
+ physics = dinv.physics.MRI(mask=None, device=device, acceleration_factor=4)
221
+ x = torch.randn((2, 320, 320), device=device).unsqueeze(0)
222
+ x2 = physics.A_adjoint(physics.A(x))
223
+ assert x2.shape == x.shape
224
+
225
+ physics = dinv.physics.MRI(mask=None, device=device, acceleration_factor=8, seed=0)
226
+ y1 = physics.A(x)
227
+ physics.reset()
228
+ y2 = physics.A(x)
229
+ if y1.shape == y2.shape:
230
+ error = (y1.abs() - y2.abs()).flatten().mean().abs()
231
+ assert error > 0.0
232
+
233
+
234
+ def choose_noise(noise_type):
235
+ gain = 0.1
236
+ sigma = 0.1
237
+ if noise_type == "PoissonGaussian":
238
+ noise_model = dinv.physics.PoissonGaussianNoise(sigma=sigma, gain=gain)
239
+ elif noise_type == "Gaussian":
240
+ noise_model = dinv.physics.GaussianNoise(sigma)
241
+ elif noise_type == "UniformGaussian":
242
+ noise_model = dinv.physics.UniformGaussianNoise(
243
+ sigma=sigma
244
+ ) # This is equivalent to GaussianNoise when sigma is fixed
245
+ elif noise_type == "Uniform":
246
+ noise_model = dinv.physics.UniformNoise(a=gain)
247
+ elif noise_type == "Poisson":
248
+ noise_model = dinv.physics.PoissonNoise(gain)
249
+ elif noise_type == "Neighbor2Neighbor":
250
+ noise_model = dinv.physics.PoissonNoise(gain)
251
+ else:
252
+ raise Exception("Noise model not found")
253
+
254
+ return noise_model
255
+
256
+
257
+ @pytest.mark.parametrize("noise_type", NOISES)
258
+ def test_noise(device, noise_type):
259
+ r"""
260
+ Tests noise models.
261
+ """
262
+ physics = dinv.physics.DecomposablePhysics()
263
+ physics.noise_model = choose_noise(noise_type)
264
+ x = torch.ones((1, 12, 7), device=device).unsqueeze(0)
265
+
266
+ y1 = physics(
267
+ x
268
+ ) # Note: this works but not physics.A(x) because only the noise is reset (A does not encapsulate noise)
269
+ assert y1.shape == x.shape
270
+
271
+ if noise_type == "UniformGaussian":
272
+ physics.reset()
273
+ y2 = physics(x)
274
+ error = (y1 - y2).flatten().abs().sum()
275
+ assert error > 0.0
276
+
277
+
278
+ def test_reset_noise(device):
279
+ r"""
280
+ Tests that the reset function works.
281
+
282
+ :param device: (torch.device) cpu or cuda:x
283
+ :return: asserts error is > 0
284
+ """
285
+ physics = dinv.physics.DecomposablePhysics()
286
+ physics.noise_model = dinv.physics.UniformGaussianNoise(
287
+ sigma=None
288
+ ) # Should be 20/255 (to check)
289
+ x = torch.ones((1, 12, 7), device=device).unsqueeze(0)
290
+
291
+ y1 = physics(
292
+ x
293
+ ) # Note: this works but not physics.A(x) because only the noise is reset (A does not encapsulate noise)
294
+ physics.reset()
295
+ y2 = physics(x)
296
+ error = (y1 - y2).flatten().abs().sum()
297
+ assert error > 0.0
298
+
299
+
300
+ def test_tomography(device):
301
+ r"""
302
+ Tests tomography operator which does not have a numerically precise adjoint.
303
+
304
+ :param device: (torch.device) cpu or cuda:x
305
+ """
306
+ for circle in [True, False]:
307
+ imsize = (1, 16, 16)
308
+ physics = dinv.physics.Tomography(
309
+ img_width=imsize[-1], angles=imsize[-1], device=device, circle=circle
310
+ )
311
+
312
+ x = torch.randn(imsize, device=device).unsqueeze(0)
313
+ r = physics.A_adjoint(physics.A(x))
314
+ y = physics.A(r)
315
+ error = (physics.A_dagger(y) - r).flatten().mean().abs()
316
+ assert error < 0.2
@@ -0,0 +1,158 @@
1
+ import pytest
2
+ import torch.nn
3
+ import numpy as np
4
+
5
+ import deepinv as dinv
6
+ from deepinv.optim.data_fidelity import L2
7
+ from deepinv.sampling import ULA, SKRock, DiffPIR, DPS
8
+
9
+
10
+ SAMPLING_ALGOS = ["DDRM", "ULA", "SKRock"]
11
+
12
+
13
+ def choose_algo(algo, likelihood, thresh_conv, sigma, sigma_prior):
14
+ if algo == "ULA":
15
+ out = ULA(
16
+ GaussianScore(sigma_prior),
17
+ likelihood,
18
+ max_iter=500,
19
+ thinning=1,
20
+ step_size=0.01 / (1 / sigma**2 + 1 / sigma_prior**2),
21
+ clip=(-100, 100),
22
+ thresh_conv=thresh_conv,
23
+ sigma=1,
24
+ verbose=True,
25
+ )
26
+ elif algo == "SKRock":
27
+ out = SKRock(
28
+ GaussianScore(sigma_prior),
29
+ likelihood,
30
+ max_iter=500,
31
+ inner_iter=5,
32
+ step_size=1 / (1 / sigma**2 + 1 / sigma_prior**2),
33
+ clip=(-100, 100),
34
+ thresh_conv=thresh_conv,
35
+ sigma=1,
36
+ verbose=True,
37
+ )
38
+ elif algo == "DDRM":
39
+ diff = dinv.sampling.DDRM(
40
+ denoiser=GaussianDenoiser(sigma_prior),
41
+ eta=1,
42
+ sigmas=np.linspace(1, 0, 100),
43
+ )
44
+ out = dinv.sampling.DiffusionSampler(diff, clip=(-100, 100), max_iter=500)
45
+ else:
46
+ raise Exception("The sampling algorithm doesnt exist")
47
+
48
+ return out
49
+
50
+
51
+ class GaussianScore(torch.nn.Module):
52
+ def __init__(self, sigma_prior):
53
+ super().__init__()
54
+ self.sigma_prior2 = sigma_prior**2
55
+
56
+ def forward(self, x, sigma):
57
+ return x / self.sigma_prior2
58
+
59
+
60
+ class GaussianDenoiser(torch.nn.Module):
61
+ def __init__(self, sigma_prior):
62
+ super().__init__()
63
+ self.sigma_prior2 = sigma_prior**2
64
+
65
+ def forward(self, x, sigma):
66
+ return x / (1 + sigma**2 / self.sigma_prior2)
67
+
68
+
69
+ @pytest.mark.parametrize("algo", SAMPLING_ALGOS)
70
+ def test_sampling_algo(algo, imsize, device):
71
+ test_sample = torch.ones((1, *imsize))
72
+
73
+ sigma = 1
74
+ sigma_prior = 1
75
+ physics = dinv.physics.Denoising()
76
+ physics.noise_model = dinv.physics.GaussianNoise(sigma)
77
+ y = physics(test_sample)
78
+
79
+ convergence_crit = 0.1 # for fast tests
80
+ likelihood = L2(sigma=sigma)
81
+ f = choose_algo(
82
+ algo,
83
+ likelihood,
84
+ thresh_conv=convergence_crit,
85
+ sigma=sigma,
86
+ sigma_prior=sigma_prior,
87
+ )
88
+
89
+ xmean, xvar = f(y, physics, seed=0)
90
+
91
+ tol = 5 # can be lowered?
92
+ sigma2 = sigma**2
93
+ sigma_prior2 = sigma_prior**2
94
+
95
+ # the posterior of a gaussian likelihood with a gaussian prior is gaussian
96
+ post_var = (sigma2 * sigma_prior2) / (sigma2 + sigma_prior2)
97
+ post_mean = y / (1 + sigma2 / sigma_prior2)
98
+
99
+ mean_ok = (
100
+ torch.sum((xmean - post_mean).abs() / post_mean < tol)
101
+ > np.prod(xmean.shape) / 2
102
+ )
103
+
104
+ var_ok = (
105
+ torch.sum((xvar - post_var).abs() / post_var < tol) > np.prod(xvar.shape) / 2
106
+ )
107
+
108
+ assert f.mean_has_converged() and f.var_has_converged() and mean_ok and var_ok
109
+
110
+
111
+ def test_diffpir(device):
112
+ from deepinv.models import DiffUNet
113
+
114
+ x = torch.ones((1, 3, 32, 32)).to(device)
115
+
116
+ sigma = 12.75 / 255.0 # noise level
117
+
118
+ physics = dinv.physics.BlurFFT(
119
+ img_size=(3, x.shape[-2], x.shape[-1]),
120
+ filter=torch.ones((1, 1, 5, 5), device=device) / 25,
121
+ device=device,
122
+ noise_model=dinv.physics.GaussianNoise(sigma=sigma),
123
+ )
124
+
125
+ y = physics(x)
126
+
127
+ model = DiffUNet().to(device)
128
+ likelihood = L2()
129
+
130
+ algorithm = DiffPIR(model, likelihood, max_iter=5, verbose=False, device=device)
131
+
132
+ out = algorithm(y, physics)
133
+ assert out.shape == x.shape
134
+
135
+
136
+ def test_dps(device):
137
+ from deepinv.models import DiffUNet
138
+
139
+ x = torch.ones((1, 3, 32, 32)).to(device)
140
+
141
+ sigma = 12.75 / 255.0 # noise level
142
+
143
+ physics = dinv.physics.BlurFFT(
144
+ img_size=(3, x.shape[-2], x.shape[-1]),
145
+ filter=torch.ones((1, 1, 5, 5), device=device) / 25,
146
+ device=device,
147
+ noise_model=dinv.physics.GaussianNoise(sigma=sigma),
148
+ )
149
+
150
+ y = physics(x)
151
+
152
+ model = DiffUNet().to(device)
153
+ likelihood = L2()
154
+
155
+ algorithm = DPS(model, likelihood, max_iter=5, verbose=False, device=device)
156
+
157
+ out = algorithm(y, physics)
158
+ assert out.shape == x.shape
@@ -0,0 +1,158 @@
1
+ import pytest
2
+ import torch
3
+
4
+ import deepinv as dinv
5
+ from deepinv.optim.prior import PnP
6
+ from deepinv.optim.data_fidelity import L2
7
+ from deepinv.unfolded import unfolded_builder, DEQ_builder
8
+
9
+
10
+ OPTIM_ALGO = ["PGD", "HQS"]
11
+
12
+
13
+ @pytest.mark.parametrize("unfolded_algo", OPTIM_ALGO)
14
+ def test_unfolded(unfolded_algo, imsize, dummy_dataset, device):
15
+ pytest.importorskip("pytorch_wavelets")
16
+
17
+ # Select the data fidelity term
18
+ data_fidelity = L2()
19
+
20
+ # Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis.
21
+ # If the prior is initialized with a list of length max_iter,
22
+ # then a distinct weight is trained for each PGD iteration.
23
+ # For fixed trained model prior across iterations, initialize with a single model.
24
+ max_iter = 30 if torch.cuda.is_available() else 20 # Number of unrolled iterations
25
+ level = 3
26
+ prior = [
27
+ PnP(denoiser=dinv.models.WaveletPrior(wv="db8", level=level, device=device))
28
+ for i in range(max_iter)
29
+ ]
30
+
31
+ # Unrolled optimization algorithm parameters
32
+ lamb = [
33
+ 1.0
34
+ ] * max_iter # initialization of the regularization parameter. A distinct lamb is trained for each iteration.
35
+ stepsize = [
36
+ 1.0
37
+ ] * max_iter # initialization of the stepsizes. A distinct stepsize is trained for each iteration.
38
+
39
+ sigma_denoiser_init = 0.01
40
+ sigma_denoiser = [sigma_denoiser_init * torch.ones(level, 3)] * max_iter
41
+ # sigma_denoiser = [torch.Tensor([sigma_denoiser_init])]*max_iter
42
+ params_algo = { # wrap all the restoration parameters in a 'params_algo' dictionary
43
+ "stepsize": stepsize,
44
+ "g_param": sigma_denoiser,
45
+ "lambda": lamb,
46
+ }
47
+
48
+ trainable_params = [
49
+ "g_param",
50
+ "stepsize",
51
+ ] # define which parameters from 'params_algo' are trainable
52
+
53
+ # Define the unfolded trainable model.
54
+ model = unfolded_builder(
55
+ unfolded_algo,
56
+ params_algo=params_algo,
57
+ trainable_params=trainable_params,
58
+ data_fidelity=data_fidelity,
59
+ max_iter=max_iter,
60
+ prior=prior,
61
+ )
62
+
63
+ for idx, (name, param) in enumerate(model.named_parameters()):
64
+ assert param.requires_grad
65
+ assert (trainable_params[0] in name) or (trainable_params[1] in name)
66
+
67
+
68
+ @pytest.mark.parametrize("unfolded_algo", OPTIM_ALGO)
69
+ def test_DEQ(unfolded_algo, imsize, dummy_dataset, device):
70
+ pytest.importorskip("pytorch_wavelets")
71
+ torch.set_grad_enabled(
72
+ True
73
+ ) # Disabled somewhere in previous test files, necessary for this test to pass
74
+
75
+ # Select the data fidelity term
76
+ data_fidelity = L2()
77
+
78
+ # Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis.
79
+ # If the prior is initialized with a list of length max_iter,
80
+ # then a distinct weight is trained for each PGD iteration.
81
+ # For fixed trained model prior across iterations, initialize with a single model.
82
+ max_iter = 30 if torch.cuda.is_available() else 20 # Number of unrolled iterations
83
+ level = 3
84
+ prior = [
85
+ PnP(denoiser=dinv.models.WaveletPrior(wv="db8", level=level, device=device))
86
+ for i in range(max_iter)
87
+ ]
88
+
89
+ # Unrolled optimization algorithm parameters
90
+ lamb = [
91
+ 1.0
92
+ ] * max_iter # initialization of the regularization parameter. A distinct lamb is trained for each iteration.
93
+ stepsize = [
94
+ 1.0
95
+ ] * max_iter # initialization of the stepsizes. A distinct stepsize is trained for each iteration.
96
+
97
+ sigma_denoiser_init = 0.01
98
+ sigma_denoiser = [sigma_denoiser_init * torch.ones(level, 3)] * max_iter
99
+ # sigma_denoiser = [torch.Tensor([sigma_denoiser_init])]*max_iter
100
+ params_algo = { # wrap all the restoration parameters in a 'params_algo' dictionary
101
+ "stepsize": stepsize,
102
+ "g_param": sigma_denoiser,
103
+ "lambda": lamb,
104
+ }
105
+
106
+ trainable_params = [
107
+ "g_param",
108
+ "stepsize",
109
+ ] # define which parameters from 'params_algo' are trainable
110
+
111
+ # Define the unfolded trainable model.
112
+ for and_acc in [False, True]:
113
+ # DRS, ADMM and CP algorithms are not real fixed-point algorithms on the primal variable
114
+
115
+ model = DEQ_builder(
116
+ unfolded_algo,
117
+ params_algo=params_algo,
118
+ trainable_params=trainable_params,
119
+ data_fidelity=data_fidelity,
120
+ max_iter=max_iter,
121
+ prior=prior,
122
+ anderson_acceleration=and_acc,
123
+ anderson_acceleration_backward=and_acc,
124
+ )
125
+
126
+ for idx, (name, param) in enumerate(model.named_parameters()):
127
+ assert param.requires_grad
128
+ assert (trainable_params[0] in name) or (trainable_params[1] in name)
129
+
130
+ # batch_size, n_channels, img_size_w, img_size_h = 5, imsize
131
+ batch_size = 5
132
+ n_channels, img_size_w, img_size_h = imsize
133
+ noise_level = 0.01
134
+
135
+ torch.manual_seed(0)
136
+ test_sample = torch.randn(batch_size, n_channels, img_size_w, img_size_h).to(
137
+ device
138
+ )
139
+ groundtruth_sample = torch.randn(
140
+ batch_size, n_channels, img_size_w, img_size_h
141
+ ).to(device)
142
+
143
+ physics = dinv.physics.BlurFFT(
144
+ img_size=(n_channels, img_size_w, img_size_h),
145
+ filter=dinv.physics.blur.gaussian_blur(),
146
+ device=device,
147
+ noise_model=dinv.physics.GaussianNoise(sigma=noise_level),
148
+ )
149
+
150
+ y = physics(test_sample).type(test_sample.dtype).to(device)
151
+
152
+ out = model(y, physics=physics)
153
+
154
+ assert out.shape == test_sample.shape
155
+
156
+ loss_fn = dinv.loss.SupLoss(metric=dinv.metric.mse())
157
+ loss = loss_fn(groundtruth_sample, out)
158
+ loss.backward()
@@ -0,0 +1,68 @@
1
+ import deepinv
2
+ import torch
3
+ import pytest
4
+
5
+
6
+ @pytest.fixture
7
+ def tensorlist():
8
+ x = torch.ones((1, 1, 2, 2))
9
+ y = torch.ones((1, 1, 2, 2))
10
+ x = deepinv.utils.TensorList([x, x])
11
+ y = deepinv.utils.TensorList([y, y])
12
+ return x, y
13
+
14
+
15
+ def test_tensordict_sum(tensorlist):
16
+ x, y = tensorlist
17
+ z = torch.ones((1, 1, 2, 2)) * 2
18
+ z1 = deepinv.utils.TensorList([z, z])
19
+ z = x + y
20
+ assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
21
+
22
+
23
+ def test_tensordict_mul(tensorlist):
24
+ x, y = tensorlist
25
+ z = torch.ones((1, 1, 2, 2))
26
+ z1 = deepinv.utils.TensorList([z, z])
27
+ z = x * y
28
+ assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
29
+
30
+
31
+ def test_tensordict_div(tensorlist):
32
+ x, y = tensorlist
33
+ z = torch.ones((1, 1, 2, 2))
34
+ z1 = deepinv.utils.TensorList([z, z])
35
+ z = x / y
36
+ assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
37
+
38
+
39
+ def test_tensordict_sub(tensorlist):
40
+ x, y = tensorlist
41
+ z = torch.zeros((1, 1, 2, 2))
42
+ z1 = deepinv.utils.TensorList([z, z])
43
+ z = x - y
44
+ assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
45
+
46
+
47
+ def test_tensordict_neg(tensorlist):
48
+ x, y = tensorlist
49
+ z = -torch.ones((1, 1, 2, 2))
50
+ z1 = deepinv.utils.TensorList([z, z])
51
+ z = -x
52
+ assert (z1[0] == z[0]).all() and (z1[1] == z[1]).all()
53
+
54
+
55
+ def test_tensordict_append(tensorlist):
56
+ x, y = tensorlist
57
+ z = torch.ones((1, 1, 2, 2))
58
+ z1 = deepinv.utils.TensorList([z, z, z, z])
59
+ z = x.append(y)
60
+ assert (z1[0] == z[0]).all() and (z1[-1] == z[-1]).all()
61
+
62
+
63
+ def test_plot():
64
+ x = torch.ones((1, 1, 2, 2))
65
+ imgs = [x, x]
66
+ deepinv.utils.plot(imgs, titles=["a", "b"])
67
+ deepinv.utils.plot(x, titles="a")
68
+ deepinv.utils.plot(imgs)