deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/physics/lidar.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from deepinv.physics.forward import Physics
|
|
3
|
+
from deepinv.physics.noise import PoissonNoise
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SinglePhotonLidar(Physics):
|
|
7
|
+
r"""
|
|
8
|
+
Single photon lidar operator for depth ranging.
|
|
9
|
+
|
|
10
|
+
See https://ieeexplore.ieee.org/abstract/document/9127841 for a review of this imaging method.
|
|
11
|
+
|
|
12
|
+
The forward operator is given by
|
|
13
|
+
|
|
14
|
+
.. math::
|
|
15
|
+
y_{i,j,t} = \mathcal{P}(h(t-d_{i,j}) r_{i,j} + b_{i,j})
|
|
16
|
+
|
|
17
|
+
where :math:`\mathcal{P}` is the Poisson noise model, :math:`h(t)` is a Gaussian impulse response function at
|
|
18
|
+
time :math:`t`, :math:`d_{i,j}` is the depth of the scene at pixel :math:`(i,j)`,
|
|
19
|
+
:math:`r_{i,j}` is the intensity of the scene at pixel :math:`(i,j)` and :math:`b_{i,j}` is the background noise
|
|
20
|
+
at pixel :math:`(i,j)`.
|
|
21
|
+
|
|
22
|
+
For a pixel grid of size (H,W) and batch size B, the signals have size (B, 3, H, W), where the first channel
|
|
23
|
+
contains the depth of the scene :math:`d`, the second channel contains the intensity of the scene :math:`r` and
|
|
24
|
+
the third channel contains the per pixel background noise levels :math:`b`.
|
|
25
|
+
|
|
26
|
+
:param float sigma: Standard deviation of the Gaussian impulse response function.
|
|
27
|
+
:param int bins: Number of histogram bins per pixel.
|
|
28
|
+
:param str device: Device to use (gpu or cpu).
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, sigma=1.0, bins=50, device="cpu"):
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
self.T = bins
|
|
35
|
+
self.grid = torch.meshgrid(torch.arange(bins), indexing="ij")[0].to(device)
|
|
36
|
+
self.sigma = torch.nn.Parameter(
|
|
37
|
+
torch.tensor(sigma, device=device), requires_grad=False
|
|
38
|
+
)
|
|
39
|
+
self.noise_model = PoissonNoise()
|
|
40
|
+
|
|
41
|
+
h = ((self.grid - 3 * sigma) / self.sigma).pow(2)
|
|
42
|
+
h = torch.exp(-h / 2.0)
|
|
43
|
+
h = h[: int(6 * sigma)]
|
|
44
|
+
h = h / h.sum()
|
|
45
|
+
self.irf = h.unsqueeze(0).unsqueeze(0) # set impulse response function
|
|
46
|
+
self.grid = self.grid.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
|
47
|
+
|
|
48
|
+
def A(self, x):
|
|
49
|
+
r"""
|
|
50
|
+
Applies the forward operator.
|
|
51
|
+
|
|
52
|
+
Input is of size (B, 3, H, W) and output is of size (B, bins, H, W)
|
|
53
|
+
|
|
54
|
+
:param torch.tensor x: tensor containing the depth, intensity and background noise levels.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
h = ((self.grid - x[:, 0, :, :]) / self.sigma).pow(2)
|
|
58
|
+
h = torch.exp(-h / 2.0)
|
|
59
|
+
h = h / h.sum(dim=1, keepdim=True)
|
|
60
|
+
y = x[:, 1, :, :] * h + x[:, 2, :, :]
|
|
61
|
+
return y
|
|
62
|
+
|
|
63
|
+
def A_dagger(self, y):
|
|
64
|
+
r"""
|
|
65
|
+
Applies Matched filtering to find the peaks.
|
|
66
|
+
|
|
67
|
+
Input is of size (B, bins, H, W), output of size (B, 3, H, W).
|
|
68
|
+
|
|
69
|
+
:param torch.tensor y: measurements
|
|
70
|
+
"""
|
|
71
|
+
B, T, H, W = y.shape
|
|
72
|
+
|
|
73
|
+
# reshape to (B*H*W, 1, T)
|
|
74
|
+
y = y.permute(0, 2, 3, 1).reshape(B * H * W, 1, T)
|
|
75
|
+
|
|
76
|
+
# Apply irf using convolution
|
|
77
|
+
x = torch.nn.functional.conv1d(y, self.irf, padding="same")
|
|
78
|
+
|
|
79
|
+
# Find peak value in each channel
|
|
80
|
+
_, x = torch.max(x, dim=-1, keepdim=True)
|
|
81
|
+
x = x.type(torch.float32)
|
|
82
|
+
offset = self.irf.shape[-1] // 2
|
|
83
|
+
x -= 3 * self.sigma - offset - 0.5
|
|
84
|
+
|
|
85
|
+
mask = torch.ones_like(y)
|
|
86
|
+
grid = self.grid.squeeze(-1).squeeze(-1) # (1, T)
|
|
87
|
+
|
|
88
|
+
mask *= (x - 4 * self.sigma) < grid
|
|
89
|
+
mask *= (x + 4 * self.sigma) > grid
|
|
90
|
+
|
|
91
|
+
b = (y * (1 - mask)).sum(dim=-1, keepdim=True)
|
|
92
|
+
r = y.sum(dim=-1, keepdim=True) - b
|
|
93
|
+
b /= T
|
|
94
|
+
|
|
95
|
+
x = torch.stack([x, r, b], dim=-1)
|
|
96
|
+
|
|
97
|
+
# reshape to (B, 3, H, W)
|
|
98
|
+
x = x.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
|
99
|
+
|
|
100
|
+
return x
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# if __name__ == "__main__":
|
|
104
|
+
# import matplotlib.pyplot as plt
|
|
105
|
+
# import deepinv as dinv
|
|
106
|
+
#
|
|
107
|
+
# bins = 40
|
|
108
|
+
# device = "cuda:0"
|
|
109
|
+
# physics = SinglePhotonLidar(bins=bins, device=device)
|
|
110
|
+
#
|
|
111
|
+
# x = torch.ones((1, 3, 2, 4), device=device)
|
|
112
|
+
# x[:, 0, :, :] *= bins / 2
|
|
113
|
+
# x[:, 1, :, :] *= 300
|
|
114
|
+
# x[:, 2, :, :] *= 1
|
|
115
|
+
#
|
|
116
|
+
# y = physics(x)
|
|
117
|
+
# xhat = physics.A_dagger(y)
|
|
118
|
+
#
|
|
119
|
+
# y0 = y[0, :, 0, 0].detach().cpu().numpy()
|
|
120
|
+
# plt.plot(y0)
|
|
121
|
+
# plt.show()
|
|
122
|
+
#
|
|
123
|
+
# print(f"MSE {dinv.utils.cal_mse(x, xhat)}")
|
deepinv/physics/mri.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.fft
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
from deepinv.physics.forward import DecomposablePhysics
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MRI(DecomposablePhysics):
|
|
9
|
+
r"""
|
|
10
|
+
Single-coil accelerated magnetic resonance imaging.
|
|
11
|
+
|
|
12
|
+
The linear operator operates in 2D slices and is defined as
|
|
13
|
+
|
|
14
|
+
.. math::
|
|
15
|
+
|
|
16
|
+
y = SFx
|
|
17
|
+
|
|
18
|
+
where :math:`S` applies a mask (subsampling operator), and :math:`F` is the 2D discrete Fourier Transform.
|
|
19
|
+
This operator has a simple singular value decomposition, so it inherits the structure of
|
|
20
|
+
:meth:`deepinv.physics.DecomposablePhysics` and thus have a fast pseudo-inverse and prox operators.
|
|
21
|
+
|
|
22
|
+
The complex images :math:`x` and measurements :math:`y` should be of size (B, 2, H, W) where the first channel corresponds to the real part
|
|
23
|
+
and the second channel corresponds to the imaginary part.
|
|
24
|
+
|
|
25
|
+
:param torch.tensor mask: the mask values should be binary.
|
|
26
|
+
The mask size should be of the form (H,W) where H is the image height and W is the image width.
|
|
27
|
+
:param torch.device device: cpu or gpu.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
mask=None,
|
|
33
|
+
image_size=(320, 320),
|
|
34
|
+
acceleration_factor=4,
|
|
35
|
+
device="cpu",
|
|
36
|
+
seed=None,
|
|
37
|
+
**kwargs,
|
|
38
|
+
):
|
|
39
|
+
super().__init__(**kwargs)
|
|
40
|
+
self.device = device
|
|
41
|
+
self.image_size = image_size
|
|
42
|
+
|
|
43
|
+
if mask is not None:
|
|
44
|
+
mask = mask.to(device).unsqueeze(0).unsqueeze(0)
|
|
45
|
+
else:
|
|
46
|
+
mask = (
|
|
47
|
+
self.sample_mask(
|
|
48
|
+
image_size=image_size,
|
|
49
|
+
acceleration_factor=acceleration_factor,
|
|
50
|
+
seed=seed,
|
|
51
|
+
)
|
|
52
|
+
.unsqueeze(0)
|
|
53
|
+
.unsqueeze(0)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.mask = torch.nn.Parameter(
|
|
57
|
+
torch.cat([mask, mask], dim=1), requires_grad=False
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def reset(self, **kwargs):
|
|
61
|
+
r"""
|
|
62
|
+
Resets the physics, i.e. re-samples a new mask and new noise realization (if any).
|
|
63
|
+
"""
|
|
64
|
+
super().reset(**kwargs)
|
|
65
|
+
mask = (
|
|
66
|
+
self.sample_mask(image_size=self.image_size, **kwargs)
|
|
67
|
+
.unsqueeze(0)
|
|
68
|
+
.unsqueeze(0)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self.mask = torch.nn.Parameter(
|
|
72
|
+
torch.cat([mask, mask], dim=1), requires_grad=False
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def V_adjoint(self, x): # (B, 2, H, W) -> (B, H, W, 2)
|
|
76
|
+
y = fft2c_new(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
77
|
+
return y
|
|
78
|
+
|
|
79
|
+
def U(self, x):
|
|
80
|
+
return x[:, self.mask.squeeze(0) > 0]
|
|
81
|
+
|
|
82
|
+
def U_adjoint(self, x):
|
|
83
|
+
_, c, h, w = self.mask.shape
|
|
84
|
+
out = torch.zeros((x.shape[0], c, h, w), device=x.device)
|
|
85
|
+
out[:, self.mask.squeeze(0) > 0] = x
|
|
86
|
+
return out
|
|
87
|
+
|
|
88
|
+
def V(self, x): # (B, 2, H, W) -> (B, H, W, 2)
|
|
89
|
+
x = x.permute(0, 2, 3, 1)
|
|
90
|
+
return ifft2c_new(x).permute(0, 3, 1, 2)
|
|
91
|
+
|
|
92
|
+
def sample_mask(self, image_size=(320, 320), acceleration_factor=4, seed=None):
|
|
93
|
+
r"""
|
|
94
|
+
Create a mask of vertical lines.
|
|
95
|
+
|
|
96
|
+
:param tuple image_size: image size.
|
|
97
|
+
:param int acceleration_factor: acceleration factor.
|
|
98
|
+
:param int seed: random seed.
|
|
99
|
+
:return: mask of size (H, W) with values in {0, 1}.
|
|
100
|
+
"""
|
|
101
|
+
if seed is not None:
|
|
102
|
+
np.random.seed(seed)
|
|
103
|
+
if acceleration_factor == 4:
|
|
104
|
+
central_lines_percent = 0.08
|
|
105
|
+
num_lines_center = int(central_lines_percent * image_size[-1])
|
|
106
|
+
side_lines_percent = 0.25 - central_lines_percent
|
|
107
|
+
num_lines_side = int(side_lines_percent * image_size[-1])
|
|
108
|
+
if acceleration_factor == 8:
|
|
109
|
+
central_lines_percent = 0.04
|
|
110
|
+
num_lines_center = int(central_lines_percent * image_size[-1])
|
|
111
|
+
side_lines_percent = 0.125 - central_lines_percent
|
|
112
|
+
num_lines_side = int(side_lines_percent * image_size[-1])
|
|
113
|
+
mask = torch.zeros(image_size)
|
|
114
|
+
center_line_indices = torch.linspace(
|
|
115
|
+
image_size[0] // 2 - num_lines_center // 2,
|
|
116
|
+
image_size[0] // 2 + num_lines_center // 2 + 1,
|
|
117
|
+
steps=50,
|
|
118
|
+
dtype=torch.long,
|
|
119
|
+
)
|
|
120
|
+
mask[:, center_line_indices] = 1
|
|
121
|
+
random_line_indices = np.random.choice(
|
|
122
|
+
image_size[0], size=(num_lines_side // 2,), replace=False
|
|
123
|
+
)
|
|
124
|
+
mask[:, random_line_indices] = 1
|
|
125
|
+
return mask.float().to(self.device)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
#
|
|
129
|
+
# reference: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py
|
|
130
|
+
def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
|
|
131
|
+
r"""
|
|
132
|
+
Apply centered 2 dimensional Fast Fourier Transform.
|
|
133
|
+
:param torch.tensor data: Complex valued input data containing at least 3 dimensions:
|
|
134
|
+
dimensions -2 & -1 are spatial dimensions and dimension -3 has size
|
|
135
|
+
2. All other dimensions are assumed to be batch dimensions.
|
|
136
|
+
:param bool norm: Normalization mode. See ``torch.fft.fft``.
|
|
137
|
+
:return: (torch.tensor) the FFT of the input.
|
|
138
|
+
"""
|
|
139
|
+
if not data.shape[-1] == 2:
|
|
140
|
+
raise ValueError("Tensor does not have separate complex dim.")
|
|
141
|
+
|
|
142
|
+
data = ifftshift(data, dim=[-3, -2])
|
|
143
|
+
data = torch.view_as_real(
|
|
144
|
+
torch.fft.fftn( # type: ignore
|
|
145
|
+
torch.view_as_complex(data), dim=(-2, -1), norm=norm
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
data = fftshift(data, dim=[-3, -2])
|
|
149
|
+
|
|
150
|
+
return data
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
|
|
154
|
+
"""
|
|
155
|
+
Apply centered 2-dimensional Inverse Fast Fourier Transform.
|
|
156
|
+
Args:
|
|
157
|
+
data: Complex valued input data containing at least 3 dimensions:
|
|
158
|
+
dimensions -2 & -1 are spatial dimensions and dimension -3 has size
|
|
159
|
+
2. All other dimensions are assumed to be batch dimensions.
|
|
160
|
+
norm: Normalization mode. See ``torch.fft.ifft``.
|
|
161
|
+
Returns:
|
|
162
|
+
The IFFT of the input.
|
|
163
|
+
"""
|
|
164
|
+
if not data.shape[-1] == 2:
|
|
165
|
+
raise ValueError("Tensor does not have separate complex dim.")
|
|
166
|
+
|
|
167
|
+
data = ifftshift(data, dim=[-3, -2])
|
|
168
|
+
data = torch.view_as_real(
|
|
169
|
+
torch.fft.ifftn( # type: ignore
|
|
170
|
+
torch.view_as_complex(data), dim=(-2, -1), norm=norm
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
data = fftshift(data, dim=[-3, -2])
|
|
174
|
+
|
|
175
|
+
return data
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# Helper functions
|
|
179
|
+
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
|
|
180
|
+
"""
|
|
181
|
+
Similar to roll but for only one dim.
|
|
182
|
+
Args:
|
|
183
|
+
x: A PyTorch tensor.
|
|
184
|
+
shift: Amount to roll.
|
|
185
|
+
dim: Which dimension to roll.
|
|
186
|
+
Returns:
|
|
187
|
+
Rolled version of x.
|
|
188
|
+
"""
|
|
189
|
+
shift = shift % x.size(dim)
|
|
190
|
+
if shift == 0:
|
|
191
|
+
return x
|
|
192
|
+
|
|
193
|
+
left = x.narrow(dim, 0, x.size(dim) - shift)
|
|
194
|
+
right = x.narrow(dim, x.size(dim) - shift, shift)
|
|
195
|
+
|
|
196
|
+
return torch.cat((right, left), dim=dim)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def roll(x: torch.Tensor, shift: List[int], dim: List[int]) -> torch.Tensor:
|
|
200
|
+
"""
|
|
201
|
+
Similar to np.roll but applies to PyTorch Tensors.
|
|
202
|
+
Args:
|
|
203
|
+
x: A PyTorch tensor.
|
|
204
|
+
shift: Amount to roll.
|
|
205
|
+
dim: Which dimension to roll.
|
|
206
|
+
Returns:
|
|
207
|
+
Rolled version of x.
|
|
208
|
+
"""
|
|
209
|
+
if len(shift) != len(dim):
|
|
210
|
+
raise ValueError("len(shift) must match len(dim)")
|
|
211
|
+
|
|
212
|
+
for s, d in zip(shift, dim):
|
|
213
|
+
x = roll_one_dim(x, s, d)
|
|
214
|
+
|
|
215
|
+
return x
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
|
|
219
|
+
"""
|
|
220
|
+
Similar to np.fft.fftshift but applies to PyTorch Tensors
|
|
221
|
+
Args:
|
|
222
|
+
x: A PyTorch tensor.
|
|
223
|
+
dim: Which dimension to fftshift.
|
|
224
|
+
Returns:
|
|
225
|
+
fftshifted version of x.
|
|
226
|
+
"""
|
|
227
|
+
if dim is None:
|
|
228
|
+
# this weird code is necessary for toch.jit.script typing
|
|
229
|
+
dim = [0] * (x.dim())
|
|
230
|
+
for i in range(1, x.dim()):
|
|
231
|
+
dim[i] = i
|
|
232
|
+
|
|
233
|
+
# also necessary for torch.jit.script
|
|
234
|
+
shift = [0] * len(dim)
|
|
235
|
+
for i, dim_num in enumerate(dim):
|
|
236
|
+
shift[i] = x.shape[dim_num] // 2
|
|
237
|
+
|
|
238
|
+
return roll(x, shift, dim)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
|
|
242
|
+
"""
|
|
243
|
+
Similar to np.fft.ifftshift but applies to PyTorch Tensors
|
|
244
|
+
Args:
|
|
245
|
+
x: A PyTorch tensor.
|
|
246
|
+
dim: Which dimension to ifftshift.
|
|
247
|
+
Returns:
|
|
248
|
+
ifftshifted version of x.
|
|
249
|
+
"""
|
|
250
|
+
if dim is None:
|
|
251
|
+
# this weird code is necessary for toch.jit.script typing
|
|
252
|
+
dim = [0] * (x.dim())
|
|
253
|
+
for i in range(1, x.dim()):
|
|
254
|
+
dim[i] = i
|
|
255
|
+
|
|
256
|
+
# also necessary for torch.jit.script
|
|
257
|
+
shift = [0] * len(dim)
|
|
258
|
+
for i, dim_num in enumerate(dim):
|
|
259
|
+
shift[i] = (x.shape[dim_num] + 1) // 2
|
|
260
|
+
|
|
261
|
+
return roll(x, shift, dim)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
# if __name__ == "__main__":
|
|
265
|
+
# # deepinv test
|
|
266
|
+
# from deepinv.tests.test_physics import (
|
|
267
|
+
# test_operators_norm,
|
|
268
|
+
# test_operators_adjointness,
|
|
269
|
+
# test_pseudo_inverse,
|
|
270
|
+
# device,
|
|
271
|
+
# )
|
|
272
|
+
# import deepinv as dinv
|
|
273
|
+
# from fastmri.data import subsample
|
|
274
|
+
#
|
|
275
|
+
# imsize = (25, 32)
|
|
276
|
+
# # Create a mask function
|
|
277
|
+
# mask_func = subsample.RandomMaskFunc(center_fractions=[0.08], accelerations=[4])
|
|
278
|
+
# m = mask_func.sample_mask((imsize[1], imsize[0]), offset=None)
|
|
279
|
+
#
|
|
280
|
+
# # mask = torch.ones((imsize[0], 1)) * (m[0] + m[1]).permute(1, 0)
|
|
281
|
+
# mask = torch.ones(imsize)
|
|
282
|
+
# mask[mask > 1] = 1
|
|
283
|
+
#
|
|
284
|
+
# sigma = 0.1
|
|
285
|
+
# # physics = MRI(mask=mask, device=dinv.device)
|
|
286
|
+
# physics = dinv.physics.Denoising()
|
|
287
|
+
# physics.noise_model = dinv.physics.GaussianNoise(sigma)
|
|
288
|
+
#
|
|
289
|
+
# # choose a reconstruction architecture
|
|
290
|
+
# backbone = dinv.models.MedianFilter()
|
|
291
|
+
#
|
|
292
|
+
# class denoiser(torch.nn.Module):
|
|
293
|
+
# def __init__(self):
|
|
294
|
+
# super().__init__()
|
|
295
|
+
#
|
|
296
|
+
# def forward(self, x, sigma=None):
|
|
297
|
+
# return x
|
|
298
|
+
#
|
|
299
|
+
# f = dinv.models.ArtifactRemoval(backbone)
|
|
300
|
+
#
|
|
301
|
+
# batch_size = 1
|
|
302
|
+
#
|
|
303
|
+
# for tau in np.logspace(-5, 3, 1):
|
|
304
|
+
# x = torch.ones((batch_size, 2) + imsize, device=dinv.device)
|
|
305
|
+
# y = physics(x)
|
|
306
|
+
#
|
|
307
|
+
# # choose training losses
|
|
308
|
+
# loss = dinv.loss.SureGaussianLoss(sigma, tau=tau)
|
|
309
|
+
# x_net = f(y, physics)
|
|
310
|
+
# mse = dinv.metric.mse()(physics.A(x), physics.A(x_net))
|
|
311
|
+
# sure = loss(y, x_net, physics, f)
|
|
312
|
+
#
|
|
313
|
+
# print(f"tau:{tau:.2e} mse: {mse:.2e}, sure: {sure:.2e}")
|
|
314
|
+
# rel_error = (sure - mse).abs() / mse
|
|
315
|
+
# print(f"rel_error: {rel_error:.2e}")
|
|
316
|
+
#
|
|
317
|
+
# d = physics.A_adjoint(y)
|
|
318
|
+
# dinv.utils.plot([d.sum(1).unsqueeze(1), x.sum(1).unsqueeze(1)])
|
|
319
|
+
#
|
|
320
|
+
# print("adjoint test....")
|
|
321
|
+
# test_operators_adjointness(
|
|
322
|
+
# "MRI", (2, 320, 320), dinv.device
|
|
323
|
+
# ) # pass, tensor(0., device='cuda:0')
|
|
324
|
+
# print("norm test....")
|
|
325
|
+
# test_operators_norm("MRI", (2, 320, 320), dinv.device) # pass
|
|
326
|
+
# print("pinv test....")
|
|
327
|
+
# test_pseudo_inverse("MRI", (2, 320, 320), dinv.device) # pass
|
|
328
|
+
#
|
|
329
|
+
# print("pass all...")
|
deepinv/physics/noise.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GaussianNoise(torch.nn.Module):
|
|
5
|
+
r"""
|
|
6
|
+
|
|
7
|
+
Gaussian noise :math:`y=z+\epsilon` where :math:`\epsilon\sim \mathcal{N}(0,I\sigma^2)`.
|
|
8
|
+
|
|
9
|
+
It can be added to a physics operator in its construction or by setting the ``noise_model``
|
|
10
|
+
attribute of the physics operator.
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
::
|
|
14
|
+
|
|
15
|
+
>>> from deepinv.physics import Denoising, GaussianNoise
|
|
16
|
+
>>> import torch
|
|
17
|
+
>>> physics = Denoising()
|
|
18
|
+
>>> physics.noise_model = GaussianNoise()
|
|
19
|
+
>>> x = torch.rand(1, 1, 2, 2)
|
|
20
|
+
>>> y = physics(x)
|
|
21
|
+
|
|
22
|
+
:param float sigma: Standard deviation of the noise.
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, sigma=0.1):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.sigma = torch.nn.Parameter(torch.tensor(sigma), requires_grad=False)
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
r"""
|
|
32
|
+
Adds the noise to measurements x
|
|
33
|
+
|
|
34
|
+
:param torch.Tensor x: measurements
|
|
35
|
+
:returns: noisy measurements
|
|
36
|
+
"""
|
|
37
|
+
return x + torch.randn_like(x) * self.sigma
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class UniformGaussianNoise(torch.nn.Module):
|
|
41
|
+
r"""
|
|
42
|
+
Gaussian noise :math:`y=z+\epsilon` where
|
|
43
|
+
:math:`\epsilon\sim \mathcal{N}(0,I\sigma^2)` and
|
|
44
|
+
:math:`\sigma \sim\mathcal{U}(\sigma_{\text{min}}, \sigma_{\text{max}})`
|
|
45
|
+
|
|
46
|
+
It can be added to a physics operator in its construction or by setting:
|
|
47
|
+
|
|
48
|
+
::
|
|
49
|
+
|
|
50
|
+
>>> physics.noise_model = UniformGaussianNoise()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
:param float sigma_min: minimum standard deviation of the noise.
|
|
54
|
+
:param float sigma_max: maximum standard deviation of the noise.
|
|
55
|
+
:param float, torch.Tensor sigma: standard deviation of the noise.
|
|
56
|
+
If ``None``, the noise is sampled uniformly at random
|
|
57
|
+
in :math:`[\sigma_{\text{min}}, \sigma_{\text{max}}]`) during the forward pass. Default: ``None``.
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, sigma_min=0.0, sigma_max=0.5, sigma=None):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.sigma_min = sigma_min
|
|
64
|
+
self.sigma_max = sigma_max
|
|
65
|
+
self.sigma = sigma
|
|
66
|
+
|
|
67
|
+
def forward(self, x):
|
|
68
|
+
r"""
|
|
69
|
+
Adds the noise to measurements x.
|
|
70
|
+
|
|
71
|
+
:param torch.Tensor x: measurements
|
|
72
|
+
:returns: noisy measurements.
|
|
73
|
+
"""
|
|
74
|
+
if self.sigma is None:
|
|
75
|
+
sigma = (
|
|
76
|
+
torch.rand((x.shape[0], 1) + (1,) * (x.dim() - 2))
|
|
77
|
+
* (self.sigma_max - self.sigma_min)
|
|
78
|
+
+ self.sigma_min
|
|
79
|
+
)
|
|
80
|
+
self.sigma = sigma.to(x.device)
|
|
81
|
+
noise = torch.randn_like(x) * self.sigma
|
|
82
|
+
return x + noise
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class PoissonNoise(torch.nn.Module):
|
|
86
|
+
r"""
|
|
87
|
+
|
|
88
|
+
Poisson noise :math:`y = \mathcal{P}(\frac{x}{\gamma})`
|
|
89
|
+
with gain :math:`\gamma>0`.
|
|
90
|
+
|
|
91
|
+
If ``normalize=True``, the output is divided by the gain, i.e., :math:`\tilde{y} = \gamma y`.
|
|
92
|
+
|
|
93
|
+
It can be added to a physics operator in its construction or by setting:
|
|
94
|
+
::
|
|
95
|
+
|
|
96
|
+
>>> physics.noise_model = PoissonNoise()
|
|
97
|
+
|
|
98
|
+
:param float gain: gain of the noise.
|
|
99
|
+
:param bool normalize: normalize the output.
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self, gain=1.0, normalize=True):
|
|
104
|
+
super().__init__()
|
|
105
|
+
self.normalize = torch.nn.Parameter(
|
|
106
|
+
torch.tensor(normalize), requires_grad=False
|
|
107
|
+
)
|
|
108
|
+
self.gain = torch.nn.Parameter(torch.tensor(gain), requires_grad=False)
|
|
109
|
+
|
|
110
|
+
def forward(self, x):
|
|
111
|
+
r"""
|
|
112
|
+
Adds the noise to measurements x
|
|
113
|
+
|
|
114
|
+
:param torch.Tensor x: measurements
|
|
115
|
+
:returns: noisy measurements
|
|
116
|
+
"""
|
|
117
|
+
y = torch.poisson(x / self.gain)
|
|
118
|
+
if self.normalize:
|
|
119
|
+
y *= self.gain
|
|
120
|
+
return y
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class PoissonGaussianNoise(torch.nn.Module):
|
|
124
|
+
r"""
|
|
125
|
+
Poisson-Gaussian noise :math:`y = \gamma z + \epsilon` where :math:`z\sim\mathcal{P}(\frac{x}{\gamma})`
|
|
126
|
+
and :math:`\epsilon\sim\mathcal{N}(0, I \sigma^2)`.
|
|
127
|
+
|
|
128
|
+
It can be added to a physics operator by setting
|
|
129
|
+
|
|
130
|
+
::
|
|
131
|
+
|
|
132
|
+
>>> physics.noise_model = PoissonGaussianNoise()
|
|
133
|
+
|
|
134
|
+
:param float gain: gain of the noise.
|
|
135
|
+
:param float sigma: Standard deviation of the noise.
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(self, gain=1.0, sigma=0.1):
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.gain = torch.nn.Parameter(torch.tensor(gain), requires_grad=False)
|
|
142
|
+
self.sigma = torch.nn.Parameter(torch.tensor(sigma), requires_grad=False)
|
|
143
|
+
|
|
144
|
+
def forward(self, x):
|
|
145
|
+
r"""
|
|
146
|
+
Adds the noise to measurements x
|
|
147
|
+
|
|
148
|
+
:param torch.Tensor x: measurements
|
|
149
|
+
:returns: noisy measurements
|
|
150
|
+
"""
|
|
151
|
+
y = torch.poisson(x / self.gain) * self.gain
|
|
152
|
+
|
|
153
|
+
y += torch.randn_like(x) * self.sigma
|
|
154
|
+
return y
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class UniformNoise(torch.nn.Module):
|
|
158
|
+
r"""
|
|
159
|
+
Uniform noise :math:`y = x + \epsilon` where :math:`\epsilon\sim\mathcal{U}(-a,a)`.
|
|
160
|
+
|
|
161
|
+
It can be added to a physics operator by setting
|
|
162
|
+
::
|
|
163
|
+
|
|
164
|
+
>>> physics.noise_model = UniformNoise()
|
|
165
|
+
|
|
166
|
+
:param float a: amplitude of the noise.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, a=0.1):
|
|
170
|
+
super().__init__()
|
|
171
|
+
self.a = torch.nn.Parameter(torch.tensor(a), requires_grad=False)
|
|
172
|
+
|
|
173
|
+
def forward(self, x):
|
|
174
|
+
r"""
|
|
175
|
+
Adds the noise to measurements x
|
|
176
|
+
|
|
177
|
+
:param torch.Tensor x: measurements
|
|
178
|
+
:returns: noisy measurements
|
|
179
|
+
"""
|
|
180
|
+
return x + (torch.rand_like(x) - 0.5) * 2 * self.a
|
deepinv/physics/range.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from deepinv.physics.forward import DecomposablePhysics
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Decolorize(DecomposablePhysics):
|
|
6
|
+
r"""
|
|
7
|
+
Converts RGB images to grayscale.
|
|
8
|
+
|
|
9
|
+
Follows the `rec601 <https://en.wikipedia.org/wiki/Rec._601>`_ convention.
|
|
10
|
+
|
|
11
|
+
Signals must be tensors with 3 colour (RGB) channels, i.e. [*,3,*,*]
|
|
12
|
+
The measurements are grayscale images.
|
|
13
|
+
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, **kwargs):
|
|
17
|
+
super().__init__(**kwargs)
|
|
18
|
+
self.mask = 1.0
|
|
19
|
+
|
|
20
|
+
def V_adjoint(self, x):
|
|
21
|
+
y = x[:, 0, :, :] * 0.2989 + x[:, 1, :, :] * 0.5870 + x[:, 2, :, :] * 0.1140
|
|
22
|
+
return y.unsqueeze(1)
|
|
23
|
+
|
|
24
|
+
def V(self, y):
|
|
25
|
+
return torch.cat([y * 0.2989, y * 0.5870, y * 0.1140], dim=1)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# # test code
|
|
29
|
+
# if __name__ == "__main__":
|
|
30
|
+
# device = "cuda:0"
|
|
31
|
+
#
|
|
32
|
+
# import deepinv as dinv
|
|
33
|
+
# import matplotlib.pyplot as plt
|
|
34
|
+
# import torchvision
|
|
35
|
+
#
|
|
36
|
+
# dinv.device = "cpu"
|
|
37
|
+
#
|
|
38
|
+
# x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
|
|
39
|
+
# x = x.unsqueeze(0).float().to(dinv.device) / 255
|
|
40
|
+
# x = torchvision.transforms.Resize((128, 128))(x)
|
|
41
|
+
#
|
|
42
|
+
# physics = Decolorize()
|
|
43
|
+
#
|
|
44
|
+
# y = physics(x)
|
|
45
|
+
#
|
|
46
|
+
# print(physics.adjointness_test(x))
|
|
47
|
+
# print(physics.compute_norm(x))
|
|
48
|
+
# xhat = physics.A_adjoint(y)
|
|
49
|
+
#
|
|
50
|
+
# plot_results = False # set to True to plot results
|
|
51
|
+
#
|
|
52
|
+
# if plot_results:
|
|
53
|
+
# dinv.utils.plot([x, xhat, y])
|