deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/loss/metric.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from torch import autograd as autograd
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LpNorm(torch.nn.Module):
|
|
8
|
+
r"""
|
|
9
|
+
:math:`\ell_p` metric for :math:`p>0`.
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
If ``onesided=False`` then the metric is defined as
|
|
13
|
+
:math:`d(x,y)=\|x-y\|_p^p`.
|
|
14
|
+
|
|
15
|
+
otherwise it is the one-sided error https://ieeexplore.ieee.org/abstract/document/6418031/, defined as
|
|
16
|
+
:math:`d(x,y)= \|\max(x\circ y) \|_p^p`. where :math:`\circ` denotes element-wise multiplication.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, p=2, onesided=False):
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.p = p
|
|
23
|
+
self.onesided = onesided
|
|
24
|
+
|
|
25
|
+
def forward(self, x, y):
|
|
26
|
+
if self.onesided:
|
|
27
|
+
return torch.nn.functional.relu(-x * y).flatten().pow(self.p).mean()
|
|
28
|
+
else:
|
|
29
|
+
return (x - y).flatten().abs().pow(self.p).mean()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def mse():
|
|
33
|
+
return nn.MSELoss()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def l1():
|
|
37
|
+
return nn.L1Loss()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class CharbonnierLoss(nn.Module):
|
|
41
|
+
r"""
|
|
42
|
+
Charbonnier Loss
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, eps=1e-9):
|
|
47
|
+
super(CharbonnierLoss, self).__init__()
|
|
48
|
+
self.eps = eps
|
|
49
|
+
|
|
50
|
+
def forward(self, x, y):
|
|
51
|
+
diff = x - y
|
|
52
|
+
loss = torch.mean(torch.sqrt((diff * diff) + self.eps))
|
|
53
|
+
return loss
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def r1_penalty(real_pred, real_img):
|
|
57
|
+
"""R1 regularization for discriminator. The core idea is to
|
|
58
|
+
penalize the gradient on real data alone: when the
|
|
59
|
+
generator distribution produces the true data distribution
|
|
60
|
+
and the discriminator is equal to 0 on the data manifold, the
|
|
61
|
+
gradient penalty ensures that the discriminator cannot create
|
|
62
|
+
a non-zero gradient orthogonal to the data manifold without
|
|
63
|
+
suffering a loss in the GAN game.
|
|
64
|
+
Ref:
|
|
65
|
+
Eq. 9 in Which training methods for GANs do actually converge.
|
|
66
|
+
"""
|
|
67
|
+
grad_real = autograd.grad(
|
|
68
|
+
outputs=real_pred.sum(), inputs=real_img, create_graph=True
|
|
69
|
+
)[0]
|
|
70
|
+
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
|
|
71
|
+
return grad_penalty
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
|
75
|
+
noise = torch.randn_like(fake_img) / math.sqrt(
|
|
76
|
+
fake_img.shape[2] * fake_img.shape[3]
|
|
77
|
+
)
|
|
78
|
+
grad = autograd.grad(
|
|
79
|
+
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
|
|
80
|
+
)[0]
|
|
81
|
+
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
|
82
|
+
|
|
83
|
+
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
|
84
|
+
|
|
85
|
+
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
|
86
|
+
|
|
87
|
+
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
|
|
91
|
+
"""Calculate gradient penalty for wgan-gp.
|
|
92
|
+
Args:
|
|
93
|
+
discriminator (nn.Module): Network for the discriminator.
|
|
94
|
+
real_data (Tensor): Real input data.
|
|
95
|
+
fake_data (Tensor): Fake input data.
|
|
96
|
+
weight (Tensor): Weight tensor. Default: None.
|
|
97
|
+
Returns:
|
|
98
|
+
Tensor: A tensor for gradient penalty.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
batch_size = real_data.size(0)
|
|
102
|
+
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
|
|
103
|
+
|
|
104
|
+
# interpolate between real_data and fake_data
|
|
105
|
+
interpolates = alpha * real_data + (1.0 - alpha) * fake_data
|
|
106
|
+
interpolates = autograd.Variable(interpolates, requires_grad=True)
|
|
107
|
+
|
|
108
|
+
disc_interpolates = discriminator(interpolates)
|
|
109
|
+
gradients = autograd.grad(
|
|
110
|
+
outputs=disc_interpolates,
|
|
111
|
+
inputs=interpolates,
|
|
112
|
+
grad_outputs=torch.ones_like(disc_interpolates),
|
|
113
|
+
create_graph=True,
|
|
114
|
+
retain_graph=True,
|
|
115
|
+
only_inputs=True,
|
|
116
|
+
)[0]
|
|
117
|
+
|
|
118
|
+
if weight is not None:
|
|
119
|
+
gradients = gradients * weight
|
|
120
|
+
|
|
121
|
+
gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
|
122
|
+
if weight is not None:
|
|
123
|
+
gradients_penalty /= torch.mean(weight)
|
|
124
|
+
|
|
125
|
+
return gradients_penalty
|
deepinv/loss/moi.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MOILoss(nn.Module):
|
|
7
|
+
r"""
|
|
8
|
+
Multi-operator imaging loss
|
|
9
|
+
|
|
10
|
+
This loss can be used to learn when signals are observed via multiple (possibly incomplete)
|
|
11
|
+
forward operators :math:`\{A_g\}_{g=1}^{G}`,
|
|
12
|
+
i.e., :math:`y_i = A_{g_i}x_i` where :math:`g_i\in \{1,\dots,G\}` (see https://arxiv.org/abs/2201.12151).
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
The measurement consistency loss is defined as
|
|
16
|
+
|
|
17
|
+
.. math::
|
|
18
|
+
|
|
19
|
+
\| \hat{x} - \inverse{\hat{x},A_g} \|^2
|
|
20
|
+
|
|
21
|
+
where :math:`\hat{x}=\inverse{y,A_s}` is a reconstructed signal (observed via operator :math:`A_s`) and
|
|
22
|
+
:math:`A_g` is a forward operator sampled at random from a set :math:`\{A_g\}_{g=1}^{G}`.
|
|
23
|
+
|
|
24
|
+
By default, the error is computed using the MSE metric, however any other metric (e.g., :math:`\ell_1`)
|
|
25
|
+
can be used as well.
|
|
26
|
+
|
|
27
|
+
:param torch.nn.Module metric: metric used for computing data consistency,
|
|
28
|
+
which is set as the mean squared error by default.
|
|
29
|
+
:param float weight: total weight of the loss
|
|
30
|
+
:param bool apply_noise: if ``True``, the augmented measurement is computed with the full sensing model
|
|
31
|
+
:math:`\sensor{\noise{\forw{\hat{x}}}}` (i.e., noise and sensor model),
|
|
32
|
+
otherwise is generated as :math:`\forw{\hat{x}}`.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self, physics_list, metric=torch.nn.MSELoss(), apply_noise=True, weight=1.0
|
|
37
|
+
):
|
|
38
|
+
super(MOILoss, self).__init__()
|
|
39
|
+
self.name = "moi"
|
|
40
|
+
self.physics_list = physics_list
|
|
41
|
+
self.metric = metric
|
|
42
|
+
self.weight = weight
|
|
43
|
+
self.noise = apply_noise
|
|
44
|
+
|
|
45
|
+
def forward(self, x_net, model, **kwargs):
|
|
46
|
+
r"""
|
|
47
|
+
Computes the MOI loss.
|
|
48
|
+
|
|
49
|
+
:param torch.Tensor x_net: Reconstructed image :math:`\inverse{y}`.
|
|
50
|
+
:param list of deepinv.physics.Physics physics: List containing the :math:`G` different forward operators
|
|
51
|
+
associated with the measurements.
|
|
52
|
+
:param torch.nn.Module model: Reconstruction function.
|
|
53
|
+
:return: (torch.Tensor) loss.
|
|
54
|
+
"""
|
|
55
|
+
j = np.random.randint(len(self.physics_list))
|
|
56
|
+
|
|
57
|
+
if self.noise:
|
|
58
|
+
y = self.physics_list[j](x_net)
|
|
59
|
+
else:
|
|
60
|
+
y = self.physics_list[j].A(x_net)
|
|
61
|
+
|
|
62
|
+
x2 = model(y, self.physics_list[j])
|
|
63
|
+
|
|
64
|
+
return self.weight * self.metric(x2, x_net)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class JacobianSpectralNorm(nn.Module):
|
|
6
|
+
r"""
|
|
7
|
+
Computes the spectral norm of the Jacobian.
|
|
8
|
+
|
|
9
|
+
Given a function :math:`f:\mathbb{R}^n\to\mathbb{R}^n`, this module computes the spectral
|
|
10
|
+
norm of the Jacobian of :math:`f` in :math:`x`, i.e.
|
|
11
|
+
|
|
12
|
+
.. math::
|
|
13
|
+
\|\frac{df}{du}(x)\|_2.
|
|
14
|
+
|
|
15
|
+
This spectral norm is computed with a power method leveraging jacobian vector products, as proposed in `<https://arxiv.org/abs/2012.13247v2>`_.
|
|
16
|
+
|
|
17
|
+
:param int max_iter: maximum numer of iteration of the power method.
|
|
18
|
+
:param float tol: tolerance for the convergence of the power method.
|
|
19
|
+
:param bool eval_mode: set to `False` if one does not want to backpropagate through the spectral norm (default), set to `True` otherwise.
|
|
20
|
+
:param bool verbose: whether to print computation details or not.
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
Example of usage:
|
|
24
|
+
|
|
25
|
+
::
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
from deepinv.loss.regularisers import JacobianSpectralNorm
|
|
29
|
+
|
|
30
|
+
reg_l2 = JacobianSpectralNorm(max_iter=10, tol=1e-3, eval_mode=False, verbose=True)
|
|
31
|
+
|
|
32
|
+
A = torch.diag(torch.Tensor(range(1, 51))) # creates a diagonal matrix with largest eigenvalue = 50
|
|
33
|
+
x = torch.randn_like(A).requires_grad_()
|
|
34
|
+
out = A @ x
|
|
35
|
+
|
|
36
|
+
regval = reg_l2(out, x)
|
|
37
|
+
print(regval) # >> returns approx 50
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, max_iter=10, tol=1e-3, eval_mode=False, verbose=False):
|
|
41
|
+
super(JacobianSpectralNorm, self).__init__()
|
|
42
|
+
self.name = "jsn"
|
|
43
|
+
self.max_iter = max_iter
|
|
44
|
+
self.tol = tol
|
|
45
|
+
self.eval = eval_mode
|
|
46
|
+
self.verbose = verbose
|
|
47
|
+
|
|
48
|
+
def forward(self, y, x, **kwargs):
|
|
49
|
+
"""
|
|
50
|
+
Computes the spectral norm of the Jacobian of :math:`f` in :math:`x`.
|
|
51
|
+
|
|
52
|
+
.. warning::
|
|
53
|
+
The input :math:`x` must have requires_grad=True before evaluating :math:`f`.
|
|
54
|
+
|
|
55
|
+
:param torch.Tensor y: output of the function :math:`f` at :math:`x`.
|
|
56
|
+
:param torch.Tensor x: input of the function :math:`f`.
|
|
57
|
+
"""
|
|
58
|
+
u = torch.randn_like(x)
|
|
59
|
+
u = u / torch.norm(u.flatten(), p=2)
|
|
60
|
+
|
|
61
|
+
zold = torch.zeros_like(u)
|
|
62
|
+
|
|
63
|
+
for it in range(self.max_iter):
|
|
64
|
+
# Double backward trick. From https://gist.github.com/apaszke/c7257ac04cb8debb82221764f6d117ad
|
|
65
|
+
w = torch.ones_like(y, requires_grad=True)
|
|
66
|
+
v = torch.autograd.grad(
|
|
67
|
+
torch.autograd.grad(y, x, w, create_graph=True),
|
|
68
|
+
w,
|
|
69
|
+
u,
|
|
70
|
+
create_graph=not self.eval,
|
|
71
|
+
)[
|
|
72
|
+
0
|
|
73
|
+
] # v = A(u)
|
|
74
|
+
|
|
75
|
+
(v,) = torch.autograd.grad(y, x, v, retain_graph=True, create_graph=True)
|
|
76
|
+
|
|
77
|
+
z = torch.dot(u.flatten(), v.flatten()) / torch.norm(u, p=2) ** 2
|
|
78
|
+
|
|
79
|
+
if it > 0:
|
|
80
|
+
rel_var = torch.norm(z - zold)
|
|
81
|
+
if rel_var < self.tol and self.verbose:
|
|
82
|
+
print(
|
|
83
|
+
"Power iteration converged at iteration: ",
|
|
84
|
+
it,
|
|
85
|
+
", val: ",
|
|
86
|
+
z.sqrt().item(),
|
|
87
|
+
", relvar :",
|
|
88
|
+
rel_var,
|
|
89
|
+
)
|
|
90
|
+
break
|
|
91
|
+
zold = z.detach().clone()
|
|
92
|
+
|
|
93
|
+
u = v / torch.norm(v.flatten(), p=2)
|
|
94
|
+
|
|
95
|
+
if self.eval:
|
|
96
|
+
w.detach_()
|
|
97
|
+
v.detach_()
|
|
98
|
+
u.detach_()
|
|
99
|
+
|
|
100
|
+
return z.view(-1).sqrt()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class FNEJacobianSpectralNorm(nn.Module):
|
|
104
|
+
r"""
|
|
105
|
+
Computes the Firm-Nonexpansiveness Jacobian spectral norm.
|
|
106
|
+
|
|
107
|
+
Given a function :math:`f:\mathbb{R}^n\to\mathbb{R}^n`, this module computes the spectral
|
|
108
|
+
norm of the Jacobian of :math:`2f-\operatorname{Id}` (where :math:`\operatorname{Id}` denotes the
|
|
109
|
+
identity) in :math:`x`, i.e.
|
|
110
|
+
|
|
111
|
+
.. math::
|
|
112
|
+
\|\frac{d(2f-\operatorname{Id})}{du}(x)\|_2,
|
|
113
|
+
|
|
114
|
+
as proposed in `<https://arxiv.org/abs/2012.13247v2>`_.
|
|
115
|
+
This spectral norm is computed with the :meth:`deepinv.loss.JacobianSpectralNorm` module.
|
|
116
|
+
|
|
117
|
+
:param int max_iter: maximum numer of iteration of the power method.
|
|
118
|
+
:param float tol: tolerance for the convergence of the power method.
|
|
119
|
+
:param bool eval_mode: set to `False` if one does not want to backpropagate through the spectral norm (default), set to `True` otherwise.
|
|
120
|
+
:param bool verbose: whether to print computation details or not.
|
|
121
|
+
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(self, max_iter=10, tol=1e-3, verbose=False, eval_mode=False):
|
|
125
|
+
super(FNEJacobianSpectralNorm, self).__init__()
|
|
126
|
+
self.spectral_norm_module = JacobianSpectralNorm(
|
|
127
|
+
max_iter=max_iter, tol=tol, verbose=verbose, eval_mode=eval_mode
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def forward(
|
|
131
|
+
self, y_in, x_in, model, *args_model, interpolation=False, **kwargs_model
|
|
132
|
+
):
|
|
133
|
+
r"""
|
|
134
|
+
Computes the Firm-Nonexpansiveness (FNE) Jacobian spectral norm of a model.
|
|
135
|
+
|
|
136
|
+
:param torch.Tensor y_in: input of the model (by default).
|
|
137
|
+
:param torch.Tensor x_in: an additional point of the model (by default).
|
|
138
|
+
:param torch.nn.Module model: neural network, or function, of which we want to compute the FNE Jacobian spectral norm.
|
|
139
|
+
:param `*args_model`: additional arguments of the model.
|
|
140
|
+
:param bool interpolation: whether to input to model an interpolation between y_in and x_in instead of y_in (default is `False`).
|
|
141
|
+
:param `**kargs_model`: additional keyword arguments of the model.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
if interpolation:
|
|
145
|
+
eta = torch.rand(y_in.size(0), 1, 1, 1, requires_grad=True).to(y_in.device)
|
|
146
|
+
x = eta * y_in.detach() + (1 - eta) * x_in.detach()
|
|
147
|
+
else:
|
|
148
|
+
x = y_in
|
|
149
|
+
|
|
150
|
+
x.requires_grad_()
|
|
151
|
+
x_out = model(x, *args_model, **kwargs_model)
|
|
152
|
+
|
|
153
|
+
y = 2.0 * x_out - x
|
|
154
|
+
|
|
155
|
+
return self.spectral_norm_module(y, x)
|
deepinv/loss/score.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ScoreLoss(nn.Module):
|
|
7
|
+
r"""
|
|
8
|
+
Approximates score of a distribution.
|
|
9
|
+
|
|
10
|
+
It approximates the score of the measurement distribution :math:`S(y)\approx \nabla \log p(y)`
|
|
11
|
+
https://proceedings.neurips.cc/paper_files/paper/2021/file/077b83af57538aa183971a2fe0971ec1-Paper.pdf.
|
|
12
|
+
|
|
13
|
+
The score loss is defined as
|
|
14
|
+
|
|
15
|
+
.. math::
|
|
16
|
+
|
|
17
|
+
\| \epsilon - \sigma S(y+ \sigma \epsilon) \|^2
|
|
18
|
+
|
|
19
|
+
where :math:`\epsilon` is sampled from :math:`N(0,I)` and
|
|
20
|
+
:math:`\sigma` is sampled from :math:`N(0,I\delta^2)`.
|
|
21
|
+
|
|
22
|
+
:param float delta: hyperparameter :math:`\delta` controlling the level of noise.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, delta):
|
|
26
|
+
super(ScoreLoss, self).__init__()
|
|
27
|
+
self.name = "score"
|
|
28
|
+
self.metric = torch.nn.MSELoss()
|
|
29
|
+
self.delta = delta
|
|
30
|
+
|
|
31
|
+
def forward(self, y, model, **kwargs):
|
|
32
|
+
r"""
|
|
33
|
+
Computes the Score loss.
|
|
34
|
+
|
|
35
|
+
:param torch.Tensor y: measurements.
|
|
36
|
+
:param torch.nn.Module model: Reconstruction function.
|
|
37
|
+
:return: (torch.Tensor) loss.
|
|
38
|
+
"""
|
|
39
|
+
std = np.randn() * self.delta
|
|
40
|
+
noise = torch.randn_like(y)
|
|
41
|
+
return self.metric(noise, std * model(std * noise + y))
|
deepinv/loss/sup.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class SupLoss(nn.Module):
|
|
6
|
+
r"""
|
|
7
|
+
Standard supervised loss
|
|
8
|
+
|
|
9
|
+
The supervised loss is defined as
|
|
10
|
+
|
|
11
|
+
.. math::
|
|
12
|
+
|
|
13
|
+
\|x-\inverse{y}\|^2
|
|
14
|
+
|
|
15
|
+
where :math:`\inverse{y}` is the reconstructed signal and :math:`x` is the ground truth target.
|
|
16
|
+
|
|
17
|
+
By default, the error is computed using the MSE metric, however any other metric (e.g., :math:`\ell_1`)
|
|
18
|
+
can be used as well.
|
|
19
|
+
|
|
20
|
+
:param torch.nn.Module metric: metric used for computing data consistency,
|
|
21
|
+
which is set as the mean squared error by default.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, metric=torch.nn.MSELoss()):
|
|
25
|
+
super(SupLoss, self).__init__()
|
|
26
|
+
self.name = "supervised"
|
|
27
|
+
self.metric = metric
|
|
28
|
+
|
|
29
|
+
def forward(self, x, x_net, **kwargs):
|
|
30
|
+
r"""
|
|
31
|
+
Computes the loss.
|
|
32
|
+
|
|
33
|
+
:param torch.Tensor x: Target (ground-truth) image.
|
|
34
|
+
:param torch.Tensor x_net: Reconstructed image :math:\inverse{y}.
|
|
35
|
+
:return: (torch.Tensor) loss.
|
|
36
|
+
"""
|
|
37
|
+
return self.metric(x_net, x)
|