@playcanvas/splat-transform 1.9.0 → 1.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -745,6 +745,16 @@ class Progress {
745
745
  if (!quiet)
746
746
  impl.onProgress(this.currentNode);
747
747
  }
748
+ /**
749
+ * Cancel the current progress node, popping it from the stack without
750
+ * completing remaining steps. Use this before early exits (e.g. break)
751
+ * to keep the progress stack balanced.
752
+ */
753
+ cancel() {
754
+ if (!this.currentNode)
755
+ return;
756
+ this.currentNode = this.currentNode.parent;
757
+ }
748
758
  /**
749
759
  * Advance to the next step. Auto-increments the step counter.
750
760
  * Auto-ends when all steps are complete.
@@ -932,33 +942,74 @@ const sortMortonOrder = (dataTable, indices) => {
932
942
  generate(indices);
933
943
  };
934
944
 
945
+ const nthElement = (arr, lo, hi, k, values) => {
946
+ while (lo < hi) {
947
+ const mid = (lo + hi) >> 1;
948
+ const va = values[arr[lo]], vb = values[arr[mid]], vc = values[arr[hi]];
949
+ let pivotIdx;
950
+ if ((vb - va) * (vc - vb) >= 0)
951
+ pivotIdx = mid;
952
+ else if ((va - vb) * (vc - va) >= 0)
953
+ pivotIdx = lo;
954
+ else
955
+ pivotIdx = hi;
956
+ const pivotVal = values[arr[pivotIdx]];
957
+ let tmp = arr[pivotIdx];
958
+ arr[pivotIdx] = arr[hi];
959
+ arr[hi] = tmp;
960
+ let store = lo;
961
+ for (let i = lo; i < hi; i++) {
962
+ if (values[arr[i]] < pivotVal) {
963
+ tmp = arr[i];
964
+ arr[i] = arr[store];
965
+ arr[store] = tmp;
966
+ store++;
967
+ }
968
+ }
969
+ tmp = arr[store];
970
+ arr[store] = arr[hi];
971
+ arr[hi] = tmp;
972
+ if (store === k)
973
+ return;
974
+ else if (store < k)
975
+ lo = store + 1;
976
+ else
977
+ hi = store - 1;
978
+ }
979
+ };
935
980
  class KdTree {
936
981
  centroids;
937
982
  root;
983
+ colData;
938
984
  constructor(centroids) {
939
- const build = (indices, depth) => {
940
- const { centroids } = this;
941
- const values = centroids.columns[depth % centroids.numColumns].data;
942
- indices.sort((a, b) => values[a] - values[b]);
943
- if (indices.length === 1) {
944
- return {
945
- index: indices[0],
946
- count: 1
947
- };
948
- }
949
- else if (indices.length === 2) {
985
+ const numCols = centroids.numColumns;
986
+ const colData = centroids.columns.map(c => c.data);
987
+ const indices = new Uint32Array(centroids.numRows);
988
+ for (let i = 0; i < indices.length; ++i) {
989
+ indices[i] = i;
990
+ }
991
+ const build = (lo, hi, depth) => {
992
+ const count = hi - lo + 1;
993
+ if (count === 1) {
994
+ return { index: indices[lo], count: 1 };
995
+ }
996
+ const values = colData[depth % numCols];
997
+ if (count === 2) {
998
+ if (values[indices[lo]] > values[indices[hi]]) {
999
+ const tmp = indices[lo];
1000
+ indices[lo] = indices[hi];
1001
+ indices[hi] = tmp;
1002
+ }
950
1003
  return {
951
- index: indices[0],
1004
+ index: indices[lo],
952
1005
  count: 2,
953
- right: {
954
- index: indices[1],
955
- count: 1
956
- }
1006
+ right: { index: indices[hi], count: 1 }
957
1007
  };
958
1008
  }
959
- const mid = indices.length >> 1;
960
- const left = build(indices.subarray(0, mid), depth + 1);
961
- const right = build(indices.subarray(mid + 1), depth + 1);
1009
+ const mid = lo + (count >> 1);
1010
+ nthElement(indices, lo, hi, mid, values);
1011
+ const left = build(lo, mid - 1, depth + 1);
1012
+ const right = build(mid + 1, hi, depth + 1);
962
1013
  return {
963
1014
  index: indices[mid],
964
1015
  count: 1 + left.count + right.count,
@@ -966,48 +1017,39 @@ class KdTree {
966
1017
  right
967
1018
  };
968
1019
  };
969
- const indices = new Uint32Array(centroids.numRows);
970
- for (let i = 0; i < indices.length; ++i) {
971
- indices[i] = i;
972
- }
973
1020
  this.centroids = centroids;
974
- this.root = build(indices, 0);
1021
+ this.colData = colData;
1022
+ this.root = build(0, indices.length - 1, 0);
975
1023
  }
976
1024
  findNearest(point, filterFunc) {
977
- const { centroids } = this;
978
- const { numColumns } = centroids;
979
- const calcDistance = (index) => {
980
- let l = 0;
981
- for (let i = 0; i < numColumns; ++i) {
982
- const v = centroids.columns[i].data[index] - point[i];
983
- l += v * v;
984
- }
985
- return l;
986
- };
1025
+ const colData = this.colData;
1026
+ const numCols = colData.length;
987
1027
  let mind = Infinity;
988
1028
  let mini = -1;
989
1029
  let cnt = 0;
990
- const recurse = (node, depth) => {
991
- const axis = depth % numColumns;
992
- const distance = point[axis] - centroids.columns[axis].data[node.index];
1030
+ const recurse = (node, axis) => {
1031
+ const distance = point[axis] - colData[axis][node.index];
993
1032
  const next = (distance > 0) ? node.right : node.left;
1033
+ const nextAxis = axis + 1 < numCols ? axis + 1 : 0;
994
1034
  cnt++;
995
1035
  if (next) {
996
- recurse(next, depth + 1);
1036
+ recurse(next, nextAxis);
997
1037
  }
998
- // check index
999
1038
  if (!filterFunc || filterFunc(node.index)) {
1000
- const thisd = calcDistance(node.index);
1039
+ let thisd = 0;
1040
+ for (let c = 0; c < numCols; c++) {
1041
+ const v = colData[c][node.index] - point[c];
1042
+ thisd += v * v;
1043
+ }
1001
1044
  if (thisd < mind) {
1002
1045
  mind = thisd;
1003
1046
  mini = node.index;
1004
1047
  }
1005
1048
  }
1006
- // check the other side
1007
1049
  if (distance * distance < mind) {
1008
1050
  const other = next === node.right ? node.left : node.right;
1009
1051
  if (other) {
1010
- recurse(other, depth + 1);
1052
+ recurse(other, nextAxis);
1011
1053
  }
1012
1054
  }
1013
1055
  };
@@ -1019,16 +1061,8 @@ class KdTree {
1019
1061
  return { indices: new Int32Array(0), distances: new Float32Array(0) };
1020
1062
  }
1021
1063
  k = Math.min(k, this.centroids.numRows);
1022
- const { centroids } = this;
1023
- const { numColumns } = centroids;
1024
- const calcDistance = (index) => {
1025
- let l = 0;
1026
- for (let i = 0; i < numColumns; ++i) {
1027
- const v = centroids.columns[i].data[index] - point[i];
1028
- l += v * v;
1029
- }
1030
- return l;
1031
- };
1064
+ const colData = this.colData;
1065
+ const numCols = colData.length;
1032
1066
  // Bounded max-heap: stores (distance, index) pairs sorted so the
1033
1067
  // farthest element is at position 0, enabling O(1) pruning bound.
1034
1068
  const heapDist = new Float32Array(k).fill(Infinity);
@@ -1036,14 +1070,12 @@ class KdTree {
1036
1070
  let heapSize = 0;
1037
1071
  const heapPush = (dist, idx) => {
1038
1072
  if (heapSize < k) {
1039
- // Heap not full yet -- insert via sift-up
1040
1073
  let pos = heapSize++;
1041
1074
  heapDist[pos] = dist;
1042
1075
  heapIdx[pos] = idx;
1043
1076
  while (pos > 0) {
1044
1077
  const parent = (pos - 1) >> 1;
1045
1078
  if (heapDist[parent] < heapDist[pos]) {
1046
- // swap
1047
1079
  const td = heapDist[parent];
1048
1080
  heapDist[parent] = heapDist[pos];
1049
1081
  heapDist[pos] = td;
@@ -1058,7 +1090,6 @@ class KdTree {
1058
1090
  }
1059
1091
  }
1060
1092
  else if (dist < heapDist[0]) {
1061
- // Replace root (farthest) and sift-down
1062
1093
  heapDist[0] = dist;
1063
1094
  heapIdx[0] = idx;
1064
1095
  let pos = 0;
@@ -1082,22 +1113,26 @@ class KdTree {
1082
1113
  }
1083
1114
  }
1084
1115
  };
1085
- const recurse = (node, depth) => {
1086
- const axis = depth % numColumns;
1087
- const distance = point[axis] - centroids.columns[axis].data[node.index];
1116
+ const recurse = (node, axis) => {
1117
+ const distance = point[axis] - colData[axis][node.index];
1088
1118
  const next = (distance > 0) ? node.right : node.left;
1119
+ const nextAxis = axis + 1 < numCols ? axis + 1 : 0;
1089
1120
  if (next) {
1090
- recurse(next, depth + 1);
1121
+ recurse(next, nextAxis);
1091
1122
  }
1092
1123
  if (!filterFunc || filterFunc(node.index)) {
1093
- const thisd = calcDistance(node.index);
1124
+ let thisd = 0;
1125
+ for (let c = 0; c < numCols; c++) {
1126
+ const v = colData[c][node.index] - point[c];
1127
+ thisd += v * v;
1128
+ }
1094
1129
  heapPush(thisd, node.index);
1095
1130
  }
1096
1131
  const bound = heapSize < k ? Infinity : heapDist[0];
1097
1132
  if (distance * distance < bound) {
1098
1133
  const other = next === node.right ? node.left : node.right;
1099
1134
  if (other) {
1100
- recurse(other, depth + 1);
1135
+ recurse(other, nextAxis);
1101
1136
  }
1102
1137
  }
1103
1138
  };
@@ -1132,6 +1167,54 @@ const OPACITY_PRUNE_THRESHOLD = 0.1;
1132
1167
  const KNN_K = 16;
1133
1168
  const MC_SAMPLES = 1;
1134
1169
  const EPS_COV = 1e-8;
1170
+ const PROGRESS_TICKS = 100;
1171
+ // Radix sort edge indices by their Float32 costs.
1172
+ // Converts floats to sortable uint32 keys (preserving order), then does
1173
+ // 4-pass LSD radix sort with 8-bit radix. Returns the number of valid
1174
+ // (finite-cost) edges written to `out`.
1175
+ const radixSortIndicesByFloat = (out, count, keys) => {
1176
+ const keyBits = new Uint32Array(keys.buffer, keys.byteOffset, keys.length);
1177
+ const sortKeys = new Uint32Array(count);
1178
+ let validCount = 0;
1179
+ for (let i = 0; i < count; i++) {
1180
+ const bits = keyBits[i];
1181
+ if ((bits & 0x7F800000) === 0x7F800000)
1182
+ continue;
1183
+ sortKeys[validCount] = (bits & 0x80000000) ? ~bits >>> 0 : (bits | 0x80000000) >>> 0;
1184
+ out[validCount] = i;
1185
+ validCount++;
1186
+ }
1187
+ if (validCount <= 1)
1188
+ return validCount;
1189
+ const n = validCount;
1190
+ const temp = new Uint32Array(n);
1191
+ const tempKeys = new Uint32Array(n);
1192
+ const counts = new Uint32Array(256);
1193
+ for (let pass = 0; pass < 4; pass++) {
1194
+ const shift = pass << 3;
1195
+ const srcIdx = (pass & 1) ? temp : out;
1196
+ const dstIdx = (pass & 1) ? out : temp;
1197
+ const srcK = (pass & 1) ? tempKeys : sortKeys;
1198
+ const dstK = (pass & 1) ? sortKeys : tempKeys;
1199
+ counts.fill(0);
1200
+ for (let i = 0; i < n; i++) {
1201
+ counts[(srcK[i] >>> shift) & 0xFF]++;
1202
+ }
1203
+ let sum = 0;
1204
+ for (let b = 0; b < 256; b++) {
1205
+ const c = counts[b];
1206
+ counts[b] = sum;
1207
+ sum += c;
1208
+ }
1209
+ for (let i = 0; i < n; i++) {
1210
+ const bucket = (srcK[i] >>> shift) & 0xFF;
1211
+ const pos = counts[bucket]++;
1212
+ dstIdx[pos] = srcIdx[i];
1213
+ dstK[pos] = srcK[i];
1214
+ }
1215
+ }
1216
+ return validCount;
1217
+ };
1135
1218
  // ---------- sigmoid / logit ----------
1136
1219
  const sigmoid$1 = (x) => 1 / (1 + Math.exp(-x));
1137
1220
  const logit = (p) => {
@@ -1619,7 +1702,6 @@ const simplifyGaussians = (dataTable, targetCount) => {
1619
1702
  let current;
1620
1703
  if (keptIndices.length < N && keptIndices.length > targetCount) {
1621
1704
  current = dataTable.permuteRows(keptIndices);
1622
- logger.debug(`opacity pruning: ${N} -> ${current.numRows} splats (threshold=${pruneThreshold.toFixed(4)})`);
1623
1705
  }
1624
1706
  else {
1625
1707
  current = dataTable;
@@ -1630,7 +1712,8 @@ const simplifyGaussians = (dataTable, targetCount) => {
1630
1712
  while (current.numRows > targetCount) {
1631
1713
  const n = current.numRows;
1632
1714
  const kEff = Math.min(Math.max(1, KNN_K), Math.max(1, n - 1));
1633
- logger.debug(`merging iteration: ${n} -> ${targetCount} splats`);
1715
+ logger.progress.begin(5);
1716
+ logger.progress.step('Building KD-tree');
1634
1717
  const cx = current.getColumnByName('x').data;
1635
1718
  const cy = current.getColumnByName('y').data;
1636
1719
  const cz = current.getColumnByName('z').data;
@@ -1643,16 +1726,21 @@ const simplifyGaussians = (dataTable, targetCount) => {
1643
1726
  const cr2 = current.getColumnByName('rot_2').data;
1644
1727
  const cr3 = current.getColumnByName('rot_3').data;
1645
1728
  const cache = buildPerSplatCache(n, cx, cy, cz, cop, cs0, cs1, cs2, cr0, cr1, cr2, cr3);
1646
- // Build KNN graph
1647
1729
  const posTable = new DataTable([
1648
1730
  new Column('x', cx instanceof Float32Array ? cx : new Float32Array(cx)),
1649
1731
  new Column('y', cy instanceof Float32Array ? cy : new Float32Array(cy)),
1650
1732
  new Column('z', cz instanceof Float32Array ? cz : new Float32Array(cz))
1651
1733
  ]);
1652
1734
  const kdTree = new KdTree(posTable);
1653
- const edgeSet = new Set();
1654
- const edges = [];
1735
+ logger.progress.step('Finding nearest neighbors');
1736
+ let edgeCapacity = Math.ceil(n * kEff / 2);
1737
+ let edgeU = new Uint32Array(edgeCapacity);
1738
+ let edgeV = new Uint32Array(edgeCapacity);
1739
+ let edgeCount = 0;
1655
1740
  const queryPoint = new Float32Array(3);
1741
+ const knnInterval = Math.max(1, Math.ceil(n / PROGRESS_TICKS));
1742
+ const knnTicks = Math.ceil(n / knnInterval);
1743
+ logger.progress.begin(knnTicks);
1656
1744
  for (let i = 0; i < n; i++) {
1657
1745
  queryPoint[0] = cx[i];
1658
1746
  queryPoint[1] = cy[i];
@@ -1660,46 +1748,58 @@ const simplifyGaussians = (dataTable, targetCount) => {
1660
1748
  const knn = kdTree.findKNearest(queryPoint, kEff + 1);
1661
1749
  for (let ki = 0; ki < knn.indices.length; ki++) {
1662
1750
  const j = knn.indices[ki];
1663
- if (j === i || j < 0)
1751
+ if (j <= i)
1664
1752
  continue;
1665
- const u = Math.min(i, j);
1666
- const v = Math.max(i, j);
1667
- const key = `${u},${v}`;
1668
- if (!edgeSet.has(key)) {
1669
- edgeSet.add(key);
1670
- edges.push([u, v]);
1753
+ if (edgeCount === edgeCapacity) {
1754
+ edgeCapacity *= 2;
1755
+ const newU = new Uint32Array(edgeCapacity);
1756
+ const newV = new Uint32Array(edgeCapacity);
1757
+ newU.set(edgeU);
1758
+ newV.set(edgeV);
1759
+ edgeU = newU;
1760
+ edgeV = newV;
1671
1761
  }
1762
+ edgeU[edgeCount] = i;
1763
+ edgeV[edgeCount] = j;
1764
+ edgeCount++;
1672
1765
  }
1766
+ if ((i + 1) % knnInterval === 0)
1767
+ logger.progress.step();
1673
1768
  }
1674
- if (edges.length === 0)
1769
+ if (n % knnInterval !== 0)
1770
+ logger.progress.step();
1771
+ if (edgeCount === 0) {
1772
+ logger.progress.cancel();
1675
1773
  break;
1676
- // Compute edge costs
1774
+ }
1775
+ logger.progress.step('Computing edge costs');
1677
1776
  const appData = [];
1678
1777
  for (let ai = 0; ai < allAppearanceCols.length; ai++) {
1679
1778
  const col = current.getColumnByName(allAppearanceCols[ai]);
1680
1779
  if (col)
1681
1780
  appData.push(col.data);
1682
1781
  }
1683
- const costs = new Float32Array(edges.length);
1684
- for (let e = 0; e < edges.length; e++) {
1685
- costs[e] = computeEdgeCost(edges[e][0], edges[e][1], cx, cy, cz, cache, Z, appData, appData.length);
1686
- }
1687
- // Greedy disjoint pair selection
1688
- const valid = [];
1689
- for (let i = 0; i < edges.length; i++) {
1690
- if (Number.isFinite(costs[i]))
1691
- valid.push(i);
1692
- }
1693
- valid.sort((a, b) => {
1694
- const d = costs[a] - costs[b];
1695
- return d !== 0 ? d : a - b;
1696
- });
1697
1782
  const mergesNeeded = n - targetCount;
1783
+ const costs = new Float32Array(edgeCount);
1784
+ const costInterval = Math.max(1, Math.ceil(edgeCount / PROGRESS_TICKS));
1785
+ const costTicks = Math.ceil(edgeCount / costInterval);
1786
+ logger.progress.begin(costTicks);
1787
+ for (let e = 0; e < edgeCount; e++) {
1788
+ costs[e] = computeEdgeCost(edgeU[e], edgeV[e], cx, cy, cz, cache, Z, appData, appData.length);
1789
+ if ((e + 1) % costInterval === 0)
1790
+ logger.progress.step();
1791
+ }
1792
+ if (edgeCount % costInterval !== 0)
1793
+ logger.progress.step();
1794
+ logger.progress.step('Merging splats');
1795
+ // Sort and greedy disjoint pair selection
1796
+ const sorted = new Uint32Array(edgeCount);
1797
+ const validCount = radixSortIndicesByFloat(sorted, edgeCount, costs);
1698
1798
  const used = new Uint8Array(n);
1699
1799
  const pairs = [];
1700
- for (let t = 0; t < valid.length; t++) {
1701
- const e = valid[t];
1702
- const u = edges[e][0], v = edges[e][1];
1800
+ for (let t = 0; t < validCount; t++) {
1801
+ const e = sorted[t];
1802
+ const u = edgeU[e], v = edgeV[e];
1703
1803
  if (used[u] || used[v])
1704
1804
  continue;
1705
1805
  used[u] = 1;
@@ -1708,9 +1808,10 @@ const simplifyGaussians = (dataTable, targetCount) => {
1708
1808
  if (pairs.length >= mergesNeeded)
1709
1809
  break;
1710
1810
  }
1711
- if (pairs.length === 0)
1811
+ if (pairs.length === 0) {
1812
+ logger.progress.cancel();
1712
1813
  break;
1713
- logger.debug(`selected ${pairs.length} merge pairs from ${edges.length} edges`);
1814
+ }
1714
1815
  // Mark which indices are consumed by merging
1715
1816
  const usedSet = new Uint8Array(n);
1716
1817
  for (let p = 0; p < pairs.length; p++) {
@@ -1738,7 +1839,7 @@ const simplifyGaussians = (dataTable, targetCount) => {
1738
1839
  newTable.columns[c].data[dst] = cols[c].data[src];
1739
1840
  }
1740
1841
  }
1741
- // Merge pairs -- cache column refs and handled set once
1842
+ // Merge pairs
1742
1843
  const mergeOut = {
1743
1844
  mu: new Float64Array(3),
1744
1845
  sc: new Float64Array(3),
@@ -1766,6 +1867,9 @@ const simplifyGaussians = (dataTable, targetCount) => {
1766
1867
  .filter(col => !handledCols.has(col.name))
1767
1868
  .map(col => ({ src: col, dst: newTable.getColumnByName(col.name) }))
1768
1869
  .filter(pair => pair.dst);
1870
+ const mergeInterval = Math.max(1, Math.ceil(pairs.length / PROGRESS_TICKS));
1871
+ const mergeTicks = Math.ceil(pairs.length / mergeInterval);
1872
+ logger.progress.begin(mergeTicks);
1769
1873
  for (let p = 0; p < pairs.length; p++, dst++) {
1770
1874
  const pi = pairs[p][0], pj = pairs[p][1];
1771
1875
  momentMatch(pi, pj, cx, cy, cz, cop, cs0, cs1, cs2, cr0, cr1, cr2, cr3, mergeOut, appData, appData.length);
@@ -1788,7 +1892,12 @@ const simplifyGaussians = (dataTable, targetCount) => {
1788
1892
  for (let u = 0; u < unhandledColPairs.length; u++) {
1789
1893
  unhandledColPairs[u].dst.data[dst] = unhandledColPairs[u].src.data[dominant];
1790
1894
  }
1895
+ if ((p + 1) % mergeInterval === 0)
1896
+ logger.progress.step();
1791
1897
  }
1898
+ if (pairs.length % mergeInterval !== 0)
1899
+ logger.progress.step();
1900
+ logger.progress.step('Finalizing');
1792
1901
  current = newTable;
1793
1902
  }
1794
1903
  return current;
@@ -5616,7 +5725,7 @@ class CompressedChunk {
5616
5725
  }
5617
5726
  }
5618
5727
 
5619
- var version = "1.9.0";
5728
+ var version = "1.9.1";
5620
5729
 
5621
5730
  const generatedByString = `Generated by splat-transform ${version}`;
5622
5731
  const chunkProps = [