reduced-3dgs 1.8.19__tar.gz → 1.9.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reduced-3dgs might be problematic. Click here for more details.

Files changed (48) hide show
  1. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/PKG-INFO +1 -1
  2. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/importance/combinations.py +19 -3
  3. reduced_3dgs-1.9.1/reduced_3dgs/importance/trainer.py +241 -0
  4. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs.egg-info/PKG-INFO +1 -1
  5. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/setup.py +1 -1
  6. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/gaussian-importance/cuda_rasterizer/forward.cu +3 -3
  7. reduced_3dgs-1.8.19/reduced_3dgs/importance/trainer.py +0 -130
  8. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/LICENSE.md +0 -0
  9. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/README.md +0 -0
  10. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/__init__.py +0 -0
  11. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/combinations.py +0 -0
  12. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/importance/__init__.py +0 -0
  13. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/pruning/__init__.py +0 -0
  14. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/pruning/combinations.py +0 -0
  15. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/pruning/trainer.py +0 -0
  16. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/quantization/__init__.py +0 -0
  17. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/quantization/abc.py +0 -0
  18. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/quantization/exclude_zeros.py +0 -0
  19. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/quantization/quantizer.py +0 -0
  20. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/quantization/wrapper.py +0 -0
  21. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/quantize.py +0 -0
  22. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/shculling/__init__.py +0 -0
  23. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/shculling/gaussian_model.py +0 -0
  24. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/shculling/trainer.py +0 -0
  25. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs/train.py +0 -0
  26. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs.egg-info/SOURCES.txt +0 -0
  27. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs.egg-info/dependency_links.txt +0 -0
  28. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs.egg-info/requires.txt +0 -0
  29. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/reduced_3dgs.egg-info/top_level.txt +0 -0
  30. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/setup.cfg +0 -0
  31. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu +0 -0
  32. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu +0 -0
  33. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu +0 -0
  34. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py +0 -0
  35. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/ext.cpp +0 -0
  36. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/rasterize_points.cu +0 -0
  37. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/reduced_3dgs/kmeans.cu +0 -0
  38. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/reduced_3dgs/redundancy_score.cu +0 -0
  39. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/reduced_3dgs/sh_culling.cu +0 -0
  40. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/diff-gaussian-rasterization/reduced_3dgs.cu +0 -0
  41. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/gaussian-importance/cuda_rasterizer/backward.cu +0 -0
  42. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/gaussian-importance/cuda_rasterizer/rasterizer_impl.cu +0 -0
  43. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/gaussian-importance/diff_gaussian_rasterization/__init__.py +0 -0
  44. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/gaussian-importance/ext.cpp +0 -0
  45. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/gaussian-importance/rasterize_points.cu +0 -0
  46. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/simple-knn/ext.cpp +0 -0
  47. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/simple-knn/simple_knn.cu +0 -0
  48. {reduced_3dgs-1.8.19 → reduced_3dgs-1.9.1}/submodules/simple-knn/spatial.cu +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: reduced_3dgs
3
- Version: 1.8.19
3
+ Version: 1.9.1
4
4
  Summary: Refactored code for the paper "Reducing the Memory Footprint of 3D Gaussian Splatting"
5
5
  Home-page: https://github.com/yindaheng98/reduced-3dgs
6
6
  Author: yindaheng98
@@ -10,9 +10,17 @@ def BaseImportancePrunerInDensifyTrainer(
10
10
  scene_extent: float,
11
11
  dataset: List[Camera],
12
12
  *args,
13
- importance_prune_from_iter=1000,
14
- importance_prune_until_iter=15000,
15
- importance_prune_interval=100,
13
+ importance_prune_from_iter=15000,
14
+ importance_prune_until_iter=20000,
15
+ importance_prune_interval: int = 1000,
16
+ importance_prune_type="comprehensive",
17
+ importance_prune_percent=0.1,
18
+ importance_prune_thr_important_score=None,
19
+ importance_prune_thr_v_important_score=1.0,
20
+ importance_prune_thr_max_v_important_score=None,
21
+ importance_prune_thr_count=1,
22
+ importance_prune_thr_T_alpha=0.01,
23
+ importance_v_pow=0.1,
16
24
  **kwargs):
