refcolorrestore 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,61 @@
1
+ from typing import Tuple
2
+ import torch
3
+ import os
4
+ from gaussian_splatting import GaussianModel
5
+ from gaussian_splatting.prepare import prepare_dataset, prepare_gaussians
6
+ from extrinterp import ExtrinsicInterpolator
7
+ from refcolorrestore.dataset import Extrinsic2DualCameraDataset, DualCamera2RestorationDataset, DualCameraDataset
8
+
9
+
10
+ def prepare_rendering(
11
+ sh_degree: int, source: str, device: str, n: int, window_size: int,
12
+ trainable_camera: bool = False,
13
+ load_ply: str = None, load_ply_gt: str = None,
14
+ load_camera: str = None,
15
+ use_intrinsics: int | dict = 0
16
+ ) -> Tuple[DualCameraDataset, GaussianModel, GaussianModel]:
17
+ dataset = prepare_dataset(source=source, device=device, trainable_camera=trainable_camera, load_camera=load_camera, load_depth=False)
18
+ if isinstance(use_intrinsics, int):
19
+ i = use_intrinsics
20
+ use_intrinsics = dict(
21
+ image_height=dataset[i].image_height, image_width=dataset[i].image_width,
22
+ FoVx=dataset[i].FoVx, FoVy=dataset[i].FoVy)
23
+ elif not isinstance(use_intrinsics, dict):
24
+ raise ValueError("Invalid use_intrinsics format")
25
+ dataset = Extrinsic2DualCameraDataset(ExtrinsicInterpolator(dataset=dataset, n=n, window_size=window_size), **use_intrinsics)
26
+ gaussians = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply)
27
+ gaussians_gt = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply_gt)
28
+ return dataset, gaussians, gaussians_gt
29
+
30
+
31
+ if __name__ == "__main__":
32
+ from argparse import ArgumentParser
33
+ parser = ArgumentParser()
34
+ parser.add_argument("--sh_degree", default=3, type=int)
35
+ parser.add_argument("-s", "--source", required=True, type=str)
36
+ parser.add_argument("-d", "--destination", required=True, type=str)
37
+ parser.add_argument("-i", "--iteration", required=True, type=int)
38
+ parser.add_argument("--load_camera", default=None, type=str)
39
+ parser.add_argument("--mode", choices=["base", "camera"], default="base")
40
+ parser.add_argument("--device", default="cuda", type=str)
41
+ parser.add_argument("--destination_gt", required=True, type=str)
42
+ parser.add_argument("--iteration_gt", required=True, type=int)
43
+ parser.add_argument("--interp_n", required=True, type=int)
44
+ parser.add_argument("--interp_window_size", type=int, default=3)
45
+ parser.add_argument("--use_intrinsics", type=str, default="0", help="Use intrinsics for rendering, can be an integer index or a dict with keys: image_height, image_width, FoVx, FoVy")
46
+ parser.add_argument("--downsample", default=4, type=int)
47
+ parser.add_argument("--data_dir", required=True, type=str)
48
+ args = parser.parse_args()
49
+ load_ply = os.path.join(args.destination, "point_cloud", "iteration_" + str(args.iteration), "point_cloud.ply")
50
+ load_ply_gt = os.path.join(args.destination_gt, "point_cloud", "iteration_" + str(args.iteration_gt), "point_cloud.ply")
51
+ with torch.no_grad():
52
+ cameras, gaussians, gaussians_gt = prepare_rendering(
53
+ sh_degree=args.sh_degree, source=args.source, device=args.device,
54
+ n=args.interp_n, window_size=args.interp_window_size,
55
+ trainable_camera=args.mode == "camera",
56
+ load_ply=load_ply, load_ply_gt=load_ply_gt, load_camera=args.load_camera,
57
+ use_intrinsics=eval(args.use_intrinsics)
58
+ )
59
+ dataset = DualCamera2RestorationDataset(cameras=cameras, color_distorted_gaussians=gaussians, ground_truth_gaussians=gaussians_gt)
60
+ os.makedirs(args.data_dir, exist_ok=True)
61
+ dataset.save_dataset(args.data_dir)
@@ -0,0 +1,3 @@
1
+ from .abc import DualCameraDataset, RestorationDataset, RestorationTuple, SavedRestorationDataset
2
+ from .interp import Extrinsic2DualCameraDataset
3
+ from .dataset import DualCamera2RestorationDataset
@@ -0,0 +1,82 @@
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import NamedTuple, Tuple
4
+ from tqdm import tqdm
5
+
6
+ import torch
7
+ import torchvision
8
+
9
+ from gaussian_splatting.camera import Camera
10
+
11
+
12
+ class DualCameraDataset:
13
+
14
+ @abstractmethod
15
+ def to(self, device) -> 'DualCameraDataset':
16
+ return self
17
+
18
+ @abstractmethod
19
+ def __len__(self) -> int:
20
+ raise NotImplementedError
21
+
22
+ @abstractmethod
23
+ def __getitem__(self, idx) -> Tuple[Camera, Camera]:
24
+ raise NotImplementedError
25
+
26
+
27
+ class RestorationTuple(NamedTuple):
28
+ color_distorted: torch.FloatTensor
29
+ reference: torch.FloatTensor
30
+
31
+
32
+ class RestorationDataset:
33
+
34
+ @abstractmethod
35
+ def to(self, device) -> 'RestorationDataset':
36
+ return self
37
+
38
+ @abstractmethod
39
+ def __len__(self) -> int:
40
+ raise NotImplementedError
41
+
42
+ @abstractmethod
43
+ def __getitem__(self, idx) -> Tuple[RestorationTuple, torch.FloatTensor]:
44
+ raise NotImplementedError
45
+
46
+ def save_image_tuple(self, idx, image_dir):
47
+ restoration, gt = self[idx]
48
+ color_distorted_dir = os.path.join(image_dir, 'distorted')
49
+ os.makedirs(color_distorted_dir, exist_ok=True)
50
+ torchvision.utils.save_image(restoration.color_distorted, os.path.join(color_distorted_dir, '{0:05d}'.format(idx) + ".png"))
51
+ reference_dir = os.path.join(image_dir, 'reference')
52
+ os.makedirs(reference_dir, exist_ok=True)
53
+ torchvision.utils.save_image(restoration.reference, os.path.join(reference_dir, '{0:05d}'.format(idx) + ".png"))
54
+ gt_dir = os.path.join(image_dir, 'groundtruth')
55
+ os.makedirs(gt_dir, exist_ok=True)
56
+ torchvision.utils.save_image(gt, os.path.join(gt_dir, '{0:05d}'.format(idx) + ".png"))
57
+
58
+ def save_dataset(self, image_dir):
59
+ os.makedirs(image_dir, exist_ok=True)
60
+ for idx in tqdm(range(len(self)), desc="Saving images"):
61
+ self.save_image_tuple(idx, image_dir)
62
+
63
+
64
+ class SavedRestorationDataset(RestorationDataset):
65
+ def __init__(self, data_dir: str):
66
+ self.color_distorted_dir = os.path.join(data_dir, 'distorted')
67
+ self.reference_dir = os.path.join(data_dir, 'reference')
68
+ self.ground_truth_dir = os.path.join(data_dir, 'groundtruth')
69
+
70
+ def __len__(self) -> int:
71
+ n = 0
72
+ while os.path.exists(os.path.join(self.color_distorted_dir, '{0:05d}'.format(n) + ".png")) and \
73
+ os.path.exists(os.path.join(self.reference_dir, '{0:05d}'.format(n) + ".png")) and \
74
+ os.path.exists(os.path.join(self.ground_truth_dir, '{0:05d}'.format(n) + ".png")):
75
+ n += 1
76
+ return n
77
+
78
+ def __getitem__(self, idx) -> Tuple[RestorationTuple, torch.FloatTensor]:
79
+ color_distorted = torchvision.io.read_image(os.path.join(self.color_distorted_dir, '{0:05d}'.format(idx) + ".png")) / 255.0
80
+ reference = torchvision.io.read_image(os.path.join(self.reference_dir, '{0:05d}'.format(idx) + ".png")) / 255.0
81
+ ground_truth = torchvision.io.read_image(os.path.join(self.ground_truth_dir, '{0:05d}'.format(idx) + ".png")) / 255.0
82
+ return RestorationTuple(color_distorted, reference), ground_truth
@@ -0,0 +1,30 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from gaussian_splatting import GaussianModel
6
+
7
+ from .abc import DualCameraDataset, RestorationDataset, RestorationTuple
8
+
9
+
10
+ class DualCamera2RestorationDataset(RestorationDataset):
11
+ def __init__(self, cameras: DualCameraDataset, color_distorted_gaussians: GaussianModel, ground_truth_gaussians: GaussianModel):
12
+ self.cameras = cameras
13
+ self.ground_truth_gaussians = ground_truth_gaussians
14
+ self.color_distorted_gaussians = color_distorted_gaussians
15
+
16
+ def to(self, device) -> 'DualCamera2RestorationDataset':
17
+ self.cameras = self.cameras.to(device)
18
+ self.ground_truth_gaussians = self.ground_truth_gaussians.to(device)
19
+ self.color_distorted_gaussians = self.color_distorted_gaussians.to(device)
20
+ return self
21
+
22
+ def __len__(self) -> int:
23
+ return len(self.cameras)
24
+
25
+ def __getitem__(self, idx) -> Tuple[RestorationTuple, torch.FloatTensor]:
26
+ lr_camera, hr_camera = self.cameras[idx]
27
+ color_distorted = self.color_distorted_gaussians(hr_camera)['render']
28
+ ground_truth = self.ground_truth_gaussians(hr_camera)['render']
29
+ reference = self.ground_truth_gaussians(lr_camera)['render']
30
+ return RestorationTuple(color_distorted, reference), ground_truth
@@ -0,0 +1,46 @@
1
+ from typing import Tuple
2
+ import torch
3
+
4
+ from gaussian_splatting import Camera
5
+ from extrinterp import ExtrinsicDataset
6
+
7
+ from .abc import DualCameraDataset
8
+
9
+
10
+ class Extrinsic2DualCameraDataset(DualCameraDataset):
11
+ def __init__(
12
+ self,
13
+ base_dataset: ExtrinsicDataset,
14
+ FoVx: float = 90.0*torch.pi/180, FoVy: float = 90.0*torch.pi/180,
15
+ image_height: int = 1000, image_width: int = 1000, downsample=4):
16
+ self.cameras = base_dataset
17
+ self.FoVx = FoVx
18
+ self.FoVy = FoVy
19
+ self.downsample = downsample
20
+ self.image_height = image_height // 4 * 4
21
+ self.image_width = image_width // 4 * 4
22
+
23
+ def __len__(self):
24
+ return len(self.cameras)
25
+
26
+ def __getitem__(self, idx) -> Tuple[Camera, Camera]:
27
+ extr = self.cameras[idx]
28
+ camera_hr = extr.to_camera(
29
+ image_height=self.image_height,
30
+ image_width=self.image_width,
31
+ FoVx=self.FoVx,
32
+ FoVy=self.FoVy,
33
+ device=self.cameras[idx].R.device
34
+ )
35
+ camera_lr = extr.to_camera(
36
+ image_height=self.image_height // self.downsample,
37
+ image_width=self.image_width // self.downsample,
38
+ FoVx=self.FoVx,
39
+ FoVy=self.FoVy,
40
+ device=self.cameras[idx].R.device
41
+ )
42
+ return camera_lr, camera_hr
43
+
44
+ def to(self, device):
45
+ self.cameras = self.cameras.to(device)
46
+ return self
@@ -0,0 +1,28 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
+
5
+ from refcolorrestore.dataset import SavedRestorationDataset
6
+ from refcolorrestore.shelf import model_dict
7
+
8
+
9
+ def load_dataset(
10
+ image_dir: str,
11
+ batch_size: int, num_workers: int, shuffle=False) -> DataLoader:
12
+ dataset = SavedRestorationDataset(data_dir=image_dir)
13
+ loader = DataLoader(
14
+ dataset,
15
+ batch_size=batch_size,
16
+ num_workers=num_workers,
17
+ pin_memory=True,
18
+ shuffle=shuffle,
19
+ drop_last=False
20
+ )
21
+ return loader
22
+
23
+
24
+ def build_model(model: str, load_path: str = None, **kwargs) -> nn.Module:
25
+ net = model_dict[model](**kwargs)
26
+ if load_path:
27
+ net.load_state_dict(torch.load(load_path))
28
+ return net
@@ -0,0 +1,56 @@
1
+ import shutil
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import DataLoader
5
+ import os
6
+ from tqdm import tqdm
7
+ import torchvision
8
+ from gaussian_splatting.utils import psnr, ssim
9
+ from gaussian_splatting.utils.lpipsPyTorch import lpips
10
+
11
+ from refcolorrestore.shelf import model_dict
12
+ from refcolorrestore.prepare import load_dataset, build_model
13
+
14
+
15
+ def rendering(net: nn.Module, dataset: DataLoader, render_path: str, log_path: str, device: str, save_image: bool):
16
+ net = net.eval().to(device)
17
+ shutil.rmtree(render_path, ignore_errors=True)
18
+ os.makedirs(render_path, exist_ok=True)
19
+ pbar = tqdm(dataset, desc="Rendering progress")
20
+ with open(log_path, "w") as f:
21
+ f.write(f"frame,psnr,ssim,lpips\n")
22
+ for idx, (tup, ground_truth) in enumerate(pbar):
23
+ restore = net(tup.color_distorted.to(device), tup.reference.to(device))
24
+ ground_truth = ground_truth.to(device)
25
+ scores = {
26
+ "PSNR": psnr(restore, ground_truth).mean().item(),
27
+ "SSIM": ssim(restore, ground_truth).mean().item(),
28
+ "LPIPS": lpips(restore, ground_truth).mean().item(),
29
+ }
30
+ pbar.set_postfix(scores)
31
+ with open(log_path, "a") as f:
32
+ f.write(f"{idx},{scores['PSNR']},{scores['SSIM']},{scores['LPIPS']}\n")
33
+ if save_image:
34
+ torchvision.utils.save_image(restore, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
35
+
36
+
37
+ if __name__ == "__main__":
38
+ from argparse import ArgumentParser
39
+ parser = ArgumentParser()
40
+ parser.add_argument("-s", "--source", required=True, type=str)
41
+ parser.add_argument("-d", "--destination", required=True, type=str)
42
+ parser.add_argument("--model", choices=list(model_dict.keys()), default="dualresnet")
43
+ parser.add_argument("--model_destination", required=True, type=str)
44
+ parser.add_argument("--device", default="cuda", type=str)
45
+ parser.add_argument("-o", "--option", default=[], action='append', type=str)
46
+ parser.add_argument("--batch_size", default=1, type=int)
47
+ parser.add_argument("--num_workers", default=4, type=int)
48
+ parser.add_argument("--epoch", default=30, type=int)
49
+ parser.add_argument("--save_image", action='store_true')
50
+ args = parser.parse_args()
51
+ configs = {o.split("=", 1)[0]: eval(o.split("=", 1)[1]) for o in args.option}
52
+ loader = load_dataset(image_dir=args.source, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
53
+ model_path = os.path.join(args.model_destination, f"{args.model}.pth")
54
+ model = build_model(model=args.model, load_path=model_path, **configs)
55
+ with torch.no_grad():
56
+ rendering(model, loader, os.path.join(args.destination, args.model), os.path.join(args.destination, args.model + '.csv'), device=args.device, save_image=args.save_image)
@@ -0,0 +1,15 @@
1
+ from .arch import rt4ksr_rep, rt4krestore_rep
2
+ from .archnerf import rt4kdual_rep
3
+ from .srvgg_arch import srvgg, restorevgg, dualvgg
4
+ from .srresnet_arch import srresnet, restoreresnet, dualresnet
5
+ model_dict = {
6
+ "rt4ksr": rt4ksr_rep,
7
+ "rt4krestore": rt4krestore_rep,
8
+ "rt4kdual": rt4kdual_rep,
9
+ "srvgg": srvgg,
10
+ "restoreresnet": restoreresnet,
11
+ "dualresnet": dualresnet,
12
+ "srresnet": srresnet,
13
+ "restorevgg": restorevgg,
14
+ "dualvgg": dualvgg
15
+ }
@@ -0,0 +1,168 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ from .modules import *
7
+
8
+
9
+ ####################################
10
+ # DEFINE MODEL INSTANCES
11
+ ####################################
12
+
13
+
14
+ class RT4KSR_Rep(nn.Module):
15
+ def __init__(self,
16
+ num_channels_in,
17
+ num_channels_out,
18
+ num_feats,
19
+ num_blocks,
20
+ upscale,
21
+ act,
22
+ eca_gamma,
23
+ is_train,
24
+ forget,
25
+ layernorm,
26
+ residual) -> None:
27
+ super().__init__()
28
+ self.forget = forget
29
+ self.gamma = nn.Parameter(torch.zeros(1))
30
+ self.gaussian = torchvision.transforms.GaussianBlur(kernel_size=5, sigma=1)
31
+ self.upscale = upscale
32
+
33
+ self.down = nn.PixelUnshuffle(2)
34
+ self.up = nn.PixelShuffle(2)
35
+ self.head = nn.Sequential(nn.Conv2d(num_channels_in * (2**2), num_feats, 3, padding=1))
36
+
37
+ hfb = []
38
+ if is_train:
39
+ hfb.append(ResBlock(num_feats, ratio=2))
40
+ else:
41
+ hfb.append((RepResBlock(num_feats)))
42
+ hfb.append(act)
43
+ self.hfb = nn.Sequential(*hfb)
44
+
45
+ body = []
46
+ for i in range(num_blocks):
47
+ if is_train:
48
+ body.append(SimplifiedNAFBlock(in_c=num_feats, act=act, exp=2, eca_gamma=eca_gamma, layernorm=layernorm, residual=residual))
49
+ else:
50
+ body.append(SimplifiedRepNAFBlock(in_c=num_feats, act=act, exp=2, eca_gamma=eca_gamma, layernorm=layernorm, residual=residual))
51
+
52
+ self.body = nn.Sequential(*body)
53
+
54
+ tail = [LayerNorm2d(num_feats)]
55
+ if is_train:
56
+ tail.append(ResBlock(num_feats, ratio=2))
57
+ else:
58
+ tail.append(RepResBlock(num_feats))
59
+ self.tail = nn.Sequential(*tail)
60
+
61
+ self.upsample = nn.Sequential(
62
+ nn.Conv2d(num_feats, num_channels_out * ((2 * upscale) ** 2), 3, padding=1),
63
+ nn.PixelShuffle(upscale*2)
64
+ )
65
+
66
+ def forward(self, _, out_lr): # only lr is used
67
+ x = out_lr
68
+ # stage 1
69
+ hf = x - self.gaussian(x)
70
+
71
+ # unshuffle to save computation
72
+ x_unsh = self.down(x)
73
+ hf_unsh = self.down(hf)
74
+
75
+ shallow_feats_hf = self.head(hf_unsh)
76
+ shallow_feats_lr = self.head(x_unsh)
77
+
78
+ # stage 2
79
+ deep_feats = self.body(shallow_feats_lr)
80
+ hf_feats = self.hfb(shallow_feats_hf)
81
+
82
+ # stage 3
83
+ if self.forget:
84
+ deep_feats = self.tail(self.gamma * deep_feats + hf_feats)
85
+ else:
86
+ deep_feats = self.tail(deep_feats)
87
+
88
+ out = self.upsample(deep_feats)
89
+ out += F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
90
+ return F.interpolate(out, size=_.shape[2:], mode='bilinear', align_corners=False).clamp(0, 1.)
91
+
92
+
93
+ class RT4KRestore_Rep(RT4KSR_Rep):
94
+ def __init__(self, *args, **kwargs):
95
+ super().__init__(*args, upscale=1, **kwargs)
96
+
97
+ def forward(self, x, _): # only lr is used
98
+ shape = x.shape[-2:]
99
+ x = F.interpolate(x, size=(x.shape[-2]//2*2, x.shape[-1]//2*2), mode='bilinear', align_corners=False)
100
+ # stage 1
101
+ hf = x - self.gaussian(x)
102
+
103
+ # unshuffle to save computation
104
+ x_unsh = self.down(x)
105
+ hf_unsh = self.down(hf)
106
+
107
+ shallow_feats_hf = self.head(hf_unsh)
108
+ shallow_feats_lr = self.head(x_unsh)
109
+
110
+ # stage 2
111
+ deep_feats = self.body(shallow_feats_lr)
112
+ hf_feats = self.hfb(shallow_feats_hf)
113
+
114
+ # stage 3
115
+ if self.forget:
116
+ deep_feats = self.tail(self.gamma * deep_feats + hf_feats)
117
+ else:
118
+ deep_feats = self.tail(deep_feats)
119
+
120
+ out = self.upsample(deep_feats)
121
+ out += x
122
+ return F.interpolate(out, size=shape, mode='bilinear', align_corners=False).clamp(0, 1.)
123
+
124
+ ####################################
125
+ # RETURN INITIALIZED MODEL INSTANCES
126
+ ####################################
127
+
128
+
129
+ def rt4ksr_rep(
130
+ act_type: str = "gelu",
131
+ feature_channels: int = 24,
132
+ num_blocks: int = 4,
133
+ scale: int = 4,
134
+ is_train: bool = True,
135
+ ):
136
+ act = activation(act_type)
137
+ model = RT4KSR_Rep(num_channels_in=3,
138
+ num_channels_out=3,
139
+ num_feats=feature_channels,
140
+ num_blocks=num_blocks,
141
+ upscale=scale,
142
+ act=act,
143
+ eca_gamma=0,
144
+ forget=False,
145
+ is_train=is_train,
146
+ layernorm=True,
147
+ residual=True)
148
+ return model
149
+
150
+
151
+ def rt4krestore_rep(
152
+ act_type: str = "gelu",
153
+ feature_channels: int = 24,
154
+ num_blocks: int = 4,
155
+ is_train: bool = True,
156
+ ):
157
+ act = activation(act_type)
158
+ model = RT4KRestore_Rep(num_channels_in=3,
159
+ num_channels_out=3,
160
+ num_feats=feature_channels,
161
+ num_blocks=num_blocks,
162
+ act=act,
163
+ eca_gamma=0,
164
+ forget=False,
165
+ is_train=is_train,
166
+ layernorm=True,
167
+ residual=True)
168
+ return model
@@ -0,0 +1,67 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .arch import *
5
+
6
+
7
+ class RT4KDual_Rep(RT4KSR_Rep):
8
+ def __init__(self, num_channels_color, num_channels_gray, *args, **kwargs):
9
+ super().__init__(*args, **kwargs, upscale=1, num_channels_in=num_channels_color + num_channels_gray)
10
+ self.num_channels_color = num_channels_color
11
+ self.num_channels_gray = num_channels_gray
12
+
13
+ def forward(self, x, lr):
14
+ shape = x.shape[-2:]
15
+ x = F.interpolate(x, size=(x.shape[-2]//2*2, x.shape[-1]//2*2), mode='bilinear', align_corners=False)
16
+ base = F.interpolate(lr, size=(x.shape[-2]//2*2, x.shape[-1]//2*2), mode='bicubic', align_corners=False)
17
+ x = torch.cat([base, x], dim=1)
18
+
19
+ # Following is just copy from RT4KSR_Rep
20
+ # stage 1
21
+ hf = x - self.gaussian(x)
22
+
23
+ # unshuffle to save computation
24
+ x_unsh = self.down(x)
25
+ hf_unsh = self.down(hf)
26
+
27
+ shallow_feats_hf = self.head(hf_unsh)
28
+ shallow_feats_lr = self.head(x_unsh)
29
+
30
+ # stage 2
31
+ deep_feats = self.body(shallow_feats_lr)
32
+ hf_feats = self.hfb(shallow_feats_hf)
33
+
34
+ # stage 3
35
+ if self.forget:
36
+ deep_feats = self.tail(self.gamma * deep_feats + hf_feats)
37
+ else:
38
+ deep_feats = self.tail(deep_feats)
39
+
40
+ out = self.upsample(deep_feats)
41
+ out += base
42
+ return F.interpolate(out, size=shape, mode='bilinear', align_corners=False).clamp(0, 1.)
43
+
44
+ ####################################
45
+ # RETURN INITIALIZED MODEL INSTANCES
46
+ ####################################
47
+
48
+
49
+ def rt4kdual_rep(
50
+ act_type: str = "gelu",
51
+ feature_channels: int = 24,
52
+ num_blocks: int = 4,
53
+ is_train: bool = True,
54
+ ):
55
+ act = activation(act_type)
56
+ model = RT4KDual_Rep(num_channels_color=3,
57
+ num_channels_gray=3,
58
+ num_channels_out=3,
59
+ num_feats=feature_channels,
60
+ num_blocks=num_blocks,
61
+ act=act,
62
+ eca_gamma=0,
63
+ forget=False,
64
+ is_train=is_train,
65
+ layernorm=True,
66
+ residual=True)
67
+ return model