deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/models/dip.py ADDED
@@ -0,0 +1,214 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from deepinv.loss import MCLoss
5
+ from tqdm import tqdm
6
+
7
+
8
+ def add_module(self, module):
9
+ self.add_module(str(len(self) + 1), module)
10
+
11
+
12
+ torch.nn.Module.add = add_module
13
+
14
+
15
+ class ConvDecoder(nn.Module):
16
+ r"""
17
+ Convolutional decoder network.
18
+
19
+ The architecture was introduced in `"Accelerated MRI with Un-trained Neural Networks" <https://arxiv.org/abs/2007.02471>`_,
20
+ and it is well suited as a deep image prior (see :class:`deepinv.models.DeepImagePrior`).
21
+
22
+
23
+ :param tuple img_shape: shape of the output image.
24
+ :param tuple in_size: size of the input vector.
25
+ :param int layers: number of layers in the network.
26
+ :param int channels: number of channels in the network.
27
+
28
+ """
29
+
30
+ # Code adapted from https://github.com/MLI-lab/ConvDecoder/tree/master by Darestani and Heckel.
31
+ def __init__(self, img_shape, in_size=(4, 4), layers=7, channels=256):
32
+ super(ConvDecoder, self).__init__()
33
+
34
+ out_size = img_shape[1:]
35
+ output_channels = img_shape[0]
36
+
37
+ # parameter setup
38
+ kernel_size = 3
39
+ strides = [1] * (layers - 1)
40
+
41
+ # compute up-sampling factor from one layer to another
42
+ scale_x, scale_y = (
43
+ (out_size[0] / in_size[0]) ** (1.0 / (layers - 1)),
44
+ (out_size[1] / in_size[1]) ** (1.0 / (layers - 1)),
45
+ )
46
+ hidden_size = [
47
+ (
48
+ int(np.ceil(scale_x**n * in_size[0])),
49
+ int(np.ceil(scale_y**n * in_size[1])),
50
+ )
51
+ for n in range(1, (layers - 1))
52
+ ] + [out_size]
53
+
54
+ # hidden layers
55
+ self.net = nn.Sequential()
56
+ for i in range(layers - 1):
57
+ self.net.add(nn.Upsample(size=hidden_size[i], mode="nearest"))
58
+ conv = nn.Conv2d(
59
+ channels,
60
+ channels,
61
+ kernel_size,
62
+ strides[i],
63
+ padding=(kernel_size - 1) // 2,
64
+ bias=True,
65
+ )
66
+ self.net.add(conv)
67
+ self.net.add(nn.ReLU())
68
+ self.net.add(nn.BatchNorm2d(channels, affine=True))
69
+ # final layer
70
+ self.net.add(
71
+ nn.Conv2d(
72
+ channels,
73
+ channels,
74
+ kernel_size,
75
+ strides[i],
76
+ padding=(kernel_size - 1) // 2,
77
+ bias=True,
78
+ )
79
+ )
80
+ self.net.add(nn.ReLU())
81
+ self.net.add(nn.BatchNorm2d(channels, affine=True))
82
+ self.net.add(nn.Conv2d(channels, output_channels, 1, 1, padding=0, bias=True))
83
+
84
+ def forward(self, x, scale_out=1):
85
+ return self.net(x) * scale_out
86
+
87
+
88
+ class DeepImagePrior(torch.nn.Module):
89
+ r"""
90
+
91
+ Deep Image Prior reconstruction.
92
+
93
+ This method is based on the paper `"Deep Image Prior" by Ulyanov et al. (2018)
94
+ <https://arxiv.org/abs/1711.10925>`_, and reconstructs
95
+ an image by minimizing the loss function
96
+
97
+ .. math::
98
+
99
+ \min_{\theta} \|y-Af_{\theta}(z)\|^2
100
+
101
+ where :math:`z` is a random input and :math:`f_{\theta}` is a convolutional decoder network with parameters
102
+ :math:`\theta`. The minimization should be stopped early to avoid overfitting. The method uses the Adam
103
+ optimizer.
104
+
105
+ .. note::
106
+
107
+ This method only works with certain convolutional decoder networks. We recommend using the
108
+ network :class:`deepinv.models.ConvDecoder`.
109
+
110
+
111
+ .. note::
112
+
113
+ The number of iterations and learning rate are set to the values used in the original paper. However, these
114
+ values may not be optimal for all problems. We recommend experimenting with different values.
115
+
116
+ :param torch.nn.Module generator: Convolutional decoder network.
117
+ :param list, tuple input_size: Size (C,H,W) of the input noise vector :math:`z`.
118
+ :param int iterations: Number of optimization iterations.
119
+ :param float learning_rate: Learning rate of the Adam optimizer.
120
+ :param bool verbose: If ``True``, print progress.
121
+ :param bool re_init: If ``True``, re-initialize the network parameters before each reconstruction.
122
+
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ generator,
128
+ input_size,
129
+ iterations=2500,
130
+ learning_rate=1e-2,
131
+ verbose=False,
132
+ re_init=False,
133
+ ):
134
+ super().__init__()
135
+ self.generator = generator
136
+ self.max_iter = int(iterations)
137
+ self.lr = learning_rate
138
+ self.loss = MCLoss()
139
+ self.verbose = verbose
140
+ self.re_init = re_init
141
+ self.input_size = input_size
142
+
143
+ def forward(self, y, physics):
144
+ r"""
145
+ Reconstruct an image from the measurement :math:`y`. The reconstruction is performed by solving a minimiza
146
+ problem.
147
+
148
+ .. warning::
149
+
150
+ The optimization is run for every test batch. Thus, this method can be slow when tested on a large
151
+ number of test batches.
152
+
153
+ :param torch.Tensor y: Measurement.
154
+ :param torch.Tensor physics: Physics model.
155
+ """
156
+ if self.re_init:
157
+ for layer in self.generator.children():
158
+ if hasattr(layer, "reset_parameters"):
159
+ layer.reset_parameters()
160
+
161
+ self.generator.requires_grad_(True)
162
+ z = torch.randn(self.input_size, device=y.device).unsqueeze(0)
163
+ optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr)
164
+
165
+ for it in tqdm(range(self.max_iter), disable=(not self.verbose)):
166
+ x = self.generator(z)
167
+ error = self.loss(y, x, physics)
168
+ optimizer.zero_grad()
169
+ error.backward()
170
+ optimizer.step()
171
+
172
+ return self.generator(z)
173
+
174
+
175
+ # test code
176
+ # if __name__ == "__main__":
177
+ # device = "cuda:0"
178
+ # import torchvision
179
+ # import deepinv as dinv
180
+ #
181
+ # device = dinv.utils.get_freer_gpu()
182
+ #
183
+ # x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
184
+ # x = x.unsqueeze(0).float().to(device) / 255
185
+ # x = torchvision.transforms.Resize((128, 128))(x)
186
+ #
187
+ # physics = dinv.physics.Inpainting(
188
+ # tensor_size=x.shape[1:],
189
+ # device=device,
190
+ # noise_model=dinv.physics.GaussianNoise(sigma=0.05),
191
+ # )
192
+ #
193
+ # y = physics(x)
194
+ #
195
+ # iterations = 1000
196
+ # lr = 1e-2
197
+ # channels = 256
198
+ # in_size = [8, 8]
199
+ # backbone = ConvDecoder(
200
+ # img_shape=x.shape[1:], in_size=in_size, channels=channels
201
+ # ).to(device)
202
+ #
203
+ # model = DeepImagePrior(
204
+ # backbone,
205
+ # learning_rate=lr,
206
+ # re_init=True,
207
+ # iterations=iterations,
208
+ # verbose=True,
209
+ # input_size=[channels] + in_size,
210
+ # ).to(device)
211
+ #
212
+ # x_hat = model(y, physics)
213
+ #
214
+ # dinv.utils.plot([x, y, x_hat], titles=["GT", "Meas.", "Recon."])
@@ -0,0 +1,131 @@
1
+ import torch.nn as nn
2
+ import torch
3
+ from .utils import get_weights_url
4
+ import math
5
+
6
+
7
+ class DnCNN(nn.Module):
8
+ r"""
9
+ DnCNN convolutional denoiser.
10
+
11
+ The architecture was introduced by Zhang et al. in https://arxiv.org/abs/1608.03981 and is composed of a series of
12
+ convolutional layers with ReLU activation functions. The number of layers can be specified by the user. Unlike the
13
+ original paper, this implementation does not include batch normalization layers.
14
+
15
+ The network can be initialized with pretrained weights, which can be downloaded from an online repository. The
16
+ pretrained weights are trained with the default parameters of the network, i.e. 20 layers, 64 channels and biases.
17
+
18
+ :param int in_channels: input image channels
19
+ :param int out_channels: output image channels
20
+ :param int depth: number of convolutional layers
21
+ :param bool bias: use bias in the convolutional layers
22
+ :param int nf: number of channels per convolutional layer
23
+ :param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random
24
+ using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an
25
+ online repository (only available for architecture with depth 20, 64 channels and biases).
26
+ It is possible to download weights trained via the regularization method in https://epubs.siam.org/doi/abs/10.1137/20M1387961
27
+ using ``pretrained='download_lipschitz'``.
28
+ Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
29
+ See :ref:`pretrained-weights <pretrained-weights>` for more details.
30
+ :param bool train: training or testing mode
31
+ :param str device: gpu or cpu
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ in_channels=3,
37
+ out_channels=3,
38
+ depth=20,
39
+ bias=True,
40
+ nf=64,
41
+ pretrained="download",
42
+ train=False,
43
+ device="cpu",
44
+ ):
45
+ super(DnCNN, self).__init__()
46
+
47
+ self.depth = depth
48
+
49
+ self.in_conv = nn.Conv2d(
50
+ in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
51
+ )
52
+ self.conv_list = nn.ModuleList(
53
+ [
54
+ nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
55
+ for _ in range(self.depth - 2)
56
+ ]
57
+ )
58
+ self.out_conv = nn.Conv2d(
59
+ nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
60
+ )
61
+
62
+ self.nl_list = nn.ModuleList([nn.ReLU() for _ in range(self.depth - 1)])
63
+
64
+ # if pretrain and ckpt_path is not None:
65
+ # self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage), strict=True)
66
+
67
+ if pretrained is not None:
68
+ if pretrained.startswith("download"):
69
+ name = ""
70
+ if bias and depth == 20:
71
+ if pretrained == "download_lipschitz":
72
+ if in_channels == 3 and out_channels == 3:
73
+ name = "dncnn_sigma2_lipschitz_color.pth"
74
+ elif in_channels == 1 and out_channels == 1:
75
+ name = "dncnn_sigma2_lipschitz_gray.pth"
76
+ else:
77
+ if in_channels == 3 and out_channels == 3:
78
+ name = "dncnn_sigma2_color.pth"
79
+ elif in_channels == 1 and out_channels == 1:
80
+ name = "dncnn_sigma2_gray.pth"
81
+
82
+ if name == "":
83
+ raise Exception(
84
+ "No pretrained weights were found online that match the chosen architecture"
85
+ )
86
+ url = get_weights_url(model_name="dncnn", file_name=name)
87
+ ckpt = torch.hub.load_state_dict_from_url(
88
+ url, map_location=lambda storage, loc: storage, file_name=name
89
+ )
90
+ else:
91
+ ckpt = torch.load(pretrained, map_location=lambda storage, loc: storage)
92
+ self.load_state_dict(ckpt, strict=True)
93
+
94
+ if not train:
95
+ self.eval()
96
+ for _, v in self.named_parameters():
97
+ v.requires_grad = False
98
+ else:
99
+ self.apply(weights_init_kaiming)
100
+
101
+ if device is not None:
102
+ self.to(device)
103
+
104
+ def forward(self, x, sigma=None):
105
+ r"""
106
+ Run the denoiser on noisy image. The noise level is not used in this denoiser.
107
+
108
+ :param torch.Tensor x: noisy image
109
+ :param float sigma: noise level (not used)
110
+ """
111
+ x1 = self.in_conv(x)
112
+ x1 = self.nl_list[0](x1)
113
+
114
+ for i in range(self.depth - 2):
115
+ x_l = self.conv_list[i](x1)
116
+ x1 = self.nl_list[i + 1](x_l)
117
+
118
+ return self.out_conv(x1) + x
119
+
120
+
121
+ def weights_init_kaiming(m):
122
+ classname = m.__class__.__name__
123
+ if classname.find("Conv") != -1:
124
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
125
+ elif classname.find("Linear") != -1:
126
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
127
+ elif classname.find("BatchNorm") != -1:
128
+ m.weight.data.normal_(mean=0, std=math.sqrt(2.0 / 9.0 / 64.0)).clamp_(
129
+ -0.025, 0.025
130
+ )
131
+ nn.init.constant(m.bias.data, 0.0)