deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from deepinv.physics.noise import GaussianNoise
|
|
2
|
+
from deepinv.physics.forward import LinearPhysics
|
|
3
|
+
from deepinv.physics.blur import Downsampling
|
|
4
|
+
from deepinv.physics.range import Decolorize
|
|
5
|
+
from deepinv.utils import TensorList
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Pansharpen(LinearPhysics):
|
|
9
|
+
"""
|
|
10
|
+
Pansharpening forward operator.
|
|
11
|
+
|
|
12
|
+
The measurements consist of a high resolution grayscale image and a low resolution RGB image, and
|
|
13
|
+
are represented using :class:`deepinv.utils.TensorList`, where the first element is the RGB image and the second
|
|
14
|
+
element is the grayscale image.
|
|
15
|
+
|
|
16
|
+
By default, the downsampling is done with a gaussian filter with standard deviation equal to the downsampling,
|
|
17
|
+
however, the user can provide a custom downsampling filter.
|
|
18
|
+
|
|
19
|
+
It is possible to assign a different noise model to the RGB and grayscale images.
|
|
20
|
+
|
|
21
|
+
Example usage:
|
|
22
|
+
|
|
23
|
+
::
|
|
24
|
+
|
|
25
|
+
import deepinv
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
x = torch.randn(1, 3, 256, 256)
|
|
29
|
+
physics = deepinv.physics.Pansharpen(img_size=x.shape[1:], device=x.device)
|
|
30
|
+
|
|
31
|
+
y = physics(x) # returns a TensorList with the RGB and grayscale images
|
|
32
|
+
|
|
33
|
+
x_adj = physics.A_adjoint(y)
|
|
34
|
+
x_pinv = physics.A_dagger(y)
|
|
35
|
+
|
|
36
|
+
deepinv.utils.plot([y[0], y[1], x_adj, x_pinv, x], titles=['low res color', 'high res gray',
|
|
37
|
+
'A_adjoint', 'A_dagger', 'x'])
|
|
38
|
+
|
|
39
|
+
:param tuple[int] img_size: size of the input image.
|
|
40
|
+
:param int factor: downsampling factor.
|
|
41
|
+
:param torch.nn.Module noise_color: noise model for the RGB image.
|
|
42
|
+
:param torch.nn.Module noise_gray: noise model for the grayscale image.
|
|
43
|
+
:param torch.Tensor, str, NoneType filter: Downsampling filter. It can be 'gaussian', 'bilinear' or 'bicubic' or a
|
|
44
|
+
custom ``torch.Tensor`` filter. If ``None``, no filtering is applied.
|
|
45
|
+
:param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``.
|
|
46
|
+
If ``padding='valid'`` the blurred output is smaller than the image (no padding)
|
|
47
|
+
otherwise the blurred output has the same size as the image.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
img_size,
|
|
53
|
+
factor=4,
|
|
54
|
+
noise_color=GaussianNoise(sigma=0.0),
|
|
55
|
+
noise_gray=GaussianNoise(sigma=0.05),
|
|
56
|
+
filter="gaussian",
|
|
57
|
+
device="cpu",
|
|
58
|
+
padding="circular",
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(**kwargs)
|
|
62
|
+
|
|
63
|
+
self.downsampling = Downsampling(
|
|
64
|
+
img_size=img_size,
|
|
65
|
+
factor=factor,
|
|
66
|
+
filter=filter,
|
|
67
|
+
device=device,
|
|
68
|
+
padding=padding,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self.noise_color = noise_color
|
|
72
|
+
self.noise_gray = noise_gray
|
|
73
|
+
self.colorize = Decolorize()
|
|
74
|
+
|
|
75
|
+
def A(self, x):
|
|
76
|
+
return TensorList([self.downsampling(x), self.colorize(x)])
|
|
77
|
+
|
|
78
|
+
def A_adjoint(self, y):
|
|
79
|
+
return self.downsampling.A_adjoint(y[0]) + self.colorize.A_adjoint(y[1])
|
|
80
|
+
|
|
81
|
+
def forward(self, x):
|
|
82
|
+
return TensorList(
|
|
83
|
+
[self.noise_color(self.downsampling(x)), self.noise_gray(self.colorize(x))]
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# test code
|
|
88
|
+
# if __name__ == "__main__":
|
|
89
|
+
# device = "cuda:0"
|
|
90
|
+
# import torch
|
|
91
|
+
# import torchvision
|
|
92
|
+
# import deepinv
|
|
93
|
+
#
|
|
94
|
+
# device = "cuda:0"
|
|
95
|
+
#
|
|
96
|
+
# x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
|
|
97
|
+
# x = x.unsqueeze(0).float().to(device) / 255
|
|
98
|
+
# x = torchvision.transforms.Resize((160, 180))(x)
|
|
99
|
+
#
|
|
100
|
+
# class Toy(LinearPhysics):
|
|
101
|
+
# def __init__(self, **kwargs):
|
|
102
|
+
# super().__init__(**kwargs)
|
|
103
|
+
# self.A = lambda x: x * 2
|
|
104
|
+
# self.A_adjoint = lambda x: x * 2
|
|
105
|
+
#
|
|
106
|
+
# sigma_noise = 0.1
|
|
107
|
+
# kernel = torch.zeros((1, 1, 15, 15), device=device)
|
|
108
|
+
# kernel[:, :, 7, :] = 1 / 15
|
|
109
|
+
# # physics = deepinv.physics.BlurFFT(img_size=x.shape[1:], filter=kernel, device=device)
|
|
110
|
+
# physics = Pansharpen(factor=8, img_size=x.shape[1:], device=device)
|
|
111
|
+
#
|
|
112
|
+
# y = physics(x)
|
|
113
|
+
#
|
|
114
|
+
# xhat2 = physics.A_adjoint(y)
|
|
115
|
+
# xhat1 = physics.A_dagger(y)
|
|
116
|
+
#
|
|
117
|
+
# physics.compute_norm(x)
|
|
118
|
+
# physics.adjointness_test(x)
|
|
119
|
+
#
|
|
120
|
+
# deepinv.utils.plot(
|
|
121
|
+
# [y[0], y[1], xhat2, xhat1, x],
|
|
122
|
+
# titles=["low res color", "high res gray", "A_adjoint", "A_dagger", "x"],
|
|
123
|
+
# )
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from deepinv.physics.forward import DecomposablePhysics
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def hadamard_1d(u, normalize=True):
|
|
7
|
+
"""
|
|
8
|
+
Multiply H_n @ u where H_n is the Hadamard matrix of dimension n x n.
|
|
9
|
+
n must be a power of 2.
|
|
10
|
+
|
|
11
|
+
Parameters:
|
|
12
|
+
u: Tensor of shape (..., n)
|
|
13
|
+
normalize: if True, divide the result by 2^{m/2} where m = log_2(n).
|
|
14
|
+
Returns:
|
|
15
|
+
product: Tensor of shape (..., n)
|
|
16
|
+
"""
|
|
17
|
+
n = u.shape[-1]
|
|
18
|
+
m = int(np.log2(n))
|
|
19
|
+
assert n == 1 << m, "n must be a power of 2"
|
|
20
|
+
x = u[..., np.newaxis]
|
|
21
|
+
for d in range(m)[::-1]:
|
|
22
|
+
x = torch.cat(
|
|
23
|
+
(x[..., ::2, :] + x[..., 1::2, :], x[..., ::2, :] - x[..., 1::2, :]), dim=-1
|
|
24
|
+
)
|
|
25
|
+
return x.squeeze(-2) / 2 ** (m / 2) if normalize else x.squeeze(-2)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def hadamard_2d(x):
|
|
29
|
+
"""
|
|
30
|
+
Computes 2 dimensional Hadamard transform using 1 dimensional transform.
|
|
31
|
+
"""
|
|
32
|
+
out = hadamard_1d(hadamard_1d(x).transpose(-1, -2)).transpose(-1, -2)
|
|
33
|
+
return out
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SinglePixelCamera(DecomposablePhysics):
|
|
37
|
+
r"""
|
|
38
|
+
Single pixel imaging camera.
|
|
39
|
+
|
|
40
|
+
Linear imaging operator with binary entries.
|
|
41
|
+
|
|
42
|
+
If ``fast=False``, the operator uses a 2D subsampled hadamard transform, which keeps the first :math:`m` modes
|
|
43
|
+
according to the `sequency ordering <https://en.wikipedia.org/wiki/Walsh_matrix#Sequency_ordering>`_.
|
|
44
|
+
In this case, the images should have a size which is a power of 2.
|
|
45
|
+
|
|
46
|
+
If ``fast=False``, the operator is a random iid binary matrix with equal probability of :math:`1/\sqrt{m}` or
|
|
47
|
+
:math:`-1/\sqrt{m}`.
|
|
48
|
+
|
|
49
|
+
Both options allow for an efficient singular value decomposition (see :meth:`deepinv.physics.DecomposablePhysics`)
|
|
50
|
+
The operator is always applied independently across channels.
|
|
51
|
+
|
|
52
|
+
It is recommended to use ``fast=True`` for image sizes bigger than 32 x 32, since the forward computation with
|
|
53
|
+
``fast=False`` has an :math:`O(mn)` complexity, whereas with ``fast=True`` it has an :math:`O(n \log n)` complexity.
|
|
54
|
+
|
|
55
|
+
An existing operator can be loaded from a saved ``.pth`` file via ``self.load_state_dict(save_path)``,
|
|
56
|
+
in a similar fashion to :meth:`torch.nn.Module`.
|
|
57
|
+
|
|
58
|
+
:param int m: number of single pixel measurements per acquisition.
|
|
59
|
+
:param tuple img_shape: shape (C, H, W) of images.
|
|
60
|
+
:param bool fast: The operator is iid binary if false, otherwise A is a 2D subsampled hadamard transform.
|
|
61
|
+
:param str device: Device to store the forward matrix.
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self, m, img_shape, fast=True, device="cpu", dtype=torch.float32, **kwargs
|
|
67
|
+
):
|
|
68
|
+
super().__init__(**kwargs)
|
|
69
|
+
self.name = f"spcamera_m{m}"
|
|
70
|
+
self.img_shape = img_shape
|
|
71
|
+
self.fast = fast
|
|
72
|
+
self.device = device
|
|
73
|
+
|
|
74
|
+
if self.fast:
|
|
75
|
+
C, H, W = img_shape
|
|
76
|
+
mi = min(int(np.sqrt(m)), H)
|
|
77
|
+
mj = min(m - mi, W)
|
|
78
|
+
|
|
79
|
+
revi = get_permutation_list(H)[:mi]
|
|
80
|
+
revj = get_permutation_list(W)[:mj]
|
|
81
|
+
|
|
82
|
+
assert H == 1 << int(np.log2(H)), "image height must be a power of 2"
|
|
83
|
+
assert W == 1 << int(np.log2(W)), "image width must be a power of 2"
|
|
84
|
+
|
|
85
|
+
mask = torch.zeros(img_shape).unsqueeze(0)
|
|
86
|
+
for i in range(len(revi)):
|
|
87
|
+
for j in range(len(revj)):
|
|
88
|
+
mask[0, :, revi[i], revj[j]] = 1
|
|
89
|
+
|
|
90
|
+
mask = mask.to(device)
|
|
91
|
+
self.mask = torch.nn.Parameter(mask, requires_grad=False)
|
|
92
|
+
|
|
93
|
+
else:
|
|
94
|
+
n = int(np.prod(img_shape[1:]))
|
|
95
|
+
A = torch.ones((m, n), device=device)
|
|
96
|
+
A[torch.randn_like(A) > 0.5] = -1.0
|
|
97
|
+
A /= np.sqrt(m) # normalize
|
|
98
|
+
u, mask, vh = torch.linalg.svd(A, full_matrices=False)
|
|
99
|
+
|
|
100
|
+
self.mask = mask.to(device).unsqueeze(0).type(dtype)
|
|
101
|
+
self.vh = vh.to(device).type(dtype)
|
|
102
|
+
self.u = u.to(device).type(dtype)
|
|
103
|
+
|
|
104
|
+
self.u = torch.nn.Parameter(self.u, requires_grad=False)
|
|
105
|
+
self.vh = torch.nn.Parameter(self.vh, requires_grad=False)
|
|
106
|
+
self.mask = torch.nn.Parameter(self.mask, requires_grad=False)
|
|
107
|
+
|
|
108
|
+
def V_adjoint(self, x):
|
|
109
|
+
if self.fast:
|
|
110
|
+
y = hadamard_2d(x)
|
|
111
|
+
else:
|
|
112
|
+
N, C = x.shape[0], self.img_shape[0]
|
|
113
|
+
x = x.reshape(N, C, -1)
|
|
114
|
+
y = torch.einsum("ijk, mk->ijm", x, self.vh)
|
|
115
|
+
return y
|
|
116
|
+
|
|
117
|
+
def V(self, y):
|
|
118
|
+
if self.fast:
|
|
119
|
+
x = hadamard_2d(y)
|
|
120
|
+
else:
|
|
121
|
+
N = y.shape[0]
|
|
122
|
+
C, H, W = self.img_shape[0], self.img_shape[1], self.img_shape[2]
|
|
123
|
+
x = torch.einsum("ijk, km->ijm", y, self.vh)
|
|
124
|
+
x = x.reshape(N, C, H, W)
|
|
125
|
+
return x
|
|
126
|
+
|
|
127
|
+
def U_adjoint(self, x):
|
|
128
|
+
if self.fast:
|
|
129
|
+
out = x
|
|
130
|
+
else:
|
|
131
|
+
out = torch.einsum("ijk, km->ijm", x, self.u)
|
|
132
|
+
return out
|
|
133
|
+
|
|
134
|
+
def U(self, x):
|
|
135
|
+
if self.fast:
|
|
136
|
+
out = x
|
|
137
|
+
else:
|
|
138
|
+
out = torch.einsum("ijk, mk->ijm", x, self.u)
|
|
139
|
+
return out
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def gray_decode(n):
|
|
143
|
+
m = n >> 1
|
|
144
|
+
while m:
|
|
145
|
+
n ^= m
|
|
146
|
+
m >>= 1
|
|
147
|
+
return n
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def reverse(n, numbits):
|
|
151
|
+
return sum(1 << (numbits - 1 - i) for i in range(numbits) if n >> i & 1)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_permutation_list(n):
|
|
155
|
+
rev = np.zeros((n), dtype=int)
|
|
156
|
+
for l in range(n):
|
|
157
|
+
rev[l] = reverse(l, np.log2(n).astype(int))
|
|
158
|
+
|
|
159
|
+
rev2 = np.zeros_like(rev)
|
|
160
|
+
for l in range(n):
|
|
161
|
+
rev2[l] = rev[gray_decode(l)]
|
|
162
|
+
|
|
163
|
+
return rev2
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# test code
|
|
167
|
+
# if __name__ == "__main__":
|
|
168
|
+
# import matplotlib.pyplot as plt
|
|
169
|
+
# import deepinv as dinv
|
|
170
|
+
# import torchvision
|
|
171
|
+
#
|
|
172
|
+
# device = "cuda:0"
|
|
173
|
+
# x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
|
|
174
|
+
# x = x.unsqueeze(0).float().to(device) / 255
|
|
175
|
+
# x = torchvision.transforms.Resize((16, 8))(x)
|
|
176
|
+
#
|
|
177
|
+
# m = 20
|
|
178
|
+
# physics = SinglePixelCamera(m, (3, 16, 8), fast=False, device=device)
|
|
179
|
+
#
|
|
180
|
+
# y = physics(x)
|
|
181
|
+
#
|
|
182
|
+
# xhat = physics.A_adjoint(y)
|
|
183
|
+
#
|
|
184
|
+
# dinv.utils.plot([x, xhat])
|
|
185
|
+
#
|
|
186
|
+
# print(physics.adjointness_test(x))
|
|
187
|
+
# print(physics.compute_norm(x))
|
|
188
|
+
# # mi = min(int(np.sqrt(m)), x.shape[-2])
|
|
189
|
+
# # mj = min(m - mi, x.shape[-2])
|
|
190
|
+
# #
|
|
191
|
+
# # revi = get_permutation_list(x.shape[-2])[:mi]
|
|
192
|
+
# # revj = get_permutation_list(x.shape[-1])[:mj]
|
|
193
|
+
# #
|
|
194
|
+
# # mask = torch.zeros_like(x)
|
|
195
|
+
# # for i in range(len(revi)):
|
|
196
|
+
# # for j in range(len(revj)):
|
|
197
|
+
# # mask[0, :, revi[i], revj[j]] = 1
|
|
198
|
+
# #
|
|
199
|
+
# # # generate low pass hadamard mask
|
|
200
|
+
# # f = hadamard_2d(x)
|
|
201
|
+
# # f = f * mask
|
|
202
|
+
# # out = hadamard_2d(f)
|
|
203
|
+
# #
|
|
204
|
+
# # dinv.utils.plot_batch([x, out, f])
|
|
205
|
+
# #
|
|
206
|
+
# # rev = get_permutation_list(8)
|
|
207
|
+
# # imgs = []
|
|
208
|
+
# # for i in range(8):
|
|
209
|
+
# # y = torch.zeros((8, 1, 8, 8), device=dinv.device)
|
|
210
|
+
# # for j in range(8):
|
|
211
|
+
# # x = torch.zeros((8, 8), device=dinv.device)
|
|
212
|
+
# # x[rev[i], rev[j]] = 1
|
|
213
|
+
# # x = hadamard_2d(x)
|
|
214
|
+
# # y[j, 0, :, :] = x
|
|
215
|
+
# #
|
|
216
|
+
# # imgs.append(y)
|
|
217
|
+
# #
|
|
218
|
+
# # dinv.utils.plot_batch(imgs)
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from torch import nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from deepinv.physics.forward import LinearPhysics
|
|
6
|
+
|
|
7
|
+
if torch.__version__ > "1.2.0":
|
|
8
|
+
affine_grid = lambda theta, size: F.affine_grid(theta, size, align_corners=True)
|
|
9
|
+
grid_sample = lambda input, grid, mode="bilinear": F.grid_sample(
|
|
10
|
+
input, grid, align_corners=True, mode=mode
|
|
11
|
+
)
|
|
12
|
+
else:
|
|
13
|
+
affine_grid = F.affine_grid
|
|
14
|
+
grid_sample = F.grid_sample
|
|
15
|
+
|
|
16
|
+
# constants
|
|
17
|
+
PI = 4 * torch.ones(1).atan()
|
|
18
|
+
SQRT2 = (2 * torch.ones(1)).sqrt()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def fftfreq(n):
|
|
22
|
+
val = 1.0 / n
|
|
23
|
+
results = torch.zeros(n)
|
|
24
|
+
N = (n - 1) // 2 + 1
|
|
25
|
+
p1 = torch.arange(0, N)
|
|
26
|
+
results[:N] = p1
|
|
27
|
+
p2 = torch.arange(-(n // 2), 0)
|
|
28
|
+
results[N:] = p2
|
|
29
|
+
return results * val
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def deg2rad(x):
|
|
33
|
+
return x * 4 * torch.ones(1, device=x.device, dtype=x.dtype).atan() / 180
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AbstractFilter(nn.Module):
|
|
37
|
+
def __init__(self, device="cpu", dtype=torch.float):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.device = device
|
|
40
|
+
self.dtype = dtype
|
|
41
|
+
|
|
42
|
+
def forward(self, x):
|
|
43
|
+
input_size = x.shape[2]
|
|
44
|
+
projection_size_padded = max(
|
|
45
|
+
64, int(2 ** (2 * torch.tensor(input_size)).float().log2().ceil())
|
|
46
|
+
)
|
|
47
|
+
pad_width = projection_size_padded - input_size
|
|
48
|
+
padded_tensor = F.pad(x, (0, 0, 0, pad_width))
|
|
49
|
+
f = self._get_fourier_filter(padded_tensor.shape[2]).to(x.device)
|
|
50
|
+
fourier_filter = self.create_filter(f)
|
|
51
|
+
fourier_filter = fourier_filter.unsqueeze(-2)
|
|
52
|
+
|
|
53
|
+
projection = (
|
|
54
|
+
torch.view_as_real(torch.fft.fft(padded_tensor.transpose(2, 3))).transpose(
|
|
55
|
+
2, 3
|
|
56
|
+
)
|
|
57
|
+
* fourier_filter
|
|
58
|
+
)
|
|
59
|
+
result = torch.view_as_real(
|
|
60
|
+
torch.fft.ifft(torch.view_as_complex(projection).transpose(2, 3))
|
|
61
|
+
)[..., 0]
|
|
62
|
+
result = result.transpose(2, 3)[:, :, :input_size, :]
|
|
63
|
+
|
|
64
|
+
return result
|
|
65
|
+
|
|
66
|
+
def _get_fourier_filter(self, size):
|
|
67
|
+
n = torch.cat(
|
|
68
|
+
[torch.arange(1, size / 2 + 1, 2), torch.arange(size / 2 - 1, 0, -2)]
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
f = torch.zeros(size, dtype=self.dtype, device=self.device)
|
|
72
|
+
f[0] = 0.25
|
|
73
|
+
f[1::2] = -1 / (PI * n) ** 2
|
|
74
|
+
|
|
75
|
+
fourier_filter = torch.view_as_real(torch.fft.fft(f, dim=-1))
|
|
76
|
+
fourier_filter[:, 1] = fourier_filter[:, 0]
|
|
77
|
+
|
|
78
|
+
return 2 * fourier_filter
|
|
79
|
+
|
|
80
|
+
def create_filter(self, f):
|
|
81
|
+
raise NotImplementedError
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class RampFilter(AbstractFilter):
|
|
85
|
+
def __init__(self, **kwargs):
|
|
86
|
+
super(RampFilter, self).__init__(**kwargs)
|
|
87
|
+
|
|
88
|
+
def create_filter(self, f):
|
|
89
|
+
return f
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class Radon(nn.Module):
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
in_size=None,
|
|
96
|
+
theta=None,
|
|
97
|
+
circle=False,
|
|
98
|
+
dtype=torch.float,
|
|
99
|
+
device=torch.device("cpu"),
|
|
100
|
+
):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.circle = circle
|
|
103
|
+
self.theta = theta
|
|
104
|
+
if theta is None:
|
|
105
|
+
self.theta = torch.arange(180)
|
|
106
|
+
self.dtype = dtype
|
|
107
|
+
self.all_grids = None
|
|
108
|
+
if in_size is not None:
|
|
109
|
+
self.all_grids = self._create_grids(self.theta, in_size, circle).to(device)
|
|
110
|
+
|
|
111
|
+
def forward(self, x):
|
|
112
|
+
N, C, W, H = x.shape
|
|
113
|
+
assert W == H, "Input image must be square"
|
|
114
|
+
|
|
115
|
+
if (
|
|
116
|
+
self.all_grids is None
|
|
117
|
+
): # if in_size was not given, we have to create the grid online.
|
|
118
|
+
self.all_grids = self._create_grids(
|
|
119
|
+
self.theta, W, self.circle, device=x.device
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if not self.circle:
|
|
123
|
+
diagonal = SQRT2 * W
|
|
124
|
+
pad = int((diagonal - W).ceil())
|
|
125
|
+
new_center = (W + pad) // 2
|
|
126
|
+
old_center = W // 2
|
|
127
|
+
pad_before = new_center - old_center
|
|
128
|
+
pad_width = (pad_before, pad - pad_before)
|
|
129
|
+
x = F.pad(x, (pad_width[0], pad_width[1], pad_width[0], pad_width[1]))
|
|
130
|
+
|
|
131
|
+
N, C, W, _ = x.shape
|
|
132
|
+
out = torch.zeros(N, C, W, len(self.theta), device=x.device, dtype=self.dtype)
|
|
133
|
+
|
|
134
|
+
for i in range(len(self.theta)):
|
|
135
|
+
rotated = grid_sample(x, self.all_grids[i].repeat(N, 1, 1, 1).to(x.device))
|
|
136
|
+
out[..., i] = rotated.sum(2)
|
|
137
|
+
return out
|
|
138
|
+
|
|
139
|
+
def _create_grids(self, angles, grid_size, circle, device="cpu"):
|
|
140
|
+
if not circle:
|
|
141
|
+
grid_size = int((SQRT2 * grid_size).ceil())
|
|
142
|
+
all_grids = []
|
|
143
|
+
for theta in angles:
|
|
144
|
+
theta = deg2rad(theta)
|
|
145
|
+
R = torch.tensor(
|
|
146
|
+
[[[theta.cos(), theta.sin(), 0], [-theta.sin(), theta.cos(), 0]]],
|
|
147
|
+
dtype=self.dtype,
|
|
148
|
+
device=device,
|
|
149
|
+
)
|
|
150
|
+
all_grids.append(affine_grid(R, torch.Size([1, 1, grid_size, grid_size])))
|
|
151
|
+
return torch.stack(all_grids)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class IRadon(nn.Module):
|
|
155
|
+
def __init__(
|
|
156
|
+
self,
|
|
157
|
+
in_size=None,
|
|
158
|
+
theta=None,
|
|
159
|
+
circle=False,
|
|
160
|
+
use_filter=True,
|
|
161
|
+
out_size=None,
|
|
162
|
+
dtype=torch.float,
|
|
163
|
+
device=torch.device("cpu"),
|
|
164
|
+
):
|
|
165
|
+
super().__init__()
|
|
166
|
+
self.circle = circle
|
|
167
|
+
self.device = device
|
|
168
|
+
self.theta = theta if theta is not None else torch.arange(180).to(self.device)
|
|
169
|
+
self.out_size = out_size
|
|
170
|
+
self.in_size = in_size
|
|
171
|
+
self.dtype = dtype
|
|
172
|
+
self.ygrid, self.xgrid, self.all_grids = None, None, None
|
|
173
|
+
if in_size is not None:
|
|
174
|
+
self.ygrid, self.xgrid = self._create_yxgrid(in_size, circle)
|
|
175
|
+
self.all_grids = self._create_grids(self.theta, in_size, circle).to(
|
|
176
|
+
self.device
|
|
177
|
+
)
|
|
178
|
+
self.filter = (
|
|
179
|
+
RampFilter(dtype=self.dtype, device=self.device)
|
|
180
|
+
if use_filter
|
|
181
|
+
else lambda x: x
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def forward(self, x, filtering=True):
|
|
185
|
+
it_size = x.shape[2]
|
|
186
|
+
ch_size = x.shape[1]
|
|
187
|
+
|
|
188
|
+
if self.in_size is None:
|
|
189
|
+
self.in_size = (
|
|
190
|
+
int((it_size / SQRT2).floor()) if not self.circle else it_size
|
|
191
|
+
)
|
|
192
|
+
# if None in [self.ygrid, self.xgrid, self.all_grids]:
|
|
193
|
+
if self.ygrid is None or self.xgrid is None or self.all_grids is None:
|
|
194
|
+
self.ygrid, self.xgrid = self._create_yxgrid(self.in_size, self.circle)
|
|
195
|
+
self.all_grids = self._create_grids(self.theta, self.in_size, self.circle)
|
|
196
|
+
|
|
197
|
+
x = self.filter(x) if filtering else x
|
|
198
|
+
|
|
199
|
+
reco = torch.zeros(
|
|
200
|
+
x.shape[0], ch_size, it_size, it_size, device=self.device, dtype=self.dtype
|
|
201
|
+
)
|
|
202
|
+
for i_theta in range(len(self.theta)):
|
|
203
|
+
reco += grid_sample(
|
|
204
|
+
x, self.all_grids[i_theta].repeat(reco.shape[0], 1, 1, 1)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
if not self.circle:
|
|
208
|
+
W = self.in_size
|
|
209
|
+
diagonal = it_size
|
|
210
|
+
pad = int(torch.tensor(diagonal - W, dtype=torch.float).ceil())
|
|
211
|
+
new_center = (W + pad) // 2
|
|
212
|
+
old_center = W // 2
|
|
213
|
+
pad_before = new_center - old_center
|
|
214
|
+
pad_width = (pad_before, pad - pad_before)
|
|
215
|
+
reco = F.pad(
|
|
216
|
+
reco, (-pad_width[0], -pad_width[1], -pad_width[0], -pad_width[1])
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if self.circle:
|
|
220
|
+
reconstruction_circle = (self.xgrid**2 + self.ygrid**2) <= 1
|
|
221
|
+
reconstruction_circle = reconstruction_circle.repeat(
|
|
222
|
+
x.shape[0], ch_size, 1, 1
|
|
223
|
+
)
|
|
224
|
+
reco[~reconstruction_circle] = 0.0
|
|
225
|
+
|
|
226
|
+
reco = reco * PI.item() / (2 * len(self.theta))
|
|
227
|
+
|
|
228
|
+
if self.out_size is not None:
|
|
229
|
+
pad = (self.out_size - self.in_size) // 2
|
|
230
|
+
reco = F.pad(reco, (pad, pad, pad, pad))
|
|
231
|
+
|
|
232
|
+
return reco
|
|
233
|
+
|
|
234
|
+
def _create_yxgrid(self, in_size, circle):
|
|
235
|
+
if not circle:
|
|
236
|
+
in_size = int((SQRT2 * in_size).ceil())
|
|
237
|
+
unitrange = torch.linspace(-1, 1, in_size, dtype=self.dtype, device=self.device)
|
|
238
|
+
return torch.meshgrid(unitrange, unitrange, indexing="ij")
|
|
239
|
+
|
|
240
|
+
def _XYtoT(self, theta):
|
|
241
|
+
T = self.xgrid * (deg2rad(theta)).cos() - self.ygrid * (deg2rad(theta)).sin()
|
|
242
|
+
return T
|
|
243
|
+
|
|
244
|
+
def _create_grids(self, angles, grid_size, circle):
|
|
245
|
+
if not circle:
|
|
246
|
+
grid_size = int((SQRT2 * grid_size).ceil())
|
|
247
|
+
all_grids = []
|
|
248
|
+
for i_theta in range(len(angles)):
|
|
249
|
+
X = (
|
|
250
|
+
torch.ones(grid_size, dtype=self.dtype, device=self.device)
|
|
251
|
+
.view(-1, 1)
|
|
252
|
+
.repeat(1, grid_size)
|
|
253
|
+
* i_theta
|
|
254
|
+
* 2.0
|
|
255
|
+
/ (len(angles) - 1)
|
|
256
|
+
- 1.0
|
|
257
|
+
)
|
|
258
|
+
Y = self._XYtoT(angles[i_theta])
|
|
259
|
+
all_grids.append(
|
|
260
|
+
torch.cat((X.unsqueeze(-1), Y.unsqueeze(-1)), dim=-1).unsqueeze(0)
|
|
261
|
+
)
|
|
262
|
+
return torch.stack(all_grids)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class Tomography(LinearPhysics):
|
|
266
|
+
r"""
|
|
267
|
+
(Computed) Tomography operator.
|
|
268
|
+
|
|
269
|
+
The Radon transform is the integral transform which takes a square image :math:`x` defined on the plane to a function
|
|
270
|
+
:math:`y=Rx` defined on the (two-dimensional) space of lines in the plane, whose value at a particular line is equal
|
|
271
|
+
to the line integral of the function over that line.
|
|
272
|
+
|
|
273
|
+
.. note::
|
|
274
|
+
|
|
275
|
+
The pseudo-inverse is computed using the filtered back-projection algorithm with a Ramp filter.
|
|
276
|
+
This is not the exact linear pseudo-inverse of the Radon transform, but it is a good approximation which is
|
|
277
|
+
robust to noise.
|
|
278
|
+
|
|
279
|
+
.. warning::
|
|
280
|
+
|
|
281
|
+
The adjoint operator has small numerical errors due to interpolation.
|
|
282
|
+
|
|
283
|
+
:param int img_width: width/height of the square image input.
|
|
284
|
+
:param int, torch.tensor angles: If the type is ``int``, the angles are sampled uniformly between 0 and 360 degrees.
|
|
285
|
+
If the type is ``torch.tensor``, the angles are the ones provided (e.g., ``torch.linspace(0, 180, steps=10)``).
|
|
286
|
+
:param bool circle: If ``True`` both forward and backward projection will be restricted to pixels inside a circle
|
|
287
|
+
inscribed in the square image.
|
|
288
|
+
:param str device: gpu or cpu.
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
def __init__(
|
|
292
|
+
self,
|
|
293
|
+
img_width,
|
|
294
|
+
angles,
|
|
295
|
+
circle=False,
|
|
296
|
+
device=torch.device("cpu"),
|
|
297
|
+
dtype=torch.float,
|
|
298
|
+
**kwargs,
|
|
299
|
+
):
|
|
300
|
+
super().__init__(**kwargs)
|
|
301
|
+
|
|
302
|
+
if isinstance(angles, int) or isinstance(angles, float):
|
|
303
|
+
theta = torch.linspace(0, 180, steps=angles + 1, device=device)[:-1]
|
|
304
|
+
else:
|
|
305
|
+
theta = angles.to(device)
|
|
306
|
+
|
|
307
|
+
self.radon = Radon(
|
|
308
|
+
img_width, theta, circle=circle, device=device, dtype=dtype
|
|
309
|
+
).to(device)
|
|
310
|
+
self.iradon = IRadon(
|
|
311
|
+
img_width, theta, circle=circle, device=device, dtype=dtype
|
|
312
|
+
).to(device)
|
|
313
|
+
|
|
314
|
+
def A(self, x):
|
|
315
|
+
return self.radon(x)
|
|
316
|
+
|
|
317
|
+
def A_dagger(self, y):
|
|
318
|
+
return self.iradon(y)
|
|
319
|
+
|
|
320
|
+
def A_adjoint(self, y):
|
|
321
|
+
return self.iradon(y, filtering=False)
|