deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/__about__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
__all__ = [
|
|
2
|
+
"__title__",
|
|
3
|
+
"__summary__",
|
|
4
|
+
"__url__",
|
|
5
|
+
"__version__",
|
|
6
|
+
"__author__",
|
|
7
|
+
"__email__",
|
|
8
|
+
"__license__",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
__title__ = "deepinv"
|
|
12
|
+
__summary__ = "Deep Learning for Inverse Problems Library for PyTorch"
|
|
13
|
+
__url__ = "https://github.com/edongdongchen/deepinv"
|
|
14
|
+
__version__ = "0.0.1"
|
|
15
|
+
__author__ = "Dongdong Chen, Julian Tachella"
|
|
16
|
+
__email__ = "echendongdong@gmail.com"
|
|
17
|
+
__license__ = "BSD 3-Clause Clear"
|
deepinv/__init__.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from .__about__ import *
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__title__",
|
|
6
|
+
"__summary__",
|
|
7
|
+
"__url__",
|
|
8
|
+
"__version__",
|
|
9
|
+
"__author__",
|
|
10
|
+
"__email__",
|
|
11
|
+
"__license__",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
from deepinv import models
|
|
15
|
+
|
|
16
|
+
__all__ += ["models"]
|
|
17
|
+
|
|
18
|
+
from deepinv import optim
|
|
19
|
+
|
|
20
|
+
__all__ += ["optim"]
|
|
21
|
+
|
|
22
|
+
from deepinv import loss
|
|
23
|
+
|
|
24
|
+
__all__ += ["loss"]
|
|
25
|
+
|
|
26
|
+
from deepinv import utils
|
|
27
|
+
|
|
28
|
+
__all__ += ["utils"]
|
|
29
|
+
|
|
30
|
+
from deepinv import models
|
|
31
|
+
|
|
32
|
+
__all__ += ["iterative"]
|
|
33
|
+
|
|
34
|
+
from deepinv import physics
|
|
35
|
+
|
|
36
|
+
__all__ += ["physics"]
|
|
37
|
+
|
|
38
|
+
from deepinv import datasets
|
|
39
|
+
|
|
40
|
+
__all__ += ["datasets"]
|
|
41
|
+
|
|
42
|
+
from deepinv import transform
|
|
43
|
+
|
|
44
|
+
__all__ += ["transform"]
|
|
45
|
+
|
|
46
|
+
from deepinv import sampling
|
|
47
|
+
|
|
48
|
+
__all__ += ["sampling"]
|
|
49
|
+
|
|
50
|
+
from deepinv.loss import metric
|
|
51
|
+
|
|
52
|
+
__all__ += ["metric"]
|
|
53
|
+
|
|
54
|
+
from deepinv import unfolded
|
|
55
|
+
|
|
56
|
+
__all__ += ["unfolded"]
|
|
57
|
+
|
|
58
|
+
from deepinv.training_utils import train, test
|
|
59
|
+
|
|
60
|
+
# GLOBAL PROPERTY
|
|
61
|
+
dtype = torch.float
|
|
62
|
+
|
|
63
|
+
# if torch.cuda.is_available():
|
|
64
|
+
# try:
|
|
65
|
+
# free_gpu_id = get_freer_gpu()
|
|
66
|
+
# device = torch.device(f"cuda:{free_gpu_id}")
|
|
67
|
+
# except:
|
|
68
|
+
# device = torch.device("cuda")
|
|
69
|
+
# print("unable to get GPU info")
|
|
70
|
+
# else:
|
|
71
|
+
# device = "cpu"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .datagenerator import generate_dataset, HDF5Dataset
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
from tqdm import tqdm
|
|
2
|
+
import os
|
|
3
|
+
import h5py
|
|
4
|
+
import torch
|
|
5
|
+
from torch.utils.data import DataLoader, Subset
|
|
6
|
+
from torch.utils import data
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HDF5Dataset(data.Dataset):
|
|
10
|
+
r"""
|
|
11
|
+
DeepInverse HDF5 dataset with signal/measurement pairs.
|
|
12
|
+
|
|
13
|
+
:param str path: Path to the folder containing the dataset (one or multiple HDF5 files).
|
|
14
|
+
:param bool train: Set to ``True`` for training and ``False`` for testing.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, path, train=True):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.data_info = []
|
|
20
|
+
self.data_cache = {}
|
|
21
|
+
self.unsupervised = False
|
|
22
|
+
|
|
23
|
+
hd5 = h5py.File(path, "r")
|
|
24
|
+
if train:
|
|
25
|
+
if "x_train" in hd5:
|
|
26
|
+
self.x = hd5["x_train"]
|
|
27
|
+
else:
|
|
28
|
+
self.unsupervised = True
|
|
29
|
+
self.y = hd5["y_train"]
|
|
30
|
+
else:
|
|
31
|
+
self.x = hd5["x_test"]
|
|
32
|
+
self.y = hd5["y_test"]
|
|
33
|
+
|
|
34
|
+
def __getitem__(self, index):
|
|
35
|
+
y = torch.from_numpy(self.y[index]).type(torch.float)
|
|
36
|
+
|
|
37
|
+
x = y
|
|
38
|
+
if not self.unsupervised:
|
|
39
|
+
x = torch.from_numpy(self.x[index]).type(torch.float)
|
|
40
|
+
|
|
41
|
+
return x, y
|
|
42
|
+
|
|
43
|
+
def __len__(self):
|
|
44
|
+
return len(self.y)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def generate_dataset(
|
|
48
|
+
train_dataset,
|
|
49
|
+
physics,
|
|
50
|
+
save_dir,
|
|
51
|
+
test_dataset=None,
|
|
52
|
+
device="cpu",
|
|
53
|
+
train_datapoints=None,
|
|
54
|
+
test_datapoints=None,
|
|
55
|
+
dataset_filename="dinv_dataset",
|
|
56
|
+
batch_size=4,
|
|
57
|
+
num_workers=0,
|
|
58
|
+
supervised=True,
|
|
59
|
+
):
|
|
60
|
+
r"""
|
|
61
|
+
Generates dataset of signal/measurement pairs from base dataset.
|
|
62
|
+
|
|
63
|
+
It generates the measurement data using the forward operator provided by the user.
|
|
64
|
+
The dataset is saved in HD5 format and can be easily loaded using the HD5Dataset class.
|
|
65
|
+
The generated dataset contains a train and test splits.
|
|
66
|
+
|
|
67
|
+
:param torch.data.Dataset train_dataset: base dataset (e.g., MNIST, CelebA, etc.)
|
|
68
|
+
with images used for generating associated measurements
|
|
69
|
+
via the chosen forward operator. The generated dataset is saved in HD5 format and can be easily loaded using the
|
|
70
|
+
HD5Dataset class.
|
|
71
|
+
:param deepinv.physics.Physics physics: Forward operator used to generate the measurement data.
|
|
72
|
+
It can be either a single operator or a list of forward operators. In the latter case, the dataset will be
|
|
73
|
+
assigned evenly across operators.
|
|
74
|
+
:param str save_dir: folder where the dataset and forward operator will be saved.
|
|
75
|
+
:param torch.data.Dataset test_dataset: if included, the function will also generate measurements associated to the
|
|
76
|
+
test dataset.
|
|
77
|
+
:param torch.device device: which indicates cpu or gpu.
|
|
78
|
+
:param int, None train_datapoints: Desired number of datapoints in the training dataset. If set to ``None``, it will use the
|
|
79
|
+
number of datapoints in the base dataset. This is useful for generating a larger train dataset via data
|
|
80
|
+
augmentation (which should be chosen in the train_dataset).
|
|
81
|
+
:param int, None test_datapoints: Desired number of datapoints in the test dataset. If set to ``None``, it will use the
|
|
82
|
+
number of datapoints in the base test dataset.
|
|
83
|
+
:param str dataset_filename: desired filename of the dataset.
|
|
84
|
+
:param int batch_size: batch size for generating the measurement data
|
|
85
|
+
(it only affects the speed of the generating process)
|
|
86
|
+
:param int num_workers: number of workers for generating the measurement data
|
|
87
|
+
(it only affects the speed of the generating process)
|
|
88
|
+
:param bool supervised: Generates supervised pairs (x,y) of measurements and signals.
|
|
89
|
+
If set to ``False``, it will generate a training dataset with measurements only (y)
|
|
90
|
+
and a test dataset with pairs (x,y)
|
|
91
|
+
|
|
92
|
+
"""
|
|
93
|
+
if os.path.exists(os.path.join(save_dir, dataset_filename)):
|
|
94
|
+
print(
|
|
95
|
+
"WARNING: Dataset already exists, this will overwrite the previous dataset."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
if test_dataset is None and train_dataset is None:
|
|
99
|
+
raise ValueError("No train or test datasets provided.")
|
|
100
|
+
|
|
101
|
+
if not os.path.exists(save_dir):
|
|
102
|
+
os.makedirs(save_dir)
|
|
103
|
+
|
|
104
|
+
if not (type(physics) in [list, tuple]):
|
|
105
|
+
physics = [physics]
|
|
106
|
+
G = 1
|
|
107
|
+
else:
|
|
108
|
+
G = len(physics)
|
|
109
|
+
|
|
110
|
+
if train_dataset is not None:
|
|
111
|
+
if train_datapoints is not None:
|
|
112
|
+
datapoints = int(train_datapoints)
|
|
113
|
+
else:
|
|
114
|
+
datapoints = len(train_dataset)
|
|
115
|
+
|
|
116
|
+
n_train = datapoints # min(len(train_dataset), datapoints)
|
|
117
|
+
n_train_g = int(n_train / G)
|
|
118
|
+
n_dataset_g = int(min(len(train_dataset), datapoints) / G)
|
|
119
|
+
|
|
120
|
+
if test_dataset is not None:
|
|
121
|
+
test_datapoints = (
|
|
122
|
+
test_datapoints if test_datapoints is not None else len(test_dataset)
|
|
123
|
+
)
|
|
124
|
+
n_test = min(len(test_dataset), test_datapoints)
|
|
125
|
+
n_test_g = int(n_test / G)
|
|
126
|
+
|
|
127
|
+
hf_paths = []
|
|
128
|
+
|
|
129
|
+
for g in range(G):
|
|
130
|
+
hf_path = f"{save_dir}/{dataset_filename}{g}.h5"
|
|
131
|
+
hf_paths.append(hf_path)
|
|
132
|
+
hf = h5py.File(hf_path, "w")
|
|
133
|
+
|
|
134
|
+
hf.attrs["operator"] = physics[g].__class__.__name__
|
|
135
|
+
|
|
136
|
+
if train_dataset is not None:
|
|
137
|
+
x = train_dataset[0]
|
|
138
|
+
elif test_dataset is not None:
|
|
139
|
+
x = test_dataset[0]
|
|
140
|
+
|
|
141
|
+
x = x[0] if isinstance(x, list) or isinstance(x, tuple) else x
|
|
142
|
+
x = x.to(device).unsqueeze(0)
|
|
143
|
+
|
|
144
|
+
# choose operator and generate measurement
|
|
145
|
+
y = physics[g](x)
|
|
146
|
+
|
|
147
|
+
torch.save(physics[g].state_dict(), f"{save_dir}/physics{g}.pt")
|
|
148
|
+
|
|
149
|
+
if train_dataset is not None:
|
|
150
|
+
|
|
151
|
+
hf.create_dataset("y_train", (n_train_g,) + y.shape[1:], dtype="float")
|
|
152
|
+
if supervised:
|
|
153
|
+
hf.create_dataset("x_train", (n_train_g,) + x.shape[1:], dtype="float")
|
|
154
|
+
|
|
155
|
+
if G > 1:
|
|
156
|
+
print(
|
|
157
|
+
f"Computing train measurement vectors from base dataset of operator {g + 1} out of {G}..."
|
|
158
|
+
)
|
|
159
|
+
else:
|
|
160
|
+
print("Computing train measurement vectors from base dataset...")
|
|
161
|
+
|
|
162
|
+
index = 0
|
|
163
|
+
|
|
164
|
+
epochs = int(n_train_g / len(train_dataset)) + 1
|
|
165
|
+
for e in tqdm(range(epochs)):
|
|
166
|
+
train_dataloader = DataLoader(
|
|
167
|
+
Subset(
|
|
168
|
+
train_dataset,
|
|
169
|
+
indices=list(range(g * n_dataset_g, (g + 1) * n_dataset_g)),
|
|
170
|
+
),
|
|
171
|
+
batch_size=batch_size,
|
|
172
|
+
num_workers=num_workers,
|
|
173
|
+
pin_memory=False if device == "cpu" else True,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
for i, x in enumerate(train_dataloader):
|
|
177
|
+
x = x[0] if isinstance(x, list) or isinstance(x, tuple) else x
|
|
178
|
+
x = x.to(device)
|
|
179
|
+
|
|
180
|
+
# choose operator and generate measurement
|
|
181
|
+
y = physics[g](x)
|
|
182
|
+
|
|
183
|
+
# Add new data to it
|
|
184
|
+
bsize = x.size()[0]
|
|
185
|
+
|
|
186
|
+
if bsize + index > n_train_g:
|
|
187
|
+
bsize = n_train_g - index
|
|
188
|
+
|
|
189
|
+
hf["y_train"][index : index + bsize] = y[:bsize, :].to("cpu").numpy()
|
|
190
|
+
if supervised:
|
|
191
|
+
hf["x_train"][index : index + bsize] = (
|
|
192
|
+
x[:bsize, :, :, :].to("cpu").numpy()
|
|
193
|
+
)
|
|
194
|
+
index = index + bsize
|
|
195
|
+
|
|
196
|
+
if test_dataset is not None:
|
|
197
|
+
index = 0
|
|
198
|
+
test_dataloader = DataLoader(
|
|
199
|
+
Subset(
|
|
200
|
+
test_dataset, indices=list(range(g * n_test_g, (g + 1) * n_test_g))
|
|
201
|
+
),
|
|
202
|
+
batch_size=batch_size,
|
|
203
|
+
num_workers=num_workers,
|
|
204
|
+
pin_memory=True,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
if G > 1:
|
|
208
|
+
print(
|
|
209
|
+
f"Computing test measurement vectors from base dataset of operator {g + 1} out of {G}..."
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
print("Computing test measurement vectors from base dataset...")
|
|
213
|
+
|
|
214
|
+
for i, x in enumerate(tqdm(test_dataloader)):
|
|
215
|
+
x = x[0] if isinstance(x, list) or isinstance(x, tuple) else x
|
|
216
|
+
x = x.to(device)
|
|
217
|
+
|
|
218
|
+
# choose operator
|
|
219
|
+
y = physics[g](x)
|
|
220
|
+
|
|
221
|
+
if i == 0: # create dict
|
|
222
|
+
hf.create_dataset(
|
|
223
|
+
"x_test", (n_test_g,) + x.shape[1:], dtype="float"
|
|
224
|
+
)
|
|
225
|
+
hf.create_dataset(
|
|
226
|
+
"y_test", (n_test_g,) + y.shape[1:], dtype="float"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Add new data to it
|
|
230
|
+
bsize = x.size()[0]
|
|
231
|
+
hf["x_test"][index : index + bsize] = x.to("cpu").numpy()
|
|
232
|
+
hf["y_test"][index : index + bsize] = y.to("cpu").numpy()
|
|
233
|
+
index = index + bsize
|
|
234
|
+
hf.close()
|
|
235
|
+
|
|
236
|
+
print("Dataset has been saved in " + str(save_dir))
|
|
237
|
+
|
|
238
|
+
return hf_paths[0] if G == 1 else hf_paths
|
deepinv/loss/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from deepinv.loss.mc import MCLoss
|
|
2
|
+
from deepinv.loss.ei import EILoss
|
|
3
|
+
from deepinv.loss.moi import MOILoss
|
|
4
|
+
from deepinv.loss.sup import SupLoss
|
|
5
|
+
from deepinv.loss.score import ScoreLoss
|
|
6
|
+
from deepinv.loss.tv import TVLoss
|
|
7
|
+
from deepinv.loss.sure import SureGaussianLoss, SurePoissonLoss, SurePGLoss
|
|
8
|
+
from deepinv.loss.regularisers import JacobianSpectralNorm, FNEJacobianSpectralNorm
|
|
9
|
+
from deepinv.loss.measplit import SplittingLoss, Neighbor2Neighbor
|
|
10
|
+
from deepinv.loss.metric import LpNorm, CharbonnierLoss
|
deepinv/loss/ei.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class EILoss(nn.Module):
|
|
6
|
+
r"""
|
|
7
|
+
Equivariant imaging self-supervised loss.
|
|
8
|
+
|
|
9
|
+
Assumes that the set of signals is invariant to a group of transformations (rotations, translations, etc.)
|
|
10
|
+
in order to learn from incomplete measurement data alone https://https://arxiv.org/pdf/2103.14756.pdf.
|
|
11
|
+
|
|
12
|
+
The EI loss is defined as
|
|
13
|
+
|
|
14
|
+
.. math::
|
|
15
|
+
|
|
16
|
+
\| T_g \hat{x} - \inverse{\forw{T_g \hat{x}}}\|^2
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
where :math:`\hat{x}=\inverse{y}` is a reconstructed signal and
|
|
20
|
+
:math:`T_g` is a transformation sampled at random from a group :math:`g\sim\group`.
|
|
21
|
+
|
|
22
|
+
By default, the error is computed using the MSE metric, however any other metric (e.g., :math:`\ell_1`)
|
|
23
|
+
can be used as well.
|
|
24
|
+
|
|
25
|
+
:param deepinv.Transform, torchvision.transforms transform: Transform to generate the virtually
|
|
26
|
+
augmented measurement. It can be any torch-differentiable function (e.g., a ``torch.nn.Module``).
|
|
27
|
+
:param torch.nn.Module metric: Metric used to compute the error between the reconstructed augmented measurement and the reference
|
|
28
|
+
image.
|
|
29
|
+
:param bool apply_noise: if ``True``, the augmented measurement is computed with the full sensing model
|
|
30
|
+
:math:`\sensor{\noise{\forw{\hat{x}}}}` (i.e., noise and sensor model),
|
|
31
|
+
otherwise is generated as :math:`\forw{\hat{x}}`.
|
|
32
|
+
:param float weight: Weight of the loss.
|
|
33
|
+
:param bool no_grad: if ``True``, the gradient does not propagate through :math:`T_g`. Default: ``True``.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
transform,
|
|
39
|
+
metric=torch.nn.MSELoss(),
|
|
40
|
+
apply_noise=True,
|
|
41
|
+
weight=1.0,
|
|
42
|
+
no_grad=True,
|
|
43
|
+
):
|
|
44
|
+
super(EILoss, self).__init__()
|
|
45
|
+
self.name = "ei"
|
|
46
|
+
self.metric = metric
|
|
47
|
+
self.weight = weight
|
|
48
|
+
self.T = transform
|
|
49
|
+
self.noise = apply_noise
|
|
50
|
+
self.no_grad = no_grad
|
|
51
|
+
|
|
52
|
+
def forward(self, x_net, physics, model, **kwargs):
|
|
53
|
+
r"""
|
|
54
|
+
Computes the EI loss
|
|
55
|
+
|
|
56
|
+
:param torch.Tensor x_net: Reconstructed image :math:`\inverse{y}`.
|
|
57
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
|
|
58
|
+
:param torch.nn.Module model: Reconstruction function.
|
|
59
|
+
:return: (torch.Tensor) loss.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
if self.no_grad:
|
|
63
|
+
with torch.no_grad():
|
|
64
|
+
x2 = self.T(x_net)
|
|
65
|
+
else:
|
|
66
|
+
x2 = self.T(x_net)
|
|
67
|
+
|
|
68
|
+
if self.noise:
|
|
69
|
+
y = physics(x2)
|
|
70
|
+
else:
|
|
71
|
+
y = physics.A(x2)
|
|
72
|
+
|
|
73
|
+
x3 = model(y, physics)
|
|
74
|
+
|
|
75
|
+
loss_ei = self.weight * self.metric(x3, x2)
|
|
76
|
+
return loss_ei
|
deepinv/loss/mc.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MCLoss(nn.Module):
|
|
6
|
+
r"""
|
|
7
|
+
Measurement consistency loss
|
|
8
|
+
|
|
9
|
+
This loss enforces that the reconstructions are measurement-consistent, i.e., :math:`y=\forw{\inverse{y}}`.
|
|
10
|
+
|
|
11
|
+
The measurement consistency loss is defined as
|
|
12
|
+
|
|
13
|
+
.. math::
|
|
14
|
+
|
|
15
|
+
\|y-\forw{\inverse{y}}\|^2
|
|
16
|
+
|
|
17
|
+
where :math:`\inverse{y}` is the reconstructed signal and :math:`A` is a forward operator.
|
|
18
|
+
|
|
19
|
+
By default, the error is computed using the MSE metric, however any other metric (e.g., :math:`\ell_1`)
|
|
20
|
+
can be used as well.
|
|
21
|
+
|
|
22
|
+
:param torch.nn.Module metric: metric used for computing data consistency, which is set as the mean squared error by default.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, metric=torch.nn.MSELoss()):
|
|
26
|
+
super(MCLoss, self).__init__()
|
|
27
|
+
self.name = "mc"
|
|
28
|
+
self.metric = metric
|
|
29
|
+
|
|
30
|
+
def forward(self, y, x_net, physics, **kwargs):
|
|
31
|
+
r"""
|
|
32
|
+
Computes the measurement splitting loss
|
|
33
|
+
|
|
34
|
+
:param torch.Tensor y: measurements.
|
|
35
|
+
:param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
|
|
36
|
+
:param deepinv.physics.Physics physics: forward operator associated with the measurements.
|
|
37
|
+
:return: (torch.Tensor) loss.
|
|
38
|
+
"""
|
|
39
|
+
return self.metric(physics.A(x_net), y)
|
deepinv/loss/measplit.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from deepinv.physics import Inpainting
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SplittingLoss(torch.nn.Module):
|
|
7
|
+
r"""
|
|
8
|
+
Measurement splitting loss.
|
|
9
|
+
|
|
10
|
+
Splits the measurement and forward operator (of size :math:`m`)
|
|
11
|
+
into two smaller pairs :math:`(y_1,A_1)` (of size :math:`m_1`) and :math:`(y_2,A_2)` (of size :math:`m_2`) ,
|
|
12
|
+
to compute the self-supervised loss:
|
|
13
|
+
|
|
14
|
+
.. math::
|
|
15
|
+
|
|
16
|
+
\frac{m}{m_2}\| y_2 - A_2 \inversef{y_1,A_1}\|^2
|
|
17
|
+
|
|
18
|
+
where :math:`R` is the trainable network. See https://pubmed.ncbi.nlm.nih.gov/32614100/.
|
|
19
|
+
|
|
20
|
+
By default, the error is computed using the MSE metric, however any other metric (e.g., :math:`\ell_1`)
|
|
21
|
+
can be used as well.
|
|
22
|
+
|
|
23
|
+
:param torch.nn.Module metric: metric used for computing data consistency,
|
|
24
|
+
which is set as the mean squared error by default.
|
|
25
|
+
:param float split_ratio: splitting ratio, should be between 0 and 1. The size of :math:`y_1` increases
|
|
26
|
+
with the splitting ratio.
|
|
27
|
+
:param bool regular_mask: If ``True``, it will use a regular mask, otherwise it uses a random mask.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, metric=torch.nn.MSELoss(), split_ratio=0.9, regular_mask=False):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.name = "ms"
|
|
33
|
+
self.metric = metric
|
|
34
|
+
self.regular_mask = regular_mask
|
|
35
|
+
self.split_ratio = split_ratio
|
|
36
|
+
|
|
37
|
+
def forward(self, y, physics, model, **kwargs):
|
|
38
|
+
r"""
|
|
39
|
+
Computes the measurement splitting loss
|
|
40
|
+
|
|
41
|
+
:param torch.Tensor y: Measurements.
|
|
42
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
|
|
43
|
+
:param torch.nn.Module model: Reconstruction function.
|
|
44
|
+
:return: (torch.Tensor) loss.
|
|
45
|
+
"""
|
|
46
|
+
tsize = y.size()[1:]
|
|
47
|
+
|
|
48
|
+
# sample a splitting
|
|
49
|
+
mask = torch.ones(tsize).to(y.device)
|
|
50
|
+
if not self.regular_mask:
|
|
51
|
+
mask[torch.rand_like(mask) > self.split_ratio] = 0
|
|
52
|
+
else:
|
|
53
|
+
stride = int(1 / (1 - self.split_ratio))
|
|
54
|
+
start = np.random.randint(stride)
|
|
55
|
+
mask[..., start::stride, start::stride] = 0.0
|
|
56
|
+
|
|
57
|
+
# create inpainting masks
|
|
58
|
+
inp = Inpainting(tsize, mask)
|
|
59
|
+
inp2 = Inpainting(tsize, 1 - mask)
|
|
60
|
+
|
|
61
|
+
# concatenate operators
|
|
62
|
+
physics1 = inp * physics # A_1 = P*A
|
|
63
|
+
physics2 = inp2 * physics # A_2 = (I-P)*A
|
|
64
|
+
|
|
65
|
+
# divide measurements
|
|
66
|
+
y1 = inp.A(y)
|
|
67
|
+
y2 = inp2.A(y)
|
|
68
|
+
|
|
69
|
+
loss_ms = self.metric(physics2.A(model(y1, physics1)), y2)
|
|
70
|
+
loss_ms /= 1 - self.split_ratio # normalize loss
|
|
71
|
+
|
|
72
|
+
return loss_ms
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Neighbor2Neighbor(torch.nn.Module):
|
|
76
|
+
r"""
|
|
77
|
+
Neighbor2Neighbor loss.
|
|
78
|
+
|
|
79
|
+
Splits the noisy measurements using two masks :math:`A_1` and :math:`A_2`, each choosing a different neighboring
|
|
80
|
+
map (see details in `"Neighbor2Neighbor: Self-Supervised Denoising from Single Noisy Images"
|
|
81
|
+
<https://openaccess.thecvf.com/content/CVPR2021/papers/Huang_Neighbor2Neighbor_Self-Supervised_Denoising_From_Single_Noisy_Images_CVPR_2021_paper.pdf>`_).
|
|
82
|
+
|
|
83
|
+
The self-supervised loss is computed as:
|
|
84
|
+
|
|
85
|
+
.. math::
|
|
86
|
+
|
|
87
|
+
\| A_2 y - R(A_1 y)\|^2 + \gamma \| A_2 y - R(A_1 y) - (A_2 R(y) - A_1 R(y))\|^2
|
|
88
|
+
|
|
89
|
+
where :math:`R` is the trainable denoiser network, :math:`\gamma>0` is a regularization parameter
|
|
90
|
+
and no gradient is propagated when computing :math:`R(y)`.
|
|
91
|
+
|
|
92
|
+
By default, the error is computed using the MSE metric, however any other metric (e.g., :math:`\ell_1`)
|
|
93
|
+
can be used as well.
|
|
94
|
+
|
|
95
|
+
The code has been adapted from the repository https://github.com/TaoHuang2018/Neighbor2Neighbor.
|
|
96
|
+
|
|
97
|
+
:param torch.nn.Module metric: metric used for computing data consistency,
|
|
98
|
+
which is set as the mean squared error by default.
|
|
99
|
+
:param float gamma: regularization parameter :math:`\gamma`.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, metric=torch.nn.MSELoss(), gamma=2.0):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.name = "neigh2neigh"
|
|
105
|
+
self.metric = metric
|
|
106
|
+
self.gamma = gamma
|
|
107
|
+
|
|
108
|
+
def space_to_depth(self, x, block_size):
|
|
109
|
+
n, c, h, w = x.size()
|
|
110
|
+
unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
|
|
111
|
+
return unfolded_x.view(n, c * block_size**2, h // block_size, w // block_size)
|
|
112
|
+
|
|
113
|
+
def generate_mask_pair(self, img):
|
|
114
|
+
# prepare masks (N x C x H/2 x W/2)
|
|
115
|
+
n, c, h, w = img.shape
|
|
116
|
+
mask1 = torch.zeros(
|
|
117
|
+
size=(n * h // 2 * w // 2 * 4,), dtype=torch.bool, device=img.device
|
|
118
|
+
)
|
|
119
|
+
mask2 = torch.zeros(
|
|
120
|
+
size=(n * h // 2 * w // 2 * 4,), dtype=torch.bool, device=img.device
|
|
121
|
+
)
|
|
122
|
+
# prepare random mask pairs
|
|
123
|
+
idx_pair = torch.tensor(
|
|
124
|
+
[[0, 1], [0, 2], [1, 3], [2, 3], [1, 0], [2, 0], [3, 1], [3, 2]],
|
|
125
|
+
dtype=torch.int64,
|
|
126
|
+
device=img.device,
|
|
127
|
+
)
|
|
128
|
+
rd_idx = torch.zeros(
|
|
129
|
+
size=(n * h // 2 * w // 2,), dtype=torch.int64, device=img.device
|
|
130
|
+
)
|
|
131
|
+
torch.randint(low=0, high=8, size=(n * h // 2 * w // 2,), out=rd_idx)
|
|
132
|
+
rd_pair_idx = idx_pair[rd_idx]
|
|
133
|
+
rd_pair_idx += torch.arange(
|
|
134
|
+
start=0,
|
|
135
|
+
end=n * h // 2 * w // 2 * 4,
|
|
136
|
+
step=4,
|
|
137
|
+
dtype=torch.int64,
|
|
138
|
+
device=img.device,
|
|
139
|
+
).reshape(-1, 1)
|
|
140
|
+
# get masks
|
|
141
|
+
mask1[rd_pair_idx[:, 0]] = 1
|
|
142
|
+
mask2[rd_pair_idx[:, 1]] = 1
|
|
143
|
+
return mask1, mask2
|
|
144
|
+
|
|
145
|
+
def generate_subimages(self, img, mask):
|
|
146
|
+
n, c, h, w = img.shape
|
|
147
|
+
subimage = torch.zeros(
|
|
148
|
+
n, c, h // 2, w // 2, dtype=img.dtype, layout=img.layout, device=img.device
|
|
149
|
+
)
|
|
150
|
+
# per channel
|
|
151
|
+
for i in range(c):
|
|
152
|
+
img_per_channel = self.space_to_depth(img[:, i : i + 1, :, :], block_size=2)
|
|
153
|
+
img_per_channel = img_per_channel.permute(0, 2, 3, 1).reshape(-1)
|
|
154
|
+
subimage[:, i : i + 1, :, :] = (
|
|
155
|
+
img_per_channel[mask].reshape(n, h // 2, w // 2, 1).permute(0, 3, 1, 2)
|
|
156
|
+
)
|
|
157
|
+
return subimage
|
|
158
|
+
|
|
159
|
+
def forward(self, y, physics, model, **kwargs):
|
|
160
|
+
r"""
|
|
161
|
+
Computes the neighbor2neighbor loss.
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
:param torch.Tensor y: Measurements.
|
|
165
|
+
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
|
|
166
|
+
:param torch.nn.Module model: Reconstruction function.
|
|
167
|
+
:return: (torch.Tensor) loss.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
assert len(y.shape) == 4, "Input measurements should be images"
|
|
171
|
+
assert (
|
|
172
|
+
y.shape[2] % 2 == 0 and y.shape[3] % 2 == 0
|
|
173
|
+
), "Image dimensions should be even"
|
|
174
|
+
|
|
175
|
+
mask1, mask2 = self.generate_mask_pair(y)
|
|
176
|
+
|
|
177
|
+
y1 = self.generate_subimages(y, mask1)
|
|
178
|
+
xhat1 = model(y1, physics)
|
|
179
|
+
y2 = self.generate_subimages(y, mask2)
|
|
180
|
+
|
|
181
|
+
xhat = model(y, physics).detach()
|
|
182
|
+
y1_hat = self.generate_subimages(xhat, mask1)
|
|
183
|
+
y2_hat = self.generate_subimages(xhat, mask2)
|
|
184
|
+
|
|
185
|
+
loss_n2n = self.metric(xhat1, y2) + self.gamma * self.metric(
|
|
186
|
+
xhat1 - y1_hat, y2 - y2_hat
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return loss_n2n
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# if __name__ == "__main__":
|
|
193
|
+
# import deepinv as dinv
|
|
194
|
+
#
|
|
195
|
+
# sigma = 0.1
|
|
196
|
+
# physics = dinv.physics.Denoising()
|
|
197
|
+
# physics.noise_model = dinv.physics.GaussianNoise(sigma)
|
|
198
|
+
#
|
|
199
|
+
# # choose a reconstruction architecture
|
|
200
|
+
# backbone = dinv.models.MedianFilter()
|
|
201
|
+
# f = dinv.models.ArtifactRemoval(backbone)
|
|
202
|
+
# batch_size = 1
|
|
203
|
+
# imsize = (3, 128, 128)
|
|
204
|
+
#
|
|
205
|
+
# for split_ratio in np.linspace(0.7, 0.99, 10):
|
|
206
|
+
# x = torch.ones((batch_size,) + imsize, device=dinv.device)
|
|
207
|
+
# y = physics(x)
|
|
208
|
+
#
|
|
209
|
+
# # choose training losses
|
|
210
|
+
# loss = SplittingLoss(split_ratio=split_ratio, regular_mask=True)
|
|
211
|
+
# x_net = f(y, physics)
|
|
212
|
+
# mse = dinv.metric.mse()(physics.A(x), physics.A(x_net))
|
|
213
|
+
# split_loss = loss(y, physics, f)
|
|
214
|
+
#
|
|
215
|
+
# print(
|
|
216
|
+
# f"split_ratio:{split_ratio:.2f} mse: {mse:.2e}, split-loss: {split_loss:.2e}"
|
|
217
|
+
# )
|
|
218
|
+
# rel_error = (split_loss - mse).abs() / mse
|
|
219
|
+
# print(f"rel_error: {rel_error:.2f}")
|