deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,676 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
import deepinv.physics
|
|
7
|
+
from deepinv.sampling.langevin import MonteCarlo
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DiffusionSampler(MonteCarlo):
|
|
11
|
+
r"""
|
|
12
|
+
Turns a diffusion method into a Monte Carlo sampler
|
|
13
|
+
|
|
14
|
+
Unlike diffusion methods, the resulting sampler computes the mean and variance of the distribution
|
|
15
|
+
by running the diffusion multiple times.
|
|
16
|
+
|
|
17
|
+
:param torch.nn.Module diffusion: a diffusion model
|
|
18
|
+
:param int max_iter: the maximum number of iterations
|
|
19
|
+
:param tuple clip: the clip range
|
|
20
|
+
:param callable g_statistic: the algorithm computes mean and variance of the g function, by default :math:`g(x) = x`.
|
|
21
|
+
:param float thres_conv: the convergence threshold for the mean and variance
|
|
22
|
+
:param bool verbose: whether to print the progress
|
|
23
|
+
:param bool save_chain: whether to save the chain
|
|
24
|
+
:param int thinning: the thinning factor
|
|
25
|
+
:param float burnin_ratio: the burnin ratio
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
diffusion,
|
|
31
|
+
max_iter=1e2,
|
|
32
|
+
clip=(-1, 2),
|
|
33
|
+
thres_conv=1e-1,
|
|
34
|
+
g_statistic=lambda x: x,
|
|
35
|
+
verbose=True,
|
|
36
|
+
save_chain=False,
|
|
37
|
+
):
|
|
38
|
+
# generate an iterator
|
|
39
|
+
# set the params of the base class
|
|
40
|
+
data_fidelity = None
|
|
41
|
+
diffusion.verbose = False
|
|
42
|
+
prior = diffusion
|
|
43
|
+
|
|
44
|
+
def iterator(x, y, physics, likelihood, prior):
|
|
45
|
+
# run one sampling kernel iteration
|
|
46
|
+
x = prior(y, physics)
|
|
47
|
+
return x
|
|
48
|
+
|
|
49
|
+
super().__init__(
|
|
50
|
+
iterator,
|
|
51
|
+
prior,
|
|
52
|
+
data_fidelity,
|
|
53
|
+
max_iter=max_iter,
|
|
54
|
+
thinning=1,
|
|
55
|
+
save_chain=save_chain,
|
|
56
|
+
burnin_ratio=0.0,
|
|
57
|
+
clip=clip,
|
|
58
|
+
verbose=verbose,
|
|
59
|
+
thresh_conv=thres_conv,
|
|
60
|
+
g_statistic=g_statistic,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class DDRM(nn.Module):
|
|
65
|
+
r"""
|
|
66
|
+
Denoising Diffusion Restoration Models (DDRM).
|
|
67
|
+
|
|
68
|
+
This class implements the denoising diffusion restoration model (DDRM) described in https://arxiv.org/abs/2201.11793.
|
|
69
|
+
|
|
70
|
+
The DDRM is a sampling method that uses a denoiser to sample from the posterior distribution of the inverse problem.
|
|
71
|
+
|
|
72
|
+
It requires that the physics operator has a singular value decomposition, i.e.,
|
|
73
|
+
it is :meth:`deepinv.physics.DecomposablePhysics` class.
|
|
74
|
+
|
|
75
|
+
:param torch.nn.Module denoiser: a denoiser model that can handle different noise levels.
|
|
76
|
+
:param list[int], numpy.array sigmas: a list of noise levels to use in the diffusion, they should be in decreasing
|
|
77
|
+
order from 1 to 0.
|
|
78
|
+
:param float eta: hyperparameter
|
|
79
|
+
:param float etab: hyperparameter
|
|
80
|
+
:param bool verbose: if True, print progress
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
denoiser,
|
|
86
|
+
sigmas=np.linspace(1, 0, 100),
|
|
87
|
+
eta=0.85,
|
|
88
|
+
etab=1.0,
|
|
89
|
+
verbose=False,
|
|
90
|
+
):
|
|
91
|
+
super(DDRM, self).__init__()
|
|
92
|
+
self.denoiser = denoiser
|
|
93
|
+
self.sigmas = sigmas
|
|
94
|
+
self.max_iter = len(sigmas)
|
|
95
|
+
self.eta = eta
|
|
96
|
+
self.verbose = verbose
|
|
97
|
+
self.etab = etab
|
|
98
|
+
|
|
99
|
+
def forward(self, y, physics: deepinv.physics.DecomposablePhysics, seed=None):
|
|
100
|
+
r"""
|
|
101
|
+
Runs the diffusion to obtain a random sample of the posterior distribution.
|
|
102
|
+
|
|
103
|
+
:param torch.Tensor y: the measurements.
|
|
104
|
+
:param deepinv.physics.DecomposablePhysics physics: the physics operator, which must have a singular value
|
|
105
|
+
decomposition.
|
|
106
|
+
:param int seed: the seed for the random number generator.
|
|
107
|
+
"""
|
|
108
|
+
# assert physics.__class__ == deepinv.physics.DecomposablePhysics, 'The forward operator requires a singular value decomposition'
|
|
109
|
+
with torch.no_grad():
|
|
110
|
+
if seed:
|
|
111
|
+
np.random.seed(seed)
|
|
112
|
+
torch.manual_seed(seed)
|
|
113
|
+
|
|
114
|
+
if hasattr(physics.noise_model, "sigma"):
|
|
115
|
+
sigma_noise = physics.noise_model.sigma
|
|
116
|
+
else:
|
|
117
|
+
sigma_noise = 0.01
|
|
118
|
+
|
|
119
|
+
if physics.__class__ == deepinv.physics.Denoising:
|
|
120
|
+
mask = torch.ones_like(
|
|
121
|
+
y
|
|
122
|
+
) # TODO: fix for economic SVD decompositions (eg. Decolorize)
|
|
123
|
+
else:
|
|
124
|
+
mask = torch.cat([physics.mask.abs()] * y.shape[0], dim=0)
|
|
125
|
+
|
|
126
|
+
c = np.sqrt(1 - self.eta**2)
|
|
127
|
+
y_bar = physics.U_adjoint(y)
|
|
128
|
+
case = mask > sigma_noise
|
|
129
|
+
y_bar[case] = y_bar[case] / mask[case]
|
|
130
|
+
nsr = torch.zeros_like(mask)
|
|
131
|
+
nsr[case] = sigma_noise / mask[case]
|
|
132
|
+
|
|
133
|
+
# iteration 1
|
|
134
|
+
# compute init noise
|
|
135
|
+
mean = torch.zeros_like(y_bar)
|
|
136
|
+
std = torch.ones_like(y_bar) * self.sigmas[0]
|
|
137
|
+
mean[case] = y_bar[case]
|
|
138
|
+
std[case] = (self.sigmas[0] ** 2 - nsr[case].pow(2)).sqrt()
|
|
139
|
+
x_bar = mean + std * torch.randn_like(y_bar)
|
|
140
|
+
x_bar_prev = x_bar.clone()
|
|
141
|
+
|
|
142
|
+
# denoise
|
|
143
|
+
x = self.denoiser(physics.V(x_bar), self.sigmas[0])
|
|
144
|
+
|
|
145
|
+
for t in tqdm(range(1, self.max_iter), disable=(not self.verbose)):
|
|
146
|
+
# add noise in transformed domain
|
|
147
|
+
x_bar = physics.V_adjoint(x)
|
|
148
|
+
|
|
149
|
+
case2 = torch.logical_and(case, (self.sigmas[t] < nsr))
|
|
150
|
+
case3 = torch.logical_and(case, (self.sigmas[t] >= nsr))
|
|
151
|
+
|
|
152
|
+
# n = np.prod(mask.shape)
|
|
153
|
+
# print(f'case: {case.sum()/n*100:.2f}, case2: {case2.sum()/n*100:.2f}, case3: {case3.sum()/n*100:.2f}')
|
|
154
|
+
|
|
155
|
+
mean = (
|
|
156
|
+
x_bar
|
|
157
|
+
+ c * self.sigmas[t] * (x_bar_prev - x_bar) / self.sigmas[t - 1]
|
|
158
|
+
)
|
|
159
|
+
mean[case2] = (
|
|
160
|
+
x_bar[case2]
|
|
161
|
+
+ c * self.sigmas[t] * (y_bar[case2] - x_bar[case2]) / nsr[case2]
|
|
162
|
+
)
|
|
163
|
+
mean[case3] = (1.0 - self.etab) * x_bar[case3] + self.etab * y_bar[
|
|
164
|
+
case3
|
|
165
|
+
]
|
|
166
|
+
|
|
167
|
+
std = torch.ones_like(x_bar) * self.eta * self.sigmas[t]
|
|
168
|
+
std[case3] = (
|
|
169
|
+
self.sigmas[t] ** 2 - (nsr[case3] * self.etab).pow(2)
|
|
170
|
+
).sqrt()
|
|
171
|
+
|
|
172
|
+
x_bar = mean + std * torch.randn_like(x_bar)
|
|
173
|
+
x_bar_prev = x_bar.clone()
|
|
174
|
+
# denoise
|
|
175
|
+
x = self.denoiser(physics.V(x_bar), self.sigmas[t])
|
|
176
|
+
|
|
177
|
+
return x
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class DiffPIR(nn.Module):
|
|
181
|
+
r"""
|
|
182
|
+
Diffusion PnP Image Restoration (DiffPIR).
|
|
183
|
+
|
|
184
|
+
This class implements the Diffusion PnP image restoration algorithm (DiffPIR) described
|
|
185
|
+
in https://arxiv.org/abs/2305.08995.
|
|
186
|
+
|
|
187
|
+
The DiffPIR algorithm is inspired on a half-quadratic splitting (HQS) plug-and-play algorithm, where the denoiser
|
|
188
|
+
is a conditional diffusion denoiser, combined with a diffusion process. The algorithm writes as follows,
|
|
189
|
+
for :math:`t` decreasing from :math:`T` to :math:`1`:
|
|
190
|
+
|
|
191
|
+
.. math::
|
|
192
|
+
\begin{equation*}
|
|
193
|
+
\begin{aligned}
|
|
194
|
+
x_{0}^{t} &= D_{\theta}(x_t, \frac{\sqrt{1-\overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}}) \\
|
|
195
|
+
\widehat{x}_{0}^{t} &= \operatorname{prox}_{2 f(y, \cdot) /{\rho_t}}(x_{0}^{t}) \\
|
|
196
|
+
\widehat{\varepsilon} &= \left(x_t - \sqrt{\overline{\alpha}_t} \,\,
|
|
197
|
+
\widehat{x}_{0}^t\right)/\sqrt{1-\overline{\alpha}_t} \\
|
|
198
|
+
\varepsilon_t &= \mathcal{N}(0, \mathbf{I}) \\
|
|
199
|
+
x_{t-1} &= \sqrt{\overline{\alpha}_t} \,\, \widehat{x}_{0}^t + \sqrt{1-\overline{\alpha}_t}
|
|
200
|
+
\left(\sqrt{1-\zeta} \,\, \widehat{\varepsilon} + \sqrt{\zeta} \,\, \varepsilon_t\right),
|
|
201
|
+
\end{aligned}
|
|
202
|
+
\end{equation*}
|
|
203
|
+
|
|
204
|
+
where :math:`D_\theta(\cdot,\sigma)` is a Gaussian denoiser network with noise level :math:`\sigma`
|
|
205
|
+
and :math:`f(y, \cdot)` is the data fidelity
|
|
206
|
+
term.
|
|
207
|
+
|
|
208
|
+
.. note::
|
|
209
|
+
|
|
210
|
+
The algorithm might require careful tunning of the hyperparameters :math:`\lambda` and :math:`\zeta` to
|
|
211
|
+
obtain optimal results.
|
|
212
|
+
|
|
213
|
+
:param torch.nn.Module model: a conditional noise estimation model
|
|
214
|
+
:param float sigma: the noise level of the data
|
|
215
|
+
:param deepinv.optim.DataFidelity data_fidelity: the data fidelity operator
|
|
216
|
+
:param int max_iter: the number of iterations to run the algorithm (default: 100)
|
|
217
|
+
:param float zeta: hyperparameter :math:`\zeta` for the sampling step (must be between 0 and 1). Default: 1.0.
|
|
218
|
+
:param float lambda_: hyperparameter :math:`\lambda` for the data fidelity step
|
|
219
|
+
(:math:`\rho_t = \lambda \frac{\sigma_n^2}{\bar{\sigma}_t^2}` in the paper where the optimal value range
|
|
220
|
+
between 3.0 and 25.0 depending on the problem). Default: ``7.0``.
|
|
221
|
+
:param bool verbose: if ``True``, print progress
|
|
222
|
+
:param str device: the device to use for the computations
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
model,
|
|
228
|
+
data_fidelity,
|
|
229
|
+
sigma=0.05,
|
|
230
|
+
max_iter=100,
|
|
231
|
+
zeta=1.0,
|
|
232
|
+
lambda_=7.0,
|
|
233
|
+
verbose=False,
|
|
234
|
+
device="cpu",
|
|
235
|
+
):
|
|
236
|
+
super(DiffPIR, self).__init__()
|
|
237
|
+
self.model = model
|
|
238
|
+
self.lambda_ = lambda_
|
|
239
|
+
self.data_fidelity = data_fidelity
|
|
240
|
+
self.max_iter = max_iter
|
|
241
|
+
self.zeta = zeta
|
|
242
|
+
self.verbose = verbose
|
|
243
|
+
self.device = device
|
|
244
|
+
self.beta_start, self.beta_end = 0.1 / 1000, 20 / 1000
|
|
245
|
+
self.num_train_timesteps = 1000
|
|
246
|
+
|
|
247
|
+
(
|
|
248
|
+
self.sqrt_1m_alphas_cumprod,
|
|
249
|
+
self.reduced_alpha_cumprod,
|
|
250
|
+
self.sqrt_alphas_cumprod,
|
|
251
|
+
self.sqrt_recip_alphas_cumprod,
|
|
252
|
+
self.sqrt_recipm1_alphas_cumprod,
|
|
253
|
+
self.betas,
|
|
254
|
+
) = self.get_alpha_beta()
|
|
255
|
+
|
|
256
|
+
self.rhos, self.sigmas, self.seq = self.get_noise_schedule(sigma=sigma)
|
|
257
|
+
|
|
258
|
+
def get_alpha_beta(self):
|
|
259
|
+
"""
|
|
260
|
+
Get the alpha and beta sequences for the algorithm. This is necessary for mapping noise levels to timesteps.
|
|
261
|
+
"""
|
|
262
|
+
betas = np.linspace(
|
|
263
|
+
self.beta_start, self.beta_end, self.num_train_timesteps, dtype=np.float32
|
|
264
|
+
)
|
|
265
|
+
betas = torch.from_numpy(betas).to(self.device)
|
|
266
|
+
alphas = 1.0 - betas
|
|
267
|
+
alphas_cumprod = np.cumprod(alphas.cpu(), axis=0) # This is \overline{\alpha}_t
|
|
268
|
+
|
|
269
|
+
# Useful sequences deriving from alphas_cumprod
|
|
270
|
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
|
271
|
+
sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
|
272
|
+
reduced_alpha_cumprod = torch.div(
|
|
273
|
+
sqrt_1m_alphas_cumprod, sqrt_alphas_cumprod
|
|
274
|
+
) # equivalent noise sigma on image
|
|
275
|
+
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
|
|
276
|
+
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
|
|
277
|
+
|
|
278
|
+
return (
|
|
279
|
+
sqrt_1m_alphas_cumprod,
|
|
280
|
+
reduced_alpha_cumprod,
|
|
281
|
+
sqrt_alphas_cumprod,
|
|
282
|
+
sqrt_recip_alphas_cumprod,
|
|
283
|
+
sqrt_recipm1_alphas_cumprod,
|
|
284
|
+
betas,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def get_noise_schedule(self, sigma):
|
|
288
|
+
"""
|
|
289
|
+
Get the noise schedule for the algorithm.
|
|
290
|
+
"""
|
|
291
|
+
lambda_ = self.lambda_
|
|
292
|
+
sigmas = []
|
|
293
|
+
sigma_ks = []
|
|
294
|
+
rhos = []
|
|
295
|
+
for i in range(self.num_train_timesteps):
|
|
296
|
+
sigmas.append(self.reduced_alpha_cumprod[self.num_train_timesteps - 1 - i])
|
|
297
|
+
sigma_ks.append(
|
|
298
|
+
(self.sqrt_1m_alphas_cumprod[i] / self.sqrt_alphas_cumprod[i])
|
|
299
|
+
)
|
|
300
|
+
rhos.append(lambda_ * (sigma**2) / (sigma_ks[i] ** 2))
|
|
301
|
+
rhos, sigmas = torch.tensor(rhos).to(self.device), torch.tensor(sigmas).to(
|
|
302
|
+
self.device
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
seq = np.sqrt(np.linspace(0, self.num_train_timesteps**2, self.max_iter))
|
|
306
|
+
seq = [int(s) for s in list(seq)]
|
|
307
|
+
seq[-1] = seq[-1] - 1
|
|
308
|
+
|
|
309
|
+
return rhos, sigmas, seq
|
|
310
|
+
|
|
311
|
+
def find_nearest(self, array, value):
|
|
312
|
+
"""
|
|
313
|
+
Find the argmin of the nearest value in an array.
|
|
314
|
+
"""
|
|
315
|
+
array = np.asarray(array)
|
|
316
|
+
idx = (np.abs(array - value)).argmin()
|
|
317
|
+
return idx
|
|
318
|
+
|
|
319
|
+
def get_alpha_prod(
|
|
320
|
+
self, beta_start=0.1 / 1000, beta_end=20 / 1000, num_train_timesteps=1000
|
|
321
|
+
):
|
|
322
|
+
"""
|
|
323
|
+
Get the alpha sequences; this is necessary for mapping noise levels to timesteps when performing pure denoising.
|
|
324
|
+
"""
|
|
325
|
+
betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
|
326
|
+
betas = torch.from_numpy(
|
|
327
|
+
betas
|
|
328
|
+
) # .to(self.device) Removing this for now, can be done outside
|
|
329
|
+
alphas = 1.0 - betas
|
|
330
|
+
alphas_cumprod = np.cumprod(alphas.cpu(), axis=0) # This is \overline{\alpha}_t
|
|
331
|
+
|
|
332
|
+
# Useful sequences deriving from alphas_cumprod
|
|
333
|
+
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
|
|
334
|
+
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
|
|
335
|
+
return (
|
|
336
|
+
sqrt_recip_alphas_cumprod,
|
|
337
|
+
sqrt_recipm1_alphas_cumprod,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
def forward(
|
|
341
|
+
self,
|
|
342
|
+
y,
|
|
343
|
+
physics: deepinv.physics.LinearPhysics,
|
|
344
|
+
seed=None,
|
|
345
|
+
x_init=None,
|
|
346
|
+
):
|
|
347
|
+
r"""
|
|
348
|
+
Runs the diffusion to obtain a random sample of the posterior distribution.
|
|
349
|
+
|
|
350
|
+
:param torch.Tensor y: the measurements.
|
|
351
|
+
:param deepinv.physics.LinearPhysics physics: the physics operator.
|
|
352
|
+
:param float sigma: the noise level of the data.
|
|
353
|
+
:param int seed: the seed for the random number generator.
|
|
354
|
+
:param torch.Tensor x_init: the initial guess for the reconstruction.
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
if seed:
|
|
358
|
+
torch.manual_seed(seed)
|
|
359
|
+
|
|
360
|
+
if hasattr(physics.noise_model, "sigma"):
|
|
361
|
+
sigma = physics.noise_model.sigma # Then we overwrite the default values
|
|
362
|
+
self.rhos, self.sigmas, self.seq = self.get_noise_schedule(sigma=sigma)
|
|
363
|
+
|
|
364
|
+
# Initialization
|
|
365
|
+
if x_init is None: # Necessary when x and y don't live in the same space
|
|
366
|
+
x = 2 * physics.A_adjoint(y) - 1
|
|
367
|
+
else:
|
|
368
|
+
x = 2 * x_init - 1
|
|
369
|
+
|
|
370
|
+
sqrt_recip_alphas_cumprod, sqrt_recipm1_alphas_cumprod = self.get_alpha_prod()
|
|
371
|
+
|
|
372
|
+
with torch.no_grad():
|
|
373
|
+
for i in range(len(self.seq)):
|
|
374
|
+
# Current noise level
|
|
375
|
+
curr_sigma = self.sigmas[self.seq[i]].cpu().numpy()
|
|
376
|
+
|
|
377
|
+
# time step associated with the noise level sigmas[i]
|
|
378
|
+
t_i = self.find_nearest(self.reduced_alpha_cumprod, curr_sigma)
|
|
379
|
+
|
|
380
|
+
# Denoising step
|
|
381
|
+
x_aux = x / 2 + 0.5
|
|
382
|
+
denoised = 2 * self.model(x_aux, curr_sigma / 2) - 1
|
|
383
|
+
noise_est = (
|
|
384
|
+
sqrt_recip_alphas_cumprod[t_i] * x - denoised
|
|
385
|
+
) / sqrt_recipm1_alphas_cumprod[t_i]
|
|
386
|
+
|
|
387
|
+
x0 = (
|
|
388
|
+
self.sqrt_recip_alphas_cumprod[t_i] * x
|
|
389
|
+
- self.sqrt_recipm1_alphas_cumprod[t_i] * noise_est
|
|
390
|
+
)
|
|
391
|
+
x0 = x0.clamp(-1, 1)
|
|
392
|
+
|
|
393
|
+
if not self.seq[i] == self.seq[-1]:
|
|
394
|
+
# Data fidelity step
|
|
395
|
+
x0_p = x0 / 2 + 0.5
|
|
396
|
+
x0_p = self.data_fidelity.prox(
|
|
397
|
+
x0_p, y, physics, gamma=1 / (2 * self.rhos[t_i])
|
|
398
|
+
)
|
|
399
|
+
x0 = x0_p * 2 - 1
|
|
400
|
+
|
|
401
|
+
# Sampling step
|
|
402
|
+
t_im1 = self.find_nearest(
|
|
403
|
+
self.reduced_alpha_cumprod,
|
|
404
|
+
self.sigmas[self.seq[i + 1]].cpu().numpy(),
|
|
405
|
+
) # time step associated with the next noise level
|
|
406
|
+
eps = (
|
|
407
|
+
x - self.sqrt_alphas_cumprod[t_i] * x0
|
|
408
|
+
) / self.sqrt_1m_alphas_cumprod[
|
|
409
|
+
t_i
|
|
410
|
+
] # effective noise
|
|
411
|
+
x = (
|
|
412
|
+
self.sqrt_alphas_cumprod[t_im1] * x0
|
|
413
|
+
+ self.sqrt_1m_alphas_cumprod[t_im1]
|
|
414
|
+
* np.sqrt(1 - self.zeta)
|
|
415
|
+
* eps
|
|
416
|
+
+ self.sqrt_1m_alphas_cumprod[t_im1]
|
|
417
|
+
* np.sqrt(self.zeta)
|
|
418
|
+
* torch.randn_like(x)
|
|
419
|
+
) # sampling
|
|
420
|
+
|
|
421
|
+
out = x / 2 + 0.5 # back to [0, 1] range
|
|
422
|
+
|
|
423
|
+
return out
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
class DPS(nn.Module):
|
|
427
|
+
r"""
|
|
428
|
+
Diffusion Posterior Sampling (DPS).
|
|
429
|
+
|
|
430
|
+
This class implements the Diffusion Posterior Sampling algorithm (DPS) described in
|
|
431
|
+
https://arxiv.org/abs/2209.14687.
|
|
432
|
+
|
|
433
|
+
DPS is an approximation of a gradient-based posterior sampling algorithm,
|
|
434
|
+
which has minimal assumptions on the forward model. The only restriction is that
|
|
435
|
+
the measurement model has to be differentiable, which is generally the case.
|
|
436
|
+
|
|
437
|
+
The algorithm writes as follows, for :math:`t` decreasing from :math:`T` to :math:`1`:
|
|
438
|
+
|
|
439
|
+
.. math::
|
|
440
|
+
|
|
441
|
+
\begin{equation*}
|
|
442
|
+
\begin{aligned}
|
|
443
|
+
\widehat{\mathbf{x}}_{t} &= D_{\theta}(\mathbf{x}_t, \sqrt{1-\overline{\alpha}_t}/\sqrt{\overline{\alpha}_t})
|
|
444
|
+
\\
|
|
445
|
+
\mathbf{g}_t &= \nabla_{\mathbf{x}_t} \log p( \widehat{\mathbf{x}}_{t}(\mathbf{x}_t) | \mathbf{y} ) \\
|
|
446
|
+
\mathbf{\varepsilon}_t &= \mathcal{N}(0, \mathbf{I}) \\
|
|
447
|
+
\mathbf{x}_{t-1} &= a_t \,\, \mathbf{x}_t
|
|
448
|
+
+ b_t \, \, \widehat{\mathbf{x}}_t
|
|
449
|
+
+ \tilde{\sigma}_t \, \, \mathbf{\varepsilon}_t + \mathbf{g}_t,
|
|
450
|
+
\end{aligned}
|
|
451
|
+
\end{equation*}
|
|
452
|
+
|
|
453
|
+
where :math:`\denoiser{\cdot}{\sigma}` is a denoising network for noise level :math:`\sigma`,
|
|
454
|
+
:math:`\eta` is a hyperparameter, and the constants :math:`\tilde{\sigma}_t, a_t, b_t` are defined as
|
|
455
|
+
|
|
456
|
+
.. math::
|
|
457
|
+
\begin{equation*}
|
|
458
|
+
\begin{aligned}
|
|
459
|
+
\tilde{\sigma}_t &= \eta \sqrt{ (1 - \frac{\overline{\alpha}_t}{\overline{\alpha}_{t-1}})
|
|
460
|
+
\frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_t}} \\
|
|
461
|
+
a_t &= \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2}/\sqrt{1-\overline{\alpha}_t} \\
|
|
462
|
+
b_t &= \sqrt{\overline{\alpha}_{t-1}} - \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2}
|
|
463
|
+
\frac{\sqrt{\overline{\alpha}_{t}}}{\sqrt{1 - \overline{\alpha}_{t}}}.
|
|
464
|
+
\end{aligned}
|
|
465
|
+
\end{equation*}
|
|
466
|
+
|
|
467
|
+
:param torch.nn.Module model: a denoiser network that can handle different noise levels
|
|
468
|
+
:param deepinv.optim.DataFidelity data_fidelity: the data fidelity operator
|
|
469
|
+
:param int max_iter: the number of diffusion iterations to run the algorithm (default: 1000)
|
|
470
|
+
:param float eta: DDIM hyperparameter which controls the stochasticity
|
|
471
|
+
:param bool verbose: if True, print progress
|
|
472
|
+
:param str device: the device to use for the computations
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
def __init__(
|
|
476
|
+
self,
|
|
477
|
+
model,
|
|
478
|
+
data_fidelity,
|
|
479
|
+
max_iter=1000,
|
|
480
|
+
eta=1.0,
|
|
481
|
+
verbose=False,
|
|
482
|
+
device="cpu",
|
|
483
|
+
save_iterates=False,
|
|
484
|
+
):
|
|
485
|
+
super(DPS, self).__init__()
|
|
486
|
+
self.model = model
|
|
487
|
+
self.model.requires_grad_(True)
|
|
488
|
+
self.data_fidelity = data_fidelity
|
|
489
|
+
self.max_iter = max_iter
|
|
490
|
+
self.eta = eta
|
|
491
|
+
self.verbose = verbose
|
|
492
|
+
self.device = device
|
|
493
|
+
self.beta_start, self.beta_end = 0.1 / 1000, 20 / 1000
|
|
494
|
+
self.num_train_timesteps = 1000
|
|
495
|
+
self.save_iterates = save_iterates
|
|
496
|
+
|
|
497
|
+
self.betas, self.alpha_cumprod = self.compute_alpha_betas()
|
|
498
|
+
|
|
499
|
+
def compute_alpha_betas(self):
|
|
500
|
+
r"""
|
|
501
|
+
|
|
502
|
+
Get the beta and alpha sequences for the algorithm. This is necessary for mapping noise levels to timesteps.
|
|
503
|
+
|
|
504
|
+
"""
|
|
505
|
+
betas = np.linspace(
|
|
506
|
+
self.beta_start, self.beta_end, self.num_train_timesteps, dtype=np.float32
|
|
507
|
+
)
|
|
508
|
+
betas = torch.from_numpy(betas).to(self.device)
|
|
509
|
+
|
|
510
|
+
alpha_cumprod = (
|
|
511
|
+
1 - torch.cat([torch.zeros(1).to(betas.device), betas], dim=0)
|
|
512
|
+
).cumprod(dim=0)
|
|
513
|
+
return betas, alpha_cumprod
|
|
514
|
+
|
|
515
|
+
def get_alpha(self, alpha_cumprod, t):
|
|
516
|
+
a = alpha_cumprod.index_select(0, t + 1).view(-1, 1, 1, 1)
|
|
517
|
+
return a
|
|
518
|
+
|
|
519
|
+
def forward(
|
|
520
|
+
self,
|
|
521
|
+
y,
|
|
522
|
+
physics: deepinv.physics.Physics,
|
|
523
|
+
seed=None,
|
|
524
|
+
x_init=None,
|
|
525
|
+
):
|
|
526
|
+
r"""
|
|
527
|
+
Runs the diffusion to obtain a random sample of the posterior distribution.
|
|
528
|
+
|
|
529
|
+
:param torch.Tensor y: the measurements.
|
|
530
|
+
:param deepinv.physics.LinearPhysics physics: the physics operator.
|
|
531
|
+
:param int seed: the seed for the random number generator.
|
|
532
|
+
:param torch.Tensor x_init: the initial guess for the reconstruction.
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
if seed:
|
|
536
|
+
torch.manual_seed(seed)
|
|
537
|
+
|
|
538
|
+
# Initialization
|
|
539
|
+
if x_init is None: # Necessary when x and y don't live in the same space
|
|
540
|
+
x = 2 * physics.A_adjoint(y) - 1
|
|
541
|
+
else:
|
|
542
|
+
x = 2 * x_init - 1
|
|
543
|
+
|
|
544
|
+
skip = self.num_train_timesteps // self.max_iter
|
|
545
|
+
batch_size = y.shape[0]
|
|
546
|
+
|
|
547
|
+
seq = range(0, self.num_train_timesteps, skip)
|
|
548
|
+
seq_next = [-1] + list(seq[:-1])
|
|
549
|
+
time_pairs = list(zip(reversed(seq), reversed(seq_next)))
|
|
550
|
+
|
|
551
|
+
if self.save_iterates:
|
|
552
|
+
xs = [x]
|
|
553
|
+
|
|
554
|
+
xt = x.to(self.device)
|
|
555
|
+
|
|
556
|
+
for i, j in tqdm(time_pairs, disable=(not self.verbose)):
|
|
557
|
+
t = (torch.ones(batch_size) * i).to(self.device)
|
|
558
|
+
next_t = (torch.ones(batch_size) * j).to(self.device)
|
|
559
|
+
|
|
560
|
+
at = self.get_alpha(self.alpha_cumprod, t.long())
|
|
561
|
+
at_next = self.get_alpha(self.alpha_cumprod, next_t.long())
|
|
562
|
+
|
|
563
|
+
with torch.enable_grad():
|
|
564
|
+
xt.requires_grad_(True)
|
|
565
|
+
|
|
566
|
+
# 1. Denoising
|
|
567
|
+
# we call the denoiser using standard deviation instead of the time step.
|
|
568
|
+
aux_x = xt / 2 + 0.5
|
|
569
|
+
x0_t = 2 * self.model(aux_x, (1 - at).sqrt() / at.sqrt() / 2) - 1
|
|
570
|
+
|
|
571
|
+
x0_t = torch.clip(x0_t, -1.0, 1.0) # optional
|
|
572
|
+
|
|
573
|
+
# DPS
|
|
574
|
+
l2_loss = self.data_fidelity(x0_t, y, physics).sqrt().sum()
|
|
575
|
+
|
|
576
|
+
norm_grad = torch.autograd.grad(outputs=l2_loss, inputs=xt)[0]
|
|
577
|
+
norm_grad = norm_grad.detach()
|
|
578
|
+
|
|
579
|
+
c1 = ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() * self.eta
|
|
580
|
+
c2 = ((1 - at_next) - c1**2).sqrt()
|
|
581
|
+
|
|
582
|
+
# 3. noise step
|
|
583
|
+
epsilon = torch.randn_like(xt)
|
|
584
|
+
|
|
585
|
+
# 4. DDPM(IM) step
|
|
586
|
+
xt_next = (
|
|
587
|
+
(at_next.sqrt() - c2 * at.sqrt() / (1 - at).sqrt()) * x0_t
|
|
588
|
+
+ c1 * epsilon
|
|
589
|
+
+ c2 * xt / (1 - at).sqrt()
|
|
590
|
+
- norm_grad
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
if self.save_iterates:
|
|
594
|
+
xs.append(xt_next.to("cpu"))
|
|
595
|
+
xt = xt_next.clone()
|
|
596
|
+
|
|
597
|
+
if self.save_iterates:
|
|
598
|
+
return xs
|
|
599
|
+
else:
|
|
600
|
+
return xt
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
# if __name__ == "__main__":
|
|
604
|
+
# import deepinv as dinv
|
|
605
|
+
# from deepinv.models.denoiser import Denoiser
|
|
606
|
+
# import torchvision
|
|
607
|
+
# from deepinv.utils.metric import cal_psnr
|
|
608
|
+
#
|
|
609
|
+
# device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
|
|
610
|
+
#
|
|
611
|
+
# x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
|
|
612
|
+
# x = x.unsqueeze(0).float().to(device) / 255
|
|
613
|
+
#
|
|
614
|
+
# sigma_noise = 0.01
|
|
615
|
+
# # physics = dinv.physics.Denoising()
|
|
616
|
+
#
|
|
617
|
+
# # physics = dinv.physics.BlurFFT(img_size=x.shape[1:], filter=dinv.physics.blur.gaussian_blur(sigma=1.),
|
|
618
|
+
# # device=device)
|
|
619
|
+
# physics = dinv.physics.Decolorize()
|
|
620
|
+
# # physics = dinv.physics.Inpainting(
|
|
621
|
+
# # mask=0.5, tensor_size=(3, 218, 178), device=dinv.device
|
|
622
|
+
# # )
|
|
623
|
+
# # physics.mask *= (torch.rand_like(physics.mask))
|
|
624
|
+
# physics.noise_model = dinv.physics.GaussianNoise(sigma_noise)
|
|
625
|
+
#
|
|
626
|
+
# y = physics(x)
|
|
627
|
+
# model_spec = {
|
|
628
|
+
# "name": "drunet",
|
|
629
|
+
# "args": {"device": device, "pretrained": "download"},
|
|
630
|
+
# }
|
|
631
|
+
#
|
|
632
|
+
# denoiser = Denoiser(model_spec=model_spec)
|
|
633
|
+
#
|
|
634
|
+
# f = DDRM(
|
|
635
|
+
# denoiser=denoiser,
|
|
636
|
+
# etab=1.0,
|
|
637
|
+
# sigma_noise=sigma_noise,
|
|
638
|
+
# sigmas=np.linspace(1, 0, 100),
|
|
639
|
+
# verbose=True,
|
|
640
|
+
# )
|
|
641
|
+
#
|
|
642
|
+
# xhat = f(y, physics)
|
|
643
|
+
# dinv.utils.plot(
|
|
644
|
+
# [physics.A_adjoint(y), x, xhat], titles=["meas.", "ground-truth", "xhat"]
|
|
645
|
+
# )
|
|
646
|
+
#
|
|
647
|
+
# print(f"PSNR 1 sample: {cal_psnr(x, xhat):.2f} dB")
|
|
648
|
+
# # print(f'mean PSNR sample: {cal_psnr(x, denoiser(y, sigma_noise)):.2f} dB')
|
|
649
|
+
#
|
|
650
|
+
# # sampler = dinv.sampling.DiffusionSampler(f, max_iter=10, save_chain=True, verbose=True)
|
|
651
|
+
# # xmean, xvar = sampler(y, physics)
|
|
652
|
+
#
|
|
653
|
+
# # chain = sampler.get_chain()
|
|
654
|
+
# # distance = np.zeros((len(chain)))
|
|
655
|
+
# # for k, xhat in enumerate(chain):
|
|
656
|
+
# # dist = (xhat - xmean).pow(2).mean()
|
|
657
|
+
# # distance[k] = dist
|
|
658
|
+
# # distance = np.sort(distance)
|
|
659
|
+
# # thres = distance[int(len(distance) * .95)] #
|
|
660
|
+
# # err = (x - xmean).pow(2).mean()
|
|
661
|
+
# # print(f'Confidence region: {thres:.2e}, error: {err:.2e}')
|
|
662
|
+
#
|
|
663
|
+
# # xstdn = xvar.sqrt()
|
|
664
|
+
# # xstdn_plot = xstdn.sum(dim=1).unsqueeze(1)
|
|
665
|
+
#
|
|
666
|
+
# # error = (xmean - x).abs() # per pixel average abs. error
|
|
667
|
+
# # error_plot = error.sum(dim=1).unsqueeze(1)
|
|
668
|
+
#
|
|
669
|
+
# # print(f'Correct std: {(xstdn>error).sum()/np.prod(xstdn.shape)*100:.1f}%')
|
|
670
|
+
# # error = (xmean - x)
|
|
671
|
+
# # dinv.utils.plot_debug(
|
|
672
|
+
# # [physics.A_adjoint(y), x, xmean, xstdn_plot, error_plot], titles=["meas.", "ground-truth", "mean", "std", "error"]
|
|
673
|
+
# # )
|
|
674
|
+
#
|
|
675
|
+
# # print(f'PSNR 1 sample: {cal_psnr(x, chain[0]):.2f} dB')
|
|
676
|
+
# # print(f'mean PSNR sample: {cal_psnr(x, xmean):.2f} dB')
|