gsply 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gsply/writer.py ADDED
@@ -0,0 +1,1019 @@
1
+ """Writing functions for Gaussian splatting PLY files.
2
+
3
+ This module provides ultra-fast writing of Gaussian splatting PLY files
4
+ in uncompressed format, with compressed format support planned.
5
+
6
+ API Examples:
7
+ >>> from gsply import plywrite
8
+ >>> plywrite("output.ply", means, scales, quats, opacities, sh0, shN)
9
+
10
+ >>> # Or use format-specific writers
11
+ >>> from gsply.writer import write_uncompressed
12
+ >>> write_uncompressed("output.ply", means, scales, quats, opacities, sh0, shN)
13
+
14
+ Performance:
15
+ - Write uncompressed: 5-10ms for 50K Gaussians
16
+ - Write compressed: Not yet implemented
17
+ """
18
+
19
+ import numpy as np
20
+ from pathlib import Path
21
+ from typing import Optional, Union
22
+ import logging
23
+
24
+ # Try to import numba for JIT optimization (optional)
25
+ try:
26
+ from numba import jit
27
+ import numba
28
+ HAS_NUMBA = True
29
+ except ImportError:
30
+ HAS_NUMBA = False
31
+ # Fallback: no-op decorator
32
+ def jit(*args, **kwargs):
33
+ def decorator(func):
34
+ return func
35
+ return decorator
36
+ # Mock numba module for prange fallback
37
+ class _MockNumba:
38
+ @staticmethod
39
+ def prange(n):
40
+ return range(n)
41
+ numba = _MockNumba()
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ # ======================================================================================
47
+ # JIT-COMPILED COMPRESSION FUNCTIONS
48
+ # ======================================================================================
49
+
50
+ @jit(nopython=True, parallel=True, fastmath=True, cache=True)
51
+ def _pack_positions_jit(sorted_means, chunk_indices, min_x, min_y, min_z, max_x, max_y, max_z):
52
+ """JIT-compiled position quantization and packing (11-10-11 bits) with parallel processing.
53
+
54
+ Args:
55
+ sorted_means: (N, 3) float32 array of positions
56
+ chunk_indices: int32 array of chunk indices for each vertex
57
+ min_x, min_y, min_z: chunk minimum bounds
58
+ max_x, max_y, max_z: chunk maximum bounds
59
+
60
+ Returns:
61
+ packed: (N,) uint32 array of packed positions
62
+ """
63
+ n = len(sorted_means)
64
+ packed = np.zeros(n, dtype=np.uint32)
65
+
66
+ for i in numba.prange(n):
67
+ chunk_idx = chunk_indices[i]
68
+
69
+ # Compute ranges (handle zero range)
70
+ range_x = max_x[chunk_idx] - min_x[chunk_idx]
71
+ range_y = max_y[chunk_idx] - min_y[chunk_idx]
72
+ range_z = max_z[chunk_idx] - min_z[chunk_idx]
73
+
74
+ if range_x == 0.0:
75
+ range_x = 1.0
76
+ if range_y == 0.0:
77
+ range_y = 1.0
78
+ if range_z == 0.0:
79
+ range_z = 1.0
80
+
81
+ # Normalize to [0, 1]
82
+ norm_x = (sorted_means[i, 0] - min_x[chunk_idx]) / range_x
83
+ norm_y = (sorted_means[i, 1] - min_y[chunk_idx]) / range_y
84
+ norm_z = (sorted_means[i, 2] - min_z[chunk_idx]) / range_z
85
+
86
+ # Clamp
87
+ norm_x = max(0.0, min(1.0, norm_x))
88
+ norm_y = max(0.0, min(1.0, norm_y))
89
+ norm_z = max(0.0, min(1.0, norm_z))
90
+
91
+ # Quantize
92
+ px = np.uint32(norm_x * 2047.0)
93
+ py = np.uint32(norm_y * 1023.0)
94
+ pz = np.uint32(norm_z * 2047.0)
95
+
96
+ # Pack (11-10-11 bits)
97
+ packed[i] = (px << 21) | (py << 11) | pz
98
+
99
+ return packed
100
+
101
+
102
+ @jit(nopython=True, parallel=True, fastmath=True, cache=True)
103
+ def _pack_scales_jit(sorted_scales, chunk_indices, min_sx, min_sy, min_sz, max_sx, max_sy, max_sz):
104
+ """JIT-compiled scale quantization and packing (11-10-11 bits) with parallel processing.
105
+
106
+ Args:
107
+ sorted_scales: (N, 3) float32 array of scales
108
+ chunk_indices: int32 array of chunk indices for each vertex
109
+ min_sx, min_sy, min_sz: chunk minimum scale bounds
110
+ max_sx, max_sy, max_sz: chunk maximum scale bounds
111
+
112
+ Returns:
113
+ packed: (N,) uint32 array of packed scales
114
+ """
115
+ n = len(sorted_scales)
116
+ packed = np.zeros(n, dtype=np.uint32)
117
+
118
+ for i in numba.prange(n):
119
+ chunk_idx = chunk_indices[i]
120
+
121
+ # Compute ranges (handle zero range)
122
+ range_sx = max_sx[chunk_idx] - min_sx[chunk_idx]
123
+ range_sy = max_sy[chunk_idx] - min_sy[chunk_idx]
124
+ range_sz = max_sz[chunk_idx] - min_sz[chunk_idx]
125
+
126
+ if range_sx == 0.0:
127
+ range_sx = 1.0
128
+ if range_sy == 0.0:
129
+ range_sy = 1.0
130
+ if range_sz == 0.0:
131
+ range_sz = 1.0
132
+
133
+ # Normalize to [0, 1]
134
+ norm_sx = (sorted_scales[i, 0] - min_sx[chunk_idx]) / range_sx
135
+ norm_sy = (sorted_scales[i, 1] - min_sy[chunk_idx]) / range_sy
136
+ norm_sz = (sorted_scales[i, 2] - min_sz[chunk_idx]) / range_sz
137
+
138
+ # Clamp
139
+ norm_sx = max(0.0, min(1.0, norm_sx))
140
+ norm_sy = max(0.0, min(1.0, norm_sy))
141
+ norm_sz = max(0.0, min(1.0, norm_sz))
142
+
143
+ # Quantize
144
+ sx = np.uint32(norm_sx * 2047.0)
145
+ sy = np.uint32(norm_sy * 1023.0)
146
+ sz = np.uint32(norm_sz * 2047.0)
147
+
148
+ # Pack (11-10-11 bits)
149
+ packed[i] = (sx << 21) | (sy << 11) | sz
150
+
151
+ return packed
152
+
153
+
154
+ @jit(nopython=True, parallel=True, fastmath=True, cache=True)
155
+ def _pack_colors_jit(sorted_sh0, sorted_opacities, chunk_indices, min_r, min_g, min_b, max_r, max_g, max_b, sh_c0):
156
+ """JIT-compiled color and opacity quantization and packing (8-8-8-8 bits) with parallel processing.
157
+
158
+ Args:
159
+ sorted_sh0: (N, 3) float32 array of SH0 coefficients
160
+ sorted_opacities: (N,) float32 array of opacities (logit space)
161
+ chunk_indices: int32 array of chunk indices for each vertex
162
+ min_r, min_g, min_b: chunk minimum color bounds
163
+ max_r, max_g, max_b: chunk maximum color bounds
164
+ sh_c0: SH constant for conversion
165
+
166
+ Returns:
167
+ packed: (N,) uint32 array of packed colors
168
+ """
169
+ n = len(sorted_sh0)
170
+ packed = np.zeros(n, dtype=np.uint32)
171
+
172
+ for i in numba.prange(n):
173
+ chunk_idx = chunk_indices[i]
174
+
175
+ # Convert SH0 to RGB
176
+ color_r = sorted_sh0[i, 0] * sh_c0 + 0.5
177
+ color_g = sorted_sh0[i, 1] * sh_c0 + 0.5
178
+ color_b = sorted_sh0[i, 2] * sh_c0 + 0.5
179
+
180
+ # Compute ranges (handle zero range)
181
+ range_r = max_r[chunk_idx] - min_r[chunk_idx]
182
+ range_g = max_g[chunk_idx] - min_g[chunk_idx]
183
+ range_b = max_b[chunk_idx] - min_b[chunk_idx]
184
+
185
+ if range_r == 0.0:
186
+ range_r = 1.0
187
+ if range_g == 0.0:
188
+ range_g = 1.0
189
+ if range_b == 0.0:
190
+ range_b = 1.0
191
+
192
+ # Normalize to [0, 1]
193
+ norm_r = (color_r - min_r[chunk_idx]) / range_r
194
+ norm_g = (color_g - min_g[chunk_idx]) / range_g
195
+ norm_b = (color_b - min_b[chunk_idx]) / range_b
196
+
197
+ # Clamp
198
+ norm_r = max(0.0, min(1.0, norm_r))
199
+ norm_g = max(0.0, min(1.0, norm_g))
200
+ norm_b = max(0.0, min(1.0, norm_b))
201
+
202
+ # Quantize colors
203
+ cr = np.uint32(norm_r * 255.0)
204
+ cg = np.uint32(norm_g * 255.0)
205
+ cb = np.uint32(norm_b * 255.0)
206
+
207
+ # Opacity: logit to linear
208
+ opacity_linear = 1.0 / (1.0 + np.exp(-sorted_opacities[i]))
209
+ opacity_linear = max(0.0, min(1.0, opacity_linear))
210
+ co = np.uint32(opacity_linear * 255.0)
211
+
212
+ # Pack (8-8-8-8 bits)
213
+ packed[i] = (cr << 24) | (cg << 16) | (cb << 8) | co
214
+
215
+ return packed
216
+
217
+
218
+ @jit(nopython=True, parallel=True, fastmath=True, cache=True)
219
+ def _pack_quaternions_jit(sorted_quats):
220
+ """JIT-compiled quaternion normalization and packing (2+10-10-10 bits, smallest-three) with parallel processing.
221
+
222
+ Args:
223
+ sorted_quats: (N, 4) float32 array of quaternions
224
+
225
+ Returns:
226
+ packed: (N,) uint32 array of packed quaternions
227
+ """
228
+ n = len(sorted_quats)
229
+ packed = np.zeros(n, dtype=np.uint32)
230
+ norm_factor = np.sqrt(2.0) * 0.5
231
+
232
+ for i in numba.prange(n):
233
+ # Normalize quaternion
234
+ quat = sorted_quats[i]
235
+ norm = np.sqrt(quat[0]*quat[0] + quat[1]*quat[1] + quat[2]*quat[2] + quat[3]*quat[3])
236
+ if norm > 0:
237
+ quat = quat / norm
238
+
239
+ # Find largest component by absolute value
240
+ abs_vals = np.abs(quat)
241
+ largest_idx = 0
242
+ largest_val = abs_vals[0]
243
+ for j in range(1, 4):
244
+ if abs_vals[j] > largest_val:
245
+ largest_val = abs_vals[j]
246
+ largest_idx = j
247
+
248
+ # Flip quaternion if largest component is negative
249
+ if quat[largest_idx] < 0:
250
+ quat = -quat
251
+
252
+ # Extract three smaller components
253
+ three_components = np.zeros(3, dtype=np.float32)
254
+ idx = 0
255
+ for j in range(4):
256
+ if j != largest_idx:
257
+ three_components[idx] = quat[j]
258
+ idx += 1
259
+
260
+ # Normalize to [0, 1] for quantization
261
+ qa_norm = three_components[0] * norm_factor + 0.5
262
+ qb_norm = three_components[1] * norm_factor + 0.5
263
+ qc_norm = three_components[2] * norm_factor + 0.5
264
+
265
+ # Clamp
266
+ qa_norm = max(0.0, min(1.0, qa_norm))
267
+ qb_norm = max(0.0, min(1.0, qb_norm))
268
+ qc_norm = max(0.0, min(1.0, qc_norm))
269
+
270
+ # Quantize
271
+ qa_int = np.uint32(qa_norm * 1023.0)
272
+ qb_int = np.uint32(qb_norm * 1023.0)
273
+ qc_int = np.uint32(qc_norm * 1023.0)
274
+
275
+ # Pack (2 bits for which + 10+10+10 bits)
276
+ packed[i] = (np.uint32(largest_idx) << 30) | (qa_int << 20) | (qb_int << 10) | qc_int
277
+
278
+ return packed
279
+
280
+
281
+ @jit(nopython=True, fastmath=True, cache=True)
282
+ def _radix_sort_by_chunks(chunk_indices, num_chunks):
283
+ """Radix sort (counting sort) for chunk indices (4x faster than argsort).
284
+
285
+ Since chunk indices are small integers (0 to num_chunks-1), counting sort
286
+ achieves O(n) complexity vs O(n log n) for comparison-based sorting.
287
+
288
+ Args:
289
+ chunk_indices: (N,) int32 array of chunk indices
290
+ num_chunks: number of unique chunks
291
+
292
+ Returns:
293
+ sort_indices: (N,) int32 array of indices that would sort the data
294
+ """
295
+ n = len(chunk_indices)
296
+
297
+ # Count occurrences of each chunk
298
+ counts = np.zeros(num_chunks, dtype=np.int32)
299
+ for i in range(n):
300
+ counts[chunk_indices[i]] += 1
301
+
302
+ # Compute starting positions for each chunk
303
+ offsets = np.zeros(num_chunks, dtype=np.int32)
304
+ for i in range(1, num_chunks):
305
+ offsets[i] = offsets[i-1] + counts[i-1]
306
+
307
+ # Build sorted index array
308
+ sort_indices = np.empty(n, dtype=np.int32)
309
+ positions = offsets.copy()
310
+ for i in range(n):
311
+ chunk_id = chunk_indices[i]
312
+ sort_indices[positions[chunk_id]] = i
313
+ positions[chunk_id] += 1
314
+
315
+ return sort_indices
316
+
317
+
318
+ @jit(nopython=True, parallel=False, fastmath=True, cache=True)
319
+ def _compute_chunk_bounds_jit(sorted_means, sorted_scales, sorted_sh0,
320
+ chunk_starts, chunk_ends, sh_c0):
321
+ """JIT-compiled chunk bounds computation (9x faster than Python loop).
322
+
323
+ Computes min/max bounds for positions, scales, and colors for each chunk.
324
+ This is the main bottleneck in compressed write (~90ms -> ~10ms).
325
+
326
+ Args:
327
+ sorted_means: (N, 3) float32 array of positions
328
+ sorted_scales: (N, 3) float32 array of scales
329
+ sorted_sh0: (N, 3) float32 array of SH0 coefficients
330
+ chunk_starts: (num_chunks,) int array of chunk start indices
331
+ chunk_ends: (num_chunks,) int array of chunk end indices
332
+ sh_c0: SH constant for RGB conversion
333
+
334
+ Returns:
335
+ bounds: (num_chunks, 18) float32 array with layout:
336
+ [0:6] - min_x, min_y, min_z, max_x, max_y, max_z
337
+ [6:12] - min_scale_x/y/z, max_scale_x/y/z (clamped to [-20,20])
338
+ [12:18] - min_r, min_g, min_b, max_r, max_g, max_b
339
+ """
340
+ num_chunks = len(chunk_starts)
341
+ bounds = np.zeros((num_chunks, 18), dtype=np.float32)
342
+
343
+ for chunk_idx in range(num_chunks):
344
+ start = chunk_starts[chunk_idx]
345
+ end = chunk_ends[chunk_idx]
346
+
347
+ if start >= end: # Empty chunk
348
+ continue
349
+
350
+ # Initialize with first element
351
+ bounds[chunk_idx, 0] = sorted_means[start, 0] # min_x
352
+ bounds[chunk_idx, 1] = sorted_means[start, 1] # min_y
353
+ bounds[chunk_idx, 2] = sorted_means[start, 2] # min_z
354
+ bounds[chunk_idx, 3] = sorted_means[start, 0] # max_x
355
+ bounds[chunk_idx, 4] = sorted_means[start, 1] # max_y
356
+ bounds[chunk_idx, 5] = sorted_means[start, 2] # max_z
357
+
358
+ bounds[chunk_idx, 6] = sorted_scales[start, 0] # min_scale_x
359
+ bounds[chunk_idx, 7] = sorted_scales[start, 1] # min_scale_y
360
+ bounds[chunk_idx, 8] = sorted_scales[start, 2] # min_scale_z
361
+ bounds[chunk_idx, 9] = sorted_scales[start, 0] # max_scale_x
362
+ bounds[chunk_idx, 10] = sorted_scales[start, 1] # max_scale_y
363
+ bounds[chunk_idx, 11] = sorted_scales[start, 2] # max_scale_z
364
+
365
+ # Convert SH0 to RGB for first element
366
+ color_r = sorted_sh0[start, 0] * sh_c0 + 0.5
367
+ color_g = sorted_sh0[start, 1] * sh_c0 + 0.5
368
+ color_b = sorted_sh0[start, 2] * sh_c0 + 0.5
369
+
370
+ bounds[chunk_idx, 12] = color_r # min_r
371
+ bounds[chunk_idx, 13] = color_g # min_g
372
+ bounds[chunk_idx, 14] = color_b # min_b
373
+ bounds[chunk_idx, 15] = color_r # max_r
374
+ bounds[chunk_idx, 16] = color_g # max_g
375
+ bounds[chunk_idx, 17] = color_b # max_b
376
+
377
+ # Process remaining elements in chunk
378
+ for i in range(start + 1, end):
379
+ # Position bounds
380
+ bounds[chunk_idx, 0] = min(bounds[chunk_idx, 0], sorted_means[i, 0])
381
+ bounds[chunk_idx, 1] = min(bounds[chunk_idx, 1], sorted_means[i, 1])
382
+ bounds[chunk_idx, 2] = min(bounds[chunk_idx, 2], sorted_means[i, 2])
383
+ bounds[chunk_idx, 3] = max(bounds[chunk_idx, 3], sorted_means[i, 0])
384
+ bounds[chunk_idx, 4] = max(bounds[chunk_idx, 4], sorted_means[i, 1])
385
+ bounds[chunk_idx, 5] = max(bounds[chunk_idx, 5], sorted_means[i, 2])
386
+
387
+ # Scale bounds
388
+ bounds[chunk_idx, 6] = min(bounds[chunk_idx, 6], sorted_scales[i, 0])
389
+ bounds[chunk_idx, 7] = min(bounds[chunk_idx, 7], sorted_scales[i, 1])
390
+ bounds[chunk_idx, 8] = min(bounds[chunk_idx, 8], sorted_scales[i, 2])
391
+ bounds[chunk_idx, 9] = max(bounds[chunk_idx, 9], sorted_scales[i, 0])
392
+ bounds[chunk_idx, 10] = max(bounds[chunk_idx, 10], sorted_scales[i, 1])
393
+ bounds[chunk_idx, 11] = max(bounds[chunk_idx, 11], sorted_scales[i, 2])
394
+
395
+ # Color bounds (convert SH0 to RGB)
396
+ color_r = sorted_sh0[i, 0] * sh_c0 + 0.5
397
+ color_g = sorted_sh0[i, 1] * sh_c0 + 0.5
398
+ color_b = sorted_sh0[i, 2] * sh_c0 + 0.5
399
+
400
+ bounds[chunk_idx, 12] = min(bounds[chunk_idx, 12], color_r)
401
+ bounds[chunk_idx, 13] = min(bounds[chunk_idx, 13], color_g)
402
+ bounds[chunk_idx, 14] = min(bounds[chunk_idx, 14], color_b)
403
+ bounds[chunk_idx, 15] = max(bounds[chunk_idx, 15], color_r)
404
+ bounds[chunk_idx, 16] = max(bounds[chunk_idx, 16], color_g)
405
+ bounds[chunk_idx, 17] = max(bounds[chunk_idx, 17], color_b)
406
+
407
+ # Clamp scale bounds to [-20, 20] (matches splat-transform)
408
+ for j in range(6, 12):
409
+ bounds[chunk_idx, j] = max(-20.0, min(20.0, bounds[chunk_idx, j]))
410
+
411
+ return bounds
412
+
413
+
414
+ # ======================================================================================
415
+ # UNCOMPRESSED PLY WRITER
416
+ # ======================================================================================
417
+
418
+ def write_uncompressed(
419
+ file_path: Union[str, Path],
420
+ means: np.ndarray,
421
+ scales: np.ndarray,
422
+ quats: np.ndarray,
423
+ opacities: np.ndarray,
424
+ sh0: np.ndarray,
425
+ shN: Optional[np.ndarray] = None,
426
+ validate: bool = True,
427
+ ) -> None:
428
+ """Write uncompressed Gaussian splatting PLY file.
429
+
430
+ Uses direct file writing for maximum performance (~5-10ms for 50K Gaussians).
431
+ Automatically determines SH degree from shN shape.
432
+
433
+ Args:
434
+ file_path: Output PLY file path
435
+ means: (N, 3) - xyz positions
436
+ scales: (N, 3) - scale parameters
437
+ quats: (N, 4) - rotation quaternions
438
+ opacities: (N,) - opacity values
439
+ sh0: (N, 3) - DC spherical harmonics
440
+ shN: (N, K, 3) or (N, K*3) - Higher-order SH coefficients (optional)
441
+ validate: If True, validate input shapes (default True). Disable for trusted data.
442
+
443
+ Performance:
444
+ Direct numpy.tofile() achieves ~5-10ms for 50K Gaussians
445
+
446
+ Example:
447
+ >>> write_uncompressed("output.ply", means, scales, quats, opacities, sh0, shN)
448
+ >>> # Or without higher-order SH
449
+ >>> write_uncompressed("output.ply", means, scales, quats, opacities, sh0)
450
+ >>> # Skip validation for trusted data (5-10% faster)
451
+ >>> write_uncompressed("output.ply", means, scales, quats, opacities, sh0, validate=False)
452
+ """
453
+ file_path = Path(file_path)
454
+
455
+ # Validate and normalize inputs
456
+ if not isinstance(means, np.ndarray):
457
+ means = np.asarray(means, dtype=np.float32)
458
+ if not isinstance(scales, np.ndarray):
459
+ scales = np.asarray(scales, dtype=np.float32)
460
+ if not isinstance(quats, np.ndarray):
461
+ quats = np.asarray(quats, dtype=np.float32)
462
+ if not isinstance(opacities, np.ndarray):
463
+ opacities = np.asarray(opacities, dtype=np.float32)
464
+ if not isinstance(sh0, np.ndarray):
465
+ sh0 = np.asarray(sh0, dtype=np.float32)
466
+ if shN is not None and not isinstance(shN, np.ndarray):
467
+ shN = np.asarray(shN, dtype=np.float32)
468
+
469
+ # Only convert dtype if needed (avoids copy when already float32)
470
+ if means.dtype != np.float32:
471
+ means = means.astype(np.float32, copy=False)
472
+ if scales.dtype != np.float32:
473
+ scales = scales.astype(np.float32, copy=False)
474
+ if quats.dtype != np.float32:
475
+ quats = quats.astype(np.float32, copy=False)
476
+ if opacities.dtype != np.float32:
477
+ opacities = opacities.astype(np.float32, copy=False)
478
+ if sh0.dtype != np.float32:
479
+ sh0 = sh0.astype(np.float32, copy=False)
480
+ if shN is not None and shN.dtype != np.float32:
481
+ shN = shN.astype(np.float32, copy=False)
482
+
483
+ num_gaussians = means.shape[0]
484
+
485
+ # Validate shapes (optional for trusted data)
486
+ if validate:
487
+ assert means.shape == (num_gaussians, 3), f"means must be (N, 3), got {means.shape}"
488
+ assert scales.shape == (num_gaussians, 3), f"scales must be (N, 3), got {scales.shape}"
489
+ assert quats.shape == (num_gaussians, 4), f"quats must be (N, 4), got {quats.shape}"
490
+ assert opacities.shape == (num_gaussians,), f"opacities must be (N,), got {opacities.shape}"
491
+ assert sh0.shape == (num_gaussians, 3), f"sh0 must be (N, 3), got {sh0.shape}"
492
+
493
+ # Use newaxis instead of reshape (creates view without overhead)
494
+ opacities = opacities[:, np.newaxis]
495
+
496
+ # Flatten shN if needed (from (N, K, 3) to (N, K*3))
497
+ if shN is not None and shN.ndim == 3:
498
+ N, K, C = shN.shape
499
+ if validate:
500
+ assert C == 3, f"shN must have shape (N, K, 3), got {shN.shape}"
501
+ shN = shN.reshape(N, K * 3)
502
+
503
+ # Build header
504
+ header_lines = [
505
+ "ply",
506
+ "format binary_little_endian 1.0",
507
+ f"element vertex {num_gaussians}",
508
+ "property float x",
509
+ "property float y",
510
+ "property float z",
511
+ ]
512
+
513
+ # Add SH0 properties
514
+ for i in range(3):
515
+ header_lines.append(f"property float f_dc_{i}")
516
+
517
+ # Add SHN properties if present
518
+ if shN is not None:
519
+ num_sh_rest = shN.shape[1]
520
+ for i in range(num_sh_rest):
521
+ header_lines.append(f"property float f_rest_{i}")
522
+
523
+ # Add remaining properties
524
+ header_lines.extend([
525
+ "property float opacity",
526
+ "property float scale_0",
527
+ "property float scale_1",
528
+ "property float scale_2",
529
+ "property float rot_0",
530
+ "property float rot_1",
531
+ "property float rot_2",
532
+ "property float rot_3",
533
+ "end_header",
534
+ ])
535
+
536
+ header = "\n".join(header_lines) + "\n"
537
+ header_bytes = header.encode('ascii')
538
+
539
+ # Preallocate and assign data (optimized approach - 31-35% faster than concatenate)
540
+ if shN is not None:
541
+ sh_coeffs = shN.shape[1] # Number of SH coefficients (already reshaped to N x K*3)
542
+ total_props = 3 + 3 + sh_coeffs + 1 + 3 + 4 # means, sh0, shN, opacity, scales, quats
543
+ data = np.empty((num_gaussians, total_props), dtype='<f4')
544
+ data[:, 0:3] = means
545
+ data[:, 3:6] = sh0
546
+ data[:, 6:6+sh_coeffs] = shN
547
+ data[:, 6+sh_coeffs:7+sh_coeffs] = opacities # opacities is already (N, 1)
548
+ data[:, 7+sh_coeffs:10+sh_coeffs] = scales
549
+ data[:, 10+sh_coeffs:14+sh_coeffs] = quats
550
+ else:
551
+ data = np.empty((num_gaussians, 14), dtype='<f4')
552
+ data[:, 0:3] = means
553
+ data[:, 3:6] = sh0
554
+ data[:, 6:7] = opacities # opacities is already (N, 1)
555
+ data[:, 7:10] = scales
556
+ data[:, 10:14] = quats
557
+
558
+ # Write directly to file
559
+ with open(file_path, 'wb') as f:
560
+ f.write(header_bytes)
561
+ data.tofile(f)
562
+
563
+ logger.debug(f"[Gaussian PLY] Wrote uncompressed: {num_gaussians} Gaussians to {file_path.name}")
564
+
565
+
566
+ # ======================================================================================
567
+ # COMPRESSED PLY WRITER (VECTORIZED)
568
+ # ======================================================================================
569
+
570
+ CHUNK_SIZE = 256
571
+ SH_C0 = 0.28209479177387814
572
+
573
+ def write_compressed(
574
+ file_path: Union[str, Path],
575
+ means: np.ndarray,
576
+ scales: np.ndarray,
577
+ quats: np.ndarray,
578
+ opacities: np.ndarray,
579
+ sh0: np.ndarray,
580
+ shN: Optional[np.ndarray] = None,
581
+ validate: bool = True,
582
+ ) -> None:
583
+ """Write compressed Gaussian splatting PLY file (PlayCanvas format).
584
+
585
+ Compresses data using chunk-based quantization (256 Gaussians per chunk).
586
+ Achieves 3.8-14.5x compression ratio using highly optimized vectorized operations.
587
+
588
+ Args:
589
+ file_path: Output PLY file path
590
+ means: (N, 3) - xyz positions
591
+ scales: (N, 3) - scale parameters
592
+ quats: (N, 4) - rotation quaternions (must be normalized)
593
+ opacities: (N,) - opacity values
594
+ sh0: (N, 3) - DC spherical harmonics
595
+ shN: (N, K, 3) or (N, K*3) - Higher-order SH coefficients (optional)
596
+ validate: If True, validate input shapes (default True)
597
+
598
+ Performance:
599
+ ~43ms for 50K Gaussians (highly optimized vectorized compression)
600
+ 78% faster than initial implementation
601
+
602
+ Format:
603
+ Compressed PLY with chunk-based quantization:
604
+ - 256 Gaussians per chunk
605
+ - Position: 11-10-11 bit quantization
606
+ - Scale: 11-10-11 bit quantization
607
+ - Color: 8-8-8-8 bit quantization
608
+ - Quaternion: smallest-three encoding (2+10+10+10 bits)
609
+ - SH coefficients: 8-bit quantization (optional)
610
+
611
+ Example:
612
+ >>> write_compressed("output.ply", means, scales, quats, opacities, sh0, shN)
613
+ >>> # File is 14.5x smaller than uncompressed
614
+ """
615
+ file_path = Path(file_path)
616
+
617
+ # Validate and normalize inputs
618
+ if not isinstance(means, np.ndarray):
619
+ means = np.asarray(means, dtype=np.float32)
620
+ if not isinstance(scales, np.ndarray):
621
+ scales = np.asarray(scales, dtype=np.float32)
622
+ if not isinstance(quats, np.ndarray):
623
+ quats = np.asarray(quats, dtype=np.float32)
624
+ if not isinstance(opacities, np.ndarray):
625
+ opacities = np.asarray(opacities, dtype=np.float32)
626
+ if not isinstance(sh0, np.ndarray):
627
+ sh0 = np.asarray(sh0, dtype=np.float32)
628
+ if shN is not None and not isinstance(shN, np.ndarray):
629
+ shN = np.asarray(shN, dtype=np.float32)
630
+
631
+ # Only convert dtype if needed
632
+ if means.dtype != np.float32:
633
+ means = means.astype(np.float32, copy=False)
634
+ if scales.dtype != np.float32:
635
+ scales = scales.astype(np.float32, copy=False)
636
+ if quats.dtype != np.float32:
637
+ quats = quats.astype(np.float32, copy=False)
638
+ if opacities.dtype != np.float32:
639
+ opacities = opacities.astype(np.float32, copy=False)
640
+ if sh0.dtype != np.float32:
641
+ sh0 = sh0.astype(np.float32, copy=False)
642
+ if shN is not None and shN.dtype != np.float32:
643
+ shN = shN.astype(np.float32, copy=False)
644
+
645
+ num_gaussians = means.shape[0]
646
+
647
+ # Validate shapes (optional)
648
+ if validate:
649
+ assert means.shape == (num_gaussians, 3), f"means must be (N, 3), got {means.shape}"
650
+ assert scales.shape == (num_gaussians, 3), f"scales must be (N, 3), got {scales.shape}"
651
+ assert quats.shape == (num_gaussians, 4), f"quats must be (N, 4), got {quats.shape}"
652
+ assert opacities.shape == (num_gaussians,), f"opacities must be (N,), got {opacities.shape}"
653
+ assert sh0.shape == (num_gaussians, 3), f"sh0 must be (N, 3), got {sh0.shape}"
654
+
655
+ # Flatten shN if needed
656
+ if shN is not None and shN.ndim == 3:
657
+ N, K, C = shN.shape
658
+ if validate:
659
+ assert C == 3, f"shN must have shape (N, K, 3), got {shN.shape}"
660
+ shN = shN.reshape(N, K * 3)
661
+
662
+ # Compute number of chunks
663
+ num_chunks = (num_gaussians + CHUNK_SIZE - 1) // CHUNK_SIZE
664
+
665
+ # Pre-compute chunk indices for all vertices (vectorized)
666
+ chunk_indices = np.arange(num_gaussians, dtype=np.int32) // CHUNK_SIZE
667
+
668
+ # ====================================================================================
669
+ # COMPUTE CHUNK BOUNDS (OPTIMIZED WITH SORTING)
670
+ # ====================================================================================
671
+ #
672
+ # IMPORTANT: The compressed PLY format REQUIRES vertices to be in chunk order.
673
+ # The reader assumes vertices 0-255 are chunk 0, 256-511 are chunk 1, etc.
674
+ # This is not a bug - it's a format specification requirement.
675
+ #
676
+ # Performance: Sorting once (O(n log n)) + binary search (O(k log n)) is much
677
+ # faster than boolean masking per chunk (O(n * k) where k = num_chunks).
678
+ #
679
+ # ====================================================================================
680
+
681
+ # Allocate chunk bounds arrays
682
+ chunk_bounds = np.zeros((num_chunks, 18), dtype=np.float32)
683
+
684
+ # Sort all data by chunk indices using radix sort (O(n) vs O(n log n))
685
+ # Radix sort is 4x faster for small integer keys like chunk indices
686
+ # This is required by the compressed PLY format specification.
687
+ if HAS_NUMBA:
688
+ # JIT-compiled radix sort (21ms -> 5ms)
689
+ sort_idx = _radix_sort_by_chunks(chunk_indices, num_chunks)
690
+ else:
691
+ # Fallback: standard comparison sort
692
+ sort_idx = np.argsort(chunk_indices)
693
+
694
+ sorted_chunk_indices = chunk_indices[sort_idx]
695
+ sorted_means = means[sort_idx]
696
+ sorted_scales = scales[sort_idx]
697
+ sorted_sh0 = sh0[sort_idx]
698
+ sorted_quats = quats[sort_idx]
699
+ sorted_opacities = opacities[sort_idx]
700
+ if shN is not None:
701
+ sorted_shN = shN[sort_idx]
702
+ else:
703
+ sorted_shN = None
704
+
705
+ # Find chunk boundaries using searchsorted (O(num_chunks log n))
706
+ chunk_starts = np.searchsorted(sorted_chunk_indices, np.arange(num_chunks), side='left')
707
+ chunk_ends = np.searchsorted(sorted_chunk_indices, np.arange(num_chunks), side='right')
708
+
709
+ # Compute chunk bounds (JIT-optimized: 9x faster than Python loop)
710
+ if HAS_NUMBA:
711
+ # JIT-compiled bounds computation (90ms -> 10ms)
712
+ chunk_bounds = _compute_chunk_bounds_jit(sorted_means, sorted_scales, sorted_sh0,
713
+ chunk_starts, chunk_ends, SH_C0)
714
+ else:
715
+ # Fallback: Python loop with NumPy operations
716
+ for chunk_idx in range(num_chunks):
717
+ start = chunk_starts[chunk_idx]
718
+ end = chunk_ends[chunk_idx]
719
+
720
+ if start == end: # Empty chunk (shouldn't happen but handle gracefully)
721
+ continue
722
+
723
+ # Slice data for this chunk (O(1) operation)
724
+ chunk_means = sorted_means[start:end]
725
+ chunk_scales = sorted_scales[start:end]
726
+ chunk_color_rgb = sorted_sh0[start:end] * SH_C0 + 0.5
727
+
728
+ # Position bounds
729
+ chunk_bounds[chunk_idx, 0] = np.min(chunk_means[:, 0]) # min_x
730
+ chunk_bounds[chunk_idx, 1] = np.min(chunk_means[:, 1]) # min_y
731
+ chunk_bounds[chunk_idx, 2] = np.min(chunk_means[:, 2]) # min_z
732
+ chunk_bounds[chunk_idx, 3] = np.max(chunk_means[:, 0]) # max_x
733
+ chunk_bounds[chunk_idx, 4] = np.max(chunk_means[:, 1]) # max_y
734
+ chunk_bounds[chunk_idx, 5] = np.max(chunk_means[:, 2]) # max_z
735
+
736
+ # Scale bounds (clamped to [-20, 20] to handle infinity, matches splat-transform)
737
+ chunk_bounds[chunk_idx, 6] = np.clip(np.min(chunk_scales[:, 0]), -20, 20) # min_scale_x
738
+ chunk_bounds[chunk_idx, 7] = np.clip(np.min(chunk_scales[:, 1]), -20, 20) # min_scale_y
739
+ chunk_bounds[chunk_idx, 8] = np.clip(np.min(chunk_scales[:, 2]), -20, 20) # min_scale_z
740
+ chunk_bounds[chunk_idx, 9] = np.clip(np.max(chunk_scales[:, 0]), -20, 20) # max_scale_x
741
+ chunk_bounds[chunk_idx, 10] = np.clip(np.max(chunk_scales[:, 1]), -20, 20) # max_scale_y
742
+ chunk_bounds[chunk_idx, 11] = np.clip(np.max(chunk_scales[:, 2]), -20, 20) # max_scale_z
743
+
744
+ # Color bounds
745
+ chunk_bounds[chunk_idx, 12] = np.min(chunk_color_rgb[:, 0]) # min_r
746
+ chunk_bounds[chunk_idx, 13] = np.min(chunk_color_rgb[:, 1]) # min_g
747
+ chunk_bounds[chunk_idx, 14] = np.min(chunk_color_rgb[:, 2]) # min_b
748
+ chunk_bounds[chunk_idx, 15] = np.max(chunk_color_rgb[:, 0]) # max_r
749
+ chunk_bounds[chunk_idx, 16] = np.max(chunk_color_rgb[:, 1]) # max_g
750
+ chunk_bounds[chunk_idx, 17] = np.max(chunk_color_rgb[:, 2]) # max_b
751
+
752
+ # Extract bounds for vectorized quantization
753
+ min_x, min_y, min_z = chunk_bounds[:, 0], chunk_bounds[:, 1], chunk_bounds[:, 2]
754
+ max_x, max_y, max_z = chunk_bounds[:, 3], chunk_bounds[:, 4], chunk_bounds[:, 5]
755
+ min_scale_x, min_scale_y, min_scale_z = chunk_bounds[:, 6], chunk_bounds[:, 7], chunk_bounds[:, 8]
756
+ max_scale_x, max_scale_y, max_scale_z = chunk_bounds[:, 9], chunk_bounds[:, 10], chunk_bounds[:, 11]
757
+ min_r, min_g, min_b = chunk_bounds[:, 12], chunk_bounds[:, 13], chunk_bounds[:, 14]
758
+ max_r, max_g, max_b = chunk_bounds[:, 15], chunk_bounds[:, 16], chunk_bounds[:, 17]
759
+
760
+ # ====================================================================================
761
+ # QUANTIZATION AND BIT PACKING (JIT-optimized when available)
762
+ # ====================================================================================
763
+
764
+ # Allocate packed vertex data (4 uint32 per vertex)
765
+ packed_data = np.zeros((num_gaussians, 4), dtype=np.uint32)
766
+
767
+ # Use JIT-compiled functions if available (5-6x faster than NumPy for writing)
768
+ # Note: First-time compilation adds ~40s overhead, but subsequent calls are 5-6x faster
769
+ if HAS_NUMBA:
770
+ # JIT-compiled compression (parallel, fastmath)
771
+ packed_data[:, 0] = _pack_positions_jit(sorted_means, sorted_chunk_indices, min_x, min_y, min_z, max_x, max_y, max_z)
772
+ packed_data[:, 2] = _pack_scales_jit(sorted_scales, sorted_chunk_indices, min_scale_x, min_scale_y, min_scale_z, max_scale_x, max_scale_y, max_scale_z)
773
+ packed_data[:, 3] = _pack_colors_jit(sorted_sh0, sorted_opacities, sorted_chunk_indices, min_r, min_g, min_b, max_r, max_g, max_b, SH_C0)
774
+ packed_data[:, 1] = _pack_quaternions_jit(sorted_quats)
775
+ else:
776
+ # Vectorized NumPy operations (optimal for writing)
777
+ # --- POSITION QUANTIZATION (11-10-11 bits) ---
778
+ range_x = max_x[sorted_chunk_indices] - min_x[sorted_chunk_indices]
779
+ range_y = max_y[sorted_chunk_indices] - min_y[sorted_chunk_indices]
780
+ range_z = max_z[sorted_chunk_indices] - min_z[sorted_chunk_indices]
781
+
782
+ range_x = np.where(range_x == 0, 1.0, range_x)
783
+ range_y = np.where(range_y == 0, 1.0, range_y)
784
+ range_z = np.where(range_z == 0, 1.0, range_z)
785
+
786
+ norm_x = (sorted_means[:, 0] - min_x[sorted_chunk_indices]) / range_x
787
+ norm_y = (sorted_means[:, 1] - min_y[sorted_chunk_indices]) / range_y
788
+ norm_z = (sorted_means[:, 2] - min_z[sorted_chunk_indices]) / range_z
789
+
790
+ norm_x = np.clip(norm_x, 0.0, 1.0)
791
+ norm_y = np.clip(norm_y, 0.0, 1.0)
792
+ norm_z = np.clip(norm_z, 0.0, 1.0)
793
+
794
+ px = (norm_x * 2047.0).astype(np.uint32)
795
+ py = (norm_y * 1023.0).astype(np.uint32)
796
+ pz = (norm_z * 2047.0).astype(np.uint32)
797
+
798
+ packed_data[:, 0] = (px << 21) | (py << 11) | pz
799
+
800
+ # --- SCALE QUANTIZATION (11-10-11 bits) ---
801
+ range_sx = max_scale_x[sorted_chunk_indices] - min_scale_x[sorted_chunk_indices]
802
+ range_sy = max_scale_y[sorted_chunk_indices] - min_scale_y[sorted_chunk_indices]
803
+ range_sz = max_scale_z[sorted_chunk_indices] - min_scale_z[sorted_chunk_indices]
804
+
805
+ range_sx = np.where(range_sx == 0, 1.0, range_sx)
806
+ range_sy = np.where(range_sy == 0, 1.0, range_sy)
807
+ range_sz = np.where(range_sz == 0, 1.0, range_sz)
808
+
809
+ norm_sx = (sorted_scales[:, 0] - min_scale_x[sorted_chunk_indices]) / range_sx
810
+ norm_sy = (sorted_scales[:, 1] - min_scale_y[sorted_chunk_indices]) / range_sy
811
+ norm_sz = (sorted_scales[:, 2] - min_scale_z[sorted_chunk_indices]) / range_sz
812
+
813
+ norm_sx = np.clip(norm_sx, 0.0, 1.0)
814
+ norm_sy = np.clip(norm_sy, 0.0, 1.0)
815
+ norm_sz = np.clip(norm_sz, 0.0, 1.0)
816
+
817
+ sx = (norm_sx * 2047.0).astype(np.uint32)
818
+ sy = (norm_sy * 1023.0).astype(np.uint32)
819
+ sz = (norm_sz * 2047.0).astype(np.uint32)
820
+
821
+ packed_data[:, 2] = (sx << 21) | (sy << 11) | sz
822
+
823
+ # --- COLOR QUANTIZATION (8-8-8-8 bits) ---
824
+ color_rgb = sorted_sh0 * SH_C0 + 0.5
825
+
826
+ range_r = max_r[sorted_chunk_indices] - min_r[sorted_chunk_indices]
827
+ range_g = max_g[sorted_chunk_indices] - min_g[sorted_chunk_indices]
828
+ range_b = max_b[sorted_chunk_indices] - min_b[sorted_chunk_indices]
829
+
830
+ range_r = np.where(range_r == 0, 1.0, range_r)
831
+ range_g = np.where(range_g == 0, 1.0, range_g)
832
+ range_b = np.where(range_b == 0, 1.0, range_b)
833
+
834
+ norm_r = (color_rgb[:, 0] - min_r[sorted_chunk_indices]) / range_r
835
+ norm_g = (color_rgb[:, 1] - min_g[sorted_chunk_indices]) / range_g
836
+ norm_b = (color_rgb[:, 2] - min_b[sorted_chunk_indices]) / range_b
837
+
838
+ norm_r = np.clip(norm_r, 0.0, 1.0)
839
+ norm_g = np.clip(norm_g, 0.0, 1.0)
840
+ norm_b = np.clip(norm_b, 0.0, 1.0)
841
+
842
+ cr = (norm_r * 255.0).astype(np.uint32)
843
+ cg = (norm_g * 255.0).astype(np.uint32)
844
+ cb = (norm_b * 255.0).astype(np.uint32)
845
+
846
+ opacity_linear = 1.0 / (1.0 + np.exp(-sorted_opacities))
847
+ opacity_linear = np.clip(opacity_linear, 0.0, 1.0)
848
+ co = (opacity_linear * 255.0).astype(np.uint32)
849
+
850
+ packed_data[:, 3] = (cr << 24) | (cg << 16) | (cb << 8) | co
851
+
852
+ # --- QUATERNION QUANTIZATION (smallest-three encoding: 2+10+10+10 bits) ---
853
+ quats_normalized = sorted_quats / np.linalg.norm(sorted_quats, axis=1, keepdims=True)
854
+
855
+ abs_quats = np.abs(quats_normalized)
856
+ largest_idx = np.argmax(abs_quats, axis=1).astype(np.uint32)
857
+
858
+ sign_mask = np.take_along_axis(quats_normalized, largest_idx[:, np.newaxis], axis=1) < 0
859
+ quats_normalized = np.where(sign_mask, -quats_normalized, quats_normalized)
860
+
861
+ mask = np.ones((num_gaussians, 4), dtype=bool)
862
+ mask[np.arange(num_gaussians), largest_idx] = False
863
+ three_components = quats_normalized[mask].reshape(num_gaussians, 3)
864
+
865
+ norm = np.sqrt(2.0) * 0.5
866
+ qa_norm = three_components[:, 0] * norm + 0.5
867
+ qb_norm = three_components[:, 1] * norm + 0.5
868
+ qc_norm = three_components[:, 2] * norm + 0.5
869
+
870
+ qa_norm = np.clip(qa_norm, 0.0, 1.0)
871
+ qb_norm = np.clip(qb_norm, 0.0, 1.0)
872
+ qc_norm = np.clip(qc_norm, 0.0, 1.0)
873
+
874
+ qa_int = (qa_norm * 1023.0).astype(np.uint32)
875
+ qb_int = (qb_norm * 1023.0).astype(np.uint32)
876
+ qc_int = (qc_norm * 1023.0).astype(np.uint32)
877
+
878
+ packed_data[:, 1] = (largest_idx << 30) | (qa_int << 20) | (qb_int << 10) | qc_int
879
+
880
+ # ====================================================================================
881
+ # SH COEFFICIENT COMPRESSION (8-bit quantization)
882
+ # ====================================================================================
883
+
884
+ packed_sh = None
885
+ if sorted_shN is not None and sorted_shN.shape[1] > 0:
886
+ # Normalize to [0, 1] range
887
+ # SH values are typically in range [-4, 4]
888
+ # This matches splat-transform: nvalue = shN / 8 + 0.5 (write-compressed-ply.ts:85)
889
+ sh_normalized = (sorted_shN / 8.0) + 0.5
890
+ sh_normalized = np.clip(sh_normalized, 0.0, 1.0)
891
+
892
+ # Quantize to uint8: trunc(nvalue * 256), clamped to [0, 255]
893
+ # This matches splat-transform: Math.trunc(nvalue * 256) (write-compressed-ply.ts:86)
894
+ packed_sh = np.clip(np.trunc(sh_normalized * 256.0), 0, 255).astype(np.uint8)
895
+
896
+ # ====================================================================================
897
+ # WRITE HEADER AND DATA
898
+ # ====================================================================================
899
+
900
+ # Build header
901
+ header_lines = [
902
+ "ply",
903
+ "format binary_little_endian 1.0",
904
+ f"element chunk {num_chunks}",
905
+ ]
906
+
907
+ # Add chunk properties (18 floats)
908
+ chunk_props = [
909
+ "min_x", "min_y", "min_z",
910
+ "max_x", "max_y", "max_z",
911
+ "min_scale_x", "min_scale_y", "min_scale_z",
912
+ "max_scale_x", "max_scale_y", "max_scale_z",
913
+ "min_r", "min_g", "min_b",
914
+ "max_r", "max_g", "max_b",
915
+ ]
916
+ for prop in chunk_props:
917
+ header_lines.append(f"property float {prop}")
918
+
919
+ # Add vertex element
920
+ header_lines.append(f"element vertex {num_gaussians}")
921
+ header_lines.append("property uint packed_position")
922
+ header_lines.append("property uint packed_rotation")
923
+ header_lines.append("property uint packed_scale")
924
+ header_lines.append("property uint packed_color")
925
+
926
+ # Add SH element if present
927
+ if packed_sh is not None:
928
+ num_sh_coeffs = packed_sh.shape[1]
929
+ header_lines.append(f"element sh {num_gaussians}")
930
+ for i in range(num_sh_coeffs):
931
+ header_lines.append(f"property uchar coeff_{i}")
932
+
933
+ header_lines.append("end_header")
934
+ header = "\n".join(header_lines) + "\n"
935
+ header_bytes = header.encode('ascii')
936
+
937
+ # Write to file
938
+ with open(file_path, 'wb') as f:
939
+ f.write(header_bytes)
940
+ chunk_bounds.tofile(f)
941
+ packed_data.tofile(f)
942
+ if packed_sh is not None:
943
+ packed_sh.tofile(f)
944
+
945
+ logger.debug(f"[Gaussian PLY] Wrote compressed: {num_gaussians} Gaussians to {file_path.name} "
946
+ f"({num_chunks} chunks, {len(header_bytes) + chunk_bounds.nbytes + packed_data.nbytes + (packed_sh.nbytes if packed_sh is not None else 0)} bytes)")
947
+
948
+
949
+ # ======================================================================================
950
+ # UNIFIED WRITING API
951
+ # ======================================================================================
952
+
953
+ def plywrite(
954
+ file_path: Union[str, Path],
955
+ means: np.ndarray,
956
+ scales: np.ndarray,
957
+ quats: np.ndarray,
958
+ opacities: np.ndarray,
959
+ sh0: np.ndarray,
960
+ shN: Optional[np.ndarray] = None,
961
+ compressed: bool = False,
962
+ validate: bool = True,
963
+ ) -> None:
964
+ """Write Gaussian splatting PLY file (auto-select format).
965
+
966
+ Automatically selects format based on compressed parameter or file extension:
967
+ - compressed=False or .ply -> uncompressed (fast)
968
+ - compressed=True -> automatically saves as .compressed.ply
969
+ - .compressed.ply or .ply_compressed extension -> compressed format
970
+
971
+ When compressed=True, the output file extension is automatically changed to
972
+ .compressed.ply (e.g., "output.ply" becomes "output.compressed.ply").
973
+
974
+ Args:
975
+ file_path: Output PLY file path (extension auto-adjusted if compressed=True)
976
+ means: (N, 3) - xyz positions
977
+ scales: (N, 3) - scale parameters
978
+ quats: (N, 4) - rotation quaternions
979
+ opacities: (N,) - opacity values
980
+ sh0: (N, 3) - DC spherical harmonics
981
+ shN: (N, K, 3) or (N, K*3) - Higher-order SH coefficients (optional)
982
+ compressed: If True, write compressed format and auto-adjust extension
983
+ validate: If True, validate input shapes (default True). Disable for trusted data.
984
+
985
+ Example:
986
+ >>> # Write uncompressed (fast)
987
+ >>> plywrite("output.ply", means, scales, quats, opacities, sh0, shN)
988
+ >>> # Write compressed (saves as "output.compressed.ply")
989
+ >>> plywrite("output.ply", means, scales, quats, opacities, sh0, shN, compressed=True)
990
+ >>> # Or without higher-order SH
991
+ >>> plywrite("output.ply", means, scales, quats, opacities, sh0)
992
+ >>> # Skip validation for trusted data (5-10% faster)
993
+ >>> plywrite("output.ply", means, scales, quats, opacities, sh0, validate=False)
994
+ """
995
+ file_path = Path(file_path)
996
+
997
+ # Auto-detect compression from extension
998
+ is_compressed_ext = file_path.name.endswith(('.ply_compressed', '.compressed.ply'))
999
+
1000
+ # Check if compressed format requested
1001
+ if compressed or is_compressed_ext:
1002
+ # If compressed=True but no compressed extension, add .compressed.ply
1003
+ if compressed and not is_compressed_ext:
1004
+ # Replace .ply with .compressed.ply, or just append if no .ply
1005
+ if file_path.suffix == '.ply':
1006
+ file_path = file_path.with_suffix('.compressed.ply')
1007
+ else:
1008
+ file_path = Path(str(file_path) + '.compressed.ply')
1009
+
1010
+ write_compressed(file_path, means, scales, quats, opacities, sh0, shN)
1011
+ else:
1012
+ write_uncompressed(file_path, means, scales, quats, opacities, sh0, shN, validate=validate)
1013
+
1014
+
1015
+ __all__ = [
1016
+ 'plywrite',
1017
+ 'write_uncompressed',
1018
+ 'write_compressed',
1019
+ ]