@playcanvas/splat-transform 0.5.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.mjs CHANGED
@@ -1884,7 +1884,7 @@ let Quat$1 = class Quat {
1884
1884
  }
1885
1885
  };
1886
1886
 
1887
- var version$1 = "0.5.3";
1887
+ var version$1 = "0.6.0";
1888
1888
 
1889
1889
  class Column {
1890
1890
  name;
@@ -1909,17 +1909,6 @@ class Column {
1909
1909
  clone() {
1910
1910
  return new Column(this.name, this.data.slice());
1911
1911
  }
1912
- filter(length, filter) {
1913
- const constructor = this.data.constructor;
1914
- const data = new constructor(length);
1915
- let j = 0;
1916
- for (let i = 0; i < this.data.length; i++) {
1917
- if (filter[i]) {
1918
- data[j++] = this.data[i];
1919
- }
1920
- }
1921
- return new Column(this.name, data);
1922
- }
1923
1912
  }
1924
1913
  class DataTable {
1925
1914
  columns;
@@ -1995,22 +1984,20 @@ class DataTable {
1995
1984
  clone() {
1996
1985
  return new DataTable(this.columns.map(c => c.clone()));
1997
1986
  }
1998
- filter(predicate) {
1999
- const flags = new Uint8Array(this.numRows);
2000
- const row = {};
2001
- let numRows = 0;
2002
- for (let i = 0; i < this.numRows; i++) {
2003
- this.getRow(i, row);
2004
- flags[i] = predicate(i, row) ? 1 : 0;
2005
- numRows += flags[i];
2006
- }
2007
- if (numRows === 0) {
2008
- return null;
2009
- }
2010
- if (numRows === this.numRows) {
2011
- return this;
1987
+ // return a new table containing the rows referenced in indices
1988
+ permuteRows(indices) {
1989
+ const result = new DataTable(this.columns.map((c) => {
1990
+ const constructor = c.data.constructor;
1991
+ return new Column(c.name, new constructor(indices.length));
1992
+ }));
1993
+ for (let i = 0; i < this.numColumns; ++i) {
1994
+ const src = this.getColumn(i).data;
1995
+ const dst = result.getColumn(i).data;
1996
+ for (let j = 0; j < indices.length; j++) {
1997
+ dst[j] = src[indices[j]];
1998
+ }
2012
1999
  }
2013
- return new DataTable(this.columns.map(c => c.filter(numRows, flags)));
2000
+ return result;
2014
2001
  }
2015
2002
  }
2016
2003
 
@@ -2241,6 +2228,18 @@ const transform = (dataTable, t, r, s) => {
2241
2228
  };
2242
2229
 
2243
2230
  const shNames$2 = new Array(45).fill('').map((_, i) => `f_rest_${i}`);
2231
+ const filter = (dataTable, predicate) => {
2232
+ const indices = new Uint32Array(dataTable.numRows);
2233
+ let index = 0;
2234
+ const row = {};
2235
+ for (let i = 0; i < dataTable.numRows; i++) {
2236
+ dataTable.getRow(i, row);
2237
+ if (predicate(row, i)) {
2238
+ indices[index++] = i;
2239
+ }
2240
+ }
2241
+ return dataTable.permuteRows(indices.subarray(0, index));
2242
+ };
2244
2243
  // process a data table with standard options
2245
2244
  const process = (dataTable, processActions) => {
2246
2245
  let result = dataTable;
@@ -2257,7 +2256,7 @@ const process = (dataTable, processActions) => {
2257
2256
  transform(result, Vec3$1.ZERO, Quat$1.IDENTITY, processAction.value);
2258
2257
  break;
2259
2258
  case 'filterNaN': {
2260
- const predicate = (rowIndex, row) => {
2259
+ const predicate = (row, rowIndex) => {
2261
2260
  for (const key in row) {
2262
2261
  if (!isFinite(row[key])) {
2263
2262
  return false;
@@ -2265,21 +2264,21 @@ const process = (dataTable, processActions) => {
2265
2264
  }
2266
2265
  return true;
2267
2266
  };
2268
- result = result.filter(predicate);
2267
+ result = filter(result, predicate);
2269
2268
  break;
2270
2269
  }
2271
2270
  case 'filterByValue': {
2272
2271
  const { columnName, comparator, value } = processAction;
2273
2272
  const Predicates = {
2274
- 'lt': (rowIndex, row) => row[columnName] < value,
2275
- 'lte': (rowIndex, row) => row[columnName] <= value,
2276
- 'gt': (rowIndex, row) => row[columnName] > value,
2277
- 'gte': (rowIndex, row) => row[columnName] >= value,
2278
- 'eq': (rowIndex, row) => row[columnName] === value,
2279
- 'neq': (rowIndex, row) => row[columnName] !== value
2273
+ 'lt': (row, rowIndex) => row[columnName] < value,
2274
+ 'lte': (row, rowIndex) => row[columnName] <= value,
2275
+ 'gt': (row, rowIndex) => row[columnName] > value,
2276
+ 'gte': (row, rowIndex) => row[columnName] >= value,
2277
+ 'eq': (row, rowIndex) => row[columnName] === value,
2278
+ 'neq': (row, rowIndex) => row[columnName] !== value
2280
2279
  };
2281
- const predicate = Predicates[comparator] ?? ((rowIndex, row) => true);
2282
- result = result.filter(predicate);
2280
+ const predicate = Predicates[comparator] ?? ((row, rowIndex) => true);
2281
+ result = filter(result, predicate);
2283
2282
  break;
2284
2283
  }
2285
2284
  case 'filterBands': {
@@ -2310,6 +2309,211 @@ const process = (dataTable, processActions) => {
2310
2309
  return result;
2311
2310
  };
2312
2311
 
2312
+ // Size of a chunk in the compressed PLY format (number of splats per chunk)
2313
+ const CHUNK_SIZE$1 = 256;
2314
+ const isCompressedPly$1 = (ply) => {
2315
+ const hasShape = (dataTable, columns, type) => {
2316
+ return columns.every((name) => {
2317
+ const col = dataTable.getColumnByName(name);
2318
+ return col && col.dataType === type;
2319
+ });
2320
+ };
2321
+ const chunkProperties = [
2322
+ 'min_x',
2323
+ 'min_y',
2324
+ 'min_z',
2325
+ 'max_x',
2326
+ 'max_y',
2327
+ 'max_z',
2328
+ 'min_scale_x',
2329
+ 'min_scale_y',
2330
+ 'min_scale_z',
2331
+ 'max_scale_x',
2332
+ 'max_scale_y',
2333
+ 'max_scale_z',
2334
+ 'min_r',
2335
+ 'min_g',
2336
+ 'min_b',
2337
+ 'max_r',
2338
+ 'max_g',
2339
+ 'max_b'
2340
+ ];
2341
+ const vertexProperties = [
2342
+ 'packed_position',
2343
+ 'packed_rotation',
2344
+ 'packed_scale',
2345
+ 'packed_color'
2346
+ ];
2347
+ const numElements = ply.elements.length;
2348
+ if (numElements !== 2 && numElements !== 3)
2349
+ return false;
2350
+ const chunk = ply.elements.find(e => e.name === 'chunk');
2351
+ if (!chunk || !hasShape(chunk.dataTable, chunkProperties, 'float32'))
2352
+ return false;
2353
+ const vertex = ply.elements.find(e => e.name === 'vertex');
2354
+ if (!vertex || !hasShape(vertex.dataTable, vertexProperties, 'uint32'))
2355
+ return false;
2356
+ if (Math.ceil(vertex.dataTable.numRows / CHUNK_SIZE$1) !== chunk.dataTable.numRows) {
2357
+ return false;
2358
+ }
2359
+ // check optional spherical harmonics
2360
+ if (numElements === 3) {
2361
+ const sh = ply.elements.find(e => e.name === 'sh');
2362
+ if (!sh) {
2363
+ return false;
2364
+ }
2365
+ const shData = sh.dataTable;
2366
+ if ([9, 24, 45].indexOf(shData.numColumns) === -1) {
2367
+ return false;
2368
+ }
2369
+ for (let i = 0; i < shData.numColumns; ++i) {
2370
+ const col = shData.getColumnByName(`f_rest_${i}`);
2371
+ if (!col || col.dataType !== 'uint8') {
2372
+ return false;
2373
+ }
2374
+ }
2375
+ if (shData.numRows !== vertex.dataTable.numRows) {
2376
+ return false;
2377
+ }
2378
+ }
2379
+ return true;
2380
+ };
2381
+ // Detects the compressed PLY schema and returns a decompressed DataTable, or null if not compressed.
2382
+ const decompressPly = (ply) => {
2383
+ const chunkData = ply.elements.find(e => e.name === 'chunk').dataTable;
2384
+ const getChunk = (name) => chunkData.getColumnByName(name).data;
2385
+ const vertexData = ply.elements.find(e => e.name === 'vertex').dataTable;
2386
+ const packed_position = vertexData.getColumnByName('packed_position').data;
2387
+ const packed_rotation = vertexData.getColumnByName('packed_rotation').data;
2388
+ const packed_scale = vertexData.getColumnByName('packed_scale').data;
2389
+ const packed_color = vertexData.getColumnByName('packed_color').data;
2390
+ const min_x = getChunk('min_x');
2391
+ const min_y = getChunk('min_y');
2392
+ const min_z = getChunk('min_z');
2393
+ const max_x = getChunk('max_x');
2394
+ const max_y = getChunk('max_y');
2395
+ const max_z = getChunk('max_z');
2396
+ const min_scale_x = getChunk('min_scale_x');
2397
+ const min_scale_y = getChunk('min_scale_y');
2398
+ const min_scale_z = getChunk('min_scale_z');
2399
+ const max_scale_x = getChunk('max_scale_x');
2400
+ const max_scale_y = getChunk('max_scale_y');
2401
+ const max_scale_z = getChunk('max_scale_z');
2402
+ const min_r = getChunk('min_r');
2403
+ const min_g = getChunk('min_g');
2404
+ const min_b = getChunk('min_b');
2405
+ const max_r = getChunk('max_r');
2406
+ const max_g = getChunk('max_g');
2407
+ const max_b = getChunk('max_b');
2408
+ const numSplats = vertexData.numRows;
2409
+ const columns = [
2410
+ new Column('x', new Float32Array(numSplats)),
2411
+ new Column('y', new Float32Array(numSplats)),
2412
+ new Column('z', new Float32Array(numSplats)),
2413
+ new Column('f_dc_0', new Float32Array(numSplats)),
2414
+ new Column('f_dc_1', new Float32Array(numSplats)),
2415
+ new Column('f_dc_2', new Float32Array(numSplats)),
2416
+ new Column('opacity', new Float32Array(numSplats)),
2417
+ new Column('rot_0', new Float32Array(numSplats)),
2418
+ new Column('rot_1', new Float32Array(numSplats)),
2419
+ new Column('rot_2', new Float32Array(numSplats)),
2420
+ new Column('rot_3', new Float32Array(numSplats)),
2421
+ new Column('scale_0', new Float32Array(numSplats)),
2422
+ new Column('scale_1', new Float32Array(numSplats)),
2423
+ new Column('scale_2', new Float32Array(numSplats))
2424
+ ];
2425
+ const result = new DataTable(columns);
2426
+ const lerp = (a, b, t) => a * (1 - t) + b * t;
2427
+ const unpackUnorm = (value, bits) => {
2428
+ const t = (1 << bits) - 1;
2429
+ return (value & t) / t;
2430
+ };
2431
+ const unpack111011 = (value) => ({
2432
+ x: unpackUnorm(value >>> 21, 11),
2433
+ y: unpackUnorm(value >>> 11, 10),
2434
+ z: unpackUnorm(value, 11)
2435
+ });
2436
+ const unpack8888 = (value) => ({
2437
+ x: unpackUnorm(value >>> 24, 8),
2438
+ y: unpackUnorm(value >>> 16, 8),
2439
+ z: unpackUnorm(value >>> 8, 8),
2440
+ w: unpackUnorm(value, 8)
2441
+ });
2442
+ const unpackRot = (value) => {
2443
+ const norm = 1.0 / (Math.sqrt(2) * 0.5);
2444
+ const a = (unpackUnorm(value >>> 20, 10) - 0.5) * norm;
2445
+ const b = (unpackUnorm(value >>> 10, 10) - 0.5) * norm;
2446
+ const c = (unpackUnorm(value, 10) - 0.5) * norm;
2447
+ const m = Math.sqrt(Math.max(0, 1.0 - (a * a + b * b + c * c)));
2448
+ const which = value >>> 30;
2449
+ switch (which) {
2450
+ case 0:
2451
+ return { x: m, y: a, z: b, w: c };
2452
+ case 1:
2453
+ return { x: a, y: m, z: b, w: c };
2454
+ case 2:
2455
+ return { x: a, y: b, z: m, w: c };
2456
+ default:
2457
+ return { x: a, y: b, z: c, w: m };
2458
+ }
2459
+ };
2460
+ const SH_C0 = 0.28209479177387814;
2461
+ const ox = result.getColumnByName('x').data;
2462
+ const oy = result.getColumnByName('y').data;
2463
+ const oz = result.getColumnByName('z').data;
2464
+ const or0 = result.getColumnByName('rot_0').data;
2465
+ const or1 = result.getColumnByName('rot_1').data;
2466
+ const or2 = result.getColumnByName('rot_2').data;
2467
+ const or3 = result.getColumnByName('rot_3').data;
2468
+ const os0 = result.getColumnByName('scale_0').data;
2469
+ const os1 = result.getColumnByName('scale_1').data;
2470
+ const os2 = result.getColumnByName('scale_2').data;
2471
+ const of0 = result.getColumnByName('f_dc_0').data;
2472
+ const of1 = result.getColumnByName('f_dc_1').data;
2473
+ const of2 = result.getColumnByName('f_dc_2').data;
2474
+ const oo = result.getColumnByName('opacity').data;
2475
+ for (let i = 0; i < numSplats; ++i) {
2476
+ const ci = Math.floor(i / CHUNK_SIZE$1);
2477
+ const p = unpack111011(packed_position[i]);
2478
+ const r = unpackRot(packed_rotation[i]);
2479
+ const s = unpack111011(packed_scale[i]);
2480
+ const c = unpack8888(packed_color[i]);
2481
+ ox[i] = lerp(min_x[ci], max_x[ci], p.x);
2482
+ oy[i] = lerp(min_y[ci], max_y[ci], p.y);
2483
+ oz[i] = lerp(min_z[ci], max_z[ci], p.z);
2484
+ or0[i] = r.x;
2485
+ or1[i] = r.y;
2486
+ or2[i] = r.z;
2487
+ or3[i] = r.w;
2488
+ os0[i] = lerp(min_scale_x[ci], max_scale_x[ci], s.x);
2489
+ os1[i] = lerp(min_scale_y[ci], max_scale_y[ci], s.y);
2490
+ os2[i] = lerp(min_scale_z[ci], max_scale_z[ci], s.z);
2491
+ const cr = lerp(min_r[ci], max_r[ci], c.x);
2492
+ const cg = lerp(min_g[ci], max_g[ci], c.y);
2493
+ const cb = lerp(min_b[ci], max_b[ci], c.z);
2494
+ of0[i] = (cr - 0.5) / SH_C0;
2495
+ of1[i] = (cg - 0.5) / SH_C0;
2496
+ of2[i] = (cb - 0.5) / SH_C0;
2497
+ oo[i] = -Math.log(1 / c.w - 1);
2498
+ }
2499
+ // extract spherical harmonics
2500
+ const shElem = ply.elements.find(e => e.name === 'sh');
2501
+ if (shElem) {
2502
+ const shData = shElem.dataTable;
2503
+ for (let k = 0; k < shData.numColumns; ++k) {
2504
+ const col = shData.getColumn(k);
2505
+ const src = col.data;
2506
+ const dst = new Float32Array(numSplats);
2507
+ for (let i = 0; i < numSplats; ++i) {
2508
+ const n = (src[i] === 0) ? 0 : (src[i] === 255) ? 1 : (src[i] + 0.5) / 256;
2509
+ dst[i] = (n - 0.5) * 8;
2510
+ }
2511
+ result.addColumn(new Column(col.name, dst));
2512
+ }
2513
+ }
2514
+ return result;
2515
+ };
2516
+
2313
2517
  // Half-precision floating point decoder
2314
2518
  function decodeFloat16(encoded) {
2315
2519
  const signBit = (encoded >> 15) & 1;
@@ -3010,7 +3214,7 @@ class CompressedChunk {
3010
3214
  }
3011
3215
 
3012
3216
  // sort the compressed indices into morton order
3013
- const generateOrdering = (dataTable) => {
3217
+ const generateOrdering = (dataTable, indices) => {
3014
3218
  const cx = dataTable.getColumnByName('x').data;
3015
3219
  const cy = dataTable.getColumnByName('y').data;
3016
3220
  const cz = dataTable.getColumnByName('z').data;
@@ -3066,9 +3270,13 @@ const generateOrdering = (dataTable) => {
3066
3270
  console.log('invalid extents', xlen, ylen, zlen);
3067
3271
  return;
3068
3272
  }
3069
- const xmul = 1024 / xlen;
3070
- const ymul = 1024 / ylen;
3071
- const zmul = 1024 / zlen;
3273
+ // all points are identical
3274
+ if (xlen === 0 && ylen === 0 && zlen === 0) {
3275
+ return;
3276
+ }
3277
+ const xmul = (xlen === 0) ? 0 : 1024 / xlen;
3278
+ const ymul = (ylen === 0) ? 0 : 1024 / ylen;
3279
+ const zmul = (zlen === 0) ? 0 : 1024 / zlen;
3072
3280
  const morton = new Uint32Array(indices.length);
3073
3281
  for (let i = 0; i < indices.length; ++i) {
3074
3282
  const ri = indices[i];
@@ -3101,10 +3309,6 @@ const generateOrdering = (dataTable) => {
3101
3309
  start = end;
3102
3310
  }
3103
3311
  };
3104
- const indices = new Uint32Array(dataTable.numRows);
3105
- for (let i = 0; i < indices.length; ++i) {
3106
- indices[i] = i;
3107
- }
3108
3312
  generate(indices);
3109
3313
  return indices;
3110
3314
  };
@@ -3125,11 +3329,13 @@ const vertexProps = [
3125
3329
  'packed_color'
3126
3330
  ];
3127
3331
  const shNames$1 = new Array(45).fill('').map((_, i) => `f_rest_${i}`);
3332
+ // Size of a chunk in the compressed PLY format (number of splats per chunk)
3333
+ const CHUNK_SIZE = 256;
3128
3334
  const writeCompressedPly = async (fileHandle, dataTable) => {
3129
3335
  const shBands = { '9': 1, '24': 2, '-1': 3 }[shNames$1.findIndex(v => !dataTable.hasColumn(v))] ?? 0;
3130
3336
  const outputSHCoeffs = [0, 3, 8, 15][shBands];
3131
3337
  const numSplats = dataTable.numRows;
3132
- const numChunks = Math.ceil(numSplats / 256);
3338
+ const numChunks = Math.ceil(numSplats / CHUNK_SIZE);
3133
3339
  const shHeader = shBands ? [
3134
3340
  `element sh ${numSplats}`,
3135
3341
  new Array(outputSHCoeffs * 3).fill('').map((_, i) => `property uchar f_rest_${i}`)
@@ -3150,26 +3356,30 @@ const writeCompressedPly = async (fileHandle, dataTable) => {
3150
3356
  const splatIData = new Uint32Array(numSplats * vertexProps.length);
3151
3357
  const shData = new Uint8Array(numSplats * outputSHCoeffs * 3);
3152
3358
  // sort splats into some kind of order (morton order rn)
3153
- const sortIndices = generateOrdering(dataTable);
3359
+ const sortIndices = new Uint32Array(dataTable.numRows);
3360
+ for (let i = 0; i < sortIndices.length; ++i) {
3361
+ sortIndices[i] = i;
3362
+ }
3363
+ generateOrdering(dataTable, sortIndices);
3154
3364
  const row = {};
3155
3365
  const chunk = new CompressedChunk();
3156
3366
  for (let i = 0; i < numChunks; ++i) {
3157
- const num = Math.min(numSplats, (i + 1) * 256) - i * 256;
3367
+ const num = Math.min(numSplats, (i + 1) * CHUNK_SIZE) - i * CHUNK_SIZE;
3158
3368
  for (let j = 0; j < num; ++j) {
3159
- const index = sortIndices[i * 256 + j];
3369
+ const index = sortIndices[i * CHUNK_SIZE + j];
3160
3370
  // read splat data
3161
3371
  dataTable.getRow(index, row);
3162
3372
  // update chunk
3163
3373
  chunk.set(j, row);
3164
3374
  // quantize and write sh data
3165
- let off = (i * 256 + j) * outputSHCoeffs * 3;
3375
+ let off = (i * CHUNK_SIZE + j) * outputSHCoeffs * 3;
3166
3376
  for (let k = 0; k < outputSHCoeffs * 3; ++k) {
3167
3377
  const nvalue = row[shNames$1[k]] / 8 + 0.5;
3168
3378
  shData[off++] = Math.max(0, Math.min(255, Math.trunc(nvalue * 256)));
3169
3379
  }
3170
3380
  }
3171
3381
  // repeat the last gaussian to fill the rest of the final chunk
3172
- for (let j = num; j < 256; ++j) {
3382
+ for (let j = num; j < CHUNK_SIZE; ++j) {
3173
3383
  chunk.set(j, row);
3174
3384
  }
3175
3385
  // pack the chunk
@@ -3177,7 +3387,7 @@ const writeCompressedPly = async (fileHandle, dataTable) => {
3177
3387
  // store the float data
3178
3388
  chunkData.set(chunk.chunkData, i * 18);
3179
3389
  // write packed bits
3180
- const offset = i * 256 * 4;
3390
+ const offset = i * CHUNK_SIZE * 4;
3181
3391
  for (let j = 0; j < num; ++j) {
3182
3392
  splatIData[offset + j * 4 + 0] = chunk.position[j];
3183
3393
  splatIData[offset + j * 4 + 1] = chunk.rotation[j];
@@ -3368,11 +3578,11 @@ const writePly = async (fileHandle, plyData) => {
3368
3578
 
3369
3579
  /**
3370
3580
  * The engine version number. This is in semantic versioning format (MAJOR.MINOR.PATCH).
3371
- */ const version = '2.10.3';
3581
+ */ const version = '2.10.6';
3372
3582
  /**
3373
3583
  * The engine revision number. This is the Git hash of the last commit made to the branch
3374
3584
  * from which the engine was built.
3375
- */ const revision = '2dc84cf';
3585
+ */ const revision = '5322cb8';
3376
3586
  /**
3377
3587
  * Merge the contents of two objects into a single object.
3378
3588
  *
@@ -60807,7 +61017,7 @@ var stdDeclarationPS$1 = /* glsl */ `
60807
61017
  // parallax
60808
61018
  #ifdef STD_HEIGHT_MAP
60809
61019
  vec2 dUvOffset;
60810
- #ifdef STD_DIFFUSE_TEXTURE_ALLOCATE
61020
+ #ifdef STD_HEIGHT_TEXTURE_ALLOCATE
60811
61021
  uniform sampler2D texture_heightMap;
60812
61022
  #endif
60813
61023
  #endif
@@ -68855,7 +69065,7 @@ var stdDeclarationPS = /* wgsl */ `
68855
69065
  // parallax
68856
69066
  #ifdef STD_HEIGHT_MAP
68857
69067
  var<private> dUvOffset: vec2f;
68858
- #ifdef STD_DIFFUSE_TEXTURE_ALLOCATE
69068
+ #ifdef STD_HEIGHT_TEXTURE_ALLOCATE
68859
69069
  var texture_heightMap : texture_2d<f32>;
68860
69070
  var texture_heightMapSampler : sampler;
68861
69071
  #endif
@@ -83332,7 +83542,7 @@ const getPrimitiveType = (primitive)=>{
83332
83542
  return PRIMITIVE_TRIANGLES;
83333
83543
  }
83334
83544
  };
83335
- const generateIndices = (numVertices)=>{
83545
+ const generateIndices$1 = (numVertices)=>{
83336
83546
  const dummyIndices = new Uint16Array(numVertices);
83337
83547
  for(let i = 0; i < numVertices; i++){
83338
83548
  dummyIndices[i] = i;
@@ -83363,7 +83573,7 @@ const generateNormals = (sourceDesc, indices)=>{
83363
83573
  const numVertices = p.count;
83364
83574
  // generate indices if necessary
83365
83575
  if (!indices) {
83366
- indices = generateIndices(numVertices);
83576
+ indices = generateIndices$1(numVertices);
83367
83577
  }
83368
83578
  // generate normals
83369
83579
  const normalsTemp = calculateNormals(positions, indices);
@@ -89294,8 +89504,7 @@ ${useF16 ? 'enable f16;' : ''}
89294
89504
 
89295
89505
  struct Uniforms {
89296
89506
  numPoints: u32,
89297
- numCentroids: u32,
89298
- pointBase: u32
89507
+ numCentroids: u32
89299
89508
  };
89300
89509
 
89301
89510
  @group(0) @binding(0) var<uniform> uniforms: Uniforms;
@@ -89328,7 +89537,7 @@ fn main(
89328
89537
  @builtin(num_workgroups) num_workgroups: vec3u
89329
89538
  ) {
89330
89539
  // calculate row index for this thread point
89331
- let pointIndex = uniforms.pointBase + global_id.x + global_id.y * num_workgroups.x * 64u;
89540
+ let pointIndex = global_id.x + global_id.y * num_workgroups.x * 64u;
89332
89541
 
89333
89542
  // copy the point data from global memory
89334
89543
  var point: array<${floatType}, numColumns>;
@@ -89385,37 +89594,37 @@ fn main(
89385
89594
  const roundUp = (value, multiple) => {
89386
89595
  return Math.ceil(value / multiple) * multiple;
89387
89596
  };
89388
- const interleaveData = (dataTable, useF16) => {
89389
- const { numRows, numColumns } = dataTable;
89390
- if (useF16) {
89391
- const result = new Uint16Array(roundUp(numColumns * numRows, 2));
89597
+ const interleaveData = (result, dataTable, numRows, rowOffset) => {
89598
+ const { numColumns } = dataTable;
89599
+ if (result instanceof Uint16Array) {
89600
+ // interleave shorts
89392
89601
  for (let c = 0; c < numColumns; ++c) {
89393
89602
  const column = dataTable.columns[c];
89394
89603
  for (let r = 0; r < numRows; ++r) {
89395
- result[r * numColumns + c] = FloatPacking.float2Half(column.data[r]);
89604
+ result[r * numColumns + c] = FloatPacking.float2Half(column.data[rowOffset + r]);
89396
89605
  }
89397
89606
  }
89398
- return result;
89399
89607
  }
89400
- const result = new Float32Array(numColumns * numRows);
89401
- for (let c = 0; c < numColumns; ++c) {
89402
- const column = dataTable.columns[c];
89403
- for (let r = 0; r < numRows; ++r) {
89404
- result[r * numColumns + c] = column.data[r];
89608
+ else {
89609
+ // interleave floats
89610
+ for (let c = 0; c < numColumns; ++c) {
89611
+ const column = dataTable.columns[c];
89612
+ for (let r = 0; r < numRows; ++r) {
89613
+ result[r * numColumns + c] = column.data[rowOffset + r];
89614
+ }
89405
89615
  }
89406
89616
  }
89407
- return result;
89408
89617
  };
89409
- class GpuCluster {
89618
+ class GpuClustering {
89410
89619
  execute;
89411
89620
  destroy;
89412
- constructor(gpuDevice, points, numCentroids) {
89621
+ constructor(gpuDevice, numColumns, numCentroids) {
89413
89622
  const device = gpuDevice.app.graphicsDevice;
89414
89623
  // Check if device supports f16
89415
- const useF16 = 'supportsShaderF16' in device && device.supportsShaderF16;
89416
- const bytesPerFloat = useF16 ? 2 : 4;
89417
- const numPoints = points.numRows;
89418
- const numColumns = points.numColumns;
89624
+ const useF16 = !!('supportsShaderF16' in device && device.supportsShaderF16);
89625
+ const workgroupSize = 64;
89626
+ const workgroupsPerBatch = 1024;
89627
+ const batchSize = workgroupsPerBatch * workgroupSize;
89419
89628
  const bindGroupFormat = new BindGroupFormat(device, [
89420
89629
  new BindUniformBufferFormat('uniforms', SHADERSTAGE_COMPUTE),
89421
89630
  new BindStorageBufferFormat('pointsBuffer', SHADERSTAGE_COMPUTE, true),
@@ -89430,47 +89639,43 @@ class GpuCluster {
89430
89639
  computeUniformBufferFormats: {
89431
89640
  uniforms: new UniformBufferFormat(device, [
89432
89641
  new UniformFormat('numPoints', UNIFORMTYPE_UINT),
89433
- new UniformFormat('numCentroids', UNIFORMTYPE_UINT),
89434
- new UniformFormat('pointBase', UNIFORMTYPE_UINT)
89642
+ new UniformFormat('numCentroids', UNIFORMTYPE_UINT)
89435
89643
  ])
89436
89644
  },
89437
89645
  // @ts-ignore
89438
89646
  computeBindGroupFormat: bindGroupFormat
89439
89647
  });
89648
+ const interleavedPoints = useF16 ? new Uint16Array(roundUp(numColumns * batchSize, 2)) : new Float32Array(numColumns * batchSize);
89649
+ const interleavedCentroids = useF16 ? new Uint16Array(roundUp(numColumns * numCentroids, 2)) : new Float32Array(numColumns * numCentroids);
89650
+ const resultsData = new Uint32Array(batchSize);
89651
+ const pointsBuffer = new StorageBuffer(device, interleavedPoints.byteLength, BUFFERUSAGE_COPY_DST);
89652
+ const centroidsBuffer = new StorageBuffer(device, interleavedCentroids.byteLength, BUFFERUSAGE_COPY_DST);
89653
+ const resultsBuffer = new StorageBuffer(device, resultsData.byteLength, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST);
89440
89654
  const compute = new Compute(device, shader, 'compute-cluster');
89441
- const pointsBuffer = new StorageBuffer(device, useF16 ? roundUp(numColumns * numPoints, 2) * 2 : numColumns * numPoints * 4, BUFFERUSAGE_COPY_DST);
89442
- const centroidsBuffer = new StorageBuffer(device, numColumns * numCentroids * bytesPerFloat, BUFFERUSAGE_COPY_DST);
89443
- const resultsBuffer = new StorageBuffer(device, numPoints * 4, BUFFERUSAGE_COPY_SRC | BUFFERUSAGE_COPY_DST);
89444
- // interleave the points table data and write to gpu
89445
- const interleavedPoints = interleaveData(points, useF16);
89446
- pointsBuffer.write(0, interleavedPoints, 0, interleavedPoints.length);
89447
89655
  compute.setParameter('pointsBuffer', pointsBuffer);
89448
89656
  compute.setParameter('centroidsBuffer', centroidsBuffer);
89449
89657
  compute.setParameter('resultsBuffer', resultsBuffer);
89450
- this.execute = async (centroids, labels) => {
89451
- // interleave centroids and write to gpu
89452
- const interleavedCentroids = interleaveData(centroids, useF16);
89453
- centroidsBuffer.write(0, interleavedCentroids, 0, interleavedCentroids.length);
89454
- compute.setParameter('numPoints', points.numRows);
89455
- compute.setParameter('numCentroids', centroids.numRows);
89456
- // execute in batches of 1024 worksgroups
89457
- const workgroupSize = 64;
89458
- const workgroupsPerBatch = 1024;
89459
- const batchSize = workgroupsPerBatch * workgroupSize;
89658
+ this.execute = async (points, centroids, labels) => {
89659
+ const numPoints = points.numRows;
89460
89660
  const numBatches = Math.ceil(numPoints / batchSize);
89661
+ // upload centroid data to gpu
89662
+ interleaveData(interleavedCentroids, centroids, numCentroids, 0);
89663
+ centroidsBuffer.write(0, interleavedCentroids, 0, interleavedCentroids.length);
89664
+ compute.setParameter('numCentroids', numCentroids);
89461
89665
  for (let batch = 0; batch < numBatches; batch++) {
89462
89666
  const currentBatchSize = Math.min(numPoints - batch * batchSize, batchSize);
89463
89667
  const groups = Math.ceil(currentBatchSize / 64);
89464
- compute.setParameter('pointBase', batch * batchSize);
89668
+ // write this batch of point data to gpu
89669
+ interleaveData(interleavedPoints, points, currentBatchSize, batch * batchSize);
89670
+ pointsBuffer.write(0, interleavedPoints, 0, useF16 ? roundUp(numColumns * currentBatchSize, 2) : numColumns * currentBatchSize);
89671
+ compute.setParameter('numPoints', currentBatchSize);
89465
89672
  // start compute job
89466
89673
  compute.setupDispatch(groups);
89467
89674
  device.computeDispatch([compute], `cluster-dispatch-${batch}`);
89468
- // FIXME: submit call is required, but not public API
89469
- // @ts-ignore
89470
- device.submit();
89675
+ // read results from gpu and store in labels
89676
+ await resultsBuffer.read(0, currentBatchSize * 4, resultsData, true);
89677
+ labels.set(resultsData.subarray(0, currentBatchSize), batch * batchSize);
89471
89678
  }
89472
- // read results from gpu
89473
- await resultsBuffer.read(0, undefined, labels, true);
89474
89679
  };
89475
89680
  this.destroy = () => {
89476
89681
  pointsBuffer.destroy();
@@ -89550,14 +89755,14 @@ const kmeans = async (points, k, iterations, device) => {
89550
89755
  // construct centroids data table and assign initial values
89551
89756
  const centroids = new DataTable(points.columns.map(c => new Column(c.name, new Float32Array(k))));
89552
89757
  initializeCentroids(points, centroids, row);
89553
- const gpuCluster = device && new GpuCluster(device, points, k);
89758
+ const gpuClustering = device && new GpuClustering(device, points.numColumns, k);
89554
89759
  const labels = new Uint32Array(points.numRows);
89555
89760
  let converged = false;
89556
89761
  let steps = 0;
89557
89762
  console.log(`Running k-means clustering: dims=${points.numColumns} points=${points.numRows} clusters=${k} iterations=${iterations}...`);
89558
89763
  while (!converged) {
89559
- if (gpuCluster) {
89560
- await gpuCluster.execute(centroids, labels);
89764
+ if (gpuClustering) {
89765
+ await gpuClustering.execute(points, centroids, labels);
89561
89766
  }
89562
89767
  else {
89563
89768
  clusterKdTreeCpu(points, centroids, labels);
@@ -89574,17 +89779,20 @@ const kmeans = async (points, k, iterations, device) => {
89574
89779
  }
89575
89780
  stdout.write('#');
89576
89781
  }
89782
+ if (gpuClustering) {
89783
+ gpuClustering.destroy();
89784
+ }
89577
89785
  console.log(' done 🎉');
89578
89786
  return { centroids, labels };
89579
89787
  };
89580
89788
 
89581
89789
  const shNames = new Array(45).fill('').map((_, i) => `f_rest_${i}`);
89582
- const calcMinMax = (dataTable, columnNames) => {
89790
+ const calcMinMax = (dataTable, columnNames, indices) => {
89583
89791
  const columns = columnNames.map(name => dataTable.getColumnByName(name));
89584
89792
  const minMax = columnNames.map(() => [Infinity, -Infinity]);
89585
89793
  const row = {};
89586
- for (let i = 0; i < dataTable.numRows; ++i) {
89587
- const r = dataTable.getRow(i, row, columns);
89794
+ for (let i = 0; i < indices.length; ++i) {
89795
+ const r = dataTable.getRow(indices[i], row, columns);
89588
89796
  for (let j = 0; j < columnNames.length; ++j) {
89589
89797
  const value = r[columnNames[j]];
89590
89798
  if (value < minMax[j][0])
@@ -89602,10 +89810,16 @@ const logTransform = (value) => {
89602
89810
  const identity = (index, width) => {
89603
89811
  return index;
89604
89812
  };
89605
- const writeSogs = async (fileHandle, dataTable, outputFilename, shIterations = 10, shMethod) => {
89606
- // generate an optimal ordering
89607
- const sortIndices = generateOrdering(dataTable);
89608
- const numRows = dataTable.numRows;
89813
+ const generateIndices = (dataTable) => {
89814
+ const result = new Uint32Array(dataTable.numRows);
89815
+ for (let i = 0; i < result.length; ++i) {
89816
+ result[i] = i;
89817
+ }
89818
+ generateOrdering(dataTable, result);
89819
+ return result;
89820
+ };
89821
+ const writeSog = async (fileHandle, dataTable, outputFilename, shIterations = 10, shMethod, indices = generateIndices(dataTable)) => {
89822
+ const numRows = indices.length;
89609
89823
  const width = Math.ceil(Math.sqrt(numRows) / 16) * 16;
89610
89824
  const height = Math.ceil(numRows / width / 16) * 16;
89611
89825
  const channels = 4;
@@ -89623,10 +89837,10 @@ const writeSogs = async (fileHandle, dataTable, outputFilename, shIterations = 1
89623
89837
  const meansL = new Uint8Array(width * height * channels);
89624
89838
  const meansU = new Uint8Array(width * height * channels);
89625
89839
  const meansNames = ['x', 'y', 'z'];
89626
- const meansMinMax = calcMinMax(dataTable, meansNames).map(v => v.map(logTransform));
89840
+ const meansMinMax = calcMinMax(dataTable, meansNames, indices).map(v => v.map(logTransform));
89627
89841
  const meansColumns = meansNames.map(name => dataTable.getColumnByName(name));
89628
- for (let i = 0; i < dataTable.numRows; ++i) {
89629
- dataTable.getRow(sortIndices[i], row, meansColumns);
89842
+ for (let i = 0; i < indices.length; ++i) {
89843
+ dataTable.getRow(indices[i], row, meansColumns);
89630
89844
  const x = 65535 * (logTransform(row.x) - meansMinMax[0][0]) / (meansMinMax[0][1] - meansMinMax[0][0]);
89631
89845
  const y = 65535 * (logTransform(row.y) - meansMinMax[1][0]) / (meansMinMax[1][1] - meansMinMax[1][0]);
89632
89846
  const z = 65535 * (logTransform(row.z) - meansMinMax[2][0]) / (meansMinMax[2][1] - meansMinMax[2][0]);
@@ -89647,8 +89861,8 @@ const writeSogs = async (fileHandle, dataTable, outputFilename, shIterations = 1
89647
89861
  const quatNames = ['rot_0', 'rot_1', 'rot_2', 'rot_3'];
89648
89862
  const quatColumns = quatNames.map(name => dataTable.getColumnByName(name));
89649
89863
  const q = [0, 0, 0, 0];
89650
- for (let i = 0; i < dataTable.numRows; ++i) {
89651
- dataTable.getRow(sortIndices[i], row, quatColumns);
89864
+ for (let i = 0; i < indices.length; ++i) {
89865
+ dataTable.getRow(indices[i], row, quatColumns);
89652
89866
  q[0] = row.rot_0;
89653
89867
  q[1] = row.rot_1;
89654
89868
  q[2] = row.rot_2;
@@ -89688,9 +89902,9 @@ const writeSogs = async (fileHandle, dataTable, outputFilename, shIterations = 1
89688
89902
  const scales = new Uint8Array(width * height * channels);
89689
89903
  const scaleNames = ['scale_0', 'scale_1', 'scale_2'];
89690
89904
  const scaleColumns = scaleNames.map(name => dataTable.getColumnByName(name));
89691
- const scaleMinMax = calcMinMax(dataTable, scaleNames);
89692
- for (let i = 0; i < dataTable.numRows; ++i) {
89693
- dataTable.getRow(sortIndices[i], row, scaleColumns);
89905
+ const scaleMinMax = calcMinMax(dataTable, scaleNames, indices);
89906
+ for (let i = 0; i < indices.length; ++i) {
89907
+ dataTable.getRow(indices[i], row, scaleColumns);
89694
89908
  const ti = layout(i);
89695
89909
  scales[ti * 4] = 255 * (row.scale_0 - scaleMinMax[0][0]) / (scaleMinMax[0][1] - scaleMinMax[0][0]);
89696
89910
  scales[ti * 4 + 1] = 255 * (row.scale_1 - scaleMinMax[1][0]) / (scaleMinMax[1][1] - scaleMinMax[1][0]);
@@ -89702,9 +89916,9 @@ const writeSogs = async (fileHandle, dataTable, outputFilename, shIterations = 1
89702
89916
  const sh0 = new Uint8Array(width * height * channels);
89703
89917
  const sh0Names = ['f_dc_0', 'f_dc_1', 'f_dc_2', 'opacity'];
89704
89918
  const sh0Columns = sh0Names.map(name => dataTable.getColumnByName(name));
89705
- const sh0MinMax = calcMinMax(dataTable, sh0Names);
89706
- for (let i = 0; i < dataTable.numRows; ++i) {
89707
- dataTable.getRow(sortIndices[i], row, sh0Columns);
89919
+ const sh0MinMax = calcMinMax(dataTable, sh0Names, indices);
89920
+ for (let i = 0; i < indices.length; ++i) {
89921
+ dataTable.getRow(indices[i], row, sh0Columns);
89708
89922
  const ti = layout(i);
89709
89923
  sh0[ti * 4] = 255 * (row.f_dc_0 - sh0MinMax[0][0]) / (sh0MinMax[0][1] - sh0MinMax[0][0]);
89710
89924
  sh0[ti * 4 + 1] = 255 * (row.f_dc_1 - sh0MinMax[1][0]) / (sh0MinMax[1][1] - sh0MinMax[1][0]);
@@ -89753,14 +89967,18 @@ const writeSogs = async (fileHandle, dataTable, outputFilename, shIterations = 1
89753
89967
  const shColumnNames = shNames.slice(0, shCoeffs * 3);
89754
89968
  const shColumns = shColumnNames.map(name => dataTable.getColumnByName(name));
89755
89969
  // create a table with just spherical harmonics data
89970
+ // NOTE: this step should also copy the rows referenced in indices, but that's a
89971
+ // lot of duplicate data when it's unneeded (which is currently never). so that
89972
+ // means k-means is clustering the full dataset, instead of the rows referenced in
89973
+ // indices.
89756
89974
  const shDataTable = new DataTable(shColumns);
89757
- const paletteSize = Math.min(64, 2 ** Math.floor(Math.log2(dataTable.numRows / 1024))) * 1024;
89975
+ const paletteSize = Math.min(64, 2 ** Math.floor(Math.log2(indices.length / 1024))) * 1024;
89758
89976
  // calculate kmeans
89759
89977
  const gpuDevice = shMethod === 'gpu' ? await createDevice() : null;
89760
89978
  const { centroids, labels } = await kmeans(shDataTable, paletteSize, shIterations, gpuDevice);
89761
89979
  // write centroids
89762
89980
  const centroidsBuf = new Uint8Array(64 * shCoeffs * Math.ceil(centroids.numRows / 64) * channels);
89763
- const centroidsMinMax = calcMinMax(shDataTable, shColumnNames);
89981
+ const centroidsMinMax = calcMinMax(shDataTable, shColumnNames, indices);
89764
89982
  const centroidsMin = centroidsMinMax.map(v => v[0]).reduce((a, b) => Math.min(a, b));
89765
89983
  const centroidsMax = centroidsMinMax.map(v => v[1]).reduce((a, b) => Math.max(a, b));
89766
89984
  const centroidsRow = {};
@@ -89779,8 +89997,8 @@ const writeSogs = async (fileHandle, dataTable, outputFilename, shIterations = 1
89779
89997
  await write('shN_centroids.webp', centroidsBuf, 64 * shCoeffs, Math.ceil(centroids.numRows / 64));
89780
89998
  // write labels
89781
89999
  const labelsBuf = new Uint8Array(width * height * channels);
89782
- for (let i = 0; i < dataTable.numRows; ++i) {
89783
- const label = labels[sortIndices[i]];
90000
+ for (let i = 0; i < indices.length; ++i) {
90001
+ const label = labels[indices[i]];
89784
90002
  const ti = layout(i);
89785
90003
  labelsBuf[ti * 4] = label & 0xff;
89786
90004
  labelsBuf[ti * 4 + 1] = (label >> 8) & 0xff;
@@ -89789,7 +90007,7 @@ const writeSogs = async (fileHandle, dataTable, outputFilename, shIterations = 1
89789
90007
  }
89790
90008
  await write('shN_labels.webp', labelsBuf);
89791
90009
  meta.shN = {
89792
- shape: [dataTable.numRows, shCoeffs],
90010
+ shape: [indices.length, shCoeffs],
89793
90011
  dtype: 'float32',
89794
90012
  mins: centroidsMin,
89795
90013
  maxs: centroidsMax,
@@ -89815,7 +90033,16 @@ const readFile = async (filename) => {
89815
90033
  fileData = await readSplat(inputFile);
89816
90034
  }
89817
90035
  else if (lowerFilename.endsWith('.ply')) {
89818
- fileData = await readPly$1(inputFile);
90036
+ const ply = await readPly$1(inputFile);
90037
+ if (isCompressedPly$1(ply)) {
90038
+ fileData = {
90039
+ comments: ply.comments,
90040
+ elements: [{ name: 'vertex', dataTable: decompressPly(ply) }]
90041
+ };
90042
+ }
90043
+ else {
90044
+ fileData = ply;
90045
+ }
89819
90046
  }
89820
90047
  else {
89821
90048
  await inputFile.close();
@@ -89829,8 +90056,8 @@ const getOutputFormat = (filename) => {
89829
90056
  if (lowerFilename.endsWith('.csv')) {
89830
90057
  return 'csv';
89831
90058
  }
89832
- else if (lowerFilename.endsWith('.json')) {
89833
- return 'json';
90059
+ else if (lowerFilename.endsWith('meta.json')) {
90060
+ return 'sog';
89834
90061
  }
89835
90062
  else if (lowerFilename.endsWith('.compressed.ply')) {
89836
90063
  return 'compressed-ply';
@@ -89862,8 +90089,8 @@ const writeFile = async (filename, dataTable, options) => {
89862
90089
  case 'csv':
89863
90090
  await writeCsv(outputFile, dataTable);
89864
90091
  break;
89865
- case 'json':
89866
- await writeSogs(outputFile, dataTable, filename, options.iterations, options.gpu ? 'gpu' : 'cpu');
90092
+ case 'sog':
90093
+ await writeSog(outputFile, dataTable, filename, options.iterations, options.gpu ? 'gpu' : 'cpu');
89867
90094
  break;
89868
90095
  case 'compressed-ply':
89869
90096
  await writeCompressedPly(outputFile, dataTable);
@@ -90035,7 +90262,7 @@ const parseArguments = () => {
90035
90262
  });
90036
90263
  break;
90037
90264
  case 'filterByValue': {
90038
- const parts = t.value.split(',').map(p => p.trim());
90265
+ const parts = t.value.split(',').map((p) => p.trim());
90039
90266
  if (parts.length !== 3) {
90040
90267
  throw new Error(`Invalid filterByValue value: ${t.value}`);
90041
90268
  }
@@ -90048,7 +90275,7 @@ const parseArguments = () => {
90048
90275
  break;
90049
90276
  }
90050
90277
  case 'filterBands': {
90051
- const shBands = parseNumber(t.value);
90278
+ const shBands = parseInteger(t.value);
90052
90279
  if (![0, 1, 2, 3].includes(shBands)) {
90053
90280
  throw new Error(`Invalid filterBands value: ${t.value}. Must be 0, 1, 2, or 3.`);
90054
90281
  }
@@ -90143,7 +90370,7 @@ const main = async () => {
90143
90370
  }
90144
90371
  catch (err) {
90145
90372
  // handle errors
90146
- console.error(`error: ${err.message}`);
90373
+ console.error(err);
90147
90374
  exit(1);
90148
90375
  }
90149
90376
  const endTime = hrtime(startTime);