deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from deepinv.optim.fixed_point import FixedPoint
|
|
3
|
+
from deepinv.optim.optim_iterators import *
|
|
4
|
+
from deepinv.unfolded.unfolded import BaseUnfold
|
|
5
|
+
from deepinv.optim.optimizers import create_iterator
|
|
6
|
+
from deepinv.optim.data_fidelity import L2
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseDEQ(BaseUnfold):
|
|
10
|
+
r"""
|
|
11
|
+
Base class for deep equilibrium (DEQ) algorithms. Child of :class:`deepinv.unfolded.BaseUnfold`.
|
|
12
|
+
|
|
13
|
+
Enables to turn any fixed-point algorithm into a DEQ algorithm, i.e. an algorithm
|
|
14
|
+
that can be virtually unrolled infinitely leveraging the implicit function theorem.
|
|
15
|
+
The backward pass is performed using fixed point iterations to find solutions of the fixed-point equation
|
|
16
|
+
|
|
17
|
+
.. math::
|
|
18
|
+
|
|
19
|
+
\begin{equation}
|
|
20
|
+
v = \left(\frac{\partial \operatorname{FixedPoint}(x^\star)}{\partial x^\star} \right )^T v + u.
|
|
21
|
+
\end{equation}
|
|
22
|
+
|
|
23
|
+
where :math:`u` is the incoming gradient from the backward pass,
|
|
24
|
+
and :math:`x^\star` is the equilibrium point of the forward pass.
|
|
25
|
+
|
|
26
|
+
See `this tutorial <http://implicit-layers-tutorial.org/deep_equilibrium_models/>`_ for more details.
|
|
27
|
+
|
|
28
|
+
For now DEQ is only possible with PGD, HQS and GD optimization algorithms.
|
|
29
|
+
|
|
30
|
+
:param int max_iter_backward: Maximum number of backward iterations. Default: ``50``.
|
|
31
|
+
:param bool anderson_acceleration_backward: if True, the Anderson acceleration is used at iteration of fixed-point algorithm for computing the backward pass. Default: ``False``.
|
|
32
|
+
:param int history_size_backward: size of the history used for the Anderson acceleration for the backward pass. Default: ``5``.
|
|
33
|
+
:param float beta_anderson_acc_backward: momentum of the Anderson acceleration step for the backward pass. Default: ``1.0``.
|
|
34
|
+
:param float eps_anderson_acc_backward: regularization parameter of the Anderson acceleration step for the backward pass. Default: ``1e-4``.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
*args,
|
|
40
|
+
max_iter_backward=50,
|
|
41
|
+
anderson_acceleration_backward=False,
|
|
42
|
+
history_size_backward=5,
|
|
43
|
+
beta_anderson_acc_backward=1.0,
|
|
44
|
+
eps_anderson_acc_backward=1e-4,
|
|
45
|
+
**kwargs,
|
|
46
|
+
):
|
|
47
|
+
super().__init__(*args, **kwargs)
|
|
48
|
+
self.max_iter_backward = max_iter_backward
|
|
49
|
+
self.anderson_acceleration = anderson_acceleration_backward
|
|
50
|
+
self.history_size = history_size_backward
|
|
51
|
+
self.beta_anderson_acc = beta_anderson_acc_backward
|
|
52
|
+
self.eps_anderson_acc = eps_anderson_acc_backward
|
|
53
|
+
|
|
54
|
+
def forward(self, y, physics, x_gt=None, compute_metrics=False):
|
|
55
|
+
r"""
|
|
56
|
+
The forward pass of the DEQ algorithm. Compared to :class:`deepinv.unfolded.BaseUnfold`, the backward algorithm is performed using fixed point iterations.
|
|
57
|
+
|
|
58
|
+
:param torch.Tensor y: Input tensor.
|
|
59
|
+
:param deepinv.Physics physics: Physics object.
|
|
60
|
+
:param torch.Tensor x_gt: (optional) ground truth image, for plotting the PSNR across optim iterations.
|
|
61
|
+
:param bool compute_metrics: whether to compute the metrics or not. Default: ``False``.
|
|
62
|
+
:return: If ``compute_metrics`` is ``False``, returns (:class:`torch.Tensor`) the output of the algorithm.
|
|
63
|
+
Else, returns (:class:`torch.Tensor`, dict) the output of the algorithm and the metrics.
|
|
64
|
+
"""
|
|
65
|
+
with torch.no_grad(): # Perform the forward pass without gradient tracking
|
|
66
|
+
X, metrics = self.fixed_point(
|
|
67
|
+
y, physics, x_gt=x_gt, compute_metrics=compute_metrics
|
|
68
|
+
)
|
|
69
|
+
# Once, at the equilibrium point, performs one additional iteration with gradient tracking.
|
|
70
|
+
cur_data_fidelity = self.update_data_fidelity_fn(self.max_iter - 1)
|
|
71
|
+
cur_prior = self.update_prior_fn(self.max_iter - 1)
|
|
72
|
+
cur_params = self.update_params_fn(self.max_iter - 1)
|
|
73
|
+
x = self.fixed_point.iterator(
|
|
74
|
+
X, cur_data_fidelity, cur_prior, cur_params, y, physics
|
|
75
|
+
)["est"][0]
|
|
76
|
+
# Another iteration for jacobian computation via automatic differentiation.
|
|
77
|
+
x0 = x.clone().detach().requires_grad_()
|
|
78
|
+
f0 = self.fixed_point.iterator(
|
|
79
|
+
{"est": (x0,)}, cur_data_fidelity, cur_prior, cur_params, y, physics
|
|
80
|
+
)["est"][0]
|
|
81
|
+
|
|
82
|
+
# Add a backwards hook that takes the incoming backward gradient `X["est"][0]` and solves the fixed point equation
|
|
83
|
+
def backward_hook(grad):
|
|
84
|
+
class backward_iterator(OptimIterator):
|
|
85
|
+
def __init__(self, **kwargs):
|
|
86
|
+
super().__init__(**kwargs)
|
|
87
|
+
|
|
88
|
+
def forward(self, X, *args, **kwargs):
|
|
89
|
+
return {
|
|
90
|
+
"est": (
|
|
91
|
+
torch.autograd.grad(f0, x0, X["est"][0], retain_graph=True)[
|
|
92
|
+
0
|
|
93
|
+
]
|
|
94
|
+
+ grad,
|
|
95
|
+
)
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Use the :class:`deepinv.optim.fixed_point.FixedPoint` class to solve the fixed point equation
|
|
99
|
+
def init_iterate_fn(y, physics, F_fn=None):
|
|
100
|
+
return {"est": (grad,)} # initialize the fixed point algorithm.
|
|
101
|
+
|
|
102
|
+
backward_FP = FixedPoint(
|
|
103
|
+
backward_iterator(),
|
|
104
|
+
init_iterate_fn=init_iterate_fn,
|
|
105
|
+
max_iter=self.max_iter_backward,
|
|
106
|
+
check_conv_fn=self.check_conv_fn,
|
|
107
|
+
anderson_acceleration=self.anderson_acceleration,
|
|
108
|
+
history_size=self.history_size,
|
|
109
|
+
beta_anderson_acc=self.beta_anderson_acc,
|
|
110
|
+
eps_anderson_acc=self.eps_anderson_acc,
|
|
111
|
+
)
|
|
112
|
+
g = backward_FP({"est": (grad,)}, None)[0]["est"][0]
|
|
113
|
+
return g
|
|
114
|
+
|
|
115
|
+
if x.requires_grad:
|
|
116
|
+
x.register_hook(backward_hook)
|
|
117
|
+
|
|
118
|
+
if compute_metrics:
|
|
119
|
+
return x, metrics
|
|
120
|
+
else:
|
|
121
|
+
return x
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def DEQ_builder(
|
|
125
|
+
iteration,
|
|
126
|
+
params_algo={"lambda": 1.0, "stepsize": 1.0},
|
|
127
|
+
data_fidelity=L2(),
|
|
128
|
+
prior=None,
|
|
129
|
+
F_fn=None,
|
|
130
|
+
g_first=False,
|
|
131
|
+
**kwargs,
|
|
132
|
+
):
|
|
133
|
+
r"""
|
|
134
|
+
Helper function for building an instance of the :meth:`BaseDEQ` class.
|
|
135
|
+
|
|
136
|
+
:param str, deepinv.optim.optim_iterators.OptimIterator iteration: either the name of the algorithm to be used,
|
|
137
|
+
or directly an optim iterator.
|
|
138
|
+
If an algorithm name (string), should be either ``"PGD"`` (proximal gradient descent), ``"ADMM"`` (ADMM),
|
|
139
|
+
``"HQS"`` (half-quadratic splitting), ``"CP"`` (Chambolle-Pock) or ``"DRS"`` (Douglas Rachford).
|
|
140
|
+
:param dict params_algo: dictionary containing all the relevant parameters for running the algorithm,
|
|
141
|
+
e.g. the stepsize, regularisation parameter, denoising standard deviation.
|
|
142
|
+
Each value of the dictionary can be either Iterable (distinct value for each iteration) or
|
|
143
|
+
a single float (same value for each iteration).
|
|
144
|
+
Default: ``{"stepsize": 1.0, "lambda": 1.0}``. See :any:`optim-params` for more details.
|
|
145
|
+
:param list, deepinv.optim.DataFidelity: data-fidelity term.
|
|
146
|
+
Either a single instance (same data-fidelity for each iteration) or a list of instances of
|
|
147
|
+
:meth:`deepinv.optim.DataFidelity` (distinct data-fidelity for each iteration). Default: `None`.
|
|
148
|
+
:param list, deepinv.optim.Prior prior: regularization prior.
|
|
149
|
+
Either a single instance (same prior for each iteration) or a list of instances of
|
|
150
|
+
deepinv.optim.Prior (distinct prior for each iteration). Default: `None`.
|
|
151
|
+
:param callable F_fn: Custom user input cost function. default: None.
|
|
152
|
+
:param bool g_first: whether to perform the step on :math:`g` before that on :math:`f` before or not. default: False
|
|
153
|
+
:param kwargs: additional arguments to be passed to the :meth:`BaseUnfold` class.
|
|
154
|
+
"""
|
|
155
|
+
iterator = create_iterator(iteration, prior=prior, F_fn=F_fn, g_first=g_first)
|
|
156
|
+
return BaseDEQ(
|
|
157
|
+
iterator,
|
|
158
|
+
has_cost=iterator.has_cost,
|
|
159
|
+
data_fidelity=data_fidelity,
|
|
160
|
+
prior=prior,
|
|
161
|
+
params_algo=params_algo,
|
|
162
|
+
**kwargs,
|
|
163
|
+
)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from deepinv.optim.optim_iterators import *
|
|
4
|
+
from deepinv.optim.optimizers import BaseOptim, create_iterator
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseUnfold(BaseOptim):
|
|
8
|
+
r"""
|
|
9
|
+
Base class for unfolded algorithms. Child of :class:`deepinv.optim.BaseOptim`.
|
|
10
|
+
|
|
11
|
+
Enables to turn any iterative optimization algorithm into an unfolded algorithm, i.e. an algorithm
|
|
12
|
+
that can be trained end-to-end, with learnable parameters. Recall that the algorithms have the
|
|
13
|
+
following form (see :meth:`deepinv.optim.optim_iterators.BaseIterator`):
|
|
14
|
+
|
|
15
|
+
.. math::
|
|
16
|
+
\begin{aligned}
|
|
17
|
+
z_{k+1} &= \operatorname{step}_f(x_k, z_k, y, A, \lambda, \gamma, ...)\\
|
|
18
|
+
x_{k+1} &= \operatorname{step}_g(x_k, z_k, y, A, \sigma, ...)
|
|
19
|
+
\end{aligned}
|
|
20
|
+
|
|
21
|
+
where :math:`\operatorname{step}_f` and :math:`\operatorname{step}_g` are learnable modules.
|
|
22
|
+
These modules encompass trainable parameters of the algorithm (e.g. stepsize :math:`\gamma`, regularization parameter :math:`\lambda`, prior parameter (`g_param`) :math:`\sigma` ...)
|
|
23
|
+
as well as trainable priors (e.g. a deep denoiser).
|
|
24
|
+
|
|
25
|
+
:param list trainable_params: List of parameters to be trained. Each parameter should be a key of the ``params_algo`` dictionary for the :class:`deepinv.optim.optim_iterators.BaseIterator` class.
|
|
26
|
+
This does not encompass the trainable weights of the prior module .
|
|
27
|
+
:param torch.device device: Device on which to perform the computations. Default: `torch.device("cpu")`.
|
|
28
|
+
:param args: Non-keyword arguments to be passed to the :class:`deepinv.optim.BaseOptim` class.
|
|
29
|
+
:param kwargs: Keyword arguments to be passed to the :class:`deepinv.optim.BaseOptim` class.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, *args, trainable_params=[], device="cpu", **kwargs):
|
|
33
|
+
super().__init__(*args, **kwargs)
|
|
34
|
+
# Each parameter in `init_params_algo` is a list, which is converted to a `nn.ParameterList` if they should be trained.
|
|
35
|
+
for param_key in trainable_params:
|
|
36
|
+
if param_key in self.init_params_algo.keys():
|
|
37
|
+
param_value = self.init_params_algo[param_key]
|
|
38
|
+
self.init_params_algo[param_key] = nn.ParameterList(
|
|
39
|
+
[nn.Parameter(torch.tensor(el).to(device)) for el in param_value]
|
|
40
|
+
)
|
|
41
|
+
self.init_params_algo = nn.ParameterDict(self.init_params_algo)
|
|
42
|
+
self.params_algo = self.init_params_algo.copy()
|
|
43
|
+
# The prior (list of instances of :class:`deepinv.optim.Prior) is converted to a `nn.ModuleList` to be trainable.
|
|
44
|
+
self.prior = nn.ModuleList(self.prior)
|
|
45
|
+
self.data_fidelity = nn.ModuleList(self.data_fidelity)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def unfolded_builder(
|
|
49
|
+
iteration,
|
|
50
|
+
params_algo={"lambda": 1.0, "stepsize": 1.0},
|
|
51
|
+
data_fidelity=None,
|
|
52
|
+
prior=None,
|
|
53
|
+
F_fn=None,
|
|
54
|
+
g_first=False,
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
r"""
|
|
58
|
+
Helper function for building an instance of the :meth:`BaseUnfold` class.
|
|
59
|
+
|
|
60
|
+
:param str, deepinv.optim.optim_iterators.OptimIterator iteration: either the name of the algorithm to be used,
|
|
61
|
+
or directly an optim iterator.
|
|
62
|
+
If an algorithm name (string), should be either ``"PGD"`` (proximal gradient descent), ``"ADMM"`` (ADMM),
|
|
63
|
+
``"HQS"`` (half-quadratic splitting), ``"CP"`` (Chambolle-Pock) or ``"DRS"`` (Douglas Rachford).
|
|
64
|
+
:param dict params_algo: dictionary containing all the relevant parameters for running the algorithm,
|
|
65
|
+
e.g. the stepsize, regularisation parameter, denoising standard deviation.
|
|
66
|
+
Each value of the dictionary can be either Iterable (distinct value for each iteration) or
|
|
67
|
+
a single float (same value for each iteration).
|
|
68
|
+
Default: ``{"stepsize": 1.0, "lambda": 1.0}``. See :any:`optim-params` for more details.
|
|
69
|
+
:param list, deepinv.optim.DataFidelity: data-fidelity term.
|
|
70
|
+
Either a single instance (same data-fidelity for each iteration) or a list of instances of
|
|
71
|
+
:meth:`deepinv.optim.DataFidelity` (distinct data-fidelity for each iteration). Default: `None`.
|
|
72
|
+
:param list, deepinv.optim.Prior prior: regularization prior.
|
|
73
|
+
Either a single instance (same prior for each iteration) or a list of instances of
|
|
74
|
+
deepinv.optim.Prior (distinct prior for each iteration). Default: `None`.
|
|
75
|
+
:param callable F_fn: Custom user input cost function. default: None.
|
|
76
|
+
:param bool g_first: whether to perform the step on :math:`g` before that on :math:`f` before or not. default: False
|
|
77
|
+
:param kwargs: additional arguments to be passed to the :meth:`BaseUnfold` class.
|
|
78
|
+
"""
|
|
79
|
+
iterator = create_iterator(iteration, prior=prior, F_fn=F_fn, g_first=g_first)
|
|
80
|
+
return BaseUnfold(
|
|
81
|
+
iterator,
|
|
82
|
+
has_cost=iterator.has_cost,
|
|
83
|
+
data_fidelity=data_fidelity,
|
|
84
|
+
prior=prior,
|
|
85
|
+
params_algo=params_algo,
|
|
86
|
+
**kwargs,
|
|
87
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .logger import AverageMeter, ProgressMeter, get_timestamp
|
|
2
|
+
from .nn import save_model, load_checkpoint, investigate_model
|
|
3
|
+
from .metric import cal_psnr, cal_mse, cal_psnr_complex
|
|
4
|
+
from .plotting import (
|
|
5
|
+
rescale_img,
|
|
6
|
+
plot,
|
|
7
|
+
torch2cpu,
|
|
8
|
+
plot_curves,
|
|
9
|
+
plot_parameters,
|
|
10
|
+
make_grid,
|
|
11
|
+
wandb_imgs,
|
|
12
|
+
wandb_plot_curves,
|
|
13
|
+
resize_pad_square_tensor,
|
|
14
|
+
)
|
|
15
|
+
from .demo import load_url_image
|
|
16
|
+
from .nn import get_freer_gpu, TensorList, rand_like, zeros_like, randn_like, ones_like
|
|
17
|
+
from .phantoms import RandomPhantomDataset, SheppLoganDataset
|
deepinv/utils/demo.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import shutil
|
|
3
|
+
import os
|
|
4
|
+
import zipfile
|
|
5
|
+
import torch
|
|
6
|
+
import torchvision
|
|
7
|
+
import numpy as np
|
|
8
|
+
from torchvision import transforms
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from io import BytesIO
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MRIData(torch.utils.data.Dataset):
|
|
15
|
+
"""fastMRI dataset (knee subset)."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self, root_dir, train=True, sample_index=None, tag=900, transform=None
|
|
19
|
+
):
|
|
20
|
+
x = torch.load(str(root_dir) + ".pt")
|
|
21
|
+
x = x.squeeze()
|
|
22
|
+
self.transform = transform
|
|
23
|
+
|
|
24
|
+
if train:
|
|
25
|
+
self.x = x[:tag]
|
|
26
|
+
else:
|
|
27
|
+
self.x = x[tag:, ...]
|
|
28
|
+
|
|
29
|
+
self.x = torch.stack([self.x, torch.zeros_like(self.x)], dim=1)
|
|
30
|
+
|
|
31
|
+
if sample_index is not None:
|
|
32
|
+
self.x = self.x[sample_index].unsqueeze(0)
|
|
33
|
+
|
|
34
|
+
def __getitem__(self, index):
|
|
35
|
+
x = self.x[index]
|
|
36
|
+
|
|
37
|
+
if self.transform is not None:
|
|
38
|
+
x = self.transform(x)
|
|
39
|
+
|
|
40
|
+
return x
|
|
41
|
+
|
|
42
|
+
def __len__(self):
|
|
43
|
+
return len(self.x)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_git_root():
|
|
47
|
+
import git
|
|
48
|
+
|
|
49
|
+
git_repo = git.Repo(".", search_parent_directories=True)
|
|
50
|
+
git_root = git_repo.git.rev_parse("--show-toplevel")
|
|
51
|
+
return git_root
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_image_dataset_url(dataset_name, file_type="zip"):
|
|
55
|
+
return (
|
|
56
|
+
"https://huggingface.co/datasets/deepinv/images/resolve/main/"
|
|
57
|
+
+ dataset_name
|
|
58
|
+
+ "."
|
|
59
|
+
+ file_type
|
|
60
|
+
+ "?download=true"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_degradation_url(file_name):
|
|
65
|
+
return (
|
|
66
|
+
"https://huggingface.co/datasets/deepinv/degradations/resolve/main/"
|
|
67
|
+
+ file_name
|
|
68
|
+
+ "?download=true"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_image_url(file_name):
|
|
73
|
+
return (
|
|
74
|
+
"https://huggingface.co/datasets/deepinv/images/resolve/main/"
|
|
75
|
+
+ file_name
|
|
76
|
+
+ "?download=true"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def load_dataset(
|
|
81
|
+
dataset_name, data_dir, transform, download=True, url=None, train=True
|
|
82
|
+
):
|
|
83
|
+
dataset_dir = data_dir / dataset_name
|
|
84
|
+
if dataset_name == "fastmri_knee_singlecoil":
|
|
85
|
+
file_type = "pt"
|
|
86
|
+
else:
|
|
87
|
+
file_type = "zip"
|
|
88
|
+
if download and not dataset_dir.exists():
|
|
89
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
|
90
|
+
if url is None:
|
|
91
|
+
url = get_image_dataset_url(dataset_name, file_type)
|
|
92
|
+
response = requests.get(url, stream=True)
|
|
93
|
+
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
|
94
|
+
block_size = 1024 # 1 Kibibyte
|
|
95
|
+
print("Downloading " + str(dataset_dir) + f".{file_type}")
|
|
96
|
+
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
|
97
|
+
with open(str(dataset_dir) + f".{file_type}", "wb") as file:
|
|
98
|
+
for data in response.iter_content(block_size):
|
|
99
|
+
progress_bar.update(len(data))
|
|
100
|
+
file.write(data)
|
|
101
|
+
progress_bar.close()
|
|
102
|
+
|
|
103
|
+
if file_type == "zip":
|
|
104
|
+
with zipfile.ZipFile(str(dataset_dir) + ".zip") as zip_ref:
|
|
105
|
+
zip_ref.extractall(str(data_dir))
|
|
106
|
+
# remove temp file
|
|
107
|
+
os.remove(str(dataset_dir) + f".{file_type}")
|
|
108
|
+
print(f"{dataset_name} dataset downloaded in {data_dir}")
|
|
109
|
+
else:
|
|
110
|
+
shutil.move(
|
|
111
|
+
str(dataset_dir) + f".{file_type}",
|
|
112
|
+
str(dataset_dir / dataset_name) + f".{file_type}",
|
|
113
|
+
)
|
|
114
|
+
if dataset_name == "fastmri_knee_singlecoil":
|
|
115
|
+
dataset = MRIData(
|
|
116
|
+
train=train, root_dir=dataset_dir / dataset_name, transform=transform
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
dataset = torchvision.datasets.ImageFolder(
|
|
120
|
+
root=dataset_dir, transform=transform
|
|
121
|
+
)
|
|
122
|
+
return dataset
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def load_degradation(name, data_dir, index=0, download=True):
|
|
126
|
+
path = data_dir / name
|
|
127
|
+
if download and not path.exists():
|
|
128
|
+
data_dir.mkdir(parents=True, exist_ok=True)
|
|
129
|
+
url = get_degradation_url(name)
|
|
130
|
+
with requests.get(url, stream=True) as r:
|
|
131
|
+
with open(str(data_dir / name), "wb") as f:
|
|
132
|
+
shutil.copyfileobj(r.raw, f)
|
|
133
|
+
print(f"{name} degradation downloaded in {data_dir}")
|
|
134
|
+
deg = np.load(path, allow_pickle=True)
|
|
135
|
+
deg_torch = torch.from_numpy(deg[index]) # .unsqueeze(0).unsqueeze(0)
|
|
136
|
+
return deg_torch
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def load_url_image(
|
|
140
|
+
url=None, img_size=None, grayscale=False, resize_mode="crop", device="cpu"
|
|
141
|
+
):
|
|
142
|
+
r"""
|
|
143
|
+
|
|
144
|
+
Load an image from a URL and return a torch.Tensor.
|
|
145
|
+
|
|
146
|
+
:param str url: URL of the image file.
|
|
147
|
+
:param int, tuple[int] img_size: Size of the image to return.
|
|
148
|
+
:param bool grayscale: Whether to convert the image to grayscale.
|
|
149
|
+
:param str resize_mode: If ``img_size`` is not None, options are ``"crop"`` or ``"resize"``.
|
|
150
|
+
:param str device: Device on which to load the image (gpu or cpu).
|
|
151
|
+
:return: :class:`torch.Tensor` containing the image.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
response = requests.get(url)
|
|
155
|
+
img = Image.open(BytesIO(response.content))
|
|
156
|
+
transform_list = []
|
|
157
|
+
if img_size is not None:
|
|
158
|
+
if resize_mode == "crop":
|
|
159
|
+
transform_list.append(transforms.CenterCrop(img_size))
|
|
160
|
+
elif resize_mode == "resize":
|
|
161
|
+
transform_list.append(transforms.Resize(img_size))
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"resize_mode must be either 'crop' or 'resize', got {resize_mode}"
|
|
165
|
+
)
|
|
166
|
+
if grayscale:
|
|
167
|
+
transform_list.append(transforms.Grayscale())
|
|
168
|
+
transform_list.append(transforms.ToTensor())
|
|
169
|
+
transform = transforms.Compose(transform_list)
|
|
170
|
+
x = transform(img).unsqueeze(0).to(device)
|
|
171
|
+
return x
|
deepinv/utils/logger.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import csv
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# utils
|
|
7
|
+
class AverageMeter(object):
|
|
8
|
+
"""Computes and stores the average and current value"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, name, fmt=":f"):
|
|
11
|
+
self.name = name
|
|
12
|
+
self.fmt = fmt
|
|
13
|
+
self.reset()
|
|
14
|
+
|
|
15
|
+
def reset(self):
|
|
16
|
+
self.val = 0
|
|
17
|
+
self.avg = 0
|
|
18
|
+
self.sum = 0
|
|
19
|
+
self.count = 0
|
|
20
|
+
|
|
21
|
+
def update(self, val, n=1):
|
|
22
|
+
self.val = val
|
|
23
|
+
self.sum += val * n
|
|
24
|
+
self.count += n
|
|
25
|
+
self.avg = self.sum / self.count
|
|
26
|
+
|
|
27
|
+
def __str__(self):
|
|
28
|
+
fmtstr = "{name}={avg" + self.fmt + "}"
|
|
29
|
+
return fmtstr.format(**self.__dict__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ProgressMeter(object):
|
|
33
|
+
def __init__(self, num_epochs, meters, surfix="", prefix=""):
|
|
34
|
+
self.epoch_fmtstr = self._get_epoch_fmtstr(num_epochs)
|
|
35
|
+
self.meters = meters
|
|
36
|
+
self.surfix = surfix
|
|
37
|
+
self.prefix = prefix
|
|
38
|
+
|
|
39
|
+
def display(self, epoch):
|
|
40
|
+
entries = [self.surfix]
|
|
41
|
+
entries += [get_timestamp()]
|
|
42
|
+
entries += [self.epoch_fmtstr.format(epoch)]
|
|
43
|
+
entries += [str(meter) for meter in self.meters]
|
|
44
|
+
entries += [self.prefix]
|
|
45
|
+
print("\t".join(entries))
|
|
46
|
+
|
|
47
|
+
def _get_epoch_fmtstr(self, num_epochs):
|
|
48
|
+
num_digits = len(str(num_epochs // 1))
|
|
49
|
+
fmt = "{:" + str(num_digits) + "d}"
|
|
50
|
+
return "[" + fmt + "/" + fmt.format(num_epochs) + "]"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# --------------------------------
|
|
54
|
+
# logger
|
|
55
|
+
# --------------------------------
|
|
56
|
+
def get_timestamp():
|
|
57
|
+
return datetime.now().strftime("%y-%m-%d-%H:%M:%S")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class LOG(object):
|
|
61
|
+
def __init__(self, filepath, filename, field_name):
|
|
62
|
+
self.filepath = filepath
|
|
63
|
+
self.filename = filename
|
|
64
|
+
self.field_name = field_name
|
|
65
|
+
|
|
66
|
+
self.logfile, self.logwriter = csv_log(
|
|
67
|
+
file_name=os.path.join(filepath, filename + ".csv"), field_name=field_name
|
|
68
|
+
)
|
|
69
|
+
self.logwriter.writeheader()
|
|
70
|
+
|
|
71
|
+
def record(self, *args):
|
|
72
|
+
dict = {}
|
|
73
|
+
for i in range(len(self.field_name)):
|
|
74
|
+
dict[self.field_name[i]] = args[i]
|
|
75
|
+
self.logwriter.writerow(dict)
|
|
76
|
+
|
|
77
|
+
def close(self):
|
|
78
|
+
self.logfile.close()
|
|
79
|
+
|
|
80
|
+
def print(self, msg):
|
|
81
|
+
logT(msg)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def csv_log(file_name, field_name):
|
|
85
|
+
assert file_name is not None
|
|
86
|
+
assert field_name is not None
|
|
87
|
+
logfile = open(file_name, "w")
|
|
88
|
+
logwriter = csv.DictWriter(logfile, fieldnames=field_name)
|
|
89
|
+
return logfile, logwriter
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def logT(*args, **kwargs):
|
|
93
|
+
print(get_timestamp(), *args, **kwargs)
|
deepinv/utils/metric.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def norm(a):
|
|
5
|
+
return a.pow(2).sum(dim=3).sum(dim=2).sqrt().unsqueeze(2).unsqueeze(3)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def cal_angle(a, b):
|
|
9
|
+
norm_a = (a * a).flatten().sum().sqrt()
|
|
10
|
+
norm_b = (b * b).flatten().sum().sqrt()
|
|
11
|
+
angle = (a * b).flatten().sum() / (norm_a * norm_b)
|
|
12
|
+
angle = angle.acos() / 3.14159265359
|
|
13
|
+
return angle.detach().cpu().numpy()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def cal_psnr(a, b, max_pixel=1, normalize=False):
|
|
17
|
+
r"""
|
|
18
|
+
Computes the peak signal-to-noise ratio (PSNR)
|
|
19
|
+
|
|
20
|
+
If the tensors have size (N, C, H, W), then the PSNR is computed as
|
|
21
|
+
|
|
22
|
+
.. math::
|
|
23
|
+
\text{PSNR} = \frac{20}{N} \log_{10} \frac{MAX_I}{\sqrt{\|a- b\|^2_2 / (CHW) }}
|
|
24
|
+
|
|
25
|
+
where :math:`MAX_I` is the maximum possible pixel value of the image (e.g. 1.0 for a
|
|
26
|
+
normalized image), and :math:`a` and :math:`b` are the estimate and reference images.
|
|
27
|
+
|
|
28
|
+
:param torch.Tensor a: tensor estimate
|
|
29
|
+
:param torch.Tensor b: tensor reference
|
|
30
|
+
:param float max_pixel: maximum pixel value
|
|
31
|
+
:param bool normalize: if ``True``, a is normalized to have the same norm as b.
|
|
32
|
+
"""
|
|
33
|
+
with torch.no_grad():
|
|
34
|
+
if type(a) is list or type(a) is tuple:
|
|
35
|
+
a = a[0]
|
|
36
|
+
b = b[0]
|
|
37
|
+
|
|
38
|
+
if normalize:
|
|
39
|
+
an = a / norm(a) * norm(b)
|
|
40
|
+
else:
|
|
41
|
+
an = a
|
|
42
|
+
|
|
43
|
+
mse = (an - b).pow(2).reshape(an.shape[0], -1).mean(dim=1)
|
|
44
|
+
mse[mse == 0] = 1e-10
|
|
45
|
+
psnr = 20 * torch.log10(max_pixel / mse.sqrt())
|
|
46
|
+
|
|
47
|
+
return psnr.mean().detach().cpu().item()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def cal_mse(a, b):
|
|
51
|
+
"""Computes the mean squared error (MSE)"""
|
|
52
|
+
with torch.no_grad():
|
|
53
|
+
mse = torch.mean((a - b) ** 2)
|
|
54
|
+
return mse
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def cal_psnr_complex(a, b):
|
|
58
|
+
"""
|
|
59
|
+
first permute the dimension, such that the last dimension of the tensor is 2 (real, imag)
|
|
60
|
+
:param a: shape [N,2,H,W]
|
|
61
|
+
:param b: shape [N,2,H,W]
|
|
62
|
+
:return: psnr value
|
|
63
|
+
"""
|
|
64
|
+
a = complex_abs(a.permute(0, 2, 3, 1))
|
|
65
|
+
b = complex_abs(b.permute(0, 2, 3, 1))
|
|
66
|
+
return cal_psnr(a, b)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def complex_abs(data):
|
|
70
|
+
"""
|
|
71
|
+
Compute the absolute value of a complex valued input tensor.
|
|
72
|
+
Args:
|
|
73
|
+
data (torch.Tensor): A complex valued tensor, where the size of the final dimension
|
|
74
|
+
should be 2.
|
|
75
|
+
Returns:
|
|
76
|
+
torch.Tensor: Absolute value of data
|
|
77
|
+
"""
|
|
78
|
+
assert data.size(-1) == 2
|
|
79
|
+
return (data**2).sum(dim=-1).sqrt()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def norm_psnr(a, b, complex=False):
|
|
83
|
+
return cal_psnr(
|
|
84
|
+
(a - a.min()) / (a.max() - a.min()),
|
|
85
|
+
(b - b.min()) / (b.max() - b.min()),
|
|
86
|
+
complex=complex,
|
|
87
|
+
)
|