gsvvcompressor 1.2.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. gsvvcompressor/__init__.py +13 -0
  2. gsvvcompressor/__main__.py +243 -0
  3. gsvvcompressor/combinations/__init__.py +84 -0
  4. gsvvcompressor/combinations/registry.py +52 -0
  5. gsvvcompressor/combinations/vq_xyz_1mask.py +89 -0
  6. gsvvcompressor/combinations/vq_xyz_1mask_zstd.py +103 -0
  7. gsvvcompressor/combinations/vq_xyz_draco.py +468 -0
  8. gsvvcompressor/combinations/vq_xyz_draco_2pass.py +156 -0
  9. gsvvcompressor/combinations/vq_xyz_zstd.py +106 -0
  10. gsvvcompressor/compress/__init__.py +5 -0
  11. gsvvcompressor/compress/zstd.py +144 -0
  12. gsvvcompressor/decoder.py +155 -0
  13. gsvvcompressor/deserializer.py +42 -0
  14. gsvvcompressor/draco/__init__.py +34 -0
  15. gsvvcompressor/draco/draco_decoder.exe +0 -0
  16. gsvvcompressor/draco/draco_encoder.exe +0 -0
  17. gsvvcompressor/draco/dracoreduced3dgs.cp310-win_amd64.pyd +0 -0
  18. gsvvcompressor/draco/interface.py +339 -0
  19. gsvvcompressor/draco/serialize.py +235 -0
  20. gsvvcompressor/draco/twopass.py +359 -0
  21. gsvvcompressor/encoder.py +122 -0
  22. gsvvcompressor/interframe/__init__.py +11 -0
  23. gsvvcompressor/interframe/combine.py +271 -0
  24. gsvvcompressor/interframe/decoder.py +99 -0
  25. gsvvcompressor/interframe/encoder.py +92 -0
  26. gsvvcompressor/interframe/interface.py +221 -0
  27. gsvvcompressor/interframe/twopass.py +226 -0
  28. gsvvcompressor/io/__init__.py +31 -0
  29. gsvvcompressor/io/bytes.py +103 -0
  30. gsvvcompressor/io/config.py +78 -0
  31. gsvvcompressor/io/gaussian_model.py +127 -0
  32. gsvvcompressor/movecameras.py +33 -0
  33. gsvvcompressor/payload.py +34 -0
  34. gsvvcompressor/serializer.py +42 -0
  35. gsvvcompressor/vq/__init__.py +15 -0
  36. gsvvcompressor/vq/interface.py +324 -0
  37. gsvvcompressor/vq/singlemask.py +127 -0
  38. gsvvcompressor/vq/twopass.py +1 -0
  39. gsvvcompressor/xyz/__init__.py +26 -0
  40. gsvvcompressor/xyz/dense.py +39 -0
  41. gsvvcompressor/xyz/interface.py +382 -0
  42. gsvvcompressor/xyz/knn.py +141 -0
  43. gsvvcompressor/xyz/quant.py +143 -0
  44. gsvvcompressor/xyz/size.py +44 -0
  45. gsvvcompressor/xyz/twopass.py +1 -0
  46. gsvvcompressor-1.2.0.dist-info/METADATA +690 -0
  47. gsvvcompressor-1.2.0.dist-info/RECORD +50 -0
  48. gsvvcompressor-1.2.0.dist-info/WHEEL +5 -0
  49. gsvvcompressor-1.2.0.dist-info/licenses/LICENSE +21 -0
  50. gsvvcompressor-1.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,382 @@
