reduced-3dgs 1.8.15__cp310-cp310-win_amd64.whl → 1.8.17__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reduced-3dgs might be problematic. Click here for more details.

@@ -1,2 +1,2 @@
1
- from .trainer import BasePruner, PrunerInDensify, BasePruningTrainer, BasePrunerInDensifyTrainer
2
- from .combinations import PruningTrainer, PrunerInDensifyTrainer
1
+ from .trainer import BasePruner, BasePruningTrainer
2
+ from .combinations import BasePrunerInDensifyTrainer, PruningTrainer, PrunerInDensifyTrainer
@@ -1,8 +1,39 @@
1
1
 
2
- from gaussian_splatting import GaussianModel
2
+ from typing import List
3
+ from gaussian_splatting import Camera, GaussianModel
3
4
  from gaussian_splatting.dataset import TrainableCameraDataset
4
- from gaussian_splatting.trainer import DepthTrainerWrapper
5
- from .trainer import BasePruningTrainer, BasePrunerInDensifyTrainer
5
+ from gaussian_splatting.trainer import DepthTrainerWrapper, NoopDensifier, DensificationTrainerWrapper
6
+ from .trainer import BasePruner, BasePruningTrainer
7
+
8
+
9
+ def BasePrunerInDensifyTrainer(
10
+ model: GaussianModel,
11
+ scene_extent: float,
12
+ dataset: List[Camera],
13
+ prune_from_iter=1000,
14
+ prune_until_iter=15000,
15
+ prune_interval: int = 100,
16
+ box_size=1.,
17
+ lambda_mercy=1.,
18
+ mercy_minimum=3,
19
+ mercy_type='redundancy_opacity',
20
+ *args, **kwargs):
21
+ return DensificationTrainerWrapper(
22
+ lambda model, scene_extent: BasePruner(
23
+ NoopDensifier(model),
24
+ dataset,
25
+ prune_from_iter=prune_from_iter,
26
+ prune_until_iter=prune_until_iter,
27
+ prune_interval=prune_interval,
28
+ box_size=box_size,
29
+ lambda_mercy=lambda_mercy,
30
+ mercy_minimum=mercy_minimum,
31
+ mercy_type=mercy_type,
32
+ ),
33
+ model,
34
+ scene_extent,
35
+ *args, **kwargs
36
+ )
6
37
 
7
38
 
8
39
  # Depth trainer
