refcolorrestore 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- refcolorrestore/__init__.py +0 -0
- refcolorrestore/build_dataset.py +61 -0
- refcolorrestore/dataset/__init__.py +3 -0
- refcolorrestore/dataset/abc.py +82 -0
- refcolorrestore/dataset/dataset.py +30 -0
- refcolorrestore/dataset/interp.py +46 -0
- refcolorrestore/prepare.py +28 -0
- refcolorrestore/restore.py +56 -0
- refcolorrestore/shelf/__init__.py +15 -0
- refcolorrestore/shelf/arch.py +168 -0
- refcolorrestore/shelf/archnerf.py +67 -0
- refcolorrestore/shelf/modules.py +238 -0
- refcolorrestore/shelf/srresnet_arch.py +214 -0
- refcolorrestore/shelf/srvgg_arch.py +141 -0
- refcolorrestore/shelf/utils.py +48 -0
- refcolorrestore/train.py +53 -0
- refcolorrestore-1.2.0.dist-info/METADATA +47 -0
- refcolorrestore-1.2.0.dist-info/RECORD +21 -0
- refcolorrestore-1.2.0.dist-info/WHEEL +5 -0
- refcolorrestore-1.2.0.dist-info/licenses/LICENSE +21 -0
- refcolorrestore-1.2.0.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
import torch
|
|
3
|
+
import os
|
|
4
|
+
from gaussian_splatting import GaussianModel
|
|
5
|
+
from gaussian_splatting.prepare import prepare_dataset, prepare_gaussians
|
|
6
|
+
from extrinterp import ExtrinsicInterpolator
|
|
7
|
+
from refcolorrestore.dataset import Extrinsic2DualCameraDataset, DualCamera2RestorationDataset, DualCameraDataset
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def prepare_rendering(
|
|
11
|
+
sh_degree: int, source: str, device: str, n: int, window_size: int,
|
|
12
|
+
trainable_camera: bool = False,
|
|
13
|
+
load_ply: str = None, load_ply_gt: str = None,
|
|
14
|
+
load_camera: str = None,
|
|
15
|
+
use_intrinsics: int | dict = 0
|
|
16
|
+
) -> Tuple[DualCameraDataset, GaussianModel, GaussianModel]:
|
|
17
|
+
dataset = prepare_dataset(source=source, device=device, trainable_camera=trainable_camera, load_camera=load_camera, load_depth=False)
|
|
18
|
+
if isinstance(use_intrinsics, int):
|
|
19
|
+
i = use_intrinsics
|
|
20
|
+
use_intrinsics = dict(
|
|
21
|
+
image_height=dataset[i].image_height, image_width=dataset[i].image_width,
|
|
22
|
+
FoVx=dataset[i].FoVx, FoVy=dataset[i].FoVy)
|
|
23
|
+
elif not isinstance(use_intrinsics, dict):
|
|
24
|
+
raise ValueError("Invalid use_intrinsics format")
|
|
25
|
+
dataset = Extrinsic2DualCameraDataset(ExtrinsicInterpolator(dataset=dataset, n=n, window_size=window_size), **use_intrinsics)
|
|
26
|
+
gaussians = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply)
|
|
27
|
+
gaussians_gt = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply_gt)
|
|
28
|
+
return dataset, gaussians, gaussians_gt
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if __name__ == "__main__":
|
|
32
|
+
from argparse import ArgumentParser
|
|
33
|
+
parser = ArgumentParser()
|
|
34
|
+
parser.add_argument("--sh_degree", default=3, type=int)
|
|
35
|
+
parser.add_argument("-s", "--source", required=True, type=str)
|
|
36
|
+
parser.add_argument("-d", "--destination", required=True, type=str)
|
|
37
|
+
parser.add_argument("-i", "--iteration", required=True, type=int)
|
|
38
|
+
parser.add_argument("--load_camera", default=None, type=str)
|
|
39
|
+
parser.add_argument("--mode", choices=["base", "camera"], default="base")
|
|
40
|
+
parser.add_argument("--device", default="cuda", type=str)
|
|
41
|
+
parser.add_argument("--destination_gt", required=True, type=str)
|
|
42
|
+
parser.add_argument("--iteration_gt", required=True, type=int)
|
|
43
|
+
parser.add_argument("--interp_n", required=True, type=int)
|
|
44
|
+
parser.add_argument("--interp_window_size", type=int, default=3)
|
|
45
|
+
parser.add_argument("--use_intrinsics", type=str, default="0", help="Use intrinsics for rendering, can be an integer index or a dict with keys: image_height, image_width, FoVx, FoVy")
|
|
46
|
+
parser.add_argument("--downsample", default=4, type=int)
|
|
47
|
+
parser.add_argument("--data_dir", required=True, type=str)
|
|
48
|
+
args = parser.parse_args()
|
|
49
|
+
load_ply = os.path.join(args.destination, "point_cloud", "iteration_" + str(args.iteration), "point_cloud.ply")
|
|
50
|
+
load_ply_gt = os.path.join(args.destination_gt, "point_cloud", "iteration_" + str(args.iteration_gt), "point_cloud.ply")
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
cameras, gaussians, gaussians_gt = prepare_rendering(
|
|
53
|
+
sh_degree=args.sh_degree, source=args.source, device=args.device,
|
|
54
|
+
n=args.interp_n, window_size=args.interp_window_size,
|
|
55
|
+
trainable_camera=args.mode == "camera",
|
|
56
|
+
load_ply=load_ply, load_ply_gt=load_ply_gt, load_camera=args.load_camera,
|
|
57
|
+
use_intrinsics=eval(args.use_intrinsics)
|
|
58
|
+
)
|
|
59
|
+
dataset = DualCamera2RestorationDataset(cameras=cameras, color_distorted_gaussians=gaussians, ground_truth_gaussians=gaussians_gt)
|
|
60
|
+
os.makedirs(args.data_dir, exist_ok=True)
|
|
61
|
+
dataset.save_dataset(args.data_dir)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from typing import NamedTuple, Tuple
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torchvision
|
|
8
|
+
|
|
9
|
+
from gaussian_splatting.camera import Camera
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DualCameraDataset:
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def to(self, device) -> 'DualCameraDataset':
|
|
16
|
+
return self
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def __len__(self) -> int:
|
|
20
|
+
raise NotImplementedError
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def __getitem__(self, idx) -> Tuple[Camera, Camera]:
|
|
24
|
+
raise NotImplementedError
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RestorationTuple(NamedTuple):
|
|
28
|
+
color_distorted: torch.FloatTensor
|
|
29
|
+
reference: torch.FloatTensor
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RestorationDataset:
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def to(self, device) -> 'RestorationDataset':
|
|
36
|
+
return self
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def __len__(self) -> int:
|
|
40
|
+
raise NotImplementedError
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def __getitem__(self, idx) -> Tuple[RestorationTuple, torch.FloatTensor]:
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
def save_image_tuple(self, idx, image_dir):
|
|
47
|
+
restoration, gt = self[idx]
|
|
48
|
+
color_distorted_dir = os.path.join(image_dir, 'distorted')
|
|
49
|
+
os.makedirs(color_distorted_dir, exist_ok=True)
|
|
50
|
+
torchvision.utils.save_image(restoration.color_distorted, os.path.join(color_distorted_dir, '{0:05d}'.format(idx) + ".png"))
|
|
51
|
+
reference_dir = os.path.join(image_dir, 'reference')
|
|
52
|
+
os.makedirs(reference_dir, exist_ok=True)
|
|
53
|
+
torchvision.utils.save_image(restoration.reference, os.path.join(reference_dir, '{0:05d}'.format(idx) + ".png"))
|
|
54
|
+
gt_dir = os.path.join(image_dir, 'groundtruth')
|
|
55
|
+
os.makedirs(gt_dir, exist_ok=True)
|
|
56
|
+
torchvision.utils.save_image(gt, os.path.join(gt_dir, '{0:05d}'.format(idx) + ".png"))
|
|
57
|
+
|
|
58
|
+
def save_dataset(self, image_dir):
|
|
59
|
+
os.makedirs(image_dir, exist_ok=True)
|
|
60
|
+
for idx in tqdm(range(len(self)), desc="Saving images"):
|
|
61
|
+
self.save_image_tuple(idx, image_dir)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SavedRestorationDataset(RestorationDataset):
|
|
65
|
+
def __init__(self, data_dir: str):
|
|
66
|
+
self.color_distorted_dir = os.path.join(data_dir, 'distorted')
|
|
67
|
+
self.reference_dir = os.path.join(data_dir, 'reference')
|
|
68
|
+
self.ground_truth_dir = os.path.join(data_dir, 'groundtruth')
|
|
69
|
+
|
|
70
|
+
def __len__(self) -> int:
|
|
71
|
+
n = 0
|
|
72
|
+
while os.path.exists(os.path.join(self.color_distorted_dir, '{0:05d}'.format(n) + ".png")) and \
|
|
73
|
+
os.path.exists(os.path.join(self.reference_dir, '{0:05d}'.format(n) + ".png")) and \
|
|
74
|
+
os.path.exists(os.path.join(self.ground_truth_dir, '{0:05d}'.format(n) + ".png")):
|
|
75
|
+
n += 1
|
|
76
|
+
return n
|
|
77
|
+
|
|
78
|
+
def __getitem__(self, idx) -> Tuple[RestorationTuple, torch.FloatTensor]:
|
|
79
|
+
color_distorted = torchvision.io.read_image(os.path.join(self.color_distorted_dir, '{0:05d}'.format(idx) + ".png")) / 255.0
|
|
80
|
+
reference = torchvision.io.read_image(os.path.join(self.reference_dir, '{0:05d}'.format(idx) + ".png")) / 255.0
|
|
81
|
+
ground_truth = torchvision.io.read_image(os.path.join(self.ground_truth_dir, '{0:05d}'.format(idx) + ".png")) / 255.0
|
|
82
|
+
return RestorationTuple(color_distorted, reference), ground_truth
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from gaussian_splatting import GaussianModel
|
|
6
|
+
|
|
7
|
+
from .abc import DualCameraDataset, RestorationDataset, RestorationTuple
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DualCamera2RestorationDataset(RestorationDataset):
|
|
11
|
+
def __init__(self, cameras: DualCameraDataset, color_distorted_gaussians: GaussianModel, ground_truth_gaussians: GaussianModel):
|
|
12
|
+
self.cameras = cameras
|
|
13
|
+
self.ground_truth_gaussians = ground_truth_gaussians
|
|
14
|
+
self.color_distorted_gaussians = color_distorted_gaussians
|
|
15
|
+
|
|
16
|
+
def to(self, device) -> 'DualCamera2RestorationDataset':
|
|
17
|
+
self.cameras = self.cameras.to(device)
|
|
18
|
+
self.ground_truth_gaussians = self.ground_truth_gaussians.to(device)
|
|
19
|
+
self.color_distorted_gaussians = self.color_distorted_gaussians.to(device)
|
|
20
|
+
return self
|
|
21
|
+
|
|
22
|
+
def __len__(self) -> int:
|
|
23
|
+
return len(self.cameras)
|
|
24
|
+
|
|
25
|
+
def __getitem__(self, idx) -> Tuple[RestorationTuple, torch.FloatTensor]:
|
|
26
|
+
lr_camera, hr_camera = self.cameras[idx]
|
|
27
|
+
color_distorted = self.color_distorted_gaussians(hr_camera)['render']
|
|
28
|
+
ground_truth = self.ground_truth_gaussians(hr_camera)['render']
|
|
29
|
+
reference = self.ground_truth_gaussians(lr_camera)['render']
|
|
30
|
+
return RestorationTuple(color_distorted, reference), ground_truth
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from gaussian_splatting import Camera
|
|
5
|
+
from extrinterp import ExtrinsicDataset
|
|
6
|
+
|
|
7
|
+
from .abc import DualCameraDataset
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Extrinsic2DualCameraDataset(DualCameraDataset):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
base_dataset: ExtrinsicDataset,
|
|
14
|
+
FoVx: float = 90.0*torch.pi/180, FoVy: float = 90.0*torch.pi/180,
|
|
15
|
+
image_height: int = 1000, image_width: int = 1000, downsample=4):
|
|
16
|
+
self.cameras = base_dataset
|
|
17
|
+
self.FoVx = FoVx
|
|
18
|
+
self.FoVy = FoVy
|
|
19
|
+
self.downsample = downsample
|
|
20
|
+
self.image_height = image_height // 4 * 4
|
|
21
|
+
self.image_width = image_width // 4 * 4
|
|
22
|
+
|
|
23
|
+
def __len__(self):
|
|
24
|
+
return len(self.cameras)
|
|
25
|
+
|
|
26
|
+
def __getitem__(self, idx) -> Tuple[Camera, Camera]:
|
|
27
|
+
extr = self.cameras[idx]
|
|
28
|
+
camera_hr = extr.to_camera(
|
|
29
|
+
image_height=self.image_height,
|
|
30
|
+
image_width=self.image_width,
|
|
31
|
+
FoVx=self.FoVx,
|
|
32
|
+
FoVy=self.FoVy,
|
|
33
|
+
device=self.cameras[idx].R.device
|
|
34
|
+
)
|
|
35
|
+
camera_lr = extr.to_camera(
|
|
36
|
+
image_height=self.image_height // self.downsample,
|
|
37
|
+
image_width=self.image_width // self.downsample,
|
|
38
|
+
FoVx=self.FoVx,
|
|
39
|
+
FoVy=self.FoVy,
|
|
40
|
+
device=self.cameras[idx].R.device
|
|
41
|
+
)
|
|
42
|
+
return camera_lr, camera_hr
|
|
43
|
+
|
|
44
|
+
def to(self, device):
|
|
45
|
+
self.cameras = self.cameras.to(device)
|
|
46
|
+
return self
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.utils.data import DataLoader
|
|
4
|
+
|
|
5
|
+
from refcolorrestore.dataset import SavedRestorationDataset
|
|
6
|
+
from refcolorrestore.shelf import model_dict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def load_dataset(
|
|
10
|
+
image_dir: str,
|
|
11
|
+
batch_size: int, num_workers: int, shuffle=False) -> DataLoader:
|
|
12
|
+
dataset = SavedRestorationDataset(data_dir=image_dir)
|
|
13
|
+
loader = DataLoader(
|
|
14
|
+
dataset,
|
|
15
|
+
batch_size=batch_size,
|
|
16
|
+
num_workers=num_workers,
|
|
17
|
+
pin_memory=True,
|
|
18
|
+
shuffle=shuffle,
|
|
19
|
+
drop_last=False
|
|
20
|
+
)
|
|
21
|
+
return loader
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_model(model: str, load_path: str = None, **kwargs) -> nn.Module:
|
|
25
|
+
net = model_dict[model](**kwargs)
|
|
26
|
+
if load_path:
|
|
27
|
+
net.load_state_dict(torch.load(load_path))
|
|
28
|
+
return net
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from torch.utils.data import DataLoader
|
|
5
|
+
import os
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
import torchvision
|
|
8
|
+
from gaussian_splatting.utils import psnr, ssim
|
|
9
|
+
from gaussian_splatting.utils.lpipsPyTorch import lpips
|
|
10
|
+
|
|
11
|
+
from refcolorrestore.shelf import model_dict
|
|
12
|
+
from refcolorrestore.prepare import load_dataset, build_model
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def rendering(net: nn.Module, dataset: DataLoader, render_path: str, log_path: str, device: str, save_image: bool):
|
|
16
|
+
net = net.eval().to(device)
|
|
17
|
+
shutil.rmtree(render_path, ignore_errors=True)
|
|
18
|
+
os.makedirs(render_path, exist_ok=True)
|
|
19
|
+
pbar = tqdm(dataset, desc="Rendering progress")
|
|
20
|
+
with open(log_path, "w") as f:
|
|
21
|
+
f.write(f"frame,psnr,ssim,lpips\n")
|
|
22
|
+
for idx, (tup, ground_truth) in enumerate(pbar):
|
|
23
|
+
restore = net(tup.color_distorted.to(device), tup.reference.to(device))
|
|
24
|
+
ground_truth = ground_truth.to(device)
|
|
25
|
+
scores = {
|
|
26
|
+
"PSNR": psnr(restore, ground_truth).mean().item(),
|
|
27
|
+
"SSIM": ssim(restore, ground_truth).mean().item(),
|
|
28
|
+
"LPIPS": lpips(restore, ground_truth).mean().item(),
|
|
29
|
+
}
|
|
30
|
+
pbar.set_postfix(scores)
|
|
31
|
+
with open(log_path, "a") as f:
|
|
32
|
+
f.write(f"{idx},{scores['PSNR']},{scores['SSIM']},{scores['LPIPS']}\n")
|
|
33
|
+
if save_image:
|
|
34
|
+
torchvision.utils.save_image(restore, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
if __name__ == "__main__":
|
|
38
|
+
from argparse import ArgumentParser
|
|
39
|
+
parser = ArgumentParser()
|
|
40
|
+
parser.add_argument("-s", "--source", required=True, type=str)
|
|
41
|
+
parser.add_argument("-d", "--destination", required=True, type=str)
|
|
42
|
+
parser.add_argument("--model", choices=list(model_dict.keys()), default="dualresnet")
|
|
43
|
+
parser.add_argument("--model_destination", required=True, type=str)
|
|
44
|
+
parser.add_argument("--device", default="cuda", type=str)
|
|
45
|
+
parser.add_argument("-o", "--option", default=[], action='append', type=str)
|
|
46
|
+
parser.add_argument("--batch_size", default=1, type=int)
|
|
47
|
+
parser.add_argument("--num_workers", default=4, type=int)
|
|
48
|
+
parser.add_argument("--epoch", default=30, type=int)
|
|
49
|
+
parser.add_argument("--save_image", action='store_true')
|
|
50
|
+
args = parser.parse_args()
|
|
51
|
+
configs = {o.split("=", 1)[0]: eval(o.split("=", 1)[1]) for o in args.option}
|
|
52
|
+
loader = load_dataset(image_dir=args.source, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
|
|
53
|
+
model_path = os.path.join(args.model_destination, f"{args.model}.pth")
|
|
54
|
+
model = build_model(model=args.model, load_path=model_path, **configs)
|
|
55
|
+
with torch.no_grad():
|
|
56
|
+
rendering(model, loader, os.path.join(args.destination, args.model), os.path.join(args.destination, args.model + '.csv'), device=args.device, save_image=args.save_image)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .arch import rt4ksr_rep, rt4krestore_rep
|
|
2
|
+
from .archnerf import rt4kdual_rep
|
|
3
|
+
from .srvgg_arch import srvgg, restorevgg, dualvgg
|
|
4
|
+
from .srresnet_arch import srresnet, restoreresnet, dualresnet
|
|
5
|
+
model_dict = {
|
|
6
|
+
"rt4ksr": rt4ksr_rep,
|
|
7
|
+
"rt4krestore": rt4krestore_rep,
|
|
8
|
+
"rt4kdual": rt4kdual_rep,
|
|
9
|
+
"srvgg": srvgg,
|
|
10
|
+
"restoreresnet": restoreresnet,
|
|
11
|
+
"dualresnet": dualresnet,
|
|
12
|
+
"srresnet": srresnet,
|
|
13
|
+
"restorevgg": restorevgg,
|
|
14
|
+
"dualvgg": dualvgg
|
|
15
|
+
}
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
import torchvision
|
|
5
|
+
|
|
6
|
+
from .modules import *
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
####################################
|
|
10
|
+
# DEFINE MODEL INSTANCES
|
|
11
|
+
####################################
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RT4KSR_Rep(nn.Module):
|
|
15
|
+
def __init__(self,
|
|
16
|
+
num_channels_in,
|
|
17
|
+
num_channels_out,
|
|
18
|
+
num_feats,
|
|
19
|
+
num_blocks,
|
|
20
|
+
upscale,
|
|
21
|
+
act,
|
|
22
|
+
eca_gamma,
|
|
23
|
+
is_train,
|
|
24
|
+
forget,
|
|
25
|
+
layernorm,
|
|
26
|
+
residual) -> None:
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.forget = forget
|
|
29
|
+
self.gamma = nn.Parameter(torch.zeros(1))
|
|
30
|
+
self.gaussian = torchvision.transforms.GaussianBlur(kernel_size=5, sigma=1)
|
|
31
|
+
self.upscale = upscale
|
|
32
|
+
|
|
33
|
+
self.down = nn.PixelUnshuffle(2)
|
|
34
|
+
self.up = nn.PixelShuffle(2)
|
|
35
|
+
self.head = nn.Sequential(nn.Conv2d(num_channels_in * (2**2), num_feats, 3, padding=1))
|
|
36
|
+
|
|
37
|
+
hfb = []
|
|
38
|
+
if is_train:
|
|
39
|
+
hfb.append(ResBlock(num_feats, ratio=2))
|
|
40
|
+
else:
|
|
41
|
+
hfb.append((RepResBlock(num_feats)))
|
|
42
|
+
hfb.append(act)
|
|
43
|
+
self.hfb = nn.Sequential(*hfb)
|
|
44
|
+
|
|
45
|
+
body = []
|
|
46
|
+
for i in range(num_blocks):
|
|
47
|
+
if is_train:
|
|
48
|
+
body.append(SimplifiedNAFBlock(in_c=num_feats, act=act, exp=2, eca_gamma=eca_gamma, layernorm=layernorm, residual=residual))
|
|
49
|
+
else:
|
|
50
|
+
body.append(SimplifiedRepNAFBlock(in_c=num_feats, act=act, exp=2, eca_gamma=eca_gamma, layernorm=layernorm, residual=residual))
|
|
51
|
+
|
|
52
|
+
self.body = nn.Sequential(*body)
|
|
53
|
+
|
|
54
|
+
tail = [LayerNorm2d(num_feats)]
|
|
55
|
+
if is_train:
|
|
56
|
+
tail.append(ResBlock(num_feats, ratio=2))
|
|
57
|
+
else:
|
|
58
|
+
tail.append(RepResBlock(num_feats))
|
|
59
|
+
self.tail = nn.Sequential(*tail)
|
|
60
|
+
|
|
61
|
+
self.upsample = nn.Sequential(
|
|
62
|
+
nn.Conv2d(num_feats, num_channels_out * ((2 * upscale) ** 2), 3, padding=1),
|
|
63
|
+
nn.PixelShuffle(upscale*2)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def forward(self, _, out_lr): # only lr is used
|
|
67
|
+
x = out_lr
|
|
68
|
+
# stage 1
|
|
69
|
+
hf = x - self.gaussian(x)
|
|
70
|
+
|
|
71
|
+
# unshuffle to save computation
|
|
72
|
+
x_unsh = self.down(x)
|
|
73
|
+
hf_unsh = self.down(hf)
|
|
74
|
+
|
|
75
|
+
shallow_feats_hf = self.head(hf_unsh)
|
|
76
|
+
shallow_feats_lr = self.head(x_unsh)
|
|
77
|
+
|
|
78
|
+
# stage 2
|
|
79
|
+
deep_feats = self.body(shallow_feats_lr)
|
|
80
|
+
hf_feats = self.hfb(shallow_feats_hf)
|
|
81
|
+
|
|
82
|
+
# stage 3
|
|
83
|
+
if self.forget:
|
|
84
|
+
deep_feats = self.tail(self.gamma * deep_feats + hf_feats)
|
|
85
|
+
else:
|
|
86
|
+
deep_feats = self.tail(deep_feats)
|
|
87
|
+
|
|
88
|
+
out = self.upsample(deep_feats)
|
|
89
|
+
out += F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
|
|
90
|
+
return F.interpolate(out, size=_.shape[2:], mode='bilinear', align_corners=False).clamp(0, 1.)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class RT4KRestore_Rep(RT4KSR_Rep):
|
|
94
|
+
def __init__(self, *args, **kwargs):
|
|
95
|
+
super().__init__(*args, upscale=1, **kwargs)
|
|
96
|
+
|
|
97
|
+
def forward(self, x, _): # only lr is used
|
|
98
|
+
shape = x.shape[-2:]
|
|
99
|
+
x = F.interpolate(x, size=(x.shape[-2]//2*2, x.shape[-1]//2*2), mode='bilinear', align_corners=False)
|
|
100
|
+
# stage 1
|
|
101
|
+
hf = x - self.gaussian(x)
|
|
102
|
+
|
|
103
|
+
# unshuffle to save computation
|
|
104
|
+
x_unsh = self.down(x)
|
|
105
|
+
hf_unsh = self.down(hf)
|
|
106
|
+
|
|
107
|
+
shallow_feats_hf = self.head(hf_unsh)
|
|
108
|
+
shallow_feats_lr = self.head(x_unsh)
|
|
109
|
+
|
|
110
|
+
# stage 2
|
|
111
|
+
deep_feats = self.body(shallow_feats_lr)
|
|
112
|
+
hf_feats = self.hfb(shallow_feats_hf)
|
|
113
|
+
|
|
114
|
+
# stage 3
|
|
115
|
+
if self.forget:
|
|
116
|
+
deep_feats = self.tail(self.gamma * deep_feats + hf_feats)
|
|
117
|
+
else:
|
|
118
|
+
deep_feats = self.tail(deep_feats)
|
|
119
|
+
|
|
120
|
+
out = self.upsample(deep_feats)
|
|
121
|
+
out += x
|
|
122
|
+
return F.interpolate(out, size=shape, mode='bilinear', align_corners=False).clamp(0, 1.)
|
|
123
|
+
|
|
124
|
+
####################################
|
|
125
|
+
# RETURN INITIALIZED MODEL INSTANCES
|
|
126
|
+
####################################
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def rt4ksr_rep(
|
|
130
|
+
act_type: str = "gelu",
|
|
131
|
+
feature_channels: int = 24,
|
|
132
|
+
num_blocks: int = 4,
|
|
133
|
+
scale: int = 4,
|
|
134
|
+
is_train: bool = True,
|
|
135
|
+
):
|
|
136
|
+
act = activation(act_type)
|
|
137
|
+
model = RT4KSR_Rep(num_channels_in=3,
|
|
138
|
+
num_channels_out=3,
|
|
139
|
+
num_feats=feature_channels,
|
|
140
|
+
num_blocks=num_blocks,
|
|
141
|
+
upscale=scale,
|
|
142
|
+
act=act,
|
|
143
|
+
eca_gamma=0,
|
|
144
|
+
forget=False,
|
|
145
|
+
is_train=is_train,
|
|
146
|
+
layernorm=True,
|
|
147
|
+
residual=True)
|
|
148
|
+
return model
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def rt4krestore_rep(
|
|
152
|
+
act_type: str = "gelu",
|
|
153
|
+
feature_channels: int = 24,
|
|
154
|
+
num_blocks: int = 4,
|
|
155
|
+
is_train: bool = True,
|
|
156
|
+
):
|
|
157
|
+
act = activation(act_type)
|
|
158
|
+
model = RT4KRestore_Rep(num_channels_in=3,
|
|
159
|
+
num_channels_out=3,
|
|
160
|
+
num_feats=feature_channels,
|
|
161
|
+
num_blocks=num_blocks,
|
|
162
|
+
act=act,
|
|
163
|
+
eca_gamma=0,
|
|
164
|
+
forget=False,
|
|
165
|
+
is_train=is_train,
|
|
166
|
+
layernorm=True,
|
|
167
|
+
residual=True)
|
|
168
|
+
return model
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from .arch import *
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RT4KDual_Rep(RT4KSR_Rep):
|
|
8
|
+
def __init__(self, num_channels_color, num_channels_gray, *args, **kwargs):
|
|
9
|
+
super().__init__(*args, **kwargs, upscale=1, num_channels_in=num_channels_color + num_channels_gray)
|
|
10
|
+
self.num_channels_color = num_channels_color
|
|
11
|
+
self.num_channels_gray = num_channels_gray
|
|
12
|
+
|
|
13
|
+
def forward(self, x, lr):
|
|
14
|
+
shape = x.shape[-2:]
|
|
15
|
+
x = F.interpolate(x, size=(x.shape[-2]//2*2, x.shape[-1]//2*2), mode='bilinear', align_corners=False)
|
|
16
|
+
base = F.interpolate(lr, size=(x.shape[-2]//2*2, x.shape[-1]//2*2), mode='bicubic', align_corners=False)
|
|
17
|
+
x = torch.cat([base, x], dim=1)
|
|
18
|
+
|
|
19
|
+
# Following is just copy from RT4KSR_Rep
|
|
20
|
+
# stage 1
|
|
21
|
+
hf = x - self.gaussian(x)
|
|
22
|
+
|
|
23
|
+
# unshuffle to save computation
|
|
24
|
+
x_unsh = self.down(x)
|
|
25
|
+
hf_unsh = self.down(hf)
|
|
26
|
+
|
|
27
|
+
shallow_feats_hf = self.head(hf_unsh)
|
|
28
|
+
shallow_feats_lr = self.head(x_unsh)
|
|
29
|
+
|
|
30
|
+
# stage 2
|
|
31
|
+
deep_feats = self.body(shallow_feats_lr)
|
|
32
|
+
hf_feats = self.hfb(shallow_feats_hf)
|
|
33
|
+
|
|
34
|
+
# stage 3
|
|
35
|
+
if self.forget:
|
|
36
|
+
deep_feats = self.tail(self.gamma * deep_feats + hf_feats)
|
|
37
|
+
else:
|
|
38
|
+
deep_feats = self.tail(deep_feats)
|
|
39
|
+
|
|
40
|
+
out = self.upsample(deep_feats)
|
|
41
|
+
out += base
|
|
42
|
+
return F.interpolate(out, size=shape, mode='bilinear', align_corners=False).clamp(0, 1.)
|
|
43
|
+
|
|
44
|
+
####################################
|
|
45
|
+
# RETURN INITIALIZED MODEL INSTANCES
|
|
46
|
+
####################################
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def rt4kdual_rep(
|
|
50
|
+
act_type: str = "gelu",
|
|
51
|
+
feature_channels: int = 24,
|
|
52
|
+
num_blocks: int = 4,
|
|
53
|
+
is_train: bool = True,
|
|
54
|
+
):
|
|
55
|
+
act = activation(act_type)
|
|
56
|
+
model = RT4KDual_Rep(num_channels_color=3,
|
|
57
|
+
num_channels_gray=3,
|
|
58
|
+
num_channels_out=3,
|
|
59
|
+
num_feats=feature_channels,
|
|
60
|
+
num_blocks=num_blocks,
|
|
61
|
+
act=act,
|
|
62
|
+
eca_gamma=0,
|
|
63
|
+
forget=False,
|
|
64
|
+
is_train=is_train,
|
|
65
|
+
layernorm=True,
|
|
66
|
+
residual=True)
|
|
67
|
+
return model
|