nerfstudio-gnt 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. nerfstudio_gnt-0.0.1/GNT/config.py +197 -0
  2. nerfstudio_gnt-0.0.1/GNT/eval.py +236 -0
  3. nerfstudio_gnt-0.0.1/GNT/gnt/__init__.py +0 -0
  4. nerfstudio_gnt-0.0.1/GNT/gnt/criterion.py +22 -0
  5. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/__init__.py +31 -0
  6. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/colmap_read_model.py +316 -0
  7. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/create_training_dataset.py +123 -0
  8. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/data_utils.py +267 -0
  9. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/data_verifier.py +141 -0
  10. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/deepvoxels.py +142 -0
  11. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/google_scanned_objects.py +117 -0
  12. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/ibrnet_collected.py +152 -0
  13. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/llff.py +144 -0
  14. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/llff_data_utils.py +393 -0
  15. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/llff_render.py +110 -0
  16. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/llff_test.py +158 -0
  17. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/nerf_synthetic.py +159 -0
  18. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/nerf_synthetic_render.py +160 -0
  19. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/nmr_dataset.py +170 -0
  20. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/realestate.py +147 -0
  21. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/shiny.py +182 -0
  22. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/shiny_data_utils.py +407 -0
  23. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/shiny_render.py +135 -0
  24. nerfstudio_gnt-0.0.1/GNT/gnt/data_loaders/spaces_dataset.py +473 -0
  25. nerfstudio_gnt-0.0.1/GNT/gnt/feature_network.py +321 -0
  26. nerfstudio_gnt-0.0.1/GNT/gnt/model.py +168 -0
  27. nerfstudio_gnt-0.0.1/GNT/gnt/projection.py +133 -0
  28. nerfstudio_gnt-0.0.1/GNT/gnt/render_image.py +107 -0
  29. nerfstudio_gnt-0.0.1/GNT/gnt/render_ray.py +283 -0
  30. nerfstudio_gnt-0.0.1/GNT/gnt/sample_ray.py +157 -0
  31. nerfstudio_gnt-0.0.1/GNT/gnt/transformer_network.py +309 -0
  32. nerfstudio_gnt-0.0.1/GNT/render.py +193 -0
  33. nerfstudio_gnt-0.0.1/GNT/train.py +319 -0
  34. nerfstudio_gnt-0.0.1/GNT/utils.py +301 -0
  35. nerfstudio_gnt-0.0.1/GNTConfig.py +19 -0
  36. nerfstudio_gnt-0.0.1/GNTDataManager.py +436 -0
  37. nerfstudio_gnt-0.0.1/GNTModel.py +310 -0
  38. nerfstudio_gnt-0.0.1/GNTPipeline.py +71 -0
  39. nerfstudio_gnt-0.0.1/GNTTrainer.py +32 -0
  40. nerfstudio_gnt-0.0.1/PKG-INFO +64 -0
  41. nerfstudio_gnt-0.0.1/README.md +55 -0
  42. nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/PKG-INFO +64 -0
  43. nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/SOURCES.txt +47 -0
  44. nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/dependency_links.txt +1 -0
  45. nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/entry_points.txt +2 -0
  46. nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/requires.txt +2 -0
  47. nerfstudio_gnt-0.0.1/nerfstudio_gnt.egg-info/top_level.txt +6 -0
  48. nerfstudio_gnt-0.0.1/pyproject.toml +23 -0
  49. nerfstudio_gnt-0.0.1/setup.cfg +4 -0
