deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,676 @@
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+
6
+ import deepinv.physics
7
+ from deepinv.sampling.langevin import MonteCarlo
8
+
9
+
10
+ class DiffusionSampler(MonteCarlo):
11
+ r"""
12
+ Turns a diffusion method into a Monte Carlo sampler
13
+
14
+ Unlike diffusion methods, the resulting sampler computes the mean and variance of the distribution
15
+ by running the diffusion multiple times.
16
+
17
+ :param torch.nn.Module diffusion: a diffusion model
18
+ :param int max_iter: the maximum number of iterations
19
+ :param tuple clip: the clip range
20
+ :param callable g_statistic: the algorithm computes mean and variance of the g function, by default :math:`g(x) = x`.
21
+ :param float thres_conv: the convergence threshold for the mean and variance
22
+ :param bool verbose: whether to print the progress
23
+ :param bool save_chain: whether to save the chain
24
+ :param int thinning: the thinning factor
25
+ :param float burnin_ratio: the burnin ratio
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ diffusion,
31
+ max_iter=1e2,
32
+ clip=(-1, 2),
33
+ thres_conv=1e-1,
34
+ g_statistic=lambda x: x,
35
+ verbose=True,
36
+ save_chain=False,
37
+ ):
38
+ # generate an iterator
39
+ # set the params of the base class
40
+ data_fidelity = None
41
+ diffusion.verbose = False
42
+ prior = diffusion
43
+
44
+ def iterator(x, y, physics, likelihood, prior):
45
+ # run one sampling kernel iteration
46
+ x = prior(y, physics)
47
+ return x
48
+
49
+ super().__init__(
50
+ iterator,
51
+ prior,
52
+ data_fidelity,
53
+ max_iter=max_iter,
54
+ thinning=1,
55
+ save_chain=save_chain,
56
+ burnin_ratio=0.0,
57
+ clip=clip,
58
+ verbose=verbose,
59
+ thresh_conv=thres_conv,
60
+ g_statistic=g_statistic,
61
+ )
62
+
63
+
64
+ class DDRM(nn.Module):
65
+ r"""
66
+ Denoising Diffusion Restoration Models (DDRM).
67
+
68
+ This class implements the denoising diffusion restoration model (DDRM) described in https://arxiv.org/abs/2201.11793.
69
+
70
+ The DDRM is a sampling method that uses a denoiser to sample from the posterior distribution of the inverse problem.
71
+
72
+ It requires that the physics operator has a singular value decomposition, i.e.,
73
+ it is :meth:`deepinv.physics.DecomposablePhysics` class.
74
+
75
+ :param torch.nn.Module denoiser: a denoiser model that can handle different noise levels.
76
+ :param list[int], numpy.array sigmas: a list of noise levels to use in the diffusion, they should be in decreasing
77
+ order from 1 to 0.
78
+ :param float eta: hyperparameter
79
+ :param float etab: hyperparameter
80
+ :param bool verbose: if True, print progress
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ denoiser,
86
+ sigmas=np.linspace(1, 0, 100),
87
+ eta=0.85,
88
+ etab=1.0,
89
+ verbose=False,
90
+ ):
91
+ super(DDRM, self).__init__()
92
+ self.denoiser = denoiser
93
+ self.sigmas = sigmas
94
+ self.max_iter = len(sigmas)
95
+ self.eta = eta
96
+ self.verbose = verbose
97
+ self.etab = etab
98
+
99
+ def forward(self, y, physics: deepinv.physics.DecomposablePhysics, seed=None):
100
+ r"""
101
+ Runs the diffusion to obtain a random sample of the posterior distribution.
102
+
103
+ :param torch.Tensor y: the measurements.
104
+ :param deepinv.physics.DecomposablePhysics physics: the physics operator, which must have a singular value
105
+ decomposition.
106
+ :param int seed: the seed for the random number generator.
107
+ """
108
+ # assert physics.__class__ == deepinv.physics.DecomposablePhysics, 'The forward operator requires a singular value decomposition'
109
+ with torch.no_grad():
110
+ if seed:
111
+ np.random.seed(seed)
112
+ torch.manual_seed(seed)
113
+
114
+ if hasattr(physics.noise_model, "sigma"):
115
+ sigma_noise = physics.noise_model.sigma
116
+ else:
117
+ sigma_noise = 0.01
118
+
119
+ if physics.__class__ == deepinv.physics.Denoising:
120
+ mask = torch.ones_like(
121
+ y
122
+ ) # TODO: fix for economic SVD decompositions (eg. Decolorize)
123
+ else:
124
+ mask = torch.cat([physics.mask.abs()] * y.shape[0], dim=0)
125
+
126
+ c = np.sqrt(1 - self.eta**2)
127
+ y_bar = physics.U_adjoint(y)
128
+ case = mask > sigma_noise
129
+ y_bar[case] = y_bar[case] / mask[case]
130
+ nsr = torch.zeros_like(mask)
131
+ nsr[case] = sigma_noise / mask[case]
132
+
133
+ # iteration 1
134
+ # compute init noise
135
+ mean = torch.zeros_like(y_bar)
136
+ std = torch.ones_like(y_bar) * self.sigmas[0]
137
+ mean[case] = y_bar[case]
138
+ std[case] = (self.sigmas[0] ** 2 - nsr[case].pow(2)).sqrt()
139
+ x_bar = mean + std * torch.randn_like(y_bar)
140
+ x_bar_prev = x_bar.clone()
141
+
142
+ # denoise
143
+ x = self.denoiser(physics.V(x_bar), self.sigmas[0])
144
+
145
+ for t in tqdm(range(1, self.max_iter), disable=(not self.verbose)):
146
+ # add noise in transformed domain
147
+ x_bar = physics.V_adjoint(x)
148
+
149
+ case2 = torch.logical_and(case, (self.sigmas[t] < nsr))
150
+ case3 = torch.logical_and(case, (self.sigmas[t] >= nsr))
151
+
152
+ # n = np.prod(mask.shape)
153
+ # print(f'case: {case.sum()/n*100:.2f}, case2: {case2.sum()/n*100:.2f}, case3: {case3.sum()/n*100:.2f}')
154
+
155
+ mean = (
156
+ x_bar
157
+ + c * self.sigmas[t] * (x_bar_prev - x_bar) / self.sigmas[t - 1]
158
+ )
159
+ mean[case2] = (
160
+ x_bar[case2]
161
+ + c * self.sigmas[t] * (y_bar[case2] - x_bar[case2]) / nsr[case2]
162
+ )
163
+ mean[case3] = (1.0 - self.etab) * x_bar[case3] + self.etab * y_bar[
164
+ case3
165
+ ]
166
+
167
+ std = torch.ones_like(x_bar) * self.eta * self.sigmas[t]
168
+ std[case3] = (
169
+ self.sigmas[t] ** 2 - (nsr[case3] * self.etab).pow(2)
170
+ ).sqrt()
171
+
172
+ x_bar = mean + std * torch.randn_like(x_bar)
173
+ x_bar_prev = x_bar.clone()
174
+ # denoise
175
+ x = self.denoiser(physics.V(x_bar), self.sigmas[t])
176
+
177
+ return x
178
+
179
+
180
+ class DiffPIR(nn.Module):
181
+ r"""
182
+ Diffusion PnP Image Restoration (DiffPIR).
183
+
184
+ This class implements the Diffusion PnP image restoration algorithm (DiffPIR) described
185
+ in https://arxiv.org/abs/2305.08995.
186
+
187
+ The DiffPIR algorithm is inspired on a half-quadratic splitting (HQS) plug-and-play algorithm, where the denoiser
188
+ is a conditional diffusion denoiser, combined with a diffusion process. The algorithm writes as follows,
189
+ for :math:`t` decreasing from :math:`T` to :math:`1`:
190
+
191
+ .. math::
192
+ \begin{equation*}
193
+ \begin{aligned}
194
+ x_{0}^{t} &= D_{\theta}(x_t, \frac{\sqrt{1-\overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}}) \\
195
+ \widehat{x}_{0}^{t} &= \operatorname{prox}_{2 f(y, \cdot) /{\rho_t}}(x_{0}^{t}) \\
196
+ \widehat{\varepsilon} &= \left(x_t - \sqrt{\overline{\alpha}_t} \,\,
197
+ \widehat{x}_{0}^t\right)/\sqrt{1-\overline{\alpha}_t} \\
198
+ \varepsilon_t &= \mathcal{N}(0, \mathbf{I}) \\
199
+ x_{t-1} &= \sqrt{\overline{\alpha}_t} \,\, \widehat{x}_{0}^t + \sqrt{1-\overline{\alpha}_t}
200
+ \left(\sqrt{1-\zeta} \,\, \widehat{\varepsilon} + \sqrt{\zeta} \,\, \varepsilon_t\right),
201
+ \end{aligned}
202
+ \end{equation*}
203
+
204
+ where :math:`D_\theta(\cdot,\sigma)` is a Gaussian denoiser network with noise level :math:`\sigma`
205
+ and :math:`f(y, \cdot)` is the data fidelity
206
+ term.
207
+
208
+ .. note::
209
+
210
+ The algorithm might require careful tunning of the hyperparameters :math:`\lambda` and :math:`\zeta` to
211
+ obtain optimal results.
212
+
213
+ :param torch.nn.Module model: a conditional noise estimation model
214
+ :param float sigma: the noise level of the data
215
+ :param deepinv.optim.DataFidelity data_fidelity: the data fidelity operator
216
+ :param int max_iter: the number of iterations to run the algorithm (default: 100)
217
+ :param float zeta: hyperparameter :math:`\zeta` for the sampling step (must be between 0 and 1). Default: 1.0.
218
+ :param float lambda_: hyperparameter :math:`\lambda` for the data fidelity step
219
+ (:math:`\rho_t = \lambda \frac{\sigma_n^2}{\bar{\sigma}_t^2}` in the paper where the optimal value range
220
+ between 3.0 and 25.0 depending on the problem). Default: ``7.0``.
221
+ :param bool verbose: if ``True``, print progress
222
+ :param str device: the device to use for the computations
223
+ """
224
+
225
+ def __init__(
226
+ self,
227
+ model,
228
+ data_fidelity,
229
+ sigma=0.05,
230
+ max_iter=100,
231
+ zeta=1.0,
232
+ lambda_=7.0,
233
+ verbose=False,
234
+ device="cpu",
235
+ ):
236
+ super(DiffPIR, self).__init__()
237
+ self.model = model
238
+ self.lambda_ = lambda_
239
+ self.data_fidelity = data_fidelity
240
+ self.max_iter = max_iter
241
+ self.zeta = zeta
242
+ self.verbose = verbose
243
+ self.device = device
244
+ self.beta_start, self.beta_end = 0.1 / 1000, 20 / 1000
245
+ self.num_train_timesteps = 1000
246
+
247
+ (
248
+ self.sqrt_1m_alphas_cumprod,
249
+ self.reduced_alpha_cumprod,
250
+ self.sqrt_alphas_cumprod,
251
+ self.sqrt_recip_alphas_cumprod,
252
+ self.sqrt_recipm1_alphas_cumprod,
253
+ self.betas,
254
+ ) = self.get_alpha_beta()
255
+
256
+ self.rhos, self.sigmas, self.seq = self.get_noise_schedule(sigma=sigma)
257
+
258
+ def get_alpha_beta(self):
259
+ """
260
+ Get the alpha and beta sequences for the algorithm. This is necessary for mapping noise levels to timesteps.
261
+ """
262
+ betas = np.linspace(
263
+ self.beta_start, self.beta_end, self.num_train_timesteps, dtype=np.float32
264
+ )
265
+ betas = torch.from_numpy(betas).to(self.device)
266
+ alphas = 1.0 - betas
267
+ alphas_cumprod = np.cumprod(alphas.cpu(), axis=0) # This is \overline{\alpha}_t
268
+
269
+ # Useful sequences deriving from alphas_cumprod
270
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
271
+ sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
272
+ reduced_alpha_cumprod = torch.div(
273
+ sqrt_1m_alphas_cumprod, sqrt_alphas_cumprod
274
+ ) # equivalent noise sigma on image
275
+ sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
276
+ sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
277
+
278
+ return (
279
+ sqrt_1m_alphas_cumprod,
280
+ reduced_alpha_cumprod,
281
+ sqrt_alphas_cumprod,
282
+ sqrt_recip_alphas_cumprod,
283
+ sqrt_recipm1_alphas_cumprod,
284
+ betas,
285
+ )
286
+
287
+ def get_noise_schedule(self, sigma):
288
+ """
289
+ Get the noise schedule for the algorithm.
290
+ """
291
+ lambda_ = self.lambda_
292
+ sigmas = []
293
+ sigma_ks = []
294
+ rhos = []
295
+ for i in range(self.num_train_timesteps):
296
+ sigmas.append(self.reduced_alpha_cumprod[self.num_train_timesteps - 1 - i])
297
+ sigma_ks.append(
298
+ (self.sqrt_1m_alphas_cumprod[i] / self.sqrt_alphas_cumprod[i])
299
+ )
300
+ rhos.append(lambda_ * (sigma**2) / (sigma_ks[i] ** 2))
301
+ rhos, sigmas = torch.tensor(rhos).to(self.device), torch.tensor(sigmas).to(
302
+ self.device
303
+ )
304
+
305
+ seq = np.sqrt(np.linspace(0, self.num_train_timesteps**2, self.max_iter))
306
+ seq = [int(s) for s in list(seq)]
307
+ seq[-1] = seq[-1] - 1
308
+
309
+ return rhos, sigmas, seq
310
+
311
+ def find_nearest(self, array, value):
312
+ """
313
+ Find the argmin of the nearest value in an array.
314
+ """
315
+ array = np.asarray(array)
316
+ idx = (np.abs(array - value)).argmin()
317
+ return idx
318
+
319
+ def get_alpha_prod(
320
+ self, beta_start=0.1 / 1000, beta_end=20 / 1000, num_train_timesteps=1000
321
+ ):
322
+ """
323
+ Get the alpha sequences; this is necessary for mapping noise levels to timesteps when performing pure denoising.
324
+ """
325
+ betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
326
+ betas = torch.from_numpy(
327
+ betas
328
+ ) # .to(self.device) Removing this for now, can be done outside
329
+ alphas = 1.0 - betas
330
+ alphas_cumprod = np.cumprod(alphas.cpu(), axis=0) # This is \overline{\alpha}_t
331
+
332
+ # Useful sequences deriving from alphas_cumprod
333
+ sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
334
+ sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
335
+ return (
336
+ sqrt_recip_alphas_cumprod,
337
+ sqrt_recipm1_alphas_cumprod,
338
+ )
339
+
340
+ def forward(
341
+ self,
342
+ y,
343
+ physics: deepinv.physics.LinearPhysics,
344
+ seed=None,
345
+ x_init=None,
346
+ ):
347
+ r"""
348
+ Runs the diffusion to obtain a random sample of the posterior distribution.
349
+
350
+ :param torch.Tensor y: the measurements.
351
+ :param deepinv.physics.LinearPhysics physics: the physics operator.
352
+ :param float sigma: the noise level of the data.
353
+ :param int seed: the seed for the random number generator.
354
+ :param torch.Tensor x_init: the initial guess for the reconstruction.
355
+ """
356
+
357
+ if seed:
358
+ torch.manual_seed(seed)
359
+
360
+ if hasattr(physics.noise_model, "sigma"):
361
+ sigma = physics.noise_model.sigma # Then we overwrite the default values
362
+ self.rhos, self.sigmas, self.seq = self.get_noise_schedule(sigma=sigma)
363
+
364
+ # Initialization
365
+ if x_init is None: # Necessary when x and y don't live in the same space
366
+ x = 2 * physics.A_adjoint(y) - 1
367
+ else:
368
+ x = 2 * x_init - 1
369
+
370
+ sqrt_recip_alphas_cumprod, sqrt_recipm1_alphas_cumprod = self.get_alpha_prod()
371
+
372
+ with torch.no_grad():
373
+ for i in range(len(self.seq)):
374
+ # Current noise level
375
+ curr_sigma = self.sigmas[self.seq[i]].cpu().numpy()
376
+
377
+ # time step associated with the noise level sigmas[i]
378
+ t_i = self.find_nearest(self.reduced_alpha_cumprod, curr_sigma)
379
+
380
+ # Denoising step
381
+ x_aux = x / 2 + 0.5
382
+ denoised = 2 * self.model(x_aux, curr_sigma / 2) - 1
383
+ noise_est = (
384
+ sqrt_recip_alphas_cumprod[t_i] * x - denoised
385
+ ) / sqrt_recipm1_alphas_cumprod[t_i]
386
+
387
+ x0 = (
388
+ self.sqrt_recip_alphas_cumprod[t_i] * x
389
+ - self.sqrt_recipm1_alphas_cumprod[t_i] * noise_est
390
+ )
391
+ x0 = x0.clamp(-1, 1)
392
+
393
+ if not self.seq[i] == self.seq[-1]:
394
+ # Data fidelity step
395
+ x0_p = x0 / 2 + 0.5
396
+ x0_p = self.data_fidelity.prox(
397
+ x0_p, y, physics, gamma=1 / (2 * self.rhos[t_i])
398
+ )
399
+ x0 = x0_p * 2 - 1
400
+
401
+ # Sampling step
402
+ t_im1 = self.find_nearest(
403
+ self.reduced_alpha_cumprod,
404
+ self.sigmas[self.seq[i + 1]].cpu().numpy(),
405
+ ) # time step associated with the next noise level
406
+ eps = (
407
+ x - self.sqrt_alphas_cumprod[t_i] * x0
408
+ ) / self.sqrt_1m_alphas_cumprod[
409
+ t_i
410
+ ] # effective noise
411
+ x = (
412
+ self.sqrt_alphas_cumprod[t_im1] * x0
413
+ + self.sqrt_1m_alphas_cumprod[t_im1]
414
+ * np.sqrt(1 - self.zeta)
415
+ * eps
416
+ + self.sqrt_1m_alphas_cumprod[t_im1]
417
+ * np.sqrt(self.zeta)
418
+ * torch.randn_like(x)
419
+ ) # sampling
420
+
421
+ out = x / 2 + 0.5 # back to [0, 1] range
422
+
423
+ return out
424
+
425
+
426
+ class DPS(nn.Module):
427
+ r"""
428
+ Diffusion Posterior Sampling (DPS).
429
+
430
+ This class implements the Diffusion Posterior Sampling algorithm (DPS) described in
431
+ https://arxiv.org/abs/2209.14687.
432
+
433
+ DPS is an approximation of a gradient-based posterior sampling algorithm,
434
+ which has minimal assumptions on the forward model. The only restriction is that
435
+ the measurement model has to be differentiable, which is generally the case.
436
+
437
+ The algorithm writes as follows, for :math:`t` decreasing from :math:`T` to :math:`1`:
438
+
439
+ .. math::
440
+
441
+ \begin{equation*}
442
+ \begin{aligned}
443
+ \widehat{\mathbf{x}}_{t} &= D_{\theta}(\mathbf{x}_t, \sqrt{1-\overline{\alpha}_t}/\sqrt{\overline{\alpha}_t})
444
+ \\
445
+ \mathbf{g}_t &= \nabla_{\mathbf{x}_t} \log p( \widehat{\mathbf{x}}_{t}(\mathbf{x}_t) | \mathbf{y} ) \\
446
+ \mathbf{\varepsilon}_t &= \mathcal{N}(0, \mathbf{I}) \\
447
+ \mathbf{x}_{t-1} &= a_t \,\, \mathbf{x}_t
448
+ + b_t \, \, \widehat{\mathbf{x}}_t
449
+ + \tilde{\sigma}_t \, \, \mathbf{\varepsilon}_t + \mathbf{g}_t,
450
+ \end{aligned}
451
+ \end{equation*}
452
+
453
+ where :math:`\denoiser{\cdot}{\sigma}` is a denoising network for noise level :math:`\sigma`,
454
+ :math:`\eta` is a hyperparameter, and the constants :math:`\tilde{\sigma}_t, a_t, b_t` are defined as
455
+
456
+ .. math::
457
+ \begin{equation*}
458
+ \begin{aligned}
459
+ \tilde{\sigma}_t &= \eta \sqrt{ (1 - \frac{\overline{\alpha}_t}{\overline{\alpha}_{t-1}})
460
+ \frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_t}} \\
461
+ a_t &= \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2}/\sqrt{1-\overline{\alpha}_t} \\
462
+ b_t &= \sqrt{\overline{\alpha}_{t-1}} - \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2}
463
+ \frac{\sqrt{\overline{\alpha}_{t}}}{\sqrt{1 - \overline{\alpha}_{t}}}.
464
+ \end{aligned}
465
+ \end{equation*}
466
+
467
+ :param torch.nn.Module model: a denoiser network that can handle different noise levels
468
+ :param deepinv.optim.DataFidelity data_fidelity: the data fidelity operator
469
+ :param int max_iter: the number of diffusion iterations to run the algorithm (default: 1000)
470
+ :param float eta: DDIM hyperparameter which controls the stochasticity
471
+ :param bool verbose: if True, print progress
472
+ :param str device: the device to use for the computations
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ model,
478
+ data_fidelity,
479
+ max_iter=1000,
480
+ eta=1.0,
481
+ verbose=False,
482
+ device="cpu",
483
+ save_iterates=False,
484
+ ):
485
+ super(DPS, self).__init__()
486
+ self.model = model
487
+ self.model.requires_grad_(True)
488
+ self.data_fidelity = data_fidelity
489
+ self.max_iter = max_iter
490
+ self.eta = eta
491
+ self.verbose = verbose
492
+ self.device = device
493
+ self.beta_start, self.beta_end = 0.1 / 1000, 20 / 1000
494
+ self.num_train_timesteps = 1000
495
+ self.save_iterates = save_iterates
496
+
497
+ self.betas, self.alpha_cumprod = self.compute_alpha_betas()
498
+
499
+ def compute_alpha_betas(self):
500
+ r"""
501
+
502
+ Get the beta and alpha sequences for the algorithm. This is necessary for mapping noise levels to timesteps.
503
+
504
+ """
505
+ betas = np.linspace(
506
+ self.beta_start, self.beta_end, self.num_train_timesteps, dtype=np.float32
507
+ )
508
+ betas = torch.from_numpy(betas).to(self.device)
509
+
510
+ alpha_cumprod = (
511
+ 1 - torch.cat([torch.zeros(1).to(betas.device), betas], dim=0)
512
+ ).cumprod(dim=0)
513
+ return betas, alpha_cumprod
514
+
515
+ def get_alpha(self, alpha_cumprod, t):
516
+ a = alpha_cumprod.index_select(0, t + 1).view(-1, 1, 1, 1)
517
+ return a
518
+
519
+ def forward(
520
+ self,
521
+ y,
522
+ physics: deepinv.physics.Physics,
523
+ seed=None,
524
+ x_init=None,
525
+ ):
526
+ r"""
527
+ Runs the diffusion to obtain a random sample of the posterior distribution.
528
+
529
+ :param torch.Tensor y: the measurements.
530
+ :param deepinv.physics.LinearPhysics physics: the physics operator.
531
+ :param int seed: the seed for the random number generator.
532
+ :param torch.Tensor x_init: the initial guess for the reconstruction.
533
+ """
534
+
535
+ if seed:
536
+ torch.manual_seed(seed)
537
+
538
+ # Initialization
539
+ if x_init is None: # Necessary when x and y don't live in the same space
540
+ x = 2 * physics.A_adjoint(y) - 1
541
+ else:
542
+ x = 2 * x_init - 1
543
+
544
+ skip = self.num_train_timesteps // self.max_iter
545
+ batch_size = y.shape[0]
546
+
547
+ seq = range(0, self.num_train_timesteps, skip)
548
+ seq_next = [-1] + list(seq[:-1])
549
+ time_pairs = list(zip(reversed(seq), reversed(seq_next)))
550
+
551
+ if self.save_iterates:
552
+ xs = [x]
553
+
554
+ xt = x.to(self.device)
555
+
556
+ for i, j in tqdm(time_pairs, disable=(not self.verbose)):
557
+ t = (torch.ones(batch_size) * i).to(self.device)
558
+ next_t = (torch.ones(batch_size) * j).to(self.device)
559
+
560
+ at = self.get_alpha(self.alpha_cumprod, t.long())
561
+ at_next = self.get_alpha(self.alpha_cumprod, next_t.long())
562
+
563
+ with torch.enable_grad():
564
+ xt.requires_grad_(True)
565
+
566
+ # 1. Denoising
567
+ # we call the denoiser using standard deviation instead of the time step.
568
+ aux_x = xt / 2 + 0.5
569
+ x0_t = 2 * self.model(aux_x, (1 - at).sqrt() / at.sqrt() / 2) - 1
570
+
571
+ x0_t = torch.clip(x0_t, -1.0, 1.0) # optional
572
+
573
+ # DPS
574
+ l2_loss = self.data_fidelity(x0_t, y, physics).sqrt().sum()
575
+
576
+ norm_grad = torch.autograd.grad(outputs=l2_loss, inputs=xt)[0]
577
+ norm_grad = norm_grad.detach()
578
+
579
+ c1 = ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() * self.eta
580
+ c2 = ((1 - at_next) - c1**2).sqrt()
581
+
582
+ # 3. noise step
583
+ epsilon = torch.randn_like(xt)
584
+
585
+ # 4. DDPM(IM) step
586
+ xt_next = (
587
+ (at_next.sqrt() - c2 * at.sqrt() / (1 - at).sqrt()) * x0_t
588
+ + c1 * epsilon
589
+ + c2 * xt / (1 - at).sqrt()
590
+ - norm_grad
591
+ )
592
+
593
+ if self.save_iterates:
594
+ xs.append(xt_next.to("cpu"))
595
+ xt = xt_next.clone()
596
+
597
+ if self.save_iterates:
598
+ return xs
599
+ else:
600
+ return xt
601
+
602
+
603
+ # if __name__ == "__main__":
604
+ # import deepinv as dinv
605
+ # from deepinv.models.denoiser import Denoiser
606
+ # import torchvision
607
+ # from deepinv.utils.metric import cal_psnr
608
+ #
609
+ # device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
610
+ #
611
+ # x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
612
+ # x = x.unsqueeze(0).float().to(device) / 255
613
+ #
614
+ # sigma_noise = 0.01
615
+ # # physics = dinv.physics.Denoising()
616
+ #
617
+ # # physics = dinv.physics.BlurFFT(img_size=x.shape[1:], filter=dinv.physics.blur.gaussian_blur(sigma=1.),
618
+ # # device=device)
619
+ # physics = dinv.physics.Decolorize()
620
+ # # physics = dinv.physics.Inpainting(
621
+ # # mask=0.5, tensor_size=(3, 218, 178), device=dinv.device
622
+ # # )
623
+ # # physics.mask *= (torch.rand_like(physics.mask))
624
+ # physics.noise_model = dinv.physics.GaussianNoise(sigma_noise)
625
+ #
626
+ # y = physics(x)
627
+ # model_spec = {
628
+ # "name": "drunet",
629
+ # "args": {"device": device, "pretrained": "download"},
630
+ # }
631
+ #
632
+ # denoiser = Denoiser(model_spec=model_spec)
633
+ #
634
+ # f = DDRM(
635
+ # denoiser=denoiser,
636
+ # etab=1.0,
637
+ # sigma_noise=sigma_noise,
638
+ # sigmas=np.linspace(1, 0, 100),
639
+ # verbose=True,
640
+ # )
641
+ #
642
+ # xhat = f(y, physics)
643
+ # dinv.utils.plot(
644
+ # [physics.A_adjoint(y), x, xhat], titles=["meas.", "ground-truth", "xhat"]
645
+ # )
646
+ #
647
+ # print(f"PSNR 1 sample: {cal_psnr(x, xhat):.2f} dB")
648
+ # # print(f'mean PSNR sample: {cal_psnr(x, denoiser(y, sigma_noise)):.2f} dB')
649
+ #
650
+ # # sampler = dinv.sampling.DiffusionSampler(f, max_iter=10, save_chain=True, verbose=True)
651
+ # # xmean, xvar = sampler(y, physics)
652
+ #
653
+ # # chain = sampler.get_chain()
654
+ # # distance = np.zeros((len(chain)))
655
+ # # for k, xhat in enumerate(chain):
656
+ # # dist = (xhat - xmean).pow(2).mean()
657
+ # # distance[k] = dist
658
+ # # distance = np.sort(distance)
659
+ # # thres = distance[int(len(distance) * .95)] #
660
+ # # err = (x - xmean).pow(2).mean()
661
+ # # print(f'Confidence region: {thres:.2e}, error: {err:.2e}')
662
+ #
663
+ # # xstdn = xvar.sqrt()
664
+ # # xstdn_plot = xstdn.sum(dim=1).unsqueeze(1)
665
+ #
666
+ # # error = (xmean - x).abs() # per pixel average abs. error
667
+ # # error_plot = error.sum(dim=1).unsqueeze(1)
668
+ #
669
+ # # print(f'Correct std: {(xstdn>error).sum()/np.prod(xstdn.shape)*100:.1f}%')
670
+ # # error = (xmean - x)
671
+ # # dinv.utils.plot_debug(
672
+ # # [physics.A_adjoint(y), x, xmean, xstdn_plot, error_plot], titles=["meas.", "ground-truth", "mean", "std", "error"]
673
+ # # )
674
+ #
675
+ # # print(f'PSNR 1 sample: {cal_psnr(x, chain[0]):.2f} dB')
676
+ # # print(f'mean PSNR sample: {cal_psnr(x, xmean):.2f} dB')