deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,197 @@
1
+ from deepinv.physics.forward import LinearPhysics
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ def dst1(x):
7
+ r"""
8
+ Orthogonal Discrete Sine Transform, Type I
9
+ The transform is performed across the last dimension of the input signal
10
+ Due to orthogonality we have ``dst1(dst1(x)) = x``.
11
+
12
+ :param torch.tensor x: the input signal
13
+ :return: (torch.tensor) the DST-I of the signal over the last dimension
14
+
15
+ """
16
+ x_shape = x.shape
17
+
18
+ b = int(np.prod(x_shape[:-1]))
19
+ n = x_shape[-1]
20
+ x = x.view(-1, n)
21
+
22
+ z = torch.zeros(b, 1, device=x.device)
23
+ x = torch.cat([z, x, z, -x.flip([1])], dim=1)
24
+ x = torch.view_as_real(torch.fft.rfft(x, norm="ortho"))
25
+ x = x[:, 1:-1, 1]
26
+ return x.view(*x_shape)
27
+
28
+
29
+ class CompressedSensing(LinearPhysics):
30
+ r"""
31
+ Compressed Sensing forward operator. Creates a random sampling :math:`m \times n` matrix where :math:`n` is the
32
+ number of elements of the signal, i.e., ``np.prod(img_shape)`` and ``m`` is the number of measurements.
33
+
34
+ This class generates a random iid Gaussian matrix if ``fast=False``
35
+
36
+ .. math::
37
+
38
+ A_{i,j} \sim \mathcal{N}(0,\frac{1}{m})
39
+
40
+ or a Subsampled Orthogonal with Random Signs matrix (SORS) if ``fast=True`` (see https://arxiv.org/abs/1506.03521)
41
+
42
+ .. math::
43
+
44
+ A = \text{diag}(m)D\text{diag}(s)
45
+
46
+ where :math:`s\in\{-1,1\}^{n}` is a random sign flip with probability 0.5,
47
+ :math:`D\in\mathbb{R}^{n\times n}` is a fast orthogonal transform (DST-1) and
48
+ :math:`\text{diag}(m)\in\mathbb{R}^{m\times n}` is random subsampling matrix, which keeps :math:`m` out of :math:`n` entries.
49
+
50
+ It is recommended to use ``fast=True`` for image sizes bigger than 32 x 32, since the forward computation with
51
+ ``fast=False`` has an :math:`O(mn)` complexity, whereas with ``fast=True`` it has an :math:`O(n \log n)` complexity.
52
+
53
+ An existing operator can be loaded from a saved .pth file via ``self.load_state_dict(save_path)``,
54
+ in a similar fashion to :class:`torch.nn.Module`.
55
+
56
+ .. note::
57
+
58
+ If ``fast=False``, the forward operator has a norm which tends to :math:`(1+\sqrt{n/m})^2` for large :math:`n`
59
+ and :math:`m` due to the `Marcenko-Pastur law
60
+ <https://en.wikipedia.org/wiki/Marchenko%E2%80%93Pastur_distribution>`_.
61
+ If ``fast=True``, the forward operator has a unit norm.
62
+
63
+ :param int m: number of measurements.
64
+ :param tuple img_shape: shape (C, H, W) of inputs.
65
+ :param bool fast: The operator is iid Gaussian if false, otherwise A is a SORS matrix with the Discrete Sine Transform (type I).
66
+ :param bool channelwise: Channels are processed independently using the same random forward operator.
67
+ :param torch.type dtype: Forward matrix is stored as a dtype.
68
+ :param str device: Device to store the forward matrix.
69
+
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ m,
75
+ img_shape,
76
+ fast=False,
77
+ channelwise=False,
78
+ dtype=torch.float,
79
+ device="cpu",
80
+ **kwargs,
81
+ ):
82
+ super().__init__(**kwargs)
83
+ self.name = f"CS_m{m}"
84
+ self.img_shape = img_shape
85
+ self.fast = fast
86
+ self.channelwise = channelwise
87
+ self.dtype = dtype
88
+
89
+ if channelwise:
90
+ n = int(np.prod(img_shape[1:]))
91
+ else:
92
+ n = int(np.prod(img_shape))
93
+
94
+ if self.fast:
95
+ self.n = n
96
+ self.D = torch.ones(self.n, device=device)
97
+ self.D[torch.rand_like(self.D) > 0.5] = -1.0
98
+ self.mask = torch.zeros(self.n, device=device)
99
+ idx = np.sort(np.random.choice(self.n, size=m, replace=False))
100
+ self.mask[torch.from_numpy(idx)] = 1
101
+ self.mask = self.mask.type(torch.bool)
102
+
103
+ self.D = torch.nn.Parameter(self.D, requires_grad=False)
104
+ self.mask = torch.nn.Parameter(self.mask, requires_grad=False)
105
+ else:
106
+ self._A = torch.randn((m, n), device=device) / np.sqrt(m)
107
+ self._A_dagger = torch.linalg.pinv(self._A)
108
+ self._A = torch.nn.Parameter(self._A, requires_grad=False)
109
+ self._A_dagger = torch.nn.Parameter(self._A_dagger, requires_grad=False)
110
+ self._A_adjoint = (
111
+ torch.nn.Parameter(self._A.t(), requires_grad=False)
112
+ .type(dtype)
113
+ .to(device)
114
+ )
115
+
116
+ def A(self, x):
117
+ N, C = x.shape[:2]
118
+ if self.channelwise:
119
+ x = x.reshape(N * C, -1)
120
+ else:
121
+ x = x.reshape(N, -1)
122
+
123
+ if self.fast:
124
+ y = dst1(x * self.D)[:, self.mask]
125
+ else:
126
+ y = torch.einsum("in, mn->im", x, self._A)
127
+
128
+ if self.channelwise:
129
+ y = y.view(N, C, -1)
130
+
131
+ return y
132
+
133
+ def A_adjoint(self, y):
134
+ N = y.shape[0]
135
+ C, H, W = self.img_shape[0], self.img_shape[1], self.img_shape[2]
136
+
137
+ if self.channelwise:
138
+ N2 = N * C
139
+ y = y.view(N2, -1)
140
+ else:
141
+ N2 = N
142
+
143
+ if self.fast:
144
+ y2 = torch.zeros((N2, self.n), device=y.device)
145
+ y2[:, self.mask] = y.type(y2.dtype)
146
+ x = dst1(y2) * self.D
147
+ else:
148
+ x = torch.einsum("im, nm->in", y, self._A_adjoint) # x:(N, n, 1)
149
+
150
+ x = x.view(N, C, H, W)
151
+ return x
152
+
153
+ def A_dagger(self, y):
154
+ if self.fast:
155
+ return self.A_adjoint(y)
156
+ else:
157
+ N = y.shape[0]
158
+ C, H, W = self.img_shape[0], self.img_shape[1], self.img_shape[2]
159
+
160
+ if self.channelwise:
161
+ y = y.reshape(N * C, -1)
162
+
163
+ x = torch.einsum("im, nm->in", y, self._A_dagger)
164
+ x = x.reshape(N, C, H, W)
165
+ return x
166
+
167
+
168
+ # if __name__ == "__main__":
169
+ # device = "cuda:0"
170
+ #
171
+ # # for comparing fast=True and fast=False forward matrices.
172
+ # for i in range(1):
173
+ # n = 2 ** (i + 4)
174
+ # im_size = (1, n, n)
175
+ # m = int(np.prod(im_size))
176
+ # x = torch.randn((1,) + im_size, device=device)
177
+ #
178
+ # print((dst1(dst1(x)) - x).flatten().abs().sum())
179
+ #
180
+ # physics = CompressedSensing(img_shape=im_size, m=m, fast=True, device=device)
181
+ #
182
+ # print((physics.A_adjoint(physics.A(x)) - x).flatten().abs().sum())
183
+ # print(f"adjointness: {physics.adjointness_test(x)}")
184
+ # print(f"norm: {physics.power_method(x, verbose=False)}")
185
+ # start = torch.cuda.Event(enable_timing=True)
186
+ # end = torch.cuda.Event(enable_timing=True)
187
+ # start.record()
188
+ # for j in range(100):
189
+ # y = physics.A(x)
190
+ # xhat = physics.A_dagger(y)
191
+ # end.record()
192
+ #
193
+ # # print((xhat-x).pow(2).flatten().mean())
194
+ #
195
+ # # Waits for everything to finish running
196
+ # torch.cuda.synchronize()
197
+ # print(start.elapsed_time(end))