reduced-3dgs 1.10.0__cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reduced-3dgs might be problematic. Click here for more details.

Files changed (31) hide show
  1. reduced_3dgs/__init__.py +0 -0
  2. reduced_3dgs/combinations.py +245 -0
  3. reduced_3dgs/diff_gaussian_rasterization/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  4. reduced_3dgs/diff_gaussian_rasterization/__init__.py +235 -0
  5. reduced_3dgs/importance/__init__.py +3 -0
  6. reduced_3dgs/importance/combinations.py +63 -0
  7. reduced_3dgs/importance/diff_gaussian_rasterization/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  8. reduced_3dgs/importance/diff_gaussian_rasterization/__init__.py +347 -0
  9. reduced_3dgs/importance/trainer.py +269 -0
  10. reduced_3dgs/pruning/__init__.py +2 -0
  11. reduced_3dgs/pruning/combinations.py +65 -0
  12. reduced_3dgs/pruning/trainer.py +145 -0
  13. reduced_3dgs/quantization/__init__.py +4 -0
  14. reduced_3dgs/quantization/abc.py +49 -0
  15. reduced_3dgs/quantization/exclude_zeros.py +41 -0
  16. reduced_3dgs/quantization/quantizer.py +289 -0
  17. reduced_3dgs/quantization/wrapper.py +67 -0
  18. reduced_3dgs/quantize.py +49 -0
  19. reduced_3dgs/shculling/__init__.py +2 -0
  20. reduced_3dgs/shculling/gaussian_model.py +78 -0
  21. reduced_3dgs/shculling/trainer.py +158 -0
  22. reduced_3dgs/simple_knn/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  23. reduced_3dgs/train.py +195 -0
  24. reduced_3dgs-1.10.0.dist-info/LICENSE.md +93 -0
  25. reduced_3dgs-1.10.0.dist-info/METADATA +278 -0
  26. reduced_3dgs-1.10.0.dist-info/RECORD +31 -0
  27. reduced_3dgs-1.10.0.dist-info/WHEEL +6 -0
  28. reduced_3dgs-1.10.0.dist-info/top_level.txt +1 -0
  29. reduced_3dgs.libs/libc10-ff4eddb5.so +0 -0
  30. reduced_3dgs.libs/libc10_cuda-c675d3fb.so +0 -0
  31. reduced_3dgs.libs/libcudart-8774224f.so.12.4.127 +0 -0
