deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/loss/metric.py ADDED
@@ -0,0 +1,125 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import autograd as autograd
5
+
6
+
7
+ class LpNorm(torch.nn.Module):
8
+ r"""
9
+ :math:`\ell_p` metric for :math:`p>0`.
10
+
11
+
12
+ If ``onesided=False`` then the metric is defined as
13
+ :math:`d(x,y)=\|x-y\|_p^p`.
14
+
15
+ otherwise it is the one-sided error https://ieeexplore.ieee.org/abstract/document/6418031/, defined as
16
+ :math:`d(x,y)= \|\max(x\circ y) \|_p^p`. where :math:`\circ` denotes element-wise multiplication.
17
+
18
+ """
19
+
20
+ def __init__(self, p=2, onesided=False):
21
+ super().__init__()
22
+ self.p = p
23
+ self.onesided = onesided
24
+
25
+ def forward(self, x, y):
26
+ if self.onesided:
27
+ return torch.nn.functional.relu(-x * y).flatten().pow(self.p).mean()
28
+ else:
29
+ return (x - y).flatten().abs().pow(self.p).mean()
30
+
31
+
32
+ def mse():
33
+ return nn.MSELoss()
34
+
35
+
36
+ def l1():
37
+ return nn.L1Loss()
38
+
39
+
40
+ class CharbonnierLoss(nn.Module):
41
+ r"""
42
+ Charbonnier Loss
43
+
44
+ """
45
+
46
+ def __init__(self, eps=1e-9):
47
+ super(CharbonnierLoss, self).__init__()
48
+ self.eps = eps
49
+
50
+ def forward(self, x, y):
51
+ diff = x - y
52
+ loss = torch.mean(torch.sqrt((diff * diff) + self.eps))
53
+ return loss
54
+
55
+
56
+ def r1_penalty(real_pred, real_img):
57
+ """R1 regularization for discriminator. The core idea is to
58
+ penalize the gradient on real data alone: when the
59
+ generator distribution produces the true data distribution
60
+ and the discriminator is equal to 0 on the data manifold, the
61
+ gradient penalty ensures that the discriminator cannot create
62
+ a non-zero gradient orthogonal to the data manifold without
63
+ suffering a loss in the GAN game.
64
+ Ref:
65
+ Eq. 9 in Which training methods for GANs do actually converge.
66
+ """
67
+ grad_real = autograd.grad(
68
+ outputs=real_pred.sum(), inputs=real_img, create_graph=True
69
+ )[0]
70
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
71
+ return grad_penalty
72
+
73
+
74
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
75
+ noise = torch.randn_like(fake_img) / math.sqrt(
76
+ fake_img.shape[2] * fake_img.shape[3]
77
+ )
78
+ grad = autograd.grad(
79
+ outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
80
+ )[0]
81
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
82
+
83
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
84
+
85
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
86
+
87
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
88
+
89
+
90
+ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
91
+ """Calculate gradient penalty for wgan-gp.
92
+ Args:
93
+ discriminator (nn.Module): Network for the discriminator.
94
+ real_data (Tensor): Real input data.
95
+ fake_data (Tensor): Fake input data.
96
+ weight (Tensor): Weight tensor. Default: None.
97
+ Returns:
98
+ Tensor: A tensor for gradient penalty.
99
+ """
100
+
101
+ batch_size = real_data.size(0)
102
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
103
+
104
+ # interpolate between real_data and fake_data
105
+ interpolates = alpha * real_data + (1.0 - alpha) * fake_data
106
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
107
+
108
+ disc_interpolates = discriminator(interpolates)
109
+ gradients = autograd.grad(
110
+ outputs=disc_interpolates,
111
+ inputs=interpolates,
112
+ grad_outputs=torch.ones_like(disc_interpolates),
113
+ create_graph=True,
114
+ retain_graph=True,
115
+ only_inputs=True,
116
+ )[0]
117
+
118
+ if weight is not None:
119
+ gradients = gradients * weight
120
+
121
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
122
+ if weight is not None:
123
+ gradients_penalty /= torch.mean(weight)
124
+
125
+ return gradients_penalty
deepinv/loss/moi.py ADDED
@@ -0,0 +1,64 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class MOILoss(nn.Module):
7
+ r"""
8
+ Multi-operator imaging loss
9
+
10
+ This loss can be used to learn when signals are observed via multiple (possibly incomplete)
11
+ forward operators :math:`\{A_g\}_{g=1}^{G}`,
12
+ i.e., :math:`y_i = A_{g_i}x_i` where :math:`g_i\in \{1,\dots,G\}` (see https://arxiv.org/abs/2201.12151).
13
+
14
+
15
+ The measurement consistency loss is defined as
16
+
17
+ .. math::
18
+
19
+ \| \hat{x} - \inverse{\hat{x},A_g} \|^2
20
+
21
+ where :math:`\hat{x}=\inverse{y,A_s}` is a reconstructed signal (observed via operator :math:`A_s`) and
22
+ :math:`A_g` is a forward operator sampled at random from a set :math:`\{A_g\}_{g=1}^{G}`.
23
+
24
+ By default, the error is computed using the MSE metric, however any other metric (e.g., :math:`\ell_1`)
25
+ can be used as well.
26
+
27
+ :param torch.nn.Module metric: metric used for computing data consistency,
28
+ which is set as the mean squared error by default.
29
+ :param float weight: total weight of the loss
30
+ :param bool apply_noise: if ``True``, the augmented measurement is computed with the full sensing model
31
+ :math:`\sensor{\noise{\forw{\hat{x}}}}` (i.e., noise and sensor model),
32
+ otherwise is generated as :math:`\forw{\hat{x}}`.
33
+ """
34
+
35
+ def __init__(
36
+ self, physics_list, metric=torch.nn.MSELoss(), apply_noise=True, weight=1.0
37
+ ):
38
+ super(MOILoss, self).__init__()
39
+ self.name = "moi"
40
+ self.physics_list = physics_list
41
+ self.metric = metric
42
+ self.weight = weight
43
+ self.noise = apply_noise
44
+
45
+ def forward(self, x_net, model, **kwargs):
46
+ r"""
47
+ Computes the MOI loss.
48
+
49
+ :param torch.Tensor x_net: Reconstructed image :math:`\inverse{y}`.
50
+ :param list of deepinv.physics.Physics physics: List containing the :math:`G` different forward operators
51
+ associated with the measurements.
52
+ :param torch.nn.Module model: Reconstruction function.
53
+ :return: (torch.Tensor) loss.
54
+ """
55
+ j = np.random.randint(len(self.physics_list))
56
+
57
+ if self.noise:
58
+ y = self.physics_list[j](x_net)
59
+ else:
60
+ y = self.physics_list[j].A(x_net)
61
+
62
+ x2 = model(y, self.physics_list[j])
63
+
64
+ return self.weight * self.metric(x2, x_net)
@@ -0,0 +1,155 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class JacobianSpectralNorm(nn.Module):
6
+ r"""
7
+ Computes the spectral norm of the Jacobian.
8
+
9
+ Given a function :math:`f:\mathbb{R}^n\to\mathbb{R}^n`, this module computes the spectral
10
+ norm of the Jacobian of :math:`f` in :math:`x`, i.e.
11
+
12
+ .. math::
13
+ \|\frac{df}{du}(x)\|_2.
14
+
15
+ This spectral norm is computed with a power method leveraging jacobian vector products, as proposed in `<https://arxiv.org/abs/2012.13247v2>`_.
16
+
17
+ :param int max_iter: maximum numer of iteration of the power method.
18
+ :param float tol: tolerance for the convergence of the power method.
19
+ :param bool eval_mode: set to `False` if one does not want to backpropagate through the spectral norm (default), set to `True` otherwise.
20
+ :param bool verbose: whether to print computation details or not.
21
+
22
+
23
+ Example of usage:
24
+
25
+ ::
26
+
27
+ import torch
28
+ from deepinv.loss.regularisers import JacobianSpectralNorm
29
+
30
+ reg_l2 = JacobianSpectralNorm(max_iter=10, tol=1e-3, eval_mode=False, verbose=True)
31
+
32
+ A = torch.diag(torch.Tensor(range(1, 51))) # creates a diagonal matrix with largest eigenvalue = 50
33
+ x = torch.randn_like(A).requires_grad_()
34
+ out = A @ x
35
+
36
+ regval = reg_l2(out, x)
37
+ print(regval) # >> returns approx 50
38
+ """
39
+
40
+ def __init__(self, max_iter=10, tol=1e-3, eval_mode=False, verbose=False):
41
+ super(JacobianSpectralNorm, self).__init__()
42
+ self.name = "jsn"
43
+ self.max_iter = max_iter
44
+ self.tol = tol
45
+ self.eval = eval_mode
46
+ self.verbose = verbose
47
+
48
+ def forward(self, y, x, **kwargs):
49
+ """
50
+ Computes the spectral norm of the Jacobian of :math:`f` in :math:`x`.
51
+
52
+ .. warning::
53
+ The input :math:`x` must have requires_grad=True before evaluating :math:`f`.
54
+
55
+ :param torch.Tensor y: output of the function :math:`f` at :math:`x`.
56
+ :param torch.Tensor x: input of the function :math:`f`.
57
+ """
58
+ u = torch.randn_like(x)
59
+ u = u / torch.norm(u.flatten(), p=2)
60
+
61
+ zold = torch.zeros_like(u)
62
+
63
+ for it in range(self.max_iter):
64
+ # Double backward trick. From https://gist.github.com/apaszke/c7257ac04cb8debb82221764f6d117ad
65
+ w = torch.ones_like(y, requires_grad=True)
66
+ v = torch.autograd.grad(
67
+ torch.autograd.grad(y, x, w, create_graph=True),
68
+ w,
69
+ u,
70
+ create_graph=not self.eval,
71
+ )[
72
+ 0
73
+ ] # v = A(u)
74
+
75
+ (v,) = torch.autograd.grad(y, x, v, retain_graph=True, create_graph=True)
76
+
77
+ z = torch.dot(u.flatten(), v.flatten()) / torch.norm(u, p=2) ** 2
78
+
79
+ if it > 0:
80
+ rel_var = torch.norm(z - zold)
81
+ if rel_var < self.tol and self.verbose:
82
+ print(
83
+ "Power iteration converged at iteration: ",
84
+ it,
85
+ ", val: ",
86
+ z.sqrt().item(),
87
+ ", relvar :",
88
+ rel_var,
89
+ )
90
+ break
91
+ zold = z.detach().clone()
92
+
93
+ u = v / torch.norm(v.flatten(), p=2)
94
+
95
+ if self.eval:
96
+ w.detach_()
97
+ v.detach_()
98
+ u.detach_()
99
+
100
+ return z.view(-1).sqrt()
101
+
102
+
103
+ class FNEJacobianSpectralNorm(nn.Module):
104
+ r"""
105
+ Computes the Firm-Nonexpansiveness Jacobian spectral norm.
106
+
107
+ Given a function :math:`f:\mathbb{R}^n\to\mathbb{R}^n`, this module computes the spectral
108
+ norm of the Jacobian of :math:`2f-\operatorname{Id}` (where :math:`\operatorname{Id}` denotes the
109
+ identity) in :math:`x`, i.e.
110
+
111
+ .. math::
112
+ \|\frac{d(2f-\operatorname{Id})}{du}(x)\|_2,
113
+
114
+ as proposed in `<https://arxiv.org/abs/2012.13247v2>`_.
115
+ This spectral norm is computed with the :meth:`deepinv.loss.JacobianSpectralNorm` module.
116
+
117
+ :param int max_iter: maximum numer of iteration of the power method.
118
+ :param float tol: tolerance for the convergence of the power method.
119
+ :param bool eval_mode: set to `False` if one does not want to backpropagate through the spectral norm (default), set to `True` otherwise.
120
+ :param bool verbose: whether to print computation details or not.
121
+
122
+ """
123
+
124
+ def __init__(self, max_iter=10, tol=1e-3, verbose=False, eval_mode=False):
125
+ super(FNEJacobianSpectralNorm, self).__init__()
126
+ self.spectral_norm_module = JacobianSpectralNorm(
127
+ max_iter=max_iter, tol=tol, verbose=verbose, eval_mode=eval_mode
128
+ )
129
+
130
+ def forward(
131
+ self, y_in, x_in, model, *args_model, interpolation=False, **kwargs_model
132
+ ):
133
+ r"""
134
+ Computes the Firm-Nonexpansiveness (FNE) Jacobian spectral norm of a model.
135
+
136
+ :param torch.Tensor y_in: input of the model (by default).
137
+ :param torch.Tensor x_in: an additional point of the model (by default).
138
+ :param torch.nn.Module model: neural network, or function, of which we want to compute the FNE Jacobian spectral norm.
139
+ :param `*args_model`: additional arguments of the model.
140
+ :param bool interpolation: whether to input to model an interpolation between y_in and x_in instead of y_in (default is `False`).
141
+ :param `**kargs_model`: additional keyword arguments of the model.
142
+ """
143
+
144
+ if interpolation:
145
+ eta = torch.rand(y_in.size(0), 1, 1, 1, requires_grad=True).to(y_in.device)
146
+ x = eta * y_in.detach() + (1 - eta) * x_in.detach()
147
+ else:
148
+ x = y_in
149
+
150
+ x.requires_grad_()
151
+ x_out = model(x, *args_model, **kwargs_model)
152
+
153
+ y = 2.0 * x_out - x
154
+
155
+ return self.spectral_norm_module(y, x)
deepinv/loss/score.py ADDED
@@ -0,0 +1,41 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ class ScoreLoss(nn.Module):
7
+ r"""
8
+ Approximates score of a distribution.
9
+
10
+ It approximates the score of the measurement distribution :math:`S(y)\approx \nabla \log p(y)`
11
+ https://proceedings.neurips.cc/paper_files/paper/2021/file/077b83af57538aa183971a2fe0971ec1-Paper.pdf.
12
+
13
+ The score loss is defined as
14
+
15
+ .. math::
16
+
17
+ \| \epsilon - \sigma S(y+ \sigma \epsilon) \|^2
18
+
19
+ where :math:`\epsilon` is sampled from :math:`N(0,I)` and
20
+ :math:`\sigma` is sampled from :math:`N(0,I\delta^2)`.
21
+
22
+ :param float delta: hyperparameter :math:`\delta` controlling the level of noise.
23
+ """
24
+
25
+ def __init__(self, delta):
26
+ super(ScoreLoss, self).__init__()
27
+ self.name = "score"
28
+ self.metric = torch.nn.MSELoss()
29
+ self.delta = delta
30
+
31
+ def forward(self, y, model, **kwargs):
32
+ r"""
33
+ Computes the Score loss.
34
+
35
+ :param torch.Tensor y: measurements.
36
+ :param torch.nn.Module model: Reconstruction function.
37
+ :return: (torch.Tensor) loss.
38
+ """
39
+ std = np.randn() * self.delta
40
+ noise = torch.randn_like(y)
41
+ return self.metric(noise, std * model(std * noise + y))
deepinv/loss/sup.py ADDED
@@ -0,0 +1,37 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SupLoss(nn.Module):
6
+ r"""
7
+ Standard supervised loss
8
+
9
+ The supervised loss is defined as
10
+
11
+ .. math::
12
+
13
+ \|x-\inverse{y}\|^2
14
+
15
+ where :math:`\inverse{y}` is the reconstructed signal and :math:`x` is the ground truth target.
16
+
17
+ By default, the error is computed using the MSE metric, however any other metric (e.g., :math:`\ell_1`)
18
+ can be used as well.
19
+
20
+ :param torch.nn.Module metric: metric used for computing data consistency,
21
+ which is set as the mean squared error by default.
22
+ """
23
+
24
+ def __init__(self, metric=torch.nn.MSELoss()):
25
+ super(SupLoss, self).__init__()
26
+ self.name = "supervised"
27
+ self.metric = metric
28
+
29
+ def forward(self, x, x_net, **kwargs):
30
+ r"""
31
+ Computes the loss.
32
+
33
+ :param torch.Tensor x: Target (ground-truth) image.
34
+ :param torch.Tensor x_net: Reconstructed image :math:\inverse{y}.
35
+ :return: (torch.Tensor) loss.
36
+ """
37
+ return self.metric(x_net, x)