reduced-3dgs 1.10.1__tar.gz → 1.10.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reduced-3dgs might be problematic. Click here for more details.

Files changed (49) hide show
  1. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/PKG-INFO +1 -1
  2. reduced_3dgs-1.10.3/reduced_3dgs/prepare.py +105 -0
  3. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/quantization/quantizer.py +11 -3
  4. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/quantize.py +1 -0
  5. reduced_3dgs-1.10.3/reduced_3dgs/train.py +93 -0
  6. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs.egg-info/PKG-INFO +1 -1
  7. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs.egg-info/SOURCES.txt +1 -0
  8. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/setup.py +1 -1
  9. reduced_3dgs-1.10.1/reduced_3dgs/train.py +0 -200
  10. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/LICENSE.md +0 -0
  11. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/README.md +0 -0
  12. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/__init__.py +0 -0
  13. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/combinations.py +0 -0
  14. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/importance/__init__.py +0 -0
  15. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/importance/combinations.py +0 -0
  16. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/importance/trainer.py +0 -0
  17. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/pruning/__init__.py +0 -0
  18. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/pruning/combinations.py +0 -0
  19. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/pruning/trainer.py +0 -0
  20. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/quantization/__init__.py +0 -0
  21. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/quantization/abc.py +0 -0
  22. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/quantization/exclude_zeros.py +0 -0
  23. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/quantization/wrapper.py +0 -0
  24. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/shculling/__init__.py +0 -0
  25. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/shculling/gaussian_model.py +0 -0
  26. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs/shculling/trainer.py +0 -0
  27. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs.egg-info/dependency_links.txt +0 -0
  28. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs.egg-info/requires.txt +0 -0
  29. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/reduced_3dgs.egg-info/top_level.txt +0 -0
  30. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/setup.cfg +0 -0
  31. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu +0 -0
  32. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu +0 -0
  33. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu +0 -0
  34. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py +0 -0
  35. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/ext.cpp +0 -0
  36. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/rasterize_points.cu +0 -0
  37. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/reduced_3dgs/kmeans.cu +0 -0
  38. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/reduced_3dgs/redundancy_score.cu +0 -0
  39. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/reduced_3dgs/sh_culling.cu +0 -0
  40. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/diff-gaussian-rasterization/reduced_3dgs.cu +0 -0
  41. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/gaussian-importance/cuda_rasterizer/backward.cu +0 -0
  42. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/gaussian-importance/cuda_rasterizer/forward.cu +0 -0
  43. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/gaussian-importance/cuda_rasterizer/rasterizer_impl.cu +0 -0
  44. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/gaussian-importance/diff_gaussian_rasterization/__init__.py +0 -0
  45. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/gaussian-importance/ext.cpp +0 -0
  46. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/gaussian-importance/rasterize_points.cu +0 -0
  47. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/simple-knn/ext.cpp +0 -0
  48. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/simple-knn/simple_knn.cu +0 -0
  49. {reduced_3dgs-1.10.1 → reduced_3dgs-1.10.3}/submodules/simple-knn/spatial.cu +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: reduced_3dgs
3
- Version: 1.10.1
3
+ Version: 1.10.3
4
4
  Summary: Refactored code for the paper "Reducing the Memory Footprint of 3D Gaussian Splatting"
5
5
  Home-page: https://github.com/yindaheng98/reduced-3dgs
6
6
  Author: yindaheng98
