deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/utils/nn.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TensorList:
|
|
7
|
+
r"""
|
|
8
|
+
|
|
9
|
+
Represents a list of :class:`torch.Tensor` with different shapes.
|
|
10
|
+
It allows to sum, flatten, append, etc. lists of tensors seamlessly, in a
|
|
11
|
+
similar fashion to :class:`torch.Tensor`.
|
|
12
|
+
|
|
13
|
+
:param x: a list of :class:`torch.Tensor`, a single :class:`torch.Tensor` or a TensorList.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, x):
|
|
17
|
+
super().__init__()
|
|
18
|
+
|
|
19
|
+
if isinstance(x, list) or isinstance(x, TensorList):
|
|
20
|
+
self.x = list(x)
|
|
21
|
+
elif isinstance(x, torch.Tensor):
|
|
22
|
+
self.x = [x]
|
|
23
|
+
else:
|
|
24
|
+
raise TypeError("x must be a list of torch.Tensor or a single torch.Tensor")
|
|
25
|
+
|
|
26
|
+
self.shape = [xi.shape for xi in self.x]
|
|
27
|
+
|
|
28
|
+
def __len__(self):
|
|
29
|
+
r"""
|
|
30
|
+
Returns the number of tensors in the list.
|
|
31
|
+
"""
|
|
32
|
+
return len(self.x)
|
|
33
|
+
|
|
34
|
+
def __getitem__(self, item):
|
|
35
|
+
r"""
|
|
36
|
+
Returns the ith tensor in the list.
|
|
37
|
+
"""
|
|
38
|
+
return self.x[item]
|
|
39
|
+
|
|
40
|
+
def flatten(self):
|
|
41
|
+
r"""
|
|
42
|
+
Returns a :class:`torch.Tensor` with a flattened version of the list of tensors.
|
|
43
|
+
"""
|
|
44
|
+
return torch.cat([xi.flatten() for xi in self.x])
|
|
45
|
+
|
|
46
|
+
def append(self, other):
|
|
47
|
+
r"""
|
|
48
|
+
Appends a :class:`torch.Tensor` or a list of :class:`torch.Tensor` to the list.
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
if isinstance(other, list):
|
|
52
|
+
self.x += other
|
|
53
|
+
elif isinstance(other, TensorList):
|
|
54
|
+
self.x += other.x
|
|
55
|
+
elif isinstance(other, torch.Tensor):
|
|
56
|
+
self.x.append(other)
|
|
57
|
+
else:
|
|
58
|
+
raise TypeError(
|
|
59
|
+
"the appended item must be a list of :class:`torch.Tensor` or a single :class:`torch.Tensor`"
|
|
60
|
+
)
|
|
61
|
+
return self
|
|
62
|
+
|
|
63
|
+
def __add__(self, other):
|
|
64
|
+
r"""
|
|
65
|
+
|
|
66
|
+
Adds two TensorLists. The sizes of the tensor lists must match.
|
|
67
|
+
|
|
68
|
+
"""
|
|
69
|
+
if not isinstance(other, list) and not isinstance(other, TensorList):
|
|
70
|
+
return TensorList([xi + other for xi in self.x])
|
|
71
|
+
else:
|
|
72
|
+
return TensorList([xi + otheri for xi, otheri in zip(self.x, other)])
|
|
73
|
+
|
|
74
|
+
def __mul__(self, other):
|
|
75
|
+
r"""
|
|
76
|
+
|
|
77
|
+
Multiply two TensorLists. The sizes of the tensor lists must match.
|
|
78
|
+
|
|
79
|
+
"""
|
|
80
|
+
if not isinstance(other, list) and not isinstance(other, TensorList):
|
|
81
|
+
return TensorList([xi * other for xi in self.x])
|
|
82
|
+
else:
|
|
83
|
+
return TensorList([xi * otheri for xi, otheri in zip(self.x, other)])
|
|
84
|
+
|
|
85
|
+
def __truediv__(self, other):
|
|
86
|
+
r"""
|
|
87
|
+
|
|
88
|
+
Divide two TensorLists. The sizes of the tensor lists must match.
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
if not isinstance(other, list) and not isinstance(other, TensorList):
|
|
92
|
+
return TensorList([xi / other for xi in self.x])
|
|
93
|
+
else:
|
|
94
|
+
return TensorList([xi / otheri for xi, otheri in zip(self.x, other)])
|
|
95
|
+
|
|
96
|
+
def __neg__(self):
|
|
97
|
+
r"""
|
|
98
|
+
|
|
99
|
+
Negate a TensorList.
|
|
100
|
+
"""
|
|
101
|
+
return TensorList([-xi for xi in self.x])
|
|
102
|
+
|
|
103
|
+
def __sub__(self, other):
|
|
104
|
+
r"""
|
|
105
|
+
|
|
106
|
+
Substract two TensorLists. The sizes of the tensor lists must match.
|
|
107
|
+
|
|
108
|
+
"""
|
|
109
|
+
if not isinstance(other, list) and not isinstance(other, TensorList):
|
|
110
|
+
return TensorList([xi - other for xi in self.x])
|
|
111
|
+
else:
|
|
112
|
+
return TensorList([xi - otheri for xi, otheri in zip(self.x, other)])
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def randn_like(x):
|
|
116
|
+
r"""
|
|
117
|
+
Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor`
|
|
118
|
+
with the same type as x, filled with standard gaussian numbers.
|
|
119
|
+
"""
|
|
120
|
+
if isinstance(x, torch.Tensor):
|
|
121
|
+
return torch.randn_like(x)
|
|
122
|
+
else:
|
|
123
|
+
return TensorList([torch.randn_like(xi) for xi in x])
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def rand_like(x):
|
|
127
|
+
r"""
|
|
128
|
+
Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor`
|
|
129
|
+
with the same type as x, filled with random uniform numbers in [0,1].
|
|
130
|
+
"""
|
|
131
|
+
if isinstance(x, torch.Tensor):
|
|
132
|
+
return torch.rand_like(x)
|
|
133
|
+
else:
|
|
134
|
+
return TensorList([torch.rand_like(xi) for xi in x])
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def zeros_like(x):
|
|
138
|
+
r"""
|
|
139
|
+
Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor`
|
|
140
|
+
with the same type as x, filled with zeros.
|
|
141
|
+
"""
|
|
142
|
+
if isinstance(x, torch.Tensor):
|
|
143
|
+
return torch.zeros_like(x)
|
|
144
|
+
else:
|
|
145
|
+
return TensorList([torch.zeros_like(xi) for xi in x])
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def ones_like(x):
|
|
149
|
+
r"""
|
|
150
|
+
Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor`
|
|
151
|
+
with the same type as x, filled with ones.
|
|
152
|
+
"""
|
|
153
|
+
if isinstance(x, torch.Tensor):
|
|
154
|
+
return torch.ones_like(x)
|
|
155
|
+
else:
|
|
156
|
+
return TensorList([torch.ones_like(xi) for xi in x])
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def get_freer_gpu():
|
|
160
|
+
"""
|
|
161
|
+
Returns the GPU device with the most free memory.
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
try:
|
|
165
|
+
if os.name == "posix":
|
|
166
|
+
os.system("nvidia-smi -q -d Memory |grep -A5 GPU|grep Free >tmp")
|
|
167
|
+
memory_available = [int(x.split()[2]) for x in open("tmp", "r").readlines()]
|
|
168
|
+
else:
|
|
169
|
+
os.system('bash -c "nvidia-smi -q -d Memory |grep -A5 GPU|grep Free >tmp"')
|
|
170
|
+
memory_available = [int(x.split()[2]) for x in open("tmp", "r").readlines()]
|
|
171
|
+
idx = np.argmax(memory_available)
|
|
172
|
+
device = torch.device(f"cuda:{idx}")
|
|
173
|
+
print(f"Selected GPU {idx} with {np.max(memory_available)} MB free memory ")
|
|
174
|
+
except:
|
|
175
|
+
device = torch.device(f"cuda")
|
|
176
|
+
print("Couldn't find free GPU")
|
|
177
|
+
|
|
178
|
+
return device
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def save_model(
|
|
182
|
+
epoch, model, optimizer, ckp_interval, epochs, loss, save_path, eval_psnr=None
|
|
183
|
+
):
|
|
184
|
+
if (epoch > 0 and epoch % ckp_interval == 0) or epoch + 1 == epochs:
|
|
185
|
+
os.makedirs(save_path, exist_ok=True)
|
|
186
|
+
|
|
187
|
+
state = {
|
|
188
|
+
"epoch": epoch,
|
|
189
|
+
"state_dict": model.state_dict(),
|
|
190
|
+
"loss": loss,
|
|
191
|
+
"optimizer": optimizer.state_dict(),
|
|
192
|
+
}
|
|
193
|
+
if eval_psnr is not None:
|
|
194
|
+
state["eval_psnr"] = eval_psnr
|
|
195
|
+
torch.save(state, os.path.join(save_path, "ckp_{}.pth.tar".format(epoch)))
|
|
196
|
+
pass
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def load_checkpoint(model, path_checkpoint, device):
|
|
200
|
+
checkpoint = torch.load(path_checkpoint, map_location=device)
|
|
201
|
+
model.load_state_dict(checkpoint["state_dict"])
|
|
202
|
+
return model
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def investigate_model(model, idx_max=1, check_name="iterator.g_step.g_param.0"):
|
|
206
|
+
for idx, (name, param) in enumerate(model.named_parameters()):
|
|
207
|
+
if param.requires_grad and (idx < idx_max or check_name in name):
|
|
208
|
+
print(
|
|
209
|
+
name,
|
|
210
|
+
param.data.flatten()[0],
|
|
211
|
+
"gradient norm = ",
|
|
212
|
+
param.grad.detach().data.norm(2),
|
|
213
|
+
)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class NeuralIteration(nn.Module):
|
|
5
|
+
def __init__(self):
|
|
6
|
+
super(NeuralIteration, self).__init__()
|
|
7
|
+
|
|
8
|
+
def init(self, backbone_blocks, step_size=1.0, iterations=1):
|
|
9
|
+
self.iterations = iterations
|
|
10
|
+
self.n_blocks = len(backbone_blocks) if isinstance(backbone_blocks, list) else 1
|
|
11
|
+
self.blocks = backbone_blocks
|
|
12
|
+
|
|
13
|
+
self.register_parameter(
|
|
14
|
+
name="step_size",
|
|
15
|
+
param=torch.nn.Parameter(
|
|
16
|
+
step_size * torch.ones(self.iterations), requires_grad=True
|
|
17
|
+
),
|
|
18
|
+
)
|
|
19
|
+
if self.n_blocks > 1: # weight_tied=False (many blocks)
|
|
20
|
+
assert (
|
|
21
|
+
self.n_blocks == iterations
|
|
22
|
+
), "'# blocks' does not equal to 'iterations'"
|
|
23
|
+
self.blocks = torch.nn.ModuleList(
|
|
24
|
+
[backbone_blocks[_] for _ in range(iterations)]
|
|
25
|
+
)
|
|
26
|
+
else: # weight_tied=True (only one block)
|
|
27
|
+
self.blocks = torch.nn.ModuleList([backbone_blocks])
|
|
28
|
+
|
|
29
|
+
def forward(self, y, physics, x_init=None):
|
|
30
|
+
return physics.A_adjoint(y)
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def measurement_consistency_grad(physics, x, y):
|
|
34
|
+
# grad(||y-Ax||) = A^T(y-Ax)
|
|
35
|
+
return physics.A_adjoint(
|
|
36
|
+
y.to(physics.device) - physics.A(x.to(physics.device))
|
|
37
|
+
).to(x.device)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class GradientDescent(NeuralIteration):
|
|
41
|
+
def __init__(self, backbone_blocks, step_size=1.0, iterations=1):
|
|
42
|
+
super(GradientDescent, self).__init__()
|
|
43
|
+
self.name = "gd"
|
|
44
|
+
self.init(backbone_blocks, step_size, iterations)
|
|
45
|
+
|
|
46
|
+
def forward(self, y, physics, x_init=None):
|
|
47
|
+
x = x_init.clone() if x_init is not None else physics.A_adjoint(y)
|
|
48
|
+
for t in range(self.iterations):
|
|
49
|
+
if self.n_blocks == 1:
|
|
50
|
+
t = 0
|
|
51
|
+
x = (
|
|
52
|
+
self.blocks[t](x)
|
|
53
|
+
+ x
|
|
54
|
+
+ self.step_size[t]
|
|
55
|
+
* NeuralIteration.measurement_consistency_grad(physics, x, y)
|
|
56
|
+
)
|
|
57
|
+
return x
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ProximalGradientDescent(NeuralIteration):
|
|
61
|
+
def __init__(self, backbone_blocks, step_size=1.0, iterations=1):
|
|
62
|
+
super(ProximalGradientDescent, self).__init__()
|
|
63
|
+
self.name = "pgd"
|
|
64
|
+
self.init(backbone_blocks, step_size, iterations)
|
|
65
|
+
self.mc_grad = NeuralIteration.measurement_consistency_grad
|
|
66
|
+
|
|
67
|
+
def forward(self, y, physics, x_init=None):
|
|
68
|
+
x = x_init.clone() if x_init is not None else physics.A_adjoint(y)
|
|
69
|
+
for t in range(self.iterations):
|
|
70
|
+
if self.n_blocks == 1:
|
|
71
|
+
t = 0
|
|
72
|
+
x = self.blocks[t](
|
|
73
|
+
x
|
|
74
|
+
+ self.step_size[t]
|
|
75
|
+
* NeuralIteration.measurement_consistency_grad(physics, x, y)
|
|
76
|
+
)
|
|
77
|
+
return x
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
if __name__ == "__main__":
|
|
81
|
+
import torch
|
|
82
|
+
import deepinv as dinv
|
|
83
|
+
|
|
84
|
+
net = dinv.models.unet().to(dinv.device)
|
|
85
|
+
physics = dinv.physics.Inpainting([32, 32], device=dinv.device)
|
|
86
|
+
|
|
87
|
+
x = torch.randn(10, 1, 32, 32).to(dinv.device)
|
|
88
|
+
y = physics.A(x)
|
|
89
|
+
fbp = physics.A_dagger(y)
|
|
90
|
+
x_rec = net(fbp)
|
|
91
|
+
|
|
92
|
+
unroll = ProximalGradientDescent(net, step_size=1.0, iterations=1)
|
|
93
|
+
x_unroll = unroll(y, physics, x_init=fbp)
|
|
94
|
+
|
|
95
|
+
print("iterations=3")
|
|
96
|
+
iterations = 3
|
|
97
|
+
step_size = 1.0
|
|
98
|
+
blocks = [dinv.models.unet().to(dinv.device) for _ in range(iterations)]
|
|
99
|
+
|
|
100
|
+
physics = dinv.physics.Inpainting([32, 32], device=dinv.device)
|
|
101
|
+
x = torch.randn(10, 1, 32, 32).to(dinv.device)
|
|
102
|
+
y = physics.A(x)
|
|
103
|
+
fbp = physics.A_dagger(y)
|
|
104
|
+
|
|
105
|
+
unroll = ProximalGradientDescent(blocks, step_size=step_size, iterations=iterations)
|
|
106
|
+
x_unroll = unroll(y, physics, x_init=fbp)
|
|
107
|
+
|
|
108
|
+
print(f"iterations={iterations}", x.shape, y.shape, fbp.shape, x_unroll.shape)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_DPIR_params(noise_level_img):
|
|
5
|
+
r"""
|
|
6
|
+
Default parameters for the DPIR Plug-and-Play algorithm.
|
|
7
|
+
|
|
8
|
+
:param float noise_level_img: Noise level of the input image.
|
|
9
|
+
"""
|
|
10
|
+
max_iter = 8
|
|
11
|
+
s1 = 49.0 / 255.0
|
|
12
|
+
s2 = noise_level_img
|
|
13
|
+
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
|
|
14
|
+
np.float32
|
|
15
|
+
)
|
|
16
|
+
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
|
|
17
|
+
lamb = 1 / 0.23
|
|
18
|
+
return lamb, list(sigma_denoiser), list(stepsize), max_iter
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_GSPnP_params(problem, noise_level_img):
|
|
22
|
+
r"""
|
|
23
|
+
Default parameters for the GSPnP Plug-and-Play algorithm.
|
|
24
|
+
|
|
25
|
+
:param str problem: Type of inverse-problem problem to solve. Can be ``deblur``, ``super-resolution``, or ``inpaint``.
|
|
26
|
+
:param float noise_level_img: Noise level of the input image.
|
|
27
|
+
"""
|
|
28
|
+
if problem == "deblur":
|
|
29
|
+
max_iter = 500
|
|
30
|
+
sigma_denoiser = 1.8 * noise_level_img
|
|
31
|
+
lamb = 1 / 0.1
|
|
32
|
+
elif problem == "super-resolution":
|
|
33
|
+
max_iter = 500
|
|
34
|
+
sigma_denoiser = 2.0 * noise_level_img
|
|
35
|
+
lamb = 1 / 0.065
|
|
36
|
+
elif problem == "inpaint":
|
|
37
|
+
max_iter = 100
|
|
38
|
+
sigma_denoiser = 10.0 / 255
|
|
39
|
+
lamb = 1 / 0.1
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError("parameters unknown with this degradation")
|
|
42
|
+
stepsize = 1.0
|
|
43
|
+
return lamb, sigma_denoiser, stepsize, max_iter
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import odl
|
|
6
|
+
except:
|
|
7
|
+
odl = ImportError("The odl package is not installed.")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def random_shapes(interior=False):
|
|
11
|
+
"""
|
|
12
|
+
Generate random shape parameters.
|
|
13
|
+
Taken from https://github.com/adler-j/adler/blob/master/adler/odl/phantom.py
|
|
14
|
+
"""
|
|
15
|
+
if interior:
|
|
16
|
+
x_0 = np.random.rand() - 0.5
|
|
17
|
+
y_0 = np.random.rand() - 0.5
|
|
18
|
+
else:
|
|
19
|
+
x_0 = 2 * np.random.rand() - 1.0
|
|
20
|
+
y_0 = 2 * np.random.rand() - 1.0
|
|
21
|
+
|
|
22
|
+
return (
|
|
23
|
+
(np.random.rand() - 0.5) * np.random.exponential(0.4),
|
|
24
|
+
np.random.exponential() * 0.2,
|
|
25
|
+
np.random.exponential() * 0.2,
|
|
26
|
+
x_0,
|
|
27
|
+
y_0,
|
|
28
|
+
np.random.rand() * 2 * np.pi,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def random_phantom(spc, n_ellipse=50, interior=False):
|
|
33
|
+
"""
|
|
34
|
+
Generate a random ellipsoid phantom.
|
|
35
|
+
Taken from https://github.com/adler-j/adler/blob/master/adler/odl/phantom.py
|
|
36
|
+
"""
|
|
37
|
+
if isinstance(odl, ImportError):
|
|
38
|
+
raise ImportError(
|
|
39
|
+
"odl is needed to use generate random phantoms. "
|
|
40
|
+
"It should be installed with `python3 -m pip install"
|
|
41
|
+
" https://github.com/odlgroup/odl/archive/master.zip`"
|
|
42
|
+
) from odl
|
|
43
|
+
n = np.random.poisson(n_ellipse)
|
|
44
|
+
shapes = [random_shapes(interior=interior) for _ in range(n)]
|
|
45
|
+
return odl.phantom.ellipsoid_phantom(spc, shapes)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RandomPhantomDataset(torch.utils.data.Dataset):
|
|
49
|
+
"""
|
|
50
|
+
Dataset of random ellipsoid phantoms. The phantoms are generated on the fly.
|
|
51
|
+
The phantoms are generated using the odl library (https://odlgroup.github.io/odl/).
|
|
52
|
+
|
|
53
|
+
:param int size: Size of the phantom (square) image.
|
|
54
|
+
:param int n_data: Number of phantoms to generate per sample.
|
|
55
|
+
:param transform: Transformation to apply to the output image.
|
|
56
|
+
:param float length: Length of the dataset. Useful for iterating the data-loader for a certain nb of iterations.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, size=128, n_data=1, transform=None, length=np.inf):
|
|
60
|
+
self.space = odl.uniform_discr(
|
|
61
|
+
[-64, -64], [64, 64], [size, size], dtype="float32"
|
|
62
|
+
)
|
|
63
|
+
self.transform = transform
|
|
64
|
+
self.n_data = n_data
|
|
65
|
+
self.length = length
|
|
66
|
+
|
|
67
|
+
def __len__(self):
|
|
68
|
+
return self.length
|
|
69
|
+
|
|
70
|
+
def __getitem__(self, index):
|
|
71
|
+
"""
|
|
72
|
+
:return tuple : A tuple (phantom, 0) where phantom is a torch tensor of shape (n_data, size, size).
|
|
73
|
+
"""
|
|
74
|
+
phantom_np = np.array([random_phantom(self.space) for i in range(self.n_data)])
|
|
75
|
+
phantom = torch.from_numpy(phantom_np).float()
|
|
76
|
+
if self.transform is not None:
|
|
77
|
+
phantom = self.transform(phantom)
|
|
78
|
+
return phantom, 0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SheppLoganDataset(torch.utils.data.Dataset):
|
|
82
|
+
"""
|
|
83
|
+
Dataset for the single Shepp-Logan phantom. The dataset has length 1.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, size=128, n_data=1, transform=None):
|
|
87
|
+
if isinstance(odl, ImportError):
|
|
88
|
+
raise ImportError(
|
|
89
|
+
"odl is needed to use generate the Shepp Logan phantom. "
|
|
90
|
+
"It should be installed with `python3 -m pip install"
|
|
91
|
+
" https://github.com/odlgroup/odl/archive/master.zip`"
|
|
92
|
+
) from odl
|
|
93
|
+
self.space = odl.uniform_discr(
|
|
94
|
+
[-64, -64], [64, 64], [size, size], dtype="float32"
|
|
95
|
+
)
|
|
96
|
+
self.transform = transform
|
|
97
|
+
self.n_data = n_data
|
|
98
|
+
|
|
99
|
+
def __len__(self):
|
|
100
|
+
return 1
|
|
101
|
+
|
|
102
|
+
def __getitem__(self, index):
|
|
103
|
+
if isinstance(odl, ImportError):
|
|
104
|
+
raise ImportError(
|
|
105
|
+
"odl is needed to use generate the Shepp Logan phantom. "
|
|
106
|
+
"It should be installed with `python3 -m pip install"
|
|
107
|
+
" https://github.com/odlgroup/odl/archive/master.zip`"
|
|
108
|
+
) from odl
|
|
109
|
+
phantom_np = np.array(
|
|
110
|
+
[odl.phantom.shepp_logan(self.space, True) for i in range(self.n_data)]
|
|
111
|
+
)
|
|
112
|
+
phantom = torch.from_numpy(phantom_np).float()
|
|
113
|
+
if self.transform is not None:
|
|
114
|
+
phantom = self.transform(phantom)
|
|
115
|
+
return phantom, 0
|