deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,512 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
import time as time
|
|
5
|
+
|
|
6
|
+
import deepinv.optim
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from deepinv.optim.utils import check_conv
|
|
9
|
+
from deepinv.sampling.utils import Welford, projbox, refl_projbox
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MonteCarlo(nn.Module):
|
|
13
|
+
r"""
|
|
14
|
+
Base class for Monte Carlo sampling.
|
|
15
|
+
|
|
16
|
+
This class can be used to create new Monte Carlo samplers, by only defining their kernel inside a torch.nn.Module:
|
|
17
|
+
|
|
18
|
+
::
|
|
19
|
+
|
|
20
|
+
# define custom sampling kernel (possibly a Markov kernel which depends on the previous sanple).
|
|
21
|
+
class MyKernel(torch.torch.nn.Module):
|
|
22
|
+
def __init__(self, iterator_params):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.iterator_params = iterator_params
|
|
25
|
+
|
|
26
|
+
def forward(self, x, y, physics, likelihood, prior):
|
|
27
|
+
# run one sampling kernel iteration
|
|
28
|
+
new_x = f(x, y, physics, likelihood, prior, self.iterator_params)
|
|
29
|
+
return new_x
|
|
30
|
+
|
|
31
|
+
class MySampler(MonteCarlo):
|
|
32
|
+
def __init__(self, prior, data_fidelity, iterator_params,
|
|
33
|
+
max_iter=1e3, burnin_ratio=.1, clip=(-1,2), verbose=True):
|
|
34
|
+
# generate an iterator
|
|
35
|
+
iterator = MyKernel(step_size=step_size, alpha=alpha)
|
|
36
|
+
# set the params of the base class
|
|
37
|
+
super().__init__(iterator, prior, data_fidelity, max_iter=max_iter,
|
|
38
|
+
burnin_ratio=burnin_ratio, clip=clip, verbose=verbose)
|
|
39
|
+
|
|
40
|
+
# create the sampler
|
|
41
|
+
sampler = MySampler(prior, data_fidelity, iterator_params)
|
|
42
|
+
|
|
43
|
+
# compute posterior mean and variance of reconstruction of measurement y
|
|
44
|
+
mean, var = sampler(y, physics)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
This class computes the mean and variance of the chain using Welford's algorithm, which avoids storing the whole
|
|
48
|
+
Monte Carlo samples.
|
|
49
|
+
|
|
50
|
+
:param deepinv.optim.ScorePrior prior: negative log-prior based on a trained or model-based denoiser.
|
|
51
|
+
:param deepinv.optim.DataFidelity data_fidelity: negative log-likelihood function linked with the
|
|
52
|
+
noise distribution in the acquisition physics.
|
|
53
|
+
:param int max_iter: number of Monte Carlo iterations.
|
|
54
|
+
:param int thinning: thins the Monte Carlo samples by an integer :math:`\geq 1` (i.e., keeping one out of ``thinning``
|
|
55
|
+
samples to compute posterior statistics).
|
|
56
|
+
:param float burnin_ratio: percentage of iterations used for burn-in period, should be set between 0 and 1.
|
|
57
|
+
The burn-in samples are discarded constant with a numerical algorithm.
|
|
58
|
+
:param tuple clip: Tuple containing the box-constraints :math:`[a,b]`.
|
|
59
|
+
If ``None``, the algorithm will not project the samples.
|
|
60
|
+
:param float crit_conv: Threshold for verifying the convergence of the mean and variance estimates.
|
|
61
|
+
:param function_handle g_statistic: The sampler will compute the posterior mean and variance
|
|
62
|
+
of the function g_statistic. By default, it is the identity function (lambda x: x),
|
|
63
|
+
and thus the sampler computes the posterior mean and variance.
|
|
64
|
+
:param bool verbose: prints progress of the algorithm.
|
|
65
|
+
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
iterator: torch.nn.Module,
|
|
71
|
+
prior: deepinv.optim.ScorePrior,
|
|
72
|
+
data_fidelity: deepinv.optim.DataFidelity,
|
|
73
|
+
max_iter=1e3,
|
|
74
|
+
burnin_ratio=0.2,
|
|
75
|
+
thinning=10,
|
|
76
|
+
clip=(-1.0, 2.0),
|
|
77
|
+
thresh_conv=1e-3,
|
|
78
|
+
crit_conv="residual",
|
|
79
|
+
save_chain=False,
|
|
80
|
+
g_statistic=lambda x: x,
|
|
81
|
+
verbose=False,
|
|
82
|
+
):
|
|
83
|
+
super(MonteCarlo, self).__init__()
|
|
84
|
+
|
|
85
|
+
self.iterator = iterator
|
|
86
|
+
self.prior = prior
|
|
87
|
+
self.likelihood = data_fidelity
|
|
88
|
+
self.C_set = clip
|
|
89
|
+
self.thinning = thinning
|
|
90
|
+
self.max_iter = int(max_iter)
|
|
91
|
+
self.thresh_conv = thresh_conv
|
|
92
|
+
self.crit_conv = crit_conv
|
|
93
|
+
self.burnin_iter = int(burnin_ratio * max_iter)
|
|
94
|
+
self.verbose = verbose
|
|
95
|
+
self.mean_convergence = False
|
|
96
|
+
self.var_convergence = False
|
|
97
|
+
self.g_function = g_statistic
|
|
98
|
+
self.save_chain = save_chain
|
|
99
|
+
self.chain = []
|
|
100
|
+
|
|
101
|
+
def forward(self, y, physics, seed=None, x_init=None):
|
|
102
|
+
r"""
|
|
103
|
+
Runs an Monte Carlo chain to obtain the posterior mean and variance of the reconstruction of the measurements y.
|
|
104
|
+
|
|
105
|
+
:param torch.tensor y: Measurements
|
|
106
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements
|
|
107
|
+
:param float seed: Random seed for generating the Monte Carlo samples
|
|
108
|
+
:return: (tuple of torch.tensor) containing the posterior mean and variance.
|
|
109
|
+
"""
|
|
110
|
+
with torch.no_grad():
|
|
111
|
+
if seed is not None:
|
|
112
|
+
np.random.seed(seed)
|
|
113
|
+
torch.manual_seed(seed)
|
|
114
|
+
|
|
115
|
+
# Algorithm parameters
|
|
116
|
+
if self.C_set:
|
|
117
|
+
C_lower_lim = self.C_set[0]
|
|
118
|
+
C_upper_lim = self.C_set[1]
|
|
119
|
+
|
|
120
|
+
# Initialization
|
|
121
|
+
if x_init is None:
|
|
122
|
+
x = physics.A_adjoint(y)
|
|
123
|
+
else:
|
|
124
|
+
x = x_init
|
|
125
|
+
|
|
126
|
+
# Monte Carlo loop
|
|
127
|
+
start_time = time.time()
|
|
128
|
+
statistics = Welford(self.g_function(x))
|
|
129
|
+
|
|
130
|
+
self.mean_convergence = False
|
|
131
|
+
self.var_convergence = False
|
|
132
|
+
for it in tqdm(range(self.max_iter), disable=(not self.verbose)):
|
|
133
|
+
x = self.iterator(
|
|
134
|
+
x, y, physics, likelihood=self.likelihood, prior=self.prior
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
if self.C_set:
|
|
138
|
+
x = projbox(x, C_lower_lim, C_upper_lim)
|
|
139
|
+
|
|
140
|
+
if it >= self.burnin_iter and (it % self.thinning) == 0:
|
|
141
|
+
if it >= (self.max_iter - self.thinning):
|
|
142
|
+
mean_prev = statistics.mean().clone()
|
|
143
|
+
var_prev = statistics.var().clone()
|
|
144
|
+
statistics.update(self.g_function(x))
|
|
145
|
+
|
|
146
|
+
if self.save_chain:
|
|
147
|
+
self.chain.append(x.clone())
|
|
148
|
+
|
|
149
|
+
if self.verbose:
|
|
150
|
+
if torch.cuda.is_available():
|
|
151
|
+
torch.cuda.synchronize()
|
|
152
|
+
end_time = time.time()
|
|
153
|
+
elapsed = end_time - start_time
|
|
154
|
+
print(
|
|
155
|
+
f"Monte Carlo sampling finished! elapsed time={elapsed:.2f} seconds"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if (
|
|
159
|
+
check_conv(
|
|
160
|
+
{"est": (mean_prev,)},
|
|
161
|
+
{"est": (statistics.mean(),)},
|
|
162
|
+
it,
|
|
163
|
+
self.crit_conv,
|
|
164
|
+
self.thresh_conv,
|
|
165
|
+
self.verbose,
|
|
166
|
+
)
|
|
167
|
+
and it > 1
|
|
168
|
+
):
|
|
169
|
+
self.mean_convergence = True
|
|
170
|
+
|
|
171
|
+
if (
|
|
172
|
+
check_conv(
|
|
173
|
+
{"est": (var_prev,)},
|
|
174
|
+
{"est": (statistics.var(),)},
|
|
175
|
+
it,
|
|
176
|
+
self.crit_conv,
|
|
177
|
+
self.thresh_conv,
|
|
178
|
+
self.verbose,
|
|
179
|
+
)
|
|
180
|
+
and it > 1
|
|
181
|
+
):
|
|
182
|
+
self.var_convergence = True
|
|
183
|
+
|
|
184
|
+
return statistics.mean(), statistics.var()
|
|
185
|
+
|
|
186
|
+
def get_chain(self):
|
|
187
|
+
r"""
|
|
188
|
+
Returns the thinned Monte Carlo samples (after burn-in iterations).
|
|
189
|
+
Requires ``save_chain=True``.
|
|
190
|
+
"""
|
|
191
|
+
return self.chain
|
|
192
|
+
|
|
193
|
+
def reset(self):
|
|
194
|
+
r"""
|
|
195
|
+
Resets the Markov chain.
|
|
196
|
+
"""
|
|
197
|
+
self.chain = []
|
|
198
|
+
self.mean_convergence = False
|
|
199
|
+
self.var_convergence = False
|
|
200
|
+
|
|
201
|
+
def mean_has_converged(self):
|
|
202
|
+
r"""
|
|
203
|
+
Returns a boolean indicating if the posterior mean verifies the convergence criteria.
|
|
204
|
+
"""
|
|
205
|
+
return self.mean_convergence
|
|
206
|
+
|
|
207
|
+
def var_has_converged(self):
|
|
208
|
+
r"""
|
|
209
|
+
Returns a boolean indicating if the posterior variance verifies the convergence criteria.
|
|
210
|
+
"""
|
|
211
|
+
return self.var_convergence
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class ULAIterator(nn.Module):
|
|
215
|
+
def __init__(self, step_size, alpha, sigma):
|
|
216
|
+
super().__init__()
|
|
217
|
+
self.step_size = step_size
|
|
218
|
+
self.alpha = alpha
|
|
219
|
+
self.noise_std = np.sqrt(2 * step_size)
|
|
220
|
+
self.sigma = sigma
|
|
221
|
+
|
|
222
|
+
def forward(self, x, y, physics, likelihood, prior):
|
|
223
|
+
noise = torch.randn_like(x) * self.noise_std
|
|
224
|
+
lhood = -likelihood.grad(x, y, physics)
|
|
225
|
+
lprior = -prior(x, self.sigma) * self.alpha
|
|
226
|
+
return x + self.step_size * (lhood + lprior) + noise
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class ULA(MonteCarlo):
|
|
230
|
+
r"""
|
|
231
|
+
Projected Plug-and-Play Unadjusted Langevin Algorithm.
|
|
232
|
+
|
|
233
|
+
The algorithm runs the following markov chain iteration
|
|
234
|
+
(Algorithm 2 from https://arxiv.org/abs/2103.04715):
|
|
235
|
+
|
|
236
|
+
.. math::
|
|
237
|
+
|
|
238
|
+
x_{k+1} = \Pi_{[a,b]} \left(x_{k} + \eta \nabla \log p(y|A,x_k) +
|
|
239
|
+
\eta \alpha \nabla \log p(x_{k}) + \sqrt{2\eta}z_{k+1} \right).
|
|
240
|
+
|
|
241
|
+
where :math:`x_{k}` is the :math:`k` th sample of the Markov chain,
|
|
242
|
+
:math:`\log p(y|x)` is the log-likelihood function, :math:`\log p(x)` is the log-prior,
|
|
243
|
+
:math:`\eta>0` is the step size, :math:`\alpha>0` controls the amount of regularization,
|
|
244
|
+
:math:`\Pi_{[a,b]}(x)` projects the entries of :math:`x` to the interval :math:`[a,b]` and
|
|
245
|
+
:math:`z\sim \mathcal{N}(0,I)` is a standard Gaussian vector.
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
- Projected PnP-ULA assumes that the denoiser is :math:`L`-Lipschitz differentiable
|
|
249
|
+
- For convergence, ULA required step_size smaller than :math:`\frac{1}{L+\|A\|_2^2}`
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
:param deepinv.optim.ScorePrior, torch.nn.Module prior: negative log-prior based on a trained or model-based denoiser.
|
|
253
|
+
:param deepinv.optim.DataFidelity, torch.nn.Module data_fidelity: negative log-likelihood function linked with the
|
|
254
|
+
noise distribution in the acquisition physics.
|
|
255
|
+
:param float step_size: step size :math:`\eta>0` of the algorithm.
|
|
256
|
+
Tip: use :meth:`deepinv.physics.Physics.compute_norm()` to compute the Lipschitz constant of the forward operator.
|
|
257
|
+
:param float sigma: noise level used in the plug-and-play prior denoiser. A larger value of sigma will result in
|
|
258
|
+
a more regularized reconstruction.
|
|
259
|
+
:param float alpha: regularization parameter :math:`\alpha`
|
|
260
|
+
:param int max_iter: number of Monte Carlo iterations.
|
|
261
|
+
:param int thinning: Thins the Markov Chain by an integer :math:`\geq 1` (i.e., keeping one out of ``thinning``
|
|
262
|
+
samples to compute posterior statistics).
|
|
263
|
+
:param float burnin_ratio: percentage of iterations used for burn-in period, should be set between 0 and 1.
|
|
264
|
+
The burn-in samples are discarded constant with a numerical algorithm.
|
|
265
|
+
:param tuple clip: Tuple containing the box-constraints :math:`[a,b]`.
|
|
266
|
+
If ``None``, the algorithm will not project the samples.
|
|
267
|
+
:param float crit_conv: Threshold for verifying the convergence of the mean and variance estimates.
|
|
268
|
+
:param function_handle g_statistic: The sampler will compute the posterior mean and variance
|
|
269
|
+
of the function g_statistic. By default, it is the identity function (lambda x: x),
|
|
270
|
+
and thus the sampler computes the posterior mean and variance.
|
|
271
|
+
:param bool verbose: prints progress of the algorithm.
|
|
272
|
+
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
prior,
|
|
278
|
+
data_fidelity,
|
|
279
|
+
step_size=1.0,
|
|
280
|
+
sigma=0.05,
|
|
281
|
+
alpha=1.0,
|
|
282
|
+
max_iter=1e3,
|
|
283
|
+
thinning=5,
|
|
284
|
+
burnin_ratio=0.2,
|
|
285
|
+
clip=(-1.0, 2.0),
|
|
286
|
+
thresh_conv=1e-3,
|
|
287
|
+
save_chain=False,
|
|
288
|
+
g_statistic=lambda x: x,
|
|
289
|
+
verbose=False,
|
|
290
|
+
):
|
|
291
|
+
iterator = ULAIterator(step_size=step_size, alpha=alpha, sigma=sigma)
|
|
292
|
+
super().__init__(
|
|
293
|
+
iterator,
|
|
294
|
+
prior,
|
|
295
|
+
data_fidelity,
|
|
296
|
+
max_iter=max_iter,
|
|
297
|
+
thresh_conv=thresh_conv,
|
|
298
|
+
g_statistic=g_statistic,
|
|
299
|
+
burnin_ratio=burnin_ratio,
|
|
300
|
+
clip=clip,
|
|
301
|
+
thinning=thinning,
|
|
302
|
+
save_chain=save_chain,
|
|
303
|
+
verbose=verbose,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class SKRockIterator(nn.Module):
|
|
308
|
+
def __init__(self, step_size, alpha, inner_iter, eta, sigma):
|
|
309
|
+
super().__init__()
|
|
310
|
+
self.step_size = step_size
|
|
311
|
+
self.alpha = alpha
|
|
312
|
+
self.eta = eta
|
|
313
|
+
self.inner_iter = inner_iter
|
|
314
|
+
self.noise_std = np.sqrt(2 * step_size)
|
|
315
|
+
self.sigma = sigma
|
|
316
|
+
|
|
317
|
+
def forward(self, x, y, physics, likelihood, prior):
|
|
318
|
+
posterior = lambda u: likelihood.grad(u, y, physics) + self.alpha * prior(
|
|
319
|
+
u, self.sigma
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# First kind Chebyshev function
|
|
323
|
+
T_s = lambda s, u: np.cosh(s * np.arccosh(u))
|
|
324
|
+
# First derivative Chebyshev polynomial first kind
|
|
325
|
+
T_prime_s = lambda s, u: s * np.sinh(s * np.arccosh(u)) / np.sqrt(u**2 - 1)
|
|
326
|
+
|
|
327
|
+
w0 = 1 + self.eta / (self.inner_iter**2) # parameter \omega_0
|
|
328
|
+
w1 = T_s(self.inner_iter, w0) / T_prime_s(
|
|
329
|
+
self.inner_iter, w0
|
|
330
|
+
) # parameter \omega_1
|
|
331
|
+
mu1 = w1 / w0 # parameter \mu_1
|
|
332
|
+
nu1 = self.inner_iter * w1 / 2 # parameter \nu_1
|
|
333
|
+
kappa1 = self.inner_iter * (w1 / w0) # parameter \kappa_1
|
|
334
|
+
|
|
335
|
+
# sampling the variable x
|
|
336
|
+
noise = np.sqrt(2 * self.step_size) * torch.randn_like(x) # diffusion term
|
|
337
|
+
|
|
338
|
+
# first internal iteration (s=1)
|
|
339
|
+
xts_2 = x.clone()
|
|
340
|
+
xts = (
|
|
341
|
+
x.clone()
|
|
342
|
+
- mu1 * self.step_size * posterior(x + nu1 * noise)
|
|
343
|
+
+ kappa1 * noise
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
for js in range(
|
|
347
|
+
2, self.inner_iter + 1
|
|
348
|
+
): # s=2,...,self.inner_iter SK-ROCK internal iterations
|
|
349
|
+
xts_1 = xts.clone()
|
|
350
|
+
mu = 2 * w1 * T_s(js - 1, w0) / T_s(js, w0) # parameter \mu_js
|
|
351
|
+
nu = 2 * w0 * T_s(js - 1, w0) / T_s(js, w0) # parameter \nu_js
|
|
352
|
+
kappa = 1 - nu # parameter \kappa_js
|
|
353
|
+
xts = -mu * self.step_size * posterior(xts) + nu * xts + kappa * xts_2
|
|
354
|
+
xts_2 = xts_1
|
|
355
|
+
|
|
356
|
+
return xts # new sample produced by the SK-ROCK algorithm
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
class SKRock(MonteCarlo):
|
|
360
|
+
r"""
|
|
361
|
+
Plug-and-Play SKROCK algorithm.
|
|
362
|
+
|
|
363
|
+
Obtains samples of the posterior distribution using an orthogonal Runge-Kutta-Chebyshev stochastic
|
|
364
|
+
approximation to accelerate the standard Unadjusted Langevin Algorithm.
|
|
365
|
+
|
|
366
|
+
The algorithm was introduced in "Accelerating proximal Markov chain Monte Carlo by using an explicit stabilised method"
|
|
367
|
+
by L. Vargas, M. Pereyra and K. Zygalakis (https://arxiv.org/abs/1908.08845)
|
|
368
|
+
|
|
369
|
+
- SKROCK assumes that the denoiser is :math:`L`-Lipschitz differentiable
|
|
370
|
+
- For convergence, SKROCK required step_size smaller than :math:`\frac{1}{L+\|A\|_2^2}`
|
|
371
|
+
|
|
372
|
+
:param deepinv.optim.ScorePrior, torch.nn.Module prior: negative log-prior based on a trained or model-based denoiser.
|
|
373
|
+
:param deepinv.optim.DataFidelity, torch.nn.Module data_fidelity: negative log-likelihood function linked with the
|
|
374
|
+
noise distribution in the acquisition physics.
|
|
375
|
+
:param float step_size: Step size of the algorithm. Tip: use physics.lipschitz to compute the Lipschitz
|
|
376
|
+
:param float eta: :math:`\eta` SKROCK damping parameter.
|
|
377
|
+
:param float alpha: regularization parameter :math:`\alpha`.
|
|
378
|
+
:param int inner_iter: Number of inner SKROCK iterations.
|
|
379
|
+
:param int max_iter: Number of outer iterations.
|
|
380
|
+
:param int thinning: Thins the Markov Chain by an integer :math:`\geq 1` (i.e., keeping one out of ``thinning``
|
|
381
|
+
samples to compute posterior statistics).
|
|
382
|
+
:param float burnin_ratio: percentage of iterations used for burn-in period. The burn-in samples are discarded
|
|
383
|
+
constant with a numerical algorithm.
|
|
384
|
+
:param tuple clip: Tuple containing the box-constraints :math:`[a,b]`.
|
|
385
|
+
If ``None``, the algorithm will not project the samples.
|
|
386
|
+
:param bool verbose: prints progress of the algorithm.
|
|
387
|
+
:param float sigma: noise level used in the plug-and-play prior denoiser. A larger value of sigma will result in
|
|
388
|
+
a more regularized reconstruction.
|
|
389
|
+
:param function_handle g_statistic: The sampler will compute the posterior mean and variance
|
|
390
|
+
of the function g_statistic. By default, it is the identity function (lambda x: x),
|
|
391
|
+
and thus the sampler computes the posterior mean and variance.
|
|
392
|
+
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
def __init__(
|
|
396
|
+
self,
|
|
397
|
+
prior: deepinv.optim.ScorePrior,
|
|
398
|
+
data_fidelity,
|
|
399
|
+
step_size=1.0,
|
|
400
|
+
inner_iter=10,
|
|
401
|
+
eta=0.05,
|
|
402
|
+
alpha=1.0,
|
|
403
|
+
max_iter=1e3,
|
|
404
|
+
burnin_ratio=0.2,
|
|
405
|
+
thinning=10,
|
|
406
|
+
clip=(-1.0, 2.0),
|
|
407
|
+
thresh_conv=1e-3,
|
|
408
|
+
save_chain=False,
|
|
409
|
+
g_statistic=lambda x: x,
|
|
410
|
+
verbose=False,
|
|
411
|
+
sigma=0.05,
|
|
412
|
+
):
|
|
413
|
+
iterator = SKRockIterator(
|
|
414
|
+
step_size=step_size,
|
|
415
|
+
alpha=alpha,
|
|
416
|
+
inner_iter=inner_iter,
|
|
417
|
+
eta=eta,
|
|
418
|
+
sigma=sigma,
|
|
419
|
+
)
|
|
420
|
+
super().__init__(
|
|
421
|
+
iterator,
|
|
422
|
+
prior,
|
|
423
|
+
data_fidelity,
|
|
424
|
+
max_iter=max_iter,
|
|
425
|
+
thresh_conv=thresh_conv,
|
|
426
|
+
thinning=thinning,
|
|
427
|
+
burnin_ratio=burnin_ratio,
|
|
428
|
+
clip=clip,
|
|
429
|
+
g_statistic=g_statistic,
|
|
430
|
+
save_chain=save_chain,
|
|
431
|
+
verbose=verbose,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
# if __name__ == "__main__":
|
|
436
|
+
# import deepinv as dinv
|
|
437
|
+
# import torchvision
|
|
438
|
+
# from deepinv.optim.data_fidelity import L2
|
|
439
|
+
#
|
|
440
|
+
# x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
|
|
441
|
+
# x = x.unsqueeze(0).float().to(dinv.device) / 255
|
|
442
|
+
# # physics = dinv.physics.CompressedSensing(m=50000, fast=True, img_shape=(3, 218, 178), device=dinv.device)
|
|
443
|
+
# # physics = dinv.physics.Denoising()
|
|
444
|
+
# physics = dinv.physics.Inpainting(
|
|
445
|
+
# mask=0.95, tensor_size=(3, 218, 178), device=dinv.device
|
|
446
|
+
# )
|
|
447
|
+
# # physics = dinv.physics.BlurFFT(filter=dinv.physics.blur.gaussian_blur(sigma=(2,2)), img_size=x.shape[1:], device=dinv.device)
|
|
448
|
+
#
|
|
449
|
+
# sigma = 0.1
|
|
450
|
+
# physics.noise_model = dinv.physics.GaussianNoise(sigma)
|
|
451
|
+
#
|
|
452
|
+
# y = physics(x)
|
|
453
|
+
#
|
|
454
|
+
# likelihood = L2(sigma=sigma)
|
|
455
|
+
#
|
|
456
|
+
# # model_spec = {'name': 'median_filter', 'args': {'kernel_size': 3}}
|
|
457
|
+
# model_spec = {
|
|
458
|
+
# "name": "dncnn",
|
|
459
|
+
# "args": {
|
|
460
|
+
# "device": dinv.device,
|
|
461
|
+
# "in_channels": 3,
|
|
462
|
+
# "out_channels": 3,
|
|
463
|
+
# "pretrained": "download_lipschitz",
|
|
464
|
+
# },
|
|
465
|
+
# }
|
|
466
|
+
# # model_spec = {'name': 'waveletprior', 'args': {'wv': 'db8', 'level': 4, 'device': dinv.device}}
|
|
467
|
+
#
|
|
468
|
+
# prior = ScorePrior(model_spec=model_spec, sigma_normalize=True)
|
|
469
|
+
#
|
|
470
|
+
# sigma_den = 2 / 255
|
|
471
|
+
# f = ULA(
|
|
472
|
+
# prior,
|
|
473
|
+
# likelihood,
|
|
474
|
+
# max_iter=5000,
|
|
475
|
+
# sigma=sigma_den,
|
|
476
|
+
# burnin_ratio=0.3,
|
|
477
|
+
# verbose=True,
|
|
478
|
+
# alpha=0.3,
|
|
479
|
+
# step_size=0.5 * 1 / (1 / (sigma**2) + 1 / (sigma_den**2)),
|
|
480
|
+
# clip=(-1, 2),
|
|
481
|
+
# save_chain=True,
|
|
482
|
+
# )
|
|
483
|
+
# # f = SKRock(prior, likelihood, max_iter=1000, burnin_ratio=.3, verbose=True,
|
|
484
|
+
# # alpha=.9, step_size=.1*(sigma**2), clip=(-1, 2))
|
|
485
|
+
#
|
|
486
|
+
# xmean, xvar = f(y, physics)
|
|
487
|
+
#
|
|
488
|
+
# print(str(f.mean_has_converged()))
|
|
489
|
+
# print(str(f.var_has_converged()))
|
|
490
|
+
#
|
|
491
|
+
# chain = f.get_chain()
|
|
492
|
+
# distance = np.zeros((len(chain)))
|
|
493
|
+
# for k, xhat in enumerate(chain):
|
|
494
|
+
# dist = (xhat - xmean).pow(2).mean()
|
|
495
|
+
# distance[k] = dist
|
|
496
|
+
# distance = np.sort(distance)
|
|
497
|
+
# thres = distance[int(len(distance) * 0.95)] #
|
|
498
|
+
# err = (x - xmean).pow(2).mean()
|
|
499
|
+
# print(f"Confidence region: {thres:.2e}, error: {err:.2e}")
|
|
500
|
+
#
|
|
501
|
+
# xstdn = xvar.sqrt()
|
|
502
|
+
# xstdn_plot = xstdn.sum(dim=1).unsqueeze(1)
|
|
503
|
+
#
|
|
504
|
+
# error = (xmean - x).abs() # per pixel average abs. error
|
|
505
|
+
# error_plot = error.sum(dim=1).unsqueeze(1)
|
|
506
|
+
#
|
|
507
|
+
# print(f"Correct std: {(xstdn*3>error).sum()/np.prod(xstdn.shape)*100:.1f}%")
|
|
508
|
+
#
|
|
509
|
+
# dinv.utils.plot(
|
|
510
|
+
# [physics.A_adjoint(y), x, xmean, xstdn_plot, error_plot],
|
|
511
|
+
# titles=["meas.", "ground-truth", "mean", "norm. std", "abs. error"],
|
|
512
|
+
# )
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Welford:
|
|
5
|
+
r"""
|
|
6
|
+
Welford's algorithm for calculating mean and variance
|
|
7
|
+
|
|
8
|
+
https://doi.org/10.2307/1266577
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, x):
|
|
12
|
+
self.k = 1
|
|
13
|
+
self.M = x.clone()
|
|
14
|
+
self.S = torch.zeros_like(x)
|
|
15
|
+
|
|
16
|
+
def update(self, x):
|
|
17
|
+
self.k += 1
|
|
18
|
+
Mnext = self.M + (x - self.M) / self.k
|
|
19
|
+
self.S = self.S + (x - self.M) * (x - Mnext)
|
|
20
|
+
self.M = Mnext
|
|
21
|
+
|
|
22
|
+
def mean(self):
|
|
23
|
+
return self.M
|
|
24
|
+
|
|
25
|
+
def var(self):
|
|
26
|
+
return self.S / (self.k - 1)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def refl_projbox(x, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
|
|
30
|
+
x = torch.abs(x)
|
|
31
|
+
return torch.clamp(x, min=lower, max=upper)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def projbox(x, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
return torch.clamp(x, min=lower, max=upper)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
import deepinv as dinv
|
|
6
|
+
from deepinv.tests.dummy_datasets.datasets import DummyCircles
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def device():
|
|
11
|
+
return dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.fixture
|
|
15
|
+
def toymatrix():
|
|
16
|
+
w = 50
|
|
17
|
+
A = torch.diag(torch.Tensor(range(1, w + 1)))
|
|
18
|
+
return A
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.fixture
|
|
22
|
+
def dummy_dataset(imsize, device):
|
|
23
|
+
return DummyCircles(samples=1, imsize=imsize)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.fixture
|
|
27
|
+
def imsize():
|
|
28
|
+
h = 37
|
|
29
|
+
w = 31
|
|
30
|
+
c = 3
|
|
31
|
+
return c, h, w
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.fixture
|
|
35
|
+
def imsize_1_channel():
|
|
36
|
+
h = 37
|
|
37
|
+
w = 31
|
|
38
|
+
c = 1
|
|
39
|
+
return c, h, w
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import Dataset
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def create_circular_mask(imsize, center=None, radius=None):
|
|
7
|
+
h, w = imsize
|
|
8
|
+
if center is None: # use the middle of the image
|
|
9
|
+
center = (int(h / 2), int(w / 2))
|
|
10
|
+
if radius is None: # use the smallest distance between the center and image walls
|
|
11
|
+
radius = min(center[0], center[1], h - center[0], w - center[1])
|
|
12
|
+
|
|
13
|
+
X, Y = np.ogrid[:h, :w]
|
|
14
|
+
dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
|
|
15
|
+
mask = dist_from_center <= radius
|
|
16
|
+
return mask
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DummyCircles(Dataset):
|
|
20
|
+
def __init__(self, samples, imsize=(3, 32, 28), max_circles=10, seed=1):
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
self.x = torch.zeros((samples,) + imsize, dtype=torch.float32)
|
|
24
|
+
|
|
25
|
+
rng = np.random.default_rng(seed)
|
|
26
|
+
|
|
27
|
+
max_rad = max(imsize[0], imsize[1]) / 2
|
|
28
|
+
for i in range(samples):
|
|
29
|
+
circles = rng.integers(low=1, high=max_circles)
|
|
30
|
+
|
|
31
|
+
for c in range(circles):
|
|
32
|
+
pos = rng.uniform(high=imsize[1:])
|
|
33
|
+
colour = rng.random((imsize[0], 1), dtype=np.float32)
|
|
34
|
+
r = rng.uniform(high=max_rad)
|
|
35
|
+
mask = torch.from_numpy(
|
|
36
|
+
create_circular_mask(imsize[1:], center=pos, radius=r)
|
|
37
|
+
)
|
|
38
|
+
self.x[i, :, mask] = torch.from_numpy(colour)
|
|
39
|
+
|
|
40
|
+
def __getitem__(self, index):
|
|
41
|
+
return self.x[index, :, :, :]
|
|
42
|
+
|
|
43
|
+
def __len__(self):
|
|
44
|
+
return self.x.shape[0]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__":
|
|
48
|
+
device = "cuda:0"
|
|
49
|
+
imsize = (3, 23, 100)
|
|
50
|
+
dataset = DummyCircles(10, imsize=imsize)
|
|
51
|
+
|
|
52
|
+
x = dataset[0]
|
|
53
|
+
|
|
54
|
+
import matplotlib.pyplot as plt
|
|
55
|
+
|
|
56
|
+
plt.imshow(x.permute(1, 2, 0).cpu().numpy())
|
|
57
|
+
plt.show()
|