deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from .optim_iterator import OptimIterator, fStep, gStep
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class HQSIteration(OptimIterator):
|
|
5
|
+
r"""
|
|
6
|
+
Single iteration of half-quadratic splitting.
|
|
7
|
+
|
|
8
|
+
Class for a single iteration of the Half-Quadratic Splitting (HQS) algorithm for minimising :math:`\lambda f(x) + g(x)`.
|
|
9
|
+
The iteration is given by
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
.. math::
|
|
13
|
+
\begin{equation*}
|
|
14
|
+
\begin{aligned}
|
|
15
|
+
u_{k} &= \operatorname{prox}_{\gamma \lambda f}(x_k) \\
|
|
16
|
+
x_{k+1} &= \operatorname{prox}_{\sigma g}(u_k).
|
|
17
|
+
\end{aligned}
|
|
18
|
+
\end{equation*}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
where :math:`\gamma` and :math:`\sigma` are step-sizes. Note that this algorithm does not converge to
|
|
22
|
+
a minimizer of :math:`\lambda f(x) + g(x)`, but instead to a minimizer of
|
|
23
|
+
:math:`\lambda \gamma\, ^1f+\sigma g`, where :math:`^1f` denotes
|
|
24
|
+
the Moreau envelope of :math:`f`
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, **kwargs):
|
|
29
|
+
super(HQSIteration, self).__init__(**kwargs)
|
|
30
|
+
self.g_step = gStepHQS(**kwargs)
|
|
31
|
+
self.f_step = fStepHQS(**kwargs)
|
|
32
|
+
self.requires_prox_g = True
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class fStepHQS(fStep):
|
|
36
|
+
r"""
|
|
37
|
+
HQS fStep module.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, **kwargs):
|
|
41
|
+
super(fStepHQS, self).__init__(**kwargs)
|
|
42
|
+
|
|
43
|
+
def forward(self, x, cur_data_fidelity, cur_params, y, physics):
|
|
44
|
+
r"""
|
|
45
|
+
Single proximal step on the data-fidelity term :math:`f`.
|
|
46
|
+
|
|
47
|
+
:param torch.Tensor x: Current iterate :math:`x_k`.
|
|
48
|
+
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
|
|
49
|
+
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
|
|
50
|
+
:param torch.Tensor y: Input data.
|
|
51
|
+
:param deepinv.physics physics: Instance of the physics modeling the data-fidelity term.
|
|
52
|
+
"""
|
|
53
|
+
return cur_data_fidelity.prox(
|
|
54
|
+
x, y, physics, gamma=cur_params["lambda"] * cur_params["stepsize"]
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class gStepHQS(gStep):
|
|
59
|
+
r"""
|
|
60
|
+
HQS gStep module.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, **kwargs):
|
|
64
|
+
super(gStepHQS, self).__init__(**kwargs)
|
|
65
|
+
|
|
66
|
+
def forward(self, x, cur_prior, cur_params):
|
|
67
|
+
r"""
|
|
68
|
+
Single proximal step on the prior term :math:`g`.
|
|
69
|
+
|
|
70
|
+
:param torch.Tensor x: Current iterate :math:`x_k`.
|
|
71
|
+
:param dict cur_prior: Class containing the current prior.
|
|
72
|
+
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
|
|
73
|
+
"""
|
|
74
|
+
return cur_prior.prox(x, cur_params["g_param"], gamma=cur_params["stepsize"])
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from deepinv.optim.data_fidelity import L2
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OptimIterator(nn.Module):
|
|
7
|
+
r"""
|
|
8
|
+
Base class for all :meth:`Optim` iterators.
|
|
9
|
+
|
|
10
|
+
An optim iterator is an object that implements a fixed point iteration for minimizing the sum of two functions
|
|
11
|
+
:math:`F = \lambda*f + g` where :math:`f` is a data-fidelity term that will be modeled by an instance of physics
|
|
12
|
+
and g is a regularizer. The fixed point iteration takes the form
|
|
13
|
+
|
|
14
|
+
.. math::
|
|
15
|
+
\qquad (x_{k+1}, z_{k+1}) = \operatorname{FixedPoint}(x_k, z_k, f, g, A, y, ...)
|
|
16
|
+
|
|
17
|
+
where :math:`x` is a "primal" variable converging to the solution of the minimization problem, and
|
|
18
|
+
:math:`z` is a "dual" variable.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
.. note::
|
|
22
|
+
By an abuse of terminology, we call "primal" and "dual" variables the variables that are updated
|
|
23
|
+
at each step and which may correspond to the actual primal and dual variables from
|
|
24
|
+
(for instance in the case of the PD algorithm), but not necessarily (for instance in the case of the
|
|
25
|
+
PGD algorithm).
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
The implementation of the fixed point algorithm in :meth:`deepinv.optim` is split in two steps, alternating between
|
|
29
|
+
a step on f and a step on g, that is for :math:`k=1,2,...`
|
|
30
|
+
|
|
31
|
+
.. math::
|
|
32
|
+
z_{k+1} = \operatorname{step}_f(x_k, z_k, y, A, ...)\\
|
|
33
|
+
x_{k+1} = \operatorname{step}_g(x_k, z_k, y, A, ...)
|
|
34
|
+
|
|
35
|
+
where :math:`\operatorname{step}_f` and :math:`\operatorname{step}_g` are the steps on f and g respectively.
|
|
36
|
+
|
|
37
|
+
:param bool g_first: If True, the algorithm starts with a step on g and finishes with a step on f.
|
|
38
|
+
:param F_fn: function that returns the function F to be minimized at each iteration. Default: None.
|
|
39
|
+
:param bool has_cost: If True, the function F is computed at each iteration. Default: False.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, g_first=False, F_fn=None, has_cost=False):
|
|
43
|
+
super(OptimIterator, self).__init__()
|
|
44
|
+
self.g_first = g_first
|
|
45
|
+
self.F_fn = F_fn
|
|
46
|
+
self.has_cost = has_cost
|
|
47
|
+
if self.F_fn is None:
|
|
48
|
+
self.has_cost = False
|
|
49
|
+
self.f_step = fStep(g_first=self.g_first)
|
|
50
|
+
self.g_step = gStep(g_first=self.g_first)
|
|
51
|
+
self.requires_grad_g = False
|
|
52
|
+
self.requires_prox_g = False
|
|
53
|
+
|
|
54
|
+
def relaxation_step(self, u, v, beta):
|
|
55
|
+
r"""
|
|
56
|
+
Performs a relaxation step of the form :math:`\beta u + (1-\beta) v`.
|
|
57
|
+
|
|
58
|
+
:param torch.Tensor u: First tensor.
|
|
59
|
+
:param torch.Tensor v: Second tensor.
|
|
60
|
+
:param float beta: Relaxation parameter.
|
|
61
|
+
:return: Relaxed tensor.
|
|
62
|
+
"""
|
|
63
|
+
return beta * u + (1 - beta) * v
|
|
64
|
+
|
|
65
|
+
def forward(self, X, cur_data_fidelity, cur_prior, cur_params, y, physics):
|
|
66
|
+
r"""
|
|
67
|
+
General form of a single iteration of splitting algorithms for minimizing :math:`F = \lambda f + g`, alternating
|
|
68
|
+
between a step on :math:`f` and a step on :math:`g`.
|
|
69
|
+
The primal and dual variables as well as the estimated cost at the current iterate are stored in a dictionary
|
|
70
|
+
$X$ of the form `{'est': (x,z), 'cost': F}`.
|
|
71
|
+
|
|
72
|
+
:param dict X: Dictionary containing the current iterate and the estimated cost.
|
|
73
|
+
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
|
|
74
|
+
:param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
|
|
75
|
+
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
|
|
76
|
+
:param torch.Tensor y: Input data.
|
|
77
|
+
:param deepinv.physics physics: Instance of the physics modeling the observation.
|
|
78
|
+
:return: Dictionary `{"est": (x, z), "cost": F}` containing the updated current iterate and the estimated current cost.
|
|
79
|
+
"""
|
|
80
|
+
x_prev = X["est"][0]
|
|
81
|
+
if not self.g_first:
|
|
82
|
+
z = self.f_step(x_prev, cur_data_fidelity, cur_params, y, physics)
|
|
83
|
+
x = self.g_step(z, cur_prior, cur_params)
|
|
84
|
+
else:
|
|
85
|
+
z = self.g_step(x_prev, cur_prior, cur_params)
|
|
86
|
+
x = self.f_step(z, cur_data_fidelity, cur_params, y, physics)
|
|
87
|
+
x = self.relaxation_step(x, x_prev, cur_params["beta"])
|
|
88
|
+
F = (
|
|
89
|
+
self.F_fn(x, cur_data_fidelity, cur_prior, cur_params, y, physics)
|
|
90
|
+
if self.has_cost
|
|
91
|
+
else None
|
|
92
|
+
)
|
|
93
|
+
return {"est": (x, z), "cost": F}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class fStep(nn.Module):
|
|
97
|
+
r"""
|
|
98
|
+
Module for the single iteration steps on the data-fidelity term :math:`f`.
|
|
99
|
+
|
|
100
|
+
:param bool g_first: If True, the algorithm starts with a step on g and finishes with a step on f. Default: False.
|
|
101
|
+
:param kwargs: Additional keyword arguments.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, g_first=False, **kwargs):
|
|
105
|
+
super(fStep, self).__init__()
|
|
106
|
+
self.g_first = g_first
|
|
107
|
+
|
|
108
|
+
def forward(self, x, cur_data_fidelity, cur_params, y, physics):
|
|
109
|
+
r"""
|
|
110
|
+
Single iteration step on the data-fidelity term :math:`f`.
|
|
111
|
+
|
|
112
|
+
:param torch.Tensor x: Current iterate.
|
|
113
|
+
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
|
|
114
|
+
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
|
|
115
|
+
:param torch.Tensor y: Input data.
|
|
116
|
+
:param deepinv.physics physics: Instance of the physics modeling the observation.
|
|
117
|
+
"""
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class gStep(nn.Module):
|
|
122
|
+
r"""
|
|
123
|
+
Module for the single iteration steps on the prior term :math:`g`.
|
|
124
|
+
|
|
125
|
+
:param bool g_first: If True, the algorithm starts with a step on g and finishes with a step on f. Default: False.
|
|
126
|
+
:param kwargs: Additional keyword arguments.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self, g_first=False, **kwargs):
|
|
130
|
+
super(gStep, self).__init__()
|
|
131
|
+
self.g_first = g_first
|
|
132
|
+
|
|
133
|
+
def forward(self, x, cur_prior, cur_params):
|
|
134
|
+
r"""
|
|
135
|
+
Single iteration step on the prior term :math:`g`.
|
|
136
|
+
|
|
137
|
+
:param torch.Tensor x: Current iterate.
|
|
138
|
+
:param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
|
|
139
|
+
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
|
|
140
|
+
"""
|
|
141
|
+
pass
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from .optim_iterator import OptimIterator, fStep, gStep
|
|
2
|
+
from .utils import gradient_descent_step
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PGDIteration(OptimIterator):
|
|
6
|
+
r"""
|
|
7
|
+
Iterator for proximal gradient descent.
|
|
8
|
+
|
|
9
|
+
Class for a single iteration of the Proximal Gradient Descent (PGD) algorithm for minimising :math:`\lambda f(x) + g(x)`.
|
|
10
|
+
|
|
11
|
+
The iteration is given by
|
|
12
|
+
|
|
13
|
+
.. math::
|
|
14
|
+
\begin{equation*}
|
|
15
|
+
\begin{aligned}
|
|
16
|
+
u_{k} &= x_k - \lambda \gamma \nabla f(x_k) \\
|
|
17
|
+
x_{k+1} &= \operatorname{prox}_{\gamma g}(u_k),
|
|
18
|
+
\end{aligned}
|
|
19
|
+
\end{equation*}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
where :math:`\gamma` is a stepsize that should satisfy :math:`\lambda \gamma \leq 2/\operatorname{Lip}(\|\nabla f\|)`.
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, **kwargs):
|
|
27
|
+
super(PGDIteration, self).__init__(**kwargs)
|
|
28
|
+
self.g_step = gStepPGD(**kwargs)
|
|
29
|
+
self.f_step = fStepPGD(**kwargs)
|
|
30
|
+
if self.g_first:
|
|
31
|
+
self.requires_grad_g = True
|
|
32
|
+
else:
|
|
33
|
+
self.requires_prox_g = True
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class fStepPGD(fStep):
|
|
37
|
+
r"""
|
|
38
|
+
PGD fStep module.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, **kwargs):
|
|
42
|
+
super(fStepPGD, self).__init__(**kwargs)
|
|
43
|
+
|
|
44
|
+
def forward(self, x, cur_data_fidelity, cur_params, y, physics):
|
|
45
|
+
r"""
|
|
46
|
+
Single PGD iteration step on the data-fidelity term :math:`f`.
|
|
47
|
+
|
|
48
|
+
:param torch.Tensor x: Current iterate :math:`x_k`.
|
|
49
|
+
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
|
|
50
|
+
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
|
|
51
|
+
:param torch.Tensor y: Input data.
|
|
52
|
+
:param deepinv.physics physics: Instance of the physics modeling the data-fidelity term.
|
|
53
|
+
"""
|
|
54
|
+
if not self.g_first:
|
|
55
|
+
# if cur_params["lambda"] >= 2:
|
|
56
|
+
# raise ValueError("lambda must be smaller than 2")
|
|
57
|
+
grad = (
|
|
58
|
+
cur_params["lambda"]
|
|
59
|
+
* cur_params["stepsize"]
|
|
60
|
+
* cur_data_fidelity.grad(x, y, physics)
|
|
61
|
+
)
|
|
62
|
+
return gradient_descent_step(x, grad)
|
|
63
|
+
else:
|
|
64
|
+
return cur_data_fidelity.prox(
|
|
65
|
+
x, y, physics, gamma=cur_params["lambda"] * cur_params["stepsize"]
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class gStepPGD(gStep):
|
|
70
|
+
r"""
|
|
71
|
+
PGD gStep module.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, **kwargs):
|
|
75
|
+
super(gStepPGD, self).__init__(**kwargs)
|
|
76
|
+
|
|
77
|
+
def forward(self, x, cur_prior, cur_params):
|
|
78
|
+
r"""
|
|
79
|
+
Single iteration step on the prior term :math:`g`.
|
|
80
|
+
|
|
81
|
+
:param torch.Tensor x: Current iterate :math:`x_k`.
|
|
82
|
+
:param dict cur_prior: Dictionary containing the current prior.
|
|
83
|
+
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
|
|
84
|
+
"""
|
|
85
|
+
if not self.g_first:
|
|
86
|
+
return cur_prior.prox(
|
|
87
|
+
x, cur_params["g_param"], gamma=cur_params["stepsize"]
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
grad = cur_params["stepsize"] * cur_prior.grad(x, cur_params["g_param"])
|
|
91
|
+
return gradient_descent_step(x, grad)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from .optim_iterator import OptimIterator, fStep, gStep
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CPIteration(OptimIterator):
|
|
7
|
+
r"""
|
|
8
|
+
Iterator for Chambolle-Pock.
|
|
9
|
+
|
|
10
|
+
Class for a single iteration of the `Chambolle-Pock <https://hal.science/hal-00490826/document>`_ Primal-Dual (PD)
|
|
11
|
+
algorithm for minimising :math:`\lambda F(Kx) + G(x)` or :math:`\lambda F(x) + G(Kx)` for generic functions :math:`F` and :math:`G`.
|
|
12
|
+
Our implementation corresponds to Algorithm 1 of `<https://hal.science/hal-00490826/document>`_.
|
|
13
|
+
|
|
14
|
+
If the attribute ``g_first`` is set to ``False`` (by default), the iteration is given by
|
|
15
|
+
|
|
16
|
+
.. math::
|
|
17
|
+
\begin{equation*}
|
|
18
|
+
\begin{aligned}
|
|
19
|
+
u_{k+1} &= \operatorname{prox}_{\sigma (\lambda F)^*}(u_k + \sigma K z_k) \\
|
|
20
|
+
x_{k+1} &= \operatorname{prox}_{\tau G}(x_k-\tau K^\top u_{k+1}) \\
|
|
21
|
+
z_{k+1} &= x_{k+1} + \beta(x_{k+1}-x_k) \\
|
|
22
|
+
\end{aligned}
|
|
23
|
+
\end{equation*}
|
|
24
|
+
|
|
25
|
+
where :math:`(\lambda F)^*` is the Fenchel-Legendre conjugate of :math:`\lambda F`, :math:`\beta>0` is a relaxation parameter, and :math:`\sigma` and :math:`\tau` are step-sizes that should
|
|
26
|
+
satisfy :math:`\sigma \tau \|K\|^2 \leq 1`.
|
|
27
|
+
|
|
28
|
+
If the attribute ``g_first`` is set to ``True``, the functions :math:`F` and :math:`G` are inverted in the previous iteration.
|
|
29
|
+
|
|
30
|
+
In particular, setting :math:`F = \distancename`, :math:`K = A` and :math:`G = \regname`, the above algorithms solves
|
|
31
|
+
|
|
32
|
+
.. math::
|
|
33
|
+
|
|
34
|
+
\begin{equation*}
|
|
35
|
+
\underset{x}{\operatorname{min}} \,\, \lambda \distancename(Ax, y) + \regname(x)
|
|
36
|
+
\end{equation*}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
with a splitting on :math:`\distancename`, with not differentiability assumption needed on :math:`\distancename`
|
|
40
|
+
or :math:`\regname`, not any invertibility assumption on :math:`A`.
|
|
41
|
+
|
|
42
|
+
Note that the algorithm requires an intiliazation of the three variables :math:`x_0`, :math:`z_0` and :math:`u_0`.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, **kwargs):
|
|
46
|
+
super(CPIteration, self).__init__(**kwargs)
|
|
47
|
+
self.g_step = gStepCP(**kwargs)
|
|
48
|
+
self.f_step = fStepCP(**kwargs)
|
|
49
|
+
|
|
50
|
+
def forward(self, X, cur_data_fidelity, cur_prior, cur_params, y, physics):
|
|
51
|
+
r"""
|
|
52
|
+
Single iteration of the Chambolle-Pock algorithm.
|
|
53
|
+
|
|
54
|
+
:param dict X: Dictionary containing the current iterate and the estimated cost.
|
|
55
|
+
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
|
|
56
|
+
:param deepinv.optim.Prior cur_prior: Instance of the Prior class defining the current prior.
|
|
57
|
+
:param dict cur_params: dictionary containing the current parameters of the algorithm.
|
|
58
|
+
:param torch.Tensor y: Input data.
|
|
59
|
+
:param deepinv.physics physics: Instance of the physics modeling the data-fidelity term.
|
|
60
|
+
:return: Dictionary `{"est": (x, ), "cost": F}` containing the updated current iterate and the estimated current cost.
|
|
61
|
+
"""
|
|
62
|
+
x_prev, z_prev, u_prev = X["est"] # x : primal, z : relaxed primal, u : dual
|
|
63
|
+
K = lambda x: cur_params["K"](x) if "K" in cur_params.keys() else x
|
|
64
|
+
K_adjoint = (
|
|
65
|
+
lambda x: cur_params["K_adjoint"](x)
|
|
66
|
+
if "K_adjoint" in cur_params.keys()
|
|
67
|
+
else x
|
|
68
|
+
)
|
|
69
|
+
if self.g_first:
|
|
70
|
+
u = self.g_step(u_prev, K(z_prev), cur_prior, cur_params)
|
|
71
|
+
x = self.f_step(
|
|
72
|
+
x_prev, K_adjoint(u), cur_data_fidelity, y, physics, cur_params
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
u = self.f_step(
|
|
76
|
+
u_prev, K(z_prev), cur_data_fidelity, y, physics, cur_params
|
|
77
|
+
)
|
|
78
|
+
x = self.g_step(x_prev, K_adjoint(u), cur_prior, cur_params)
|
|
79
|
+
z = x + cur_params["beta"] * (x - x_prev)
|
|
80
|
+
F = (
|
|
81
|
+
self.F_fn(x, cur_data_fidelity, cur_prior, cur_params, y, physics)
|
|
82
|
+
if self.has_cost
|
|
83
|
+
else None
|
|
84
|
+
)
|
|
85
|
+
return {"est": (x, z, u), "cost": F}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class fStepCP(fStep):
|
|
89
|
+
r"""
|
|
90
|
+
Chambolle-Pock fStep module.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(self, **kwargs):
|
|
94
|
+
super(fStepCP, self).__init__(**kwargs)
|
|
95
|
+
|
|
96
|
+
def forward(self, x, w, cur_data_fidelity, y, physics, cur_params):
|
|
97
|
+
r"""
|
|
98
|
+
Single Chambolle-Pock iteration step on the data-fidelity term :math:`\lambda f`.
|
|
99
|
+
|
|
100
|
+
:param torch.Tensor x: Current first variable :math:`x` if `"g_first"` and :math:`u` otherwise.
|
|
101
|
+
:param torch.Tensor w: Current second variable :math:`A^\top u` if `"g_first"` and :math:`A z` otherwise.
|
|
102
|
+
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
|
|
103
|
+
:param torch.Tensor y: Input data.
|
|
104
|
+
:param deepinv.physics physics: Instance of the physics modeling the data-fidelity term.
|
|
105
|
+
:param dict cur_params: Dictionary containing the current fStep parameters (keys `"stepsize_dual"` (or `"stepsize"`) and `"lambda"`).
|
|
106
|
+
"""
|
|
107
|
+
if self.g_first:
|
|
108
|
+
p = x - cur_params["stepsize"] * w
|
|
109
|
+
return cur_data_fidelity.prox(
|
|
110
|
+
p, y, physics, gamma=cur_params["stepsize"] * cur_params["lambda"]
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
p = x + cur_params["stepsize_dual"] * w
|
|
114
|
+
return cur_data_fidelity.prox_d_conjugate(
|
|
115
|
+
p, y, gamma=cur_params["stepsize_dual"], lamb=cur_params["lambda"]
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class gStepCP(gStep):
|
|
120
|
+
r"""
|
|
121
|
+
Chambolle-Pock gStep module.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(self, **kwargs):
|
|
125
|
+
super(gStepCP, self).__init__(**kwargs)
|
|
126
|
+
|
|
127
|
+
def forward(self, x, w, cur_prior, cur_params):
|
|
128
|
+
r"""
|
|
129
|
+
Single Chambolle-Pock iteration step on the prior term :math:`g`.
|
|
130
|
+
|
|
131
|
+
:param torch.Tensor x: Current first variable :math:`u` if `"g_first"` and :math:`x` otherwise.
|
|
132
|
+
:param torch.Tensor w: Current second variable :math:`A z` if `"g_first"` and :math:`A^\top u` otherwise.
|
|
133
|
+
:param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
|
|
134
|
+
:param dict cur_params: Dictionary containing the current gStep parameters (keys `"prox_g"`, `"stepsize"` (or `"stepsize_dual"`) and `"g_param"`).
|
|
135
|
+
"""
|
|
136
|
+
if self.g_first:
|
|
137
|
+
p = x + cur_params["stepsize_dual"] * w
|
|
138
|
+
return cur_prior.prox_conjugate(
|
|
139
|
+
p, cur_params["g_param"], gamma=cur_params["stepsize_dual"]
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
p = x - cur_params["stepsize"] * w
|
|
143
|
+
return cur_prior.prox(
|
|
144
|
+
p, cur_params["g_param"], gamma=cur_params["stepsize"]
|
|
145
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
def gradient_descent_step(x, grad, bregman_potential="L2"):
|
|
2
|
+
r"""
|
|
3
|
+
Performs a single step of gradient descent on the Bregman divergence.
|
|
4
|
+
|
|
5
|
+
:param torch.Tensor x: Current iterate.
|
|
6
|
+
:param torch.Tensor grad: Gradient of the Bregman divergence.
|
|
7
|
+
:param str bregman_potential: Bregman potential used in the Bregman divergence.
|
|
8
|
+
"""
|
|
9
|
+
if bregman_potential == "L2":
|
|
10
|
+
grad_step = x - grad
|
|
11
|
+
elif bregman_potential == "Burg_entropy":
|
|
12
|
+
grad_step = x / (1 + x * grad)
|
|
13
|
+
else:
|
|
14
|
+
raise ValueError(
|
|
15
|
+
f"Gradient Descent with bregman potential {bregman_potential} not implemented"
|
|
16
|
+
)
|
|
17
|
+
return grad_step
|