reduced-3dgs 1.10.0__cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reduced-3dgs might be problematic. Click here for more details.
- reduced_3dgs/__init__.py +0 -0
- reduced_3dgs/combinations.py +245 -0
- reduced_3dgs/diff_gaussian_rasterization/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- reduced_3dgs/diff_gaussian_rasterization/__init__.py +235 -0
- reduced_3dgs/importance/__init__.py +3 -0
- reduced_3dgs/importance/combinations.py +63 -0
- reduced_3dgs/importance/diff_gaussian_rasterization/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- reduced_3dgs/importance/diff_gaussian_rasterization/__init__.py +347 -0
- reduced_3dgs/importance/trainer.py +269 -0
- reduced_3dgs/pruning/__init__.py +2 -0
- reduced_3dgs/pruning/combinations.py +65 -0
- reduced_3dgs/pruning/trainer.py +145 -0
- reduced_3dgs/quantization/__init__.py +4 -0
- reduced_3dgs/quantization/abc.py +49 -0
- reduced_3dgs/quantization/exclude_zeros.py +41 -0
- reduced_3dgs/quantization/quantizer.py +289 -0
- reduced_3dgs/quantization/wrapper.py +67 -0
- reduced_3dgs/quantize.py +49 -0
- reduced_3dgs/shculling/__init__.py +2 -0
- reduced_3dgs/shculling/gaussian_model.py +78 -0
- reduced_3dgs/shculling/trainer.py +158 -0
- reduced_3dgs/simple_knn/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- reduced_3dgs/train.py +195 -0
- reduced_3dgs-1.10.0.dist-info/LICENSE.md +93 -0
- reduced_3dgs-1.10.0.dist-info/METADATA +278 -0
- reduced_3dgs-1.10.0.dist-info/RECORD +31 -0
- reduced_3dgs-1.10.0.dist-info/WHEEL +6 -0
- reduced_3dgs-1.10.0.dist-info/top_level.txt +1 -0
- reduced_3dgs.libs/libc10-ff4eddb5.so +0 -0
- reduced_3dgs.libs/libc10_cuda-c675d3fb.so +0 -0
- reduced_3dgs.libs/libcudart-8774224f.so.12.4.127 +0 -0
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright (C) 2023, Inria
|
|
3
|
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# This software is free for non-commercial, research and evaluation use
|
|
7
|
+
# under the terms of the LICENSE.md file.
|
|
8
|
+
#
|
|
9
|
+
# For inquiries contact george.drettakis@inria.fr
|
|
10
|
+
#
|
|
11
|
+
|
|
12
|
+
from typing import NamedTuple
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch
|
|
15
|
+
from . import _C
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def cpu_deep_copy_tuple(input_tuple):
|
|
19
|
+
copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
|
|
20
|
+
return tuple(copied_tensors)
|
|
21
|
+
|
|
22
|
+
def rasterize_gaussians(
|
|
23
|
+
means3D,
|
|
24
|
+
means2D,
|
|
25
|
+
sh,
|
|
26
|
+
colors_precomp,
|
|
27
|
+
opacities,
|
|
28
|
+
scales,
|
|
29
|
+
rotations,
|
|
30
|
+
cov3Ds_precomp,
|
|
31
|
+
raster_settings,
|
|
32
|
+
):
|
|
33
|
+
if raster_settings.f_count:
|
|
34
|
+
return _RasterizeGaussians.forward_count(
|
|
35
|
+
means3D,
|
|
36
|
+
means2D,
|
|
37
|
+
sh,
|
|
38
|
+
colors_precomp,
|
|
39
|
+
opacities,
|
|
40
|
+
scales,
|
|
41
|
+
rotations,
|
|
42
|
+
cov3Ds_precomp,
|
|
43
|
+
raster_settings,
|
|
44
|
+
|
|
45
|
+
)
|
|
46
|
+
return _RasterizeGaussians.apply(
|
|
47
|
+
means3D,
|
|
48
|
+
means2D,
|
|
49
|
+
sh,
|
|
50
|
+
colors_precomp,
|
|
51
|
+
opacities,
|
|
52
|
+
scales,
|
|
53
|
+
rotations,
|
|
54
|
+
cov3Ds_precomp,
|
|
55
|
+
raster_settings,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
class _RasterizeGaussians(torch.autograd.Function):
|
|
59
|
+
@staticmethod
|
|
60
|
+
def forward(
|
|
61
|
+
ctx,
|
|
62
|
+
means3D,
|
|
63
|
+
means2D,
|
|
64
|
+
sh,
|
|
65
|
+
colors_precomp,
|
|
66
|
+
opacities,
|
|
67
|
+
scales,
|
|
68
|
+
rotations,
|
|
69
|
+
cov3Ds_precomp,
|
|
70
|
+
raster_settings,
|
|
71
|
+
):
|
|
72
|
+
|
|
73
|
+
# Restructure arguments the way that the C++ lib expects them
|
|
74
|
+
args = (
|
|
75
|
+
raster_settings.bg,
|
|
76
|
+
means3D,
|
|
77
|
+
colors_precomp,
|
|
78
|
+
opacities,
|
|
79
|
+
scales,
|
|
80
|
+
rotations,
|
|
81
|
+
raster_settings.scale_modifier,
|
|
82
|
+
cov3Ds_precomp,
|
|
83
|
+
raster_settings.viewmatrix,
|
|
84
|
+
raster_settings.projmatrix,
|
|
85
|
+
raster_settings.tanfovx,
|
|
86
|
+
raster_settings.tanfovy,
|
|
87
|
+
raster_settings.image_height,
|
|
88
|
+
raster_settings.image_width,
|
|
89
|
+
sh,
|
|
90
|
+
raster_settings.sh_degree,
|
|
91
|
+
raster_settings.campos,
|
|
92
|
+
raster_settings.prefiltered,
|
|
93
|
+
raster_settings.debug
|
|
94
|
+
)
|
|
95
|
+
gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii, geomBuffer, \
|
|
96
|
+
binningBuffer, imgBuffer = None, None, None, None, None, None, None, None, None
|
|
97
|
+
|
|
98
|
+
if raster_settings.f_count:
|
|
99
|
+
args = args + (raster_settings.f_count,)
|
|
100
|
+
if raster_settings.debug:
|
|
101
|
+
cpu_args = cpu_deep_copy_tuple(args)
|
|
102
|
+
try:
|
|
103
|
+
gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii,\
|
|
104
|
+
geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
|
|
105
|
+
except Exception as ex:
|
|
106
|
+
torch.save(cpu_args, "snapshot_fw.dump")
|
|
107
|
+
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
|
|
108
|
+
raise ex
|
|
109
|
+
else:
|
|
110
|
+
gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii,\
|
|
111
|
+
geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
|
|
112
|
+
else:
|
|
113
|
+
# Invoke C++/CUDA rasterizer
|
|
114
|
+
if raster_settings.debug:
|
|
115
|
+
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
|
|
116
|
+
try:
|
|
117
|
+
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
|
|
118
|
+
except Exception as ex:
|
|
119
|
+
torch.save(cpu_args, "snapshot_fw.dump")
|
|
120
|
+
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
|
|
121
|
+
raise ex
|
|
122
|
+
else:
|
|
123
|
+
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
|
|
124
|
+
|
|
125
|
+
# Keep relevant tensors for backward
|
|
126
|
+
ctx.raster_settings = raster_settings
|
|
127
|
+
ctx.num_rendered = num_rendered
|
|
128
|
+
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
|
|
129
|
+
ctx.count = gaussians_count
|
|
130
|
+
ctx.opacity_important_score = opacity_important_score
|
|
131
|
+
ctx.T_alpha_important_score = T_alpha_important_score
|
|
132
|
+
|
|
133
|
+
if raster_settings.f_count:
|
|
134
|
+
return gaussians_count, opacity_important_score, T_alpha_important_score, color, radii
|
|
135
|
+
|
|
136
|
+
return color, radii
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def forward_count(
|
|
140
|
+
means3D,
|
|
141
|
+
means2D,
|
|
142
|
+
sh,
|
|
143
|
+
colors_precomp,
|
|
144
|
+
opacities,
|
|
145
|
+
scales,
|
|
146
|
+
rotations,
|
|
147
|
+
cov3Ds_precomp,
|
|
148
|
+
raster_settings,
|
|
149
|
+
):
|
|
150
|
+
assert(raster_settings.f_count)
|
|
151
|
+
# Restructure arguments the way that the C++ lib expects them
|
|
152
|
+
args = (
|
|
153
|
+
raster_settings.bg,
|
|
154
|
+
means3D,
|
|
155
|
+
colors_precomp,
|
|
156
|
+
opacities,
|
|
157
|
+
scales,
|
|
158
|
+
rotations,
|
|
159
|
+
raster_settings.scale_modifier,
|
|
160
|
+
cov3Ds_precomp,
|
|
161
|
+
raster_settings.viewmatrix,
|
|
162
|
+
raster_settings.projmatrix,
|
|
163
|
+
raster_settings.tanfovx,
|
|
164
|
+
raster_settings.tanfovy,
|
|
165
|
+
raster_settings.image_height,
|
|
166
|
+
raster_settings.image_width,
|
|
167
|
+
sh,
|
|
168
|
+
raster_settings.sh_degree,
|
|
169
|
+
raster_settings.campos,
|
|
170
|
+
raster_settings.prefiltered,
|
|
171
|
+
raster_settings.debug,
|
|
172
|
+
raster_settings.f_count
|
|
173
|
+
)
|
|
174
|
+
# gaussians_count, important_score, num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = None, None, None, None, None, None, None, None
|
|
175
|
+
# Invoke C++/CUDA rasterizer
|
|
176
|
+
# TODO(Kevin): pass the count in, but the output include a count list
|
|
177
|
+
if raster_settings.debug:
|
|
178
|
+
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
|
|
179
|
+
try:
|
|
180
|
+
gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
|
|
181
|
+
except Exception as ex:
|
|
182
|
+
torch.save(cpu_args, "snapshot_fw.dump")
|
|
183
|
+
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
|
|
184
|
+
raise ex
|
|
185
|
+
else:
|
|
186
|
+
gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
|
|
187
|
+
|
|
188
|
+
return gaussians_count, opacity_important_score, T_alpha_important_score, color, radii
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
def backward(ctx, grad_out_color, _):
|
|
192
|
+
|
|
193
|
+
# Restore necessary values from context
|
|
194
|
+
num_rendered = ctx.num_rendered
|
|
195
|
+
raster_settings = ctx.raster_settings
|
|
196
|
+
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
|
|
197
|
+
|
|
198
|
+
# Restructure args as C++ method expects them
|
|
199
|
+
args = (raster_settings.bg,
|
|
200
|
+
means3D,
|
|
201
|
+
radii,
|
|
202
|
+
colors_precomp,
|
|
203
|
+
scales,
|
|
204
|
+
rotations,
|
|
205
|
+
raster_settings.scale_modifier,
|
|
206
|
+
cov3Ds_precomp,
|
|
207
|
+
raster_settings.viewmatrix,
|
|
208
|
+
raster_settings.projmatrix,
|
|
209
|
+
raster_settings.tanfovx,
|
|
210
|
+
raster_settings.tanfovy,
|
|
211
|
+
grad_out_color,
|
|
212
|
+
sh,
|
|
213
|
+
raster_settings.sh_degree,
|
|
214
|
+
raster_settings.campos,
|
|
215
|
+
geomBuffer,
|
|
216
|
+
num_rendered,
|
|
217
|
+
binningBuffer,
|
|
218
|
+
imgBuffer,
|
|
219
|
+
raster_settings.debug)
|
|
220
|
+
|
|
221
|
+
# Compute gradients for relevant tensors by invoking backward method
|
|
222
|
+
if raster_settings.debug:
|
|
223
|
+
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
|
|
224
|
+
try:
|
|
225
|
+
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
|
|
226
|
+
except Exception as ex:
|
|
227
|
+
torch.save(cpu_args, "snapshot_bw.dump")
|
|
228
|
+
print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
|
|
229
|
+
raise ex
|
|
230
|
+
else:
|
|
231
|
+
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
|
|
232
|
+
|
|
233
|
+
grads = (
|
|
234
|
+
grad_means3D,
|
|
235
|
+
grad_means2D,
|
|
236
|
+
grad_sh,
|
|
237
|
+
grad_colors_precomp,
|
|
238
|
+
grad_opacities,
|
|
239
|
+
grad_scales,
|
|
240
|
+
grad_rotations,
|
|
241
|
+
grad_cov3Ds_precomp,
|
|
242
|
+
None,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
return grads
|
|
246
|
+
|
|
247
|
+
class GaussianRasterizationSettings(NamedTuple):
|
|
248
|
+
image_height: int
|
|
249
|
+
image_width: int
|
|
250
|
+
tanfovx : float
|
|
251
|
+
tanfovy : float
|
|
252
|
+
bg : torch.Tensor
|
|
253
|
+
scale_modifier : float
|
|
254
|
+
viewmatrix : torch.Tensor
|
|
255
|
+
projmatrix : torch.Tensor
|
|
256
|
+
sh_degree : int
|
|
257
|
+
campos : torch.Tensor
|
|
258
|
+
prefiltered : bool
|
|
259
|
+
debug : bool
|
|
260
|
+
f_count : bool
|
|
261
|
+
|
|
262
|
+
class GaussianRasterizer(nn.Module):
|
|
263
|
+
def __init__(self, raster_settings):
|
|
264
|
+
super().__init__()
|
|
265
|
+
self.raster_settings = raster_settings
|
|
266
|
+
|
|
267
|
+
def markVisible(self, positions):
|
|
268
|
+
# Mark visible points (based on frustum culling for camera) with a boolean
|
|
269
|
+
with torch.no_grad():
|
|
270
|
+
raster_settings = self.raster_settings
|
|
271
|
+
visible = _C.mark_visible(
|
|
272
|
+
positions,
|
|
273
|
+
raster_settings.viewmatrix,
|
|
274
|
+
raster_settings.projmatrix)
|
|
275
|
+
|
|
276
|
+
return visible
|
|
277
|
+
|
|
278
|
+
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
|
|
279
|
+
|
|
280
|
+
raster_settings = self.raster_settings
|
|
281
|
+
|
|
282
|
+
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
|
|
283
|
+
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
|
|
284
|
+
|
|
285
|
+
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
|
|
286
|
+
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
|
|
287
|
+
|
|
288
|
+
if shs is None:
|
|
289
|
+
shs = torch.Tensor([])
|
|
290
|
+
if colors_precomp is None:
|
|
291
|
+
colors_precomp = torch.Tensor([])
|
|
292
|
+
|
|
293
|
+
if scales is None:
|
|
294
|
+
scales = torch.Tensor([])
|
|
295
|
+
if rotations is None:
|
|
296
|
+
rotations = torch.Tensor([])
|
|
297
|
+
if cov3D_precomp is None:
|
|
298
|
+
cov3D_precomp = torch.Tensor([])
|
|
299
|
+
|
|
300
|
+
# Invoke C++/CUDA rasterization routine
|
|
301
|
+
return rasterize_gaussians(
|
|
302
|
+
means3D,
|
|
303
|
+
means2D,
|
|
304
|
+
shs,
|
|
305
|
+
colors_precomp,
|
|
306
|
+
opacities,
|
|
307
|
+
scales,
|
|
308
|
+
rotations,
|
|
309
|
+
cov3D_precomp,
|
|
310
|
+
raster_settings,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def forward_count(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
|
|
314
|
+
|
|
315
|
+
raster_settings = self.raster_settings
|
|
316
|
+
|
|
317
|
+
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
|
|
318
|
+
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
|
|
319
|
+
|
|
320
|
+
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
|
|
321
|
+
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
|
|
322
|
+
|
|
323
|
+
if shs is None:
|
|
324
|
+
shs = torch.Tensor([])
|
|
325
|
+
if colors_precomp is None:
|
|
326
|
+
colors_precomp = torch.Tensor([])
|
|
327
|
+
|
|
328
|
+
if scales is None:
|
|
329
|
+
scales = torch.Tensor([])
|
|
330
|
+
if rotations is None:
|
|
331
|
+
rotations = torch.Tensor([])
|
|
332
|
+
if cov3D_precomp is None:
|
|
333
|
+
cov3D_precomp = torch.Tensor([])
|
|
334
|
+
|
|
335
|
+
# Invoke C++/CUDA rasterization routine
|
|
336
|
+
return rasterize_gaussians(
|
|
337
|
+
means3D,
|
|
338
|
+
means2D,
|
|
339
|
+
shs,
|
|
340
|
+
colors_precomp,
|
|
341
|
+
opacities,
|
|
342
|
+
scales,
|
|
343
|
+
rotations,
|
|
344
|
+
cov3D_precomp,
|
|
345
|
+
raster_settings,
|
|
346
|
+
)
|
|
347
|
+
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from gaussian_splatting import Camera, GaussianModel
|
|
6
|
+
from gaussian_splatting.camera import build_camera
|
|
7
|
+
from gaussian_splatting.trainer import AbstractDensifier, DensifierWrapper, DensificationTrainer, NoopDensifier
|
|
8
|
+
from gaussian_splatting.dataset import CameraDataset
|
|
9
|
+
from .diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def count_render(self: GaussianModel, viewpoint_camera: Camera):
|
|
13
|
+
"""
|
|
14
|
+
Render the scene.
|
|
15
|
+
|
|
16
|
+
Background tensor (bg_color) must be on GPU!
|
|
17
|
+
"""
|
|
18
|
+
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
|
19
|
+
screenspace_points = torch.zeros_like(self.get_xyz, dtype=self.get_xyz.dtype, requires_grad=True, device=self._xyz.device) + 0
|
|
20
|
+
try:
|
|
21
|
+
screenspace_points.retain_grad()
|
|
22
|
+
except:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
# Set up rasterization configuration
|
|
26
|
+
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
|
27
|
+
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
|
28
|
+
|
|
29
|
+
raster_settings = GaussianRasterizationSettings(
|
|
30
|
+
image_height=int(viewpoint_camera.image_height),
|
|
31
|
+
image_width=int(viewpoint_camera.image_width),
|
|
32
|
+
tanfovx=tanfovx,
|
|
33
|
+
tanfovy=tanfovy,
|
|
34
|
+
bg=viewpoint_camera.bg_color.to(self._xyz.device),
|
|
35
|
+
scale_modifier=self.scale_modifier,
|
|
36
|
+
viewmatrix=viewpoint_camera.world_view_transform,
|
|
37
|
+
projmatrix=viewpoint_camera.full_proj_transform,
|
|
38
|
+
sh_degree=self.active_sh_degree,
|
|
39
|
+
campos=viewpoint_camera.camera_center,
|
|
40
|
+
prefiltered=False,
|
|
41
|
+
debug=self.debug,
|
|
42
|
+
f_count=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
|
46
|
+
means3D = self.get_xyz
|
|
47
|
+
means2D = screenspace_points
|
|
48
|
+
opacity = self.get_opacity
|
|
49
|
+
|
|
50
|
+
scales = self.get_scaling
|
|
51
|
+
rotations = self.get_rotation
|
|
52
|
+
|
|
53
|
+
shs = self.get_features
|
|
54
|
+
|
|
55
|
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
|
56
|
+
gaussians_count, opacity_important_score, T_alpha_important_score, rendered_image, radii = rasterizer(
|
|
57
|
+
means3D=means3D,
|
|
58
|
+
means2D=means2D,
|
|
59
|
+
shs=shs,
|
|
60
|
+
colors_precomp=None,
|
|
61
|
+
opacities=opacity,
|
|
62
|
+
scales=scales,
|
|
63
|
+
rotations=rotations,
|
|
64
|
+
cov3D_precomp=None)
|
|
65
|
+
|
|
66
|
+
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
|
67
|
+
# They will be excluded from value updates used in the splitting criteria.
|
|
68
|
+
return {
|
|
69
|
+
"render": rendered_image,
|
|
70
|
+
"viewspace_points": screenspace_points,
|
|
71
|
+
"visibility_filter": radii > 0,
|
|
72
|
+
"radii": radii,
|
|
73
|
+
"gaussians_count": gaussians_count,
|
|
74
|
+
"opacity_important_score": opacity_important_score,
|
|
75
|
+
"T_alpha_important_score": T_alpha_important_score
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def prune_list(model: GaussianModel, dataset: CameraDataset, resize=None):
|
|
80
|
+
gaussian_count = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.int)
|
|
81
|
+
opacity_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
|
|
82
|
+
T_alpha_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
|
|
83
|
+
for camera in dataset:
|
|
84
|
+
if resize is not None:
|
|
85
|
+
height, width = camera.image_height, camera.image_width
|
|
86
|
+
scale = resize / max(height, width)
|
|
87
|
+
height, width = int(height * scale), int(width * scale)
|
|
88
|
+
camera = build_camera(
|
|
89
|
+
image_height=height, image_width=width,
|
|
90
|
+
FoVx=camera.FoVx, FoVy=camera.FoVy,
|
|
91
|
+
R=camera.R, T=camera.T,
|
|
92
|
+
device=camera.R.device)
|
|
93
|
+
out = count_render(model, camera)
|
|
94
|
+
gaussian_count += out["gaussians_count"]
|
|
95
|
+
opacity_important_score += out["opacity_important_score"]
|
|
96
|
+
T_alpha_important_score += out["T_alpha_important_score"]
|
|
97
|
+
return gaussian_count, opacity_important_score, T_alpha_important_score
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# return importance score with adaptive volume measure described in paper
|
|
101
|
+
def calculate_v_imp_score(gaussians: GaussianModel, imp_list, v_pow):
|
|
102
|
+
"""
|
|
103
|
+
:param gaussians: A data structure containing Gaussian components with a get_scaling method.
|
|
104
|
+
:param imp_list: The importance scores for each Gaussian component.
|
|
105
|
+
:param v_pow: The power to which the volume ratios are raised.
|
|
106
|
+
:return: A list of adjusted values (v_list) used for pruning.
|
|
107
|
+
"""
|
|
108
|
+
# Calculate the volume of each Gaussian component
|
|
109
|
+
volume = torch.prod(gaussians.get_scaling, dim=1)
|
|
110
|
+
# Determine the kth_percent_largest value
|
|
111
|
+
index = int(len(volume) * 0.9)
|
|
112
|
+
sorted_volume, _ = torch.sort(volume, descending=True)
|
|
113
|
+
kth_percent_largest = sorted_volume[index]
|
|
114
|
+
# Calculate v_list
|
|
115
|
+
v_list = torch.pow(volume / kth_percent_largest, v_pow)
|
|
116
|
+
v_list = v_list * imp_list
|
|
117
|
+
return v_list
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def score2mask(percent, import_score: list, threshold=None):
|
|
121
|
+
sorted_tensor, _ = torch.sort(import_score, dim=0)
|
|
122
|
+
index_nth_percentile = int(percent * (sorted_tensor.shape[0] - 1))
|
|
123
|
+
value_nth_percentile = sorted_tensor[index_nth_percentile]
|
|
124
|
+
thr = min(threshold, value_nth_percentile) if threshold is not None else value_nth_percentile
|
|
125
|
+
prune_mask = (import_score <= thr)
|
|
126
|
+
return prune_mask
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def prune_gaussians(
|
|
130
|
+
gaussians: GaussianModel, dataset: CameraDataset,
|
|
131
|
+
resize=None,
|
|
132
|
+
prune_type="comprehensive",
|
|
133
|
+
prune_percent=0.1,
|
|
134
|
+
prune_thr_important_score=None,
|
|
135
|
+
prune_thr_v_important_score=None,
|
|
136
|
+
prune_thr_max_v_important_score=None,
|
|
137
|
+
prune_thr_count=None,
|
|
138
|
+
prune_thr_T_alpha=None,
|
|
139
|
+
prune_thr_T_alpha_avg=None,
|
|
140
|
+
v_pow=0.1):
|
|
141
|
+
gaussian_list, opacity_imp_list, T_alpha_imp_list = prune_list(gaussians, dataset, resize)
|
|
142
|
+
match prune_type:
|
|
143
|
+
case "important_score":
|
|
144
|
+
mask = score2mask(prune_percent, opacity_imp_list, prune_thr_important_score)
|
|
145
|
+
case "v_important_score":
|
|
146
|
+
v_list = calculate_v_imp_score(gaussians, opacity_imp_list, v_pow)
|
|
147
|
+
mask = score2mask(prune_percent, v_list, prune_thr_v_important_score)
|
|
148
|
+
case "max_v_important_score":
|
|
149
|
+
v_list = opacity_imp_list * torch.max(gaussians.get_scaling, dim=1)[0]
|
|
150
|
+
mask = score2mask(prune_percent, v_list, prune_thr_max_v_important_score)
|
|
151
|
+
case "count":
|
|
152
|
+
mask = score2mask(prune_percent, gaussian_list, prune_thr_count)
|
|
153
|
+
case "T_alpha":
|
|
154
|
+
# new importance score defined by doji
|
|
155
|
+
mask = score2mask(prune_percent, T_alpha_imp_list, prune_thr_T_alpha)
|
|
156
|
+
case "T_alpha_avg":
|
|
157
|
+
v_list = T_alpha_imp_list / gaussian_list
|
|
158
|
+
v_list[gaussian_list <= 0] = 0
|
|
159
|
+
mask = score2mask(prune_percent, v_list, prune_thr_T_alpha_avg)
|
|
160
|
+
case "comprehensive":
|
|
161
|
+
mask = torch.zeros_like(gaussian_list, dtype=torch.bool)
|
|
162
|
+
if prune_thr_important_score is not None:
|
|
163
|
+
mask |= score2mask(prune_percent, opacity_imp_list, prune_thr_important_score)
|
|
164
|
+
if prune_thr_v_important_score is not None:
|
|
165
|
+
v_list = calculate_v_imp_score(gaussians, opacity_imp_list, v_pow)
|
|
166
|
+
mask |= score2mask(prune_percent, v_list, prune_thr_v_important_score)
|
|
167
|
+
if prune_thr_max_v_important_score is not None:
|
|
168
|
+
v_list = opacity_imp_list * torch.max(gaussians.get_scaling, dim=1)[0]
|
|
169
|
+
mask |= score2mask(prune_percent, v_list, prune_thr_max_v_important_score)
|
|
170
|
+
if prune_thr_count is not None:
|
|
171
|
+
mask |= score2mask(prune_percent, gaussian_list, prune_thr_count)
|
|
172
|
+
if prune_thr_T_alpha is not None:
|
|
173
|
+
mask |= score2mask(prune_percent, T_alpha_imp_list, prune_thr_T_alpha)
|
|
174
|
+
if prune_thr_T_alpha_avg is not None:
|
|
175
|
+
v_list = T_alpha_imp_list / gaussian_list
|
|
176
|
+
v_list[gaussian_list <= 0] = 0
|
|
177
|
+
mask |= score2mask(prune_percent, v_list, prune_thr_T_alpha_avg)
|
|
178
|
+
case _:
|
|
179
|
+
raise Exception("Unsupportive prunning method")
|
|
180
|
+
return mask
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class ImportancePruner(DensifierWrapper):
|
|
184
|
+
def __init__(
|
|
185
|
+
self, base_densifier: AbstractDensifier,
|
|
186
|
+
dataset: CameraDataset,
|
|
187
|
+
importance_prune_from_iter=15000,
|
|
188
|
+
importance_prune_until_iter=20000,
|
|
189
|
+
importance_prune_interval: int = 1000,
|
|
190
|
+
importance_score_resize=None,
|
|
191
|
+
importance_prune_type="comprehensive",
|
|
192
|
+
importance_prune_percent=0.1,
|
|
193
|
+
importance_prune_thr_important_score=None,
|
|
194
|
+
importance_prune_thr_v_important_score=3.0,
|
|
195
|
+
importance_prune_thr_max_v_important_score=None,
|
|
196
|
+
importance_prune_thr_count=1,
|
|
197
|
+
importance_prune_thr_T_alpha=1,
|
|
198
|
+
importance_prune_thr_T_alpha_avg=0.001,
|
|
199
|
+
importance_v_pow=0.1):
|
|
200
|
+
super().__init__(base_densifier)
|
|
201
|
+
self.dataset = dataset
|
|
202
|
+
self.importance_prune_from_iter = importance_prune_from_iter
|
|
203
|
+
self.importance_prune_until_iter = importance_prune_until_iter
|
|
204
|
+
self.importance_prune_interval = importance_prune_interval
|
|
205
|
+
self.resize = importance_score_resize
|
|
206
|
+
self.prune_percent = importance_prune_percent
|
|
207
|
+
self.prune_thr_important_score = importance_prune_thr_important_score
|
|
208
|
+
self.prune_thr_v_important_score = importance_prune_thr_v_important_score
|
|
209
|
+
self.prune_thr_max_v_important_score = importance_prune_thr_max_v_important_score
|
|
210
|
+
self.prune_thr_count = importance_prune_thr_count
|
|
211
|
+
self.prune_thr_T_alpha = importance_prune_thr_T_alpha
|
|
212
|
+
self.prune_thr_T_alpha_avg = importance_prune_thr_T_alpha_avg
|
|
213
|
+
self.v_pow = importance_v_pow
|
|
214
|
+
self.prune_type = importance_prune_type
|
|
215
|
+
|
|
216
|
+
def densify_and_prune(self, loss, out, camera, step: int):
|
|
217
|
+
ret = super().densify_and_prune(loss, out, camera, step)
|
|
218
|
+
if self.importance_prune_from_iter <= step <= self.importance_prune_until_iter and step % self.importance_prune_interval == 0:
|
|
219
|
+
remove_mask = prune_gaussians(
|
|
220
|
+
self.model, self.dataset,
|
|
221
|
+
self.resize,
|
|
222
|
+
self.prune_type, self.prune_percent,
|
|
223
|
+
self.prune_thr_important_score, self.prune_thr_v_important_score,
|
|
224
|
+
self.prune_thr_max_v_important_score, self.prune_thr_count,
|
|
225
|
+
self.prune_thr_T_alpha, self.prune_thr_T_alpha_avg, self.v_pow,
|
|
226
|
+
)
|
|
227
|
+
ret = ret._replace(remove_mask=remove_mask if ret.remove_mask is None else torch.logical_or(ret.remove_mask, remove_mask))
|
|
228
|
+
return ret
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def BaseImportancePruningTrainer(
|
|
232
|
+
model: GaussianModel,
|
|
233
|
+
scene_extent: float,
|
|
234
|
+
dataset: List[Camera],
|
|
235
|
+
*args,
|
|
236
|
+
importance_prune_from_iter=15000,
|
|
237
|
+
importance_prune_until_iter=20000,
|
|
238
|
+
importance_prune_interval: int = 1000,
|
|
239
|
+
importance_score_resize=None,
|
|
240
|
+
importance_prune_type="comprehensive",
|
|
241
|
+
importance_prune_percent=0.1,
|
|
242
|
+
importance_prune_thr_important_score=None,
|
|
243
|
+
importance_prune_thr_v_important_score=3.0,
|
|
244
|
+
importance_prune_thr_max_v_important_score=None,
|
|
245
|
+
importance_prune_thr_count=1,
|
|
246
|
+
importance_prune_thr_T_alpha=1.0,
|
|
247
|
+
importance_prune_thr_T_alpha_avg=0.001,
|
|
248
|
+
importance_v_pow=0.1,
|
|
249
|
+
**kwargs):
|
|
250
|
+
return DensificationTrainer(
|
|
251
|
+
model, scene_extent,
|
|
252
|
+
ImportancePruner(
|
|
253
|
+
NoopDensifier(model),
|
|
254
|
+
dataset,
|
|
255
|
+
importance_prune_from_iter=importance_prune_from_iter,
|
|
256
|
+
importance_prune_until_iter=importance_prune_until_iter,
|
|
257
|
+
importance_prune_interval=importance_prune_interval,
|
|
258
|
+
importance_score_resize=importance_score_resize,
|
|
259
|
+
importance_prune_type=importance_prune_type,
|
|
260
|
+
importance_prune_percent=importance_prune_percent,
|
|
261
|
+
importance_prune_thr_important_score=importance_prune_thr_important_score,
|
|
262
|
+
importance_prune_thr_v_important_score=importance_prune_thr_v_important_score,
|
|
263
|
+
importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
|
|
264
|
+
importance_prune_thr_count=importance_prune_thr_count,
|
|
265
|
+
importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
|
|
266
|
+
importance_prune_thr_T_alpha_avg=importance_prune_thr_T_alpha_avg,
|
|
267
|
+
importance_v_pow=importance_v_pow,
|
|
268
|
+
), *args, **kwargs
|
|
269
|
+
)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Callable, List
|
|
3
|
+
from gaussian_splatting import Camera, GaussianModel
|
|
4
|
+
from gaussian_splatting.dataset import TrainableCameraDataset
|
|
5
|
+
from gaussian_splatting.trainer import AbstractDensifier, DepthTrainerWrapper, NoopDensifier, SplitCloneDensifierTrainerWrapper
|
|
6
|
+
from .trainer import BasePruner, BasePruningTrainer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def PrunerInDensifyTrainerWrapper(
|
|
10
|
+
noargs_base_densifier_constructor: Callable[[GaussianModel, float, List[Camera]], AbstractDensifier],
|
|
11
|
+
model: GaussianModel,
|
|
12
|
+
scene_extent: float,
|
|
13
|
+
dataset: List[Camera],
|
|
14
|
+
*args,
|
|
15
|
+
prune_from_iter=1000,
|
|
16
|
+
prune_until_iter=15000,
|
|
17
|
+
prune_interval: int = 100,
|
|
18
|
+
box_size=1.,
|
|
19
|
+
lambda_mercy=1.,
|
|
20
|
+
mercy_minimum=3,
|
|
21
|
+
mercy_type='redundancy_opacity',
|
|
22
|
+
**kwargs):
|
|
23
|
+
return SplitCloneDensifierTrainerWrapper(
|
|
24
|
+
lambda model, scene_extent: BasePruner(
|
|
25
|
+
noargs_base_densifier_constructor(model, scene_extent, dataset),
|
|
26
|
+
scene_extent, dataset,
|
|
27
|
+
prune_from_iter=prune_from_iter,
|
|
28
|
+
prune_until_iter=prune_until_iter,
|
|
29
|
+
prune_interval=prune_interval,
|
|
30
|
+
box_size=box_size,
|
|
31
|
+
lambda_mercy=lambda_mercy,
|
|
32
|
+
mercy_minimum=mercy_minimum,
|
|
33
|
+
mercy_type=mercy_type,
|
|
34
|
+
),
|
|
35
|
+
model,
|
|
36
|
+
scene_extent,
|
|
37
|
+
*args, **kwargs
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def BasePrunerInDensifyTrainer(
|
|
42
|
+
model: GaussianModel,
|
|
43
|
+
scene_extent: float,
|
|
44
|
+
dataset: List[Camera],
|
|
45
|
+
*args, **kwargs):
|
|
46
|
+
return PrunerInDensifyTrainerWrapper(
|
|
47
|
+
lambda model, scene_extent, dataset: NoopDensifier(model),
|
|
48
|
+
model, scene_extent, dataset,
|
|
49
|
+
*args, **kwargs
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Depth trainer
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def DepthPruningTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
|
|
57
|
+
return DepthTrainerWrapper(BasePruningTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def DepthPrunerInDensifyTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
|
|
61
|
+
return DepthTrainerWrapper(BasePrunerInDensifyTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
PruningTrainer = DepthPruningTrainer
|
|
65
|
+
PrunerInDensifyTrainer = DepthPrunerInDensifyTrainer
|