@@ -0,0 +1,105 @@
1
+ from gaussian_splatting import GaussianModel
2
+ from gaussian_splatting.dataset import CameraDataset
3
+ from gaussian_splatting.dataset.colmap import colmap_init
4
+ from gaussian_splatting.trainer import AbstractTrainer
5
+ from gaussian_splatting.trainer.extensions import ScaleRegularizeTrainerWrapper
6
+ from reduced_3dgs.quantization import VectorQuantizeTrainerWrapper
7
+ from reduced_3dgs.shculling import VariableSHGaussianModel, SHCullingTrainer
8
+ from reduced_3dgs.pruning import PruningTrainer
9
+ from reduced_3dgs.combinations import PrunerInDensifyTrainer, SHCullingDensificationTrainer, SHCullingPruningTrainer, SHCullingPrunerInDensifyTrainer
10
+ from reduced_3dgs.combinations import CameraTrainableVariableSHGaussianModel, CameraSHCullingTrainer, CameraPruningTrainer
11
+ from reduced_3dgs.combinations import CameraPrunerInDensifyTrainer, CameraSHCullingDensifyTrainer, CameraSHCullingPruningTrainer, CameraSHCullingPruningDensifyTrainer
12
+
13
+
14
+ def prepare_gaussians(sh_degree: int, source: str, device: str, trainable_camera: bool = False, load_ply: str = None) -> GaussianModel:
15
+ if trainable_camera:
16
+ gaussians = CameraTrainableVariableSHGaussianModel(sh_degree).to(device)
17
+ gaussians.load_ply(load_ply) if load_ply else colmap_init(gaussians, source)
18
+ else:
19
+ gaussians = VariableSHGaussianModel(sh_degree).to(device)
20
+ gaussians.load_ply(load_ply) if load_ply else colmap_init(gaussians, source)
21
+ return gaussians
22
+
23
+
24
+ modes = {
25
+ "shculling": SHCullingTrainer,
26
+ "pruning": PruningTrainer,
27
+ "densify-pruning": PrunerInDensifyTrainer,
28
+ "densify-shculling": SHCullingDensificationTrainer,
29
+ "prune-shculling": SHCullingPruningTrainer,
30
+ "densify-prune-shculling": SHCullingPrunerInDensifyTrainer,
31
+ "camera-shculling": CameraSHCullingTrainer,
32
+ "camera-pruning": CameraPruningTrainer,
33
+ "camera-densify-pruning": CameraPrunerInDensifyTrainer,
34
+ "camera-densify-shculling": CameraSHCullingDensifyTrainer,
35
+ "camera-prune-shculling": CameraSHCullingPruningTrainer,
36
+ "camera-densify-prune-shculling": CameraSHCullingPruningDensifyTrainer,
37
+ }
38
+
39
+
40
+ def prepare_quantizer(
41
+ gaussians: GaussianModel,
42
+ scene_extent: float,
43
+ dataset: CameraDataset,
44
+ base_constructor,
45
+ load_quantized: str = None,
46
+
47
+ num_clusters=256,
48
+ num_clusters_rotation_re=None,
49
+ num_clusters_rotation_im=None,
50
+ num_clusters_opacity=None,
51
+ num_clusters_scaling=None,
52
+ num_clusters_features_dc=None,
53
+ num_clusters_features_rest=[],
54
+
55
+ quantize_from_iter=5000,
56
+ quantize_until_iter=30000,
57
+ quantize_interval=1000,
58
+ **configs):
59
+ trainer = VectorQuantizeTrainerWrapper(
60
+ base_constructor(
61
+ gaussians,
62
+ scene_extent=scene_extent,
63
+ dataset=dataset,
64
+ **configs
65
+ ),
66
+
67
+ num_clusters=num_clusters,
68
+ num_clusters_rotation_re=num_clusters_rotation_re,
69
+ num_clusters_rotation_im=num_clusters_rotation_im,
70
+ num_clusters_opacity=num_clusters_opacity,
71
+ num_clusters_scaling=num_clusters_scaling,
72
+ num_clusters_features_dc=num_clusters_features_dc,
73
+ num_clusters_features_rest=num_clusters_features_rest,
74
+
75
+ quantize_from_iter=quantize_from_iter,
76
+ quantize_until_iter=quantize_until_iter,
77
+ quantize_interval=quantize_interval,
78
+ )
79
+ if load_quantized:
80
+ trainer.quantizer.load_quantized(load_quantized)
81
+ return trainer, trainer.quantizer
82
+
83
+
84
+ def prepare_trainer(gaussians: GaussianModel, dataset: CameraDataset, mode: str, with_scale_reg=False, quantize: bool = False, load_quantized: str = None, configs={}) -> AbstractTrainer:
85
+ constructor = modes[mode]
86
+ if with_scale_reg:
87
+ constructor = lambda *args, **kwargs: ScaleRegularizeTrainerWrapper(modes[mode], *args, **kwargs)
88
+ if quantize:
89
+ trainer, quantizer = prepare_quantizer(
90
+ gaussians,
91
+ scene_extent=dataset.scene_extent(),
92
+ dataset=dataset,
93
+ base_constructor=modes[mode],
94
+ load_quantized=load_quantized,
95
+ **configs
96
+ )
97
+ else:
98
+ trainer = constructor(
99
+ gaussians,
100
+ scene_extent=dataset.scene_extent(),
101
+ dataset=dataset,
102
+ **configs
103
+ )
104
+ quantizer = None
105
+ return trainer, quantizer
@@ -206,8 +206,7 @@ class VectorQuantizer(AbstractQuantizer):
206
206
  ids_dict = self.find_nearest_cluster_id(model, self._codebook_dict)
