deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/optim/prior.py ADDED
@@ -0,0 +1,288 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from deepinv.optim.utils import gradient_descent
5
+
6
+
7
+ class Prior(nn.Module):
8
+ r"""
9
+ Prior term :math:`g(x)`.
10
+
11
+ This is the base class for the prior term :math:`g(x)`. Similarly to the :meth:`deepinv.optim.DataFidelity` class,
12
+ this class comes with methods for computing
13
+ :math:`\operatorname{prox}_{g}` and :math:`\nabla g`.
14
+ To implement a custom prior, for an explicit prior, overwrite :math:`g` (do not forget to specify
15
+ `self.explicit_prior = True`)
16
+
17
+ This base class is also used to implement implicit priors. For instance, in PnP methods, the method computing the
18
+ proximity operator is overwritten by a method performing denoising. For an implicit prior, overwrite `grad`
19
+ or `prox`.
20
+
21
+
22
+ .. note::
23
+
24
+ The methods for computing the proximity operator and the gradient of the prior rely on automatic
25
+ differentiation. These methods should not be used when the prior is not differentiable, although they will
26
+ not raise an error.
27
+
28
+
29
+ :param callable g: Prior function :math:`g(x)`.
30
+ """
31
+
32
+ def __init__(self, g=None):
33
+ super().__init__()
34
+ self._g = g
35
+ self.explicit_prior = False if self._g is None else True
36
+
37
+ def g(self, x, *args, **kwargs):
38
+ r"""
39
+ Computes the prior :math:`g(x)`.
40
+
41
+ :param torch.tensor x: Variable :math:`x` at which the prior is computed.
42
+ :return: (torch.tensor) prior :math:`g(x)`.
43
+ """
44
+ return self._g(x, *args, **kwargs)
45
+
46
+ def forward(self, x, *args, **kwargs):
47
+ r"""
48
+ Computes the prior :math:`g(x)`.
49
+
50
+ :param torch.tensor x: Variable :math:`x` at which the prior is computed.
51
+ :return: (torch.tensor) prior :math:`g(x)`.
52
+ """
53
+ return self.g(x, *args, **kwargs)
54
+
55
+ def grad(self, x, *args, **kwargs):
56
+ r"""
57
+ Calculates the gradient of the prior term :math:`g` at :math:`x`.
58
+ By default, the gradient is computed using automatic differentiation.
59
+
60
+ :param torch.tensor x: Variable :math:`x` at which the gradient is computed.
61
+ :return: (torch.tensor) gradient :math:`\nabla_x g`, computed in :math:`x`.
62
+ """
63
+ with torch.enable_grad():
64
+ x = x.requires_grad_()
65
+ grad = torch.autograd.grad(
66
+ self.g(x, *args, **kwargs), x, create_graph=True, only_inputs=True
67
+ )[0]
68
+ return grad
69
+
70
+ def prox(
71
+ self,
72
+ x,
73
+ *args,
74
+ gamma=1.0,
75
+ stepsize_inter=1.0,
76
+ max_iter_inter=50,
77
+ tol_inter=1e-3,
78
+ **kwargs,
79
+ ):
80
+ r"""
81
+ Calculates the proximity operator of :math:`g` at :math:`x`. By default, the proximity operator is computed using internal gradient descent.
82
+
83
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
84
+ :param float gamma: stepsize of the proximity operator.
85
+ :param float stepsize_inter: stepsize used for internal gradient descent
86
+ :param int max_iter_inter: maximal number of iterations for internal gradient descent.
87
+ :param float tol_inter: internal gradient descent has converged when the L2 distance between two consecutive iterates is smaller than tol_inter.
88
+ :return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma g}(x)`, computed in :math:`x`.
89
+ """
90
+ grad = lambda z: gamma * self.grad(z, *args, **kwargs) + (z - x)
91
+ return gradient_descent(
92
+ grad, x, step_size=stepsize_inter, max_iter=max_iter_inter, tol=tol_inter
93
+ )
94
+
95
+ def prox_conjugate(self, x, *args, gamma=1.0, lamb=1.0, **kwargs):
96
+ r"""
97
+ Calculates the proximity operator of the convex conjugate :math:`(\lambda g)^*` at :math:`x`, using the Moreau formula.
98
+
99
+ ::Warning:: Only valid for convex :math:`g`
100
+
101
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
102
+ :param float gamma: stepsize of the proximity operator.
103
+ :param float lamb: math:`\lambda` parameter in front of :math:`f`
104
+ :return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma \lambda g)^*}(x)`, computed in :math:`x`.
105
+ """
106
+ return x - gamma * self.prox(x / gamma, lamb / gamma, *args, **kwargs)
107
+
108
+
109
+ class PnP(Prior):
110
+ r"""
111
+ Plug-and-play prior :math:`\operatorname{prox}_{\gamma g}(x) = \operatorname{D}_{\sigma}(x)`.
112
+
113
+
114
+ :param callable denoiser: Denoiser :math:`\operatorname{D}_{\sigma}`.
115
+ """
116
+
117
+ def __init__(self, denoiser, *args, **kwargs):
118
+ super().__init__(*args, **kwargs)
119
+ self.denoiser = denoiser
120
+ self.explicit_prior = False
121
+
122
+ def prox(self, x, sigma_denoiser, *args, **kwargs):
123
+ r"""
124
+ Uses denoising as the proximity operator of the PnP prior :math:`g` at :math:`x`.
125
+
126
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
127
+ :param float sigma_denoiser: noise level parameter of the denoiser.
128
+ :return: (torch.tensor) proximity operator at :math:`x`.
129
+ """
130
+ return self.denoiser(x, sigma_denoiser)
131
+
132
+
133
+ class RED(Prior):
134
+ r"""
135
+ Regularization-by-Denoising (RED) prior :math:`\nabla g(x) = \operatorname{Id} - \operatorname{D}_{\sigma}(x)`.
136
+
137
+
138
+ :param callable denoiser: Denoiser :math:`\operatorname{D}_{\sigma}`.
139
+ """
140
+
141
+ def __init__(self, denoiser, *args, **kwargs):
142
+ super().__init__(*args, **kwargs)
143
+ self.denoiser = denoiser
144
+ self.explicit_prior = False
145
+
146
+ def grad(self, x, sigma_denoiser, *args, **kwargs):
147
+ r"""
148
+ Calculates the gradient of the prior term :math:`g` at :math:`x`.
149
+ By default, the gradient is computed using automatic differentiation.
150
+
151
+ :param torch.Tensor x: Variable :math:`x` at which the gradient is computed.
152
+ :return: (:class:`torch.Tensor`) gradient :math:`\nabla_x g`, computed in :math:`x`.
153
+ """
154
+ return x - self.denoiser(x, sigma_denoiser)
155
+
156
+
157
+ class ScorePrior(Prior):
158
+ r"""
159
+ Score via MMSE denoiser :math:`\nabla g(x)=\left(x-\operatorname{D}_{\sigma}(x)\right)/\sigma^2`.
160
+
161
+ This approximates the score of a distribution using Tweedie's formula, i.e.,
162
+
163
+ .. math::
164
+
165
+ - \nabla \log p_{\sigma}(x) \propto \left(x-D(x,\sigma)\right)/\sigma^2
166
+
167
+ where :math:`p_{\sigma} = p*\mathcal{N}(0,I\sigma^2)` is the prior convolved with a Gaussian kernel,
168
+ :math:`D(\cdot,\sigma)` is a (trained or model-based) denoiser with noise level :math:`\sigma`,
169
+ which is typically set to a low value.
170
+
171
+ .. note::
172
+
173
+ If math:`\sigma=1`, this prior is equal to :class:`deepinv.optim.RED`, which is defined in
174
+ `Regularization by Denoising (RED) <https://arxiv.org/abs/1611.02862>`_ and doesn't require the normalization.
175
+
176
+
177
+ .. note::
178
+
179
+ This class can also be used with maximum-a-posteriori (MAP) denoisers,
180
+ but :math:`p_{\sigma}(x)` is not given by the convolution with a Gaussian kernel, but rather
181
+ given by the Moreau-Yosida envelope of :math:`p(x)`, i.e.,
182
+
183
+ .. math::
184
+
185
+ p_{\sigma}(x)=e^{- \inf_z \left(-\log p(z) + \frac{1}{2\sigma}\|x-z\|^2 \right)}.
186
+
187
+
188
+ """
189
+
190
+ def __init__(self, denoiser, *args, **kwargs):
191
+ super().__init__(*args, **kwargs)
192
+ self.denoiser = denoiser
193
+ self.explicit_prior = False
194
+
195
+ def forward(self, x, sigma):
196
+ r"""
197
+ Applies the denoiser to the input signal.
198
+
199
+ :param torch.Tensor x: the input tensor.
200
+ :param float sigma: the noise level.
201
+ """
202
+ return (1 / sigma**2) * (x - self.denoiser(x, sigma))
203
+
204
+
205
+ class Tikhonov(Prior):
206
+ r"""
207
+ Tikhonov regularizer :math:`g(x) = \frac{1}{2}\| x \|_2^2`.
208
+ """
209
+
210
+ def __init__(self, *args, **kwargs):
211
+ super().__init__(*args, **kwargs)
212
+ self.explicit_prior = True
213
+
214
+ def g(self, x, ths=1.0):
215
+ r"""
216
+ Computes the Tikhonov regularizer :math:`g(x) = \frac{\tau}{2}\| x \|_2^2`.
217
+
218
+ :param torch.Tensor x: Variable :math:`x` at which the prior is computed.
219
+ :param float ths: regularization parameter :math:`\tau`.
220
+ :return: (torch.Tensor) prior :math:`g(x)`.
221
+ """
222
+ return (
223
+ 0.5
224
+ * ths
225
+ * torch.norm(x.contiguous().view(x.shape[0], -1), p=2, dim=-1) ** 2
226
+ )
227
+
228
+ def grad(self, x):
229
+ r"""
230
+ Calculates the gradient of the Tikhonov regularization term :math:`g` at :math:`x`.
231
+
232
+ :param torch.Tensor x: Variable :math:`x` at which the gradient is computed.
233
+ :return: (torch.Tensor) gradient at :math:`x`.
234
+ """
235
+ return x
236
+
237
+ def prox(self, x, ths=1.0, gamma=1.0):
238
+ r"""
239
+ Calculates the proximity operator of the Tikhonov regularization term :math:`\gamma \tau g` at :math:`x`.
240
+
241
+ :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
242
+ :param float ths: regularization parameter :math:`\tau`.
243
+ :param float gamma: stepsize of the proximity operator.
244
+ :return: (torch.Tensor) proximity operator at :math:`x`.
245
+ """
246
+ return (1 / (ths * gamma + 1)) * x
247
+
248
+
249
+ class L1Prior(Prior):
250
+ r"""
251
+ :math:`\ell_1` prior :math:`g(x) = \| x \|_1`.
252
+
253
+ """
254
+
255
+ def __init__(self, *args, **kwargs):
256
+ super().__init__(*args, **kwargs)
257
+ self.explicit_prior = True
258
+
259
+ def g(self, x, ths=1.0):
260
+ r"""
261
+ Computes the regularizer :math:`g(x) = \tau\| x \|_1`.
262
+
263
+ :param torch.Tensor x: Variable :math:`x` at which the prior is computed.
264
+ :param float ths: threshold parameter :math:`\tau`.
265
+ :return: (torch.Tensor) prior :math:`g(x)`.
266
+ """
267
+ return ths * torch.norm(x.contiguous().view(x.shape[0], -1), p=1, dim=-1)
268
+
269
+ def prox(self, x, ths=1.0, gamma=1.0):
270
+ r"""
271
+ Calculates the proximity operator of the l1 regularization term :math:`g` at :math:`x`.
272
+
273
+ More precisely, it computes
274
+
275
+ .. math::
276
+ \operatorname{prox}_{\gamma \tau g}(x) = \operatorname{sign}(x) \max(|x| - \gamma \tau, 0)
277
+
278
+
279
+ where :math:`\tau` is the threshold parameter and :math:`\gamma` is a stepsize.
280
+
281
+ :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
282
+ :param float ths: threshold parameter :math:`\tau`.
283
+ :param float gamma: stepsize of the proximity operator.
284
+ :return: (torch.Tensor) proximity operator at :math:`x`.
285
+ """
286
+ return torch.sign(x) * torch.max(
287
+ torch.abs(x) - ths * gamma, torch.zeros_like(x)
288
+ )
deepinv/optim/utils.py ADDED
@@ -0,0 +1,80 @@
1
+ from deepinv.utils import zeros_like
2
+
3
+
4
+ def check_conv(X_prev, X, it, crit_conv="residual", thres_conv=1e-3, verbose=False):
5
+ if crit_conv == "residual":
6
+ if isinstance(X_prev, dict):
7
+ X_prev = X_prev["est"][0]
8
+ if isinstance(X, dict):
9
+ X = X["est"][0]
10
+ crit_cur = (X_prev - X).norm() / (X.norm() + 1e-06)
11
+ elif crit_conv == "cost":
12
+ F_prev = X_prev["cost"]
13
+ F = X["cost"]
14
+ crit_cur = (F_prev - F).norm() / (F.norm() + 1e-06)
15
+ else:
16
+ raise ValueError("convergence criteria not implemented")
17
+ if crit_cur < thres_conv:
18
+ if verbose:
19
+ print(
20
+ f"Iteration {it}, current converge crit. = {crit_cur:.2E}, objective = {thres_conv:.2E} \r"
21
+ )
22
+ return True
23
+ else:
24
+ return False
25
+
26
+
27
+ def conjugate_gradient(A, b, max_iter=1e2, tol=1e-5):
28
+ """
29
+ Standard conjugate gradient algorithm to solve Ax=b
30
+ see: http://en.wikipedia.org/wiki/Conjugate_gradient_method
31
+ :param A: Linear operator as a callable function, has to be square!
32
+ :param b: input tensor
33
+ :param max_iter: maximum number of CG iterations
34
+ :param tol: absolute tolerance for stopping the CG algorithm.
35
+ :return: torch tensor x verifying Ax=b
36
+
37
+ """
38
+
39
+ def dot(s1, s2):
40
+ dot = (s1 * s2).flatten().sum()
41
+ return dot
42
+
43
+ x = zeros_like(b)
44
+
45
+ r = b
46
+ p = r
47
+ rsold = dot(r, r)
48
+
49
+ for i in range(int(max_iter)):
50
+ Ap = A(p)
51
+ alpha = rsold / dot(p, Ap)
52
+ x = x + p * alpha
53
+ r = r + Ap * (-alpha)
54
+ rsnew = dot(r, r)
55
+ # print(rsnew.sqrt())
56
+ if rsnew.sqrt() < tol:
57
+ break
58
+ p = r + p * (rsnew / rsold)
59
+ rsold = rsnew
60
+
61
+ return x
62
+
63
+
64
+ def gradient_descent(grad_f, x, step_size=1.0, max_iter=1e2, tol=1e-5):
65
+ """
66
+ Standard gradient descent algorithm to solve min_x f(x)
67
+ :param grad_f: gradient of function to bz minimized as a callable function.
68
+ :param x: input tensor
69
+ :param step_size: (constant) step size of the gradient descent algorithm.
70
+ :param max_iter: maximum number of iterations
71
+ :param tol: absolute tolerance for stopping the algorithm.
72
+ :return: torch tensor x verifying min_x f(x)
73
+
74
+ """
75
+ for i in range(int(max_iter)):
76
+ x_prev = x
77
+ x = x - grad_f(x) * step_size
78
+ if check_conv(x_prev, x, i, thres_conv=tol):
79
+ break
80
+ return x
@@ -0,0 +1,18 @@
1
+ from .inpainting import Inpainting
2
+ from .compressed_sensing import CompressedSensing
3
+ from .blur import Blur, BlindBlur, Downsampling, BlurFFT
4
+ from .range import Decolorize
5
+ from .haze import Haze
6
+ from .forward import Denoising, Physics, LinearPhysics, DecomposablePhysics
7
+ from .noise import (
8
+ GaussianNoise,
9
+ PoissonNoise,
10
+ PoissonGaussianNoise,
11
+ UniformNoise,
12
+ UniformGaussianNoise,
13
+ )
14
+ from .mri import MRI
15
+ from .tomography import Tomography
16
+ from .lidar import SinglePhotonLidar
17
+ from .singlepixel import SinglePixelCamera
18
+ from .remote_sensing import Pansharpen