@manycore/aholo-splat-transform 1.2.9 → 1.2.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/CHANGELOG.md CHANGED
@@ -1,5 +1,16 @@
1
1
  # ChangeLOG
2
2
 
3
+ ## 1.2.11
4
+
5
+ - 移除`webgpu`依赖,改为自主编译版本
6
+ - 优化`k-means`计算,采用`vec4`对齐,简化欧几里得距离计算。(性能提升约50%)
7
+ - 优化`clusterAverage`,减少js端overhead
8
+ - native线程池共享
9
+
10
+ ## 1.2.10
11
+
12
+ - 增加`darwin-arm64`的预编译版本
13
+
3
14
  ## 1.2.9
4
15
 
5
16
  - 改进编码风格开启`verbatimModuleSyntax` & `isolatedModules`,迁移至`OXC`
package/README.md CHANGED
@@ -8,7 +8,7 @@ A 3DGS modifier used by aholo
8
8
  - system
9
9
  - windows: windows 22H2+, x86_64, D3D12 or Vulkan compatible GPU(When use GPU features, dedicated GPU for better performance)
10
10
  - linux: x86_64, glibc >= 2.34, libstdc++ >= 3.4.30, Vulkan compatible GPU(When use GPU features, dedicated GPU for better performance)
11
- - osx: Not Supported
11
+ - osx: apple silicon ARM64 only.
12
12
 
13
13
  ## Usage
14
14
 
package/dist/file/sog.js CHANGED
@@ -3,7 +3,7 @@ import { Buffer } from 'node:buffer';
3
3
  import { decodeWebP, encodeWebP, WebPLosslessProfile } from '../native/index.js';
4
4
  import { ColIdx } from '../SplatData.js';
5
5
  import { SH_C0, SH_MAPS, NUM_F_REST_TO_SH_DEGREE } from '../constant.js';
6
- import { getOrCreateDevice, kmeans, logger, mortonSort, quantize1d, isUrl, extractFromRootDir, clamp, } from '../utils/index.js';
6
+ import { getOrCreateDevice, kMeans, logger, mortonSort, quantize1d, isUrl, extractFromRootDir, clamp, } from '../utils/index.js';
7
7
  const ZIP_MAGIC = 0x04034b50;
8
8
  const PERM_TABLE = [
9
9
  // original quat idx ---> actual storage idx
@@ -485,7 +485,7 @@ export class SogFile {
485
485
  const device = await getOrCreateDevice();
486
486
  logger.info(`SOG SH${shDegree} k-means with clusters=${paletteSize} iterations=${this.iterations}`);
487
487
  logger.time(`SOG SH${shDegree} k-means`);
488
- const { centroids, labels } = await kmeans(shDataTable, paletteSize, this.iterations, device);
488
+ const { centroids, labels } = await kMeans(shDataTable, paletteSize, this.iterations, device);
489
489
  logger.timeEnd(`SOG SH${shDegree} k-means`);
490
490
  const codebook = quantize1d(centroids);
491
491
  // write centroids
@@ -51,4 +51,4 @@ export declare function decodeAVIFBatched(inputs: Array<Uint8Array | Buffer>): {
51
51
  width: number;
52
52
  height: number;
53
53
  }[];
54
- export declare function clusterAverage(dataTable: Float32Array[], clusters: Uint32Array[], output: Float32Array[]): void;
54
+ export declare function clusterAverage(dataTable: Float32Array[], labels: Uint32Array, k: number, output: Float32Array[]): void;
@@ -1,26 +1,36 @@
1
1
  import { createRequire } from 'node:module';
2
2
  import { SplatData } from '../SplatData.js';
3
3
  import { Buffer } from 'node:buffer';
4
- import { isMusl } from './utils.js';
5
- import p from '../../package.json' with { type: 'json' };
4
+ import { getNativePackageName } from './utils.js';
6
5
  const getModule = (function () {
7
6
  let m = undefined;
8
7
  const require = createRequire(import.meta.url);
9
- let runtime = undefined;
10
- if (process.platform === 'win32') {
11
- runtime = 'msvc';
12
- }
13
- else if (process.platform === 'linux') {
14
- runtime = isMusl() ? 'musl' : 'gnu';
15
- }
16
- const binaryPackage = `${p.name}-${process.platform}-${process.arch}${runtime ? `-${runtime}` : ''}`;
8
+ const binaryModule = getNativePackageName() + '/splat-transform.node';
17
9
  return function () {
18
10
  if (!m) {
19
- m = require(binaryPackage);
11
+ m = require(binaryModule);
20
12
  }
21
13
  return m;
22
14
  };
23
15
  })();
16
+ const [defaultThreadPool, smallThreadPool] = (function () {
17
+ let defaultTheadPool;
18
+ let smallTheadPool;
19
+ return [
20
+ function () {
21
+ if (!defaultTheadPool) {
22
+ defaultTheadPool = new (getModule().ThreadPool)();
23
+ }
24
+ return defaultTheadPool;
25
+ },
26
+ function () {
27
+ if (!smallTheadPool) {
28
+ smallTheadPool = new (getModule().ThreadPool)(4);
29
+ }
30
+ return smallTheadPool;
31
+ },
32
+ ];
33
+ })();
24
34
  export function generateLod(splat, levelParameters, blockPrecision, minSize, maxStep) {
25
35
  if (splat.counts === 0) {
26
36
  return {
@@ -44,7 +54,7 @@ export function generateLod(splat, levelParameters, blockPrecision, minSize, max
44
54
  parameters[i * 2 + 1] = scaleBoost;
45
55
  }
46
56
  }
47
- const { blockBoxes, blockRefs, gaussianCount, data } = getModule().generate_lod(inputBuffers, splat.shCounts, buffer, blockPrecision, minSize, maxStep);
57
+ const { blockBoxes, blockRefs, gaussianCount, data } = getModule().generate_lod(inputBuffers, splat.shCounts, buffer, blockPrecision, minSize, maxStep, defaultThreadPool());
48
58
  const blockView = new Float32Array(blockBoxes.buffer, blockBoxes.byteOffset, blockBoxes.byteLength / 4);
49
59
  const blockRefsView = new Uint32Array(blockRefs.buffer, blockRefs.byteOffset, blockRefs.byteLength / 4);
50
60
  const blockCount = blockView.length / 6;
@@ -108,15 +118,15 @@ export function encodeAVIFBatched(inputs) {
108
118
  return getModule().avif_encode_rgba_batched(inputs.map(i => ({
109
119
  ...i,
110
120
  data: i.data instanceof Buffer ? i.data : Buffer.from(i.data.buffer, i.data.byteOffset, i.data.byteLength),
111
- })));
121
+ })), smallThreadPool());
112
122
  }