File without changes
@@ -0,0 +1,245 @@
1
+ from typing import List
2
+ from gaussian_splatting import GaussianModel, CameraTrainableGaussianModel, Camera
3
+ from gaussian_splatting.dataset import CameraDataset, TrainableCameraDataset
4
+ from gaussian_splatting.trainer import OpacityResetDensificationTrainer
5
+ # from gaussian_splatting.trainer import BaseOpacityResetDensificationTrainer as OpacityResetDensificationTrainer
6
+ from gaussian_splatting.trainer import OpacityResetTrainerWrapper, CameraTrainerWrapper, NoopDensifier, DepthTrainerWrapper
7
+ from .shculling import VariableSHGaussianModel, SHCullingTrainerWrapper
8
+ from .shculling import SHCullingTrainer
9
+ # from .shculling import BaseSHCullingTrainer as SHCullingTrainer
10
+ from .pruning import PruningTrainerWrapper, PrunerInDensifyTrainerWrapper
11
+ # from .pruning import BasePruningTrainer as PruningTrainer, BasePrunerInDensifyTrainer as PrunerInDensifyTrainer
12
+ from .importance import ImportancePruner
13
+
14
+
15
+ def BaseFullPruningTrainer(
16
+ model: GaussianModel,
17
+ scene_extent: float,
18
+ dataset: List[Camera],
19
+ *args,
20
+ importance_prune_from_iter=15000,
21
+ importance_prune_until_iter=20000,
22
+ importance_prune_interval: int = 1000,
23
+ importance_score_resize=None,
24
+ importance_prune_type="comprehensive",
25
+ importance_prune_percent=0.1,
26
+ importance_prune_thr_important_score=None,
27
+ importance_prune_thr_v_important_score=3.0,
28
+ importance_prune_thr_max_v_important_score=None,
29
+ importance_prune_thr_count=1,
30
+ importance_prune_thr_T_alpha=1.0,
31
+ importance_prune_thr_T_alpha_avg=0.001,
32
+ importance_v_pow=0.1,
33
+ **kwargs):
34
+ return PruningTrainerWrapper(
35
+ lambda model, scene_extent, dataset: ImportancePruner(
36
+ NoopDensifier(model),
37
+ dataset,
38
+ importance_prune_from_iter=importance_prune_from_iter,
39
+ importance_prune_until_iter=importance_prune_until_iter,
40
+ importance_prune_interval=importance_prune_interval,
41
+ importance_score_resize=importance_score_resize,
42
+ importance_prune_type=importance_prune_type,
43
+ importance_prune_percent=importance_prune_percent,
44
+ importance_prune_thr_important_score=importance_prune_thr_important_score,
45
+ importance_prune_thr_v_important_score=importance_prune_thr_v_important_score,
46
+ importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
47
+ importance_prune_thr_count=importance_prune_thr_count,
48
+ importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
49
+ importance_prune_thr_T_alpha_avg=importance_prune_thr_T_alpha_avg,
50
+ importance_v_pow=importance_v_pow,
51
+ ),
52
+ model, scene_extent, dataset,
53
+ *args, **kwargs
54
+ )
55
+
56
+
57
+ def BaseFullPrunerInDensifyTrainer(
58
+ model: GaussianModel,
59
+ scene_extent: float,
60
+ dataset: List[Camera],
61
+ *args,
62
+ importance_prune_from_iter=15000,
63
+ importance_prune_until_iter=20000,
64
+ importance_prune_interval: int = 1000,
65
+ importance_score_resize=None,
66
+ importance_prune_type="comprehensive",
67
+ importance_prune_percent=0.1,
68
+ importance_prune_thr_important_score=None,
69
+ importance_prune_thr_v_important_score=3.0,
70
+ importance_prune_thr_max_v_important_score=None,
71
+ importance_prune_thr_count=1,
72
+ importance_prune_thr_T_alpha=1.0,
73
+ importance_prune_thr_T_alpha_avg=0.001,
74
+ importance_v_pow=0.1,
75
+ **kwargs):
76
+ return PrunerInDensifyTrainerWrapper(
77
+ lambda model, scene_extent, dataset: ImportancePruner(
78
+ NoopDensifier(model),
79
+ dataset,
80
+ importance_prune_from_iter=importance_prune_from_iter,
81
+ importance_prune_until_iter=importance_prune_until_iter,
82
+ importance_prune_interval=importance_prune_interval,
83
+ importance_score_resize=importance_score_resize,
84
+ importance_prune_type=importance_prune_type,
85
+ importance_prune_percent=importance_prune_percent,
86
+ importance_prune_thr_important_score=importance_prune_thr_important_score,
87
+ importance_prune_thr_v_important_score=importance_prune_thr_v_important_score,
88
+ importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
89
+ importance_prune_thr_count=importance_prune_thr_count,
90
+ importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
91
+ importance_prune_thr_T_alpha_avg=importance_prune_thr_T_alpha_avg,
92
+ importance_v_pow=importance_v_pow,
93
+ ),
94
+ model, scene_extent, dataset,
95
+ *args, **kwargs
96
+ )
97
+
98
+
99
+ def DepthFullPruningTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
100
+ return DepthTrainerWrapper(BaseFullPruningTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
101
+
102
+
103
+ def DepthFullPrunerInDensifyTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
104
+ return DepthTrainerWrapper(BaseFullPrunerInDensifyTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
105
+
106
+
107
+ def OpacityResetPruningTrainer(
108
+ model: GaussianModel,
109
+ scene_extent: float,
110
+ dataset: CameraDataset,
111
+ *args, **kwargs):
112
+ return OpacityResetTrainerWrapper(
113
+ lambda model, scene_extent, *args, **kwargs: DepthFullPruningTrainer(model, scene_extent, dataset, *args, **kwargs),
114
+ model, scene_extent,
115
+ *args, **kwargs
116
+ )
117
+
118
+
119
+ def OpacityResetPrunerInDensifyTrainer(
120
+ model: GaussianModel,
121
+ scene_extent: float,
122
+ dataset: CameraDataset,
123
+ *args, **kwargs):
124
+ return OpacityResetTrainerWrapper(
125
+ lambda model, scene_extent, *args, **kwargs: DepthFullPrunerInDensifyTrainer(model, scene_extent, dataset, *args, **kwargs),
126
+ model, scene_extent,
127
+ *args, **kwargs
128
+ )
129
+
130
+
131
+ PruningTrainer = OpacityResetPruningTrainer
132
+ PrunerInDensifyTrainer = OpacityResetPrunerInDensifyTrainer
133
+
134
+
135
+ def SHCullingDensificationTrainer(
136
+ model: VariableSHGaussianModel,
137
+ scene_extent: float,
138
+ dataset: CameraDataset,
139
+ *args, **kwargs):
140
+ return SHCullingTrainerWrapper(
141
+ lambda model, scene_extent, dataset, *args, **kwargs: OpacityResetDensificationTrainer(model, scene_extent, *args, **kwargs),
142
+ model, scene_extent, dataset,
143
+ *args, **kwargs
144
+ )
145
+
146
+
147
+ def SHCullingPruningTrainer(
148
+ model: VariableSHGaussianModel,
149
+ scene_extent: float,
150
+ dataset: CameraDataset,
151
+ *args, **kwargs):
152
+ return SHCullingTrainerWrapper(
153
+ OpacityResetPruningTrainer,
154
+ model, scene_extent, dataset,
155
+ *args, **kwargs
156
+ )
157
+
158
+
159
+ def SHCullingPrunerInDensifyTrainer(
160
+ model: VariableSHGaussianModel,
161
+ scene_extent: float,
162
+ dataset: CameraDataset,
163
+ *args, **kwargs):
164
+ return SHCullingTrainerWrapper(
165
+ OpacityResetPrunerInDensifyTrainer,
166
+ model, scene_extent, dataset,
167
+ *args, **kwargs
168
+ )
169
+
170
+
171
+ class CameraTrainableVariableSHGaussianModel(VariableSHGaussianModel):
172
+ def forward(self, camera: Camera):
173
+ return CameraTrainableGaussianModel.forward(self, camera)
174
+
175
+
176
+ def CameraSHCullingTrainer(
177
+ model: CameraTrainableVariableSHGaussianModel,
178
+ scene_extent: float,
179
+ dataset: TrainableCameraDataset,
180
+ *args, **kwargs):
181
+ return CameraTrainerWrapper(
182
+ SHCullingTrainer,
183
+ model, scene_extent, dataset,
184
+ *args, **kwargs
185
+ )
186
+
187
+
188
+ def CameraPruningTrainer(
189
+ model: CameraTrainableVariableSHGaussianModel,
190
+ scene_extent: float,
191
+ dataset: TrainableCameraDataset,
192
+ *args, **kwargs):
193
+ return CameraTrainerWrapper(
194
+ OpacityResetPruningTrainer,
195
+ model, scene_extent, dataset,
196
+ *args, **kwargs
197
+ )
198
+
199
+
200
+ def CameraPrunerInDensifyTrainer(
201
+ model: CameraTrainableVariableSHGaussianModel,
202
+ scene_extent: float,
203
+ dataset: TrainableCameraDataset,
204
+ *args, **kwargs):
205
+ return CameraTrainerWrapper(
206
+ OpacityResetPrunerInDensifyTrainer,
207
+ model, scene_extent, dataset,
208
+ *args, **kwargs
209
+ )
210
+
211
+
212
+ def CameraSHCullingDensifyTrainer(
213
+ model: CameraTrainableVariableSHGaussianModel,
214
+ scene_extent: float,
215
+ dataset: TrainableCameraDataset,
216
+ *args, **kwargs):
217
+ return CameraTrainerWrapper(
218
+ SHCullingDensificationTrainer,
219
+ model, scene_extent, dataset,
220
+ *args, **kwargs
221
+ )
222
+
223
+
224
+ def CameraSHCullingPruningTrainer(
225
+ model: CameraTrainableVariableSHGaussianModel,
226
+ scene_extent: float,
227
+ dataset: TrainableCameraDataset,
228
+ *args, **kwargs):
229
+ return CameraTrainerWrapper(
230
+ SHCullingPruningTrainer,
231
+ model, scene_extent, dataset,
232
+ *args, **kwargs
233
+ )
234
+
235
+
236
+ def CameraSHCullingPruningDensifyTrainer(
237
+ model: CameraTrainableVariableSHGaussianModel,
238
+ scene_extent: float,
239
+ dataset: TrainableCameraDataset,
240
+ *args, **kwargs):
241
+ return CameraTrainerWrapper(
242
+ SHCullingPrunerInDensifyTrainer,
243
+ model, scene_extent, dataset,
244
+ *args, **kwargs
245
+ )
@@ -0,0 +1,235 @@
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from typing import NamedTuple
13
+ import torch.nn as nn
14
+ import torch
15
+ from . import _C
16
+
17
+ def cpu_deep_copy_tuple(input_tuple):
18
+ copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
19
+ return tuple(copied_tensors)
20
+
21
+ def rasterize_gaussians(
22
+ means3D,
23
+ means2D,
24
+ sh,
25
+ degrees,
26
+ colors_precomp,
27
+ opacities,
28
+ scales,
29
+ rotations,
30
+ cov3Ds_precomp,
31
+ raster_settings,
32
+ lambda_sh_sparsity
33
+ ):
34
+ return _RasterizeGaussians.apply(
35
+ means3D,
36
+ means2D,
37
+ sh,
38
+ degrees,
39
+ colors_precomp,
40
+ opacities,
41
+ scales,
42
+ rotations,
43
+ cov3Ds_precomp,
44
+ raster_settings,
45
+ lambda_sh_sparsity
46
+ )
47
+
48
+ class _RasterizeGaussians(torch.autograd.Function):
49
+ @staticmethod
50
+ def forward(
51
+ ctx,
52
+ means3D,
53
+ means2D,
54
+ sh,
55
+ degrees,
56
+ colors_precomp,
57
+ opacities,
58
+ scales,
59
+ rotations,
60
+ cov3Ds_precomp,
61
+ raster_settings,
62
+ lambda_sh_sparsity
63
+ ):
64
+
65
+ # Restructure arguments the way that the C++ lib expects them
66
+ args = (
67
+ raster_settings.bg,
68
+ means3D,
69
+ colors_precomp,
70
+ opacities,
71
+ scales,
72
+ rotations,
73
+ raster_settings.scale_modifier,
74
+ cov3Ds_precomp,
75
+ raster_settings.viewmatrix,
76
+ raster_settings.projmatrix,
77
+ raster_settings.tanfovx,
78
+ raster_settings.tanfovy,
79
+ raster_settings.image_height,
80
+ raster_settings.image_width,
81
+ sh,
82
+ degrees,
83
+ raster_settings.campos,
84
+ raster_settings.prefiltered,
85
+ True
86
+ # raster_settings.debug
87
+ )
88
+
89
+ # Invoke C++/CUDA rasterizer
90
+ if raster_settings.debug:
91
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
92
+ try:
93
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
94
+ except Exception as ex:
95
+ torch.save(cpu_args, "snapshot_fw.dump")
96
+ print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
97
+ raise ex
98
+ else:
99
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
100
+
101
+ # Keep relevant tensors for backward
102
+ ctx.raster_settings = raster_settings
103
+ ctx.num_rendered = num_rendered
104
+ ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, degrees)
105
+ ctx.lambda_sh_sparsity = lambda_sh_sparsity
106
+ return color, radii
107
+
108
+ @staticmethod
109
+ def backward(ctx, grad_out_color, _):
110
+
111
+ # Restore necessary values from context
112
+ num_rendered = ctx.num_rendered
113
+ raster_settings = ctx.raster_settings
114
+ lambda_sh_sparsity = ctx.lambda_sh_sparsity
115
+ colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, degrees = ctx.saved_tensors
116
+
117
+ # Restructure args as C++ method expects them
118
+ args = (raster_settings.bg,
119
+ means3D,
120
+ radii,
121
+ colors_precomp,
122
+ scales,
123
+ rotations,
124
+ raster_settings.scale_modifier,
125
+ cov3Ds_precomp,
126
+ raster_settings.viewmatrix,
127
+ raster_settings.projmatrix,
128
+ raster_settings.tanfovx,
129
+ raster_settings.tanfovy,
130
+ grad_out_color,
131
+ sh,
132
+ degrees,
133
+ raster_settings.campos,
134
+ geomBuffer,
135
+ num_rendered,
136
+ binningBuffer,
137
+ imgBuffer,
138
+ lambda_sh_sparsity,
139
+ raster_settings.debug)
140
+
141
+ # Compute gradients for relevant tensors by invoking backward method
142
+ if raster_settings.debug:
143
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
144
+ try:
145
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
146
+ except Exception as ex:
147
+ torch.save(cpu_args, "snapshot_bw.dump")
148
+ print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
149
+ raise ex
150
+ else:
151
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
152
+
153
+ grads = (
154
+ grad_means3D,
155
+ grad_means2D,
156
+ grad_sh,
157
+ None,
158
+ grad_colors_precomp,
159
+ grad_opacities,
160
+ grad_scales,
161
+ grad_rotations,
162
+ grad_cov3Ds_precomp,
163
+ None,
164
+ None,
165
+ )
166
+
167
+ return grads
168
+
169
+ class GaussianRasterizationSettings(NamedTuple):
170
+ image_height: int
171
+ image_width: int
172
+ tanfovx : float
173
+ tanfovy : float
174
+ bg : torch.Tensor
175
+ scale_modifier : float
176
+ viewmatrix : torch.Tensor
177
+ projmatrix : torch.Tensor
178
+ sh_degree : int
179
+ campos : torch.Tensor
180
+ prefiltered : bool
181
+ debug : bool
182
+
183
+ class GaussianRasterizer(nn.Module):
184
+ def __init__(self, raster_settings):
185
+ super().__init__()
186
+ self.raster_settings = raster_settings
187
+
188
+ def markVisible(self, positions):
189
+ # Mark visible points (based on frustum culling for camera) with a boolean
190
+ with torch.no_grad():
191
+ raster_settings = self.raster_settings
192
+ visible = _C.mark_visible(
193
+ positions,
194
+ raster_settings.viewmatrix,
195
+ raster_settings.projmatrix)
196
+
197
+ return visible
198
+
199
+ def forward(self, means3D, means2D, opacities, shs = None, degrees = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None, lambda_sh_sparsity=0.):
200
+
201
+ raster_settings = self.raster_settings
202
+
203
+ if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
204
+ raise Exception('Please provide excatly one of either SHs or precomputed colors!')
205
+
206
+ if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
207
+ raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
208
+
209
+ if shs is None:
210
+ shs = torch.Tensor([])
211
+ if colors_precomp is None:
212
+ colors_precomp = torch.Tensor([])
213
+
214
+ if scales is None:
215
+ scales = torch.Tensor([])
216
+ if rotations is None:
217
+ rotations = torch.Tensor([])
218
+ if cov3D_precomp is None:
219
+ cov3D_precomp = torch.Tensor([])
220
+
221
+ # Invoke C++/CUDA rasterization routine
222
+ return rasterize_gaussians(
223
+ means3D,
224
+ means2D,
225
+ shs,
226
+ degrees,
227
+ colors_precomp,
228
+ opacities,
229
+ scales,
230
+ rotations,
231
+ cov3D_precomp,
232
+ raster_settings,
233
+ lambda_sh_sparsity
234
+ )
235
+
@@ -0,0 +1,3 @@
1
+ from .trainer import ImportancePruner, BaseImportancePruningTrainer
2
+ from .combinations import BaseImportancePrunerInDensifyTrainer, DepthImportancePruningTrainer, DepthImportancePrunerInDensifyTrainer
3
+ from .combinations import ImportancePruningTrainer, ImportancePrunerInDensifyTrainer
@@ -0,0 +1,63 @@
1
+ from typing import List
2
+ from gaussian_splatting import Camera, GaussianModel
3
+ from gaussian_splatting.dataset import TrainableCameraDataset
4
+ from gaussian_splatting.trainer import DepthTrainerWrapper, NoopDensifier, DensificationTrainerWrapper
5
+ from .trainer import ImportancePruner, BaseImportancePruningTrainer
6
+
7
+
8
+ def BaseImportancePrunerInDensifyTrainer(
9
+ model: GaussianModel,
10
+ scene_extent: float,
11
+ dataset: List[Camera],
12
+ *args,
13
+ importance_prune_from_iter=15000,
14
+ importance_prune_until_iter=20000,
15
+ importance_prune_interval: int = 1000,
16
+ importance_score_resize=None,
17
+ importance_prune_type="comprehensive",
18
+ importance_prune_percent=0.1,
19
+ importance_prune_thr_important_score=None,
20
+ importance_prune_thr_v_important_score=3.0,
21
+ importance_prune_thr_max_v_important_score=None,
22
+ importance_prune_thr_count=1,
23
+ importance_prune_thr_T_alpha=1.0,
24
+ importance_prune_thr_T_alpha_avg=0.001,
25
+ importance_v_pow=0.1,
26
+ **kwargs):
27
+ return DensificationTrainerWrapper(
28
+ lambda model, scene_extent: ImportancePruner(
29
+ NoopDensifier(model),
30
+ dataset,
31
+ importance_prune_from_iter=importance_prune_from_iter,
32
+ importance_prune_until_iter=importance_prune_until_iter,
33
+ importance_prune_interval=importance_prune_interval,
34
+ importance_score_resize=importance_score_resize,
35
+ importance_prune_type=importance_prune_type,
36
+ importance_prune_percent=importance_prune_percent,
37
+ importance_prune_thr_important_score=importance_prune_thr_important_score,
38
+ importance_prune_thr_v_important_score=importance_prune_thr_v_important_score,
39
+ importance_prune_thr_max_v_important_score=importance_prune_thr_max_v_important_score,
40
+ importance_prune_thr_count=importance_prune_thr_count,
41
+ importance_prune_thr_T_alpha=importance_prune_thr_T_alpha,
42
+ importance_prune_thr_T_alpha_avg=importance_prune_thr_T_alpha_avg,
43
+ importance_v_pow=importance_v_pow,
44
+ ),
45
+ model,
46
+ scene_extent,
47
+ *args, **kwargs
48
+ )
49
+
50
+
51
+ # Depth trainer
52
+
53
+
54
+ def DepthImportancePruningTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
55
+ return DepthTrainerWrapper(BaseImportancePruningTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
56
+
57
+
58
+ def DepthImportancePrunerInDensifyTrainer(model: GaussianModel, scene_extent: float, dataset: TrainableCameraDataset, *args, **kwargs):
59
+ return DepthTrainerWrapper(BaseImportancePrunerInDensifyTrainer, model, scene_extent, *args, dataset=dataset, **kwargs)
60
+
61
+
62
+ ImportancePruningTrainer = DepthImportancePruningTrainer
63
+ ImportancePrunerInDensifyTrainer = DepthImportancePrunerInDensifyTrainer