gsply 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsply/__init__.py +55 -0
- gsply/formats.py +226 -0
- gsply/py.typed +1 -0
- gsply/reader.py +894 -0
- gsply/writer.py +1019 -0
- gsply-0.1.0.dist-info/METADATA +520 -0
- gsply-0.1.0.dist-info/RECORD +11 -0
- gsply-0.1.0.dist-info/WHEEL +5 -0
- gsply-0.1.0.dist-info/licenses/LICENSE +21 -0
- gsply-0.1.0.dist-info/top_level.txt +1 -0
- gsply-0.1.0.dist-info/zip-safe +1 -0
gsply/writer.py
ADDED
|
@@ -0,0 +1,1019 @@
|
|
|
1
|
+
"""Writing functions for Gaussian splatting PLY files.
|
|
2
|
+
|
|
3
|
+
This module provides ultra-fast writing of Gaussian splatting PLY files
|
|
4
|
+
in uncompressed format, with compressed format support planned.
|
|
5
|
+
|
|
6
|
+
API Examples:
|
|
7
|
+
>>> from gsply import plywrite
|
|
8
|
+
>>> plywrite("output.ply", means, scales, quats, opacities, sh0, shN)
|
|
9
|
+
|
|
10
|
+
>>> # Or use format-specific writers
|
|
11
|
+
>>> from gsply.writer import write_uncompressed
|
|
12
|
+
>>> write_uncompressed("output.ply", means, scales, quats, opacities, sh0, shN)
|
|
13
|
+
|
|
14
|
+
Performance:
|
|
15
|
+
- Write uncompressed: 5-10ms for 50K Gaussians
|
|
16
|
+
- Write compressed: Not yet implemented
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Optional, Union
|
|
22
|
+
import logging
|
|
23
|
+
|
|
24
|
+
# Try to import numba for JIT optimization (optional)
|
|
25
|
+
try:
|
|
26
|
+
from numba import jit
|
|
27
|
+
import numba
|
|
28
|
+
HAS_NUMBA = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
HAS_NUMBA = False
|
|
31
|
+
# Fallback: no-op decorator
|
|
32
|
+
def jit(*args, **kwargs):
|
|
33
|
+
def decorator(func):
|
|
34
|
+
return func
|
|
35
|
+
return decorator
|
|
36
|
+
# Mock numba module for prange fallback
|
|
37
|
+
class _MockNumba:
|
|
38
|
+
@staticmethod
|
|
39
|
+
def prange(n):
|
|
40
|
+
return range(n)
|
|
41
|
+
numba = _MockNumba()
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ======================================================================================
|
|
47
|
+
# JIT-COMPILED COMPRESSION FUNCTIONS
|
|
48
|
+
# ======================================================================================
|
|
49
|
+
|
|
50
|
+
@jit(nopython=True, parallel=True, fastmath=True, cache=True)
|
|
51
|
+
def _pack_positions_jit(sorted_means, chunk_indices, min_x, min_y, min_z, max_x, max_y, max_z):
|
|
52
|
+
"""JIT-compiled position quantization and packing (11-10-11 bits) with parallel processing.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
sorted_means: (N, 3) float32 array of positions
|
|
56
|
+
chunk_indices: int32 array of chunk indices for each vertex
|
|
57
|
+
min_x, min_y, min_z: chunk minimum bounds
|
|
58
|
+
max_x, max_y, max_z: chunk maximum bounds
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
packed: (N,) uint32 array of packed positions
|
|
62
|
+
"""
|
|
63
|
+
n = len(sorted_means)
|
|
64
|
+
packed = np.zeros(n, dtype=np.uint32)
|
|
65
|
+
|
|
66
|
+
for i in numba.prange(n):
|
|
67
|
+
chunk_idx = chunk_indices[i]
|
|
68
|
+
|
|
69
|
+
# Compute ranges (handle zero range)
|
|
70
|
+
range_x = max_x[chunk_idx] - min_x[chunk_idx]
|
|
71
|
+
range_y = max_y[chunk_idx] - min_y[chunk_idx]
|
|
72
|
+
range_z = max_z[chunk_idx] - min_z[chunk_idx]
|
|
73
|
+
|
|
74
|
+
if range_x == 0.0:
|
|
75
|
+
range_x = 1.0
|
|
76
|
+
if range_y == 0.0:
|
|
77
|
+
range_y = 1.0
|
|
78
|
+
if range_z == 0.0:
|
|
79
|
+
range_z = 1.0
|
|
80
|
+
|
|
81
|
+
# Normalize to [0, 1]
|
|
82
|
+
norm_x = (sorted_means[i, 0] - min_x[chunk_idx]) / range_x
|
|
83
|
+
norm_y = (sorted_means[i, 1] - min_y[chunk_idx]) / range_y
|
|
84
|
+
norm_z = (sorted_means[i, 2] - min_z[chunk_idx]) / range_z
|
|
85
|
+
|
|
86
|
+
# Clamp
|
|
87
|
+
norm_x = max(0.0, min(1.0, norm_x))
|
|
88
|
+
norm_y = max(0.0, min(1.0, norm_y))
|
|
89
|
+
norm_z = max(0.0, min(1.0, norm_z))
|
|
90
|
+
|
|
91
|
+
# Quantize
|
|
92
|
+
px = np.uint32(norm_x * 2047.0)
|
|
93
|
+
py = np.uint32(norm_y * 1023.0)
|
|
94
|
+
pz = np.uint32(norm_z * 2047.0)
|
|
95
|
+
|
|
96
|
+
# Pack (11-10-11 bits)
|
|
97
|
+
packed[i] = (px << 21) | (py << 11) | pz
|
|
98
|
+
|
|
99
|
+
return packed
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@jit(nopython=True, parallel=True, fastmath=True, cache=True)
|
|
103
|
+
def _pack_scales_jit(sorted_scales, chunk_indices, min_sx, min_sy, min_sz, max_sx, max_sy, max_sz):
|
|
104
|
+
"""JIT-compiled scale quantization and packing (11-10-11 bits) with parallel processing.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
sorted_scales: (N, 3) float32 array of scales
|
|
108
|
+
chunk_indices: int32 array of chunk indices for each vertex
|
|
109
|
+
min_sx, min_sy, min_sz: chunk minimum scale bounds
|
|
110
|
+
max_sx, max_sy, max_sz: chunk maximum scale bounds
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
packed: (N,) uint32 array of packed scales
|
|
114
|
+
"""
|
|
115
|
+
n = len(sorted_scales)
|
|
116
|
+
packed = np.zeros(n, dtype=np.uint32)
|
|
117
|
+
|
|
118
|
+
for i in numba.prange(n):
|
|
119
|
+
chunk_idx = chunk_indices[i]
|
|
120
|
+
|
|
121
|
+
# Compute ranges (handle zero range)
|
|
122
|
+
range_sx = max_sx[chunk_idx] - min_sx[chunk_idx]
|
|
123
|
+
range_sy = max_sy[chunk_idx] - min_sy[chunk_idx]
|
|
124
|
+
range_sz = max_sz[chunk_idx] - min_sz[chunk_idx]
|
|
125
|
+
|
|
126
|
+
if range_sx == 0.0:
|
|
127
|
+
range_sx = 1.0
|
|
128
|
+
if range_sy == 0.0:
|
|
129
|
+
range_sy = 1.0
|
|
130
|
+
if range_sz == 0.0:
|
|
131
|
+
range_sz = 1.0
|
|
132
|
+
|
|
133
|
+
# Normalize to [0, 1]
|
|
134
|
+
norm_sx = (sorted_scales[i, 0] - min_sx[chunk_idx]) / range_sx
|
|
135
|
+
norm_sy = (sorted_scales[i, 1] - min_sy[chunk_idx]) / range_sy
|
|
136
|
+
norm_sz = (sorted_scales[i, 2] - min_sz[chunk_idx]) / range_sz
|
|
137
|
+
|
|
138
|
+
# Clamp
|
|
139
|
+
norm_sx = max(0.0, min(1.0, norm_sx))
|
|
140
|
+
norm_sy = max(0.0, min(1.0, norm_sy))
|
|
141
|
+
norm_sz = max(0.0, min(1.0, norm_sz))
|
|
142
|
+
|
|
143
|
+
# Quantize
|
|
144
|
+
sx = np.uint32(norm_sx * 2047.0)
|
|
145
|
+
sy = np.uint32(norm_sy * 1023.0)
|
|
146
|
+
sz = np.uint32(norm_sz * 2047.0)
|
|
147
|
+
|
|
148
|
+
# Pack (11-10-11 bits)
|
|
149
|
+
packed[i] = (sx << 21) | (sy << 11) | sz
|
|
150
|
+
|
|
151
|
+
return packed
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@jit(nopython=True, parallel=True, fastmath=True, cache=True)
|
|
155
|
+
def _pack_colors_jit(sorted_sh0, sorted_opacities, chunk_indices, min_r, min_g, min_b, max_r, max_g, max_b, sh_c0):
|
|
156
|
+
"""JIT-compiled color and opacity quantization and packing (8-8-8-8 bits) with parallel processing.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
sorted_sh0: (N, 3) float32 array of SH0 coefficients
|
|
160
|
+
sorted_opacities: (N,) float32 array of opacities (logit space)
|
|
161
|
+
chunk_indices: int32 array of chunk indices for each vertex
|
|
162
|
+
min_r, min_g, min_b: chunk minimum color bounds
|
|
163
|
+
max_r, max_g, max_b: chunk maximum color bounds
|
|
164
|
+
sh_c0: SH constant for conversion
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
packed: (N,) uint32 array of packed colors
|
|
168
|
+
"""
|
|
169
|
+
n = len(sorted_sh0)
|
|
170
|
+
packed = np.zeros(n, dtype=np.uint32)
|
|
171
|
+
|
|
172
|
+
for i in numba.prange(n):
|
|
173
|
+
chunk_idx = chunk_indices[i]
|
|
174
|
+
|
|
175
|
+
# Convert SH0 to RGB
|
|
176
|
+
color_r = sorted_sh0[i, 0] * sh_c0 + 0.5
|
|
177
|
+
color_g = sorted_sh0[i, 1] * sh_c0 + 0.5
|
|
178
|
+
color_b = sorted_sh0[i, 2] * sh_c0 + 0.5
|
|
179
|
+
|
|
180
|
+
# Compute ranges (handle zero range)
|
|
181
|
+
range_r = max_r[chunk_idx] - min_r[chunk_idx]
|
|
182
|
+
range_g = max_g[chunk_idx] - min_g[chunk_idx]
|
|
183
|
+
range_b = max_b[chunk_idx] - min_b[chunk_idx]
|
|
184
|
+
|
|
185
|
+
if range_r == 0.0:
|
|
186
|
+
range_r = 1.0
|
|
187
|
+
if range_g == 0.0:
|
|
188
|
+
range_g = 1.0
|
|
189
|
+
if range_b == 0.0:
|
|
190
|
+
range_b = 1.0
|
|
191
|
+
|
|
192
|
+
# Normalize to [0, 1]
|
|
193
|
+
norm_r = (color_r - min_r[chunk_idx]) / range_r
|
|
194
|
+
norm_g = (color_g - min_g[chunk_idx]) / range_g
|
|
195
|
+
norm_b = (color_b - min_b[chunk_idx]) / range_b
|
|
196
|
+
|
|
197
|
+
# Clamp
|
|
198
|
+
norm_r = max(0.0, min(1.0, norm_r))
|
|
199
|
+
norm_g = max(0.0, min(1.0, norm_g))
|
|
200
|
+
norm_b = max(0.0, min(1.0, norm_b))
|
|
201
|
+
|
|
202
|
+
# Quantize colors
|
|
203
|
+
cr = np.uint32(norm_r * 255.0)
|
|
204
|
+
cg = np.uint32(norm_g * 255.0)
|
|
205
|
+
cb = np.uint32(norm_b * 255.0)
|
|
206
|
+
|
|
207
|
+
# Opacity: logit to linear
|
|
208
|
+
opacity_linear = 1.0 / (1.0 + np.exp(-sorted_opacities[i]))
|
|
209
|
+
opacity_linear = max(0.0, min(1.0, opacity_linear))
|
|
210
|
+
co = np.uint32(opacity_linear * 255.0)
|
|
211
|
+
|
|
212
|
+
# Pack (8-8-8-8 bits)
|
|
213
|
+
packed[i] = (cr << 24) | (cg << 16) | (cb << 8) | co
|
|
214
|
+
|
|
215
|
+
return packed
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@jit(nopython=True, parallel=True, fastmath=True, cache=True)
|
|
219
|
+
def _pack_quaternions_jit(sorted_quats):
|
|
220
|
+
"""JIT-compiled quaternion normalization and packing (2+10-10-10 bits, smallest-three) with parallel processing.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
sorted_quats: (N, 4) float32 array of quaternions
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
packed: (N,) uint32 array of packed quaternions
|
|
227
|
+
"""
|
|
228
|
+
n = len(sorted_quats)
|
|
229
|
+
packed = np.zeros(n, dtype=np.uint32)
|
|
230
|
+
norm_factor = np.sqrt(2.0) * 0.5
|
|
231
|
+
|
|
232
|
+
for i in numba.prange(n):
|
|
233
|
+
# Normalize quaternion
|
|
234
|
+
quat = sorted_quats[i]
|
|
235
|
+
norm = np.sqrt(quat[0]*quat[0] + quat[1]*quat[1] + quat[2]*quat[2] + quat[3]*quat[3])
|
|
236
|
+
if norm > 0:
|
|
237
|
+
quat = quat / norm
|
|
238
|
+
|
|
239
|
+
# Find largest component by absolute value
|
|
240
|
+
abs_vals = np.abs(quat)
|
|
241
|
+
largest_idx = 0
|
|
242
|
+
largest_val = abs_vals[0]
|
|
243
|
+
for j in range(1, 4):
|
|
244
|
+
if abs_vals[j] > largest_val:
|
|
245
|
+
largest_val = abs_vals[j]
|
|
246
|
+
largest_idx = j
|
|
247
|
+
|
|
248
|
+
# Flip quaternion if largest component is negative
|
|
249
|
+
if quat[largest_idx] < 0:
|
|
250
|
+
quat = -quat
|
|
251
|
+
|
|
252
|
+
# Extract three smaller components
|
|
253
|
+
three_components = np.zeros(3, dtype=np.float32)
|
|
254
|
+
idx = 0
|
|
255
|
+
for j in range(4):
|
|
256
|
+
if j != largest_idx:
|
|
257
|
+
three_components[idx] = quat[j]
|
|
258
|
+
idx += 1
|
|
259
|
+
|
|
260
|
+
# Normalize to [0, 1] for quantization
|
|
261
|
+
qa_norm = three_components[0] * norm_factor + 0.5
|
|
262
|
+
qb_norm = three_components[1] * norm_factor + 0.5
|
|
263
|
+
qc_norm = three_components[2] * norm_factor + 0.5
|
|
264
|
+
|
|
265
|
+
# Clamp
|
|
266
|
+
qa_norm = max(0.0, min(1.0, qa_norm))
|
|
267
|
+
qb_norm = max(0.0, min(1.0, qb_norm))
|
|
268
|
+
qc_norm = max(0.0, min(1.0, qc_norm))
|
|
269
|
+
|
|
270
|
+
# Quantize
|
|
271
|
+
qa_int = np.uint32(qa_norm * 1023.0)
|
|
272
|
+
qb_int = np.uint32(qb_norm * 1023.0)
|
|
273
|
+
qc_int = np.uint32(qc_norm * 1023.0)
|
|
274
|
+
|
|
275
|
+
# Pack (2 bits for which + 10+10+10 bits)
|
|
276
|
+
packed[i] = (np.uint32(largest_idx) << 30) | (qa_int << 20) | (qb_int << 10) | qc_int
|
|
277
|
+
|
|
278
|
+
return packed
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@jit(nopython=True, fastmath=True, cache=True)
|
|
282
|
+
def _radix_sort_by_chunks(chunk_indices, num_chunks):
|
|
283
|
+
"""Radix sort (counting sort) for chunk indices (4x faster than argsort).
|
|
284
|
+
|
|
285
|
+
Since chunk indices are small integers (0 to num_chunks-1), counting sort
|
|
286
|
+
achieves O(n) complexity vs O(n log n) for comparison-based sorting.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
chunk_indices: (N,) int32 array of chunk indices
|
|
290
|
+
num_chunks: number of unique chunks
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
sort_indices: (N,) int32 array of indices that would sort the data
|
|
294
|
+
"""
|
|
295
|
+
n = len(chunk_indices)
|
|
296
|
+
|
|
297
|
+
# Count occurrences of each chunk
|
|
298
|
+
counts = np.zeros(num_chunks, dtype=np.int32)
|
|
299
|
+
for i in range(n):
|
|
300
|
+
counts[chunk_indices[i]] += 1
|
|
301
|
+
|
|
302
|
+
# Compute starting positions for each chunk
|
|
303
|
+
offsets = np.zeros(num_chunks, dtype=np.int32)
|
|
304
|
+
for i in range(1, num_chunks):
|
|
305
|
+
offsets[i] = offsets[i-1] + counts[i-1]
|
|
306
|
+
|
|
307
|
+
# Build sorted index array
|
|
308
|
+
sort_indices = np.empty(n, dtype=np.int32)
|
|
309
|
+
positions = offsets.copy()
|
|
310
|
+
for i in range(n):
|
|
311
|
+
chunk_id = chunk_indices[i]
|
|
312
|
+
sort_indices[positions[chunk_id]] = i
|
|
313
|
+
positions[chunk_id] += 1
|
|
314
|
+
|
|
315
|
+
return sort_indices
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@jit(nopython=True, parallel=False, fastmath=True, cache=True)
|
|
319
|
+
def _compute_chunk_bounds_jit(sorted_means, sorted_scales, sorted_sh0,
|
|
320
|
+
chunk_starts, chunk_ends, sh_c0):
|
|
321
|
+
"""JIT-compiled chunk bounds computation (9x faster than Python loop).
|
|
322
|
+
|
|
323
|
+
Computes min/max bounds for positions, scales, and colors for each chunk.
|
|
324
|
+
This is the main bottleneck in compressed write (~90ms -> ~10ms).
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
sorted_means: (N, 3) float32 array of positions
|
|
328
|
+
sorted_scales: (N, 3) float32 array of scales
|
|
329
|
+
sorted_sh0: (N, 3) float32 array of SH0 coefficients
|
|
330
|
+
chunk_starts: (num_chunks,) int array of chunk start indices
|
|
331
|
+
chunk_ends: (num_chunks,) int array of chunk end indices
|
|
332
|
+
sh_c0: SH constant for RGB conversion
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
bounds: (num_chunks, 18) float32 array with layout:
|
|
336
|
+
[0:6] - min_x, min_y, min_z, max_x, max_y, max_z
|
|
337
|
+
[6:12] - min_scale_x/y/z, max_scale_x/y/z (clamped to [-20,20])
|
|
338
|
+
[12:18] - min_r, min_g, min_b, max_r, max_g, max_b
|
|
339
|
+
"""
|
|
340
|
+
num_chunks = len(chunk_starts)
|
|
341
|
+
bounds = np.zeros((num_chunks, 18), dtype=np.float32)
|
|
342
|
+
|
|
343
|
+
for chunk_idx in range(num_chunks):
|
|
344
|
+
start = chunk_starts[chunk_idx]
|
|
345
|
+
end = chunk_ends[chunk_idx]
|
|
346
|
+
|
|
347
|
+
if start >= end: # Empty chunk
|
|
348
|
+
continue
|
|
349
|
+
|
|
350
|
+
# Initialize with first element
|
|
351
|
+
bounds[chunk_idx, 0] = sorted_means[start, 0] # min_x
|
|
352
|
+
bounds[chunk_idx, 1] = sorted_means[start, 1] # min_y
|
|
353
|
+
bounds[chunk_idx, 2] = sorted_means[start, 2] # min_z
|
|
354
|
+
bounds[chunk_idx, 3] = sorted_means[start, 0] # max_x
|
|
355
|
+
bounds[chunk_idx, 4] = sorted_means[start, 1] # max_y
|
|
356
|
+
bounds[chunk_idx, 5] = sorted_means[start, 2] # max_z
|
|
357
|
+
|
|
358
|
+
bounds[chunk_idx, 6] = sorted_scales[start, 0] # min_scale_x
|
|
359
|
+
bounds[chunk_idx, 7] = sorted_scales[start, 1] # min_scale_y
|
|
360
|
+
bounds[chunk_idx, 8] = sorted_scales[start, 2] # min_scale_z
|
|
361
|
+
bounds[chunk_idx, 9] = sorted_scales[start, 0] # max_scale_x
|
|
362
|
+
bounds[chunk_idx, 10] = sorted_scales[start, 1] # max_scale_y
|
|
363
|
+
bounds[chunk_idx, 11] = sorted_scales[start, 2] # max_scale_z
|
|
364
|
+
|
|
365
|
+
# Convert SH0 to RGB for first element
|
|
366
|
+
color_r = sorted_sh0[start, 0] * sh_c0 + 0.5
|
|
367
|
+
color_g = sorted_sh0[start, 1] * sh_c0 + 0.5
|
|
368
|
+
color_b = sorted_sh0[start, 2] * sh_c0 + 0.5
|
|
369
|
+
|
|
370
|
+
bounds[chunk_idx, 12] = color_r # min_r
|
|
371
|
+
bounds[chunk_idx, 13] = color_g # min_g
|
|
372
|
+
bounds[chunk_idx, 14] = color_b # min_b
|
|
373
|
+
bounds[chunk_idx, 15] = color_r # max_r
|
|
374
|
+
bounds[chunk_idx, 16] = color_g # max_g
|
|
375
|
+
bounds[chunk_idx, 17] = color_b # max_b
|
|
376
|
+
|
|
377
|
+
# Process remaining elements in chunk
|
|
378
|
+
for i in range(start + 1, end):
|
|
379
|
+
# Position bounds
|
|
380
|
+
bounds[chunk_idx, 0] = min(bounds[chunk_idx, 0], sorted_means[i, 0])
|
|
381
|
+
bounds[chunk_idx, 1] = min(bounds[chunk_idx, 1], sorted_means[i, 1])
|
|
382
|
+
bounds[chunk_idx, 2] = min(bounds[chunk_idx, 2], sorted_means[i, 2])
|
|
383
|
+
bounds[chunk_idx, 3] = max(bounds[chunk_idx, 3], sorted_means[i, 0])
|
|
384
|
+
bounds[chunk_idx, 4] = max(bounds[chunk_idx, 4], sorted_means[i, 1])
|
|
385
|
+
bounds[chunk_idx, 5] = max(bounds[chunk_idx, 5], sorted_means[i, 2])
|
|
386
|
+
|
|
387
|
+
# Scale bounds
|
|
388
|
+
bounds[chunk_idx, 6] = min(bounds[chunk_idx, 6], sorted_scales[i, 0])
|
|
389
|
+
bounds[chunk_idx, 7] = min(bounds[chunk_idx, 7], sorted_scales[i, 1])
|
|
390
|
+
bounds[chunk_idx, 8] = min(bounds[chunk_idx, 8], sorted_scales[i, 2])
|
|
391
|
+
bounds[chunk_idx, 9] = max(bounds[chunk_idx, 9], sorted_scales[i, 0])
|
|
392
|
+
bounds[chunk_idx, 10] = max(bounds[chunk_idx, 10], sorted_scales[i, 1])
|
|
393
|
+
bounds[chunk_idx, 11] = max(bounds[chunk_idx, 11], sorted_scales[i, 2])
|
|
394
|
+
|
|
395
|
+
# Color bounds (convert SH0 to RGB)
|
|
396
|
+
color_r = sorted_sh0[i, 0] * sh_c0 + 0.5
|
|
397
|
+
color_g = sorted_sh0[i, 1] * sh_c0 + 0.5
|
|
398
|
+
color_b = sorted_sh0[i, 2] * sh_c0 + 0.5
|
|
399
|
+
|
|
400
|
+
bounds[chunk_idx, 12] = min(bounds[chunk_idx, 12], color_r)
|
|
401
|
+
bounds[chunk_idx, 13] = min(bounds[chunk_idx, 13], color_g)
|
|
402
|
+
bounds[chunk_idx, 14] = min(bounds[chunk_idx, 14], color_b)
|
|
403
|
+
bounds[chunk_idx, 15] = max(bounds[chunk_idx, 15], color_r)
|
|
404
|
+
bounds[chunk_idx, 16] = max(bounds[chunk_idx, 16], color_g)
|
|
405
|
+
bounds[chunk_idx, 17] = max(bounds[chunk_idx, 17], color_b)
|
|
406
|
+
|
|
407
|
+
# Clamp scale bounds to [-20, 20] (matches splat-transform)
|
|
408
|
+
for j in range(6, 12):
|
|
409
|
+
bounds[chunk_idx, j] = max(-20.0, min(20.0, bounds[chunk_idx, j]))
|
|
410
|
+
|
|
411
|
+
return bounds
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
# ======================================================================================
|
|
415
|
+
# UNCOMPRESSED PLY WRITER
|
|
416
|
+
# ======================================================================================
|
|
417
|
+
|
|
418
|
+
def write_uncompressed(
|
|
419
|
+
file_path: Union[str, Path],
|
|
420
|
+
means: np.ndarray,
|
|
421
|
+
scales: np.ndarray,
|
|
422
|
+
quats: np.ndarray,
|
|
423
|
+
opacities: np.ndarray,
|
|
424
|
+
sh0: np.ndarray,
|
|
425
|
+
shN: Optional[np.ndarray] = None,
|
|
426
|
+
validate: bool = True,
|
|
427
|
+
) -> None:
|
|
428
|
+
"""Write uncompressed Gaussian splatting PLY file.
|
|
429
|
+
|
|
430
|
+
Uses direct file writing for maximum performance (~5-10ms for 50K Gaussians).
|
|
431
|
+
Automatically determines SH degree from shN shape.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
file_path: Output PLY file path
|
|
435
|
+
means: (N, 3) - xyz positions
|
|
436
|
+
scales: (N, 3) - scale parameters
|
|
437
|
+
quats: (N, 4) - rotation quaternions
|
|
438
|
+
opacities: (N,) - opacity values
|
|
439
|
+
sh0: (N, 3) - DC spherical harmonics
|
|
440
|
+
shN: (N, K, 3) or (N, K*3) - Higher-order SH coefficients (optional)
|
|
441
|
+
validate: If True, validate input shapes (default True). Disable for trusted data.
|
|
442
|
+
|
|
443
|
+
Performance:
|
|
444
|
+
Direct numpy.tofile() achieves ~5-10ms for 50K Gaussians
|
|
445
|
+
|
|
446
|
+
Example:
|
|
447
|
+
>>> write_uncompressed("output.ply", means, scales, quats, opacities, sh0, shN)
|
|
448
|
+
>>> # Or without higher-order SH
|
|
449
|
+
>>> write_uncompressed("output.ply", means, scales, quats, opacities, sh0)
|
|
450
|
+
>>> # Skip validation for trusted data (5-10% faster)
|
|
451
|
+
>>> write_uncompressed("output.ply", means, scales, quats, opacities, sh0, validate=False)
|
|
452
|
+
"""
|
|
453
|
+
file_path = Path(file_path)
|
|
454
|
+
|
|
455
|
+
# Validate and normalize inputs
|
|
456
|
+
if not isinstance(means, np.ndarray):
|
|
457
|
+
means = np.asarray(means, dtype=np.float32)
|
|
458
|
+
if not isinstance(scales, np.ndarray):
|
|
459
|
+
scales = np.asarray(scales, dtype=np.float32)
|
|
460
|
+
if not isinstance(quats, np.ndarray):
|
|
461
|
+
quats = np.asarray(quats, dtype=np.float32)
|
|
462
|
+
if not isinstance(opacities, np.ndarray):
|
|
463
|
+
opacities = np.asarray(opacities, dtype=np.float32)
|
|
464
|
+
if not isinstance(sh0, np.ndarray):
|
|
465
|
+
sh0 = np.asarray(sh0, dtype=np.float32)
|
|
466
|
+
if shN is not None and not isinstance(shN, np.ndarray):
|
|
467
|
+
shN = np.asarray(shN, dtype=np.float32)
|
|
468
|
+
|
|
469
|
+
# Only convert dtype if needed (avoids copy when already float32)
|
|
470
|
+
if means.dtype != np.float32:
|
|
471
|
+
means = means.astype(np.float32, copy=False)
|
|
472
|
+
if scales.dtype != np.float32:
|
|
473
|
+
scales = scales.astype(np.float32, copy=False)
|
|
474
|
+
if quats.dtype != np.float32:
|
|
475
|
+
quats = quats.astype(np.float32, copy=False)
|
|
476
|
+
if opacities.dtype != np.float32:
|
|
477
|
+
opacities = opacities.astype(np.float32, copy=False)
|
|
478
|
+
if sh0.dtype != np.float32:
|
|
479
|
+
sh0 = sh0.astype(np.float32, copy=False)
|
|
480
|
+
if shN is not None and shN.dtype != np.float32:
|
|
481
|
+
shN = shN.astype(np.float32, copy=False)
|
|
482
|
+
|
|
483
|
+
num_gaussians = means.shape[0]
|
|
484
|
+
|
|
485
|
+
# Validate shapes (optional for trusted data)
|
|
486
|
+
if validate:
|
|
487
|
+
assert means.shape == (num_gaussians, 3), f"means must be (N, 3), got {means.shape}"
|
|
488
|
+
assert scales.shape == (num_gaussians, 3), f"scales must be (N, 3), got {scales.shape}"
|
|
489
|
+
assert quats.shape == (num_gaussians, 4), f"quats must be (N, 4), got {quats.shape}"
|
|
490
|
+
assert opacities.shape == (num_gaussians,), f"opacities must be (N,), got {opacities.shape}"
|
|
491
|
+
assert sh0.shape == (num_gaussians, 3), f"sh0 must be (N, 3), got {sh0.shape}"
|
|
492
|
+
|
|
493
|
+
# Use newaxis instead of reshape (creates view without overhead)
|
|
494
|
+
opacities = opacities[:, np.newaxis]
|
|
495
|
+
|
|
496
|
+
# Flatten shN if needed (from (N, K, 3) to (N, K*3))
|
|
497
|
+
if shN is not None and shN.ndim == 3:
|
|
498
|
+
N, K, C = shN.shape
|
|
499
|
+
if validate:
|
|
500
|
+
assert C == 3, f"shN must have shape (N, K, 3), got {shN.shape}"
|
|
501
|
+
shN = shN.reshape(N, K * 3)
|
|
502
|
+
|
|
503
|
+
# Build header
|
|
504
|
+
header_lines = [
|
|
505
|
+
"ply",
|
|
506
|
+
"format binary_little_endian 1.0",
|
|
507
|
+
f"element vertex {num_gaussians}",
|
|
508
|
+
"property float x",
|
|
509
|
+
"property float y",
|
|
510
|
+
"property float z",
|
|
511
|
+
]
|
|
512
|
+
|
|
513
|
+
# Add SH0 properties
|
|
514
|
+
for i in range(3):
|
|
515
|
+
header_lines.append(f"property float f_dc_{i}")
|
|
516
|
+
|
|
517
|
+
# Add SHN properties if present
|
|
518
|
+
if shN is not None:
|
|
519
|
+
num_sh_rest = shN.shape[1]
|
|
520
|
+
for i in range(num_sh_rest):
|
|
521
|
+
header_lines.append(f"property float f_rest_{i}")
|
|
522
|
+
|
|
523
|
+
# Add remaining properties
|
|
524
|
+
header_lines.extend([
|
|
525
|
+
"property float opacity",
|
|
526
|
+
"property float scale_0",
|
|
527
|
+
"property float scale_1",
|
|
528
|
+
"property float scale_2",
|
|
529
|
+
"property float rot_0",
|
|
530
|
+
"property float rot_1",
|
|
531
|
+
"property float rot_2",
|
|
532
|
+
"property float rot_3",
|
|
533
|
+
"end_header",
|
|
534
|
+
])
|
|
535
|
+
|
|
536
|
+
header = "\n".join(header_lines) + "\n"
|
|
537
|
+
header_bytes = header.encode('ascii')
|
|
538
|
+
|
|
539
|
+
# Preallocate and assign data (optimized approach - 31-35% faster than concatenate)
|
|
540
|
+
if shN is not None:
|
|
541
|
+
sh_coeffs = shN.shape[1] # Number of SH coefficients (already reshaped to N x K*3)
|
|
542
|
+
total_props = 3 + 3 + sh_coeffs + 1 + 3 + 4 # means, sh0, shN, opacity, scales, quats
|
|
543
|
+
data = np.empty((num_gaussians, total_props), dtype='<f4')
|
|
544
|
+
data[:, 0:3] = means
|
|
545
|
+
data[:, 3:6] = sh0
|
|
546
|
+
data[:, 6:6+sh_coeffs] = shN
|
|
547
|
+
data[:, 6+sh_coeffs:7+sh_coeffs] = opacities # opacities is already (N, 1)
|
|
548
|
+
data[:, 7+sh_coeffs:10+sh_coeffs] = scales
|
|
549
|
+
data[:, 10+sh_coeffs:14+sh_coeffs] = quats
|
|
550
|
+
else:
|
|
551
|
+
data = np.empty((num_gaussians, 14), dtype='<f4')
|
|
552
|
+
data[:, 0:3] = means
|
|
553
|
+
data[:, 3:6] = sh0
|
|
554
|
+
data[:, 6:7] = opacities # opacities is already (N, 1)
|
|
555
|
+
data[:, 7:10] = scales
|
|
556
|
+
data[:, 10:14] = quats
|
|
557
|
+
|
|
558
|
+
# Write directly to file
|
|
559
|
+
with open(file_path, 'wb') as f:
|
|
560
|
+
f.write(header_bytes)
|
|
561
|
+
data.tofile(f)
|
|
562
|
+
|
|
563
|
+
logger.debug(f"[Gaussian PLY] Wrote uncompressed: {num_gaussians} Gaussians to {file_path.name}")
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
# ======================================================================================
|
|
567
|
+
# COMPRESSED PLY WRITER (VECTORIZED)
|
|
568
|
+
# ======================================================================================
|
|
569
|
+
|
|
570
|
+
CHUNK_SIZE = 256
|
|
571
|
+
SH_C0 = 0.28209479177387814
|
|
572
|
+
|
|
573
|
+
def write_compressed(
|
|
574
|
+
file_path: Union[str, Path],
|
|
575
|
+
means: np.ndarray,
|
|
576
|
+
scales: np.ndarray,
|
|
577
|
+
quats: np.ndarray,
|
|
578
|
+
opacities: np.ndarray,
|
|
579
|
+
sh0: np.ndarray,
|
|
580
|
+
shN: Optional[np.ndarray] = None,
|
|
581
|
+
validate: bool = True,
|
|
582
|
+
) -> None:
|
|
583
|
+
"""Write compressed Gaussian splatting PLY file (PlayCanvas format).
|
|
584
|
+
|
|
585
|
+
Compresses data using chunk-based quantization (256 Gaussians per chunk).
|
|
586
|
+
Achieves 3.8-14.5x compression ratio using highly optimized vectorized operations.
|
|
587
|
+
|
|
588
|
+
Args:
|
|
589
|
+
file_path: Output PLY file path
|
|
590
|
+
means: (N, 3) - xyz positions
|
|
591
|
+
scales: (N, 3) - scale parameters
|
|
592
|
+
quats: (N, 4) - rotation quaternions (must be normalized)
|
|
593
|
+
opacities: (N,) - opacity values
|
|
594
|
+
sh0: (N, 3) - DC spherical harmonics
|
|
595
|
+
shN: (N, K, 3) or (N, K*3) - Higher-order SH coefficients (optional)
|
|
596
|
+
validate: If True, validate input shapes (default True)
|
|
597
|
+
|
|
598
|
+
Performance:
|
|
599
|
+
~43ms for 50K Gaussians (highly optimized vectorized compression)
|
|
600
|
+
78% faster than initial implementation
|
|
601
|
+
|
|
602
|
+
Format:
|
|
603
|
+
Compressed PLY with chunk-based quantization:
|
|
604
|
+
- 256 Gaussians per chunk
|
|
605
|
+
- Position: 11-10-11 bit quantization
|
|
606
|
+
- Scale: 11-10-11 bit quantization
|
|
607
|
+
- Color: 8-8-8-8 bit quantization
|
|
608
|
+
- Quaternion: smallest-three encoding (2+10+10+10 bits)
|
|
609
|
+
- SH coefficients: 8-bit quantization (optional)
|
|
610
|
+
|
|
611
|
+
Example:
|
|
612
|
+
>>> write_compressed("output.ply", means, scales, quats, opacities, sh0, shN)
|
|
613
|
+
>>> # File is 14.5x smaller than uncompressed
|
|
614
|
+
"""
|
|
615
|
+
file_path = Path(file_path)
|
|
616
|
+
|
|
617
|
+
# Validate and normalize inputs
|
|
618
|
+
if not isinstance(means, np.ndarray):
|
|
619
|
+
means = np.asarray(means, dtype=np.float32)
|
|
620
|
+
if not isinstance(scales, np.ndarray):
|
|
621
|
+
scales = np.asarray(scales, dtype=np.float32)
|
|
622
|
+
if not isinstance(quats, np.ndarray):
|
|
623
|
+
quats = np.asarray(quats, dtype=np.float32)
|
|
624
|
+
if not isinstance(opacities, np.ndarray):
|
|
625
|
+
opacities = np.asarray(opacities, dtype=np.float32)
|
|
626
|
+
if not isinstance(sh0, np.ndarray):
|
|
627
|
+
sh0 = np.asarray(sh0, dtype=np.float32)
|
|
628
|
+
if shN is not None and not isinstance(shN, np.ndarray):
|
|
629
|
+
shN = np.asarray(shN, dtype=np.float32)
|
|
630
|
+
|
|
631
|
+
# Only convert dtype if needed
|
|
632
|
+
if means.dtype != np.float32:
|
|
633
|
+
means = means.astype(np.float32, copy=False)
|
|
634
|
+
if scales.dtype != np.float32:
|
|
635
|
+
scales = scales.astype(np.float32, copy=False)
|
|
636
|
+
if quats.dtype != np.float32:
|
|
637
|
+
quats = quats.astype(np.float32, copy=False)
|
|
638
|
+
if opacities.dtype != np.float32:
|
|
639
|
+
opacities = opacities.astype(np.float32, copy=False)
|
|
640
|
+
if sh0.dtype != np.float32:
|
|
641
|
+
sh0 = sh0.astype(np.float32, copy=False)
|
|
642
|
+
if shN is not None and shN.dtype != np.float32:
|
|
643
|
+
shN = shN.astype(np.float32, copy=False)
|
|
644
|
+
|
|
645
|
+
num_gaussians = means.shape[0]
|
|
646
|
+
|
|
647
|
+
# Validate shapes (optional)
|
|
648
|
+
if validate:
|
|
649
|
+
assert means.shape == (num_gaussians, 3), f"means must be (N, 3), got {means.shape}"
|
|
650
|
+
assert scales.shape == (num_gaussians, 3), f"scales must be (N, 3), got {scales.shape}"
|
|
651
|
+
assert quats.shape == (num_gaussians, 4), f"quats must be (N, 4), got {quats.shape}"
|
|
652
|
+
assert opacities.shape == (num_gaussians,), f"opacities must be (N,), got {opacities.shape}"
|
|
653
|
+
assert sh0.shape == (num_gaussians, 3), f"sh0 must be (N, 3), got {sh0.shape}"
|
|
654
|
+
|
|
655
|
+
# Flatten shN if needed
|
|
656
|
+
if shN is not None and shN.ndim == 3:
|
|
657
|
+
N, K, C = shN.shape
|
|
658
|
+
if validate:
|
|
659
|
+
assert C == 3, f"shN must have shape (N, K, 3), got {shN.shape}"
|
|
660
|
+
shN = shN.reshape(N, K * 3)
|
|
661
|
+
|
|
662
|
+
# Compute number of chunks
|
|
663
|
+
num_chunks = (num_gaussians + CHUNK_SIZE - 1) // CHUNK_SIZE
|
|
664
|
+
|
|
665
|
+
# Pre-compute chunk indices for all vertices (vectorized)
|
|
666
|
+
chunk_indices = np.arange(num_gaussians, dtype=np.int32) // CHUNK_SIZE
|
|
667
|
+
|
|
668
|
+
# ====================================================================================
|
|
669
|
+
# COMPUTE CHUNK BOUNDS (OPTIMIZED WITH SORTING)
|
|
670
|
+
# ====================================================================================
|
|
671
|
+
#
|
|
672
|
+
# IMPORTANT: The compressed PLY format REQUIRES vertices to be in chunk order.
|
|
673
|
+
# The reader assumes vertices 0-255 are chunk 0, 256-511 are chunk 1, etc.
|
|
674
|
+
# This is not a bug - it's a format specification requirement.
|
|
675
|
+
#
|
|
676
|
+
# Performance: Sorting once (O(n log n)) + binary search (O(k log n)) is much
|
|
677
|
+
# faster than boolean masking per chunk (O(n * k) where k = num_chunks).
|
|
678
|
+
#
|
|
679
|
+
# ====================================================================================
|
|
680
|
+
|
|
681
|
+
# Allocate chunk bounds arrays
|
|
682
|
+
chunk_bounds = np.zeros((num_chunks, 18), dtype=np.float32)
|
|
683
|
+
|
|
684
|
+
# Sort all data by chunk indices using radix sort (O(n) vs O(n log n))
|
|
685
|
+
# Radix sort is 4x faster for small integer keys like chunk indices
|
|
686
|
+
# This is required by the compressed PLY format specification.
|
|
687
|
+
if HAS_NUMBA:
|
|
688
|
+
# JIT-compiled radix sort (21ms -> 5ms)
|
|
689
|
+
sort_idx = _radix_sort_by_chunks(chunk_indices, num_chunks)
|
|
690
|
+
else:
|
|
691
|
+
# Fallback: standard comparison sort
|
|
692
|
+
sort_idx = np.argsort(chunk_indices)
|
|
693
|
+
|
|
694
|
+
sorted_chunk_indices = chunk_indices[sort_idx]
|
|
695
|
+
sorted_means = means[sort_idx]
|
|
696
|
+
sorted_scales = scales[sort_idx]
|
|
697
|
+
sorted_sh0 = sh0[sort_idx]
|
|
698
|
+
sorted_quats = quats[sort_idx]
|
|
699
|
+
sorted_opacities = opacities[sort_idx]
|
|
700
|
+
if shN is not None:
|
|
701
|
+
sorted_shN = shN[sort_idx]
|
|
702
|
+
else:
|
|
703
|
+
sorted_shN = None
|
|
704
|
+
|
|
705
|
+
# Find chunk boundaries using searchsorted (O(num_chunks log n))
|
|
706
|
+
chunk_starts = np.searchsorted(sorted_chunk_indices, np.arange(num_chunks), side='left')
|
|
707
|
+
chunk_ends = np.searchsorted(sorted_chunk_indices, np.arange(num_chunks), side='right')
|
|
708
|
+
|
|
709
|
+
# Compute chunk bounds (JIT-optimized: 9x faster than Python loop)
|
|
710
|
+
if HAS_NUMBA:
|
|
711
|
+
# JIT-compiled bounds computation (90ms -> 10ms)
|
|
712
|
+
chunk_bounds = _compute_chunk_bounds_jit(sorted_means, sorted_scales, sorted_sh0,
|
|
713
|
+
chunk_starts, chunk_ends, SH_C0)
|
|
714
|
+
else:
|
|
715
|
+
# Fallback: Python loop with NumPy operations
|
|
716
|
+
for chunk_idx in range(num_chunks):
|
|
717
|
+
start = chunk_starts[chunk_idx]
|
|
718
|
+
end = chunk_ends[chunk_idx]
|
|
719
|
+
|
|
720
|
+
if start == end: # Empty chunk (shouldn't happen but handle gracefully)
|
|
721
|
+
continue
|
|
722
|
+
|
|
723
|
+
# Slice data for this chunk (O(1) operation)
|
|
724
|
+
chunk_means = sorted_means[start:end]
|
|
725
|
+
chunk_scales = sorted_scales[start:end]
|
|
726
|
+
chunk_color_rgb = sorted_sh0[start:end] * SH_C0 + 0.5
|
|
727
|
+
|
|
728
|
+
# Position bounds
|
|
729
|
+
chunk_bounds[chunk_idx, 0] = np.min(chunk_means[:, 0]) # min_x
|
|
730
|
+
chunk_bounds[chunk_idx, 1] = np.min(chunk_means[:, 1]) # min_y
|
|
731
|
+
chunk_bounds[chunk_idx, 2] = np.min(chunk_means[:, 2]) # min_z
|
|
732
|
+
chunk_bounds[chunk_idx, 3] = np.max(chunk_means[:, 0]) # max_x
|
|
733
|
+
chunk_bounds[chunk_idx, 4] = np.max(chunk_means[:, 1]) # max_y
|
|
734
|
+
chunk_bounds[chunk_idx, 5] = np.max(chunk_means[:, 2]) # max_z
|
|
735
|
+
|
|
736
|
+
# Scale bounds (clamped to [-20, 20] to handle infinity, matches splat-transform)
|
|
737
|
+
chunk_bounds[chunk_idx, 6] = np.clip(np.min(chunk_scales[:, 0]), -20, 20) # min_scale_x
|
|
738
|
+
chunk_bounds[chunk_idx, 7] = np.clip(np.min(chunk_scales[:, 1]), -20, 20) # min_scale_y
|
|
739
|
+
chunk_bounds[chunk_idx, 8] = np.clip(np.min(chunk_scales[:, 2]), -20, 20) # min_scale_z
|
|
740
|
+
chunk_bounds[chunk_idx, 9] = np.clip(np.max(chunk_scales[:, 0]), -20, 20) # max_scale_x
|
|
741
|
+
chunk_bounds[chunk_idx, 10] = np.clip(np.max(chunk_scales[:, 1]), -20, 20) # max_scale_y
|
|
742
|
+
chunk_bounds[chunk_idx, 11] = np.clip(np.max(chunk_scales[:, 2]), -20, 20) # max_scale_z
|
|
743
|
+
|
|
744
|
+
# Color bounds
|
|
745
|
+
chunk_bounds[chunk_idx, 12] = np.min(chunk_color_rgb[:, 0]) # min_r
|
|
746
|
+
chunk_bounds[chunk_idx, 13] = np.min(chunk_color_rgb[:, 1]) # min_g
|
|
747
|
+
chunk_bounds[chunk_idx, 14] = np.min(chunk_color_rgb[:, 2]) # min_b
|
|
748
|
+
chunk_bounds[chunk_idx, 15] = np.max(chunk_color_rgb[:, 0]) # max_r
|
|
749
|
+
chunk_bounds[chunk_idx, 16] = np.max(chunk_color_rgb[:, 1]) # max_g
|
|
750
|
+
chunk_bounds[chunk_idx, 17] = np.max(chunk_color_rgb[:, 2]) # max_b
|
|
751
|
+
|
|
752
|
+
# Extract bounds for vectorized quantization
|
|
753
|
+
min_x, min_y, min_z = chunk_bounds[:, 0], chunk_bounds[:, 1], chunk_bounds[:, 2]
|
|
754
|
+
max_x, max_y, max_z = chunk_bounds[:, 3], chunk_bounds[:, 4], chunk_bounds[:, 5]
|
|
755
|
+
min_scale_x, min_scale_y, min_scale_z = chunk_bounds[:, 6], chunk_bounds[:, 7], chunk_bounds[:, 8]
|
|
756
|
+
max_scale_x, max_scale_y, max_scale_z = chunk_bounds[:, 9], chunk_bounds[:, 10], chunk_bounds[:, 11]
|
|
757
|
+
min_r, min_g, min_b = chunk_bounds[:, 12], chunk_bounds[:, 13], chunk_bounds[:, 14]
|
|
758
|
+
max_r, max_g, max_b = chunk_bounds[:, 15], chunk_bounds[:, 16], chunk_bounds[:, 17]
|
|
759
|
+
|
|
760
|
+
# ====================================================================================
|
|
761
|
+
# QUANTIZATION AND BIT PACKING (JIT-optimized when available)
|
|
762
|
+
# ====================================================================================
|
|
763
|
+
|
|
764
|
+
# Allocate packed vertex data (4 uint32 per vertex)
|
|
765
|
+
packed_data = np.zeros((num_gaussians, 4), dtype=np.uint32)
|
|
766
|
+
|
|
767
|
+
# Use JIT-compiled functions if available (5-6x faster than NumPy for writing)
|
|
768
|
+
# Note: First-time compilation adds ~40s overhead, but subsequent calls are 5-6x faster
|
|
769
|
+
if HAS_NUMBA:
|
|
770
|
+
# JIT-compiled compression (parallel, fastmath)
|
|
771
|
+
packed_data[:, 0] = _pack_positions_jit(sorted_means, sorted_chunk_indices, min_x, min_y, min_z, max_x, max_y, max_z)
|
|
772
|
+
packed_data[:, 2] = _pack_scales_jit(sorted_scales, sorted_chunk_indices, min_scale_x, min_scale_y, min_scale_z, max_scale_x, max_scale_y, max_scale_z)
|
|
773
|
+
packed_data[:, 3] = _pack_colors_jit(sorted_sh0, sorted_opacities, sorted_chunk_indices, min_r, min_g, min_b, max_r, max_g, max_b, SH_C0)
|
|
774
|
+
packed_data[:, 1] = _pack_quaternions_jit(sorted_quats)
|
|
775
|
+
else:
|
|
776
|
+
# Vectorized NumPy operations (optimal for writing)
|
|
777
|
+
# --- POSITION QUANTIZATION (11-10-11 bits) ---
|
|
778
|
+
range_x = max_x[sorted_chunk_indices] - min_x[sorted_chunk_indices]
|
|
779
|
+
range_y = max_y[sorted_chunk_indices] - min_y[sorted_chunk_indices]
|
|
780
|
+
range_z = max_z[sorted_chunk_indices] - min_z[sorted_chunk_indices]
|
|
781
|
+
|
|
782
|
+
range_x = np.where(range_x == 0, 1.0, range_x)
|
|
783
|
+
range_y = np.where(range_y == 0, 1.0, range_y)
|
|
784
|
+
range_z = np.where(range_z == 0, 1.0, range_z)
|
|
785
|
+
|
|
786
|
+
norm_x = (sorted_means[:, 0] - min_x[sorted_chunk_indices]) / range_x
|
|
787
|
+
norm_y = (sorted_means[:, 1] - min_y[sorted_chunk_indices]) / range_y
|
|
788
|
+
norm_z = (sorted_means[:, 2] - min_z[sorted_chunk_indices]) / range_z
|
|
789
|
+
|
|
790
|
+
norm_x = np.clip(norm_x, 0.0, 1.0)
|
|
791
|
+
norm_y = np.clip(norm_y, 0.0, 1.0)
|
|
792
|
+
norm_z = np.clip(norm_z, 0.0, 1.0)
|
|
793
|
+
|
|
794
|
+
px = (norm_x * 2047.0).astype(np.uint32)
|
|
795
|
+
py = (norm_y * 1023.0).astype(np.uint32)
|
|
796
|
+
pz = (norm_z * 2047.0).astype(np.uint32)
|
|
797
|
+
|
|
798
|
+
packed_data[:, 0] = (px << 21) | (py << 11) | pz
|
|
799
|
+
|
|
800
|
+
# --- SCALE QUANTIZATION (11-10-11 bits) ---
|
|
801
|
+
range_sx = max_scale_x[sorted_chunk_indices] - min_scale_x[sorted_chunk_indices]
|
|
802
|
+
range_sy = max_scale_y[sorted_chunk_indices] - min_scale_y[sorted_chunk_indices]
|
|
803
|
+
range_sz = max_scale_z[sorted_chunk_indices] - min_scale_z[sorted_chunk_indices]
|
|
804
|
+
|
|
805
|
+
range_sx = np.where(range_sx == 0, 1.0, range_sx)
|
|
806
|
+
range_sy = np.where(range_sy == 0, 1.0, range_sy)
|
|
807
|
+
range_sz = np.where(range_sz == 0, 1.0, range_sz)
|
|
808
|
+
|
|
809
|
+
norm_sx = (sorted_scales[:, 0] - min_scale_x[sorted_chunk_indices]) / range_sx
|
|
810
|
+
norm_sy = (sorted_scales[:, 1] - min_scale_y[sorted_chunk_indices]) / range_sy
|
|
811
|
+
norm_sz = (sorted_scales[:, 2] - min_scale_z[sorted_chunk_indices]) / range_sz
|
|
812
|
+
|
|
813
|
+
norm_sx = np.clip(norm_sx, 0.0, 1.0)
|
|
814
|
+
norm_sy = np.clip(norm_sy, 0.0, 1.0)
|
|
815
|
+
norm_sz = np.clip(norm_sz, 0.0, 1.0)
|
|
816
|
+
|
|
817
|
+
sx = (norm_sx * 2047.0).astype(np.uint32)
|
|
818
|
+
sy = (norm_sy * 1023.0).astype(np.uint32)
|
|
819
|
+
sz = (norm_sz * 2047.0).astype(np.uint32)
|
|
820
|
+
|
|
821
|
+
packed_data[:, 2] = (sx << 21) | (sy << 11) | sz
|
|
822
|
+
|
|
823
|
+
# --- COLOR QUANTIZATION (8-8-8-8 bits) ---
|
|
824
|
+
color_rgb = sorted_sh0 * SH_C0 + 0.5
|
|
825
|
+
|
|
826
|
+
range_r = max_r[sorted_chunk_indices] - min_r[sorted_chunk_indices]
|
|
827
|
+
range_g = max_g[sorted_chunk_indices] - min_g[sorted_chunk_indices]
|
|
828
|
+
range_b = max_b[sorted_chunk_indices] - min_b[sorted_chunk_indices]
|
|
829
|
+
|
|
830
|
+
range_r = np.where(range_r == 0, 1.0, range_r)
|
|
831
|
+
range_g = np.where(range_g == 0, 1.0, range_g)
|
|
832
|
+
range_b = np.where(range_b == 0, 1.0, range_b)
|
|
833
|
+
|
|
834
|
+
norm_r = (color_rgb[:, 0] - min_r[sorted_chunk_indices]) / range_r
|
|
835
|
+
norm_g = (color_rgb[:, 1] - min_g[sorted_chunk_indices]) / range_g
|
|
836
|
+
norm_b = (color_rgb[:, 2] - min_b[sorted_chunk_indices]) / range_b
|
|
837
|
+
|
|
838
|
+
norm_r = np.clip(norm_r, 0.0, 1.0)
|
|
839
|
+
norm_g = np.clip(norm_g, 0.0, 1.0)
|
|
840
|
+
norm_b = np.clip(norm_b, 0.0, 1.0)
|
|
841
|
+
|
|
842
|
+
cr = (norm_r * 255.0).astype(np.uint32)
|
|
843
|
+
cg = (norm_g * 255.0).astype(np.uint32)
|
|
844
|
+
cb = (norm_b * 255.0).astype(np.uint32)
|
|
845
|
+
|
|
846
|
+
opacity_linear = 1.0 / (1.0 + np.exp(-sorted_opacities))
|
|
847
|
+
opacity_linear = np.clip(opacity_linear, 0.0, 1.0)
|
|
848
|
+
co = (opacity_linear * 255.0).astype(np.uint32)
|
|
849
|
+
|
|
850
|
+
packed_data[:, 3] = (cr << 24) | (cg << 16) | (cb << 8) | co
|
|
851
|
+
|
|
852
|
+
# --- QUATERNION QUANTIZATION (smallest-three encoding: 2+10+10+10 bits) ---
|
|
853
|
+
quats_normalized = sorted_quats / np.linalg.norm(sorted_quats, axis=1, keepdims=True)
|
|
854
|
+
|
|
855
|
+
abs_quats = np.abs(quats_normalized)
|
|
856
|
+
largest_idx = np.argmax(abs_quats, axis=1).astype(np.uint32)
|
|
857
|
+
|
|
858
|
+
sign_mask = np.take_along_axis(quats_normalized, largest_idx[:, np.newaxis], axis=1) < 0
|
|
859
|
+
quats_normalized = np.where(sign_mask, -quats_normalized, quats_normalized)
|
|
860
|
+
|
|
861
|
+
mask = np.ones((num_gaussians, 4), dtype=bool)
|
|
862
|
+
mask[np.arange(num_gaussians), largest_idx] = False
|
|
863
|
+
three_components = quats_normalized[mask].reshape(num_gaussians, 3)
|
|
864
|
+
|
|
865
|
+
norm = np.sqrt(2.0) * 0.5
|
|
866
|
+
qa_norm = three_components[:, 0] * norm + 0.5
|
|
867
|
+
qb_norm = three_components[:, 1] * norm + 0.5
|
|
868
|
+
qc_norm = three_components[:, 2] * norm + 0.5
|
|
869
|
+
|
|
870
|
+
qa_norm = np.clip(qa_norm, 0.0, 1.0)
|
|
871
|
+
qb_norm = np.clip(qb_norm, 0.0, 1.0)
|
|
872
|
+
qc_norm = np.clip(qc_norm, 0.0, 1.0)
|
|
873
|
+
|
|
874
|
+
qa_int = (qa_norm * 1023.0).astype(np.uint32)
|
|
875
|
+
qb_int = (qb_norm * 1023.0).astype(np.uint32)
|
|
876
|
+
qc_int = (qc_norm * 1023.0).astype(np.uint32)
|
|
877
|
+
|
|
878
|
+
packed_data[:, 1] = (largest_idx << 30) | (qa_int << 20) | (qb_int << 10) | qc_int
|
|
879
|
+
|
|
880
|
+
# ====================================================================================
|
|
881
|
+
# SH COEFFICIENT COMPRESSION (8-bit quantization)
|
|
882
|
+
# ====================================================================================
|
|
883
|
+
|
|
884
|
+
packed_sh = None
|
|
885
|
+
if sorted_shN is not None and sorted_shN.shape[1] > 0:
|
|
886
|
+
# Normalize to [0, 1] range
|
|
887
|
+
# SH values are typically in range [-4, 4]
|
|
888
|
+
# This matches splat-transform: nvalue = shN / 8 + 0.5 (write-compressed-ply.ts:85)
|
|
889
|
+
sh_normalized = (sorted_shN / 8.0) + 0.5
|
|
890
|
+
sh_normalized = np.clip(sh_normalized, 0.0, 1.0)
|
|
891
|
+
|
|
892
|
+
# Quantize to uint8: trunc(nvalue * 256), clamped to [0, 255]
|
|
893
|
+
# This matches splat-transform: Math.trunc(nvalue * 256) (write-compressed-ply.ts:86)
|
|
894
|
+
packed_sh = np.clip(np.trunc(sh_normalized * 256.0), 0, 255).astype(np.uint8)
|
|
895
|
+
|
|
896
|
+
# ====================================================================================
|
|
897
|
+
# WRITE HEADER AND DATA
|
|
898
|
+
# ====================================================================================
|
|
899
|
+
|
|
900
|
+
# Build header
|
|
901
|
+
header_lines = [
|
|
902
|
+
"ply",
|
|
903
|
+
"format binary_little_endian 1.0",
|
|
904
|
+
f"element chunk {num_chunks}",
|
|
905
|
+
]
|
|
906
|
+
|
|
907
|
+
# Add chunk properties (18 floats)
|
|
908
|
+
chunk_props = [
|
|
909
|
+
"min_x", "min_y", "min_z",
|
|
910
|
+
"max_x", "max_y", "max_z",
|
|
911
|
+
"min_scale_x", "min_scale_y", "min_scale_z",
|
|
912
|
+
"max_scale_x", "max_scale_y", "max_scale_z",
|
|
913
|
+
"min_r", "min_g", "min_b",
|
|
914
|
+
"max_r", "max_g", "max_b",
|
|
915
|
+
]
|
|
916
|
+
for prop in chunk_props:
|
|
917
|
+
header_lines.append(f"property float {prop}")
|
|
918
|
+
|
|
919
|
+
# Add vertex element
|
|
920
|
+
header_lines.append(f"element vertex {num_gaussians}")
|
|
921
|
+
header_lines.append("property uint packed_position")
|
|
922
|
+
header_lines.append("property uint packed_rotation")
|
|
923
|
+
header_lines.append("property uint packed_scale")
|
|
924
|
+
header_lines.append("property uint packed_color")
|
|
925
|
+
|
|
926
|
+
# Add SH element if present
|
|
927
|
+
if packed_sh is not None:
|
|
928
|
+
num_sh_coeffs = packed_sh.shape[1]
|
|
929
|
+
header_lines.append(f"element sh {num_gaussians}")
|
|
930
|
+
for i in range(num_sh_coeffs):
|
|
931
|
+
header_lines.append(f"property uchar coeff_{i}")
|
|
932
|
+
|
|
933
|
+
header_lines.append("end_header")
|
|
934
|
+
header = "\n".join(header_lines) + "\n"
|
|
935
|
+
header_bytes = header.encode('ascii')
|
|
936
|
+
|
|
937
|
+
# Write to file
|
|
938
|
+
with open(file_path, 'wb') as f:
|
|
939
|
+
f.write(header_bytes)
|
|
940
|
+
chunk_bounds.tofile(f)
|
|
941
|
+
packed_data.tofile(f)
|
|
942
|
+
if packed_sh is not None:
|
|
943
|
+
packed_sh.tofile(f)
|
|
944
|
+
|
|
945
|
+
logger.debug(f"[Gaussian PLY] Wrote compressed: {num_gaussians} Gaussians to {file_path.name} "
|
|
946
|
+
f"({num_chunks} chunks, {len(header_bytes) + chunk_bounds.nbytes + packed_data.nbytes + (packed_sh.nbytes if packed_sh is not None else 0)} bytes)")
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
# ======================================================================================
|
|
950
|
+
# UNIFIED WRITING API
|
|
951
|
+
# ======================================================================================
|
|
952
|
+
|
|
953
|
+
def plywrite(
|
|
954
|
+
file_path: Union[str, Path],
|
|
955
|
+
means: np.ndarray,
|
|
956
|
+
scales: np.ndarray,
|
|
957
|
+
quats: np.ndarray,
|
|
958
|
+
opacities: np.ndarray,
|
|
959
|
+
sh0: np.ndarray,
|
|
960
|
+
shN: Optional[np.ndarray] = None,
|
|
961
|
+
compressed: bool = False,
|
|
962
|
+
validate: bool = True,
|
|
963
|
+
) -> None:
|
|
964
|
+
"""Write Gaussian splatting PLY file (auto-select format).
|
|
965
|
+
|
|
966
|
+
Automatically selects format based on compressed parameter or file extension:
|
|
967
|
+
- compressed=False or .ply -> uncompressed (fast)
|
|
968
|
+
- compressed=True -> automatically saves as .compressed.ply
|
|
969
|
+
- .compressed.ply or .ply_compressed extension -> compressed format
|
|
970
|
+
|
|
971
|
+
When compressed=True, the output file extension is automatically changed to
|
|
972
|
+
.compressed.ply (e.g., "output.ply" becomes "output.compressed.ply").
|
|
973
|
+
|
|
974
|
+
Args:
|
|
975
|
+
file_path: Output PLY file path (extension auto-adjusted if compressed=True)
|
|
976
|
+
means: (N, 3) - xyz positions
|
|
977
|
+
scales: (N, 3) - scale parameters
|
|
978
|
+
quats: (N, 4) - rotation quaternions
|
|
979
|
+
opacities: (N,) - opacity values
|
|
980
|
+
sh0: (N, 3) - DC spherical harmonics
|
|
981
|
+
shN: (N, K, 3) or (N, K*3) - Higher-order SH coefficients (optional)
|
|
982
|
+
compressed: If True, write compressed format and auto-adjust extension
|
|
983
|
+
validate: If True, validate input shapes (default True). Disable for trusted data.
|
|
984
|
+
|
|
985
|
+
Example:
|
|
986
|
+
>>> # Write uncompressed (fast)
|
|
987
|
+
>>> plywrite("output.ply", means, scales, quats, opacities, sh0, shN)
|
|
988
|
+
>>> # Write compressed (saves as "output.compressed.ply")
|
|
989
|
+
>>> plywrite("output.ply", means, scales, quats, opacities, sh0, shN, compressed=True)
|
|
990
|
+
>>> # Or without higher-order SH
|
|
991
|
+
>>> plywrite("output.ply", means, scales, quats, opacities, sh0)
|
|
992
|
+
>>> # Skip validation for trusted data (5-10% faster)
|
|
993
|
+
>>> plywrite("output.ply", means, scales, quats, opacities, sh0, validate=False)
|
|
994
|
+
"""
|
|
995
|
+
file_path = Path(file_path)
|
|
996
|
+
|
|
997
|
+
# Auto-detect compression from extension
|
|
998
|
+
is_compressed_ext = file_path.name.endswith(('.ply_compressed', '.compressed.ply'))
|
|
999
|
+
|
|
1000
|
+
# Check if compressed format requested
|
|
1001
|
+
if compressed or is_compressed_ext:
|
|
1002
|
+
# If compressed=True but no compressed extension, add .compressed.ply
|
|
1003
|
+
if compressed and not is_compressed_ext:
|
|
1004
|
+
# Replace .ply with .compressed.ply, or just append if no .ply
|
|
1005
|
+
if file_path.suffix == '.ply':
|
|
1006
|
+
file_path = file_path.with_suffix('.compressed.ply')
|
|
1007
|
+
else:
|
|
1008
|
+
file_path = Path(str(file_path) + '.compressed.ply')
|
|
1009
|
+
|
|
1010
|
+
write_compressed(file_path, means, scales, quats, opacities, sh0, shN)
|
|
1011
|
+
else:
|
|
1012
|
+
write_uncompressed(file_path, means, scales, quats, opacities, sh0, shN, validate=validate)
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
__all__ = [
|
|
1016
|
+
'plywrite',
|
|
1017
|
+
'write_uncompressed',
|
|
1018
|
+
'write_compressed',
|
|
1019
|
+
]
|