deepinv 0.1.0.dev0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/PKG-INFO +33 -33
  2. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/README.rst +25 -25
  3. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/__about__.py +3 -5
  4. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/__init__.py +3 -12
  5. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/datasets/__init__.py +1 -0
  6. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/datasets/datagenerator.py +23 -20
  7. deepinv-0.2.0/deepinv/datasets/patch_dataset.py +29 -0
  8. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/__init__.py +3 -1
  9. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/ei.py +2 -1
  10. deepinv-0.2.0/deepinv/loss/loss.py +29 -0
  11. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/mc.py +2 -2
  12. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/measplit.py +6 -5
  13. deepinv-0.2.0/deepinv/loss/metric.py +242 -0
  14. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/moi.py +2 -2
  15. deepinv-0.2.0/deepinv/loss/r2r.py +70 -0
  16. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/regularisers.py +18 -16
  17. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/score.py +2 -2
  18. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/sup.py +2 -1
  19. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/sure.py +21 -52
  20. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/loss/tv.py +14 -6
  21. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/GSPnP.py +21 -19
  22. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/PDNet.py +32 -26
  23. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/__init__.py +5 -3
  24. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/artifactremoval.py +1 -1
  25. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/bm3d.py +14 -4
  26. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/diffunet.py +3 -3
  27. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/dip.py +3 -3
  28. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/drunet.py +8 -66
  29. deepinv-0.2.0/deepinv/models/epll.py +65 -0
  30. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/equivariant.py +1 -1
  31. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/median.py +1 -0
  32. deepinv-0.2.0/deepinv/models/restormer.py +739 -0
  33. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/scunet.py +1 -1
  34. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/swinir.py +10 -10
  35. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/tgv.py +78 -79
  36. deepinv-0.2.0/deepinv/models/tv.py +164 -0
  37. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/unet.py +25 -4
  38. deepinv-0.2.0/deepinv/models/utils.py +75 -0
  39. deepinv-0.2.0/deepinv/models/wavdict.py +373 -0
  40. deepinv-0.2.0/deepinv/optim/__init__.py +27 -0
  41. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/data_fidelity.py +194 -112
  42. deepinv-0.2.0/deepinv/optim/dpir.py +55 -0
  43. deepinv-0.2.0/deepinv/optim/epll.py +195 -0
  44. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/fixed_point.py +36 -50
  45. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optim_iterators/__init__.py +2 -3
  46. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optim_iterators/admm.py +11 -9
  47. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optim_iterators/drs.py +10 -8
  48. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optim_iterators/gradient_descent.py +4 -4
  49. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optim_iterators/hqs.py +12 -10
  50. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optim_iterators/optim_iterator.py +4 -4
  51. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optim_iterators/pgd.py +15 -17
  52. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optim_iterators/primal_dual_CP.py +18 -17
  53. deepinv-0.2.0/deepinv/optim/optim_iterators/spectral_methods.py +119 -0
  54. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optimizers.py +73 -51
  55. deepinv-0.2.0/deepinv/optim/prior.py +615 -0
  56. deepinv-0.2.0/deepinv/optim/utils.py +372 -0
  57. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/physics/__init__.py +11 -2
  58. deepinv-0.2.0/deepinv/physics/blur.py +574 -0
  59. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/physics/compressed_sensing.py +29 -7
  60. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/physics/forward.py +305 -75
  61. deepinv-0.2.0/deepinv/physics/functional/__init__.py +21 -0
  62. deepinv-0.2.0/deepinv/physics/functional/convolution.py +293 -0
  63. deepinv-0.2.0/deepinv/physics/functional/downsampling.py +5 -0
  64. deepinv-0.2.0/deepinv/physics/functional/hist.py +228 -0
  65. deepinv-0.2.0/deepinv/physics/functional/interp.py +193 -0
  66. deepinv-0.2.0/deepinv/physics/functional/multiplier.py +40 -0
  67. deepinv-0.2.0/deepinv/physics/functional/product_convolution.py +73 -0
  68. deepinv-0.1.0.dev0/deepinv/physics/tomography.py → deepinv-0.2.0/deepinv/physics/functional/radon.py +24 -61
  69. deepinv-0.2.0/deepinv/physics/generator/__init__.py +9 -0
  70. deepinv-0.2.0/deepinv/physics/generator/base.py +129 -0
  71. deepinv-0.2.0/deepinv/physics/generator/blur.py +633 -0
  72. deepinv-0.2.0/deepinv/physics/generator/mri.py +73 -0
  73. deepinv-0.2.0/deepinv/physics/generator/noise.py +57 -0
  74. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/physics/haze.py +3 -3
  75. deepinv-0.2.0/deepinv/physics/inpainting.py +88 -0
  76. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/physics/lidar.py +4 -4
  77. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/physics/mri.py +26 -71
  78. deepinv-0.2.0/deepinv/physics/noise.py +301 -0
  79. deepinv-0.2.0/deepinv/physics/phase_retrieval.py +136 -0
  80. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/physics/range.py +16 -1
  81. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/physics/remote_sensing.py +37 -27
  82. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/physics/singlepixel.py +21 -4
  83. deepinv-0.2.0/deepinv/physics/tomography.py +102 -0
  84. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/sampling/__init__.py +1 -1
  85. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/sampling/diffusion.py +51 -1
  86. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/sampling/langevin.py +4 -4
  87. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/tests/conftest.py +3 -1
  88. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/tests/dummy_datasets/datasets.py +3 -0
  89. deepinv-0.2.0/deepinv/tests/test_generators.py +345 -0
  90. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/tests/test_loss.py +7 -9
  91. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/tests/test_loss_train.py +26 -21
  92. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/tests/test_models.py +198 -13
  93. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/tests/test_optim.py +156 -42
  94. deepinv-0.2.0/deepinv/tests/test_physics.py +549 -0
  95. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/tests/test_sampling.py +1 -1
  96. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/tests/test_unfolded.py +4 -4
  97. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/tests/test_utils.py +39 -5
  98. deepinv-0.2.0/deepinv/training/__init__.py +2 -0
  99. deepinv-0.2.0/deepinv/training/testing.py +212 -0
  100. deepinv-0.2.0/deepinv/training/trainer.py +756 -0
  101. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/transform/__init__.py +1 -0
  102. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/transform/rotate.py +8 -2
  103. deepinv-0.2.0/deepinv/transform/scale.py +74 -0
  104. deepinv-0.2.0/deepinv/transform/shift.py +46 -0
  105. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/unfolded/deep_equilibrium.py +1 -1
  106. deepinv-0.2.0/deepinv/unfolded/unfolded.py +162 -0
  107. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/utils/__init__.py +3 -1
  108. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/utils/demo.py +33 -3
  109. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/utils/logger.py +20 -7
  110. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/utils/metric.py +28 -13
  111. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/utils/nn.py +75 -20
  112. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/utils/optimization.py +1 -0
  113. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/utils/parameters.py +4 -21
  114. deepinv-0.2.0/deepinv/utils/patch_extractor.py +58 -0
  115. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/utils/plotting.py +187 -21
  116. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv.egg-info/PKG-INFO +33 -33
  117. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv.egg-info/SOURCES.txt +28 -1
  118. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv.egg-info/requires.txt +7 -1
  119. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/pyproject.toml +12 -10
  120. deepinv-0.1.0.dev0/deepinv/loss/metric.py +0 -125
  121. deepinv-0.1.0.dev0/deepinv/models/tv.py +0 -146
  122. deepinv-0.1.0.dev0/deepinv/models/utils.py +0 -22
  123. deepinv-0.1.0.dev0/deepinv/models/wavdict.py +0 -231
  124. deepinv-0.1.0.dev0/deepinv/optim/__init__.py +0 -5
  125. deepinv-0.1.0.dev0/deepinv/optim/prior.py +0 -288
  126. deepinv-0.1.0.dev0/deepinv/optim/utils.py +0 -80
  127. deepinv-0.1.0.dev0/deepinv/physics/blur.py +0 -544
  128. deepinv-0.1.0.dev0/deepinv/physics/inpainting.py +0 -48
  129. deepinv-0.1.0.dev0/deepinv/physics/noise.py +0 -180
  130. deepinv-0.1.0.dev0/deepinv/tests/test_physics.py +0 -316
  131. deepinv-0.1.0.dev0/deepinv/training_utils.py +0 -529
  132. deepinv-0.1.0.dev0/deepinv/transform/shift.py +0 -26
  133. deepinv-0.1.0.dev0/deepinv/unfolded/unfolded.py +0 -87
  134. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/LICENSE +0 -0
  135. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/ae.py +0 -0
  136. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/models/dncnn.py +0 -0
  137. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/optim/optim_iterators/utils.py +0 -0
  138. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/sampling/utils.py +0 -0
  139. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/unfolded/__init__.py +0 -0
  140. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv/utils/phantoms.py +0 -0
  141. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv.egg-info/dependency_links.txt +0 -0
  142. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/deepinv.egg-info/top_level.txt +0 -0
  143. {deepinv-0.1.0.dev0 → deepinv-0.2.0}/setup.cfg +0 -0
