deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/loss/sure.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def hutch_div(y, physics, f, mc_iter=1):
|
|
7
|
+
r"""
|
|
8
|
+
Hutch divergence for A(f(x)).
|
|
9
|
+
|
|
10
|
+
:param torch.Tensor y: Measurements.
|
|
11
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
|
|
12
|
+
:param torch.nn.Module f: Reconstruction network.
|
|
13
|
+
:param int mc_iter: number of iterations. Default=1.
|
|
14
|
+
:return: (float) hutch divergence.
|
|
15
|
+
"""
|
|
16
|
+
input = y.requires_grad_(True)
|
|
17
|
+
output = physics.A(f(input, physics))
|
|
18
|
+
out = 0
|
|
19
|
+
for i in range(mc_iter):
|
|
20
|
+
b = torch.randn_like(input)
|
|
21
|
+
x = torch.autograd.grad(output, input, b, retain_graph=True, create_graph=True)[
|
|
22
|
+
0
|
|
23
|
+
]
|
|
24
|
+
out += (b * x).mean()
|
|
25
|
+
|
|
26
|
+
return out / mc_iter
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def exact_div(y, physics, model):
|
|
30
|
+
r"""
|
|
31
|
+
Exact divergence for A(f(x)).
|
|
32
|
+
|
|
33
|
+
:param torch.Tensor y: Measurements.
|
|
34
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
|
|
35
|
+
:param torch.nn.Module model: Reconstruction network.
|
|
36
|
+
:param int mc_iter: number of iterations. Default=1.
|
|
37
|
+
:return: (float) exact divergence.
|
|
38
|
+
"""
|
|
39
|
+
input = y.requires_grad_(True)
|
|
40
|
+
output = physics.A(model(input, physics))
|
|
41
|
+
out = 0
|
|
42
|
+
_, c, h, w = input.shape
|
|
43
|
+
for i in range(c):
|
|
44
|
+
for j in range(h):
|
|
45
|
+
for k in range(w):
|
|
46
|
+
b = torch.zeros_like(input)
|
|
47
|
+
b[:, i, j, k] = 1
|
|
48
|
+
x = torch.autograd.grad(
|
|
49
|
+
output, input, b, retain_graph=True, create_graph=True
|
|
50
|
+
)[0]
|
|
51
|
+
out += (b * x).sum()
|
|
52
|
+
|
|
53
|
+
return out / (c * h * w)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def mc_div(y1, y, f, physics, tau):
|
|
57
|
+
r"""
|
|
58
|
+
Monte-Carlo estimation for the divergence of A(f(x)).
|
|
59
|
+
|
|
60
|
+
:param torch.Tensor y: Measurements.
|
|
61
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
|
|
62
|
+
:param torch.nn.Module f: Reconstruction network.
|
|
63
|
+
:param int mc_iter: number of iterations. Default=1.
|
|
64
|
+
:return: (float) hutch divergence.
|
|
65
|
+
"""
|
|
66
|
+
b = torch.randn_like(y)
|
|
67
|
+
y2 = physics.A(f(y + b * tau, physics))
|
|
68
|
+
out = (b * (y2 - y1) / tau).mean()
|
|
69
|
+
return out
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class SureGaussianLoss(nn.Module):
|
|
73
|
+
r"""
|
|
74
|
+
SURE loss for Gaussian noise
|
|
75
|
+
|
|
76
|
+
The loss is designed for the following noise model:
|
|
77
|
+
|
|
78
|
+
.. math::
|
|
79
|
+
|
|
80
|
+
y \sim\mathcal{N}(u,\sigma^2 I) \quad \text{with}\quad u= A(x).
|
|
81
|
+
|
|
82
|
+
The loss is computed as
|
|
83
|
+
|
|
84
|
+
.. math::
|
|
85
|
+
|
|
86
|
+
\frac{1}{m}\|y - A\inverse{y}\|_2^2 -\sigma^2 +\frac{2\sigma^2}{m\tau}b^{\top} \left(A\inverse{y+\tau b_i} -
|
|
87
|
+
A\inverse{y}\right)
|
|
88
|
+
|
|
89
|
+
where :math:`R` is the trainable network, :math:`A` is the forward operator,
|
|
90
|
+
:math:`y` is the noisy measurement vector of size :math:`m`, :math:`A` is the forward operator,
|
|
91
|
+
:math:`b\sim\mathcal{N}(0,I)` and :math:`\tau\geq 0` is a hyperparameter controlling the
|
|
92
|
+
Monte Carlo approximation of the divergence.
|
|
93
|
+
|
|
94
|
+
This loss approximates the divergence of :math:`A\inverse{y}` (in the original SURE loss)
|
|
95
|
+
using the Monte Carlo approximation in
|
|
96
|
+
https://ieeexplore.ieee.org/abstract/document/4099398/
|
|
97
|
+
|
|
98
|
+
If the measurement data is truly Gaussian with standard deviation :math:`\sigma`,
|
|
99
|
+
this loss is an unbiased estimator of the mean squared loss :math:`\frac{1}{m}\|u-A\inverse{y}\|_2^2`
|
|
100
|
+
where :math:`z` is the noiseless measurement.
|
|
101
|
+
|
|
102
|
+
.. warning::
|
|
103
|
+
|
|
104
|
+
The loss can be sensitive to the choice of :math:`\tau`, which should be proportional to the size of :math:`y`.
|
|
105
|
+
The default value of 0.01 is adapted to :math:`y` vectors with entries in :math:`[0,1]`.
|
|
106
|
+
|
|
107
|
+
:param float sigma: Standard deviation of the Gaussian noise.
|
|
108
|
+
:param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(self, sigma, tau=1e-2):
|
|
112
|
+
super(SureGaussianLoss, self).__init__()
|
|
113
|
+
self.name = "SureGaussian"
|
|
114
|
+
self.sigma2 = sigma**2
|
|
115
|
+
self.tau = tau
|
|
116
|
+
|
|
117
|
+
def forward(self, y, x_net, physics, model, **kwargs):
|
|
118
|
+
r"""
|
|
119
|
+
Computes the SURE Loss.
|
|
120
|
+
|
|
121
|
+
:param torch.Tensor y: Measurements.
|
|
122
|
+
:param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
|
|
123
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
|
|
124
|
+
:param torch.nn.Module model: Reconstruction network.
|
|
125
|
+
:return: (float) SURE loss.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
y1 = physics.A(x_net)
|
|
129
|
+
div = 2 * self.sigma2 * mc_div(y1, y, model, physics, self.tau)
|
|
130
|
+
mse = (y1 - y).pow(2).mean()
|
|
131
|
+
loss_sure = mse + div - self.sigma2
|
|
132
|
+
return loss_sure
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class SurePoissonLoss(nn.Module):
|
|
136
|
+
r"""
|
|
137
|
+
SURE loss for Poisson noise
|
|
138
|
+
|
|
139
|
+
The loss is designed for the following noise model:
|
|
140
|
+
|
|
141
|
+
.. math::
|
|
142
|
+
|
|
143
|
+
y = \gamma z \quad \text{with}\quad z\sim \mathcal{P}(\frac{u}{\gamma}), \quad u=A(x).
|
|
144
|
+
|
|
145
|
+
The loss is computed as
|
|
146
|
+
|
|
147
|
+
.. math::
|
|
148
|
+
|
|
149
|
+
\frac{1}{m}\|y-A\inverse{y}\|_2^2-\frac{\gamma}{m} 1^{\top}y
|
|
150
|
+
+\frac{2\gamma}{m\tau}(b\odot y)^{\top} \left(A\inverse{y+\tau b}-A\inverse{y}\right)
|
|
151
|
+
|
|
152
|
+
where :math:`R` is the trainable network, :math:`y` is the noisy measurement vector of size :math:`m`,
|
|
153
|
+
:math:`b` is a Bernoulli random variable taking values of -1 and 1 each with a probability of 0.5,
|
|
154
|
+
:math:`\tau` is a small positive number, and :math:`\odot` is an elementwise multiplication.
|
|
155
|
+
|
|
156
|
+
See https://ieeexplore.ieee.org/abstract/document/6714502/ for details.
|
|
157
|
+
If the measurement data is truly Poisson
|
|
158
|
+
this loss is an unbiased estimator of the mean squared loss :math:`\frac{1}{m}\|u-A\inverse{y}\|_2^2`
|
|
159
|
+
where :math:`z` is the noiseless measurement.
|
|
160
|
+
|
|
161
|
+
.. warning::
|
|
162
|
+
|
|
163
|
+
The loss can be sensitive to the choice of :math:`\tau`, which should be proportional to the size of :math:`y`.
|
|
164
|
+
The default value of 0.01 is adapted to :math:`y` vectors with entries in :math:`[0,1]`.
|
|
165
|
+
|
|
166
|
+
:param float gain: Gain of the Poisson Noise.
|
|
167
|
+
:param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(self, gain, tau=1e-3):
|
|
171
|
+
super(SurePoissonLoss, self).__init__()
|
|
172
|
+
self.name = "SurePoisson"
|
|
173
|
+
self.gain = gain
|
|
174
|
+
self.tau = tau
|
|
175
|
+
|
|
176
|
+
def forward(self, y, x_net, physics, model, **kwargs):
|
|
177
|
+
r"""
|
|
178
|
+
Computes the SURE loss.
|
|
179
|
+
|
|
180
|
+
:param torch.Tensor y: measurements.
|
|
181
|
+
:param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
|
|
182
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements
|
|
183
|
+
:param torch.nn.Module model: Reconstruction network
|
|
184
|
+
:return: (float) SURE loss.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
# generate a random vector b
|
|
188
|
+
b = torch.rand_like(y) > 0.5
|
|
189
|
+
b = (2 * b - 1) * 1.0 # binary [-1, 1]
|
|
190
|
+
|
|
191
|
+
y1 = physics.A(x_net)
|
|
192
|
+
y2 = physics.A(model(y + self.tau * b, physics))
|
|
193
|
+
|
|
194
|
+
# compute m (size of y)
|
|
195
|
+
# m = y.numel() #(torch.abs(y) > 1e-5).flatten().sum()
|
|
196
|
+
|
|
197
|
+
loss_sure = (
|
|
198
|
+
(y1 - y).pow(2).mean()
|
|
199
|
+
- self.gain * y.mean()
|
|
200
|
+
+ 2.0 / self.tau * (b * y * self.gain * (y2 - y1)).mean()
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
return loss_sure
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class SurePGLoss(nn.Module):
|
|
207
|
+
r"""
|
|
208
|
+
SURE loss for Poisson-Gaussian noise
|
|
209
|
+
|
|
210
|
+
The loss is designed for the following noise model:
|
|
211
|
+
|
|
212
|
+
.. math::
|
|
213
|
+
|
|
214
|
+
y = \gamma z + \epsilon
|
|
215
|
+
|
|
216
|
+
where :math:`u = A(x)`, :math:`z \sim \mathcal{P}\left(\frac{u}{\gamma}\right)`,
|
|
217
|
+
and :math:`\epsilon \sim \mathcal{N}(0, \sigma^2 I)`.
|
|
218
|
+
|
|
219
|
+
The loss is computed as
|
|
220
|
+
|
|
221
|
+
.. math::
|
|
222
|
+
|
|
223
|
+
& \frac{1}{m}\|y-A\inverse{y}\|_2^2-\frac{\gamma}{m} 1^{\top}y-\sigma^2
|
|
224
|
+
+\frac{2}{m\tau_1}(b\odot (\gamma y + \sigma^2 I))^{\top} \left(A\inverse{y+\tau b}-A\inverse{y} \right) \\\\
|
|
225
|
+
& +\frac{2\gamma \sigma^2}{m\tau_2^2}c^{\top} \left( A\inverse{y+\tau c} + A\inverse{y-\tau c} - 2A\inverse{y} \right)
|
|
226
|
+
|
|
227
|
+
where :math:`R` is the trainable network, :math:`y` is the noisy measurement vector,
|
|
228
|
+
:math:`b` is a Bernoulli random variable taking values of -1 and 1 each with a probability of 0.5,
|
|
229
|
+
:math:`\tau` is a small positive number, and :math:`\odot` is an elementwise multiplication.
|
|
230
|
+
|
|
231
|
+
If the measurement data is truly Poisson-Gaussian
|
|
232
|
+
this loss is an unbiased estimator of the mean squared loss :math:`\frac{1}{m}\|u-A\inverse{y}\|_2^2`
|
|
233
|
+
where :math:`z` is the noiseless measurement.
|
|
234
|
+
|
|
235
|
+
See https://ieeexplore.ieee.org/abstract/document/6714502/ for details.
|
|
236
|
+
|
|
237
|
+
.. warning::
|
|
238
|
+
|
|
239
|
+
The loss can be sensitive to the choice of :math:`\tau`, which should be proportional to the size of :math:`y`.
|
|
240
|
+
The default value of 0.01 is adapted to :math:`y` vectors with entries in :math:`[0,1]`.
|
|
241
|
+
|
|
242
|
+
:param float sigma: Standard deviation of the Gaussian noise.
|
|
243
|
+
:param float gamma: Gain of the Poisson Noise.
|
|
244
|
+
:param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(self, sigma, gain, tau1=1e-3, tau2=1e-2):
|
|
248
|
+
super(SurePGLoss, self).__init__()
|
|
249
|
+
self.name = "SurePG"
|
|
250
|
+
# self.sure_loss_weight = sure_loss_weight
|
|
251
|
+
self.sigma2 = sigma**2
|
|
252
|
+
self.gain = gain
|
|
253
|
+
self.tau1 = tau1
|
|
254
|
+
self.tau2 = tau2
|
|
255
|
+
|
|
256
|
+
def forward(self, y, x_net, physics, model, **kwargs):
|
|
257
|
+
r"""
|
|
258
|
+
Computes the SURE loss.
|
|
259
|
+
|
|
260
|
+
:param torch.Tensor y: measurements.
|
|
261
|
+
:param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
|
|
262
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements
|
|
263
|
+
:param torch.nn.Module f: Reconstruction network
|
|
264
|
+
:return: (float) SURE loss.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
b1 = torch.rand_like(y) > 0.5
|
|
268
|
+
b1 = (2 * b1 - 1) * 1.0 # binary [-1, 1]
|
|
269
|
+
|
|
270
|
+
p = 0.7236 # .5 + .5*np.sqrt(1/5.)
|
|
271
|
+
|
|
272
|
+
b2 = torch.ones_like(b1) * np.sqrt(p / (1 - p))
|
|
273
|
+
b2[torch.rand_like(b2) < p] = -np.sqrt((1 - p) / p)
|
|
274
|
+
|
|
275
|
+
meas1 = physics.A(x_net)
|
|
276
|
+
meas2 = physics.A(model(y + self.tau1 * b1, physics))
|
|
277
|
+
meas2p = physics.A(model(y + self.tau2 * b2, physics))
|
|
278
|
+
meas2n = physics.A(model(y - self.tau2 * b2, physics))
|
|
279
|
+
|
|
280
|
+
# compute m (size of y)
|
|
281
|
+
# m = (torch.abs(y) > 1e-5).flatten().sum()
|
|
282
|
+
|
|
283
|
+
loss_mc = (meas1 - y).pow(2).mean()
|
|
284
|
+
|
|
285
|
+
loss_div1 = (
|
|
286
|
+
2
|
|
287
|
+
/ self.tau1
|
|
288
|
+
* ((b1 * (self.gain * y + self.sigma2)) * (meas2 - meas1)).mean()
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
offset = -self.gain * y.mean() - self.sigma2
|
|
292
|
+
|
|
293
|
+
loss_div2 = (
|
|
294
|
+
-2
|
|
295
|
+
* self.sigma2
|
|
296
|
+
* self.gain
|
|
297
|
+
/ (self.tau2**2)
|
|
298
|
+
* (b2 * (meas2p + meas2n - 2 * meas1)).mean()
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
loss_sure = loss_mc + loss_div1 + loss_div2 + offset
|
|
302
|
+
return loss_sure
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# if __name__ == "__main__":
|
|
306
|
+
# from deepinv.models import Denoiser
|
|
307
|
+
# import deepinv as dinv
|
|
308
|
+
#
|
|
309
|
+
# model_spec = {
|
|
310
|
+
# "name": "waveletprior",
|
|
311
|
+
# "args": {"wv": "db8", "level": 3, "device": dinv.device},
|
|
312
|
+
# }
|
|
313
|
+
# f = dinv.models.ArtifactRemoval(Denoiser(model_spec))
|
|
314
|
+
# # test divergence
|
|
315
|
+
#
|
|
316
|
+
# x = torch.ones((1, 3, 16, 16), device=dinv.device) * 0.5
|
|
317
|
+
# physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(0.1))
|
|
318
|
+
# y = physics(x)
|
|
319
|
+
#
|
|
320
|
+
# y1 = f(y, physics)
|
|
321
|
+
# tau = 1e-4
|
|
322
|
+
#
|
|
323
|
+
# exact = exact_div(y, physics, f)
|
|
324
|
+
#
|
|
325
|
+
# error_h = 0
|
|
326
|
+
# error_mc = 0
|
|
327
|
+
# for i in range(100):
|
|
328
|
+
# h = hutch_div(y, physics, f)
|
|
329
|
+
# mc = mc_div(y1, y, f, physics, tau)
|
|
330
|
+
#
|
|
331
|
+
# error_h += torch.abs(h - exact)
|
|
332
|
+
# error_mc += torch.abs(mc - exact)
|
|
333
|
+
#
|
|
334
|
+
# error_mc /= 100
|
|
335
|
+
# error_h /= 100
|
|
336
|
+
#
|
|
337
|
+
# print(f"error_h: {error_h}")
|
|
338
|
+
# print(f"error_mc: {error_mc}")
|
deepinv/loss/tv.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TVLoss(nn.Module):
|
|
6
|
+
r"""
|
|
7
|
+
Total variation loss (:math:`\ell_2` norm).
|
|
8
|
+
|
|
9
|
+
It computes the loss :math:`\|D\hat{x}\|_2^2`,
|
|
10
|
+
where :math:`D` is a normalized linear operator that computes the vertical and horizontal first order differences
|
|
11
|
+
of the reconstructed image :math:`\hat{x}`.
|
|
12
|
+
|
|
13
|
+
:param float weight: scalar weight for the TV loss.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, weight=1.0):
|
|
17
|
+
super(TVLoss, self).__init__()
|
|
18
|
+
self.tv_loss_weight = weight
|
|
19
|
+
self.name = "tv"
|
|
20
|
+
|
|
21
|
+
def forward(self, x_net, **kwargs):
|
|
22
|
+
r"""
|
|
23
|
+
Computes the TV loss.
|
|
24
|
+
|
|
25
|
+
:param torch.Tensor x_net: reconstructed image.
|
|
26
|
+
:return: (torch.Tensor) loss.
|
|
27
|
+
"""
|
|
28
|
+
batch_size = x_net.size()[0]
|
|
29
|
+
h_x = x_net.size()[2]
|
|
30
|
+
w_x = x_net.size()[3]
|
|
31
|
+
count_h = self.tensor_size(x_net[:, :, 1:, :])
|
|
32
|
+
count_w = self.tensor_size(x_net[:, :, :, 1:])
|
|
33
|
+
h_tv = torch.pow((x_net[:, :, 1:, :] - x_net[:, :, : h_x - 1, :]), 2).sum()
|
|
34
|
+
w_tv = torch.pow((x_net[:, :, :, 1:] - x_net[:, :, :, : w_x - 1]), 2).sum()
|
|
35
|
+
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def tensor_size(t):
|
|
39
|
+
return t.size()[1] * t.size()[2] * t.size()[3]
|
deepinv/models/GSPnP.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .utils import get_weights_url
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class StudentGrad(nn.Module):
|
|
7
|
+
def __init__(self, denoiser):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.model = denoiser
|
|
10
|
+
|
|
11
|
+
def forward(self, x, sigma):
|
|
12
|
+
return self.model(x, sigma)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GSPnP(nn.Module):
|
|
16
|
+
r"""
|
|
17
|
+
Gradient Step module to use a denoiser architecture as a Gradient Step Denoiser.
|
|
18
|
+
See https://arxiv.org/pdf/2110.03220.pdf.
|
|
19
|
+
Code from https://github.com/samuro95/GSPnP.
|
|
20
|
+
|
|
21
|
+
:param nn.Module denoiser: Denoiser model.
|
|
22
|
+
:param float alpha: Relaxation parameter
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, denoiser, alpha=1.0, train=False):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.student_grad = StudentGrad(denoiser)
|
|
28
|
+
self.alpha = alpha
|
|
29
|
+
self.train = train
|
|
30
|
+
|
|
31
|
+
def potential(self, x, sigma):
|
|
32
|
+
N = self.student_grad(x, sigma)
|
|
33
|
+
return (
|
|
34
|
+
0.5
|
|
35
|
+
* self.alpha
|
|
36
|
+
* torch.norm((x - N).view(x.shape[0], -1), p=2, dim=-1) ** 2
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def potential_grad(self, x, sigma):
|
|
40
|
+
r"""
|
|
41
|
+
Calculate :math:`\nabla g` the gradient of the regularizer :math:`g` at input :math:`x`.
|
|
42
|
+
|
|
43
|
+
:param torch.tensor x: Input image
|
|
44
|
+
:param float sigma: Denoiser level :math:`\sigma` (std)
|
|
45
|
+
"""
|
|
46
|
+
torch.set_grad_enabled(True)
|
|
47
|
+
x = x.float()
|
|
48
|
+
x = x.requires_grad_()
|
|
49
|
+
N = self.student_grad(x, sigma)
|
|
50
|
+
JN = torch.autograd.grad(
|
|
51
|
+
N, x, grad_outputs=x - N, create_graph=True, only_inputs=True
|
|
52
|
+
)[0]
|
|
53
|
+
if not self.train:
|
|
54
|
+
torch.set_grad_enabled(False)
|
|
55
|
+
Dg = x - N - JN
|
|
56
|
+
return self.alpha * Dg
|
|
57
|
+
|
|
58
|
+
def forward(self, x, sigma):
|
|
59
|
+
r"""
|
|
60
|
+
Denoising with Gradient Step Denoiser
|
|
61
|
+
|
|
62
|
+
:param torch.tensor x: Input image
|
|
63
|
+
:param float sigma: Denoiser level (std)
|
|
64
|
+
"""
|
|
65
|
+
Dg = self.potential_grad(x, sigma)
|
|
66
|
+
x_hat = x - Dg
|
|
67
|
+
return x_hat
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def GSDRUNet(
|
|
71
|
+
alpha=1.0,
|
|
72
|
+
in_channels=3,
|
|
73
|
+
out_channels=3,
|
|
74
|
+
nb=2,
|
|
75
|
+
nc=[64, 128, 256, 512],
|
|
76
|
+
act_mode="E",
|
|
77
|
+
pretrained=None,
|
|
78
|
+
train=False,
|
|
79
|
+
device=torch.device("cpu"),
|
|
80
|
+
):
|
|
81
|
+
"""
|
|
82
|
+
Gradient Step Denoiser with DRUNet architecture
|
|
83
|
+
|
|
84
|
+
:param float alpha: Relaxation parameter
|
|
85
|
+
:param int in_channels: Number of input channels
|
|
86
|
+
:param int out_channels: Number of output channels
|
|
87
|
+
:param int nb: Number of blocks in the DRUNet
|
|
88
|
+
:param list nc: Number of channels in the DRUNet
|
|
89
|
+
:param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus.
|
|
90
|
+
:param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and
|
|
91
|
+
"strideconv" for convolution with stride 2.
|
|
92
|
+
:param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel
|
|
93
|
+
shuffling, and "upconv" for nearest neighbour upsampling with additional convolution.
|
|
94
|
+
:param bool download: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random
|
|
95
|
+
using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an
|
|
96
|
+
online repository (only available for the default architecture).
|
|
97
|
+
Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
|
|
98
|
+
See :ref:`pretrained-weights <pretrained-weights>` for more details.
|
|
99
|
+
:param bool train: training or testing mode.
|
|
100
|
+
:param str device: gpu or cpu.
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
from deepinv.models.drunet import DRUNet
|
|
104
|
+
|
|
105
|
+
denoiser = DRUNet(
|
|
106
|
+
in_channels=in_channels,
|
|
107
|
+
out_channels=out_channels,
|
|
108
|
+
nb=nb,
|
|
109
|
+
nc=nc,
|
|
110
|
+
act_mode=act_mode,
|
|
111
|
+
pretrained=None,
|
|
112
|
+
train=train,
|
|
113
|
+
device=device,
|
|
114
|
+
)
|
|
115
|
+
GSmodel = GSPnP(denoiser, alpha=alpha, train=train)
|
|
116
|
+
if pretrained:
|
|
117
|
+
if pretrained == "download":
|
|
118
|
+
url = get_weights_url(model_name="gradientstep", file_name="GSDRUNet.ckpt")
|
|
119
|
+
ckpt = torch.hub.load_state_dict_from_url(
|
|
120
|
+
url,
|
|
121
|
+
map_location=lambda storage, loc: storage,
|
|
122
|
+
file_name="GSDRUNet.ckpt",
|
|
123
|
+
)["state_dict"]
|
|
124
|
+
else:
|
|
125
|
+
ckpt = torch.load(pretrained, map_location=lambda storage, loc: storage)[
|
|
126
|
+
"state_dict"
|
|
127
|
+
]
|
|
128
|
+
GSmodel.load_state_dict(ckpt, strict=False)
|
|
129
|
+
return GSmodel
|
deepinv/models/PDNet.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# This is an implementation of https://arxiv.org/abs/1707.06474
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def init_weights(m):
|
|
7
|
+
if isinstance(m, nn.Linear):
|
|
8
|
+
torch.nn.init.xavier_uniform(m.weight)
|
|
9
|
+
m.bias.data.fill_(0.0)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PDNet_PrimalBlock(nn.Module):
|
|
13
|
+
def __init__(self, in_channels=6, out_channels=5, depth=3, bias=True, nf=32):
|
|
14
|
+
r"""
|
|
15
|
+
Primal block for the Primal-Dual unfolding model (PDNet) from https://arxiv.org/abs/1707.06474.
|
|
16
|
+
|
|
17
|
+
Primal variables are images of shape (batch_size, in_channels, height, width). The input of each
|
|
18
|
+
primal block is the concatenation of the current primal variable and the backprojected dual variable along
|
|
19
|
+
the channel dimension. The output of each primal block is the current primal variable.
|
|
20
|
+
|
|
21
|
+
:param int in_channels: number of input channels. Default: 6.
|
|
22
|
+
:param int out_channels: number of output channels. Default: 5.
|
|
23
|
+
:param int depth: number of convolutional layers in the block. Default: 3.
|
|
24
|
+
:param bool bias: whether to use bias in convolutional layers. Default: True.
|
|
25
|
+
:param int nf: number of features in the convolutional layers. Default: 32.
|
|
26
|
+
"""
|
|
27
|
+
super(PDNet_PrimalBlock, self).__init__()
|
|
28
|
+
|
|
29
|
+
self.depth = depth
|
|
30
|
+
|
|
31
|
+
self.in_conv = nn.Conv2d(
|
|
32
|
+
in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
|
|
33
|
+
)
|
|
34
|
+
self.in_conv.apply(init_weights)
|
|
35
|
+
self.conv_list = nn.ModuleList(
|
|
36
|
+
[
|
|
37
|
+
nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
|
|
38
|
+
for _ in range(self.depth - 2)
|
|
39
|
+
]
|
|
40
|
+
)
|
|
41
|
+
self.conv_list.apply(init_weights)
|
|
42
|
+
self.out_conv = nn.Conv2d(
|
|
43
|
+
nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
|
44
|
+
)
|
|
45
|
+
self.out_conv.apply(init_weights)
|
|
46
|
+
|
|
47
|
+
self.nl_list = nn.ModuleList([nn.PReLU() for _ in range(self.depth - 1)])
|
|
48
|
+
|
|
49
|
+
def forward(self, x, Atu):
|
|
50
|
+
x_in = torch.cat((x, Atu), dim=1)
|
|
51
|
+
|
|
52
|
+
x_ = self.in_conv(x_in)
|
|
53
|
+
x_ = self.nl_list[0](x_)
|
|
54
|
+
|
|
55
|
+
for i in range(self.depth - 2):
|
|
56
|
+
x_l = self.conv_list[i](x_)
|
|
57
|
+
x_ = self.nl_list[i + 1](x_l)
|
|
58
|
+
|
|
59
|
+
return self.out_conv(x_) + x
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class PDNet_DualBlock(nn.Module):
|
|
63
|
+
def __init__(self, in_channels=7, out_channels=5, depth=3, bias=True, nf=32):
|
|
64
|
+
r"""
|
|
65
|
+
Dual block for the Primal-Dual unfolding model (PDNet) from https://arxiv.org/abs/1707.06474.
|
|
66
|
+
|
|
67
|
+
Dual variables are images of shape (batch_size, in_channels, height, width). The input of each
|
|
68
|
+
primal block is the concatenation of the current dual variable with the projected primal variable and
|
|
69
|
+
the measurements. The output of each dual block is the current primal variable.
|
|
70
|
+
|
|
71
|
+
:param int in_channels: number of input channels. Default: 7.
|
|
72
|
+
:param int out_channels: number of output channels. Default: 5.
|
|
73
|
+
:param int depth: number of convolutional layers in the block. Default: 3.
|
|
74
|
+
:param bool bias: whether to use bias in convolutional layers. Default: True.
|
|
75
|
+
:param int nf: number of features in the convolutional layers. Default: 32.
|
|
76
|
+
"""
|
|
77
|
+
super(PDNet_DualBlock, self).__init__()
|
|
78
|
+
|
|
79
|
+
self.depth = depth
|
|
80
|
+
|
|
81
|
+
self.in_conv = nn.Conv2d(
|
|
82
|
+
in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
|
|
83
|
+
)
|
|
84
|
+
self.in_conv.apply(init_weights)
|
|
85
|
+
self.conv_list = nn.ModuleList(
|
|
86
|
+
[
|
|
87
|
+
nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
|
|
88
|
+
for _ in range(self.depth - 2)
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
self.conv_list.apply(init_weights)
|
|
92
|
+
self.out_conv = nn.Conv2d(
|
|
93
|
+
nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
|
94
|
+
)
|
|
95
|
+
self.out_conv.apply(init_weights)
|
|
96
|
+
|
|
97
|
+
self.nl_list = nn.ModuleList([nn.PReLU() for _ in range(self.depth - 1)])
|
|
98
|
+
|
|
99
|
+
def forward(self, u, Ax_cur, y):
|
|
100
|
+
x_in = torch.cat((u, Ax_cur, y), dim=1)
|
|
101
|
+
|
|
102
|
+
x_ = self.in_conv(x_in)
|
|
103
|
+
x_ = self.nl_list[0](x_)
|
|
104
|
+
|
|
105
|
+
for i in range(self.depth - 2):
|
|
106
|
+
x_l = self.conv_list[i](x_)
|
|
107
|
+
x_ = self.nl_list[i + 1](x_l)
|
|
108
|
+
|
|
109
|
+
return self.out_conv(x_) + u
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .drunet import DRUNet
|
|
2
|
+
from .scunet import SCUNet
|
|
3
|
+
from .ae import AutoEncoder
|
|
4
|
+
from .unet import UNet
|
|
5
|
+
from .dncnn import DnCNN
|
|
6
|
+
from .artifactremoval import ArtifactRemoval
|
|
7
|
+
from .tgv import TGV as TGV
|
|
8
|
+
from .tv import TV as TV
|
|
9
|
+
from .wavdict import WaveletPrior, WaveletDict
|
|
10
|
+
from .GSPnP import GSDRUNet
|
|
11
|
+
from .median import MedianFilter
|
|
12
|
+
from .dip import DeepImagePrior, ConvDecoder
|
|
13
|
+
from .diffunet import DiffUNet
|
|
14
|
+
from .swinir import SwinIR
|
|
15
|
+
from .PDNet import PDNet_PrimalBlock, PDNet_DualBlock
|
|
16
|
+
from .bm3d import BM3D
|
|
17
|
+
from .equivariant import EquivariantDenoiser
|
deepinv/models/ae.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AutoEncoder(torch.nn.Module):
|
|
5
|
+
r"""
|
|
6
|
+
Simple fully connected autoencoder network.
|
|
7
|
+
|
|
8
|
+
Simple architecture that can be used for debugging or fast prototyping.
|
|
9
|
+
|
|
10
|
+
:param int dim_input: total number of elements (pixels) of the input.
|
|
11
|
+
:param int dim_hid: number of features in intermediate layer.
|
|
12
|
+
:param int dim_hid: latent space dimension.
|
|
13
|
+
:param int residual: use a residual connection between input and output.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, dim_input, dim_mid=1000, dim_hid=32, residual=True):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.residual = residual
|
|
20
|
+
|
|
21
|
+
self.encoder = torch.nn.Sequential(
|
|
22
|
+
torch.nn.Linear(dim_input, dim_mid),
|
|
23
|
+
torch.nn.ReLU(),
|
|
24
|
+
torch.nn.Linear(dim_mid, dim_hid),
|
|
25
|
+
)
|
|
26
|
+
self.decoder = torch.nn.Sequential(
|
|
27
|
+
torch.nn.Linear(dim_hid, dim_mid),
|
|
28
|
+
torch.nn.ReLU(),
|
|
29
|
+
torch.nn.Linear(dim_mid, dim_input),
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def forward(self, x, sigma=None):
|
|
33
|
+
N, C, H, W = x.shape
|
|
34
|
+
x = x.view(N, -1)
|
|
35
|
+
|
|
36
|
+
encoded = self.encoder(x)
|
|
37
|
+
decoded = self.decoder(encoded)
|
|
38
|
+
|
|
39
|
+
if self.residual:
|
|
40
|
+
decoded = decoded + x
|
|
41
|
+
|
|
42
|
+
decoded = decoded.view(N, C, H, W)
|
|
43
|
+
return decoded
|