nerfstudio-pixelnerf 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
PixelNeRFConfig.py ADDED
@@ -0,0 +1,37 @@
1
+ """nerstudio-pixel-nerf/PixelNeRF.py"""
2
+
3
+ from PixelNeRFDataManager import PixelNeRFDataManagerConfig
4
+ from PixelNeRFModel import PixelNeRFModelConfig
5
+ from nerfstudio.configs.base_config import ViewerConfig
6
+ from nerfstudio.engine.optimizers import AdamOptimizerConfig
7
+ from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig
8
+ from nerfstudio.engine.trainer import TrainerConfig
9
+ from nerfstudio.plugins.types import MethodSpecification
10
+ from PixelNerfPipeline import PixelNerfPipelineConfig
11
+
12
+
13
+ PixelNeRF = MethodSpecification(
14
+ config=TrainerConfig(
15
+ method_name="pixel-nerf",
16
+ steps_per_eval_batch=500,
17
+ steps_per_save=2000,
18
+ max_num_iterations=300000,
19
+ mixed_precision=True,
20
+ pipeline=PixelNerfPipelineConfig(
21
+ datamanager=PixelNeRFDataManagerConfig(),
22
+ model=PixelNeRFModelConfig(),
23
+ ),
24
+ optimizers={
25
+ "network": {
26
+ "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
27
+ "scheduler": ExponentialDecaySchedulerConfig(
28
+ lr_final=5e-5,
29
+ max_steps=300000,
30
+ ),
31
+ },
32
+ },
33
+ viewer=ViewerConfig(num_rays_per_chunk=1 << 11),
34
+ vis="tensorboard",
35
+ ),
36
+ description="Configuration for the PixelNeRF method"
37
+ )
@@ -0,0 +1,86 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, Literal, Tuple, Type, Union
3
+ from nerfstudio.cameras.rays import RayBundle
4
+ from nerfstudio.data.datamanagers.base_datamanager import (
5
+ VanillaDataManager,
6
+ )
7
+
8
+ import random
9
+ import torch
10
+
11
+
12
+ @dataclass
13
+ class PixelNeRFDataManagerConfig(VanillaDataManager):
14
+ """Configuration for the GNT data manager.
15
+
16
+ Args:
17
+ _target: The target class to instantiate, in this case, GNTDataManager
18
+ """
19
+ _target: Type = field(default_factory=lambda: PixelNeRFDataManager, init=False)
20
+ num_source_views: int = 3
21
+ """Number of source views to sample from the dataset for conditioning the pixelNeRF model. The paper typically uses 3 views, but you can experiment with this number."""
22
+
23
+
24
+ class PixelNeRFDataManager(VanillaDataManager):
25
+ config: PixelNeRFDataManagerConfig
26
+
27
+ def __init__(
28
+ self,
29
+ config: PixelNeRFDataManagerConfig,
30
+ device: Union[torch.device, str] = "cpu",
31
+ test_mode: Literal["test", "val", "inference"] = "val",
32
+ **kwargs,
33
+ ):
34
+ # 1. Chame o super() para herdar todo o carregamento de dataset do Nerfstudio!
35
+ super().__init__(
36
+ config=config, device=device, test_mode=test_mode, **kwargs
37
+ )
38
+
39
+ def _sample_source_views(self, num_views: int) -> Dict[str, torch.Tensor]:
40
+ """
41
+ Sorteia N imagens do dataset para atuar como contexto (source views)
42
+ e as formata para o pixelNeRF.
43
+ """
44
+ dataset = self.train_dataset
45
+ # I should implement some smarter sampling strategy here, but for now I'll just randomly sample N views from the dataset.
46
+ indices = random.sample(range(len(dataset)), num_views)
47
+
48
+ src_rgbs = []
49
+ src_poses = []
50
+ focals = []
51
+ cs = []
52
+
53
+ for idx in indices:
54
+ data = dataset[idx]
55
+ src_rgbs.append(data["image"])
56
+ camera = dataset.cameras[idx]
57
+ c2w_3x4 = camera.camera_to_worlds
58
+ c2w_4x4 = torch.cat([
59
+ c2w_3x4,
60
+ torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=c2w_3x4.device)
61
+ ], dim=0)
62
+ src_poses.append(c2w_4x4)
63
+
64
+ focals.append(torch.tensor([camera.fx.item(), camera.fy.item()]))
65
+ cs.append(torch.tensor([camera.cx.item(), camera.cy.item()]))
66
+
67
+ return {
68
+ "src_rgbs": torch.stack(src_rgbs).unsqueeze(0),
69
+ "src_cameras": torch.stack(src_poses).unsqueeze(0),
70
+ "focal": torch.stack(focals),
71
+ "c": torch.stack(cs)
72
+ }
73
+
74
+ def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
75
+ ray_bundle, batch = self.train_pixel_sampler.sample(self.config.train_num_rays_per_batch)
76
+ source_data = self._sample_source_views(self.config.num_source_views)
77
+ ray_bundle.metadata.update(source_data)
78
+
79
+ return ray_bundle, batch
80
+
81
+ def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
82
+ """A mesma lógica de next_train, mas usando o eval_pixel_sampler."""
83
+ ray_bundle, batch = self.eval_pixel_sampler.sample(self.config.eval_num_rays_per_batch)
84
+ source_data = self._sample_source_views(self.config.num_source_views)
85
+ ray_bundle.metadata.update(source_data)
86
+ return ray_bundle, batch
PixelNeRFModel.py ADDED
@@ -0,0 +1,199 @@
1
+ import sys
2
+ import os
3
+ from pathlib import Path
4
+
5
+ pixelnerf_src_root = str(Path(__file__).parent / "pixelnerf" / "src")
6
+
7
+ if pixelnerf_src_root not in sys.path:
8
+ sys.path.insert(0, pixelnerf_src_root)
9
+
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, List, Optional, Type, Tuple, cast
12
+ from nerfstudio.cameras.cameras import Cameras
13
+ from nerfstudio.cameras.rays import RayBundle
14
+ from nerfstudio.data.scene_box import SceneBox
15
+ from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
16
+ from nerfstudio.models.base_model import Model, ModelConfig
17
+ from torch.nn import Parameter
18
+ from pixelnerf.src.model.models import PixelNeRFNet
19
+ from pixelnerf.src.render import NeRFRenderer
20
+ from dotmap import DotMap
21
+
22
+ import os
23
+ import torch
24
+ import nerfstudio.utils.profiler as profiler
25
+
26
+
27
+ @dataclass
28
+ class PixelNeRFModelConfig(ModelConfig):
29
+ _target: Type = field(default_factory=lambda: PixelNeRFModel, init=False)
30
+ ckpt_path: Optional[str] = None
31
+ """Path to a .pth checkpoint file to load the network weights from. If not provided, will look for .pth files in the output directory and load the latest one."""
32
+ no_reload: bool = False
33
+ """If True, will not attempt to load from any checkpoint and will always train from scratch."""
34
+ out_dir: str = "outputs"
35
+ """Subdirectory of the output directory to save checkpoints and logs for this model. If not provided, will use 'default_exp'."""
36
+ exp_name: str = "default_exp"
37
+ """Name of the experiment, used as a subdirectory under out_dir to save checkpoints and logs. If not provided, will use 'default_exp'."""
38
+ encoder: Dict[str, Any] = field(
39
+ default_factory=lambda: {
40
+ "backbone": "resnet34",
41
+ "pretrained": True,
42
+ "num_layers": 4,
43
+ },
44
+ metadata={
45
+ "help": "Configuration for the pixelNeRF encoder. Currently using the paper default configuration"
46
+ },
47
+ )
48
+ renderer: Dict[str, Any] = field(
49
+ default_factory=lambda: {
50
+ "n_coarse": 64,
51
+ "n_fine": 32,
52
+ "n_fine_depth": 64,
53
+ "depth_std": 0.01,
54
+ "white_bkgd": False,
55
+ },
56
+ metadata={
57
+ "help": "Configuration for the pixelNeRF renderer. Currently using the paper default configuration"
58
+ },
59
+ )
60
+ lindisp: bool = False
61
+ """Whether to sample linearly in disparity (inverse depth) rather than depth. Paper defines it troughout dataset preprocessing, so we keep it as a config option but set it to False by default since it's not commonly used in nerf implementations."""
62
+
63
+
64
+ class PixelNeRFModel(Model):
65
+ config: PixelNeRFModelConfig
66
+
67
+ def __init__(self, config, scene_box=None, num_train_data=0, **kwargs):
68
+ if scene_box is None:
69
+ scene_box = SceneBox(
70
+ aabb=torch.tensor([[-1, -1, -1], [1, 1, 1]], dtype=torch.float32)
71
+ )
72
+ super().__init__(
73
+ config=config, scene_box=scene_box, num_train_data=num_train_data, **kwargs
74
+ )
75
+
76
+ def populate_modules(self):
77
+ super().populate_modules()
78
+ self.net = PixelNeRFNet(self.config)
79
+
80
+ self.renderer = NeRFRenderer.from_conf(
81
+ self.config.renderer,
82
+ lindisp=self.config.lindisp,
83
+ )
84
+
85
+ if torch.cuda.is_available():
86
+ print(f"Using {torch.cuda.device_count()} GPUs for parallelization")
87
+ self.renderer = self.renderer.bind_parallel(
88
+ self.net, gpus=list(range(torch.cuda.device_count()))
89
+ ).eval()
90
+
91
+ if self.config.no_reload:
92
+ print("Not loading from ckpt, training from scratch...")
93
+ else:
94
+ self.load_from_ckpt(self.config.out_dir, force_latest=False)
95
+
96
+ def get_param_groups(self) -> Dict[str, List[Parameter]]:
97
+ return {"network": list(self.net.parameters())}
98
+
99
+ def get_training_callbacks(
100
+ self, training_callback_attributes: TrainingCallbackAttributes
101
+ ) -> List[TrainingCallback]:
102
+ return []
103
+
104
+ def get_loss_dict(
105
+ self, outputs, batch, metrics_dict=None
106
+ ) -> Dict[str, torch.Tensor]:
107
+ """The paper calcs loss"""
108
+ loss = torch.nn.functional.mse_loss(outputs["rgb_coarse"], batch["rgb"])
109
+ if "rgb_fine" in outputs:
110
+ loss = loss + torch.nn.functional.mse_loss(
111
+ outputs["rgb_fine"], batch["rgb"]
112
+ )
113
+ return {"rgb_loss": loss}
114
+
115
+ def get_metrics(
116
+ self, outputs, batch
117
+ ) -> Dict[str, torch.Tensor]:
118
+ """The paper only reports PSNR, but you can add more metrics here if you want."""
119
+ pred = outputs.get("rgb_fine", outputs["rgb_coarse"])
120
+ gt = batch["rgb"].to(pred.device)
121
+ psnr = -10.0 * torch.log10(torch.mean((pred - gt) ** 2).clamp_min(1e-10))
122
+ return {"psnr": psnr}
123
+
124
+ def get_image_metrics_and_images(
125
+ self, outputs, batch
126
+ ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]:
127
+ pred = outputs.get("rgb_fine", outputs["rgb_coarse"])
128
+ gt = batch["rgb"].to(pred.device)
129
+ psnr = -10.0 * torch.log10(torch.mean((pred - gt) ** 2).clamp_min(1e-10))
130
+ return {"psnr": float(psnr.item())}, {"rgb": pred, "rgb_gt": gt}
131
+
132
+ def load_from_ckpt(self, out_folder, force_latest=False):
133
+ if not os.path.exists(out_folder):
134
+ print("No ckpts found, training from scratch...")
135
+ return 0
136
+ ckpts = sorted(
137
+ [
138
+ os.path.join(out_folder, f)
139
+ for f in os.listdir(out_folder)
140
+ if f.endswith(".pth")
141
+ ]
142
+ )
143
+ if self.config.ckpt_path and not force_latest:
144
+ if os.path.isfile(self.config.ckpt_path):
145
+ ckpts = [self.config.ckpt_path]
146
+ if ckpts and not self.config.no_reload:
147
+ fpath = ckpts[-1]
148
+ self.net.load_state_dict(torch.load(fpath, map_location="cpu"))
149
+ print(f"Reloading from {fpath}")
150
+ return int(fpath[-10:-4])
151
+ print("No ckpts found, training from scratch...")
152
+ return 0
153
+
154
+ @profiler.time_function
155
+ def get_outputs(self, ray_bundle: RayBundle | Cameras) -> Dict[str, torch.Tensor | List]:
156
+ assert isinstance(ray_bundle, RayBundle)
157
+ device = next(self.net.parameters()).device
158
+ metadata = ray_bundle.metadata or {}
159
+
160
+ for key in ("src_rgbs", "src_cameras", "focal", "c"):
161
+ if key not in metadata:
162
+ raise KeyError(f"Missing metadata key '{key}' — pipeline must inject source views")
163
+
164
+ src_images = metadata["src_rgbs"].squeeze(0).permute(0, 3, 1, 2).to(device) # (NS, 3, H, W)
165
+ src_poses = metadata["src_cameras"].squeeze(0).to(device) # (NS, 4, 4)
166
+ focal = metadata["focal"].to(device) # (NS, 2)
167
+ c = metadata["c"].to(device) # (NS, 2)
168
+
169
+ self.net.encode(
170
+ src_images.unsqueeze(0),
171
+ src_poses.unsqueeze(0),
172
+ focal,
173
+ c=c,
174
+ )
175
+
176
+ rays = torch.cat([
177
+ ray_bundle.origins.to(device),
178
+ ray_bundle.directions.to(device),
179
+ ray_bundle.nears.to(device),
180
+ ray_bundle.fars.to(device),
181
+ ], dim=-1).unsqueeze(0)
182
+
183
+ render_dict = DotMap(self.renderer(rays, want_weights=True))
184
+
185
+ outputs: Dict[str, torch.Tensor | List] = {
186
+ "rgb_coarse": render_dict.coarse.rgb.squeeze(0),
187
+ "depth_coarse": render_dict.coarse.depth.squeeze(0),
188
+ "weights_coarse": render_dict.coarse.weights.squeeze(0),
189
+ }
190
+ if len(render_dict.fine) > 0:
191
+ outputs["rgb_fine"] = render_dict.fine.rgb.squeeze(0)
192
+ outputs["depth_fine"] = render_dict.fine.depth.squeeze(0)
193
+ outputs["weights_fine"] = render_dict.fine.weights.squeeze(0)
194
+
195
+ outputs["rgb"] = outputs.get("rgb_fine", outputs["rgb_coarse"])
196
+ outputs["accumulation"] = outputs.get("weights_fine", outputs["weights_coarse"]).sum(dim=-1)
197
+ outputs["depth"] = outputs.get("depth_fine", outputs["depth_coarse"])
198
+
199
+ return cast(Dict[str, torch.Tensor | List], outputs)
@@ -0,0 +1,10 @@
1
+ Metadata-Version: 2.4
2
+ Name: nerfstudio-pixelnerf
3
+ Version: 0.0.1
4
+ Summary: Unofficial Implementation of `pixelNeRF: Neural Radiance Fields from One or Few Images` [Yu et al.] for NeRFStudio
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: dotmap>=1.3.30
8
+ Requires-Dist: lpips>=0.1.4
9
+ Requires-Dist: nerfstudio
10
+ Requires-Dist: pyhocon>=0.3.63
@@ -0,0 +1,8 @@
1
+ PixelNeRFConfig.py,sha256=_kGrr4nNlV8xaMN8s7xoJW7pbevM8laYaGZWU_BblRU,1320
2
+ PixelNeRFDataManager.py,sha256=9GA6_mQ9TcfR-CLITTQKvm-f77CgfgG9dWusxttW7U4,3300
3
+ PixelNeRFModel.py,sha256=oEyq9RJuiHdtTJ-CqKttdy6dR9rcmH7J8g2UMTNUYr4,8184
4
+ nerfstudio_pixelnerf-0.0.1.dist-info/METADATA,sha256=Ihw9SP-o0AmoHR-Ekma2cz_R2DSB6jPH6E0Q19AKFSQ,367
5
+ nerfstudio_pixelnerf-0.0.1.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
6
+ nerfstudio_pixelnerf-0.0.1.dist-info/entry_points.txt,sha256=DUIH2xvV26nSkjyi9DswfxnhEcc9njfxCsqo-qUBcmk,66
7
+ nerfstudio_pixelnerf-0.0.1.dist-info/top_level.txt,sha256=UZJrWmW6hLBUIi4B78LvwwrgHuNZ0AJRAnTysU9RpcA,70
8
+ nerfstudio_pixelnerf-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [nerfstudio.method_configs]
2
+ pixelnerf = PixelNeRFConfig:PixelNeRF
@@ -0,0 +1,4 @@
1
+ PixelNeRFConfig
2
+ PixelNeRFDataManager
3
+ PixelNeRFModel
4
+ PixelNeRFPipeline