@@ -1,8 +1,7 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: deepinv
3
- Version: 0.1.0.dev0
3
+ Version: 0.2.0
4
4
  Summary: Pytorch library for solving inverse problems with deep learning
5
- Author: Matthieu Terris, Samuel Hurault, Dongdong Chen
6
5
  Author-email: Julian Tachella <tachellajulian@gmail.com>
7
6
  License: BSD 3-Clause
8
7
  Project-URL: Homepage, https://deepinv.github.io/
@@ -15,11 +14,6 @@ Classifier: Intended Audience :: Science/Research
15
14
  Classifier: License :: OSI Approved :: BSD License
16
15
  Classifier: Operating System :: OS Independent
17
16
  Classifier: Programming Language :: Python :: 3
18
- Classifier: Programming Language :: Python :: 3.8
19
- Classifier: Programming Language :: Python :: 3.9
20
- Classifier: Programming Language :: Python :: 3.10
21
- Classifier: Programming Language :: Python :: 3.11
22
- Classifier: Programming Language :: Python :: 3.12
23
17
  Classifier: Topic :: Utilities
24
18
  Classifier: Topic :: Scientific/Engineering
25
19
  Classifier: Topic :: Software Development :: Libraries
@@ -29,11 +23,11 @@ License-File: LICENSE
29
23
  Requires-Dist: numpy
