deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/loss/sure.py ADDED
@@ -0,0 +1,338 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ def hutch_div(y, physics, f, mc_iter=1):
7
+ r"""
8
+ Hutch divergence for A(f(x)).
9
+
10
+ :param torch.Tensor y: Measurements.
11
+ :param deepinv.physics.Physics physics: Forward operator associated with the measurements.
12
+ :param torch.nn.Module f: Reconstruction network.
13
+ :param int mc_iter: number of iterations. Default=1.
14
+ :return: (float) hutch divergence.
15
+ """
16
+ input = y.requires_grad_(True)
17
+ output = physics.A(f(input, physics))
18
+ out = 0
19
+ for i in range(mc_iter):
20
+ b = torch.randn_like(input)
21
+ x = torch.autograd.grad(output, input, b, retain_graph=True, create_graph=True)[
22
+ 0
23
+ ]
24
+ out += (b * x).mean()
25
+
26
+ return out / mc_iter
27
+
28
+
29
+ def exact_div(y, physics, model):
30
+ r"""
31
+ Exact divergence for A(f(x)).
32
+
33
+ :param torch.Tensor y: Measurements.
34
+ :param deepinv.physics.Physics physics: Forward operator associated with the measurements.
35
+ :param torch.nn.Module model: Reconstruction network.
36
+ :param int mc_iter: number of iterations. Default=1.
37
+ :return: (float) exact divergence.
38
+ """
39
+ input = y.requires_grad_(True)
40
+ output = physics.A(model(input, physics))
41
+ out = 0
42
+ _, c, h, w = input.shape
43
+ for i in range(c):
44
+ for j in range(h):
45
+ for k in range(w):
46
+ b = torch.zeros_like(input)
47
+ b[:, i, j, k] = 1
48
+ x = torch.autograd.grad(
49
+ output, input, b, retain_graph=True, create_graph=True
50
+ )[0]
51
+ out += (b * x).sum()
52
+
53
+ return out / (c * h * w)
54
+
55
+
56
+ def mc_div(y1, y, f, physics, tau):
57
+ r"""
58
+ Monte-Carlo estimation for the divergence of A(f(x)).
59
+
60
+ :param torch.Tensor y: Measurements.
61
+ :param deepinv.physics.Physics physics: Forward operator associated with the measurements.
62
+ :param torch.nn.Module f: Reconstruction network.
63
+ :param int mc_iter: number of iterations. Default=1.
64
+ :return: (float) hutch divergence.
65
+ """
66
+ b = torch.randn_like(y)
67
+ y2 = physics.A(f(y + b * tau, physics))
68
+ out = (b * (y2 - y1) / tau).mean()
69
+ return out
70
+
71
+
72
+ class SureGaussianLoss(nn.Module):
73
+ r"""
74
+ SURE loss for Gaussian noise
75
+
76
+ The loss is designed for the following noise model:
77
+
78
+ .. math::
79
+
80
+ y \sim\mathcal{N}(u,\sigma^2 I) \quad \text{with}\quad u= A(x).
81
+
82
+ The loss is computed as
83
+
84
+ .. math::
85
+
86
+ \frac{1}{m}\|y - A\inverse{y}\|_2^2 -\sigma^2 +\frac{2\sigma^2}{m\tau}b^{\top} \left(A\inverse{y+\tau b_i} -
87
+ A\inverse{y}\right)
88
+
89
+ where :math:`R` is the trainable network, :math:`A` is the forward operator,
90
+ :math:`y` is the noisy measurement vector of size :math:`m`, :math:`A` is the forward operator,
91
+ :math:`b\sim\mathcal{N}(0,I)` and :math:`\tau\geq 0` is a hyperparameter controlling the
92
+ Monte Carlo approximation of the divergence.
93
+
94
+ This loss approximates the divergence of :math:`A\inverse{y}` (in the original SURE loss)
95
+ using the Monte Carlo approximation in
96
+ https://ieeexplore.ieee.org/abstract/document/4099398/
97
+
98
+ If the measurement data is truly Gaussian with standard deviation :math:`\sigma`,
99
+ this loss is an unbiased estimator of the mean squared loss :math:`\frac{1}{m}\|u-A\inverse{y}\|_2^2`
100
+ where :math:`z` is the noiseless measurement.
101
+
102
+ .. warning::
103
+
104
+ The loss can be sensitive to the choice of :math:`\tau`, which should be proportional to the size of :math:`y`.
105
+ The default value of 0.01 is adapted to :math:`y` vectors with entries in :math:`[0,1]`.
106
+
107
+ :param float sigma: Standard deviation of the Gaussian noise.
108
+ :param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
109
+ """
110
+
111
+ def __init__(self, sigma, tau=1e-2):
112
+ super(SureGaussianLoss, self).__init__()
113
+ self.name = "SureGaussian"
114
+ self.sigma2 = sigma**2
115
+ self.tau = tau
116
+
117
+ def forward(self, y, x_net, physics, model, **kwargs):
118
+ r"""
119
+ Computes the SURE Loss.
120
+
121
+ :param torch.Tensor y: Measurements.
122
+ :param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
123
+ :param deepinv.physics.Physics physics: Forward operator associated with the measurements.
124
+ :param torch.nn.Module model: Reconstruction network.
125
+ :return: (float) SURE loss.
126
+ """
127
+
128
+ y1 = physics.A(x_net)
129
+ div = 2 * self.sigma2 * mc_div(y1, y, model, physics, self.tau)
130
+ mse = (y1 - y).pow(2).mean()
131
+ loss_sure = mse + div - self.sigma2
132
+ return loss_sure
133
+
134
+
135
+ class SurePoissonLoss(nn.Module):
136
+ r"""
137
+ SURE loss for Poisson noise
138
+
139
+ The loss is designed for the following noise model:
140
+
141
+ .. math::
142
+
143
+ y = \gamma z \quad \text{with}\quad z\sim \mathcal{P}(\frac{u}{\gamma}), \quad u=A(x).
144
+
145
+ The loss is computed as
146
+
147
+ .. math::
148
+
149
+ \frac{1}{m}\|y-A\inverse{y}\|_2^2-\frac{\gamma}{m} 1^{\top}y
150
+ +\frac{2\gamma}{m\tau}(b\odot y)^{\top} \left(A\inverse{y+\tau b}-A\inverse{y}\right)
151
+
152
+ where :math:`R` is the trainable network, :math:`y` is the noisy measurement vector of size :math:`m`,
153
+ :math:`b` is a Bernoulli random variable taking values of -1 and 1 each with a probability of 0.5,
154
+ :math:`\tau` is a small positive number, and :math:`\odot` is an elementwise multiplication.
155
+
156
+ See https://ieeexplore.ieee.org/abstract/document/6714502/ for details.
157
+ If the measurement data is truly Poisson
158
+ this loss is an unbiased estimator of the mean squared loss :math:`\frac{1}{m}\|u-A\inverse{y}\|_2^2`
159
+ where :math:`z` is the noiseless measurement.
160
+
161
+ .. warning::
162
+
163
+ The loss can be sensitive to the choice of :math:`\tau`, which should be proportional to the size of :math:`y`.
164
+ The default value of 0.01 is adapted to :math:`y` vectors with entries in :math:`[0,1]`.
165
+
166
+ :param float gain: Gain of the Poisson Noise.
167
+ :param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
168
+ """
169
+
170
+ def __init__(self, gain, tau=1e-3):
171
+ super(SurePoissonLoss, self).__init__()
172
+ self.name = "SurePoisson"
173
+ self.gain = gain
174
+ self.tau = tau
175
+
176
+ def forward(self, y, x_net, physics, model, **kwargs):
177
+ r"""
178
+ Computes the SURE loss.
179
+
180
+ :param torch.Tensor y: measurements.
181
+ :param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
182
+ :param deepinv.physics.Physics physics: Forward operator associated with the measurements
183
+ :param torch.nn.Module model: Reconstruction network
184
+ :return: (float) SURE loss.
185
+ """
186
+
187
+ # generate a random vector b
188
+ b = torch.rand_like(y) > 0.5
189
+ b = (2 * b - 1) * 1.0 # binary [-1, 1]
190
+
191
+ y1 = physics.A(x_net)
192
+ y2 = physics.A(model(y + self.tau * b, physics))
193
+
194
+ # compute m (size of y)
195
+ # m = y.numel() #(torch.abs(y) > 1e-5).flatten().sum()
196
+
197
+ loss_sure = (
198
+ (y1 - y).pow(2).mean()
199
+ - self.gain * y.mean()
200
+ + 2.0 / self.tau * (b * y * self.gain * (y2 - y1)).mean()
201
+ )
202
+
203
+ return loss_sure
204
+
205
+
206
+ class SurePGLoss(nn.Module):
207
+ r"""
208
+ SURE loss for Poisson-Gaussian noise
209
+
210
+ The loss is designed for the following noise model:
211
+
212
+ .. math::
213
+
214
+ y = \gamma z + \epsilon
215
+
216
+ where :math:`u = A(x)`, :math:`z \sim \mathcal{P}\left(\frac{u}{\gamma}\right)`,
217
+ and :math:`\epsilon \sim \mathcal{N}(0, \sigma^2 I)`.
218
+
219
+ The loss is computed as
220
+
221
+ .. math::
222
+
223
+ & \frac{1}{m}\|y-A\inverse{y}\|_2^2-\frac{\gamma}{m} 1^{\top}y-\sigma^2
224
+ +\frac{2}{m\tau_1}(b\odot (\gamma y + \sigma^2 I))^{\top} \left(A\inverse{y+\tau b}-A\inverse{y} \right) \\\\
225
+ & +\frac{2\gamma \sigma^2}{m\tau_2^2}c^{\top} \left( A\inverse{y+\tau c} + A\inverse{y-\tau c} - 2A\inverse{y} \right)
226
+
227
+ where :math:`R` is the trainable network, :math:`y` is the noisy measurement vector,
228
+ :math:`b` is a Bernoulli random variable taking values of -1 and 1 each with a probability of 0.5,
229
+ :math:`\tau` is a small positive number, and :math:`\odot` is an elementwise multiplication.
230
+
231
+ If the measurement data is truly Poisson-Gaussian
232
+ this loss is an unbiased estimator of the mean squared loss :math:`\frac{1}{m}\|u-A\inverse{y}\|_2^2`
233
+ where :math:`z` is the noiseless measurement.
234
+
235
+ See https://ieeexplore.ieee.org/abstract/document/6714502/ for details.
236
+
237
+ .. warning::
238
+
239
+ The loss can be sensitive to the choice of :math:`\tau`, which should be proportional to the size of :math:`y`.
240
+ The default value of 0.01 is adapted to :math:`y` vectors with entries in :math:`[0,1]`.
241
+
242
+ :param float sigma: Standard deviation of the Gaussian noise.
243
+ :param float gamma: Gain of the Poisson Noise.
244
+ :param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
245
+ """
246
+
247
+ def __init__(self, sigma, gain, tau1=1e-3, tau2=1e-2):
248
+ super(SurePGLoss, self).__init__()
249
+ self.name = "SurePG"
250
+ # self.sure_loss_weight = sure_loss_weight
251
+ self.sigma2 = sigma**2
252
+ self.gain = gain
253
+ self.tau1 = tau1
254
+ self.tau2 = tau2
255
+
256
+ def forward(self, y, x_net, physics, model, **kwargs):
257
+ r"""
258
+ Computes the SURE loss.
259
+
260
+ :param torch.Tensor y: measurements.
261
+ :param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
262
+ :param deepinv.physics.Physics physics: Forward operator associated with the measurements
263
+ :param torch.nn.Module f: Reconstruction network
264
+ :return: (float) SURE loss.
265
+ """
266
+
267
+ b1 = torch.rand_like(y) > 0.5
268
+ b1 = (2 * b1 - 1) * 1.0 # binary [-1, 1]
269
+
270
+ p = 0.7236 # .5 + .5*np.sqrt(1/5.)
271
+
272
+ b2 = torch.ones_like(b1) * np.sqrt(p / (1 - p))
273
+ b2[torch.rand_like(b2) < p] = -np.sqrt((1 - p) / p)
274
+
275
+ meas1 = physics.A(x_net)
276
+ meas2 = physics.A(model(y + self.tau1 * b1, physics))
277
+ meas2p = physics.A(model(y + self.tau2 * b2, physics))
278
+ meas2n = physics.A(model(y - self.tau2 * b2, physics))
279
+
280
+ # compute m (size of y)
281
+ # m = (torch.abs(y) > 1e-5).flatten().sum()
282
+
283
+ loss_mc = (meas1 - y).pow(2).mean()
284
+
285
+ loss_div1 = (
286
+ 2
287
+ / self.tau1
288
+ * ((b1 * (self.gain * y + self.sigma2)) * (meas2 - meas1)).mean()
289
+ )
290
+
291
+ offset = -self.gain * y.mean() - self.sigma2
292
+
293
+ loss_div2 = (
294
+ -2
295
+ * self.sigma2
296
+ * self.gain
297
+ / (self.tau2**2)
298
+ * (b2 * (meas2p + meas2n - 2 * meas1)).mean()
299
+ )
300
+
301
+ loss_sure = loss_mc + loss_div1 + loss_div2 + offset
302
+ return loss_sure
303
+
304
+
305
+ # if __name__ == "__main__":
306
+ # from deepinv.models import Denoiser
307
+ # import deepinv as dinv
308
+ #
309
+ # model_spec = {
310
+ # "name": "waveletprior",
311
+ # "args": {"wv": "db8", "level": 3, "device": dinv.device},
312
+ # }
313
+ # f = dinv.models.ArtifactRemoval(Denoiser(model_spec))
314
+ # # test divergence
315
+ #
316
+ # x = torch.ones((1, 3, 16, 16), device=dinv.device) * 0.5
317
+ # physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(0.1))
318
+ # y = physics(x)
319
+ #
320
+ # y1 = f(y, physics)
321
+ # tau = 1e-4
322
+ #
323
+ # exact = exact_div(y, physics, f)
324
+ #
325
+ # error_h = 0
326
+ # error_mc = 0
327
+ # for i in range(100):
328
+ # h = hutch_div(y, physics, f)
329
+ # mc = mc_div(y1, y, f, physics, tau)
330
+ #
331
+ # error_h += torch.abs(h - exact)
332
+ # error_mc += torch.abs(mc - exact)
333
+ #
334
+ # error_mc /= 100
335
+ # error_h /= 100
336
+ #
337
+ # print(f"error_h: {error_h}")
338
+ # print(f"error_mc: {error_mc}")
deepinv/loss/tv.py ADDED
@@ -0,0 +1,39 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class TVLoss(nn.Module):
6
+ r"""
7
+ Total variation loss (:math:`\ell_2` norm).
8
+
9
+ It computes the loss :math:`\|D\hat{x}\|_2^2`,
10
+ where :math:`D` is a normalized linear operator that computes the vertical and horizontal first order differences
11
+ of the reconstructed image :math:`\hat{x}`.
12
+
13
+ :param float weight: scalar weight for the TV loss.
14
+ """
15
+
16
+ def __init__(self, weight=1.0):
17
+ super(TVLoss, self).__init__()
18
+ self.tv_loss_weight = weight
19
+ self.name = "tv"
20
+
21
+ def forward(self, x_net, **kwargs):
22
+ r"""
23
+ Computes the TV loss.
24
+
25
+ :param torch.Tensor x_net: reconstructed image.
26
+ :return: (torch.Tensor) loss.
27
+ """
28
+ batch_size = x_net.size()[0]
29
+ h_x = x_net.size()[2]
30
+ w_x = x_net.size()[3]
31
+ count_h = self.tensor_size(x_net[:, :, 1:, :])
32
+ count_w = self.tensor_size(x_net[:, :, :, 1:])
33
+ h_tv = torch.pow((x_net[:, :, 1:, :] - x_net[:, :, : h_x - 1, :]), 2).sum()
34
+ w_tv = torch.pow((x_net[:, :, :, 1:] - x_net[:, :, :, : w_x - 1]), 2).sum()
35
+ return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
36
+
37
+ @staticmethod
38
+ def tensor_size(t):
39
+ return t.size()[1] * t.size()[2] * t.size()[3]
@@ -0,0 +1,129 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .utils import get_weights_url
4
+
5
+
6
+ class StudentGrad(nn.Module):
7
+ def __init__(self, denoiser):
8
+ super().__init__()
9
+ self.model = denoiser
10
+
11
+ def forward(self, x, sigma):
12
+ return self.model(x, sigma)
13
+
14
+
15
+ class GSPnP(nn.Module):
16
+ r"""
17
+ Gradient Step module to use a denoiser architecture as a Gradient Step Denoiser.
18
+ See https://arxiv.org/pdf/2110.03220.pdf.
19
+ Code from https://github.com/samuro95/GSPnP.
20
+
21
+ :param nn.Module denoiser: Denoiser model.
22
+ :param float alpha: Relaxation parameter
23
+ """
24
+
25
+ def __init__(self, denoiser, alpha=1.0, train=False):
26
+ super().__init__()
27
+ self.student_grad = StudentGrad(denoiser)
28
+ self.alpha = alpha
29
+ self.train = train
30
+
31
+ def potential(self, x, sigma):
32
+ N = self.student_grad(x, sigma)
33
+ return (
34
+ 0.5
35
+ * self.alpha
36
+ * torch.norm((x - N).view(x.shape[0], -1), p=2, dim=-1) ** 2
37
+ )
38
+
39
+ def potential_grad(self, x, sigma):
40
+ r"""
41
+ Calculate :math:`\nabla g` the gradient of the regularizer :math:`g` at input :math:`x`.
42
+
43
+ :param torch.tensor x: Input image
44
+ :param float sigma: Denoiser level :math:`\sigma` (std)
45
+ """
46
+ torch.set_grad_enabled(True)
47
+ x = x.float()
48
+ x = x.requires_grad_()
49
+ N = self.student_grad(x, sigma)
50
+ JN = torch.autograd.grad(
51
+ N, x, grad_outputs=x - N, create_graph=True, only_inputs=True
52
+ )[0]
53
+ if not self.train:
54
+ torch.set_grad_enabled(False)
55
+ Dg = x - N - JN
56
+ return self.alpha * Dg
57
+
58
+ def forward(self, x, sigma):
59
+ r"""
60
+ Denoising with Gradient Step Denoiser
61
+
62
+ :param torch.tensor x: Input image
63
+ :param float sigma: Denoiser level (std)
64
+ """
65
+ Dg = self.potential_grad(x, sigma)
66
+ x_hat = x - Dg
67
+ return x_hat
68
+
69
+
70
+ def GSDRUNet(
71
+ alpha=1.0,
72
+ in_channels=3,
73
+ out_channels=3,
74
+ nb=2,
75
+ nc=[64, 128, 256, 512],
76
+ act_mode="E",
77
+ pretrained=None,
78
+ train=False,
79
+ device=torch.device("cpu"),
80
+ ):
81
+ """
82
+ Gradient Step Denoiser with DRUNet architecture
83
+
84
+ :param float alpha: Relaxation parameter
85
+ :param int in_channels: Number of input channels
86
+ :param int out_channels: Number of output channels
87
+ :param int nb: Number of blocks in the DRUNet
88
+ :param list nc: Number of channels in the DRUNet
89
+ :param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus.
90
+ :param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and
91
+ "strideconv" for convolution with stride 2.
92
+ :param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel
93
+ shuffling, and "upconv" for nearest neighbour upsampling with additional convolution.
94
+ :param bool download: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random
95
+ using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an
96
+ online repository (only available for the default architecture).
97
+ Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
98
+ See :ref:`pretrained-weights <pretrained-weights>` for more details.
99
+ :param bool train: training or testing mode.
100
+ :param str device: gpu or cpu.
101
+
102
+ """
103
+ from deepinv.models.drunet import DRUNet
104
+
105
+ denoiser = DRUNet(
106
+ in_channels=in_channels,
107
+ out_channels=out_channels,
108
+ nb=nb,
109
+ nc=nc,
110
+ act_mode=act_mode,
111
+ pretrained=None,
112
+ train=train,
113
+ device=device,
114
+ )
115
+ GSmodel = GSPnP(denoiser, alpha=alpha, train=train)
116
+ if pretrained:
117
+ if pretrained == "download":
118
+ url = get_weights_url(model_name="gradientstep", file_name="GSDRUNet.ckpt")
119
+ ckpt = torch.hub.load_state_dict_from_url(
120
+ url,
121
+ map_location=lambda storage, loc: storage,
122
+ file_name="GSDRUNet.ckpt",
123
+ )["state_dict"]
124
+ else:
125
+ ckpt = torch.load(pretrained, map_location=lambda storage, loc: storage)[
126
+ "state_dict"
127
+ ]
128
+ GSmodel.load_state_dict(ckpt, strict=False)
129
+ return GSmodel
@@ -0,0 +1,109 @@
1
+ # This is an implementation of https://arxiv.org/abs/1707.06474
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ def init_weights(m):
7
+ if isinstance(m, nn.Linear):
8
+ torch.nn.init.xavier_uniform(m.weight)
9
+ m.bias.data.fill_(0.0)
10
+
11
+
12
+ class PDNet_PrimalBlock(nn.Module):
13
+ def __init__(self, in_channels=6, out_channels=5, depth=3, bias=True, nf=32):
14
+ r"""
15
+ Primal block for the Primal-Dual unfolding model (PDNet) from https://arxiv.org/abs/1707.06474.
16
+
17
+ Primal variables are images of shape (batch_size, in_channels, height, width). The input of each
18
+ primal block is the concatenation of the current primal variable and the backprojected dual variable along
19
+ the channel dimension. The output of each primal block is the current primal variable.
20
+
21
+ :param int in_channels: number of input channels. Default: 6.
22
+ :param int out_channels: number of output channels. Default: 5.
23
+ :param int depth: number of convolutional layers in the block. Default: 3.
24
+ :param bool bias: whether to use bias in convolutional layers. Default: True.
25
+ :param int nf: number of features in the convolutional layers. Default: 32.
26
+ """
27
+ super(PDNet_PrimalBlock, self).__init__()
28
+
29
+ self.depth = depth
30
+
31
+ self.in_conv = nn.Conv2d(
32
+ in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
33
+ )
34
+ self.in_conv.apply(init_weights)
35
+ self.conv_list = nn.ModuleList(
36
+ [
37
+ nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
38
+ for _ in range(self.depth - 2)
39
+ ]
40
+ )
41
+ self.conv_list.apply(init_weights)
42
+ self.out_conv = nn.Conv2d(
43
+ nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
44
+ )
45
+ self.out_conv.apply(init_weights)
46
+
47
+ self.nl_list = nn.ModuleList([nn.PReLU() for _ in range(self.depth - 1)])
48
+
49
+ def forward(self, x, Atu):
50
+ x_in = torch.cat((x, Atu), dim=1)
51
+
52
+ x_ = self.in_conv(x_in)
53
+ x_ = self.nl_list[0](x_)
54
+
55
+ for i in range(self.depth - 2):
56
+ x_l = self.conv_list[i](x_)
57
+ x_ = self.nl_list[i + 1](x_l)
58
+
59
+ return self.out_conv(x_) + x
60
+
61
+
62
+ class PDNet_DualBlock(nn.Module):
63
+ def __init__(self, in_channels=7, out_channels=5, depth=3, bias=True, nf=32):
64
+ r"""
65
+ Dual block for the Primal-Dual unfolding model (PDNet) from https://arxiv.org/abs/1707.06474.
66
+
67
+ Dual variables are images of shape (batch_size, in_channels, height, width). The input of each
68
+ primal block is the concatenation of the current dual variable with the projected primal variable and
69
+ the measurements. The output of each dual block is the current primal variable.
70
+
71
+ :param int in_channels: number of input channels. Default: 7.
72
+ :param int out_channels: number of output channels. Default: 5.
73
+ :param int depth: number of convolutional layers in the block. Default: 3.
74
+ :param bool bias: whether to use bias in convolutional layers. Default: True.
75
+ :param int nf: number of features in the convolutional layers. Default: 32.
76
+ """
77
+ super(PDNet_DualBlock, self).__init__()
78
+
79
+ self.depth = depth
80
+
81
+ self.in_conv = nn.Conv2d(
82
+ in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
83
+ )
84
+ self.in_conv.apply(init_weights)
85
+ self.conv_list = nn.ModuleList(
86
+ [
87
+ nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
88
+ for _ in range(self.depth - 2)
89
+ ]
90
+ )
91
+ self.conv_list.apply(init_weights)
92
+ self.out_conv = nn.Conv2d(
93
+ nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
94
+ )
95
+ self.out_conv.apply(init_weights)
96
+
97
+ self.nl_list = nn.ModuleList([nn.PReLU() for _ in range(self.depth - 1)])
98
+
99
+ def forward(self, u, Ax_cur, y):
100
+ x_in = torch.cat((u, Ax_cur, y), dim=1)
101
+
102
+ x_ = self.in_conv(x_in)
103
+ x_ = self.nl_list[0](x_)
104
+
105
+ for i in range(self.depth - 2):
106
+ x_l = self.conv_list[i](x_)
107
+ x_ = self.nl_list[i + 1](x_l)
108
+
109
+ return self.out_conv(x_) + u
@@ -0,0 +1,17 @@
1
+ from .drunet import DRUNet
2
+ from .scunet import SCUNet
3
+ from .ae import AutoEncoder
4
+ from .unet import UNet
5
+ from .dncnn import DnCNN
6
+ from .artifactremoval import ArtifactRemoval
7
+ from .tgv import TGV as TGV
8
+ from .tv import TV as TV
9
+ from .wavdict import WaveletPrior, WaveletDict
10
+ from .GSPnP import GSDRUNet
11
+ from .median import MedianFilter
12
+ from .dip import DeepImagePrior, ConvDecoder
13
+ from .diffunet import DiffUNet
14
+ from .swinir import SwinIR
15
+ from .PDNet import PDNet_PrimalBlock, PDNet_DualBlock
16
+ from .bm3d import BM3D
17
+ from .equivariant import EquivariantDenoiser
deepinv/models/ae.py ADDED
@@ -0,0 +1,43 @@
1
+ import torch
2
+
3
+
4
+ class AutoEncoder(torch.nn.Module):
5
+ r"""
6
+ Simple fully connected autoencoder network.
7
+
8
+ Simple architecture that can be used for debugging or fast prototyping.
9
+
10
+ :param int dim_input: total number of elements (pixels) of the input.
11
+ :param int dim_hid: number of features in intermediate layer.
12
+ :param int dim_hid: latent space dimension.
13
+ :param int residual: use a residual connection between input and output.
14
+
15
+ """
16
+
17
+ def __init__(self, dim_input, dim_mid=1000, dim_hid=32, residual=True):
18
+ super().__init__()
19
+ self.residual = residual
20
+
21
+ self.encoder = torch.nn.Sequential(
22
+ torch.nn.Linear(dim_input, dim_mid),
23
+ torch.nn.ReLU(),
24
+ torch.nn.Linear(dim_mid, dim_hid),
25
+ )
26
+ self.decoder = torch.nn.Sequential(
27
+ torch.nn.Linear(dim_hid, dim_mid),
28
+ torch.nn.ReLU(),
29
+ torch.nn.Linear(dim_mid, dim_input),
30
+ )
31
+
32
+ def forward(self, x, sigma=None):
33
+ N, C, H, W = x.shape
34
+ x = x.view(N, -1)
35
+
36
+ encoded = self.encoder(x)
37
+ decoded = self.decoder(encoded)
38
+
39
+ if self.residual:
40
+ decoded = decoded + x
41
+
42
+ decoded = decoded.view(N, C, H, W)
43
+ return decoded