deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,289 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import warnings
4
+
5
+
6
+ class FixedPoint(nn.Module):
7
+ """
8
+ Fixed-point iterations module.
9
+
10
+ This module implements the fixed-point iteration algorithm given a specific fixed-point iterator (e.g.
11
+ proximal gradient iteration, the ADMM iteration, see :meth:`deepinv.optim.optim_iterators`), that is
12
+ for :math:`k=1,2,...`
13
+
14
+ .. math::
15
+ \qquad (x_{k+1}, u_{k+1}) = \operatorname{FixedPoint}(x_k, u_k, f, g, A, y, ...) \hspace{2cm} (1)
16
+
17
+
18
+ where :math:`f` is the data-fidelity term, :math:`g` is the prior, :math:`A` is the physics model, :math:`y` is the
19
+ data.
20
+
21
+
22
+ ::
23
+
24
+ # This example shows how to use the FixedPoint class to solve the problem
25
+ # min_x 0.5*lambda*||Ax-y||_2^2 + ||x||_1
26
+ # with the PGD algorithm, where A is the identity operator, lambda = 1 and y = [2, 2].
27
+
28
+ # Create the measurement operator A
29
+ A = torch.tensor([[1, 0], [0, 1]], dtype=torch.float64)
30
+ A_forward = lambda v: A @ v
31
+ A_adjoint = lambda v: A.transpose(0, 1) @ v
32
+
33
+ # Define the physics model associated to this operator
34
+ physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
35
+
36
+ # Define the measurement y
37
+ y = torch.tensor([2, 2], dtype=torch.float64)
38
+
39
+ # Define the data fidelity term
40
+ data_fidelity = L2()
41
+
42
+ # Define the proximity operator of the prior and store it in a dictionary
43
+ def prox_g(x, g_param=0.1):
44
+ return torch.sign(x) * torch.maximum(x.abs() - g_param, torch.tensor([0]))
45
+
46
+ prior = {"prox_g": prox_g}
47
+
48
+ # Define the parameters of the algorithm
49
+ params = {"g_param": 1.0, "stepsize": 1.0, "lambda": 1.0}
50
+
51
+ # Choose the iterator associated to the PGD algorithm
52
+ iterator = PGDIteration(data_fidelity=data_fidelity)
53
+
54
+ # Iterate the iterator
55
+ x_init = torch.tensor([2, 2], dtype=torch.float64) # Define initialisation of the algorithm
56
+ X = {"est": (x_init ,), "cost": []} # Iterates are stored in a dictionary of the form {'est': (x,z), 'cost': F}
57
+
58
+ max_iter = 50
59
+ for it in range(max_iter):
60
+ X = iterator(X, prior, params, y, physics)
61
+
62
+ # Return the solution
63
+ sol = X["est"][0] # sol = [1, 1]
64
+
65
+
66
+ :param deepinv.optim.optim_iterators.optim_iterator iterator: function that takes as input the current iterate, as
67
+ well as parameters of the optimization problem (prior, measurements, etc.)
68
+ :param function update_params_fn: function that returns the parameters to be used at each iteration. Default: ``None``.
69
+ :param function update_prior_fn: function that returns the prior to be used at each iteration. Default: ``None``.
70
+ :param function init_iterate_fn: function that returns the initial iterate. Default: ``None``.
71
+ :param function init_metrics_fn: function that returns the initial metrics. Default: ``None``.
72
+ :param function check_iteration_fn: function that performs a check on the last iteration and returns a bool indicating if we can proceed to next iteration. Default: ``None``.
73
+ :param function check_conv_fn: function that checks the convergence after each iteration, returns a bool indicating if convergence has been reached. Default: ``None``.
74
+ :param int max_iter: maximum number of iterations. Default: ``50``.
75
+ :param bool early_stop: if True, the algorithm stops when the convergence criterion is reached. Default: ``True``.
76
+ :param bool anderson_acceleration: if True, the Anderson acceleration is used. Default: ``False``.
77
+ :param int history_size: size of the history used for the Anderson acceleration. Default: ``5``.
78
+ :param float beta_anderson_acc: momentum of the Anderson acceleration step. Default: ``1.0``.
79
+ :param float eps_anderson_acc: regularization parameter of the Anderson acceleration step. Default: ``1e-4``.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ iterator=None,
85
+ update_params_fn=None,
86
+ update_data_fidelity_fn=None,
87
+ update_prior_fn=None,
88
+ init_iterate_fn=None,
89
+ init_metrics_fn=None,
90
+ update_metrics_fn=None,
91
+ check_iteration_fn=None,
92
+ check_conv_fn=None,
93
+ max_iter=50,
94
+ early_stop=True,
95
+ anderson_acceleration=False,
96
+ history_size=5,
97
+ beta_anderson_acc=1.0,
98
+ eps_anderson_acc=1e-4,
99
+ ):
100
+ super().__init__()
101
+ self.iterator = iterator
102
+ self.max_iter = max_iter
103
+ self.early_stop = early_stop
104
+ self.update_params_fn = update_params_fn
105
+ self.update_data_fidelity_fn = update_data_fidelity_fn
106
+ self.update_prior_fn = update_prior_fn
107
+ self.init_iterate_fn = init_iterate_fn
108
+ self.init_metrics_fn = init_metrics_fn
109
+ self.update_metrics_fn = update_metrics_fn
110
+ self.check_conv_fn = check_conv_fn
111
+ self.check_iteration_fn = check_iteration_fn
112
+ self.anderson_acceleration = anderson_acceleration
113
+ self.history_size = history_size
114
+ self.beta_anderson_acc = beta_anderson_acc
115
+ self.eps_anderson_acc = eps_anderson_acc
116
+
117
+ if self.check_conv_fn is None and self.early_stop:
118
+ warnings.warn(
119
+ "early_stop is set to True but no check_conv_fn has been defined."
120
+ )
121
+ self.early_stop = False
122
+
123
+ def init_anderson_acceleration(self, X):
124
+ r"""
125
+ Initialize the Anderson acceleration algorithm.
126
+
127
+ :param dict X: initial iterate.
128
+ """
129
+ x = X["est"][0]
130
+ b, d, h, w = x.shape
131
+ x_hist = torch.zeros(
132
+ b, self.history_size, d * h * w, dtype=x.dtype, device=x.device
133
+ ) # history of iterates.
134
+ T_hist = torch.zeros(
135
+ b, self.history_size, d * h * w, dtype=x.dtype, device=x.device
136
+ ) # history of T(x_k) with T the fixed point operator.
137
+ H = torch.zeros(
138
+ b,
139
+ self.history_size + 1,
140
+ self.history_size + 1,
141
+ dtype=x.dtype,
142
+ device=x.device,
143
+ ) # H in the Anderson acceleration linear system Hp = q .
144
+ H[:, 0, 1:] = H[:, 1:, 0] = 1.0
145
+ q = torch.zeros(
146
+ b, self.history_size + 1, 1, dtype=x.dtype, device=x.device
147
+ ) # q in the Anderson acceleration linear system Hp = q .
148
+ q[:, 0] = 1
149
+ return x_hist, T_hist, H, q
150
+
151
+ def anderson_acceleration_step(
152
+ self,
153
+ it,
154
+ X_prev,
155
+ TX_prev,
156
+ x_hist,
157
+ T_hist,
158
+ H,
159
+ q,
160
+ cur_data_fidelity,
161
+ cur_prior,
162
+ cur_params,
163
+ *args,
164
+ ):
165
+ r"""
166
+ Anderson acceleration step.
167
+
168
+ :param int it: current iteration.
169
+ :param dict X_prev: previous iterate.
170
+ :param dict TX_prev: output of the fixed-point operator evaluated at X_prev
171
+ :param torch.Tensor x_hist: history of last ``history-size`` iterates.
172
+ :param torch.Tensor T_hist: history of T evlauaton at the last ``history-size``, where T is the fixed-point operator.
173
+ :param torch.Tensor H: H in the Anderson acceleration linear system Hp = q .
174
+ :param torch.Tensor q: q in the Anderson acceleration linear system Hp = q .
175
+ :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
176
+ :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
177
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
178
+ :param args: arguments for the iterator.
179
+ """
180
+ x_prev = X_prev["est"][0] # current iterate Tx
181
+ Tx_prev = TX_prev["est"][0] # current iterate x
182
+ b = x_prev.shape[0] # batchsize
183
+ x_hist[:, it % self.history_size] = x_prev.reshape(
184
+ b, -1
185
+ ) # prepare history of x
186
+ T_hist[:, it % self.history_size] = Tx_prev.reshape(
187
+ b, -1
188
+ ) # prepare history of Tx
189
+ m = min(it + 1, self.history_size)
190
+ G = T_hist[:, :m] - x_hist[:, :m]
191
+ H[:, 1 : m + 1, 1 : m + 1] = (
192
+ torch.bmm(G, G.transpose(1, 2))
193
+ + self.eps_anderson_acc
194
+ * torch.eye(m, dtype=Tx_prev.dtype, device=Tx_prev.device)[None]
195
+ )
196
+ p = torch.linalg.solve(H[:, : m + 1, : m + 1], q[:, : m + 1])[
197
+ :, 1 : m + 1, 0
198
+ ] # solve the linear system H p = q.
199
+ x = (
200
+ self.beta_anderson_acc * (p[:, None] @ T_hist[:, :m])[:, 0]
201
+ + (1 - self.beta_anderson_acc) * (p[:, None] @ x_hist[:, :m])[:, 0]
202
+ ) # Anderson acceleration step.
203
+ x = x.view(x_prev.shape)
204
+ F = (
205
+ self.iterator.F_fn(x, cur_data_fidelity, cur_prior, cur_params, *args)
206
+ if self.iterator.has_cost
207
+ else None
208
+ )
209
+ est = list(TX_prev["est"])
210
+ est[0] = x
211
+ return {"est": est, "cost": F}
212
+
213
+ def forward(self, *args, compute_metrics=False, x_gt=None, **kwargs):
214
+ r"""
215
+ Loops over the fixed-point iterator as (1) and returns the fixed point.
216
+
217
+ The iterates are stored in a dictionary of the form ``X = {'est': (x_k, u_k), 'cost': F_k}`` where:
218
+ - ``est`` is a tuple containing the current primal and auxiliary iterates,
219
+ - ``cost`` is the value of the cost function at the current iterate.
220
+
221
+ Since the prior and parameters (stepsize, regularisation parameter, etc.) can change at each iteration,
222
+ the prior and parameters are updated before each call to the iterator.
223
+
224
+ :param bool compute_metrics: if ``True``, the metrics are computed along the iterations. Default: ``False``.
225
+ :param torch.Tensor x_gt: ground truth solution. Default: ``None``.
226
+ :param args: optional arguments for the iterator. Commonly (y,physics) where ``y`` (torch.Tensor y) is the measurement and
227
+ ``physics`` (deepinv.physics) is the physics model.
228
+ :param kwargs: optional keyword arguments for the iterator.
229
+ :return tuple: ``(x,metrics)`` with ``x`` the fixed-point solution (dict) and
230
+ ``metrics`` the computed along the iterations if ``compute_metrics`` is ``True`` or ``None``
231
+ otherwise.
232
+ """
233
+ X = (
234
+ self.init_iterate_fn(*args, F_fn=self.iterator.F_fn)
235
+ if self.init_iterate_fn
236
+ else None
237
+ )
238
+ metrics = (
239
+ self.init_metrics_fn(X, x_gt=x_gt)
240
+ if self.init_metrics_fn and compute_metrics
241
+ else None
242
+ )
243
+ if self.anderson_acceleration:
244
+ x_hist, T_hist, H, q = self.init_anderson_acceleration(X)
245
+ it = 0
246
+ while it < self.max_iter:
247
+ cur_params = self.update_params_fn(it) if self.update_params_fn else None
248
+ cur_data_fidelity = (
249
+ self.update_data_fidelity_fn(it)
250
+ if self.update_data_fidelity_fn
251
+ else None
252
+ )
253
+ cur_prior = self.update_prior_fn(it) if self.update_prior_fn else None
254
+ X_prev = X
255
+ X = self.iterator(X_prev, cur_data_fidelity, cur_prior, cur_params, *args)
256
+ if self.anderson_acceleration:
257
+ X = self.anderson_acceleration_step(
258
+ it,
259
+ X_prev,
260
+ X,
261
+ x_hist,
262
+ T_hist,
263
+ H,
264
+ q,
265
+ cur_data_fidelity,
266
+ cur_prior,
267
+ cur_params,
268
+ *args,
269
+ )
270
+ check_iteration = (
271
+ self.check_iteration_fn(X_prev, X) if self.check_iteration_fn else True
272
+ )
273
+ if check_iteration:
274
+ metrics = (
275
+ self.update_metrics_fn(metrics, X_prev, X, x_gt=x_gt)
276
+ if self.update_metrics_fn and compute_metrics
277
+ else None
278
+ )
279
+ if (
280
+ self.early_stop
281
+ and (self.check_conv_fn is not None)
282
+ and it > 1
283
+ and self.check_conv_fn(it, X_prev, X)
284
+ ):
285
+ break
286
+ it += 1
287
+ else:
288
+ X = X_prev
289
+ return X, metrics
@@ -0,0 +1,9 @@
1
+ from .optim_iterator import OptimIterator
2
+ from .optim_iterator import fStep, gStep
3
+ from .admm import ADMMIteration
4
+ from .pgd import PGDIteration
5
+ from .primal_dual_CP import CPIteration
6
+ from .hqs import HQSIteration
7
+ from .drs import DRSIteration
8
+ from .gradient_descent import GDIteration
9
+ from .optim_iterator import OptimIterator, fStep, gStep
@@ -0,0 +1,117 @@
1
+ import torch
2
+ from .optim_iterator import OptimIterator, fStep, gStep
3
+
4
+
5
+ class ADMMIteration(OptimIterator):
6
+ r"""
7
+ Iterator for alternating direction method of multipliers.
8
+
9
+ Class for a single iteration of the Alternating Direction Method of Multipliers (ADMM) algorithm for
10
+ minimising :math:`\lambda f(x) + g(x)`.
11
+
12
+ If the attribute ``g_first`` is set to False (by default),
13
+ the iteration is (`see this paper <https://www.nowpublishers.com/article/Details/MAL-016>`_):
14
+
15
+ .. math::
16
+ \begin{equation*}
17
+ \begin{aligned}
18
+ u_{k+1} &= \operatorname{prox}_{\gamma \lambda f}(x_k - z_k) \\
19
+ x_{k+1} &= \operatorname{prox}_{\gamma g}(u_{k+1} + z_k) \\
20
+ z_{k+1} &= z_k + \beta (u_{k+1} - x_{k+1})
21
+ \end{aligned}
22
+ \end{equation*}
23
+
24
+ where :math:`\gamma>0` is a stepsize and :math:`\beta>0` is a relaxation parameter.
25
+
26
+ If the attribute ``g_first`` is set to ``True``, the functions :math:`f` and :math:`g` are
27
+ inverted in the previous iteration.
28
+
29
+ """
30
+
31
+ def __init__(self, **kwargs):
32
+ super(ADMMIteration, self).__init__(**kwargs)
33
+ self.g_step = gStepADMM(**kwargs)
34
+ self.f_step = fStepADMM(**kwargs)
35
+ self.requires_prox_g = True
36
+
37
+ def forward(self, X, cur_data_fidelity, cur_prior, cur_params, y, physics):
38
+ r"""
39
+ Single iteration of the ADMM algorithm.
40
+
41
+ :param dict X: Dictionary containing the current iterate and the estimated cost.
42
+ :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
43
+ :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
44
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
45
+ :param torch.Tensor y: Input data.
46
+ :param deepinv.physics physics: Instance of the physics modeling the observation.
47
+ :return: Dictionary `{"est": (x, z), "cost": F}` containing the updated current iterate and the estimated current cost.
48
+ """
49
+ x, z = X["est"]
50
+ if z.shape != x.shape:
51
+ # In ADMM, the "dual" variable z is a fake dual variable as it lives in the primal, hence this line to prevent from usual initialisation
52
+ z = torch.zeros_like(x)
53
+ if self.g_first:
54
+ u = self.g_step(x, z, cur_prior, cur_params)
55
+ x = self.f_step(u, z, cur_data_fidelity, cur_params, y, physics)
56
+ else:
57
+ u = self.f_step(x, z, cur_data_fidelity, cur_params, y, physics)
58
+ x = self.g_step(u, z, cur_prior, cur_params)
59
+ z = z + cur_params["beta"] * (u - x)
60
+ F = (
61
+ self.F_fn(x, cur_data_fidelity, cur_prior, cur_params, y, physics)
62
+ if self.has_cost
63
+ else None
64
+ )
65
+ return {"est": (x, z), "cost": F}
66
+
67
+
68
+ class fStepADMM(fStep):
69
+ r"""
70
+ ADMM fStep module.
71
+ """
72
+
73
+ def __init__(self, **kwargs):
74
+ super(fStepADMM, self).__init__(**kwargs)
75
+
76
+ def forward(self, x, z, cur_data_fidelity, cur_params, y, physics):
77
+ r"""
78
+ Single iteration step on the data-fidelity term :math:`\lambda f`.
79
+
80
+ :param torch.Tensor x: current first variable
81
+ :param torch.Tensor z: current second variable
82
+ :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
83
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
84
+ :param torch.Tensor y: Input data.
85
+ :param deepinv.physics physics: Instance of the physics modeling the observation.
86
+ """
87
+ if self.g_first:
88
+ p = x + z
89
+ else:
90
+ p = x - z
91
+ return cur_data_fidelity.prox(
92
+ p, y, physics, gamma=cur_params["lambda"] * cur_params["stepsize"]
93
+ )
94
+
95
+
96
+ class gStepADMM(gStep):
97
+ r"""
98
+ ADMM gStep module.
99
+ """
100
+
101
+ def __init__(self, **kwargs):
102
+ super(gStepADMM, self).__init__(**kwargs)
103
+
104
+ def forward(self, x, z, cur_prior, cur_params):
105
+ r"""
106
+ Single iteration step on the prior term :math:`g`.
107
+
108
+ :param torch.Tensor x: current first variable
109
+ :param torch.Tensor z: current second variable
110
+ :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
111
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
112
+ """
113
+ if self.g_first:
114
+ p = x - z
115
+ else:
116
+ p = x + z
117
+ return cur_prior.prox(p, cur_params["g_param"], gamma=cur_params["stepsize"])
@@ -0,0 +1,115 @@
1
+ import torch
2
+
3
+ from .optim_iterator import OptimIterator, fStep, gStep
4
+
5
+
6
+ class DRSIteration(OptimIterator):
7
+ r"""
8
+ Iterator for Douglas-Rachford Splitting.
9
+
10
+ Class for a single iteration of the Douglas-Rachford Splitting (DRS) algorithm for minimising
11
+ :math:`\lambda f(x) + g(x)`.
12
+
13
+ If the attribute ``g_first`` is set to False (by default), the iteration is given by
14
+
15
+ .. math::
16
+ \begin{equation*}
17
+ \begin{aligned}
18
+ u_{k+1} &= \operatorname{prox}_{\gamma \lambda f}(z_k) \\
19
+ x_{k+1} &= \operatorname{prox}_{\gamma g}(2*u_{k+1}-z_k) \\
20
+ z_{k+1} &= z_k + \beta (x_{k+1} - u_{k+1})
21
+ \end{aligned}
22
+ \end{equation*}
23
+
24
+ where :math:`\gamma>0` is a stepsize and :math:`\beta>0` is a relaxation parameter.
25
+
26
+ If the attribute ``g_first`` is set to True, the functions :math:`f` and :math:`g` are inverted in the previous iteration.
27
+ """
28
+
29
+ def __init__(self, **kwargs):
30
+ super().__init__(**kwargs)
31
+ self.g_step = gStepDRS(**kwargs)
32
+ self.f_step = fStepDRS(**kwargs)
33
+ self.requires_prox_g = True
34
+
35
+ def forward(self, X, cur_data_fidelity, cur_prior, cur_params, y, physics):
36
+ r"""
37
+ Single iteration of the DRS algorithm.
38
+
39
+ :param dict X: Dictionary containing the current iterate and the estimated cost.
40
+ :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
41
+ :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
42
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
43
+ :param torch.Tensor y: Input data.
44
+ :param deepinv.physics physics: Instance of the physics modeling the observation.
45
+ :return: Dictionary `{"est": (x, z), "cost": F}` containing the updated current iterate and the estimated current cost.
46
+ """
47
+ x, z = X["est"]
48
+ if z.shape != x.shape:
49
+ # In DRS, the "dual" variable z is a fake dual variable as it lives in the primal, hence this line to prevent from usual initialisation
50
+ z = torch.zeros_like(x)
51
+ if self.g_first:
52
+ u = self.g_step(x, z, cur_prior, cur_params)
53
+ x = self.f_step(u, z, cur_data_fidelity, cur_params, y, physics)
54
+ else:
55
+ u = self.f_step(x, z, cur_data_fidelity, cur_params, y, physics)
56
+ x = self.g_step(u, z, cur_prior, cur_params)
57
+ z = z + cur_params["beta"] * (x - u)
58
+ F = (
59
+ self.F_fn(x, cur_data_fidelity, cur_prior, cur_params, y, physics)
60
+ if self.has_cost
61
+ else None
62
+ )
63
+ return {"est": (x, z), "cost": F}
64
+
65
+
66
+ class fStepDRS(fStep):
67
+ r"""
68
+ DRS fStep module.
69
+ """
70
+
71
+ def __init__(self, **kwargs):
72
+ super(fStepDRS, self).__init__(**kwargs)
73
+
74
+ def forward(self, x, z, cur_data_fidelity, cur_params, y, physics):
75
+ r"""
76
+ Single iteration step on the data-fidelity term :math:`f`.
77
+
78
+ :param torch.Tensor x: Current first variable.
79
+ :param torch.Tensor z: Current second variable.
80
+ :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
81
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
82
+ :param torch.Tensor y: Input data.
83
+ :param deepinv.physics physics: Instance of the physics modeling the data-fidelity term.
84
+ """
85
+ if self.g_first:
86
+ p = 2 * x - z
87
+ else:
88
+ p = z
89
+ return cur_data_fidelity.prox(
90
+ p, y, physics, gamma=cur_params["lambda"] * cur_params["stepsize"]
91
+ )
92
+
93
+
94
+ class gStepDRS(gStep):
95
+ r"""
96
+ DRS gStep module.
97
+ """
98
+
99
+ def __init__(self, **kwargs):
100
+ super(gStepDRS, self).__init__(**kwargs)
101
+
102
+ def forward(self, x, z, cur_prior, cur_params):
103
+ r"""
104
+ Single iteration step on the prior term :math:`g`.
105
+
106
+ :param torch.Tensor x: Current first variable.
107
+ :param torch.Tensor z: Current second variable.
108
+ :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
109
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
110
+ """
111
+ if self.g_first:
112
+ p = z
113
+ else:
114
+ p = 2 * x - z
115
+ return cur_prior.prox(p, cur_params["g_param"], gamma=cur_params["stepsize"])
@@ -0,0 +1,90 @@
1
+ from .optim_iterator import OptimIterator, fStep, gStep
2
+ from .utils import gradient_descent_step
3
+
4
+
5
+ class GDIteration(OptimIterator):
6
+ r"""
7
+ Iterator for Gradient Descent.
8
+
9
+ Class for a single iteration of the gradient descent (GD) algorithm for minimising :math:`\lambda f(x) + g(x)`.
10
+
11
+ The iteration is given by
12
+
13
+
14
+ .. math::
15
+ \begin{equation*}
16
+ \begin{aligned}
17
+ v_{k} &= \nabla f(x_k) + \nabla g(x_k) \\
18
+ x_{k+1} &= x_k-\gamma v_{k}
19
+ \end{aligned}
20
+ \end{equation*}
21
+
22
+
23
+ where :math:`\gamma` is a stepsize.
24
+ """
25
+
26
+ def __init__(self, **kwargs):
27
+ super(GDIteration, self).__init__(**kwargs)
28
+ self.g_step = gStepGD(**kwargs)
29
+ self.f_step = fStepGD(**kwargs)
30
+ self.requires_grad_g = True
31
+
32
+ def forward(self, X, cur_data_fidelity, cur_prior, cur_params, y, physics):
33
+ r"""
34
+ Single gradient descent iteration on the objective :math:`\lambda f(x) + g(x)`.
35
+
36
+ :param dict X: Dictionary containing the current iterate :math:`x_k`.
37
+ :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
38
+ :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
39
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
40
+ :param torch.Tensor y: Input data.
41
+ :return: Dictionary `{"est": (x, ), "cost": F}` containing the updated current iterate and the estimated current cost.
42
+ """
43
+ x_prev = X["est"][0]
44
+ grad = cur_params["stepsize"] * (
45
+ self.g_step(x_prev, cur_prior, cur_params)
46
+ + self.f_step(x_prev, cur_data_fidelity, cur_params, y, physics)
47
+ )
48
+ x = gradient_descent_step(x_prev, grad)
49
+ F = self.F_fn(x, cur_prior, cur_params, y, physics) if self.has_cost else None
50
+ return {"est": (x,), "cost": F}
51
+
52
+
53
+ class fStepGD(fStep):
54
+ r"""
55
+ GD fStep module.
56
+ """
57
+
58
+ def __init__(self, **kwargs):
59
+ super(fStepGD, self).__init__(**kwargs)
60
+
61
+ def forward(self, x, cur_data_fidelity, cur_params, y, physics):
62
+ r"""
63
+ Single gradient descent iteration on the data fit term :math:`f`.
64
+
65
+ :param torch.Tensor x: current iterate :math:`x_k`.
66
+ :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
67
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
68
+ :param torch.Tensor y: Input data.
69
+ :param deepinv.physics physics: Instance of the physics modeling the data-fidelity term.
70
+ """
71
+ return cur_params["lambda"] * cur_data_fidelity.grad(x, y, physics)
72
+
73
+
74
+ class gStepGD(gStep):
75
+ r"""
76
+ GD gStep module.
77
+ """
78
+
79
+ def __init__(self, **kwargs):
80
+ super(gStepGD, self).__init__(**kwargs)
81
+
82
+ def forward(self, x, cur_prior, cur_params):
83
+ r"""
84
+ Single iteration step on the prior term :math:`g`.
85
+
86
+ :param torch.Tensor x: Current iterate :math:`x_k`.
87
+ :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
88
+ :param dict cur_params: Dictionary containing the current parameters of the algorithm.
89
+ """
90
+ return cur_prior.grad(x, cur_params["g_param"])