30
24
  Requires-Dist: matplotlib
31
25
  Requires-Dist: hdf5storage
26
+ Requires-Dist: tqdm
32
27
  Requires-Dist: torch
33
28
  Requires-Dist: torchvision
34
29
  Requires-Dist: einops
35
30
  Requires-Dist: wandb
36
- Requires-Dist: fastmri
37
31
  Provides-Extra: test
38
32
  Requires-Dist: pytest; extra == "test"
39
33
  Requires-Dist: pytest-cov; extra == "test"
@@ -43,9 +37,15 @@ Requires-Dist: sphinx; extra == "doc"
43
37
  Requires-Dist: sphinx_gallery; extra == "doc"
44
38
  Requires-Dist: sphinx_rtd_theme; extra == "doc"
45
39
  Requires-Dist: sphinxemoji; extra == "doc"
40
+ Requires-Dist: sphinx_copybutton; extra == "doc"
41
+ Requires-Dist: sphinx_autoapi; extra == "doc"
46
42
  Provides-Extra: denoisers
47
43
  Requires-Dist: bm3d; extra == "denoisers"
48
44
  Requires-Dist: timm; extra == "denoisers"
45
+ Requires-Dist: PyWavelets; extra == "denoisers"
46
+ Requires-Dist: ptwt; extra == "denoisers"
47
+ Requires-Dist: FrEIA; extra == "denoisers"
48
+ Requires-Dist: pyiqa; extra == "denoisers"
49
49
 
50
50
  .. image:: https://github.com/deepinv/deepinv/raw/main/docs/source/figures/deepinv_logolarge.png
51
51
  :width: 500px
@@ -58,18 +58,18 @@ Requires-Dist: timm; extra == "denoisers"
58
58
 
59
59
  Introduction
60
60
  ------------
61
- Deep Inverse is an open-source pytorch library for solving imaging inverse problems using deep learning. The goal of ``deepinv`` is to accelerate the development of deep learning based methods for imaging inverse problems, by combining popular learning-based reconstruction approaches in a common and simplified framework, standarizing forward imaging models and simplifying the creation of imaging datasets.
61
+ Deep Inverse is an open-source pytorch library for solving imaging inverse problems using deep learning. The goal of ``deepinv`` is to accelerate the development of deep learning based methods for imaging inverse problems, by combining popular learning-based reconstruction approaches in a common and simplified framework, standardizing forward imaging models and simplifying the creation of imaging datasets.
62
62
 
63
- With ``deepinv`` you can:
63
+ ``deepinv`` features
64
64
 
65
65
 