17
25
  return DensificationTrainerWrapper(
18
26
  lambda model, scene_extent: ImportancePruner(
@@ -21,6 +29,14 @@ def BaseImportancePrunerInDensifyTrainer(
21
29
  importance_prune_from_iter=importance_prune_from_iter,
22
30
  importance_prune_until_iter=importance_prune_until_iter,
23
31
  importance_prune_interval=importance_prune_interval,
32
+ importance_prune_type=importance_prune_type,
33
+ importance_prune_percent=importance_prune_percent,
34
+ importance_prune_thr_important_score=importance_prune_thr_important_score,
35
+ importance_prune_thr_v_important_score=importance_prune_thr_v_important_score,
36
+ importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
37
+ importance_prune_thr_count=importance_prune_thr_count,
38
+ importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
39
+ importance_v_pow=importance_v_pow,
24
40
  ),
25
41
  model,
26
42
  scene_extent,
@@ -0,0 +1,241 @@
1
+ import math
2
+ from typing import Callable, List
3
+ import torch
4
+
5
+ from gaussian_splatting import Camera, GaussianModel
6
+ from gaussian_splatting.trainer import AbstractDensifier, DensifierWrapper, DensificationTrainer, NoopDensifier
7
+ from gaussian_splatting.dataset import CameraDataset
8
+ from .diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
9
+
10
+
11
+ def count_render(self: GaussianModel, viewpoint_camera: Camera):
12
+ """
13
+ Render the scene.
14
+
15
+ Background tensor (bg_color) must be on GPU!
16
+ """
17
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
18
+ screenspace_points = torch.zeros_like(self.get_xyz, dtype=self.get_xyz.dtype, requires_grad=True, device=self._xyz.device) + 0
19
+ try:
20
+ screenspace_points.retain_grad()
21
+ except:
22
+ pass
23
+
24
+ # Set up rasterization configuration
25
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
26
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
27
+
28
+ raster_settings = GaussianRasterizationSettings(
29
+ image_height=int(viewpoint_camera.image_height),
30
+ image_width=int(viewpoint_camera.image_width),
31
+ tanfovx=tanfovx,
32
+ tanfovy=tanfovy,
33
+ bg=viewpoint_camera.bg_color.to(self._xyz.device),
34
+ scale_modifier=self.scale_modifier,
35
+ viewmatrix=viewpoint_camera.world_view_transform,
36
+ projmatrix=viewpoint_camera.full_proj_transform,
37
+ sh_degree=self.active_sh_degree,
38
+ campos=viewpoint_camera.camera_center,
39
+ prefiltered=False,
40
+ debug=self.debug,
41
+ f_count=True,
42
+ )
43
+
44
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
45
+ means3D = self.get_xyz
46
+ means2D = screenspace_points
47
+ opacity = self.get_opacity
48
+
49
+ scales = self.get_scaling
50
+ rotations = self.get_rotation
51
+
52
+ shs = self.get_features
53
+
54
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
55
+ gaussians_count, opacity_important_score, T_alpha_important_score, rendered_image, radii = rasterizer(
56
+ means3D=means3D,
57
+ means2D=means2D,
58
+ shs=shs,
59
+ colors_precomp=None,
60
+ opacities=opacity,
61
+ scales=scales,
62
+ rotations=rotations,
63
+ cov3D_precomp=None)
64
+
65
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
66
+ # They will be excluded from value updates used in the splitting criteria.
67
+ return {
68
+ "render": rendered_image,
69
+ "viewspace_points": screenspace_points,
70
+ "visibility_filter": radii > 0,
71
+ "radii": radii,
72
+ "gaussians_count": gaussians_count,
73
+ "opacity_important_score": opacity_important_score,
74
+ "T_alpha_important_score": T_alpha_important_score
75
+ }
76
+
77
+
78
+ def prune_list(model: GaussianModel, dataset: CameraDataset):
79
+ gaussian_count = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.int)
80
+ opacity_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
81
+ T_alpha_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
82
+ for camera in dataset:
83
+ out = count_render(model, camera)
84
+ gaussian_count += out["gaussians_count"]
85
+ opacity_important_score += out["opacity_important_score"]
86
+ T_alpha_important_score += out["T_alpha_important_score"]
87
+ return gaussian_count, opacity_important_score, T_alpha_important_score
88
+
89
+
90
+ # return importance score with adaptive volume measure described in paper
91
+ def calculate_v_imp_score(gaussians: GaussianModel, imp_list, v_pow):
92
+ """
93
+ :param gaussians: A data structure containing Gaussian components with a get_scaling method.
94
+ :param imp_list: The importance scores for each Gaussian component.
95
+ :param v_pow: The power to which the volume ratios are raised.
96
+ :return: A list of adjusted values (v_list) used for pruning.
97
+ """
98
+ # Calculate the volume of each Gaussian component
99
+ volume = torch.prod(gaussians.get_scaling, dim=1)
100
+ # Determine the kth_percent_largest value
101
+ index = int(len(volume) * 0.9)
102
+ sorted_volume, _ = torch.sort(volume, descending=True)
103
+ kth_percent_largest = sorted_volume[index]
104
+ # Calculate v_list
105
+ v_list = torch.pow(volume / kth_percent_largest, v_pow)
106
+ v_list = v_list * imp_list
107
+ return v_list
108
+
109
+
110
+ def score2mask(percent, import_score: list, threshold=None):
111
+ sorted_tensor, _ = torch.sort(import_score, dim=0)
112
+ index_nth_percentile = int(percent * (sorted_tensor.shape[0] - 1))
113
+ value_nth_percentile = sorted_tensor[index_nth_percentile]
114
+ thr = min(threshold, value_nth_percentile) if threshold is not None else value_nth_percentile
115
+ prune_mask = (import_score <= thr)
116
+ return prune_mask
117
+
118
+
119
+ def prune_gaussians(
120
+ gaussians: GaussianModel, dataset: CameraDataset,
121
+ prune_type="comprehensive",
122
+ prune_percent=0.1,
123
+ prune_thr_important_score=None,
124
+ prune_thr_v_important_score=1.0,
125
+ prune_thr_max_v_important_score=None,
126
+ prune_thr_count=1,
127
+ prune_thr_T_alpha=0.01,
128
+ v_pow=0.1):
129
+ gaussian_list, opacity_imp_list, T_alpha_imp_list = prune_list(gaussians, dataset)
130
+ match prune_type:
131
+ case "important_score":
132
+ mask = score2mask(prune_percent, opacity_imp_list, prune_thr_important_score)
133
+ case "v_important_score":
134
+ v_list = calculate_v_imp_score(gaussians, opacity_imp_list, v_pow)
135
+ mask = score2mask(prune_percent, v_list, prune_thr_v_important_score)
136
+ case "max_v_important_score":
137
+ v_list = opacity_imp_list * torch.max(gaussians.get_scaling, dim=1)[0]
138
+ mask = score2mask(prune_percent, v_list, prune_thr_max_v_important_score)
139
+ case "count":
140
+ mask = score2mask(prune_percent, gaussian_list, prune_thr_count)
141
+ case "T_alpha":
142
+ # new importance score defined by doji
143
+ mask = score2mask(prune_percent, T_alpha_imp_list, prune_thr_T_alpha)
144
+ case "comprehensive":
145
+ mask = torch.zeros_like(gaussian_list, dtype=torch.bool)
146
+ if prune_thr_important_score is not None:
147
+ mask |= score2mask(prune_percent, opacity_imp_list, prune_thr_important_score)
148
+ if prune_thr_v_important_score is not None:
149
+ v_list = calculate_v_imp_score(gaussians, opacity_imp_list, v_pow)
150
+ mask |= score2mask(prune_percent, v_list, prune_thr_v_important_score)
151
+ if prune_thr_max_v_important_score is not None:
152
+ v_list = opacity_imp_list * torch.max(gaussians.get_scaling, dim=1)[0]
153
+ mask |= score2mask(prune_percent, v_list, prune_thr_max_v_important_score)
154
+ if prune_thr_count is not None:
155
+ mask |= score2mask(prune_percent, gaussian_list, prune_thr_count)
156
+ if prune_thr_T_alpha is not None:
157
+ mask |= score2mask(prune_percent, T_alpha_imp_list, prune_thr_T_alpha)
158
+ case _:
159
+ raise Exception("Unsupportive prunning method")
160
+ return mask
161
+
162
+
163
+ class ImportancePruner(DensifierWrapper):
164
+ def __init__(
165
+ self, base_densifier: AbstractDensifier,
166
+ dataset: CameraDataset,
167
+ importance_prune_from_iter=15000,
168
+ importance_prune_until_iter=20000,
169
+ importance_prune_interval: int = 1000,
170
+ importance_prune_type="comprehensive",
171
+ importance_prune_percent=0.1,
172
+ importance_prune_thr_important_score=None,
173
+ importance_prune_thr_v_important_score=1.0,
174
+ importance_prune_thr_max_v_important_score=None,
175
+ importance_prune_thr_count=1,
176
+ importance_prune_thr_T_alpha=0.01,
177
+ importance_v_pow=0.1
178
+ ):
179
+ super().__init__(base_densifier)
180
+ self.dataset = dataset
181
+ self.importance_prune_from_iter = importance_prune_from_iter
182
+ self.importance_prune_until_iter = importance_prune_until_iter
183
+ self.importance_prune_interval = importance_prune_interval
184
+ self.prune_percent = importance_prune_percent
185
+ self.prune_thr_important_score = importance_prune_thr_important_score
186
+ self.prune_thr_v_important_score = importance_prune_thr_v_important_score
187
+ self.prune_thr_max_v_important_score = importance_prune_thr_max_v_important_score
188
+ self.prune_thr_count = importance_prune_thr_count
189
+ self.prune_thr_T_alpha = importance_prune_thr_T_alpha
190
+ self.v_pow = importance_v_pow
191
+ self.prune_type = importance_prune_type
192
+
193
+ def densify_and_prune(self, loss, out, camera, step: int):
194
+ ret = super().densify_and_prune(loss, out, camera, step)
195
+ if self.importance_prune_from_iter <= step <= self.importance_prune_until_iter and step % self.importance_prune_interval == 0:
196
+ remove_mask = prune_gaussians(
197
+ self.model, self.dataset,
198
+ self.prune_type, self.prune_percent,
199
+ self.prune_thr_important_score, self.prune_thr_v_important_score,
200
+ self.prune_thr_max_v_important_score, self.prune_thr_count,
201
+ self.prune_thr_T_alpha, self.v_pow,
202
+ )
203
+ ret = ret._replace(remove_mask=remove_mask if ret.remove_mask is None else torch.logical_or(ret.remove_mask, remove_mask))
204
+ return ret
205
+
206
+
207
+ def BaseImportancePruningTrainer(
208
+ model: GaussianModel,
209
+ scene_extent: float,
210
+ dataset: List[Camera],
211
+ *args,
212
+ importance_prune_from_iter=15000,
213
+ importance_prune_until_iter=20000,
214
+ importance_prune_interval: int = 1000,
215
+ importance_prune_type="comprehensive",
216
+ importance_prune_percent=0.1,
217
+ importance_prune_thr_important_score=None,
218
+ importance_prune_thr_v_important_score=1.0,
219
+ importance_prune_thr_max_v_important_score=None,
220
+ importance_prune_thr_count=1,
221
+ importance_prune_thr_T_alpha=0.01,
222
+ importance_v_pow=0.1,
223
+ **kwargs):
224
+ return DensificationTrainer(
225
+ model, scene_extent,
226
+ ImportancePruner(
227
+ NoopDensifier(model),
228
+ dataset,
229
+ importance_prune_from_iter=importance_prune_from_iter,
230
+ importance_prune_until_iter=importance_prune_until_iter,
231
+ importance_prune_interval=importance_prune_interval,
232
+ importance_prune_type=importance_prune_type,
233
+ importance_prune_percent=importance_prune_percent,
234
+ importance_prune_thr_important_score=importance_prune_thr_important_score,
235
+ importance_prune_thr_v_important_score=importance_prune_thr_v_important_score,
236
+ importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
237
+ importance_prune_thr_count=importance_prune_thr_count,
238
+ importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
239
+ importance_v_pow=importance_v_pow,
240
+ ), *args, **kwargs
241
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: reduced_3dgs
3
- Version: 1.8.19
3
+ Version: 1.9.1
4
4
  Summary: Refactored code for the paper "Reducing the Memory Footprint of 3D Gaussian Splatting"
5
5
  Home-page: https://github.com/yindaheng98/reduced-3dgs
6
6
  Author: yindaheng98
@@ -60,7 +60,7 @@ if os.name == 'nt':
60
60
 
61
61
  setup(
62
62
  name="reduced_3dgs",
63
- version='1.8.19',
63
+ version='1.9.1',
64
64
  author='yindaheng98',
65
65
  author_email='yindaheng98@gmail.com',
66
66
  url='https://github.com/yindaheng98/reduced-3dgs',
@@ -471,9 +471,9 @@ renderCUDA_count(
471
471
  }
472
472
 
473
473
  //add count
474
- gaussian_count[collected_id[j]]++;
475
- opacity_important_score[collected_id[j]] += con_o.w; //opacity
476
- T_alpha_important_score[collected_id[j]] += alpha * T;
474
+ atomicAdd(&gaussian_count[collected_id[j]], 1);
475
+ atomicAdd(&opacity_important_score[collected_id[j]], con_o.w); //opacity
476
+ atomicAdd(&T_alpha_important_score[collected_id[j]], alpha * T);
477
477
 
478
478
  // Eq. (3) from 3D Gaussian splatting paper.
479
479
  for (int ch = 0; ch < CHANNELS; ch++)
@@ -1,130 +0,0 @@
1
- import math
2
- from typing import Callable, List
3
- import torch
4
-
5
- from gaussian_splatting import Camera, GaussianModel
6
- from gaussian_splatting.trainer import AbstractDensifier, DensifierWrapper, DensificationTrainer, NoopDensifier
7
- from gaussian_splatting.dataset import CameraDataset
8
- from .diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
9
-
10
-
11
- def count_render(self: GaussianModel, viewpoint_camera: Camera):
12
- """
13
- Render the scene.
14
-
15
- Background tensor (bg_color) must be on GPU!
16
- """
17
- # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
18
- screenspace_points = torch.zeros_like(self.get_xyz, dtype=self.get_xyz.dtype, requires_grad=True, device=self._xyz.device) + 0
19
- try:
20
- screenspace_points.retain_grad()
21
- except:
22
- pass
23
-
24
- # Set up rasterization configuration
25
- tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
26
- tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
27
-
28
- raster_settings = GaussianRasterizationSettings(
29
- image_height=int(viewpoint_camera.image_height),
30
- image_width=int(viewpoint_camera.image_width),
31
- tanfovx=tanfovx,
32
- tanfovy=tanfovy,
33
- bg=viewpoint_camera.bg_color.to(self._xyz.device),
34
- scale_modifier=self.scale_modifier,
35
- viewmatrix=viewpoint_camera.world_view_transform,
36
- projmatrix=viewpoint_camera.full_proj_transform,
37
- sh_degree=self.active_sh_degree,
38
- campos=viewpoint_camera.camera_center,
39
- prefiltered=False,
40
- debug=self.debug,
41
- f_count=True,
42
- )
43
-
44
- rasterizer = GaussianRasterizer(raster_settings=raster_settings)
45
- means3D = self.get_xyz
46
- means2D = screenspace_points
47
- opacity = self.get_opacity
48
-
49
- scales = self.get_scaling
50
- rotations = self.get_rotation
51
-
52
- shs = self.get_features
53
-
54
- # Rasterize visible Gaussians to image, obtain their radii (on screen).
55
- gaussians_count, opacity_important_score, T_alpha_important_score, rendered_image, radii = rasterizer(
56
- means3D=means3D,
57
- means2D=means2D,
58
- shs=shs,
59
- colors_precomp=None,
60
- opacities=opacity,
61
- scales=scales,
62
- rotations=rotations,
63
- cov3D_precomp=None)
64
-
65
- # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
66
- # They will be excluded from value updates used in the splitting criteria.
67
- return {
68
- "render": rendered_image,
69
- "viewspace_points": screenspace_points,
70
- "visibility_filter": radii > 0,
71
- "radii": radii,
72
- "gaussians_count": gaussians_count,
73
- "opacity_important_score": opacity_important_score,
74
- "T_alpha_important_score": T_alpha_important_score
75
- }
76
-
77
-
78
- def prune_gaussians(model: GaussianModel, dataset: CameraDataset):
79
- gaussian_count = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.int)
80
- opacity_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
81
- T_alpha_important_score = torch.zeros(model.get_xyz.shape[0], device=model.get_xyz.device, dtype=torch.float)
82
- for camera in dataset:
83
- out = count_render(model, camera)
84
- gaussian_count += out["gaussians_count"]
85
- opacity_important_score += out["opacity_important_score"]
86
- T_alpha_important_score += out["T_alpha_important_score"]
87
- return None
88
-
89
-
90
- class ImportancePruner(DensifierWrapper):
91
- def __init__(
92
- self, base_densifier: AbstractDensifier,
93
- dataset: CameraDataset,
94
- importance_prune_from_iter=15000,
95
- importance_prune_until_iter=20000,
96
- importance_prune_interval: int = 1000,
97
- ):
98
- super().__init__(base_densifier)
99
- self.dataset = dataset
100
- self.importance_prune_from_iter = importance_prune_from_iter
101
- self.importance_prune_until_iter = importance_prune_until_iter
102
- self.importance_prune_interval = importance_prune_interval
103
-
104
- def densify_and_prune(self, loss, out, camera, step: int):
105
- ret = super().densify_and_prune(loss, out, camera, step)
106
- if self.importance_prune_from_iter <= step <= self.importance_prune_until_iter and step % self.importance_prune_interval == 0:
107
- remove_mask = prune_gaussians(self.model, self.dataset)
108
- ret = ret._replace(remove_mask=remove_mask if ret.remove_mask is None else torch.logical_or(ret.remove_mask, remove_mask))
109
- return ret
110
-
111
-
112
- def BaseImportancePruningTrainer(
113
- model: GaussianModel,
114
- scene_extent: float,
115
- dataset: List[Camera],
116
- *args,
117
- importance_prune_from_iter=1000,
118
- importance_prune_until_iter=15000,
119
- importance_prune_interval: int = 100,
120
- **kwargs):
121
- return DensificationTrainer(
122
- model, scene_extent,
123
- ImportancePruner(
124
- NoopDensifier(model),
125
- dataset,
126
- importance_prune_from_iter=importance_prune_from_iter,
127
- importance_prune_until_iter=importance_prune_until_iter,
128
- importance_prune_interval=importance_prune_interval,
129
- ), *args, **kwargs
130
- )
File without changes
File without changes
File without changes