deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/optim/prior.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from deepinv.optim.utils import gradient_descent
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Prior(nn.Module):
|
|
8
|
+
r"""
|
|
9
|
+
Prior term :math:`g(x)`.
|
|
10
|
+
|
|
11
|
+
This is the base class for the prior term :math:`g(x)`. Similarly to the :meth:`deepinv.optim.DataFidelity` class,
|
|
12
|
+
this class comes with methods for computing
|
|
13
|
+
:math:`\operatorname{prox}_{g}` and :math:`\nabla g`.
|
|
14
|
+
To implement a custom prior, for an explicit prior, overwrite :math:`g` (do not forget to specify
|
|
15
|
+
`self.explicit_prior = True`)
|
|
16
|
+
|
|
17
|
+
This base class is also used to implement implicit priors. For instance, in PnP methods, the method computing the
|
|
18
|
+
proximity operator is overwritten by a method performing denoising. For an implicit prior, overwrite `grad`
|
|
19
|
+
or `prox`.
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
.. note::
|
|
23
|
+
|
|
24
|
+
The methods for computing the proximity operator and the gradient of the prior rely on automatic
|
|
25
|
+
differentiation. These methods should not be used when the prior is not differentiable, although they will
|
|
26
|
+
not raise an error.
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
:param callable g: Prior function :math:`g(x)`.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, g=None):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self._g = g
|
|
35
|
+
self.explicit_prior = False if self._g is None else True
|
|
36
|
+
|
|
37
|
+
def g(self, x, *args, **kwargs):
|
|
38
|
+
r"""
|
|
39
|
+
Computes the prior :math:`g(x)`.
|
|
40
|
+
|
|
41
|
+
:param torch.tensor x: Variable :math:`x` at which the prior is computed.
|
|
42
|
+
:return: (torch.tensor) prior :math:`g(x)`.
|
|
43
|
+
"""
|
|
44
|
+
return self._g(x, *args, **kwargs)
|
|
45
|
+
|
|
46
|
+
def forward(self, x, *args, **kwargs):
|
|
47
|
+
r"""
|
|
48
|
+
Computes the prior :math:`g(x)`.
|
|
49
|
+
|
|
50
|
+
:param torch.tensor x: Variable :math:`x` at which the prior is computed.
|
|
51
|
+
:return: (torch.tensor) prior :math:`g(x)`.
|
|
52
|
+
"""
|
|
53
|
+
return self.g(x, *args, **kwargs)
|
|
54
|
+
|
|
55
|
+
def grad(self, x, *args, **kwargs):
|
|
56
|
+
r"""
|
|
57
|
+
Calculates the gradient of the prior term :math:`g` at :math:`x`.
|
|
58
|
+
By default, the gradient is computed using automatic differentiation.
|
|
59
|
+
|
|
60
|
+
:param torch.tensor x: Variable :math:`x` at which the gradient is computed.
|
|
61
|
+
:return: (torch.tensor) gradient :math:`\nabla_x g`, computed in :math:`x`.
|
|
62
|
+
"""
|
|
63
|
+
with torch.enable_grad():
|
|
64
|
+
x = x.requires_grad_()
|
|
65
|
+
grad = torch.autograd.grad(
|
|
66
|
+
self.g(x, *args, **kwargs), x, create_graph=True, only_inputs=True
|
|
67
|
+
)[0]
|
|
68
|
+
return grad
|
|
69
|
+
|
|
70
|
+
def prox(
|
|
71
|
+
self,
|
|
72
|
+
x,
|
|
73
|
+
*args,
|
|
74
|
+
gamma=1.0,
|
|
75
|
+
stepsize_inter=1.0,
|
|
76
|
+
max_iter_inter=50,
|
|
77
|
+
tol_inter=1e-3,
|
|
78
|
+
**kwargs,
|
|
79
|
+
):
|
|
80
|
+
r"""
|
|
81
|
+
Calculates the proximity operator of :math:`g` at :math:`x`. By default, the proximity operator is computed using internal gradient descent.
|
|
82
|
+
|
|
83
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
84
|
+
:param float gamma: stepsize of the proximity operator.
|
|
85
|
+
:param float stepsize_inter: stepsize used for internal gradient descent
|
|
86
|
+
:param int max_iter_inter: maximal number of iterations for internal gradient descent.
|
|
87
|
+
:param float tol_inter: internal gradient descent has converged when the L2 distance between two consecutive iterates is smaller than tol_inter.
|
|
88
|
+
:return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma g}(x)`, computed in :math:`x`.
|
|
89
|
+
"""
|
|
90
|
+
grad = lambda z: gamma * self.grad(z, *args, **kwargs) + (z - x)
|
|
91
|
+
return gradient_descent(
|
|
92
|
+
grad, x, step_size=stepsize_inter, max_iter=max_iter_inter, tol=tol_inter
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def prox_conjugate(self, x, *args, gamma=1.0, lamb=1.0, **kwargs):
|
|
96
|
+
r"""
|
|
97
|
+
Calculates the proximity operator of the convex conjugate :math:`(\lambda g)^*` at :math:`x`, using the Moreau formula.
|
|
98
|
+
|
|
99
|
+
::Warning:: Only valid for convex :math:`g`
|
|
100
|
+
|
|
101
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
102
|
+
:param float gamma: stepsize of the proximity operator.
|
|
103
|
+
:param float lamb: math:`\lambda` parameter in front of :math:`f`
|
|
104
|
+
:return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma \lambda g)^*}(x)`, computed in :math:`x`.
|
|
105
|
+
"""
|
|
106
|
+
return x - gamma * self.prox(x / gamma, lamb / gamma, *args, **kwargs)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class PnP(Prior):
|
|
110
|
+
r"""
|
|
111
|
+
Plug-and-play prior :math:`\operatorname{prox}_{\gamma g}(x) = \operatorname{D}_{\sigma}(x)`.
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
:param callable denoiser: Denoiser :math:`\operatorname{D}_{\sigma}`.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def __init__(self, denoiser, *args, **kwargs):
|
|
118
|
+
super().__init__(*args, **kwargs)
|
|
119
|
+
self.denoiser = denoiser
|
|
120
|
+
self.explicit_prior = False
|
|
121
|
+
|
|
122
|
+
def prox(self, x, sigma_denoiser, *args, **kwargs):
|
|
123
|
+
r"""
|
|
124
|
+
Uses denoising as the proximity operator of the PnP prior :math:`g` at :math:`x`.
|
|
125
|
+
|
|
126
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
127
|
+
:param float sigma_denoiser: noise level parameter of the denoiser.
|
|
128
|
+
:return: (torch.tensor) proximity operator at :math:`x`.
|
|
129
|
+
"""
|
|
130
|
+
return self.denoiser(x, sigma_denoiser)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class RED(Prior):
|
|
134
|
+
r"""
|
|
135
|
+
Regularization-by-Denoising (RED) prior :math:`\nabla g(x) = \operatorname{Id} - \operatorname{D}_{\sigma}(x)`.
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
:param callable denoiser: Denoiser :math:`\operatorname{D}_{\sigma}`.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(self, denoiser, *args, **kwargs):
|
|
142
|
+
super().__init__(*args, **kwargs)
|
|
143
|
+
self.denoiser = denoiser
|
|
144
|
+
self.explicit_prior = False
|
|
145
|
+
|
|
146
|
+
def grad(self, x, sigma_denoiser, *args, **kwargs):
|
|
147
|
+
r"""
|
|
148
|
+
Calculates the gradient of the prior term :math:`g` at :math:`x`.
|
|
149
|
+
By default, the gradient is computed using automatic differentiation.
|
|
150
|
+
|
|
151
|
+
:param torch.Tensor x: Variable :math:`x` at which the gradient is computed.
|
|
152
|
+
:return: (:class:`torch.Tensor`) gradient :math:`\nabla_x g`, computed in :math:`x`.
|
|
153
|
+
"""
|
|
154
|
+
return x - self.denoiser(x, sigma_denoiser)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class ScorePrior(Prior):
|
|
158
|
+
r"""
|
|
159
|
+
Score via MMSE denoiser :math:`\nabla g(x)=\left(x-\operatorname{D}_{\sigma}(x)\right)/\sigma^2`.
|
|
160
|
+
|
|
161
|
+
This approximates the score of a distribution using Tweedie's formula, i.e.,
|
|
162
|
+
|
|
163
|
+
.. math::
|
|
164
|
+
|
|
165
|
+
- \nabla \log p_{\sigma}(x) \propto \left(x-D(x,\sigma)\right)/\sigma^2
|
|
166
|
+
|
|
167
|
+
where :math:`p_{\sigma} = p*\mathcal{N}(0,I\sigma^2)` is the prior convolved with a Gaussian kernel,
|
|
168
|
+
:math:`D(\cdot,\sigma)` is a (trained or model-based) denoiser with noise level :math:`\sigma`,
|
|
169
|
+
which is typically set to a low value.
|
|
170
|
+
|
|
171
|
+
.. note::
|
|
172
|
+
|
|
173
|
+
If math:`\sigma=1`, this prior is equal to :class:`deepinv.optim.RED`, which is defined in
|
|
174
|
+
`Regularization by Denoising (RED) <https://arxiv.org/abs/1611.02862>`_ and doesn't require the normalization.
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
.. note::
|
|
178
|
+
|
|
179
|
+
This class can also be used with maximum-a-posteriori (MAP) denoisers,
|
|
180
|
+
but :math:`p_{\sigma}(x)` is not given by the convolution with a Gaussian kernel, but rather
|
|
181
|
+
given by the Moreau-Yosida envelope of :math:`p(x)`, i.e.,
|
|
182
|
+
|
|
183
|
+
.. math::
|
|
184
|
+
|
|
185
|
+
p_{\sigma}(x)=e^{- \inf_z \left(-\log p(z) + \frac{1}{2\sigma}\|x-z\|^2 \right)}.
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def __init__(self, denoiser, *args, **kwargs):
|
|
191
|
+
super().__init__(*args, **kwargs)
|
|
192
|
+
self.denoiser = denoiser
|
|
193
|
+
self.explicit_prior = False
|
|
194
|
+
|
|
195
|
+
def forward(self, x, sigma):
|
|
196
|
+
r"""
|
|
197
|
+
Applies the denoiser to the input signal.
|
|
198
|
+
|
|
199
|
+
:param torch.Tensor x: the input tensor.
|
|
200
|
+
:param float sigma: the noise level.
|
|
201
|
+
"""
|
|
202
|
+
return (1 / sigma**2) * (x - self.denoiser(x, sigma))
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class Tikhonov(Prior):
|
|
206
|
+
r"""
|
|
207
|
+
Tikhonov regularizer :math:`g(x) = \frac{1}{2}\| x \|_2^2`.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(self, *args, **kwargs):
|
|
211
|
+
super().__init__(*args, **kwargs)
|
|
212
|
+
self.explicit_prior = True
|
|
213
|
+
|
|
214
|
+
def g(self, x, ths=1.0):
|
|
215
|
+
r"""
|
|
216
|
+
Computes the Tikhonov regularizer :math:`g(x) = \frac{\tau}{2}\| x \|_2^2`.
|
|
217
|
+
|
|
218
|
+
:param torch.Tensor x: Variable :math:`x` at which the prior is computed.
|
|
219
|
+
:param float ths: regularization parameter :math:`\tau`.
|
|
220
|
+
:return: (torch.Tensor) prior :math:`g(x)`.
|
|
221
|
+
"""
|
|
222
|
+
return (
|
|
223
|
+
0.5
|
|
224
|
+
* ths
|
|
225
|
+
* torch.norm(x.contiguous().view(x.shape[0], -1), p=2, dim=-1) ** 2
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def grad(self, x):
|
|
229
|
+
r"""
|
|
230
|
+
Calculates the gradient of the Tikhonov regularization term :math:`g` at :math:`x`.
|
|
231
|
+
|
|
232
|
+
:param torch.Tensor x: Variable :math:`x` at which the gradient is computed.
|
|
233
|
+
:return: (torch.Tensor) gradient at :math:`x`.
|
|
234
|
+
"""
|
|
235
|
+
return x
|
|
236
|
+
|
|
237
|
+
def prox(self, x, ths=1.0, gamma=1.0):
|
|
238
|
+
r"""
|
|
239
|
+
Calculates the proximity operator of the Tikhonov regularization term :math:`\gamma \tau g` at :math:`x`.
|
|
240
|
+
|
|
241
|
+
:param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
242
|
+
:param float ths: regularization parameter :math:`\tau`.
|
|
243
|
+
:param float gamma: stepsize of the proximity operator.
|
|
244
|
+
:return: (torch.Tensor) proximity operator at :math:`x`.
|
|
245
|
+
"""
|
|
246
|
+
return (1 / (ths * gamma + 1)) * x
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class L1Prior(Prior):
|
|
250
|
+
r"""
|
|
251
|
+
:math:`\ell_1` prior :math:`g(x) = \| x \|_1`.
|
|
252
|
+
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
def __init__(self, *args, **kwargs):
|
|
256
|
+
super().__init__(*args, **kwargs)
|
|
257
|
+
self.explicit_prior = True
|
|
258
|
+
|
|
259
|
+
def g(self, x, ths=1.0):
|
|
260
|
+
r"""
|
|
261
|
+
Computes the regularizer :math:`g(x) = \tau\| x \|_1`.
|
|
262
|
+
|
|
263
|
+
:param torch.Tensor x: Variable :math:`x` at which the prior is computed.
|
|
264
|
+
:param float ths: threshold parameter :math:`\tau`.
|
|
265
|
+
:return: (torch.Tensor) prior :math:`g(x)`.
|
|
266
|
+
"""
|
|
267
|
+
return ths * torch.norm(x.contiguous().view(x.shape[0], -1), p=1, dim=-1)
|
|
268
|
+
|
|
269
|
+
def prox(self, x, ths=1.0, gamma=1.0):
|
|
270
|
+
r"""
|
|
271
|
+
Calculates the proximity operator of the l1 regularization term :math:`g` at :math:`x`.
|
|
272
|
+
|
|
273
|
+
More precisely, it computes
|
|
274
|
+
|
|
275
|
+
.. math::
|
|
276
|
+
\operatorname{prox}_{\gamma \tau g}(x) = \operatorname{sign}(x) \max(|x| - \gamma \tau, 0)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
where :math:`\tau` is the threshold parameter and :math:`\gamma` is a stepsize.
|
|
280
|
+
|
|
281
|
+
:param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
282
|
+
:param float ths: threshold parameter :math:`\tau`.
|
|
283
|
+
:param float gamma: stepsize of the proximity operator.
|
|
284
|
+
:return: (torch.Tensor) proximity operator at :math:`x`.
|
|
285
|
+
"""
|
|
286
|
+
return torch.sign(x) * torch.max(
|
|
287
|
+
torch.abs(x) - ths * gamma, torch.zeros_like(x)
|
|
288
|
+
)
|
deepinv/optim/utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from deepinv.utils import zeros_like
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def check_conv(X_prev, X, it, crit_conv="residual", thres_conv=1e-3, verbose=False):
|
|
5
|
+
if crit_conv == "residual":
|
|
6
|
+
if isinstance(X_prev, dict):
|
|
7
|
+
X_prev = X_prev["est"][0]
|
|
8
|
+
if isinstance(X, dict):
|
|
9
|
+
X = X["est"][0]
|
|
10
|
+
crit_cur = (X_prev - X).norm() / (X.norm() + 1e-06)
|
|
11
|
+
elif crit_conv == "cost":
|
|
12
|
+
F_prev = X_prev["cost"]
|
|
13
|
+
F = X["cost"]
|
|
14
|
+
crit_cur = (F_prev - F).norm() / (F.norm() + 1e-06)
|
|
15
|
+
else:
|
|
16
|
+
raise ValueError("convergence criteria not implemented")
|
|
17
|
+
if crit_cur < thres_conv:
|
|
18
|
+
if verbose:
|
|
19
|
+
print(
|
|
20
|
+
f"Iteration {it}, current converge crit. = {crit_cur:.2E}, objective = {thres_conv:.2E} \r"
|
|
21
|
+
)
|
|
22
|
+
return True
|
|
23
|
+
else:
|
|
24
|
+
return False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def conjugate_gradient(A, b, max_iter=1e2, tol=1e-5):
|
|
28
|
+
"""
|
|
29
|
+
Standard conjugate gradient algorithm to solve Ax=b
|
|
30
|
+
see: http://en.wikipedia.org/wiki/Conjugate_gradient_method
|
|
31
|
+
:param A: Linear operator as a callable function, has to be square!
|
|
32
|
+
:param b: input tensor
|
|
33
|
+
:param max_iter: maximum number of CG iterations
|
|
34
|
+
:param tol: absolute tolerance for stopping the CG algorithm.
|
|
35
|
+
:return: torch tensor x verifying Ax=b
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def dot(s1, s2):
|
|
40
|
+
dot = (s1 * s2).flatten().sum()
|
|
41
|
+
return dot
|
|
42
|
+
|
|
43
|
+
x = zeros_like(b)
|
|
44
|
+
|
|
45
|
+
r = b
|
|
46
|
+
p = r
|
|
47
|
+
rsold = dot(r, r)
|
|
48
|
+
|
|
49
|
+
for i in range(int(max_iter)):
|
|
50
|
+
Ap = A(p)
|
|
51
|
+
alpha = rsold / dot(p, Ap)
|
|
52
|
+
x = x + p * alpha
|
|
53
|
+
r = r + Ap * (-alpha)
|
|
54
|
+
rsnew = dot(r, r)
|
|
55
|
+
# print(rsnew.sqrt())
|
|
56
|
+
if rsnew.sqrt() < tol:
|
|
57
|
+
break
|
|
58
|
+
p = r + p * (rsnew / rsold)
|
|
59
|
+
rsold = rsnew
|
|
60
|
+
|
|
61
|
+
return x
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def gradient_descent(grad_f, x, step_size=1.0, max_iter=1e2, tol=1e-5):
|
|
65
|
+
"""
|
|
66
|
+
Standard gradient descent algorithm to solve min_x f(x)
|
|
67
|
+
:param grad_f: gradient of function to bz minimized as a callable function.
|
|
68
|
+
:param x: input tensor
|
|
69
|
+
:param step_size: (constant) step size of the gradient descent algorithm.
|
|
70
|
+
:param max_iter: maximum number of iterations
|
|
71
|
+
:param tol: absolute tolerance for stopping the algorithm.
|
|
72
|
+
:return: torch tensor x verifying min_x f(x)
|
|
73
|
+
|
|
74
|
+
"""
|
|
75
|
+
for i in range(int(max_iter)):
|
|
76
|
+
x_prev = x
|
|
77
|
+
x = x - grad_f(x) * step_size
|
|
78
|
+
if check_conv(x_prev, x, i, thres_conv=tol):
|
|
79
|
+
break
|
|
80
|
+
return x
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .inpainting import Inpainting
|
|
2
|
+
from .compressed_sensing import CompressedSensing
|
|
3
|
+
from .blur import Blur, BlindBlur, Downsampling, BlurFFT
|
|
4
|
+
from .range import Decolorize
|
|
5
|
+
from .haze import Haze
|
|
6
|
+
from .forward import Denoising, Physics, LinearPhysics, DecomposablePhysics
|
|
7
|
+
from .noise import (
|
|
8
|
+
GaussianNoise,
|
|
9
|
+
PoissonNoise,
|
|
10
|
+
PoissonGaussianNoise,
|
|
11
|
+
UniformNoise,
|
|
12
|
+
UniformGaussianNoise,
|
|
13
|
+
)
|
|
14
|
+
from .mri import MRI
|
|
15
|
+
from .tomography import Tomography
|
|
16
|
+
from .lidar import SinglePhotonLidar
|
|
17
|
+
from .singlepixel import SinglePixelCamera
|
|
18
|
+
from .remote_sensing import Pansharpen
|