66
- * Large collection of `predefined imaging operators <https://deepinv.github.io/deepinv/deepinv.physics.html>`_ (MRI, CT, deblurring, inpainting, etc.)
67
- * `Training losses <https://deepinv.github.io/deepinv/deepinv.loss.html>`_ for inverse problems (self-supervised learning, regularization, etc.).
68
- * Many `pretrained deep denoisers <https://deepinv.github.io/deepinv/deepinv.models.html>`_ which can be used for `plug-and-play restoration <https://deepinv.github.io/deepinv/deepinv.pnp.html>`_.
69
- * Framework for `building datasets <https://deepinv.github.io/deepinv/deepinv.datasets.html>`_ for inverse problems.
70
- * Easy-to-build `unfolded architectures <https://deepinv.github.io/deepinv/deepinv.unfolded.html>`_ (ADMM, forward-backward, deep equilibrium, etc.).
71
- * `Sampling algorithms <https://deepinv.github.io/deepinv/deepinv.sampling.html>`_ for uncertainty quantification (Langevin, diffusion, etc.).
72
- * A large number of well-explained `examples <https://deepinv.github.io/deepinv/auto_examples/index.html>`_, from basics to state-of-the-art methods.
66
+ * A large collection of `predefined imaging operators <https://deepinv.github.io/deepinv/deepinv.physics.html>`_ (MRI, CT, deblurring, inpainting, etc.)
67
+ * `Training losses <https://deepinv.github.io/deepinv/deepinv.loss.html>`_ for inverse problems (self-supervised learning, regularization, etc.)
68
+ * Many `pretrained deep denoisers <https://deepinv.github.io/deepinv/deepinv.models.html>`_ which can be used for `plug-and-play restoration <https://deepinv.github.io/deepinv/deepinv.pnp.html>`_
69
+ * A framework for `building datasets <https://deepinv.github.io/deepinv/deepinv.datasets.html>`_ for inverse problems
70
+ * Easy-to-build `unfolded architectures <https://deepinv.github.io/deepinv/deepinv.unfolded.html>`_ (ADMM, forward-backward, deep equilibrium, etc.)
71
+ * `Sampling algorithms <https://deepinv.github.io/deepinv/deepinv.sampling.html>`_ for uncertainty quantification (Langevin, diffusion, etc.)
72
+ * A large number of well-explained `examples <https://deepinv.github.io/deepinv/auto_examples/index.html>`_, from basics to state-of-the-art methods
73
73
 
74
74
  .. image:: https://github.com/deepinv/deepinv/raw/main/docs/source/figures/deepinv_schematic.png
75
75
  :width: 1000px
@@ -103,22 +103,22 @@ Try out the following plug-and-play image inpainting example:
103
103
 
104
104
  .. code-block:: python
105
105
 
106
- import deepinv as dinv
107
- from deepinv.utils import load_url_image
108
-
109
- url = ("https://huggingface.co/datasets/deepinv/images/resolve/main/cameraman.png?download=true")
110
- x = load_url_image(url=url, img_size=512, grayscale=True, device='cpu')
111
-
112
- physics = dinv.physics.Inpainting((1, 512, 512), mask = 0.5, \
113
- noise_model=dinv.physics.GaussianNoise(sigma=0.01))
114
-
115
- data_fidelity = dinv.optim.data_fidelity.L2()
116
- prior = dinv.optim.prior.PnP(denoiser=dinv.models.MedianFilter())
117
- model = dinv.optim.optim_builder(iteration="HQS", prior=prior, data_fidelity=data_fidelity, \
118
- params_algo={"stepsize": 1.0, "g_param": 0.1, "lambda": 2.})
119
- y = physics(x)
120
- x_hat = model(y, physics)
121
- dinv.utils.plot([x, y, x_hat], ["signal", "measurement", "estimate"], rescale_mode='clip')
106
+ >>> import deepinv as dinv
107
+ >>> from deepinv.utils import load_url_image
108
+ >>>
109
+ >>> url = ("https://huggingface.co/datasets/deepinv/images/resolve/main/cameraman.png?download=true")
110
+ >>> x = load_url_image(url=url, img_size=512, grayscale=True, device='cpu')
111
+ >>>
112
+ >>> physics = dinv.physics.Inpainting((1, 512, 512), mask = 0.5, \
113
+ >>> noise_model=dinv.physics.GaussianNoise(sigma=0.01))
114
+ >>>
115
+ >>> data_fidelity = dinv.optim.data_fidelity.L2()
116
+ >>> prior = dinv.optim.prior.PnP(denoiser=dinv.models.MedianFilter())
117
+ >>> model = dinv.optim.optim_builder(iteration="HQS", prior=prior, data_fidelity=data_fidelity, \
118
+ >>> params_algo={"stepsize": 1.0, "g_param": 0.1})
119
+ >>> y = physics(x)
120
+ >>> x_hat = model(y, physics)
121
+ >>> dinv.utils.plot([x, y, x_hat], ["signal", "measurement", "estimate"], rescale_mode='clip')
122
122
 
123
123
 
124
124
  Also try out `one of the examples <https://deepinv.github.io/deepinv/auto_examples/index.html>`_ to get started.
@@ -9,18 +9,18 @@
9
9
 
10
10
  Introduction
11
11
  ------------
