deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,123 @@
1
+ from deepinv.physics.noise import GaussianNoise
2
+ from deepinv.physics.forward import LinearPhysics
3
+ from deepinv.physics.blur import Downsampling
4
+ from deepinv.physics.range import Decolorize
5
+ from deepinv.utils import TensorList
6
+
7
+
8
+ class Pansharpen(LinearPhysics):
9
+ """
10
+ Pansharpening forward operator.
11
+
12
+ The measurements consist of a high resolution grayscale image and a low resolution RGB image, and
13
+ are represented using :class:`deepinv.utils.TensorList`, where the first element is the RGB image and the second
14
+ element is the grayscale image.
15
+
16
+ By default, the downsampling is done with a gaussian filter with standard deviation equal to the downsampling,
17
+ however, the user can provide a custom downsampling filter.
18
+
19
+ It is possible to assign a different noise model to the RGB and grayscale images.
20
+
21
+ Example usage:
22
+
23
+ ::
24
+
25
+ import deepinv
26
+ import torch
27
+
28
+ x = torch.randn(1, 3, 256, 256)
29
+ physics = deepinv.physics.Pansharpen(img_size=x.shape[1:], device=x.device)
30
+
31
+ y = physics(x) # returns a TensorList with the RGB and grayscale images
32
+
33
+ x_adj = physics.A_adjoint(y)
34
+ x_pinv = physics.A_dagger(y)
35
+
36
+ deepinv.utils.plot([y[0], y[1], x_adj, x_pinv, x], titles=['low res color', 'high res gray',
37
+ 'A_adjoint', 'A_dagger', 'x'])
38
+
39
+ :param tuple[int] img_size: size of the input image.
40
+ :param int factor: downsampling factor.
41
+ :param torch.nn.Module noise_color: noise model for the RGB image.
42
+ :param torch.nn.Module noise_gray: noise model for the grayscale image.
43
+ :param torch.Tensor, str, NoneType filter: Downsampling filter. It can be 'gaussian', 'bilinear' or 'bicubic' or a
44
+ custom ``torch.Tensor`` filter. If ``None``, no filtering is applied.
45
+ :param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``.
46
+ If ``padding='valid'`` the blurred output is smaller than the image (no padding)
47
+ otherwise the blurred output has the same size as the image.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ img_size,
53
+ factor=4,
54
+ noise_color=GaussianNoise(sigma=0.0),
55
+ noise_gray=GaussianNoise(sigma=0.05),
56
+ filter="gaussian",
57
+ device="cpu",
58
+ padding="circular",
59
+ **kwargs,
60
+ ):
61
+ super().__init__(**kwargs)
62
+
63
+ self.downsampling = Downsampling(
64
+ img_size=img_size,
65
+ factor=factor,
66
+ filter=filter,
67
+ device=device,
68
+ padding=padding,
69
+ )
70
+
71
+ self.noise_color = noise_color
72
+ self.noise_gray = noise_gray
73
+ self.colorize = Decolorize()
74
+
75
+ def A(self, x):
76
+ return TensorList([self.downsampling(x), self.colorize(x)])
77
+
78
+ def A_adjoint(self, y):
79
+ return self.downsampling.A_adjoint(y[0]) + self.colorize.A_adjoint(y[1])
80
+
81
+ def forward(self, x):
82
+ return TensorList(
83
+ [self.noise_color(self.downsampling(x)), self.noise_gray(self.colorize(x))]
84
+ )
85
+
86
+
87
+ # test code
88
+ # if __name__ == "__main__":
89
+ # device = "cuda:0"
90
+ # import torch
91
+ # import torchvision
92
+ # import deepinv
93
+ #
94
+ # device = "cuda:0"
95
+ #
96
+ # x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
97
+ # x = x.unsqueeze(0).float().to(device) / 255
98
+ # x = torchvision.transforms.Resize((160, 180))(x)
99
+ #
100
+ # class Toy(LinearPhysics):
101
+ # def __init__(self, **kwargs):
102
+ # super().__init__(**kwargs)
103
+ # self.A = lambda x: x * 2
104
+ # self.A_adjoint = lambda x: x * 2
105
+ #
106
+ # sigma_noise = 0.1
107
+ # kernel = torch.zeros((1, 1, 15, 15), device=device)
108
+ # kernel[:, :, 7, :] = 1 / 15
109
+ # # physics = deepinv.physics.BlurFFT(img_size=x.shape[1:], filter=kernel, device=device)
110
+ # physics = Pansharpen(factor=8, img_size=x.shape[1:], device=device)
111
+ #
112
+ # y = physics(x)
113
+ #
114
+ # xhat2 = physics.A_adjoint(y)
115
+ # xhat1 = physics.A_dagger(y)
116
+ #
117
+ # physics.compute_norm(x)
118
+ # physics.adjointness_test(x)
119
+ #
120
+ # deepinv.utils.plot(
121
+ # [y[0], y[1], xhat2, xhat1, x],
122
+ # titles=["low res color", "high res gray", "A_adjoint", "A_dagger", "x"],
123
+ # )
@@ -0,0 +1,218 @@
1
+ from deepinv.physics.forward import DecomposablePhysics
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ def hadamard_1d(u, normalize=True):
7
+ """
8
+ Multiply H_n @ u where H_n is the Hadamard matrix of dimension n x n.
9
+ n must be a power of 2.
10
+
11
+ Parameters:
12
+ u: Tensor of shape (..., n)
13
+ normalize: if True, divide the result by 2^{m/2} where m = log_2(n).
14
+ Returns:
15
+ product: Tensor of shape (..., n)
16
+ """
17
+ n = u.shape[-1]
18
+ m = int(np.log2(n))
19
+ assert n == 1 << m, "n must be a power of 2"
20
+ x = u[..., np.newaxis]
21
+ for d in range(m)[::-1]:
22
+ x = torch.cat(
23
+ (x[..., ::2, :] + x[..., 1::2, :], x[..., ::2, :] - x[..., 1::2, :]), dim=-1
24
+ )
25
+ return x.squeeze(-2) / 2 ** (m / 2) if normalize else x.squeeze(-2)
26
+
27
+
28
+ def hadamard_2d(x):
29
+ """
30
+ Computes 2 dimensional Hadamard transform using 1 dimensional transform.
31
+ """
32
+ out = hadamard_1d(hadamard_1d(x).transpose(-1, -2)).transpose(-1, -2)
33
+ return out
34
+
35
+
36
+ class SinglePixelCamera(DecomposablePhysics):
37
+ r"""
38
+ Single pixel imaging camera.
39
+
40
+ Linear imaging operator with binary entries.
41
+
42
+ If ``fast=False``, the operator uses a 2D subsampled hadamard transform, which keeps the first :math:`m` modes
43
+ according to the `sequency ordering <https://en.wikipedia.org/wiki/Walsh_matrix#Sequency_ordering>`_.
44
+ In this case, the images should have a size which is a power of 2.
45
+
46
+ If ``fast=False``, the operator is a random iid binary matrix with equal probability of :math:`1/\sqrt{m}` or
47
+ :math:`-1/\sqrt{m}`.
48
+
49
+ Both options allow for an efficient singular value decomposition (see :meth:`deepinv.physics.DecomposablePhysics`)
50
+ The operator is always applied independently across channels.
51
+
52
+ It is recommended to use ``fast=True`` for image sizes bigger than 32 x 32, since the forward computation with
53
+ ``fast=False`` has an :math:`O(mn)` complexity, whereas with ``fast=True`` it has an :math:`O(n \log n)` complexity.
54
+
55
+ An existing operator can be loaded from a saved ``.pth`` file via ``self.load_state_dict(save_path)``,
56
+ in a similar fashion to :meth:`torch.nn.Module`.
57
+
58
+ :param int m: number of single pixel measurements per acquisition.
59
+ :param tuple img_shape: shape (C, H, W) of images.
60
+ :param bool fast: The operator is iid binary if false, otherwise A is a 2D subsampled hadamard transform.
61
+ :param str device: Device to store the forward matrix.
62
+
63
+ """
64
+
65
+ def __init__(
66
+ self, m, img_shape, fast=True, device="cpu", dtype=torch.float32, **kwargs
67
+ ):
68
+ super().__init__(**kwargs)
69
+ self.name = f"spcamera_m{m}"
70
+ self.img_shape = img_shape
71
+ self.fast = fast
72
+ self.device = device
73
+
74
+ if self.fast:
75
+ C, H, W = img_shape
76
+ mi = min(int(np.sqrt(m)), H)
77
+ mj = min(m - mi, W)
78
+
79
+ revi = get_permutation_list(H)[:mi]
80
+ revj = get_permutation_list(W)[:mj]
81
+
82
+ assert H == 1 << int(np.log2(H)), "image height must be a power of 2"
83
+ assert W == 1 << int(np.log2(W)), "image width must be a power of 2"
84
+
85
+ mask = torch.zeros(img_shape).unsqueeze(0)
86
+ for i in range(len(revi)):
87
+ for j in range(len(revj)):
88
+ mask[0, :, revi[i], revj[j]] = 1
89
+
90
+ mask = mask.to(device)
91
+ self.mask = torch.nn.Parameter(mask, requires_grad=False)
92
+
93
+ else:
94
+ n = int(np.prod(img_shape[1:]))
95
+ A = torch.ones((m, n), device=device)
96
+ A[torch.randn_like(A) > 0.5] = -1.0
97
+ A /= np.sqrt(m) # normalize
98
+ u, mask, vh = torch.linalg.svd(A, full_matrices=False)
99
+
100
+ self.mask = mask.to(device).unsqueeze(0).type(dtype)
101
+ self.vh = vh.to(device).type(dtype)
102
+ self.u = u.to(device).type(dtype)
103
+
104
+ self.u = torch.nn.Parameter(self.u, requires_grad=False)
105
+ self.vh = torch.nn.Parameter(self.vh, requires_grad=False)
106
+ self.mask = torch.nn.Parameter(self.mask, requires_grad=False)
107
+
108
+ def V_adjoint(self, x):
109
+ if self.fast:
110
+ y = hadamard_2d(x)
111
+ else:
112
+ N, C = x.shape[0], self.img_shape[0]
113
+ x = x.reshape(N, C, -1)
114
+ y = torch.einsum("ijk, mk->ijm", x, self.vh)
115
+ return y
116
+
117
+ def V(self, y):
118
+ if self.fast:
119
+ x = hadamard_2d(y)
120
+ else:
121
+ N = y.shape[0]
122
+ C, H, W = self.img_shape[0], self.img_shape[1], self.img_shape[2]
123
+ x = torch.einsum("ijk, km->ijm", y, self.vh)
124
+ x = x.reshape(N, C, H, W)
125
+ return x
126
+
127
+ def U_adjoint(self, x):
128
+ if self.fast:
129
+ out = x
130
+ else:
131
+ out = torch.einsum("ijk, km->ijm", x, self.u)
132
+ return out
133
+
134
+ def U(self, x):
135
+ if self.fast:
136
+ out = x
137
+ else:
138
+ out = torch.einsum("ijk, mk->ijm", x, self.u)
139
+ return out
140
+
141
+
142
+ def gray_decode(n):
143
+ m = n >> 1
144
+ while m:
145
+ n ^= m
146
+ m >>= 1
147
+ return n
148
+
149
+
150
+ def reverse(n, numbits):
151
+ return sum(1 << (numbits - 1 - i) for i in range(numbits) if n >> i & 1)
152
+
153
+
154
+ def get_permutation_list(n):
155
+ rev = np.zeros((n), dtype=int)
156
+ for l in range(n):
157
+ rev[l] = reverse(l, np.log2(n).astype(int))
158
+
159
+ rev2 = np.zeros_like(rev)
160
+ for l in range(n):
161
+ rev2[l] = rev[gray_decode(l)]
162
+
163
+ return rev2
164
+
165
+
166
+ # test code
167
+ # if __name__ == "__main__":
168
+ # import matplotlib.pyplot as plt
169
+ # import deepinv as dinv
170
+ # import torchvision
171
+ #
172
+ # device = "cuda:0"
173
+ # x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
174
+ # x = x.unsqueeze(0).float().to(device) / 255
175
+ # x = torchvision.transforms.Resize((16, 8))(x)
176
+ #
177
+ # m = 20
178
+ # physics = SinglePixelCamera(m, (3, 16, 8), fast=False, device=device)
179
+ #
180
+ # y = physics(x)
181
+ #
182
+ # xhat = physics.A_adjoint(y)
183
+ #
184
+ # dinv.utils.plot([x, xhat])
185
+ #
186
+ # print(physics.adjointness_test(x))
187
+ # print(physics.compute_norm(x))
188
+ # # mi = min(int(np.sqrt(m)), x.shape[-2])
189
+ # # mj = min(m - mi, x.shape[-2])
190
+ # #
191
+ # # revi = get_permutation_list(x.shape[-2])[:mi]
192
+ # # revj = get_permutation_list(x.shape[-1])[:mj]
193
+ # #
194
+ # # mask = torch.zeros_like(x)
195
+ # # for i in range(len(revi)):
196
+ # # for j in range(len(revj)):
197
+ # # mask[0, :, revi[i], revj[j]] = 1
198
+ # #
199
+ # # # generate low pass hadamard mask
200
+ # # f = hadamard_2d(x)
201
+ # # f = f * mask
202
+ # # out = hadamard_2d(f)
203
+ # #
204
+ # # dinv.utils.plot_batch([x, out, f])
205
+ # #
206
+ # # rev = get_permutation_list(8)
207
+ # # imgs = []
208
+ # # for i in range(8):
209
+ # # y = torch.zeros((8, 1, 8, 8), device=dinv.device)
210
+ # # for j in range(8):
211
+ # # x = torch.zeros((8, 8), device=dinv.device)
212
+ # # x[rev[i], rev[j]] = 1
213
+ # # x = hadamard_2d(x)
214
+ # # y[j, 0, :, :] = x
215
+ # #
216
+ # # imgs.append(y)
217
+ # #
218
+ # # dinv.utils.plot_batch(imgs)
@@ -0,0 +1,321 @@
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from deepinv.physics.forward import LinearPhysics
6
+
7
+ if torch.__version__ > "1.2.0":
8
+ affine_grid = lambda theta, size: F.affine_grid(theta, size, align_corners=True)
9
+ grid_sample = lambda input, grid, mode="bilinear": F.grid_sample(
10
+ input, grid, align_corners=True, mode=mode
11
+ )
12
+ else:
13
+ affine_grid = F.affine_grid
14
+ grid_sample = F.grid_sample
15
+
16
+ # constants
17
+ PI = 4 * torch.ones(1).atan()
18
+ SQRT2 = (2 * torch.ones(1)).sqrt()
19
+
20
+
21
+ def fftfreq(n):
22
+ val = 1.0 / n
23
+ results = torch.zeros(n)
24
+ N = (n - 1) // 2 + 1
25
+ p1 = torch.arange(0, N)
26
+ results[:N] = p1
27
+ p2 = torch.arange(-(n // 2), 0)
28
+ results[N:] = p2
29
+ return results * val
30
+
31
+
32
+ def deg2rad(x):
33
+ return x * 4 * torch.ones(1, device=x.device, dtype=x.dtype).atan() / 180
34
+
35
+
36
+ class AbstractFilter(nn.Module):
37
+ def __init__(self, device="cpu", dtype=torch.float):
38
+ super().__init__()
39
+ self.device = device
40
+ self.dtype = dtype
41
+
42
+ def forward(self, x):
43
+ input_size = x.shape[2]
44
+ projection_size_padded = max(
45
+ 64, int(2 ** (2 * torch.tensor(input_size)).float().log2().ceil())
46
+ )
47
+ pad_width = projection_size_padded - input_size
48
+ padded_tensor = F.pad(x, (0, 0, 0, pad_width))
49
+ f = self._get_fourier_filter(padded_tensor.shape[2]).to(x.device)
50
+ fourier_filter = self.create_filter(f)
51
+ fourier_filter = fourier_filter.unsqueeze(-2)
52
+
53
+ projection = (
54
+ torch.view_as_real(torch.fft.fft(padded_tensor.transpose(2, 3))).transpose(
55
+ 2, 3
56
+ )
57
+ * fourier_filter
58
+ )
59
+ result = torch.view_as_real(
60
+ torch.fft.ifft(torch.view_as_complex(projection).transpose(2, 3))
61
+ )[..., 0]
62
+ result = result.transpose(2, 3)[:, :, :input_size, :]
63
+
64
+ return result
65
+
66
+ def _get_fourier_filter(self, size):
67
+ n = torch.cat(
68
+ [torch.arange(1, size / 2 + 1, 2), torch.arange(size / 2 - 1, 0, -2)]
69
+ )
70
+
71
+ f = torch.zeros(size, dtype=self.dtype, device=self.device)
72
+ f[0] = 0.25
73
+ f[1::2] = -1 / (PI * n) ** 2
74
+
75
+ fourier_filter = torch.view_as_real(torch.fft.fft(f, dim=-1))
76
+ fourier_filter[:, 1] = fourier_filter[:, 0]
77
+
78
+ return 2 * fourier_filter
79
+
80
+ def create_filter(self, f):
81
+ raise NotImplementedError
82
+
83
+
84
+ class RampFilter(AbstractFilter):
85
+ def __init__(self, **kwargs):
86
+ super(RampFilter, self).__init__(**kwargs)
87
+
88
+ def create_filter(self, f):
89
+ return f
90
+
91
+
92
+ class Radon(nn.Module):
93
+ def __init__(
94
+ self,
95
+ in_size=None,
96
+ theta=None,
97
+ circle=False,
98
+ dtype=torch.float,
99
+ device=torch.device("cpu"),
100
+ ):
101
+ super().__init__()
102
+ self.circle = circle
103
+ self.theta = theta
104
+ if theta is None:
105
+ self.theta = torch.arange(180)
106
+ self.dtype = dtype
107
+ self.all_grids = None
108
+ if in_size is not None:
109
+ self.all_grids = self._create_grids(self.theta, in_size, circle).to(device)
110
+
111
+ def forward(self, x):
112
+ N, C, W, H = x.shape
113
+ assert W == H, "Input image must be square"
114
+
115
+ if (
116
+ self.all_grids is None
117
+ ): # if in_size was not given, we have to create the grid online.
118
+ self.all_grids = self._create_grids(
119
+ self.theta, W, self.circle, device=x.device
120
+ )
121
+
122
+ if not self.circle:
123
+ diagonal = SQRT2 * W
124
+ pad = int((diagonal - W).ceil())
125
+ new_center = (W + pad) // 2
126
+ old_center = W // 2
127
+ pad_before = new_center - old_center
128
+ pad_width = (pad_before, pad - pad_before)
129
+ x = F.pad(x, (pad_width[0], pad_width[1], pad_width[0], pad_width[1]))
130
+
131
+ N, C, W, _ = x.shape
132
+ out = torch.zeros(N, C, W, len(self.theta), device=x.device, dtype=self.dtype)
133
+
134
+ for i in range(len(self.theta)):
135
+ rotated = grid_sample(x, self.all_grids[i].repeat(N, 1, 1, 1).to(x.device))
136
+ out[..., i] = rotated.sum(2)
137
+ return out
138
+
139
+ def _create_grids(self, angles, grid_size, circle, device="cpu"):
140
+ if not circle:
141
+ grid_size = int((SQRT2 * grid_size).ceil())
142
+ all_grids = []
143
+ for theta in angles:
144
+ theta = deg2rad(theta)
145
+ R = torch.tensor(
146
+ [[[theta.cos(), theta.sin(), 0], [-theta.sin(), theta.cos(), 0]]],
147
+ dtype=self.dtype,
148
+ device=device,
149
+ )
150
+ all_grids.append(affine_grid(R, torch.Size([1, 1, grid_size, grid_size])))
151
+ return torch.stack(all_grids)
152
+
153
+
154
+ class IRadon(nn.Module):
155
+ def __init__(
156
+ self,
157
+ in_size=None,
158
+ theta=None,
159
+ circle=False,
160
+ use_filter=True,
161
+ out_size=None,
162
+ dtype=torch.float,
163
+ device=torch.device("cpu"),
164
+ ):
165
+ super().__init__()
166
+ self.circle = circle
167
+ self.device = device
168
+ self.theta = theta if theta is not None else torch.arange(180).to(self.device)
169
+ self.out_size = out_size
170
+ self.in_size = in_size
171
+ self.dtype = dtype
172
+ self.ygrid, self.xgrid, self.all_grids = None, None, None
173
+ if in_size is not None:
174
+ self.ygrid, self.xgrid = self._create_yxgrid(in_size, circle)
175
+ self.all_grids = self._create_grids(self.theta, in_size, circle).to(
176
+ self.device
177
+ )
178
+ self.filter = (
179
+ RampFilter(dtype=self.dtype, device=self.device)
180
+ if use_filter
181
+ else lambda x: x
182
+ )
183
+
184
+ def forward(self, x, filtering=True):
185
+ it_size = x.shape[2]
186
+ ch_size = x.shape[1]
187
+
188
+ if self.in_size is None:
189
+ self.in_size = (
190
+ int((it_size / SQRT2).floor()) if not self.circle else it_size
191
+ )
192
+ # if None in [self.ygrid, self.xgrid, self.all_grids]:
193
+ if self.ygrid is None or self.xgrid is None or self.all_grids is None:
194
+ self.ygrid, self.xgrid = self._create_yxgrid(self.in_size, self.circle)
195
+ self.all_grids = self._create_grids(self.theta, self.in_size, self.circle)
196
+
197
+ x = self.filter(x) if filtering else x
198
+
199
+ reco = torch.zeros(
200
+ x.shape[0], ch_size, it_size, it_size, device=self.device, dtype=self.dtype
201
+ )
202
+ for i_theta in range(len(self.theta)):
203
+ reco += grid_sample(
204
+ x, self.all_grids[i_theta].repeat(reco.shape[0], 1, 1, 1)
205
+ )
206
+
207
+ if not self.circle:
208
+ W = self.in_size
209
+ diagonal = it_size
210
+ pad = int(torch.tensor(diagonal - W, dtype=torch.float).ceil())
211
+ new_center = (W + pad) // 2
212
+ old_center = W // 2
213
+ pad_before = new_center - old_center
214
+ pad_width = (pad_before, pad - pad_before)
215
+ reco = F.pad(
216
+ reco, (-pad_width[0], -pad_width[1], -pad_width[0], -pad_width[1])
217
+ )
218
+
219
+ if self.circle:
220
+ reconstruction_circle = (self.xgrid**2 + self.ygrid**2) <= 1
221
+ reconstruction_circle = reconstruction_circle.repeat(
222
+ x.shape[0], ch_size, 1, 1
223
+ )
224
+ reco[~reconstruction_circle] = 0.0
225
+
226
+ reco = reco * PI.item() / (2 * len(self.theta))
227
+
228
+ if self.out_size is not None:
229
+ pad = (self.out_size - self.in_size) // 2
230
+ reco = F.pad(reco, (pad, pad, pad, pad))
231
+
232
+ return reco
233
+
234
+ def _create_yxgrid(self, in_size, circle):
235
+ if not circle:
236
+ in_size = int((SQRT2 * in_size).ceil())
237
+ unitrange = torch.linspace(-1, 1, in_size, dtype=self.dtype, device=self.device)
238
+ return torch.meshgrid(unitrange, unitrange, indexing="ij")
239
+
240
+ def _XYtoT(self, theta):
241
+ T = self.xgrid * (deg2rad(theta)).cos() - self.ygrid * (deg2rad(theta)).sin()
242
+ return T
243
+
244
+ def _create_grids(self, angles, grid_size, circle):
245
+ if not circle:
246
+ grid_size = int((SQRT2 * grid_size).ceil())
247
+ all_grids = []
248
+ for i_theta in range(len(angles)):
249
+ X = (
250
+ torch.ones(grid_size, dtype=self.dtype, device=self.device)
251
+ .view(-1, 1)
252
+ .repeat(1, grid_size)
253
+ * i_theta
254
+ * 2.0
255
+ / (len(angles) - 1)
256
+ - 1.0
257
+ )
258
+ Y = self._XYtoT(angles[i_theta])
259
+ all_grids.append(
260
+ torch.cat((X.unsqueeze(-1), Y.unsqueeze(-1)), dim=-1).unsqueeze(0)
261
+ )
262
+ return torch.stack(all_grids)
263
+
264
+
265
+ class Tomography(LinearPhysics):
266
+ r"""
267
+ (Computed) Tomography operator.
268
+
269
+ The Radon transform is the integral transform which takes a square image :math:`x` defined on the plane to a function
270
+ :math:`y=Rx` defined on the (two-dimensional) space of lines in the plane, whose value at a particular line is equal
271
+ to the line integral of the function over that line.
272
+
273
+ .. note::
274
+
275
+ The pseudo-inverse is computed using the filtered back-projection algorithm with a Ramp filter.
276
+ This is not the exact linear pseudo-inverse of the Radon transform, but it is a good approximation which is
277
+ robust to noise.
278
+
279
+ .. warning::
280
+
281
+ The adjoint operator has small numerical errors due to interpolation.
282
+
283
+ :param int img_width: width/height of the square image input.
284
+ :param int, torch.tensor angles: If the type is ``int``, the angles are sampled uniformly between 0 and 360 degrees.
285
+ If the type is ``torch.tensor``, the angles are the ones provided (e.g., ``torch.linspace(0, 180, steps=10)``).
286
+ :param bool circle: If ``True`` both forward and backward projection will be restricted to pixels inside a circle
287
+ inscribed in the square image.
288
+ :param str device: gpu or cpu.
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ img_width,
294
+ angles,
295
+ circle=False,
296
+ device=torch.device("cpu"),
297
+ dtype=torch.float,
298
+ **kwargs,
299
+ ):
300
+ super().__init__(**kwargs)
301
+
302
+ if isinstance(angles, int) or isinstance(angles, float):
303
+ theta = torch.linspace(0, 180, steps=angles + 1, device=device)[:-1]
304
+ else:
305
+ theta = angles.to(device)
306
+
307
+ self.radon = Radon(
308
+ img_width, theta, circle=circle, device=device, dtype=dtype
309
+ ).to(device)
310
+ self.iradon = IRadon(
311
+ img_width, theta, circle=circle, device=device, dtype=dtype
312
+ ).to(device)
313
+
314
+ def A(self, x):
315
+ return self.radon(x)
316
+
317
+ def A_dagger(self, y):
318
+ return self.iradon(y)
319
+
320
+ def A_adjoint(self, y):
321
+ return self.iradon(y, filtering=False)
@@ -0,0 +1,2 @@
1
+ from .langevin import MonteCarlo, ULA, SKRock
2
+ from .diffusion import DDRM, DiffusionSampler, DiffPIR, DPS