1
+ """
2
+ XYZ quantization-based inter-frame codec interface.
3
+
4
+ This module provides an inter-frame codec that only operates on the xyz
5
+ coordinates of a GaussianModel using quantization.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional, Self
10
+
11
+ import torch
12
+
13
+ from gaussian_splatting import GaussianModel
14
+
15
+ from ..payload import Payload
16
+ from ..interframe import InterframeEncoderInitConfig, InterframeCodecContext, InterframeCodecInterface
17
+ from .quant import XYZQuantConfig, compute_quant_config, quantize_xyz, dequantize_xyz
18
+
19
+
20
+ @dataclass
21
+ class XYZQuantInterframeCodecConfig(InterframeEncoderInitConfig):
22
+ """
23
+ Configuration parameters for XYZ quantization-based inter-frame codec.
24
+
25
+ This dataclass holds the initialization settings for xyz coordinate
26
+ quantization.
27
+
28
+ Attributes:
29
+ k: Which nearest neighbor to use for step size estimation (1 = nearest).
30
+ sample_size: Number of points to sample for NN estimation.
31
+ seed: Random seed for reproducible sampling.
32
+ quantile: Quantile of NN distances to use for dense scale estimation.
33
+ alpha: Scaling factor for step size.
34
+ min_step: Optional minimum step size.
35
+ max_step: Optional maximum step size.
36
+ tolerance: Tolerance for inter-frame change detection. Only coordinates
37
+ with absolute difference > tolerance are considered changed.
38
+ """
39
+ k: int = 1
40
+ sample_size: Optional[int] = 10000
41
+ seed: Optional[int] = 42
42
+ quantile: float = 0.05
43
+ alpha: float = 0.2
44
+ min_step: Optional[float] = None
45
+ max_step: Optional[float] = None
46
+ tolerance: int = 0
47
+
48
+
49
+ @dataclass
50
+ class XYZQuantInterframeCodecContext(InterframeCodecContext):
51
+ """
52
+ Context data for XYZ quantization-based inter-frame encoding/decoding.
53
+
54
+ This dataclass holds the quantization state including the quantization
55
+ configuration and quantized xyz coordinates.
56
+
57
+ Attributes:
58
+ quant_config: The quantization configuration (step_size, origin).
59
+ quantized_xyz: The quantized xyz coordinates, shape (N, 3), dtype int32.
60
+ tolerance: Tolerance for inter-frame change detection (encoder-only,
61
+ not used during decoding).
62
+ """
63
+ quant_config: XYZQuantConfig
64
+ quantized_xyz: torch.Tensor # shape (N, 3), dtype int32
65
+ tolerance: int = 0 # encoder-only
66
+
67
+
68
+ @dataclass
69
+ class XYZQuantKeyframePayload(Payload):
70
+ """
71
+ Payload for XYZ quantization keyframe data.
72
+
73
+ Contains the quantization configuration and full quantized xyz coordinates.
74
+
75
+ Attributes:
76
+ quant_config: The quantization configuration.
77
+ quantized_xyz: The quantized xyz coordinates.
78
+ """
79
+ quant_config: XYZQuantConfig
80
+ quantized_xyz: torch.Tensor
81
+
82
+ def to(self, device) -> Self:
83
+ """
84
+ Move the Payload to the specified device.
85
+
86
+ Args:
87
+ device: The target device (e.g., 'cpu', 'cuda', torch.device).
88
+
89
+ Returns:
90
+ A new XYZQuantKeyframePayload instance on the target device.
91
+ """
92
+ return XYZQuantKeyframePayload(
93
+ quant_config=XYZQuantConfig(
94
+ step_size=self.quant_config.step_size,
95
+ origin=self.quant_config.origin.to(device),
96
+ ),
97
+ quantized_xyz=self.quantized_xyz.to(device),
98
+ )
99
+
100
+
101
+ @dataclass
102
+ class XYZQuantInterframePayload(Payload):
103
+ """
104
+ Payload for XYZ quantization interframe data.
105
+
106
+ Contains only the changed quantized xyz coordinates for subsequent frames.
107
+ The quantization configuration is inherited from the keyframe context.
108
+
109
+ Attributes:
110
+ xyz_mask: Boolean tensor indicating which xyz values changed, shape (N,).
111
+ quantized_xyz: Only the changed quantized xyz values (sparse), shape (M, 3).
112
+ """
113
+ xyz_mask: torch.Tensor
114
+ quantized_xyz: torch.Tensor
115
+
116
+ def to(self, device) -> Self:
117
+ """
118
+ Move the Payload to the specified device.
119
+
120
+ Args:
121
+ device: The target device (e.g., 'cpu', 'cuda', torch.device).
122
+
123
+ Returns:
124
+ A new XYZQuantInterframePayload instance on the target device.
125
+ """
126
+ return XYZQuantInterframePayload(
127
+ xyz_mask=self.xyz_mask.to(device),
128
+ quantized_xyz=self.quantized_xyz.to(device),
129
+ )
130
+
131
+
132
+ class XYZQuantInterframeCodecInterface(InterframeCodecInterface):
133
+ """
134
+ XYZ quantization-based inter-frame encoding/decoding interface.
135
+
136
+ This interface uses coordinate quantization to compress the xyz coordinates
137
+ of GaussianModel. The keyframe computes the quantization configuration,
138
+ and subsequent frames use the same configuration to quantize their coordinates.
139
+
140
+ Only operates on xyz coordinates; other GaussianModel attributes are not modified.
141
+ """
142
+
143
+ def decode_interframe(
144
+ self,
145
+ payload: XYZQuantInterframePayload,
146
+ prev_context: XYZQuantInterframeCodecContext,
147
+ ) -> XYZQuantInterframeCodecContext:
148
+ """
149
+ Decode a delta payload to reconstruct the next frame's context.
150
+
151
+ Applies the changed xyz values from the payload to the previous context.
152
+
153
+ Args:
154
+ payload: The delta payload containing changed quantized xyz with mask.
155
+ prev_context: The context of the previous frame (contains quant_config).
156
+
157
+ Returns:
158
+ The reconstructed context for the current frame.
159
+ """
160
+ # Clone previous quantized_xyz and apply changes
161
+ new_quantized_xyz = prev_context.quantized_xyz.clone()
162
+ new_quantized_xyz[payload.xyz_mask] = payload.quantized_xyz
163
+
164
+ return XYZQuantInterframeCodecContext(
165
+ quant_config=prev_context.quant_config,
166
+ quantized_xyz=new_quantized_xyz,
167
+ # tolerance is encoder-only, decoder doesn't use it
168
+ )
169
+
170
+ def encode_interframe(
171
+ self,
172
+ prev_context: XYZQuantInterframeCodecContext,
173
+ next_context: XYZQuantInterframeCodecContext,
174
+ ) -> XYZQuantInterframePayload:
175
+ """
176
+ Encode the difference between two consecutive frames.
177
+
178
+ Compares prev and next contexts to find changed xyz coordinates and stores
179
+ only the changed values with their corresponding mask.
180
+
181
+ Args:
182
+ prev_context: The context of the previous frame.
183
+ next_context: The context of the next frame.
184
+
185
+ Returns:
186
+ A payload containing only changed quantized xyz with mask.
187
+ """
188
+ prev_xyz = prev_context.quantized_xyz
189
+ next_xyz = next_context.quantized_xyz
190
+ tolerance = prev_context.tolerance
191
+
192
+ # Find changed xyz coordinates (any coordinate difference > tolerance)
193
+ diff = (prev_xyz - next_xyz).abs()
194
+ mask = (diff > tolerance).any(dim=-1)
195
+ changed_xyz = next_xyz[mask]
196
+
197
+ return XYZQuantInterframePayload(
198
+ xyz_mask=mask,
199
+ quantized_xyz=changed_xyz,
200
+ )
201
+
202
+ def decode_keyframe(self, payload: XYZQuantKeyframePayload) -> XYZQuantInterframeCodecContext:
203
+ """
204
+ Decode a keyframe payload to create initial context.
205
+
206
+ Args:
207
+ payload: The keyframe payload containing quant_config and quantized xyz.
208
+
209
+ Returns:
210
+ The context for the first/key frame.
211
+ """
212
+ return XYZQuantInterframeCodecContext(
213
+ quant_config=payload.quant_config,
214
+ quantized_xyz=payload.quantized_xyz,
215
+ # tolerance is encoder-only, decoder doesn't use it
216
+ )
217
+
218
+ def decode_keyframe_for_encode(
219
+ self,
220
+ payload: XYZQuantKeyframePayload,
221
+ context: XYZQuantInterframeCodecContext,
222
+ ) -> XYZQuantInterframeCodecContext:
223
+ """
224
+ Decode a keyframe payload during encoding to avoid error accumulation.
225
+
226
+ Since the encode/decode round-trip is lossless for XYZ quantization
227
+ (the quantized data is simply copied), we can reuse the original context.
228
+ This preserves context.tolerance which is needed by the encoder.
229
+
230
+ Args:
231
+ payload: The keyframe payload that was just encoded.
232
+ context: The original context used for encoding this keyframe.
233
+
234
+ Returns:
235
+ The reconstructed context (same as original for this codec).
236
+ """
237
+ # Round-trip is lossless, reuse original context (including tolerance)
238
+ return context
239
+
240
+ def decode_interframe_for_encode(
241
+ self,
242
+ payload: XYZQuantInterframePayload,
243
+ prev_context: XYZQuantInterframeCodecContext,
244
+ ) -> XYZQuantInterframeCodecContext:
245
+ """
246
+ Decode an interframe payload during encoding to avoid error accumulation.
247
+
248
+ Similar to decode_interframe, but preserves tolerance from prev_context
249
+ since it's needed by the encoder for subsequent frames.
250
+
251
+ Args:
252
+ payload: The interframe payload that was just encoded.
253
+ prev_context: The previous frame's context (reconstructed version).
254
+
255
+ Returns:
256
+ The reconstructed context with tolerance preserved.
257
+ """
258
+ # Clone previous quantized_xyz and apply changes
259
+ new_quantized_xyz = prev_context.quantized_xyz.clone()
260
+ new_quantized_xyz[payload.xyz_mask] = payload.quantized_xyz
261
+
262
+ return XYZQuantInterframeCodecContext(
263
+ quant_config=prev_context.quant_config,
264
+ quantized_xyz=new_quantized_xyz,
265
+ tolerance=prev_context.tolerance, # preserve for encoder
266
+ )
267
+
268
+ def encode_keyframe(self, context: XYZQuantInterframeCodecContext) -> XYZQuantKeyframePayload:
269
+ """
270
+ Encode the first frame as a keyframe.
271
+
272
+ Args:
273
+ context: The context of the first frame.
274
+
275
+ Returns:
276
+ A payload containing the quant_config and quantized xyz.
277
+ """
278
+ return XYZQuantKeyframePayload(
279
+ quant_config=context.quant_config,
280
+ quantized_xyz=context.quantized_xyz,
281
+ # tolerance is encoder-only, not needed in payload
282
+ )
283
+
284
+ def keyframe_to_context(
285
+ self,
286
+ frame: GaussianModel,
287
+ init_config: XYZQuantInterframeCodecConfig,
288
+ ) -> XYZQuantInterframeCodecContext:
289
+ """
290
+ Convert a keyframe to a XYZQuantInterframeCodecContext.
291
+
292
+ Computes quantization configuration from the frame's xyz coordinates
293
+ and quantizes them.
294
+
295
+ Args:
296
+ frame: The GaussianModel frame to convert.
297
+ init_config: Configuration parameters for quantization.
298
+
299
+ Returns:
300
+ The corresponding XYZQuantInterframeCodecContext representation.
301
+ """
302
+ xyz = frame.get_xyz
303
+
304
+ # Compute quantization configuration from xyz coordinates
305
+ quant_config = compute_quant_config(
306
+ points=xyz,
307
+ k=init_config.k,
308
+ sample_size=init_config.sample_size,
309
+ seed=init_config.seed,
310
+ quantile=init_config.quantile,
311
+ alpha=init_config.alpha,
312
+ min_step=init_config.min_step,
313
+ max_step=init_config.max_step,
314
+ )
315
+
316
+ # Quantize xyz coordinates
317
+ quantized_xyz = quantize_xyz(xyz, quant_config)
318
+
319
+ return XYZQuantInterframeCodecContext(
320
+ quant_config=quant_config,
321
+ quantized_xyz=quantized_xyz,
322
+ tolerance=init_config.tolerance,
323
+ )
324
+
325
+ def interframe_to_context(
326
+ self,
327
+ frame: GaussianModel,
328
+ prev_context: XYZQuantInterframeCodecContext,
329
+ ) -> XYZQuantInterframeCodecContext:
330
+ """
331
+ Convert a frame to a XYZQuantInterframeCodecContext using the previous context's config.
332
+
333
+ Uses the quantization config from the previous context to quantize xyz coordinates.
334
+
335
+ Args:
336
+ frame: The GaussianModel frame to convert.
337
+ prev_context: The context from the previous frame.
338
+
339
+ Returns:
340
+ The corresponding XYZQuantInterframeCodecContext representation.
341
+ """
342
+ xyz = frame.get_xyz
343
+
344
+ # Quantize xyz using the existing quant_config from keyframe
345
+ quantized_xyz = quantize_xyz(xyz, prev_context.quant_config)
346
+
347
+ return XYZQuantInterframeCodecContext(
348
+ quant_config=prev_context.quant_config,
349
+ quantized_xyz=quantized_xyz,
350
+ tolerance=prev_context.tolerance,
351
+ )
352
+
353
+ def context_to_frame(
354
+ self,
355
+ context: XYZQuantInterframeCodecContext,
356
+ frame: GaussianModel,
357
+ ) -> GaussianModel:
358
+ """
359
+ Convert a XYZQuantInterframeCodecContext back to a GaussianModel frame.
360
+
361
+ Dequantizes the xyz coordinates and sets them on the frame.
362
+ Only modifies xyz coordinates; other attributes are not touched.
363
+
364
+ Args:
365
+ context: The XYZQuantInterframeCodecContext to convert.
366
+ frame: An empty GaussianModel or one from previous pipeline steps.
367
+ This frame will be modified in-place with the xyz data.
368
+
369
+ Returns:
370
+ The modified GaussianModel with the xyz data.
371
+ """
372
+ # Dequantize xyz coordinates
373
+ xyz = dequantize_xyz(
374
+ context.quantized_xyz,
375
+ context.quant_config,
376
+ dtype=frame.get_xyz.dtype,
377
+ )
378
+
379
+ # Set xyz on the frame
380
+ frame._xyz = torch.nn.Parameter(xyz.to(frame.get_xyz.device))
381
+
382
+ return frame
@@ -0,0 +1,141 @@
1
+ """
2
+ K-nearest neighbor distance computation for 3D point clouds.
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from scipy.spatial import cKDTree
10
+
11
+
12
+ def compute_nn_distances(
13
+ points: torch.Tensor,
14
+ k: int = 1,
15
+ sample_size: Optional[int] = None,
16
+ seed: Optional[int] = None,
17
+ ) -> torch.Tensor:
18
+ """
19
+ Compute k-th nearest neighbor distances for points.
20
+
21
+ Uses scipy's cKDTree for efficient O(N log N) computation.
22
+ Optionally samples a subset of points to reduce computation for very large point clouds.
23
+
24
+ Args:
25
+ points: Point coordinates, shape (N, 3).
26
+ k: Which nearest neighbor distance to compute (1 = nearest, 2 = second nearest, etc.).
27
+ sample_size: If provided, randomly sample this many points for distance estimation.
28
+ If None or >= N, use all points.
29
+ seed: Random seed for reproducible sampling.
30
+
31
+ Returns:
32
+ Tensor of k-th nearest neighbor distances, shape (M,) where M = min(sample_size, N).
33
+
34
+ Raises:
35
+ ValueError: If there are not enough points to compute k-th neighbor (n_points <= k).
36
+ """
37
+ device = points.device
38
+ points_np = points.detach().cpu().numpy().astype(np.float64)
39
+ n_points = points_np.shape[0]
40
+
41
+ if n_points <= k:
42
+ raise ValueError(
43
+ f"Not enough points to compute {k}-th nearest neighbor. "
44
+ f"Got {n_points} points, need at least {k + 1}."
45
+ )
46
+
47
+ # Optionally subsample for large point clouds
48
+ if sample_size is not None and sample_size < n_points:
49
+ rng = np.random.default_rng(seed)
50
+ indices = rng.choice(n_points, size=sample_size, replace=False)
51
+ query_points = points_np[indices]
52
+ else:
53
+ query_points = points_np
54
+
55
+ tree = cKDTree(points_np)
56
+ # Query k+1 neighbors because the first neighbor is the point itself
57
+ distances, _ = tree.query(query_points, k=k + 1)
58
+
59
+ # Extract the k-th neighbor distance (index k, since index 0 is self with distance 0)
60
+ nn_distances = distances[:, k]
61
+
62
+ return torch.from_numpy(nn_distances).to(device=device, dtype=points.dtype)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ import time
67
+
68
+ print("=" * 60)
69
+ print("Testing compute_nn_distances")
70
+ print("=" * 60)
71
+
72
+ # Test 1: Simple grid points
73
+ print("\n[Test 1] Simple 2x2x2 grid with spacing 1.0")
74
+ grid_points = torch.tensor([
75
+ [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
76
+ [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1],
77
+ ], dtype=torch.float32)
78
+ nn_dist = compute_nn_distances(grid_points, k=1)
79
+ print(f" Points shape: {grid_points.shape}")
80
+ print(f" NN distances: {nn_dist}")
81
+ print(f" Expected: all 1.0 (grid spacing)")
82
+ assert torch.allclose(nn_dist, torch.ones(8)), "Test 1 failed!"
83
+ print(" PASSED")
84
+
85
+ # Test 2: Different k values
86
+ print("\n[Test 2] k=2 on grid (second nearest neighbor)")
87
+ nn_dist_k2 = compute_nn_distances(grid_points, k=2)
88
+ print(f" NN distances (k=2): {nn_dist_k2}")
89
+ print(f" Expected: all 1.0 (each point has multiple neighbors at distance 1)")
90
+ assert torch.allclose(nn_dist_k2, torch.ones(8)), "Test 2 failed!"
91
+ print(" PASSED")
92
+
93
+ # Test 3: k=4 should give sqrt(2) for corner points (face diagonal)
94
+ print("\n[Test 3] k=4 on grid (fourth nearest neighbor)")
95
+ nn_dist_k4 = compute_nn_distances(grid_points, k=4)
96
+ print(f" NN distances (k=4): {nn_dist_k4}")
97
+ print(f" Expected: all sqrt(2) ~ 1.414 (face diagonal neighbors)")
98
+ # Each corner has 3 edge neighbors at dist 1, then 3 face-diagonal neighbors at sqrt(2)
99
+ assert torch.allclose(nn_dist_k4, torch.full((8,), 2**0.5), atol=1e-5), "Test 3 failed!"
100
+ print(" PASSED")
101
+
102
+ # Test 4: Sampling
103
+ print("\n[Test 4] Sampling from larger point cloud")
104
+ large_points = torch.rand(1000, 3)
105
+ nn_dist_full = compute_nn_distances(large_points, k=1)
106
+ nn_dist_sampled = compute_nn_distances(large_points, k=1, sample_size=100, seed=42)
107
+ print(f" Full cloud: {len(nn_dist_full)} distances")
108
+ print(f" Sampled: {len(nn_dist_sampled)} distances")
109
+ print(f" Full mean: {nn_dist_full.mean():.4f}, Sampled mean: {nn_dist_sampled.mean():.4f}")
110
+ assert len(nn_dist_sampled) == 100, "Test 4 failed!"
111
+ print(" PASSED")
112
+
113
+ # Test 5: Error on insufficient points
114
+ print("\n[Test 5] Error when n_points <= k")
115
+ try:
116
+ compute_nn_distances(torch.rand(2, 3), k=2)
117
+ print(" FAILED - should have raised ValueError")
118
+ except ValueError as e:
119
+ print(f" Caught expected error: {e}")
120
+ print(" PASSED")
121
+
122
+ # Test 6: Reproducibility with seed
123
+ print("\n[Test 6] Reproducibility with seed")
124
+ dist1 = compute_nn_distances(large_points, k=1, sample_size=50, seed=123)
125
+ dist2 = compute_nn_distances(large_points, k=1, sample_size=50, seed=123)
126
+ assert torch.allclose(dist1, dist2), "Test 6 failed!"
127
+ print(" Same seed gives same results: PASSED")
128
+
129
+ # Test 7: Performance test
130
+ print("\n[Test 7] Performance test")
131
+ for n in [10_000, 100_000, 500_000]:
132
+ points = torch.rand(n, 3)
133
+
134
+ start = time.perf_counter()
135
+ _ = compute_nn_distances(points, k=1, sample_size=10000, seed=42)
136
+ elapsed = time.perf_counter() - start
137
+ print(f" {n:>7,} points (sampled 10k): {elapsed:.3f}s")
138
+
139
+ print("\n" + "=" * 60)
140
+ print("All tests passed!")
141
+ print("=" * 60)
@@ -0,0 +1,143 @@
1
+ """
2
+ XYZ coordinate quantization and dequantization.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import torch
9
+
10
+ from .knn import compute_nn_distances
11
+ from .dense import compute_dense_scale
12
+ from .size import compute_step_size
13
+
14
+
15
+ @dataclass
16
+ class XYZQuantConfig:
17
+ """
18
+ Configuration for XYZ coordinate quantization.
19
+
20
+ Attributes:
21
+ step_size: The quantization step size (delta).
22
+ origin: The origin offset for quantization (median of coordinates).
23
+ """
24
+ step_size: float
25
+ origin: torch.Tensor # shape (3,)
26
+
27
+
28
+ def compute_quant_config(
29
+ points: torch.Tensor,
30
+ k: int = 1,
31
+ sample_size: Optional[int] = 10000,
32
+ seed: Optional[int] = 42,
33
+ quantile: float = 0.05,
34
+ alpha: float = 0.2,
35
+ min_step: Optional[float] = None,
36
+ max_step: Optional[float] = None,
37
+ ) -> XYZQuantConfig:
38
+ """
39
+ Compute quantization configuration from points.
40
+
41
+ Internally computes step_size via:
42
+ 1. compute_nn_distances() - k-th nearest neighbor distances
43
+ 2. compute_dense_scale() - low quantile of NN distances
44
+ 3. compute_step_size() - alpha * dense_scale
45
+
46
+ Args:
47
+ points: Point coordinates, shape (N, 3).
48
+ k: Which nearest neighbor to use (1 = nearest).
49
+ sample_size: Number of points to sample for NN estimation.
50
+ Set to None to use all points. Default 10000 balances speed and accuracy.
51
+ seed: Random seed for reproducible sampling.
52
+ quantile: Quantile of NN distances to use (e.g., 0.01 or 0.05).
53
+ Lower values target denser regions.
54
+ alpha: Scaling factor for step size (typical range: 0.1 to 0.25).
55
+ Smaller values preserve more detail but increase data size.
56
+ min_step: Optional minimum step size to prevent extremely fine quantization.
57
+ max_step: Optional maximum step size to prevent overly coarse quantization.
58
+
59
+ Returns:
60
+ XYZQuantConfig with step size and origin (median of coordinates).
61
+ """
62
+ nn_distances = compute_nn_distances(points, k=k, sample_size=sample_size, seed=seed)
63
+ dense_scale = compute_dense_scale(nn_distances, quantile=quantile)
64
+ step_size = compute_step_size(dense_scale, alpha=alpha, min_step=min_step, max_step=max_step)
65
+ origin = points.median(dim=0).values
66
+ return XYZQuantConfig(step_size=step_size, origin=origin)
67
+
68
+
69
+ def quantize_xyz(
70
+ points: torch.Tensor,
71
+ config: XYZQuantConfig,
72
+ ) -> torch.Tensor:
73
+ """
74
+ Quantize XYZ coordinates using the given configuration.
75
+
76
+ Applies the formula:
77
+ quantized = round((points - origin) / step_size)
78
+
79
+ Args:
80
+ points: Point coordinates, shape (N, 3).
81
+ config: Quantization configuration (step size and origin).
82
+
83
+ Returns:
84
+ Quantized coordinates as integers, shape (N, 3), dtype torch.int32.
85
+ """
86
+ normalized = (points - config.origin) / config.step_size
87
+ quantized = torch.round(normalized).to(torch.int32)
88
+ return quantized
89
+
90
+
91
+ def dequantize_xyz(
92
+ quantized: torch.Tensor,
93
+ config: XYZQuantConfig,
94
+ dtype: torch.dtype = torch.float32,
95
+ ) -> torch.Tensor:
96
+ """
97
+ Dequantize XYZ coordinates back to floating point.
98
+
99
+ Applies the formula:
100
+ points = quantized * step_size + origin
101
+
102
+ Args:
103
+ quantized: Quantized coordinates, shape (N, 3), dtype torch.int32.
104
+ config: Quantization configuration (step size and origin).
105
+ dtype: Output dtype for the dequantized coordinates.
106
+
107
+ Returns:
108
+ Dequantized coordinates, shape (N, 3).
109
+ """
110
+ return quantized.to(dtype) * config.step_size + config.origin
111
+
112
+
113
+ def estimate_quantization_error(
114
+ points: torch.Tensor,
115
+ config: XYZQuantConfig,
116
+ ) -> dict:
117
+ """
118
+ Estimate the quantization error statistics.
119
+
120
+ Useful for debugging and validating quantization parameters.
121
+
122
+ Args:
123
+ points: Original point coordinates, shape (N, 3).
124
+ config: Quantization configuration.
125
+
126
+ Returns:
127
+ Dictionary with error statistics:
128
+ - 'max_error': Maximum absolute error across all coordinates.
129
+ - 'mean_error': Mean absolute error.
130
+ - 'rmse': Root mean squared error.
131
+ - 'relative_max_error': Max error relative to step size.
132
+ """
133
+ quantized = quantize_xyz(points, config)
134
+ reconstructed = dequantize_xyz(quantized, config, dtype=points.dtype)
135
+
136
+ error = (points - reconstructed).abs()
137
+
138
+ return {
139
+ 'max_error': error.max().item(),
140
+ 'mean_error': error.mean().item(),
141
+ 'rmse': torch.sqrt((error ** 2).mean()).item(),
142
+ 'relative_max_error': error.max().item() / config.step_size,
143
+ }