12
- Deep Inverse is an open-source pytorch library for solving imaging inverse problems using deep learning. The goal of ``deepinv`` is to accelerate the development of deep learning based methods for imaging inverse problems, by combining popular learning-based reconstruction approaches in a common and simplified framework, standarizing forward imaging models and simplifying the creation of imaging datasets.
12
+ Deep Inverse is an open-source pytorch library for solving imaging inverse problems using deep learning. The goal of ``deepinv`` is to accelerate the development of deep learning based methods for imaging inverse problems, by combining popular learning-based reconstruction approaches in a common and simplified framework, standardizing forward imaging models and simplifying the creation of imaging datasets.
13
13
 
14
- With ``deepinv`` you can:
14
+ ``deepinv`` features
15
15
 
16
16
 
17
- * Large collection of `predefined imaging operators <https://deepinv.github.io/deepinv/deepinv.physics.html>`_ (MRI, CT, deblurring, inpainting, etc.)
18
- * `Training losses <https://deepinv.github.io/deepinv/deepinv.loss.html>`_ for inverse problems (self-supervised learning, regularization, etc.).
19
- * Many `pretrained deep denoisers <https://deepinv.github.io/deepinv/deepinv.models.html>`_ which can be used for `plug-and-play restoration <https://deepinv.github.io/deepinv/deepinv.pnp.html>`_.
20
- * Framework for `building datasets <https://deepinv.github.io/deepinv/deepinv.datasets.html>`_ for inverse problems.
21
- * Easy-to-build `unfolded architectures <https://deepinv.github.io/deepinv/deepinv.unfolded.html>`_ (ADMM, forward-backward, deep equilibrium, etc.).
22
- * `Sampling algorithms <https://deepinv.github.io/deepinv/deepinv.sampling.html>`_ for uncertainty quantification (Langevin, diffusion, etc.).
23
- * A large number of well-explained `examples <https://deepinv.github.io/deepinv/auto_examples/index.html>`_, from basics to state-of-the-art methods.
17
+ * A large collection of `predefined imaging operators <https://deepinv.github.io/deepinv/deepinv.physics.html>`_ (MRI, CT, deblurring, inpainting, etc.)
18
+ * `Training losses <https://deepinv.github.io/deepinv/deepinv.loss.html>`_ for inverse problems (self-supervised learning, regularization, etc.)
19
+ * Many `pretrained deep denoisers <https://deepinv.github.io/deepinv/deepinv.models.html>`_ which can be used for `plug-and-play restoration <https://deepinv.github.io/deepinv/deepinv.pnp.html>`_
20
+ * A framework for `building datasets <https://deepinv.github.io/deepinv/deepinv.datasets.html>`_ for inverse problems
21
+ * Easy-to-build `unfolded architectures <https://deepinv.github.io/deepinv/deepinv.unfolded.html>`_ (ADMM, forward-backward, deep equilibrium, etc.)
22
+ * `Sampling algorithms <https://deepinv.github.io/deepinv/deepinv.sampling.html>`_ for uncertainty quantification (Langevin, diffusion, etc.)
23
+ * A large number of well-explained `examples <https://deepinv.github.io/deepinv/auto_examples/index.html>`_, from basics to state-of-the-art methods
24
24
 
25
25
  .. image:: https://github.com/deepinv/deepinv/raw/main/docs/source/figures/deepinv_schematic.png
26
26
  :width: 1000px
@@ -54,22 +54,22 @@ Try out the following plug-and-play image inpainting example:
54
54
 
55
55
  .. code-block:: python
56
56
 
