deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,123 @@
1
+ import torch
2
+ from deepinv.physics.forward import Physics
3
+ from deepinv.physics.noise import PoissonNoise
4
+
5
+
6
+ class SinglePhotonLidar(Physics):
7
+ r"""
8
+ Single photon lidar operator for depth ranging.
9
+
10
+ See https://ieeexplore.ieee.org/abstract/document/9127841 for a review of this imaging method.
11
+
12
+ The forward operator is given by
13
+
14
+ .. math::
15
+ y_{i,j,t} = \mathcal{P}(h(t-d_{i,j}) r_{i,j} + b_{i,j})
16
+
17
+ where :math:`\mathcal{P}` is the Poisson noise model, :math:`h(t)` is a Gaussian impulse response function at
18
+ time :math:`t`, :math:`d_{i,j}` is the depth of the scene at pixel :math:`(i,j)`,
19
+ :math:`r_{i,j}` is the intensity of the scene at pixel :math:`(i,j)` and :math:`b_{i,j}` is the background noise
20
+ at pixel :math:`(i,j)`.
21
+
22
+ For a pixel grid of size (H,W) and batch size B, the signals have size (B, 3, H, W), where the first channel
23
+ contains the depth of the scene :math:`d`, the second channel contains the intensity of the scene :math:`r` and
24
+ the third channel contains the per pixel background noise levels :math:`b`.
25
+
26
+ :param float sigma: Standard deviation of the Gaussian impulse response function.
27
+ :param int bins: Number of histogram bins per pixel.
28
+ :param str device: Device to use (gpu or cpu).
29
+ """
30
+
31
+ def __init__(self, sigma=1.0, bins=50, device="cpu"):
32
+ super().__init__()
33
+
34
+ self.T = bins
35
+ self.grid = torch.meshgrid(torch.arange(bins), indexing="ij")[0].to(device)
36
+ self.sigma = torch.nn.Parameter(
37
+ torch.tensor(sigma, device=device), requires_grad=False
38
+ )
39
+ self.noise_model = PoissonNoise()
40
+
41
+ h = ((self.grid - 3 * sigma) / self.sigma).pow(2)
42
+ h = torch.exp(-h / 2.0)
43
+ h = h[: int(6 * sigma)]
44
+ h = h / h.sum()
45
+ self.irf = h.unsqueeze(0).unsqueeze(0) # set impulse response function
46
+ self.grid = self.grid.unsqueeze(0).unsqueeze(2).unsqueeze(3)
47
+
48
+ def A(self, x):
49
+ r"""
50
+ Applies the forward operator.
51
+
52
+ Input is of size (B, 3, H, W) and output is of size (B, bins, H, W)
53
+
54
+ :param torch.tensor x: tensor containing the depth, intensity and background noise levels.
55
+ """
56
+
57
+ h = ((self.grid - x[:, 0, :, :]) / self.sigma).pow(2)
58
+ h = torch.exp(-h / 2.0)
59
+ h = h / h.sum(dim=1, keepdim=True)
60
+ y = x[:, 1, :, :] * h + x[:, 2, :, :]
61
+ return y
62
+
63
+ def A_dagger(self, y):
64
+ r"""
65
+ Applies Matched filtering to find the peaks.
66
+
67
+ Input is of size (B, bins, H, W), output of size (B, 3, H, W).
68
+
69
+ :param torch.tensor y: measurements
70
+ """
71
+ B, T, H, W = y.shape
72
+
73
+ # reshape to (B*H*W, 1, T)
74
+ y = y.permute(0, 2, 3, 1).reshape(B * H * W, 1, T)
75
+
76
+ # Apply irf using convolution
77
+ x = torch.nn.functional.conv1d(y, self.irf, padding="same")
78
+
79
+ # Find peak value in each channel
80
+ _, x = torch.max(x, dim=-1, keepdim=True)
81
+ x = x.type(torch.float32)
82
+ offset = self.irf.shape[-1] // 2
83
+ x -= 3 * self.sigma - offset - 0.5
84
+
85
+ mask = torch.ones_like(y)
86
+ grid = self.grid.squeeze(-1).squeeze(-1) # (1, T)
87
+
88
+ mask *= (x - 4 * self.sigma) < grid
89
+ mask *= (x + 4 * self.sigma) > grid
90
+
91
+ b = (y * (1 - mask)).sum(dim=-1, keepdim=True)
92
+ r = y.sum(dim=-1, keepdim=True) - b
93
+ b /= T
94
+
95
+ x = torch.stack([x, r, b], dim=-1)
96
+
97
+ # reshape to (B, 3, H, W)
98
+ x = x.reshape(B, H, W, 3).permute(0, 3, 1, 2)
99
+
100
+ return x
101
+
102
+
103
+ # if __name__ == "__main__":
104
+ # import matplotlib.pyplot as plt
105
+ # import deepinv as dinv
106
+ #
107
+ # bins = 40
108
+ # device = "cuda:0"
109
+ # physics = SinglePhotonLidar(bins=bins, device=device)
110
+ #
111
+ # x = torch.ones((1, 3, 2, 4), device=device)
112
+ # x[:, 0, :, :] *= bins / 2
113
+ # x[:, 1, :, :] *= 300
114
+ # x[:, 2, :, :] *= 1
115
+ #
116
+ # y = physics(x)
117
+ # xhat = physics.A_dagger(y)
118
+ #
119
+ # y0 = y[0, :, 0, 0].detach().cpu().numpy()
120
+ # plt.plot(y0)
121
+ # plt.show()
122
+ #
123
+ # print(f"MSE {dinv.utils.cal_mse(x, xhat)}")
deepinv/physics/mri.py ADDED
@@ -0,0 +1,329 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.fft
4
+ from typing import List, Optional
5
+ from deepinv.physics.forward import DecomposablePhysics
6
+
7
+
8
+ class MRI(DecomposablePhysics):
9
+ r"""
10
+ Single-coil accelerated magnetic resonance imaging.
11
+
12
+ The linear operator operates in 2D slices and is defined as
13
+
14
+ .. math::
15
+
16
+ y = SFx
17
+
18
+ where :math:`S` applies a mask (subsampling operator), and :math:`F` is the 2D discrete Fourier Transform.
19
+ This operator has a simple singular value decomposition, so it inherits the structure of
20
+ :meth:`deepinv.physics.DecomposablePhysics` and thus have a fast pseudo-inverse and prox operators.
21
+
22
+ The complex images :math:`x` and measurements :math:`y` should be of size (B, 2, H, W) where the first channel corresponds to the real part
23
+ and the second channel corresponds to the imaginary part.
24
+
25
+ :param torch.tensor mask: the mask values should be binary.
26
+ The mask size should be of the form (H,W) where H is the image height and W is the image width.
27
+ :param torch.device device: cpu or gpu.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ mask=None,
33
+ image_size=(320, 320),
34
+ acceleration_factor=4,
35
+ device="cpu",
36
+ seed=None,
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+ self.device = device
41
+ self.image_size = image_size
42
+
43
+ if mask is not None:
44
+ mask = mask.to(device).unsqueeze(0).unsqueeze(0)
45
+ else:
46
+ mask = (
47
+ self.sample_mask(
48
+ image_size=image_size,
49
+ acceleration_factor=acceleration_factor,
50
+ seed=seed,
51
+ )
52
+ .unsqueeze(0)
53
+ .unsqueeze(0)
54
+ )
55
+
56
+ self.mask = torch.nn.Parameter(
57
+ torch.cat([mask, mask], dim=1), requires_grad=False
58
+ )
59
+
60
+ def reset(self, **kwargs):
61
+ r"""
62
+ Resets the physics, i.e. re-samples a new mask and new noise realization (if any).
63
+ """
64
+ super().reset(**kwargs)
65
+ mask = (
66
+ self.sample_mask(image_size=self.image_size, **kwargs)
67
+ .unsqueeze(0)
68
+ .unsqueeze(0)
69
+ )
70
+
71
+ self.mask = torch.nn.Parameter(
72
+ torch.cat([mask, mask], dim=1), requires_grad=False
73
+ )
74
+
75
+ def V_adjoint(self, x): # (B, 2, H, W) -> (B, H, W, 2)
76
+ y = fft2c_new(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
77
+ return y
78
+
79
+ def U(self, x):
80
+ return x[:, self.mask.squeeze(0) > 0]
81
+
82
+ def U_adjoint(self, x):
83
+ _, c, h, w = self.mask.shape
84
+ out = torch.zeros((x.shape[0], c, h, w), device=x.device)
85
+ out[:, self.mask.squeeze(0) > 0] = x
86
+ return out
87
+
88
+ def V(self, x): # (B, 2, H, W) -> (B, H, W, 2)
89
+ x = x.permute(0, 2, 3, 1)
90
+ return ifft2c_new(x).permute(0, 3, 1, 2)
91
+
92
+ def sample_mask(self, image_size=(320, 320), acceleration_factor=4, seed=None):
93
+ r"""
94
+ Create a mask of vertical lines.
95
+
96
+ :param tuple image_size: image size.
97
+ :param int acceleration_factor: acceleration factor.
98
+ :param int seed: random seed.
99
+ :return: mask of size (H, W) with values in {0, 1}.
100
+ """
101
+ if seed is not None:
102
+ np.random.seed(seed)
103
+ if acceleration_factor == 4:
104
+ central_lines_percent = 0.08
105
+ num_lines_center = int(central_lines_percent * image_size[-1])
106
+ side_lines_percent = 0.25 - central_lines_percent
107
+ num_lines_side = int(side_lines_percent * image_size[-1])
108
+ if acceleration_factor == 8:
109
+ central_lines_percent = 0.04
110
+ num_lines_center = int(central_lines_percent * image_size[-1])
111
+ side_lines_percent = 0.125 - central_lines_percent
112
+ num_lines_side = int(side_lines_percent * image_size[-1])
113
+ mask = torch.zeros(image_size)
114
+ center_line_indices = torch.linspace(
115
+ image_size[0] // 2 - num_lines_center // 2,
116
+ image_size[0] // 2 + num_lines_center // 2 + 1,
117
+ steps=50,
118
+ dtype=torch.long,
119
+ )
120
+ mask[:, center_line_indices] = 1
121
+ random_line_indices = np.random.choice(
122
+ image_size[0], size=(num_lines_side // 2,), replace=False
123
+ )
124
+ mask[:, random_line_indices] = 1
125
+ return mask.float().to(self.device)
126
+
127
+
128
+ #
129
+ # reference: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py
130
+ def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
131
+ r"""
132
+ Apply centered 2 dimensional Fast Fourier Transform.
133
+ :param torch.tensor data: Complex valued input data containing at least 3 dimensions:
134
+ dimensions -2 & -1 are spatial dimensions and dimension -3 has size
135
+ 2. All other dimensions are assumed to be batch dimensions.
136
+ :param bool norm: Normalization mode. See ``torch.fft.fft``.
137
+ :return: (torch.tensor) the FFT of the input.
138
+ """
139
+ if not data.shape[-1] == 2:
140
+ raise ValueError("Tensor does not have separate complex dim.")
141
+
142
+ data = ifftshift(data, dim=[-3, -2])
143
+ data = torch.view_as_real(
144
+ torch.fft.fftn( # type: ignore
145
+ torch.view_as_complex(data), dim=(-2, -1), norm=norm
146
+ )
147
+ )
148
+ data = fftshift(data, dim=[-3, -2])
149
+
150
+ return data
151
+
152
+
153
+ def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
154
+ """
155
+ Apply centered 2-dimensional Inverse Fast Fourier Transform.
156
+ Args:
157
+ data: Complex valued input data containing at least 3 dimensions:
158
+ dimensions -2 & -1 are spatial dimensions and dimension -3 has size
159
+ 2. All other dimensions are assumed to be batch dimensions.
160
+ norm: Normalization mode. See ``torch.fft.ifft``.
161
+ Returns:
162
+ The IFFT of the input.
163
+ """
164
+ if not data.shape[-1] == 2:
165
+ raise ValueError("Tensor does not have separate complex dim.")
166
+
167
+ data = ifftshift(data, dim=[-3, -2])
168
+ data = torch.view_as_real(
169
+ torch.fft.ifftn( # type: ignore
170
+ torch.view_as_complex(data), dim=(-2, -1), norm=norm
171
+ )
172
+ )
173
+ data = fftshift(data, dim=[-3, -2])
174
+
175
+ return data
176
+
177
+
178
+ # Helper functions
179
+ def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
180
+ """
181
+ Similar to roll but for only one dim.
182
+ Args:
183
+ x: A PyTorch tensor.
184
+ shift: Amount to roll.
185
+ dim: Which dimension to roll.
186
+ Returns:
187
+ Rolled version of x.
188
+ """
189
+ shift = shift % x.size(dim)
190
+ if shift == 0:
191
+ return x
192
+
193
+ left = x.narrow(dim, 0, x.size(dim) - shift)
194
+ right = x.narrow(dim, x.size(dim) - shift, shift)
195
+
196
+ return torch.cat((right, left), dim=dim)
197
+
198
+
199
+ def roll(x: torch.Tensor, shift: List[int], dim: List[int]) -> torch.Tensor:
200
+ """
201
+ Similar to np.roll but applies to PyTorch Tensors.
202
+ Args:
203
+ x: A PyTorch tensor.
204
+ shift: Amount to roll.
205
+ dim: Which dimension to roll.
206
+ Returns:
207
+ Rolled version of x.
208
+ """
209
+ if len(shift) != len(dim):
210
+ raise ValueError("len(shift) must match len(dim)")
211
+
212
+ for s, d in zip(shift, dim):
213
+ x = roll_one_dim(x, s, d)
214
+
215
+ return x
216
+
217
+
218
+ def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
219
+ """
220
+ Similar to np.fft.fftshift but applies to PyTorch Tensors
221
+ Args:
222
+ x: A PyTorch tensor.
223
+ dim: Which dimension to fftshift.
224
+ Returns:
225
+ fftshifted version of x.
226
+ """
227
+ if dim is None:
228
+ # this weird code is necessary for toch.jit.script typing
229
+ dim = [0] * (x.dim())
230
+ for i in range(1, x.dim()):
231
+ dim[i] = i
232
+
233
+ # also necessary for torch.jit.script
234
+ shift = [0] * len(dim)
235
+ for i, dim_num in enumerate(dim):
236
+ shift[i] = x.shape[dim_num] // 2
237
+
238
+ return roll(x, shift, dim)
239
+
240
+
241
+ def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
242
+ """
243
+ Similar to np.fft.ifftshift but applies to PyTorch Tensors
244
+ Args:
245
+ x: A PyTorch tensor.
246
+ dim: Which dimension to ifftshift.
247
+ Returns:
248
+ ifftshifted version of x.
249
+ """
250
+ if dim is None:
251
+ # this weird code is necessary for toch.jit.script typing
252
+ dim = [0] * (x.dim())
253
+ for i in range(1, x.dim()):
254
+ dim[i] = i
255
+
256
+ # also necessary for torch.jit.script
257
+ shift = [0] * len(dim)
258
+ for i, dim_num in enumerate(dim):
259
+ shift[i] = (x.shape[dim_num] + 1) // 2
260
+
261
+ return roll(x, shift, dim)
262
+
263
+
264
+ # if __name__ == "__main__":
265
+ # # deepinv test
266
+ # from deepinv.tests.test_physics import (
267
+ # test_operators_norm,
268
+ # test_operators_adjointness,
269
+ # test_pseudo_inverse,
270
+ # device,
271
+ # )
272
+ # import deepinv as dinv
273
+ # from fastmri.data import subsample
274
+ #
275
+ # imsize = (25, 32)
276
+ # # Create a mask function
277
+ # mask_func = subsample.RandomMaskFunc(center_fractions=[0.08], accelerations=[4])
278
+ # m = mask_func.sample_mask((imsize[1], imsize[0]), offset=None)
279
+ #
280
+ # # mask = torch.ones((imsize[0], 1)) * (m[0] + m[1]).permute(1, 0)
281
+ # mask = torch.ones(imsize)
282
+ # mask[mask > 1] = 1
283
+ #
284
+ # sigma = 0.1
285
+ # # physics = MRI(mask=mask, device=dinv.device)
286
+ # physics = dinv.physics.Denoising()
287
+ # physics.noise_model = dinv.physics.GaussianNoise(sigma)
288
+ #
289
+ # # choose a reconstruction architecture
290
+ # backbone = dinv.models.MedianFilter()
291
+ #
292
+ # class denoiser(torch.nn.Module):
293
+ # def __init__(self):
294
+ # super().__init__()
295
+ #
296
+ # def forward(self, x, sigma=None):
297
+ # return x
298
+ #
299
+ # f = dinv.models.ArtifactRemoval(backbone)
300
+ #
301
+ # batch_size = 1
302
+ #
303
+ # for tau in np.logspace(-5, 3, 1):
304
+ # x = torch.ones((batch_size, 2) + imsize, device=dinv.device)
305
+ # y = physics(x)
306
+ #
307
+ # # choose training losses
308
+ # loss = dinv.loss.SureGaussianLoss(sigma, tau=tau)
309
+ # x_net = f(y, physics)
310
+ # mse = dinv.metric.mse()(physics.A(x), physics.A(x_net))
311
+ # sure = loss(y, x_net, physics, f)
312
+ #
313
+ # print(f"tau:{tau:.2e} mse: {mse:.2e}, sure: {sure:.2e}")
314
+ # rel_error = (sure - mse).abs() / mse
315
+ # print(f"rel_error: {rel_error:.2e}")
316
+ #
317
+ # d = physics.A_adjoint(y)
318
+ # dinv.utils.plot([d.sum(1).unsqueeze(1), x.sum(1).unsqueeze(1)])
319
+ #
320
+ # print("adjoint test....")
321
+ # test_operators_adjointness(
322
+ # "MRI", (2, 320, 320), dinv.device
323
+ # ) # pass, tensor(0., device='cuda:0')
324
+ # print("norm test....")
325
+ # test_operators_norm("MRI", (2, 320, 320), dinv.device) # pass
326
+ # print("pinv test....")
327
+ # test_pseudo_inverse("MRI", (2, 320, 320), dinv.device) # pass
328
+ #
329
+ # print("pass all...")
@@ -0,0 +1,180 @@
1
+ import torch
2
+
3
+
4
+ class GaussianNoise(torch.nn.Module):
5
+ r"""
6
+
7
+ Gaussian noise :math:`y=z+\epsilon` where :math:`\epsilon\sim \mathcal{N}(0,I\sigma^2)`.
8
+
9
+ It can be added to a physics operator in its construction or by setting the ``noise_model``
10
+ attribute of the physics operator.
11
+
12
+
13
+ ::
14
+
15
+ >>> from deepinv.physics import Denoising, GaussianNoise
16
+ >>> import torch
17
+ >>> physics = Denoising()
18
+ >>> physics.noise_model = GaussianNoise()
19
+ >>> x = torch.rand(1, 1, 2, 2)
20
+ >>> y = physics(x)
21
+
22
+ :param float sigma: Standard deviation of the noise.
23
+
24
+ """
25
+
26
+ def __init__(self, sigma=0.1):
27
+ super().__init__()
28
+ self.sigma = torch.nn.Parameter(torch.tensor(sigma), requires_grad=False)
29
+
30
+ def forward(self, x):
31
+ r"""
32
+ Adds the noise to measurements x
33
+
34
+ :param torch.Tensor x: measurements
35
+ :returns: noisy measurements
36
+ """
37
+ return x + torch.randn_like(x) * self.sigma
38
+
39
+
40
+ class UniformGaussianNoise(torch.nn.Module):
41
+ r"""
42
+ Gaussian noise :math:`y=z+\epsilon` where
43
+ :math:`\epsilon\sim \mathcal{N}(0,I\sigma^2)` and
44
+ :math:`\sigma \sim\mathcal{U}(\sigma_{\text{min}}, \sigma_{\text{max}})`
45
+
46
+ It can be added to a physics operator in its construction or by setting:
47
+
48
+ ::
49
+
50
+ >>> physics.noise_model = UniformGaussianNoise()
51
+
52
+
53
+ :param float sigma_min: minimum standard deviation of the noise.
54
+ :param float sigma_max: maximum standard deviation of the noise.
55
+ :param float, torch.Tensor sigma: standard deviation of the noise.
56
+ If ``None``, the noise is sampled uniformly at random
57
+ in :math:`[\sigma_{\text{min}}, \sigma_{\text{max}}]`) during the forward pass. Default: ``None``.
58
+
59
+ """
60
+
61
+ def __init__(self, sigma_min=0.0, sigma_max=0.5, sigma=None):
62
+ super().__init__()
63
+ self.sigma_min = sigma_min
64
+ self.sigma_max = sigma_max
65
+ self.sigma = sigma
66
+
67
+ def forward(self, x):
68
+ r"""
69
+ Adds the noise to measurements x.
70
+
71
+ :param torch.Tensor x: measurements
72
+ :returns: noisy measurements.
73
+ """
74
+ if self.sigma is None:
75
+ sigma = (
76
+ torch.rand((x.shape[0], 1) + (1,) * (x.dim() - 2))
77
+ * (self.sigma_max - self.sigma_min)
78
+ + self.sigma_min
79
+ )
80
+ self.sigma = sigma.to(x.device)
81
+ noise = torch.randn_like(x) * self.sigma
82
+ return x + noise
83
+
84
+
85
+ class PoissonNoise(torch.nn.Module):
86
+ r"""
87
+
88
+ Poisson noise :math:`y = \mathcal{P}(\frac{x}{\gamma})`
89
+ with gain :math:`\gamma>0`.
90
+
91
+ If ``normalize=True``, the output is divided by the gain, i.e., :math:`\tilde{y} = \gamma y`.
92
+
93
+ It can be added to a physics operator in its construction or by setting:
94
+ ::
95
+
96
+ >>> physics.noise_model = PoissonNoise()
97
+
98
+ :param float gain: gain of the noise.
99
+ :param bool normalize: normalize the output.
100
+
101
+ """
102
+
103
+ def __init__(self, gain=1.0, normalize=True):
104
+ super().__init__()
105
+ self.normalize = torch.nn.Parameter(
106
+ torch.tensor(normalize), requires_grad=False
107
+ )
108
+ self.gain = torch.nn.Parameter(torch.tensor(gain), requires_grad=False)
109
+
110
+ def forward(self, x):
111
+ r"""
112
+ Adds the noise to measurements x
113
+
114
+ :param torch.Tensor x: measurements
115
+ :returns: noisy measurements
116
+ """
117
+ y = torch.poisson(x / self.gain)
118
+ if self.normalize:
119
+ y *= self.gain
120
+ return y
121
+
122
+
123
+ class PoissonGaussianNoise(torch.nn.Module):
124
+ r"""
125
+ Poisson-Gaussian noise :math:`y = \gamma z + \epsilon` where :math:`z\sim\mathcal{P}(\frac{x}{\gamma})`
126
+ and :math:`\epsilon\sim\mathcal{N}(0, I \sigma^2)`.
127
+
128
+ It can be added to a physics operator by setting
129
+
130
+ ::
131
+
132
+ >>> physics.noise_model = PoissonGaussianNoise()
133
+
134
+ :param float gain: gain of the noise.
135
+ :param float sigma: Standard deviation of the noise.
136
+
137
+ """
138
+
139
+ def __init__(self, gain=1.0, sigma=0.1):
140
+ super().__init__()
141
+ self.gain = torch.nn.Parameter(torch.tensor(gain), requires_grad=False)
142
+ self.sigma = torch.nn.Parameter(torch.tensor(sigma), requires_grad=False)
143
+
144
+ def forward(self, x):
145
+ r"""
146
+ Adds the noise to measurements x
147
+
148
+ :param torch.Tensor x: measurements
149
+ :returns: noisy measurements
150
+ """
151
+ y = torch.poisson(x / self.gain) * self.gain
152
+
153
+ y += torch.randn_like(x) * self.sigma
154
+ return y
155
+
156
+
157
+ class UniformNoise(torch.nn.Module):
158
+ r"""
159
+ Uniform noise :math:`y = x + \epsilon` where :math:`\epsilon\sim\mathcal{U}(-a,a)`.
160
+
161
+ It can be added to a physics operator by setting
162
+ ::
163
+
164
+ >>> physics.noise_model = UniformNoise()
165
+
166
+ :param float a: amplitude of the noise.
167
+ """
168
+
169
+ def __init__(self, a=0.1):
170
+ super().__init__()
171
+ self.a = torch.nn.Parameter(torch.tensor(a), requires_grad=False)
172
+
173
+ def forward(self, x):
174
+ r"""
175
+ Adds the noise to measurements x
176
+
177
+ :param torch.Tensor x: measurements
178
+ :returns: noisy measurements
179
+ """
180
+ return x + (torch.rand_like(x) - 0.5) * 2 * self.a
@@ -0,0 +1,53 @@
1
+ import torch
2
+ from deepinv.physics.forward import DecomposablePhysics
3
+
4
+
5
+ class Decolorize(DecomposablePhysics):
6
+ r"""
7
+ Converts RGB images to grayscale.
8
+
9
+ Follows the `rec601 <https://en.wikipedia.org/wiki/Rec._601>`_ convention.
10
+
11
+ Signals must be tensors with 3 colour (RGB) channels, i.e. [*,3,*,*]
12
+ The measurements are grayscale images.
13
+
14
+ """
15
+
16
+ def __init__(self, **kwargs):
17
+ super().__init__(**kwargs)
18
+ self.mask = 1.0
19
+
20
+ def V_adjoint(self, x):
21
+ y = x[:, 0, :, :] * 0.2989 + x[:, 1, :, :] * 0.5870 + x[:, 2, :, :] * 0.1140
22
+ return y.unsqueeze(1)
23
+
24
+ def V(self, y):
25
+ return torch.cat([y * 0.2989, y * 0.5870, y * 0.1140], dim=1)
26
+
27
+
28
+ # # test code
29
+ # if __name__ == "__main__":
30
+ # device = "cuda:0"
31
+ #
32
+ # import deepinv as dinv
33
+ # import matplotlib.pyplot as plt
34
+ # import torchvision
35
+ #
36
+ # dinv.device = "cpu"
37
+ #
38
+ # x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
39
+ # x = x.unsqueeze(0).float().to(dinv.device) / 255
40
+ # x = torchvision.transforms.Resize((128, 128))(x)
41
+ #
42
+ # physics = Decolorize()
43
+ #
44
+ # y = physics(x)
45
+ #
46
+ # print(physics.adjointness_test(x))
47
+ # print(physics.compute_norm(x))
48
+ # xhat = physics.A_adjoint(y)
49
+ #
50
+ # plot_results = False # set to True to plot results
51
+ #
52
+ # if plot_results:
53
+ # dinv.utils.plot([x, xhat, y])