reduced-3dgs 1.10.0__cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reduced-3dgs might be problematic. Click here for more details.

Files changed (31) hide show
  1. reduced_3dgs/__init__.py +0 -0
  2. reduced_3dgs/combinations.py +245 -0
  3. reduced_3dgs/diff_gaussian_rasterization/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  4. reduced_3dgs/diff_gaussian_rasterization/__init__.py +235 -0
  5. reduced_3dgs/importance/__init__.py +3 -0
  6. reduced_3dgs/importance/combinations.py +63 -0
  7. reduced_3dgs/importance/diff_gaussian_rasterization/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  8. reduced_3dgs/importance/diff_gaussian_rasterization/__init__.py +347 -0
  9. reduced_3dgs/importance/trainer.py +269 -0
  10. reduced_3dgs/pruning/__init__.py +2 -0
  11. reduced_3dgs/pruning/combinations.py +65 -0
  12. reduced_3dgs/pruning/trainer.py +145 -0
  13. reduced_3dgs/quantization/__init__.py +4 -0
  14. reduced_3dgs/quantization/abc.py +49 -0
  15. reduced_3dgs/quantization/exclude_zeros.py +41 -0
  16. reduced_3dgs/quantization/quantizer.py +289 -0
  17. reduced_3dgs/quantization/wrapper.py +67 -0
  18. reduced_3dgs/quantize.py +49 -0
  19. reduced_3dgs/shculling/__init__.py +2 -0
  20. reduced_3dgs/shculling/gaussian_model.py +78 -0
  21. reduced_3dgs/shculling/trainer.py +158 -0
  22. reduced_3dgs/simple_knn/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  23. reduced_3dgs/train.py +195 -0
  24. reduced_3dgs-1.10.0.dist-info/LICENSE.md +93 -0
  25. reduced_3dgs-1.10.0.dist-info/METADATA +278 -0
  26. reduced_3dgs-1.10.0.dist-info/RECORD +31 -0
  27. reduced_3dgs-1.10.0.dist-info/WHEEL +6 -0
  28. reduced_3dgs-1.10.0.dist-info/top_level.txt +1 -0
  29. reduced_3dgs.libs/libc10-ff4eddb5.so +0 -0
  30. reduced_3dgs.libs/libc10_cuda-c675d3fb.so +0 -0
  31. reduced_3dgs.libs/libcudart-8774224f.so.12.4.127 +0 -0