57
- import deepinv as dinv
58
- from deepinv.utils import load_url_image
59
-
60
- url = ("https://huggingface.co/datasets/deepinv/images/resolve/main/cameraman.png?download=true")
61
- x = load_url_image(url=url, img_size=512, grayscale=True, device='cpu')
62
-
63
- physics = dinv.physics.Inpainting((1, 512, 512), mask = 0.5, \
64
- noise_model=dinv.physics.GaussianNoise(sigma=0.01))
65
-
66
- data_fidelity = dinv.optim.data_fidelity.L2()
67
- prior = dinv.optim.prior.PnP(denoiser=dinv.models.MedianFilter())
68
- model = dinv.optim.optim_builder(iteration="HQS", prior=prior, data_fidelity=data_fidelity, \
69
- params_algo={"stepsize": 1.0, "g_param": 0.1, "lambda": 2.})
70
- y = physics(x)
71
- x_hat = model(y, physics)
72
- dinv.utils.plot([x, y, x_hat], ["signal", "measurement", "estimate"], rescale_mode='clip')
57
+ >>> import deepinv as dinv
58
+ >>> from deepinv.utils import load_url_image
59
+ >>>
60
+ >>> url = ("https://huggingface.co/datasets/deepinv/images/resolve/main/cameraman.png?download=true")
61
+ >>> x = load_url_image(url=url, img_size=512, grayscale=True, device='cpu')
62
+ >>>
63
+ >>> physics = dinv.physics.Inpainting((1, 512, 512), mask = 0.5, \
64
+ >>> noise_model=dinv.physics.GaussianNoise(sigma=0.01))
65
+ >>>
66
+ >>> data_fidelity = dinv.optim.data_fidelity.L2()
67
+ >>> prior = dinv.optim.prior.PnP(denoiser=dinv.models.MedianFilter())
68
+ >>> model = dinv.optim.optim_builder(iteration="HQS", prior=prior, data_fidelity=data_fidelity, \
69
+ >>> params_algo={"stepsize": 1.0, "g_param": 0.1})
70
+ >>> y = physics(x)
71
+ >>> x_hat = model(y, physics)
72
+ >>> dinv.utils.plot([x, y, x_hat], ["signal", "measurement", "estimate"], rescale_mode='clip')
73
73
 
74
74
 
75
75
  Also try out `one of the examples <https://deepinv.github.io/deepinv/auto_examples/index.html>`_ to get started.
@@ -4,14 +4,12 @@ __all__ = [
4
4
  "__url__",
5
5
  "__version__",
6
6
  "__author__",
7
- "__email__",
8
7
  "__license__",
9
8
  ]
10
9
 
11
10
  __title__ = "deepinv"
12
11
  __summary__ = "Deep Learning for Inverse Problems Library for PyTorch"
13
- __url__ = "https://github.com/edongdongchen/deepinv"
14
- __version__ = "0.0.1"
15
- __author__ = "Dongdong Chen, Julian Tachella"
16
- __email__ = "echendongdong@gmail.com"
12
+ __version__ = "0.2.0"
13
+ __author__ = "Julian Tachella, Samuel Hurault, Matthieu Terris, Dongdong Chen"
17
14
  __license__ = "BSD 3-Clause Clear"
15
+ __url__ = "https://deepinv.github.io/"
@@ -7,7 +7,6 @@ __all__ = [
7
7
  "__url__",
8
8
  "__version__",
9
9
  "__author__",
10
- "__email__",
11
10
  "__license__",
12
11
  ]
13
12
 
@@ -55,17 +54,9 @@ from deepinv import unfolded
55
54
 
56
55
  __all__ += ["unfolded"]
57
56
 
58
- from deepinv.training_utils import train, test
57
+ from deepinv.training import train, test, Trainer
58
+
59
+ __all__ += ["training"]
59
60
 
60
61
  # GLOBAL PROPERTY
61
62
  dtype = torch.float
62
-
63
- # if torch.cuda.is_available():
64
- # try:
65
- # free_gpu_id = get_freer_gpu()
66
- # device = torch.device(f"cuda:{free_gpu_id}")
67
- # except:
68
- # device = torch.device("cuda")
69
- # print("unable to get GPU info")
70
- # else:
71
- # device = "cpu"
@@ -1 +1,2 @@
1
1
  from .datagenerator import generate_dataset, HDF5Dataset
2
+ from .patch_dataset import PatchDataset
@@ -56,6 +56,8 @@ def generate_dataset(
56
56
  batch_size=4,
57
57
  num_workers=0,
58
58
  supervised=True,
59
+ verbose=True,
60
+ show_progress_bar=False,
59
61
  ):
60
62
  r"""
61
63
  Generates dataset of signal/measurement pairs from base dataset.
@@ -88,6 +90,9 @@ def generate_dataset(
88
90
  :param bool supervised: Generates supervised pairs (x,y) of measurements and signals.
89
91
  If set to ``False``, it will generate a training dataset with measurements only (y)
90
92
  and a test dataset with pairs (x,y)
93
+ :param bool verbose: Output progress information in the console.
94
+ :param bool show_progress_bar: Show progress bar during the generation
95
+ of the dataset (if verbose is set to True).
91
96
 
92
97
  """
93
98
  if os.path.exists(os.path.join(save_dir, dataset_filename)):
