@manycore/aholo-splat-transform 1.2.7 → 1.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. package/CHANGELOG.md +120 -106
  2. package/COPYRIGHT.md +17 -0
  3. package/README.md +39 -39
  4. package/THIRD_PARTY_LICENSES.txt +1373 -0
  5. package/bin/cli.js +125 -118
  6. package/dist/SplatData.d.ts +67 -67
  7. package/dist/SplatData.js +167 -156
  8. package/dist/constant.d.ts +3 -3
  9. package/dist/constant.js +13 -13
  10. package/dist/file/IFile.d.ts +5 -5
  11. package/dist/file/IFile.js +1 -1
  12. package/dist/file/esz.d.ts +11 -0
  13. package/dist/file/esz.js +337 -0
  14. package/dist/file/index.d.ts +8 -7
  15. package/dist/file/index.js +7 -6
  16. package/dist/file/ksplat.d.ts +12 -12
  17. package/dist/file/ksplat.js +293 -232
  18. package/dist/file/lcc.d.ts +11 -11
  19. package/dist/file/lcc.js +161 -157
  20. package/dist/file/ply.d.ts +13 -13
  21. package/dist/file/ply.js +439 -388
  22. package/dist/file/sog.d.ts +80 -80
  23. package/dist/file/sog.js +525 -504
  24. package/dist/file/splat.d.ts +6 -6
  25. package/dist/file/splat.js +119 -99
  26. package/dist/file/spz.d.ts +11 -8
  27. package/dist/file/spz.js +597 -400
  28. package/dist/file/voxel.d.ts +43 -37
  29. package/dist/file/voxel.js +411 -280
  30. package/dist/index.d.ts +33 -33
  31. package/dist/index.js +54 -54
  32. package/dist/native/index.d.ts +54 -54
  33. package/dist/native/index.js +122 -128
  34. package/dist/native/utils.d.ts +1 -0
  35. package/dist/native/utils.js +54 -0
  36. package/dist/tasks/AutoChunkLodTask.d.ts +13 -13
  37. package/dist/tasks/AutoChunkLodTask.js +117 -117
  38. package/dist/tasks/AutoLodTask.d.ts +10 -10
  39. package/dist/tasks/AutoLodTask.js +20 -20
  40. package/dist/tasks/BaseTask.d.ts +15 -15
  41. package/dist/tasks/BaseTask.js +5 -5
  42. package/dist/tasks/FlexLodTask.d.ts +12 -12
  43. package/dist/tasks/FlexLodTask.js +54 -44
  44. package/dist/tasks/ModifyTask.d.ts +9 -9
  45. package/dist/tasks/ModifyTask.js +166 -156
  46. package/dist/tasks/ReadTask.d.ts +9 -9
  47. package/dist/tasks/ReadTask.js +29 -29
  48. package/dist/tasks/SkeletonLodTask.d.ts +10 -10
  49. package/dist/tasks/SkeletonLodTask.js +176 -156
  50. package/dist/tasks/VoxelTask.d.ts +35 -30
  51. package/dist/tasks/VoxelTask.js +40 -37
  52. package/dist/tasks/WriteTask.d.ts +12 -11
  53. package/dist/tasks/WriteTask.js +70 -70
  54. package/dist/utils/BufferReader.d.ts +12 -12
  55. package/dist/utils/BufferReader.js +45 -47
  56. package/dist/utils/Logger.d.ts +11 -11
  57. package/dist/utils/Logger.js +40 -38
  58. package/dist/utils/StreamChunkDecoder.d.ts +16 -16
  59. package/dist/utils/StreamChunkDecoder.js +31 -36
  60. package/dist/utils/index.d.ts +27 -27
  61. package/dist/utils/index.js +101 -101
  62. package/dist/utils/k-means.d.ts +4 -4
  63. package/dist/utils/k-means.js +340 -350
  64. package/dist/utils/math.d.ts +46 -46
  65. package/dist/utils/math.js +350 -351
  66. package/dist/utils/quantize-1d.d.ts +4 -4
  67. package/dist/utils/quantize-1d.js +164 -164
  68. package/dist/utils/sh-rotate.d.ts +2 -2
  69. package/dist/utils/sh-rotate.js +236 -175
  70. package/dist/utils/splat.d.ts +21 -20
  71. package/dist/utils/splat.js +397 -378
  72. package/dist/utils/voxel/binary.d.ts +8 -0
  73. package/dist/utils/voxel/binary.js +176 -0
  74. package/dist/utils/voxel/common.d.ts +178 -162
  75. package/dist/utils/voxel/common.js +1752 -1700
  76. package/dist/utils/voxel/coplanar-merge.d.ts +63 -63
  77. package/dist/utils/voxel/coplanar-merge.js +818 -819
  78. package/dist/utils/voxel/filter-cluster.d.ts +20 -0
  79. package/dist/utils/voxel/filter-cluster.js +628 -0
  80. package/dist/utils/voxel/gpu-dilation.d.ts +2 -2
  81. package/dist/utils/voxel/gpu-dilation.js +677 -665
  82. package/dist/utils/voxel/marching-cubes.d.ts +42 -42
  83. package/dist/utils/voxel/marching-cubes.js +1645 -1657
  84. package/dist/utils/voxel/mesh.d.ts +3 -3
  85. package/dist/utils/voxel/mesh.js +130 -130
  86. package/dist/utils/voxel/nav.d.ts +29 -29
  87. package/dist/utils/voxel/nav.js +1068 -1043
  88. package/dist/utils/voxel/postprocess.d.ts +23 -23
  89. package/dist/utils/voxel/postprocess.js +408 -375
  90. package/dist/utils/voxel/voxel-faces.d.ts +18 -18
  91. package/dist/utils/voxel/voxel-faces.js +662 -663
  92. package/dist/utils/voxel/voxelize.d.ts +34 -33
  93. package/dist/utils/voxel/voxelize.js +1208 -1193
  94. package/dist/utils/webgpu.d.ts +8 -8
  95. package/dist/utils/webgpu.js +122 -122
  96. package/package.json +37 -30
  97. package/dist/native/cpp/bin/linux/binding.node +0 -0
  98. package/dist/native/cpp/bin/windows/binding.node +0 -0