@@ -0,0 +1,347 @@
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from typing import NamedTuple
13
+ import torch.nn as nn
14
+ import torch
15
+ from . import _C
16
+
17
+
18
+ def cpu_deep_copy_tuple(input_tuple):
19
+ copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
20
+ return tuple(copied_tensors)
21
+
22
+ def rasterize_gaussians(
23
+ means3D,
24
+ means2D,
25
+ sh,
26
+ colors_precomp,
27
+ opacities,
28
+ scales,
29
+ rotations,
30
+ cov3Ds_precomp,
31
+ raster_settings,
32
+ ):
33
+ if raster_settings.f_count:
34
+ return _RasterizeGaussians.forward_count(
35
+ means3D,
36
+ means2D,
37
+ sh,
38
+ colors_precomp,
39
+ opacities,
40
+ scales,
41
+ rotations,
42
+ cov3Ds_precomp,
43
+ raster_settings,
44
+
45
+ )
46
+ return _RasterizeGaussians.apply(
47
+ means3D,
48
+ means2D,
49
+ sh,
50
+ colors_precomp,
51
+ opacities,
52
+ scales,
53
+ rotations,
54
+ cov3Ds_precomp,
55
+ raster_settings,
56
+ )
57
+
58
+ class _RasterizeGaussians(torch.autograd.Function):
59
+ @staticmethod
60
+ def forward(
61
+ ctx,
62
+ means3D,
63
+ means2D,
64
+ sh,
65
+ colors_precomp,
66
+ opacities,
67
+ scales,
68
+ rotations,
69
+ cov3Ds_precomp,
70
+ raster_settings,
71
+ ):
72
+
73
+ # Restructure arguments the way that the C++ lib expects them
74
+ args = (
75
+ raster_settings.bg,
76
+ means3D,
77
+ colors_precomp,
78
+ opacities,
79
+ scales,
80
+ rotations,
81
+ raster_settings.scale_modifier,
82
+ cov3Ds_precomp,
83
+ raster_settings.viewmatrix,
84
+ raster_settings.projmatrix,
85
+ raster_settings.tanfovx,
86
+ raster_settings.tanfovy,
87
+ raster_settings.image_height,
88
+ raster_settings.image_width,
89
+ sh,
90
+ raster_settings.sh_degree,
91
+ raster_settings.campos,
92
+ raster_settings.prefiltered,
93
+ raster_settings.debug
94
+ )
95
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii, geomBuffer, \
96
+ binningBuffer, imgBuffer = None, None, None, None, None, None, None, None, None
97
+
98
+ if raster_settings.f_count:
99
+ args = args + (raster_settings.f_count,)
100
+ if raster_settings.debug:
101
+ cpu_args = cpu_deep_copy_tuple(args)
102
+ try:
103
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii,\
104
+ geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
105
+ except Exception as ex:
106
+ torch.save(cpu_args, "snapshot_fw.dump")
107
+ print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
108
+ raise ex
109
+ else:
110
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii,\
111
+ geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
112
+ else:
113
+ # Invoke C++/CUDA rasterizer
114
+ if raster_settings.debug:
115
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
116
+ try:
117
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
118
+ except Exception as ex:
119
+ torch.save(cpu_args, "snapshot_fw.dump")
120
+ print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
121
+ raise ex
122
+ else:
123
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
124
+
125
+ # Keep relevant tensors for backward
126
+ ctx.raster_settings = raster_settings
127
+ ctx.num_rendered = num_rendered
128
+ ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
129
+ ctx.count = gaussians_count
130
+ ctx.opacity_important_score = opacity_important_score
131
+ ctx.T_alpha_important_score = T_alpha_important_score
132
+
133
+ if raster_settings.f_count:
134
+ return gaussians_count, opacity_important_score, T_alpha_important_score, color, radii
135
+
136
+ return color, radii
137
+
138
+ @staticmethod
139
+ def forward_count(
140
+ means3D,
141
+ means2D,
142
+ sh,
143
+ colors_precomp,
144
+ opacities,
145
+ scales,
146
+ rotations,
147
+ cov3Ds_precomp,
148
+ raster_settings,
149
+ ):
150
+ assert(raster_settings.f_count)
151
+ # Restructure arguments the way that the C++ lib expects them
152
+ args = (
153
+ raster_settings.bg,
154
+ means3D,
155
+ colors_precomp,
156
+ opacities,
157
+ scales,
158
+ rotations,
159
+ raster_settings.scale_modifier,
160
+ cov3Ds_precomp,
161
+ raster_settings.viewmatrix,
162
+ raster_settings.projmatrix,
163
+ raster_settings.tanfovx,
164
+ raster_settings.tanfovy,
165
+ raster_settings.image_height,
166
+ raster_settings.image_width,
167
+ sh,
168
+ raster_settings.sh_degree,
169
+ raster_settings.campos,
170
+ raster_settings.prefiltered,
171
+ raster_settings.debug,
172
+ raster_settings.f_count
173
+ )
174
+ # gaussians_count, important_score, num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = None, None, None, None, None, None, None, None
175
+ # Invoke C++/CUDA rasterizer
176
+ # TODO(Kevin): pass the count in, but the output include a count list
177
+ if raster_settings.debug:
178
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
179
+ try:
180
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
181
+ except Exception as ex:
182
+ torch.save(cpu_args, "snapshot_fw.dump")
183
+ print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
184
+ raise ex
185
+ else:
186
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
187
+
188
+ return gaussians_count, opacity_important_score, T_alpha_important_score, color, radii
189
+
190
+ @staticmethod
191
+ def backward(ctx, grad_out_color, _):
192
+
193
+ # Restore necessary values from context
194
+ num_rendered = ctx.num_rendered
195
+ raster_settings = ctx.raster_settings
196
+ colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
197
+
198
+ # Restructure args as C++ method expects them
199
+ args = (raster_settings.bg,
200
+ means3D,
201
+ radii,
202
+ colors_precomp,
203
+ scales,
204
+ rotations,
205
+ raster_settings.scale_modifier,
206
+ cov3Ds_precomp,
207
+ raster_settings.viewmatrix,
208
+ raster_settings.projmatrix,
209
+ raster_settings.tanfovx,
210
+ raster_settings.tanfovy,
211
+ grad_out_color,
212
+ sh,
213
+ raster_settings.sh_degree,
214
+ raster_settings.campos,
215
+ geomBuffer,
216
+ num_rendered,
217
+ binningBuffer,
218
+ imgBuffer,
219
+ raster_settings.debug)
220
+
221
+ # Compute gradients for relevant tensors by invoking backward method
222
+ if raster_settings.debug:
223
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
224
+ try:
225
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
226
+ except Exception as ex:
227
+ torch.save(cpu_args, "snapshot_bw.dump")
228
+ print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
229
+ raise ex
230
+ else:
231
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
232
+
233
+ grads = (
234
+ grad_means3D,
235
+ grad_means2D,
236
+ grad_sh,
237
+ grad_colors_precomp,
238
+ grad_opacities,
239
+ grad_scales,
240
+ grad_rotations,
241
+ grad_cov3Ds_precomp,
242
+ None,
243
+ )
244
+
245
+ return grads
246
+
247
+ class GaussianRasterizationSettings(NamedTuple):
248
+ image_height: int
249
+ image_width: int
250
+ tanfovx : float
251
+ tanfovy : float
252
+ bg : torch.Tensor
253
+ scale_modifier : float
254
+ viewmatrix : torch.Tensor
255
+ projmatrix : torch.Tensor
256
+ sh_degree : int
257
+ campos : torch.Tensor
258
+ prefiltered : bool
259
+ debug : bool
260
+ f_count : bool
261
+
262
+ class GaussianRasterizer(nn.Module):
263
+ def __init__(self, raster_settings):
264
+ super().__init__()
265
+ self.raster_settings = raster_settings
266
+
267
+ def markVisible(self, positions):
268
+ # Mark visible points (based on frustum culling for camera) with a boolean
269
+ with torch.no_grad():
270
+ raster_settings = self.raster_settings
271
+ visible = _C.mark_visible(
272
+ positions,
273
+ raster_settings.viewmatrix,
274
+ raster_settings.projmatrix)
275
+
276
+ return visible
277
+
278
+ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
279
+
280
+ raster_settings = self.raster_settings
281
+
282
+ if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
283
+ raise Exception('Please provide excatly one of either SHs or precomputed colors!')
284
+
285
+ if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
286
+ raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
287
+
288
+ if shs is None:
289
+ shs = torch.Tensor([])
290
+ if colors_precomp is None:
291
+ colors_precomp = torch.Tensor([])
292
+
293
+ if scales is None:
294
+ scales = torch.Tensor([])
295
+ if rotations is None:
296
+ rotations = torch.Tensor([])
297
+ if cov3D_precomp is None:
298
+ cov3D_precomp = torch.Tensor([])
299
+
300
+ # Invoke C++/CUDA rasterization routine
301
+ return rasterize_gaussians(
302
+ means3D,
303
+ means2D,
304
+ shs,
305
+ colors_precomp,
306
+ opacities,
307
+ scales,
308
+ rotations,
309
+ cov3D_precomp,
310
+ raster_settings,
311
+ )
312
+
313
+ def forward_count(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
314
+
315
+ raster_settings = self.raster_settings
316
+
317
+ if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
318
+ raise Exception('Please provide excatly one of either SHs or precomputed colors!')
319
+
320
+ if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
321
+ raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
322
+
323
+ if shs is None:
324
+ shs = torch.Tensor([])
325
+ if colors_precomp is None:
326
+ colors_precomp = torch.Tensor([])
327
+
328
+ if scales is None:
329
+ scales = torch.Tensor([])
330
+ if rotations is None:
331
+ rotations = torch.Tensor([])
332
+ if cov3D_precomp is None:
333
+ cov3D_precomp = torch.Tensor([])
334
+
335
+ # Invoke C++/CUDA rasterization routine
336
+ return rasterize_gaussians(
337
+ means3D,
338
+ means2D,
339
+ shs,
340
+ colors_precomp,
341
+ opacities,
342
+ scales,
343
+ rotations,
344
+ cov3D_precomp,
345
+ raster_settings,
346
+ )
347
+
@@ -0,0 +1,269 @@
1
+ import math
2
+ from typing import List
3
+ import torch
4
+
5
+ from gaussian_splatting import Camera, GaussianModel
6
+ from gaussian_splatting.camera import build_camera
7
+ from gaussian_splatting.trainer import AbstractDensifier, DensifierWrapper, DensificationTrainer, NoopDensifier
8
+ from gaussian_splatting.dataset import CameraDataset
9
+ from .diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
10
+
11
+
12
+ def count_render(self: GaussianModel, viewpoint_camera: Camera):
13
+ """
14
+ Render the scene.
15
+
16
+ Background tensor (bg_color) must be on GPU!
17
+ """
18
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
19
+ screenspace_points = torch.zeros_like(self.get_xyz, dtype=self.get_xyz.dtype, requires_grad=True, device=self._xyz.device) + 0
20
+ try:
21
+ screenspace_points.retain_grad()
22
+ except:
23
+ pass
24
+
25
+ # Set up rasterization configuration
26
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
27
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
28
+
29
+ raster_settings = GaussianRasterizationSettings(
30
+ image_height=int(viewpoint_camera.image_height),
31
+ image_width=int(viewpoint_camera.image_width),
32
+ tanfovx=tanfovx,
33
+ tanfovy=tanfovy,
34
+ bg=viewpoint_camera.bg_color.to(self._xyz.device),
35
+ scale_modifier=self.scale_modifier,
36
+ viewmatrix=viewpoint_camera.world_view_transform,
37
+ projmatrix=viewpoint_camera.full_proj_transform,
38
+ sh_degree=self.active_sh_degree,
39
+ campos=viewpoint_camera.camera_center,
40
+ prefiltered=False,
41
+ debug=self.debug,
42
+ f_count=True,
43
+ )
44
+
45
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
46
+ means3D = self.get_xyz
47
+ means2D = screenspace_points
48
+ opacity = self.get_opacity
49
+
50
+ scales = self.get_scaling
51
+ rotations = self.get_rotation
52
+
53
+ shs = self.get_features
54
+
55
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
56
+ gaussians_count, opacity_important_score, T_alpha_important_score, rendered_image, radii = rasterizer(
57
+ means3D=means3D,
58
+ means2D=means2D,
59
+ shs=shs,
60
+ colors_precomp=None,
61
+ opacities=opacity,
62
+ scales=scales,
63
+ rotations=rotations,
64
+ cov3D_precomp=None)
65
+
66
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
67
+ # They will be excluded from value updates used in the splitting criteria.
68
+ return {
69
+ "render": rendered_image,
70
+ "viewspace_points": screenspace_points,
71
+ "visibility_filter": radii > 0,
72
+ "radii": radii,
73
+ "gaussians_count": gaussians_count,
74
+ "opacity_important_score": opacity_important_score,
75
+ "T_alpha_important_score": T_alpha_important_score
76
+ }
77
+
78
+
79
+ def prune_list(model: GaussianModel, dataset: CameraDataset, resize=None):
80
+ gaussian_count = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.int)
81
+ opacity_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
82
+ T_alpha_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
83
+ for camera in dataset:
84
+ if resize is not None:
85
+ height, width = camera.image_height, camera.image_width
86
+ scale = resize / max(height, width)
87
+ height, width = int(height * scale), int(width * scale)
88
+ camera = build_camera(
89
+ image_height=height, image_width=width,
90
+ FoVx=camera.FoVx, FoVy=camera.FoVy,
91
+ R=camera.R, T=camera.T,
92
+ device=camera.R.device)
93
+ out = count_render(model, camera)
94
+ gaussian_count += out["gaussians_count"]
95
+ opacity_important_score += out["opacity_important_score"]
96
+ T_alpha_important_score += out["T_alpha_important_score"]
97
+ return gaussian_count, opacity_important_score, T_alpha_important_score
98
+
99
+
100
+ # return importance score with adaptive volume measure described in paper
101
+ def calculate_v_imp_score(gaussians: GaussianModel, imp_list, v_pow):
102
+ """
103
+ :param gaussians: A data structure containing Gaussian components with a get_scaling method.
104
+ :param imp_list: The importance scores for each Gaussian component.
105
+ :param v_pow: The power to which the volume ratios are raised.
106
+ :return: A list of adjusted values (v_list) used for pruning.
107
+ """
108
+ # Calculate the volume of each Gaussian component
109
+ volume = torch.prod(gaussians.get_scaling, dim=1)
110
+ # Determine the kth_percent_largest value
111
+ index = int(len(volume) * 0.9)
112
+ sorted_volume, _ = torch.sort(volume, descending=True)
113
+ kth_percent_largest = sorted_volume[index]
114
+ # Calculate v_list
115
+ v_list = torch.pow(volume / kth_percent_largest, v_pow)
116
+ v_list = v_list * imp_list
117
+ return v_list
118
+
119
+
120
+ def score2mask(percent, import_score: list, threshold=None):
121
+ sorted_tensor, _ = torch.sort(import_score, dim=0)
122
+ index_nth_percentile = int(percent * (sorted_tensor.shape[0] - 1))
123
+ value_nth_percentile = sorted_tensor[index_nth_percentile]
124
+ thr = min(threshold, value_nth_percentile) if threshold is not None else value_nth_percentile
125
+ prune_mask = (import_score <= thr)
126
+ return prune_mask
127
+
128
+
129
+ def prune_gaussians(
130
+ gaussians: GaussianModel, dataset: CameraDataset,
131
+ resize=None,
132
+ prune_type="comprehensive",
133
+ prune_percent=0.1,
134
+ prune_thr_important_score=None,
135
+ prune_thr_v_important_score=None,
136
+ prune_thr_max_v_important_score=None,
137
+ prune_thr_count=None,
138
+ prune_thr_T_alpha=None,
139
+ prune_thr_T_alpha_avg=None,
140
+ v_pow=0.1):
141
+ gaussian_list, opacity_imp_list, T_alpha_imp_list = prune_list(gaussians, dataset, resize)
142
+ match prune_type:
143
+ case "important_score":
144
+ mask = score2mask(prune_percent, opacity_imp_list, prune_thr_important_score)
145
+ case "v_important_score":
146
+ v_list = calculate_v_imp_score(gaussians, opacity_imp_list, v_pow)
147
+ mask = score2mask(prune_percent, v_list, prune_thr_v_important_score)
148
+ case "max_v_important_score":
149
+ v_list = opacity_imp_list * torch.max(gaussians.get_scaling, dim=1)[0]
150
+ mask = score2mask(prune_percent, v_list, prune_thr_max_v_important_score)
151
+ case "count":
152
+ mask = score2mask(prune_percent, gaussian_list, prune_thr_count)
153
+ case "T_alpha":
154
+ # new importance score defined by doji
155
+ mask = score2mask(prune_percent, T_alpha_imp_list, prune_thr_T_alpha)
156
+ case "T_alpha_avg":
157
+ v_list = T_alpha_imp_list / gaussian_list
158
+ v_list[gaussian_list <= 0] = 0
159
+ mask = score2mask(prune_percent, v_list, prune_thr_T_alpha_avg)
160
+ case "comprehensive":
161
+ mask = torch.zeros_like(gaussian_list, dtype=torch.bool)
162
+ if prune_thr_important_score is not None:
163
+ mask |= score2mask(prune_percent, opacity_imp_list, prune_thr_important_score)
164
+ if prune_thr_v_important_score is not None:
165
+ v_list = calculate_v_imp_score(gaussians, opacity_imp_list, v_pow)
166
+ mask |= score2mask(prune_percent, v_list, prune_thr_v_important_score)
167
+ if prune_thr_max_v_important_score is not None:
168
+ v_list = opacity_imp_list * torch.max(gaussians.get_scaling, dim=1)[0]
169
+ mask |= score2mask(prune_percent, v_list, prune_thr_max_v_important_score)
170
+ if prune_thr_count is not None:
171
+ mask |= score2mask(prune_percent, gaussian_list, prune_thr_count)
172
+ if prune_thr_T_alpha is not None:
173
+ mask |= score2mask(prune_percent, T_alpha_imp_list, prune_thr_T_alpha)
174
+ if prune_thr_T_alpha_avg is not None:
175
+ v_list = T_alpha_imp_list / gaussian_list
176
+ v_list[gaussian_list <= 0] = 0
177
+ mask |= score2mask(prune_percent, v_list, prune_thr_T_alpha_avg)
178
+ case _:
179
+ raise Exception("Unsupportive prunning method")
180
+ return mask
181
+
182
+
183
+ class ImportancePruner(DensifierWrapper):
184
+ def __init__(
185
+ self, base_densifier: AbstractDensifier,
186
+ dataset: CameraDataset,
187
+ importance_prune_from_iter=15000,
188
+ importance_prune_until_iter=20000,
189
+ importance_prune_interval: int = 1000,
190
+ importance_score_resize=None,
191
+ importance_prune_type="comprehensive",
192
+ importance_prune_percent=0.1,
193
+ importance_prune_thr_important_score=None,
194
+ importance_prune_thr_v_important_score=3.0,
195
+ importance_prune_thr_max_v_important_score=None,
196
+ importance_prune_thr_count=1,
197
+ importance_prune_thr_T_alpha=1,
198
+ importance_prune_thr_T_alpha_avg=0.001,
199
+ importance_v_pow=0.1):
200
+ super().__init__(base_densifier)
201
+ self.dataset = dataset
202
+ self.importance_prune_from_iter = importance_prune_from_iter
203
+ self.importance_prune_until_iter = importance_prune_until_iter
204
+ self.importance_prune_interval = importance_prune_interval
205
+ self.resize = importance_score_resize
206
+ self.prune_percent = importance_prune_percent
207
+ self.prune_thr_important_score = importance_prune_thr_important_score
208
+ self.prune_thr_v_important_score = importance_prune_thr_v_important_score
209
+ self.prune_thr_max_v_important_score = importance_prune_thr_max_v_important_score
210
+ self.prune_thr_count = importance_prune_thr_count
211
+ self.prune_thr_T_alpha = importance_prune_thr_T_alpha
212
+ self.prune_thr_T_alpha_avg = importance_prune_thr_T_alpha_avg
213
+ self.v_pow = importance_v_pow
214
+ self.prune_type = importance_prune_type
215
+
216
+ def densify_and_prune(self, loss, out, camera, step: int):
217
+ ret = super().densify_and_prune(loss, out, camera, step)
218
+ if self.importance_prune_from_iter <= step <= self.importance_prune_until_iter and step % self.importance_prune_interval == 0:
219
+ remove_mask = prune_gaussians(
220
+ self.model, self.dataset,
221
+ self.resize,
222
+ self.prune_type, self.prune_percent,
223
+ self.prune_thr_important_score, self.prune_thr_v_important_score,
224
+ self.prune_thr_max_v_important_score, self.prune_thr_count,
225
+ self.prune_thr_T_alpha, self.prune_thr_T_alpha_avg, self.v_pow,
226
+ )
227
+ ret = ret._replace(remove_mask=remove_mask if ret.remove_mask is None else torch.logical_or(ret.remove_mask, remove_mask))
228
+ return ret
229
+
230
+
231
+ def BaseImportancePruningTrainer(
232
+ model: GaussianModel,
233
+ scene_extent: float,
234
+ dataset: List[Camera],
235
+ *args,
236
+ importance_prune_from_iter=15000,
237
+ importance_prune_until_iter=20000,
238
+ importance_prune_interval: int = 1000,
239
+ importance_score_resize=None,
240
+ importance_prune_type="comprehensive",
241
+ importance_prune_percent=0.1,
242
+ importance_prune_thr_important_score=None,
243
+ importance_prune_thr_v_important_score=3.0,
244
+ importance_prune_thr_max_v_important_score=None,
245
+ importance_prune_thr_count=1,
246
+ importance_prune_thr_T_alpha=1.0,
247
+ importance_prune_thr_T_alpha_avg=0.001,
248
+ importance_v_pow=0.1,
249
+ **kwargs):
250
+ return DensificationTrainer(
251
+ model, scene_extent,
252
+ ImportancePruner(
253
+ NoopDensifier(model),
254
+ dataset,
255
+ importance_prune_from_iter=importance_prune_from_iter,
256
+ importance_prune_until_iter=importance_prune_until_iter,
257
+ importance_prune_interval=importance_prune_interval,
258
+ importance_score_resize=importance_score_resize,
259
+ importance_prune_type=importance_prune_type,
260
+ importance_prune_percent=importance_prune_percent,
261
+ importance_prune_thr_important_score=importance_prune_thr_important_score,
262
+ importance_prune_thr_v_important_score=importance_prune_thr_v_important_score,
263
+ importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
264
+ importance_prune_thr_count=importance_prune_thr_count,
265
+ importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
266
+ importance_prune_thr_T_alpha_avg=importance_prune_thr_T_alpha_avg,
267
+ importance_v_pow=importance_v_pow,
268
+ ), *args, **kwargs
269
+ )
@@ -0,0 +1,2 @@
1
+ from .trainer import BasePruner, BasePruningTrainer, PruningTrainerWrapper
2
+ from .combinations import BasePrunerInDensifyTrainer, PruningTrainer, PrunerInDensifyTrainer, PrunerInDensifyTrainerWrapper
@@ -0,0 +1,65 @@
1
+
2
+ from typing import Callable, List
3
+ from gaussian_splatting import Camera, GaussianModel
4
+ from gaussian_splatting.dataset import TrainableCameraDataset
5
+ from gaussian_splatting.trainer import AbstractDensifier, DepthTrainerWrapper, NoopDensifier, SplitCloneDensifierTrainerWrapper
6
+ from .trainer import BasePruner, BasePruningTrainer
7
+
8
+
9
+ def PrunerInDensifyTrainerWrapper(
10
+ noargs_base_densifier_constructor: Callable[[GaussianModel, float, List[Camera]], AbstractDensifier],
11
+ model: GaussianModel,
12
+ scene_extent: float,
13
+ dataset: List[Camera],
14
+ *args,
15
+ prune_from_iter=1000,
16
+ prune_until_iter=15000,
17
+ prune_interval: int = 100,
18
+ box_size=1.,
19
+ lambda_mercy=1.,
20
+ mercy_minimum=3,
21
+ mercy_type='redundancy_opacity',
22
+ **kwargs):
23
+ return SplitCloneDensifierTrainerWrapper(
24
+ lambda model, scene_extent: BasePruner(
25
+ noargs_base_densifier_constructor(model, scene_extent, dataset),
26
+ scene_extent, dataset,
27
+ prune_from_iter=prune_from_iter,
28
+ prune_until_iter=prune_until_iter,
29
+ prune_interval=prune_interval,
30
+ box_size=box_size,
31
+ lambda_mercy=lambda_mercy,
32
+ mercy_minimum=mercy_minimum,
33
+ mercy_type=mercy_type,
34
+ ),
35
+ model,
36
+ scene_extent,
37
+ *args, **kwargs
38
+ )
39
+
40
+
41
+ def BasePrunerInDensifyTrainer(
42
+ model: GaussianModel,
43
+ scene_extent: float,
44
+ dataset: List[Camera],
45
+ *args, **kwargs):
46
+ return PrunerInDensifyTrainerWrapper(
47
+ lambda model, scene_extent, dataset: NoopDensifier(model),
48
+ model, scene_extent, dataset,
49
+ *args, **kwargs
50
+ )
51
+
52
+
53
+ # Depth trainer
54
+
55
+
56
+ def DepthPruningTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
57
+ return DepthTrainerWrapper(BasePruningTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
58
+
59
+
60
+ def DepthPrunerInDensifyTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
61
+ return DepthTrainerWrapper(BasePrunerInDensifyTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
62
+
63
+
64
+ PruningTrainer = DepthPruningTrainer
65
+ PrunerInDensifyTrainer = DepthPrunerInDensifyTrainer