deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import pytest
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
import deepinv as dinv
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
MODEL_LIST_1_CHANNEL = [
|
|
9
|
+
"autoencoder",
|
|
10
|
+
"drunet",
|
|
11
|
+
"dncnn",
|
|
12
|
+
"median",
|
|
13
|
+
"tgv",
|
|
14
|
+
"waveletprior",
|
|
15
|
+
"waveletdict",
|
|
16
|
+
]
|
|
17
|
+
MODEL_LIST = MODEL_LIST_1_CHANNEL + [
|
|
18
|
+
"bm3d",
|
|
19
|
+
"gsdrunet",
|
|
20
|
+
"scunet",
|
|
21
|
+
"swinir",
|
|
22
|
+
"tv",
|
|
23
|
+
"unet",
|
|
24
|
+
"waveletdict_hard",
|
|
25
|
+
"waveletdict_topk",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def choose_denoiser(name, imsize):
|
|
30
|
+
if name.startswith("waveletdict") or name == "waveletprior":
|
|
31
|
+
pytest.importorskip(
|
|
32
|
+
"pytorch_wavelets",
|
|
33
|
+
reason="This test requires pytorch_wavelets. It should be "
|
|
34
|
+
"installed with `pip install "
|
|
35
|
+
"git+https://github.com/fbcotter/pytorch_wavelets.git`",
|
|
36
|
+
)
|
|
37
|
+
if name == "bm3d":
|
|
38
|
+
pytest.importorskip(
|
|
39
|
+
"bm3d",
|
|
40
|
+
reason="This test requires bm3d. It should be "
|
|
41
|
+
"installed with `pip install bm3d`",
|
|
42
|
+
)
|
|
43
|
+
if name in ("swinir", "scunet"):
|
|
44
|
+
pytest.importorskip(
|
|
45
|
+
"timm",
|
|
46
|
+
reason="This test requires timm. It should be "
|
|
47
|
+
"installed with `pip install timm`",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if name == "unet":
|
|
51
|
+
out = dinv.models.UNet(in_channels=imsize[0], out_channels=imsize[0])
|
|
52
|
+
elif name == "drunet":
|
|
53
|
+
out = dinv.models.DRUNet(in_channels=imsize[0], out_channels=imsize[0])
|
|
54
|
+
elif name == "scunet":
|
|
55
|
+
out = dinv.models.SCUNet(in_nc=imsize[0])
|
|
56
|
+
elif name == "gsdrunet":
|
|
57
|
+
out = dinv.models.GSDRUNet(in_channels=imsize[0], out_channels=imsize[0])
|
|
58
|
+
elif name == "bm3d":
|
|
59
|
+
out = dinv.models.BM3D()
|
|
60
|
+
elif name == "dncnn":
|
|
61
|
+
out = dinv.models.DnCNN(in_channels=imsize[0], out_channels=imsize[0])
|
|
62
|
+
elif name == "waveletprior":
|
|
63
|
+
out = dinv.models.WaveletPrior()
|
|
64
|
+
elif name == "waveletdict":
|
|
65
|
+
out = dinv.models.WaveletDict()
|
|
66
|
+
elif name == "waveletdict_hard":
|
|
67
|
+
out = dinv.models.WaveletDict(non_linearity="hard")
|
|
68
|
+
elif name == "waveletdict_topk":
|
|
69
|
+
out = dinv.models.WaveletDict(non_linearity="topk")
|
|
70
|
+
elif name == "tgv":
|
|
71
|
+
out = dinv.models.TGV(n_it_max=10)
|
|
72
|
+
elif name == "tv":
|
|
73
|
+
out = dinv.models.TV(n_it_max=10)
|
|
74
|
+
elif name == "median":
|
|
75
|
+
out = dinv.models.MedianFilter()
|
|
76
|
+
elif name == "autoencoder":
|
|
77
|
+
out = dinv.models.AutoEncoder(dim_input=imsize[0] * imsize[1] * imsize[2])
|
|
78
|
+
elif name == "swinir":
|
|
79
|
+
out = dinv.models.SwinIR(in_chans=imsize[0])
|
|
80
|
+
else:
|
|
81
|
+
raise Exception("Unknown denoiser")
|
|
82
|
+
|
|
83
|
+
return out
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@pytest.mark.parametrize("denoiser", MODEL_LIST)
|
|
87
|
+
def test_denoiser(imsize, device, denoiser):
|
|
88
|
+
model = choose_denoiser(denoiser, imsize).to(device)
|
|
89
|
+
|
|
90
|
+
torch.manual_seed(0)
|
|
91
|
+
sigma = 0.2
|
|
92
|
+
physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
|
|
93
|
+
x = torch.ones(imsize, device=device).unsqueeze(0)
|
|
94
|
+
y = physics(x)
|
|
95
|
+
x_hat = model(y, sigma)
|
|
96
|
+
|
|
97
|
+
assert x_hat.shape == x.shape
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def test_equivariant(imsize, device):
|
|
101
|
+
# 1. Check that the equivariance module is compatible with a denoiser
|
|
102
|
+
model = dinv.models.DRUNet(in_channels=imsize[0], out_channels=imsize[0])
|
|
103
|
+
|
|
104
|
+
model = dinv.models.EquivariantDenoiser(
|
|
105
|
+
model, transform="rotoflips", random=True
|
|
106
|
+
).to(device)
|
|
107
|
+
|
|
108
|
+
torch.manual_seed(0)
|
|
109
|
+
sigma = 0.2
|
|
110
|
+
physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
|
|
111
|
+
x = torch.ones(imsize, device=device).unsqueeze(0)
|
|
112
|
+
y = physics(x)
|
|
113
|
+
x_hat = model(y, sigma)
|
|
114
|
+
|
|
115
|
+
assert x_hat.shape == x.shape
|
|
116
|
+
|
|
117
|
+
# 2. Check that the equivariance module yields the identity when the denoiser is the identity
|
|
118
|
+
class DummyIdentity(torch.nn.Module):
|
|
119
|
+
def __init__(self):
|
|
120
|
+
super().__init__()
|
|
121
|
+
|
|
122
|
+
def forward(self, x, sigma):
|
|
123
|
+
return x
|
|
124
|
+
|
|
125
|
+
model_id = DummyIdentity()
|
|
126
|
+
|
|
127
|
+
list_transforms = ["rotations", "flips", "rotoflips"]
|
|
128
|
+
|
|
129
|
+
for transform in list_transforms:
|
|
130
|
+
for random in [True, False]:
|
|
131
|
+
model = dinv.models.EquivariantDenoiser(
|
|
132
|
+
model_id, transform=transform, random=random
|
|
133
|
+
).to(device)
|
|
134
|
+
|
|
135
|
+
x = torch.ones(imsize, device=device).unsqueeze(0)
|
|
136
|
+
y = physics(x)
|
|
137
|
+
y_hat = model(y, sigma)
|
|
138
|
+
|
|
139
|
+
assert torch.allclose(y, y_hat)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@pytest.mark.parametrize("denoiser", MODEL_LIST_1_CHANNEL)
|
|
143
|
+
def test_denoiser_1_channel(imsize_1_channel, device, denoiser):
|
|
144
|
+
model = choose_denoiser(denoiser, imsize_1_channel).to(device)
|
|
145
|
+
|
|
146
|
+
torch.manual_seed(0)
|
|
147
|
+
sigma = 0.2
|
|
148
|
+
physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
|
|
149
|
+
x = torch.ones(imsize_1_channel, device=device).unsqueeze(0)
|
|
150
|
+
y = physics(x)
|
|
151
|
+
|
|
152
|
+
x_hat = model(y, sigma)
|
|
153
|
+
|
|
154
|
+
assert x_hat.shape == x.shape
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_drunet_inputs(imsize_1_channel, device):
|
|
158
|
+
f = dinv.models.DRUNet(
|
|
159
|
+
in_channels=imsize_1_channel[0], out_channels=imsize_1_channel[0], device=device
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
torch.manual_seed(0)
|
|
163
|
+
sigma = 0.2
|
|
164
|
+
physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
|
|
165
|
+
x = torch.ones(imsize_1_channel, device=device).unsqueeze(0)
|
|
166
|
+
y = physics(x)
|
|
167
|
+
|
|
168
|
+
# Case 1: sigma is a float
|
|
169
|
+
x_hat = f(y, sigma)
|
|
170
|
+
assert x_hat.shape == x.shape
|
|
171
|
+
|
|
172
|
+
# Case 2: sigma is a torch tensor with batch dimension
|
|
173
|
+
batch_size = 3
|
|
174
|
+
x = torch.ones((batch_size, 1, 31, 37), device=device)
|
|
175
|
+
y = physics(x)
|
|
176
|
+
sigma_tensor = torch.tensor([sigma] * batch_size).to(device)
|
|
177
|
+
x_hat = f(y, sigma_tensor)
|
|
178
|
+
assert x_hat.shape == x.shape
|
|
179
|
+
|
|
180
|
+
# Case 3: image has shape mulitple of 8
|
|
181
|
+
x = torch.ones((3, 1, 32, 40), device=device)
|
|
182
|
+
y = physics(x)
|
|
183
|
+
x_hat = f(y, sigma_tensor)
|
|
184
|
+
assert x_hat.shape == x.shape
|
|
185
|
+
|
|
186
|
+
# Case 4: sigma is a tensor with no dimension
|
|
187
|
+
sigma_tensor = torch.tensor(sigma).to(device)
|
|
188
|
+
x_hat = f(y, sigma_tensor)
|
|
189
|
+
assert x_hat.shape == x.shape
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def test_diffunetmodel(imsize, device):
|
|
193
|
+
# This model is a bit different from others as not strictly a denoiser as such.
|
|
194
|
+
# The Ho et al. diffusion model only works for color, square image with powers of two in w, h.
|
|
195
|
+
# Smallest size accepted so far is (3, 32, 32), but probably not meaningful at that size since trained at 256x256.
|
|
196
|
+
|
|
197
|
+
from deepinv.models import DiffUNet
|
|
198
|
+
|
|
199
|
+
model = DiffUNet().to(device)
|
|
200
|
+
|
|
201
|
+
torch.manual_seed(0)
|
|
202
|
+
sigma = 0.2
|
|
203
|
+
physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
|
|
204
|
+
x = torch.ones((3, 32, 32), device=device).unsqueeze(
|
|
205
|
+
0
|
|
206
|
+
) # Testing the smallest size possible
|
|
207
|
+
y = physics(x)
|
|
208
|
+
|
|
209
|
+
timestep = torch.tensor([1]).to(
|
|
210
|
+
device
|
|
211
|
+
) # We pick a random timestep, goal is to check that model inference is ok.
|
|
212
|
+
x_hat = model(y, timestep)
|
|
213
|
+
x_hat = x_hat[:, :3, ...]
|
|
214
|
+
|
|
215
|
+
assert x_hat.shape == x.shape
|
|
216
|
+
|
|
217
|
+
# Now we check that the denoise_forward method works
|
|
218
|
+
x_hat = model(y, sigma)
|
|
219
|
+
assert x_hat.shape == x.shape
|
|
220
|
+
|
|
221
|
+
with pytest.raises(Exception):
|
|
222
|
+
# The following should raise an exception because type_t is not in ['noise_level', 'timestep'].
|
|
223
|
+
x_hat = model(y, sigma, type_t="wrong_type")
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def test_PDNet(imsize_1_channel, device):
|
|
227
|
+
# Tests the PDNet algorithm - this is an unfolded algorithm so it is tested on its own here.
|
|
228
|
+
from deepinv.optim.optimizers import CPIteration, fStep, gStep
|
|
229
|
+
from deepinv.optim import Prior, DataFidelity
|
|
230
|
+
from deepinv.models import PDNet_PrimalBlock, PDNet_DualBlock
|
|
231
|
+
from deepinv.unfolded import unfolded_builder
|
|
232
|
+
|
|
233
|
+
sigma = 0.2
|
|
234
|
+
physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(sigma))
|
|
235
|
+
x = torch.ones(imsize_1_channel, device=device).unsqueeze(0)
|
|
236
|
+
y = physics(x)
|
|
237
|
+
|
|
238
|
+
class PDNetIteration(CPIteration):
|
|
239
|
+
r"""Single iteration of learned primal dual.
|
|
240
|
+
We only redefine the fStep and gStep classes.
|
|
241
|
+
The forward method is inherited from the CPIteration class.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
def __init__(self, **kwargs):
|
|
245
|
+
super().__init__(**kwargs)
|
|
246
|
+
self.g_step = gStepPDNet(**kwargs)
|
|
247
|
+
self.f_step = fStepPDNet(**kwargs)
|
|
248
|
+
|
|
249
|
+
class fStepPDNet(fStep):
|
|
250
|
+
r"""
|
|
251
|
+
Dual update of the PDNet algorithm.
|
|
252
|
+
We write it as a proximal operator of the data fidelity term.
|
|
253
|
+
This proximal mapping is to be replaced by a trainable model.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def __init__(self, **kwargs):
|
|
257
|
+
super().__init__(**kwargs)
|
|
258
|
+
|
|
259
|
+
def forward(self, x, w, cur_data_fidelity, y, *args):
|
|
260
|
+
r"""
|
|
261
|
+
:param torch.Tensor x: Current first variable :math:`u`.
|
|
262
|
+
:param torch.Tensor w: Current second variable :math:`A z`.
|
|
263
|
+
:param deepinv.optim.data_fidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data fidelity term.
|
|
264
|
+
:param torch.Tensor y: Input data.
|
|
265
|
+
"""
|
|
266
|
+
return cur_data_fidelity.prox(x, w, y)
|
|
267
|
+
|
|
268
|
+
class gStepPDNet(gStep):
|
|
269
|
+
r"""
|
|
270
|
+
Primal update of the PDNet algorithm.
|
|
271
|
+
We write it as a proximal operator of the prior term.
|
|
272
|
+
This proximal mapping is to be replaced by a trainable model.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
def __init__(self, **kwargs):
|
|
276
|
+
super().__init__(**kwargs)
|
|
277
|
+
|
|
278
|
+
def forward(self, x, w, cur_prior, *args):
|
|
279
|
+
r"""
|
|
280
|
+
:param torch.Tensor x: Current first variable :math:`x`.
|
|
281
|
+
:param torch.Tensor w: Current second variable :math:`A^\top u`.
|
|
282
|
+
:param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
|
|
283
|
+
"""
|
|
284
|
+
return cur_prior.prox(x, w)
|
|
285
|
+
|
|
286
|
+
class PDNetPrior(Prior):
|
|
287
|
+
def __init__(self, model, *args, **kwargs):
|
|
288
|
+
super().__init__(*args, **kwargs)
|
|
289
|
+
self.model = model
|
|
290
|
+
|
|
291
|
+
def prox(self, x, w):
|
|
292
|
+
return self.model(x, w[:, 0:1, :, :])
|
|
293
|
+
|
|
294
|
+
class PDNetDataFid(DataFidelity):
|
|
295
|
+
def __init__(self, model, *args, **kwargs):
|
|
296
|
+
super().__init__(*args, **kwargs)
|
|
297
|
+
self.model = model
|
|
298
|
+
|
|
299
|
+
def prox(self, x, w, y):
|
|
300
|
+
return self.model(x, w[:, 1:2, :, :], y)
|
|
301
|
+
|
|
302
|
+
# Unrolled optimization algorithm parameters
|
|
303
|
+
max_iter = 5 # number of unfolded layers
|
|
304
|
+
|
|
305
|
+
# Set up the data fidelity term. Each layer has its own data fidelity module.
|
|
306
|
+
data_fidelity = [
|
|
307
|
+
PDNetDataFid(model=PDNet_DualBlock().to(device)) for i in range(max_iter)
|
|
308
|
+
]
|
|
309
|
+
|
|
310
|
+
# Set up the trainable prior. Each layer has its own prior module.
|
|
311
|
+
prior = [PDNetPrior(model=PDNet_PrimalBlock().to(device)) for i in range(max_iter)]
|
|
312
|
+
|
|
313
|
+
n_primal = 5 # extend the primal space
|
|
314
|
+
n_dual = 5 # extend the dual space
|
|
315
|
+
|
|
316
|
+
def custom_init(y, physics):
|
|
317
|
+
x0 = physics.A_dagger(y).repeat(1, n_primal, 1, 1)
|
|
318
|
+
u0 = torch.zeros_like(y).repeat(1, n_dual, 1, 1)
|
|
319
|
+
return {"est": (x0, x0, u0)}
|
|
320
|
+
|
|
321
|
+
def custom_output(X):
|
|
322
|
+
return X["est"][0][:, 1, :, :].unsqueeze(1)
|
|
323
|
+
|
|
324
|
+
# Define the unfolded trainable model.
|
|
325
|
+
model = unfolded_builder(
|
|
326
|
+
iteration=PDNetIteration(),
|
|
327
|
+
params_algo={"K": physics.A, "K_adjoint": physics.A_adjoint, "beta": 1.0},
|
|
328
|
+
data_fidelity=data_fidelity,
|
|
329
|
+
prior=prior,
|
|
330
|
+
max_iter=max_iter,
|
|
331
|
+
custom_init=custom_init,
|
|
332
|
+
get_output=custom_output,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
x_hat = model(y, physics)
|
|
336
|
+
|
|
337
|
+
assert x_hat.shape == x.shape
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
@pytest.mark.parametrize(
|
|
341
|
+
"denoiser, dep",
|
|
342
|
+
[
|
|
343
|
+
("BM3D", "bm3d"),
|
|
344
|
+
("SCUNet", "timm"),
|
|
345
|
+
("SwinIR", "timm"),
|
|
346
|
+
("WaveletPrior", "pytorch_wavelets"),
|
|
347
|
+
("WaveletDict", "pytorch_wavelets"),
|
|
348
|
+
],
|
|
349
|
+
)
|
|
350
|
+
def test_optional_dependencies(denoiser, dep):
|
|
351
|
+
# Skip the test if the optional dependency is installed
|
|
352
|
+
if dep in sys.modules:
|
|
353
|
+
pytest.skip(f"Optional dependency {dep} is installed.")
|
|
354
|
+
|
|
355
|
+
klass = getattr(dinv.models, denoiser)
|
|
356
|
+
with pytest.raises(ImportError, match=f"pip install .*{dep}"):
|
|
357
|
+
klass()
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# def test_dip(imsize, device): TODO: fix this test
|
|
361
|
+
# torch.manual_seed(0)
|
|
362
|
+
# channels = 64
|
|
363
|
+
# physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(0.2))
|
|
364
|
+
# f = dinv.models.DeepImagePrior(
|
|
365
|
+
# generator=dinv.models.ConvDecoder(imsize, layers=3, channels=channels).to(
|
|
366
|
+
# device
|
|
367
|
+
# ),
|
|
368
|
+
# input_size=(channels, imsize[1], imsize[2]),
|
|
369
|
+
# iterations=30,
|
|
370
|
+
# )
|
|
371
|
+
# x = torch.ones(imsize, device=device).unsqueeze(0)
|
|
372
|
+
# y = physics(x)
|
|
373
|
+
# mse_in = (y - x).pow(2).mean()
|
|
374
|
+
# x_net = f(y, physics)
|
|
375
|
+
# mse_out = (x_net - x).pow(2).mean()
|
|
376
|
+
#
|
|
377
|
+
# assert mse_out < mse_in
|