@@ -0,0 +1 @@
1
+ from .trainer import ImportancePruningTrainerWrapper, BaseImportancePruningTrainer, ImportancePruningTrainer
@@ -0,0 +1,347 @@
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from typing import NamedTuple
13
+ import torch.nn as nn
14
+ import torch
15
+ from . import _C
16
+
17
+
18
+ def cpu_deep_copy_tuple(input_tuple):
19
+ copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
20
+ return tuple(copied_tensors)
21
+
22
+ def rasterize_gaussians(
23
+ means3D,
24
+ means2D,
25
+ sh,
26
+ colors_precomp,
27
+ opacities,
28
+ scales,
29
+ rotations,
30
+ cov3Ds_precomp,
31
+ raster_settings,
32
+ ):
33
+ if raster_settings.f_count:
34
+ return _RasterizeGaussians.forward_count(
35
+ means3D,
36
+ means2D,
37
+ sh,
38
+ colors_precomp,
39
+ opacities,
40
+ scales,
41
+ rotations,
42
+ cov3Ds_precomp,
43
+ raster_settings,
44
+
45
+ )
46
+ return _RasterizeGaussians.apply(
47
+ means3D,
48
+ means2D,
49
+ sh,
50
+ colors_precomp,
51
+ opacities,
52
+ scales,
53
+ rotations,
54
+ cov3Ds_precomp,
55
+ raster_settings,
56
+ )
57
+
58
+ class _RasterizeGaussians(torch.autograd.Function):
59
+ @staticmethod
60
+ def forward(
61
+ ctx,
62
+ means3D,
63
+ means2D,
64
+ sh,
65
+ colors_precomp,
66
+ opacities,
67
+ scales,
68
+ rotations,
69
+ cov3Ds_precomp,
70
+ raster_settings,
71
+ ):
72
+
73
+ # Restructure arguments the way that the C++ lib expects them
74
+ args = (
75
+ raster_settings.bg,
76
+ means3D,
77
+ colors_precomp,
78
+ opacities,
79
+ scales,
80
+ rotations,
81
+ raster_settings.scale_modifier,
82
+ cov3Ds_precomp,
83
+ raster_settings.viewmatrix,
84
+ raster_settings.projmatrix,
85
+ raster_settings.tanfovx,
86
+ raster_settings.tanfovy,
87
+ raster_settings.image_height,
88
+ raster_settings.image_width,
89
+ sh,
90
+ raster_settings.sh_degree,
91
+ raster_settings.campos,
92
+ raster_settings.prefiltered,
93
+ raster_settings.debug
94
+ )
95
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii, geomBuffer, \
96
+ binningBuffer, imgBuffer = None, None, None, None, None, None, None, None, None
97
+
98
+ if raster_settings.f_count:
99
+ args = args + (raster_settings.f_count,)
100
+ if raster_settings.debug:
101
+ cpu_args = cpu_deep_copy_tuple(args)
102
+ try:
103
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii,\
104
+ geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
105
+ except Exception as ex:
106
+ torch.save(cpu_args, "snapshot_fw.dump")
107
+ print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
108
+ raise ex
109
+ else:
110
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii,\
111
+ geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
112
+ else:
113
+ # Invoke C++/CUDA rasterizer
114
+ if raster_settings.debug:
115
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
116
+ try:
117
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
118
+ except Exception as ex:
119
+ torch.save(cpu_args, "snapshot_fw.dump")
120
+ print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
121
+ raise ex
122
+ else:
123
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
124
+
125
+ # Keep relevant tensors for backward
126
+ ctx.raster_settings = raster_settings
127
+ ctx.num_rendered = num_rendered
128
+ ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
129
+ ctx.count = gaussians_count
130
+ ctx.opacity_important_score = opacity_important_score
131
+ ctx.T_alpha_important_score = T_alpha_important_score
132
+
133
+ if raster_settings.f_count:
134
+ return gaussians_count, opacity_important_score, T_alpha_important_score, color, radii
135
+
136
+ return color, radii
137
+
138
+ @staticmethod
139
+ def forward_count(
140
+ means3D,
141
+ means2D,
142
+ sh,
143
+ colors_precomp,
144
+ opacities,
145
+ scales,
146
+ rotations,
147
+ cov3Ds_precomp,
148
+ raster_settings,
149
+ ):
150
+ assert(raster_settings.f_count)
151
+ # Restructure arguments the way that the C++ lib expects them
152
+ args = (
153
+ raster_settings.bg,
154
+ means3D,
155
+ colors_precomp,
156
+ opacities,
157
+ scales,
158
+ rotations,
159
+ raster_settings.scale_modifier,
160
+ cov3Ds_precomp,
161
+ raster_settings.viewmatrix,
162
+ raster_settings.projmatrix,
163
+ raster_settings.tanfovx,
164
+ raster_settings.tanfovy,
165
+ raster_settings.image_height,
166
+ raster_settings.image_width,
167
+ sh,
168
+ raster_settings.sh_degree,
169
+ raster_settings.campos,
170
+ raster_settings.prefiltered,
171
+ raster_settings.debug,
172
+ raster_settings.f_count
173
+ )
174
+ # gaussians_count, important_score, num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = None, None, None, None, None, None, None, None
175
+ # Invoke C++/CUDA rasterizer
176
+ # TODO(Kevin): pass the count in, but the output include a count list
177
+ if raster_settings.debug:
178
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
179
+ try:
180
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
181
+ except Exception as ex:
182
+ torch.save(cpu_args, "snapshot_fw.dump")
183
+ print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
184
+ raise ex
185
+ else:
186
+ gaussians_count, opacity_important_score, T_alpha_important_score, num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.count_gaussians(*args)
187
+
188
+ return gaussians_count, opacity_important_score, T_alpha_important_score, color, radii
189
+
190
+ @staticmethod
191
+ def backward(ctx, grad_out_color, _):
192
+
193
+ # Restore necessary values from context
194
+ num_rendered = ctx.num_rendered
195
+ raster_settings = ctx.raster_settings
196
+ colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
197
+
198
+ # Restructure args as C++ method expects them
199
+ args = (raster_settings.bg,
200
+ means3D,
201
+ radii,
202
+ colors_precomp,
203
+ scales,
204
+ rotations,
205
+ raster_settings.scale_modifier,
206
+ cov3Ds_precomp,
207
+ raster_settings.viewmatrix,
208
+ raster_settings.projmatrix,
209
+ raster_settings.tanfovx,
210
+ raster_settings.tanfovy,
211
+ grad_out_color,
212
+ sh,
213
+ raster_settings.sh_degree,
214
+ raster_settings.campos,
215
+ geomBuffer,
216
+ num_rendered,
217
+ binningBuffer,
218
+ imgBuffer,
219
+ raster_settings.debug)
220
+
221
+ # Compute gradients for relevant tensors by invoking backward method
222
+ if raster_settings.debug:
223
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
224
+ try:
225
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
226
+ except Exception as ex:
227
+ torch.save(cpu_args, "snapshot_bw.dump")
228
+ print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
229
+ raise ex
230
+ else:
231
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
232
+
233
+ grads = (
234
+ grad_means3D,
235
+ grad_means2D,
236
+ grad_sh,
237
+ grad_colors_precomp,
238
+ grad_opacities,
239
+ grad_scales,
240
+ grad_rotations,
241
+ grad_cov3Ds_precomp,
242
+ None,
243
+ )
244
+
245
+ return grads
246
+
247
+ class GaussianRasterizationSettings(NamedTuple):
248
+ image_height: int
249
+ image_width: int
250
+ tanfovx : float
251
+ tanfovy : float
252
+ bg : torch.Tensor
253
+ scale_modifier : float
254
+ viewmatrix : torch.Tensor
255
+ projmatrix : torch.Tensor
256
+ sh_degree : int
257
+ campos : torch.Tensor
258
+ prefiltered : bool
259
+ debug : bool
260
+ f_count : bool
261
+
262
+ class GaussianRasterizer(nn.Module):
263
+ def __init__(self, raster_settings):
264
+ super().__init__()
265
+ self.raster_settings = raster_settings
266
+
267
+ def markVisible(self, positions):
268
+ # Mark visible points (based on frustum culling for camera) with a boolean
269
+ with torch.no_grad():
270
+ raster_settings = self.raster_settings
271
+ visible = _C.mark_visible(
272
+ positions,
273
+ raster_settings.viewmatrix,
274
+ raster_settings.projmatrix)
275
+
276
+ return visible
277
+
278
+ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
279
+
280
+ raster_settings = self.raster_settings
281
+
282
+ if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
283
+ raise Exception('Please provide excatly one of either SHs or precomputed colors!')
284
+
285
+ if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
286
+ raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
287
+
288
+ if shs is None:
289
+ shs = torch.Tensor([])
290
+ if colors_precomp is None:
291
+ colors_precomp = torch.Tensor([])
292
+
293
+ if scales is None:
294
+ scales = torch.Tensor([])
295
+ if rotations is None:
296
+ rotations = torch.Tensor([])
297
+ if cov3D_precomp is None:
298
+ cov3D_precomp = torch.Tensor([])
299
+
300
+ # Invoke C++/CUDA rasterization routine
301
+ return rasterize_gaussians(
302
+ means3D,
303
+ means2D,
304
+ shs,
305
+ colors_precomp,
306
+ opacities,
307
+ scales,
308
+ rotations,
309
+ cov3D_precomp,
310
+ raster_settings,
311
+ )
312
+
313
+ def forward_count(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
314
+
315
+ raster_settings = self.raster_settings
316
+
317
+ if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
318
+ raise Exception('Please provide excatly one of either SHs or precomputed colors!')
319
+
320
+ if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
321
+ raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
322
+
323
+ if shs is None:
324
+ shs = torch.Tensor([])
325
+ if colors_precomp is None:
326
+ colors_precomp = torch.Tensor([])
327
+
328
+ if scales is None:
329
+ scales = torch.Tensor([])
330
+ if rotations is None:
331
+ rotations = torch.Tensor([])
332
+ if cov3D_precomp is None:
333
+ cov3D_precomp = torch.Tensor([])
334
+
335
+ # Invoke C++/CUDA rasterization routine
336
+ return rasterize_gaussians(
337
+ means3D,
338
+ means2D,
339
+ shs,
340
+ colors_precomp,
341
+ opacities,
342
+ scales,
343
+ rotations,
344
+ cov3D_precomp,
345
+ raster_settings,
346
+ )
347
+
@@ -0,0 +1,141 @@
1
+ import math
2
+ import torch
3
+
4
+ from gaussian_splatting import Camera, GaussianModel
5
+ from gaussian_splatting.trainer import AbstractTrainer, TrainerWrapper, BaseTrainer, Trainer
6
+ from gaussian_splatting.dataset import CameraDataset
7
+ from .diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
8
+
9
+
10
+ def count_render(self: GaussianModel, viewpoint_camera: Camera):
11
+ """
12
+ Render the scene.
13
+
14
+ Background tensor (bg_color) must be on GPU!
15
+ """
16
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
17
+ screenspace_points = torch.zeros_like(self.get_xyz, dtype=self.get_xyz.dtype, requires_grad=True, device=self._xyz.device) + 0
18
+ try:
19
+ screenspace_points.retain_grad()
20
+ except:
21
+ pass
22
+
23
+ # Set up rasterization configuration
24
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
25
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
26
+
27
+ raster_settings = GaussianRasterizationSettings(
28
+ image_height=int(viewpoint_camera.image_height),
29
+ image_width=int(viewpoint_camera.image_width),
30
+ tanfovx=tanfovx,
31
+ tanfovy=tanfovy,
32
+ bg=viewpoint_camera.bg_color.to(self._xyz.device),
33
+ scale_modifier=self.scale_modifier,
34
+ viewmatrix=viewpoint_camera.world_view_transform,
35
+ projmatrix=viewpoint_camera.full_proj_transform,
36
+ sh_degree=self.active_sh_degree,
37
+ campos=viewpoint_camera.camera_center,
38
+ prefiltered=False,
39
+ debug=self.debug,
40
+ f_count=True,
41
+ )
42
+
43
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
44
+ means3D = self.get_xyz
45
+ means2D = screenspace_points
46
+ opacity = self.get_opacity
47
+
48
+ scales = self.get_scaling
49
+ rotations = self.get_rotation
50
+
51
+ shs = self.get_features
52
+
53
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
54
+ gaussians_count, opacity_important_score, T_alpha_important_score, rendered_image, radii = rasterizer(
55
+ means3D=means3D,
56
+ means2D=means2D,
57
+ shs=shs,
58
+ colors_precomp=None,
59
+ opacities=opacity,
60
+ scales=scales,
61
+ rotations=rotations,
62
+ cov3D_precomp=None)
63
+
64
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
65
+ # They will be excluded from value updates used in the splitting criteria.
66
+ return {
67
+ "render": rendered_image,
68
+ "viewspace_points": screenspace_points,
69
+ "visibility_filter": radii > 0,
70
+ "radii": radii,
71
+ "gaussians_count": gaussians_count,
72
+ "opacity_important_score": opacity_important_score,
73
+ "T_alpha_important_score": T_alpha_important_score
74
+ }
75
+
76
+
77
+ class ImportancePruner(TrainerWrapper):
78
+ def __init__(
79
+ self, base_trainer: AbstractTrainer,
80
+ dataset: CameraDataset,
81
+ importance_prune_at_steps=[15000],
82
+ ):
83
+ super().__init__(base_trainer)
84
+ self.dataset = dataset
85
+ self.importance_prune_at_steps = importance_prune_at_steps
86
+
87
+ def optim_step(self):
88
+ ret = super().optim_step()
89
+ if self.curr_step in self.importance_prune_at_steps:
90
+ gaussian_count = torch.zeros(self.model.get_xyz.shape[0], device=self.model.get_xyz.device, dtype=torch.int)
91
+ opacity_important_score = torch.zeros(self.model.get_xyz.shape[0], device=self.model.get_xyz.device, dtype=torch.float)
92
+ T_alpha_important_score = torch.zeros(self.model.get_xyz.shape[0], device=self.model.get_xyz.device, dtype=torch.float)
93
+ for camera in self.dataset:
94
+ out = count_render(self.model, camera)
95
+ gaussian_count += out["gaussians_count"]
96
+ opacity_important_score += out["opacity_important_score"]
97
+ T_alpha_important_score += out["T_alpha_important_score"]
98
+ pass
99
+ return ret
100
+
101
+
102
+ def ImportancePruningTrainerWrapper(
103
+ base_trainer_constructor,
104
+ model: GaussianModel,
105
+ scene_extent: float,
106
+ dataset: CameraDataset,
107
+ importance_prune_at_steps=[15000],
108
+ *args, **kwargs):
109
+ return ImportancePruner(
110
+ base_trainer_constructor(model, scene_extent, dataset, *args, **kwargs),
111
+ dataset,
112
+ importance_prune_at_steps=importance_prune_at_steps,
113
+ )
114
+
115
+
116
+ def BaseImportancePruningTrainer(
117
+ model: GaussianModel,
118
+ scene_extent: float,
119
+ dataset: CameraDataset,
120
+ importance_prune_at_steps=[15000],
121
+ *args, **kwargs):
122
+ return ImportancePruningTrainerWrapper(
123
+ lambda model, scene_extent, dataset, *args, **kwargs: BaseTrainer(model, scene_extent, *args, **kwargs),
124
+ model, scene_extent, dataset,
125
+ importance_prune_at_steps=importance_prune_at_steps,
126
+ *args, **kwargs,
127
+ )
128
+
129
+
130
+ def ImportancePruningTrainer(
131
+ model: GaussianModel,
132
+ scene_extent: float,
133
+ dataset: CameraDataset,
134
+ importance_prune_at_steps=[15000],
135
+ *args, **kwargs):
136
+ return ImportancePruningTrainerWrapper(
137
+ lambda model, scene_extent, dataset, *args, **kwargs: Trainer(model, scene_extent, *args, **kwargs),
138
+ model, scene_extent, dataset,
139
+ importance_prune_at_steps=importance_prune_at_steps,
140
+ *args, **kwargs,
141
+ )
@@ -1,7 +1,7 @@
1
1
  from typing import List
2
2
  import torch
3
3
  from gaussian_splatting import GaussianModel, Camera
4
- from gaussian_splatting.trainer import AbstractDensifier, Densifier, DensificationInstruct, DensificationTrainer
4
+ from gaussian_splatting.trainer import AbstractDensifier, DensifierWrapper, DensificationTrainer, NoopDensifier
5
5
  from reduced_3dgs.diff_gaussian_rasterization._C import sphere_ellipsoid_intersection, allocate_minimum_redundancy_value, find_minimum_projected_pixel_size
6
6
  from reduced_3dgs.simple_knn._C import distIndex2
7
7
 
@@ -79,9 +79,10 @@ def mercy_gaussians(
79
79
  return mask
80
80
 
81
81
 
82
- class BasePruner(AbstractDensifier):
82
+ class BasePruner(DensifierWrapper):
83
83
  def __init__(
84
- self, model: GaussianModel, dataset: List[Camera],
84
+ self, base_densifier: AbstractDensifier,
85
+ dataset: List[Camera],
85
86
  prune_from_iter=1000,
86
87
  prune_until_iter=15000,
87
88
  prune_interval: int = 100,
@@ -89,7 +90,7 @@ class BasePruner(AbstractDensifier):
89
90
  lambda_mercy=1.,
90
91
  mercy_minimum=3,
91
92
  mercy_type='redundancy_opacity'):
92
- self._model = model
93
+ super().__init__(base_densifier)
93
94
  self.dataset = dataset
94
95
  self.prune_from_iter = prune_from_iter
95
96
  self.prune_until_iter = prune_until_iter
@@ -99,14 +100,12 @@ class BasePruner(AbstractDensifier):
99
100
  self.mercy_minimum = mercy_minimum
100
101
  self.mercy_type = mercy_type
101
102
 
102
- @property
103
- def model(self) -> GaussianModel:
104
- return self._model
105
-
106
- def densify_and_prune(self, loss, out, camera, step: int) -> DensificationInstruct:
103
+ def densify_and_prune(self, loss, out, camera, step: int):
104
+ ret = super().densify_and_prune(loss, out, camera, step)
107
105
  if self.prune_from_iter <= step <= self.prune_until_iter and step % self.prune_interval == 0:
108
- return DensificationInstruct(remove_mask=mercy_gaussians(self.model, self.dataset, self.box_size, self.lambda_mercy, self.mercy_minimum, self.mercy_type))
109
- return DensificationInstruct()
106
+ remove_mask = mercy_gaussians(self.model, self.dataset, self.box_size, self.lambda_mercy, self.mercy_minimum, self.mercy_type)
107
+ ret = ret._replace(remove_mask=remove_mask if ret.remove_mask is None else torch.logical_or(remove_mask, ret.remove_mask))
108
+ return ret
110
109
 
111
110
 
112
111
  def BasePruningTrainer(
@@ -124,73 +123,14 @@ def BasePruningTrainer(
124
123
  return DensificationTrainer(
125
124
  model, scene_extent,
126
125
  BasePruner(
127
- model, dataset,
128
- prune_from_iter, prune_until_iter, prune_interval,
129
- box_size, lambda_mercy, mercy_minimum, mercy_type
130
- ), *args, **kwargs
131
- )
132
-
133
-
134
- class PrunerInDensify(Densifier):
135
- def __init__(
136
- self, model: GaussianModel, scene_extent, dataset: List[Camera],
137
- box_size=1.,
138
- lambda_mercy=1.,
139
- mercy_minimum=3,
140
- mercy_type='redundancy_opacity',
141
- *args, **kwargs):
142
- super().__init__(model, scene_extent, *args, **kwargs)
143
- self.dataset = dataset
144
- self.box_size = box_size
145
- self.lambda_mercy = lambda_mercy
146
- self.mercy_minimum = mercy_minimum
147
- self.mercy_type = mercy_type
148
-
149
- def prune(self) -> torch.Tensor:
150
- return torch.logical_or(mercy_gaussians(self.model, self.dataset, self.box_size, self.lambda_mercy, self.mercy_minimum, self.mercy_type), super().prune())
151
-
152
-
153
- def BasePrunerInDensifyTrainer(
154
- model: GaussianModel,
155
- scene_extent: float,
156
-
157
- dataset: List[Camera],
158
- box_size=1.,
159
- lambda_mercy=1.,
160
- mercy_minimum=3,
161
- mercy_type='redundancy_opacity',
162
-
163
- densify_from_iter=500,
164
- densify_until_iter=15000,
165
- densify_interval=100,
166
- densify_grad_threshold=0.0002,
167
- densify_opacity_threshold=0.005,
168
- densify_percent_dense=0.01,
169
- densify_percent_too_big=0.8,
170
-
171
- prune_from_iter=1000,
172
- prune_until_iter=15000,
173
- prune_interval=100,
174
- prune_screensize_threshold=20,
175
- prune_percent_too_big=1,
176
-
177
- *args, **kwargs):
178
- return DensificationTrainer(
179
- model, scene_extent,
180
- PrunerInDensify(
181
- model, scene_extent, dataset,
182
- box_size, lambda_mercy, mercy_minimum, mercy_type,
183
- densify_from_iter=densify_from_iter,
184
- densify_until_iter=densify_until_iter,
185
- densify_interval=densify_interval,
186
- densify_grad_threshold=densify_grad_threshold,
187
- densify_opacity_threshold=densify_opacity_threshold,
188
- densify_percent_dense=densify_percent_dense,
189
- densify_percent_too_big=densify_percent_too_big,
126
+ NoopDensifier(model),
127
+ dataset,
190
128
  prune_from_iter=prune_from_iter,
191
129
  prune_until_iter=prune_until_iter,
192
130
  prune_interval=prune_interval,
193
- prune_screensize_threshold=prune_screensize_threshold,
194
- prune_percent_too_big=prune_percent_too_big
131
+ box_size=box_size,
132
+ lambda_mercy=lambda_mercy,
133
+ mercy_minimum=mercy_minimum,
134
+ mercy_type=mercy_type,
195
135
  ), *args, **kwargs
196
136
  )
reduced_3dgs/train.py CHANGED
@@ -14,6 +14,7 @@ from gaussian_splatting.trainer.extensions import ScaleRegularizeTrainerWrapper
14
14
  from reduced_3dgs.quantization import AbstractQuantizer, VectorQuantizeTrainerWrapper
15
15
  from reduced_3dgs.shculling import VariableSHGaussianModel, SHCullingTrainer
16
16
  from reduced_3dgs.pruning import PruningTrainer
17
+ from reduced_3dgs.pruning.importance import ImportancePruningTrainerWrapper
17
18
  from reduced_3dgs.combinations import OpacityResetPrunerInDensifyTrainer, SHCullingDensifyTrainer, SHCullingPruneTrainer, SHCullingPruningDensifyTrainer
18
19
  from reduced_3dgs.combinations import CameraTrainableVariableSHGaussianModel, CameraSHCullingTrainer, CameraPruningTrainer
19
20
  from reduced_3dgs.combinations import CameraOpacityResetPrunerInDensifyTrainer, CameraSHCullingDensifyTrainer, CameraSHCullingPruneTrainer, CameraSHCullingPruningDensifyTrainer
@@ -27,6 +28,7 @@ basemodes = {
27
28
  "prune-shculling": SHCullingPruneTrainer,
28
29
  "densify-prune-shculling": SHCullingPruningDensifyTrainer,
29
30
  }
31
+ basemodes = {k: lambda *args, **kwargs: ImportancePruningTrainerWrapper(v, *args, **kwargs) for k, v in basemodes.items()}
30
32
  cameramodes = {
31
33
  "camera-shculling": CameraSHCullingTrainer,
32
34
  "camera-pruning": CameraPruningTrainer,
@@ -35,6 +37,7 @@ cameramodes = {
35
37
  "camera-prune-shculling": CameraSHCullingPruneTrainer,
36
38
  "camera-densify-prune-shculling": CameraSHCullingPruningDensifyTrainer,
37
39
  }
40
+ cameramodes = {k: lambda *args, **kwargs: ImportancePruningTrainerWrapper(v, *args, **kwargs) for k, v in cameramodes.items()}
38
41
 
39
42
 
40
43
  def prepare_quantizer(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: reduced_3dgs
3
- Version: 1.8.15
3
+ Version: 1.8.17
4
4
  Summary: Refactored code for the paper "Reducing the Memory Footprint of 3D Gaussian Splatting"
5
5
  Home-page: https://github.com/yindaheng98/reduced-3dgs
6
6
  Author: yindaheng98
@@ -0,0 +1,27 @@
1
+ reduced_3dgs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ reduced_3dgs/combinations.py,sha256=YDDNRGQlOuLsXnOOMiZ6mqG0cx1ICtG7jz6ObVx4OUo,4794
3
+ reduced_3dgs/quantize.py,sha256=Y44qHyFdOIqke7NoeqXmyKloS43j-al74ZiNsuZZHbM,2527
4
+ reduced_3dgs/train.py,sha256=VKz55aF-Pq-cs9KSa1UHwwpjcy1oTQPFWPhsGbAZHp4,10158
5
+ reduced_3dgs/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd,sha256=zoV665wa2AxMAID1Rz35ns35h5ASHPY6fSPcUbYaxvY,1618432
6
+ reduced_3dgs/diff_gaussian_rasterization/__init__.py,sha256=oV6JjTc-50MscX4XHeIWSgLr3l8Y25knBIs-0gRbJr4,7932
7
+ reduced_3dgs/pruning/__init__.py,sha256=ZhjYTfHUFlBoQ9NROTtoRKSACaFxU-wb5rd9EVLPnhk,147
8
+ reduced_3dgs/pruning/combinations.py,sha256=7MdBWls_ibtr8VGNXZnMQHBMl7buO6yWxaADFvH8K_M,1785
9
+ reduced_3dgs/pruning/trainer.py,sha256=bSrxEOztamJ3ap0PJI8jahRN4mGVPJWwr53vXbfRj4s,6450
10
+ reduced_3dgs/pruning/importance/__init__.py,sha256=MIYdZ0Gx_2WTNTpy8qylNaNxvMG5J8paEnqUrPNAvSw,110
11
+ reduced_3dgs/pruning/importance/trainer.py,sha256=LMy2Y_Br91lrhHtwO7NJSRqaju-ZpIofN5H1LdxQYq0,5361
12
+ reduced_3dgs/pruning/importance/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd,sha256=c8VNpWSd---FOjTu98PF1wI86kRUK_elpDFy_1XLBbU,1308672
13
+ reduced_3dgs/pruning/importance/diff_gaussian_rasterization/__init__.py,sha256=Tix8auyXBb_QFQtXrV3sLE9kdnl5zgHH0BbqcFzDp84,12850
14
+ reduced_3dgs/quantization/__init__.py,sha256=1z1xMn3yj9u7cR9JizGrI3WSyIES_Tqq6oDquvglSeo,225
15
+ reduced_3dgs/quantization/abc.py,sha256=rsi8HFRwQCltWTYiJ3BpygtQDT7hK6J01jKMOboOY8w,1910
16
+ reduced_3dgs/quantization/exclude_zeros.py,sha256=fKSgjHous4OpdI6mQi9z23if9jnbB79w2jChpxkCJWw,2381
17
+ reduced_3dgs/quantization/quantizer.py,sha256=-B07X1VrV7E1fT_NXXlZ0FLAb-c_jBUlZ8274KwHyVw,16614
18
+ reduced_3dgs/quantization/wrapper.py,sha256=cyXqfJgo9b3fS7DYXxOk5LmQudvrEhweOebFsjRnXiQ,2549
19
+ reduced_3dgs/shculling/__init__.py,sha256=nP2BejDCUdCmJNRbg0hfhHREO6jyZXwIcRiw6ttVgqo,149
20
+ reduced_3dgs/shculling/gaussian_model.py,sha256=f8QWaL09vaV9Tcf6Dngjg_Fmk1wTQPAjWhuhI_N02Y8,2877
21
+ reduced_3dgs/shculling/trainer.py,sha256=9hwR77djhZpyf-URhwKHjnLbe0ZAOS-DIw58RzkcHXQ,6369
22
+ reduced_3dgs/simple_knn/_C.cp310-win_amd64.pyd,sha256=_NQNjRHGEgAPEn-k664JdF1-Owu16qkyKGcmC9QNeww,1255424
23
+ reduced_3dgs-1.8.17.dist-info/licenses/LICENSE.md,sha256=LQ4_LAqlncGkg_mQy5ykMAFtQDSPB0eKmIEtBut0yjw,4916
24
+ reduced_3dgs-1.8.17.dist-info/METADATA,sha256=pwG3GbJX9EUJaIY5Yp9bDLeLPXKAwDFmKX_t0VBYqsE,13015
25
+ reduced_3dgs-1.8.17.dist-info/WHEEL,sha256=f7LviPjlKeAUpCmKrQrAXn06zBZT0604QctSZKHK6ZM,101
26
+ reduced_3dgs-1.8.17.dist-info/top_level.txt,sha256=PpU5aT3-baSCdqCtTaZknoB32H93UeKCkYDkRCCZMEI,13
27
+ reduced_3dgs-1.8.17.dist-info/RECORD,,
@@ -1,23 +0,0 @@
1
- reduced_3dgs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- reduced_3dgs/combinations.py,sha256=YDDNRGQlOuLsXnOOMiZ6mqG0cx1ICtG7jz6ObVx4OUo,4794
3
- reduced_3dgs/quantize.py,sha256=Y44qHyFdOIqke7NoeqXmyKloS43j-al74ZiNsuZZHbM,2527
4
- reduced_3dgs/train.py,sha256=WdxDg2MC4HCwIaGta3c7oYVPA9gb6CMqHLDD6WeOwEk,9829
5
- reduced_3dgs/diff_gaussian_rasterization/_C.cp310-win_amd64.pyd,sha256=rKopFy0xnrDeJv6BGVT5O5C7Q6fvYbMBTQB0w3ltLaU,1618432
6
- reduced_3dgs/diff_gaussian_rasterization/__init__.py,sha256=oV6JjTc-50MscX4XHeIWSgLr3l8Y25knBIs-0gRbJr4,7932
7
- reduced_3dgs/pruning/__init__.py,sha256=T6vvO5RFMFUfHl2k7kleZ_AGgDwyMBfc7jxNJ-CRULA,164
8
- reduced_3dgs/pruning/combinations.py,sha256=ZaqvmmWsl1pJJZbA_yJOJVbv-MoQHLO8Lts3KT-5r_Q,833
9
- reduced_3dgs/pruning/trainer.py,sha256=5mOfYPEO6ZOW5ZSeI33LoN4EmGYxWKz2pMQw8RMRiEU,8530
10
- reduced_3dgs/quantization/__init__.py,sha256=1z1xMn3yj9u7cR9JizGrI3WSyIES_Tqq6oDquvglSeo,225
11
- reduced_3dgs/quantization/abc.py,sha256=rsi8HFRwQCltWTYiJ3BpygtQDT7hK6J01jKMOboOY8w,1910
12
- reduced_3dgs/quantization/exclude_zeros.py,sha256=fKSgjHous4OpdI6mQi9z23if9jnbB79w2jChpxkCJWw,2381
13
- reduced_3dgs/quantization/quantizer.py,sha256=-B07X1VrV7E1fT_NXXlZ0FLAb-c_jBUlZ8274KwHyVw,16614
14
- reduced_3dgs/quantization/wrapper.py,sha256=cyXqfJgo9b3fS7DYXxOk5LmQudvrEhweOebFsjRnXiQ,2549
15
- reduced_3dgs/shculling/__init__.py,sha256=nP2BejDCUdCmJNRbg0hfhHREO6jyZXwIcRiw6ttVgqo,149
16
- reduced_3dgs/shculling/gaussian_model.py,sha256=f8QWaL09vaV9Tcf6Dngjg_Fmk1wTQPAjWhuhI_N02Y8,2877
17
- reduced_3dgs/shculling/trainer.py,sha256=9hwR77djhZpyf-URhwKHjnLbe0ZAOS-DIw58RzkcHXQ,6369
18
- reduced_3dgs/simple_knn/_C.cp310-win_amd64.pyd,sha256=KabJeO-Ax9RSS0cI_W3OoIl_oI0yawmnfQfpgsT6MPw,1255424
19
- reduced_3dgs-1.8.15.dist-info/licenses/LICENSE.md,sha256=LQ4_LAqlncGkg_mQy5ykMAFtQDSPB0eKmIEtBut0yjw,4916
20
- reduced_3dgs-1.8.15.dist-info/METADATA,sha256=0O3KJo4RTyBSVD47eD6gGTGg3MgxDSYAJLDs_-vMm3U,13015
21
- reduced_3dgs-1.8.15.dist-info/WHEEL,sha256=f7LviPjlKeAUpCmKrQrAXn06zBZT0604QctSZKHK6ZM,101
22
- reduced_3dgs-1.8.15.dist-info/top_level.txt,sha256=PpU5aT3-baSCdqCtTaZknoB32H93UeKCkYDkRCCZMEI,13
23
- reduced_3dgs-1.8.15.dist-info/RECORD,,