113
123
  export function decodeAVIF(data) {
114
124
  const buffer = data instanceof Buffer ? data : Buffer.from(data.buffer, data.byteOffset, data.byteLength);
115
125
  return getModule().avif_decode_rgba(buffer);
116
126
  }
117
127
  export function decodeAVIFBatched(inputs) {
118
- return getModule().avif_decode_rgba_batched(inputs.map(i => (i instanceof Buffer ? i : Buffer.from(i.buffer, i.byteOffset, i.byteLength))));
128
+ return getModule().avif_decode_rgba_batched(inputs.map(i => (i instanceof Buffer ? i : Buffer.from(i.buffer, i.byteOffset, i.byteLength))), smallThreadPool());
119
129
  }
120
- export function clusterAverage(dataTable, clusters, output) {
121
- return getModule().cluster_average(dataTable.map(t => Buffer.from(t.buffer, t.byteOffset, t.byteLength)), clusters.map(t => Buffer.from(t.buffer, t.byteOffset, t.byteLength)), output.map(t => Buffer.from(t.buffer, t.byteOffset, t.byteLength)));
130
+ export function clusterAverage(dataTable, labels, k, output) {
131
+ return getModule().cluster_average(dataTable.map(t => Buffer.from(t.buffer, t.byteOffset, t.byteLength)), Buffer.from(labels.buffer, labels.byteOffset, labels.byteLength), k, output.map(t => Buffer.from(t.buffer, t.byteOffset, t.byteLength)), defaultThreadPool());
122
132
  }
@@ -1 +1 @@
1
- export declare function isMusl(): boolean;
1
+ export declare function getNativePackageName(): string;
@@ -1,6 +1,7 @@
1
1
  import { readFileSync } from 'node:fs';
2
2
  import child_process from 'node:child_process';
3
- export function isMusl() {
3
+ import p from '../../package.json' with { type: 'json' };
4
+ function isMusl() {
4
5
  let musl = false;
5
6
  if (process.platform === 'linux') {
6
7
  musl = isMuslFromFilesystem();
@@ -52,3 +53,13 @@ function isMuslFromChildProcess() {
52
53
  return false;
53
54
  }
54
55
  }
56
+ export function getNativePackageName() {
57
+ let runtime = undefined;
58
+ if (process.platform === 'win32') {
59
+ runtime = 'msvc';
60
+ }
61
+ else if (process.platform === 'linux') {
62
+ runtime = isMusl() ? 'musl' : 'gnu';
63
+ }
64
+ return `${p.name}-${process.platform}-${process.arch}${runtime ? `-${runtime}` : ''}`;
65
+ }
@@ -15,7 +15,7 @@ export * from './StreamChunkDecoder.js';
15
15
  export * from './math.js';
16
16
  export * from './sh-rotate.js';
17
17
  export * from './splat.js';
18
- export * from './k-means.js';
18
+ export * from './k-means/index.js';
19
19
  export * from './quantize-1d.js';
20
20
  export * from './webgpu.js';
21
21
  export * from './voxel/common.js';
@@ -89,7 +89,7 @@ export * from './StreamChunkDecoder.js';
89
89
  export * from './math.js';
90
90
  export * from './sh-rotate.js';
91
91
  export * from './splat.js';
92
- export * from './k-means.js';
92
+ export * from './k-means/index.js';
93
93
  export * from './quantize-1d.js';
94
94
  export * from './webgpu.js';
95
95
  export * from './voxel/common.js';
@@ -0,0 +1,16 @@
1
+ export default class GpuClustering {
2
+ private device;
3
+ private numPoints;
4
+ private numCentroids;
5
+ private batchSize;
6
+ private resource;
7
+ private numBatches;
8
+ private concurrencyBatches;
9
+ private concurrencyRuns;
10
+ private workgroupSize;
11
+ private pointStride;
12
+ private vecColumns;
13
+ constructor(device: GPUDevice, numPoints: number, numColumns: number, numCentroids: number);
14
+ execute(points: Float32Array[], centroids: Float32Array[], labels: Uint32Array): Promise<void>;
15
+ destroy(): void;
16
+ }
@@ -0,0 +1,287 @@
1
+ import { logger } from '../index.js';
2
+ const chunkSize = 128;
3
+ const workgroupSize = 128;
4
+ function clusterWgsl(vecColumns) {
5
+ return /* wgsl */ `
6
+ struct Uniforms {
7
+ numPoints: u32,
8
+ numCentroids: u32
9
+ };
10
+
11
+ @group(0) @binding(0) var<uniform> uniforms: Uniforms;
12
+ @group(0) @binding(1) var<storage, read> points: array<vec4<f32>>;
13
+ @group(0) @binding(2) var<storage, read> centroids: array<vec4<f32>>;
14
+ @group(0) @binding(3) var<storage, read> centroidSq: array<f32>;
15
+ @group(0) @binding(4) var<storage, read_write> results: array<u32>;
16
+
17
+ const vecColumns = ${vecColumns}u;
18
+ const chunkSize = ${chunkSize}u;
19
+ const workgroupSize = ${workgroupSize}u;
20
+ var<workgroup> sharedChunk: array<vec4<f32>, vecColumns * chunkSize>;
21
+ var<workgroup> sharedSq: array<f32, chunkSize>;
22
+
23
+ fn calcDistance(point: array<vec4<f32>, vecColumns>, centroid: u32) -> f32 {
24
+ let ci = centroid * vecColumns;
25
+ var result = sharedSq[centroid];
26
+
27
+ for (var i = 0u; i < vecColumns; i++) {
28
+ // euclid distance simplify
29
+ // (centroid - point) ^ 2 = centroid ^ 2 - 2 * dot(centroid, point) + point ^ 2
30
+ // point ^ 2 omitted, for same point find nearest centroid is not necessary
31
+ result -= 2.0 * dot(point[i], sharedChunk[ci + i]);
32
+ }
33
+
34
+ return result;
35
+ }
36
+
37
+ @compute @workgroup_size(workgroupSize)
38
+ fn main(
39
+ @builtin(local_invocation_index) localId : u32,
40
+ @builtin(global_invocation_id) globalId: vec3u,
41
+ @builtin(num_workgroups) numWorkgroups: vec3u
42
+ ) {
43
+ let pointIndex = globalId.x + globalId.y * numWorkgroups.x * workgroupSize;
44
+
45
+ var point: array<vec4<f32>, vecColumns>;
46
+ if (pointIndex < uniforms.numPoints) {
47
+ for (var i = 0u; i < vecColumns; i++) {
48
+ point[i] = points[pointIndex * vecColumns + i];
49
+ }
50
+ }
51
+
52
+ var mind = 1000000.0;
53
+ var mini = 0u;
54
+
55
+ let numChunks = u32(ceil(f32(uniforms.numCentroids) / f32(chunkSize)));
56
+ for (var i = 0u; i < numChunks; i++) {
57
+ let chunkToLoad = min(chunkSize, uniforms.numCentroids - i * chunkSize);
58
+ for (var row = localId; row < chunkToLoad; row += workgroupSize) {
59
+ let srcRow = i * chunkSize + row;
60
+ let dst = row * vecColumns;
61
+ let src = srcRow * vecColumns;
62
+
63
+ for (var c = 0u; c < vecColumns; c++) {
64
+ sharedChunk[dst + c] = centroids[src + c];
65
+ }
66
+ sharedSq[row] = centroidSq[srcRow];
67
+ }
68
+
69
+ workgroupBarrier();
70
+
71
+ if (pointIndex < uniforms.numPoints) {
72
+ let thisChunkSize = min(chunkSize, uniforms.numCentroids - i * chunkSize);
73
+ for (var c = 0u; c < thisChunkSize; c++) {
74
+ let d = calcDistance(point, c);
75
+ if (d < mind) {
76
+ mind = d;
77
+ mini = i * chunkSize + c;
78
+ }
79
+ }
80
+ }
81
+
82
+ workgroupBarrier();
83
+ }
84
+
85
+ if (pointIndex < uniforms.numPoints) {
86
+ results[pointIndex] = mini;
87
+ }
88
+ }
89
+ `;
90
+ }
91
+ function packVec4Data(result, dataTable, numRows, rowOffset, vecColumns, norms) {
92
+ const numColumns = dataTable.length;
93
+ const stride = vecColumns * 4;
94
+ for (let r = 0; r < numRows; r++) {
95
+ const dst = r * stride;
96
+ let norm = 0.0;
97
+ for (let c = 0; c < numColumns; c++) {
98
+ const v = dataTable[c][rowOffset + r];
99
+ result[dst + c] = v;
100
+ if (norms) {
101
+ norm += v * v;
102
+ }
103
+ }
104
+ for (let c = numColumns; c < stride; c++) {
105
+ result[dst + c] = 0.0;
106
+ }
107
+ if (norms) {
108
+ norms[r] = norm;
109
+ }
110
+ }
111
+ }
112
+ const MAX_CONCURRENCY_BATCHES = 10;
113
+ export default class GpuClustering {
114
+ constructor(device, numPoints, numColumns, numCentroids) {
115
+ this.device = device;
116
+ this.numPoints = numPoints;
117
+ this.numCentroids = numCentroids;
118
+ this.vecColumns = Math.ceil(numColumns / 4);
119
+ this.workgroupSize = workgroupSize;
120
+ this.pointStride = this.vecColumns * 4;
121
+ const storageLimit = Math.min(device.limits.maxBufferSize, device.limits.maxStorageBufferBindingSize);
122
+ const workgroupsPerBatch = Math.max(1, Math.min(device.limits.maxComputeWorkgroupsPerDimension, // device dispatch limit
123
+ Math.floor(storageLimit / (this.pointStride * this.workgroupSize * 4)), // point storage limit
124
+ Math.ceil(numPoints / this.workgroupSize)));
125
+ this.batchSize = workgroupsPerBatch * this.workgroupSize;
126
+ this.numBatches = Math.ceil(numPoints / this.batchSize);
127
+ this.concurrencyBatches = Math.min(MAX_CONCURRENCY_BATCHES, this.numBatches);
128
+ this.concurrencyRuns = Math.ceil(this.numBatches / this.concurrencyBatches);
129
+ const shader = device.createShaderModule({
130
+ code: clusterWgsl(this.vecColumns),
131
+ });
132
+ const pipeline = device.createComputePipeline({
133
+ layout: 'auto',
134
+ compute: {
135
+ module: shader,
136
+ entryPoint: 'main',
137
+ },
138
+ });
139
+ const pointsBackBuffer = new Float32Array(this.pointStride * this.batchSize);
140
+ const centroidsBackBuffer = new Float32Array(this.pointStride * numCentroids);
141
+ const centroidSqBackBuffer = new Float32Array(numCentroids);
142
+ const uniformBackBuffer = new Uint32Array([0, numCentroids]);
143
+ const pointsBuffers = [];
144
+ const centroidsBuffer = device.createBuffer({
145
+ size: centroidsBackBuffer.byteLength,
146
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
147
+ });
148
+ const centroidSqBuffer = device.createBuffer({
149
+ size: centroidSqBackBuffer.byteLength,
150
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
151
+ });
152
+ const uniformBuffer = device.createBuffer({
153
+ size: 256 * this.concurrencyBatches,
154
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM,
155
+ });
156
+ const resultBuffer = device.createBuffer({
157
+ size: this.concurrencyBatches * this.batchSize * 4,
158
+ usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.STORAGE,
159
+ });
160
+ const resultReadBackBuffer = device.createBuffer({
161
+ size: this.concurrencyBatches * this.batchSize * 4,
162
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
163
+ });
164
+ const layout = pipeline.getBindGroupLayout(0);
165
+ const bindGroups = [];
166
+ for (let i = 0; i < this.concurrencyBatches; i++) {
167
+ const pointsBuffer = device.createBuffer({
168
+ size: pointsBackBuffer.byteLength,
169
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
170
+ });
171
+ pointsBuffers.push(pointsBuffer);
172
+ const entries = [
173
+ {
174
+ binding: 0,
175
+ resource: {
176
+ buffer: uniformBuffer,
177
+ offset: i * 256,
178
+ size: 8,
179
+ },
180
+ },
181
+ {
182
+ binding: 1,
183
+ resource: pointsBuffer,
184
+ },
185
+ {
186
+ binding: 2,
187
+ resource: centroidsBuffer,
188
+ },
189
+ {
190
+ binding: 3,
191
+ resource: centroidSqBuffer,
192
+ },
193
+ {
194
+ binding: 4,
195
+ resource: {
196
+ buffer: resultBuffer,
197
+ offset: i * this.batchSize * 4,
198
+ size: this.batchSize * 4,
199
+ },
200
+ },
201
+ ];
202
+ bindGroups.push(device.createBindGroup({
203
+ layout,
204
+ entries,
205
+ }));
206
+ }
207
+ const gpuBuffers = {
208
+ uniform: uniformBuffer,
209
+ points: pointsBuffers,
210
+ centroids: centroidsBuffer,
211
+ centroidSq: centroidSqBuffer,
212
+ result: resultBuffer,
213
+ resultReadBack: resultReadBackBuffer,
214
+ };
215
+ const backBuffers = {
216
+ uniform: uniformBackBuffer,
217
+ points: pointsBackBuffer,
218
+ centroids: centroidsBackBuffer,
219
+ centroidSq: centroidSqBackBuffer,
220
+ };
221
+ this.resource = {
222
+ pipeline,
223
+ bindGroups,
224
+ gpuBuffers,
225
+ backBuffers,
226
+ uploadedBatches: [],
227
+ };
228
+ logger.info(`GPU k-means kernel bootstrapped with batch ${workgroupsPerBatch}*${this.workgroupSize}*${this.numBatches}, concurrency: ${this.concurrencyBatches}, runs: ${this.concurrencyRuns}`);
229
+ }
230
+ async execute(points, centroids, labels) {
231
+ const { device, numPoints, numCentroids, numBatches, batchSize, resource, concurrencyBatches, concurrencyRuns, pointStride, vecColumns, workgroupSize, } = this;
232
+ // upload centroid data to gpu
233
+ packVec4Data(resource.backBuffers.centroids, centroids, numCentroids, 0, vecColumns, resource.backBuffers.centroidSq);
234
+ device.queue.writeBuffer(resource.gpuBuffers.centroids, 0, resource.backBuffers.centroids.buffer);
235
+ device.queue.writeBuffer(resource.gpuBuffers.centroidSq, 0, resource.backBuffers.centroidSq.buffer);
236
+ for (let i = 0; i < concurrencyRuns; i++) {
237
+ const batchStart = i * concurrencyBatches;
238
+ let resultCount = 0;
239
+ for (let j = 0; j < concurrencyBatches; j++) {
240
+ const batchIndex = batchStart + j;
241
+ if (batchIndex >= numBatches) {
242
+ break;
243
+ }
244
+ const currentBatchSize = Math.min(numPoints - batchIndex * batchSize, batchSize);
245
+ resultCount += currentBatchSize;
246
+ // write this batch of point data to gpu
247
+ if (resource.uploadedBatches[j] !== batchIndex) {
248
+ packVec4Data(resource.backBuffers.points, points, currentBatchSize, batchIndex * batchSize, vecColumns);
249
+ device.queue.writeBuffer(resource.gpuBuffers.points[j], 0, resource.backBuffers.points.buffer, 0, pointStride * currentBatchSize * 4);
250
+ resource.backBuffers.uniform[0] = currentBatchSize;
251
+ device.queue.writeBuffer(resource.gpuBuffers.uniform, 256 * j, resource.backBuffers.uniform.buffer, 0, 8);
252
+ resource.uploadedBatches[j] = batchIndex;
253
+ }
254
+ }
255
+ const encoder = device.createCommandEncoder();
256
+ const computePass = encoder.beginComputePass();
257
+ computePass.setPipeline(resource.pipeline);
258
+ for (let j = 0; j < concurrencyBatches; j++) {
259
+ const batchIndex = batchStart + j;
260
+ if (batchIndex >= numBatches) {
261
+ break;
262
+ }
263
+ const currentBatchSize = Math.min(numPoints - batchIndex * batchSize, batchSize);
264
+ const groups = Math.ceil(currentBatchSize / workgroupSize);
265
+ computePass.setBindGroup(0, resource.bindGroups[j]);
266
+ computePass.dispatchWorkgroups(groups);
267
+ }
268
+ computePass.end();
269
+ encoder.copyBufferToBuffer(resource.gpuBuffers.result, 0, resource.gpuBuffers.resultReadBack, 0, resultCount * 4);
270
+ device.queue.submit([encoder.finish()]);
271
+ await resource.gpuBuffers.resultReadBack.mapAsync(GPUMapMode.READ);
272
+ const mapped = resource.gpuBuffers.resultReadBack.getMappedRange();
273
+ labels.set(new Uint32Array(mapped, 0, resultCount), batchStart * batchSize);
274
+ resource.gpuBuffers.resultReadBack.unmap();
275
+ }
276
+ }
277
+ destroy() {
278
+ this.resource.gpuBuffers.uniform.destroy();
279
+ this.resource.gpuBuffers.centroids.destroy();
280
+ this.resource.gpuBuffers.centroidSq?.destroy();
281
+ this.resource.gpuBuffers.result.destroy();
282
+ this.resource.gpuBuffers.resultReadBack.destroy();
283
+ for (const buffer of this.resource.gpuBuffers.points) {
284
+ buffer.destroy();
285
+ }
286
+ }
287
+ }
@@ -1,4 +1,4 @@
1
- export declare function kmeans(points: Float32Array[], k: number, iterations: number, device: GPUDevice): Promise<{
1
+ export declare function kMeans(points: Float32Array[], k: number, iterations: number, device: GPUDevice): Promise<{
2
2
  centroids: Float32Array<ArrayBufferLike>[];
3
3
  labels: Uint32Array<ArrayBuffer>;
4
4
  }>;
@@ -0,0 +1,71 @@
1
+ import { clusterAverage } from '../../native/index.js';
2
+ import { logger } from '../index.js';
3
+ import GpuClustering from './GpuClustering.js';
4
+ // in the 1d case we use quantile-based initialization for better handling of skewed data
5
+ function initializeCentroids1D(data, centroids) {
6
+ const n = data.length;
7
+ const k = centroids.length;
8
+ // Sort data to compute quantiles
9
+ const sorted = Float32Array.from(data).sort((a, b) => a - b);
10
+ for (let i = 0; i < k; ++i) {
11
+ // Place centroid at the center of its expected cluster region
12
+ const quantile = (2 * i + 1) / (2 * k);
13
+ const index = Math.min(Math.floor(quantile * n), n - 1);
14
+ centroids[i] = sorted[index];
15
+ }
16
+ }
17
+ // use floyd's algorithm to pick m unique random indices from 0..n-1
18
+ function pickRandomIndices(n, m) {
19
+ const chosen = new Set();
20
+ for (let j = n - m; j < n; j++) {
21
+ const t = Math.floor(Math.random() * (j + 1));
22
+ chosen.add(chosen.has(t) ? j : t);
23
+ }
24
+ return [...chosen];
25
+ }
26
+ function initializeCentroids(dataTable, centroids) {
27
+ const indices = pickRandomIndices(dataTable[0].length, centroids[0].length);
28
+ for (let i = 0; i < centroids[0].length; i++) {
29
+ for (let j = 0; j < dataTable.length; j++) {
30
+ centroids[j][i] = dataTable[j][indices[i]];
31
+ }
32
+ }
33
+ }
34
+ // https://github.com/playcanvas/splat-transform/blob/main/src/lib/spatial/k-means.ts
35
+ export async function kMeans(points, k, iterations, device) {
36
+ const numRows = points.length > 0 ? points[0].length : 0;
37
+ if (numRows < k) {
38
+ return {
39
+ centroids: points,
40
+ // use a typed array here so downstream code can rely on
41
+ // labels supporting subarray(), even in this early-return
42
+ // path used for very small datasets.
43
+ labels: new Uint32Array(numRows).map((_, i) => i),
44
+ };
45
+ }
46
+ const centroids = points.map(_ => new Float32Array(k));
47
+ if (points.length === 1) {
48
+ initializeCentroids1D(points[0], centroids[0]);
49
+ }
50
+ else {
51
+ initializeCentroids(points, centroids);
52
+ }
53
+ const gpuClustering = new GpuClustering(device, numRows, points.length, k);
54
+ const labels = new Uint32Array(numRows);
55
+ let converged = false;
56
+ let steps = 0;
57
+ while (!converged) {
58
+ logger.info(`kmeans iteration ${steps + 1}`);
59
+ await gpuClustering.execute(points, centroids, labels);
60
+ clusterAverage(points, labels, k, centroids);
61
+ steps++;
62
+ if (steps >= iterations) {
63
+ converged = true;
64
+ }
65
+ }
66
+ gpuClustering.destroy();
67
+ return {
68
+ centroids,
69
+ labels,
70
+ };
71
+ }
@@ -1,11 +1,12 @@
1
1
  /// <reference types="@webgpu/types" />
2
2
  import { createRequire } from 'node:module';
3
3
  import { logger } from './index.js';
4
+ import { getNativePackageName } from '../native/utils.js';
4
5
  const getModule = (function () {
5
6
  let m = undefined;
6
7
  return function () {
7
8
  if (!m) {
8
- m = createRequire(import.meta.url)('webgpu');
9
+ m = createRequire(import.meta.url)(getNativePackageName() + '/dawn.node');
9
10
  Object.assign(globalThis, m.globals);
10
11
  }
11
12
  return m;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@manycore/aholo-splat-transform",
3
- "version": "1.2.9",
3
+ "version": "1.2.11",
4
4
  "description": "Transform & filter Gaussian splats file",
5
5
  "author": "egs",
6
6
  "repository": "https://github.com/manycoretech/aholo-viewer.git",
@@ -27,11 +27,11 @@
27
27
  "dependencies": {
28
28
  "commander": "^14.0.2",
29
29
  "fflate": "^0.8.2",
30
- "tslib": "^2.8.1",
31
- "webgpu": "^0.4.0"
30
+ "tslib": "^2.8.1"
32
31
  },
33
32
  "optionalDependencies": {
34
- "@manycore/aholo-splat-transform-linux-x64-gnu": "1.2.9",
35
- "@manycore/aholo-splat-transform-win32-x64-msvc": "1.2.9"
33
+ "@manycore/aholo-splat-transform-darwin-arm64": "1.2.11",
34
+ "@manycore/aholo-splat-transform-linux-x64-gnu": "1.2.11",
35
+ "@manycore/aholo-splat-transform-win32-x64-msvc": "1.2.11"
36
36
  }
37
37
  }
@@ -1,340 +0,0 @@
1
- import { clusterAverage } from '../native/index.js';
2
- import { logger } from './index.js';
3
- // in the 1d case we use quantile-based initialization for better handling of skewed data
4
- function initializeCentroids1D(data, centroids) {
5
- const n = data.length;
6
- const k = centroids.length;
7
- // Sort data to compute quantiles
8
- const sorted = Float32Array.from(data).sort((a, b) => a - b);
9
- for (let i = 0; i < k; ++i) {
10
- // Place centroid at the center of its expected cluster region
11
- const quantile = (2 * i + 1) / (2 * k);
12
- const index = Math.min(Math.floor(quantile * n), n - 1);
13
- centroids[i] = sorted[index];
14
- }
15
- }
16
- // use floyd's algorithm to pick m unique random indices from 0..n-1
17
- function pickRandomIndices(n, m) {
18
- const chosen = new Set();
19
- for (let j = n - m; j < n; j++) {
20
- const t = Math.floor(Math.random() * (j + 1));
21
- chosen.add(chosen.has(t) ? j : t);
22
- }
23
- return [...chosen];
24
- }
25
- function initializeCentroids(dataTable, centroids) {
26
- const indices = pickRandomIndices(dataTable[0].length, centroids[0].length);
27
- for (let i = 0; i < centroids[0].length; i++) {
28
- for (let j = 0; j < dataTable.length; j++) {
29
- centroids[j][i] = dataTable[j][indices[i]];
30
- }
31
- }
32
- }
33
- const chunkSize = 128;
34
- const workgroupSize = 64;
35
- function clusterWgsl(numColumns) {
36
- return /* wgsl */ `
37
- struct Uniforms {
38
- numPoints: u32,
39
- numCentroids: u32
40
- };
41
-
42
- @group(0) @binding(0) var<uniform> uniforms: Uniforms;
43
- @group(0) @binding(1) var<storage, read> points: array<f32>;
44
- @group(0) @binding(2) var<storage, read> centroids: array<f32>;
45
- @group(0) @binding(3) var<storage, read_write> results: array<u32>;
46
-
47
- const numColumns = ${numColumns}; // number of columns in the points and centroids tables
48
- const chunkSize = ${chunkSize}u; // must be a multiple of 64
49
- const workgroupSize = ${workgroupSize}u;
50
- var<workgroup> sharedChunk: array<f32, numColumns * chunkSize>;
51
-
52
- // calculate the squared distance between the point and centroid
53
- fn calcDistanceSqr(point: array<f32, numColumns>, centroid: u32) -> f32 {
54
- var result = 0.0;
55
-
56
- var ci = centroid * numColumns;
57
-
58
- for (var i = 0u; i < numColumns; i++) {
59
- let v = f32(point[i] - sharedChunk[ci+i]);
60
- result += v * v;
61
- }
62
-
63
- return result;
64
- }
65
-
66
- @compute @workgroup_size(workgroupSize)
67
- fn main(
68
- @builtin(local_invocation_index) local_id : u32,
69
- @builtin(global_invocation_id) global_id: vec3u,
70
- @builtin(num_workgroups) num_workgroups: vec3u
71
- ) {
72
- // calculate row index for this thread point
73
- let pointIndex = global_id.x + global_id.y * num_workgroups.x * workgroupSize;
74
-
75
- // copy the point data from global memory
76
- var point: array<f32, numColumns>;
77
- if (pointIndex < uniforms.numPoints) {
78
- for (var i = 0u; i < numColumns; i++) {
79
- point[i] = points[pointIndex * numColumns + i];
80
- }
81
- }
82
-
83
- var mind = 1000000.0;
84
- var mini = 0u;
85
-
86
- // work through the list of centroids in shared memory chunks
87
- let numChunks = u32(ceil(f32(uniforms.numCentroids) / f32(chunkSize)));
88
- for (var i = 0u; i < numChunks; i++) {
89
-
90
- // copy this thread's slice of the centroid shared chunk data
91
- let dstRow = local_id * (chunkSize / workgroupSize);
92
- let srcRow = min(uniforms.numCentroids, i * chunkSize + local_id * chunkSize / workgroupSize);
93
- let numRows = min(uniforms.numCentroids, srcRow + chunkSize / workgroupSize) - srcRow;
94
-
95
- var dst = dstRow * numColumns;
96
- var src = srcRow * numColumns;
97
-
98
- for (var c = 0u; c < numRows * numColumns; c++) {
99
- sharedChunk[dst + c] = centroids[src + c];
100
- }
101
-
102
- // wait for all threads to finish writing their part of centroids shared memory buffer
103
- workgroupBarrier();
104
-
105
- // loop over the next chunk of centroids finding the closest
106
- if (pointIndex < uniforms.numPoints) {
107
- let thisChunkSize = min(chunkSize, uniforms.numCentroids - i * chunkSize);
108
- for (var c = 0u; c < thisChunkSize; c++) {
109
- let d = calcDistanceSqr(point, c);
110
- if (d < mind) {
111
- mind = d;
112
- mini = i * chunkSize + c;
113
- }
114
- }
115
- }
116
-
117
- // next loop will overwrite the shared memory, so wait
118
- workgroupBarrier();
119
- }
120
-
121
- if (pointIndex < uniforms.numPoints) {
122
- results[pointIndex] = mini;
123
- }
124
- }
125
- `;
126
- }
127
- function interleaveData(result, dataTable, numRows, rowOffset) {
128
- const numColumns = dataTable.length;
129
- for (let c = 0; c < numColumns; ++c) {
130
- const column = dataTable[c];
131
- for (let r = 0; r < numRows; ++r) {
132
- result[r * numColumns + c] = column[rowOffset + r];
133
- }
134
- }
135
- }
136
- const MAX_CONCURRENCY_BATCHES = 10;
137
- class GpuClustering {
138
- constructor(device, numPoints, numColumns, numCentroids) {
139
- this.device = device;
140
- this.numPoints = numPoints;
141
- this.numColumns = numColumns;
142
- this.numCentroids = numCentroids;
143
- const workgroupsPerBatch = Math.min(device.limits.maxComputeWorkgroupsPerDimension, // device dispatch limit
144
- Math.floor(device.limits.maxBufferSize / (numColumns * workgroupSize * 4)), // point storage limit
145
- Math.ceil(numPoints / workgroupSize));
146
- this.batchSize = workgroupsPerBatch * workgroupSize;
147
- this.numBatches = Math.ceil(numPoints / this.batchSize);
148
- this.concurrencyBatches = Math.min(MAX_CONCURRENCY_BATCHES, this.numBatches);
149
- this.concurrencyRuns = Math.ceil(this.numBatches / this.concurrencyBatches);
150
- const shader = device.createShaderModule({
151
- code: clusterWgsl(numColumns),
152
- });
153
- const pipeline = device.createComputePipeline({
154
- layout: 'auto',
155
- compute: {
156
- module: shader,
157
- entryPoint: 'main',
158
- },
159
- });
160
- const pointsBackBuffer = new Float32Array(numColumns * this.batchSize);
161
- const centroidsBackBuffer = new Float32Array(numColumns * numCentroids);
162
- const uniformBackBuffer = new Uint32Array([0, numCentroids]);
163
- const pointsBuffers = [];
164
- const centroidsBuffer = device.createBuffer({
165
- size: centroidsBackBuffer.byteLength,
166
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
167
- });
168
- const uniformBuffer = device.createBuffer({
169
- size: 256 * this.concurrencyBatches,
170
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM,
171
- });
172
- const resultBuffer = device.createBuffer({
173
- size: this.concurrencyBatches * this.batchSize * 4,
174
- usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.STORAGE,
175
- });
176
- const resultReadBackBuffer = device.createBuffer({
177
- size: this.concurrencyBatches * this.batchSize * 4,
178
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
179
- });
180
- const layout = pipeline.getBindGroupLayout(0);
181
- const bindGroups = [];
182
- for (let i = 0; i < this.concurrencyBatches; i++) {
183
- const pointsBuffer = device.createBuffer({
184
- size: pointsBackBuffer.byteLength,
185
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
186
- });
187
- pointsBuffers.push(pointsBuffer);
188
- bindGroups.push(device.createBindGroup({
189
- layout,
190
- entries: [
191
- {
192
- binding: 0,
193
- resource: {
194
- buffer: uniformBuffer,
195
- offset: i * 256,
196
- size: 8,
197
- },
198
- },
199
- {
200
- binding: 1,
201
- resource: pointsBuffer,
202
- },
203
- {
204
- binding: 2,
205
- resource: centroidsBuffer,
206
- },
207
- {
208
- binding: 3,
209
- resource: {
210
- buffer: resultBuffer,
211
- offset: i * this.batchSize * 4,
212
- size: this.batchSize * 4,
213
- },
214
- },
215
- ],
216
- }));
217
- }
218
- this.resource = {
219
- pipeline,
220
- bindGroups,
221
- gpuBuffers: {
222
- uniform: uniformBuffer,
223
- points: pointsBuffers,
224
- centroids: centroidsBuffer,
225
- result: resultBuffer,
226
- resultReadBack: resultReadBackBuffer,
227
- },
228
- backBuffers: {
229
- uniform: uniformBackBuffer,
230
- points: pointsBackBuffer,
231
- centroids: centroidsBackBuffer,
232
- },
233
- uploadedBatches: [],
234
- };
235
- logger.info(`GPU k-means kernel bootstrapped with batch ${workgroupsPerBatch}*${workgroupSize}*${this.numBatches}, concurrency: ${this.concurrencyBatches}, runs: ${this.concurrencyRuns}`);
236
- }
237
- async execute(points, centroids, labels) {
238
- const { device, numPoints, numColumns, numCentroids, numBatches, batchSize, resource, concurrencyBatches, concurrencyRuns, } = this;
239
- // upload centroid data to gpu
240
- interleaveData(resource.backBuffers.centroids, centroids, numCentroids, 0);
241
- device.queue.writeBuffer(resource.gpuBuffers.centroids, 0, resource.backBuffers.centroids.buffer);
242
- for (let i = 0; i < concurrencyRuns; i++) {
243
- const batchStart = i * concurrencyBatches;
244
- let resultCount = 0;
245
- for (let j = 0; j < concurrencyBatches; j++) {
246
- const batchIndex = batchStart + j;
247
- if (batchIndex >= numBatches) {
248
- break;
249
- }
250
- const currentBatchSize = Math.min(numPoints - batchIndex * batchSize, batchSize);
251
- resultCount += currentBatchSize;
252
- // write this batch of point data to gpu
253
- if (resource.uploadedBatches[j] !== batchIndex) {
254
- interleaveData(resource.backBuffers.points, points, currentBatchSize, batchIndex * batchSize);
255
- device.queue.writeBuffer(resource.gpuBuffers.points[j], 0, resource.backBuffers.points.buffer, 0, numColumns * currentBatchSize * 4);
256
- resource.backBuffers.uniform[0] = currentBatchSize;
257
- device.queue.writeBuffer(resource.gpuBuffers.uniform, 256 * j, resource.backBuffers.uniform.buffer, 0, 8);
258
- resource.uploadedBatches[j] = batchIndex;
259
- }
260
- }
261
- const encoder = device.createCommandEncoder();
262
- const computePass = encoder.beginComputePass();
263
- computePass.setPipeline(resource.pipeline);
264
- for (let j = 0; j < concurrencyBatches; j++) {
265
- const batchIndex = batchStart + j;
266
- if (batchIndex >= numBatches) {
267
- break;
268
- }
269
- const currentBatchSize = Math.min(numPoints - batchIndex * batchSize, batchSize);
270
- const groups = Math.ceil(currentBatchSize / workgroupSize);
271
- computePass.setBindGroup(0, resource.bindGroups[j]);
272
- computePass.dispatchWorkgroups(groups);
273
- }
274
- computePass.end();
275
- encoder.copyBufferToBuffer(resource.gpuBuffers.result, 0, resource.gpuBuffers.resultReadBack, 0, resultCount * 4);
276
- device.queue.submit([encoder.finish()]);
277
- await resource.gpuBuffers.resultReadBack.mapAsync(GPUMapMode.READ);
278
- const mapped = resource.gpuBuffers.resultReadBack.getMappedRange();
279
- labels.set(new Uint32Array(mapped, 0, resultCount), batchStart * batchSize);
280
- resource.gpuBuffers.resultReadBack.unmap();
281
- }
282
- }
283
- destroy() {
284
- this.resource.gpuBuffers.uniform.destroy();
285
- this.resource.gpuBuffers.centroids.destroy();
286
- this.resource.gpuBuffers.result.destroy();
287
- this.resource.gpuBuffers.resultReadBack.destroy();
288
- for (const buffer of this.resource.gpuBuffers.points) {
289
- buffer.destroy();
290
- }
291
- }
292
- }
293
- function groupLabels(labels, k) {
294
- const clusters = [];
295
- for (let i = 0; i < k; ++i) {
296
- clusters[i] = [];
297
- }
298
- for (let i = 0; i < labels.length; ++i) {
299
- clusters[labels[i]].push(i);
300
- }
301
- return clusters.map(c => new Uint32Array(c));
302
- }
303
- // https://github.com/playcanvas/splat-transform/blob/main/src/lib/spatial/k-means.ts
304
- export async function kmeans(points, k, iterations, device) {
305
- const numRows = points.length > 0 ? points[0].length : 0;
306
- if (numRows < k) {
307
- return {
308
- centroids: points,
309
- // use a typed array here so downstream code can rely on
310
- // labels supporting subarray(), even in this early-return
311
- // path used for very small datasets.
312
- labels: new Uint32Array(numRows).map((_, i) => i),
313
- };
314
- }
315
- const centroids = points.map(_ => new Float32Array(k));
316
- if (points.length === 1) {
317
- initializeCentroids1D(points[0], centroids[0]);
318
- }
319
- else {
320
- initializeCentroids(points, centroids);
321
- }
322
- const gpuClustering = new GpuClustering(device, numRows, points.length, k);
323
- const labels = new Uint32Array(numRows);
324
- let converged = false;
325
- let steps = 0;
326
- while (!converged) {
327
- logger.info(`kmeans iteration ${steps + 1}`);
328
- await gpuClustering.execute(points, centroids, labels);
329
- clusterAverage(points, groupLabels(labels, k), centroids);
330
- steps++;
331
- if (steps >= iterations) {
332
- converged = true;
333
- }
334
- }
335
- gpuClustering.destroy();
336
- return {
337
- centroids,
338
- labels,
339
- };
340
- }