@multiplekex/shallot 0.1.12 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/package.json +3 -4
  2. package/src/core/builder.ts +71 -32
  3. package/src/core/component.ts +25 -11
  4. package/src/core/index.ts +14 -13
  5. package/src/core/math.ts +135 -0
  6. package/src/core/runtime.ts +0 -1
  7. package/src/core/state.ts +9 -68
  8. package/src/core/xml.ts +381 -265
  9. package/src/editor/format.ts +5 -0
  10. package/src/editor/index.ts +101 -0
  11. package/src/extras/arrows/index.ts +28 -69
  12. package/src/extras/gradient/index.ts +36 -52
  13. package/src/extras/lines/index.ts +51 -122
  14. package/src/extras/orbit/index.ts +40 -15
  15. package/src/extras/text/font.ts +546 -0
  16. package/src/extras/text/index.ts +158 -204
  17. package/src/extras/text/sdf.ts +429 -0
  18. package/src/standard/activity/index.ts +172 -0
  19. package/src/standard/compute/graph.ts +23 -23
  20. package/src/standard/compute/index.ts +76 -61
  21. package/src/standard/defaults.ts +8 -5
  22. package/src/standard/index.ts +1 -0
  23. package/src/standard/input/index.ts +30 -19
  24. package/src/standard/loading/index.ts +18 -13
  25. package/src/standard/render/bvh/blas.ts +752 -0
  26. package/src/standard/render/bvh/radix.ts +476 -0
  27. package/src/standard/render/bvh/structs.ts +167 -0
  28. package/src/standard/render/bvh/tlas.ts +886 -0
  29. package/src/standard/render/bvh/traverse.ts +467 -0
  30. package/src/standard/render/camera.ts +302 -27
  31. package/src/standard/render/data.ts +93 -0
  32. package/src/standard/render/depth.ts +117 -0
  33. package/src/standard/render/forward/index.ts +259 -0
  34. package/src/standard/render/forward/raster.ts +228 -0
  35. package/src/standard/render/index.ts +443 -70
  36. package/src/standard/render/indirect.ts +40 -0
  37. package/src/standard/render/instance.ts +214 -0
  38. package/src/standard/render/intersection.ts +72 -0
  39. package/src/standard/render/light.ts +16 -16
  40. package/src/standard/render/mesh/index.ts +67 -75
  41. package/src/standard/render/mesh/unified.ts +96 -0
  42. package/src/standard/render/{transparent.ts → overlay.ts} +14 -15
  43. package/src/standard/render/pass.ts +10 -4
  44. package/src/standard/render/postprocess.ts +142 -64
  45. package/src/standard/render/ray.ts +61 -0
  46. package/src/standard/render/scene.ts +38 -164
  47. package/src/standard/render/shaders.ts +484 -0
  48. package/src/standard/render/surface/compile.ts +3 -10
  49. package/src/standard/render/surface/index.ts +60 -30
  50. package/src/standard/render/surface/noise.ts +45 -0
  51. package/src/standard/render/surface/structs.ts +60 -19
  52. package/src/standard/render/surface/wgsl.ts +573 -0
  53. package/src/standard/render/triangle.ts +84 -0
  54. package/src/standard/transforms/index.ts +4 -6
  55. package/src/standard/tween/index.ts +10 -1
  56. package/src/standard/tween/sequence.ts +24 -16
  57. package/src/standard/tween/tween.ts +67 -16
  58. package/src/core/types.ts +0 -37
  59. package/src/standard/compute/inspect.ts +0 -201
  60. package/src/standard/compute/pass.ts +0 -23
  61. package/src/standard/compute/timing.ts +0 -139
  62. package/src/standard/render/forward.ts +0 -273
