deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
import deepinv
|
|
7
|
+
from deepinv.tests.dummy_datasets.datasets import DummyCircles
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
import deepinv as dinv
|
|
10
|
+
from deepinv.loss.regularisers import JacobianSpectralNorm, FNEJacobianSpectralNorm
|
|
11
|
+
|
|
12
|
+
LOSSES = ["sup", "mcei"]
|
|
13
|
+
LIST_SURE = ["Gaussian", "Poisson", "PoissonGaussian", "UniformGaussian"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_jacobian_spectral_values(toymatrix):
|
|
17
|
+
# Define the Jacobian regularisers we want to check
|
|
18
|
+
reg_l2 = JacobianSpectralNorm(max_iter=100, tol=1e-3, eval_mode=False, verbose=True)
|
|
19
|
+
reg_FNE_l2 = FNEJacobianSpectralNorm(
|
|
20
|
+
max_iter=100, tol=1e-3, eval_mode=False, verbose=True
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Setup our toy example; here y = A@x
|
|
24
|
+
x_detached = torch.randn_like(toymatrix).requires_grad_()
|
|
25
|
+
out = toymatrix @ x_detached
|
|
26
|
+
|
|
27
|
+
def model(x):
|
|
28
|
+
return toymatrix @ x
|
|
29
|
+
|
|
30
|
+
regl2 = reg_l2(out, x_detached)
|
|
31
|
+
regfnel2 = reg_FNE_l2(out, x_detached, model, interpolation=False)
|
|
32
|
+
|
|
33
|
+
assert math.isclose(regl2.item(), toymatrix.size(0), rel_tol=1e-3)
|
|
34
|
+
assert math.isclose(regfnel2.item(), 2 * toymatrix.size(0) - 1, rel_tol=1e-3)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def choose_loss(loss_name):
|
|
38
|
+
loss = []
|
|
39
|
+
if loss_name == "mcei":
|
|
40
|
+
loss.append(dinv.loss.MCLoss())
|
|
41
|
+
loss.append(dinv.loss.EILoss(dinv.transform.Shift()))
|
|
42
|
+
elif loss_name == "splittv":
|
|
43
|
+
loss.append(dinv.loss.SplittingLoss(regular_mask=True, split_ratio=0.25))
|
|
44
|
+
loss.append(dinv.loss.TVLoss())
|
|
45
|
+
elif loss_name == "score":
|
|
46
|
+
loss.append(dinv.loss.ScoreLoss(1.0))
|
|
47
|
+
elif loss_name == "sup":
|
|
48
|
+
loss.append(dinv.loss.SupLoss())
|
|
49
|
+
else:
|
|
50
|
+
raise Exception("The loss doesnt exist")
|
|
51
|
+
|
|
52
|
+
return loss
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def choose_sure(noise_type):
|
|
56
|
+
gain = 0.1
|
|
57
|
+
sigma = 0.1
|
|
58
|
+
if noise_type == "PoissonGaussian":
|
|
59
|
+
loss = dinv.loss.SurePGLoss(sigma=sigma, gain=gain)
|
|
60
|
+
noise_model = dinv.physics.PoissonGaussianNoise(sigma=sigma, gain=gain)
|
|
61
|
+
elif noise_type == "Gaussian":
|
|
62
|
+
loss = dinv.loss.SureGaussianLoss(sigma=sigma)
|
|
63
|
+
noise_model = dinv.physics.GaussianNoise(sigma)
|
|
64
|
+
elif noise_type == "UniformGaussian":
|
|
65
|
+
loss = dinv.loss.SureGaussianLoss(sigma=sigma)
|
|
66
|
+
noise_model = dinv.physics.UniformGaussianNoise(
|
|
67
|
+
sigma=sigma
|
|
68
|
+
) # This is equivalent to GaussianNoise when sigma is fixed
|
|
69
|
+
elif noise_type == "Poisson":
|
|
70
|
+
loss = dinv.loss.SurePoissonLoss(gain=gain)
|
|
71
|
+
noise_model = dinv.physics.PoissonNoise(gain)
|
|
72
|
+
elif noise_type == "Neighbor2Neighbor":
|
|
73
|
+
loss = dinv.loss.Neighbor2Neighbor()
|
|
74
|
+
noise_model = dinv.physics.PoissonNoise(gain)
|
|
75
|
+
else:
|
|
76
|
+
raise Exception("The SURE loss doesnt exist")
|
|
77
|
+
|
|
78
|
+
return loss, noise_model
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@pytest.mark.parametrize("noise_type", LIST_SURE)
|
|
82
|
+
def test_sure(noise_type, device):
|
|
83
|
+
imsize = (3, 256, 256) # a bigger image reduces the error
|
|
84
|
+
# choose backbone denoiser
|
|
85
|
+
backbone = dinv.models.MedianFilter()
|
|
86
|
+
|
|
87
|
+
# choose a reconstruction architecture
|
|
88
|
+
f = dinv.models.ArtifactRemoval(backbone)
|
|
89
|
+
|
|
90
|
+
# choose training losses
|
|
91
|
+
loss, noise = choose_sure(noise_type)
|
|
92
|
+
|
|
93
|
+
# choose noise
|
|
94
|
+
torch.manual_seed(0) # for reproducibility
|
|
95
|
+
physics = dinv.physics.Denoising(noise=noise)
|
|
96
|
+
|
|
97
|
+
batch_size = 1
|
|
98
|
+
x = torch.ones((batch_size,) + imsize, device=device)
|
|
99
|
+
y = physics(x)
|
|
100
|
+
|
|
101
|
+
x_net = f(y, physics)
|
|
102
|
+
mse = deepinv.metric.mse()(x, x_net)
|
|
103
|
+
sure = loss(y=y, x_net=x_net, physics=physics, model=f)
|
|
104
|
+
|
|
105
|
+
rel_error = (sure - mse).abs() / mse
|
|
106
|
+
assert rel_error < 0.9
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pytest.fixture
|
|
110
|
+
def imsize():
|
|
111
|
+
return (3, 15, 10)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@pytest.fixture
|
|
115
|
+
def physics(imsize, device):
|
|
116
|
+
# choose a forward operator
|
|
117
|
+
return dinv.physics.Inpainting(tensor_size=imsize, mask=0.5, device=device)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.fixture
|
|
121
|
+
def dataset(physics, tmp_path, imsize, device):
|
|
122
|
+
# load dummy dataset
|
|
123
|
+
save_dir = tmp_path / "dataset"
|
|
124
|
+
dinv.datasets.generate_dataset(
|
|
125
|
+
train_dataset=DummyCircles(samples=50, imsize=imsize),
|
|
126
|
+
test_dataset=DummyCircles(samples=10, imsize=imsize),
|
|
127
|
+
physics=physics,
|
|
128
|
+
save_dir=save_dir,
|
|
129
|
+
device=device,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return (
|
|
133
|
+
dinv.datasets.HDF5Dataset(save_dir / "dinv_dataset0.h5", train=True),
|
|
134
|
+
dinv.datasets.HDF5Dataset(save_dir / "dinv_dataset0.h5", train=False),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def test_notraining(physics, tmp_path, imsize, device):
|
|
139
|
+
# load dummy dataset
|
|
140
|
+
save_dir = tmp_path / "dataset"
|
|
141
|
+
|
|
142
|
+
dinv.datasets.generate_dataset(
|
|
143
|
+
train_dataset=None,
|
|
144
|
+
test_dataset=DummyCircles(samples=10, imsize=imsize),
|
|
145
|
+
physics=physics,
|
|
146
|
+
save_dir=save_dir,
|
|
147
|
+
device=device,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
dataset = dinv.datasets.HDF5Dataset(save_dir / "dinv_dataset0.h5", train=False)
|
|
151
|
+
|
|
152
|
+
assert dataset[0][0].shape == imsize
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@pytest.mark.parametrize("loss_name", LOSSES)
|
|
156
|
+
def test_losses(loss_name, tmp_path, dataset, physics, imsize, device):
|
|
157
|
+
# choose training losses
|
|
158
|
+
loss = choose_loss(loss_name)
|
|
159
|
+
|
|
160
|
+
save_dir = tmp_path / "dataset"
|
|
161
|
+
# choose backbone denoiser
|
|
162
|
+
backbone = dinv.models.AutoEncoder(
|
|
163
|
+
dim_input=imsize[0] * imsize[1] * imsize[2], dim_mid=128, dim_hid=32
|
|
164
|
+
).to(device)
|
|
165
|
+
|
|
166
|
+
# choose a reconstruction architecture
|
|
167
|
+
model = dinv.models.ArtifactRemoval(backbone)
|
|
168
|
+
|
|
169
|
+
# choose optimizer and scheduler
|
|
170
|
+
epochs = 50
|
|
171
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-8)
|
|
172
|
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs * 0.8))
|
|
173
|
+
|
|
174
|
+
dataloader = DataLoader(dataset[0], batch_size=2, shuffle=True, num_workers=0)
|
|
175
|
+
test_dataloader = DataLoader(dataset[1], batch_size=2, shuffle=False, num_workers=0)
|
|
176
|
+
|
|
177
|
+
# test the untrained model
|
|
178
|
+
initial_psnr = dinv.test(
|
|
179
|
+
model=model,
|
|
180
|
+
test_dataloader=test_dataloader,
|
|
181
|
+
physics=physics,
|
|
182
|
+
plot_images=False,
|
|
183
|
+
device=device,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# train the network
|
|
187
|
+
model = dinv.train(
|
|
188
|
+
model=model,
|
|
189
|
+
train_dataloader=dataloader,
|
|
190
|
+
epochs=epochs,
|
|
191
|
+
scheduler=scheduler,
|
|
192
|
+
losses=loss,
|
|
193
|
+
physics=physics,
|
|
194
|
+
optimizer=optimizer,
|
|
195
|
+
device=device,
|
|
196
|
+
ckp_interval=int(epochs / 2),
|
|
197
|
+
save_path=save_dir / "dinv_test",
|
|
198
|
+
plot_images=False,
|
|
199
|
+
verbose=False,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
final_psnr = dinv.test(
|
|
203
|
+
model=model,
|
|
204
|
+
test_dataloader=test_dataloader,
|
|
205
|
+
physics=physics,
|
|
206
|
+
plot_images=False,
|
|
207
|
+
device=device,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
assert final_psnr[0] > initial_psnr[0]
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def test_sure_losses(device):
|
|
214
|
+
f = dinv.models.ArtifactRemoval(dinv.models.MedianFilter())
|
|
215
|
+
# test divergence
|
|
216
|
+
|
|
217
|
+
x = torch.ones((1, 3, 16, 16), device=device) * 0.5
|
|
218
|
+
physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(0.1))
|
|
219
|
+
y = physics(x)
|
|
220
|
+
|
|
221
|
+
y1 = f(y, physics)
|
|
222
|
+
tau = 1e-4
|
|
223
|
+
|
|
224
|
+
exact = dinv.loss.sure.exact_div(y, physics, f)
|
|
225
|
+
|
|
226
|
+
error_h = 0
|
|
227
|
+
error_mc = 0
|
|
228
|
+
|
|
229
|
+
num_it = 100
|
|
230
|
+
|
|
231
|
+
for i in range(num_it):
|
|
232
|
+
h = dinv.loss.sure.hutch_div(y, physics, f)
|
|
233
|
+
mc = dinv.loss.sure.mc_div(y1, y, f, physics, tau)
|
|
234
|
+
|
|
235
|
+
error_h += torch.abs(h - exact)
|
|
236
|
+
error_mc += torch.abs(mc - exact)
|
|
237
|
+
|
|
238
|
+
error_mc /= num_it
|
|
239
|
+
error_h /= num_it
|
|
240
|
+
|
|
241
|
+
# print(f"error_h: {error_h}")
|
|
242
|
+
# print(f"error_mc: {error_mc}")
|
|
243
|
+
assert error_h < 5e-2
|
|
244
|
+
assert error_mc < 5e-2
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def test_measplit(device):
|
|
248
|
+
sigma = 0.1
|
|
249
|
+
physics = dinv.physics.Denoising()
|
|
250
|
+
physics.noise_model = dinv.physics.GaussianNoise(sigma)
|
|
251
|
+
|
|
252
|
+
# choose a reconstruction architecture
|
|
253
|
+
backbone = dinv.models.MedianFilter()
|
|
254
|
+
f = dinv.models.ArtifactRemoval(backbone)
|
|
255
|
+
batch_size = 1
|
|
256
|
+
imsize = (3, 32, 32)
|
|
257
|
+
|
|
258
|
+
# for split_ratio in np.linspace(0.7, 0.99, 10):
|
|
259
|
+
x = torch.ones((batch_size,) + imsize, device=device)
|
|
260
|
+
y = physics(x)
|
|
261
|
+
|
|
262
|
+
# choose training losses
|
|
263
|
+
loss = dinv.loss.SplittingLoss(split_ratio=0.5, regular_mask=True)
|
|
264
|
+
split_loss = loss(y, physics, f)
|
|
265
|
+
|
|
266
|
+
loss = dinv.loss.Neighbor2Neighbor()
|
|
267
|
+
n2n_loss = loss(y, physics, f)
|
|
268
|
+
|
|
269
|
+
assert split_loss > 0 and n2n_loss > 0
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.utils.data import DataLoader
|
|
6
|
+
|
|
7
|
+
import deepinv as dinv
|
|
8
|
+
from deepinv.optim.data_fidelity import L2
|
|
9
|
+
from deepinv.optim.prior import PnP
|
|
10
|
+
from deepinv.tests.dummy_datasets.datasets import DummyCircles
|
|
11
|
+
from deepinv.unfolded import unfolded_builder
|
|
12
|
+
from deepinv.training_utils import train
|
|
13
|
+
from deepinv.training_utils import test as feature_test
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_generate_dataset(tmp_path, imsize, device):
|
|
17
|
+
N = 10
|
|
18
|
+
max_N = 10
|
|
19
|
+
train_dataset = DummyCircles(samples=N, imsize=imsize)
|
|
20
|
+
test_dataset = DummyCircles(samples=N, imsize=imsize)
|
|
21
|
+
|
|
22
|
+
physics = dinv.physics.Inpainting(mask=0.5, tensor_size=imsize, device=device)
|
|
23
|
+
|
|
24
|
+
dinv.datasets.generate_dataset(
|
|
25
|
+
train_dataset,
|
|
26
|
+
physics,
|
|
27
|
+
tmp_path,
|
|
28
|
+
test_dataset=test_dataset,
|
|
29
|
+
device=device,
|
|
30
|
+
dataset_filename="dinv_dataset",
|
|
31
|
+
train_datapoints=max_N,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
dataset = dinv.datasets.HDF5Dataset(path=f"{tmp_path}/dinv_dataset0.h5", train=True)
|
|
35
|
+
|
|
36
|
+
assert len(dataset) == min(max_N, N)
|
|
37
|
+
|
|
38
|
+
x, y = dataset[0]
|
|
39
|
+
assert x.shape == imsize
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# optim_algos = [
|
|
43
|
+
# "PGD",
|
|
44
|
+
# "HQS",
|
|
45
|
+
# "DRS",
|
|
46
|
+
# "ADMM",
|
|
47
|
+
# "CP",
|
|
48
|
+
# ]
|
|
49
|
+
|
|
50
|
+
optim_algos = ["PGD"]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@pytest.mark.parametrize("name_algo", optim_algos)
|
|
54
|
+
def test_optim_algo(name_algo, imsize, device):
|
|
55
|
+
# This test uses WaveletPrior, which requires pytorch_wavelets
|
|
56
|
+
# TODO: we could use a dummy trainable denoiser with a linear layer instead
|
|
57
|
+
pytest.importorskip("pytorch_wavelets")
|
|
58
|
+
|
|
59
|
+
# pths
|
|
60
|
+
BASE_DIR = Path(".")
|
|
61
|
+
CKPT_DIR = BASE_DIR / "ckpts"
|
|
62
|
+
|
|
63
|
+
# Select the data fidelity term
|
|
64
|
+
data_fidelity = L2()
|
|
65
|
+
|
|
66
|
+
# Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis.
|
|
67
|
+
# If the prior is initialized with a list of length max_iter,
|
|
68
|
+
# then a distinct weight is trained for each PGD iteration.
|
|
69
|
+
# For fixed trained model prior across iterations, initialize with a single model.
|
|
70
|
+
max_iter = 30 if torch.cuda.is_available() else 3 # Number of unrolled iterations
|
|
71
|
+
level = 3
|
|
72
|
+
prior = [
|
|
73
|
+
PnP(denoiser=dinv.models.WaveletPrior(wv="db8", level=level, device=device))
|
|
74
|
+
for i in range(max_iter)
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
# Unrolled optimization algorithm parameters
|
|
78
|
+
lamb = [
|
|
79
|
+
1.0
|
|
80
|
+
] * max_iter # initialization of the regularization parameter. A distinct lamb is trained for each iteration.
|
|
81
|
+
stepsize = [
|
|
82
|
+
1.0
|
|
83
|
+
] * max_iter # initialization of the stepsizes. A distinct stepsize is trained for each iteration.
|
|
84
|
+
|
|
85
|
+
sigma_denoiser = [0.01 * torch.ones(level, 1)] * max_iter
|
|
86
|
+
# sigma_denoiser = [torch.Tensor([sigma_denoiser_init])]*max_iter
|
|
87
|
+
params_algo = { # wrap all the restoration parameters in a 'params_algo' dictionary
|
|
88
|
+
"stepsize": stepsize,
|
|
89
|
+
"g_param": sigma_denoiser,
|
|
90
|
+
"lambda": lamb,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# define which parameters from 'params_algo' are trainable
|
|
94
|
+
trainable_params = ["g_param", "stepsize"]
|
|
95
|
+
|
|
96
|
+
# Define the unfolded trainable model.
|
|
97
|
+
|
|
98
|
+
# Because the CP algorithm uses more than 2 variables, we need to define a custom initialization.
|
|
99
|
+
if name_algo == "CP":
|
|
100
|
+
|
|
101
|
+
def custom_init(y, physics):
|
|
102
|
+
x_init = physics.A_adjoint(y)
|
|
103
|
+
u_init = y
|
|
104
|
+
return {"est": (x_init, x_init, u_init)}
|
|
105
|
+
|
|
106
|
+
params_algo["sigma"] = 1.0
|
|
107
|
+
else:
|
|
108
|
+
custom_init = None
|
|
109
|
+
|
|
110
|
+
model_unfolded = unfolded_builder(
|
|
111
|
+
name_algo,
|
|
112
|
+
params_algo=params_algo,
|
|
113
|
+
trainable_params=trainable_params,
|
|
114
|
+
data_fidelity=data_fidelity,
|
|
115
|
+
max_iter=max_iter,
|
|
116
|
+
prior=prior,
|
|
117
|
+
custom_init=custom_init,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
for idx, (name, param) in enumerate(model_unfolded.named_parameters()):
|
|
121
|
+
assert param.requires_grad
|
|
122
|
+
assert (trainable_params[0] in name) or (trainable_params[1] in name)
|
|
123
|
+
|
|
124
|
+
N = 10
|
|
125
|
+
train_dataset = DummyCircles(samples=N, imsize=imsize)
|
|
126
|
+
test_dataset = DummyCircles(samples=N, imsize=imsize)
|
|
127
|
+
|
|
128
|
+
physics = dinv.physics.Inpainting(mask=0.5, tensor_size=imsize, device=device)
|
|
129
|
+
|
|
130
|
+
train_dataloader = DataLoader(
|
|
131
|
+
train_dataset, batch_size=2, num_workers=1, shuffle=True
|
|
132
|
+
)
|
|
133
|
+
test_dataloader = DataLoader(
|
|
134
|
+
test_dataset, batch_size=2, num_workers=1, shuffle=False
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
epochs = 1
|
|
138
|
+
losses = [dinv.loss.SupLoss(metric=dinv.metric.mse())]
|
|
139
|
+
optimizer = torch.optim.Adam(model_unfolded.parameters(), lr=1e-3, weight_decay=0.0)
|
|
140
|
+
|
|
141
|
+
trained_unfolded_model = train(
|
|
142
|
+
model=model_unfolded,
|
|
143
|
+
train_dataloader=train_dataloader,
|
|
144
|
+
eval_dataloader=test_dataloader,
|
|
145
|
+
epochs=epochs,
|
|
146
|
+
losses=losses,
|
|
147
|
+
physics=physics,
|
|
148
|
+
optimizer=optimizer,
|
|
149
|
+
device=device,
|
|
150
|
+
save_path=str(CKPT_DIR),
|
|
151
|
+
verbose=True,
|
|
152
|
+
wandb_vis=False,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
results = feature_test(
|
|
156
|
+
model=trained_unfolded_model,
|
|
157
|
+
test_dataloader=test_dataloader,
|
|
158
|
+
physics=physics,
|
|
159
|
+
device=device,
|
|
160
|
+
plot_images=False,
|
|
161
|
+
verbose=True,
|
|
162
|
+
wandb_vis=False,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Now check that training with online measurements works as well
|
|
166
|
+
train(
|
|
167
|
+
model=model_unfolded,
|
|
168
|
+
train_dataloader=train_dataloader,
|
|
169
|
+
eval_dataloader=test_dataloader,
|
|
170
|
+
epochs=epochs,
|
|
171
|
+
losses=losses,
|
|
172
|
+
physics=physics,
|
|
173
|
+
optimizer=optimizer,
|
|
174
|
+
device=device,
|
|
175
|
+
save_path=str(CKPT_DIR),
|
|
176
|
+
verbose=True,
|
|
177
|
+
wandb_vis=False,
|
|
178
|
+
online_measurements=True,
|
|
179
|
+
)
|