deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/models/tgv.py ADDED
@@ -0,0 +1,232 @@
1
+ import warnings
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class TGV(nn.Module):
7
+ r"""
8
+ Proximal operator of (2nd order) Total Generalised Variation operator.
9
+
10
+ (see K. Bredies, K. Kunisch, and T. Pock, "Total generalized variation," SIAM J. Imaging Sci., 3(3), 492-526, 2010.)
11
+
12
+ This algorithm converges to the unique image :math:`x` (and the auxiliary vector field :math:`r`) minimizing
13
+
14
+ .. math::
15
+
16
+ \underset{x, r}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \lambda_1 \|r\|_{1,2} + \lambda_2 \|J(Dx-r)\|_{1,F}
17
+
18
+ where :math:`D` maps an image to its gradient field and :math:`J` maps a vector field to its Jacobian.
19
+ For a large value of :math:`\lambda_2`, the TGV behaves like the TV.
20
+ For a small value, it behaves like the :math:`\ell_1`-Frobenius norm of the Hessian.
21
+
22
+ The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method
23
+ for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and
24
+ Applications, vol. 158, no. 2, pp. 460-479, 2013.
25
+
26
+ Code (and description) adapted from Laurent Condat's matlab version (https://lcondat.github.io/software.html) and
27
+ Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_.
28
+
29
+ :param bool verbose: Whether to print computation details or not. Default: False.
30
+ :param int n_it_max: Maximum number of iterations. Default: 1000.
31
+ :param float crit: Convergence criterion. Default: 1e-5.
32
+ :param torch.tensor, None x2: Primary variable. Default: None.
33
+ :param torch.tensor, None u2: Dual variable. Default: None.
34
+ :param torch.tensor, None r2: Auxiliary variable. Default: None.
35
+ """
36
+
37
+ def __init__(
38
+ self, verbose=False, n_it_max=1000, crit=1e-5, x2=None, u2=None, r2=None
39
+ ):
40
+ super(TGV, self).__init__()
41
+
42
+ self.verbose = verbose
43
+ self.n_it_max = n_it_max
44
+ self.crit = crit
45
+ self.restart = True
46
+
47
+ self.tau = 0.01 # > 0
48
+
49
+ self.rho = 1.99 # in 1,2
50
+ self.sigma = 1 / self.tau / 72
51
+
52
+ self.x2 = x2
53
+ self.r2 = r2
54
+ self.u2 = u2
55
+
56
+ self.has_converged = False
57
+
58
+ def prox_tau_fx(self, x, y):
59
+ return (x + self.tau * y) / (1 + self.tau)
60
+
61
+ def prox_tau_fr(self, r, lambda1):
62
+ left = torch.sqrt(torch.sum(r**2, axis=-1)) / (self.tau * lambda1)
63
+ tmp = r - r / (
64
+ torch.maximum(
65
+ left, torch.tensor([1], device=left.device).type(left.dtype)
66
+ ).unsqueeze(-1)
67
+ )
68
+ return tmp
69
+
70
+ def prox_sigma_g_conj(self, u, lambda2):
71
+ return u / (
72
+ torch.maximum(
73
+ torch.sqrt(torch.sum(u**2, axis=-1)) / lambda2,
74
+ torch.tensor([1], device=u.device).type(u.dtype),
75
+ ).unsqueeze(-1)
76
+ )
77
+
78
+ def forward(self, y, ths=None):
79
+ restart = (
80
+ True
81
+ if (self.restart or self.x2 is None or self.x2.shape != y.shape)
82
+ else False
83
+ )
84
+
85
+ if restart:
86
+ self.x2 = y.clone()
87
+ self.r2 = torch.zeros((*self.x2.shape, 2), device=self.x2.device).type(
88
+ self.x2.dtype
89
+ )
90
+ self.u2 = torch.zeros((*self.x2.shape, 4), device=self.x2.device).type(
91
+ self.x2.dtype
92
+ )
93
+ self.restart = False
94
+
95
+ if ths is not None:
96
+ lambda1 = ths * 0.1
97
+ lambda2 = ths * 0.15
98
+
99
+ cy = (y**2).sum() / 2
100
+ primalcostlowerbound = 0
101
+
102
+ for _ in range(self.n_it_max):
103
+ x_prev = self.x2.clone()
104
+ tmp = self.tau * epsilonT(self.u2)
105
+ x = self.prox_tau_fx(self.x2 - nablaT(tmp), y)
106
+ r = self.prox_tau_fr(self.r2 + tmp, lambda1)
107
+ u = self.prox_sigma_g_conj(
108
+ self.u2
109
+ + self.sigma * epsilon(nabla(2 * x - self.x2) - (2 * r - self.r2)),
110
+ lambda2,
111
+ )
112
+ self.x2 = self.x2 + self.rho * (x - self.x2)
113
+ self.r2 = self.r2 + self.rho * (r - self.r2)
114
+ self.u2 = self.u2 + self.rho * (u - self.u2)
115
+
116
+ rel_err = torch.linalg.norm(
117
+ x_prev.flatten() - self.x2.flatten()
118
+ ) / torch.linalg.norm(self.x2.flatten() + 1e-12)
119
+
120
+ if _ > 1 and rel_err < self.crit:
121
+ self.has_converged = True
122
+ if self.verbose:
123
+ print("TGV prox reached convergence")
124
+ break
125
+
126
+ if self.verbose and _ % 100 == 0:
127
+ primalcost = (
128
+ torch.linalg.norm(x.flatten() - y.flatten()) ** 2
129
+ + lambda1 * torch.sum(torch.sqrt(torch.sum(r**2, axis=-1)))
130
+ + lambda2
131
+ * torch.sum(
132
+ torch.sqrt(torch.sum(epsilon(nabla(x) - r) ** 2, axis=-1))
133
+ )
134
+ )
135
+ # dualcost = cy - ((y - nablaT(epsilonT(u))) ** 2).sum() / 2.0
136
+ tmp = torch.max(
137
+ torch.sqrt(torch.sum(epsilonT(u) ** 2, axis=-1))
138
+ ) # to check feasibility: the value will be <= lambda1 only at convergence. Since u is not feasible, the dual cost is not reliable: the gap=primalcost-dualcost can be <0 and cannot be used as stopping criterion.
139
+ u3 = u / torch.maximum(
140
+ tmp / lambda1, torch.tensor([1], device=tmp.device).type(tmp.dtype)
141
+ ) # u3 is a scaled version of u, which is feasible. so, its dual cost is a valid, but very rough lower bound of the primal cost.
142
+ dualcost2 = (
143
+ cy - torch.sum((y - nablaT(epsilonT(u3))) ** 2) / 2.0
144
+ ) # we display the best value of dualcost2 computed so far.
145
+ primalcostlowerbound = max(primalcostlowerbound, dualcost2.item())
146
+ if self.verbose:
147
+ print(
148
+ "Iter: ",
149
+ _,
150
+ " Primal cost: ",
151
+ primalcost.item(),
152
+ " Rel err:",
153
+ rel_err,
154
+ )
155
+
156
+ if _ == self.n_it_max - 1:
157
+ if self.verbose:
158
+ print(
159
+ "The algorithm did not converge, stopped after "
160
+ + str(_ + 1)
161
+ + " iterations."
162
+ )
163
+
164
+ return self.x2
165
+
166
+
167
+ def nabla(I):
168
+ b, c, h, w = I.shape
169
+ G = torch.zeros((b, c, h, w, 2), device=I.device).type(I.dtype)
170
+ G[:, :, :-1, :, 0] = G[:, :, :-1, :, 0] - I[:, :, :-1]
171
+ G[:, :, :-1, :, 0] = G[:, :, :-1, :, 0] + I[:, :, 1:]
172
+ G[:, :, :, :-1, 1] = G[:, :, :, :-1, 1] - I[..., :-1]
173
+ G[:, :, :, :-1, 1] = G[:, :, :, :-1, 1] + I[..., 1:]
174
+ return G
175
+
176
+
177
+ def nablaT(G):
178
+ b, c, h, w = G.shape[:-1]
179
+ I = torch.zeros((b, c, h, w), device=G.device).type(
180
+ G.dtype
181
+ ) # note that we just reversed left and right sides of each line to obtain the transposed operator
182
+ I[:, :, :-1] = I[:, :, :-1] - G[:, :, :-1, :, 0]
183
+ I[:, :, 1:] = I[:, :, 1:] + G[:, :, :-1, :, 0]
184
+ I[..., :-1] = I[..., :-1] - G[..., :-1, 1]
185
+ I[..., 1:] = I[..., 1:] + G[..., :-1, 1]
186
+ return I
187
+
188
+
189
+ # # ADJOINTNESS TEST
190
+ # u = torch.randn((4, 3, 100,100)).type(torch.DoubleTensor)
191
+ # Au = nabla(u)
192
+ # v = torch.randn(*Au.shape).type(Au.dtype)
193
+ # Atv = nablaT(v)
194
+ # e = v.flatten()@Au.flatten()-Atv.flatten()@u.flatten()
195
+ # print('Adjointness test (should be small): ', e)
196
+
197
+
198
+ def epsilon(I): # Simplified
199
+ b, c, h, w, _ = I.shape
200
+ G = torch.zeros((b, c, h, w, 4), device=I.device).type(I.dtype)
201
+ G[:, :, 1:, :, 0] = G[:, :, 1:, :, 0] - I[:, :, :-1, :, 0] # xdy
202
+ G[..., 0] = G[..., 0] + I[..., 0]
203
+ G[..., 1:, 1] = G[..., 1:, 1] - I[..., :-1, 0] # xdx
204
+ G[..., 1:, 1] = G[..., 1:, 1] + I[..., 1:, 0]
205
+ G[..., 1:, 2] = G[..., 1:, 2] - I[..., :-1, 1] # xdx
206
+ G[..., 2] = G[..., 2] + I[..., 1]
207
+ G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] - I[:, :, :-1, :, 1] # xdy
208
+ G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] + I[:, :, 1:, :, 1]
209
+ return G
210
+
211
+
212
+ def epsilonT(G):
213
+ b, c, h, w, _ = G.shape
214
+ I = torch.zeros((b, c, h, w, 2), device=G.device).type(G.dtype)
215
+ I[:, :, :-1, :, 0] = I[:, :, :-1, :, 0] - G[:, :, 1:, :, 0]
216
+ I[..., 0] = I[..., 0] + G[..., 0]
217
+ I[..., :-1, 0] = I[..., :-1, 0] - G[..., 1:, 1]
218
+ I[..., 1:, 0] = I[..., 1:, 0] + G[..., 1:, 1]
219
+ I[..., :-1, 1] = I[..., :-1, 1] - G[..., 1:, 2]
220
+ I[..., 1] = I[..., 1] + G[..., 2]
221
+ I[:, :, :-1, :, 1] = I[:, :, :-1, :, 1] - G[:, :, :-1, :, 3]
222
+ I[:, :, 1:, :, 1] = I[:, :, 1:, :, 1] + G[:, :, :-1, :, 3]
223
+ return I
224
+
225
+
226
+ # # ADJOINTNESS TEST
227
+ # u = torch.randn((2, 3,100,100,2)).type(torch.DoubleTensor)
228
+ # Au = epsilon(u)
229
+ # v = torch.randn(*Au.shape).type(Au.dtype)
230
+ # Atv = epsilonT(v)
231
+ # e = v.flatten()@Au.flatten()-Atv.flatten()@u.flatten()
232
+ # print('Adjointness test (should be small): ', e)
deepinv/models/tv.py ADDED
@@ -0,0 +1,146 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class TV(nn.Module):
6
+ r"""
7
+ Proximal operator of the isotropic Total Variation operator.
8
+
9
+ This algorithm converges to the unique image :math:`x` that is the solution of
10
+
11
+ .. math::
12
+
13
+ \underset{x}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \lambda \|Dx\|_{1,2},
14
+
15
+ where :math:`D` maps an image to its gradient field.
16
+
17
+ The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method
18
+ for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and
19
+ Applications, vol. 158, no. 2, pp. 460-479, 2013.
20
+
21
+ Code (and description) adapted from Laurent Condat's matlab version (https://lcondat.github.io/software.html) and
22
+ Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_.
23
+
24
+ :param bool verbose: Whether to print computation details or not. Default: False.
25
+ :param int n_it_max: Maximum number of iterations. Default: 1000.
26
+ :param float crit: Convergence criterion. Default: 1e-5.
27
+ :param torch.tensor, None x2: Primary variable. Default: None.
28
+ :param torch.tensor, None u2: Dual variable. Default: None.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ verbose=False,
34
+ n_it_max=1000,
35
+ crit=1e-5,
36
+ x2=None,
37
+ u2=None,
38
+ ):
39
+ super(TV, self).__init__()
40
+
41
+ self.verbose = verbose
42
+ self.n_it_max = n_it_max
43
+ self.crit = crit
44
+ self.restart = True
45
+
46
+ self.tau = 0.01 # > 0
47
+
48
+ self.rho = 1.99 # in 1,2
49
+ self.sigma = 1 / self.tau / 72
50
+
51
+ self.x2 = x2
52
+ self.u2 = u2
53
+
54
+ self.has_converged = False
55
+
56
+ def prox_tau_fx(self, x, y):
57
+ return (x + self.tau * y) / (1 + self.tau)
58
+
59
+ def prox_sigma_g_conj(self, u, lambda2):
60
+ return u / (
61
+ torch.maximum(
62
+ torch.sqrt(torch.sum(u**2, axis=-1)) / lambda2,
63
+ torch.tensor([1], device=u.device).type(u.dtype),
64
+ ).unsqueeze(-1)
65
+ )
66
+
67
+ def forward(self, y, ths=None):
68
+ restart = (
69
+ True
70
+ if (self.restart or self.x2 is None or self.x2.shape != y.shape)
71
+ else False
72
+ )
73
+
74
+ if restart:
75
+ self.x2 = y.clone()
76
+ self.u2 = torch.zeros((*self.x2.shape, 2), device=self.x2.device).type(
77
+ self.x2.dtype
78
+ )
79
+ self.restart = False
80
+
81
+ if ths is not None:
82
+ lambd = ths
83
+
84
+ for _ in range(self.n_it_max):
85
+ x_prev = self.x2.clone()
86
+
87
+ x = self.prox_tau_fx(self.x2 - self.tau * nablaT(self.u2), y)
88
+ u = self.prox_sigma_g_conj(
89
+ self.u2 + self.sigma * nabla(2 * x - self.x2), lambd
90
+ )
91
+ self.x2 = self.x2 + self.rho * (x - self.x2)
92
+ self.u2 = self.u2 + self.rho * (u - self.u2)
93
+
94
+ rel_err = torch.linalg.norm(
95
+ x_prev.flatten() - self.x2.flatten()
96
+ ) / torch.linalg.norm(self.x2.flatten() + 1e-12)
97
+
98
+ if _ > 1 and rel_err < self.crit:
99
+ if self.verbose:
100
+ print("TV prox reached convergence")
101
+ break
102
+
103
+ if _ % 100 == 0 and self.verbose:
104
+ primalcost = 0.5 * torch.linalg.norm(
105
+ self.x2.flatten() - y.flatten()
106
+ ) ** 2 + lambd * torch.sum(
107
+ torch.sqrt(torch.sum(nabla(self.x2) ** 2, axis=-1))
108
+ )
109
+ dualcost = (y**2).sum() / 2 - torch.sum(
110
+ (y - nablaT(self.u2)) ** 2
111
+ ) / 2.0
112
+ primalcostlowerbound = max(primalcost, dualcost)
113
+ print("Iter ", _, "primal cost :", primalcost.item())
114
+
115
+ return self.x2
116
+
117
+
118
+ def nabla(I):
119
+ b, c, h, w = I.shape
120
+ G = torch.zeros((b, c, h, w, 2), device=I.device).type(I.dtype)
121
+ G[:, :, :-1, :, 0] = G[:, :, :-1, :, 0] - I[:, :, :-1]
122
+ G[:, :, :-1, :, 0] = G[:, :, :-1, :, 0] + I[:, :, 1:]
123
+ G[:, :, :, :-1, 1] = G[:, :, :, :-1, 1] - I[..., :-1]
124
+ G[:, :, :, :-1, 1] = G[:, :, :, :-1, 1] + I[..., 1:]
125
+ return G
126
+
127
+
128
+ def nablaT(G):
129
+ b, c, h, w = G.shape[:-1]
130
+ I = torch.zeros((b, c, h, w), device=G.device).type(
131
+ G.dtype
132
+ ) # note that we just reversed left and right sides of each line to obtain the transposed operator
133
+ I[:, :, :-1] = I[:, :, :-1] - G[:, :, :-1, :, 0]
134
+ I[:, :, 1:] = I[:, :, 1:] + G[:, :, :-1, :, 0]
135
+ I[..., :-1] = I[..., :-1] - G[..., :-1, 1]
136
+ I[..., 1:] = I[..., 1:] + G[..., :-1, 1]
137
+ return I
138
+
139
+
140
+ # # ADJOINTNESS TEST
141
+ # u = torch.randn((4, 3, 100,100)).type(torch.DoubleTensor)
142
+ # Au = nabla(u)
143
+ # v = torch.randn(*Au.shape).type(Au.dtype)
144
+ # Atv = nablaT(v)
145
+ # e = v.flatten()@Au.flatten()-Atv.flatten()@u.flatten()
146
+ # print('Adjointness test (should be small): ', e)