deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,512 @@
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ import time as time
5
+
6
+ import deepinv.optim
7
+ from tqdm import tqdm
8
+ from deepinv.optim.utils import check_conv
9
+ from deepinv.sampling.utils import Welford, projbox, refl_projbox
10
+
11
+
12
+ class MonteCarlo(nn.Module):
13
+ r"""
14
+ Base class for Monte Carlo sampling.
15
+
16
+ This class can be used to create new Monte Carlo samplers, by only defining their kernel inside a torch.nn.Module:
17
+
18
+ ::
19
+
20
+ # define custom sampling kernel (possibly a Markov kernel which depends on the previous sanple).
21
+ class MyKernel(torch.torch.nn.Module):
22
+ def __init__(self, iterator_params):
23
+ super().__init__()
24
+ self.iterator_params = iterator_params
25
+
26
+ def forward(self, x, y, physics, likelihood, prior):
27
+ # run one sampling kernel iteration
28
+ new_x = f(x, y, physics, likelihood, prior, self.iterator_params)
29
+ return new_x
30
+
31
+ class MySampler(MonteCarlo):
32
+ def __init__(self, prior, data_fidelity, iterator_params,
33
+ max_iter=1e3, burnin_ratio=.1, clip=(-1,2), verbose=True):
34
+ # generate an iterator
35
+ iterator = MyKernel(step_size=step_size, alpha=alpha)
36
+ # set the params of the base class
37
+ super().__init__(iterator, prior, data_fidelity, max_iter=max_iter,
38
+ burnin_ratio=burnin_ratio, clip=clip, verbose=verbose)
39
+
40
+ # create the sampler
41
+ sampler = MySampler(prior, data_fidelity, iterator_params)
42
+
43
+ # compute posterior mean and variance of reconstruction of measurement y
44
+ mean, var = sampler(y, physics)
45
+
46
+
47
+ This class computes the mean and variance of the chain using Welford's algorithm, which avoids storing the whole
48
+ Monte Carlo samples.
49
+
50
+ :param deepinv.optim.ScorePrior prior: negative log-prior based on a trained or model-based denoiser.
51
+ :param deepinv.optim.DataFidelity data_fidelity: negative log-likelihood function linked with the
52
+ noise distribution in the acquisition physics.
53
+ :param int max_iter: number of Monte Carlo iterations.
54
+ :param int thinning: thins the Monte Carlo samples by an integer :math:`\geq 1` (i.e., keeping one out of ``thinning``
55
+ samples to compute posterior statistics).
56
+ :param float burnin_ratio: percentage of iterations used for burn-in period, should be set between 0 and 1.
57
+ The burn-in samples are discarded constant with a numerical algorithm.
58
+ :param tuple clip: Tuple containing the box-constraints :math:`[a,b]`.
59
+ If ``None``, the algorithm will not project the samples.
60
+ :param float crit_conv: Threshold for verifying the convergence of the mean and variance estimates.
61
+ :param function_handle g_statistic: The sampler will compute the posterior mean and variance
62
+ of the function g_statistic. By default, it is the identity function (lambda x: x),
63
+ and thus the sampler computes the posterior mean and variance.
64
+ :param bool verbose: prints progress of the algorithm.
65
+
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ iterator: torch.nn.Module,
71
+ prior: deepinv.optim.ScorePrior,
72
+ data_fidelity: deepinv.optim.DataFidelity,
73
+ max_iter=1e3,
74
+ burnin_ratio=0.2,
75
+ thinning=10,
76
+ clip=(-1.0, 2.0),
77
+ thresh_conv=1e-3,
78
+ crit_conv="residual",
79
+ save_chain=False,
80
+ g_statistic=lambda x: x,
81
+ verbose=False,
82
+ ):
83
+ super(MonteCarlo, self).__init__()
84
+
85
+ self.iterator = iterator
86
+ self.prior = prior
87
+ self.likelihood = data_fidelity
88
+ self.C_set = clip
89
+ self.thinning = thinning
90
+ self.max_iter = int(max_iter)
91
+ self.thresh_conv = thresh_conv
92
+ self.crit_conv = crit_conv
93
+ self.burnin_iter = int(burnin_ratio * max_iter)
94
+ self.verbose = verbose
95
+ self.mean_convergence = False
96
+ self.var_convergence = False
97
+ self.g_function = g_statistic
98
+ self.save_chain = save_chain
99
+ self.chain = []
100
+
101
+ def forward(self, y, physics, seed=None, x_init=None):
102
+ r"""
103
+ Runs an Monte Carlo chain to obtain the posterior mean and variance of the reconstruction of the measurements y.
104
+
105
+ :param torch.tensor y: Measurements
106
+ :param deepinv.physics.Physics physics: Forward operator associated with the measurements
107
+ :param float seed: Random seed for generating the Monte Carlo samples
108
+ :return: (tuple of torch.tensor) containing the posterior mean and variance.
109
+ """
110
+ with torch.no_grad():
111
+ if seed is not None:
112
+ np.random.seed(seed)
113
+ torch.manual_seed(seed)
114
+
115
+ # Algorithm parameters
116
+ if self.C_set:
117
+ C_lower_lim = self.C_set[0]
118
+ C_upper_lim = self.C_set[1]
119
+
120
+ # Initialization
121
+ if x_init is None:
122
+ x = physics.A_adjoint(y)
123
+ else:
124
+ x = x_init
125
+
126
+ # Monte Carlo loop
127
+ start_time = time.time()
128
+ statistics = Welford(self.g_function(x))
129
+
130
+ self.mean_convergence = False
131
+ self.var_convergence = False
132
+ for it in tqdm(range(self.max_iter), disable=(not self.verbose)):
133
+ x = self.iterator(
134
+ x, y, physics, likelihood=self.likelihood, prior=self.prior
135
+ )
136
+
137
+ if self.C_set:
138
+ x = projbox(x, C_lower_lim, C_upper_lim)
139
+
140
+ if it >= self.burnin_iter and (it % self.thinning) == 0:
141
+ if it >= (self.max_iter - self.thinning):
142
+ mean_prev = statistics.mean().clone()
143
+ var_prev = statistics.var().clone()
144
+ statistics.update(self.g_function(x))
145
+
146
+ if self.save_chain:
147
+ self.chain.append(x.clone())
148
+
149
+ if self.verbose:
150
+ if torch.cuda.is_available():
151
+ torch.cuda.synchronize()
152
+ end_time = time.time()
153
+ elapsed = end_time - start_time
154
+ print(
155
+ f"Monte Carlo sampling finished! elapsed time={elapsed:.2f} seconds"
156
+ )
157
+
158
+ if (
159
+ check_conv(
160
+ {"est": (mean_prev,)},
161
+ {"est": (statistics.mean(),)},
162
+ it,
163
+ self.crit_conv,
164
+ self.thresh_conv,
165
+ self.verbose,
166
+ )
167
+ and it > 1
168
+ ):
169
+ self.mean_convergence = True
170
+
171
+ if (
172
+ check_conv(
173
+ {"est": (var_prev,)},
174
+ {"est": (statistics.var(),)},
175
+ it,
176
+ self.crit_conv,
177
+ self.thresh_conv,
178
+ self.verbose,
179
+ )
180
+ and it > 1
181
+ ):
182
+ self.var_convergence = True
183
+
184
+ return statistics.mean(), statistics.var()
185
+
186
+ def get_chain(self):
187
+ r"""
188
+ Returns the thinned Monte Carlo samples (after burn-in iterations).
189
+ Requires ``save_chain=True``.
190
+ """
191
+ return self.chain
192
+
193
+ def reset(self):
194
+ r"""
195
+ Resets the Markov chain.
196
+ """
197
+ self.chain = []
198
+ self.mean_convergence = False
199
+ self.var_convergence = False
200
+
201
+ def mean_has_converged(self):
202
+ r"""
203
+ Returns a boolean indicating if the posterior mean verifies the convergence criteria.
204
+ """
205
+ return self.mean_convergence
206
+
207
+ def var_has_converged(self):
208
+ r"""
209
+ Returns a boolean indicating if the posterior variance verifies the convergence criteria.
210
+ """
211
+ return self.var_convergence
212
+
213
+
214
+ class ULAIterator(nn.Module):
215
+ def __init__(self, step_size, alpha, sigma):
216
+ super().__init__()
217
+ self.step_size = step_size
218
+ self.alpha = alpha
219
+ self.noise_std = np.sqrt(2 * step_size)
220
+ self.sigma = sigma
221
+
222
+ def forward(self, x, y, physics, likelihood, prior):
223
+ noise = torch.randn_like(x) * self.noise_std
224
+ lhood = -likelihood.grad(x, y, physics)
225
+ lprior = -prior(x, self.sigma) * self.alpha
226
+ return x + self.step_size * (lhood + lprior) + noise
227
+
228
+
229
+ class ULA(MonteCarlo):
230
+ r"""
231
+ Projected Plug-and-Play Unadjusted Langevin Algorithm.
232
+
233
+ The algorithm runs the following markov chain iteration
234
+ (Algorithm 2 from https://arxiv.org/abs/2103.04715):
235
+
236
+ .. math::
237
+
238
+ x_{k+1} = \Pi_{[a,b]} \left(x_{k} + \eta \nabla \log p(y|A,x_k) +
239
+ \eta \alpha \nabla \log p(x_{k}) + \sqrt{2\eta}z_{k+1} \right).
240
+
241
+ where :math:`x_{k}` is the :math:`k` th sample of the Markov chain,
242
+ :math:`\log p(y|x)` is the log-likelihood function, :math:`\log p(x)` is the log-prior,
243
+ :math:`\eta>0` is the step size, :math:`\alpha>0` controls the amount of regularization,
244
+ :math:`\Pi_{[a,b]}(x)` projects the entries of :math:`x` to the interval :math:`[a,b]` and
245
+ :math:`z\sim \mathcal{N}(0,I)` is a standard Gaussian vector.
246
+
247
+
248
+ - Projected PnP-ULA assumes that the denoiser is :math:`L`-Lipschitz differentiable
249
+ - For convergence, ULA required step_size smaller than :math:`\frac{1}{L+\|A\|_2^2}`
250
+
251
+
252
+ :param deepinv.optim.ScorePrior, torch.nn.Module prior: negative log-prior based on a trained or model-based denoiser.
253
+ :param deepinv.optim.DataFidelity, torch.nn.Module data_fidelity: negative log-likelihood function linked with the
254
+ noise distribution in the acquisition physics.
255
+ :param float step_size: step size :math:`\eta>0` of the algorithm.
256
+ Tip: use :meth:`deepinv.physics.Physics.compute_norm()` to compute the Lipschitz constant of the forward operator.
257
+ :param float sigma: noise level used in the plug-and-play prior denoiser. A larger value of sigma will result in
258
+ a more regularized reconstruction.
259
+ :param float alpha: regularization parameter :math:`\alpha`
260
+ :param int max_iter: number of Monte Carlo iterations.
261
+ :param int thinning: Thins the Markov Chain by an integer :math:`\geq 1` (i.e., keeping one out of ``thinning``
262
+ samples to compute posterior statistics).
263
+ :param float burnin_ratio: percentage of iterations used for burn-in period, should be set between 0 and 1.
264
+ The burn-in samples are discarded constant with a numerical algorithm.
265
+ :param tuple clip: Tuple containing the box-constraints :math:`[a,b]`.
266
+ If ``None``, the algorithm will not project the samples.
267
+ :param float crit_conv: Threshold for verifying the convergence of the mean and variance estimates.
268
+ :param function_handle g_statistic: The sampler will compute the posterior mean and variance
269
+ of the function g_statistic. By default, it is the identity function (lambda x: x),
270
+ and thus the sampler computes the posterior mean and variance.
271
+ :param bool verbose: prints progress of the algorithm.
272
+
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ prior,
278
+ data_fidelity,
279
+ step_size=1.0,
280
+ sigma=0.05,
281
+ alpha=1.0,
282
+ max_iter=1e3,
283
+ thinning=5,
284
+ burnin_ratio=0.2,
285
+ clip=(-1.0, 2.0),
286
+ thresh_conv=1e-3,
287
+ save_chain=False,
288
+ g_statistic=lambda x: x,
289
+ verbose=False,
290
+ ):
291
+ iterator = ULAIterator(step_size=step_size, alpha=alpha, sigma=sigma)
292
+ super().__init__(
293
+ iterator,
294
+ prior,
295
+ data_fidelity,
296
+ max_iter=max_iter,
297
+ thresh_conv=thresh_conv,
298
+ g_statistic=g_statistic,
299
+ burnin_ratio=burnin_ratio,
300
+ clip=clip,
301
+ thinning=thinning,
302
+ save_chain=save_chain,
303
+ verbose=verbose,
304
+ )
305
+
306
+
307
+ class SKRockIterator(nn.Module):
308
+ def __init__(self, step_size, alpha, inner_iter, eta, sigma):
309
+ super().__init__()
310
+ self.step_size = step_size
311
+ self.alpha = alpha
312
+ self.eta = eta
313
+ self.inner_iter = inner_iter
314
+ self.noise_std = np.sqrt(2 * step_size)
315
+ self.sigma = sigma
316
+
317
+ def forward(self, x, y, physics, likelihood, prior):
318
+ posterior = lambda u: likelihood.grad(u, y, physics) + self.alpha * prior(
319
+ u, self.sigma
320
+ )
321
+
322
+ # First kind Chebyshev function
323
+ T_s = lambda s, u: np.cosh(s * np.arccosh(u))
324
+ # First derivative Chebyshev polynomial first kind
325
+ T_prime_s = lambda s, u: s * np.sinh(s * np.arccosh(u)) / np.sqrt(u**2 - 1)
326
+
327
+ w0 = 1 + self.eta / (self.inner_iter**2) # parameter \omega_0
328
+ w1 = T_s(self.inner_iter, w0) / T_prime_s(
329
+ self.inner_iter, w0
330
+ ) # parameter \omega_1
331
+ mu1 = w1 / w0 # parameter \mu_1
332
+ nu1 = self.inner_iter * w1 / 2 # parameter \nu_1
333
+ kappa1 = self.inner_iter * (w1 / w0) # parameter \kappa_1
334
+
335
+ # sampling the variable x
336
+ noise = np.sqrt(2 * self.step_size) * torch.randn_like(x) # diffusion term
337
+
338
+ # first internal iteration (s=1)
339
+ xts_2 = x.clone()
340
+ xts = (
341
+ x.clone()
342
+ - mu1 * self.step_size * posterior(x + nu1 * noise)
343
+ + kappa1 * noise
344
+ )
345
+
346
+ for js in range(
347
+ 2, self.inner_iter + 1
348
+ ): # s=2,...,self.inner_iter SK-ROCK internal iterations
349
+ xts_1 = xts.clone()
350
+ mu = 2 * w1 * T_s(js - 1, w0) / T_s(js, w0) # parameter \mu_js
351
+ nu = 2 * w0 * T_s(js - 1, w0) / T_s(js, w0) # parameter \nu_js
352
+ kappa = 1 - nu # parameter \kappa_js
353
+ xts = -mu * self.step_size * posterior(xts) + nu * xts + kappa * xts_2
354
+ xts_2 = xts_1
355
+
356
+ return xts # new sample produced by the SK-ROCK algorithm
357
+
358
+
359
+ class SKRock(MonteCarlo):
360
+ r"""
361
+ Plug-and-Play SKROCK algorithm.
362
+
363
+ Obtains samples of the posterior distribution using an orthogonal Runge-Kutta-Chebyshev stochastic
364
+ approximation to accelerate the standard Unadjusted Langevin Algorithm.
365
+
366
+ The algorithm was introduced in "Accelerating proximal Markov chain Monte Carlo by using an explicit stabilised method"
367
+ by L. Vargas, M. Pereyra and K. Zygalakis (https://arxiv.org/abs/1908.08845)
368
+
369
+ - SKROCK assumes that the denoiser is :math:`L`-Lipschitz differentiable
370
+ - For convergence, SKROCK required step_size smaller than :math:`\frac{1}{L+\|A\|_2^2}`
371
+
372
+ :param deepinv.optim.ScorePrior, torch.nn.Module prior: negative log-prior based on a trained or model-based denoiser.
373
+ :param deepinv.optim.DataFidelity, torch.nn.Module data_fidelity: negative log-likelihood function linked with the
374
+ noise distribution in the acquisition physics.
375
+ :param float step_size: Step size of the algorithm. Tip: use physics.lipschitz to compute the Lipschitz
376
+ :param float eta: :math:`\eta` SKROCK damping parameter.
377
+ :param float alpha: regularization parameter :math:`\alpha`.
378
+ :param int inner_iter: Number of inner SKROCK iterations.
379
+ :param int max_iter: Number of outer iterations.
380
+ :param int thinning: Thins the Markov Chain by an integer :math:`\geq 1` (i.e., keeping one out of ``thinning``
381
+ samples to compute posterior statistics).
382
+ :param float burnin_ratio: percentage of iterations used for burn-in period. The burn-in samples are discarded
383
+ constant with a numerical algorithm.
384
+ :param tuple clip: Tuple containing the box-constraints :math:`[a,b]`.
385
+ If ``None``, the algorithm will not project the samples.
386
+ :param bool verbose: prints progress of the algorithm.
387
+ :param float sigma: noise level used in the plug-and-play prior denoiser. A larger value of sigma will result in
388
+ a more regularized reconstruction.
389
+ :param function_handle g_statistic: The sampler will compute the posterior mean and variance
390
+ of the function g_statistic. By default, it is the identity function (lambda x: x),
391
+ and thus the sampler computes the posterior mean and variance.
392
+
393
+ """
394
+
395
+ def __init__(
396
+ self,
397
+ prior: deepinv.optim.ScorePrior,
398
+ data_fidelity,
399
+ step_size=1.0,
400
+ inner_iter=10,
401
+ eta=0.05,
402
+ alpha=1.0,
403
+ max_iter=1e3,
404
+ burnin_ratio=0.2,
405
+ thinning=10,
406
+ clip=(-1.0, 2.0),
407
+ thresh_conv=1e-3,
408
+ save_chain=False,
409
+ g_statistic=lambda x: x,
410
+ verbose=False,
411
+ sigma=0.05,
412
+ ):
413
+ iterator = SKRockIterator(
414
+ step_size=step_size,
415
+ alpha=alpha,
416
+ inner_iter=inner_iter,
417
+ eta=eta,
418
+ sigma=sigma,
419
+ )
420
+ super().__init__(
421
+ iterator,
422
+ prior,
423
+ data_fidelity,
424
+ max_iter=max_iter,
425
+ thresh_conv=thresh_conv,
426
+ thinning=thinning,
427
+ burnin_ratio=burnin_ratio,
428
+ clip=clip,
429
+ g_statistic=g_statistic,
430
+ save_chain=save_chain,
431
+ verbose=verbose,
432
+ )
433
+
434
+
435
+ # if __name__ == "__main__":
436
+ # import deepinv as dinv
437
+ # import torchvision
438
+ # from deepinv.optim.data_fidelity import L2
439
+ #
440
+ # x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
441
+ # x = x.unsqueeze(0).float().to(dinv.device) / 255
442
+ # # physics = dinv.physics.CompressedSensing(m=50000, fast=True, img_shape=(3, 218, 178), device=dinv.device)
443
+ # # physics = dinv.physics.Denoising()
444
+ # physics = dinv.physics.Inpainting(
445
+ # mask=0.95, tensor_size=(3, 218, 178), device=dinv.device
446
+ # )
447
+ # # physics = dinv.physics.BlurFFT(filter=dinv.physics.blur.gaussian_blur(sigma=(2,2)), img_size=x.shape[1:], device=dinv.device)
448
+ #
449
+ # sigma = 0.1
450
+ # physics.noise_model = dinv.physics.GaussianNoise(sigma)
451
+ #
452
+ # y = physics(x)
453
+ #
454
+ # likelihood = L2(sigma=sigma)
455
+ #
456
+ # # model_spec = {'name': 'median_filter', 'args': {'kernel_size': 3}}
457
+ # model_spec = {
458
+ # "name": "dncnn",
459
+ # "args": {
460
+ # "device": dinv.device,
461
+ # "in_channels": 3,
462
+ # "out_channels": 3,
463
+ # "pretrained": "download_lipschitz",
464
+ # },
465
+ # }
466
+ # # model_spec = {'name': 'waveletprior', 'args': {'wv': 'db8', 'level': 4, 'device': dinv.device}}
467
+ #
468
+ # prior = ScorePrior(model_spec=model_spec, sigma_normalize=True)
469
+ #
470
+ # sigma_den = 2 / 255
471
+ # f = ULA(
472
+ # prior,
473
+ # likelihood,
474
+ # max_iter=5000,
475
+ # sigma=sigma_den,
476
+ # burnin_ratio=0.3,
477
+ # verbose=True,
478
+ # alpha=0.3,
479
+ # step_size=0.5 * 1 / (1 / (sigma**2) + 1 / (sigma_den**2)),
480
+ # clip=(-1, 2),
481
+ # save_chain=True,
482
+ # )
483
+ # # f = SKRock(prior, likelihood, max_iter=1000, burnin_ratio=.3, verbose=True,
484
+ # # alpha=.9, step_size=.1*(sigma**2), clip=(-1, 2))
485
+ #
486
+ # xmean, xvar = f(y, physics)
487
+ #
488
+ # print(str(f.mean_has_converged()))
489
+ # print(str(f.var_has_converged()))
490
+ #
491
+ # chain = f.get_chain()
492
+ # distance = np.zeros((len(chain)))
493
+ # for k, xhat in enumerate(chain):
494
+ # dist = (xhat - xmean).pow(2).mean()
495
+ # distance[k] = dist
496
+ # distance = np.sort(distance)
497
+ # thres = distance[int(len(distance) * 0.95)] #
498
+ # err = (x - xmean).pow(2).mean()
499
+ # print(f"Confidence region: {thres:.2e}, error: {err:.2e}")
500
+ #
501
+ # xstdn = xvar.sqrt()
502
+ # xstdn_plot = xstdn.sum(dim=1).unsqueeze(1)
503
+ #
504
+ # error = (xmean - x).abs() # per pixel average abs. error
505
+ # error_plot = error.sum(dim=1).unsqueeze(1)
506
+ #
507
+ # print(f"Correct std: {(xstdn*3>error).sum()/np.prod(xstdn.shape)*100:.1f}%")
508
+ #
509
+ # dinv.utils.plot(
510
+ # [physics.A_adjoint(y), x, xmean, xstdn_plot, error_plot],
511
+ # titles=["meas.", "ground-truth", "mean", "norm. std", "abs. error"],
512
+ # )
@@ -0,0 +1,35 @@
1
+ import torch
2
+
3
+
4
+ class Welford:
5
+ r"""
6
+ Welford's algorithm for calculating mean and variance
7
+
8
+ https://doi.org/10.2307/1266577
9
+ """
10
+
11
+ def __init__(self, x):
12
+ self.k = 1
13
+ self.M = x.clone()
14
+ self.S = torch.zeros_like(x)
15
+
16
+ def update(self, x):
17
+ self.k += 1
18
+ Mnext = self.M + (x - self.M) / self.k
19
+ self.S = self.S + (x - self.M) * (x - Mnext)
20
+ self.M = Mnext
21
+
22
+ def mean(self):
23
+ return self.M
24
+
25
+ def var(self):
26
+ return self.S / (self.k - 1)
27
+
28
+
29
+ def refl_projbox(x, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
30
+ x = torch.abs(x)
31
+ return torch.clamp(x, min=lower, max=upper)
32
+
33
+
34
+ def projbox(x, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor:
35
+ return torch.clamp(x, min=lower, max=upper)
@@ -0,0 +1,39 @@
1
+ import pytest
2
+
3
+ import torch
4
+
5
+ import deepinv as dinv
6
+ from deepinv.tests.dummy_datasets.datasets import DummyCircles
7
+
8
+
9
+ @pytest.fixture
10
+ def device():
11
+ return dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
12
+
13
+
14
+ @pytest.fixture
15
+ def toymatrix():
16
+ w = 50
17
+ A = torch.diag(torch.Tensor(range(1, w + 1)))
18
+ return A
19
+
20
+
21
+ @pytest.fixture
22
+ def dummy_dataset(imsize, device):
23
+ return DummyCircles(samples=1, imsize=imsize)
24
+
25
+
26
+ @pytest.fixture
27
+ def imsize():
28
+ h = 37
29
+ w = 31
30
+ c = 3
31
+ return c, h, w
32
+
33
+
34
+ @pytest.fixture
35
+ def imsize_1_channel():
36
+ h = 37
37
+ w = 31
38
+ c = 1
39
+ return c, h, w
@@ -0,0 +1,57 @@
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+
5
+
6
+ def create_circular_mask(imsize, center=None, radius=None):
7
+ h, w = imsize
8
+ if center is None: # use the middle of the image
9
+ center = (int(h / 2), int(w / 2))
10
+ if radius is None: # use the smallest distance between the center and image walls
11
+ radius = min(center[0], center[1], h - center[0], w - center[1])
12
+
13
+ X, Y = np.ogrid[:h, :w]
14
+ dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
15
+ mask = dist_from_center <= radius
16
+ return mask
17
+
18
+
19
+ class DummyCircles(Dataset):
20
+ def __init__(self, samples, imsize=(3, 32, 28), max_circles=10, seed=1):
21
+ super().__init__()
22
+
23
+ self.x = torch.zeros((samples,) + imsize, dtype=torch.float32)
24
+
25
+ rng = np.random.default_rng(seed)
26
+
27
+ max_rad = max(imsize[0], imsize[1]) / 2
28
+ for i in range(samples):
29
+ circles = rng.integers(low=1, high=max_circles)
30
+
31
+ for c in range(circles):
32
+ pos = rng.uniform(high=imsize[1:])
33
+ colour = rng.random((imsize[0], 1), dtype=np.float32)
34
+ r = rng.uniform(high=max_rad)
35
+ mask = torch.from_numpy(
36
+ create_circular_mask(imsize[1:], center=pos, radius=r)
37
+ )
38
+ self.x[i, :, mask] = torch.from_numpy(colour)
39
+
40
+ def __getitem__(self, index):
41
+ return self.x[index, :, :, :]
42
+
43
+ def __len__(self):
44
+ return self.x.shape[0]
45
+
46
+
47
+ if __name__ == "__main__":
48
+ device = "cuda:0"
49
+ imsize = (3, 23, 100)
50
+ dataset = DummyCircles(10, imsize=imsize)
51
+
52
+ x = dataset[0]
53
+
54
+ import matplotlib.pyplot as plt
55
+
56
+ plt.imshow(x.permute(1, 2, 0).cpu().numpy())
57
+ plt.show()