gsvvcompressor 1.2.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsvvcompressor/__init__.py +13 -0
- gsvvcompressor/__main__.py +243 -0
- gsvvcompressor/combinations/__init__.py +84 -0
- gsvvcompressor/combinations/registry.py +52 -0
- gsvvcompressor/combinations/vq_xyz_1mask.py +89 -0
- gsvvcompressor/combinations/vq_xyz_1mask_zstd.py +103 -0
- gsvvcompressor/combinations/vq_xyz_draco.py +468 -0
- gsvvcompressor/combinations/vq_xyz_draco_2pass.py +156 -0
- gsvvcompressor/combinations/vq_xyz_zstd.py +106 -0
- gsvvcompressor/compress/__init__.py +5 -0
- gsvvcompressor/compress/zstd.py +144 -0
- gsvvcompressor/decoder.py +155 -0
- gsvvcompressor/deserializer.py +42 -0
- gsvvcompressor/draco/__init__.py +34 -0
- gsvvcompressor/draco/draco_decoder.exe +0 -0
- gsvvcompressor/draco/draco_encoder.exe +0 -0
- gsvvcompressor/draco/dracoreduced3dgs.cp311-win_amd64.pyd +0 -0
- gsvvcompressor/draco/interface.py +339 -0
- gsvvcompressor/draco/serialize.py +235 -0
- gsvvcompressor/draco/twopass.py +359 -0
- gsvvcompressor/encoder.py +122 -0
- gsvvcompressor/interframe/__init__.py +11 -0
- gsvvcompressor/interframe/combine.py +271 -0
- gsvvcompressor/interframe/decoder.py +99 -0
- gsvvcompressor/interframe/encoder.py +92 -0
- gsvvcompressor/interframe/interface.py +221 -0
- gsvvcompressor/interframe/twopass.py +226 -0
- gsvvcompressor/io/__init__.py +31 -0
- gsvvcompressor/io/bytes.py +103 -0
- gsvvcompressor/io/config.py +78 -0
- gsvvcompressor/io/gaussian_model.py +127 -0
- gsvvcompressor/movecameras.py +33 -0
- gsvvcompressor/payload.py +34 -0
- gsvvcompressor/serializer.py +42 -0
- gsvvcompressor/vq/__init__.py +15 -0
- gsvvcompressor/vq/interface.py +324 -0
- gsvvcompressor/vq/singlemask.py +127 -0
- gsvvcompressor/vq/twopass.py +1 -0
- gsvvcompressor/xyz/__init__.py +26 -0
- gsvvcompressor/xyz/dense.py +39 -0
- gsvvcompressor/xyz/interface.py +382 -0
- gsvvcompressor/xyz/knn.py +141 -0
- gsvvcompressor/xyz/quant.py +143 -0
- gsvvcompressor/xyz/size.py +44 -0
- gsvvcompressor/xyz/twopass.py +1 -0
- gsvvcompressor-1.2.0.dist-info/METADATA +690 -0
- gsvvcompressor-1.2.0.dist-info/RECORD +50 -0
- gsvvcompressor-1.2.0.dist-info/WHEEL +5 -0
- gsvvcompressor-1.2.0.dist-info/licenses/LICENSE +21 -0
- gsvvcompressor-1.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
"""
|
|
2
|
+
XYZ quantization-based inter-frame codec interface.
|
|
3
|
+
|
|
4
|
+
This module provides an inter-frame codec that only operates on the xyz
|
|
5
|
+
coordinates of a GaussianModel using quantization.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Optional, Self
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from gaussian_splatting import GaussianModel
|
|
14
|
+
|
|
15
|
+
from ..payload import Payload
|
|
16
|
+
from ..interframe import InterframeEncoderInitConfig, InterframeCodecContext, InterframeCodecInterface
|
|
17
|
+
from .quant import XYZQuantConfig, compute_quant_config, quantize_xyz, dequantize_xyz
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class XYZQuantInterframeCodecConfig(InterframeEncoderInitConfig):
|
|
22
|
+
"""
|
|
23
|
+
Configuration parameters for XYZ quantization-based inter-frame codec.
|
|
24
|
+
|
|
25
|
+
This dataclass holds the initialization settings for xyz coordinate
|
|
26
|
+
quantization.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
k: Which nearest neighbor to use for step size estimation (1 = nearest).
|
|
30
|
+
sample_size: Number of points to sample for NN estimation.
|
|
31
|
+
seed: Random seed for reproducible sampling.
|
|
32
|
+
quantile: Quantile of NN distances to use for dense scale estimation.
|
|
33
|
+
alpha: Scaling factor for step size.
|
|
34
|
+
min_step: Optional minimum step size.
|
|
35
|
+
max_step: Optional maximum step size.
|
|
36
|
+
tolerance: Tolerance for inter-frame change detection. Only coordinates
|
|
37
|
+
with absolute difference > tolerance are considered changed.
|
|
38
|
+
"""
|
|
39
|
+
k: int = 1
|
|
40
|
+
sample_size: Optional[int] = 10000
|
|
41
|
+
seed: Optional[int] = 42
|
|
42
|
+
quantile: float = 0.05
|
|
43
|
+
alpha: float = 0.2
|
|
44
|
+
min_step: Optional[float] = None
|
|
45
|
+
max_step: Optional[float] = None
|
|
46
|
+
tolerance: int = 0
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class XYZQuantInterframeCodecContext(InterframeCodecContext):
|
|
51
|
+
"""
|
|
52
|
+
Context data for XYZ quantization-based inter-frame encoding/decoding.
|
|
53
|
+
|
|
54
|
+
This dataclass holds the quantization state including the quantization
|
|
55
|
+
configuration and quantized xyz coordinates.
|
|
56
|
+
|
|
57
|
+
Attributes:
|
|
58
|
+
quant_config: The quantization configuration (step_size, origin).
|
|
59
|
+
quantized_xyz: The quantized xyz coordinates, shape (N, 3), dtype int32.
|
|
60
|
+
tolerance: Tolerance for inter-frame change detection (encoder-only,
|
|
61
|
+
not used during decoding).
|
|
62
|
+
"""
|
|
63
|
+
quant_config: XYZQuantConfig
|
|
64
|
+
quantized_xyz: torch.Tensor # shape (N, 3), dtype int32
|
|
65
|
+
tolerance: int = 0 # encoder-only
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class XYZQuantKeyframePayload(Payload):
|
|
70
|
+
"""
|
|
71
|
+
Payload for XYZ quantization keyframe data.
|
|
72
|
+
|
|
73
|
+
Contains the quantization configuration and full quantized xyz coordinates.
|
|
74
|
+
|
|
75
|
+
Attributes:
|
|
76
|
+
quant_config: The quantization configuration.
|
|
77
|
+
quantized_xyz: The quantized xyz coordinates.
|
|
78
|
+
"""
|
|
79
|
+
quant_config: XYZQuantConfig
|
|
80
|
+
quantized_xyz: torch.Tensor
|
|
81
|
+
|
|
82
|
+
def to(self, device) -> Self:
|
|
83
|
+
"""
|
|
84
|
+
Move the Payload to the specified device.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
device: The target device (e.g., 'cpu', 'cuda', torch.device).
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A new XYZQuantKeyframePayload instance on the target device.
|
|
91
|
+
"""
|
|
92
|
+
return XYZQuantKeyframePayload(
|
|
93
|
+
quant_config=XYZQuantConfig(
|
|
94
|
+
step_size=self.quant_config.step_size,
|
|
95
|
+
origin=self.quant_config.origin.to(device),
|
|
96
|
+
),
|
|
97
|
+
quantized_xyz=self.quantized_xyz.to(device),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class XYZQuantInterframePayload(Payload):
|
|
103
|
+
"""
|
|
104
|
+
Payload for XYZ quantization interframe data.
|
|
105
|
+
|
|
106
|
+
Contains only the changed quantized xyz coordinates for subsequent frames.
|
|
107
|
+
The quantization configuration is inherited from the keyframe context.
|
|
108
|
+
|
|
109
|
+
Attributes:
|
|
110
|
+
xyz_mask: Boolean tensor indicating which xyz values changed, shape (N,).
|
|
111
|
+
quantized_xyz: Only the changed quantized xyz values (sparse), shape (M, 3).
|
|
112
|
+
"""
|
|
113
|
+
xyz_mask: torch.Tensor
|
|
114
|
+
quantized_xyz: torch.Tensor
|
|
115
|
+
|
|
116
|
+
def to(self, device) -> Self:
|
|
117
|
+
"""
|
|
118
|
+
Move the Payload to the specified device.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
device: The target device (e.g., 'cpu', 'cuda', torch.device).
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
A new XYZQuantInterframePayload instance on the target device.
|
|
125
|
+
"""
|
|
126
|
+
return XYZQuantInterframePayload(
|
|
127
|
+
xyz_mask=self.xyz_mask.to(device),
|
|
128
|
+
quantized_xyz=self.quantized_xyz.to(device),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class XYZQuantInterframeCodecInterface(InterframeCodecInterface):
|
|
133
|
+
"""
|
|
134
|
+
XYZ quantization-based inter-frame encoding/decoding interface.
|
|
135
|
+
|
|
136
|
+
This interface uses coordinate quantization to compress the xyz coordinates
|
|
137
|
+
of GaussianModel. The keyframe computes the quantization configuration,
|
|
138
|
+
and subsequent frames use the same configuration to quantize their coordinates.
|
|
139
|
+
|
|
140
|
+
Only operates on xyz coordinates; other GaussianModel attributes are not modified.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def decode_interframe(
|
|
144
|
+
self,
|
|
145
|
+
payload: XYZQuantInterframePayload,
|
|
146
|
+
prev_context: XYZQuantInterframeCodecContext,
|
|
147
|
+
) -> XYZQuantInterframeCodecContext:
|
|
148
|
+
"""
|
|
149
|
+
Decode a delta payload to reconstruct the next frame's context.
|
|
150
|
+
|
|
151
|
+
Applies the changed xyz values from the payload to the previous context.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
payload: The delta payload containing changed quantized xyz with mask.
|
|
155
|
+
prev_context: The context of the previous frame (contains quant_config).
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
The reconstructed context for the current frame.
|
|
159
|
+
"""
|
|
160
|
+
# Clone previous quantized_xyz and apply changes
|
|
161
|
+
new_quantized_xyz = prev_context.quantized_xyz.clone()
|
|
162
|
+
new_quantized_xyz[payload.xyz_mask] = payload.quantized_xyz
|
|
163
|
+
|
|
164
|
+
return XYZQuantInterframeCodecContext(
|
|
165
|
+
quant_config=prev_context.quant_config,
|
|
166
|
+
quantized_xyz=new_quantized_xyz,
|
|
167
|
+
# tolerance is encoder-only, decoder doesn't use it
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def encode_interframe(
|
|
171
|
+
self,
|
|
172
|
+
prev_context: XYZQuantInterframeCodecContext,
|
|
173
|
+
next_context: XYZQuantInterframeCodecContext,
|
|
174
|
+
) -> XYZQuantInterframePayload:
|
|
175
|
+
"""
|
|
176
|
+
Encode the difference between two consecutive frames.
|
|
177
|
+
|
|
178
|
+
Compares prev and next contexts to find changed xyz coordinates and stores
|
|
179
|
+
only the changed values with their corresponding mask.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
prev_context: The context of the previous frame.
|
|
183
|
+
next_context: The context of the next frame.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
A payload containing only changed quantized xyz with mask.
|
|
187
|
+
"""
|
|
188
|
+
prev_xyz = prev_context.quantized_xyz
|
|
189
|
+
next_xyz = next_context.quantized_xyz
|
|
190
|
+
tolerance = prev_context.tolerance
|
|
191
|
+
|
|
192
|
+
# Find changed xyz coordinates (any coordinate difference > tolerance)
|
|
193
|
+
diff = (prev_xyz - next_xyz).abs()
|
|
194
|
+
mask = (diff > tolerance).any(dim=-1)
|
|
195
|
+
changed_xyz = next_xyz[mask]
|
|
196
|
+
|
|
197
|
+
return XYZQuantInterframePayload(
|
|
198
|
+
xyz_mask=mask,
|
|
199
|
+
quantized_xyz=changed_xyz,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def decode_keyframe(self, payload: XYZQuantKeyframePayload) -> XYZQuantInterframeCodecContext:
|
|
203
|
+
"""
|
|
204
|
+
Decode a keyframe payload to create initial context.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
payload: The keyframe payload containing quant_config and quantized xyz.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
The context for the first/key frame.
|
|
211
|
+
"""
|
|
212
|
+
return XYZQuantInterframeCodecContext(
|
|
213
|
+
quant_config=payload.quant_config,
|
|
214
|
+
quantized_xyz=payload.quantized_xyz,
|
|
215
|
+
# tolerance is encoder-only, decoder doesn't use it
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def decode_keyframe_for_encode(
|
|
219
|
+
self,
|
|
220
|
+
payload: XYZQuantKeyframePayload,
|
|
221
|
+
context: XYZQuantInterframeCodecContext,
|
|
222
|
+
) -> XYZQuantInterframeCodecContext:
|
|
223
|
+
"""
|
|
224
|
+
Decode a keyframe payload during encoding to avoid error accumulation.
|
|
225
|
+
|
|
226
|
+
Since the encode/decode round-trip is lossless for XYZ quantization
|
|
227
|
+
(the quantized data is simply copied), we can reuse the original context.
|
|
228
|
+
This preserves context.tolerance which is needed by the encoder.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
payload: The keyframe payload that was just encoded.
|
|
232
|
+
context: The original context used for encoding this keyframe.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
The reconstructed context (same as original for this codec).
|
|
236
|
+
"""
|
|
237
|
+
# Round-trip is lossless, reuse original context (including tolerance)
|
|
238
|
+
return context
|
|
239
|
+
|
|
240
|
+
def decode_interframe_for_encode(
|
|
241
|
+
self,
|
|
242
|
+
payload: XYZQuantInterframePayload,
|
|
243
|
+
prev_context: XYZQuantInterframeCodecContext,
|
|
244
|
+
) -> XYZQuantInterframeCodecContext:
|
|
245
|
+
"""
|
|
246
|
+
Decode an interframe payload during encoding to avoid error accumulation.
|
|
247
|
+
|
|
248
|
+
Similar to decode_interframe, but preserves tolerance from prev_context
|
|
249
|
+
since it's needed by the encoder for subsequent frames.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
payload: The interframe payload that was just encoded.
|
|
253
|
+
prev_context: The previous frame's context (reconstructed version).
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
The reconstructed context with tolerance preserved.
|
|
257
|
+
"""
|
|
258
|
+
# Clone previous quantized_xyz and apply changes
|
|
259
|
+
new_quantized_xyz = prev_context.quantized_xyz.clone()
|
|
260
|
+
new_quantized_xyz[payload.xyz_mask] = payload.quantized_xyz
|
|
261
|
+
|
|
262
|
+
return XYZQuantInterframeCodecContext(
|
|
263
|
+
quant_config=prev_context.quant_config,
|
|
264
|
+
quantized_xyz=new_quantized_xyz,
|
|
265
|
+
tolerance=prev_context.tolerance, # preserve for encoder
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def encode_keyframe(self, context: XYZQuantInterframeCodecContext) -> XYZQuantKeyframePayload:
|
|
269
|
+
"""
|
|
270
|
+
Encode the first frame as a keyframe.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
context: The context of the first frame.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
A payload containing the quant_config and quantized xyz.
|
|
277
|
+
"""
|
|
278
|
+
return XYZQuantKeyframePayload(
|
|
279
|
+
quant_config=context.quant_config,
|
|
280
|
+
quantized_xyz=context.quantized_xyz,
|
|
281
|
+
# tolerance is encoder-only, not needed in payload
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def keyframe_to_context(
|
|
285
|
+
self,
|
|
286
|
+
frame: GaussianModel,
|
|
287
|
+
init_config: XYZQuantInterframeCodecConfig,
|
|
288
|
+
) -> XYZQuantInterframeCodecContext:
|
|
289
|
+
"""
|
|
290
|
+
Convert a keyframe to a XYZQuantInterframeCodecContext.
|
|
291
|
+
|
|
292
|
+
Computes quantization configuration from the frame's xyz coordinates
|
|
293
|
+
and quantizes them.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
frame: The GaussianModel frame to convert.
|
|
297
|
+
init_config: Configuration parameters for quantization.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
The corresponding XYZQuantInterframeCodecContext representation.
|
|
301
|
+
"""
|
|
302
|
+
xyz = frame.get_xyz
|
|
303
|
+
|
|
304
|
+
# Compute quantization configuration from xyz coordinates
|
|
305
|
+
quant_config = compute_quant_config(
|
|
306
|
+
points=xyz,
|
|
307
|
+
k=init_config.k,
|
|
308
|
+
sample_size=init_config.sample_size,
|
|
309
|
+
seed=init_config.seed,
|
|
310
|
+
quantile=init_config.quantile,
|
|
311
|
+
alpha=init_config.alpha,
|
|
312
|
+
min_step=init_config.min_step,
|
|
313
|
+
max_step=init_config.max_step,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Quantize xyz coordinates
|
|
317
|
+
quantized_xyz = quantize_xyz(xyz, quant_config)
|
|
318
|
+
|
|
319
|
+
return XYZQuantInterframeCodecContext(
|
|
320
|
+
quant_config=quant_config,
|
|
321
|
+
quantized_xyz=quantized_xyz,
|
|
322
|
+
tolerance=init_config.tolerance,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
def interframe_to_context(
|
|
326
|
+
self,
|
|
327
|
+
frame: GaussianModel,
|
|
328
|
+
prev_context: XYZQuantInterframeCodecContext,
|
|
329
|
+
) -> XYZQuantInterframeCodecContext:
|
|
330
|
+
"""
|
|
331
|
+
Convert a frame to a XYZQuantInterframeCodecContext using the previous context's config.
|
|
332
|
+
|
|
333
|
+
Uses the quantization config from the previous context to quantize xyz coordinates.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
frame: The GaussianModel frame to convert.
|
|
337
|
+
prev_context: The context from the previous frame.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
The corresponding XYZQuantInterframeCodecContext representation.
|
|
341
|
+
"""
|
|
342
|
+
xyz = frame.get_xyz
|
|
343
|
+
|
|
344
|
+
# Quantize xyz using the existing quant_config from keyframe
|
|
345
|
+
quantized_xyz = quantize_xyz(xyz, prev_context.quant_config)
|
|
346
|
+
|
|
347
|
+
return XYZQuantInterframeCodecContext(
|
|
348
|
+
quant_config=prev_context.quant_config,
|
|
349
|
+
quantized_xyz=quantized_xyz,
|
|
350
|
+
tolerance=prev_context.tolerance,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
def context_to_frame(
|
|
354
|
+
self,
|
|
355
|
+
context: XYZQuantInterframeCodecContext,
|
|
356
|
+
frame: GaussianModel,
|
|
357
|
+
) -> GaussianModel:
|
|
358
|
+
"""
|
|
359
|
+
Convert a XYZQuantInterframeCodecContext back to a GaussianModel frame.
|
|
360
|
+
|
|
361
|
+
Dequantizes the xyz coordinates and sets them on the frame.
|
|
362
|
+
Only modifies xyz coordinates; other attributes are not touched.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
context: The XYZQuantInterframeCodecContext to convert.
|
|
366
|
+
frame: An empty GaussianModel or one from previous pipeline steps.
|
|
367
|
+
This frame will be modified in-place with the xyz data.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
The modified GaussianModel with the xyz data.
|
|
371
|
+
"""
|
|
372
|
+
# Dequantize xyz coordinates
|
|
373
|
+
xyz = dequantize_xyz(
|
|
374
|
+
context.quantized_xyz,
|
|
375
|
+
context.quant_config,
|
|
376
|
+
dtype=frame.get_xyz.dtype,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Set xyz on the frame
|
|
380
|
+
frame._xyz = torch.nn.Parameter(xyz.to(frame.get_xyz.device))
|
|
381
|
+
|
|
382
|
+
return frame
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
K-nearest neighbor distance computation for 3D point clouds.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from scipy.spatial import cKDTree
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def compute_nn_distances(
|
|
13
|
+
points: torch.Tensor,
|
|
14
|
+
k: int = 1,
|
|
15
|
+
sample_size: Optional[int] = None,
|
|
16
|
+
seed: Optional[int] = None,
|
|
17
|
+
) -> torch.Tensor:
|
|
18
|
+
"""
|
|
19
|
+
Compute k-th nearest neighbor distances for points.
|
|
20
|
+
|
|
21
|
+
Uses scipy's cKDTree for efficient O(N log N) computation.
|
|
22
|
+
Optionally samples a subset of points to reduce computation for very large point clouds.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
points: Point coordinates, shape (N, 3).
|
|
26
|
+
k: Which nearest neighbor distance to compute (1 = nearest, 2 = second nearest, etc.).
|
|
27
|
+
sample_size: If provided, randomly sample this many points for distance estimation.
|
|
28
|
+
If None or >= N, use all points.
|
|
29
|
+
seed: Random seed for reproducible sampling.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Tensor of k-th nearest neighbor distances, shape (M,) where M = min(sample_size, N).
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If there are not enough points to compute k-th neighbor (n_points <= k).
|
|
36
|
+
"""
|
|
37
|
+
device = points.device
|
|
38
|
+
points_np = points.detach().cpu().numpy().astype(np.float64)
|
|
39
|
+
n_points = points_np.shape[0]
|
|
40
|
+
|
|
41
|
+
if n_points <= k:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Not enough points to compute {k}-th nearest neighbor. "
|
|
44
|
+
f"Got {n_points} points, need at least {k + 1}."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Optionally subsample for large point clouds
|
|
48
|
+
if sample_size is not None and sample_size < n_points:
|
|
49
|
+
rng = np.random.default_rng(seed)
|
|
50
|
+
indices = rng.choice(n_points, size=sample_size, replace=False)
|
|
51
|
+
query_points = points_np[indices]
|
|
52
|
+
else:
|
|
53
|
+
query_points = points_np
|
|
54
|
+
|
|
55
|
+
tree = cKDTree(points_np)
|
|
56
|
+
# Query k+1 neighbors because the first neighbor is the point itself
|
|
57
|
+
distances, _ = tree.query(query_points, k=k + 1)
|
|
58
|
+
|
|
59
|
+
# Extract the k-th neighbor distance (index k, since index 0 is self with distance 0)
|
|
60
|
+
nn_distances = distances[:, k]
|
|
61
|
+
|
|
62
|
+
return torch.from_numpy(nn_distances).to(device=device, dtype=points.dtype)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
if __name__ == "__main__":
|
|
66
|
+
import time
|
|
67
|
+
|
|
68
|
+
print("=" * 60)
|
|
69
|
+
print("Testing compute_nn_distances")
|
|
70
|
+
print("=" * 60)
|
|
71
|
+
|
|
72
|
+
# Test 1: Simple grid points
|
|
73
|
+
print("\n[Test 1] Simple 2x2x2 grid with spacing 1.0")
|
|
74
|
+
grid_points = torch.tensor([
|
|
75
|
+
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
|
|
76
|
+
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1],
|
|
77
|
+
], dtype=torch.float32)
|
|
78
|
+
nn_dist = compute_nn_distances(grid_points, k=1)
|
|
79
|
+
print(f" Points shape: {grid_points.shape}")
|
|
80
|
+
print(f" NN distances: {nn_dist}")
|
|
81
|
+
print(f" Expected: all 1.0 (grid spacing)")
|
|
82
|
+
assert torch.allclose(nn_dist, torch.ones(8)), "Test 1 failed!"
|
|
83
|
+
print(" PASSED")
|
|
84
|
+
|
|
85
|
+
# Test 2: Different k values
|
|
86
|
+
print("\n[Test 2] k=2 on grid (second nearest neighbor)")
|
|
87
|
+
nn_dist_k2 = compute_nn_distances(grid_points, k=2)
|
|
88
|
+
print(f" NN distances (k=2): {nn_dist_k2}")
|
|
89
|
+
print(f" Expected: all 1.0 (each point has multiple neighbors at distance 1)")
|
|
90
|
+
assert torch.allclose(nn_dist_k2, torch.ones(8)), "Test 2 failed!"
|
|
91
|
+
print(" PASSED")
|
|
92
|
+
|
|
93
|
+
# Test 3: k=4 should give sqrt(2) for corner points (face diagonal)
|
|
94
|
+
print("\n[Test 3] k=4 on grid (fourth nearest neighbor)")
|
|
95
|
+
nn_dist_k4 = compute_nn_distances(grid_points, k=4)
|
|
96
|
+
print(f" NN distances (k=4): {nn_dist_k4}")
|
|
97
|
+
print(f" Expected: all sqrt(2) ~ 1.414 (face diagonal neighbors)")
|
|
98
|
+
# Each corner has 3 edge neighbors at dist 1, then 3 face-diagonal neighbors at sqrt(2)
|
|
99
|
+
assert torch.allclose(nn_dist_k4, torch.full((8,), 2**0.5), atol=1e-5), "Test 3 failed!"
|
|
100
|
+
print(" PASSED")
|
|
101
|
+
|
|
102
|
+
# Test 4: Sampling
|
|
103
|
+
print("\n[Test 4] Sampling from larger point cloud")
|
|
104
|
+
large_points = torch.rand(1000, 3)
|
|
105
|
+
nn_dist_full = compute_nn_distances(large_points, k=1)
|
|
106
|
+
nn_dist_sampled = compute_nn_distances(large_points, k=1, sample_size=100, seed=42)
|
|
107
|
+
print(f" Full cloud: {len(nn_dist_full)} distances")
|
|
108
|
+
print(f" Sampled: {len(nn_dist_sampled)} distances")
|
|
109
|
+
print(f" Full mean: {nn_dist_full.mean():.4f}, Sampled mean: {nn_dist_sampled.mean():.4f}")
|
|
110
|
+
assert len(nn_dist_sampled) == 100, "Test 4 failed!"
|
|
111
|
+
print(" PASSED")
|
|
112
|
+
|
|
113
|
+
# Test 5: Error on insufficient points
|
|
114
|
+
print("\n[Test 5] Error when n_points <= k")
|
|
115
|
+
try:
|
|
116
|
+
compute_nn_distances(torch.rand(2, 3), k=2)
|
|
117
|
+
print(" FAILED - should have raised ValueError")
|
|
118
|
+
except ValueError as e:
|
|
119
|
+
print(f" Caught expected error: {e}")
|
|
120
|
+
print(" PASSED")
|
|
121
|
+
|
|
122
|
+
# Test 6: Reproducibility with seed
|
|
123
|
+
print("\n[Test 6] Reproducibility with seed")
|
|
124
|
+
dist1 = compute_nn_distances(large_points, k=1, sample_size=50, seed=123)
|
|
125
|
+
dist2 = compute_nn_distances(large_points, k=1, sample_size=50, seed=123)
|
|
126
|
+
assert torch.allclose(dist1, dist2), "Test 6 failed!"
|
|
127
|
+
print(" Same seed gives same results: PASSED")
|
|
128
|
+
|
|
129
|
+
# Test 7: Performance test
|
|
130
|
+
print("\n[Test 7] Performance test")
|
|
131
|
+
for n in [10_000, 100_000, 500_000]:
|
|
132
|
+
points = torch.rand(n, 3)
|
|
133
|
+
|
|
134
|
+
start = time.perf_counter()
|
|
135
|
+
_ = compute_nn_distances(points, k=1, sample_size=10000, seed=42)
|
|
136
|
+
elapsed = time.perf_counter() - start
|
|
137
|
+
print(f" {n:>7,} points (sampled 10k): {elapsed:.3f}s")
|
|
138
|
+
|
|
139
|
+
print("\n" + "=" * 60)
|
|
140
|
+
print("All tests passed!")
|
|
141
|
+
print("=" * 60)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""
|
|
2
|
+
XYZ coordinate quantization and dequantization.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .knn import compute_nn_distances
|
|
11
|
+
from .dense import compute_dense_scale
|
|
12
|
+
from .size import compute_step_size
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class XYZQuantConfig:
|
|
17
|
+
"""
|
|
18
|
+
Configuration for XYZ coordinate quantization.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
step_size: The quantization step size (delta).
|
|
22
|
+
origin: The origin offset for quantization (median of coordinates).
|
|
23
|
+
"""
|
|
24
|
+
step_size: float
|
|
25
|
+
origin: torch.Tensor # shape (3,)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def compute_quant_config(
|
|
29
|
+
points: torch.Tensor,
|
|
30
|
+
k: int = 1,
|
|
31
|
+
sample_size: Optional[int] = 10000,
|
|
32
|
+
seed: Optional[int] = 42,
|
|
33
|
+
quantile: float = 0.05,
|
|
34
|
+
alpha: float = 0.2,
|
|
35
|
+
min_step: Optional[float] = None,
|
|
36
|
+
max_step: Optional[float] = None,
|
|
37
|
+
) -> XYZQuantConfig:
|
|
38
|
+
"""
|
|
39
|
+
Compute quantization configuration from points.
|
|
40
|
+
|
|
41
|
+
Internally computes step_size via:
|
|
42
|
+
1. compute_nn_distances() - k-th nearest neighbor distances
|
|
43
|
+
2. compute_dense_scale() - low quantile of NN distances
|
|
44
|
+
3. compute_step_size() - alpha * dense_scale
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
points: Point coordinates, shape (N, 3).
|
|
48
|
+
k: Which nearest neighbor to use (1 = nearest).
|
|
49
|
+
sample_size: Number of points to sample for NN estimation.
|
|
50
|
+
Set to None to use all points. Default 10000 balances speed and accuracy.
|
|
51
|
+
seed: Random seed for reproducible sampling.
|
|
52
|
+
quantile: Quantile of NN distances to use (e.g., 0.01 or 0.05).
|
|
53
|
+
Lower values target denser regions.
|
|
54
|
+
alpha: Scaling factor for step size (typical range: 0.1 to 0.25).
|
|
55
|
+
Smaller values preserve more detail but increase data size.
|
|
56
|
+
min_step: Optional minimum step size to prevent extremely fine quantization.
|
|
57
|
+
max_step: Optional maximum step size to prevent overly coarse quantization.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
XYZQuantConfig with step size and origin (median of coordinates).
|
|
61
|
+
"""
|
|
62
|
+
nn_distances = compute_nn_distances(points, k=k, sample_size=sample_size, seed=seed)
|
|
63
|
+
dense_scale = compute_dense_scale(nn_distances, quantile=quantile)
|
|
64
|
+
step_size = compute_step_size(dense_scale, alpha=alpha, min_step=min_step, max_step=max_step)
|
|
65
|
+
origin = points.median(dim=0).values
|
|
66
|
+
return XYZQuantConfig(step_size=step_size, origin=origin)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def quantize_xyz(
|
|
70
|
+
points: torch.Tensor,
|
|
71
|
+
config: XYZQuantConfig,
|
|
72
|
+
) -> torch.Tensor:
|
|
73
|
+
"""
|
|
74
|
+
Quantize XYZ coordinates using the given configuration.
|
|
75
|
+
|
|
76
|
+
Applies the formula:
|
|
77
|
+
quantized = round((points - origin) / step_size)
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
points: Point coordinates, shape (N, 3).
|
|
81
|
+
config: Quantization configuration (step size and origin).
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Quantized coordinates as integers, shape (N, 3), dtype torch.int32.
|
|
85
|
+
"""
|
|
86
|
+
normalized = (points - config.origin) / config.step_size
|
|
87
|
+
quantized = torch.round(normalized).to(torch.int32)
|
|
88
|
+
return quantized
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def dequantize_xyz(
|
|
92
|
+
quantized: torch.Tensor,
|
|
93
|
+
config: XYZQuantConfig,
|
|
94
|
+
dtype: torch.dtype = torch.float32,
|
|
95
|
+
) -> torch.Tensor:
|
|
96
|
+
"""
|
|
97
|
+
Dequantize XYZ coordinates back to floating point.
|
|
98
|
+
|
|
99
|
+
Applies the formula:
|
|
100
|
+
points = quantized * step_size + origin
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
quantized: Quantized coordinates, shape (N, 3), dtype torch.int32.
|
|
104
|
+
config: Quantization configuration (step size and origin).
|
|
105
|
+
dtype: Output dtype for the dequantized coordinates.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Dequantized coordinates, shape (N, 3).
|
|
109
|
+
"""
|
|
110
|
+
return quantized.to(dtype) * config.step_size + config.origin
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def estimate_quantization_error(
|
|
114
|
+
points: torch.Tensor,
|
|
115
|
+
config: XYZQuantConfig,
|
|
116
|
+
) -> dict:
|
|
117
|
+
"""
|
|
118
|
+
Estimate the quantization error statistics.
|
|
119
|
+
|
|
120
|
+
Useful for debugging and validating quantization parameters.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
points: Original point coordinates, shape (N, 3).
|
|
124
|
+
config: Quantization configuration.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Dictionary with error statistics:
|
|
128
|
+
- 'max_error': Maximum absolute error across all coordinates.
|
|
129
|
+
- 'mean_error': Mean absolute error.
|
|
130
|
+
- 'rmse': Root mean squared error.
|
|
131
|
+
- 'relative_max_error': Max error relative to step size.
|
|
132
|
+
"""
|
|
133
|
+
quantized = quantize_xyz(points, config)
|
|
134
|
+
reconstructed = dequantize_xyz(quantized, config, dtype=points.dtype)
|
|
135
|
+
|
|
136
|
+
error = (points - reconstructed).abs()
|
|
137
|
+
|
|
138
|
+
return {
|
|
139
|
+
'max_error': error.max().item(),
|
|
140
|
+
'mean_error': error.mean().item(),
|
|
141
|
+
'rmse': torch.sqrt((error ** 2).mean()).item(),
|
|
142
|
+
'relative_max_error': error.max().item() / config.step_size,
|
|
143
|
+
}
|