@@ -1,350 +1,340 @@
1
- import { clusterAverage } from '../native/index.js';
2
- import { logger } from './index.js';
3
- // in the 1d case we use quantile-based initialization for better handling of skewed data
4
- function initializeCentroids1D(data, centroids) {
5
- const n = data.length;
6
- const k = centroids.length;
7
- // Sort data to compute quantiles
8
- const sorted = Float32Array.from(data).sort((a, b) => a - b);
9
- for (let i = 0; i < k; ++i) {
10
- // Place centroid at the center of its expected cluster region
11
- const quantile = (2 * i + 1) / (2 * k);
12
- const index = Math.min(Math.floor(quantile * n), n - 1);
13
- centroids[i] = sorted[index];
14
- }
15
- }
16
- ;
17
- // use floyd's algorithm to pick m unique random indices from 0..n-1
18
- function pickRandomIndices(n, m) {
19
- const chosen = new Set();
20
- for (let j = n - m; j < n; j++) {
21
- const t = Math.floor(Math.random() * (j + 1));
22
- chosen.add(chosen.has(t) ? j : t);
23
- }
24
- return [...chosen];
25
- }
26
- ;
27
- function initializeCentroids(dataTable, centroids) {
28
- const indices = pickRandomIndices(dataTable[0].length, centroids[0].length);
29
- for (let i = 0; i < centroids[0].length; i++) {
30
- for (let j = 0; j < dataTable.length; j++) {
31
- centroids[j][i] = dataTable[j][indices[i]];
32
- }
33
- }
34
- }
35
- ;
36
- const chunkSize = 128;
37
- const workgroupSize = 64;
38
- function clusterWgsl(numColumns) {
39
- return /* wgsl */ `
40
- struct Uniforms {
41
- numPoints: u32,
42
- numCentroids: u32
43
- };
44
-
45
- @group(0) @binding(0) var<uniform> uniforms: Uniforms;
46
- @group(0) @binding(1) var<storage, read> points: array<f32>;
47
- @group(0) @binding(2) var<storage, read> centroids: array<f32>;
48
- @group(0) @binding(3) var<storage, read_write> results: array<u32>;
49
-
50
- const numColumns = ${numColumns}; // number of columns in the points and centroids tables
51
- const chunkSize = ${chunkSize}u; // must be a multiple of 64
52
- const workgroupSize = ${workgroupSize}u;
53
- var<workgroup> sharedChunk: array<f32, numColumns * chunkSize>;
54
-
55
- // calculate the squared distance between the point and centroid
56
- fn calcDistanceSqr(point: array<f32, numColumns>, centroid: u32) -> f32 {
57
- var result = 0.0;
58
-
59
- var ci = centroid * numColumns;
60
-
61
- for (var i = 0u; i < numColumns; i++) {
62
- let v = f32(point[i] - sharedChunk[ci+i]);
63
- result += v * v;
64
- }
65
-
66
- return result;
67
- }
68
-
69
- @compute @workgroup_size(workgroupSize)
70
- fn main(
71
- @builtin(local_invocation_index) local_id : u32,
72
- @builtin(global_invocation_id) global_id: vec3u,
73
- @builtin(num_workgroups) num_workgroups: vec3u
74
- ) {
75
- // calculate row index for this thread point
76
- let pointIndex = global_id.x + global_id.y * num_workgroups.x * workgroupSize;
77
-
78
- // copy the point data from global memory
79
- var point: array<f32, numColumns>;
80
- if (pointIndex < uniforms.numPoints) {
81
- for (var i = 0u; i < numColumns; i++) {
82
- point[i] = points[pointIndex * numColumns + i];
83
- }
84
- }
85
-
86
- var mind = 1000000.0;
87
- var mini = 0u;
88
-
89
- // work through the list of centroids in shared memory chunks
90
- let numChunks = u32(ceil(f32(uniforms.numCentroids) / f32(chunkSize)));
91
- for (var i = 0u; i < numChunks; i++) {
92
-
93
- // copy this thread's slice of the centroid shared chunk data
94
- let dstRow = local_id * (chunkSize / workgroupSize);
95
- let srcRow = min(uniforms.numCentroids, i * chunkSize + local_id * chunkSize / workgroupSize);
96
- let numRows = min(uniforms.numCentroids, srcRow + chunkSize / workgroupSize) - srcRow;
97
-
98
- var dst = dstRow * numColumns;
99
- var src = srcRow * numColumns;
100
-
101
- for (var c = 0u; c < numRows * numColumns; c++) {
102
- sharedChunk[dst + c] = centroids[src + c];
103
- }
104
-
105
- // wait for all threads to finish writing their part of centroids shared memory buffer
106
- workgroupBarrier();
107
-
108
- // loop over the next chunk of centroids finding the closest
109
- if (pointIndex < uniforms.numPoints) {
110
- let thisChunkSize = min(chunkSize, uniforms.numCentroids - i * chunkSize);
111
- for (var c = 0u; c < thisChunkSize; c++) {
112
- let d = calcDistanceSqr(point, c);
113
- if (d < mind) {
114
- mind = d;
115
- mini = i * chunkSize + c;
116
- }
117
- }
118
- }
119
-
120
- // next loop will overwrite the shared memory, so wait
121
- workgroupBarrier();
122
- }
123
-
124
- if (pointIndex < uniforms.numPoints) {
125
- results[pointIndex] = mini;
126
- }
127
- }
128
- `;
129
- }
130
- function interleaveData(result, dataTable, numRows, rowOffset) {
131
- const numColumns = dataTable.length;
132
- for (let c = 0; c < numColumns; ++c) {
133
- const column = dataTable[c];
134
- for (let r = 0; r < numRows; ++r) {
135
- result[r * numColumns + c] = column[rowOffset + r];
136
- }
137
- }
138
- }
139
- const MAX_CONCURRENCY_BATCHES = 10;
140
- class GpuClustering {
141
- device;
142
- numPoints;
143
- numColumns;
144
- numCentroids;
145
- batchSize;
146
- resource;
147
- numBatches;
148
- concurrencyBatches;
149
- concurrencyRuns;
150
- constructor(device, numPoints, numColumns, numCentroids) {
151
- this.device = device;
152
- this.numPoints = numPoints;
153
- this.numColumns = numColumns;
154
- this.numCentroids = numCentroids;
155
- const workgroupsPerBatch = Math.min(device.limits.maxComputeWorkgroupsPerDimension, // device dispatch limit
156
- Math.floor(device.limits.maxBufferSize / (numColumns * workgroupSize * 4)), // point storage limit
157
- Math.ceil(numPoints / workgroupSize) // max limit
158
- );
159
- this.batchSize = workgroupsPerBatch * workgroupSize;
160
- this.numBatches = Math.ceil(numPoints / this.batchSize);
161
- this.concurrencyBatches = Math.min(MAX_CONCURRENCY_BATCHES, this.numBatches);
162
- this.concurrencyRuns = Math.ceil(this.numBatches / this.concurrencyBatches);
163
- const shader = device.createShaderModule({
164
- code: clusterWgsl(numColumns),
165
- });
166
- const pipeline = device.createComputePipeline({
167
- layout: 'auto',
168
- compute: {
169
- module: shader,
170
- entryPoint: 'main'
171
- }
172
- });
173
- const pointsBackBuffer = new Float32Array(numColumns * this.batchSize);
174
- const centroidsBackBuffer = new Float32Array(numColumns * numCentroids);
175
- const uniformBackBuffer = new Uint32Array([0, numCentroids]);
176
- const pointsBuffers = [];
177
- const centroidsBuffer = device.createBuffer({
178
- size: centroidsBackBuffer.byteLength,
179
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
180
- });
181
- const uniformBuffer = device.createBuffer({
182
- size: 256 * this.concurrencyBatches,
183
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM
184
- });
185
- const resultBuffer = device.createBuffer({
186
- size: this.concurrencyBatches * this.batchSize * 4,
187
- usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.STORAGE
188
- });
189
- const resultReadBackBuffer = device.createBuffer({
190
- size: this.concurrencyBatches * this.batchSize * 4,
191
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
192
- });
193
- const layout = pipeline.getBindGroupLayout(0);
194
- const bindGroups = [];
195
- for (let i = 0; i < this.concurrencyBatches; i++) {
196
- const pointsBuffer = device.createBuffer({
197
- size: pointsBackBuffer.byteLength,
198
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
199
- });
200
- pointsBuffers.push(pointsBuffer);
201
- bindGroups.push(device.createBindGroup({
202
- layout,
203
- entries: [{
204
- binding: 0,
205
- resource: {
206
- buffer: uniformBuffer,
207
- offset: i * 256,
208
- size: 8
209
- }
210
- }, {
211
- binding: 1,
212
- resource: pointsBuffer,
213
- }, {
214
- binding: 2,
215
- resource: centroidsBuffer,
216
- }, {
217
- binding: 3,
218
- resource: {
219
- buffer: resultBuffer,
220
- offset: i * this.batchSize * 4,
221
- size: this.batchSize * 4
222
- }
223
- }]
224
- }));
225
- }
226
- this.resource = {
227
- pipeline,
228
- bindGroups,
229
- gpuBuffers: {
230
- uniform: uniformBuffer,
231
- points: pointsBuffers,
232
- centroids: centroidsBuffer,
233
- result: resultBuffer,
234
- resultReadBack: resultReadBackBuffer,
235
- },
236
- backBuffers: {
237
- uniform: uniformBackBuffer,
238
- points: pointsBackBuffer,
239
- centroids: centroidsBackBuffer,
240
- },
241
- uploadedBatches: [],
242
- };
243
- logger.info(`GPU k-means kernel bootstrapped with batch ${workgroupsPerBatch}*${workgroupSize}*${this.numBatches}, concurrency: ${this.concurrencyBatches}, runs: ${this.concurrencyRuns}`);
244
- }
245
- async execute(points, centroids, labels) {
246
- const { device, numPoints, numColumns, numCentroids, numBatches, batchSize, resource, concurrencyBatches, concurrencyRuns } = this;
247
- // upload centroid data to gpu
248
- interleaveData(resource.backBuffers.centroids, centroids, numCentroids, 0);
249
- device.queue.writeBuffer(resource.gpuBuffers.centroids, 0, resource.backBuffers.centroids.buffer);
250
- for (let i = 0; i < concurrencyRuns; i++) {
251
- const batchStart = i * concurrencyBatches;
252
- let resultCount = 0;
253
- for (let j = 0; j < concurrencyBatches; j++) {
254
- const batchIndex = batchStart + j;
255
- if (batchIndex >= numBatches) {
256
- break;
257
- }
258
- const currentBatchSize = Math.min(numPoints - batchIndex * batchSize, batchSize);
259
- resultCount += currentBatchSize;
260
- // write this batch of point data to gpu
261
- if (resource.uploadedBatches[j] !== batchIndex) {
262
- interleaveData(resource.backBuffers.points, points, currentBatchSize, batchIndex * batchSize);
263
- device.queue.writeBuffer(resource.gpuBuffers.points[j], 0, resource.backBuffers.points.buffer, 0, numColumns * currentBatchSize * 4);
264
- resource.backBuffers.uniform[0] = currentBatchSize;
265
- device.queue.writeBuffer(resource.gpuBuffers.uniform, 256 * j, resource.backBuffers.uniform.buffer, 0, 8);
266
- resource.uploadedBatches[j] = batchIndex;
267
- }
268
- }
269
- const encoder = device.createCommandEncoder();
270
- const computePass = encoder.beginComputePass();
271
- computePass.setPipeline(resource.pipeline);
272
- for (let j = 0; j < concurrencyBatches; j++) {
273
- const batchIndex = batchStart + j;
274
- if (batchIndex >= numBatches) {
275
- break;
276
- }
277
- const currentBatchSize = Math.min(numPoints - batchIndex * batchSize, batchSize);
278
- const groups = Math.ceil(currentBatchSize / workgroupSize);
279
- computePass.setBindGroup(0, resource.bindGroups[j]);
280
- computePass.dispatchWorkgroups(groups);
281
- }
282
- computePass.end();
283
- encoder.copyBufferToBuffer(resource.gpuBuffers.result, 0, resource.gpuBuffers.resultReadBack, 0, resultCount * 4);
284
- device.queue.submit([encoder.finish()]);
285
- await resource.gpuBuffers.resultReadBack.mapAsync(GPUMapMode.READ);
286
- const mapped = resource.gpuBuffers.resultReadBack.getMappedRange();
287
- labels.set(new Uint32Array(mapped, 0, resultCount), batchStart * batchSize);
288
- resource.gpuBuffers.resultReadBack.unmap();
289
- }
290
- ;
291
- }
292
- destroy() {
293
- this.resource.gpuBuffers.uniform.destroy();
294
- this.resource.gpuBuffers.centroids.destroy();
295
- this.resource.gpuBuffers.result.destroy();
296
- this.resource.gpuBuffers.resultReadBack.destroy();
297
- for (const buffer of this.resource.gpuBuffers.points) {
298
- buffer.destroy();
299
- }
300
- }
301
- }
302
- function groupLabels(labels, k) {
303
- const clusters = [];
304
- for (let i = 0; i < k; ++i) {
305
- clusters[i] = [];
306
- }
307
- for (let i = 0; i < labels.length; ++i) {
308
- clusters[labels[i]].push(i);
309
- }
310
- return clusters.map(c => new Uint32Array(c));
311
- }
312
- ;
313
- // https://github.com/playcanvas/splat-transform/blob/main/src/lib/spatial/k-means.ts
314
- export async function kmeans(points, k, iterations, device) {
315
- const numRows = points.length > 0 ? points[0].length : 0;
316
- if (numRows < k) {
317
- return {
318
- centroids: points,
319
- // use a typed array here so downstream code can rely on
320
- // labels supporting subarray(), even in this early-return
321
- // path used for very small datasets.
322
- labels: new Uint32Array(numRows).map((_, i) => i)
323
- };
324
- }
325
- const centroids = points.map(_ => new Float32Array(k));
326
- if (points.length === 1) {
327
- initializeCentroids1D(points[0], centroids[0]);
328
- }
329
- else {
330
- initializeCentroids(points, centroids);
331
- }
332
- const gpuClustering = new GpuClustering(device, numRows, points.length, k);
333
- const labels = new Uint32Array(numRows);
334
- let converged = false;
335
- let steps = 0;
336
- while (!converged) {
337
- logger.info(`kmeans iteration ${steps + 1}`);
338
- await gpuClustering.execute(points, centroids, labels);
339
- clusterAverage(points, groupLabels(labels, k), centroids);
340
- steps++;
341
- if (steps >= iterations) {
342
- converged = true;
343
- }
344
- }
345
- gpuClustering.destroy();
346
- return {
347
- centroids,
348
- labels
349
- };
350
- }
1
+ import { clusterAverage } from '../native/index.js';
2
+ import { logger } from './index.js';
3
+ // in the 1d case we use quantile-based initialization for better handling of skewed data
4
+ function initializeCentroids1D(data, centroids) {
5
+ const n = data.length;
6
+ const k = centroids.length;
7
+ // Sort data to compute quantiles
8
+ const sorted = Float32Array.from(data).sort((a, b) => a - b);
9
+ for (let i = 0; i < k; ++i) {
10
+ // Place centroid at the center of its expected cluster region
11
+ const quantile = (2 * i + 1) / (2 * k);
12
+ const index = Math.min(Math.floor(quantile * n), n - 1);
13
+ centroids[i] = sorted[index];
14
+ }
15
+ }
16
+ // use floyd's algorithm to pick m unique random indices from 0..n-1
17
+ function pickRandomIndices(n, m) {
18
+ const chosen = new Set();
19
+ for (let j = n - m; j < n; j++) {
20
+ const t = Math.floor(Math.random() * (j + 1));
21
+ chosen.add(chosen.has(t) ? j : t);
22
+ }
23
+ return [...chosen];
24
+ }
25
+ function initializeCentroids(dataTable, centroids) {
26
+ const indices = pickRandomIndices(dataTable[0].length, centroids[0].length);
27
+ for (let i = 0; i < centroids[0].length; i++) {
28
+ for (let j = 0; j < dataTable.length; j++) {
29
+ centroids[j][i] = dataTable[j][indices[i]];
30
+ }
31
+ }
32
+ }
33
+ const chunkSize = 128;
34
+ const workgroupSize = 64;
35
+ function clusterWgsl(numColumns) {
36
+ return /* wgsl */ `
37
+ struct Uniforms {
38
+ numPoints: u32,
39
+ numCentroids: u32
40
+ };
41
+
42
+ @group(0) @binding(0) var<uniform> uniforms: Uniforms;
43
+ @group(0) @binding(1) var<storage, read> points: array<f32>;
44
+ @group(0) @binding(2) var<storage, read> centroids: array<f32>;
45
+ @group(0) @binding(3) var<storage, read_write> results: array<u32>;
46
+
47
+ const numColumns = ${numColumns}; // number of columns in the points and centroids tables
48
+ const chunkSize = ${chunkSize}u; // must be a multiple of 64
49
+ const workgroupSize = ${workgroupSize}u;
50
+ var<workgroup> sharedChunk: array<f32, numColumns * chunkSize>;
51
+
52
+ // calculate the squared distance between the point and centroid
53
+ fn calcDistanceSqr(point: array<f32, numColumns>, centroid: u32) -> f32 {
54
+ var result = 0.0;
55
+
56
+ var ci = centroid * numColumns;
57
+
58
+ for (var i = 0u; i < numColumns; i++) {
59
+ let v = f32(point[i] - sharedChunk[ci+i]);
60
+ result += v * v;
61
+ }
62
+
63
+ return result;
64
+ }
65
+
66
+ @compute @workgroup_size(workgroupSize)
67
+ fn main(
68
+ @builtin(local_invocation_index) local_id : u32,
69
+ @builtin(global_invocation_id) global_id: vec3u,
70
+ @builtin(num_workgroups) num_workgroups: vec3u
71
+ ) {
72
+ // calculate row index for this thread point
73
+ let pointIndex = global_id.x + global_id.y * num_workgroups.x * workgroupSize;
74
+
75
+ // copy the point data from global memory
76
+ var point: array<f32, numColumns>;
77
+ if (pointIndex < uniforms.numPoints) {
78
+ for (var i = 0u; i < numColumns; i++) {
79
+ point[i] = points[pointIndex * numColumns + i];
80
+ }
81
+ }
82
+
83
+ var mind = 1000000.0;
84
+ var mini = 0u;
85
+
86
+ // work through the list of centroids in shared memory chunks
87
+ let numChunks = u32(ceil(f32(uniforms.numCentroids) / f32(chunkSize)));
88
+ for (var i = 0u; i < numChunks; i++) {
89
+
90
+ // copy this thread's slice of the centroid shared chunk data
91
+ let dstRow = local_id * (chunkSize / workgroupSize);
92
+ let srcRow = min(uniforms.numCentroids, i * chunkSize + local_id * chunkSize / workgroupSize);
93
+ let numRows = min(uniforms.numCentroids, srcRow + chunkSize / workgroupSize) - srcRow;
94
+
95
+ var dst = dstRow * numColumns;
96
+ var src = srcRow * numColumns;
97
+
98
+ for (var c = 0u; c < numRows * numColumns; c++) {
99
+ sharedChunk[dst + c] = centroids[src + c];
100
+ }
101
+
102
+ // wait for all threads to finish writing their part of centroids shared memory buffer
103
+ workgroupBarrier();
104
+
105
+ // loop over the next chunk of centroids finding the closest
106
+ if (pointIndex < uniforms.numPoints) {
107
+ let thisChunkSize = min(chunkSize, uniforms.numCentroids - i * chunkSize);
108
+ for (var c = 0u; c < thisChunkSize; c++) {
109
+ let d = calcDistanceSqr(point, c);
110
+ if (d < mind) {
111
+ mind = d;
112
+ mini = i * chunkSize + c;
113
+ }
114
+ }
115
+ }
116
+
117
+ // next loop will overwrite the shared memory, so wait
118
+ workgroupBarrier();
119
+ }
120
+
121
+ if (pointIndex < uniforms.numPoints) {
122
+ results[pointIndex] = mini;
123
+ }
124
+ }
125
+ `;
126
+ }
127
+ function interleaveData(result, dataTable, numRows, rowOffset) {
128
+ const numColumns = dataTable.length;
129
+ for (let c = 0; c < numColumns; ++c) {
130
+ const column = dataTable[c];
131
+ for (let r = 0; r < numRows; ++r) {
132
+ result[r * numColumns + c] = column[rowOffset + r];
133
+ }
134
+ }
135
+ }
136
+ const MAX_CONCURRENCY_BATCHES = 10;
137
+ class GpuClustering {
138
+ constructor(device, numPoints, numColumns, numCentroids) {
139
+ this.device = device;
140
+ this.numPoints = numPoints;
141
+ this.numColumns = numColumns;
142
+ this.numCentroids = numCentroids;
143
+ const workgroupsPerBatch = Math.min(device.limits.maxComputeWorkgroupsPerDimension, // device dispatch limit
144
+ Math.floor(device.limits.maxBufferSize / (numColumns * workgroupSize * 4)), // point storage limit
145
+ Math.ceil(numPoints / workgroupSize));
146
+ this.batchSize = workgroupsPerBatch * workgroupSize;
147
+ this.numBatches = Math.ceil(numPoints / this.batchSize);
148
+ this.concurrencyBatches = Math.min(MAX_CONCURRENCY_BATCHES, this.numBatches);
149
+ this.concurrencyRuns = Math.ceil(this.numBatches / this.concurrencyBatches);
150
+ const shader = device.createShaderModule({
151
+ code: clusterWgsl(numColumns),
152
+ });
153
+ const pipeline = device.createComputePipeline({
154
+ layout: 'auto',
155
+ compute: {
156
+ module: shader,
157
+ entryPoint: 'main',
158
+ },
159
+ });
160
+ const pointsBackBuffer = new Float32Array(numColumns * this.batchSize);
161
+ const centroidsBackBuffer = new Float32Array(numColumns * numCentroids);
162
+ const uniformBackBuffer = new Uint32Array([0, numCentroids]);
163
+ const pointsBuffers = [];
164
+ const centroidsBuffer = device.createBuffer({
165
+ size: centroidsBackBuffer.byteLength,
166
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
167
+ });
168
+ const uniformBuffer = device.createBuffer({
169
+ size: 256 * this.concurrencyBatches,
170
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM,
171
+ });
172
+ const resultBuffer = device.createBuffer({
173
+ size: this.concurrencyBatches * this.batchSize * 4,
174
+ usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.STORAGE,
175
+ });
176
+ const resultReadBackBuffer = device.createBuffer({
177
+ size: this.concurrencyBatches * this.batchSize * 4,
178
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
179
+ });
180
+ const layout = pipeline.getBindGroupLayout(0);
181
+ const bindGroups = [];
182
+ for (let i = 0; i < this.concurrencyBatches; i++) {
183
+ const pointsBuffer = device.createBuffer({
184
+ size: pointsBackBuffer.byteLength,
185
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
186
+ });
187
+ pointsBuffers.push(pointsBuffer);
188
+ bindGroups.push(device.createBindGroup({
189
+ layout,
190
+ entries: [
191
+ {
192
+ binding: 0,
193
+ resource: {
194
+ buffer: uniformBuffer,
195
+ offset: i * 256,
196
+ size: 8,
197
+ },
198
+ },
199
+ {
200
+ binding: 1,
201
+ resource: pointsBuffer,
202
+ },
203
+ {
204
+ binding: 2,
205
+ resource: centroidsBuffer,
206
+ },
207
+ {
208
+ binding: 3,
209
+ resource: {
210
+ buffer: resultBuffer,
211
+ offset: i * this.batchSize * 4,
212
+ size: this.batchSize * 4,
213
+ },
214
+ },
215
+ ],
216
+ }));
217
+ }
218
+ this.resource = {
219
+ pipeline,
220
+ bindGroups,
221
+ gpuBuffers: {
222
+ uniform: uniformBuffer,
223
+ points: pointsBuffers,
224
+ centroids: centroidsBuffer,
225
+ result: resultBuffer,
226
+ resultReadBack: resultReadBackBuffer,
227
+ },
228
+ backBuffers: {
229
+ uniform: uniformBackBuffer,
230
+ points: pointsBackBuffer,
231
+ centroids: centroidsBackBuffer,
232
+ },
233
+ uploadedBatches: [],
234
+ };
235
+ logger.info(`GPU k-means kernel bootstrapped with batch ${workgroupsPerBatch}*${workgroupSize}*${this.numBatches}, concurrency: ${this.concurrencyBatches}, runs: ${this.concurrencyRuns}`);
236
+ }
237
+ async execute(points, centroids, labels) {
238
+ const { device, numPoints, numColumns, numCentroids, numBatches, batchSize, resource, concurrencyBatches, concurrencyRuns, } = this;
239
+ // upload centroid data to gpu
240
+ interleaveData(resource.backBuffers.centroids, centroids, numCentroids, 0);
241
+ device.queue.writeBuffer(resource.gpuBuffers.centroids, 0, resource.backBuffers.centroids.buffer);
242
+ for (let i = 0; i < concurrencyRuns; i++) {
243
+ const batchStart = i * concurrencyBatches;
244
+ let resultCount = 0;
245
+ for (let j = 0; j < concurrencyBatches; j++) {
246
+ const batchIndex = batchStart + j;
247
+ if (batchIndex >= numBatches) {
248
+ break;
249
+ }
250
+ const currentBatchSize = Math.min(numPoints - batchIndex * batchSize, batchSize);
251
+ resultCount += currentBatchSize;
252
+ // write this batch of point data to gpu
253
+ if (resource.uploadedBatches[j] !== batchIndex) {
254
+ interleaveData(resource.backBuffers.points, points, currentBatchSize, batchIndex * batchSize);
255
+ device.queue.writeBuffer(resource.gpuBuffers.points[j], 0, resource.backBuffers.points.buffer, 0, numColumns * currentBatchSize * 4);
256
+ resource.backBuffers.uniform[0] = currentBatchSize;
257
+ device.queue.writeBuffer(resource.gpuBuffers.uniform, 256 * j, resource.backBuffers.uniform.buffer, 0, 8);
258
+ resource.uploadedBatches[j] = batchIndex;
259
+ }
260
+ }
261
+ const encoder = device.createCommandEncoder();
262
+ const computePass = encoder.beginComputePass();
263
+ computePass.setPipeline(resource.pipeline);
264
+ for (let j = 0; j < concurrencyBatches; j++) {
265
+ const batchIndex = batchStart + j;
266
+ if (batchIndex >= numBatches) {
267
+ break;
268
+ }
269
+ const currentBatchSize = Math.min(numPoints - batchIndex * batchSize, batchSize);
270
+ const groups = Math.ceil(currentBatchSize / workgroupSize);
271
+ computePass.setBindGroup(0, resource.bindGroups[j]);
272
+ computePass.dispatchWorkgroups(groups);
273
+ }
274
+ computePass.end();
275
+ encoder.copyBufferToBuffer(resource.gpuBuffers.result, 0, resource.gpuBuffers.resultReadBack, 0, resultCount * 4);
276
+ device.queue.submit([encoder.finish()]);
277
+ await resource.gpuBuffers.resultReadBack.mapAsync(GPUMapMode.READ);
278
+ const mapped = resource.gpuBuffers.resultReadBack.getMappedRange();
279
+ labels.set(new Uint32Array(mapped, 0, resultCount), batchStart * batchSize);
280
+ resource.gpuBuffers.resultReadBack.unmap();
281
+ }
282
+ }
283
+ destroy() {
284
+ this.resource.gpuBuffers.uniform.destroy();
285
+ this.resource.gpuBuffers.centroids.destroy();
286
+ this.resource.gpuBuffers.result.destroy();
287
+ this.resource.gpuBuffers.resultReadBack.destroy();
288
+ for (const buffer of this.resource.gpuBuffers.points) {
289
+ buffer.destroy();
290
+ }
291
+ }
292
+ }
293
+ function groupLabels(labels, k) {
294
+ const clusters = [];
295
+ for (let i = 0; i < k; ++i) {
296
+ clusters[i] = [];
297
+ }
298
+ for (let i = 0; i < labels.length; ++i) {
299
+ clusters[labels[i]].push(i);
300
+ }
301
+ return clusters.map(c => new Uint32Array(c));
302
+ }
303
+ // https://github.com/playcanvas/splat-transform/blob/main/src/lib/spatial/k-means.ts
304
+ export async function kmeans(points, k, iterations, device) {
305
+ const numRows = points.length > 0 ? points[0].length : 0;
306
+ if (numRows < k) {
307
+ return {
308
+ centroids: points,
309
+ // use a typed array here so downstream code can rely on
310
+ // labels supporting subarray(), even in this early-return
311
+ // path used for very small datasets.
312
+ labels: new Uint32Array(numRows).map((_, i) => i),
313
+ };
314
+ }
315
+ const centroids = points.map(_ => new Float32Array(k));
316
+ if (points.length === 1) {
317
+ initializeCentroids1D(points[0], centroids[0]);
318
+ }
319
+ else {
320
+ initializeCentroids(points, centroids);
321
+ }
322
+ const gpuClustering = new GpuClustering(device, numRows, points.length, k);
323
+ const labels = new Uint32Array(numRows);
324
+ let converged = false;
325
+ let steps = 0;
326
+ while (!converged) {
327
+ logger.info(`kmeans iteration ${steps + 1}`);
328
+ await gpuClustering.execute(points, centroids, labels);
329
+ clusterAverage(points, groupLabels(labels, k), centroids);
330
+ steps++;
331
+ if (steps >= iterations) {
332
+ converged = true;
333
+ }
334
+ }
335
+ gpuClustering.destroy();
336
+ return {
337
+ centroids,
338
+ labels,
339
+ };
340
+ }