@@ -0,0 +1,197 @@
1
+ import configargparse
2
+
3
+
4
+ def config_parser():
5
+ parser = configargparse.ArgumentParser()
6
+ # general
7
+ parser.add_argument("--config", is_config_file=True, help="config file path")
8
+ parser.add_argument(
9
+ "--rootdir",
10
+ type=str,
11
+ default="./",
12
+ help="the path to the project root directory. Replace this path with yours!",
13
+ )
14
+ parser.add_argument("--expname", type=str, help="experiment name")
15
+ parser.add_argument("--distributed", action="store_true", help="if use distributed training")
16
+ parser.add_argument("--local_rank", type=int, default=0, help="rank for distributed training")
17
+ parser.add_argument(
18
+ "-j",
19
+ "--workers",
20
+ default=8,
21
+ type=int,
22
+ metavar="N",
23
+ help="number of data loading workers (default: 8)",
24
+ )
25
+
26
+ ########## dataset options ##########
27
+ ## train and eval dataset
28
+ parser.add_argument(
29
+ "--train_dataset",
30
+ type=str,
31
+ default="ibrnet_collected",
32
+ help="the training dataset, should either be a single dataset, "
33
+ 'or multiple datasets connected with "+", for example, ibrnet_collected+llff+spaces',
34
+ )
35
+ parser.add_argument(
36
+ "--dataset_weights",
37
+ nargs="+",
38
+ type=float,
39
+ default=[],
40
+ help="the weights for training datasets, valid when multiple datasets are used.",
41
+ )
42
+ parser.add_argument(
43
+ "--train_scenes",
44
+ nargs="+",
45
+ default=[],
46
+ help="optional, specify a subset of training scenes from training dataset",
47
+ )
48
+ parser.add_argument(
49
+ "--eval_dataset", type=str, default="llff_test", help="the dataset to evaluate"
50
+ )
51
+ parser.add_argument(
52
+ "--eval_scenes",
53
+ nargs="+",
54
+ default=[],
55
+ help="optional, specify a subset of scenes from eval_dataset to evaluate",
56
+ )
57
+ ## others
58
+ parser.add_argument(
59
+ "--testskip",
60
+ type=int,
61
+ default=8,
62
+ help="will load 1/N images from test/val sets, "
63
+ "useful for large datasets like deepvoxels or nerf_synthetic",
64
+ )
65
+
66
+ ########## model options ##########
67
+ ## ray sampling options
68
+ parser.add_argument(
69
+ "--sample_mode",
70
+ type=str,
71
+ default="uniform",
72
+ help="how to sample pixels from images for training:" "uniform|center",
73
+ )
74
+ parser.add_argument(
75
+ "--center_ratio", type=float, default=0.8, help="the ratio of center crop to keep"
76
+ )
77
+ parser.add_argument(
78
+ "--N_rand",
79
+ type=int,
80
+ default=32 * 16,
81
+ help="batch size (number of random rays per gradient step)",
82
+ )
83
+ parser.add_argument(
84
+ "--chunk_size",
85
+ type=int,
86
+ default=1024 * 4,
87
+ help="number of rays processed in parallel, decrease if running out of memory",
88
+ )
89
+
90
+ ## model options
91
+ parser.add_argument(
92
+ "--coarse_feat_dim", type=int, default=32, help="2D feature dimension for coarse level"
93
+ )
94
+ parser.add_argument(
95
+ "--fine_feat_dim", type=int, default=32, help="2D feature dimension for fine level"
96
+ )
97
+ parser.add_argument(
98
+ "--num_source_views",
99
+ type=int,
100
+ default=10,
101
+ help="the number of input source views for each target view",
102
+ )
103
+ parser.add_argument(
104
+ "--rectify_inplane_rotation", action="store_true", help="if rectify inplane rotation"
105
+ )
106
+ parser.add_argument("--coarse_only", action="store_true", help="use coarse network only")
107
+ parser.add_argument(
108
+ "--anti_alias_pooling", type=int, default=1, help="if use anti-alias pooling"
109
+ )
110
+ parser.add_argument("--trans_depth", type=int, default=4, help="number of transformer layers")
111
+ parser.add_argument("--netwidth", type=int, default=64, help="network intermediate dimension")
112
+ parser.add_argument(
113
+ "--single_net",
114
+ type=bool,
115
+ default=True,
116
+ help="use single network for both coarse and/or fine sampling",
117
+ )
118
+
119
+ ########## checkpoints ##########
120
+ parser.add_argument(
121
+ "--no_reload", action="store_true", help="do not reload weights from saved ckpt"
122
+ )
123
+ parser.add_argument(
124
+ "--ckpt_path",
125
+ type=str,
126
+ default="",
127
+ help="specific weights npy file to reload for coarse network",
128
+ )
129
+ parser.add_argument(
130
+ "--no_load_opt", action="store_true", help="do not load optimizer when reloading"
131
+ )
132
+ parser.add_argument(
133
+ "--no_load_scheduler", action="store_true", help="do not load scheduler when reloading"
134
+ )
135
+
136
+ ########### iterations & learning rate options ##########
137
+ parser.add_argument("--n_iters", type=int, default=250000, help="num of iterations")
138
+ parser.add_argument(
139
+ "--lrate_feature", type=float, default=1e-3, help="learning rate for feature extractor"
140
+ )
141
+ parser.add_argument("--lrate_gnt", type=float, default=5e-4, help="learning rate for gnt")
142
+ parser.add_argument(
143
+ "--lrate_decay_factor",
144
+ type=float,
145
+ default=0.5,
146
+ help="decay learning rate by a factor every specified number of steps",
147
+ )
148
+ parser.add_argument(
149
+ "--lrate_decay_steps",
150
+ type=int,
151
+ default=50000,
152
+ help="decay learning rate by a factor every specified number of steps",
153
+ )
154
+
155
+ ########## rendering options ##########
156
+ parser.add_argument(
157
+ "--N_samples", type=int, default=64, help="number of coarse samples per ray"
158
+ )
159
+ parser.add_argument(
160
+ "--N_importance", type=int, default=64, help="number of important samples per ray"
161
+ )
162
+ parser.add_argument(
163
+ "--inv_uniform", action="store_true", help="if True, will uniformly sample inverse depths"
164
+ )
165
+ parser.add_argument(
166
+ "--det", action="store_true", help="deterministic sampling for coarse and fine samples"
167
+ )
168
+ parser.add_argument(
169
+ "--white_bkgd",
170
+ action="store_true",
171
+ help="apply the trick to avoid fitting to white background",
172
+ )
173
+ parser.add_argument(
174
+ "--render_stride",
175
+ type=int,
176
+ default=1,
177
+ help="render with large stride for validation to save time",
178
+ )
179
+
180
+ ########## logging/saving options ##########
181
+ parser.add_argument("--i_print", type=int, default=100, help="frequency of terminal printout")
182
+ parser.add_argument(
183
+ "--i_img", type=int, default=500, help="frequency of tensorboard image logging"
184
+ )
185
+ parser.add_argument(
186
+ "--i_weights", type=int, default=10000, help="frequency of weight ckpt saving"
187
+ )
188
+
189
+ ########## evaluation options ##########
190
+ parser.add_argument(
191
+ "--llffhold",
192
+ type=int,
193
+ default=8,
194
+ help="will take every 1/N images as LLFF test set, paper uses 8",
195
+ )
196
+
197
+ return parser
@@ -0,0 +1,236 @@
1
+ import os
2
+ import numpy as np
3
+ import shutil
4
+ import torch
5
+ import torch.utils.data.distributed
6
+
7
+ from torch.utils.data import DataLoader
8
+
9
+ from gnt.data_loaders import dataset_dict
10
+ from gnt.render_image import render_single_image
11
+ from gnt.model import GNTModel
12
+ from gnt.sample_ray import RaySamplerSingleImage
13
+ from utils import img_HWC2CHW, colorize, img2psnr, lpips, ssim
14
+ import config
15
+ import torch.distributed as dist
16
+ from gnt.projection import Projector
17
+ from gnt.data_loaders.create_training_dataset import create_training_dataset
18
+ import imageio
19
+
20
+
21
+ def worker_init_fn(worker_id):
22
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
23
+
24
+
25
+ def synchronize():
26
+ """
27
+ Helper function to synchronize (barrier) among all processes when
28
+ using distributed training
29
+ """
30
+ if not dist.is_available():
31
+ return
32
+ if not dist.is_initialized():
33
+ return
34
+ world_size = dist.get_world_size()
35
+ if world_size == 1:
36
+ return
37
+ dist.barrier()
38
+
39
+
40
+ @torch.no_grad()
41
+ def eval(args):
42
+
43
+ device = "cuda:{}".format(args.local_rank)
44
+ out_folder = os.path.join(args.rootdir, "out", args.expname)
45
+ print("outputs will be saved to {}".format(out_folder))
46
+ os.makedirs(out_folder, exist_ok=True)
47
+
48
+ # save the args and config files
49
+ f = os.path.join(out_folder, "args.txt")
50
+ with open(f, "w") as file:
51
+ for arg in sorted(vars(args)):
52
+ attr = getattr(args, arg)
53
+ file.write("{} = {}\n".format(arg, attr))
54
+
55
+ if args.config is not None:
56
+ f = os.path.join(out_folder, "config.txt")
57
+ if not os.path.isfile(f):
58
+ shutil.copy(args.config, f)
59
+
60
+ if args.run_val == False:
61
+ # create training dataset
62
+ dataset, sampler = create_training_dataset(args)
63
+ # currently only support batch_size=1 (i.e., one set of target and source views) for each GPU node
64
+ # please use distributed parallel on multiple GPUs to train multiple target views per batch
65
+ loader = torch.utils.data.DataLoader(
66
+ dataset,
67
+ batch_size=1,
68
+ worker_init_fn=lambda _: np.random.seed(),
69
+ num_workers=args.workers,
70
+ pin_memory=True,
71
+ sampler=sampler,
72
+ shuffle=True if sampler is None else False,
73
+ )
74
+ iterator = iter(loader)
75
+ else:
76
+ # create validation dataset
77
+ dataset = dataset_dict[args.eval_dataset](args, "validation", scenes=args.eval_scenes)
78
+ loader = DataLoader(dataset, batch_size=1)
79
+ iterator = iter(loader)
80
+
81
+ # Create GNT model
82
+ model = GNTModel(
83
+ args, load_opt=not args.no_load_opt, load_scheduler=not args.no_load_scheduler
84
+ )
85
+ # create projector
86
+ projector = Projector(device=device)
87
+
88
+ indx = 0
89
+ psnr_scores = []
90
+ lpips_scores = []
91
+ ssim_scores = []
92
+ while True:
93
+ try:
94
+ data = next(iterator)
95
+ except:
96
+ break
97
+ if args.local_rank == 0:
98
+ tmp_ray_sampler = RaySamplerSingleImage(data, device, render_stride=args.render_stride)
99
+ H, W = tmp_ray_sampler.H, tmp_ray_sampler.W
100
+ gt_img = tmp_ray_sampler.rgb.reshape(H, W, 3)
101
+ psnr_curr_img, lpips_curr_img, ssim_curr_img = log_view(
102
+ indx,
103
+ args,
104
+ model,
105
+ tmp_ray_sampler,
106
+ projector,
107
+ gt_img,
108
+ render_stride=args.render_stride,
109
+ prefix="val/" if args.run_val else "train/",
110
+ out_folder=out_folder,
111
+ ret_alpha=args.N_importance > 0,
112
+ single_net=args.single_net,
113
+ )
114
+ psnr_scores.append(psnr_curr_img)
115
+ lpips_scores.append(lpips_curr_img)
116
+ ssim_scores.append(ssim_curr_img)
117
+ torch.cuda.empty_cache()
118
+ indx += 1
119
+ print("Average PSNR: ", np.mean(psnr_scores))
120
+ print("Average LPIPS: ", np.mean(lpips_scores))
121
+ print("Average SSIM: ", np.mean(ssim_scores))
122
+
123
+
124
+ @torch.no_grad()
125
+ def log_view(
126
+ global_step,
127
+ args,
128
+ model,
129
+ ray_sampler,
130
+ projector,
131
+ gt_img,
132
+ render_stride=1,
133
+ prefix="",
134
+ out_folder="",
135
+ ret_alpha=False,
136
+ single_net=True,
137
+ ):
138
+ model.switch_to_eval()
139
+ with torch.no_grad():
140
+ ray_batch = ray_sampler.get_all()
141
+ if model.feature_net is not None:
142
+ featmaps = model.feature_net(ray_batch["src_rgbs"].squeeze(0).permute(0, 3, 1, 2))
143
+ else:
144
+ featmaps = [None, None]
145
+ ret = render_single_image(
146
+ ray_sampler=ray_sampler,
147
+ ray_batch=ray_batch,
148
+ model=model,
149
+ projector=projector,
150
+ chunk_size=args.chunk_size,
151
+ N_samples=args.N_samples,
152
+ inv_uniform=args.inv_uniform,
153
+ det=True,
154
+ N_importance=args.N_importance,
155
+ white_bkgd=args.white_bkgd,
156
+ render_stride=render_stride,
157
+ featmaps=featmaps,
158
+ ret_alpha=ret_alpha,
159
+ single_net=single_net,
160
+ )
161
+
162
+ average_im = ray_sampler.src_rgbs.cpu().mean(dim=(0, 1))
163
+
164
+ if args.render_stride != 1:
165
+ gt_img = gt_img[::render_stride, ::render_stride]
166
+ average_im = average_im[::render_stride, ::render_stride]
167
+
168
+ rgb_gt = img_HWC2CHW(gt_img)
169
+ average_im = img_HWC2CHW(average_im)
170
+
171
+ rgb_coarse = img_HWC2CHW(ret["outputs_coarse"]["rgb"].detach().cpu())
172
+ if "depth" in ret["outputs_coarse"].keys():
173
+ depth_pred = ret["outputs_coarse"]["depth"].detach().cpu()
174
+ depth_coarse = img_HWC2CHW(colorize(depth_pred, cmap_name="jet"))
175
+ else:
176
+ depth_coarse = None
177
+
178
+ if ret["outputs_fine"] is not None:
179
+ rgb_fine = img_HWC2CHW(ret["outputs_fine"]["rgb"].detach().cpu())
180
+ if "depth" in ret["outputs_fine"].keys():
181
+ depth_pred = ret["outputs_fine"]["depth"].detach().cpu()
182
+ depth_fine = img_HWC2CHW(colorize(depth_pred, cmap_name="jet"))
183
+ else:
184
+ rgb_fine = None
185
+ depth_fine = None
186
+
187
+ rgb_coarse = rgb_coarse.permute(1, 2, 0).detach().cpu().numpy()
188
+ filename = os.path.join(out_folder, prefix[:-1] + "_{:03d}_coarse.png".format(global_step))
189
+ imageio.imwrite(filename, rgb_coarse)
190
+
191
+ if depth_coarse is not None:
192
+ depth_coarse = depth_coarse.permute(1, 2, 0).detach().cpu().numpy()
193
+ filename = os.path.join(
194
+ out_folder, prefix[:-1] + "_{:03d}_coarse_depth.png".format(global_step)
195
+ )
196
+ imageio.imwrite(filename, depth_coarse)
197
+
198
+ if rgb_fine is not None:
199
+ rgb_fine = rgb_fine.permute(1, 2, 0).detach().cpu().numpy()
200
+ filename = os.path.join(out_folder, prefix[:-1] + "_{:03d}_fine.png".format(global_step))
201
+ imageio.imwrite(filename, rgb_fine)
202
+
203
+ if depth_fine is not None:
204
+ depth_fine = depth_fine.permute(1, 2, 0).detach().cpu().numpy()
205
+ filename = os.path.join(
206
+ out_folder, prefix[:-1] + "_{:03d}_fine_depth.png".format(global_step)
207
+ )
208
+ imageio.imwrite(filename, depth_fine)
209
+
210
+ # write scalar
211
+ pred_rgb = (
212
+ ret["outputs_fine"]["rgb"]
213
+ if ret["outputs_fine"] is not None
214
+ else ret["outputs_coarse"]["rgb"]
215
+ )
216
+ pred_rgb = torch.clip(pred_rgb, 0.0, 1.0)
217
+ lpips_curr_img = lpips(pred_rgb, gt_img, format="HWC").item()
218
+ ssim_curr_img = ssim(pred_rgb, gt_img, format="HWC").item()
219
+ psnr_curr_img = img2psnr(pred_rgb.detach().cpu(), gt_img)
220
+ print(prefix + "psnr_image: ", psnr_curr_img)
221
+ print(prefix + "lpips_image: ", lpips_curr_img)
222
+ print(prefix + "ssim_image: ", ssim_curr_img)
223
+ return psnr_curr_img, lpips_curr_img, ssim_curr_img
224
+
225
+
226
+ if __name__ == "__main__":
227
+ parser = config.config_parser()
228
+ parser.add_argument("--run_val", action="store_true", help="run on val set")
229
+ args = parser.parse_args()
230
+
231
+ if args.distributed:
232
+ torch.cuda.set_device(args.local_rank)
233
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
234
+ synchronize()
235
+
236
+ eval(args)
File without changes
@@ -0,0 +1,22 @@
1
+ import torch.nn as nn
2
+ from utils import img2mse
3
+
4
+
5
+ class Criterion(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, outputs, ray_batch, scalars_to_log):
10
+ """
11
+ training criterion
12
+ """
13
+ pred_rgb = outputs["rgb"]
14
+ if "mask" in outputs:
15
+ pred_mask = outputs["mask"].float()
16
+ else:
17
+ pred_mask = None
18
+ gt_rgb = ray_batch["rgb"]
19
+
20
+ loss = img2mse(pred_rgb, gt_rgb, pred_mask)
21
+
22
+ return loss, scalars_to_log
@@ -0,0 +1,31 @@
1
+ from .google_scanned_objects import *
2
+ from .realestate import *
3
+ from .deepvoxels import *
4
+ from .realestate import *
5
+ from .llff import *
6
+ from .llff_test import *
7
+ from .ibrnet_collected import *
8
+ from .realestate import *
9
+ from .spaces_dataset import *
10
+ from .nerf_synthetic import *
11
+ from .shiny import *
12
+ from .llff_render import *
13
+ from .shiny_render import *
14
+ from .nerf_synthetic_render import *
15
+ from .nmr_dataset import *
16
+
17
+ dataset_dict = {
18
+ "spaces": SpacesFreeDataset,
19
+ "google_scanned": GoogleScannedDataset,
20
+ "realestate": RealEstateDataset,
21
+ "deepvoxels": DeepVoxelsDataset,
22
+ "nerf_synthetic": NerfSyntheticDataset,
23
+ "llff": LLFFDataset,
24
+ "ibrnet_collected": IBRNetCollectedDataset,
25
+ "llff_test": LLFFTestDataset,
26
+ "shiny": ShinyDataset,
27
+ "llff_render": LLFFRenderDataset,
28
+ "shiny_render": ShinyRenderDataset,
29
+ "nerf_synthetic_render": NerfSyntheticRenderDataset,
30
+ "nmr": NMRDataset,
31
+ }