deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,312 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from torchvision.utils import make_grid
4
+ import wandb
5
+ import math
6
+ import torch
7
+ import matplotlib.pyplot as plt
8
+ from pathlib import Path
9
+ from collections.abc import Iterable
10
+ import matplotlib
11
+ import shutil
12
+ import torchvision.transforms as T
13
+ import torchvision.transforms.functional as F
14
+
15
+ matplotlib.rcParams.update({"font.size": 17})
16
+ matplotlib.rcParams["lines.linewidth"] = 2
17
+ from matplotlib.ticker import MaxNLocator
18
+
19
+ plt.rcParams["text.usetex"] = True if shutil.which("latex") else False
20
+
21
+
22
+ def resize_pad_square_tensor(tensor, size):
23
+ r"""
24
+ Resize a tensor BxCxWxH to a square tensor BxCxsizexsize with the same aspect ratio thanks to zero-padding.
25
+
26
+ :param torch.Tensor tensor: the tensor to resize.
27
+ :param int size: the new size.
28
+ :return torch.Tensor: the resized tensor.
29
+ """
30
+
31
+ class SquarePad:
32
+ def __call__(self, image):
33
+ W, H = image.size
34
+ print(W, H)
35
+ max_wh = np.max([W, H])
36
+ hp = int((max_wh - W) / 2)
37
+ vp = int((max_wh - H) / 2)
38
+ padding = (hp, vp, hp, vp)
39
+ return F.pad(image, padding, fill=0, padding_mode="constant")
40
+
41
+ transform = T.Compose([T.ToPILImage(), SquarePad(), T.Resize(size), T.ToTensor()])
42
+ return torch.stack([transform(el) for el in tensor])
43
+
44
+
45
+ def torch2cpu(img):
46
+ if img.shape[1] == 2: # for complex images (e.g. in MRI)
47
+ img = img.pow(2).sum(dim=1, keepdim=True).sqrt()
48
+
49
+ return (
50
+ img[0, :, :, :]
51
+ .clamp(min=0.0, max=1.0)
52
+ .detach()
53
+ .permute(1, 2, 0)
54
+ .squeeze()
55
+ .cpu()
56
+ .numpy()
57
+ )
58
+
59
+
60
+ def tensor2uint(img):
61
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
62
+ if img.ndim == 3:
63
+ img = np.transpose(img, (1, 2, 0))
64
+ return np.uint8((img * 255.0).round())
65
+
66
+
67
+ def numpy2uint(img):
68
+ img = img.clip(0, 1)
69
+ return np.uint8((img * 255.0).round())
70
+
71
+
72
+ def rescale_img(img, rescale_mode="min_max"):
73
+ if rescale_mode == "min_max":
74
+ img = (img - img.min()) / (img.max() - img.min())
75
+ elif rescale_mode == "clip":
76
+ img = img.clamp(min=0.0, max=1.0)
77
+ else:
78
+ raise ValueError("rescale_mode has to be either 'min_max' or 'clip'.")
79
+ return img
80
+
81
+
82
+ def plot(
83
+ img_list,
84
+ titles=None,
85
+ save_dir=None,
86
+ tight=True,
87
+ max_imgs=4,
88
+ rescale_mode="min_max",
89
+ show=True,
90
+ return_fig=False,
91
+ ):
92
+ r"""
93
+ Plots a list of images.
94
+
95
+ The images should be of shape [B,C,H,W], where B is the batch size, C is the number of channels,
96
+ H is the height and W is the width. The images are plotted in a grid, where the number of rows is B
97
+ and the number of columns is the length of the list. If the B is bigger than max_imgs, only the first
98
+ batches are plotted.
99
+
100
+ Example usage:
101
+
102
+ ::
103
+ import torch
104
+ from deepinv.utils import plot
105
+ img = torch.rand(4, 3, 256, 256)
106
+ plot([img, img, img], titles=["img1", "img2", "img3"], save_dir="test.png")
107
+
108
+ :param list[torch.Tensor], torch.Tensor img_list: list of images to plot or single image.
109
+ :param list[str] titles: list of titles for each image, has to be same length as img_list.
110
+ :param str save_dir: path to save the plot.
111
+ :param bool tight: use tight layout.
112
+ :param int max_imgs: maximum number of images to plot.
113
+ :param str rescale_mode: rescale mode, either 'min_max' (images are linearly rescaled between 0 and 1 using their min and max values) or 'clip' (images are clipped between 0 and 1).
114
+ :param bool show: show the image plot.
115
+ :param bool return_fig: return the figure object.
116
+ """
117
+ if save_dir:
118
+ save_dir = Path(save_dir)
119
+ save_dir.mkdir(parents=True, exist_ok=True)
120
+
121
+ if isinstance(img_list, torch.Tensor):
122
+ img_list = [img_list]
123
+
124
+ if isinstance(titles, str):
125
+ titles = [titles]
126
+
127
+ imgs = []
128
+ for im in img_list:
129
+ col_imgs = []
130
+ for i in range(min(im.shape[0], max_imgs)):
131
+ if im.shape[1] == 2: # for complex images
132
+ pimg = (
133
+ im[i, :, :, :]
134
+ .pow(2)
135
+ .sum(dim=0)
136
+ .sqrt()
137
+ .unsqueeze(0)
138
+ .type(torch.float32)
139
+ )
140
+ else:
141
+ pimg = im[i, :, :, :].type(torch.float32)
142
+ pimg = rescale_img(pimg, rescale_mode=rescale_mode)
143
+ col_imgs.append(pimg.detach().permute(1, 2, 0).squeeze().cpu().numpy())
144
+ imgs.append(col_imgs)
145
+
146
+ fig, axs = plt.subplots(
147
+ len(imgs[0]),
148
+ len(imgs),
149
+ figsize=(len(imgs) * 2, len(imgs[0]) * 2),
150
+ squeeze=False,
151
+ )
152
+
153
+ # plt.figure(figsize=(len(imgs) * 2, len(imgs[0]) * 2))
154
+ for i, row_imgs in enumerate(imgs):
155
+ for r, img in enumerate(row_imgs):
156
+ axs[r, i].imshow(img, cmap="gray")
157
+ if titles and r == 0:
158
+ axs[r, i].set_title(titles[i], size=9)
159
+ axs[r, i].axis("off")
160
+ if tight:
161
+ plt.subplots_adjust(hspace=0.01, wspace=0.05)
162
+ if save_dir:
163
+ plt.savefig(save_dir / "images.png", dpi=1200)
164
+ for i, row_imgs in enumerate(imgs):
165
+ for r, img in enumerate(row_imgs):
166
+ plt.imsave(
167
+ save_dir / (titles[i] + "_" + str(r) + ".png"), img, cmap="gray"
168
+ )
169
+ if show:
170
+ plt.show()
171
+
172
+ if return_fig:
173
+ return fig
174
+
175
+
176
+ def plot_curves(metrics, save_dir=None, show=True):
177
+ r"""
178
+ Plots the metrics of a Plug-and-Play algorithm.
179
+
180
+ :param dict metrics: dictionary of metrics to plot.
181
+ :param str save_dir: path to save the plot.
182
+ :param bool show: show the image plot.
183
+ """
184
+ if save_dir:
185
+ save_dir = Path(save_dir)
186
+ save_dir.mkdir(parents=True, exist_ok=True)
187
+ fig, axs = plt.subplots(
188
+ 1, len(metrics.keys()), figsize=(6 * len(metrics.keys()), 4)
189
+ )
190
+ for i, metric_name in enumerate(metrics.keys()):
191
+ metric_val = metrics[metric_name]
192
+ if len(metric_val) > 0:
193
+ batch_size, n_iter = len(metric_val), len(metric_val[0])
194
+ axs[i].spines["right"].set_visible(False)
195
+ axs[i].spines["top"].set_visible(False)
196
+ if metric_name == "residual":
197
+ label = (
198
+ r"Residual $\frac{||x_{k+1} - x_k||}{||x_k||}$"
199
+ if plt.rcParams["text.usetex"]
200
+ else "residual"
201
+ )
202
+ log_scale = True
203
+ elif metric_name == "psnr":
204
+ label = r"$PSNR(x_k)$" if plt.rcParams["text.usetex"] else "PSNR"
205
+ log_scale = False
206
+ elif metric_name == "cost":
207
+ label = r"$F(x_k)$" if plt.rcParams["text.usetex"] else "F"
208
+ log_scale = True
209
+ else:
210
+ label = metric_name
211
+ log_scale = False
212
+ for b in range(batch_size):
213
+ if not log_scale:
214
+ axs[i].plot(metric_val[b], "-o", label=f"batch {b+1}")
215
+ else:
216
+ axs[i].semilogy(metric_val[b], "-o", label=f"batch {b+1}")
217
+ axs[i].xaxis.set_major_locator(MaxNLocator(integer=True))
218
+ # axs[i].set_xlabel("iterations")
219
+ axs[i].set_title(label)
220
+ axs[i].legend()
221
+ plt.subplots_adjust(hspace=0.1)
222
+ if save_dir:
223
+ plt.savefig(save_dir / "curves.png")
224
+ if show:
225
+ plt.show()
226
+
227
+
228
+ def wandb_imgs(imgs, captions, n_plot):
229
+ wandb_imgs = []
230
+ for i in range(len(imgs)):
231
+ wandb_imgs.append(
232
+ wandb.Image(
233
+ make_grid(imgs[i][:n_plot], nrow=int(math.sqrt(n_plot)) + 1),
234
+ caption=captions[i],
235
+ )
236
+ )
237
+ return wandb_imgs
238
+
239
+
240
+ def wandb_plot_curves(metrics, batch_idx=0, step=0):
241
+ for metric_name, metric_val in zip(metrics.keys(), metrics.values()):
242
+ if len(metric_val) > 0:
243
+ batch_size, n_iter = len(metric_val), len(metric_val[0])
244
+ wandb.log(
245
+ {
246
+ f"{metric_name} batch {batch_idx}": wandb.plot.line_series(
247
+ xs=range(n_iter),
248
+ ys=metric_val,
249
+ keys=[f"image {j}" for j in range(batch_size)],
250
+ title=f"{metric_name} batch {batch_idx}",
251
+ xname="iteration",
252
+ )
253
+ },
254
+ step=step,
255
+ )
256
+
257
+
258
+ def plot_parameters(model, init_params=None, save_dir=None, show=True):
259
+ r"""
260
+ Plot the parameters of the model before and after training.
261
+ This can be used after training Unfolded optimization models.
262
+
263
+ :param torch.nn.Module model: the model whose parameters are plotted. The parameters are contained in the dictionary
264
+ ``params_algo`` attribute of the model.
265
+ :param dict init_params: the initial parameters of the model, before training. Defaults to ``None``.
266
+ :param str, Path save_dir: the directory where to save the plot. Defaults to ``None``.
267
+ :param show bool: whether to show the plot. Defaults to ``True``.
268
+ """
269
+
270
+ color = ["b", "g", "r", "c", "m", "y", "k", "w"]
271
+
272
+ fig, ax = plt.subplots(figsize=(7, 7))
273
+
274
+ for key, value in zip(init_params.keys(), init_params.values()):
275
+ if not isinstance(value, Iterable):
276
+ init_params[key] = [value]
277
+
278
+ def get_param(param):
279
+ if torch.is_tensor(param):
280
+ if len(param.shape) > 0:
281
+ return param[0].mean().item()
282
+ else:
283
+ return param.item()
284
+ else:
285
+ return param
286
+
287
+ for i, name_param in enumerate(model.params_algo):
288
+ value = [
289
+ get_param(model.params_algo[name_param][k])
290
+ for k in range(len(model.params_algo[name_param]))
291
+ ]
292
+ if init_params is not None and name_param in init_params:
293
+ value_init = [
294
+ get_param(init_params[name_param][k])
295
+ for k in range(len(init_params[name_param]))
296
+ ]
297
+ ax.plot(value_init, "--o", label="init. " + name_param, color=color[i])
298
+ ax.plot(value, "-o", label="learned " + name_param, color=color[i])
299
+
300
+ # Set labels and title
301
+ ax.set_facecolor("white")
302
+ ax.set_xticks(np.arange(len(value), step=5))
303
+ ax.set_xlabel("Layer index")
304
+ ax.set_ylabel("Value")
305
+ ax.grid(True, linestyle="-", alpha=0.5, color="lightgray")
306
+ ax.tick_params(color="lightgray")
307
+ ax.legend()
308
+
309
+ if show:
310
+ plt.show()
311
+ if save_dir:
312
+ plt.savefig(Path(save_dir) / "parameters.png")
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, deepinv
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,159 @@
1
+ Metadata-Version: 2.1
2
+ Name: deepinv
3
+ Version: 0.1.0.dev0
4
+ Summary: Pytorch library for solving inverse problems with deep learning
5
+ Author: Matthieu Terris, Samuel Hurault, Dongdong Chen
6
+ Author-email: Julian Tachella <tachellajulian@gmail.com>
7
+ License: BSD 3-Clause
8
+ Project-URL: Homepage, https://deepinv.github.io/
9
+ Project-URL: Source, https://github.com/deepinv/deepinv
10
+ Project-URL: Tracker, https://github.com/deepinv/deepinv/issues
11
+ Platform: any
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: BSD License
16
+ Classifier: Operating System :: OS Independent
17
+ Classifier: Programming Language :: Python :: 3
18
+ Classifier: Programming Language :: Python :: 3.8
19
+ Classifier: Programming Language :: Python :: 3.9
20
+ Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Programming Language :: Python :: 3.12
23
+ Classifier: Topic :: Utilities
24
+ Classifier: Topic :: Scientific/Engineering
25
+ Classifier: Topic :: Software Development :: Libraries
26
+ Requires-Python: >=3.8
27
+ Description-Content-Type: text/x-rst
28
+ License-File: LICENSE
29
+ Requires-Dist: numpy
30
+ Requires-Dist: matplotlib
31
+ Requires-Dist: hdf5storage
32
+ Requires-Dist: torch
33
+ Requires-Dist: torchvision
34
+ Requires-Dist: einops
35
+ Requires-Dist: wandb
36
+ Requires-Dist: fastmri
37
+ Provides-Extra: denoisers
38
+ Requires-Dist: bm3d ; extra == 'denoisers'
39
+ Requires-Dist: timm ; extra == 'denoisers'
40
+ Provides-Extra: doc
41
+ Requires-Dist: sphinx ; extra == 'doc'
42
+ Requires-Dist: sphinx-gallery ; extra == 'doc'
43
+ Requires-Dist: sphinx-rtd-theme ; extra == 'doc'
44
+ Requires-Dist: sphinxemoji ; extra == 'doc'
45
+ Provides-Extra: test
46
+ Requires-Dist: pytest ; extra == 'test'
47
+ Requires-Dist: pytest-cov ; extra == 'test'
48
+ Requires-Dist: coverage ; extra == 'test'
49
+
50
+ .. image:: https://github.com/deepinv/deepinv/raw/main/docs/source/figures/deepinv_logolarge.png
51
+ :width: 500px
52
+ :alt: deepinv logo
53
+ :align: center
54
+
55
+
56
+ |Test Status| |Docs Status| |Python 3.6+| |codecov| |Black| |discord| |colab|
57
+
58
+
59
+ Introduction
60
+ ------------
61
+ Deep Inverse is an open-source pytorch library for solving imaging inverse problems using deep learning. The goal of ``deepinv`` is to accelerate the development of deep learning based methods for imaging inverse problems, by combining popular learning-based reconstruction approaches in a common and simplified framework, standarizing forward imaging models and simplifying the creation of imaging datasets.
62
+
63
+ With ``deepinv`` you can:
64
+
65
+
66
+ * Large collection of `predefined imaging operators <https://deepinv.github.io/deepinv/deepinv.physics.html>`_ (MRI, CT, deblurring, inpainting, etc.)
67
+ * `Training losses <https://deepinv.github.io/deepinv/deepinv.loss.html>`_ for inverse problems (self-supervised learning, regularization, etc.).
68
+ * Many `pretrained deep denoisers <https://deepinv.github.io/deepinv/deepinv.models.html>`_ which can be used for `plug-and-play restoration <https://deepinv.github.io/deepinv/deepinv.pnp.html>`_.
69
+ * Framework for `building datasets <https://deepinv.github.io/deepinv/deepinv.datasets.html>`_ for inverse problems.
70
+ * Easy-to-build `unfolded architectures <https://deepinv.github.io/deepinv/deepinv.unfolded.html>`_ (ADMM, forward-backward, deep equilibrium, etc.).
71
+ * `Sampling algorithms <https://deepinv.github.io/deepinv/deepinv.sampling.html>`_ for uncertainty quantification (Langevin, diffusion, etc.).
72
+ * A large number of well-explained `examples <https://deepinv.github.io/deepinv/auto_examples/index.html>`_, from basics to state-of-the-art methods.
73
+
74
+ .. image:: https://github.com/deepinv/deepinv/raw/main/docs/source/figures/deepinv_schematic.png
75
+ :width: 1000px
76
+ :alt: deepinv schematic
77
+ :align: center
78
+
79
+
80
+ Documentation
81
+ -------------
82
+
83
+ Read the documentation and examples at `https://deepinv.github.io <https://deepinv.github.io>`_.
84
+
85
+ Install
86
+ -------
87
+
88
+ To install the latest stable release of ``deepinv``, you can simply do:
89
+
90
+ .. code-block:: bash
91
+
92
+ pip install deepinv
93
+
94
+ You can also install the latest version of ``deepinv`` directly from github:
95
+
96
+ .. code-block:: bash
97
+
98
+ pip install git+https://github.com/deepinv/deepinv.git#egg=deepinv
99
+
100
+ Getting Started
101
+ ---------------
102
+ Try out the following plug-and-play image inpainting example:
103
+
104
+ .. code-block:: python
105
+
106
+ import deepinv as dinv
107
+ from deepinv.utils import load_url_image
108
+
109
+ url = ("https://huggingface.co/datasets/deepinv/images/resolve/main/cameraman.png?download=true")
110
+ x = load_url_image(url=url, img_size=512, grayscale=True, device='cpu')
111
+
112
+ physics = dinv.physics.Inpainting((1, 512, 512), mask = 0.5, \
113
+ noise_model=dinv.physics.GaussianNoise(sigma=0.01))
114
+
115
+ data_fidelity = dinv.optim.data_fidelity.L2()
116
+ prior = dinv.optim.prior.PnP(denoiser=dinv.models.MedianFilter())
117
+ model = dinv.optim.optim_builder(iteration="HQS", prior=prior, data_fidelity=data_fidelity, \
118
+ params_algo={"stepsize": 1.0, "g_param": 0.1, "lambda": 2.})
119
+ y = physics(x)
120
+ x_hat = model(y, physics)
121
+ dinv.utils.plot([x, y, x_hat], ["signal", "measurement", "estimate"], rescale_mode='clip')
122
+
123
+
124
+ Also try out `one of the examples <https://deepinv.github.io/deepinv/auto_examples/index.html>`_ to get started.
125
+
126
+ Contributing
127
+ ------------
128
+
129
+ DeepInverse is a community-driven project and welcomes contributions of all forms.
130
+ We are ultimately aiming for a comprehensive library of inverse problems and deep learning,
131
+ and we need your help to get there!
132
+ The preferred way to contribute to ``deepinv`` is to fork the `main
133
+ repository <https://github.com/deepinv/deepinv/>`_ on GitHub,
134
+ then submit a "Pull Request" (PR). See our `contributing guide <https://deepinv.github.io/deepinv/deepinv.contributing.html>`_
135
+ for more details.
136
+
137
+
138
+ Finding help
139
+ ------------
140
+
141
+ If you have any questions or suggestions, please join the conversation in our
142
+ `Discord server <https://discord.gg/qBqY5jKw3p>`_. The recommended way to get in touch with the developers is to open an issue on the
143
+ `issue tracker <https://github.com/deepinv/deepinv/issues>`_.
144
+
145
+
146
+ .. |Black| image:: https://img.shields.io/badge/code%20style-black-000000.svg
147
+ :target: https://github.com/psf/black
148
+ .. |Test Status| image:: https://github.com/deepinv/deepinv/actions/workflows/test.yml/badge.svg
149
+ :target: https://github.com/deepinv/deepinv/actions/workflows/test.yml
150
+ .. |Docs Status| image:: https://github.com/deepinv/deepinv/actions/workflows/documentation.yml/badge.svg
151
+ :target: https://github.com/deepinv/deepinv/actions/workflows/documentation.yml
152
+ .. |Python 3.6+| image:: https://img.shields.io/badge/python-3.6%2B-blue
153
+ :target: https://www.python.org/downloads/release/python-360/
154
+ .. |codecov| image:: https://codecov.io/gh/deepinv/deepinv/branch/main/graph/badge.svg?token=77JRvUhQzh
155
+ :target: https://codecov.io/gh/deepinv/deepinv
156
+ .. |discord| image:: https://dcbadge.vercel.app/api/server/qBqY5jKw3p?style=flat
157
+ :target: https://discord.gg/qBqY5jKw3p
158
+ .. |colab| image:: https://colab.research.google.com/assets/colab-badge.svg
159
+ :target: https://colab.research.google.com/drive/1XhCO5S1dYN3eKm4NEkczzVU7ZLBuE42J
@@ -0,0 +1,97 @@
1
+ deepinv/__about__.py,sha256=w8__dQ5D_ge5R2CUQgFvMpcOR0AbIEvZRV3CurjDnZE,443
2
+ deepinv/__init__.py,sha256=tXWYmAQ0qcZNMa2Czjc-5k4ufLLRTic4iy5MEFOCQgw,1173
3
+ deepinv/training_utils.py,sha256=e9Q74q9i6CALe1X3puJYUa3NZJZunve5PMMvfYG87BI,21241
4
+ deepinv/datasets/__init__.py,sha256=pTyi2rZw4QUs5l9GwSRjkoubrWXJd7EEvLs7HZfoV2o,57
5
+ deepinv/datasets/datagenerator.py,sha256=dI2a0_Lfa28OMwxYQas89KEhIRgLUYSHvGTP1bmnVSI,8784
6
+ deepinv/loss/__init__.py,sha256=ge4IWNAz7gPnzn2jml3byUj5OfHfQZY2qazL8QdsoA0,503
7
+ deepinv/loss/ei.py,sha256=57nLB4b2YMuOCMiGap30me6WbuUpX0dq3riJG4DSULA,2681
8
+ deepinv/loss/mc.py,sha256=uQa0p56f1TjzzxWi77o-2cEn-sFaCsUvRDwM8VWXviU,1306
9
+ deepinv/loss/measplit.py,sha256=FJRanev2VDiPpn5uP_ATCcG3wayKZq7hvng3F3Kogu8,8170
10
+ deepinv/loss/metric.py,sha256=X0kOzp1--PYR-QRfcO1bFS-P6iBCAQaXN-wQtP-n07E,3764
11
+ deepinv/loss/moi.py,sha256=PC3YxSO02NCvIMt9MUrl5sQ53R6c4BZmssgMI8xQG2o,2306
12
+ deepinv/loss/regularisers.py,sha256=RgaqpmNAOqas90CKwStVijSdrMvMH2QOoqbVr4W0rkY,5756
13
+ deepinv/loss/score.py,sha256=C4D9O3JwAYQWJkCxzcWK1nU8Zc-kXdHByJ62dEI8iZ4,1222
14
+ deepinv/loss/sup.py,sha256=T6dtt1fqqlNSho0dhPM7mMEGPoXbWlK86OlGdLTYUUg,1063
15
+ deepinv/loss/sure.py,sha256=1rSyrjwNcnEmM-nrrzG_nxDh8rUGb80xavlAB9PP8hI,11893
16
+ deepinv/loss/tv.py,sha256=7DEALAF4pIC1YQ2ztLOqJOqyobDS5gI5ykp4GS66SaA,1331
17
+ deepinv/models/GSPnP.py,sha256=fquGK8wwcSO3FBTPp9bIdlUzXQgltW1UZFen4vYkttk,4377
18
+ deepinv/models/PDNet.py,sha256=7mggqLQloPKdXjs936EjV7ewPXyQhyfh1cLsKlRHAAc,4217
19
+ deepinv/models/__init__.py,sha256=lEk5q4pE3Ud1BtQ_JqGKuPnDmcdmNXQhb391mfVfIuU,561
20
+ deepinv/models/ae.py,sha256=vySCaddG0HIWWDOj35rVnYt61o3xR7ynNvB0rjxI7RY,1253
21
+ deepinv/models/artifactremoval.py,sha256=YMt8Vz4fgDlicT8DJffHN5D9kcjJ7WQKytvWrmX6Cnw,2081
22
+ deepinv/models/bm3d.py,sha256=-7HIsb1NSc7GHMyB4Tpr8n39RglbSErp0jGyL_ww7BQ,1488
23
+ deepinv/models/diffunet.py,sha256=4tpVcxBogsQh5tZ-TVu8IULeQHCuA2KY-_ZNcXOvjBo,35458
24
+ deepinv/models/dip.py,sha256=teNyf9lvqt2D8WlCFZOCrVBv4SUmBHQednYH8gZKlLk,6731
25
+ deepinv/models/dncnn.py,sha256=5z1AuA1hsa52KjVWvIDUQSZNYbPgHTduHbULuCsEchY,5304
26
+ deepinv/models/drunet.py,sha256=oQtwZuxA-3hDm1iQ2qjWHkm9m2V7FUgF9LhhYbR-7_8,20515
27
+ deepinv/models/equivariant.py,sha256=EmjW7tSdK0fBq8BAEHFgYPaMfbFoq9IortEhyQcXwZ4,5223
28
+ deepinv/models/median.py,sha256=DToa-JD6tl_bzUBevIEjfgezsUO0_dHrkfpS90X_qkc,1838
29
+ deepinv/models/scunet.py,sha256=vesyn2QC-i0TPMSOo4wZVclFpK5jcrPmVMXGOJ80Mg8,15946
30
+ deepinv/models/swinir.py,sha256=aOLQB12ESV329Hxil7biL2TgAGpBwfDTnIQuVbkfcBI,41374
31
+ deepinv/models/tgv.py,sha256=TifN7rRbXvkeQZlzT7eTMnErrfKDjyq2w2txZ80SETM,8935
32
+ deepinv/models/tv.py,sha256=iFnunUbj8ZMyeUdQBSXG_vZ4r_f9NeehPxzEsx92w-c,4870
33
+ deepinv/models/unet.py,sha256=VHCv-90KY8-jLC9Vz3YM4UejqwaLbVL3IwwGmfJUkTA,10671
34
+ deepinv/models/utils.py,sha256=fWes8C8sNx-AWfKG6hcG2Ud1YB-nns2io7DTl9H-RU0,425
35
+ deepinv/models/wavdict.py,sha256=u3sbWXuTKxrSNxhP4UIRjD9HeLT832gpDFQV72KYMSs,9035
36
+ deepinv/optim/__init__.py,sha256=FzUNmtRF6ikAEQX6nBjw4u9loko50VhcXkqwJ9D8BH8,289
37
+ deepinv/optim/data_fidelity.py,sha256=bx7n7jAkSw2waAJzd47m6SaRaWtteA3P72XOxn0oi9Q,23428
38
+ deepinv/optim/fixed_point.py,sha256=h6enJ02a1Jbo-A9cmQge2snmr6SZr9yUSY_38cS4-uQ,12348
39
+ deepinv/optim/optimizers.py,sha256=pDsgYK2taef7RpHydUdvPyUgnfsYypK-Hiuix-dUgHo,27210
40
+ deepinv/optim/prior.py,sha256=GbSZATTydZUriSXK2YvwEzARkRqd7k5KTnT6yfhnzDo,10536
41
+ deepinv/optim/utils.py,sha256=0iH5S3QXaVVoQdjzlcnmv1OA14W4VhbNXQE_35FcBR8,2413
42
+ deepinv/optim/optim_iterators/__init__.py,sha256=VL4nhMnqlvZKRFoJ6JwGQETaNbgwDOOf4j9KPc93TVQ,343
43
+ deepinv/optim/optim_iterators/admm.py,sha256=bz3_RoCHIt0xWEdyenHnjkYaV5-OKa6Cv_LNzBfnPRM,4626
44
+ deepinv/optim/optim_iterators/drs.py,sha256=mw3QbNH1CXpJBqGtoiburWKxIo_xZv7E-6CVAuv6QqM,4479
45
+ deepinv/optim/optim_iterators/gradient_descent.py,sha256=GFY_a-fnOoRP6QdoKz0NShAkVNOZfmJL7XVI4gf9sII,3346
46
+ deepinv/optim/optim_iterators/hqs.py,sha256=BukQnb7uyDYgTmMMo1uIMGa3C1Qfu1wMWutbHQCtHtc,2502
47
+ deepinv/optim/optim_iterators/optim_iterator.py,sha256=R-vWQ-xmAGWCnwtnSU-UaKlT0KemeGeoFYYRjxvL_Fw,6215
48
+ deepinv/optim/optim_iterators/pgd.py,sha256=S_bqjLkhk5OomtPADkaOUUApotV867WJMYgDWUvTtmI,3059
49
+ deepinv/optim/optim_iterators/primal_dual_CP.py,sha256=VqYsEkfKtrkTe5QCGI5v80osLANWqoMWzsdVYOP0mo8,6482
50
+ deepinv/optim/optim_iterators/utils.py,sha256=DV0v8tIhMg61x3fCBmr_-fCrf6o1x0koIBgF17wjygk,652
51
+ deepinv/physics/__init__.py,sha256=lFziroY4avj_oTc-VlmplZGsG9fJDsucbjP6KMRJB_U,576
52
+ deepinv/physics/blur.py,sha256=oQtAm6lFfe8D2aYZdYGLJDJOsh7V4HEGqM3QkCBtqzs,18580
53
+ deepinv/physics/compressed_sensing.py,sha256=xyn-y0wTj45I4ZlQdwurMLgon9xw4ShcyJOPS2_rTPc,6734
54
+ deepinv/physics/forward.py,sha256=MOSZllqCBNZgAnAecMYbtKtkwKT_KRq1logMtjqdorM,19275
55
+ deepinv/physics/haze.py,sha256=IAPPgvqfwcMxPHvD6bEZkrQDDsX8fqn_smuJkZLh_o8,2041
56
+ deepinv/physics/inpainting.py,sha256=JZZbpC6T1u33PLiSTWm1CxkK0YlN4mN-5yvvFUJTAz4,1886
57
+ deepinv/physics/lidar.py,sha256=NtxIVRwVOhPCZzR5dz5W569KULE4yThdOTVqHK_7YHo,4033
58
+ deepinv/physics/mri.py,sha256=Tsgf3xRXNge8p1Vf39INab8MsoXtNZskUtW2rMo4pMA,10595
59
+ deepinv/physics/noise.py,sha256=AKiJwyww79Kp-ifSGqZGMZAkWpC7aRIh6Od4CRhT_Rk,5128
60
+ deepinv/physics/range.py,sha256=uRyVkdTRmj4fGps95v5ANXfFt7B5G3ZQzRUInNJa118,1377
61
+ deepinv/physics/remote_sensing.py,sha256=io8nMunzTfpVqbi6R5m1SwJH90l-lQWKjsVUq_cSWc0,4211
62
+ deepinv/physics/singlepixel.py,sha256=UwYRbADZBcr9Pz2KG5MHNAvIS1P94l18u0lKv2sGqbE,6951
63
+ deepinv/physics/tomography.py,sha256=iac16rpYtkMdztRGzreJlVduNVy3NOTj0PQIPAvmNoM,10688
64
+ deepinv/sampling/__init__.py,sha256=zjl6vOe5WbYOv4qN9SoOPAohWAQOhEWVasbgORuwXfQ,106
65
+ deepinv/sampling/diffusion.py,sha256=ojpZpA_1l-OF4iAUeHpCrdFuPEyocoqLqEPzvbHI7Ps,25223
66
+ deepinv/sampling/langevin.py,sha256=Jl-PdU45LXN3pZsKODjjQL-YU0QxAf5tL6VkycVps_M,19732
67
+ deepinv/sampling/utils.py,sha256=XngAkD5bxOYWjWqWJ4j0USQsF4SSgfT7ObS5c3iGivI,796
68
+ deepinv/tests/conftest.py,sha256=rRDUpRZJhB-DoPaLzBfE_n9-6epwWFZhCBOTBcx6MUQ,613
69
+ deepinv/tests/test_loss.py,sha256=HeK3X2qHTQXerHub5U4-7yznf6LSY7cxLnxETAcBN-U,7894
70
+ deepinv/tests/test_loss_train.py,sha256=OhWwygE3T_vDLFQj8RSH8rWnEkZOWJqUkJU7hL4lmmU,5408
71
+ deepinv/tests/test_models.py,sha256=ilOifkKgss-QRX7LoQ9KoCyQ7T84dh8CAzeP5dsn8_8,12277
72
+ deepinv/tests/test_optim.py,sha256=O-B3pab_j0YXmYecOd1yheUdn9y6QozB-ndyDD11B3U,22137
73
+ deepinv/tests/test_physics.py,sha256=zwnARYFmT8SBSPuuqI-f5slOpHREvxc6_a1bFuiX-fY,10152
74
+ deepinv/tests/test_sampling.py,sha256=uX2BGV3li8kq0fbxGuSD6ieyhYtTwBntFJ1I-jaNk1Y,4199
75
+ deepinv/tests/test_unfolded.py,sha256=gQ_gyeGiTYTpAeL4YnAdxmSX0KUs0JkJ4_Y9XdT_2qI,5670
76
+ deepinv/tests/test_utils.py,sha256=tKk2MzIdsfIrOLynmUH1CkxCR5H66kcV872jNunm0Bw,1701
77
+ deepinv/tests/dummy_datasets/datasets.py,sha256=h8ZAOQ8E4IViVoPS7xrXdtjtvIpYttLcpSPtSMAQvq4,1730
78
+ deepinv/transform/__init__.py,sha256=WpWFMIsmWgldzkgCe2bGLEowNJnuF6oRl5sASxrK4b0,52
79
+ deepinv/transform/rotate.py,sha256=O8JKmXJxAJb4JLqavetINE7PxRrrvwauiPYFCbLQ31s,1238
80
+ deepinv/transform/shift.py,sha256=4yCt_HiE9vwLwsav2wGYonc7F7DhYsuaPTvJDHSxBFU,765
81
+ deepinv/unfolded/__init__.py,sha256=7qI4_I3JMqM4NJx703-FiPoU3DiXX3VcnsGVH1mR-_4,102
82
+ deepinv/unfolded/deep_equilibrium.py,sha256=1uS3iTs_wbD0wuQ7tGu29ELW5K3b9O_-OcYRtHZAr-E,8074
83
+ deepinv/unfolded/unfolded.py,sha256=X1mTbRiqWWAS5LBedotqP13-G47FtY4Hgv8lHIorxo4,4963
84
+ deepinv/utils/__init__.py,sha256=PTtDkE1oys0E83IQlkmy6oK75Az-UqiIq36zq1Tn1mM,555
85
+ deepinv/utils/demo.py,sha256=OmTtqwVM5ckzAupqEZLS78xX2YpJtFyr4eIdkXi6QbU,5386
86
+ deepinv/utils/logger.py,sha256=B5wS8Q2j6taMUNREJjOsqEvd_Igk6xtnHjTtCOq8B2A,2459
87
+ deepinv/utils/metric.py,sha256=qI28-1LppWHVeXM1m4LAEiuBoqieXouVT_qeNqJQZsI,2484
88
+ deepinv/utils/nn.py,sha256=24xpuTlbw-CO-e0LHQUSo1xjVoIh_TQyF9rlQs1gJkc,6529
89
+ deepinv/utils/optimization.py,sha256=VxBw3vkNDVGMXC_N2Cg65RJNV1rS1FcjnP43_PxIrn4,3702
90
+ deepinv/utils/parameters.py,sha256=-CoD7vneBFbgQnqL2ccdwvs4nOK6OfvFLQpoH4QPdk0,1360
91
+ deepinv/utils/phantoms.py,sha256=9bxl3Zsd7AQAWP3JKFy4R7VkC_BSeahbdNG-K3Uxhmc,4003
92
+ deepinv/utils/plotting.py,sha256=Tt0yBu_eF1A5gMc9wO4ABrC7-E5rF6Bs09zNm-pC-tk,10560
93
+ deepinv-0.1.0.dev0.dist-info/LICENSE,sha256=Up-MehfxmWSgdq040Ig_uPlexDvIiFdvgg1RQeEn8ME,1494
94
+ deepinv-0.1.0.dev0.dist-info/METADATA,sha256=szfX8-Ga9YxDW1hBRZXFwiHLG2KSjOzw6BoXrDh4_-w,7176
95
+ deepinv-0.1.0.dev0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
96
+ deepinv-0.1.0.dev0.dist-info/top_level.txt,sha256=ZwPaH-khG7KX_p-hgNQEL4gY4pz5UzV76VbOY7zVfSI,8
97
+ deepinv-0.1.0.dev0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.42.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ deepinv