deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/physics/blur.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
1
|
+
from torchvision.transforms.functional import rotate
|
|
2
|
+
import torchvision
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torch
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch.fft as fft
|
|
7
|
+
from deepinv.physics.forward import Physics, LinearPhysics, DecomposablePhysics
|
|
8
|
+
from deepinv.utils import TensorList
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def filter_fft(filter, img_size, real_fft=True):
|
|
12
|
+
ph = int((filter.shape[2] - 1) / 2)
|
|
13
|
+
pw = int((filter.shape[3] - 1) / 2)
|
|
14
|
+
|
|
15
|
+
filt2 = torch.zeros(filter.shape[:2] + img_size[-2:], device=filter.device)
|
|
16
|
+
|
|
17
|
+
filt2[:, : filter.shape[1], : filter.shape[2], : filter.shape[3]] = filter
|
|
18
|
+
filt2 = torch.roll(filt2, shifts=(-ph, -pw), dims=(2, 3))
|
|
19
|
+
|
|
20
|
+
return fft.rfft2(filt2) if real_fft else fft.fft2(filt2)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def gaussian_blur(sigma=(1, 1), angle=0):
|
|
24
|
+
r"""
|
|
25
|
+
Gaussian blur filter.
|
|
26
|
+
|
|
27
|
+
:param float, tuple[float] sigma: standard deviation of the gaussian filter. If sigma is a float the filter is isotropic, whereas
|
|
28
|
+
if sigma is a tuple of floats (sigma_x, sigma_y) the filter is anisotropic.
|
|
29
|
+
:param float angle: rotation angle of the filter in degrees (only useful for anisotropic filters)
|
|
30
|
+
"""
|
|
31
|
+
if isinstance(sigma, (int, float)):
|
|
32
|
+
sigma = (sigma, sigma)
|
|
33
|
+
|
|
34
|
+
s = max(sigma)
|
|
35
|
+
c = int(s / 0.3 + 1)
|
|
36
|
+
k_size = 2 * c + 1
|
|
37
|
+
|
|
38
|
+
delta = torch.arange(k_size)
|
|
39
|
+
|
|
40
|
+
x, y = torch.meshgrid(delta, delta, indexing="ij")
|
|
41
|
+
x = x - c
|
|
42
|
+
y = y - c
|
|
43
|
+
filt = (x / sigma[0]).pow(2)
|
|
44
|
+
filt += (y / sigma[1]).pow(2)
|
|
45
|
+
filt = torch.exp(-filt / 2.0)
|
|
46
|
+
|
|
47
|
+
filt = (
|
|
48
|
+
rotate(
|
|
49
|
+
filt.unsqueeze(0).unsqueeze(0),
|
|
50
|
+
angle,
|
|
51
|
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
|
52
|
+
)
|
|
53
|
+
.squeeze(0)
|
|
54
|
+
.squeeze(0)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
filt = filt / filt.flatten().sum()
|
|
58
|
+
|
|
59
|
+
return filt.unsqueeze(0).unsqueeze(0)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def bilinear_filter(factor=2):
|
|
63
|
+
x = np.arange(start=-factor + 0.5, stop=factor, step=1) / factor
|
|
64
|
+
w = 1 - np.abs(x)
|
|
65
|
+
w = np.outer(w, w)
|
|
66
|
+
w = w / np.sum(w)
|
|
67
|
+
return torch.Tensor(w).unsqueeze(0).unsqueeze(0)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def bicubic_filter(factor=2):
|
|
71
|
+
x = np.arange(start=-2 * factor + 0.5, stop=2 * factor, step=1) / factor
|
|
72
|
+
a = -0.5
|
|
73
|
+
x = np.abs(x)
|
|
74
|
+
w = ((a + 2) * np.power(x, 3) - (a + 3) * np.power(x, 2) + 1) * (x <= 1)
|
|
75
|
+
w += (
|
|
76
|
+
(a * np.power(x, 3) - 5 * a * np.power(x, 2) + 8 * a * x - 4 * a)
|
|
77
|
+
* (x > 1)
|
|
78
|
+
* (x < 2)
|
|
79
|
+
)
|
|
80
|
+
w = np.outer(w, w)
|
|
81
|
+
w = w / np.sum(w)
|
|
82
|
+
return torch.Tensor(w).unsqueeze(0).unsqueeze(0)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class Downsampling(LinearPhysics):
|
|
86
|
+
r"""
|
|
87
|
+
Downsampling operator for super-resolution problems.
|
|
88
|
+
|
|
89
|
+
It is defined as
|
|
90
|
+
|
|
91
|
+
.. math::
|
|
92
|
+
|
|
93
|
+
y = S (h*x)
|
|
94
|
+
|
|
95
|
+
where :math:`h` is a low-pass filter and :math:`S` is a subsampling operator.
|
|
96
|
+
|
|
97
|
+
:param tuple[int] img_size: size of the input image
|
|
98
|
+
:param int factor: downsampling factor
|
|
99
|
+
:param torch.Tensor, str, NoneType filter: Downsampling filter. It can be 'gaussian', 'bilinear' or 'bicubic' or a
|
|
100
|
+
custom ``torch.Tensor`` filter. If ``None``, no filtering is applied.
|
|
101
|
+
:param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``.
|
|
102
|
+
If ``padding='valid'`` the blurred output is smaller than the image (no padding)
|
|
103
|
+
otherwise the blurred output has the same size as the image.
|
|
104
|
+
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
img_size,
|
|
110
|
+
factor=2,
|
|
111
|
+
filter="gaussian",
|
|
112
|
+
device="cpu",
|
|
113
|
+
padding="circular",
|
|
114
|
+
**kwargs,
|
|
115
|
+
):
|
|
116
|
+
super().__init__(**kwargs)
|
|
117
|
+
self.factor = factor
|
|
118
|
+
assert isinstance(factor, int), "downsampling factor should be an integer"
|
|
119
|
+
self.imsize = img_size
|
|
120
|
+
self.padding = padding
|
|
121
|
+
if isinstance(filter, torch.Tensor):
|
|
122
|
+
self.filter = filter.to(device)
|
|
123
|
+
elif filter is None:
|
|
124
|
+
self.filter = filter
|
|
125
|
+
elif filter == "gaussian":
|
|
126
|
+
self.filter = (
|
|
127
|
+
gaussian_blur(sigma=(factor, factor)).requires_grad_(False).to(device)
|
|
128
|
+
)
|
|
129
|
+
elif filter == "bilinear":
|
|
130
|
+
self.filter = bilinear_filter(self.factor).requires_grad_(False).to(device)
|
|
131
|
+
elif filter == "bicubic":
|
|
132
|
+
self.filter = bicubic_filter(self.factor).requires_grad_(False).to(device)
|
|
133
|
+
else:
|
|
134
|
+
raise Exception("The chosen downsampling filter doesn't exist")
|
|
135
|
+
|
|
136
|
+
if self.filter is not None:
|
|
137
|
+
self.Fh = filter_fft(self.filter, img_size, real_fft=False).to(device)
|
|
138
|
+
self.Fhc = torch.conj(self.Fh)
|
|
139
|
+
self.Fh2 = self.Fhc * self.Fh
|
|
140
|
+
self.filter = torch.nn.Parameter(self.filter, requires_grad=False)
|
|
141
|
+
self.Fhc = torch.nn.Parameter(self.Fhc, requires_grad=False)
|
|
142
|
+
self.Fh2 = torch.nn.Parameter(self.Fh2, requires_grad=False)
|
|
143
|
+
|
|
144
|
+
def A(self, x):
|
|
145
|
+
if self.filter is not None:
|
|
146
|
+
x = conv(x, self.filter, padding=self.padding)
|
|
147
|
+
x = x[:, :, :: self.factor, :: self.factor] # downsample
|
|
148
|
+
return x
|
|
149
|
+
|
|
150
|
+
def A_adjoint(self, y):
|
|
151
|
+
x = torch.zeros((y.shape[0],) + self.imsize, device=y.device)
|
|
152
|
+
x[:, :, :: self.factor, :: self.factor] = y # upsample
|
|
153
|
+
if self.filter is not None:
|
|
154
|
+
x = conv_transpose(x, self.filter, padding=self.padding)
|
|
155
|
+
return x
|
|
156
|
+
|
|
157
|
+
def prox_l2(self, z, y, gamma, use_fft=True):
|
|
158
|
+
r"""
|
|
159
|
+
If the padding is circular, it computes the proximal operator with the closed-formula of
|
|
160
|
+
https://arxiv.org/abs/1510.00143.
|
|
161
|
+
|
|
162
|
+
Otherwise, it computes it using the conjugate gradient algorithm which can be slow if applied many times.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
if use_fft and self.padding == "circular": # Formula from (Zhao, 2016)
|
|
166
|
+
z_hat = self.A_adjoint(y) + 1 / gamma * z
|
|
167
|
+
Fz_hat = fft.fft2(z_hat)
|
|
168
|
+
|
|
169
|
+
def splits(a, sf):
|
|
170
|
+
"""split a into sfxsf distinct blocks
|
|
171
|
+
Args:
|
|
172
|
+
a: NxCxWxH
|
|
173
|
+
sf: split factor
|
|
174
|
+
Returns:
|
|
175
|
+
b: NxCx(W/sf)x(H/sf)x(sf^2)
|
|
176
|
+
"""
|
|
177
|
+
b = torch.stack(torch.chunk(a, sf, dim=2), dim=4)
|
|
178
|
+
b = torch.cat(torch.chunk(b, sf, dim=3), dim=4)
|
|
179
|
+
return b
|
|
180
|
+
|
|
181
|
+
top = torch.mean(splits(self.Fh * Fz_hat, self.factor), dim=-1)
|
|
182
|
+
below = torch.mean(splits(self.Fh2, self.factor), dim=-1) + 1 / gamma
|
|
183
|
+
rc = self.Fhc * (top / below).repeat(1, 1, self.factor, self.factor)
|
|
184
|
+
r = torch.real(fft.ifft2(rc))
|
|
185
|
+
return (z_hat - r) * gamma
|
|
186
|
+
else:
|
|
187
|
+
return LinearPhysics.prox_l2(self, z, y, gamma)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def extend_filter(filter):
|
|
191
|
+
b, c, h, w = filter.shape
|
|
192
|
+
w_new = w
|
|
193
|
+
h_new = h
|
|
194
|
+
|
|
195
|
+
offset_w = 0
|
|
196
|
+
offset_h = 0
|
|
197
|
+
|
|
198
|
+
if w == 1:
|
|
199
|
+
w_new = 3
|
|
200
|
+
offset_w = 1
|
|
201
|
+
elif w % 2 == 0:
|
|
202
|
+
w_new += 1
|
|
203
|
+
|
|
204
|
+
if h == 1:
|
|
205
|
+
h_new = 3
|
|
206
|
+
offset_h = 1
|
|
207
|
+
elif h % 2 == 0:
|
|
208
|
+
h_new += 1
|
|
209
|
+
|
|
210
|
+
out = torch.zeros((b, c, h_new, w_new), device=filter.device)
|
|
211
|
+
out[:, :, offset_h : h + offset_h, offset_w : w + offset_w] = filter
|
|
212
|
+
return out
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def conv(x, filter, padding):
|
|
216
|
+
r"""
|
|
217
|
+
Convolution of x and filter. The transposed of this operation is conv_transpose(x, filter, padding)
|
|
218
|
+
|
|
219
|
+
:param x: (torch.Tensor) Image of size (B,C,W,H).
|
|
220
|
+
:param filter: (torch.Tensor) Filter of size (1,C,W,H) for colour filtering or (1,1,W,H) for filtering each channel with the same filter.
|
|
221
|
+
:param padding: (string) options = 'valid','circular','replicate','reflect'. If padding='valid' the blurred output is smaller than the image (no padding), otherwise the blurred output has the same size as the image.
|
|
222
|
+
|
|
223
|
+
"""
|
|
224
|
+
b, c, h, w = x.shape
|
|
225
|
+
|
|
226
|
+
filter = filter.flip(-1).flip(
|
|
227
|
+
-2
|
|
228
|
+
) # In order to perform convolution and not correlation like Pytorch native conv
|
|
229
|
+
|
|
230
|
+
filter = extend_filter(filter)
|
|
231
|
+
|
|
232
|
+
ph = (filter.shape[2] - 1) / 2
|
|
233
|
+
pw = (filter.shape[3] - 1) / 2
|
|
234
|
+
|
|
235
|
+
if padding == "valid":
|
|
236
|
+
h_out = int(h - 2 * ph)
|
|
237
|
+
w_out = int(w - 2 * pw)
|
|
238
|
+
else:
|
|
239
|
+
h_out = h
|
|
240
|
+
w_out = w
|
|
241
|
+
pw = int(pw)
|
|
242
|
+
ph = int(ph)
|
|
243
|
+
x = F.pad(x, (pw, pw, ph, ph), mode=padding, value=0)
|
|
244
|
+
|
|
245
|
+
if filter.shape[1] == 1:
|
|
246
|
+
y = torch.zeros((b, c, h_out, w_out), device=x.device)
|
|
247
|
+
for i in range(b):
|
|
248
|
+
for j in range(c):
|
|
249
|
+
y[i, j, :, :] = F.conv2d(
|
|
250
|
+
x[i, j, :, :].unsqueeze(0).unsqueeze(1), filter, padding="valid"
|
|
251
|
+
).unsqueeze(1)
|
|
252
|
+
else:
|
|
253
|
+
y = F.conv2d(x, filter, padding="valid")
|
|
254
|
+
|
|
255
|
+
return y
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def conv_transpose(y, filter, padding):
|
|
259
|
+
r"""
|
|
260
|
+
Tranposed convolution of x and filter. The transposed of this operation is conv(x, filter, padding)
|
|
261
|
+
|
|
262
|
+
:param torch.tensor x: Image of size (B,C,W,H).
|
|
263
|
+
:param torch.tensor filter: Filter of size (1,C,W,H) for colour filtering or (1,C,W,H) for filtering each channel with the same filter.
|
|
264
|
+
:param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``.
|
|
265
|
+
If ``padding='valid'`` the blurred output is smaller than the image (no padding)
|
|
266
|
+
otherwise the blurred output has the same size as the image.
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
b, c, h, w = y.shape
|
|
270
|
+
|
|
271
|
+
filter = filter.flip(-1).flip(
|
|
272
|
+
-2
|
|
273
|
+
) # In order to perform convolution and not correlation like Pytorch native conv
|
|
274
|
+
|
|
275
|
+
filter = extend_filter(filter)
|
|
276
|
+
|
|
277
|
+
ph = (filter.shape[2] - 1) / 2
|
|
278
|
+
pw = (filter.shape[3] - 1) / 2
|
|
279
|
+
|
|
280
|
+
h_out = int(h + 2 * ph)
|
|
281
|
+
w_out = int(w + 2 * pw)
|
|
282
|
+
pw = int(pw)
|
|
283
|
+
ph = int(ph)
|
|
284
|
+
|
|
285
|
+
x = torch.zeros((b, c, h_out, w_out), device=y.device)
|
|
286
|
+
if filter.shape[1] == 1:
|
|
287
|
+
for i in range(b):
|
|
288
|
+
if filter.shape[0] > 1:
|
|
289
|
+
f = filter[i, :, :, :].unsqueeze(0)
|
|
290
|
+
else:
|
|
291
|
+
f = filter
|
|
292
|
+
|
|
293
|
+
for j in range(c):
|
|
294
|
+
x[i, j, :, :] = F.conv_transpose2d(
|
|
295
|
+
y[i, j, :, :].unsqueeze(0).unsqueeze(1), f
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
x = F.conv_transpose2d(y, filter)
|
|
299
|
+
|
|
300
|
+
if padding == "valid":
|
|
301
|
+
out = x
|
|
302
|
+
elif padding == "zero":
|
|
303
|
+
out = x[:, :, ph:-ph, pw:-pw]
|
|
304
|
+
elif padding == "circular":
|
|
305
|
+
out = x[:, :, ph:-ph, pw:-pw]
|
|
306
|
+
# sides
|
|
307
|
+
out[:, :, :ph, :] += x[:, :, -ph:, pw:-pw]
|
|
308
|
+
out[:, :, -ph:, :] += x[:, :, :ph, pw:-pw]
|
|
309
|
+
out[:, :, :, :pw] += x[:, :, ph:-ph, -pw:]
|
|
310
|
+
out[:, :, :, -pw:] += x[:, :, ph:-ph, :pw]
|
|
311
|
+
# corners
|
|
312
|
+
out[:, :, :ph, :pw] += x[:, :, -ph:, -pw:]
|
|
313
|
+
out[:, :, -ph:, -pw:] += x[:, :, :ph, :pw]
|
|
314
|
+
out[:, :, :ph, -pw:] += x[:, :, -ph:, :pw]
|
|
315
|
+
out[:, :, -ph:, :pw] += x[:, :, :ph, -pw:]
|
|
316
|
+
|
|
317
|
+
elif padding == "reflect":
|
|
318
|
+
out = x[:, :, ph:-ph, pw:-pw]
|
|
319
|
+
# sides
|
|
320
|
+
out[:, :, 1 : 1 + ph, :] += x[:, :, :ph, pw:-pw].flip(dims=(2,))
|
|
321
|
+
out[:, :, -ph - 1 : -1, :] += x[:, :, -ph:, pw:-pw].flip(dims=(2,))
|
|
322
|
+
out[:, :, :, 1 : 1 + pw] += x[:, :, ph:-ph, :pw].flip(dims=(3,))
|
|
323
|
+
out[:, :, :, -pw - 1 : -1] += x[:, :, ph:-ph, -pw:].flip(dims=(3,))
|
|
324
|
+
# corners
|
|
325
|
+
out[:, :, 1 : 1 + ph, 1 : 1 + pw] += x[:, :, :ph, :pw].flip(dims=(2, 3))
|
|
326
|
+
out[:, :, -ph - 1 : -1, -pw - 1 : -1] += x[:, :, -ph:, -pw:].flip(dims=(2, 3))
|
|
327
|
+
out[:, :, -ph - 1 : -1, 1 : 1 + pw] += x[:, :, -ph:, :pw].flip(dims=(2, 3))
|
|
328
|
+
out[:, :, 1 : 1 + ph, -pw - 1 : -1] += x[:, :, :ph, -pw:].flip(dims=(2, 3))
|
|
329
|
+
|
|
330
|
+
elif padding == "replicate":
|
|
331
|
+
out = x[:, :, ph:-ph, pw:-pw]
|
|
332
|
+
# sides
|
|
333
|
+
out[:, :, 0, :] += x[:, :, :ph, pw:-pw].sum(2)
|
|
334
|
+
out[:, :, -1, :] += x[:, :, -ph:, pw:-pw].sum(2)
|
|
335
|
+
out[:, :, :, 0] += x[:, :, ph:-ph, :pw].sum(3)
|
|
336
|
+
out[:, :, :, -1] += x[:, :, ph:-ph, -pw:].sum(3)
|
|
337
|
+
# corners
|
|
338
|
+
out[:, :, 0, 0] += x[:, :, :ph, :pw].sum(3).sum(2)
|
|
339
|
+
out[:, :, -1, -1] += x[:, :, -ph:, -pw:].sum(3).sum(2)
|
|
340
|
+
out[:, :, -1, 0] += x[:, :, -ph:, :pw].sum(3).sum(2)
|
|
341
|
+
out[:, :, 0, -1] += x[:, :, :ph, -pw:].sum(3).sum(2)
|
|
342
|
+
return out
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class BlindBlur(Physics):
|
|
346
|
+
r"""
|
|
347
|
+
Blind blur operator.
|
|
348
|
+
|
|
349
|
+
If performs
|
|
350
|
+
|
|
351
|
+
.. math::
|
|
352
|
+
|
|
353
|
+
y = w*x
|
|
354
|
+
|
|
355
|
+
where :math:`*` denotes convolution and :math:`w` is an unknown filter.
|
|
356
|
+
This class uses ``torch.conv2d`` for performing the convolutions.
|
|
357
|
+
|
|
358
|
+
The signal is described by a tuple (x,w) where the first element is the clean image, and the second element
|
|
359
|
+
is the blurring kernel. The measurements y are a tensor representing the convolution of x and w.
|
|
360
|
+
|
|
361
|
+
:param int kernel_size: maximum support size of the (unknown) blurring kernels.
|
|
362
|
+
:param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``.
|
|
363
|
+
If ``padding='valid'`` the blurred output is smaller than the image (no padding)
|
|
364
|
+
otherwise the blurred output has the same size as the image.
|
|
365
|
+
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
def __init__(self, kernel_size=3, padding="circular", **kwargs):
|
|
369
|
+
super().__init__(**kwargs)
|
|
370
|
+
self.padding = padding
|
|
371
|
+
|
|
372
|
+
if type(kernel_size) is not list or type(kernel_size) is not tuple:
|
|
373
|
+
self.kernel_size = [kernel_size, kernel_size]
|
|
374
|
+
|
|
375
|
+
def A(self, s):
|
|
376
|
+
r"""
|
|
377
|
+
|
|
378
|
+
:param tuple, list, deepinv.utils.ListTensor x: List containing two torch.tensor, x[0] with the image and x[1] with the filter.
|
|
379
|
+
:return: (torch.tensor) blurred measurement.
|
|
380
|
+
"""
|
|
381
|
+
x = s[0]
|
|
382
|
+
w = s[1]
|
|
383
|
+
return conv(x, w, self.padding)
|
|
384
|
+
|
|
385
|
+
def A_dagger(self, y):
|
|
386
|
+
r"""
|
|
387
|
+
Returns the trivial inverse where x[0] = blurry input and x[1] with a delta filter, such that
|
|
388
|
+
the convolution of x[0] and x[1] is y.
|
|
389
|
+
|
|
390
|
+
.. note:
|
|
391
|
+
|
|
392
|
+
This trivial inverse can be useful for some reconstruction networks, such as ``deepinv.models.ArtifactRemoval``.
|
|
393
|
+
|
|
394
|
+
:param torch.tensor y: blurred measurement.
|
|
395
|
+
:return: Tuple containing the trivial inverse.
|
|
396
|
+
"""
|
|
397
|
+
x = y.clone()
|
|
398
|
+
mid_h = int(self.kernel_size[0] / 2)
|
|
399
|
+
mid_w = int(self.kernel_size[1] / 2)
|
|
400
|
+
w = torch.zeros((y.shape[0], 1, self.kernel_size[0], self.kernel_size[1]))
|
|
401
|
+
w[:, :, mid_h, mid_w] = 1.0
|
|
402
|
+
|
|
403
|
+
return TensorList([x, w])
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
class Blur(LinearPhysics):
|
|
407
|
+
r"""
|
|
408
|
+
|
|
409
|
+
Blur operator.
|
|
410
|
+
|
|
411
|
+
This forward operator performs
|
|
412
|
+
|
|
413
|
+
.. math:: y = w*x
|
|
414
|
+
|
|
415
|
+
where :math:`*` denotes convolution and :math:`w` is a filter.
|
|
416
|
+
|
|
417
|
+
This class uses :meth:`torch.nn.functional.conv2d` for performing the convolutions.
|
|
418
|
+
|
|
419
|
+
:param torch.Tensor filter: Tensor of size (1, 1, H, W) or (1, C, H, W) containing the blur filter, e.g., :meth:`deepinv.physics.blur.gaussian_blur`.
|
|
420
|
+
:param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``. If ``padding='valid'`` the blurred output is smaller than the image (no padding)
|
|
421
|
+
otherwise the blurred output has the same size as the image.
|
|
422
|
+
:param str device: cpu or cuda.
|
|
423
|
+
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
def __init__(self, filter, padding="circular", device="cpu", **kwargs):
|
|
427
|
+
super().__init__(**kwargs)
|
|
428
|
+
self.padding = padding
|
|
429
|
+
self.device = device
|
|
430
|
+
self.filter = torch.nn.Parameter(filter, requires_grad=False).to(device)
|
|
431
|
+
|
|
432
|
+
def A(self, x):
|
|
433
|
+
return conv(x, self.filter, self.padding)
|
|
434
|
+
|
|
435
|
+
def A_adjoint(self, y):
|
|
436
|
+
return conv_transpose(y, self.filter, self.padding)
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
class BlurFFT(DecomposablePhysics):
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
FFT-based blur operator.
|
|
443
|
+
|
|
444
|
+
It performs the operation
|
|
445
|
+
|
|
446
|
+
.. math:: y = w*x
|
|
447
|
+
|
|
448
|
+
where :math:`*` denotes convolution and :math:`w` is a filter.
|
|
449
|
+
|
|
450
|
+
Blur operator based on ``torch.fft`` operations, which assumes a circular padding of the input, and allows for
|
|
451
|
+
the singular value decomposition via ``deepinv.Physics.DecomposablePhysics`` and has fast pseudo-inverse and prox operators.
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
:param tuple img_size: Input image size in the form (C, H, W).
|
|
456
|
+
:param torch.tensor filter: torch.Tensor of size (1, 1, H, W) or (1, C, H, W) containing the blur filter, e.g.,
|
|
457
|
+
:meth:`deepinv.physics.blur.gaussian_blur`.
|
|
458
|
+
:param str device: cpu or cuda
|
|
459
|
+
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
def __init__(self, img_size, filter, device="cpu", **kwargs):
|
|
463
|
+
super().__init__(**kwargs)
|
|
464
|
+
self.img_size = img_size
|
|
465
|
+
|
|
466
|
+
if img_size[0] > filter.shape[1]:
|
|
467
|
+
filter = filter.repeat(1, img_size[0], 1, 1)
|
|
468
|
+
|
|
469
|
+
self.mask = filter_fft(filter, img_size).to("cpu")
|
|
470
|
+
self.angle = torch.angle(self.mask)
|
|
471
|
+
self.angle = torch.exp(-1j * self.angle).to(device)
|
|
472
|
+
self.mask = torch.abs(self.mask).unsqueeze(-1)
|
|
473
|
+
self.mask = torch.cat([self.mask, self.mask], dim=-1)
|
|
474
|
+
|
|
475
|
+
self.mask = torch.nn.Parameter(self.mask, requires_grad=False).to(device)
|
|
476
|
+
|
|
477
|
+
def V_adjoint(self, x):
|
|
478
|
+
return torch.view_as_real(
|
|
479
|
+
fft.rfft2(x, norm="ortho")
|
|
480
|
+
) # make it a true SVD (see J. Romberg notes)
|
|
481
|
+
|
|
482
|
+
def U(self, x):
|
|
483
|
+
return fft.irfft2(
|
|
484
|
+
torch.view_as_complex(x) * self.angle, norm="ortho", s=self.img_size[-2:]
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def U_adjoint(self, x):
|
|
488
|
+
return torch.view_as_real(
|
|
489
|
+
fft.rfft2(x, norm="ortho") * torch.conj(self.angle)
|
|
490
|
+
) # make it a true SVD (see J.
|
|
491
|
+
# Romberg notes)
|
|
492
|
+
|
|
493
|
+
def V(self, x):
|
|
494
|
+
return fft.irfft2(torch.view_as_complex(x), norm="ortho", s=self.img_size[-2:])
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
# # test code
|
|
498
|
+
# if __name__ == "__main__":
|
|
499
|
+
# device = "cuda:0"
|
|
500
|
+
#
|
|
501
|
+
# import matplotlib.pyplot as plt
|
|
502
|
+
#
|
|
503
|
+
# device = "cuda:0"
|
|
504
|
+
# x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
|
|
505
|
+
# x = x.unsqueeze(0).float().to(device) / 255
|
|
506
|
+
# x = torchvision.transforms.Resize((160, 180))(x)
|
|
507
|
+
#
|
|
508
|
+
# sigma_noise = 0.0
|
|
509
|
+
# kernel = torch.zeros((1, 1, 15, 15), device=device)
|
|
510
|
+
# kernel[:, :, 7, :] = 1 / 15
|
|
511
|
+
# physics = Downsampling(img_size=x.shape[1:], filter="bilinear", device=device)
|
|
512
|
+
# physics2 = Blur(img_size=x.shape[1:], filter=kernel, device=device)
|
|
513
|
+
#
|
|
514
|
+
# y = physics(x)
|
|
515
|
+
# y2 = physics2(x)
|
|
516
|
+
#
|
|
517
|
+
# xhat = physics.V(physics.U_adjoint(y) / physics.mask)
|
|
518
|
+
# xhat2 = physics2.A_dagger(y2)
|
|
519
|
+
#
|
|
520
|
+
# print(xhat.shape)
|
|
521
|
+
# # print(physics.adjointness_test(x))
|
|
522
|
+
# print(torch.sum((y - y2).pow(2)))
|
|
523
|
+
# print(torch.sum((xhat - xhat2).pow(2)))
|
|
524
|
+
#
|
|
525
|
+
# print(torch.sum((x - xhat).pow(2)))
|
|
526
|
+
# print(torch.sum((x - xhat2).pow(2)))
|
|
527
|
+
#
|
|
528
|
+
# print(physics.compute_norm(x))
|
|
529
|
+
# print(physics.adjointness_test(x))
|
|
530
|
+
# xhat = physics.prox_l2(y, y, gamma=1.0)
|
|
531
|
+
#
|
|
532
|
+
# xhat = physics.A_dagger(y)
|
|
533
|
+
#
|
|
534
|
+
# plt.imshow(x.squeeze(0).permute(1, 2, 0).cpu().numpy())
|
|
535
|
+
# plt.show()
|
|
536
|
+
# plt.imshow(y.squeeze(0).permute(1, 2, 0).cpu().numpy())
|
|
537
|
+
# plt.show()
|
|
538
|
+
# plt.imshow(xhat.squeeze(0).permute(1, 2, 0).cpu().numpy())
|
|
539
|
+
# plt.show()
|
|
540
|
+
# plt.imshow(xhat2.squeeze(0).permute(1, 2, 0).cpu().numpy())
|
|
541
|
+
# plt.show()
|
|
542
|
+
#
|
|
543
|
+
# plt.imshow(physics.A(xhat).squeeze(0).permute(1, 2, 0).cpu().numpy())
|
|
544
|
+
# plt.show()
|