@@ -0,0 +1,886 @@
1
+ import { MAX_ENTITIES } from "../../../core";
2
+ import type { ComputeNode, ExecutionContext } from "../../compute";
3
+ import { createRadixSort } from "./radix";
4
+ import {
5
+ TREE_NODE_STRUCT_WGSL,
6
+ BVH_NODE_STRUCT_WGSL,
7
+ LEAF_FLAG_WGSL,
8
+ TREE_NODE_SIZE,
9
+ BVH_NODE_SIZE,
10
+ } from "./structs";
11
+
12
+ const WORKGROUP_SIZE = 256;
13
+ const MAX_TREE_DEPTH = Math.ceil(Math.log2(MAX_ENTITIES)) + 1;
14
+
15
+ export interface TLASBuffers {
16
+ treeNodes: GPUBuffer;
17
+ bvhNodes: GPUBuffer;
18
+ mortonCodes: GPUBuffer;
19
+ instanceIds: GPUBuffer;
20
+ entityIds: GPUBuffer;
21
+ sceneBounds: GPUBuffer;
22
+ parentIndices: GPUBuffer;
23
+ boundsFlags: GPUBuffer;
24
+ }
25
+
26
+ export function createTLASBuffers(device: GPUDevice): TLASBuffers {
27
+ return {
28
+ treeNodes: device.createBuffer({
29
+ label: "tlas-tree-nodes",
30
+ size: 2 * MAX_ENTITIES * TREE_NODE_SIZE,
31
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
32
+ }),
33
+ bvhNodes: device.createBuffer({
34
+ label: "tlas-bvh-nodes",
35
+ size: MAX_ENTITIES * BVH_NODE_SIZE,
36
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
37
+ }),
38
+ mortonCodes: device.createBuffer({
39
+ label: "tlas-morton-codes",
40
+ size: MAX_ENTITIES * 4,
41
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
42
+ }),
43
+ instanceIds: device.createBuffer({
44
+ label: "tlas-instance-ids",
45
+ size: MAX_ENTITIES * 4,
46
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
47
+ }),
48
+ entityIds: device.createBuffer({
49
+ label: "tlas-entity-ids",
50
+ size: MAX_ENTITIES * 4,
51
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
52
+ }),
53
+ sceneBounds: device.createBuffer({
54
+ label: "tlas-scene-bounds",
55
+ size: 32,
56
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
57
+ }),
58
+ parentIndices: device.createBuffer({
59
+ label: "tlas-parent-indices",
60
+ size: 2 * MAX_ENTITIES * 4,
61
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
62
+ }),
63
+ boundsFlags: device.createBuffer({
64
+ label: "tlas-bounds-flags",
65
+ size: MAX_ENTITIES * 4,
66
+ usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
67
+ }),
68
+ };
69
+ }
70
+
71
+ const INSTANCE_AABB_STRUCT_WGSL = /* wgsl */ `
72
+ struct InstanceAABB {
73
+ minX: f32,
74
+ minY: f32,
75
+ minZ: f32,
76
+ _pad0: u32,
77
+ maxX: f32,
78
+ maxY: f32,
79
+ maxZ: f32,
80
+ _pad1: u32,
81
+ }`;
82
+
83
+ const SCENE_BOUNDS_STRUCT_WGSL = /* wgsl */ `
84
+ struct SceneBounds {
85
+ minX: atomic<i32>,
86
+ minY: atomic<i32>,
87
+ minZ: atomic<i32>,
88
+ _pad0: u32,
89
+ maxX: atomic<i32>,
90
+ maxY: atomic<i32>,
91
+ maxZ: atomic<i32>,
92
+ _pad1: u32,
93
+ }`;
94
+
95
+ const SCENE_BOUNDS_READ_STRUCT_WGSL = /* wgsl */ `
96
+ struct SceneBounds {
97
+ minX: i32,
98
+ minY: i32,
99
+ minZ: i32,
100
+ _pad0: u32,
101
+ maxX: i32,
102
+ maxY: i32,
103
+ maxZ: i32,
104
+ _pad1: u32,
105
+ }`;
106
+
107
+ const FLOAT_INT_CONVERSION_WGSL = /* wgsl */ `
108
+ fn floatToSortableInt(f: f32) -> i32 {
109
+ let bits = bitcast<i32>(f);
110
+ let mask = (bits >> 31) & 0x7FFFFFFF;
111
+ return bits ^ mask;
112
+ }
113
+
114
+ fn sortableIntToFloat(i: i32) -> f32 {
115
+ let mask = (i >> 31) & 0x7FFFFFFF;
116
+ return bitcast<f32>(i ^ mask);
117
+ }`;
118
+
119
+ const MORTON_CODE_WGSL = /* wgsl */ `
120
+ fn expandBits(v: u32) -> u32 {
121
+ var x = v & 0x3ffu;
122
+ x = (x | (x << 16u)) & 0x030000ffu;
123
+ x = (x | (x << 8u)) & 0x0300f00fu;
124
+ x = (x | (x << 4u)) & 0x030c30c3u;
125
+ x = (x | (x << 2u)) & 0x09249249u;
126
+ return x;
127
+ }
128
+
129
+ fn mortonCode(x: u32, y: u32, z: u32) -> u32 {
130
+ return (expandBits(x) << 2u) | (expandBits(y) << 1u) | expandBits(z);
131
+ }`;
132
+
133
+ const CLZ_WGSL = /* wgsl */ `
134
+ fn clz(x: u32) -> u32 {
135
+ if (x == 0u) { return 32u; }
136
+ var n = 0u;
137
+ var v = x;
138
+ if ((v & 0xffff0000u) == 0u) { n += 16u; v <<= 16u; }
139
+ if ((v & 0xff000000u) == 0u) { n += 8u; v <<= 8u; }
140
+ if ((v & 0xf0000000u) == 0u) { n += 4u; v <<= 4u; }
141
+ if ((v & 0xc0000000u) == 0u) { n += 2u; v <<= 2u; }
142
+ if ((v & 0x80000000u) == 0u) { n += 1u; }
143
+ return n;
144
+ }`;
145
+
146
+ const LEAF_FUNCTIONS_WGSL = /* wgsl */ `
147
+ fn isLeaf(child: u32) -> bool {
148
+ return (child & LEAF_FLAG) != 0u;
149
+ }
150
+
151
+ fn leafIndex(child: u32) -> u32 {
152
+ return child & ~LEAF_FLAG;
153
+ }`;
154
+
155
+ const boundsShader = /* wgsl */ `
156
+ ${INSTANCE_AABB_STRUCT_WGSL}
157
+ ${SCENE_BOUNDS_STRUCT_WGSL}
158
+
159
+ @group(0) @binding(0) var<storage, read> instanceAABBs: array<InstanceAABB>;
160
+ @group(0) @binding(1) var<storage, read> instanceCount: array<u32>;
161
+ @group(0) @binding(2) var<storage, read_write> sceneBounds: SceneBounds;
162
+ @group(0) @binding(3) var<storage, read> entityIds: array<u32>;
163
+
164
+ var<workgroup> sharedMin: array<vec3<f32>, ${WORKGROUP_SIZE}>;
165
+ var<workgroup> sharedMax: array<vec3<f32>, ${WORKGROUP_SIZE}>;
166
+
167
+ ${FLOAT_INT_CONVERSION_WGSL}
168
+
169
+ @compute @workgroup_size(${WORKGROUP_SIZE})
170
+ fn main(
171
+ @builtin(global_invocation_id) gid: vec3<u32>,
172
+ @builtin(local_invocation_id) lid: vec3<u32>,
173
+ ) {
174
+ let count = instanceCount[0];
175
+ let tid = gid.x;
176
+ let localId = lid.x;
177
+
178
+ var localMin = vec3<f32>(1e30, 1e30, 1e30);
179
+ var localMax = vec3<f32>(-1e30, -1e30, -1e30);
180
+
181
+ if (tid < count) {
182
+ let eid = entityIds[tid];
183
+ let aabb = instanceAABBs[eid];
184
+ localMin = vec3<f32>(aabb.minX, aabb.minY, aabb.minZ);
185
+ localMax = vec3<f32>(aabb.maxX, aabb.maxY, aabb.maxZ);
186
+ }
187
+
188
+ sharedMin[localId] = localMin;
189
+ sharedMax[localId] = localMax;
190
+ workgroupBarrier();
191
+
192
+ for (var stride = ${WORKGROUP_SIZE}u / 2u; stride > 0u; stride >>= 1u) {
193
+ if (localId < stride) {
194
+ sharedMin[localId] = min(sharedMin[localId], sharedMin[localId + stride]);
195
+ sharedMax[localId] = max(sharedMax[localId], sharedMax[localId + stride]);
196
+ }
197
+ workgroupBarrier();
198
+ }
199
+
200
+ if (localId == 0u) {
201
+ let wgMin = sharedMin[0];
202
+ let wgMax = sharedMax[0];
203
+
204
+ atomicMin(&sceneBounds.minX, floatToSortableInt(wgMin.x));
205
+ atomicMin(&sceneBounds.minY, floatToSortableInt(wgMin.y));
206
+ atomicMin(&sceneBounds.minZ, floatToSortableInt(wgMin.z));
207
+ atomicMax(&sceneBounds.maxX, floatToSortableInt(wgMax.x));
208
+ atomicMax(&sceneBounds.maxY, floatToSortableInt(wgMax.y));
209
+ atomicMax(&sceneBounds.maxZ, floatToSortableInt(wgMax.z));
210
+ }
211
+ }
212
+ `;
213
+
214
+ const mortonShader = /* wgsl */ `
215
+ ${INSTANCE_AABB_STRUCT_WGSL}
216
+ ${SCENE_BOUNDS_READ_STRUCT_WGSL}
217
+
218
+ @group(0) @binding(0) var<storage, read> instanceAABBs: array<InstanceAABB>;
219
+ @group(0) @binding(1) var<storage, read> instanceCount: array<u32>;
220
+ @group(0) @binding(2) var<storage, read> sceneBounds: SceneBounds;
221
+ @group(0) @binding(3) var<storage, read_write> mortonCodes: array<u32>;
222
+ @group(0) @binding(4) var<storage, read_write> instanceIds: array<u32>;
223
+ @group(0) @binding(5) var<storage, read> entityIds: array<u32>;
224
+
225
+ ${FLOAT_INT_CONVERSION_WGSL}
226
+ ${MORTON_CODE_WGSL}
227
+
228
+ @compute @workgroup_size(${WORKGROUP_SIZE})
229
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
230
+ let tid = gid.x;
231
+ if (tid >= ${MAX_ENTITIES}u) { return; }
232
+
233
+ let count = instanceCount[0];
234
+ if (tid >= count) {
235
+ mortonCodes[tid] = 0xFFFFFFFFu;
236
+ instanceIds[tid] = 0u;
237
+ return;
238
+ }
239
+
240
+ let eid = entityIds[tid];
241
+ let aabb = instanceAABBs[eid];
242
+ let centroid = vec3<f32>(
243
+ (aabb.minX + aabb.maxX) * 0.5,
244
+ (aabb.minY + aabb.maxY) * 0.5,
245
+ (aabb.minZ + aabb.maxZ) * 0.5
246
+ );
247
+
248
+ let boundsMin = vec3<f32>(
249
+ sortableIntToFloat(sceneBounds.minX),
250
+ sortableIntToFloat(sceneBounds.minY),
251
+ sortableIntToFloat(sceneBounds.minZ)
252
+ );
253
+ let boundsMax = vec3<f32>(
254
+ sortableIntToFloat(sceneBounds.maxX),
255
+ sortableIntToFloat(sceneBounds.maxY),
256
+ sortableIntToFloat(sceneBounds.maxZ)
257
+ );
258
+
259
+ let size = boundsMax - boundsMin;
260
+ let safeSize = max(size, vec3<f32>(1e-6, 1e-6, 1e-6));
261
+
262
+ let normalized = (centroid - boundsMin) / safeSize;
263
+ let clamped = clamp(normalized, vec3<f32>(0.0), vec3<f32>(1.0));
264
+
265
+ let quantized = vec3<u32>(clamped * 1023.0);
266
+
267
+ mortonCodes[tid] = mortonCode(quantized.x, quantized.y, quantized.z);
268
+ instanceIds[tid] = eid;
269
+ }
270
+ `;
271
+
272
+ const treeShader = /* wgsl */ `
273
+ ${TREE_NODE_STRUCT_WGSL}
274
+ ${LEAF_FLAG_WGSL}
275
+
276
+ @group(0) @binding(0) var<storage, read> mortonCodes: array<u32>;
277
+ @group(0) @binding(1) var<storage, read> instanceCount: array<u32>;
278
+ @group(0) @binding(2) var<storage, read_write> treeNodes: array<TreeNode>;
279
+ @group(0) @binding(3) var<storage, read_write> parentIndices: array<u32>;
280
+
281
+ ${CLZ_WGSL}
282
+
283
+ fn delta(i: i32, j: i32, n: i32) -> i32 {
284
+ if (j < 0 || j >= n) {
285
+ return -1;
286
+ }
287
+ let codeI = mortonCodes[i];
288
+ let codeJ = mortonCodes[j];
289
+ if (codeI == codeJ) {
290
+ return i32(clz(u32(i) ^ u32(j))) + 32;
291
+ }
292
+ return i32(clz(codeI ^ codeJ));
293
+ }
294
+
295
+ @compute @workgroup_size(${WORKGROUP_SIZE})
296
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
297
+ let n = i32(instanceCount[0]);
298
+ let i = i32(gid.x);
299
+
300
+ if (i >= n - 1) {
301
+ return;
302
+ }
303
+
304
+ var first: i32;
305
+ var last: i32;
306
+
307
+ if (i == 0) {
308
+ first = 0;
309
+ last = n - 1;
310
+ } else {
311
+ let d = select(-1, 1, delta(i, i + 1, n) > delta(i, i - 1, n));
312
+
313
+ let deltaMin = delta(i, i - d, n);
314
+
315
+ var lmax = 2;
316
+ for (var iter = 0; iter < ${MAX_TREE_DEPTH}; iter++) {
317
+ if (delta(i, i + lmax * d, n) <= deltaMin) { break; }
318
+ lmax *= 2;
319
+ }
320
+
321
+ var l = 0;
322
+ var t = lmax / 2;
323
+ for (var iter2 = 0; iter2 < ${MAX_TREE_DEPTH}; iter2++) {
324
+ if (t < 1) { break; }
325
+ if (delta(i, i + (l + t) * d, n) > deltaMin) {
326
+ l += t;
327
+ }
328
+ t /= 2;
329
+ }
330
+
331
+ let j = i + l * d;
332
+ first = min(i, j);
333
+ last = max(i, j);
334
+ }
335
+
336
+ let firstCode = mortonCodes[first];
337
+ let lastCode = mortonCodes[last];
338
+
339
+ var gamma: i32;
340
+ if (firstCode == lastCode) {
341
+ gamma = (first + last) / 2;
342
+ } else {
343
+ let deltaNode = i32(clz(firstCode ^ lastCode));
344
+
345
+ var split = first;
346
+ var stride = last - first;
347
+
348
+ for (var iter3 = 0; iter3 < ${MAX_TREE_DEPTH}; iter3++) {
349
+ stride = (stride + 1) / 2;
350
+ let middle = split + stride;
351
+
352
+ if (middle < last) {
353
+ let splitCode = mortonCodes[middle];
354
+ let splitDelta = i32(clz(firstCode ^ splitCode));
355
+
356
+ if (splitDelta > deltaNode) {
357
+ split = middle;
358
+ }
359
+ }
360
+
361
+ if (stride <= 1) {
362
+ break;
363
+ }
364
+ }
365
+
366
+ gamma = split;
367
+ }
368
+
369
+ let leftIsLeaf = first == gamma;
370
+ let rightIsLeaf = last == gamma + 1;
371
+
372
+ var node: TreeNode;
373
+ node.minX = 1e30;
374
+ node.minY = 1e30;
375
+ node.minZ = 1e30;
376
+ node.maxX = -1e30;
377
+ node.maxY = -1e30;
378
+ node.maxZ = -1e30;
379
+
380
+ if (leftIsLeaf) {
381
+ node.leftChild = u32(gamma) | LEAF_FLAG;
382
+ parentIndices[u32(gamma)] = u32(i);
383
+ } else {
384
+ node.leftChild = u32(gamma);
385
+ parentIndices[u32(n) + u32(gamma)] = u32(i);
386
+ }
387
+
388
+ if (rightIsLeaf) {
389
+ node.rightChild = u32(gamma + 1) | LEAF_FLAG;
390
+ parentIndices[u32(gamma + 1)] = u32(i);
391
+ } else {
392
+ node.rightChild = u32(gamma + 1);
393
+ parentIndices[u32(n) + u32(gamma + 1)] = u32(i);
394
+ }
395
+
396
+ treeNodes[i] = node;
397
+ }
398
+ `;
399
+
400
+ const propagateShader = /* wgsl */ `
401
+ ${INSTANCE_AABB_STRUCT_WGSL}
402
+ ${LEAF_FLAG_WGSL}
403
+
404
+ const BOUNDS_SENTINEL: u32 = 0x7f800000u;
405
+
406
+ @group(0) @binding(0) var<storage, read> instanceAABBs: array<InstanceAABB>;
407
+ @group(0) @binding(1) var<storage, read> instanceIds: array<u32>;
408
+ @group(0) @binding(2) var<storage, read> instanceCount: array<u32>;
409
+ @group(0) @binding(3) var<storage, read_write> bvhNodesRaw: array<atomic<u32>>;
410
+ @group(0) @binding(4) var<storage, read_write> boundsFlags: array<atomic<u32>>;
411
+ @group(0) @binding(5) var<storage, read> parentIndices: array<u32>;
412
+
413
+ ${LEAF_FUNCTIONS_WGSL}
414
+
415
+ fn getInstanceBounds(leafIdx: u32) -> array<vec3<f32>, 2> {
416
+ let eid = instanceIds[leafIdx];
417
+ let aabb = instanceAABBs[eid];
418
+ return array<vec3<f32>, 2>(
419
+ vec3<f32>(aabb.minX, aabb.minY, aabb.minZ),
420
+ vec3<f32>(aabb.maxX, aabb.maxY, aabb.maxZ)
421
+ );
422
+ }
423
+
424
+ fn getParent(nodeIdx: u32, isLeafNode: bool, n: u32) -> u32 {
425
+ if (isLeafNode) {
426
+ return parentIndices[nodeIdx];
427
+ } else {
428
+ return parentIndices[n + nodeIdx];
429
+ }
430
+ }
431
+
432
+ fn nodeBase(idx: u32) -> u32 {
433
+ return idx * 8u;
434
+ }
435
+
436
+ fn readChildBounds(childIdx: u32) -> array<vec3<f32>, 2> {
437
+ let base = nodeBase(childIdx);
438
+ let minX = bitcast<f32>(atomicLoad(&bvhNodesRaw[base + 0u]));
439
+ let minY = bitcast<f32>(atomicLoad(&bvhNodesRaw[base + 1u]));
440
+ let minZ = bitcast<f32>(atomicLoad(&bvhNodesRaw[base + 2u]));
441
+ let maxX = bitcast<f32>(atomicLoad(&bvhNodesRaw[base + 4u]));
442
+ let maxY = bitcast<f32>(atomicLoad(&bvhNodesRaw[base + 5u]));
443
+ let maxZ = bitcast<f32>(atomicLoad(&bvhNodesRaw[base + 6u]));
444
+ return array<vec3<f32>, 2>(vec3(minX, minY, minZ), vec3(maxX, maxY, maxZ));
445
+ }
446
+
447
+ fn writeBounds(nodeIdx: u32, minB: vec3<f32>, maxB: vec3<f32>) {
448
+ let base = nodeBase(nodeIdx);
449
+ atomicStore(&bvhNodesRaw[base + 0u], bitcast<u32>(minB.x));
450
+ atomicStore(&bvhNodesRaw[base + 1u], bitcast<u32>(minB.y));
451
+ atomicStore(&bvhNodesRaw[base + 2u], bitcast<u32>(minB.z));
452
+ atomicStore(&bvhNodesRaw[base + 4u], bitcast<u32>(maxB.x));
453
+ atomicStore(&bvhNodesRaw[base + 5u], bitcast<u32>(maxB.y));
454
+ atomicStore(&bvhNodesRaw[base + 6u], bitcast<u32>(maxB.z));
455
+ }
456
+
457
+ fn readLeftChild(nodeIdx: u32) -> u32 {
458
+ return atomicLoad(&bvhNodesRaw[nodeBase(nodeIdx) + 3u]);
459
+ }
460
+
461
+ fn readRightChild(nodeIdx: u32) -> u32 {
462
+ return atomicLoad(&bvhNodesRaw[nodeBase(nodeIdx) + 7u]);
463
+ }
464
+
465
+ fn writeLeafBounds(leafIdx: u32, n: u32, minB: vec3<f32>, maxB: vec3<f32>) {
466
+ let leafNodeIdx = n - 1u + leafIdx;
467
+ let base = leafNodeIdx * 8u;
468
+ atomicStore(&bvhNodesRaw[base + 0u], bitcast<u32>(minB.x));
469
+ atomicStore(&bvhNodesRaw[base + 1u], bitcast<u32>(minB.y));
470
+ atomicStore(&bvhNodesRaw[base + 2u], bitcast<u32>(minB.z));
471
+ atomicStore(&bvhNodesRaw[base + 4u], bitcast<u32>(maxB.x));
472
+ atomicStore(&bvhNodesRaw[base + 5u], bitcast<u32>(maxB.y));
473
+ atomicStore(&bvhNodesRaw[base + 6u], bitcast<u32>(maxB.z));
474
+ }
475
+
476
+ @compute @workgroup_size(${WORKGROUP_SIZE})
477
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
478
+ let n = instanceCount[0];
479
+ let leafIdx = gid.x;
480
+
481
+ if (leafIdx >= n) {
482
+ return;
483
+ }
484
+
485
+ let bounds = getInstanceBounds(leafIdx);
486
+ writeLeafBounds(leafIdx, n, bounds[0], bounds[1]);
487
+
488
+ var current = leafIdx;
489
+ var isLeafNode = true;
490
+
491
+ for (var iter = 0u; iter < 64u; iter++) {
492
+ let parent = getParent(current, isLeafNode, n);
493
+
494
+ let oldFlag = atomicAdd(&boundsFlags[parent], 1u);
495
+ if (oldFlag == 0u) {
496
+ return;
497
+ }
498
+
499
+ let left = readLeftChild(parent);
500
+ let right = readRightChild(parent);
501
+
502
+ var leftMin: vec3<f32>;
503
+ var leftMax: vec3<f32>;
504
+ var rightMin: vec3<f32>;
505
+ var rightMax: vec3<f32>;
506
+
507
+ if (isLeaf(left)) {
508
+ let leftBounds = getInstanceBounds(leafIndex(left));
509
+ leftMin = leftBounds[0];
510
+ leftMax = leftBounds[1];
511
+ } else {
512
+ let leftBounds = readChildBounds(left);
513
+ leftMin = leftBounds[0];
514
+ leftMax = leftBounds[1];
515
+ }
516
+
517
+ if (isLeaf(right)) {
518
+ let rightBounds = getInstanceBounds(leafIndex(right));
519
+ rightMin = rightBounds[0];
520
+ rightMax = rightBounds[1];
521
+ } else {
522
+ let rightBounds = readChildBounds(right);
523
+ rightMin = rightBounds[0];
524
+ rightMax = rightBounds[1];
525
+ }
526
+
527
+ let newMin = min(leftMin, rightMin);
528
+ let newMax = max(leftMax, rightMax);
529
+
530
+ writeBounds(parent, newMin, newMax);
531
+
532
+ current = parent;
533
+ isLeafNode = false;
534
+
535
+ if (parent == 0u) {
536
+ break;
537
+ }
538
+ }
539
+ }
540
+ `;
541
+
542
+ const collapseShader = /* wgsl */ `
543
+ ${TREE_NODE_STRUCT_WGSL}
544
+ ${BVH_NODE_STRUCT_WGSL}
545
+ ${LEAF_FLAG_WGSL}
546
+
547
+ const INVALID_NODE: u32 = 0xFFFFFFFFu;
548
+
549
+ @group(0) @binding(0) var<storage, read> treeNodes: array<TreeNode>;
550
+ @group(0) @binding(1) var<storage, read> instanceCount: array<u32>;
551
+ @group(0) @binding(2) var<storage, read> parentIndices: array<u32>;
552
+ @group(0) @binding(3) var<storage, read_write> bvhNodes: array<BVHNode>;
553
+
554
+ fn isLeaf(child: u32) -> bool {
555
+ return (child & LEAF_FLAG) != 0u;
556
+ }
557
+
558
+ fn leafIndex(child: u32) -> u32 {
559
+ return child & ~LEAF_FLAG;
560
+ }
561
+
562
+ fn getDepth(nodeIdx: u32, n: u32) -> u32 {
563
+ var depth = 0u;
564
+ var current = nodeIdx;
565
+ for (var iter = 0u; iter < ${MAX_TREE_DEPTH}u; iter++) {
566
+ if (current == 0u) { break; }
567
+ current = parentIndices[n + current];
568
+ depth++;
569
+ }
570
+ return depth;
571
+ }
572
+
573
+ fn getChildBounds(child: u32, n: u32) -> array<vec3<f32>, 2> {
574
+ if (isLeaf(child)) {
575
+ let leafNodeIdx = n - 1u + leafIndex(child);
576
+ let node = treeNodes[leafNodeIdx];
577
+ return array<vec3<f32>, 2>(
578
+ vec3<f32>(node.minX, node.minY, node.minZ),
579
+ vec3<f32>(node.maxX, node.maxY, node.maxZ)
580
+ );
581
+ } else {
582
+ let node = treeNodes[child];
583
+ return array<vec3<f32>, 2>(
584
+ vec3<f32>(node.minX, node.minY, node.minZ),
585
+ vec3<f32>(node.maxX, node.maxY, node.maxZ)
586
+ );
587
+ }
588
+ }
589
+
590
+ @compute @workgroup_size(${WORKGROUP_SIZE})
591
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
592
+ let n = instanceCount[0];
593
+ let nodeIdx = gid.x;
594
+
595
+ if (n == 1u) {
596
+ if (nodeIdx == 0u) {
597
+ var out: BVHNode;
598
+ out.child0 = 0u | LEAF_FLAG;
599
+ out.child1 = INVALID_NODE;
600
+ out.child2 = INVALID_NODE;
601
+ out.child3 = INVALID_NODE;
602
+
603
+ let bounds = getChildBounds(0u | LEAF_FLAG, n);
604
+ out.c0_minX = bounds[0].x; out.c0_minY = bounds[0].y; out.c0_minZ = bounds[0].z;
605
+ out.c0_maxX = bounds[1].x; out.c0_maxY = bounds[1].y; out.c0_maxZ = bounds[1].z;
606
+ out.c1_minX = 1e30; out.c1_minY = 1e30; out.c1_minZ = 1e30;
607
+ out.c1_maxX = -1e30; out.c1_maxY = -1e30; out.c1_maxZ = -1e30;
608
+ out.c2_minX = 1e30; out.c2_minY = 1e30; out.c2_minZ = 1e30;
609
+ out.c2_maxX = -1e30; out.c2_maxY = -1e30; out.c2_maxZ = -1e30;
610
+ out.c3_minX = 1e30; out.c3_minY = 1e30; out.c3_minZ = 1e30;
611
+ out.c3_maxX = -1e30; out.c3_maxY = -1e30; out.c3_maxZ = -1e30;
612
+
613
+ bvhNodes[0] = out;
614
+ }
615
+ return;
616
+ }
617
+
618
+ if (nodeIdx >= n - 1u) {
619
+ return;
620
+ }
621
+
622
+ let depth = getDepth(nodeIdx, n);
623
+ let node = treeNodes[nodeIdx];
624
+ let left = node.leftChild;
625
+ let right = node.rightChild;
626
+
627
+ var out: BVHNode;
628
+
629
+ out.child0 = INVALID_NODE;
630
+ out.child1 = INVALID_NODE;
631
+ out.child2 = INVALID_NODE;
632
+ out.child3 = INVALID_NODE;
633
+ out.c0_minX = 1e30; out.c0_minY = 1e30; out.c0_minZ = 1e30;
634
+ out.c0_maxX = -1e30; out.c0_maxY = -1e30; out.c0_maxZ = -1e30;
635
+ out.c1_minX = 1e30; out.c1_minY = 1e30; out.c1_minZ = 1e30;
636
+ out.c1_maxX = -1e30; out.c1_maxY = -1e30; out.c1_maxZ = -1e30;
637
+ out.c2_minX = 1e30; out.c2_minY = 1e30; out.c2_minZ = 1e30;
638
+ out.c2_maxX = -1e30; out.c2_maxY = -1e30; out.c2_maxZ = -1e30;
639
+ out.c3_minX = 1e30; out.c3_minY = 1e30; out.c3_minZ = 1e30;
640
+ out.c3_maxX = -1e30; out.c3_maxY = -1e30; out.c3_maxZ = -1e30;
641
+
642
+ if ((depth & 1u) != 0u) {
643
+ out.child0 = left;
644
+ let bounds0 = getChildBounds(left, n);
645
+ out.c0_minX = bounds0[0].x; out.c0_minY = bounds0[0].y; out.c0_minZ = bounds0[0].z;
646
+ out.c0_maxX = bounds0[1].x; out.c0_maxY = bounds0[1].y; out.c0_maxZ = bounds0[1].z;
647
+
648
+ out.child1 = right;
649
+ let bounds1 = getChildBounds(right, n);
650
+ out.c1_minX = bounds1[0].x; out.c1_minY = bounds1[0].y; out.c1_minZ = bounds1[0].z;
651
+ out.c1_maxX = bounds1[1].x; out.c1_maxY = bounds1[1].y; out.c1_maxZ = bounds1[1].z;
652
+
653
+ bvhNodes[nodeIdx] = out;
654
+ return;
655
+ }
656
+
657
+ if (isLeaf(left)) {
658
+ out.child0 = left;
659
+ let bounds = getChildBounds(left, n);
660
+ out.c0_minX = bounds[0].x; out.c0_minY = bounds[0].y; out.c0_minZ = bounds[0].z;
661
+ out.c0_maxX = bounds[1].x; out.c0_maxY = bounds[1].y; out.c0_maxZ = bounds[1].z;
662
+ } else {
663
+ let leftNode = treeNodes[left];
664
+ let ll = leftNode.leftChild;
665
+ let lr = leftNode.rightChild;
666
+
667
+ out.child0 = ll;
668
+ let bounds0 = getChildBounds(ll, n);
669
+ out.c0_minX = bounds0[0].x; out.c0_minY = bounds0[0].y; out.c0_minZ = bounds0[0].z;
670
+ out.c0_maxX = bounds0[1].x; out.c0_maxY = bounds0[1].y; out.c0_maxZ = bounds0[1].z;
671
+
672
+ out.child1 = lr;
673
+ let bounds1 = getChildBounds(lr, n);
674
+ out.c1_minX = bounds1[0].x; out.c1_minY = bounds1[0].y; out.c1_minZ = bounds1[0].z;
675
+ out.c1_maxX = bounds1[1].x; out.c1_maxY = bounds1[1].y; out.c1_maxZ = bounds1[1].z;
676
+ }
677
+
678
+ if (isLeaf(right)) {
679
+ out.child2 = right;
680
+ let bounds = getChildBounds(right, n);
681
+ out.c2_minX = bounds[0].x; out.c2_minY = bounds[0].y; out.c2_minZ = bounds[0].z;
682
+ out.c2_maxX = bounds[1].x; out.c2_maxY = bounds[1].y; out.c2_maxZ = bounds[1].z;
683
+ } else {
684
+ let rightNode = treeNodes[right];
685
+ let rl = rightNode.leftChild;
686
+ let rr = rightNode.rightChild;
687
+
688
+ out.child2 = rl;
689
+ let bounds2 = getChildBounds(rl, n);
690
+ out.c2_minX = bounds2[0].x; out.c2_minY = bounds2[0].y; out.c2_minZ = bounds2[0].z;
691
+ out.c2_maxX = bounds2[1].x; out.c2_maxY = bounds2[1].y; out.c2_maxZ = bounds2[1].z;
692
+
693
+ out.child3 = rr;
694
+ let bounds3 = getChildBounds(rr, n);
695
+ out.c3_minX = bounds3[0].x; out.c3_minY = bounds3[0].y; out.c3_minZ = bounds3[0].z;
696
+ out.c3_maxX = bounds3[1].x; out.c3_maxY = bounds3[1].y; out.c3_maxZ = bounds3[1].z;
697
+ }
698
+
699
+ bvhNodes[nodeIdx] = out;
700
+ }
701
+ `;
702
+
703
+ interface TLASPipelines {
704
+ bounds: GPUComputePipeline;
705
+ morton: GPUComputePipeline;
706
+ tree: GPUComputePipeline;
707
+ propagate: GPUComputePipeline;
708
+ collapse: GPUComputePipeline;
709
+ }
710
+
711
+ interface TLASBindGroups {
712
+ bounds: GPUBindGroup;
713
+ morton: GPUBindGroup;
714
+ tree: GPUBindGroup;
715
+ propagate: GPUBindGroup;
716
+ collapse: GPUBindGroup;
717
+ }
718
+
719
+ export interface TLASConfig {
720
+ instanceAABBs: GPUBuffer;
721
+ instanceCount: GPUBuffer;
722
+ tlas: TLASBuffers;
723
+ getEntityCount: () => number;
724
+ }
725
+
726
+ export function createTLASNode(config: TLASConfig): ComputeNode {
727
+ let pipelines: TLASPipelines | null = null;
728
+ let bindGroups: TLASBindGroups | null = null;
729
+ let radixSort: Awaited<ReturnType<typeof createRadixSort>> | null = null;
730
+
731
+ return {
732
+ id: "tlas",
733
+ sync: true,
734
+ inputs: [
735
+ { id: "instance-aabbs", access: "read" },
736
+ { id: "instance-count", access: "read" },
737
+ ],
738
+ outputs: [
739
+ { id: "tlas-bvh-nodes", access: "write" },
740
+ { id: "tlas-morton-codes", access: "write" },
741
+ { id: "tlas-instance-ids", access: "write" },
742
+ ],
743
+
744
+ async prepare(device: GPUDevice) {
745
+ const [boundsModule, mortonModule, treeModule, propagateModule, collapseModule] =
746
+ await Promise.all([
747
+ device.createShaderModule({ code: boundsShader }),
748
+ device.createShaderModule({ code: mortonShader }),
749
+ device.createShaderModule({ code: treeShader }),
750
+ device.createShaderModule({ code: propagateShader }),
751
+ device.createShaderModule({ code: collapseShader }),
752
+ ]);
753
+
754
+ const [bounds, morton, tree, propagate, collapse, sort] = await Promise.all([
755
+ device.createComputePipelineAsync({
756
+ layout: "auto",
757
+ compute: { module: boundsModule, entryPoint: "main" },
758
+ }),
759
+ device.createComputePipelineAsync({
760
+ layout: "auto",
761
+ compute: { module: mortonModule, entryPoint: "main" },
762
+ }),
763
+ device.createComputePipelineAsync({
764
+ layout: "auto",
765
+ compute: { module: treeModule, entryPoint: "main" },
766
+ }),
767
+ device.createComputePipelineAsync({
768
+ layout: "auto",
769
+ compute: { module: propagateModule, entryPoint: "main" },
770
+ }),
771
+ device.createComputePipelineAsync({
772
+ layout: "auto",
773
+ compute: { module: collapseModule, entryPoint: "main" },
774
+ }),
775
+ createRadixSort(device, {
776
+ keys: config.tlas.mortonCodes,
777
+ values: config.tlas.instanceIds,
778
+ count: MAX_ENTITIES,
779
+ }),
780
+ ]);
781
+
782
+ pipelines = { bounds, morton, tree, propagate, collapse };
783
+ radixSort = sort;
784
+
785
+ bindGroups = {
786
+ bounds: device.createBindGroup({
787
+ layout: pipelines.bounds.getBindGroupLayout(0),
788
+ entries: [
789
+ { binding: 0, resource: { buffer: config.instanceAABBs } },
790
+ { binding: 1, resource: { buffer: config.instanceCount } },
791
+ { binding: 2, resource: { buffer: config.tlas.sceneBounds } },
792
+ { binding: 3, resource: { buffer: config.tlas.entityIds } },
793
+ ],
794
+ }),
795
+ morton: device.createBindGroup({
796
+ layout: pipelines.morton.getBindGroupLayout(0),
797
+ entries: [
798
+ { binding: 0, resource: { buffer: config.instanceAABBs } },
799
+ { binding: 1, resource: { buffer: config.instanceCount } },
800
+ { binding: 2, resource: { buffer: config.tlas.sceneBounds } },
801
+ { binding: 3, resource: { buffer: config.tlas.mortonCodes } },
802
+ { binding: 4, resource: { buffer: config.tlas.instanceIds } },
803
+ { binding: 5, resource: { buffer: config.tlas.entityIds } },
804
+ ],
805
+ }),
806
+ tree: device.createBindGroup({
807
+ layout: pipelines.tree.getBindGroupLayout(0),
808
+ entries: [
809
+ { binding: 0, resource: { buffer: config.tlas.mortonCodes } },
810
+ { binding: 1, resource: { buffer: config.instanceCount } },
811
+ { binding: 2, resource: { buffer: config.tlas.treeNodes } },
812
+ { binding: 3, resource: { buffer: config.tlas.parentIndices } },
813
+ ],
814
+ }),
815
+ propagate: device.createBindGroup({
816
+ layout: pipelines.propagate.getBindGroupLayout(0),
817
+ entries: [
818
+ { binding: 0, resource: { buffer: config.instanceAABBs } },
819
+ { binding: 1, resource: { buffer: config.tlas.instanceIds } },
820
+ { binding: 2, resource: { buffer: config.instanceCount } },
821
+ { binding: 3, resource: { buffer: config.tlas.treeNodes } },
822
+ { binding: 4, resource: { buffer: config.tlas.boundsFlags } },
823
+ { binding: 5, resource: { buffer: config.tlas.parentIndices } },
824
+ ],
825
+ }),
826
+ collapse: device.createBindGroup({
827
+ layout: pipelines.collapse.getBindGroupLayout(0),
828
+ entries: [
829
+ { binding: 0, resource: { buffer: config.tlas.treeNodes } },
830
+ { binding: 1, resource: { buffer: config.instanceCount } },
831
+ { binding: 2, resource: { buffer: config.tlas.parentIndices } },
832
+ { binding: 3, resource: { buffer: config.tlas.bvhNodes } },
833
+ ],
834
+ }),
835
+ };
836
+ },
837
+
838
+ execute(ctx: ExecutionContext) {
839
+ const { device, encoder } = ctx;
840
+
841
+ const workgroups = Math.ceil(MAX_ENTITIES / WORKGROUP_SIZE);
842
+
843
+ const initBounds = new Int32Array([
844
+ 0x7f7fffff, 0x7f7fffff, 0x7f7fffff, 0, 0x80800000, 0x80800000, 0x80800000, 0,
845
+ ]);
846
+ device.queue.writeBuffer(config.tlas.sceneBounds, 0, initBounds);
847
+
848
+ encoder.clearBuffer(config.tlas.boundsFlags);
849
+ encoder.clearBuffer(config.tlas.parentIndices);
850
+
851
+ const boundsPass = encoder.beginComputePass();
852
+ boundsPass.setPipeline(pipelines!.bounds);
853
+ boundsPass.setBindGroup(0, bindGroups!.bounds);
854
+ boundsPass.dispatchWorkgroups(workgroups);
855
+ boundsPass.end();
856
+
857
+ const mortonPass = encoder.beginComputePass();
858
+ mortonPass.setPipeline(pipelines!.morton);
859
+ mortonPass.setBindGroup(0, bindGroups!.morton);
860
+ mortonPass.dispatchWorkgroups(workgroups);
861
+ mortonPass.end();
862
+
863
+ const sortPass = encoder.beginComputePass();
864
+ radixSort!.dispatch(sortPass);
865
+ sortPass.end();
866
+
867
+ const treePass = encoder.beginComputePass();
868
+ treePass.setPipeline(pipelines!.tree);
869
+ treePass.setBindGroup(0, bindGroups!.tree);
870
+ treePass.dispatchWorkgroups(Math.ceil((MAX_ENTITIES - 1) / WORKGROUP_SIZE));
871
+ treePass.end();
872
+
873
+ const propagatePass = encoder.beginComputePass();
874
+ propagatePass.setPipeline(pipelines!.propagate);
875
+ propagatePass.setBindGroup(0, bindGroups!.propagate);
876
+ propagatePass.dispatchWorkgroups(workgroups);
877
+ propagatePass.end();
878
+
879
+ const collapsePass = encoder.beginComputePass();
880
+ collapsePass.setPipeline(pipelines!.collapse);
881
+ collapsePass.setBindGroup(0, bindGroups!.collapse);
882
+ collapsePass.dispatchWorkgroups(Math.ceil((MAX_ENTITIES - 1) / WORKGROUP_SIZE));
883
+ collapsePass.end();
884
+ },
885
+ };
886
+ }