@playcanvas/splat-transform 1.9.0 → 1.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.mjs CHANGED
@@ -742,6 +742,16 @@ class Progress {
742
742
  if (!quiet)
743
743
  impl.onProgress(this.currentNode);
744
744
  }
745
+ /**
746
+ * Cancel the current progress node, popping it from the stack without
747
+ * completing remaining steps. Use this before early exits (e.g. break)
748
+ * to keep the progress stack balanced.
749
+ */
750
+ cancel() {
751
+ if (!this.currentNode)
752
+ return;
753
+ this.currentNode = this.currentNode.parent;
754
+ }
745
755
  /**
746
756
  * Advance to the next step. Auto-increments the step counter.
747
757
  * Auto-ends when all steps are complete.
@@ -929,33 +939,74 @@ const sortMortonOrder = (dataTable, indices) => {
929
939
  generate(indices);
930
940
  };
931
941
 
942
+ const nthElement = (arr, lo, hi, k, values) => {
943
+ while (lo < hi) {
944
+ const mid = (lo + hi) >> 1;
945
+ const va = values[arr[lo]], vb = values[arr[mid]], vc = values[arr[hi]];
946
+ let pivotIdx;
947
+ if ((vb - va) * (vc - vb) >= 0)
948
+ pivotIdx = mid;
949
+ else if ((va - vb) * (vc - va) >= 0)
950
+ pivotIdx = lo;
951
+ else
952
+ pivotIdx = hi;
953
+ const pivotVal = values[arr[pivotIdx]];
954
+ let tmp = arr[pivotIdx];
955
+ arr[pivotIdx] = arr[hi];
956
+ arr[hi] = tmp;
957
+ let store = lo;
958
+ for (let i = lo; i < hi; i++) {
959
+ if (values[arr[i]] < pivotVal) {
960
+ tmp = arr[i];
961
+ arr[i] = arr[store];
962
+ arr[store] = tmp;
963
+ store++;
964
+ }
965
+ }
966
+ tmp = arr[store];
967
+ arr[store] = arr[hi];
968
+ arr[hi] = tmp;
969
+ if (store === k)
970
+ return;
971
+ else if (store < k)
972
+ lo = store + 1;
973
+ else
974
+ hi = store - 1;
975
+ }
976
+ };
932
977
  class KdTree {
933
978
  centroids;
934
979
  root;
980
+ colData;
935
981
  constructor(centroids) {
936
- const build = (indices, depth) => {
937
- const { centroids } = this;
938
- const values = centroids.columns[depth % centroids.numColumns].data;
939
- indices.sort((a, b) => values[a] - values[b]);
940
- if (indices.length === 1) {
941
- return {
942
- index: indices[0],
943
- count: 1
944
- };
945
- }
946
- else if (indices.length === 2) {
982
+ const numCols = centroids.numColumns;
983
+ const colData = centroids.columns.map(c => c.data);
984
+ const indices = new Uint32Array(centroids.numRows);
985
+ for (let i = 0; i < indices.length; ++i) {
986
+ indices[i] = i;
987
+ }
988
+ const build = (lo, hi, depth) => {
989
+ const count = hi - lo + 1;
990
+ if (count === 1) {
991
+ return { index: indices[lo], count: 1 };
992
+ }
993
+ const values = colData[depth % numCols];
994
+ if (count === 2) {
995
+ if (values[indices[lo]] > values[indices[hi]]) {
996
+ const tmp = indices[lo];
997
+ indices[lo] = indices[hi];
998
+ indices[hi] = tmp;
999
+ }
947
1000
  return {
948
- index: indices[0],
1001
+ index: indices[lo],
949
1002
  count: 2,
950
- right: {
951
- index: indices[1],
952
- count: 1
953
- }
1003
+ right: { index: indices[hi], count: 1 }
954
1004
  };
955
1005
  }
956
- const mid = indices.length >> 1;
957
- const left = build(indices.subarray(0, mid), depth + 1);
958
- const right = build(indices.subarray(mid + 1), depth + 1);
1006
+ const mid = lo + (count >> 1);
1007
+ nthElement(indices, lo, hi, mid, values);
1008
+ const left = build(lo, mid - 1, depth + 1);
1009
+ const right = build(mid + 1, hi, depth + 1);
959
1010
  return {
960
1011
  index: indices[mid],
961
1012
  count: 1 + left.count + right.count,
@@ -963,48 +1014,39 @@ class KdTree {
963
1014
  right
964
1015
  };
965
1016
  };
966
- const indices = new Uint32Array(centroids.numRows);
967
- for (let i = 0; i < indices.length; ++i) {
968
- indices[i] = i;
969
- }
970
1017
  this.centroids = centroids;
971
- this.root = build(indices, 0);
1018
+ this.colData = colData;
1019
+ this.root = build(0, indices.length - 1, 0);
972
1020
  }
973
1021
  findNearest(point, filterFunc) {
974
- const { centroids } = this;
975
- const { numColumns } = centroids;
976
- const calcDistance = (index) => {
977
- let l = 0;
978
- for (let i = 0; i < numColumns; ++i) {
979
- const v = centroids.columns[i].data[index] - point[i];
980
- l += v * v;
981
- }
982
- return l;
983
- };
1022
+ const colData = this.colData;
1023
+ const numCols = colData.length;
984
1024
  let mind = Infinity;
985
1025
  let mini = -1;
986
1026
  let cnt = 0;
987
- const recurse = (node, depth) => {
988
- const axis = depth % numColumns;
989
- const distance = point[axis] - centroids.columns[axis].data[node.index];
1027
+ const recurse = (node, axis) => {
1028
+ const distance = point[axis] - colData[axis][node.index];
990
1029
  const next = (distance > 0) ? node.right : node.left;
1030
+ const nextAxis = axis + 1 < numCols ? axis + 1 : 0;
991
1031
  cnt++;
992
1032
  if (next) {
993
- recurse(next, depth + 1);
1033
+ recurse(next, nextAxis);
994
1034
  }
995
- // check index
996
1035
  if (!filterFunc || filterFunc(node.index)) {
997
- const thisd = calcDistance(node.index);
1036
+ let thisd = 0;
1037
+ for (let c = 0; c < numCols; c++) {
1038
+ const v = colData[c][node.index] - point[c];
1039
+ thisd += v * v;
1040
+ }
998
1041
  if (thisd < mind) {
999
1042
  mind = thisd;
1000
1043
  mini = node.index;
1001
1044
  }
1002
1045
  }
1003
- // check the other side
1004
1046
  if (distance * distance < mind) {
1005
1047
  const other = next === node.right ? node.left : node.right;
1006
1048
  if (other) {
1007
- recurse(other, depth + 1);
1049
+ recurse(other, nextAxis);
1008
1050
  }
1009
1051
  }
1010
1052
  };
@@ -1016,16 +1058,8 @@ class KdTree {
1016
1058
  return { indices: new Int32Array(0), distances: new Float32Array(0) };
1017
1059
  }
1018
1060
  k = Math.min(k, this.centroids.numRows);
1019
- const { centroids } = this;
1020
- const { numColumns } = centroids;
1021
- const calcDistance = (index) => {
1022
- let l = 0;
1023
- for (let i = 0; i < numColumns; ++i) {
1024
- const v = centroids.columns[i].data[index] - point[i];
1025
- l += v * v;
1026
- }
1027
- return l;
1028
- };
1061
+ const colData = this.colData;
1062
+ const numCols = colData.length;
1029
1063
  // Bounded max-heap: stores (distance, index) pairs sorted so the
1030
1064
  // farthest element is at position 0, enabling O(1) pruning bound.
1031
1065
  const heapDist = new Float32Array(k).fill(Infinity);
@@ -1033,14 +1067,12 @@ class KdTree {
1033
1067
  let heapSize = 0;
1034
1068
  const heapPush = (dist, idx) => {
1035
1069
  if (heapSize < k) {
1036
- // Heap not full yet -- insert via sift-up
1037
1070
  let pos = heapSize++;
1038
1071
  heapDist[pos] = dist;
1039
1072
  heapIdx[pos] = idx;
1040
1073
  while (pos > 0) {
1041
1074
  const parent = (pos - 1) >> 1;
1042
1075
  if (heapDist[parent] < heapDist[pos]) {
1043
- // swap
1044
1076
  const td = heapDist[parent];
1045
1077
  heapDist[parent] = heapDist[pos];
1046
1078
  heapDist[pos] = td;
@@ -1055,7 +1087,6 @@ class KdTree {
1055
1087
  }
1056
1088
  }
1057
1089
  else if (dist < heapDist[0]) {
1058
- // Replace root (farthest) and sift-down
1059
1090
  heapDist[0] = dist;
1060
1091
  heapIdx[0] = idx;
1061
1092
  let pos = 0;
@@ -1079,22 +1110,26 @@ class KdTree {
1079
1110
  }
1080
1111
  }
1081
1112
  };
1082
- const recurse = (node, depth) => {
1083
- const axis = depth % numColumns;
1084
- const distance = point[axis] - centroids.columns[axis].data[node.index];
1113
+ const recurse = (node, axis) => {
1114
+ const distance = point[axis] - colData[axis][node.index];
1085
1115
  const next = (distance > 0) ? node.right : node.left;
1116
+ const nextAxis = axis + 1 < numCols ? axis + 1 : 0;
1086
1117
  if (next) {
1087
- recurse(next, depth + 1);
1118
+ recurse(next, nextAxis);
1088
1119
  }
1089
1120
  if (!filterFunc || filterFunc(node.index)) {
1090
- const thisd = calcDistance(node.index);
1121
+ let thisd = 0;
1122
+ for (let c = 0; c < numCols; c++) {
1123
+ const v = colData[c][node.index] - point[c];
1124
+ thisd += v * v;
1125
+ }
1091
1126
  heapPush(thisd, node.index);
1092
1127
  }
1093
1128
  const bound = heapSize < k ? Infinity : heapDist[0];
1094
1129
  if (distance * distance < bound) {
1095
1130
  const other = next === node.right ? node.left : node.right;
1096
1131
  if (other) {
1097
- recurse(other, depth + 1);
1132
+ recurse(other, nextAxis);
1098
1133
  }
1099
1134
  }
1100
1135
  };
@@ -1129,6 +1164,54 @@ const OPACITY_PRUNE_THRESHOLD = 0.1;
1129
1164
  const KNN_K = 16;
1130
1165
  const MC_SAMPLES = 1;
1131
1166
  const EPS_COV = 1e-8;
1167
+ const PROGRESS_TICKS = 100;
1168
+ // Radix sort edge indices by their Float32 costs.
1169
+ // Converts floats to sortable uint32 keys (preserving order), then does
1170
+ // 4-pass LSD radix sort with 8-bit radix. Returns the number of valid
1171
+ // (finite-cost) edges written to `out`.
1172
+ const radixSortIndicesByFloat = (out, count, keys) => {
1173
+ const keyBits = new Uint32Array(keys.buffer, keys.byteOffset, keys.length);
1174
+ const sortKeys = new Uint32Array(count);
1175
+ let validCount = 0;
1176
+ for (let i = 0; i < count; i++) {
1177
+ const bits = keyBits[i];
1178
+ if ((bits & 0x7F800000) === 0x7F800000)
1179
+ continue;
1180
+ sortKeys[validCount] = (bits & 0x80000000) ? ~bits >>> 0 : (bits | 0x80000000) >>> 0;
1181
+ out[validCount] = i;
1182
+ validCount++;
1183
+ }
1184
+ if (validCount <= 1)
1185
+ return validCount;
1186
+ const n = validCount;
1187
+ const temp = new Uint32Array(n);
1188
+ const tempKeys = new Uint32Array(n);
1189
+ const counts = new Uint32Array(256);
1190
+ for (let pass = 0; pass < 4; pass++) {
1191
+ const shift = pass << 3;
1192
+ const srcIdx = (pass & 1) ? temp : out;
1193
+ const dstIdx = (pass & 1) ? out : temp;
1194
+ const srcK = (pass & 1) ? tempKeys : sortKeys;
1195
+ const dstK = (pass & 1) ? sortKeys : tempKeys;
1196
+ counts.fill(0);
1197
+ for (let i = 0; i < n; i++) {
1198
+ counts[(srcK[i] >>> shift) & 0xFF]++;
1199
+ }
1200
+ let sum = 0;
1201
+ for (let b = 0; b < 256; b++) {
1202
+ const c = counts[b];
1203
+ counts[b] = sum;
1204
+ sum += c;
1205
+ }
1206
+ for (let i = 0; i < n; i++) {
1207
+ const bucket = (srcK[i] >>> shift) & 0xFF;
1208
+ const pos = counts[bucket]++;
1209
+ dstIdx[pos] = srcIdx[i];
1210
+ dstK[pos] = srcK[i];
1211
+ }
1212
+ }
1213
+ return validCount;
1214
+ };
1132
1215
  // ---------- sigmoid / logit ----------
1133
1216
  const sigmoid$1 = (x) => 1 / (1 + Math.exp(-x));
1134
1217
  const logit = (p) => {
@@ -1616,7 +1699,6 @@ const simplifyGaussians = (dataTable, targetCount) => {
1616
1699
  let current;
1617
1700
  if (keptIndices.length < N && keptIndices.length > targetCount) {
1618
1701
  current = dataTable.permuteRows(keptIndices);
1619
- logger.debug(`opacity pruning: ${N} -> ${current.numRows} splats (threshold=${pruneThreshold.toFixed(4)})`);
1620
1702
  }
1621
1703
  else {
1622
1704
  current = dataTable;
@@ -1627,7 +1709,8 @@ const simplifyGaussians = (dataTable, targetCount) => {
1627
1709
  while (current.numRows > targetCount) {
1628
1710
  const n = current.numRows;
1629
1711
  const kEff = Math.min(Math.max(1, KNN_K), Math.max(1, n - 1));
1630
- logger.debug(`merging iteration: ${n} -> ${targetCount} splats`);
1712
+ logger.progress.begin(5);
1713
+ logger.progress.step('Building KD-tree');
1631
1714
  const cx = current.getColumnByName('x').data;
1632
1715
  const cy = current.getColumnByName('y').data;
1633
1716
  const cz = current.getColumnByName('z').data;
@@ -1640,16 +1723,21 @@ const simplifyGaussians = (dataTable, targetCount) => {
1640
1723
  const cr2 = current.getColumnByName('rot_2').data;
1641
1724
  const cr3 = current.getColumnByName('rot_3').data;
1642
1725
  const cache = buildPerSplatCache(n, cx, cy, cz, cop, cs0, cs1, cs2, cr0, cr1, cr2, cr3);
1643
- // Build KNN graph
1644
1726
  const posTable = new DataTable([
1645
1727
  new Column('x', cx instanceof Float32Array ? cx : new Float32Array(cx)),
1646
1728
  new Column('y', cy instanceof Float32Array ? cy : new Float32Array(cy)),
1647
1729
  new Column('z', cz instanceof Float32Array ? cz : new Float32Array(cz))
1648
1730
  ]);
1649
1731
  const kdTree = new KdTree(posTable);
1650
- const edgeSet = new Set();
1651
- const edges = [];
1732
+ logger.progress.step('Finding nearest neighbors');
1733
+ let edgeCapacity = Math.ceil(n * kEff / 2);
1734
+ let edgeU = new Uint32Array(edgeCapacity);
1735
+ let edgeV = new Uint32Array(edgeCapacity);
1736
+ let edgeCount = 0;
1652
1737
  const queryPoint = new Float32Array(3);
1738
+ const knnInterval = Math.max(1, Math.ceil(n / PROGRESS_TICKS));
1739
+ const knnTicks = Math.ceil(n / knnInterval);
1740
+ logger.progress.begin(knnTicks);
1653
1741
  for (let i = 0; i < n; i++) {
1654
1742
  queryPoint[0] = cx[i];
1655
1743
  queryPoint[1] = cy[i];
@@ -1657,46 +1745,58 @@ const simplifyGaussians = (dataTable, targetCount) => {
1657
1745
  const knn = kdTree.findKNearest(queryPoint, kEff + 1);
1658
1746
  for (let ki = 0; ki < knn.indices.length; ki++) {
1659
1747
  const j = knn.indices[ki];
1660
- if (j === i || j < 0)
1748
+ if (j <= i)
1661
1749
  continue;
1662
- const u = Math.min(i, j);
1663
- const v = Math.max(i, j);
1664
- const key = `${u},${v}`;
1665
- if (!edgeSet.has(key)) {
1666
- edgeSet.add(key);
1667
- edges.push([u, v]);
1750
+ if (edgeCount === edgeCapacity) {
1751
+ edgeCapacity *= 2;
1752
+ const newU = new Uint32Array(edgeCapacity);
1753
+ const newV = new Uint32Array(edgeCapacity);
1754
+ newU.set(edgeU);
1755
+ newV.set(edgeV);
1756
+ edgeU = newU;
1757
+ edgeV = newV;
1668
1758
  }
1759
+ edgeU[edgeCount] = i;
1760
+ edgeV[edgeCount] = j;
1761
+ edgeCount++;
1669
1762
  }
1763
+ if ((i + 1) % knnInterval === 0)
1764
+ logger.progress.step();
1670
1765
  }
1671
- if (edges.length === 0)
1766
+ if (n % knnInterval !== 0)
1767
+ logger.progress.step();
1768
+ if (edgeCount === 0) {
1769
+ logger.progress.cancel();
1672
1770
  break;
1673
- // Compute edge costs
1771
+ }
1772
+ logger.progress.step('Computing edge costs');
1674
1773
  const appData = [];
1675
1774
  for (let ai = 0; ai < allAppearanceCols.length; ai++) {
1676
1775
  const col = current.getColumnByName(allAppearanceCols[ai]);
1677
1776
  if (col)
1678
1777
  appData.push(col.data);
1679
1778
  }
1680
- const costs = new Float32Array(edges.length);
1681
- for (let e = 0; e < edges.length; e++) {
1682
- costs[e] = computeEdgeCost(edges[e][0], edges[e][1], cx, cy, cz, cache, Z, appData, appData.length);
1683
- }
1684
- // Greedy disjoint pair selection
1685
- const valid = [];
1686
- for (let i = 0; i < edges.length; i++) {
1687
- if (Number.isFinite(costs[i]))
1688
- valid.push(i);
1689
- }
1690
- valid.sort((a, b) => {
1691
- const d = costs[a] - costs[b];
1692
- return d !== 0 ? d : a - b;
1693
- });
1694
1779
  const mergesNeeded = n - targetCount;
1780
+ const costs = new Float32Array(edgeCount);
1781
+ const costInterval = Math.max(1, Math.ceil(edgeCount / PROGRESS_TICKS));
1782
+ const costTicks = Math.ceil(edgeCount / costInterval);
1783
+ logger.progress.begin(costTicks);
1784
+ for (let e = 0; e < edgeCount; e++) {
1785
+ costs[e] = computeEdgeCost(edgeU[e], edgeV[e], cx, cy, cz, cache, Z, appData, appData.length);
1786
+ if ((e + 1) % costInterval === 0)
1787
+ logger.progress.step();
1788
+ }
1789
+ if (edgeCount % costInterval !== 0)
1790
+ logger.progress.step();
1791
+ logger.progress.step('Merging splats');
1792
+ // Sort and greedy disjoint pair selection
1793
+ const sorted = new Uint32Array(edgeCount);
1794
+ const validCount = radixSortIndicesByFloat(sorted, edgeCount, costs);
1695
1795
  const used = new Uint8Array(n);
1696
1796
  const pairs = [];
1697
- for (let t = 0; t < valid.length; t++) {
1698
- const e = valid[t];
1699
- const u = edges[e][0], v = edges[e][1];
1797
+ for (let t = 0; t < validCount; t++) {
1798
+ const e = sorted[t];
1799
+ const u = edgeU[e], v = edgeV[e];
1700
1800
  if (used[u] || used[v])
1701
1801
  continue;
1702
1802
  used[u] = 1;
@@ -1705,9 +1805,10 @@ const simplifyGaussians = (dataTable, targetCount) => {
1705
1805
  if (pairs.length >= mergesNeeded)
1706
1806
  break;
1707
1807
  }
1708
- if (pairs.length === 0)
1808
+ if (pairs.length === 0) {
1809
+ logger.progress.cancel();
1709
1810
  break;
1710
- logger.debug(`selected ${pairs.length} merge pairs from ${edges.length} edges`);
1811
+ }
1711
1812
  // Mark which indices are consumed by merging
1712
1813
  const usedSet = new Uint8Array(n);
1713
1814
  for (let p = 0; p < pairs.length; p++) {
@@ -1735,7 +1836,7 @@ const simplifyGaussians = (dataTable, targetCount) => {
1735
1836
  newTable.columns[c].data[dst] = cols[c].data[src];
1736
1837
  }
1737
1838
  }
1738
- // Merge pairs -- cache column refs and handled set once
1839
+ // Merge pairs
1739
1840
  const mergeOut = {
1740
1841
  mu: new Float64Array(3),
1741
1842
  sc: new Float64Array(3),
@@ -1763,6 +1864,9 @@ const simplifyGaussians = (dataTable, targetCount) => {
1763
1864
  .filter(col => !handledCols.has(col.name))
1764
1865
  .map(col => ({ src: col, dst: newTable.getColumnByName(col.name) }))
1765
1866
  .filter(pair => pair.dst);
1867
+ const mergeInterval = Math.max(1, Math.ceil(pairs.length / PROGRESS_TICKS));
1868
+ const mergeTicks = Math.ceil(pairs.length / mergeInterval);
1869
+ logger.progress.begin(mergeTicks);
1766
1870
  for (let p = 0; p < pairs.length; p++, dst++) {
1767
1871
  const pi = pairs[p][0], pj = pairs[p][1];
1768
1872
  momentMatch(pi, pj, cx, cy, cz, cop, cs0, cs1, cs2, cr0, cr1, cr2, cr3, mergeOut, appData, appData.length);
@@ -1785,7 +1889,12 @@ const simplifyGaussians = (dataTable, targetCount) => {
1785
1889
  for (let u = 0; u < unhandledColPairs.length; u++) {
1786
1890
  unhandledColPairs[u].dst.data[dst] = unhandledColPairs[u].src.data[dominant];
1787
1891
  }
1892
+ if ((p + 1) % mergeInterval === 0)
1893
+ logger.progress.step();
1788
1894
  }
1895
+ if (pairs.length % mergeInterval !== 0)
1896
+ logger.progress.step();
1897
+ logger.progress.step('Finalizing');
1789
1898
  current = newTable;
1790
1899
  }
1791
1900
  return current;
@@ -5613,7 +5722,7 @@ class CompressedChunk {
5613
5722
  }
5614
5723
  }
5615
5724
 
5616
- var version = "1.9.0";
5725
+ var version = "1.9.1";
5617
5726
 
5618
5727
  const generatedByString = `Generated by splat-transform ${version}`;
5619
5728
  const chunkProps = [