deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/models/tgv.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TGV(nn.Module):
|
|
7
|
+
r"""
|
|
8
|
+
Proximal operator of (2nd order) Total Generalised Variation operator.
|
|
9
|
+
|
|
10
|
+
(see K. Bredies, K. Kunisch, and T. Pock, "Total generalized variation," SIAM J. Imaging Sci., 3(3), 492-526, 2010.)
|
|
11
|
+
|
|
12
|
+
This algorithm converges to the unique image :math:`x` (and the auxiliary vector field :math:`r`) minimizing
|
|
13
|
+
|
|
14
|
+
.. math::
|
|
15
|
+
|
|
16
|
+
\underset{x, r}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \lambda_1 \|r\|_{1,2} + \lambda_2 \|J(Dx-r)\|_{1,F}
|
|
17
|
+
|
|
18
|
+
where :math:`D` maps an image to its gradient field and :math:`J` maps a vector field to its Jacobian.
|
|
19
|
+
For a large value of :math:`\lambda_2`, the TGV behaves like the TV.
|
|
20
|
+
For a small value, it behaves like the :math:`\ell_1`-Frobenius norm of the Hessian.
|
|
21
|
+
|
|
22
|
+
The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method
|
|
23
|
+
for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and
|
|
24
|
+
Applications, vol. 158, no. 2, pp. 460-479, 2013.
|
|
25
|
+
|
|
26
|
+
Code (and description) adapted from Laurent Condat's matlab version (https://lcondat.github.io/software.html) and
|
|
27
|
+
Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_.
|
|
28
|
+
|
|
29
|
+
:param bool verbose: Whether to print computation details or not. Default: False.
|
|
30
|
+
:param int n_it_max: Maximum number of iterations. Default: 1000.
|
|
31
|
+
:param float crit: Convergence criterion. Default: 1e-5.
|
|
32
|
+
:param torch.tensor, None x2: Primary variable. Default: None.
|
|
33
|
+
:param torch.tensor, None u2: Dual variable. Default: None.
|
|
34
|
+
:param torch.tensor, None r2: Auxiliary variable. Default: None.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self, verbose=False, n_it_max=1000, crit=1e-5, x2=None, u2=None, r2=None
|
|
39
|
+
):
|
|
40
|
+
super(TGV, self).__init__()
|
|
41
|
+
|
|
42
|
+
self.verbose = verbose
|
|
43
|
+
self.n_it_max = n_it_max
|
|
44
|
+
self.crit = crit
|
|
45
|
+
self.restart = True
|
|
46
|
+
|
|
47
|
+
self.tau = 0.01 # > 0
|
|
48
|
+
|
|
49
|
+
self.rho = 1.99 # in 1,2
|
|
50
|
+
self.sigma = 1 / self.tau / 72
|
|
51
|
+
|
|
52
|
+
self.x2 = x2
|
|
53
|
+
self.r2 = r2
|
|
54
|
+
self.u2 = u2
|
|
55
|
+
|
|
56
|
+
self.has_converged = False
|
|
57
|
+
|
|
58
|
+
def prox_tau_fx(self, x, y):
|
|
59
|
+
return (x + self.tau * y) / (1 + self.tau)
|
|
60
|
+
|
|
61
|
+
def prox_tau_fr(self, r, lambda1):
|
|
62
|
+
left = torch.sqrt(torch.sum(r**2, axis=-1)) / (self.tau * lambda1)
|
|
63
|
+
tmp = r - r / (
|
|
64
|
+
torch.maximum(
|
|
65
|
+
left, torch.tensor([1], device=left.device).type(left.dtype)
|
|
66
|
+
).unsqueeze(-1)
|
|
67
|
+
)
|
|
68
|
+
return tmp
|
|
69
|
+
|
|
70
|
+
def prox_sigma_g_conj(self, u, lambda2):
|
|
71
|
+
return u / (
|
|
72
|
+
torch.maximum(
|
|
73
|
+
torch.sqrt(torch.sum(u**2, axis=-1)) / lambda2,
|
|
74
|
+
torch.tensor([1], device=u.device).type(u.dtype),
|
|
75
|
+
).unsqueeze(-1)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def forward(self, y, ths=None):
|
|
79
|
+
restart = (
|
|
80
|
+
True
|
|
81
|
+
if (self.restart or self.x2 is None or self.x2.shape != y.shape)
|
|
82
|
+
else False
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if restart:
|
|
86
|
+
self.x2 = y.clone()
|
|
87
|
+
self.r2 = torch.zeros((*self.x2.shape, 2), device=self.x2.device).type(
|
|
88
|
+
self.x2.dtype
|
|
89
|
+
)
|
|
90
|
+
self.u2 = torch.zeros((*self.x2.shape, 4), device=self.x2.device).type(
|
|
91
|
+
self.x2.dtype
|
|
92
|
+
)
|
|
93
|
+
self.restart = False
|
|
94
|
+
|
|
95
|
+
if ths is not None:
|
|
96
|
+
lambda1 = ths * 0.1
|
|
97
|
+
lambda2 = ths * 0.15
|
|
98
|
+
|
|
99
|
+
cy = (y**2).sum() / 2
|
|
100
|
+
primalcostlowerbound = 0
|
|
101
|
+
|
|
102
|
+
for _ in range(self.n_it_max):
|
|
103
|
+
x_prev = self.x2.clone()
|
|
104
|
+
tmp = self.tau * epsilonT(self.u2)
|
|
105
|
+
x = self.prox_tau_fx(self.x2 - nablaT(tmp), y)
|
|
106
|
+
r = self.prox_tau_fr(self.r2 + tmp, lambda1)
|
|
107
|
+
u = self.prox_sigma_g_conj(
|
|
108
|
+
self.u2
|
|
109
|
+
+ self.sigma * epsilon(nabla(2 * x - self.x2) - (2 * r - self.r2)),
|
|
110
|
+
lambda2,
|
|
111
|
+
)
|
|
112
|
+
self.x2 = self.x2 + self.rho * (x - self.x2)
|
|
113
|
+
self.r2 = self.r2 + self.rho * (r - self.r2)
|
|
114
|
+
self.u2 = self.u2 + self.rho * (u - self.u2)
|
|
115
|
+
|
|
116
|
+
rel_err = torch.linalg.norm(
|
|
117
|
+
x_prev.flatten() - self.x2.flatten()
|
|
118
|
+
) / torch.linalg.norm(self.x2.flatten() + 1e-12)
|
|
119
|
+
|
|
120
|
+
if _ > 1 and rel_err < self.crit:
|
|
121
|
+
self.has_converged = True
|
|
122
|
+
if self.verbose:
|
|
123
|
+
print("TGV prox reached convergence")
|
|
124
|
+
break
|
|
125
|
+
|
|
126
|
+
if self.verbose and _ % 100 == 0:
|
|
127
|
+
primalcost = (
|
|
128
|
+
torch.linalg.norm(x.flatten() - y.flatten()) ** 2
|
|
129
|
+
+ lambda1 * torch.sum(torch.sqrt(torch.sum(r**2, axis=-1)))
|
|
130
|
+
+ lambda2
|
|
131
|
+
* torch.sum(
|
|
132
|
+
torch.sqrt(torch.sum(epsilon(nabla(x) - r) ** 2, axis=-1))
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
# dualcost = cy - ((y - nablaT(epsilonT(u))) ** 2).sum() / 2.0
|
|
136
|
+
tmp = torch.max(
|
|
137
|
+
torch.sqrt(torch.sum(epsilonT(u) ** 2, axis=-1))
|
|
138
|
+
) # to check feasibility: the value will be <= lambda1 only at convergence. Since u is not feasible, the dual cost is not reliable: the gap=primalcost-dualcost can be <0 and cannot be used as stopping criterion.
|
|
139
|
+
u3 = u / torch.maximum(
|
|
140
|
+
tmp / lambda1, torch.tensor([1], device=tmp.device).type(tmp.dtype)
|
|
141
|
+
) # u3 is a scaled version of u, which is feasible. so, its dual cost is a valid, but very rough lower bound of the primal cost.
|
|
142
|
+
dualcost2 = (
|
|
143
|
+
cy - torch.sum((y - nablaT(epsilonT(u3))) ** 2) / 2.0
|
|
144
|
+
) # we display the best value of dualcost2 computed so far.
|
|
145
|
+
primalcostlowerbound = max(primalcostlowerbound, dualcost2.item())
|
|
146
|
+
if self.verbose:
|
|
147
|
+
print(
|
|
148
|
+
"Iter: ",
|
|
149
|
+
_,
|
|
150
|
+
" Primal cost: ",
|
|
151
|
+
primalcost.item(),
|
|
152
|
+
" Rel err:",
|
|
153
|
+
rel_err,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if _ == self.n_it_max - 1:
|
|
157
|
+
if self.verbose:
|
|
158
|
+
print(
|
|
159
|
+
"The algorithm did not converge, stopped after "
|
|
160
|
+
+ str(_ + 1)
|
|
161
|
+
+ " iterations."
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return self.x2
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def nabla(I):
|
|
168
|
+
b, c, h, w = I.shape
|
|
169
|
+
G = torch.zeros((b, c, h, w, 2), device=I.device).type(I.dtype)
|
|
170
|
+
G[:, :, :-1, :, 0] = G[:, :, :-1, :, 0] - I[:, :, :-1]
|
|
171
|
+
G[:, :, :-1, :, 0] = G[:, :, :-1, :, 0] + I[:, :, 1:]
|
|
172
|
+
G[:, :, :, :-1, 1] = G[:, :, :, :-1, 1] - I[..., :-1]
|
|
173
|
+
G[:, :, :, :-1, 1] = G[:, :, :, :-1, 1] + I[..., 1:]
|
|
174
|
+
return G
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def nablaT(G):
|
|
178
|
+
b, c, h, w = G.shape[:-1]
|
|
179
|
+
I = torch.zeros((b, c, h, w), device=G.device).type(
|
|
180
|
+
G.dtype
|
|
181
|
+
) # note that we just reversed left and right sides of each line to obtain the transposed operator
|
|
182
|
+
I[:, :, :-1] = I[:, :, :-1] - G[:, :, :-1, :, 0]
|
|
183
|
+
I[:, :, 1:] = I[:, :, 1:] + G[:, :, :-1, :, 0]
|
|
184
|
+
I[..., :-1] = I[..., :-1] - G[..., :-1, 1]
|
|
185
|
+
I[..., 1:] = I[..., 1:] + G[..., :-1, 1]
|
|
186
|
+
return I
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
# # ADJOINTNESS TEST
|
|
190
|
+
# u = torch.randn((4, 3, 100,100)).type(torch.DoubleTensor)
|
|
191
|
+
# Au = nabla(u)
|
|
192
|
+
# v = torch.randn(*Au.shape).type(Au.dtype)
|
|
193
|
+
# Atv = nablaT(v)
|
|
194
|
+
# e = v.flatten()@Au.flatten()-Atv.flatten()@u.flatten()
|
|
195
|
+
# print('Adjointness test (should be small): ', e)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def epsilon(I): # Simplified
|
|
199
|
+
b, c, h, w, _ = I.shape
|
|
200
|
+
G = torch.zeros((b, c, h, w, 4), device=I.device).type(I.dtype)
|
|
201
|
+
G[:, :, 1:, :, 0] = G[:, :, 1:, :, 0] - I[:, :, :-1, :, 0] # xdy
|
|
202
|
+
G[..., 0] = G[..., 0] + I[..., 0]
|
|
203
|
+
G[..., 1:, 1] = G[..., 1:, 1] - I[..., :-1, 0] # xdx
|
|
204
|
+
G[..., 1:, 1] = G[..., 1:, 1] + I[..., 1:, 0]
|
|
205
|
+
G[..., 1:, 2] = G[..., 1:, 2] - I[..., :-1, 1] # xdx
|
|
206
|
+
G[..., 2] = G[..., 2] + I[..., 1]
|
|
207
|
+
G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] - I[:, :, :-1, :, 1] # xdy
|
|
208
|
+
G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] + I[:, :, 1:, :, 1]
|
|
209
|
+
return G
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def epsilonT(G):
|
|
213
|
+
b, c, h, w, _ = G.shape
|
|
214
|
+
I = torch.zeros((b, c, h, w, 2), device=G.device).type(G.dtype)
|
|
215
|
+
I[:, :, :-1, :, 0] = I[:, :, :-1, :, 0] - G[:, :, 1:, :, 0]
|
|
216
|
+
I[..., 0] = I[..., 0] + G[..., 0]
|
|
217
|
+
I[..., :-1, 0] = I[..., :-1, 0] - G[..., 1:, 1]
|
|
218
|
+
I[..., 1:, 0] = I[..., 1:, 0] + G[..., 1:, 1]
|
|
219
|
+
I[..., :-1, 1] = I[..., :-1, 1] - G[..., 1:, 2]
|
|
220
|
+
I[..., 1] = I[..., 1] + G[..., 2]
|
|
221
|
+
I[:, :, :-1, :, 1] = I[:, :, :-1, :, 1] - G[:, :, :-1, :, 3]
|
|
222
|
+
I[:, :, 1:, :, 1] = I[:, :, 1:, :, 1] + G[:, :, :-1, :, 3]
|
|
223
|
+
return I
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
# # ADJOINTNESS TEST
|
|
227
|
+
# u = torch.randn((2, 3,100,100,2)).type(torch.DoubleTensor)
|
|
228
|
+
# Au = epsilon(u)
|
|
229
|
+
# v = torch.randn(*Au.shape).type(Au.dtype)
|
|
230
|
+
# Atv = epsilonT(v)
|
|
231
|
+
# e = v.flatten()@Au.flatten()-Atv.flatten()@u.flatten()
|
|
232
|
+
# print('Adjointness test (should be small): ', e)
|
deepinv/models/tv.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TV(nn.Module):
|
|
6
|
+
r"""
|
|
7
|
+
Proximal operator of the isotropic Total Variation operator.
|
|
8
|
+
|
|
9
|
+
This algorithm converges to the unique image :math:`x` that is the solution of
|
|
10
|
+
|
|
11
|
+
.. math::
|
|
12
|
+
|
|
13
|
+
\underset{x}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \lambda \|Dx\|_{1,2},
|
|
14
|
+
|
|
15
|
+
where :math:`D` maps an image to its gradient field.
|
|
16
|
+
|
|
17
|
+
The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method
|
|
18
|
+
for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and
|
|
19
|
+
Applications, vol. 158, no. 2, pp. 460-479, 2013.
|
|
20
|
+
|
|
21
|
+
Code (and description) adapted from Laurent Condat's matlab version (https://lcondat.github.io/software.html) and
|
|
22
|
+
Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_.
|
|
23
|
+
|
|
24
|
+
:param bool verbose: Whether to print computation details or not. Default: False.
|
|
25
|
+
:param int n_it_max: Maximum number of iterations. Default: 1000.
|
|
26
|
+
:param float crit: Convergence criterion. Default: 1e-5.
|
|
27
|
+
:param torch.tensor, None x2: Primary variable. Default: None.
|
|
28
|
+
:param torch.tensor, None u2: Dual variable. Default: None.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
verbose=False,
|
|
34
|
+
n_it_max=1000,
|
|
35
|
+
crit=1e-5,
|
|
36
|
+
x2=None,
|
|
37
|
+
u2=None,
|
|
38
|
+
):
|
|
39
|
+
super(TV, self).__init__()
|
|
40
|
+
|
|
41
|
+
self.verbose = verbose
|
|
42
|
+
self.n_it_max = n_it_max
|
|
43
|
+
self.crit = crit
|
|
44
|
+
self.restart = True
|
|
45
|
+
|
|
46
|
+
self.tau = 0.01 # > 0
|
|
47
|
+
|
|
48
|
+
self.rho = 1.99 # in 1,2
|
|
49
|
+
self.sigma = 1 / self.tau / 72
|
|
50
|
+
|
|
51
|
+
self.x2 = x2
|
|
52
|
+
self.u2 = u2
|
|
53
|
+
|
|
54
|
+
self.has_converged = False
|
|
55
|
+
|
|
56
|
+
def prox_tau_fx(self, x, y):
|
|
57
|
+
return (x + self.tau * y) / (1 + self.tau)
|
|
58
|
+
|
|
59
|
+
def prox_sigma_g_conj(self, u, lambda2):
|
|
60
|
+
return u / (
|
|
61
|
+
torch.maximum(
|
|
62
|
+
torch.sqrt(torch.sum(u**2, axis=-1)) / lambda2,
|
|
63
|
+
torch.tensor([1], device=u.device).type(u.dtype),
|
|
64
|
+
).unsqueeze(-1)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def forward(self, y, ths=None):
|
|
68
|
+
restart = (
|
|
69
|
+
True
|
|
70
|
+
if (self.restart or self.x2 is None or self.x2.shape != y.shape)
|
|
71
|
+
else False
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if restart:
|
|
75
|
+
self.x2 = y.clone()
|
|
76
|
+
self.u2 = torch.zeros((*self.x2.shape, 2), device=self.x2.device).type(
|
|
77
|
+
self.x2.dtype
|
|
78
|
+
)
|
|
79
|
+
self.restart = False
|
|
80
|
+
|
|
81
|
+
if ths is not None:
|
|
82
|
+
lambd = ths
|
|
83
|
+
|
|
84
|
+
for _ in range(self.n_it_max):
|
|
85
|
+
x_prev = self.x2.clone()
|
|
86
|
+
|
|
87
|
+
x = self.prox_tau_fx(self.x2 - self.tau * nablaT(self.u2), y)
|
|
88
|
+
u = self.prox_sigma_g_conj(
|
|
89
|
+
self.u2 + self.sigma * nabla(2 * x - self.x2), lambd
|
|
90
|
+
)
|
|
91
|
+
self.x2 = self.x2 + self.rho * (x - self.x2)
|
|
92
|
+
self.u2 = self.u2 + self.rho * (u - self.u2)
|
|
93
|
+
|
|
94
|
+
rel_err = torch.linalg.norm(
|
|
95
|
+
x_prev.flatten() - self.x2.flatten()
|
|
96
|
+
) / torch.linalg.norm(self.x2.flatten() + 1e-12)
|
|
97
|
+
|
|
98
|
+
if _ > 1 and rel_err < self.crit:
|
|
99
|
+
if self.verbose:
|
|
100
|
+
print("TV prox reached convergence")
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
if _ % 100 == 0 and self.verbose:
|
|
104
|
+
primalcost = 0.5 * torch.linalg.norm(
|
|
105
|
+
self.x2.flatten() - y.flatten()
|
|
106
|
+
) ** 2 + lambd * torch.sum(
|
|
107
|
+
torch.sqrt(torch.sum(nabla(self.x2) ** 2, axis=-1))
|
|
108
|
+
)
|
|
109
|
+
dualcost = (y**2).sum() / 2 - torch.sum(
|
|
110
|
+
(y - nablaT(self.u2)) ** 2
|
|
111
|
+
) / 2.0
|
|
112
|
+
primalcostlowerbound = max(primalcost, dualcost)
|
|
113
|
+
print("Iter ", _, "primal cost :", primalcost.item())
|
|
114
|
+
|
|
115
|
+
return self.x2
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def nabla(I):
|
|
119
|
+
b, c, h, w = I.shape
|
|
120
|
+
G = torch.zeros((b, c, h, w, 2), device=I.device).type(I.dtype)
|
|
121
|
+
G[:, :, :-1, :, 0] = G[:, :, :-1, :, 0] - I[:, :, :-1]
|
|
122
|
+
G[:, :, :-1, :, 0] = G[:, :, :-1, :, 0] + I[:, :, 1:]
|
|
123
|
+
G[:, :, :, :-1, 1] = G[:, :, :, :-1, 1] - I[..., :-1]
|
|
124
|
+
G[:, :, :, :-1, 1] = G[:, :, :, :-1, 1] + I[..., 1:]
|
|
125
|
+
return G
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def nablaT(G):
|
|
129
|
+
b, c, h, w = G.shape[:-1]
|
|
130
|
+
I = torch.zeros((b, c, h, w), device=G.device).type(
|
|
131
|
+
G.dtype
|
|
132
|
+
) # note that we just reversed left and right sides of each line to obtain the transposed operator
|
|
133
|
+
I[:, :, :-1] = I[:, :, :-1] - G[:, :, :-1, :, 0]
|
|
134
|
+
I[:, :, 1:] = I[:, :, 1:] + G[:, :, :-1, :, 0]
|
|
135
|
+
I[..., :-1] = I[..., :-1] - G[..., :-1, 1]
|
|
136
|
+
I[..., 1:] = I[..., 1:] + G[..., :-1, 1]
|
|
137
|
+
return I
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# # ADJOINTNESS TEST
|
|
141
|
+
# u = torch.randn((4, 3, 100,100)).type(torch.DoubleTensor)
|
|
142
|
+
# Au = nabla(u)
|
|
143
|
+
# v = torch.randn(*Au.shape).type(Au.dtype)
|
|
144
|
+
# Atv = nablaT(v)
|
|
145
|
+
# e = v.flatten()@Au.flatten()-Atv.flatten()@u.flatten()
|
|
146
|
+
# print('Adjointness test (should be small): ', e)
|