reduced-3dgs 1.9.1__cp310-cp310-win_amd64.whl → 1.10.12__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,9 +17,19 @@ def BaseFullPruningTrainer(
17
17
  scene_extent: float,
18
18
  dataset: List[Camera],
19
19
  *args,
20
- importance_prune_from_iter=1000,
21
- importance_prune_until_iter=15000,
22
- importance_prune_interval: int = 100,
20
+ importance_prune_from_iter=15000,
21
+ importance_prune_until_iter=20000,
22
+ importance_prune_interval: int = 1000,
23
+ importance_score_resize=None,
24
+ importance_prune_type="comprehensive",
25
+ importance_prune_percent=0.1,
26
+ importance_prune_thr_important_score=None,
27
+ importance_prune_thr_v_important_score=3.0,
28
+ importance_prune_thr_max_v_important_score=None,
29
+ importance_prune_thr_count=1,
30
+ importance_prune_thr_T_alpha=1.0,
31
+ importance_prune_thr_T_alpha_avg=0.001,
32
+ importance_v_pow=0.1,
23
33
  **kwargs):
24
34
  return PruningTrainerWrapper(
25
35
  lambda model, scene_extent, dataset: ImportancePruner(
@@ -28,6 +38,16 @@ def BaseFullPruningTrainer(
28
38
  importance_prune_from_iter=importance_prune_from_iter,
29
39
  importance_prune_until_iter=importance_prune_until_iter,
30
40
  importance_prune_interval=importance_prune_interval,
41
+ importance_score_resize=importance_score_resize,
42
+ importance_prune_type=importance_prune_type,
43
+ importance_prune_percent=importance_prune_percent,
44
+ importance_prune_thr_important_score=importance_prune_thr_important_score,
45
+ importance_prune_thr_v_important_score=importance_prune_thr_v_important_score,
46
+ importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
47
+ importance_prune_thr_count=importance_prune_thr_count,
48
+ importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
49
+ importance_prune_thr_T_alpha_avg=importance_prune_thr_T_alpha_avg,
50
+ importance_v_pow=importance_v_pow,
31
51
  ),
32
52
  model, scene_extent, dataset,
33
53
  *args, **kwargs
@@ -39,9 +59,19 @@ def BaseFullPrunerInDensifyTrainer(
39
59
  scene_extent: float,
40
60
  dataset: List[Camera],
41
61
  *args,
42
- importance_prune_from_iter=1000,
43
- importance_prune_until_iter=15000,
44
- importance_prune_interval: int = 100,
62
+ importance_prune_from_iter=15000,
63
+ importance_prune_until_iter=20000,
64
+ importance_prune_interval: int = 1000,
65
+ importance_score_resize=None,
66
+ importance_prune_type="comprehensive",
67
+ importance_prune_percent=0.1,
68
+ importance_prune_thr_important_score=None,
69
+ importance_prune_thr_v_important_score=3.0,
70
+ importance_prune_thr_max_v_important_score=None,
71
+ importance_prune_thr_count=1,
72
+ importance_prune_thr_T_alpha=1.0,
73
+ importance_prune_thr_T_alpha_avg=0.001,
74
+ importance_v_pow=0.1,
45
75
  **kwargs):
46
76
  return PrunerInDensifyTrainerWrapper(
47
77
  lambda model, scene_extent, dataset: ImportancePruner(
@@ -50,6 +80,16 @@ def BaseFullPrunerInDensifyTrainer(
50
80
  importance_prune_from_iter=importance_prune_from_iter,
51
81
  importance_prune_until_iter=importance_prune_until_iter,
52
82
  importance_prune_interval=importance_prune_interval,
83
+ importance_score_resize=importance_score_resize,
84
+ importance_prune_type=importance_prune_type,
85
+ importance_prune_percent=importance_prune_percent,
86
+ importance_prune_thr_important_score=importance_prune_thr_important_score,
87
+ importance_prune_thr_v_important_score=importance_prune_thr_v_important_score,
88
+ importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
89
+ importance_prune_thr_count=importance_prune_thr_count,
90
+ importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
91
+ importance_prune_thr_T_alpha_avg=importance_prune_thr_T_alpha_avg,
92
+ importance_v_pow=importance_v_pow,
53
93
  ),
54
94
  model, scene_extent, dataset,
55
95
  *args, **kwargs
@@ -57,11 +97,17 @@ def BaseFullPrunerInDensifyTrainer(
57
97
 
58
98
 
59
99
  def DepthFullPruningTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
60
- return DepthTrainerWrapper(BaseFullPruningTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
100
+ return DepthTrainerWrapper(
101
+ BaseFullPruningTrainer,
102
+ model, scene_extent, dataset,
103
+ *args, **kwargs)
61
104
 
62
105
 
63
106
  def DepthFullPrunerInDensifyTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
64
- return DepthTrainerWrapper(BaseFullPrunerInDensifyTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
107
+ return DepthTrainerWrapper(
108
+ BaseFullPrunerInDensifyTrainer,
109
+ model, scene_extent, dataset,
110
+ *args, **kwargs)
65
111
 
66
112
 
67
113
  def OpacityResetPruningTrainer(
@@ -13,13 +13,15 @@ def BaseImportancePrunerInDensifyTrainer(
13
13
  importance_prune_from_iter=15000,
14
14
  importance_prune_until_iter=20000,
15
15
  importance_prune_interval: int = 1000,
16
+ importance_score_resize=None,
16
17
  importance_prune_type="comprehensive",
17
18
  importance_prune_percent=0.1,
18
19
  importance_prune_thr_important_score=None,
19
- importance_prune_thr_v_important_score=1.0,
20
+ importance_prune_thr_v_important_score=3.0,
20
21
  importance_prune_thr_max_v_important_score=None,
21
22
  importance_prune_thr_count=1,
22
- importance_prune_thr_T_alpha=0.01,
23
+ importance_prune_thr_T_alpha=1.0,
24
+ importance_prune_thr_T_alpha_avg=0.001,
23
25
  importance_v_pow=0.1,
24
26
  **kwargs):
25
27
  return DensificationTrainerWrapper(
@@ -29,6 +31,7 @@ def BaseImportancePrunerInDensifyTrainer(
29
31
  importance_prune_from_iter=importance_prune_from_iter,
30
32
  importance_prune_until_iter=importance_prune_until_iter,
31
33
  importance_prune_interval=importance_prune_interval,
34
+ importance_score_resize=importance_score_resize,
32
35
  importance_prune_type=importance_prune_type,
33
36
  importance_prune_percent=importance_prune_percent,
34
37
  importance_prune_thr_important_score=importance_prune_thr_important_score,
@@ -36,6 +39,7 @@ def BaseImportancePrunerInDensifyTrainer(
36
39
  importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
37
40
  importance_prune_thr_count=importance_prune_thr_count,
38
41
  importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
42
+ importance_prune_thr_T_alpha_avg=importance_prune_thr_T_alpha_avg,
39
43
  importance_v_pow=importance_v_pow,
40
44
  ),
41
45
  model,
@@ -48,11 +52,17 @@ def BaseImportancePrunerInDensifyTrainer(
48
52
 
49
53
 
50
54
  def DepthImportancePruningTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
51
- return DepthTrainerWrapper(BaseImportancePruningTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
55
+ return DepthTrainerWrapper(
56
+ BaseImportancePruningTrainer,
57
+ model, scene_extent, dataset,
58
+ *args, **kwargs)
52
59
 
53
60
 
54
61
  def DepthImportancePrunerInDensifyTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
55
- return DepthTrainerWrapper(BaseImportancePrunerInDensifyTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
62
+ return DepthTrainerWrapper(
63
+ BaseImportancePrunerInDensifyTrainer,
64
+ model, scene_extent, dataset,
65
+ *args, **kwargs)
56
66
 
57
67
 
58
68
  ImportancePruningTrainer = DepthImportancePruningTrainer
@@ -1,8 +1,9 @@
1
1
  import math
2
- from typing import Callable, List
2
+ from typing import List
3
3
  import torch
4
4
 
5
5
  from gaussian_splatting import Camera, GaussianModel
6
+ from gaussian_splatting.camera import build_camera
6
7
  from gaussian_splatting.trainer import AbstractDensifier, DensifierWrapper, DensificationTrainer, NoopDensifier
7
8
  from gaussian_splatting.dataset import CameraDataset
8
9
  from .diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
@@ -75,11 +76,20 @@ def count_render(self: GaussianModel, viewpoint_camera: Camera):
75
76
  }
76
77
 
77
78
 
78
- def prune_list(model: GaussianModel, dataset: CameraDataset):
79
+ def prune_list(model: GaussianModel, dataset: CameraDataset, resize=None):
79
80
  gaussian_count = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.int)
80
81
  opacity_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
81
82
  T_alpha_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
82
83
  for camera in dataset:
84
+ if resize is not None:
85
+ height, width = camera.image_height, camera.image_width
86
+ scale = resize / max(height, width)
87
+ height, width = int(height * scale), int(width * scale)
88
+ camera = build_camera(
89
+ image_height=height, image_width=width,
90
+ FoVx=camera.FoVx, FoVy=camera.FoVy,
91
+ R=camera.R, T=camera.T,
92
+ device=camera.R.device)
83
93
  out = count_render(model, camera)
84
94
  gaussian_count += out["gaussians_count"]
85
95
  opacity_important_score += out["opacity_important_score"]
@@ -118,15 +128,17 @@ def score2mask(percent, import_score: list, threshold=None):
118
128
 
119
129
  def prune_gaussians(
120
130
  gaussians: GaussianModel, dataset: CameraDataset,
131
+ resize=None,
121
132
  prune_type="comprehensive",
122
133
  prune_percent=0.1,
123
134
  prune_thr_important_score=None,
124
- prune_thr_v_important_score=1.0,
135
+ prune_thr_v_important_score=None,
125
136
  prune_thr_max_v_important_score=None,
126
- prune_thr_count=1,
127
- prune_thr_T_alpha=0.01,
137
+ prune_thr_count=None,
138
+ prune_thr_T_alpha=None,
139
+ prune_thr_T_alpha_avg=None,
128
140
  v_pow=0.1):
129
- gaussian_list, opacity_imp_list, T_alpha_imp_list = prune_list(gaussians, dataset)
141
+ gaussian_list, opacity_imp_list, T_alpha_imp_list = prune_list(gaussians, dataset, resize)
130
142
  match prune_type:
131
143
  case "important_score":
132
144
  mask = score2mask(prune_percent, opacity_imp_list, prune_thr_important_score)
@@ -141,6 +153,10 @@ def prune_gaussians(
141
153
  case "T_alpha":
142
154
  # new importance score defined by doji
143
155
  mask = score2mask(prune_percent, T_alpha_imp_list, prune_thr_T_alpha)
156
+ case "T_alpha_avg":
157
+ v_list = T_alpha_imp_list / gaussian_list
158
+ v_list[gaussian_list <= 0] = 0
159
+ mask = score2mask(prune_percent, v_list, prune_thr_T_alpha_avg)
144
160
  case "comprehensive":
145
161
  mask = torch.zeros_like(gaussian_list, dtype=torch.bool)
146
162
  if prune_thr_important_score is not None:
@@ -155,6 +171,10 @@ def prune_gaussians(
155
171
  mask |= score2mask(prune_percent, gaussian_list, prune_thr_count)
156
172
  if prune_thr_T_alpha is not None:
157
173
  mask |= score2mask(prune_percent, T_alpha_imp_list, prune_thr_T_alpha)
174
+ if prune_thr_T_alpha_avg is not None:
175
+ v_list = T_alpha_imp_list / gaussian_list
176
+ v_list[gaussian_list <= 0] = 0
177
+ mask |= score2mask(prune_percent, v_list, prune_thr_T_alpha_avg)
158
178
  case _:
159
179
  raise Exception("Unsupportive prunning method")
160
180
  return mask
@@ -167,26 +187,29 @@ class ImportancePruner(DensifierWrapper):
167
187
  importance_prune_from_iter=15000,
168
188
  importance_prune_until_iter=20000,
169
189
  importance_prune_interval: int = 1000,
190
+ importance_score_resize=None,
170
191
  importance_prune_type="comprehensive",
171
192
  importance_prune_percent=0.1,
172
193
  importance_prune_thr_important_score=None,
173
- importance_prune_thr_v_important_score=1.0,
194
+ importance_prune_thr_v_important_score=3.0,
174
195
  importance_prune_thr_max_v_important_score=None,
175
196
  importance_prune_thr_count=1,
176
- importance_prune_thr_T_alpha=0.01,
177
- importance_v_pow=0.1
178
- ):
197
+ importance_prune_thr_T_alpha=1,
198
+ importance_prune_thr_T_alpha_avg=0.001,
199
+ importance_v_pow=0.1):
179
200
  super().__init__(base_densifier)
180
201
  self.dataset = dataset
181
202
  self.importance_prune_from_iter = importance_prune_from_iter
182
203
  self.importance_prune_until_iter = importance_prune_until_iter
183
204
  self.importance_prune_interval = importance_prune_interval
205
+ self.resize = importance_score_resize
184
206
  self.prune_percent = importance_prune_percent
185
207
  self.prune_thr_important_score = importance_prune_thr_important_score
186
208
  self.prune_thr_v_important_score = importance_prune_thr_v_important_score
187
209
  self.prune_thr_max_v_important_score = importance_prune_thr_max_v_important_score
188
210
  self.prune_thr_count = importance_prune_thr_count
189
211
  self.prune_thr_T_alpha = importance_prune_thr_T_alpha
212
+ self.prune_thr_T_alpha_avg = importance_prune_thr_T_alpha_avg
190
213
  self.v_pow = importance_v_pow
191
214
  self.prune_type = importance_prune_type
192
215
 
@@ -195,10 +218,11 @@ class ImportancePruner(DensifierWrapper):
195
218
  if self.importance_prune_from_iter <= step <= self.importance_prune_until_iter and step % self.importance_prune_interval == 0:
196
219
  remove_mask = prune_gaussians(
197
220
  self.model, self.dataset,
221
+ self.resize,
198
222
  self.prune_type, self.prune_percent,
199
223
  self.prune_thr_important_score, self.prune_thr_v_important_score,
200
224
  self.prune_thr_max_v_important_score, self.prune_thr_count,
201
- self.prune_thr_T_alpha, self.v_pow,
225
+ self.prune_thr_T_alpha, self.prune_thr_T_alpha_avg, self.v_pow,
202
226
  )
203
227
  ret = ret._replace(remove_mask=remove_mask if ret.remove_mask is None else torch.logical_or(ret.remove_mask, remove_mask))
204
228
  return ret
@@ -212,13 +236,15 @@ def BaseImportancePruningTrainer(
212
236
  importance_prune_from_iter=15000,
213
237
  importance_prune_until_iter=20000,
214
238
  importance_prune_interval: int = 1000,
239
+ importance_score_resize=None,
215
240
  importance_prune_type="comprehensive",
216
241
  importance_prune_percent=0.1,
217
242
  importance_prune_thr_important_score=None,
218
- importance_prune_thr_v_important_score=1.0,
243
+ importance_prune_thr_v_important_score=3.0,
219
244
  importance_prune_thr_max_v_important_score=None,
220
245
  importance_prune_thr_count=1,
221
- importance_prune_thr_T_alpha=0.01,
246
+ importance_prune_thr_T_alpha=1.0,
247
+ importance_prune_thr_T_alpha_avg=0.001,
222
248
  importance_v_pow=0.1,
223
249
  **kwargs):
224
250
  return DensificationTrainer(
@@ -229,6 +255,7 @@ def BaseImportancePruningTrainer(
229
255
  importance_prune_from_iter=importance_prune_from_iter,
230
256
  importance_prune_until_iter=importance_prune_until_iter,
231
257
  importance_prune_interval=importance_prune_interval,
258
+ importance_score_resize=importance_score_resize,
232
259
  importance_prune_type=importance_prune_type,
233
260
  importance_prune_percent=importance_prune_percent,
234
261
  importance_prune_thr_important_score=importance_prune_thr_important_score,
@@ -236,6 +263,7 @@ def BaseImportancePruningTrainer(
236
263
  importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
237
264
  importance_prune_thr_count=importance_prune_thr_count,
238
265
  importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
266
+ importance_prune_thr_T_alpha_avg=importance_prune_thr_T_alpha_avg,
239
267
  importance_v_pow=importance_v_pow,
240
268
  ), *args, **kwargs
241
269
  )
@@ -0,0 +1,105 @@
1
+ from gaussian_splatting import GaussianModel
2
+ from gaussian_splatting.dataset import CameraDataset
3
+ from gaussian_splatting.dataset.colmap import colmap_init
4
+ from gaussian_splatting.trainer import AbstractTrainer
5
+ from gaussian_splatting.trainer.extensions import ScaleRegularizeTrainerWrapper
6
+ from reduced_3dgs.quantization import VectorQuantizeTrainerWrapper
7
+ from reduced_3dgs.shculling import VariableSHGaussianModel, SHCullingTrainer
8
+ from reduced_3dgs.pruning import PruningTrainer
9
+ from reduced_3dgs.combinations import PrunerInDensifyTrainer, SHCullingDensificationTrainer, SHCullingPruningTrainer, SHCullingPrunerInDensifyTrainer
10
+ from reduced_3dgs.combinations import CameraTrainableVariableSHGaussianModel, CameraSHCullingTrainer, CameraPruningTrainer
11
+ from reduced_3dgs.combinations import CameraPrunerInDensifyTrainer, CameraSHCullingDensifyTrainer, CameraSHCullingPruningTrainer, CameraSHCullingPruningDensifyTrainer
12
+
13
+
14
+ def prepare_gaussians(sh_degree: int, source: str, device: str, trainable_camera: bool = False, load_ply: str = None) -> GaussianModel:
15
+ if trainable_camera:
16
+ gaussians = CameraTrainableVariableSHGaussianModel(sh_degree).to(device)
17
+ gaussians.load_ply(load_ply) if load_ply else colmap_init(gaussians, source)
18
+ else:
19
+ gaussians = VariableSHGaussianModel(sh_degree).to(device)
20
+ gaussians.load_ply(load_ply) if load_ply else colmap_init(gaussians, source)
21
+ return gaussians
22
+
23
+
24
+ modes = {
25
+ "shculling": SHCullingTrainer,
26
+ "pruning": PruningTrainer,
27
+ "densify-pruning": PrunerInDensifyTrainer,
28
+ "densify-shculling": SHCullingDensificationTrainer,
29
+ "prune-shculling": SHCullingPruningTrainer,
30
+ "densify-prune-shculling": SHCullingPrunerInDensifyTrainer,
31
+ "camera-shculling": CameraSHCullingTrainer,
32
+ "camera-pruning": CameraPruningTrainer,
33
+ "camera-densify-pruning": CameraPrunerInDensifyTrainer,
34
+ "camera-densify-shculling": CameraSHCullingDensifyTrainer,
35
+ "camera-prune-shculling": CameraSHCullingPruningTrainer,
36
+ "camera-densify-prune-shculling": CameraSHCullingPruningDensifyTrainer,
37
+ }
38
+
39
+
40
+ def prepare_quantizer(
41
+ gaussians: GaussianModel,
42
+ scene_extent: float,
43
+ dataset: CameraDataset,
44
+ base_constructor,
45
+ load_quantized: str = None,
46
+
47
+ num_clusters=256,
48
+ num_clusters_rotation_re=None,
49
+ num_clusters_rotation_im=None,
50
+ num_clusters_opacity=None,
51
+ num_clusters_scaling=None,
52
+ num_clusters_features_dc=None,
53
+ num_clusters_features_rest=[],
54
+
55
+ quantize_from_iter=5000,
56
+ quantize_until_iter=30000,
57
+ quantize_interval=1000,
58
+ **configs):
59
+ trainer = VectorQuantizeTrainerWrapper(
60
+ base_constructor(
61
+ gaussians,
62
+ scene_extent=scene_extent,
63
+ dataset=dataset,
64
+ **configs
65
+ ),
66
+
67
+ num_clusters=num_clusters,
68
+ num_clusters_rotation_re=num_clusters_rotation_re,
69
+ num_clusters_rotation_im=num_clusters_rotation_im,
70
+ num_clusters_opacity=num_clusters_opacity,
71
+ num_clusters_scaling=num_clusters_scaling,
72
+ num_clusters_features_dc=num_clusters_features_dc,
73
+ num_clusters_features_rest=num_clusters_features_rest,
74
+
75
+ quantize_from_iter=quantize_from_iter,
76
+ quantize_until_iter=quantize_until_iter,
77
+ quantize_interval=quantize_interval,
78
+ )
79
+ if load_quantized:
80
+ trainer.quantizer.load_quantized(load_quantized)
81
+ return trainer, trainer.quantizer
82
+
83
+
84
+ def prepare_trainer(gaussians: GaussianModel, dataset: CameraDataset, mode: str, with_scale_reg=False, quantize: bool = False, load_quantized: str = None, configs={}) -> AbstractTrainer:
85
+ constructor = modes[mode]
86
+ if with_scale_reg:
87
+ constructor = lambda *args, **kwargs: ScaleRegularizeTrainerWrapper(modes[mode], *args, **kwargs)
88
+ if quantize:
89
+ trainer, quantizer = prepare_quantizer(
90
+ gaussians,
91
+ scene_extent=dataset.scene_extent(),
92
+ dataset=dataset,
93
+ base_constructor=modes[mode],
94
+ load_quantized=load_quantized,
95
+ **configs
96
+ )
97
+ else:
98
+ trainer = constructor(
99
+ gaussians,
100
+ scene_extent=dataset.scene_extent(),
101
+ dataset=dataset,
102
+ **configs
103
+ )
104
+ quantizer = None
105
+ return trainer, quantizer
@@ -23,7 +23,7 @@ def PrunerInDensifyTrainerWrapper(
23
23
  return SplitCloneDensifierTrainerWrapper(
24
24
  lambda model, scene_extent: BasePruner(
25
25
  noargs_base_densifier_constructor(model, scene_extent, dataset),
26
- dataset,
26
+ scene_extent, dataset,
27
27
  prune_from_iter=prune_from_iter,
28
28
  prune_until_iter=prune_until_iter,
29
29
  prune_interval=prune_interval,
@@ -54,11 +54,17 @@ def BasePrunerInDensifyTrainer(
54
54
 
55
55
 
56
56
  def DepthPruningTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
57
- return DepthTrainerWrapper(BasePruningTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
57
+ return DepthTrainerWrapper(
58
+ BasePruningTrainer,
59
+ model, scene_extent, dataset,
60
+ *args, **kwargs)
58
61
 
59
62
 
60
63
  def DepthPrunerInDensifyTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
61
- return DepthTrainerWrapper(BasePrunerInDensifyTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
64
+ return DepthTrainerWrapper(
65
+ BasePrunerInDensifyTrainer,
66
+ model, scene_extent, dataset,
67
+ *args, **kwargs)
62
68
 
63
69
 
64
70
  PruningTrainer = DepthPruningTrainer
@@ -1,7 +1,7 @@
1
1
  from typing import Callable, List
2
2
  import torch
3
3
  from gaussian_splatting import GaussianModel, Camera
4
- from gaussian_splatting.trainer import AbstractDensifier, DensifierWrapper, DensificationTrainer, NoopDensifier
4
+ from gaussian_splatting.trainer import AbstractDensifier, OpacityPruner, DensificationTrainer, NoopDensifier
5
5
  from reduced_3dgs.diff_gaussian_rasterization._C import sphere_ellipsoid_intersection, allocate_minimum_redundancy_value, find_minimum_projected_pixel_size
6
6
  from reduced_3dgs.simple_knn._C import distIndex2
7
7
 
@@ -79,33 +79,28 @@ def mercy_gaussians(
79
79
  return mask
80
80
 
81
81
 
82
- class BasePruner(DensifierWrapper):
82
+ class BasePruner(OpacityPruner):
83
83
  def __init__(
84
84
  self, base_densifier: AbstractDensifier,
85
+ scene_extent,
85
86
  dataset: List[Camera],
86
- prune_from_iter=1000,
87
- prune_until_iter=15000,
88
- prune_interval: int = 100,
87
+ *args,
89
88
  box_size=1.,
90
89
  lambda_mercy=1.,
91
90
  mercy_minimum=3,
92
- mercy_type='redundancy_opacity'):
93
- super().__init__(base_densifier)
91
+ mercy_type='redundancy_opacity',
92
+ **kwargs):
93
+ super().__init__(base_densifier, scene_extent, *args, **kwargs)
94
94
  self.dataset = dataset
95
- self.prune_from_iter = prune_from_iter
96
- self.prune_until_iter = prune_until_iter
97
- self.prune_interval = prune_interval
98
95
  self.box_size = box_size
99
96
  self.lambda_mercy = lambda_mercy
100
97
  self.mercy_minimum = mercy_minimum
101
98
  self.mercy_type = mercy_type
102
99
 
103
- def densify_and_prune(self, loss, out, camera, step: int):
104
- ret = super().densify_and_prune(loss, out, camera, step)
105
- if self.prune_from_iter <= step <= self.prune_until_iter and step % self.prune_interval == 0:
106
- remove_mask = mercy_gaussians(self.model, self.dataset, self.box_size, self.lambda_mercy, self.mercy_minimum, self.mercy_type)
107
- ret = ret._replace(remove_mask=remove_mask if ret.remove_mask is None else torch.logical_or(remove_mask, ret.remove_mask))
108
- return ret
100
+ def prune(self) -> torch.Tensor:
101
+ remove_mask = mercy_gaussians(self.model, self.dataset, self.box_size, self.lambda_mercy, self.mercy_minimum, self.mercy_type)
102
+ prune_mask = torch.logical_or(super().prune(), remove_mask)
103
+ return prune_mask
109
104
 
110
105
 
111
106
  def PruningTrainerWrapper(
@@ -126,7 +121,7 @@ def PruningTrainerWrapper(
126
121
  model, scene_extent,
127
122
  BasePruner(
128
123
  noargs_base_densifier_constructor(model, scene_extent, dataset),
129
- dataset,
124
+ scene_extent, dataset,
130
125
  prune_from_iter=prune_from_iter,
131
126
  prune_until_iter=prune_until_iter,
132
127
  prune_interval=prune_interval,
@@ -4,7 +4,15 @@ from typing import Dict, Tuple
4
4
  import torch
5
5
  import torch.nn as nn
6
6
  import numpy as np
7
- from sklearn.cluster import MiniBatchKMeans as KMeans
7
+ try:
8
+ from cuml.cluster import KMeans
9
+ kmeans_init = 'k-means||'
10
+ except ImportError:
11
+ print("Cuml not found, using sklearn's MiniBatchKMeans for quantization.")
12
+ from sklearn.cluster import MiniBatchKMeans
13
+ from functools import partial
14
+ KMeans = partial(MiniBatchKMeans, batch_size=256 * os.cpu_count())
15
+ kmeans_init = 'k-means++'
8
16
  from gaussian_splatting import GaussianModel
9
17
  from plyfile import PlyData, PlyElement
10
18
  import numpy as np
@@ -65,9 +73,8 @@ class VectorQuantizer(AbstractQuantizer):
65
73
  def generate_codebook(self, values: torch.Tensor, num_clusters, init_codebook=None):
66
74
  kmeans = KMeans(
67
75
  n_clusters=num_clusters, tol=self.tol, max_iter=self.max_iter,
68
- init='k-means++' if init_codebook is None else init_codebook.cpu().numpy(),
69
- random_state=0, n_init="auto", verbose=0,
70
- batch_size=256 * os.cpu_count()
76
+ init=kmeans_init if init_codebook is None else init_codebook.cpu().numpy(),
77
+ random_state=0, n_init="auto", verbose=1,
71
78
  )
72
79
  ids = torch.tensor(kmeans.fit_predict(values.cpu().numpy()), device=values.device)
73
80
  centers = torch.tensor(kmeans.cluster_centers_, dtype=values.dtype, device=values.device)
@@ -119,10 +126,12 @@ class VectorQuantizer(AbstractQuantizer):
119
126
  return self.one_nearst(model._opacity.detach(), codebook)
120
127
 
121
128
  def produce_clusters_scaling(self, model: GaussianModel, *args, **kwargs):
122
- return self.generate_codebook(model.get_scaling.detach(), self.num_clusters_scaling, *args, **kwargs)
129
+ centers, ids = self.generate_codebook(model.get_scaling.detach(), self.num_clusters_scaling, *args, **kwargs)
130
+ centers_log = model.scaling_inverse_activation(centers)
131
+ return centers_log, ids
123
132
 
124
133
  def find_nearest_cluster_id_scaling(self, model: GaussianModel, codebook: torch.Tensor):
125
- return self.one_nearst(model.get_scaling.detach(), codebook)
134
+ return self.one_nearst(model.get_scaling.detach(), model.scaling_activation(codebook))
126
135
 
127
136
  def produce_clusters(self, model: GaussianModel, init_codebook_dict={}):
128
137
  codebook_dict: Dict[str, torch.Tensor] = {}
@@ -163,7 +172,7 @@ class VectorQuantizer(AbstractQuantizer):
163
172
 
164
173
  def dequantize(self, model: GaussianModel, ids_dict: Dict[str, torch.Tensor], codebook_dict: Dict[str, torch.Tensor], xyz: torch.Tensor = None, replace=False) -> GaussianModel:
165
174
  opacity = codebook_dict["opacity"][ids_dict["opacity"], ...]
166
- scaling = model.scaling_inverse_activation(codebook_dict["scaling"][ids_dict["scaling"], ...])
175
+ scaling = codebook_dict["scaling"][ids_dict["scaling"], ...]
167
176
 
168
177
  rotation = torch.cat((
169
178
  codebook_dict["rotation_re"][ids_dict["rotation_re"], ...],
@@ -204,8 +213,7 @@ class VectorQuantizer(AbstractQuantizer):
204
213
  ids_dict = self.find_nearest_cluster_id(model, self._codebook_dict)
205
214
  return ids_dict, codebook_dict
206
215
 
207
- def save_quantized(self, model: GaussianModel, ply_path: str):
208
- ids_dict, codebook_dict = self.quantize(model, update_codebook=False)
216
+ def ply_dtype(self, max_sh_degree: int):
209
217
  dtype_full = [
210
218
  ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
211
219
  ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
@@ -215,13 +223,16 @@ class VectorQuantizer(AbstractQuantizer):
215
223
  ('scale', self.force_code_dtype or compute_uint_dtype(self.num_clusters_scaling)),
216
224
  ('f_dc', self.force_code_dtype or compute_uint_dtype(self.num_clusters_features_dc)),
217
225
  ]
218
- for sh_degree in range(model.max_sh_degree):
226
+ for sh_degree in range(max_sh_degree):
219
227
  force_code_dtype = self.force_code_dtype or compute_uint_dtype(self.num_clusters_features_rest[sh_degree])
220
228
  dtype_full.extend([
221
229
  (f'f_rest_{sh_degree}_0', force_code_dtype),
222
230
  (f'f_rest_{sh_degree}_1', force_code_dtype),
223
231
  (f'f_rest_{sh_degree}_2', force_code_dtype),
224
232
  ])
233
+ return dtype_full
234
+
235
+ def ply_data(self, model: GaussianModel, ids_dict: Dict[str, torch.Tensor]):
225
236
  data_full = [
226
237
  *np.array_split(model._xyz.detach().cpu().numpy(), 3, axis=1),
227
238
  *np.array_split(torch.zeros_like(model._xyz).detach().cpu().numpy(), 3, axis=1),
@@ -234,6 +245,12 @@ class VectorQuantizer(AbstractQuantizer):
234
245
  for sh_degree in range(model.max_sh_degree):
235
246
  features_rest = ids_dict[f'features_rest_{sh_degree}'].cpu().numpy()
236
247
  data_full.extend(np.array_split(features_rest, 3, axis=1))
248
+ return data_full
249
+
250
+ def save_quantized(self, model: GaussianModel, ply_path: str):
251
+ ids_dict, codebook_dict = self.quantize(model, update_codebook=False)
252
+ dtype_full = self.ply_dtype(model.max_sh_degree)
253
+ data_full = self.ply_data(model, ids_dict)
237
254
 
238
255
  elements = np.rec.fromarrays([data.squeeze(-1) for data in data_full], dtype=dtype_full)
239
256
  el = PlyElement.describe(elements, 'vertex')
@@ -252,36 +269,46 @@ class VectorQuantizer(AbstractQuantizer):
252
269
 
253
270
  PlyData([el, *cb]).write(ply_path)
254
271
 
255
- def load_quantized(self, model: GaussianModel, ply_path: str) -> GaussianModel:
256
- plydata = PlyData.read(ply_path)
257
-
272
+ def parse_ids(self, plydata: PlyData, max_sh_degree: int, device: torch.device) -> Dict[str, torch.Tensor]:
258
273
  ids_dict = {}
259
274
  elements = plydata['vertex']
260
- kwargs = dict(dtype=torch.long, device=model._xyz.device)
275
+ kwargs = dict(dtype=torch.long, device=device)
261
276
  ids_dict["rotation_re"] = torch.tensor(elements["rot_re"].copy(), **kwargs)
262
277
  ids_dict["rotation_im"] = torch.tensor(elements["rot_im"].copy(), **kwargs)
263
278
  ids_dict["opacity"] = torch.tensor(elements["opacity"].copy(), **kwargs)
264
279
  ids_dict["scaling"] = torch.tensor(elements["scale"].copy(), **kwargs)
265
280
  ids_dict["features_dc"] = torch.tensor(elements["f_dc"].copy(), **kwargs).unsqueeze(-1)
266
- for sh_degree in range(model.max_sh_degree):
281
+ for sh_degree in range(max_sh_degree):
267
282
  ids_dict[f'features_rest_{sh_degree}'] = torch.tensor(np.stack([elements[f'f_rest_{sh_degree}_{ch}'] for ch in range(3)], axis=1), **kwargs)
283
+ return ids_dict
268
284
 
285
+ def parse_codebook(self, plydata: PlyData, max_sh_degree: int, device: torch.device) -> Dict[str, torch.Tensor]:
269
286
  codebook_dict = {}
270
- kwargs = dict(dtype=torch.float32, device=model._xyz.device)
287
+ kwargs = dict(dtype=torch.float32, device=device)
271
288
  codebook_dict["rotation_re"] = torch.tensor(plydata["codebook_rot_re"]["rot_re"], **kwargs).unsqueeze(-1)
272
289
  codebook_dict["rotation_im"] = torch.tensor(np.stack([plydata["codebook_rot_im"][f'rot_im_{ch}'] for ch in range(3)], axis=1), **kwargs)
273
290
  codebook_dict["opacity"] = torch.tensor(plydata["codebook_opacity"]["opacity"], **kwargs).unsqueeze(-1)
274
291
  codebook_dict["scaling"] = torch.tensor(np.stack([plydata["codebook_scaling"][f'scaling_{ch}'] for ch in range(3)], axis=1), **kwargs)
275
292
  codebook_dict["features_dc"] = torch.tensor(np.stack([plydata["codebook_f_dc"][f'f_dc_{ch}'] for ch in range(3)], axis=1), **kwargs)
276
- for sh_degree in range(model.max_sh_degree):
293
+ for sh_degree in range(max_sh_degree):
277
294
  n_channels = (sh_degree + 2) ** 2 - (sh_degree + 1) ** 2
278
295
  codebook_dict[f'features_rest_{sh_degree}'] = torch.tensor(np.stack([plydata[f"codebook_f_rest_{sh_degree}"][f'f_rest_{sh_degree}_{ch}'] for ch in range(n_channels)], axis=1), **kwargs)
296
+ return codebook_dict
279
297
 
280
- self._codebook_dict = codebook_dict
281
-
298
+ def parse_xyz(self, plydata: PlyData, device: torch.device) -> torch.Tensor:
299
+ elements = plydata['vertex']
300
+ kwargs = dict(dtype=torch.float32, device=device)
282
301
  xyz = torch.stack([
283
302
  torch.tensor(elements["x"].copy(), **kwargs),
284
303
  torch.tensor(elements["y"].copy(), **kwargs),
285
304
  torch.tensor(elements["z"].copy(), **kwargs),
286
305
  ], dim=1)
306
+ return xyz
307
+
308
+ def load_quantized(self, model: GaussianModel, ply_path: str) -> GaussianModel:
309
+ plydata = PlyData.read(ply_path)
310
+ ids_dict = self.parse_ids(plydata, model.max_sh_degree, model._xyz.device)
311
+ codebook_dict = self.parse_codebook(plydata, model.max_sh_degree, model._xyz.device)
312
+ xyz = self.parse_xyz(plydata, model._xyz.device)
313
+ self._codebook_dict = codebook_dict
287
314
  return self.dequantize(model, ids_dict, codebook_dict, xyz=xyz, replace=True)
reduced_3dgs/quantize.py CHANGED
@@ -13,6 +13,7 @@ def copy_not_exists(source, destination):
13
13
 
14
14
 
15
15
  def quantize(source, destination, iteration, sh_degree, device, **kwargs):
16
+ os.makedirs(destination, exist_ok=True)
16
17
  copy_not_exists(os.path.join(source, "cfg_args"), os.path.join(destination, "cfg_args"))
17
18
  copy_not_exists(os.path.join(source, "cameras.json"), os.path.join(destination, "cameras.json"))
18
19
 
reduced_3dgs/train.py CHANGED
@@ -1,150 +1,65 @@
1
1
  import os
2
2
  import random
3
3
  import shutil
4
- from typing import List, Tuple
4
+ from typing import List
5
5
  import torch
6
6
  from tqdm import tqdm
7
- from argparse import Namespace
8
7
  from gaussian_splatting import GaussianModel
9
- from gaussian_splatting.dataset import CameraDataset, JSONCameraDataset, TrainableCameraDataset
8
+ from gaussian_splatting.dataset import CameraDataset
10
9
  from gaussian_splatting.utils import psnr
11
- from gaussian_splatting.dataset.colmap import ColmapCameraDataset, ColmapTrainableCameraDataset, colmap_init
12
10
  from gaussian_splatting.trainer import AbstractTrainer
13
- from gaussian_splatting.trainer.extensions import ScaleRegularizeTrainerWrapper
14
- from reduced_3dgs.quantization import AbstractQuantizer, VectorQuantizeTrainerWrapper
15
- from reduced_3dgs.shculling import VariableSHGaussianModel, SHCullingTrainer
16
- from reduced_3dgs.pruning import PruningTrainer
17
- from reduced_3dgs.combinations import PrunerInDensifyTrainer, SHCullingDensificationTrainer, SHCullingPruningTrainer, SHCullingPrunerInDensifyTrainer
18
- from reduced_3dgs.combinations import CameraTrainableVariableSHGaussianModel, CameraSHCullingTrainer, CameraPruningTrainer
19
- from reduced_3dgs.combinations import CameraPrunerInDensifyTrainer, CameraSHCullingDensifyTrainer, CameraSHCullingPruningTrainer, CameraSHCullingPruningDensifyTrainer
20
-
21
-
22
- basemodes = {
23
- "shculling": SHCullingTrainer,
24
- "pruning": PruningTrainer,
25
- "densify-pruning": PrunerInDensifyTrainer,
26
- "densify-shculling": SHCullingDensificationTrainer,
27
- "prune-shculling": SHCullingPruningTrainer,
28
- "densify-prune-shculling": SHCullingPrunerInDensifyTrainer,
29
- }
30
- cameramodes = {
31
- "camera-shculling": CameraSHCullingTrainer,
32
- "camera-pruning": CameraPruningTrainer,
33
- "camera-densify-pruning": CameraPrunerInDensifyTrainer,
34
- "camera-densify-shculling": CameraSHCullingDensifyTrainer,
35
- "camera-prune-shculling": CameraSHCullingPruningTrainer,
36
- "camera-densify-prune-shculling": CameraSHCullingPruningDensifyTrainer,
37
- }
38
-
39
-
40
- def prepare_quantizer(
41
- gaussians: GaussianModel,
42
- scene_extent: float,
43
- dataset: CameraDataset,
44
- base_constructor,
45
- load_quantized: str = None,
46
-
47
- num_clusters=256,
48
- num_clusters_rotation_re=None,
49
- num_clusters_rotation_im=None,
50
- num_clusters_opacity=None,
51
- num_clusters_scaling=None,
52
- num_clusters_features_dc=None,
53
- num_clusters_features_rest=[],
54
-
55
- quantize_from_iter=5000,
56
- quantize_until_iter=30000,
57
- quantize_interval=1000,
58
- **configs):
59
- trainer = VectorQuantizeTrainerWrapper(
60
- base_constructor(
61
- gaussians,
62
- scene_extent=scene_extent,
63
- dataset=dataset,
64
- **configs
65
- ),
66
-
67
- num_clusters=num_clusters,
68
- num_clusters_rotation_re=num_clusters_rotation_re,
69
- num_clusters_rotation_im=num_clusters_rotation_im,
70
- num_clusters_opacity=num_clusters_opacity,
71
- num_clusters_scaling=num_clusters_scaling,
72
- num_clusters_features_dc=num_clusters_features_dc,
73
- num_clusters_features_rest=num_clusters_features_rest,
74
-
75
- quantize_from_iter=quantize_from_iter,
76
- quantize_until_iter=quantize_until_iter,
77
- quantize_interval=quantize_interval,
78
- )
79
- if load_quantized:
80
- trainer.quantizer.load_quantized(load_quantized)
81
- return trainer, trainer.quantizer
82
-
83
-
84
- def prepare_training(sh_degree: int, source: str, device: str, mode: str, load_ply: str = None, load_camera: str = None, load_depth=False, with_scale_reg=False, quantize: bool = False, load_quantized: str = None, configs={}) -> Tuple[CameraDataset, GaussianModel, AbstractTrainer]:
85
- quantizer = None
86
- if mode in basemodes:
87
- gaussians = VariableSHGaussianModel(sh_degree).to(device)
88
- gaussians.load_ply(load_ply) if load_ply else colmap_init(gaussians, source)
89
- dataset = (JSONCameraDataset(load_camera, load_depth=load_depth) if load_camera else ColmapCameraDataset(source, load_depth=load_depth)).to(device)
90
- modes = basemodes
91
- elif mode in cameramodes:
92
- gaussians = CameraTrainableVariableSHGaussianModel(sh_degree).to(device)
93
- gaussians.load_ply(load_ply) if load_ply else colmap_init(gaussians, source)
94
- dataset = (TrainableCameraDataset.from_json(load_camera, load_depth=load_depth) if load_camera else ColmapTrainableCameraDataset(source, load_depth=load_depth)).to(device)
95
- modes = cameramodes
96
- else:
97
- raise ValueError(f"Unknown mode: {mode}")
98
- constructor = modes[mode]
99
- if with_scale_reg:
100
- constructor = lambda *args, **kwargs: ScaleRegularizeTrainerWrapper(modes[mode], *args, **kwargs)
101
- if quantize:
102
- trainer, quantizer = prepare_quantizer(
103
- gaussians,
104
- scene_extent=dataset.scene_extent(),
105
- dataset=dataset,
106
- base_constructor=modes[mode],
107
- load_quantized=load_quantized,
108
- **configs
109
- )
110
- else:
111
- trainer = constructor(
112
- gaussians,
113
- scene_extent=dataset.scene_extent(),
114
- dataset=dataset,
115
- **configs
116
- )
11
+ from gaussian_splatting.prepare import prepare_dataset
12
+ from gaussian_splatting.train import save_cfg_args
13
+ from reduced_3dgs.quantization import AbstractQuantizer
14
+ from reduced_3dgs.prepare import modes, prepare_gaussians, prepare_trainer
15
+
16
+
17
+ def prepare_training(
18
+ sh_degree: int, source: str, device: str, mode: str,
19
+ trainable_camera: bool = False, load_ply: str = None, load_camera: str = None,
20
+ load_mask=True, load_depth=True,
21
+ with_scale_reg=False, quantize: bool = False, load_quantized: str = None, configs={}):
22
+ dataset = prepare_dataset(source=source, device=device, trainable_camera=trainable_camera, load_camera=load_camera, load_mask=load_mask, load_depth=load_depth)
23
+ gaussians = prepare_gaussians(sh_degree=sh_degree, source=source, device=device, trainable_camera=trainable_camera, load_ply=load_ply)
24
+ trainer, quantizer = prepare_trainer(gaussians=gaussians, dataset=dataset, mode=mode, with_scale_reg=with_scale_reg, quantize=quantize, load_quantized=load_quantized, configs=configs)
117
25
  return dataset, gaussians, trainer, quantizer
118
26
 
119
27
 
120
- def save_cfg_args(destination: str, sh_degree: int, source: str):
121
- os.makedirs(destination, exist_ok=True)
122
- with open(os.path.join(destination, "cfg_args"), 'w') as cfg_log_f:
123
- cfg_log_f.write(str(Namespace(sh_degree=sh_degree, source_path=source)))
124
-
125
-
126
28
  def training(dataset: CameraDataset, gaussians: GaussianModel, trainer: AbstractTrainer, quantizer: AbstractQuantizer, destination: str, iteration: int, save_iterations: List[int], device: str, empty_cache_every_step=False):
127
29
  shutil.rmtree(os.path.join(destination, "point_cloud"), ignore_errors=True) # remove the previous point cloud
128
- pbar = tqdm(range(1, iteration+1))
30
+ pbar = tqdm(range(1, iteration+1), dynamic_ncols=True, desc="Training")
129
31
  epoch = list(range(len(dataset)))
130
32
  epoch_psnr = torch.empty(3, 0, device=device)
33
+ epoch_maskpsnr = torch.empty(3, 0, device=device)
131
34
  ema_loss_for_log = 0.0
132
35
  avg_psnr_for_log = 0.0
36
+ avg_maskpsnr_for_log = 0.0
133
37
  for step in pbar:
134
38
  epoch_idx = step % len(dataset)
135
39
  if epoch_idx == 0:
136
40
  avg_psnr_for_log = epoch_psnr.mean().item()
41
+ avg_maskpsnr_for_log = epoch_maskpsnr.mean().item()
137
42
  epoch_psnr = torch.empty(3, 0, device=device)
43
+ epoch_maskpsnr = torch.empty(3, 0, device=device)
138
44
  random.shuffle(epoch)
139
45
  idx = epoch[epoch_idx]
140
46
  loss, out = trainer.step(dataset[idx])
141
47
  if empty_cache_every_step:
142
48
  torch.cuda.empty_cache()
143
49
  with torch.no_grad():
50
+ ground_truth_image = dataset[idx].ground_truth_image
51
+ rendered_image = out["render"]
52
+ epoch_psnr = torch.concat([epoch_psnr, psnr(rendered_image, ground_truth_image)], dim=1)
53
+ if dataset[idx].ground_truth_image_mask is not None:
54
+ ground_truth_maskimage = ground_truth_image * dataset[idx].ground_truth_image_mask
55
+ rendered_maskimage = rendered_image * dataset[idx].ground_truth_image_mask
56
+ epoch_maskpsnr = torch.concat([epoch_maskpsnr, psnr(rendered_maskimage, ground_truth_maskimage)], dim=1)
144
57
  ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
145
- epoch_psnr = torch.concat([epoch_psnr, psnr(out["render"], dataset[idx].ground_truth_image)], dim=1)
146
58
  if step % 10 == 0:
147
- pbar.set_postfix({'epoch': step // len(dataset), 'loss': ema_loss_for_log, 'psnr': avg_psnr_for_log, 'n': gaussians._xyz.shape[0]})
59
+ postfix = {'epoch': step // len(dataset), 'loss': ema_loss_for_log, 'psnr': avg_psnr_for_log, 'n': gaussians._xyz.shape[0]}
60
+ if avg_maskpsnr_for_log > 0:
61
+ postfix['psnr (mask)'] = avg_maskpsnr_for_log
62
+ pbar.set_postfix(postfix)
148
63
  if step in save_iterations:
149
64
  save_path = os.path.join(destination, "point_cloud", "iteration_" + str(step))
150
65
  os.makedirs(save_path, exist_ok=True)
@@ -161,7 +76,7 @@ def training(dataset: CameraDataset, gaussians: GaussianModel, trainer: Abstract
161
76
 
162
77
 
163
78
  if __name__ == "__main__":
164
- from argparse import ArgumentParser, Namespace
79
+ from argparse import ArgumentParser
165
80
  parser = ArgumentParser()
166
81
  parser.add_argument("--sh_degree", default=3, type=int)
167
82
  parser.add_argument("-s", "--source", required=True, type=str)
@@ -170,10 +85,11 @@ if __name__ == "__main__":
170
85
  parser.add_argument("-l", "--load_ply", default=None, type=str)
171
86
  parser.add_argument("--load_camera", default=None, type=str)
172
87
  parser.add_argument("--quantize", action='store_true')
88
+ parser.add_argument("--no_image_mask", action="store_true")
173
89
  parser.add_argument("--no_depth_data", action='store_true')
174
90
  parser.add_argument("--with_scale_reg", action="store_true")
175
91
  parser.add_argument("--load_quantized", default=None, type=str)
176
- parser.add_argument("--mode", choices=list(basemodes.keys()) + list(cameramodes.keys()), default="densify-prune-shculling")
92
+ parser.add_argument("--mode", choices=list(modes), default="densify-prune-shculling")
177
93
  parser.add_argument("--save_iterations", nargs="+", type=int, default=[7000, 30000])
178
94
  parser.add_argument("--device", default="cuda", type=str)
179
95
  parser.add_argument("--empty_cache_every_step", action='store_true')
@@ -184,8 +100,10 @@ if __name__ == "__main__":
184
100
 
185
101
  configs = {o.split("=", 1)[0]: eval(o.split("=", 1)[1]) for o in args.option}
186
102
  dataset, gaussians, trainer, quantizer = prepare_training(
187
- sh_degree=args.sh_degree, source=args.source, device=args.device, mode=args.mode,
188
- load_ply=args.load_ply, load_camera=args.load_camera, load_depth=not args.no_depth_data, with_scale_reg=args.with_scale_reg,
103
+ sh_degree=args.sh_degree, source=args.source, device=args.device, mode=args.mode, trainable_camera="camera" in args.mode,
104
+ load_ply=args.load_ply, load_camera=args.load_camera,
105
+ load_mask=not args.no_image_mask, load_depth=not args.no_depth_data,
106
+ with_scale_reg=args.with_scale_reg,
189
107
  quantize=args.quantize, load_quantized=args.load_quantized, configs=configs)
190
108
  dataset.save_cameras(os.path.join(args.destination, "cameras.json"))
191
109
  torch.cuda.empty_cache()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: reduced_3dgs
3
- Version: 1.9.1
3
+ Version: 1.10.12
4
4
  Summary: Refactored code for the paper "Reducing the Memory Footprint of 3D Gaussian Splatting"
5
5
  Home-page: https://github.com/yindaheng98/reduced-3dgs
6
6
  Author: yindaheng98
@@ -8,13 +8,8 @@ Author-email: yindaheng98@gmail.com
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Description-Content-Type: text/markdown
10
10
  License-File: LICENSE.md
11
- Requires-Dist: tqdm
12
- Requires-Dist: plyfile
13
- Requires-Dist: scikit-learn
14
- Requires-Dist: torch
15
- Requires-Dist: torchvision
16
- Requires-Dist: numpy
17
11
  Requires-Dist: gaussian-splatting
12
+ Requires-Dist: scikit-learn
18
13
  Dynamic: author
19
14
  Dynamic: author-email
20
15
  Dynamic: classifier
@@ -40,47 +35,35 @@ This repository contains the **refactored Python code for [Reduced-3DGS](https:/
40
35
 
41
36
  * [Pytorch](https://pytorch.org/) (v2.4 or higher recommended)
42
37
  * [CUDA Toolkit](https://developer.nvidia.com/cuda-12-4-0-download-archive) (12.4 recommended, should match with PyTorch version)
38
+ * (Optional) [cuML](https://github.com/rapidsai/cuml) for faster vector quantization
43
39
 
44
- ## Install (PyPI)
45
-
40
+ (Optional) If you have trouble with [`gaussian-splatting`](https://github.com/yindaheng98/gaussian-splatting), try to install it from source:
46
41
  ```sh
47
- pip install --upgrade reduced-3dgs
42
+ pip install wheel setuptools
43
+ pip install --upgrade git+https://github.com/yindaheng98/gaussian-splatting.git@master --no-build-isolation
48
44
  ```
49
45
 
50
- ## Install (Build from source)
51
-
52
- ```sh
53
- pip install --upgrade git+https://github.com/yindaheng98/reduced-3dgs.git@main
54
- ```
55
- If you have trouble with [`gaussian-splatting`](https://github.com/yindaheng98/gaussian-splatting), you can install it from source:
56
- ```sh
57
- pip install --upgrade git+https://github.com/yindaheng98/gaussian-splatting.git@master
58
- ```
46
+ ## PyPI Install
59
47
 
60
- ## Install (Development)
61
-
62
- Install [`gaussian-splatting`](https://github.com/yindaheng98/gaussian-splatting).
63
- You can download the wheel from [PyPI](https://pypi.org/project/gaussian-splatting/):
64
48
  ```shell
65
- pip install --upgrade gaussian-splatting
49
+ pip install --upgrade reduced-3dgs
66
50
  ```
67
- Alternatively, install the latest version from the source:
68
- ```sh
69
- pip install --upgrade git+https://github.com/yindaheng98/gaussian-splatting.git@master
51
+ or
52
+ build latest from source:
53
+ ```shell
54
+ pip install wheel setuptools
55
+ pip install --upgrade git+https://github.com/yindaheng98/reduced-3dgs.git@main --no-build-isolation
70
56
  ```
71
57
 
58
+ ### Development Install
59
+
72
60
  ```shell
73
61
  git clone --recursive https://github.com/yindaheng98/reduced-3dgs
74
62
  cd reduced-3dgs
75
- pip install tqdm plyfile scikit-learn numpy tifffile triton xformers
63
+ pip install scikit-learn
76
64
  pip install --target . --upgrade --no-deps .
77
65
  ```
78
66
 
79
- (Optional) If you prefer not to install `gaussian-splatting` in your environment, you can install it in your `reduced-3dgs` directory:
80
- ```sh
81
- pip install --target . --no-deps --upgrade git+https://github.com/yindaheng98/gaussian-splatting.git@master
82
- ```
83
-
84
67
  ## Quick Start
85
68
 
86
69
  1. Download the dataset (T&T+DB COLMAP dataset, size 650MB):
@@ -0,0 +1,29 @@
1
+ reduced_3dgs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ reduced_3dgs/combinations.py,sha256=dRrtTIQLggUP22Olp0DVSlD0C8YDgnrpTlrCcGZyWM0,9441
3
+ reduced_3dgs/prepare.py,sha256=MFUUckRHKfgcva4ZOBxfPFyE95N-OlCQLplpmEPuzOk,4440
4
+ reduced_3dgs/quantize.py,sha256=BVqBb2tQgiP3hap5-OByD8VELtJJGfEeFzaVFyzCJZU,2572
5
+ reduced_3dgs/train.py,sha256=P9GCoEVl_VGldKuA3ycj1nea5rTc3xqu_f02a8oe-Lc,6734
6
+ reduced_3dgs/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd,sha256=T3bwC1BQvTG47Ha-VCGjtTOPU81HYv7TCsa0lYAqVF4,1610752
7
+ reduced_3dgs/diff_gaussian_rasterization/__init__.py,sha256=oV6JjTc-50MscX4XHeIWSgLr3l8Y25knBIs-0gRbJr4,7932
8
+ reduced_3dgs/importance/__init__.py,sha256=neJsbY5cLikEGBQGdR4MjwCQ5VWVikT1357DwL0EtWU,289
9
+ reduced_3dgs/importance/combinations.py,sha256=F9IYDZ6iquZw5Djn2An88dlp2MUGieAyuZyVaDZdSts,2962
10
+ reduced_3dgs/importance/trainer.py,sha256=Sj4ORvoYtFT7z3hifzFZDfhFyqumHraXyk3vMVtk0AU,12661
11
+ reduced_3dgs/importance/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd,sha256=TLEsXJXzHr3fYCzE-P17gGkpoQ_X80LsvqCeT9vzchQ,1302016
12
+ reduced_3dgs/importance/diff_gaussian_rasterization/__init__.py,sha256=Tix8auyXBb_QFQtXrV3sLE9kdnl5zgHH0BbqcFzDp84,12850
13
+ reduced_3dgs/pruning/__init__.py,sha256=E_YxJ9cDV_B6EJbYUBEcuRYMIht_C72rI1VJUXFCLpM,201
14
+ reduced_3dgs/pruning/combinations.py,sha256=xO39d08DEsDnKI5np_1xQ4H29O486Q_cThLfWhlH810,2385
15
+ reduced_3dgs/pruning/trainer.py,sha256=JJml-uYfDfUpbsjRNZbIvnUYYslVgFXkhejbkYSo0s4,6542
16
+ reduced_3dgs/quantization/__init__.py,sha256=1z1xMn3yj9u7cR9JizGrI3WSyIES_Tqq6oDquvglSeo,225
17
+ reduced_3dgs/quantization/abc.py,sha256=rsi8HFRwQCltWTYiJ3BpygtQDT7hK6J01jKMOboOY8w,1910
18
+ reduced_3dgs/quantization/exclude_zeros.py,sha256=fKSgjHous4OpdI6mQi9z23if9jnbB79w2jChpxkCJWw,2381
19
+ reduced_3dgs/quantization/quantizer.py,sha256=2myyBsYPGbzjZhjpKnYMYcvk6INqeds7wRdBRbIPki4,17948
20
+ reduced_3dgs/quantization/wrapper.py,sha256=cyXqfJgo9b3fS7DYXxOk5LmQudvrEhweOebFsjRnXiQ,2549
21
+ reduced_3dgs/shculling/__init__.py,sha256=nP2BejDCUdCmJNRbg0hfhHREO6jyZXwIcRiw6ttVgqo,149
22
+ reduced_3dgs/shculling/gaussian_model.py,sha256=f8QWaL09vaV9Tcf6Dngjg_Fmk1wTQPAjWhuhI_N02Y8,2877
23
+ reduced_3dgs/shculling/trainer.py,sha256=9hwR77djhZpyf-URhwKHjnLbe0ZAOS-DIw58RzkcHXQ,6369
24
+ reduced_3dgs/simple_knn/_C.cp310-win_amd64.pyd,sha256=elGT2b5IdsO5qGhW0XBj_5qdUyJBNveMgEM6yVdoHiY,1248768
25
+ reduced_3dgs-1.10.12.dist-info/licenses/LICENSE.md,sha256=LQ4_LAqlncGkg_mQy5ykMAFtQDSPB0eKmIEtBut0yjw,4916
26
+ reduced_3dgs-1.10.12.dist-info/METADATA,sha256=7cQgH63LmWATezhRqXDPA6ytT_Ld9mYXKppGqvR8QTw,12404
27
+ reduced_3dgs-1.10.12.dist-info/WHEEL,sha256=KUuBC6lxAbHCKilKua8R9W_TM71_-9Sg5uEP3uDWcoU,101
28
+ reduced_3dgs-1.10.12.dist-info/top_level.txt,sha256=PpU5aT3-baSCdqCtTaZknoB32H93UeKCkYDkRCCZMEI,13
29
+ reduced_3dgs-1.10.12.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.0.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp310-cp310-win_amd64
5
5
 
@@ -1,28 +0,0 @@
1
- reduced_3dgs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- reduced_3dgs/combinations.py,sha256=FrZLxd3AZlHPSq_WJeEdGWH5zh40rAuV5txxr8HsSPY,7031
3
- reduced_3dgs/quantize.py,sha256=Y44qHyFdOIqke7NoeqXmyKloS43j-al74ZiNsuZZHbM,2527
4
- reduced_3dgs/train.py,sha256=jXHdXk05o_ebHjx_VBzcY6fRNn9EdKve6Tf5YC5an0o,9803
5
- reduced_3dgs/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd,sha256=_DFCX447pmdRjDE8JsGkGYCo5-Q5YYk8ODuAql4tMUU,1618432
6
- reduced_3dgs/diff_gaussian_rasterization/__init__.py,sha256=oV6JjTc-50MscX4XHeIWSgLr3l8Y25knBIs-0gRbJr4,7932
7
- reduced_3dgs/importance/__init__.py,sha256=neJsbY5cLikEGBQGdR4MjwCQ5VWVikT1357DwL0EtWU,289
8
- reduced_3dgs/importance/combinations.py,sha256=Q2WqwXNuclPWBsw15aR14xQ72JewVeZo9igMeaSfuf8,2693
9
- reduced_3dgs/importance/trainer.py,sha256=ozzjPSXIpHNP4H1mkXlyAv5KJXWt1k4NKCikx9E9S1E,11151
10
- reduced_3dgs/importance/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd,sha256=VrYLQY0VPX5jte3KXwGYyp7tuvDECMNVi0VaWfZp8yE,1308160
11
- reduced_3dgs/importance/diff_gaussian_rasterization/__init__.py,sha256=Tix8auyXBb_QFQtXrV3sLE9kdnl5zgHH0BbqcFzDp84,12850
12
- reduced_3dgs/pruning/__init__.py,sha256=E_YxJ9cDV_B6EJbYUBEcuRYMIht_C72rI1VJUXFCLpM,201
13
- reduced_3dgs/pruning/combinations.py,sha256=UivTfbSMmaWYVi9E4OF-_AZA-WBWniMiX-wKUftezF8,2331
14
- reduced_3dgs/pruning/trainer.py,sha256=898m5-7AZFmzQJtyMdQcInOZtDGsAeM3OAplHsm3oSY,6948
15
- reduced_3dgs/quantization/__init__.py,sha256=1z1xMn3yj9u7cR9JizGrI3WSyIES_Tqq6oDquvglSeo,225
16
- reduced_3dgs/quantization/abc.py,sha256=rsi8HFRwQCltWTYiJ3BpygtQDT7hK6J01jKMOboOY8w,1910
17
- reduced_3dgs/quantization/exclude_zeros.py,sha256=fKSgjHous4OpdI6mQi9z23if9jnbB79w2jChpxkCJWw,2381
18
- reduced_3dgs/quantization/quantizer.py,sha256=-B07X1VrV7E1fT_NXXlZ0FLAb-c_jBUlZ8274KwHyVw,16614
19
- reduced_3dgs/quantization/wrapper.py,sha256=cyXqfJgo9b3fS7DYXxOk5LmQudvrEhweOebFsjRnXiQ,2549
20
- reduced_3dgs/shculling/__init__.py,sha256=nP2BejDCUdCmJNRbg0hfhHREO6jyZXwIcRiw6ttVgqo,149
21
- reduced_3dgs/shculling/gaussian_model.py,sha256=f8QWaL09vaV9Tcf6Dngjg_Fmk1wTQPAjWhuhI_N02Y8,2877
22
- reduced_3dgs/shculling/trainer.py,sha256=9hwR77djhZpyf-URhwKHjnLbe0ZAOS-DIw58RzkcHXQ,6369
23
- reduced_3dgs/simple_knn/_C.cp310-win_amd64.pyd,sha256=pdtHUA8FOmQkP2L9f0yJs8b3KBcjyVjdwAwv2Om0rY4,1255424
24
- reduced_3dgs-1.9.1.dist-info/licenses/LICENSE.md,sha256=LQ4_LAqlncGkg_mQy5ykMAFtQDSPB0eKmIEtBut0yjw,4916
25
- reduced_3dgs-1.9.1.dist-info/METADATA,sha256=TIrUop8Y3QjsjnVAxNEGADwBMtVArYSbYhhzMYLWMYc,13014
26
- reduced_3dgs-1.9.1.dist-info/WHEEL,sha256=f7LviPjlKeAUpCmKrQrAXn06zBZT0604QctSZKHK6ZM,101
27
- reduced_3dgs-1.9.1.dist-info/top_level.txt,sha256=PpU5aT3-baSCdqCtTaZknoB32H93UeKCkYDkRCCZMEI,13
28
- reduced_3dgs-1.9.1.dist-info/RECORD,,