deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,269 @@
1
+ import pytest
2
+
3
+ import math
4
+ import torch
5
+
6
+ import deepinv
7
+ from deepinv.tests.dummy_datasets.datasets import DummyCircles
8
+ from torch.utils.data import DataLoader
9
+ import deepinv as dinv
10
+ from deepinv.loss.regularisers import JacobianSpectralNorm, FNEJacobianSpectralNorm
11
+
12
+ LOSSES = ["sup", "mcei"]
13
+ LIST_SURE = ["Gaussian", "Poisson", "PoissonGaussian", "UniformGaussian"]
14
+
15
+
16
+ def test_jacobian_spectral_values(toymatrix):
17
+ # Define the Jacobian regularisers we want to check
18
+ reg_l2 = JacobianSpectralNorm(max_iter=100, tol=1e-3, eval_mode=False, verbose=True)
19
+ reg_FNE_l2 = FNEJacobianSpectralNorm(
20
+ max_iter=100, tol=1e-3, eval_mode=False, verbose=True
21
+ )
22
+
23
+ # Setup our toy example; here y = A@x
24
+ x_detached = torch.randn_like(toymatrix).requires_grad_()
25
+ out = toymatrix @ x_detached
26
+
27
+ def model(x):
28
+ return toymatrix @ x
29
+
30
+ regl2 = reg_l2(out, x_detached)
31
+ regfnel2 = reg_FNE_l2(out, x_detached, model, interpolation=False)
32
+
33
+ assert math.isclose(regl2.item(), toymatrix.size(0), rel_tol=1e-3)
34
+ assert math.isclose(regfnel2.item(), 2 * toymatrix.size(0) - 1, rel_tol=1e-3)
35
+
36
+
37
+ def choose_loss(loss_name):
38
+ loss = []
39
+ if loss_name == "mcei":
40
+ loss.append(dinv.loss.MCLoss())
41
+ loss.append(dinv.loss.EILoss(dinv.transform.Shift()))
42
+ elif loss_name == "splittv":
43
+ loss.append(dinv.loss.SplittingLoss(regular_mask=True, split_ratio=0.25))
44
+ loss.append(dinv.loss.TVLoss())
45
+ elif loss_name == "score":
46
+ loss.append(dinv.loss.ScoreLoss(1.0))
47
+ elif loss_name == "sup":
48
+ loss.append(dinv.loss.SupLoss())
49
+ else:
50
+ raise Exception("The loss doesnt exist")
51
+
52
+ return loss
53
+
54
+
55
+ def choose_sure(noise_type):
56
+ gain = 0.1
57
+ sigma = 0.1
58
+ if noise_type == "PoissonGaussian":
59
+ loss = dinv.loss.SurePGLoss(sigma=sigma, gain=gain)
60
+ noise_model = dinv.physics.PoissonGaussianNoise(sigma=sigma, gain=gain)
61
+ elif noise_type == "Gaussian":
62
+ loss = dinv.loss.SureGaussianLoss(sigma=sigma)
63
+ noise_model = dinv.physics.GaussianNoise(sigma)
64
+ elif noise_type == "UniformGaussian":
65
+ loss = dinv.loss.SureGaussianLoss(sigma=sigma)
66
+ noise_model = dinv.physics.UniformGaussianNoise(
67
+ sigma=sigma
68
+ ) # This is equivalent to GaussianNoise when sigma is fixed
69
+ elif noise_type == "Poisson":
70
+ loss = dinv.loss.SurePoissonLoss(gain=gain)
71
+ noise_model = dinv.physics.PoissonNoise(gain)
72
+ elif noise_type == "Neighbor2Neighbor":
73
+ loss = dinv.loss.Neighbor2Neighbor()
74
+ noise_model = dinv.physics.PoissonNoise(gain)
75
+ else:
76
+ raise Exception("The SURE loss doesnt exist")
77
+
78
+ return loss, noise_model
79
+
80
+
81
+ @pytest.mark.parametrize("noise_type", LIST_SURE)
82
+ def test_sure(noise_type, device):
83
+ imsize = (3, 256, 256) # a bigger image reduces the error
84
+ # choose backbone denoiser
85
+ backbone = dinv.models.MedianFilter()
86
+
87
+ # choose a reconstruction architecture
88
+ f = dinv.models.ArtifactRemoval(backbone)
89
+
90
+ # choose training losses
91
+ loss, noise = choose_sure(noise_type)
92
+
93
+ # choose noise
94
+ torch.manual_seed(0) # for reproducibility
95
+ physics = dinv.physics.Denoising(noise=noise)
96
+
97
+ batch_size = 1
98
+ x = torch.ones((batch_size,) + imsize, device=device)
99
+ y = physics(x)
100
+
101
+ x_net = f(y, physics)
102
+ mse = deepinv.metric.mse()(x, x_net)
103
+ sure = loss(y=y, x_net=x_net, physics=physics, model=f)
104
+
105
+ rel_error = (sure - mse).abs() / mse
106
+ assert rel_error < 0.9
107
+
108
+
109
+ @pytest.fixture
110
+ def imsize():
111
+ return (3, 15, 10)
112
+
113
+
114
+ @pytest.fixture
115
+ def physics(imsize, device):
116
+ # choose a forward operator
117
+ return dinv.physics.Inpainting(tensor_size=imsize, mask=0.5, device=device)
118
+
119
+
120
+ @pytest.fixture
121
+ def dataset(physics, tmp_path, imsize, device):
122
+ # load dummy dataset
123
+ save_dir = tmp_path / "dataset"
124
+ dinv.datasets.generate_dataset(
125
+ train_dataset=DummyCircles(samples=50, imsize=imsize),
126
+ test_dataset=DummyCircles(samples=10, imsize=imsize),
127
+ physics=physics,
128
+ save_dir=save_dir,
129
+ device=device,
130
+ )
131
+
132
+ return (
133
+ dinv.datasets.HDF5Dataset(save_dir / "dinv_dataset0.h5", train=True),
134
+ dinv.datasets.HDF5Dataset(save_dir / "dinv_dataset0.h5", train=False),
135
+ )
136
+
137
+
138
+ def test_notraining(physics, tmp_path, imsize, device):
139
+ # load dummy dataset
140
+ save_dir = tmp_path / "dataset"
141
+
142
+ dinv.datasets.generate_dataset(
143
+ train_dataset=None,
144
+ test_dataset=DummyCircles(samples=10, imsize=imsize),
145
+ physics=physics,
146
+ save_dir=save_dir,
147
+ device=device,
148
+ )
149
+
150
+ dataset = dinv.datasets.HDF5Dataset(save_dir / "dinv_dataset0.h5", train=False)
151
+
152
+ assert dataset[0][0].shape == imsize
153
+
154
+
155
+ @pytest.mark.parametrize("loss_name", LOSSES)
156
+ def test_losses(loss_name, tmp_path, dataset, physics, imsize, device):
157
+ # choose training losses
158
+ loss = choose_loss(loss_name)
159
+
160
+ save_dir = tmp_path / "dataset"
161
+ # choose backbone denoiser
162
+ backbone = dinv.models.AutoEncoder(
163
+ dim_input=imsize[0] * imsize[1] * imsize[2], dim_mid=128, dim_hid=32
164
+ ).to(device)
165
+
166
+ # choose a reconstruction architecture
167
+ model = dinv.models.ArtifactRemoval(backbone)
168
+
169
+ # choose optimizer and scheduler
170
+ epochs = 50
171
+ optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-8)
172
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs * 0.8))
173
+
174
+ dataloader = DataLoader(dataset[0], batch_size=2, shuffle=True, num_workers=0)
175
+ test_dataloader = DataLoader(dataset[1], batch_size=2, shuffle=False, num_workers=0)
176
+
177
+ # test the untrained model
178
+ initial_psnr = dinv.test(
179
+ model=model,
180
+ test_dataloader=test_dataloader,
181
+ physics=physics,
182
+ plot_images=False,
183
+ device=device,
184
+ )
185
+
186
+ # train the network
187
+ model = dinv.train(
188
+ model=model,
189
+ train_dataloader=dataloader,
190
+ epochs=epochs,
191
+ scheduler=scheduler,
192
+ losses=loss,
193
+ physics=physics,
194
+ optimizer=optimizer,
195
+ device=device,
196
+ ckp_interval=int(epochs / 2),
197
+ save_path=save_dir / "dinv_test",
198
+ plot_images=False,
199
+ verbose=False,
200
+ )
201
+
202
+ final_psnr = dinv.test(
203
+ model=model,
204
+ test_dataloader=test_dataloader,
205
+ physics=physics,
206
+ plot_images=False,
207
+ device=device,
208
+ )
209
+
210
+ assert final_psnr[0] > initial_psnr[0]
211
+
212
+
213
+ def test_sure_losses(device):
214
+ f = dinv.models.ArtifactRemoval(dinv.models.MedianFilter())
215
+ # test divergence
216
+
217
+ x = torch.ones((1, 3, 16, 16), device=device) * 0.5
218
+ physics = dinv.physics.Denoising(dinv.physics.GaussianNoise(0.1))
219
+ y = physics(x)
220
+
221
+ y1 = f(y, physics)
222
+ tau = 1e-4
223
+
224
+ exact = dinv.loss.sure.exact_div(y, physics, f)
225
+
226
+ error_h = 0
227
+ error_mc = 0
228
+
229
+ num_it = 100
230
+
231
+ for i in range(num_it):
232
+ h = dinv.loss.sure.hutch_div(y, physics, f)
233
+ mc = dinv.loss.sure.mc_div(y1, y, f, physics, tau)
234
+
235
+ error_h += torch.abs(h - exact)
236
+ error_mc += torch.abs(mc - exact)
237
+
238
+ error_mc /= num_it
239
+ error_h /= num_it
240
+
241
+ # print(f"error_h: {error_h}")
242
+ # print(f"error_mc: {error_mc}")
243
+ assert error_h < 5e-2
244
+ assert error_mc < 5e-2
245
+
246
+
247
+ def test_measplit(device):
248
+ sigma = 0.1
249
+ physics = dinv.physics.Denoising()
250
+ physics.noise_model = dinv.physics.GaussianNoise(sigma)
251
+
252
+ # choose a reconstruction architecture
253
+ backbone = dinv.models.MedianFilter()
254
+ f = dinv.models.ArtifactRemoval(backbone)
255
+ batch_size = 1
256
+ imsize = (3, 32, 32)
257
+
258
+ # for split_ratio in np.linspace(0.7, 0.99, 10):
259
+ x = torch.ones((batch_size,) + imsize, device=device)
260
+ y = physics(x)
261
+
262
+ # choose training losses
263
+ loss = dinv.loss.SplittingLoss(split_ratio=0.5, regular_mask=True)
264
+ split_loss = loss(y, physics, f)
265
+
266
+ loss = dinv.loss.Neighbor2Neighbor()
267
+ n2n_loss = loss(y, physics, f)
268
+
269
+ assert split_loss > 0 and n2n_loss > 0
@@ -0,0 +1,179 @@
1
+ import pytest
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+
7
+ import deepinv as dinv
8
+ from deepinv.optim.data_fidelity import L2
9
+ from deepinv.optim.prior import PnP
10
+ from deepinv.tests.dummy_datasets.datasets import DummyCircles
11
+ from deepinv.unfolded import unfolded_builder
12
+ from deepinv.training_utils import train
13
+ from deepinv.training_utils import test as feature_test
14
+
15
+
16
+ def test_generate_dataset(tmp_path, imsize, device):
17
+ N = 10
18
+ max_N = 10
19
+ train_dataset = DummyCircles(samples=N, imsize=imsize)
20
+ test_dataset = DummyCircles(samples=N, imsize=imsize)
21
+
22
+ physics = dinv.physics.Inpainting(mask=0.5, tensor_size=imsize, device=device)
23
+
24
+ dinv.datasets.generate_dataset(
25
+ train_dataset,
26
+ physics,
27
+ tmp_path,
28
+ test_dataset=test_dataset,
29
+ device=device,
30
+ dataset_filename="dinv_dataset",
31
+ train_datapoints=max_N,
32
+ )
33
+
34
+ dataset = dinv.datasets.HDF5Dataset(path=f"{tmp_path}/dinv_dataset0.h5", train=True)
35
+
36
+ assert len(dataset) == min(max_N, N)
37
+
38
+ x, y = dataset[0]
39
+ assert x.shape == imsize
40
+
41
+
42
+ # optim_algos = [
43
+ # "PGD",
44
+ # "HQS",
45
+ # "DRS",
46
+ # "ADMM",
47
+ # "CP",
48
+ # ]
49
+
50
+ optim_algos = ["PGD"]
51
+
52
+
53
+ @pytest.mark.parametrize("name_algo", optim_algos)
54
+ def test_optim_algo(name_algo, imsize, device):
55
+ # This test uses WaveletPrior, which requires pytorch_wavelets
56
+ # TODO: we could use a dummy trainable denoiser with a linear layer instead
57
+ pytest.importorskip("pytorch_wavelets")
58
+
59
+ # pths
60
+ BASE_DIR = Path(".")
61
+ CKPT_DIR = BASE_DIR / "ckpts"
62
+
63
+ # Select the data fidelity term
64
+ data_fidelity = L2()
65
+
66
+ # Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis.
67
+ # If the prior is initialized with a list of length max_iter,
68
+ # then a distinct weight is trained for each PGD iteration.
69
+ # For fixed trained model prior across iterations, initialize with a single model.
70
+ max_iter = 30 if torch.cuda.is_available() else 3 # Number of unrolled iterations
71
+ level = 3
72
+ prior = [
73
+ PnP(denoiser=dinv.models.WaveletPrior(wv="db8", level=level, device=device))
74
+ for i in range(max_iter)
75
+ ]
76
+
77
+ # Unrolled optimization algorithm parameters
78
+ lamb = [
79
+ 1.0
80
+ ] * max_iter # initialization of the regularization parameter. A distinct lamb is trained for each iteration.
81
+ stepsize = [
82
+ 1.0
83
+ ] * max_iter # initialization of the stepsizes. A distinct stepsize is trained for each iteration.
84
+
85
+ sigma_denoiser = [0.01 * torch.ones(level, 1)] * max_iter
86
+ # sigma_denoiser = [torch.Tensor([sigma_denoiser_init])]*max_iter
87
+ params_algo = { # wrap all the restoration parameters in a 'params_algo' dictionary
88
+ "stepsize": stepsize,
89
+ "g_param": sigma_denoiser,
90
+ "lambda": lamb,
91
+ }
92
+
93
+ # define which parameters from 'params_algo' are trainable
94
+ trainable_params = ["g_param", "stepsize"]
95
+
96
+ # Define the unfolded trainable model.
97
+
98
+ # Because the CP algorithm uses more than 2 variables, we need to define a custom initialization.
99
+ if name_algo == "CP":
100
+
101
+ def custom_init(y, physics):
102
+ x_init = physics.A_adjoint(y)
103
+ u_init = y
104
+ return {"est": (x_init, x_init, u_init)}
105
+
106
+ params_algo["sigma"] = 1.0
107
+ else:
108
+ custom_init = None
109
+
110
+ model_unfolded = unfolded_builder(
111
+ name_algo,
112
+ params_algo=params_algo,
113
+ trainable_params=trainable_params,
114
+ data_fidelity=data_fidelity,
115
+ max_iter=max_iter,
116
+ prior=prior,
117
+ custom_init=custom_init,
118
+ )
119
+
120
+ for idx, (name, param) in enumerate(model_unfolded.named_parameters()):
121
+ assert param.requires_grad
122
+ assert (trainable_params[0] in name) or (trainable_params[1] in name)
123
+
124
+ N = 10
125
+ train_dataset = DummyCircles(samples=N, imsize=imsize)
126
+ test_dataset = DummyCircles(samples=N, imsize=imsize)
127
+
128
+ physics = dinv.physics.Inpainting(mask=0.5, tensor_size=imsize, device=device)
129
+
130
+ train_dataloader = DataLoader(
131
+ train_dataset, batch_size=2, num_workers=1, shuffle=True
132
+ )
133
+ test_dataloader = DataLoader(
134
+ test_dataset, batch_size=2, num_workers=1, shuffle=False
135
+ )
136
+
137
+ epochs = 1
138
+ losses = [dinv.loss.SupLoss(metric=dinv.metric.mse())]
139
+ optimizer = torch.optim.Adam(model_unfolded.parameters(), lr=1e-3, weight_decay=0.0)
140
+
141
+ trained_unfolded_model = train(
142
+ model=model_unfolded,
143
+ train_dataloader=train_dataloader,
144
+ eval_dataloader=test_dataloader,
145
+ epochs=epochs,
146
+ losses=losses,
147
+ physics=physics,
148
+ optimizer=optimizer,
149
+ device=device,
150
+ save_path=str(CKPT_DIR),
151
+ verbose=True,
152
+ wandb_vis=False,
153
+ )
154
+
155
+ results = feature_test(
156
+ model=trained_unfolded_model,
157
+ test_dataloader=test_dataloader,
158
+ physics=physics,
159
+ device=device,
160
+ plot_images=False,
161
+ verbose=True,
162
+ wandb_vis=False,
163
+ )
164
+
165
+ # Now check that training with online measurements works as well
166
+ train(
167
+ model=model_unfolded,
168
+ train_dataloader=train_dataloader,
169
+ eval_dataloader=test_dataloader,
170
+ epochs=epochs,
171
+ losses=losses,
172
+ physics=physics,
173
+ optimizer=optimizer,
174
+ device=device,
175
+ save_path=str(CKPT_DIR),
176
+ verbose=True,
177
+ wandb_vis=False,
178
+ online_measurements=True,
179
+ )