deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,377 @@
1
+ import sys
2
+ import pytest
3
+ import torch
4
+
5
+ import deepinv as dinv
6
+
7
+
8
+ MODEL_LIST_1_CHANNEL = [
9
+ "autoencoder",
10
+ "drunet",
11
+ "dncnn",
12
+ "median",
13
+ "tgv",
14
+ "waveletprior",
15
+ "waveletdict",
16
+ ]
17
+ MODEL_LIST = MODEL_LIST_1_CHANNEL + [
18
+ "bm3d",
19
+ "gsdrunet",
20
+ "scunet",
21
+ "swinir",
22
+ "tv",
23
+ "unet",
24
+ "waveletdict_hard",
25
+ "waveletdict_topk",
26
+ ]
27
+
28
+
29
+ def choose_denoiser(name, imsize):
30
+ if name.startswith("waveletdict") or name == "waveletprior":
31
+ pytest.importorskip(
32
+ "pytorch_wavelets",
33
+ reason="This test requires pytorch_wavelets. It should be "
34
+ "installed with `pip install "
35
+ "git+https://github.com/fbcotter/pytorch_wavelets.git`",
36
+ )
37
+ if name == "bm3d":
38
+ pytest.importorskip(
39
+ "bm3d",
40
+ reason="This test requires bm3d. It should be "
41
+ "installed with `pip install bm3d`",
42
+ )
43
+ if name in ("swinir", "scunet"):
44
+ pytest.importorskip(
45
+ "timm",
46
+ reason="This test requires timm. It should be "
47
+ "installed with `pip install timm`",
48
+ )
49
+
50
+ if name == "unet":
51
+ out = dinv.models.UNet(in_channels=imsize[0], out_channels=imsize[0])
52
+ elif name == "drunet":
53
+ out = dinv.models.DRUNet(in_channels=imsize[0], out_channels=imsize[0])
54
+ elif name == "scunet":
55
+ out = dinv.models.SCUNet(in_nc=imsize[0])
56
+ elif name == "gsdrunet":
57
+ out = dinv.models.GSDRUNet(in_channels=imsize[0], out_channels=imsize[0])
58
+ elif name == "bm3d":
59
+ out = dinv.models.BM3D()
60
+ elif name == "dncnn":
61
+ out = dinv.models.DnCNN(in_channels=imsize[0], out_channels=imsize[0])
62
+ elif name == "waveletprior":
63
+ out = dinv.models.WaveletPrior()
64
+ elif name == "waveletdict":
65
+ out = dinv.models.WaveletDict()
66
+ elif name == "waveletdict_hard":
67
+ out = dinv.models.WaveletDict(non_linearity="hard")
68
+ elif name == "waveletdict_topk":
69
+ out = dinv.models.WaveletDict(non_linearity="topk")
70
+ elif name == "tgv":
71
+ out = dinv.models.TGV(n_it_max=10)
72
+ elif name == "tv":
73
+ out = dinv.models.TV(n_it_max=10)
74
+ elif name == "median":
75
+ out = dinv.models.MedianFilter()
76
+ elif name == "autoencoder":
77
+ out = dinv.models.AutoEncoder(dim_input=imsize[0] * imsize[1] * imsize[2])
78
+ elif name == "swinir":
79
+ out = dinv.models.SwinIR(in_chans=imsize[0])
80
+ else:
81
+ raise Exception("Unknown denoiser")
82
+
83
+ return out
84
+
85
+
86
+ @pytest.mark.parametrize("denoiser", MODEL_LIST)
87
+ def test_denoiser(imsize, device, denoiser):
88
+ model = choose_denoiser(denoiser, imsize).to(device)
89
+
90
+ torch.manual_seed(0)
91
+ sigma = 0.2
92
+ physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
93
+ x = torch.ones(imsize, device=device).unsqueeze(0)
94
+ y = physics(x)
95
+ x_hat = model(y, sigma)
96
+
97
+ assert x_hat.shape == x.shape
98
+
99
+
100
+ def test_equivariant(imsize, device):
101
+ # 1. Check that the equivariance module is compatible with a denoiser
102
+ model = dinv.models.DRUNet(in_channels=imsize[0], out_channels=imsize[0])
103
+
104
+ model = dinv.models.EquivariantDenoiser(
105
+ model, transform="rotoflips", random=True
106
+ ).to(device)
107
+
108
+ torch.manual_seed(0)
109
+ sigma = 0.2
110
+ physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
111
+ x = torch.ones(imsize, device=device).unsqueeze(0)
112
+ y = physics(x)
113
+ x_hat = model(y, sigma)
114
+
115
+ assert x_hat.shape == x.shape
116
+
117
+ # 2. Check that the equivariance module yields the identity when the denoiser is the identity
118
+ class DummyIdentity(torch.nn.Module):
119
+ def __init__(self):
120
+ super().__init__()
121
+
122
+ def forward(self, x, sigma):
123
+ return x
124
+
125
+ model_id = DummyIdentity()
126
+
127
+ list_transforms = ["rotations", "flips", "rotoflips"]
128
+
129
+ for transform in list_transforms:
130
+ for random in [True, False]:
131
+ model = dinv.models.EquivariantDenoiser(
132
+ model_id, transform=transform, random=random
133
+ ).to(device)
134
+
135
+ x = torch.ones(imsize, device=device).unsqueeze(0)
136
+ y = physics(x)
137
+ y_hat = model(y, sigma)
138
+
139
+ assert torch.allclose(y, y_hat)
140
+
141
+
142
+ @pytest.mark.parametrize("denoiser", MODEL_LIST_1_CHANNEL)
143
+ def test_denoiser_1_channel(imsize_1_channel, device, denoiser):
144
+ model = choose_denoiser(denoiser, imsize_1_channel).to(device)
145
+
146
+ torch.manual_seed(0)
147
+ sigma = 0.2
148
+ physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
149
+ x = torch.ones(imsize_1_channel, device=device).unsqueeze(0)
150
+ y = physics(x)
151
+
152
+ x_hat = model(y, sigma)
153
+
154
+ assert x_hat.shape == x.shape
155
+
156
+
157
+ def test_drunet_inputs(imsize_1_channel, device):
158
+ f = dinv.models.DRUNet(
159
+ in_channels=imsize_1_channel[0], out_channels=imsize_1_channel[0], device=device
160
+ )
161
+
162
+ torch.manual_seed(0)
163
+ sigma = 0.2
164
+ physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
165
+ x = torch.ones(imsize_1_channel, device=device).unsqueeze(0)
166
+ y = physics(x)
167
+
168
+ # Case 1: sigma is a float
169
+ x_hat = f(y, sigma)
170
+ assert x_hat.shape == x.shape
171
+
172
+ # Case 2: sigma is a torch tensor with batch dimension
173
+ batch_size = 3
174
+ x = torch.ones((batch_size, 1, 31, 37), device=device)
175
+ y = physics(x)
176
+ sigma_tensor = torch.tensor([sigma] * batch_size).to(device)
177
+ x_hat = f(y, sigma_tensor)
178
+ assert x_hat.shape == x.shape
179
+
180
+ # Case 3: image has shape mulitple of 8
181
+ x = torch.ones((3, 1, 32, 40), device=device)
182
+ y = physics(x)
183
+ x_hat = f(y, sigma_tensor)
184
+ assert x_hat.shape == x.shape
185
+
186
+ # Case 4: sigma is a tensor with no dimension
187
+ sigma_tensor = torch.tensor(sigma).to(device)
188
+ x_hat = f(y, sigma_tensor)
189
+ assert x_hat.shape == x.shape
190
+
191
+
192
+ def test_diffunetmodel(imsize, device):
193
+ # This model is a bit different from others as not strictly a denoiser as such.
194
+ # The Ho et al. diffusion model only works for color, square image with powers of two in w, h.
195
+ # Smallest size accepted so far is (3, 32, 32), but probably not meaningful at that size since trained at 256x256.
196
+
197
+ from deepinv.models import DiffUNet
198
+
199
+ model = DiffUNet().to(device)
200
+
201
+ torch.manual_seed(0)
202
+ sigma = 0.2
203
+ physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
204
+ x = torch.ones((3, 32, 32), device=device).unsqueeze(
205
+ 0
206
+ ) # Testing the smallest size possible
207
+ y = physics(x)
208
+
209
+ timestep = torch.tensor([1]).to(
210
+ device
211
+ ) # We pick a random timestep, goal is to check that model inference is ok.
212
+ x_hat = model(y, timestep)
213
+ x_hat = x_hat[:, :3, ...]
214
+
215
+ assert x_hat.shape == x.shape
216
+
217
+ # Now we check that the denoise_forward method works
218
+ x_hat = model(y, sigma)
219
+ assert x_hat.shape == x.shape
220
+
221
+ with pytest.raises(Exception):
222
+ # The following should raise an exception because type_t is not in ['noise_level', 'timestep'].
223
+ x_hat = model(y, sigma, type_t="wrong_type")
224
+
225
+
226
+ def test_PDNet(imsize_1_channel, device):
227
+ # Tests the PDNet algorithm - this is an unfolded algorithm so it is tested on its own here.
228
+ from deepinv.optim.optimizers import CPIteration, fStep, gStep
229
+ from deepinv.optim import Prior, DataFidelity
230
+ from deepinv.models import PDNet_PrimalBlock, PDNet_DualBlock
231
+ from deepinv.unfolded import unfolded_builder
232
+
233
+ sigma = 0.2
234
+ physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
235
+ x = torch.ones(imsize_1_channel, device=device).unsqueeze(0)
236
+ y = physics(x)
237
+
238
+ class PDNetIteration(CPIteration):
239
+ r"""Single iteration of learned primal dual.
240
+ We only redefine the fStep and gStep classes.
241
+ The forward method is inherited from the CPIteration class.
242
+ """
243
+
244
+ def __init__(self, **kwargs):
245
+ super().__init__(**kwargs)
246
+ self.g_step = gStepPDNet(**kwargs)
247
+ self.f_step = fStepPDNet(**kwargs)
248
+
249
+ class fStepPDNet(fStep):
250
+ r"""
251
+ Dual update of the PDNet algorithm.
252
+ We write it as a proximal operator of the data fidelity term.
253
+ This proximal mapping is to be replaced by a trainable model.
254
+ """
255
+
256
+ def __init__(self, **kwargs):
257
+ super().__init__(**kwargs)
258
+
259
+ def forward(self, x, w, cur_data_fidelity, y, *args):
260
+ r"""
261
+ :param torch.Tensor x: Current first variable :math:`u`.
262
+ :param torch.Tensor w: Current second variable :math:`A z`.
263
+ :param deepinv.optim.data_fidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data fidelity term.
264
+ :param torch.Tensor y: Input data.
265
+ """
266
+ return cur_data_fidelity.prox(x, w, y)
267
+
268
+ class gStepPDNet(gStep):
269
+ r"""
270
+ Primal update of the PDNet algorithm.
271
+ We write it as a proximal operator of the prior term.
272
+ This proximal mapping is to be replaced by a trainable model.
273
+ """
274
+
275
+ def __init__(self, **kwargs):
276
+ super().__init__(**kwargs)
277
+
278
+ def forward(self, x, w, cur_prior, *args):
279
+ r"""
280
+ :param torch.Tensor x: Current first variable :math:`x`.
281
+ :param torch.Tensor w: Current second variable :math:`A^\top u`.
282
+ :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
283
+ """
284
+ return cur_prior.prox(x, w)
285
+
286
+ class PDNetPrior(Prior):
287
+ def __init__(self, model, *args, **kwargs):
288
+ super().__init__(*args, **kwargs)
289
+ self.model = model
290
+
291
+ def prox(self, x, w):
292
+ return self.model(x, w[:, 0:1, :, :])
293
+
294
+ class PDNetDataFid(DataFidelity):
295
+ def __init__(self, model, *args, **kwargs):
296
+ super().__init__(*args, **kwargs)
297
+ self.model = model
298
+
299
+ def prox(self, x, w, y):
300
+ return self.model(x, w[:, 1:2, :, :], y)
301
+
302
+ # Unrolled optimization algorithm parameters
303
+ max_iter = 5 # number of unfolded layers
304
+
305
+ # Set up the data fidelity term. Each layer has its own data fidelity module.
306
+ data_fidelity = [
307
+ PDNetDataFid(model=PDNet_DualBlock().to(device)) for i in range(max_iter)
308
+ ]
309
+
310
+ # Set up the trainable prior. Each layer has its own prior module.
311
+ prior = [PDNetPrior(model=PDNet_PrimalBlock().to(device)) for i in range(max_iter)]
312
+
313
+ n_primal = 5 # extend the primal space
314
+ n_dual = 5 # extend the dual space
315
+
316
+ def custom_init(y, physics):
317
+ x0 = physics.A_dagger(y).repeat(1, n_primal, 1, 1)
318
+ u0 = torch.zeros_like(y).repeat(1, n_dual, 1, 1)
319
+ return {"est": (x0, x0, u0)}
320
+
321
+ def custom_output(X):
322
+ return X["est"][0][:, 1, :, :].unsqueeze(1)
323
+
324
+ # Define the unfolded trainable model.
325
+ model = unfolded_builder(
326
+ iteration=PDNetIteration(),
327
+ params_algo={"K": physics.A, "K_adjoint": physics.A_adjoint, "beta": 1.0},
328
+ data_fidelity=data_fidelity,
329
+ prior=prior,
330
+ max_iter=max_iter,
331
+ custom_init=custom_init,
332
+ get_output=custom_output,
333
+ )
334
+
335
+ x_hat = model(y, physics)
336
+
337
+ assert x_hat.shape == x.shape
338
+
339
+
340
+ @pytest.mark.parametrize(
341
+ "denoiser, dep",
342
+ [
343
+ ("BM3D", "bm3d"),
344
+ ("SCUNet", "timm"),
345
+ ("SwinIR", "timm"),
346
+ ("WaveletPrior", "pytorch_wavelets"),
347
+ ("WaveletDict", "pytorch_wavelets"),
348
+ ],
349
+ )
350
+ def test_optional_dependencies(denoiser, dep):
351
+ # Skip the test if the optional dependency is installed
352
+ if dep in sys.modules:
353
+ pytest.skip(f"Optional dependency {dep} is installed.")
354
+
355
+ klass = getattr(dinv.models, denoiser)
356
+ with pytest.raises(ImportError, match=f"pip install .*{dep}"):
357
+ klass()
358
+
359
+
360
+ # def test_dip(imsize, device): TODO: fix this test
361
+ # torch.manual_seed(0)
362
+ # channels = 64
363
+ # physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(0.2))
364
+ # f = dinv.models.DeepImagePrior(
365
+ # generator=dinv.models.ConvDecoder(imsize, layers=3, channels=channels).to(
366
+ # device
367
+ # ),
368
+ # input_size=(channels, imsize[1], imsize[2]),
369
+ # iterations=30,
370
+ # )
371
+ # x = torch.ones(imsize, device=device).unsqueeze(0)
372
+ # y = physics(x)
373
+ # mse_in = (y - x).pow(2).mean()
374
+ # x_net = f(y, physics)
375
+ # mse_out = (x_net - x).pow(2).mean()
376
+ #
377
+ # assert mse_out < mse_in