nerfstudio-gnt 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nerfstudio_gnt-0.0.1/GNT/config.py +197 -0
- nerfstudio_gnt-0.0.1/GNT/eval.py +236 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/__init__.py +0 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/criterion.py +22 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/__init__.py +31 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/colmap_read_model.py +316 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/create_training_dataset.py +123 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/data_utils.py +267 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/data_verifier.py +141 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/deepvoxels.py +142 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/google_scanned_objects.py +117 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/ibrnet_collected.py +152 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/llff.py +144 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/llff_data_utils.py +393 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/llff_render.py +110 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/llff_test.py +158 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/nerf_synthetic.py +159 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/nerf_synthetic_render.py +160 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/nmr_dataset.py +170 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/realestate.py +147 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/shiny.py +182 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/shiny_data_utils.py +407 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/shiny_render.py +135 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/spaces_dataset.py +473 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/feature_network.py +321 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/model.py +168 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/projection.py +133 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/render_image.py +107 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/render_ray.py +283 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/sample_ray.py +157 -0
- nerfstudio_gnt-0.0.1/GNT/gnt/transformer_network.py +309 -0
- nerfstudio_gnt-0.0.1/GNT/render.py +193 -0
- nerfstudio_gnt-0.0.1/GNT/train.py +319 -0
- nerfstudio_gnt-0.0.1/GNT/utils.py +301 -0
- nerfstudio_gnt-0.0.1/GNTConfig.py +19 -0
- nerfstudio_gnt-0.0.1/GNTDataManager.py +436 -0
- nerfstudio_gnt-0.0.1/GNTModel.py +310 -0
- nerfstudio_gnt-0.0.1/GNTPipeline.py +71 -0
- nerfstudio_gnt-0.0.1/GNTTrainer.py +32 -0
- nerfstudio_gnt-0.0.1/PKG-INFO +64 -0
- nerfstudio_gnt-0.0.1/README.md +55 -0
- nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/PKG-INFO +64 -0
- nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/SOURCES.txt +47 -0
- nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/dependency_links.txt +1 -0
- nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/entry_points.txt +2 -0
- nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/requires.txt +2 -0
- nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/top_level.txt +6 -0
- nerfstudio_gnt-0.0.1/pyproject.toml +23 -0
- nerfstudio_gnt-0.0.1/setup.cfg +4 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import configargparse
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def config_parser():
|
|
5
|
+
parser = configargparse.ArgumentParser()
|
|
6
|
+
# general
|
|
7
|
+
parser.add_argument("--config", is_config_file=True, help="config file path")
|
|
8
|
+
parser.add_argument(
|
|
9
|
+
"--rootdir",
|
|
10
|
+
type=str,
|
|
11
|
+
default="./",
|
|
12
|
+
help="the path to the project root directory. Replace this path with yours!",
|
|
13
|
+
)
|
|
14
|
+
parser.add_argument("--expname", type=str, help="experiment name")
|
|
15
|
+
parser.add_argument("--distributed", action="store_true", help="if use distributed training")
|
|
16
|
+
parser.add_argument("--local_rank", type=int, default=0, help="rank for distributed training")
|
|
17
|
+
parser.add_argument(
|
|
18
|
+
"-j",
|
|
19
|
+
"--workers",
|
|
20
|
+
default=8,
|
|
21
|
+
type=int,
|
|
22
|
+
metavar="N",
|
|
23
|
+
help="number of data loading workers (default: 8)",
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
########## dataset options ##########
|
|
27
|
+
## train and eval dataset
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--train_dataset",
|
|
30
|
+
type=str,
|
|
31
|
+
default="ibrnet_collected",
|
|
32
|
+
help="the training dataset, should either be a single dataset, "
|
|
33
|
+
'or multiple datasets connected with "+", for example, ibrnet_collected+llff+spaces',
|
|
34
|
+
)
|
|
35
|
+
parser.add_argument(
|
|
36
|
+
"--dataset_weights",
|
|
37
|
+
nargs="+",
|
|
38
|
+
type=float,
|
|
39
|
+
default=[],
|
|
40
|
+
help="the weights for training datasets, valid when multiple datasets are used.",
|
|
41
|
+
)
|
|
42
|
+
parser.add_argument(
|
|
43
|
+
"--train_scenes",
|
|
44
|
+
nargs="+",
|
|
45
|
+
default=[],
|
|
46
|
+
help="optional, specify a subset of training scenes from training dataset",
|
|
47
|
+
)
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
"--eval_dataset", type=str, default="llff_test", help="the dataset to evaluate"
|
|
50
|
+
)
|
|
51
|
+
parser.add_argument(
|
|
52
|
+
"--eval_scenes",
|
|
53
|
+
nargs="+",
|
|
54
|
+
default=[],
|
|
55
|
+
help="optional, specify a subset of scenes from eval_dataset to evaluate",
|
|
56
|
+
)
|
|
57
|
+
## others
|
|
58
|
+
parser.add_argument(
|
|
59
|
+
"--testskip",
|
|
60
|
+
type=int,
|
|
61
|
+
default=8,
|
|
62
|
+
help="will load 1/N images from test/val sets, "
|
|
63
|
+
"useful for large datasets like deepvoxels or nerf_synthetic",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
########## model options ##########
|
|
67
|
+
## ray sampling options
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--sample_mode",
|
|
70
|
+
type=str,
|
|
71
|
+
default="uniform",
|
|
72
|
+
help="how to sample pixels from images for training:" "uniform|center",
|
|
73
|
+
)
|
|
74
|
+
parser.add_argument(
|
|
75
|
+
"--center_ratio", type=float, default=0.8, help="the ratio of center crop to keep"
|
|
76
|
+
)
|
|
77
|
+
parser.add_argument(
|
|
78
|
+
"--N_rand",
|
|
79
|
+
type=int,
|
|
80
|
+
default=32 * 16,
|
|
81
|
+
help="batch size (number of random rays per gradient step)",
|
|
82
|
+
)
|
|
83
|
+
parser.add_argument(
|
|
84
|
+
"--chunk_size",
|
|
85
|
+
type=int,
|
|
86
|
+
default=1024 * 4,
|
|
87
|
+
help="number of rays processed in parallel, decrease if running out of memory",
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
## model options
|
|
91
|
+
parser.add_argument(
|
|
92
|
+
"--coarse_feat_dim", type=int, default=32, help="2D feature dimension for coarse level"
|
|
93
|
+
)
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"--fine_feat_dim", type=int, default=32, help="2D feature dimension for fine level"
|
|
96
|
+
)
|
|
97
|
+
parser.add_argument(
|
|
98
|
+
"--num_source_views",
|
|
99
|
+
type=int,
|
|
100
|
+
default=10,
|
|
101
|
+
help="the number of input source views for each target view",
|
|
102
|
+
)
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--rectify_inplane_rotation", action="store_true", help="if rectify inplane rotation"
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument("--coarse_only", action="store_true", help="use coarse network only")
|
|
107
|
+
parser.add_argument(
|
|
108
|
+
"--anti_alias_pooling", type=int, default=1, help="if use anti-alias pooling"
|
|
109
|
+
)
|
|
110
|
+
parser.add_argument("--trans_depth", type=int, default=4, help="number of transformer layers")
|
|
111
|
+
parser.add_argument("--netwidth", type=int, default=64, help="network intermediate dimension")
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--single_net",
|
|
114
|
+
type=bool,
|
|
115
|
+
default=True,
|
|
116
|
+
help="use single network for both coarse and/or fine sampling",
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
########## checkpoints ##########
|
|
120
|
+
parser.add_argument(
|
|
121
|
+
"--no_reload", action="store_true", help="do not reload weights from saved ckpt"
|
|
122
|
+
)
|
|
123
|
+
parser.add_argument(
|
|
124
|
+
"--ckpt_path",
|
|
125
|
+
type=str,
|
|
126
|
+
default="",
|
|
127
|
+
help="specific weights npy file to reload for coarse network",
|
|
128
|
+
)
|
|
129
|
+
parser.add_argument(
|
|
130
|
+
"--no_load_opt", action="store_true", help="do not load optimizer when reloading"
|
|
131
|
+
)
|
|
132
|
+
parser.add_argument(
|
|
133
|
+
"--no_load_scheduler", action="store_true", help="do not load scheduler when reloading"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
########### iterations & learning rate options ##########
|
|
137
|
+
parser.add_argument("--n_iters", type=int, default=250000, help="num of iterations")
|
|
138
|
+
parser.add_argument(
|
|
139
|
+
"--lrate_feature", type=float, default=1e-3, help="learning rate for feature extractor"
|
|
140
|
+
)
|
|
141
|
+
parser.add_argument("--lrate_gnt", type=float, default=5e-4, help="learning rate for gnt")
|
|
142
|
+
parser.add_argument(
|
|
143
|
+
"--lrate_decay_factor",
|
|
144
|
+
type=float,
|
|
145
|
+
default=0.5,
|
|
146
|
+
help="decay learning rate by a factor every specified number of steps",
|
|
147
|
+
)
|
|
148
|
+
parser.add_argument(
|
|
149
|
+
"--lrate_decay_steps",
|
|
150
|
+
type=int,
|
|
151
|
+
default=50000,
|
|
152
|
+
help="decay learning rate by a factor every specified number of steps",
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
########## rendering options ##########
|
|
156
|
+
parser.add_argument(
|
|
157
|
+
"--N_samples", type=int, default=64, help="number of coarse samples per ray"
|
|
158
|
+
)
|
|
159
|
+
parser.add_argument(
|
|
160
|
+
"--N_importance", type=int, default=64, help="number of important samples per ray"
|
|
161
|
+
)
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"--inv_uniform", action="store_true", help="if True, will uniformly sample inverse depths"
|
|
164
|
+
)
|
|
165
|
+
parser.add_argument(
|
|
166
|
+
"--det", action="store_true", help="deterministic sampling for coarse and fine samples"
|
|
167
|
+
)
|
|
168
|
+
parser.add_argument(
|
|
169
|
+
"--white_bkgd",
|
|
170
|
+
action="store_true",
|
|
171
|
+
help="apply the trick to avoid fitting to white background",
|
|
172
|
+
)
|
|
173
|
+
parser.add_argument(
|
|
174
|
+
"--render_stride",
|
|
175
|
+
type=int,
|
|
176
|
+
default=1,
|
|
177
|
+
help="render with large stride for validation to save time",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
########## logging/saving options ##########
|
|
181
|
+
parser.add_argument("--i_print", type=int, default=100, help="frequency of terminal printout")
|
|
182
|
+
parser.add_argument(
|
|
183
|
+
"--i_img", type=int, default=500, help="frequency of tensorboard image logging"
|
|
184
|
+
)
|
|
185
|
+
parser.add_argument(
|
|
186
|
+
"--i_weights", type=int, default=10000, help="frequency of weight ckpt saving"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
########## evaluation options ##########
|
|
190
|
+
parser.add_argument(
|
|
191
|
+
"--llffhold",
|
|
192
|
+
type=int,
|
|
193
|
+
default=8,
|
|
194
|
+
help="will take every 1/N images as LLFF test set, paper uses 8",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return parser
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
import shutil
|
|
4
|
+
import torch
|
|
5
|
+
import torch.utils.data.distributed
|
|
6
|
+
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
|
|
9
|
+
from gnt.data_loaders import dataset_dict
|
|
10
|
+
from gnt.render_image import render_single_image
|
|
11
|
+
from gnt.model import GNTModel
|
|
12
|
+
from gnt.sample_ray import RaySamplerSingleImage
|
|
13
|
+
from utils import img_HWC2CHW, colorize, img2psnr, lpips, ssim
|
|
14
|
+
import config
|
|
15
|
+
import torch.distributed as dist
|
|
16
|
+
from gnt.projection import Projector
|
|
17
|
+
from gnt.data_loaders.create_training_dataset import create_training_dataset
|
|
18
|
+
import imageio
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def worker_init_fn(worker_id):
|
|
22
|
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def synchronize():
|
|
26
|
+
"""
|
|
27
|
+
Helper function to synchronize (barrier) among all processes when
|
|
28
|
+
using distributed training
|
|
29
|
+
"""
|
|
30
|
+
if not dist.is_available():
|
|
31
|
+
return
|
|
32
|
+
if not dist.is_initialized():
|
|
33
|
+
return
|
|
34
|
+
world_size = dist.get_world_size()
|
|
35
|
+
if world_size == 1:
|
|
36
|
+
return
|
|
37
|
+
dist.barrier()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@torch.no_grad()
|
|
41
|
+
def eval(args):
|
|
42
|
+
|
|
43
|
+
device = "cuda:{}".format(args.local_rank)
|
|
44
|
+
out_folder = os.path.join(args.rootdir, "out", args.expname)
|
|
45
|
+
print("outputs will be saved to {}".format(out_folder))
|
|
46
|
+
os.makedirs(out_folder, exist_ok=True)
|
|
47
|
+
|
|
48
|
+
# save the args and config files
|
|
49
|
+
f = os.path.join(out_folder, "args.txt")
|
|
50
|
+
with open(f, "w") as file:
|
|
51
|
+
for arg in sorted(vars(args)):
|
|
52
|
+
attr = getattr(args, arg)
|
|
53
|
+
file.write("{} = {}\n".format(arg, attr))
|
|
54
|
+
|
|
55
|
+
if args.config is not None:
|
|
56
|
+
f = os.path.join(out_folder, "config.txt")
|
|
57
|
+
if not os.path.isfile(f):
|
|
58
|
+
shutil.copy(args.config, f)
|
|
59
|
+
|
|
60
|
+
if args.run_val == False:
|
|
61
|
+
# create training dataset
|
|
62
|
+
dataset, sampler = create_training_dataset(args)
|
|
63
|
+
# currently only support batch_size=1 (i.e., one set of target and source views) for each GPU node
|
|
64
|
+
# please use distributed parallel on multiple GPUs to train multiple target views per batch
|
|
65
|
+
loader = torch.utils.data.DataLoader(
|
|
66
|
+
dataset,
|
|
67
|
+
batch_size=1,
|
|
68
|
+
worker_init_fn=lambda _: np.random.seed(),
|
|
69
|
+
num_workers=args.workers,
|
|
70
|
+
pin_memory=True,
|
|
71
|
+
sampler=sampler,
|
|
72
|
+
shuffle=True if sampler is None else False,
|
|
73
|
+
)
|
|
74
|
+
iterator = iter(loader)
|
|
75
|
+
else:
|
|
76
|
+
# create validation dataset
|
|
77
|
+
dataset = dataset_dict[args.eval_dataset](args, "validation", scenes=args.eval_scenes)
|
|
78
|
+
loader = DataLoader(dataset, batch_size=1)
|
|
79
|
+
iterator = iter(loader)
|
|
80
|
+
|
|
81
|
+
# Create GNT model
|
|
82
|
+
model = GNTModel(
|
|
83
|
+
args, load_opt=not args.no_load_opt, load_scheduler=not args.no_load_scheduler
|
|
84
|
+
)
|
|
85
|
+
# create projector
|
|
86
|
+
projector = Projector(device=device)
|
|
87
|
+
|
|
88
|
+
indx = 0
|
|
89
|
+
psnr_scores = []
|
|
90
|
+
lpips_scores = []
|
|
91
|
+
ssim_scores = []
|
|
92
|
+
while True:
|
|
93
|
+
try:
|
|
94
|
+
data = next(iterator)
|
|
95
|
+
except:
|
|
96
|
+
break
|
|
97
|
+
if args.local_rank == 0:
|
|
98
|
+
tmp_ray_sampler = RaySamplerSingleImage(data, device, render_stride=args.render_stride)
|
|
99
|
+
H, W = tmp_ray_sampler.H, tmp_ray_sampler.W
|
|
100
|
+
gt_img = tmp_ray_sampler.rgb.reshape(H, W, 3)
|
|
101
|
+
psnr_curr_img, lpips_curr_img, ssim_curr_img = log_view(
|
|
102
|
+
indx,
|
|
103
|
+
args,
|
|
104
|
+
model,
|
|
105
|
+
tmp_ray_sampler,
|
|
106
|
+
projector,
|
|
107
|
+
gt_img,
|
|
108
|
+
render_stride=args.render_stride,
|
|
109
|
+
prefix="val/" if args.run_val else "train/",
|
|
110
|
+
out_folder=out_folder,
|
|
111
|
+
ret_alpha=args.N_importance > 0,
|
|
112
|
+
single_net=args.single_net,
|
|
113
|
+
)
|
|
114
|
+
psnr_scores.append(psnr_curr_img)
|
|
115
|
+
lpips_scores.append(lpips_curr_img)
|
|
116
|
+
ssim_scores.append(ssim_curr_img)
|
|
117
|
+
torch.cuda.empty_cache()
|
|
118
|
+
indx += 1
|
|
119
|
+
print("Average PSNR: ", np.mean(psnr_scores))
|
|
120
|
+
print("Average LPIPS: ", np.mean(lpips_scores))
|
|
121
|
+
print("Average SSIM: ", np.mean(ssim_scores))
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@torch.no_grad()
|
|
125
|
+
def log_view(
|
|
126
|
+
global_step,
|
|
127
|
+
args,
|
|
128
|
+
model,
|
|
129
|
+
ray_sampler,
|
|
130
|
+
projector,
|
|
131
|
+
gt_img,
|
|
132
|
+
render_stride=1,
|
|
133
|
+
prefix="",
|
|
134
|
+
out_folder="",
|
|
135
|
+
ret_alpha=False,
|
|
136
|
+
single_net=True,
|
|
137
|
+
):
|
|
138
|
+
model.switch_to_eval()
|
|
139
|
+
with torch.no_grad():
|
|
140
|
+
ray_batch = ray_sampler.get_all()
|
|
141
|
+
if model.feature_net is not None:
|
|
142
|
+
featmaps = model.feature_net(ray_batch["src_rgbs"].squeeze(0).permute(0, 3, 1, 2))
|
|
143
|
+
else:
|
|
144
|
+
featmaps = [None, None]
|
|
145
|
+
ret = render_single_image(
|
|
146
|
+
ray_sampler=ray_sampler,
|
|
147
|
+
ray_batch=ray_batch,
|
|
148
|
+
model=model,
|
|
149
|
+
projector=projector,
|
|
150
|
+
chunk_size=args.chunk_size,
|
|
151
|
+
N_samples=args.N_samples,
|
|
152
|
+
inv_uniform=args.inv_uniform,
|
|
153
|
+
det=True,
|
|
154
|
+
N_importance=args.N_importance,
|
|
155
|
+
white_bkgd=args.white_bkgd,
|
|
156
|
+
render_stride=render_stride,
|
|
157
|
+
featmaps=featmaps,
|
|
158
|
+
ret_alpha=ret_alpha,
|
|
159
|
+
single_net=single_net,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
average_im = ray_sampler.src_rgbs.cpu().mean(dim=(0, 1))
|
|
163
|
+
|
|
164
|
+
if args.render_stride != 1:
|
|
165
|
+
gt_img = gt_img[::render_stride, ::render_stride]
|
|
166
|
+
average_im = average_im[::render_stride, ::render_stride]
|
|
167
|
+
|
|
168
|
+
rgb_gt = img_HWC2CHW(gt_img)
|
|
169
|
+
average_im = img_HWC2CHW(average_im)
|
|
170
|
+
|
|
171
|
+
rgb_coarse = img_HWC2CHW(ret["outputs_coarse"]["rgb"].detach().cpu())
|
|
172
|
+
if "depth" in ret["outputs_coarse"].keys():
|
|
173
|
+
depth_pred = ret["outputs_coarse"]["depth"].detach().cpu()
|
|
174
|
+
depth_coarse = img_HWC2CHW(colorize(depth_pred, cmap_name="jet"))
|
|
175
|
+
else:
|
|
176
|
+
depth_coarse = None
|
|
177
|
+
|
|
178
|
+
if ret["outputs_fine"] is not None:
|
|
179
|
+
rgb_fine = img_HWC2CHW(ret["outputs_fine"]["rgb"].detach().cpu())
|
|
180
|
+
if "depth" in ret["outputs_fine"].keys():
|
|
181
|
+
depth_pred = ret["outputs_fine"]["depth"].detach().cpu()
|
|
182
|
+
depth_fine = img_HWC2CHW(colorize(depth_pred, cmap_name="jet"))
|
|
183
|
+
else:
|
|
184
|
+
rgb_fine = None
|
|
185
|
+
depth_fine = None
|
|
186
|
+
|
|
187
|
+
rgb_coarse = rgb_coarse.permute(1, 2, 0).detach().cpu().numpy()
|
|
188
|
+
filename = os.path.join(out_folder, prefix[:-1] + "_{:03d}_coarse.png".format(global_step))
|
|
189
|
+
imageio.imwrite(filename, rgb_coarse)
|
|
190
|
+
|
|
191
|
+
if depth_coarse is not None:
|
|
192
|
+
depth_coarse = depth_coarse.permute(1, 2, 0).detach().cpu().numpy()
|
|
193
|
+
filename = os.path.join(
|
|
194
|
+
out_folder, prefix[:-1] + "_{:03d}_coarse_depth.png".format(global_step)
|
|
195
|
+
)
|
|
196
|
+
imageio.imwrite(filename, depth_coarse)
|
|
197
|
+
|
|
198
|
+
if rgb_fine is not None:
|
|
199
|
+
rgb_fine = rgb_fine.permute(1, 2, 0).detach().cpu().numpy()
|
|
200
|
+
filename = os.path.join(out_folder, prefix[:-1] + "_{:03d}_fine.png".format(global_step))
|
|
201
|
+
imageio.imwrite(filename, rgb_fine)
|
|
202
|
+
|
|
203
|
+
if depth_fine is not None:
|
|
204
|
+
depth_fine = depth_fine.permute(1, 2, 0).detach().cpu().numpy()
|
|
205
|
+
filename = os.path.join(
|
|
206
|
+
out_folder, prefix[:-1] + "_{:03d}_fine_depth.png".format(global_step)
|
|
207
|
+
)
|
|
208
|
+
imageio.imwrite(filename, depth_fine)
|
|
209
|
+
|
|
210
|
+
# write scalar
|
|
211
|
+
pred_rgb = (
|
|
212
|
+
ret["outputs_fine"]["rgb"]
|
|
213
|
+
if ret["outputs_fine"] is not None
|
|
214
|
+
else ret["outputs_coarse"]["rgb"]
|
|
215
|
+
)
|
|
216
|
+
pred_rgb = torch.clip(pred_rgb, 0.0, 1.0)
|
|
217
|
+
lpips_curr_img = lpips(pred_rgb, gt_img, format="HWC").item()
|
|
218
|
+
ssim_curr_img = ssim(pred_rgb, gt_img, format="HWC").item()
|
|
219
|
+
psnr_curr_img = img2psnr(pred_rgb.detach().cpu(), gt_img)
|
|
220
|
+
print(prefix + "psnr_image: ", psnr_curr_img)
|
|
221
|
+
print(prefix + "lpips_image: ", lpips_curr_img)
|
|
222
|
+
print(prefix + "ssim_image: ", ssim_curr_img)
|
|
223
|
+
return psnr_curr_img, lpips_curr_img, ssim_curr_img
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
if __name__ == "__main__":
|
|
227
|
+
parser = config.config_parser()
|
|
228
|
+
parser.add_argument("--run_val", action="store_true", help="run on val set")
|
|
229
|
+
args = parser.parse_args()
|
|
230
|
+
|
|
231
|
+
if args.distributed:
|
|
232
|
+
torch.cuda.set_device(args.local_rank)
|
|
233
|
+
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
|
234
|
+
synchronize()
|
|
235
|
+
|
|
236
|
+
eval(args)
|
|
File without changes
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from utils import img2mse
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Criterion(nn.Module):
|
|
6
|
+
def __init__(self):
|
|
7
|
+
super().__init__()
|
|
8
|
+
|
|
9
|
+
def forward(self, outputs, ray_batch, scalars_to_log):
|
|
10
|
+
"""
|
|
11
|
+
training criterion
|
|
12
|
+
"""
|
|
13
|
+
pred_rgb = outputs["rgb"]
|
|
14
|
+
if "mask" in outputs:
|
|
15
|
+
pred_mask = outputs["mask"].float()
|
|
16
|
+
else:
|
|
17
|
+
pred_mask = None
|
|
18
|
+
gt_rgb = ray_batch["rgb"]
|
|
19
|
+
|
|
20
|
+
loss = img2mse(pred_rgb, gt_rgb, pred_mask)
|
|
21
|
+
|
|
22
|
+
return loss, scalars_to_log
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from .google_scanned_objects import *
|
|
2
|
+
from .realestate import *
|
|
3
|
+
from .deepvoxels import *
|
|
4
|
+
from .realestate import *
|
|
5
|
+
from .llff import *
|
|
6
|
+
from .llff_test import *
|
|
7
|
+
from .ibrnet_collected import *
|
|
8
|
+
from .realestate import *
|
|
9
|
+
from .spaces_dataset import *
|
|
10
|
+
from .nerf_synthetic import *
|
|
11
|
+
from .shiny import *
|
|
12
|
+
from .llff_render import *
|
|
13
|
+
from .shiny_render import *
|
|
14
|
+
from .nerf_synthetic_render import *
|
|
15
|
+
from .nmr_dataset import *
|
|
16
|
+
|
|
17
|
+
dataset_dict = {
|
|
18
|
+
"spaces": SpacesFreeDataset,
|
|
19
|
+
"google_scanned": GoogleScannedDataset,
|
|
20
|
+
"realestate": RealEstateDataset,
|
|
21
|
+
"deepvoxels": DeepVoxelsDataset,
|
|
22
|
+
"nerf_synthetic": NerfSyntheticDataset,
|
|
23
|
+
"llff": LLFFDataset,
|
|
24
|
+
"ibrnet_collected": IBRNetCollectedDataset,
|
|
25
|
+
"llff_test": LLFFTestDataset,
|
|
26
|
+
"shiny": ShinyDataset,
|
|
27
|
+
"llff_render": LLFFRenderDataset,
|
|
28
|
+
"shiny_render": ShinyRenderDataset,
|
|
29
|
+
"nerf_synthetic_render": NerfSyntheticRenderDataset,
|
|
30
|
+
"nmr": NMRDataset,
|
|
31
|
+
}
|