deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,563 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import warnings
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from deepinv.optim.fixed_point import FixedPoint
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from deepinv.utils import cal_psnr
|
|
8
|
+
from deepinv.optim.optim_iterators import *
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseOptim(nn.Module):
|
|
12
|
+
r"""
|
|
13
|
+
Class for optimization algorithms, consists in iterating a fixed-point operator.
|
|
14
|
+
|
|
15
|
+
Module solving the problem
|
|
16
|
+
|
|
17
|
+
.. math::
|
|
18
|
+
\begin{equation}
|
|
19
|
+
\label{eq:min_prob}
|
|
20
|
+
\tag{1}
|
|
21
|
+
\underset{x}{\arg\min} \quad \lambda \datafid{x}{y} + \reg{x},
|
|
22
|
+
\end{equation}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
where the first term :math:`\datafidname:\xset\times\yset \mapsto \mathbb{R}_{+}` enforces data-fidelity, the second
|
|
26
|
+
term :math:`\regname:\xset\mapsto \mathbb{R}_{+}` acts as a regularization and
|
|
27
|
+
:math:`\lambda > 0` is a regularization parameter. More precisely, the data-fidelity term penalizes the discrepancy
|
|
28
|
+
between the data :math:`y` and the forward operator :math:`A` applied to the variable :math:`x`, as
|
|
29
|
+
|
|
30
|
+
.. math::
|
|
31
|
+
\datafid{x}{y} = \distance{Ax}{y}
|
|
32
|
+
|
|
33
|
+
where :math:`\distance{\cdot}{\cdot}` is a distance function, and where :math:`A:\xset\mapsto \yset` is the forward
|
|
34
|
+
operator (see :meth:`deepinv.physics.Physics`)
|
|
35
|
+
|
|
36
|
+
Optimization algorithms for minimising the problem above can be written as fixed point algorithms,
|
|
37
|
+
i.e. for :math:`k=1,2,...`
|
|
38
|
+
|
|
39
|
+
.. math::
|
|
40
|
+
\qquad (x_{k+1}, z_{k+1}) = \operatorname{FixedPoint}(x_k, z_k, f, g, A, y, ...)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
where :math:`x_k` is a variable converging to the solution of the minimization problem, and
|
|
44
|
+
:math:`z_k` is an additional variable that may be required in the computation of the fixed point operator.
|
|
45
|
+
|
|
46
|
+
The :func:`optim_builder` function can be used to instantiate this class with a specific fixed point operator.
|
|
47
|
+
|
|
48
|
+
If the algorithm is minimizing an explicit and fixed cost function :math:`F(x) = \lambda \datafid{x}{y} + \reg{x}`,
|
|
49
|
+
the value of the cost function is computed along the iterations and can be used for convergence criterion.
|
|
50
|
+
Moreover, backtracking can be used to adapt the stepsize at each iteration. Backtracking consits in chosing
|
|
51
|
+
the largest stepsize :math:`\tau` such that, at each iteration, sufficient decrease of the cost function :math:`F` is achieved.
|
|
52
|
+
More precisely, Given :math:`\gamma \in (0,1/2)` and :math:`\eta \in (0,1)` and an initial stepsize :math:`\tau > 0`,
|
|
53
|
+
the following update rule is applied at each iteration :math:`k`:
|
|
54
|
+
|
|
55
|
+
.. math::
|
|
56
|
+
\text{ while } F(x_k) - F(x_{k+1}) < \frac{\gamma}{\tau} || x_{k-1} - x_k ||^2 \text{ do } \tau \leftarrow \eta \tau
|
|
57
|
+
|
|
58
|
+
The variable ``params_algo`` is a dictionary containing all the relevant parameters for running the algorithm.
|
|
59
|
+
If the value associated with the key is a float, the algorithm will use the same parameter across all iterations.
|
|
60
|
+
If the value is list of length max_iter, the algorithm will use the corresponding parameter at each iteration.
|
|
61
|
+
|
|
62
|
+
The variable ``data_fidelity`` is a list of instances of :meth:`deepinv.optim.DataFidelity` (or a single instance).
|
|
63
|
+
If a single instance, the same data-fidelity is used at each iteration. If a list, the data-fidelity can change at each iteration.
|
|
64
|
+
The same holds for the variable ``prior`` which is a list of instances of :meth:`deepinv.optim.Prior` (or a single instance).
|
|
65
|
+
|
|
66
|
+
::
|
|
67
|
+
|
|
68
|
+
# This minimal example shows how to use the BaseOptim class to solve the problem
|
|
69
|
+
# min_x 0.5 \lambda ||Ax-y||_2^2 + ||x||_1
|
|
70
|
+
# with the PGD algorithm, where A is the identity operator, lambda = 1 and y = [2, 2].
|
|
71
|
+
|
|
72
|
+
# Create the measurement operator A
|
|
73
|
+
A = torch.tensor([[1, 0], [0, 1]], dtype=torch.float64)
|
|
74
|
+
A_forward = lambda v: A @ v
|
|
75
|
+
A_adjoint = lambda v: A.transpose(0, 1) @ v
|
|
76
|
+
|
|
77
|
+
# Define the physics model associated to this operator
|
|
78
|
+
physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
|
|
79
|
+
|
|
80
|
+
# Define the measurement y
|
|
81
|
+
y = torch.tensor([2, 2], dtype=torch.float64)
|
|
82
|
+
|
|
83
|
+
# Define the data fidelity term
|
|
84
|
+
data_fidelity = dinv.optim.data_fidelity.L2()
|
|
85
|
+
|
|
86
|
+
# Define the prior
|
|
87
|
+
prior = dinv.optim.Prior(g = lambda x, *args: torch.norm(x, p=1))
|
|
88
|
+
|
|
89
|
+
# Define the parameters of the algorithm
|
|
90
|
+
params_algo = {"stepsize": 0.5, "lambda": 1.0}
|
|
91
|
+
|
|
92
|
+
# Define the fixed-point iterator
|
|
93
|
+
iterator = dinv.optim.optim_iterators.PGDIteration()
|
|
94
|
+
|
|
95
|
+
# Define the optimization algorithm
|
|
96
|
+
optimalgo = dinv.optim.BaseOptim(iterator,
|
|
97
|
+
data_fidelity=data_fidelity,
|
|
98
|
+
params_algo=params_algo,
|
|
99
|
+
prior=prior)
|
|
100
|
+
|
|
101
|
+
# Run the optimization algorithm
|
|
102
|
+
xhat = optimalgo(y, physics)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
:param deepinv.optim.optim_iterators.OptimIterator iterator: Fixed-point iterator of the optimization algorithm of interest.
|
|
106
|
+
:param dict params_algo: dictionary containing all the relevant parameters for running the algorithm,
|
|
107
|
+
e.g. the stepsize, regularisation parameter, denoising standard deviation.
|
|
108
|
+
Each value of the dictionary can be either Iterable (distinct value for each iteration) or
|
|
109
|
+
a single float (same value for each iteration).
|
|
110
|
+
Default: `{"stepsize": 1.0, "lambda": 1.0}`. See :any:`optim-params` for more details.
|
|
111
|
+
:param list, deepinv.optim.DataFidelity: data-fidelity term.
|
|
112
|
+
Either a single instance (same data-fidelity for each iteration) or a list of instances of
|
|
113
|
+
:meth:`deepinv.optim.DataFidelity` (distinct data-fidelity for each iteration). Default: `None`.
|
|
114
|
+
:param list, deepinv.optim.Prior: regularization prior.
|
|
115
|
+
Either a single instance (same prior for each iteration) or a list of instances of
|
|
116
|
+
:meth:`deepinv.optim.Prior` (distinct prior for each iteration). Default: ``None``.
|
|
117
|
+
:param int max_iter: maximum number of iterations of the optimization algorithm. Default: 100.
|
|
118
|
+
:param str crit_conv: convergence criterion to be used for claiming convergence, either ``"residual"`` (residual
|
|
119
|
+
of the iterate norm) or `"cost"` (on the cost function). Default: ``"residual"``
|
|
120
|
+
:param float thres_conv: value of the threshold for claiming convergence. Default: ``1e-05``.
|
|
121
|
+
:param bool early_stop: whether to stop the algorithm once the convergence criterion is reached. Default: ``True``.
|
|
122
|
+
:param bool has_cost: whether the algorithm has an explicit cost function or not. Default: `False`.
|
|
123
|
+
:param dict custom_metrics: dictionary containing custom metrics to be computed at each iteration.
|
|
124
|
+
:param bool backtracking: whether to apply a backtracking strategy for stepsize selection. Default: ``False``.
|
|
125
|
+
:param float gamma_backtracking: :math:`\gamma` parameter in the backtracking selection. Default: ``0.1``.
|
|
126
|
+
:param float eta_backtracking: :math:`\eta` parameter in the backtracking selection. Default: ``0.9``.
|
|
127
|
+
:param function custom_init: initializes the algorithm with ``custom_init(y, physics)``.
|
|
128
|
+
:param function get_output: get the image output given the current dictionary update containing primal and auxiliary variables ``X = {('est' : (primal, aux)}``. Default : ``X['est'][0]``.
|
|
129
|
+
If ``None`` (default value) algorithm is initilialized with :math:`A^Ty`. Default: ``None``.
|
|
130
|
+
:param bool anderson_acceleration: whether to use Anderson acceleration for accelerating the forward fixed-point iterations. Default: ``False``.
|
|
131
|
+
:param int history_size: size of the history of iterates used for Anderson acceleration. Default: ``5``.
|
|
132
|
+
:param float beta_anderson_acc: momentum of the Anderson acceleration step. Default: ``1.0``.
|
|
133
|
+
:param float eps_anderson_acc: regularization parameter of the Anderson acceleration step. Default: ``1e-4``.
|
|
134
|
+
:param bool verbose: whether to print relevant information of the algorithm during its run,
|
|
135
|
+
such as convergence criterion at each iterate. Default: ``False``.
|
|
136
|
+
:return: a torch model that solves the optimization problem.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
iterator,
|
|
142
|
+
params_algo={"lambda": 1.0, "stepsize": 1.0},
|
|
143
|
+
data_fidelity=None,
|
|
144
|
+
prior=None,
|
|
145
|
+
max_iter=100,
|
|
146
|
+
crit_conv="residual",
|
|
147
|
+
thres_conv=1e-5,
|
|
148
|
+
early_stop=False,
|
|
149
|
+
has_cost=False,
|
|
150
|
+
backtracking=False,
|
|
151
|
+
gamma_backtracking=0.1,
|
|
152
|
+
eta_backtracking=0.9,
|
|
153
|
+
custom_metrics=None,
|
|
154
|
+
custom_init=None,
|
|
155
|
+
get_output=lambda X: X["est"][0],
|
|
156
|
+
anderson_acceleration=False,
|
|
157
|
+
history_size=5,
|
|
158
|
+
beta_anderson_acc=1.0,
|
|
159
|
+
eps_anderson_acc=1e-4,
|
|
160
|
+
verbose=False,
|
|
161
|
+
):
|
|
162
|
+
super(BaseOptim, self).__init__()
|
|
163
|
+
|
|
164
|
+
self.early_stop = early_stop
|
|
165
|
+
self.crit_conv = crit_conv
|
|
166
|
+
self.verbose = verbose
|
|
167
|
+
self.max_iter = max_iter
|
|
168
|
+
self.backtracking = backtracking
|
|
169
|
+
self.gamma_backtracking = gamma_backtracking
|
|
170
|
+
self.eta_backtracking = eta_backtracking
|
|
171
|
+
self.has_converged = False
|
|
172
|
+
self.thres_conv = thres_conv
|
|
173
|
+
self.custom_metrics = custom_metrics
|
|
174
|
+
self.custom_init = custom_init
|
|
175
|
+
self.get_output = get_output
|
|
176
|
+
self.has_cost = has_cost
|
|
177
|
+
|
|
178
|
+
# By default ``params_algo`` should contain a prior ``g_param`` parameter, set by default to ``None``.
|
|
179
|
+
if "g_param" not in params_algo.keys():
|
|
180
|
+
params_algo["g_param"] = None
|
|
181
|
+
|
|
182
|
+
# By default ``params_algo`` should contain a relaxation ``beta`` parameter, set by default to 1..
|
|
183
|
+
if "beta" not in params_algo.keys():
|
|
184
|
+
params_algo["beta"] = 1.0
|
|
185
|
+
|
|
186
|
+
# By default, each parameter in ``params_algo` is a list.
|
|
187
|
+
# If given as a single number, we convert it to a list of 1 element.
|
|
188
|
+
# If given as a list of more than 1 element, it should have lenght ``max_iter``.
|
|
189
|
+
for key, value in zip(params_algo.keys(), params_algo.values()):
|
|
190
|
+
if not isinstance(value, Iterable):
|
|
191
|
+
params_algo[key] = [value]
|
|
192
|
+
else:
|
|
193
|
+
if len(params_algo[key]) > 1 and len(params_algo[key]) < self.max_iter:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"The number of elements in the parameter {key} is inferior to max_iter."
|
|
196
|
+
)
|
|
197
|
+
# If ``stepsize`` is a list of more than 1 element, backtracking is impossible.
|
|
198
|
+
if (
|
|
199
|
+
"stepsize" in params_algo.keys()
|
|
200
|
+
and len(params_algo["stepsize"]) > 1
|
|
201
|
+
and self.backtracking
|
|
202
|
+
):
|
|
203
|
+
self.backtracking = False
|
|
204
|
+
warnings.warn(
|
|
205
|
+
"Backtracking impossible when stepsize is predefined as a list. Setting backtracking to False."
|
|
206
|
+
)
|
|
207
|
+
# If no cost function, backtracking is impossible.
|
|
208
|
+
if not self.has_cost and self.backtracking:
|
|
209
|
+
self.backtracking = False
|
|
210
|
+
warnings.warn(
|
|
211
|
+
"Backtracking impossible when no cost function is given. Setting backtracking to False."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# keep track of initial parameters in case they are changed during optimization (e.g. backtracking)
|
|
215
|
+
self.init_params_algo = params_algo
|
|
216
|
+
|
|
217
|
+
# By default, ``self.prior`` should be a list of elements of the class :meth:`deepinv.optim.Prior`. The user could want the prior to change at each iteration.
|
|
218
|
+
if not isinstance(prior, Iterable):
|
|
219
|
+
self.prior = [prior]
|
|
220
|
+
else:
|
|
221
|
+
self.prior = prior
|
|
222
|
+
|
|
223
|
+
# By default, ``self.data_fidelity`` should be a list of elements of the class :meth:`deepinv.optim.DataFidelity`. The user could want the prior to change at each iteration.
|
|
224
|
+
if not isinstance(data_fidelity, Iterable):
|
|
225
|
+
self.data_fidelity = [data_fidelity]
|
|
226
|
+
else:
|
|
227
|
+
self.data_fidelity = data_fidelity
|
|
228
|
+
|
|
229
|
+
# Initialize the fixed-point module
|
|
230
|
+
self.fixed_point = FixedPoint(
|
|
231
|
+
iterator=iterator,
|
|
232
|
+
update_params_fn=self.update_params_fn,
|
|
233
|
+
update_data_fidelity_fn=self.update_data_fidelity_fn,
|
|
234
|
+
update_prior_fn=self.update_prior_fn,
|
|
235
|
+
check_iteration_fn=self.check_iteration_fn,
|
|
236
|
+
check_conv_fn=self.check_conv_fn,
|
|
237
|
+
init_metrics_fn=self.init_metrics_fn,
|
|
238
|
+
init_iterate_fn=self.init_iterate_fn,
|
|
239
|
+
update_metrics_fn=self.update_metrics_fn,
|
|
240
|
+
max_iter=max_iter,
|
|
241
|
+
early_stop=early_stop,
|
|
242
|
+
anderson_acceleration=anderson_acceleration,
|
|
243
|
+
history_size=history_size,
|
|
244
|
+
beta_anderson_acc=beta_anderson_acc,
|
|
245
|
+
eps_anderson_acc=eps_anderson_acc,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def update_params_fn(self, it):
|
|
249
|
+
r"""
|
|
250
|
+
For each parameter ``params_algo``, selects the parameter value for iteration ``it``
|
|
251
|
+
(if this parameter depends on the iteration number).
|
|
252
|
+
|
|
253
|
+
:param int it: iteration number.
|
|
254
|
+
:return: a dictionary containing the parameters of iteration ``it``.
|
|
255
|
+
"""
|
|
256
|
+
cur_params_dict = {
|
|
257
|
+
key: value[it] if len(value) > 1 else value[0]
|
|
258
|
+
for key, value in zip(self.params_algo.keys(), self.params_algo.values())
|
|
259
|
+
}
|
|
260
|
+
return cur_params_dict
|
|
261
|
+
|
|
262
|
+
def update_prior_fn(self, it):
|
|
263
|
+
r"""
|
|
264
|
+
For each prior function in `prior`, selects the prior value for iteration ``it``
|
|
265
|
+
(if this prior depends on the iteration number).
|
|
266
|
+
|
|
267
|
+
:param int it: iteration number.
|
|
268
|
+
:return: a dictionary containing the prior of iteration ``it``.
|
|
269
|
+
"""
|
|
270
|
+
cur_prior = self.prior[it] if len(self.prior) > 1 else self.prior[0]
|
|
271
|
+
return cur_prior
|
|
272
|
+
|
|
273
|
+
def update_data_fidelity_fn(self, it):
|
|
274
|
+
r"""
|
|
275
|
+
For each data_fidelity function in `data_fidelity`, selects the data_fidelity value for iteration ``it``
|
|
276
|
+
(if this data_fidelity depends on the iteration number).
|
|
277
|
+
|
|
278
|
+
:param int it: iteration number.
|
|
279
|
+
:return: a dictionary containing the data_fidelity of iteration ``it``.
|
|
280
|
+
"""
|
|
281
|
+
cur_data_fidelity = (
|
|
282
|
+
self.data_fidelity[it]
|
|
283
|
+
if len(self.data_fidelity) > 1
|
|
284
|
+
else self.data_fidelity[0]
|
|
285
|
+
)
|
|
286
|
+
return cur_data_fidelity
|
|
287
|
+
|
|
288
|
+
def init_iterate_fn(self, y, physics, F_fn=None):
|
|
289
|
+
r"""
|
|
290
|
+
Initializes the iterate of the algorithm.
|
|
291
|
+
The first iterate is stored in a dictionary of the form ``X = {'est': (x_0, u_0), 'cost': F_0}`` where:
|
|
292
|
+
|
|
293
|
+
* ``est`` is a tuple containing the first primal and auxiliary iterates.
|
|
294
|
+
* ``cost`` is the value of the cost function at the first iterate.
|
|
295
|
+
|
|
296
|
+
By default, the first (primal, auxiliary) iterate of the algorithm is chosen as :math:`(A^{\top}y, A^{\top}y)`.
|
|
297
|
+
A custom initialization is possible with the custom_init argument.
|
|
298
|
+
|
|
299
|
+
:param torch.Tensor y: measurement vector.
|
|
300
|
+
:param deepinv.physics: physics of the problem.
|
|
301
|
+
:param F_fn: function that computes the cost function.
|
|
302
|
+
:return: a dictionary containing the first iterate of the algorithm.
|
|
303
|
+
"""
|
|
304
|
+
self.params_algo = (
|
|
305
|
+
self.init_params_algo.copy()
|
|
306
|
+
) # reset parameters to initial values
|
|
307
|
+
if self.custom_init:
|
|
308
|
+
init_X = self.custom_init(y, physics)
|
|
309
|
+
else:
|
|
310
|
+
x_init, z_init = physics.A_adjoint(y), physics.A_adjoint(y)
|
|
311
|
+
init_X = {"est": (x_init, z_init)}
|
|
312
|
+
F = (
|
|
313
|
+
F_fn(
|
|
314
|
+
init_X["est"][0],
|
|
315
|
+
self.update_data_fidelity_fn(0),
|
|
316
|
+
self.update_prior_fn(0),
|
|
317
|
+
self.update_params_fn(0),
|
|
318
|
+
y,
|
|
319
|
+
physics,
|
|
320
|
+
)
|
|
321
|
+
if self.has_cost and F_fn is not None
|
|
322
|
+
else None
|
|
323
|
+
)
|
|
324
|
+
init_X["cost"] = F
|
|
325
|
+
return init_X
|
|
326
|
+
|
|
327
|
+
def init_metrics_fn(self, X_init, x_gt=None):
|
|
328
|
+
r"""
|
|
329
|
+
Initializes the metrics.
|
|
330
|
+
|
|
331
|
+
Metrics are computed for each batch and for each iteration.
|
|
332
|
+
They are represented by a list of list, and ``metrics[metric_name][i,j]`` contains the metric ``metric_name``
|
|
333
|
+
computed for batch i, at iteration j.
|
|
334
|
+
|
|
335
|
+
:param dict X_init: dictionary containing the primal and auxiliary initial iterates.
|
|
336
|
+
:param torch.Tensor x_gt: ground truth image, required for PSNR computation. Default: ``None``.
|
|
337
|
+
:return dict: A dictionary containing the metrics.
|
|
338
|
+
"""
|
|
339
|
+
init = {}
|
|
340
|
+
x_init = self.get_output(X_init)
|
|
341
|
+
self.batch_size = x_init.shape[0]
|
|
342
|
+
if x_gt is not None:
|
|
343
|
+
psnr = [[cal_psnr(x_init[i], x_gt[i])] for i in range(self.batch_size)]
|
|
344
|
+
else:
|
|
345
|
+
psnr = [[] for i in range(self.batch_size)]
|
|
346
|
+
init["psnr"] = psnr
|
|
347
|
+
if self.has_cost:
|
|
348
|
+
init["cost"] = [[] for i in range(self.batch_size)]
|
|
349
|
+
init["residual"] = [[] for i in range(self.batch_size)]
|
|
350
|
+
if self.custom_metrics is not None:
|
|
351
|
+
for custom_metric_name in self.custom_metrics.keys():
|
|
352
|
+
init[custom_metric_name] = [[] for i in range(self.batch_size)]
|
|
353
|
+
return init
|
|
354
|
+
|
|
355
|
+
def update_metrics_fn(self, metrics, X_prev, X, x_gt=None):
|
|
356
|
+
r"""
|
|
357
|
+
Function that compute all the metrics, across all batches, for the current iteration.
|
|
358
|
+
|
|
359
|
+
:param dict metrics: dictionary containing the metrics. Each metric is computed for each batch.
|
|
360
|
+
:param dict X_prev: dictionary containing the primal and dual previous iterates.
|
|
361
|
+
:param dict X: dictionary containing the current primal and dual iterates.
|
|
362
|
+
:param torch.Tensor x_gt: ground truth image, required for PSNR computation. Default: None.
|
|
363
|
+
:return dict: a dictionary containing the updated metrics.
|
|
364
|
+
"""
|
|
365
|
+
if metrics is not None:
|
|
366
|
+
x_prev = self.get_output(X_prev)
|
|
367
|
+
x = self.get_output(X)
|
|
368
|
+
for i in range(self.batch_size):
|
|
369
|
+
residual = (
|
|
370
|
+
((x_prev[i] - x[i]).norm() / (x[i].norm() + 1e-06))
|
|
371
|
+
.detach()
|
|
372
|
+
.cpu()
|
|
373
|
+
.item()
|
|
374
|
+
)
|
|
375
|
+
metrics["residual"][i].append(residual)
|
|
376
|
+
if x_gt is not None:
|
|
377
|
+
psnr = cal_psnr(x[i], x_gt[i])
|
|
378
|
+
metrics["psnr"][i].append(psnr)
|
|
379
|
+
if self.has_cost:
|
|
380
|
+
F = X["cost"][i]
|
|
381
|
+
metrics["cost"][i].append(F.detach().cpu().item())
|
|
382
|
+
if self.custom_metrics is not None:
|
|
383
|
+
for custom_metric_name, custom_metric_fn in zip(
|
|
384
|
+
self.custom_metrics.keys(), self.custom_metrics.values()
|
|
385
|
+
):
|
|
386
|
+
metrics[custom_metric_name][i].append(
|
|
387
|
+
custom_metric_fn(
|
|
388
|
+
metrics[custom_metric_name], x_prev[i], x[i]
|
|
389
|
+
)
|
|
390
|
+
)
|
|
391
|
+
return metrics
|
|
392
|
+
|
|
393
|
+
def check_iteration_fn(self, X_prev, X):
|
|
394
|
+
r"""
|
|
395
|
+
Performs stepsize backtracking.
|
|
396
|
+
|
|
397
|
+
:param dict X_prev: dictionary containing the primal and dual previous iterates.
|
|
398
|
+
:param dict X: dictionary containing the current primal and dual iterates.
|
|
399
|
+
"""
|
|
400
|
+
if self.backtracking and self.has_cost and X_prev is not None:
|
|
401
|
+
x_prev = self.get_output(X_prev)
|
|
402
|
+
x = self.get_output(X)
|
|
403
|
+
x_prev = x_prev.reshape((x_prev.shape[0], -1))
|
|
404
|
+
x = x.reshape((x.shape[0], -1))
|
|
405
|
+
F_prev, F = X_prev["cost"], X["cost"]
|
|
406
|
+
diff_F, diff_x = (
|
|
407
|
+
(F_prev - F).mean(),
|
|
408
|
+
(torch.norm(x - x_prev, p=2, dim=-1) ** 2).mean(),
|
|
409
|
+
)
|
|
410
|
+
stepsize = self.params_algo["stepsize"][0]
|
|
411
|
+
if diff_F < (self.gamma_backtracking / stepsize) * diff_x:
|
|
412
|
+
check_iteration = False
|
|
413
|
+
self.params_algo["stepsize"] = [self.eta_backtracking * stepsize]
|
|
414
|
+
if self.verbose:
|
|
415
|
+
print(
|
|
416
|
+
f'Backtraking : new stepsize = {self.params_algo["stepsize"][0]:.3f}'
|
|
417
|
+
)
|
|
418
|
+
else:
|
|
419
|
+
check_iteration = True
|
|
420
|
+
return check_iteration
|
|
421
|
+
else:
|
|
422
|
+
return True
|
|
423
|
+
|
|
424
|
+
def check_conv_fn(self, it, X_prev, X):
|
|
425
|
+
r"""
|
|
426
|
+
Checks the convergence of the algorithm.
|
|
427
|
+
|
|
428
|
+
:param int it: iteration number.
|
|
429
|
+
:param dict X_prev: dictionary containing the primal and dual previous iterates.
|
|
430
|
+
:param dict X: dictionary containing the current primal and dual iterates.
|
|
431
|
+
:return bool: ``True`` if the algorithm has converged, ``False`` otherwise.
|
|
432
|
+
"""
|
|
433
|
+
if self.crit_conv == "residual":
|
|
434
|
+
x_prev = self.get_output(X_prev)
|
|
435
|
+
x = self.get_output(X)
|
|
436
|
+
x_prev = x_prev.reshape((x_prev.shape[0], -1))
|
|
437
|
+
x = x.reshape((x.shape[0], -1))
|
|
438
|
+
crit_cur = (
|
|
439
|
+
(x_prev - x).norm(p=2, dim=-1) / (x.norm(p=2, dim=-1) + 1e-06)
|
|
440
|
+
).mean()
|
|
441
|
+
elif self.crit_conv == "cost":
|
|
442
|
+
F_prev = X_prev["cost"]
|
|
443
|
+
F = X["cost"]
|
|
444
|
+
crit_cur = ((F_prev - F).norm(dim=-1) / (F.norm(dim=-1) + 1e-06)).mean()
|
|
445
|
+
else:
|
|
446
|
+
raise ValueError("convergence criteria not implemented")
|
|
447
|
+
if crit_cur < self.thres_conv:
|
|
448
|
+
self.has_converged = True
|
|
449
|
+
if self.verbose:
|
|
450
|
+
print(
|
|
451
|
+
f"Iteration {it}, current converge crit. = {crit_cur:.2E}, objective = {self.thres_conv:.2E} \r"
|
|
452
|
+
)
|
|
453
|
+
return True
|
|
454
|
+
else:
|
|
455
|
+
return False
|
|
456
|
+
|
|
457
|
+
def forward(self, y, physics, x_gt=None, compute_metrics=False):
|
|
458
|
+
r"""
|
|
459
|
+
Runs the fixed-point iteration algorithm for solving :ref:`(1) <optim>`.
|
|
460
|
+
|
|
461
|
+
:param torch.Tensor y: measurement vector.
|
|
462
|
+
:param deepinv.physics physics: physics of the problem for the acquisition of ``y``.
|
|
463
|
+
:param torch.Tensor x_gt: (optional) ground truth image, for plotting the PSNR across optim iterations.
|
|
464
|
+
:param bool compute_metrics: whether to compute the metrics or not. Default: ``False``.
|
|
465
|
+
:return: If ``compute_metrics`` is ``False``, returns (torch.Tensor) the output of the algorithm.
|
|
466
|
+
Else, returns (torch.Tensor, dict) the output of the algorithm and the metrics.
|
|
467
|
+
"""
|
|
468
|
+
X, metrics = self.fixed_point(
|
|
469
|
+
y, physics, x_gt=x_gt, compute_metrics=compute_metrics
|
|
470
|
+
)
|
|
471
|
+
x = self.get_output(X)
|
|
472
|
+
if compute_metrics:
|
|
473
|
+
return x, metrics
|
|
474
|
+
else:
|
|
475
|
+
return x
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def create_iterator(iteration, prior=None, F_fn=None, g_first=False):
|
|
479
|
+
r"""
|
|
480
|
+
Helper function for creating an iterator, instance of the :meth:`deepinv.optim.optim_iterators.OptimIterator` class,
|
|
481
|
+
corresponding to the chosen minimization algorithm.
|
|
482
|
+
|
|
483
|
+
:param str, deepinv.optim.optim_iterators.OptimIterator iteration: either the name of the algorithm to be used,
|
|
484
|
+
or directly an optim iterator.
|
|
485
|
+
If an algorithm name (string), should be either ``"PGD"`` (proximal gradient descent), ``"ADMM"`` (ADMM),
|
|
486
|
+
``"HQS"`` (half-quadratic splitting), ``"CP"`` (Chambolle-Pock) or ``"DRS"`` (Douglas Rachford).
|
|
487
|
+
:param list, deepinv.optim.Prior: regularization prior.
|
|
488
|
+
Either a single instance (same prior for each iteration) or a list of instances of
|
|
489
|
+
deepinv.optim.Prior (distinct prior for each iteration). Default: `None`.
|
|
490
|
+
:param callable F_fn: Custom user input cost function. default: None.
|
|
491
|
+
:param bool g_first: whether to perform the step on :math:`g` before that on :math:`f` before or not. Default: False
|
|
492
|
+
"""
|
|
493
|
+
# If no custom objective function F_fn is given but g is explicitly given, we have an explicit objective function.
|
|
494
|
+
explicit_prior = (
|
|
495
|
+
prior[0].explicit_prior if isinstance(prior, list) else prior.explicit_prior
|
|
496
|
+
)
|
|
497
|
+
if F_fn is None and explicit_prior:
|
|
498
|
+
|
|
499
|
+
def F_fn(x, data_fidelity, prior, cur_params, y, physics):
|
|
500
|
+
return cur_params["lambda"] * data_fidelity(x, y, physics) + prior(
|
|
501
|
+
x, cur_params["g_param"]
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
has_cost = True # boolean to indicate if there is a cost function to evaluate along the iterations
|
|
505
|
+
else:
|
|
506
|
+
has_cost = False
|
|
507
|
+
# Create an instance of :class:`deepinv.optim.optim_iterators.OptimIterator`.
|
|
508
|
+
if isinstance(
|
|
509
|
+
iteration, str
|
|
510
|
+
): # If the name of the algorithm is given as a string, the correspondong class is automatically called.
|
|
511
|
+
iterator_fn = str_to_class(iteration + "Iteration")
|
|
512
|
+
return iterator_fn(g_first=g_first, F_fn=F_fn, has_cost=has_cost)
|
|
513
|
+
else:
|
|
514
|
+
# If the iteration is directly given as an instance of OptimIterator, nothing to do
|
|
515
|
+
return iteration
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def optim_builder(
|
|
519
|
+
iteration,
|
|
520
|
+
params_algo={"lambda": 1.0, "stepsize": 1.0},
|
|
521
|
+
data_fidelity=None,
|
|
522
|
+
prior=None,
|
|
523
|
+
F_fn=None,
|
|
524
|
+
g_first=False,
|
|
525
|
+
**kwargs,
|
|
526
|
+
):
|
|
527
|
+
r"""
|
|
528
|
+
Helper function for building an instance of the :meth:`BaseOptim` class.
|
|
529
|
+
|
|
530
|
+
:param str, deepinv.optim.optim_iterators.OptimIterator iteration: either the name of the algorithm to be used,
|
|
531
|
+
or directly an optim iterator.
|
|
532
|
+
If an algorithm name (string), should be either ``"PGD"`` (proximal gradient descent), ``"ADMM"`` (ADMM),
|
|
533
|
+
``"HQS"`` (half-quadratic splitting), ``"CP"`` (Chambolle-Pock) or ``"DRS"`` (Douglas Rachford).
|
|
534
|
+
:param dict params_algo: dictionary containing all the relevant parameters for running the algorithm,
|
|
535
|
+
e.g. the stepsize, regularisation parameter, denoising standart deviation.
|
|
536
|
+
Each value of the dictionary can be either Iterable (distinct value for each iteration) or
|
|
537
|
+
a single float (same value for each iteration). See :any:`optim-params` for more details.
|
|
538
|
+
Default: ``{"stepsize": 1.0, "lambda": 1.0}``.
|
|
539
|
+
:param list, deepinv.optim.DataFidelity: data-fidelity term.
|
|
540
|
+
Either a single instance (same data-fidelity for each iteration) or a list of instances of
|
|
541
|
+
:meth:`deepinv.optim.DataFidelity` (distinct data-fidelity for each iteration). Default: `None`.
|
|
542
|
+
:param list, deepinv.optim.Prior prior: regularization prior.
|
|
543
|
+
Either a single instance (same prior for each iteration) or a list of instances of
|
|
544
|
+
deepinv.optim.Prior (distinct prior for each iteration). Default: `None`.
|
|
545
|
+
:param callable F_fn: Custom user input cost function. default: `None`.
|
|
546
|
+
:param bool g_first: whether to perform the step on :math:`g` before that on :math:`f` before or not. default: `False`
|
|
547
|
+
:param kwargs: additional arguments to be passed to the :meth:`BaseOptim` class.
|
|
548
|
+
:return: an instance of the :meth:`BaseOptim` class.
|
|
549
|
+
|
|
550
|
+
"""
|
|
551
|
+
iterator = create_iterator(iteration, prior=prior, F_fn=F_fn, g_first=g_first)
|
|
552
|
+
return BaseOptim(
|
|
553
|
+
iterator,
|
|
554
|
+
has_cost=iterator.has_cost,
|
|
555
|
+
data_fidelity=data_fidelity,
|
|
556
|
+
prior=prior,
|
|
557
|
+
params_algo=params_algo,
|
|
558
|
+
**kwargs,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def str_to_class(classname):
|
|
563
|
+
return getattr(sys.modules[__name__], classname)
|