gaussian-splatting 1.17.5__cp310-cp310-win_amd64.whl → 1.19.4__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gaussian_splatting/camera.py +9 -2
- gaussian_splatting/dataset/camera_trainable.py +4 -4
- gaussian_splatting/dataset/colmap/dataset.py +12 -10
- gaussian_splatting/dataset/dataset.py +3 -2
- gaussian_splatting/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd +0 -0
- gaussian_splatting/prepare.py +11 -3
- gaussian_splatting/render.py +12 -4
- gaussian_splatting/simple_knn/_C.cp310-win_amd64.pyd +0 -0
- gaussian_splatting/train.py +26 -6
- gaussian_splatting/trainer/abc.py +8 -0
- gaussian_splatting/trainer/base.py +37 -10
- gaussian_splatting/trainer/depth.py +0 -4
- {gaussian_splatting-1.17.5.dist-info → gaussian_splatting-1.19.4.dist-info}/METADATA +14 -11
- {gaussian_splatting-1.17.5.dist-info → gaussian_splatting-1.19.4.dist-info}/RECORD +17 -17
- {gaussian_splatting-1.17.5.dist-info → gaussian_splatting-1.19.4.dist-info}/WHEEL +0 -0
- {gaussian_splatting-1.17.5.dist-info → gaussian_splatting-1.19.4.dist-info}/licenses/LICENSE.md +0 -0
- {gaussian_splatting-1.17.5.dist-info → gaussian_splatting-1.19.4.dist-info}/top_level.txt +0 -0
gaussian_splatting/camera.py
CHANGED
|
@@ -52,6 +52,7 @@ def camera2dict(camera: Camera, id):
|
|
|
52
52
|
'ground_truth_image_mask_path': camera.ground_truth_image_mask_path.replace("\\", "/") if camera.ground_truth_image_mask_path else None,
|
|
53
53
|
'ground_truth_depth_path': camera.ground_truth_depth_path.replace("\\", "/") if camera.ground_truth_depth_path else None,
|
|
54
54
|
'ground_truth_depth_mask_path': camera.ground_truth_depth_mask_path.replace("\\", "/") if camera.ground_truth_depth_mask_path else None,
|
|
55
|
+
"img_name": os.path.basename(camera.ground_truth_image_path), # necessary for SIBR_gaussianViewer_app
|
|
55
56
|
}
|
|
56
57
|
return camera_entry
|
|
57
58
|
|
|
@@ -132,7 +133,7 @@ def build_camera(
|
|
|
132
133
|
)
|
|
133
134
|
|
|
134
135
|
|
|
135
|
-
def dict2camera(camera_dict, load_depth=
|
|
136
|
+
def dict2camera(camera_dict, load_mask=True, load_depth=True, device="cuda", custom_data: dict = {}):
|
|
136
137
|
C2W = torch.zeros((4, 4), device=device)
|
|
137
138
|
C2W[:3, 3] = torch.tensor(camera_dict['position'], dtype=torch.float, device=device)
|
|
138
139
|
C2W[:3, :3] = torch.tensor(camera_dict['rotation'], dtype=torch.float, device=device)
|
|
@@ -140,6 +141,12 @@ def dict2camera(camera_dict, load_depth=False, device="cuda", custom_data: dict
|
|
|
140
141
|
Rt = torch.linalg.inv(C2W)
|
|
141
142
|
T = Rt[:3, 3]
|
|
142
143
|
R = Rt[:3, :3]
|
|
144
|
+
if load_mask and ('ground_truth_image_mask_path' not in camera_dict or camera_dict['ground_truth_image_mask_path'] is None):
|
|
145
|
+
logging.warning(f"Value of key 'ground_truth_image_mask_path' is not a valid path, skipping mask loading.")
|
|
146
|
+
if load_depth and ('ground_truth_depth_path' not in camera_dict or camera_dict['ground_truth_depth_path'] is None):
|
|
147
|
+
logging.warning(f"Value of key 'ground_truth_depth_path' is not a valid path, skipping depth loading.")
|
|
148
|
+
if load_depth and ('ground_truth_depth_mask_path' not in camera_dict or camera_dict['ground_truth_depth_mask_path'] is None):
|
|
149
|
+
logging.warning(f"Value of key 'ground_truth_depth_mask_path' is not a valid path, skipping depth mask loading.")
|
|
143
150
|
return build_camera(
|
|
144
151
|
image_width=camera_dict['width'],
|
|
145
152
|
image_height=camera_dict['height'],
|
|
@@ -148,7 +155,7 @@ def dict2camera(camera_dict, load_depth=False, device="cuda", custom_data: dict
|
|
|
148
155
|
R=R,
|
|
149
156
|
T=T,
|
|
150
157
|
image_path=camera_dict['ground_truth_image_path'] if 'ground_truth_image_path' in camera_dict else None,
|
|
151
|
-
image_mask_path=camera_dict['ground_truth_image_mask_path'] if 'ground_truth_image_mask_path' in camera_dict else None,
|
|
158
|
+
image_mask_path=camera_dict['ground_truth_image_mask_path'] if (load_mask and 'ground_truth_image_mask_path' in camera_dict) else None,
|
|
152
159
|
depth_path=camera_dict['ground_truth_depth_path'] if (load_depth and 'ground_truth_depth_path' in camera_dict) else None,
|
|
153
160
|
depth_mask_path=camera_dict['ground_truth_depth_mask_path'] if (load_depth and 'ground_truth_depth_mask_path' in camera_dict) else None,
|
|
154
161
|
device=device,
|
|
@@ -81,8 +81,8 @@ class TrainableCameraDataset(CameraDataset):
|
|
|
81
81
|
json.dump(cameras, f, indent=2)
|
|
82
82
|
|
|
83
83
|
@classmethod
|
|
84
|
-
def from_json(cls, path, load_depth=
|
|
85
|
-
cameras = JSONCameraDataset(path, load_depth=load_depth)
|
|
84
|
+
def from_json(cls, path, load_mask=True, load_depth=True):
|
|
85
|
+
cameras = JSONCameraDataset(path, load_mask=load_mask, load_depth=load_depth)
|
|
86
86
|
exposures = [(torch.tensor(camera['exposure'], dtype=torch.float) if 'exposure' in camera else torch.eye(3, 4)) for camera in cameras.json_cameras]
|
|
87
87
|
return cls(cameras, exposures)
|
|
88
88
|
|
|
@@ -91,8 +91,8 @@ class FixedTrainableCameraDataset(JSONCameraDataset):
|
|
|
91
91
|
# Same as TrainableCameraDataset, but is fixed
|
|
92
92
|
# Used for loading cameras saved by TrainableCameraDataset
|
|
93
93
|
|
|
94
|
-
def __init__(self, path, load_depth=
|
|
95
|
-
super().__init__(path, load_depth=load_depth)
|
|
94
|
+
def __init__(self, path, load_mask=True, load_depth=True):
|
|
95
|
+
super().__init__(path, load_mask=load_mask, load_depth=load_depth)
|
|
96
96
|
self.load_exposures()
|
|
97
97
|
|
|
98
98
|
def to(self, device):
|
|
@@ -27,7 +27,7 @@ class ColmapCamera(NamedTuple):
|
|
|
27
27
|
depth_mask_path: str
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def parse_colmap_camera(cameras, images, image_dir, depth_dir=None) -> List[ColmapCamera]:
|
|
30
|
+
def parse_colmap_camera(cameras, images, image_dir, load_mask=True, depth_dir=None) -> List[ColmapCamera]:
|
|
31
31
|
parsed_cameras = []
|
|
32
32
|
for _, key in enumerate(cameras):
|
|
33
33
|
extr = cameras[key]
|
|
@@ -49,9 +49,11 @@ def parse_colmap_camera(cameras, images, image_dir, depth_dir=None) -> List[Colm
|
|
|
49
49
|
raise ValueError("Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!")
|
|
50
50
|
|
|
51
51
|
image_path = os.path.join(image_dir, extr.name)
|
|
52
|
-
image_mask_path =
|
|
53
|
-
if
|
|
54
|
-
image_mask_path = os.path.splitext(
|
|
52
|
+
image_mask_path = None
|
|
53
|
+
if load_mask:
|
|
54
|
+
image_mask_path = os.path.join(image_dir, os.path.splitext(extr.name)[0] + '_mask.tiff')
|
|
55
|
+
if not os.path.exists(image_mask_path):
|
|
56
|
+
image_mask_path = os.path.splitext(image_mask_path)[0] + '.png'
|
|
55
57
|
depth_path, depth_mask_path = None, None
|
|
56
58
|
if depth_dir is not None:
|
|
57
59
|
depth_path = os.path.join(depth_dir, os.path.splitext(extr.name)[0] + '.tiff')
|
|
@@ -72,7 +74,7 @@ def parse_colmap_camera(cameras, images, image_dir, depth_dir=None) -> List[Colm
|
|
|
72
74
|
return parsed_cameras
|
|
73
75
|
|
|
74
76
|
|
|
75
|
-
def read_colmap_cameras(colmap_folder, load_depth=
|
|
77
|
+
def read_colmap_cameras(colmap_folder, load_mask=True, load_depth=True) -> List[ColmapCamera]:
|
|
76
78
|
path = colmap_folder
|
|
77
79
|
image_dir = os.path.join(path, "images")
|
|
78
80
|
try:
|
|
@@ -86,13 +88,13 @@ def read_colmap_cameras(colmap_folder, load_depth=False) -> List[ColmapCamera]:
|
|
|
86
88
|
cam_extrinsics = read_images_text(cameras_extrinsic_file)
|
|
87
89
|
cam_intrinsics = read_cameras_text(cameras_intrinsic_file)
|
|
88
90
|
depth_dir = os.path.join(path, "depths") if load_depth else None
|
|
89
|
-
return parse_colmap_camera(cam_extrinsics, cam_intrinsics, image_dir, depth_dir)
|
|
91
|
+
return parse_colmap_camera(cam_extrinsics, cam_intrinsics, image_dir, load_mask=load_mask, depth_dir=depth_dir)
|
|
90
92
|
|
|
91
93
|
|
|
92
94
|
class ColmapCameraDataset(CameraDataset):
|
|
93
|
-
def __init__(self, colmap_folder, load_depth=
|
|
95
|
+
def __init__(self, colmap_folder, load_mask=True, load_depth=True):
|
|
94
96
|
super().__init__()
|
|
95
|
-
self.raw_cameras = read_colmap_cameras(colmap_folder, load_depth=load_depth)
|
|
97
|
+
self.raw_cameras = read_colmap_cameras(colmap_folder, load_mask=load_mask, load_depth=load_depth)
|
|
96
98
|
self.cameras = [build_camera(**cam._asdict()) for cam in self.raw_cameras]
|
|
97
99
|
|
|
98
100
|
def to(self, device):
|
|
@@ -106,5 +108,5 @@ class ColmapCameraDataset(CameraDataset):
|
|
|
106
108
|
return self.cameras[idx]
|
|
107
109
|
|
|
108
110
|
|
|
109
|
-
def ColmapTrainableCameraDataset(colmap_folder, load_depth=
|
|
110
|
-
return TrainableCameraDataset(ColmapCameraDataset(colmap_folder, load_depth=load_depth))
|
|
111
|
+
def ColmapTrainableCameraDataset(colmap_folder, load_mask=True, load_depth=True):
|
|
112
|
+
return TrainableCameraDataset(ColmapCameraDataset(colmap_folder, load_mask=load_mask, load_depth=load_depth))
|
|
@@ -35,9 +35,10 @@ class CameraDataset:
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class JSONCameraDataset(CameraDataset):
|
|
38
|
-
def __init__(self, path, load_depth=
|
|
38
|
+
def __init__(self, path, load_mask=True, load_depth=True):
|
|
39
39
|
with open(path, 'r') as f:
|
|
40
40
|
self.json_cameras = json.load(f)
|
|
41
|
+
self.load_mask = load_mask
|
|
41
42
|
self.load_depth = load_depth
|
|
42
43
|
self.load_cameras()
|
|
43
44
|
|
|
@@ -51,7 +52,7 @@ class JSONCameraDataset(CameraDataset):
|
|
|
51
52
|
return self.load_cameras(device=device)
|
|
52
53
|
|
|
53
54
|
def load_cameras(self, device=None):
|
|
54
|
-
self.cameras = [dict2camera(camera, load_depth=self.load_depth, device=device) for camera in self.json_cameras]
|
|
55
|
+
self.cameras = [dict2camera(camera, load_mask=self.load_mask, load_depth=self.load_depth, device=device) for camera in self.json_cameras]
|
|
55
56
|
return self
|
|
56
57
|
|
|
57
58
|
|
|
Binary file
|
gaussian_splatting/prepare.py
CHANGED
|
@@ -6,11 +6,19 @@ from .trainer import *
|
|
|
6
6
|
from .trainer.extensions import ScaleRegularizeTrainerWrapper
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def prepare_dataset(source: str, device: str, trainable_camera: bool = False, load_camera: str = None, load_depth=
|
|
9
|
+
def prepare_dataset(source: str, device: str, trainable_camera: bool = False, load_camera: str = None, load_mask=True, load_depth=True) -> CameraDataset:
|
|
10
10
|
if trainable_camera:
|
|
11
|
-
dataset = (
|
|
11
|
+
dataset = (
|
|
12
|
+
TrainableCameraDataset.from_json(load_camera, load_mask=load_mask, load_depth=load_depth)
|
|
13
|
+
if load_camera else
|
|
14
|
+
ColmapTrainableCameraDataset(source, load_mask=load_mask, load_depth=load_depth)
|
|
15
|
+
).to(device)
|
|
12
16
|
else:
|
|
13
|
-
dataset = (
|
|
17
|
+
dataset = (
|
|
18
|
+
FixedTrainableCameraDataset(load_camera, load_mask=load_mask, load_depth=load_depth)
|
|
19
|
+
if load_camera else
|
|
20
|
+
ColmapCameraDataset(source, load_mask=load_mask, load_depth=load_depth)
|
|
21
|
+
).to(device)
|
|
14
22
|
return dataset
|
|
15
23
|
|
|
16
24
|
|
gaussian_splatting/render.py
CHANGED
|
@@ -12,8 +12,11 @@ from gaussian_splatting.utils.lpipsPyTorch import lpips
|
|
|
12
12
|
from gaussian_splatting.prepare import prepare_dataset, prepare_gaussians
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def prepare_rendering(
|
|
16
|
-
|
|
15
|
+
def prepare_rendering(
|
|
16
|
+
sh_degree: int, source: str, device: str,
|
|
17
|
+
trainable_camera: bool = False, load_ply: str = None, load_camera: str = None,
|
|
18
|
+
load_mask=True, load_depth=True) -> Tuple[CameraDataset, GaussianModel]:
|
|
19
|
+
dataset = prepare_dataset(source=source, device=device, trainable_camera=trainable_camera, load_camera=load_camera, load_mask=load_mask, load_depth=load_depth)
|
|
17
20
|
gaussians = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply)
|
|
18
21
|
return dataset, gaussians
|
|
19
22
|
|
|
@@ -54,13 +57,16 @@ def rendering(
|
|
|
54
57
|
gt_path = os.path.join(save, "gt")
|
|
55
58
|
makedirs(render_path, exist_ok=True)
|
|
56
59
|
makedirs(gt_path, exist_ok=True)
|
|
57
|
-
pbar = tqdm(dataset, desc="Rendering
|
|
60
|
+
pbar = tqdm(dataset, dynamic_ncols=True, desc="Rendering")
|
|
58
61
|
with open(os.path.join(save, "quality.csv"), "w") as f:
|
|
59
62
|
f.write("name,psnr,ssim,lpips\n")
|
|
60
63
|
for idx, camera in enumerate(pbar):
|
|
61
64
|
out = gaussians(camera)
|
|
62
65
|
rendering = out["render"]
|
|
63
66
|
gt = camera.ground_truth_image
|
|
67
|
+
if camera.ground_truth_image_mask is not None:
|
|
68
|
+
gt *= camera.ground_truth_image_mask
|
|
69
|
+
rendering *= camera.ground_truth_image_mask
|
|
64
70
|
psnr_value = psnr(rendering, gt).mean().item()
|
|
65
71
|
ssim_value = ssim(rendering, gt).mean().item()
|
|
66
72
|
lpips_value = lpips(rendering, gt).mean().item()
|
|
@@ -93,6 +99,7 @@ if __name__ == "__main__":
|
|
|
93
99
|
parser.add_argument("--load_camera", default=None, type=str)
|
|
94
100
|
parser.add_argument("--mode", choices=["base", "camera"], default="base")
|
|
95
101
|
parser.add_argument("--device", default="cuda", type=str)
|
|
102
|
+
parser.add_argument("--no_image_mask", action="store_true")
|
|
96
103
|
parser.add_argument("--no_rescale_depth_gt", action="store_true")
|
|
97
104
|
parser.add_argument("--save_depth_pcd", action="store_true")
|
|
98
105
|
args = parser.parse_args()
|
|
@@ -101,5 +108,6 @@ if __name__ == "__main__":
|
|
|
101
108
|
with torch.no_grad():
|
|
102
109
|
dataset, gaussians = prepare_rendering(
|
|
103
110
|
sh_degree=args.sh_degree, source=args.source, device=args.device, trainable_camera=args.mode == "camera",
|
|
104
|
-
load_ply=load_ply, load_camera=args.load_camera,
|
|
111
|
+
load_ply=load_ply, load_camera=args.load_camera,
|
|
112
|
+
load_mask=not args.no_image_mask, load_depth=args.save_depth_pcd)
|
|
105
113
|
rendering(dataset, gaussians, save, save_pcd=args.save_depth_pcd, rescale_depth_gt=not args.no_rescale_depth_gt)
|
|
Binary file
|
gaussian_splatting/train.py
CHANGED
|
@@ -12,8 +12,12 @@ from gaussian_splatting.trainer import AbstractTrainer
|
|
|
12
12
|
from gaussian_splatting.prepare import basemodes, shliftmodes, prepare_dataset, prepare_gaussians, prepare_trainer
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def prepare_training(
|
|
16
|
-
|
|
15
|
+
def prepare_training(
|
|
16
|
+
sh_degree: int, source: str, device: str, mode: str,
|
|
17
|
+
trainable_camera: bool = False, load_ply: str = None, load_camera: str = None,
|
|
18
|
+
load_mask=True, load_depth=True,
|
|
19
|
+
with_scale_reg=False, configs={}) -> Tuple[CameraDataset, GaussianModel, AbstractTrainer]:
|
|
20
|
+
dataset = prepare_dataset(source=source, device=device, trainable_camera=trainable_camera, load_camera=load_camera, load_mask=load_mask, load_depth=load_depth)
|
|
17
21
|
gaussians = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply)
|
|
18
22
|
trainer = prepare_trainer(gaussians=gaussians, dataset=dataset, mode=mode, trainable_camera=trainable_camera, load_ply=load_ply, with_scale_reg=with_scale_reg, configs=configs)
|
|
19
23
|
return dataset, gaussians, trainer
|
|
@@ -27,24 +31,37 @@ def save_cfg_args(destination: str, sh_degree: int, source: str):
|
|
|
27
31
|
|
|
28
32
|
def training(dataset: CameraDataset, gaussians: GaussianModel, trainer: AbstractTrainer, destination: str, iteration: int, save_iterations: List[int], device: str):
|
|
29
33
|
shutil.rmtree(os.path.join(destination, "point_cloud"), ignore_errors=True) # remove the previous point cloud
|
|
30
|
-
pbar = tqdm(range(1, iteration+1))
|
|
34
|
+
pbar = tqdm(range(1, iteration+1), dynamic_ncols=True, desc="Training")
|
|
31
35
|
epoch = list(range(len(dataset)))
|
|
32
36
|
epoch_psnr = torch.empty(3, 0, device=device)
|
|
37
|
+
epoch_maskpsnr = torch.empty(3, 0, device=device)
|
|
33
38
|
ema_loss_for_log = 0.0
|
|
34
39
|
avg_psnr_for_log = 0.0
|
|
40
|
+
avg_maskpsnr_for_log = 0.0
|
|
35
41
|
for step in pbar:
|
|
36
42
|
epoch_idx = step % len(dataset)
|
|
37
43
|
if epoch_idx == 0:
|
|
38
44
|
avg_psnr_for_log = epoch_psnr.mean().item()
|
|
45
|
+
avg_maskpsnr_for_log = epoch_maskpsnr.mean().item()
|
|
39
46
|
epoch_psnr = torch.empty(3, 0, device=device)
|
|
47
|
+
epoch_maskpsnr = torch.empty(3, 0, device=device)
|
|
40
48
|
random.shuffle(epoch)
|
|
41
49
|
idx = epoch[epoch_idx]
|
|
42
50
|
loss, out = trainer.step(dataset[idx])
|
|
43
51
|
with torch.no_grad():
|
|
52
|
+
ground_truth_image = dataset[idx].ground_truth_image
|
|
53
|
+
rendered_image = out["render"]
|
|
54
|
+
epoch_psnr = torch.concat([epoch_psnr, psnr(rendered_image, ground_truth_image)], dim=1)
|
|
55
|
+
if dataset[idx].ground_truth_image_mask is not None:
|
|
56
|
+
ground_truth_maskimage = ground_truth_image * dataset[idx].ground_truth_image_mask
|
|
57
|
+
rendered_maskimage = rendered_image * dataset[idx].ground_truth_image_mask
|
|
58
|
+
epoch_maskpsnr = torch.concat([epoch_maskpsnr, psnr(rendered_maskimage, ground_truth_maskimage)], dim=1)
|
|
44
59
|
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
|
|
45
|
-
epoch_psnr = torch.concat([epoch_psnr, psnr(out["render"], dataset[idx].ground_truth_image)], dim=1)
|
|
46
60
|
if step % 10 == 0:
|
|
47
|
-
|
|
61
|
+
postfix = {'epoch': step // len(dataset), 'loss': ema_loss_for_log, 'psnr': avg_psnr_for_log, 'n': gaussians._xyz.shape[0]}
|
|
62
|
+
if avg_maskpsnr_for_log > 0:
|
|
63
|
+
postfix['mask_psnr'] = avg_maskpsnr_for_log
|
|
64
|
+
pbar.set_postfix(postfix)
|
|
48
65
|
if step in save_iterations:
|
|
49
66
|
save_path = os.path.join(destination, "point_cloud", "iteration_" + str(step))
|
|
50
67
|
os.makedirs(save_path, exist_ok=True)
|
|
@@ -65,6 +82,7 @@ if __name__ == "__main__":
|
|
|
65
82
|
parser.add_argument("-i", "--iteration", default=30000, type=int)
|
|
66
83
|
parser.add_argument("-l", "--load_ply", default=None, type=str)
|
|
67
84
|
parser.add_argument("--load_camera", default=None, type=str)
|
|
85
|
+
parser.add_argument("--no_image_mask", action="store_true")
|
|
68
86
|
parser.add_argument("--no_depth_data", action="store_true")
|
|
69
87
|
parser.add_argument("--with_scale_reg", action="store_true")
|
|
70
88
|
parser.add_argument("--mode", choices=sorted(list(set(list(basemodes.keys()) + list(shliftmodes.keys())))), default="base")
|
|
@@ -78,7 +96,9 @@ if __name__ == "__main__":
|
|
|
78
96
|
configs = {o.split("=", 1)[0]: eval(o.split("=", 1)[1]) for o in args.option}
|
|
79
97
|
dataset, gaussians, trainer = prepare_training(
|
|
80
98
|
sh_degree=args.sh_degree, source=args.source, device=args.device, mode=args.mode, trainable_camera="camera" in args.mode,
|
|
81
|
-
load_ply=args.load_ply, load_camera=args.load_camera,
|
|
99
|
+
load_ply=args.load_ply, load_camera=args.load_camera,
|
|
100
|
+
load_mask=not args.no_image_mask, load_depth=not args.no_depth_data,
|
|
101
|
+
with_scale_reg=args.with_scale_reg, configs=configs)
|
|
82
102
|
dataset.save_cameras(os.path.join(args.destination, "cameras.json"))
|
|
83
103
|
torch.cuda.empty_cache()
|
|
84
104
|
training(
|
|
@@ -33,6 +33,10 @@ class AbstractTrainer(ABC):
|
|
|
33
33
|
def schedulers(self) -> Dict[str, Callable[[int], float]]:
|
|
34
34
|
raise ValueError("Schedulers is not set")
|
|
35
35
|
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def preprocess(self, camera: Camera) -> Camera:
|
|
38
|
+
pass
|
|
39
|
+
|
|
36
40
|
@abstractmethod
|
|
37
41
|
def loss(self, out: dict, camera: Camera) -> torch.Tensor:
|
|
38
42
|
pass
|
|
@@ -49,6 +53,7 @@ class AbstractTrainer(ABC):
|
|
|
49
53
|
|
|
50
54
|
def step(self, camera: Camera):
|
|
51
55
|
self.update_learning_rate()
|
|
56
|
+
camera = self.preprocess(camera)
|
|
52
57
|
out = self.model(camera)
|
|
53
58
|
loss = self.loss(out, camera)
|
|
54
59
|
loss.backward()
|
|
@@ -100,6 +105,9 @@ class TrainerWrapper(AbstractTrainer):
|
|
|
100
105
|
def schedulers(self) -> Dict[str, Callable[[int], float]]:
|
|
101
106
|
return self.base_trainer.schedulers
|
|
102
107
|
|
|
108
|
+
def preprocess(self, camera: Camera) -> Camera:
|
|
109
|
+
return self.base_trainer.preprocess(camera)
|
|
110
|
+
|
|
103
111
|
def loss(self, out: dict, camera: Camera) -> torch.Tensor:
|
|
104
112
|
return self.base_trainer.loss(out, camera)
|
|
105
113
|
|
|
@@ -22,13 +22,23 @@ class BaseTrainer(AbstractTrainer):
|
|
|
22
22
|
opacity_lr=0.025,
|
|
23
23
|
scaling_lr=0.005,
|
|
24
24
|
rotation_lr=0.001,
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
mask_mode="none",
|
|
26
|
+
# "none"=do not use mask
|
|
27
|
+
# "ignore"=loss of the masked area will be set to 0
|
|
28
|
+
# "bg_color"=fill the masked area of ground truth with the bg_color for rendering
|
|
29
|
+
bg_color=None,
|
|
30
|
+
# None=do not change bg_color
|
|
31
|
+
# "random"=set bg_color to random color
|
|
32
|
+
# tuple(float, float, float)=set bg_color to the given color
|
|
27
33
|
):
|
|
28
34
|
super().__init__()
|
|
29
35
|
self.lambda_dssim = lambda_dssim
|
|
30
|
-
|
|
31
|
-
|
|
36
|
+
assert mask_mode in ["none", "ignore", "bg_color"], f"Unknown mask policy: {mask_mode}"
|
|
37
|
+
assert bg_color is None or bg_color == "random" or (
|
|
38
|
+
isinstance(bg_color, tuple) and len(bg_color) == 3 and all(isinstance(c, float) for c in bg_color)
|
|
39
|
+
), f"bg_color must be 'random' or a RGB value tuple(float, float, float), got {bg_color}"
|
|
40
|
+
self.mask_mode = mask_mode
|
|
41
|
+
self.bg_color = bg_color
|
|
32
42
|
params = [
|
|
33
43
|
{'params': [model._xyz], 'lr': position_lr_init * scene_extent, "name": "xyz"},
|
|
34
44
|
{'params': [model._features_dc], 'lr': feature_lr, "name": "f_dc"},
|
|
@@ -71,18 +81,35 @@ class BaseTrainer(AbstractTrainer):
|
|
|
71
81
|
def schedulers(self) -> Dict[str, Callable[[int], float]]:
|
|
72
82
|
return self._schedulers
|
|
73
83
|
|
|
84
|
+
def preprocess(self, camera: Camera) -> Camera:
|
|
85
|
+
if self.bg_color == "random":
|
|
86
|
+
camera = camera._replace(bg_color=torch.rand_like(camera.bg_color))
|
|
87
|
+
elif isinstance(self.bg_color, tuple):
|
|
88
|
+
camera = camera._replace(bg_color=torch.tensor(self.bg_color, device=camera.bg_color.device, dtype=camera.bg_color.dtype))
|
|
89
|
+
elif self.bg_color is None:
|
|
90
|
+
pass
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(f"bg_color must be 'random' or a tuple(int, int, int), got {self.bg_color}")
|
|
93
|
+
return camera
|
|
94
|
+
|
|
74
95
|
def loss(self, out: dict, camera: Camera) -> torch.Tensor:
|
|
75
96
|
render = out["render"]
|
|
76
97
|
gt = camera.ground_truth_image
|
|
77
98
|
mask = camera.ground_truth_image_mask
|
|
78
|
-
|
|
79
|
-
|
|
99
|
+
match self.mask_mode:
|
|
100
|
+
case "none":
|
|
101
|
+
pass
|
|
102
|
+
case "ignore":
|
|
103
|
+
assert mask is not None, "Mask is required for 'ignore' mask policy"
|
|
80
104
|
render = render * mask.unsqueeze(0)
|
|
81
105
|
gt = gt * mask.unsqueeze(0)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
106
|
+
case "bg_color":
|
|
107
|
+
assert mask is not None, "Mask is required for 'bg_color' mask policy"
|
|
108
|
+
# bg_color after postprocess
|
|
109
|
+
bg_color = camera.postprocess(camera, camera.bg_color.unsqueeze(-1).unsqueeze(-1)).clamp(0.0, 1.0)
|
|
110
|
+
gt = gt * mask.unsqueeze(0) + (1 - mask.unsqueeze(0)) * bg_color
|
|
111
|
+
case _:
|
|
112
|
+
raise ValueError(f"Unknown mask policy: {self.mask_mode}")
|
|
86
113
|
Ll1 = l1_loss(render, gt)
|
|
87
114
|
ssim_value = ssim(render, gt)
|
|
88
115
|
loss = (1.0 - self.lambda_dssim) * Ll1 + self.lambda_dssim * (1.0 - ssim_value)
|
|
@@ -75,10 +75,6 @@ class DepthTrainer(TrainerWrapper):
|
|
|
75
75
|
invdepth = out["depth"].squeeze(0)
|
|
76
76
|
invdepth_gt = camera.ground_truth_depth
|
|
77
77
|
mask = camera.ground_truth_depth_mask
|
|
78
|
-
if mask is None:
|
|
79
|
-
mask = camera.ground_truth_image_mask
|
|
80
|
-
elif camera.ground_truth_image_mask is not None:
|
|
81
|
-
mask = mask * camera.ground_truth_image_mask
|
|
82
78
|
assert invdepth.shape == invdepth_gt.shape, f"invdepth shape {invdepth.shape} does not match gt depth shape {invdepth_gt.shape}"
|
|
83
79
|
if self.depth_resize is not None:
|
|
84
80
|
height, width = invdepth.shape[-2:]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gaussian_splatting
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.19.4
|
|
4
4
|
Summary: Refactored python training and inference code for 3D Gaussian Splatting
|
|
5
5
|
Home-page: https://github.com/yindaheng98/gaussian-splatting
|
|
6
6
|
Author: yindaheng98
|
|
@@ -13,6 +13,8 @@ Requires-Dist: plyfile
|
|
|
13
13
|
Requires-Dist: tifffile
|
|
14
14
|
Requires-Dist: numpy
|
|
15
15
|
Requires-Dist: opencv-python
|
|
16
|
+
Requires-Dist: pillow
|
|
17
|
+
Requires-Dist: open3d
|
|
16
18
|
Dynamic: author
|
|
17
19
|
Dynamic: author-email
|
|
18
20
|
Dynamic: classifier
|
|
@@ -46,15 +48,6 @@ We **refactored the original code following the standard Python package structur
|
|
|
46
48
|
* [Pytorch](https://pytorch.org/) (>= v2.4 recommended)
|
|
47
49
|
* [CUDA Toolkit](https://developer.nvidia.com/cuda-12-4-0-download-archive) (12.4 recommended, match with PyTorch version)
|
|
48
50
|
|
|
49
|
-
### Local Install
|
|
50
|
-
|
|
51
|
-
```shell
|
|
52
|
-
git clone --recursive https://github.com/yindaheng98/gaussian-splatting
|
|
53
|
-
cd gaussian-splatting
|
|
54
|
-
pip install tqdm plyfile tifffile
|
|
55
|
-
pip install --target . --upgrade . --no-deps
|
|
56
|
-
```
|
|
57
|
-
|
|
58
51
|
### PyPI Install
|
|
59
52
|
|
|
60
53
|
```shell
|
|
@@ -63,7 +56,17 @@ pip install --upgrade gaussian-splatting
|
|
|
63
56
|
or
|
|
64
57
|
build latest from source:
|
|
65
58
|
```shell
|
|
66
|
-
pip install
|
|
59
|
+
pip install wheel setuptools
|
|
60
|
+
pip install --upgrade git+https://github.com/yindaheng98/gaussian-splatting.git@master --no-build-isolation
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
### Development Install
|
|
64
|
+
|
|
65
|
+
```shell
|
|
66
|
+
git clone --recursive https://github.com/yindaheng98/gaussian-splatting
|
|
67
|
+
cd gaussian-splatting
|
|
68
|
+
pip install tqdm plyfile tifffile numpy opencv-python pillow open3d
|
|
69
|
+
pip install --target . --upgrade . --no-deps
|
|
67
70
|
```
|
|
68
71
|
|
|
69
72
|
## Quick Start
|
|
@@ -1,26 +1,26 @@
|
|
|
1
1
|
gaussian_splatting/__init__.py,sha256=CiOZMcyPTAaKtEuMZUhEda_Ad4_RUhmIstB-A3iuOJY,131
|
|
2
|
-
gaussian_splatting/camera.py,sha256=
|
|
2
|
+
gaussian_splatting/camera.py,sha256=vo7mu6lyFpIhDqOAgNiJuPan8t_nDJn5cJkAYygLFcA,9243
|
|
3
3
|
gaussian_splatting/camera_trainable.py,sha256=nI6hFFRV2ev7VwLlKUbzEdN9zUmngYZAANGLr1p1yBA,3841
|
|
4
4
|
gaussian_splatting/gaussian_model.py,sha256=_Dy_dDa2prALhVgg428a-O8-8PODg3c_JPkOJJ8X4o8,13275
|
|
5
|
-
gaussian_splatting/prepare.py,sha256=
|
|
6
|
-
gaussian_splatting/render.py,sha256=
|
|
7
|
-
gaussian_splatting/train.py,sha256=
|
|
5
|
+
gaussian_splatting/prepare.py,sha256=rgwdDhPU-sS57ZgTLKwzognbS1hfR2JtjZqlC4FqPbI,3241
|
|
6
|
+
gaussian_splatting/render.py,sha256=RhZoILWxtkuDEhiL2cQsvZWNBmkyDAROYdvrQaFhPkc,6295
|
|
7
|
+
gaussian_splatting/train.py,sha256=JTKK9M0bfyr8d3UK-Fr2MxgiIYF0YheH0uOdPEACBnU,6262
|
|
8
8
|
gaussian_splatting/dataset/__init__.py,sha256=-runuT-61P0YVpfV_WXqwUZM1oY0N012YH13Bt3rzSU,138
|
|
9
|
-
gaussian_splatting/dataset/camera_trainable.py,sha256=
|
|
10
|
-
gaussian_splatting/dataset/dataset.py,sha256=
|
|
9
|
+
gaussian_splatting/dataset/camera_trainable.py,sha256=Kd8v-_ZJ9dLIQ2QyVOXbmouYf5QjbgOgHNRHVpkgCms,5041
|
|
10
|
+
gaussian_splatting/dataset/dataset.py,sha256=0tmIZ5P7kOEdABiEAXPznkRN91e5rcT5VsAzOLoOuEM,2392
|
|
11
11
|
gaussian_splatting/dataset/colmap/__init__.py,sha256=YEYT2k2WJSqrkkZq4KAJYS9UMgqU6W6TJaeHLRc1CM4,213
|
|
12
|
-
gaussian_splatting/dataset/colmap/dataset.py,sha256=
|
|
12
|
+
gaussian_splatting/dataset/colmap/dataset.py,sha256=0UBQ6ynOqElHZSphJ-MSbYQqCwwYZaAXl1y9AY5YKuY,4720
|
|
13
13
|
gaussian_splatting/dataset/colmap/params_init.py,sha256=6_6gZ0Wl4aZrps2PJ_U234sxW5D-vOTfwioVa1FWC-E,1802
|
|
14
14
|
gaussian_splatting/dataset/colmap/read_write_model.py,sha256=TenI7ai5UV7Ksg2vAXvJWnYFwOOo1tlS_633RfCLuQU,23137
|
|
15
|
-
gaussian_splatting/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd,sha256=
|
|
15
|
+
gaussian_splatting/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd,sha256=Ty4zoYipuoG31tUALZawVhJk7m9kfNAFIrN9aPUoojs,1287680
|
|
16
16
|
gaussian_splatting/diff_gaussian_rasterization/__init__.py,sha256=a9D0IZiPx-Mk1795hSq54T-NYT4MtEN_MZrxeMhw0Eo,6705
|
|
17
|
-
gaussian_splatting/simple_knn/_C.cp310-win_amd64.pyd,sha256=
|
|
17
|
+
gaussian_splatting/simple_knn/_C.cp310-win_amd64.pyd,sha256=m5AuAQx7PV2F9yyi-Gl7oTK8yQ8Uukg39UkFv7HmC98,1156608
|
|
18
18
|
gaussian_splatting/trainer/__init__.py,sha256=962fEY8A0spSQn5de_d_LkPOjA1PYKrLbuAkxwZo7mI,940
|
|
19
|
-
gaussian_splatting/trainer/abc.py,sha256=
|
|
20
|
-
gaussian_splatting/trainer/base.py,sha256=
|
|
19
|
+
gaussian_splatting/trainer/abc.py,sha256=_gcqmEobhSOdZnMyNb2oKS6cZJ-Mg3oYL4xJ5Y3_oic,4262
|
|
20
|
+
gaussian_splatting/trainer/base.py,sha256=fngLruQ9hMSNLFbc_5woG7jm6cidpoZ0dzk_zImRaE4,4851
|
|
21
21
|
gaussian_splatting/trainer/camera_trainable.py,sha256=TBQXn2f578qeizPz6tgqFm-GRvttv9duuB1xx7_J9TQ,4567
|
|
22
22
|
gaussian_splatting/trainer/combinations.py,sha256=7NX4fXdDOx8ri1_mgAaWNx-YVdo5XsqMlr9qy-Ll2MM,5329
|
|
23
|
-
gaussian_splatting/trainer/depth.py,sha256=
|
|
23
|
+
gaussian_splatting/trainer/depth.py,sha256=PxWBSNxzoQcRfCFI_yJnJMS6s8qFWn81CXK6O6ffXL0,7059
|
|
24
24
|
gaussian_splatting/trainer/opacity_reset.py,sha256=KfxDyWBNocETGcqCRTdE1n3t63HmjChaAuIP3OTIWtg,2615
|
|
25
25
|
gaussian_splatting/trainer/sh_lift.py,sha256=Hwcn_cRzXZChESpTL83ZmR608ewCR2OzItt-wZtRpak,1220
|
|
26
26
|
gaussian_splatting/trainer/densifier/__init__.py,sha256=cg4aGUolq5ayWtoqQP_BEmHE4NOD5ZuzCluRclJS61I,359
|
|
@@ -45,8 +45,8 @@ gaussian_splatting/utils/lpipsPyTorch/modules/__init__.py,sha256=47DEQpj8HBSa-_T
|
|
|
45
45
|
gaussian_splatting/utils/lpipsPyTorch/modules/lpips.py,sha256=YScu0oXIEstCCjJVRItS_R_csUw70sBMFuP8Syl2UdI,1187
|
|
46
46
|
gaussian_splatting/utils/lpipsPyTorch/modules/networks.py,sha256=kqIebq7dAhHypTXweFVEf_RDbN7_Zv7O3MlD-CfRvpg,2788
|
|
47
47
|
gaussian_splatting/utils/lpipsPyTorch/modules/utils.py,sha256=TDcem3E3HqDNN2MT8qlOL_BKVHeO4HRE77JxF-kOWk8,915
|
|
48
|
-
gaussian_splatting-1.
|
|
49
|
-
gaussian_splatting-1.
|
|
50
|
-
gaussian_splatting-1.
|
|
51
|
-
gaussian_splatting-1.
|
|
52
|
-
gaussian_splatting-1.
|
|
48
|
+
gaussian_splatting-1.19.4.dist-info/licenses/LICENSE.md,sha256=bMuRQKn0u485mx8JBBTJ5Simc-aWHaQsxmoB6jsg5oE,4752
|
|
49
|
+
gaussian_splatting-1.19.4.dist-info/METADATA,sha256=JyAP_8QWzwTb_H9KkJeTQnVpervcFxSFzJZwRu3YmgI,17183
|
|
50
|
+
gaussian_splatting-1.19.4.dist-info/WHEEL,sha256=KUuBC6lxAbHCKilKua8R9W_TM71_-9Sg5uEP3uDWcoU,101
|
|
51
|
+
gaussian_splatting-1.19.4.dist-info/top_level.txt,sha256=uaYrPYXRHhpybgCnsoazTcdhpzZGnLT_vd5eoRzBWWI,19
|
|
52
|
+
gaussian_splatting-1.19.4.dist-info/RECORD,,
|
|
File without changes
|
{gaussian_splatting-1.17.5.dist-info → gaussian_splatting-1.19.4.dist-info}/licenses/LICENSE.md
RENAMED
|
File without changes
|
|
File without changes
|