deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/utils/nn.py ADDED
@@ -0,0 +1,213 @@
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+
5
+
6
+ class TensorList:
7
+ r"""
8
+
9
+ Represents a list of :class:`torch.Tensor` with different shapes.
10
+ It allows to sum, flatten, append, etc. lists of tensors seamlessly, in a
11
+ similar fashion to :class:`torch.Tensor`.
12
+
13
+ :param x: a list of :class:`torch.Tensor`, a single :class:`torch.Tensor` or a TensorList.
14
+ """
15
+
16
+ def __init__(self, x):
17
+ super().__init__()
18
+
19
+ if isinstance(x, list) or isinstance(x, TensorList):
20
+ self.x = list(x)
21
+ elif isinstance(x, torch.Tensor):
22
+ self.x = [x]
23
+ else:
24
+ raise TypeError("x must be a list of torch.Tensor or a single torch.Tensor")
25
+
26
+ self.shape = [xi.shape for xi in self.x]
27
+
28
+ def __len__(self):
29
+ r"""
30
+ Returns the number of tensors in the list.
31
+ """
32
+ return len(self.x)
33
+
34
+ def __getitem__(self, item):
35
+ r"""
36
+ Returns the ith tensor in the list.
37
+ """
38
+ return self.x[item]
39
+
40
+ def flatten(self):
41
+ r"""
42
+ Returns a :class:`torch.Tensor` with a flattened version of the list of tensors.
43
+ """
44
+ return torch.cat([xi.flatten() for xi in self.x])
45
+
46
+ def append(self, other):
47
+ r"""
48
+ Appends a :class:`torch.Tensor` or a list of :class:`torch.Tensor` to the list.
49
+
50
+ """
51
+ if isinstance(other, list):
52
+ self.x += other
53
+ elif isinstance(other, TensorList):
54
+ self.x += other.x
55
+ elif isinstance(other, torch.Tensor):
56
+ self.x.append(other)
57
+ else:
58
+ raise TypeError(
59
+ "the appended item must be a list of :class:`torch.Tensor` or a single :class:`torch.Tensor`"
60
+ )
61
+ return self
62
+
63
+ def __add__(self, other):
64
+ r"""
65
+
66
+ Adds two TensorLists. The sizes of the tensor lists must match.
67
+
68
+ """
69
+ if not isinstance(other, list) and not isinstance(other, TensorList):
70
+ return TensorList([xi + other for xi in self.x])
71
+ else:
72
+ return TensorList([xi + otheri for xi, otheri in zip(self.x, other)])
73
+
74
+ def __mul__(self, other):
75
+ r"""
76
+
77
+ Multiply two TensorLists. The sizes of the tensor lists must match.
78
+
79
+ """
80
+ if not isinstance(other, list) and not isinstance(other, TensorList):
81
+ return TensorList([xi * other for xi in self.x])
82
+ else:
83
+ return TensorList([xi * otheri for xi, otheri in zip(self.x, other)])
84
+
85
+ def __truediv__(self, other):
86
+ r"""
87
+
88
+ Divide two TensorLists. The sizes of the tensor lists must match.
89
+
90
+ """
91
+ if not isinstance(other, list) and not isinstance(other, TensorList):
92
+ return TensorList([xi / other for xi in self.x])
93
+ else:
94
+ return TensorList([xi / otheri for xi, otheri in zip(self.x, other)])
95
+
96
+ def __neg__(self):
97
+ r"""
98
+
99
+ Negate a TensorList.
100
+ """
101
+ return TensorList([-xi for xi in self.x])
102
+
103
+ def __sub__(self, other):
104
+ r"""
105
+
106
+ Substract two TensorLists. The sizes of the tensor lists must match.
107
+
108
+ """
109
+ if not isinstance(other, list) and not isinstance(other, TensorList):
110
+ return TensorList([xi - other for xi in self.x])
111
+ else:
112
+ return TensorList([xi - otheri for xi, otheri in zip(self.x, other)])
113
+
114
+
115
+ def randn_like(x):
116
+ r"""
117
+ Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor`
118
+ with the same type as x, filled with standard gaussian numbers.
119
+ """
120
+ if isinstance(x, torch.Tensor):
121
+ return torch.randn_like(x)
122
+ else:
123
+ return TensorList([torch.randn_like(xi) for xi in x])
124
+
125
+
126
+ def rand_like(x):
127
+ r"""
128
+ Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor`
129
+ with the same type as x, filled with random uniform numbers in [0,1].
130
+ """
131
+ if isinstance(x, torch.Tensor):
132
+ return torch.rand_like(x)
133
+ else:
134
+ return TensorList([torch.rand_like(xi) for xi in x])
135
+
136
+
137
+ def zeros_like(x):
138
+ r"""
139
+ Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor`
140
+ with the same type as x, filled with zeros.
141
+ """
142
+ if isinstance(x, torch.Tensor):
143
+ return torch.zeros_like(x)
144
+ else:
145
+ return TensorList([torch.zeros_like(xi) for xi in x])
146
+
147
+
148
+ def ones_like(x):
149
+ r"""
150
+ Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor`
151
+ with the same type as x, filled with ones.
152
+ """
153
+ if isinstance(x, torch.Tensor):
154
+ return torch.ones_like(x)
155
+ else:
156
+ return TensorList([torch.ones_like(xi) for xi in x])
157
+
158
+
159
+ def get_freer_gpu():
160
+ """
161
+ Returns the GPU device with the most free memory.
162
+
163
+ """
164
+ try:
165
+ if os.name == "posix":
166
+ os.system("nvidia-smi -q -d Memory |grep -A5 GPU|grep Free >tmp")
167
+ memory_available = [int(x.split()[2]) for x in open("tmp", "r").readlines()]
168
+ else:
169
+ os.system('bash -c "nvidia-smi -q -d Memory |grep -A5 GPU|grep Free >tmp"')
170
+ memory_available = [int(x.split()[2]) for x in open("tmp", "r").readlines()]
171
+ idx = np.argmax(memory_available)
172
+ device = torch.device(f"cuda:{idx}")
173
+ print(f"Selected GPU {idx} with {np.max(memory_available)} MB free memory ")
174
+ except:
175
+ device = torch.device(f"cuda")
176
+ print("Couldn't find free GPU")
177
+
178
+ return device
179
+
180
+
181
+ def save_model(
182
+ epoch, model, optimizer, ckp_interval, epochs, loss, save_path, eval_psnr=None
183
+ ):
184
+ if (epoch > 0 and epoch % ckp_interval == 0) or epoch + 1 == epochs:
185
+ os.makedirs(save_path, exist_ok=True)
186
+
187
+ state = {
188
+ "epoch": epoch,
189
+ "state_dict": model.state_dict(),
190
+ "loss": loss,
191
+ "optimizer": optimizer.state_dict(),
192
+ }
193
+ if eval_psnr is not None:
194
+ state["eval_psnr"] = eval_psnr
195
+ torch.save(state, os.path.join(save_path, "ckp_{}.pth.tar".format(epoch)))
196
+ pass
197
+
198
+
199
+ def load_checkpoint(model, path_checkpoint, device):
200
+ checkpoint = torch.load(path_checkpoint, map_location=device)
201
+ model.load_state_dict(checkpoint["state_dict"])
202
+ return model
203
+
204
+
205
+ def investigate_model(model, idx_max=1, check_name="iterator.g_step.g_param.0"):
206
+ for idx, (name, param) in enumerate(model.named_parameters()):
207
+ if param.requires_grad and (idx < idx_max or check_name in name):
208
+ print(
209
+ name,
210
+ param.data.flatten()[0],
211
+ "gradient norm = ",
212
+ param.grad.detach().data.norm(2),
213
+ )
@@ -0,0 +1,108 @@
1
+ import torch.nn as nn
2
+
3
+
4
+ class NeuralIteration(nn.Module):
5
+ def __init__(self):
6
+ super(NeuralIteration, self).__init__()
7
+
8
+ def init(self, backbone_blocks, step_size=1.0, iterations=1):
9
+ self.iterations = iterations
10
+ self.n_blocks = len(backbone_blocks) if isinstance(backbone_blocks, list) else 1
11
+ self.blocks = backbone_blocks
12
+
13
+ self.register_parameter(
14
+ name="step_size",
15
+ param=torch.nn.Parameter(
16
+ step_size * torch.ones(self.iterations), requires_grad=True
17
+ ),
18
+ )
19
+ if self.n_blocks > 1: # weight_tied=False (many blocks)
20
+ assert (
21
+ self.n_blocks == iterations
22
+ ), "'# blocks' does not equal to 'iterations'"
23
+ self.blocks = torch.nn.ModuleList(
24
+ [backbone_blocks[_] for _ in range(iterations)]
25
+ )
26
+ else: # weight_tied=True (only one block)
27
+ self.blocks = torch.nn.ModuleList([backbone_blocks])
28
+
29
+ def forward(self, y, physics, x_init=None):
30
+ return physics.A_adjoint(y)
31
+
32
+ @staticmethod
33
+ def measurement_consistency_grad(physics, x, y):
34
+ # grad(||y-Ax||) = A^T(y-Ax)
35
+ return physics.A_adjoint(
36
+ y.to(physics.device) - physics.A(x.to(physics.device))
37
+ ).to(x.device)
38
+
39
+
40
+ class GradientDescent(NeuralIteration):
41
+ def __init__(self, backbone_blocks, step_size=1.0, iterations=1):
42
+ super(GradientDescent, self).__init__()
43
+ self.name = "gd"
44
+ self.init(backbone_blocks, step_size, iterations)
45
+
46
+ def forward(self, y, physics, x_init=None):
47
+ x = x_init.clone() if x_init is not None else physics.A_adjoint(y)
48
+ for t in range(self.iterations):
49
+ if self.n_blocks == 1:
50
+ t = 0
51
+ x = (
52
+ self.blocks[t](x)
53
+ + x
54
+ + self.step_size[t]
55
+ * NeuralIteration.measurement_consistency_grad(physics, x, y)
56
+ )
57
+ return x
58
+
59
+
60
+ class ProximalGradientDescent(NeuralIteration):
61
+ def __init__(self, backbone_blocks, step_size=1.0, iterations=1):
62
+ super(ProximalGradientDescent, self).__init__()
63
+ self.name = "pgd"
64
+ self.init(backbone_blocks, step_size, iterations)
65
+ self.mc_grad = NeuralIteration.measurement_consistency_grad
66
+
67
+ def forward(self, y, physics, x_init=None):
68
+ x = x_init.clone() if x_init is not None else physics.A_adjoint(y)
69
+ for t in range(self.iterations):
70
+ if self.n_blocks == 1:
71
+ t = 0
72
+ x = self.blocks[t](
73
+ x
74
+ + self.step_size[t]
75
+ * NeuralIteration.measurement_consistency_grad(physics, x, y)
76
+ )
77
+ return x
78
+
79
+
80
+ if __name__ == "__main__":
81
+ import torch
82
+ import deepinv as dinv
83
+
84
+ net = dinv.models.unet().to(dinv.device)
85
+ physics = dinv.physics.Inpainting([32, 32], device=dinv.device)
86
+
87
+ x = torch.randn(10, 1, 32, 32).to(dinv.device)
88
+ y = physics.A(x)
89
+ fbp = physics.A_dagger(y)
90
+ x_rec = net(fbp)
91
+
92
+ unroll = ProximalGradientDescent(net, step_size=1.0, iterations=1)
93
+ x_unroll = unroll(y, physics, x_init=fbp)
94
+
95
+ print("iterations=3")
96
+ iterations = 3
97
+ step_size = 1.0
98
+ blocks = [dinv.models.unet().to(dinv.device) for _ in range(iterations)]
99
+
100
+ physics = dinv.physics.Inpainting([32, 32], device=dinv.device)
101
+ x = torch.randn(10, 1, 32, 32).to(dinv.device)
102
+ y = physics.A(x)
103
+ fbp = physics.A_dagger(y)
104
+
105
+ unroll = ProximalGradientDescent(blocks, step_size=step_size, iterations=iterations)
106
+ x_unroll = unroll(y, physics, x_init=fbp)
107
+
108
+ print(f"iterations={iterations}", x.shape, y.shape, fbp.shape, x_unroll.shape)
@@ -0,0 +1,43 @@
1
+ import numpy as np
2
+
3
+
4
+ def get_DPIR_params(noise_level_img):
5
+ r"""
6
+ Default parameters for the DPIR Plug-and-Play algorithm.
7
+
8
+ :param float noise_level_img: Noise level of the input image.
9
+ """
10
+ max_iter = 8
11
+ s1 = 49.0 / 255.0
12
+ s2 = noise_level_img
13
+ sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
14
+ np.float32
15
+ )
16
+ stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
17
+ lamb = 1 / 0.23
18
+ return lamb, list(sigma_denoiser), list(stepsize), max_iter
19
+
20
+
21
+ def get_GSPnP_params(problem, noise_level_img):
22
+ r"""
23
+ Default parameters for the GSPnP Plug-and-Play algorithm.
24
+
25
+ :param str problem: Type of inverse-problem problem to solve. Can be ``deblur``, ``super-resolution``, or ``inpaint``.
26
+ :param float noise_level_img: Noise level of the input image.
27
+ """
28
+ if problem == "deblur":
29
+ max_iter = 500
30
+ sigma_denoiser = 1.8 * noise_level_img
31
+ lamb = 1 / 0.1
32
+ elif problem == "super-resolution":
33
+ max_iter = 500
34
+ sigma_denoiser = 2.0 * noise_level_img
35
+ lamb = 1 / 0.065
36
+ elif problem == "inpaint":
37
+ max_iter = 100
38
+ sigma_denoiser = 10.0 / 255
39
+ lamb = 1 / 0.1
40
+ else:
41
+ raise ValueError("parameters unknown with this degradation")
42
+ stepsize = 1.0
43
+ return lamb, sigma_denoiser, stepsize, max_iter
@@ -0,0 +1,115 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ try:
5
+ import odl
6
+ except:
7
+ odl = ImportError("The odl package is not installed.")
8
+
9
+
10
+ def random_shapes(interior=False):
11
+ """
12
+ Generate random shape parameters.
13
+ Taken from https://github.com/adler-j/adler/blob/master/adler/odl/phantom.py
14
+ """
15
+ if interior:
16
+ x_0 = np.random.rand() - 0.5
17
+ y_0 = np.random.rand() - 0.5
18
+ else:
19
+ x_0 = 2 * np.random.rand() - 1.0
20
+ y_0 = 2 * np.random.rand() - 1.0
21
+
22
+ return (
23
+ (np.random.rand() - 0.5) * np.random.exponential(0.4),
24
+ np.random.exponential() * 0.2,
25
+ np.random.exponential() * 0.2,
26
+ x_0,
27
+ y_0,
28
+ np.random.rand() * 2 * np.pi,
29
+ )
30
+
31
+
32
+ def random_phantom(spc, n_ellipse=50, interior=False):
33
+ """
34
+ Generate a random ellipsoid phantom.
35
+ Taken from https://github.com/adler-j/adler/blob/master/adler/odl/phantom.py
36
+ """
37
+ if isinstance(odl, ImportError):
38
+ raise ImportError(
39
+ "odl is needed to use generate random phantoms. "
40
+ "It should be installed with `python3 -m pip install"
41
+ " https://github.com/odlgroup/odl/archive/master.zip`"
42
+ ) from odl
43
+ n = np.random.poisson(n_ellipse)
44
+ shapes = [random_shapes(interior=interior) for _ in range(n)]
45
+ return odl.phantom.ellipsoid_phantom(spc, shapes)
46
+
47
+
48
+ class RandomPhantomDataset(torch.utils.data.Dataset):
49
+ """
50
+ Dataset of random ellipsoid phantoms. The phantoms are generated on the fly.
51
+ The phantoms are generated using the odl library (https://odlgroup.github.io/odl/).
52
+
53
+ :param int size: Size of the phantom (square) image.
54
+ :param int n_data: Number of phantoms to generate per sample.
55
+ :param transform: Transformation to apply to the output image.
56
+ :param float length: Length of the dataset. Useful for iterating the data-loader for a certain nb of iterations.
57
+ """
58
+
59
+ def __init__(self, size=128, n_data=1, transform=None, length=np.inf):
60
+ self.space = odl.uniform_discr(
61
+ [-64, -64], [64, 64], [size, size], dtype="float32"
62
+ )
63
+ self.transform = transform
64
+ self.n_data = n_data
65
+ self.length = length
66
+
67
+ def __len__(self):
68
+ return self.length
69
+
70
+ def __getitem__(self, index):
71
+ """
72
+ :return tuple : A tuple (phantom, 0) where phantom is a torch tensor of shape (n_data, size, size).
73
+ """
74
+ phantom_np = np.array([random_phantom(self.space) for i in range(self.n_data)])
75
+ phantom = torch.from_numpy(phantom_np).float()
76
+ if self.transform is not None:
77
+ phantom = self.transform(phantom)
78
+ return phantom, 0
79
+
80
+
81
+ class SheppLoganDataset(torch.utils.data.Dataset):
82
+ """
83
+ Dataset for the single Shepp-Logan phantom. The dataset has length 1.
84
+ """
85
+
86
+ def __init__(self, size=128, n_data=1, transform=None):
87
+ if isinstance(odl, ImportError):
88
+ raise ImportError(
89
+ "odl is needed to use generate the Shepp Logan phantom. "
90
+ "It should be installed with `python3 -m pip install"
91
+ " https://github.com/odlgroup/odl/archive/master.zip`"
92
+ ) from odl
93
+ self.space = odl.uniform_discr(
94
+ [-64, -64], [64, 64], [size, size], dtype="float32"
95
+ )
96
+ self.transform = transform
97
+ self.n_data = n_data
98
+
99
+ def __len__(self):
100
+ return 1
101
+
102
+ def __getitem__(self, index):
103
+ if isinstance(odl, ImportError):
104
+ raise ImportError(
105
+ "odl is needed to use generate the Shepp Logan phantom. "
106
+ "It should be installed with `python3 -m pip install"
107
+ " https://github.com/odlgroup/odl/archive/master.zip`"
108
+ ) from odl
109
+ phantom_np = np.array(
110
+ [odl.phantom.shepp_logan(self.space, True) for i in range(self.n_data)]
111
+ )
112
+ phantom = torch.from_numpy(phantom_np).float()
113
+ if self.transform is not None:
114
+ phantom = self.transform(phantom)
115
+ return phantom, 0