@@ -152,17 +157,14 @@ def generate_dataset(
152
157
  if supervised:
153
158
  hf.create_dataset("x_train", (n_train_g,) + x.shape[1:], dtype="float")
154
159
 
155
- if G > 1:
156
- print(
157
- f"Computing train measurement vectors from base dataset of operator {g + 1} out of {G}..."
158
- )
159
- else:
160
- print("Computing train measurement vectors from base dataset...")
161
-
162
160
  index = 0
163
161
 
164
162
  epochs = int(n_train_g / len(train_dataset)) + 1
165
- for e in tqdm(range(epochs)):
163
+ for e in (progress_bar := tqdm(range(epochs), ncols=150, disable=(not verbose or not show_progress_bar))):
164
+
165
+ desc = f"Generating dataset operator {g + 1}" if G > 1 else "Generating train dataset"
166
+ progress_bar.set_description(desc)
167
+
166
168
  train_dataloader = DataLoader(
167
169
  Subset(
168
170
  train_dataset,
@@ -173,7 +175,11 @@ def generate_dataset(
173
175
  pin_memory=False if device == "cpu" else True,
174
176
  )
175
177
 
176
- for i, x in enumerate(train_dataloader):
178
+ batches = len(train_dataloader) - int(train_dataloader.drop_last)
179
+ iterator = iter(train_dataloader)
180
+ for _ in range(batches):
181
+
182
+ x = next(iterator)
177
183
  x = x[0] if isinstance(x, list) or isinstance(x, tuple) else x
178
184
  x = x.to(device)
179
185
 
@@ -186,9 +192,9 @@ def generate_dataset(
186
192
  if bsize + index > n_train_g:
187
193
  bsize = n_train_g - index
188
194
 
189
- hf["y_train"][index : index + bsize] = y[:bsize, :].to("cpu").numpy()
195
+ hf["y_train"][index: index + bsize] = y[:bsize, :].to("cpu").numpy()
190
196
  if supervised:
191
- hf["x_train"][index : index + bsize] = (
197
+ hf["x_train"][index: index + bsize] = (
192
198
  x[:bsize, :, :, :].to("cpu").numpy()
193
199
  )
194
200
  index = index + bsize
@@ -204,14 +210,11 @@ def generate_dataset(
204
210
  pin_memory=True,
205
211
  )
206
212
 
207
- if G > 1:
208
- print(
209
- f"Computing test measurement vectors from base dataset of operator {g + 1} out of {G}..."
210
- )
211
- else:
212
- print("Computing test measurement vectors from base dataset...")
213
+ batches = len(test_dataloader) - int(test_dataloader.drop_last)
214
+ iterator = iter(test_dataloader)
215
+ for i in range(batches):
213
216
 
214
- for i, x in enumerate(tqdm(test_dataloader)):
217
+ x = next(iterator)
215
218
  x = x[0] if isinstance(x, list) or isinstance(x, tuple) else x
216
219
  x = x.to(device)
217
220
 
@@ -228,8 +231,8 @@ def generate_dataset(
228
231
 
229
232
  # Add new data to it
230
233
  bsize = x.size()[0]
231
- hf["x_test"][index : index + bsize] = x.to("cpu").numpy()
232
- hf["y_test"][index : index + bsize] = y.to("cpu").numpy()
234
+ hf["x_test"][index: index + bsize] = x.to("cpu").numpy()
235
+ hf["y_test"][index: index + bsize] = y.to("cpu").numpy()
233
236
  index = index + bsize
234
237
  hf.close()
235
238
 
@@ -0,0 +1,29 @@
1
+ from torch.utils import data
2
+
3
+
4
+ class PatchDataset(data.Dataset):
5
+ r"""
6
+ Builds the dataset of all patches from a tensor of images.
7
+
8
+ :param torch.Tensor imgs: Tensor of images, size: batch size x channels x height x width
9
+ :param int patch_size: size of patches
10
+ :param callable: data augmentation. callable object, None for no augmentation.
11
+ """
12
+ def __init__(self, imgs, patch_size=6, transforms=None):
13
+ self.imgs = imgs
14
+ self.patch_size = patch_size
15
+ self.patches_per_image = (self.imgs.shape[2]-patch_size+1)*(self.imgs.shape[3]-patch_size+1)
16
+ self.transforms = transforms
17
+
18
+ def __len__(self):
19
+ return self.imgs.shape[0]*self.patches_per_image
20
+
21
+ def __getitem__(self, idx):
22
+ idx_img = idx // self.patches_per_image
23
+ idx_in_img = idx % self.patches_per_image
24
+ idx_x = idx_in_img // (self.imgs.shape[3]-self.patch_size+1)
25
+ idx_y = idx_in_img % (self.imgs.shape[3]-self.patch_size+1)
26
+ patch = self.imgs[idx_img, :, idx_x:idx_x+self.patch_size, idx_y:idx_y+self.patch_size]
27
+ if self.transforms and False:
28
+ patch = self.transforms(patch)
29
+ return patch.reshape(-1), idx
@@ -4,7 +4,9 @@ from deepinv.loss.moi import MOILoss
4
4
  from deepinv.loss.sup import SupLoss
5
5
  from deepinv.loss.score import ScoreLoss
6
6
  from deepinv.loss.tv import TVLoss
7
+ from deepinv.loss.r2r import R2RLoss
7
8
  from deepinv.loss.sure import SureGaussianLoss, SurePoissonLoss, SurePGLoss
8
9
  from deepinv.loss.regularisers import JacobianSpectralNorm, FNEJacobianSpectralNorm
9
10
  from deepinv.loss.measplit import SplittingLoss, Neighbor2Neighbor
10
- from deepinv.loss.metric import LpNorm, CharbonnierLoss
11
+ from deepinv.loss.metric import LpNorm, PSNR, SSIM, LPIPS, NIQE
12
+ from deepinv.loss.loss import Loss
@@ -1,8 +1,9 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ from deepinv.loss.loss import Loss
3
4
 
4
5
 
5
- class EILoss(nn.Module):
6
+ class EILoss(Loss):
6
7
  r"""
7
8
  Equivariant imaging self-supervised loss.
8
9
 
@@ -0,0 +1,29 @@
1
+ import torch
2
+
3
+
4
+ class Loss(torch.nn.Module):
5
+ r"""
6
+ Base class for all loss/metric functions.
7
+
8
+ Sets a template for the loss functions, whose forward method must follow the input parameters in
9
+ :meth:`deepinv.loss.Loss.forward`.
10
+ """
11
+
12
+ def __init__(self):
13
+ super(Loss, self).__init__()
14
+
15
+ def forward(self, x_net, x, y, physics, model, **kwargs):
16
+ r"""
17
+ Computes the loss.
18
+
19
+ :param torch.Tensor x_net: Reconstructed image :math:`\inverse{y}`.
20
+ :param torch.Tensor x: Reference image.
21
+ :param torch.Tensor y: Measurement.
22
+ :param deepinv.physics.Physics physics: Forward operator associated with the measurements.
23
+ :param torch.nn.Module model: Reconstruction function.
24
+
25
+ :return: (torch.Tensor) loss, the tensor size might be (1,) or (batch size,).
26
+ """
27
+ raise NotImplementedError(
28
+ "The method 'forward' must be implemented in the subclass."
29
+ )
@@ -1,8 +1,8 @@
1
1
  import torch
2
- import torch.nn as nn
2
+ from deepinv.loss.loss import Loss
3
3
 
4
4
 
5
- class MCLoss(nn.Module):
5
+ class MCLoss(Loss):
6
6
  r"""
7
7
  Measurement consistency loss
8
8
 
@@ -1,9 +1,10 @@
1
1
  import torch
2
2
  from deepinv.physics import Inpainting
3
3
  import numpy as np
4
+ from deepinv.loss.loss import Loss
4
5
 
5
6
 
6
- class SplittingLoss(torch.nn.Module):
7
+ class SplittingLoss(Loss):
7
8
  r"""
8
9
  Measurement splitting loss.
9
10
 
@@ -13,7 +14,7 @@ class SplittingLoss(torch.nn.Module):
13
14
 
14
15
  .. math::
15
16
 
16
- \frac{m}{m_2}\| y_2 - A_2 \inversef{y_1,A_1}\|^2
17
+ \frac{m}{m_2}\| y_2 - A_2 \inversef{y_1}{A_1}\|^2
17
18
 
18
19
  where :math:`R` is the trainable network. See https://pubmed.ncbi.nlm.nih.gov/32614100/.
19
20
 
@@ -55,8 +56,8 @@ class SplittingLoss(torch.nn.Module):
55
56
  mask[..., start::stride, start::stride] = 0.0
56
57
 
57
58
  # create inpainting masks
58
- inp = Inpainting(tsize, mask)
59
- inp2 = Inpainting(tsize, 1 - mask)
59
+ inp = Inpainting(tsize, mask, device=y.device)
60
+ inp2 = Inpainting(tsize, 1 - mask, device=y.device)
60
61
 
61
62
  # concatenate operators
62
63
  physics1 = inp * physics # A_1 = P*A
@@ -72,7 +73,7 @@ class SplittingLoss(torch.nn.Module):
72
73
  return loss_ms
73
74
 
74
75
 
75
- class Neighbor2Neighbor(torch.nn.Module):
76
+ class Neighbor2Neighbor(Loss):
76
77
  r"""
77
78
  Neighbor2Neighbor loss.
78
79