deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from deepinv.physics.forward import LinearPhysics
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def dst1(x):
|
|
7
|
+
r"""
|
|
8
|
+
Orthogonal Discrete Sine Transform, Type I
|
|
9
|
+
The transform is performed across the last dimension of the input signal
|
|
10
|
+
Due to orthogonality we have ``dst1(dst1(x)) = x``.
|
|
11
|
+
|
|
12
|
+
:param torch.tensor x: the input signal
|
|
13
|
+
:return: (torch.tensor) the DST-I of the signal over the last dimension
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
x_shape = x.shape
|
|
17
|
+
|
|
18
|
+
b = int(np.prod(x_shape[:-1]))
|
|
19
|
+
n = x_shape[-1]
|
|
20
|
+
x = x.view(-1, n)
|
|
21
|
+
|
|
22
|
+
z = torch.zeros(b, 1, device=x.device)
|
|
23
|
+
x = torch.cat([z, x, z, -x.flip([1])], dim=1)
|
|
24
|
+
x = torch.view_as_real(torch.fft.rfft(x, norm="ortho"))
|
|
25
|
+
x = x[:, 1:-1, 1]
|
|
26
|
+
return x.view(*x_shape)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class CompressedSensing(LinearPhysics):
|
|
30
|
+
r"""
|
|
31
|
+
Compressed Sensing forward operator. Creates a random sampling :math:`m \times n` matrix where :math:`n` is the
|
|
32
|
+
number of elements of the signal, i.e., ``np.prod(img_shape)`` and ``m`` is the number of measurements.
|
|
33
|
+
|
|
34
|
+
This class generates a random iid Gaussian matrix if ``fast=False``
|
|
35
|
+
|
|
36
|
+
.. math::
|
|
37
|
+
|
|
38
|
+
A_{i,j} \sim \mathcal{N}(0,\frac{1}{m})
|
|
39
|
+
|
|
40
|
+
or a Subsampled Orthogonal with Random Signs matrix (SORS) if ``fast=True`` (see https://arxiv.org/abs/1506.03521)
|
|
41
|
+
|
|
42
|
+
.. math::
|
|
43
|
+
|
|
44
|
+
A = \text{diag}(m)D\text{diag}(s)
|
|
45
|
+
|
|
46
|
+
where :math:`s\in\{-1,1\}^{n}` is a random sign flip with probability 0.5,
|
|
47
|
+
:math:`D\in\mathbb{R}^{n\times n}` is a fast orthogonal transform (DST-1) and
|
|
48
|
+
:math:`\text{diag}(m)\in\mathbb{R}^{m\times n}` is random subsampling matrix, which keeps :math:`m` out of :math:`n` entries.
|
|
49
|
+
|
|
50
|
+
It is recommended to use ``fast=True`` for image sizes bigger than 32 x 32, since the forward computation with
|
|
51
|
+
``fast=False`` has an :math:`O(mn)` complexity, whereas with ``fast=True`` it has an :math:`O(n \log n)` complexity.
|
|
52
|
+
|
|
53
|
+
An existing operator can be loaded from a saved .pth file via ``self.load_state_dict(save_path)``,
|
|
54
|
+
in a similar fashion to :class:`torch.nn.Module`.
|
|
55
|
+
|
|
56
|
+
.. note::
|
|
57
|
+
|
|
58
|
+
If ``fast=False``, the forward operator has a norm which tends to :math:`(1+\sqrt{n/m})^2` for large :math:`n`
|
|
59
|
+
and :math:`m` due to the `Marcenko-Pastur law
|
|
60
|
+
<https://en.wikipedia.org/wiki/Marchenko%E2%80%93Pastur_distribution>`_.
|
|
61
|
+
If ``fast=True``, the forward operator has a unit norm.
|
|
62
|
+
|
|
63
|
+
:param int m: number of measurements.
|
|
64
|
+
:param tuple img_shape: shape (C, H, W) of inputs.
|
|
65
|
+
:param bool fast: The operator is iid Gaussian if false, otherwise A is a SORS matrix with the Discrete Sine Transform (type I).
|
|
66
|
+
:param bool channelwise: Channels are processed independently using the same random forward operator.
|
|
67
|
+
:param torch.type dtype: Forward matrix is stored as a dtype.
|
|
68
|
+
:param str device: Device to store the forward matrix.
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
m,
|
|
75
|
+
img_shape,
|
|
76
|
+
fast=False,
|
|
77
|
+
channelwise=False,
|
|
78
|
+
dtype=torch.float,
|
|
79
|
+
device="cpu",
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
super().__init__(**kwargs)
|
|
83
|
+
self.name = f"CS_m{m}"
|
|
84
|
+
self.img_shape = img_shape
|
|
85
|
+
self.fast = fast
|
|
86
|
+
self.channelwise = channelwise
|
|
87
|
+
self.dtype = dtype
|
|
88
|
+
|
|
89
|
+
if channelwise:
|
|
90
|
+
n = int(np.prod(img_shape[1:]))
|
|
91
|
+
else:
|
|
92
|
+
n = int(np.prod(img_shape))
|
|
93
|
+
|
|
94
|
+
if self.fast:
|
|
95
|
+
self.n = n
|
|
96
|
+
self.D = torch.ones(self.n, device=device)
|
|
97
|
+
self.D[torch.rand_like(self.D) > 0.5] = -1.0
|
|
98
|
+
self.mask = torch.zeros(self.n, device=device)
|
|
99
|
+
idx = np.sort(np.random.choice(self.n, size=m, replace=False))
|
|
100
|
+
self.mask[torch.from_numpy(idx)] = 1
|
|
101
|
+
self.mask = self.mask.type(torch.bool)
|
|
102
|
+
|
|
103
|
+
self.D = torch.nn.Parameter(self.D, requires_grad=False)
|
|
104
|
+
self.mask = torch.nn.Parameter(self.mask, requires_grad=False)
|
|
105
|
+
else:
|
|
106
|
+
self._A = torch.randn((m, n), device=device) / np.sqrt(m)
|
|
107
|
+
self._A_dagger = torch.linalg.pinv(self._A)
|
|
108
|
+
self._A = torch.nn.Parameter(self._A, requires_grad=False)
|
|
109
|
+
self._A_dagger = torch.nn.Parameter(self._A_dagger, requires_grad=False)
|
|
110
|
+
self._A_adjoint = (
|
|
111
|
+
torch.nn.Parameter(self._A.t(), requires_grad=False)
|
|
112
|
+
.type(dtype)
|
|
113
|
+
.to(device)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def A(self, x):
|
|
117
|
+
N, C = x.shape[:2]
|
|
118
|
+
if self.channelwise:
|
|
119
|
+
x = x.reshape(N * C, -1)
|
|
120
|
+
else:
|
|
121
|
+
x = x.reshape(N, -1)
|
|
122
|
+
|
|
123
|
+
if self.fast:
|
|
124
|
+
y = dst1(x * self.D)[:, self.mask]
|
|
125
|
+
else:
|
|
126
|
+
y = torch.einsum("in, mn->im", x, self._A)
|
|
127
|
+
|
|
128
|
+
if self.channelwise:
|
|
129
|
+
y = y.view(N, C, -1)
|
|
130
|
+
|
|
131
|
+
return y
|
|
132
|
+
|
|
133
|
+
def A_adjoint(self, y):
|
|
134
|
+
N = y.shape[0]
|
|
135
|
+
C, H, W = self.img_shape[0], self.img_shape[1], self.img_shape[2]
|
|
136
|
+
|
|
137
|
+
if self.channelwise:
|
|
138
|
+
N2 = N * C
|
|
139
|
+
y = y.view(N2, -1)
|
|
140
|
+
else:
|
|
141
|
+
N2 = N
|
|
142
|
+
|
|
143
|
+
if self.fast:
|
|
144
|
+
y2 = torch.zeros((N2, self.n), device=y.device)
|
|
145
|
+
y2[:, self.mask] = y.type(y2.dtype)
|
|
146
|
+
x = dst1(y2) * self.D
|
|
147
|
+
else:
|
|
148
|
+
x = torch.einsum("im, nm->in", y, self._A_adjoint) # x:(N, n, 1)
|
|
149
|
+
|
|
150
|
+
x = x.view(N, C, H, W)
|
|
151
|
+
return x
|
|
152
|
+
|
|
153
|
+
def A_dagger(self, y):
|
|
154
|
+
if self.fast:
|
|
155
|
+
return self.A_adjoint(y)
|
|
156
|
+
else:
|
|
157
|
+
N = y.shape[0]
|
|
158
|
+
C, H, W = self.img_shape[0], self.img_shape[1], self.img_shape[2]
|
|
159
|
+
|
|
160
|
+
if self.channelwise:
|
|
161
|
+
y = y.reshape(N * C, -1)
|
|
162
|
+
|
|
163
|
+
x = torch.einsum("im, nm->in", y, self._A_dagger)
|
|
164
|
+
x = x.reshape(N, C, H, W)
|
|
165
|
+
return x
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# if __name__ == "__main__":
|
|
169
|
+
# device = "cuda:0"
|
|
170
|
+
#
|
|
171
|
+
# # for comparing fast=True and fast=False forward matrices.
|
|
172
|
+
# for i in range(1):
|
|
173
|
+
# n = 2 ** (i + 4)
|
|
174
|
+
# im_size = (1, n, n)
|
|
175
|
+
# m = int(np.prod(im_size))
|
|
176
|
+
# x = torch.randn((1,) + im_size, device=device)
|
|
177
|
+
#
|
|
178
|
+
# print((dst1(dst1(x)) - x).flatten().abs().sum())
|
|
179
|
+
#
|
|
180
|
+
# physics = CompressedSensing(img_shape=im_size, m=m, fast=True, device=device)
|
|
181
|
+
#
|
|
182
|
+
# print((physics.A_adjoint(physics.A(x)) - x).flatten().abs().sum())
|
|
183
|
+
# print(f"adjointness: {physics.adjointness_test(x)}")
|
|
184
|
+
# print(f"norm: {physics.power_method(x, verbose=False)}")
|
|
185
|
+
# start = torch.cuda.Event(enable_timing=True)
|
|
186
|
+
# end = torch.cuda.Event(enable_timing=True)
|
|
187
|
+
# start.record()
|
|
188
|
+
# for j in range(100):
|
|
189
|
+
# y = physics.A(x)
|
|
190
|
+
# xhat = physics.A_dagger(y)
|
|
191
|
+
# end.record()
|
|
192
|
+
#
|
|
193
|
+
# # print((xhat-x).pow(2).flatten().mean())
|
|
194
|
+
#
|
|
195
|
+
# # Waits for everything to finish running
|
|
196
|
+
# torch.cuda.synchronize()
|
|
197
|
+
# print(start.elapsed_time(end))
|