207
207
  return ids_dict, codebook_dict
208
208
 
209
- def save_quantized(self, model: GaussianModel, ply_path: str):
210
- ids_dict, codebook_dict = self.quantize(model, update_codebook=False)
209
+ def ply_dtype(self, max_sh_degree: int):
211
210
  dtype_full = [
212
211
  ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
213
212
  ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
@@ -217,13 +216,16 @@ class VectorQuantizer(AbstractQuantizer):
217
216
  ('scale', self.force_code_dtype or compute_uint_dtype(self.num_clusters_scaling)),
218
217
  ('f_dc', self.force_code_dtype or compute_uint_dtype(self.num_clusters_features_dc)),
219
218
  ]
220
- for sh_degree in range(model.max_sh_degree):
219
+ for sh_degree in range(max_sh_degree):
221
220
  force_code_dtype = self.force_code_dtype or compute_uint_dtype(self.num_clusters_features_rest[sh_degree])
222
221
  dtype_full.extend([
223
222
  (f'f_rest_{sh_degree}_0', force_code_dtype),
224
223
  (f'f_rest_{sh_degree}_1', force_code_dtype),
225
224
  (f'f_rest_{sh_degree}_2', force_code_dtype),
226
225
  ])
226
+ return dtype_full
227
+
228
+ def ply_data(self, model: GaussianModel, ids_dict: Dict[str, torch.Tensor]):
227
229
  data_full = [
228
230
  *np.array_split(model._xyz.detach().cpu().numpy(), 3, axis=1),
229
231
  *np.array_split(torch.zeros_like(model._xyz).detach().cpu().numpy(), 3, axis=1),
@@ -236,6 +238,12 @@ class VectorQuantizer(AbstractQuantizer):
236
238
  for sh_degree in range(model.max_sh_degree):
237
239
  features_rest = ids_dict[f'features_rest_{sh_degree}'].cpu().numpy()
238
240
  data_full.extend(np.array_split(features_rest, 3, axis=1))
241
+ return data_full
242
+
243
+ def save_quantized(self, model: GaussianModel, ply_path: str):
244
+ ids_dict, codebook_dict = self.quantize(model, update_codebook=False)
245
+ dtype_full = self.ply_dtype(model.max_sh_degree)
246
+ data_full = self.ply_data(model, ids_dict)
239
247
 
240
248
  elements = np.rec.fromarrays([data.squeeze(-1) for data in data_full], dtype=dtype_full)
241
249
  el = PlyElement.describe(elements, 'vertex')
@@ -13,6 +13,7 @@ def copy_not_exists(source, destination):
13
13
 
14
14
 
15
15
  def quantize(source, destination, iteration, sh_degree, device, **kwargs):
16
+ os.makedirs(destination, exist_ok=True)
16
17
  copy_not_exists(os.path.join(source, "cfg_args"), os.path.join(destination, "cfg_args"))
17
18
  copy_not_exists(os.path.join(source, "cameras.json"), os.path.join(destination, "cameras.json"))
18
19
 
@@ -0,0 +1,93 @@
1
+ import os
2
+ import random
3
+ import shutil
4
+ from typing import List
5
+ import torch
6
+ from tqdm import tqdm
7
+ from gaussian_splatting import GaussianModel
8
+ from gaussian_splatting.dataset import CameraDataset
9
+ from gaussian_splatting.utils import psnr
10
+ from gaussian_splatting.trainer import AbstractTrainer
11
+ from gaussian_splatting.prepare import prepare_dataset
12
+ from gaussian_splatting.train import save_cfg_args
13
+ from reduced_3dgs.quantization import AbstractQuantizer
14
+ from reduced_3dgs.prepare import modes, prepare_gaussians, prepare_trainer
15
+
16
+
17
+ def prepare_training(sh_degree: int, source: str, device: str, mode: str, trainable_camera: bool = False, load_ply: str = None, load_camera: str = None, load_depth=False, with_scale_reg=False, quantize: bool = False, load_quantized: str = None, configs={}):
18
+ dataset = prepare_dataset(source=source, device=device, trainable_camera=trainable_camera, load_camera=load_camera, load_depth=load_depth)
19
+ gaussians = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply)
20
+ trainer, quantizer = prepare_trainer(gaussians=gaussians, dataset=dataset, mode=mode, with_scale_reg=with_scale_reg, quantize=quantize, load_quantized=load_quantized, configs=configs)
21
+ return dataset, gaussians, trainer, quantizer
22
+
23
+
24
+ def training(dataset: CameraDataset, gaussians: GaussianModel, trainer: AbstractTrainer, quantizer: AbstractQuantizer, destination: str, iteration: int, save_iterations: List[int], device: str, empty_cache_every_step=False):
25
+ shutil.rmtree(os.path.join(destination, "point_cloud"), ignore_errors=True) # remove the previous point cloud
26
+ pbar = tqdm(range(1, iteration+1))
27
+ epoch = list(range(len(dataset)))
28
+ epoch_psnr = torch.empty(3, 0, device=device)
29
+ ema_loss_for_log = 0.0
30
+ avg_psnr_for_log = 0.0
31
+ for step in pbar:
32
+ epoch_idx = step % len(dataset)
33
+ if epoch_idx == 0:
34
+ avg_psnr_for_log = epoch_psnr.mean().item()
35
+ epoch_psnr = torch.empty(3, 0, device=device)
36
+ random.shuffle(epoch)
37
+ idx = epoch[epoch_idx]
38
+ loss, out = trainer.step(dataset[idx])
39
+ if empty_cache_every_step:
40
+ torch.cuda.empty_cache()
41
+ with torch.no_grad():
42
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
43
+ epoch_psnr = torch.concat([epoch_psnr, psnr(out["render"], dataset[idx].ground_truth_image)], dim=1)
44
+ if step % 10 == 0:
45
+ pbar.set_postfix({'epoch': step // len(dataset), 'loss': ema_loss_for_log, 'psnr': avg_psnr_for_log, 'n': gaussians._xyz.shape[0]})
46
+ if step in save_iterations:
47
+ save_path = os.path.join(destination, "point_cloud", "iteration_" + str(step))
48
+ os.makedirs(save_path, exist_ok=True)
49
+ gaussians.save_ply(os.path.join(save_path, "point_cloud.ply"))
50
+ dataset.save_cameras(os.path.join(destination, "cameras.json"))
51
+ if quantizer:
52
+ quantizer.save_quantized(gaussians, os.path.join(save_path, "point_cloud_quantized.ply"))
53
+ save_path = os.path.join(destination, "point_cloud", "iteration_" + str(iteration))
54
+ os.makedirs(save_path, exist_ok=True)
55
+ gaussians.save_ply(os.path.join(save_path, "point_cloud.ply"))
56
+ dataset.save_cameras(os.path.join(destination, "cameras.json"))
57
+ if quantizer:
58
+ quantizer.save_quantized(gaussians, os.path.join(save_path, "point_cloud_quantized.ply"))
59
+
60
+
61
+ if __name__ == "__main__":
62
+ from argparse import ArgumentParser
63
+ parser = ArgumentParser()
64
+ parser.add_argument("--sh_degree", default=3, type=int)
65
+ parser.add_argument("-s", "--source", required=True, type=str)
66
+ parser.add_argument("-d", "--destination", required=True, type=str)
67
+ parser.add_argument("-i", "--iteration", default=30000, type=int)
68
+ parser.add_argument("-l", "--load_ply", default=None, type=str)
69
+ parser.add_argument("--load_camera", default=None, type=str)
70
+ parser.add_argument("--quantize", action='store_true')
71
+ parser.add_argument("--no_depth_data", action='store_true')
72
+ parser.add_argument("--with_scale_reg", action="store_true")
73
+ parser.add_argument("--load_quantized", default=None, type=str)
74
+ parser.add_argument("--mode", choices=list(modes), default="densify-prune-shculling")
75
+ parser.add_argument("--save_iterations", nargs="+", type=int, default=[7000, 30000])
76
+ parser.add_argument("--device", default="cuda", type=str)
77
+ parser.add_argument("--empty_cache_every_step", action='store_true')
78
+ parser.add_argument("-o", "--option", default=[], action='append', type=str)
79
+ args = parser.parse_args()
80
+ save_cfg_args(args.destination, args.sh_degree, args.source)
81
+ torch.autograd.set_detect_anomaly(False)
82
+
83
+ configs = {o.split("=", 1)[0]: eval(o.split("=", 1)[1]) for o in args.option}
84
+ dataset, gaussians, trainer, quantizer = prepare_training(
85
+ sh_degree=args.sh_degree, source=args.source, device=args.device, mode=args.mode, trainable_camera="camera" in args.mode,
86
+ load_ply=args.load_ply, load_camera=args.load_camera, load_depth=not args.no_depth_data, with_scale_reg=args.with_scale_reg,
87
+ quantize=args.quantize, load_quantized=args.load_quantized, configs=configs)
88
+ dataset.save_cameras(os.path.join(args.destination, "cameras.json"))
89
+ torch.cuda.empty_cache()
90
+ training(
91
+ dataset=dataset, gaussians=gaussians, trainer=trainer, quantizer=quantizer,
92
+ destination=args.destination, iteration=args.iteration, save_iterations=args.save_iterations,
93
+ device=args.device)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: reduced_3dgs
3
- Version: 1.10.1
3
+ Version: 1.10.3
4
4
  Summary: Refactored code for the paper "Reducing the Memory Footprint of 3D Gaussian Splatting"
5
5
  Home-page: https://github.com/yindaheng98/reduced-3dgs
6
6
  Author: yindaheng98
@@ -3,6 +3,7 @@ README.md
3
3
  setup.py
4
4
  reduced_3dgs/__init__.py
5
5
  reduced_3dgs/combinations.py
6
+ reduced_3dgs/prepare.py
6
7
  reduced_3dgs/quantize.py
7
8
  reduced_3dgs/train.py
8
9
  reduced_3dgs.egg-info/PKG-INFO
@@ -60,7 +60,7 @@ if os.name == 'nt':
60
60
 
61
61
  setup(
62
62
  name="reduced_3dgs",
63
- version='1.10.1',
63
+ version='1.10.3',
64
64
  author='yindaheng98',
65
65
  author_email='yindaheng98@gmail.com',
66
66
  url='https://github.com/yindaheng98/reduced-3dgs',
@@ -1,200 +0,0 @@
1
- import os
2
- import random
3
- import shutil
4
- from typing import List, Tuple
5
- import torch
6
- from tqdm import tqdm
7
- from gaussian_splatting import GaussianModel
8
- from gaussian_splatting.dataset import CameraDataset
9
- from gaussian_splatting.utils import psnr
10
- from gaussian_splatting.dataset.colmap import colmap_init
11
- from gaussian_splatting.trainer import AbstractTrainer
12
- from gaussian_splatting.trainer.extensions import ScaleRegularizeTrainerWrapper
13
- from gaussian_splatting.train import prepare_dataset, save_cfg_args
14
- from reduced_3dgs.quantization import AbstractQuantizer, VectorQuantizeTrainerWrapper
15
- from reduced_3dgs.shculling import VariableSHGaussianModel, SHCullingTrainer
16
- from reduced_3dgs.pruning import PruningTrainer
17
- from reduced_3dgs.combinations import PrunerInDensifyTrainer, SHCullingDensificationTrainer, SHCullingPruningTrainer, SHCullingPrunerInDensifyTrainer
18
- from reduced_3dgs.combinations import CameraTrainableVariableSHGaussianModel, CameraSHCullingTrainer, CameraPruningTrainer
19
- from reduced_3dgs.combinations import CameraPrunerInDensifyTrainer, CameraSHCullingDensifyTrainer, CameraSHCullingPruningTrainer, CameraSHCullingPruningDensifyTrainer
20
-
21
-
22
- basemodes = {
23
- "shculling": SHCullingTrainer,
24
- "pruning": PruningTrainer,
25
- "densify-pruning": PrunerInDensifyTrainer,
26
- "densify-shculling": SHCullingDensificationTrainer,
27
- "prune-shculling": SHCullingPruningTrainer,
28
- "densify-prune-shculling": SHCullingPrunerInDensifyTrainer,
29
- }
30
- cameramodes = {
31
- "camera-shculling": CameraSHCullingTrainer,
32
- "camera-pruning": CameraPruningTrainer,
33
- "camera-densify-pruning": CameraPrunerInDensifyTrainer,
34
- "camera-densify-shculling": CameraSHCullingDensifyTrainer,
35
- "camera-prune-shculling": CameraSHCullingPruningTrainer,
36
- "camera-densify-prune-shculling": CameraSHCullingPruningDensifyTrainer,
37
- }
38
-
39
-
40
- def prepare_quantizer(
41
- gaussians: GaussianModel,
42
- scene_extent: float,
43
- dataset: CameraDataset,
44
- base_constructor,
45
- load_quantized: str = None,
46
-
47
- num_clusters=256,
48
- num_clusters_rotation_re=None,
49
- num_clusters_rotation_im=None,
50
- num_clusters_opacity=None,
51
- num_clusters_scaling=None,
52
- num_clusters_features_dc=None,
53
- num_clusters_features_rest=[],
54
-
55
- quantize_from_iter=5000,
56
- quantize_until_iter=30000,
57
- quantize_interval=1000,
58
- **configs):
59
- trainer = VectorQuantizeTrainerWrapper(
60
- base_constructor(
61
- gaussians,
62
- scene_extent=scene_extent,
63
- dataset=dataset,
64
- **configs
65
- ),
66
-
67
- num_clusters=num_clusters,
68
- num_clusters_rotation_re=num_clusters_rotation_re,
69
- num_clusters_rotation_im=num_clusters_rotation_im,
70
- num_clusters_opacity=num_clusters_opacity,
71
- num_clusters_scaling=num_clusters_scaling,
72
- num_clusters_features_dc=num_clusters_features_dc,
73
- num_clusters_features_rest=num_clusters_features_rest,
74
-
75
- quantize_from_iter=quantize_from_iter,
76
- quantize_until_iter=quantize_until_iter,
77
- quantize_interval=quantize_interval,
78
- )
79
- if load_quantized:
80
- trainer.quantizer.load_quantized(load_quantized)
81
- return trainer, trainer.quantizer
82
-
83
-
84
- def prepare_gaussians(sh_degree: int, source: str, device: str, trainable_camera: bool = False, load_ply: str = None) -> Tuple[CameraDataset, GaussianModel, AbstractTrainer]:
85
- if trainable_camera:
86
- gaussians = CameraTrainableVariableSHGaussianModel(sh_degree).to(device)
87
- gaussians.load_ply(load_ply) if load_ply else colmap_init(gaussians, source)
88
- else:
89
- gaussians = VariableSHGaussianModel(sh_degree).to(device)
90
- gaussians.load_ply(load_ply) if load_ply else colmap_init(gaussians, source)
91
- return gaussians
92
-
93
-
94
- def prepare_trainer(gaussians: GaussianModel, dataset: CameraDataset, mode: str, with_scale_reg=False, quantize: bool = False, load_quantized: str = None, configs={}) -> AbstractTrainer:
95
- if mode in basemodes:
96
- modes = basemodes
97
- elif mode in cameramodes:
98
- modes = cameramodes
99
- else:
100
- raise ValueError(f"Unknown mode: {mode}")
101
- constructor = modes[mode]
102
- if with_scale_reg:
103
- constructor = lambda *args, **kwargs: ScaleRegularizeTrainerWrapper(modes[mode], *args, **kwargs)
104
- if quantize:
105
- trainer, quantizer = prepare_quantizer(
106
- gaussians,
107
- scene_extent=dataset.scene_extent(),
108
- dataset=dataset,
109
- base_constructor=modes[mode],
110
- load_quantized=load_quantized,
111
- **configs
112
- )
113
- else:
114
- trainer = constructor(
115
- gaussians,
116
- scene_extent=dataset.scene_extent(),
117
- dataset=dataset,
118
- **configs
119
- )
120
- quantizer = None
121
- return trainer, quantizer
122
-
123
-
124
- def prepare_training(sh_degree: int, source: str, device: str, mode: str, load_ply: str = None, load_camera: str = None, load_depth=False, with_scale_reg=False, quantize: bool = False, load_quantized: str = None, configs={}) -> Tuple[CameraDataset, GaussianModel, AbstractTrainer]:
125
- dataset = prepare_dataset(source=source, device=device, trainable_camera=mode in cameramodes, load_camera=load_camera, load_depth=load_depth)
126
- gaussians = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=mode in cameramodes, load_ply=load_ply)
127
- trainer, quantizer = prepare_trainer(gaussians=gaussians, dataset=dataset, mode=mode, with_scale_reg=with_scale_reg, quantize=quantize, load_quantized=load_quantized, configs=configs)
128
- return dataset, gaussians, trainer, quantizer
129
-
130
-
131
- def training(dataset: CameraDataset, gaussians: GaussianModel, trainer: AbstractTrainer, quantizer: AbstractQuantizer, destination: str, iteration: int, save_iterations: List[int], device: str, empty_cache_every_step=False):
132
- shutil.rmtree(os.path.join(destination, "point_cloud"), ignore_errors=True) # remove the previous point cloud
133
- pbar = tqdm(range(1, iteration+1))
134
- epoch = list(range(len(dataset)))
135
- epoch_psnr = torch.empty(3, 0, device=device)
136
- ema_loss_for_log = 0.0
137
- avg_psnr_for_log = 0.0
138
- for step in pbar:
139
- epoch_idx = step % len(dataset)
140
- if epoch_idx == 0:
141
- avg_psnr_for_log = epoch_psnr.mean().item()
142
- epoch_psnr = torch.empty(3, 0, device=device)
143
- random.shuffle(epoch)
144
- idx = epoch[epoch_idx]
145
- loss, out = trainer.step(dataset[idx])
146
- if empty_cache_every_step:
147
- torch.cuda.empty_cache()
148
- with torch.no_grad():
149
- ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
150
- epoch_psnr = torch.concat([epoch_psnr, psnr(out["render"], dataset[idx].ground_truth_image)], dim=1)
151
- if step % 10 == 0:
152
- pbar.set_postfix({'epoch': step // len(dataset), 'loss': ema_loss_for_log, 'psnr': avg_psnr_for_log, 'n': gaussians._xyz.shape[0]})
153
- if step in save_iterations:
154
- save_path = os.path.join(destination, "point_cloud", "iteration_" + str(step))
155
- os.makedirs(save_path, exist_ok=True)
156
- gaussians.save_ply(os.path.join(save_path, "point_cloud.ply"))
157
- dataset.save_cameras(os.path.join(destination, "cameras.json"))
158
- if quantizer:
159
- quantizer.save_quantized(gaussians, os.path.join(save_path, "point_cloud_quantized.ply"))
160
- save_path = os.path.join(destination, "point_cloud", "iteration_" + str(iteration))
161
- os.makedirs(save_path, exist_ok=True)
162
- gaussians.save_ply(os.path.join(save_path, "point_cloud.ply"))
163
- dataset.save_cameras(os.path.join(destination, "cameras.json"))
164
- if quantizer:
165
- quantizer.save_quantized(gaussians, os.path.join(save_path, "point_cloud_quantized.ply"))
166
-
167
-
168
- if __name__ == "__main__":
169
- from argparse import ArgumentParser, Namespace
170
- parser = ArgumentParser()
171
- parser.add_argument("--sh_degree", default=3, type=int)
172
- parser.add_argument("-s", "--source", required=True, type=str)
173
- parser.add_argument("-d", "--destination", required=True, type=str)
174
- parser.add_argument("-i", "--iteration", default=30000, type=int)
175
- parser.add_argument("-l", "--load_ply", default=None, type=str)
176
- parser.add_argument("--load_camera", default=None, type=str)
177
- parser.add_argument("--quantize", action='store_true')
178
- parser.add_argument("--no_depth_data", action='store_true')
179
- parser.add_argument("--with_scale_reg", action="store_true")
180
- parser.add_argument("--load_quantized", default=None, type=str)
181
- parser.add_argument("--mode", choices=list(basemodes.keys()) + list(cameramodes.keys()), default="densify-prune-shculling")
182
- parser.add_argument("--save_iterations", nargs="+", type=int, default=[7000, 30000])
183
- parser.add_argument("--device", default="cuda", type=str)
184
- parser.add_argument("--empty_cache_every_step", action='store_true')
185
- parser.add_argument("-o", "--option", default=[], action='append', type=str)
186
- args = parser.parse_args()
187
- save_cfg_args(args.destination, args.sh_degree, args.source)
188
- torch.autograd.set_detect_anomaly(False)
189
-
190
- configs = {o.split("=", 1)[0]: eval(o.split("=", 1)[1]) for o in args.option}
191
- dataset, gaussians, trainer, quantizer = prepare_training(
192
- sh_degree=args.sh_degree, source=args.source, device=args.device, mode=args.mode,
193
- load_ply=args.load_ply, load_camera=args.load_camera, load_depth=not args.no_depth_data, with_scale_reg=args.with_scale_reg,
194
- quantize=args.quantize, load_quantized=args.load_quantized, configs=configs)
195
- dataset.save_cameras(os.path.join(args.destination, "cameras.json"))
196
- torch.cuda.empty_cache()
197
- training(
198
- dataset=dataset, gaussians=gaussians, trainer=trainer, quantizer=quantizer,
199
- destination=args.destination, iteration=args.iteration, save_iterations=args.save_iterations,
200
- device=args.device)
File without changes
File without changes
File without changes