deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,56 @@
1
+ # import DeepInv
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class ArtifactRemoval(nn.Module):
7
+ r"""
8
+ Artifact removal architecture :math:`\phi(A^{\top}y)`.
9
+
10
+ The architecture is inspired by the FBPConvNet approach of https://arxiv.org/pdf/1611.03679
11
+ where a deep network :math:`\phi` is used to improve the linear reconstruction :math:`A^{\top}y`.
12
+
13
+ :param torch.nn.Module backbone_net: Base network :math:`\phi`, can be pretrained or not.
14
+ :param bool pinv: If ``True`` uses pseudo-inverse :math:`A^{\dagger}y` instead of the default transpose.
15
+ :param torch.device device: cpu or gpu.
16
+ """
17
+
18
+ def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None):
19
+ super(ArtifactRemoval, self).__init__()
20
+ self.pinv = pinv
21
+ self.backbone_net = backbone_net
22
+
23
+ if ckpt_path is not None:
24
+ self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
25
+ self.backbone_net.eval()
26
+
27
+ if type(self.backbone_net).__name__ == "UNetRes":
28
+ for _, v in self.backbone_net.named_parameters():
29
+ v.requires_grad = False
30
+ self.backbone_net = self.backbone_net.to(device)
31
+
32
+ def forward(self, y, physics, **kwargs):
33
+ r"""
34
+ Reconstructs a signal estimate from measurements y
35
+
36
+ :param torch.tensor y: measurements
37
+ :param deepinv.physics.Physics physics: forward operator
38
+ """
39
+ if isinstance(physics, nn.DataParallel):
40
+ physics = physics.module
41
+
42
+ y_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
43
+ if type(self.backbone_net).__name__ == "UNetRes":
44
+ noise_level_map = (
45
+ torch.FloatTensor(y_in.size(0), 1, y_in.size(2), y_in.size(3))
46
+ .fill_(kwargs["sigma"])
47
+ .to(y_in.dtype)
48
+ )
49
+ y_in = torch.cat((y_in, noise_level_map), 1)
50
+
51
+ if hasattr(physics.noise_model, "sigma"):
52
+ sigma = physics.noise_model.sigma
53
+ else:
54
+ sigma = None
55
+
56
+ return self.backbone_net(y_in, sigma)
deepinv/models/bm3d.py ADDED
@@ -0,0 +1,57 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ # Compat for optional dependency on BM3D
6
+ try:
7
+ import bm3d
8
+ except:
9
+ bm3d = ImportError("The bm3d package is not installed.")
10
+
11
+
12
+ class BM3D(nn.Module):
13
+ """
14
+ BM3D denoiser.
15
+
16
+
17
+ This module wraps the BM3D denoiser from the `BM3D python package <https://pypi.org/project/bm3d/>`_.
18
+ The denoiser is applied sequentially to each noisy image in the batch.
19
+
20
+ The BM3D denoiser was introduced in "Image denoising by sparse 3D transform-domain collaborative filtering", by
21
+ Davob et al., IEEE Transactions on Image Processing (2007).
22
+
23
+
24
+ """
25
+
26
+ def __init__(self):
27
+ super().__init__()
28
+ if isinstance(bm3d, ImportError):
29
+ raise ImportError(
30
+ "BM3D denoiser not available. Please install the bm3d package with `pip install bm3d`."
31
+ ) from bm3d
32
+
33
+ def forward(self, x, sigma):
34
+ r"""
35
+ Run the denoiser on image with noise level :math:`\sigma`.
36
+
37
+ :param torch.Tensor x: noisy image
38
+ :param float sigma: noise level (not used)
39
+ """
40
+
41
+ out = torch.zeros_like(x)
42
+
43
+ for i in range(x.shape[0]):
44
+ out[i, :, :, :] = array2tensor(
45
+ bm3d.bm3d(tensor2array(x[i, :, :, :]), sigma)
46
+ )
47
+ return out
48
+
49
+
50
+ def tensor2array(img):
51
+ img = img.cpu().detach().numpy()
52
+ img = np.transpose(img, (1, 2, 0))
53
+ return img
54
+
55
+
56
+ def array2tensor(img):
57
+ return torch.from_numpy(img).permute(2, 0, 1)