deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,163 @@
1
+ import torch
2
+ from deepinv.optim.fixed_point import FixedPoint
3
+ from deepinv.optim.optim_iterators import *
4
+ from deepinv.unfolded.unfolded import BaseUnfold
5
+ from deepinv.optim.optimizers import create_iterator
6
+ from deepinv.optim.data_fidelity import L2
7
+
8
+
9
+ class BaseDEQ(BaseUnfold):
10
+ r"""
11
+ Base class for deep equilibrium (DEQ) algorithms. Child of :class:`deepinv.unfolded.BaseUnfold`.
12
+
13
+ Enables to turn any fixed-point algorithm into a DEQ algorithm, i.e. an algorithm
14
+ that can be virtually unrolled infinitely leveraging the implicit function theorem.
15
+ The backward pass is performed using fixed point iterations to find solutions of the fixed-point equation
16
+
17
+ .. math::
18
+
19
+ \begin{equation}
20
+ v = \left(\frac{\partial \operatorname{FixedPoint}(x^\star)}{\partial x^\star} \right )^T v + u.
21
+ \end{equation}
22
+
23
+ where :math:`u` is the incoming gradient from the backward pass,
24
+ and :math:`x^\star` is the equilibrium point of the forward pass.
25
+
26
+ See `this tutorial <http://implicit-layers-tutorial.org/deep_equilibrium_models/>`_ for more details.
27
+
28
+ For now DEQ is only possible with PGD, HQS and GD optimization algorithms.
29
+
30
+ :param int max_iter_backward: Maximum number of backward iterations. Default: ``50``.
31
+ :param bool anderson_acceleration_backward: if True, the Anderson acceleration is used at iteration of fixed-point algorithm for computing the backward pass. Default: ``False``.
32
+ :param int history_size_backward: size of the history used for the Anderson acceleration for the backward pass. Default: ``5``.
33
+ :param float beta_anderson_acc_backward: momentum of the Anderson acceleration step for the backward pass. Default: ``1.0``.
34
+ :param float eps_anderson_acc_backward: regularization parameter of the Anderson acceleration step for the backward pass. Default: ``1e-4``.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ *args,
40
+ max_iter_backward=50,
41
+ anderson_acceleration_backward=False,
42
+ history_size_backward=5,
43
+ beta_anderson_acc_backward=1.0,
44
+ eps_anderson_acc_backward=1e-4,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(*args, **kwargs)
48
+ self.max_iter_backward = max_iter_backward
49
+ self.anderson_acceleration = anderson_acceleration_backward
50
+ self.history_size = history_size_backward
51
+ self.beta_anderson_acc = beta_anderson_acc_backward
52
+ self.eps_anderson_acc = eps_anderson_acc_backward
53
+
54
+ def forward(self, y, physics, x_gt=None, compute_metrics=False):
55
+ r"""
56
+ The forward pass of the DEQ algorithm. Compared to :class:`deepinv.unfolded.BaseUnfold`, the backward algorithm is performed using fixed point iterations.
57
+
58
+ :param torch.Tensor y: Input tensor.
59
+ :param deepinv.Physics physics: Physics object.
60
+ :param torch.Tensor x_gt: (optional) ground truth image, for plotting the PSNR across optim iterations.
61
+ :param bool compute_metrics: whether to compute the metrics or not. Default: ``False``.
62
+ :return: If ``compute_metrics`` is ``False``, returns (:class:`torch.Tensor`) the output of the algorithm.
63
+ Else, returns (:class:`torch.Tensor`, dict) the output of the algorithm and the metrics.
64
+ """
65
+ with torch.no_grad(): # Perform the forward pass without gradient tracking
66
+ X, metrics = self.fixed_point(
67
+ y, physics, x_gt=x_gt, compute_metrics=compute_metrics
68
+ )
69
+ # Once, at the equilibrium point, performs one additional iteration with gradient tracking.
70
+ cur_data_fidelity = self.update_data_fidelity_fn(self.max_iter - 1)
71
+ cur_prior = self.update_prior_fn(self.max_iter - 1)
72
+ cur_params = self.update_params_fn(self.max_iter - 1)
73
+ x = self.fixed_point.iterator(
74
+ X, cur_data_fidelity, cur_prior, cur_params, y, physics
75
+ )["est"][0]
76
+ # Another iteration for jacobian computation via automatic differentiation.
77
+ x0 = x.clone().detach().requires_grad_()
78
+ f0 = self.fixed_point.iterator(
79
+ {"est": (x0,)}, cur_data_fidelity, cur_prior, cur_params, y, physics
80
+ )["est"][0]
81
+
82
+ # Add a backwards hook that takes the incoming backward gradient `X["est"][0]` and solves the fixed point equation
83
+ def backward_hook(grad):
84
+ class backward_iterator(OptimIterator):
85
+ def __init__(self, **kwargs):
86
+ super().__init__(**kwargs)
87
+
88
+ def forward(self, X, *args, **kwargs):
89
+ return {
90
+ "est": (
91
+ torch.autograd.grad(f0, x0, X["est"][0], retain_graph=True)[
92
+ 0
93
+ ]
94
+ + grad,
95
+ )
96
+ }
97
+
98
+ # Use the :class:`deepinv.optim.fixed_point.FixedPoint` class to solve the fixed point equation
99
+ def init_iterate_fn(y, physics, F_fn=None):
100
+ return {"est": (grad,)} # initialize the fixed point algorithm.
101
+
102
+ backward_FP = FixedPoint(
103
+ backward_iterator(),
104
+ init_iterate_fn=init_iterate_fn,
105
+ max_iter=self.max_iter_backward,
106
+ check_conv_fn=self.check_conv_fn,
107
+ anderson_acceleration=self.anderson_acceleration,
108
+ history_size=self.history_size,
109
+ beta_anderson_acc=self.beta_anderson_acc,
110
+ eps_anderson_acc=self.eps_anderson_acc,
111
+ )
112
+ g = backward_FP({"est": (grad,)}, None)[0]["est"][0]
113
+ return g
114
+
115
+ if x.requires_grad:
116
+ x.register_hook(backward_hook)
117
+
118
+ if compute_metrics:
119
+ return x, metrics
120
+ else:
121
+ return x
122
+
123
+
124
+ def DEQ_builder(
125
+ iteration,
126
+ params_algo={"lambda": 1.0, "stepsize": 1.0},
127
+ data_fidelity=L2(),
128
+ prior=None,
129
+ F_fn=None,
130
+ g_first=False,
131
+ **kwargs,
132
+ ):
133
+ r"""
134
+ Helper function for building an instance of the :meth:`BaseDEQ` class.
135
+
136
+ :param str, deepinv.optim.optim_iterators.OptimIterator iteration: either the name of the algorithm to be used,
137
+ or directly an optim iterator.
138
+ If an algorithm name (string), should be either ``"PGD"`` (proximal gradient descent), ``"ADMM"`` (ADMM),
139
+ ``"HQS"`` (half-quadratic splitting), ``"CP"`` (Chambolle-Pock) or ``"DRS"`` (Douglas Rachford).
140
+ :param dict params_algo: dictionary containing all the relevant parameters for running the algorithm,
141
+ e.g. the stepsize, regularisation parameter, denoising standard deviation.
142
+ Each value of the dictionary can be either Iterable (distinct value for each iteration) or
143
+ a single float (same value for each iteration).
144
+ Default: ``{"stepsize": 1.0, "lambda": 1.0}``. See :any:`optim-params` for more details.
145
+ :param list, deepinv.optim.DataFidelity: data-fidelity term.
146
+ Either a single instance (same data-fidelity for each iteration) or a list of instances of
147
+ :meth:`deepinv.optim.DataFidelity` (distinct data-fidelity for each iteration). Default: `None`.
148
+ :param list, deepinv.optim.Prior prior: regularization prior.
149
+ Either a single instance (same prior for each iteration) or a list of instances of
150
+ deepinv.optim.Prior (distinct prior for each iteration). Default: `None`.
151
+ :param callable F_fn: Custom user input cost function. default: None.
152
+ :param bool g_first: whether to perform the step on :math:`g` before that on :math:`f` before or not. default: False
153
+ :param kwargs: additional arguments to be passed to the :meth:`BaseUnfold` class.
154
+ """
155
+ iterator = create_iterator(iteration, prior=prior, F_fn=F_fn, g_first=g_first)
156
+ return BaseDEQ(
157
+ iterator,
158
+ has_cost=iterator.has_cost,
159
+ data_fidelity=data_fidelity,
160
+ prior=prior,
161
+ params_algo=params_algo,
162
+ **kwargs,
163
+ )
@@ -0,0 +1,87 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from deepinv.optim.optim_iterators import *
4
+ from deepinv.optim.optimizers import BaseOptim, create_iterator
5
+
6
+
7
+ class BaseUnfold(BaseOptim):
8
+ r"""
9
+ Base class for unfolded algorithms. Child of :class:`deepinv.optim.BaseOptim`.
10
+
11
+ Enables to turn any iterative optimization algorithm into an unfolded algorithm, i.e. an algorithm
12
+ that can be trained end-to-end, with learnable parameters. Recall that the algorithms have the
13
+ following form (see :meth:`deepinv.optim.optim_iterators.BaseIterator`):
14
+
15
+ .. math::
16
+ \begin{aligned}
17
+ z_{k+1} &= \operatorname{step}_f(x_k, z_k, y, A, \lambda, \gamma, ...)\\
18
+ x_{k+1} &= \operatorname{step}_g(x_k, z_k, y, A, \sigma, ...)
19
+ \end{aligned}
20
+
21
+ where :math:`\operatorname{step}_f` and :math:`\operatorname{step}_g` are learnable modules.
22
+ These modules encompass trainable parameters of the algorithm (e.g. stepsize :math:`\gamma`, regularization parameter :math:`\lambda`, prior parameter (`g_param`) :math:`\sigma` ...)
23
+ as well as trainable priors (e.g. a deep denoiser).
24
+
25
+ :param list trainable_params: List of parameters to be trained. Each parameter should be a key of the ``params_algo`` dictionary for the :class:`deepinv.optim.optim_iterators.BaseIterator` class.
26
+ This does not encompass the trainable weights of the prior module .
27
+ :param torch.device device: Device on which to perform the computations. Default: `torch.device("cpu")`.
28
+ :param args: Non-keyword arguments to be passed to the :class:`deepinv.optim.BaseOptim` class.
29
+ :param kwargs: Keyword arguments to be passed to the :class:`deepinv.optim.BaseOptim` class.
30
+ """
31
+
32
+ def __init__(self, *args, trainable_params=[], device="cpu", **kwargs):
33
+ super().__init__(*args, **kwargs)
34
+ # Each parameter in `init_params_algo` is a list, which is converted to a `nn.ParameterList` if they should be trained.
35
+ for param_key in trainable_params:
36
+ if param_key in self.init_params_algo.keys():
37
+ param_value = self.init_params_algo[param_key]
38
+ self.init_params_algo[param_key] = nn.ParameterList(
39
+ [nn.Parameter(torch.tensor(el).to(device)) for el in param_value]
40
+ )
41
+ self.init_params_algo = nn.ParameterDict(self.init_params_algo)
42
+ self.params_algo = self.init_params_algo.copy()
43
+ # The prior (list of instances of :class:`deepinv.optim.Prior) is converted to a `nn.ModuleList` to be trainable.
44
+ self.prior = nn.ModuleList(self.prior)
45
+ self.data_fidelity = nn.ModuleList(self.data_fidelity)
46
+
47
+
48
+ def unfolded_builder(
49
+ iteration,
50
+ params_algo={"lambda": 1.0, "stepsize": 1.0},
51
+ data_fidelity=None,
52
+ prior=None,
53
+ F_fn=None,
54
+ g_first=False,
55
+ **kwargs,
56
+ ):
57
+ r"""
58
+ Helper function for building an instance of the :meth:`BaseUnfold` class.
59
+
60
+ :param str, deepinv.optim.optim_iterators.OptimIterator iteration: either the name of the algorithm to be used,
61
+ or directly an optim iterator.
62
+ If an algorithm name (string), should be either ``"PGD"`` (proximal gradient descent), ``"ADMM"`` (ADMM),
63
+ ``"HQS"`` (half-quadratic splitting), ``"CP"`` (Chambolle-Pock) or ``"DRS"`` (Douglas Rachford).
64
+ :param dict params_algo: dictionary containing all the relevant parameters for running the algorithm,
65
+ e.g. the stepsize, regularisation parameter, denoising standard deviation.
66
+ Each value of the dictionary can be either Iterable (distinct value for each iteration) or
67
+ a single float (same value for each iteration).
68
+ Default: ``{"stepsize": 1.0, "lambda": 1.0}``. See :any:`optim-params` for more details.
69
+ :param list, deepinv.optim.DataFidelity: data-fidelity term.
70
+ Either a single instance (same data-fidelity for each iteration) or a list of instances of
71
+ :meth:`deepinv.optim.DataFidelity` (distinct data-fidelity for each iteration). Default: `None`.
72
+ :param list, deepinv.optim.Prior prior: regularization prior.
73
+ Either a single instance (same prior for each iteration) or a list of instances of
74
+ deepinv.optim.Prior (distinct prior for each iteration). Default: `None`.
75
+ :param callable F_fn: Custom user input cost function. default: None.
76
+ :param bool g_first: whether to perform the step on :math:`g` before that on :math:`f` before or not. default: False
77
+ :param kwargs: additional arguments to be passed to the :meth:`BaseUnfold` class.
78
+ """
79
+ iterator = create_iterator(iteration, prior=prior, F_fn=F_fn, g_first=g_first)
80
+ return BaseUnfold(
81
+ iterator,
82
+ has_cost=iterator.has_cost,
83
+ data_fidelity=data_fidelity,
84
+ prior=prior,
85
+ params_algo=params_algo,
86
+ **kwargs,
87
+ )
@@ -0,0 +1,17 @@
1
+ from .logger import AverageMeter, ProgressMeter, get_timestamp
2
+ from .nn import save_model, load_checkpoint, investigate_model
3
+ from .metric import cal_psnr, cal_mse, cal_psnr_complex
4
+ from .plotting import (
5
+ rescale_img,
6
+ plot,
7
+ torch2cpu,
8
+ plot_curves,
9
+ plot_parameters,
10
+ make_grid,
11
+ wandb_imgs,
12
+ wandb_plot_curves,
13
+ resize_pad_square_tensor,
14
+ )
15
+ from .demo import load_url_image
16
+ from .nn import get_freer_gpu, TensorList, rand_like, zeros_like, randn_like, ones_like
17
+ from .phantoms import RandomPhantomDataset, SheppLoganDataset
deepinv/utils/demo.py ADDED
@@ -0,0 +1,171 @@
1
+ import requests
2
+ import shutil
3
+ import os
4
+ import zipfile
5
+ import torch
6
+ import torchvision
7
+ import numpy as np
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ from tqdm import tqdm
12
+
13
+
14
+ class MRIData(torch.utils.data.Dataset):
15
+ """fastMRI dataset (knee subset)."""
16
+
17
+ def __init__(
18
+ self, root_dir, train=True, sample_index=None, tag=900, transform=None
19
+ ):
20
+ x = torch.load(str(root_dir) + ".pt")
21
+ x = x.squeeze()
22
+ self.transform = transform
23
+
24
+ if train:
25
+ self.x = x[:tag]
26
+ else:
27
+ self.x = x[tag:, ...]
28
+
29
+ self.x = torch.stack([self.x, torch.zeros_like(self.x)], dim=1)
30
+
31
+ if sample_index is not None:
32
+ self.x = self.x[sample_index].unsqueeze(0)
33
+
34
+ def __getitem__(self, index):
35
+ x = self.x[index]
36
+
37
+ if self.transform is not None:
38
+ x = self.transform(x)
39
+
40
+ return x
41
+
42
+ def __len__(self):
43
+ return len(self.x)
44
+
45
+
46
+ def get_git_root():
47
+ import git
48
+
49
+ git_repo = git.Repo(".", search_parent_directories=True)
50
+ git_root = git_repo.git.rev_parse("--show-toplevel")
51
+ return git_root
52
+
53
+
54
+ def get_image_dataset_url(dataset_name, file_type="zip"):
55
+ return (
56
+ "https://huggingface.co/datasets/deepinv/images/resolve/main/"
57
+ + dataset_name
58
+ + "."
59
+ + file_type
60
+ + "?download=true"
61
+ )
62
+
63
+
64
+ def get_degradation_url(file_name):
65
+ return (
66
+ "https://huggingface.co/datasets/deepinv/degradations/resolve/main/"
67
+ + file_name
68
+ + "?download=true"
69
+ )
70
+
71
+
72
+ def get_image_url(file_name):
73
+ return (
74
+ "https://huggingface.co/datasets/deepinv/images/resolve/main/"
75
+ + file_name
76
+ + "?download=true"
77
+ )
78
+
79
+
80
+ def load_dataset(
81
+ dataset_name, data_dir, transform, download=True, url=None, train=True
82
+ ):
83
+ dataset_dir = data_dir / dataset_name
84
+ if dataset_name == "fastmri_knee_singlecoil":
85
+ file_type = "pt"
86
+ else:
87
+ file_type = "zip"
88
+ if download and not dataset_dir.exists():
89
+ dataset_dir.mkdir(parents=True, exist_ok=True)
90
+ if url is None:
91
+ url = get_image_dataset_url(dataset_name, file_type)
92
+ response = requests.get(url, stream=True)
93
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
94
+ block_size = 1024 # 1 Kibibyte
95
+ print("Downloading " + str(dataset_dir) + f".{file_type}")
96
+ progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
97
+ with open(str(dataset_dir) + f".{file_type}", "wb") as file:
98
+ for data in response.iter_content(block_size):
99
+ progress_bar.update(len(data))
100
+ file.write(data)
101
+ progress_bar.close()
102
+
103
+ if file_type == "zip":
104
+ with zipfile.ZipFile(str(dataset_dir) + ".zip") as zip_ref:
105
+ zip_ref.extractall(str(data_dir))
106
+ # remove temp file
107
+ os.remove(str(dataset_dir) + f".{file_type}")
108
+ print(f"{dataset_name} dataset downloaded in {data_dir}")
109
+ else:
110
+ shutil.move(
111
+ str(dataset_dir) + f".{file_type}",
112
+ str(dataset_dir / dataset_name) + f".{file_type}",
113
+ )
114
+ if dataset_name == "fastmri_knee_singlecoil":
115
+ dataset = MRIData(
116
+ train=train, root_dir=dataset_dir / dataset_name, transform=transform
117
+ )
118
+ else:
119
+ dataset = torchvision.datasets.ImageFolder(
120
+ root=dataset_dir, transform=transform
121
+ )
122
+ return dataset
123
+
124
+
125
+ def load_degradation(name, data_dir, index=0, download=True):
126
+ path = data_dir / name
127
+ if download and not path.exists():
128
+ data_dir.mkdir(parents=True, exist_ok=True)
129
+ url = get_degradation_url(name)
130
+ with requests.get(url, stream=True) as r:
131
+ with open(str(data_dir / name), "wb") as f:
132
+ shutil.copyfileobj(r.raw, f)
133
+ print(f"{name} degradation downloaded in {data_dir}")
134
+ deg = np.load(path, allow_pickle=True)
135
+ deg_torch = torch.from_numpy(deg[index]) # .unsqueeze(0).unsqueeze(0)
136
+ return deg_torch
137
+
138
+
139
+ def load_url_image(
140
+ url=None, img_size=None, grayscale=False, resize_mode="crop", device="cpu"
141
+ ):
142
+ r"""
143
+
144
+ Load an image from a URL and return a torch.Tensor.
145
+
146
+ :param str url: URL of the image file.
147
+ :param int, tuple[int] img_size: Size of the image to return.
148
+ :param bool grayscale: Whether to convert the image to grayscale.
149
+ :param str resize_mode: If ``img_size`` is not None, options are ``"crop"`` or ``"resize"``.
150
+ :param str device: Device on which to load the image (gpu or cpu).
151
+ :return: :class:`torch.Tensor` containing the image.
152
+ """
153
+
154
+ response = requests.get(url)
155
+ img = Image.open(BytesIO(response.content))
156
+ transform_list = []
157
+ if img_size is not None:
158
+ if resize_mode == "crop":
159
+ transform_list.append(transforms.CenterCrop(img_size))
160
+ elif resize_mode == "resize":
161
+ transform_list.append(transforms.Resize(img_size))
162
+ else:
163
+ raise ValueError(
164
+ f"resize_mode must be either 'crop' or 'resize', got {resize_mode}"
165
+ )
166
+ if grayscale:
167
+ transform_list.append(transforms.Grayscale())
168
+ transform_list.append(transforms.ToTensor())
169
+ transform = transforms.Compose(transform_list)
170
+ x = transform(img).unsqueeze(0).to(device)
171
+ return x
@@ -0,0 +1,93 @@
1
+ import os
2
+ import csv
3
+ from datetime import datetime
4
+
5
+
6
+ # utils
7
+ class AverageMeter(object):
8
+ """Computes and stores the average and current value"""
9
+
10
+ def __init__(self, name, fmt=":f"):
11
+ self.name = name
12
+ self.fmt = fmt
13
+ self.reset()
14
+
15
+ def reset(self):
16
+ self.val = 0
17
+ self.avg = 0
18
+ self.sum = 0
19
+ self.count = 0
20
+
21
+ def update(self, val, n=1):
22
+ self.val = val
23
+ self.sum += val * n
24
+ self.count += n
25
+ self.avg = self.sum / self.count
26
+
27
+ def __str__(self):
28
+ fmtstr = "{name}={avg" + self.fmt + "}"
29
+ return fmtstr.format(**self.__dict__)
30
+
31
+
32
+ class ProgressMeter(object):
33
+ def __init__(self, num_epochs, meters, surfix="", prefix=""):
34
+ self.epoch_fmtstr = self._get_epoch_fmtstr(num_epochs)
35
+ self.meters = meters
36
+ self.surfix = surfix
37
+ self.prefix = prefix
38
+
39
+ def display(self, epoch):
40
+ entries = [self.surfix]
41
+ entries += [get_timestamp()]
42
+ entries += [self.epoch_fmtstr.format(epoch)]
43
+ entries += [str(meter) for meter in self.meters]
44
+ entries += [self.prefix]
45
+ print("\t".join(entries))
46
+
47
+ def _get_epoch_fmtstr(self, num_epochs):
48
+ num_digits = len(str(num_epochs // 1))
49
+ fmt = "{:" + str(num_digits) + "d}"
50
+ return "[" + fmt + "/" + fmt.format(num_epochs) + "]"
51
+
52
+
53
+ # --------------------------------
54
+ # logger
55
+ # --------------------------------
56
+ def get_timestamp():
57
+ return datetime.now().strftime("%y-%m-%d-%H:%M:%S")
58
+
59
+
60
+ class LOG(object):
61
+ def __init__(self, filepath, filename, field_name):
62
+ self.filepath = filepath
63
+ self.filename = filename
64
+ self.field_name = field_name
65
+
66
+ self.logfile, self.logwriter = csv_log(
67
+ file_name=os.path.join(filepath, filename + ".csv"), field_name=field_name
68
+ )
69
+ self.logwriter.writeheader()
70
+
71
+ def record(self, *args):
72
+ dict = {}
73
+ for i in range(len(self.field_name)):
74
+ dict[self.field_name[i]] = args[i]
75
+ self.logwriter.writerow(dict)
76
+
77
+ def close(self):
78
+ self.logfile.close()
79
+
80
+ def print(self, msg):
81
+ logT(msg)
82
+
83
+
84
+ def csv_log(file_name, field_name):
85
+ assert file_name is not None
86
+ assert field_name is not None
87
+ logfile = open(file_name, "w")
88
+ logwriter = csv.DictWriter(logfile, fieldnames=field_name)
89
+ return logfile, logwriter
90
+
91
+
92
+ def logT(*args, **kwargs):
93
+ print(get_timestamp(), *args, **kwargs)
@@ -0,0 +1,87 @@
1
+ import torch
2
+
3
+
4
+ def norm(a):
5
+ return a.pow(2).sum(dim=3).sum(dim=2).sqrt().unsqueeze(2).unsqueeze(3)
6
+
7
+
8
+ def cal_angle(a, b):
9
+ norm_a = (a * a).flatten().sum().sqrt()
10
+ norm_b = (b * b).flatten().sum().sqrt()
11
+ angle = (a * b).flatten().sum() / (norm_a * norm_b)
12
+ angle = angle.acos() / 3.14159265359
13
+ return angle.detach().cpu().numpy()
14
+
15
+
16
+ def cal_psnr(a, b, max_pixel=1, normalize=False):
17
+ r"""
18
+ Computes the peak signal-to-noise ratio (PSNR)
19
+
20
+ If the tensors have size (N, C, H, W), then the PSNR is computed as
21
+
22
+ .. math::
23
+ \text{PSNR} = \frac{20}{N} \log_{10} \frac{MAX_I}{\sqrt{\|a- b\|^2_2 / (CHW) }}
24
+
25
+ where :math:`MAX_I` is the maximum possible pixel value of the image (e.g. 1.0 for a
26
+ normalized image), and :math:`a` and :math:`b` are the estimate and reference images.
27
+
28
+ :param torch.Tensor a: tensor estimate
29
+ :param torch.Tensor b: tensor reference
30
+ :param float max_pixel: maximum pixel value
31
+ :param bool normalize: if ``True``, a is normalized to have the same norm as b.
32
+ """
33
+ with torch.no_grad():
34
+ if type(a) is list or type(a) is tuple:
35
+ a = a[0]
36
+ b = b[0]
37
+
38
+ if normalize:
39
+ an = a / norm(a) * norm(b)
40
+ else:
41
+ an = a
42
+
43
+ mse = (an - b).pow(2).reshape(an.shape[0], -1).mean(dim=1)
44
+ mse[mse == 0] = 1e-10
45
+ psnr = 20 * torch.log10(max_pixel / mse.sqrt())
46
+
47
+ return psnr.mean().detach().cpu().item()
48
+
49
+
50
+ def cal_mse(a, b):
51
+ """Computes the mean squared error (MSE)"""
52
+ with torch.no_grad():
53
+ mse = torch.mean((a - b) ** 2)
54
+ return mse
55
+
56
+
57
+ def cal_psnr_complex(a, b):
58
+ """
59
+ first permute the dimension, such that the last dimension of the tensor is 2 (real, imag)
60
+ :param a: shape [N,2,H,W]
61
+ :param b: shape [N,2,H,W]
62
+ :return: psnr value
63
+ """
64
+ a = complex_abs(a.permute(0, 2, 3, 1))
65
+ b = complex_abs(b.permute(0, 2, 3, 1))
66
+ return cal_psnr(a, b)
67
+
68
+
69
+ def complex_abs(data):
70
+ """
71
+ Compute the absolute value of a complex valued input tensor.
72
+ Args:
73
+ data (torch.Tensor): A complex valued tensor, where the size of the final dimension
74
+ should be 2.
75
+ Returns:
76
+ torch.Tensor: Absolute value of data
77
+ """
78
+ assert data.size(-1) == 2
79
+ return (data**2).sum(dim=-1).sqrt()
80
+
81
+
82
+ def norm_psnr(a, b, complex=False):
83
+ return cal_psnr(
84
+ (a - a.min()) / (a.max() - a.min()),
85
+ (b - b.min()) / (b.max() - b.min()),
86
+ complex=complex,
87
+ )