deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# import DeepInv
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ArtifactRemoval(nn.Module):
|
|
7
|
+
r"""
|
|
8
|
+
Artifact removal architecture :math:`\phi(A^{\top}y)`.
|
|
9
|
+
|
|
10
|
+
The architecture is inspired by the FBPConvNet approach of https://arxiv.org/pdf/1611.03679
|
|
11
|
+
where a deep network :math:`\phi` is used to improve the linear reconstruction :math:`A^{\top}y`.
|
|
12
|
+
|
|
13
|
+
:param torch.nn.Module backbone_net: Base network :math:`\phi`, can be pretrained or not.
|
|
14
|
+
:param bool pinv: If ``True`` uses pseudo-inverse :math:`A^{\dagger}y` instead of the default transpose.
|
|
15
|
+
:param torch.device device: cpu or gpu.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None):
|
|
19
|
+
super(ArtifactRemoval, self).__init__()
|
|
20
|
+
self.pinv = pinv
|
|
21
|
+
self.backbone_net = backbone_net
|
|
22
|
+
|
|
23
|
+
if ckpt_path is not None:
|
|
24
|
+
self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
|
|
25
|
+
self.backbone_net.eval()
|
|
26
|
+
|
|
27
|
+
if type(self.backbone_net).__name__ == "UNetRes":
|
|
28
|
+
for _, v in self.backbone_net.named_parameters():
|
|
29
|
+
v.requires_grad = False
|
|
30
|
+
self.backbone_net = self.backbone_net.to(device)
|
|
31
|
+
|
|
32
|
+
def forward(self, y, physics, **kwargs):
|
|
33
|
+
r"""
|
|
34
|
+
Reconstructs a signal estimate from measurements y
|
|
35
|
+
|
|
36
|
+
:param torch.tensor y: measurements
|
|
37
|
+
:param deepinv.physics.Physics physics: forward operator
|
|
38
|
+
"""
|
|
39
|
+
if isinstance(physics, nn.DataParallel):
|
|
40
|
+
physics = physics.module
|
|
41
|
+
|
|
42
|
+
y_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
|
|
43
|
+
if type(self.backbone_net).__name__ == "UNetRes":
|
|
44
|
+
noise_level_map = (
|
|
45
|
+
torch.FloatTensor(y_in.size(0), 1, y_in.size(2), y_in.size(3))
|
|
46
|
+
.fill_(kwargs["sigma"])
|
|
47
|
+
.to(y_in.dtype)
|
|
48
|
+
)
|
|
49
|
+
y_in = torch.cat((y_in, noise_level_map), 1)
|
|
50
|
+
|
|
51
|
+
if hasattr(physics.noise_model, "sigma"):
|
|
52
|
+
sigma = physics.noise_model.sigma
|
|
53
|
+
else:
|
|
54
|
+
sigma = None
|
|
55
|
+
|
|
56
|
+
return self.backbone_net(y_in, sigma)
|
deepinv/models/bm3d.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
# Compat for optional dependency on BM3D
|
|
6
|
+
try:
|
|
7
|
+
import bm3d
|
|
8
|
+
except:
|
|
9
|
+
bm3d = ImportError("The bm3d package is not installed.")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BM3D(nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
BM3D denoiser.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
This module wraps the BM3D denoiser from the `BM3D python package <https://pypi.org/project/bm3d/>`_.
|
|
18
|
+
The denoiser is applied sequentially to each noisy image in the batch.
|
|
19
|
+
|
|
20
|
+
The BM3D denoiser was introduced in "Image denoising by sparse 3D transform-domain collaborative filtering", by
|
|
21
|
+
Davob et al., IEEE Transactions on Image Processing (2007).
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super().__init__()
|
|
28
|
+
if isinstance(bm3d, ImportError):
|
|
29
|
+
raise ImportError(
|
|
30
|
+
"BM3D denoiser not available. Please install the bm3d package with `pip install bm3d`."
|
|
31
|
+
) from bm3d
|
|
32
|
+
|
|
33
|
+
def forward(self, x, sigma):
|
|
34
|
+
r"""
|
|
35
|
+
Run the denoiser on image with noise level :math:`\sigma`.
|
|
36
|
+
|
|
37
|
+
:param torch.Tensor x: noisy image
|
|
38
|
+
:param float sigma: noise level (not used)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
out = torch.zeros_like(x)
|
|
42
|
+
|
|
43
|
+
for i in range(x.shape[0]):
|
|
44
|
+
out[i, :, :, :] = array2tensor(
|
|
45
|
+
bm3d.bm3d(tensor2array(x[i, :, :, :]), sigma)
|
|
46
|
+
)
|
|
47
|
+
return out
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def tensor2array(img):
|
|
51
|
+
img = img.cpu().detach().numpy()
|
|
52
|
+
img = np.transpose(img, (1, 2, 0))
|
|
53
|
+
return img
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def array2tensor(img):
|
|
57
|
+
return torch.from_numpy(img).permute(2, 0, 1)
|