@multiplekex/shallot 0.2.4 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +1 -1
- package/src/core/component.ts +1 -1
- package/src/core/index.ts +1 -13
- package/src/core/math.ts +186 -0
- package/src/core/state.ts +1 -1
- package/src/core/xml.ts +56 -41
- package/src/extras/arrows/index.ts +3 -3
- package/src/extras/caustic.ts +37 -0
- package/src/extras/gradient/index.ts +63 -69
- package/src/extras/index.ts +3 -0
- package/src/extras/lines/index.ts +3 -3
- package/src/extras/orbit/index.ts +1 -1
- package/src/extras/skylab/index.ts +314 -0
- package/src/extras/text/font.ts +69 -14
- package/src/extras/text/index.ts +17 -69
- package/src/extras/text/sdf.ts +13 -2
- package/src/extras/water/index.ts +119 -0
- package/src/standard/defaults.ts +2 -0
- package/src/standard/index.ts +2 -0
- package/src/standard/raster/batch.ts +149 -0
- package/src/standard/raster/forward.ts +832 -0
- package/src/standard/raster/index.ts +191 -0
- package/src/standard/raster/shadow.ts +408 -0
- package/src/standard/{render → raytracing}/bvh/blas.ts +336 -88
- package/src/standard/raytracing/bvh/radix.ts +473 -0
- package/src/standard/raytracing/bvh/refit.ts +711 -0
- package/src/standard/{render → raytracing}/bvh/structs.ts +0 -55
- package/src/standard/{render → raytracing}/bvh/tlas.ts +155 -140
- package/src/standard/{render → raytracing}/bvh/traverse.ts +72 -64
- package/src/standard/{render → raytracing}/depth.ts +9 -9
- package/src/standard/raytracing/index.ts +409 -0
- package/src/standard/{render → raytracing}/instance.ts +31 -16
- package/src/standard/{render → raytracing}/ray.ts +1 -1
- package/src/standard/raytracing/shaders.ts +798 -0
- package/src/standard/{render → raytracing}/triangle.ts +1 -1
- package/src/standard/render/camera.ts +96 -106
- package/src/standard/render/data.ts +1 -1
- package/src/standard/render/index.ts +136 -220
- package/src/standard/render/indirect.ts +9 -10
- package/src/standard/render/light.ts +2 -2
- package/src/standard/render/mesh.ts +404 -0
- package/src/standard/render/overlay.ts +8 -5
- package/src/standard/render/pass.ts +1 -1
- package/src/standard/render/postprocess.ts +263 -242
- package/src/standard/render/scene.ts +28 -16
- package/src/standard/render/surface/index.ts +81 -12
- package/src/standard/render/surface/shaders.ts +511 -0
- package/src/standard/render/surface/structs.ts +23 -6
- package/src/standard/tween/tween.ts +44 -115
- package/src/standard/render/bvh/radix.ts +0 -476
- package/src/standard/render/forward/index.ts +0 -259
- package/src/standard/render/forward/raster.ts +0 -228
- package/src/standard/render/mesh/box.ts +0 -20
- package/src/standard/render/mesh/index.ts +0 -446
- package/src/standard/render/mesh/plane.ts +0 -11
- package/src/standard/render/mesh/sphere.ts +0 -40
- package/src/standard/render/mesh/unified.ts +0 -96
- package/src/standard/render/shaders.ts +0 -484
- package/src/standard/render/surface/compile.ts +0 -67
- package/src/standard/render/surface/noise.ts +0 -45
- package/src/standard/render/surface/wgsl.ts +0 -573
- /package/src/standard/{render → raytracing}/intersection.ts +0 -0
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
import type { ComputeNode, ExecutionContext } from "../../compute";
|
|
2
|
+
|
|
3
|
+
const WG_X = 16;
|
|
4
|
+
const WG_Y = 16;
|
|
5
|
+
const WG_SIZE = WG_X * WG_Y;
|
|
6
|
+
const ITEMS_PER_WG = 2 * WG_SIZE;
|
|
7
|
+
|
|
8
|
+
const blockSumShader = /* wgsl */ `
|
|
9
|
+
@group(0) @binding(0) var<storage, read> input: array<u32>;
|
|
10
|
+
@group(0) @binding(1) var<storage, read_write> localSums: array<u32>;
|
|
11
|
+
@group(0) @binding(2) var<storage, read_write> blockSums: array<u32>;
|
|
12
|
+
|
|
13
|
+
override WG_COUNT: u32;
|
|
14
|
+
override BIT: u32;
|
|
15
|
+
override COUNT: u32;
|
|
16
|
+
|
|
17
|
+
var<workgroup> wgData: array<u32, 2 * (${WG_SIZE} + 1)>;
|
|
18
|
+
|
|
19
|
+
@compute @workgroup_size(${WG_X}, ${WG_Y}, 1)
|
|
20
|
+
fn main(
|
|
21
|
+
@builtin(workgroup_id) wid: vec3<u32>,
|
|
22
|
+
@builtin(num_workgroups) wdim: vec3<u32>,
|
|
23
|
+
@builtin(local_invocation_index) tid: u32,
|
|
24
|
+
) {
|
|
25
|
+
let workgroup = wid.x + wid.y * wdim.x;
|
|
26
|
+
let base = workgroup * ${WG_SIZE}u;
|
|
27
|
+
let gid = base + tid;
|
|
28
|
+
|
|
29
|
+
let val = select(input[gid], 0u, gid >= COUNT);
|
|
30
|
+
let bits = (val >> BIT) & 0x3;
|
|
31
|
+
|
|
32
|
+
var sums = array<u32, 4>(0, 0, 0, 0);
|
|
33
|
+
var lastThread = 0xffffffffu;
|
|
34
|
+
|
|
35
|
+
if (workgroup < WG_COUNT) {
|
|
36
|
+
lastThread = min(${WG_SIZE}u, COUNT - base) - 1;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
let stride = ${WG_SIZE}u + 1;
|
|
40
|
+
var swap = 0u;
|
|
41
|
+
var inOff = tid;
|
|
42
|
+
var outOff = tid + stride;
|
|
43
|
+
|
|
44
|
+
for (var b = 0u; b < 4; b++) {
|
|
45
|
+
let mask = select(0u, 1u, bits == b);
|
|
46
|
+
wgData[inOff + 1] = mask;
|
|
47
|
+
workgroupBarrier();
|
|
48
|
+
|
|
49
|
+
var sum = 0u;
|
|
50
|
+
for (var off = 1u; off < ${WG_SIZE}u; off *= 2) {
|
|
51
|
+
if (tid >= off) {
|
|
52
|
+
sum = wgData[inOff] + wgData[inOff - off];
|
|
53
|
+
} else {
|
|
54
|
+
sum = wgData[inOff];
|
|
55
|
+
}
|
|
56
|
+
wgData[outOff] = sum;
|
|
57
|
+
outOff = inOff;
|
|
58
|
+
swap = stride - swap;
|
|
59
|
+
inOff = tid + swap;
|
|
60
|
+
workgroupBarrier();
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
sums[b] = sum;
|
|
64
|
+
|
|
65
|
+
if (tid == lastThread) {
|
|
66
|
+
blockSums[b * WG_COUNT + workgroup] = sum + mask;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
outOff = inOff;
|
|
70
|
+
swap = stride - swap;
|
|
71
|
+
inOff = tid + swap;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
if (gid < COUNT) {
|
|
75
|
+
localSums[gid] = sums[bits];
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
`;
|
|
79
|
+
|
|
80
|
+
const reorderShader = /* wgsl */ `
|
|
81
|
+
@group(0) @binding(0) var<storage, read> inKeys: array<u32>;
|
|
82
|
+
@group(0) @binding(1) var<storage, read_write> outKeys: array<u32>;
|
|
83
|
+
@group(0) @binding(2) var<storage, read> localSums: array<u32>;
|
|
84
|
+
@group(0) @binding(3) var<storage, read> blockSums: array<u32>;
|
|
85
|
+
@group(0) @binding(4) var<storage, read> inVals: array<u32>;
|
|
86
|
+
@group(0) @binding(5) var<storage, read_write> outVals: array<u32>;
|
|
87
|
+
|
|
88
|
+
override WG_COUNT: u32;
|
|
89
|
+
override BIT: u32;
|
|
90
|
+
override COUNT: u32;
|
|
91
|
+
|
|
92
|
+
@compute @workgroup_size(${WG_X}, ${WG_Y}, 1)
|
|
93
|
+
fn main(
|
|
94
|
+
@builtin(workgroup_id) wid: vec3<u32>,
|
|
95
|
+
@builtin(num_workgroups) wdim: vec3<u32>,
|
|
96
|
+
@builtin(local_invocation_index) tid: u32,
|
|
97
|
+
) {
|
|
98
|
+
let workgroup = wid.x + wid.y * wdim.x;
|
|
99
|
+
let gid = workgroup * ${WG_SIZE}u + tid;
|
|
100
|
+
|
|
101
|
+
if (gid >= COUNT) { return; }
|
|
102
|
+
|
|
103
|
+
let k = inKeys[gid];
|
|
104
|
+
let v = inVals[gid];
|
|
105
|
+
let bits = (k >> BIT) & 0x3;
|
|
106
|
+
let dst = blockSums[bits * WG_COUNT + workgroup] + localSums[gid];
|
|
107
|
+
|
|
108
|
+
outKeys[dst] = k;
|
|
109
|
+
outVals[dst] = v;
|
|
110
|
+
}
|
|
111
|
+
`;
|
|
112
|
+
|
|
113
|
+
const prefixSumShader = /* wgsl */ `
|
|
114
|
+
@group(0) @binding(0) var<storage, read_write> data: array<u32>;
|
|
115
|
+
@group(0) @binding(1) var<storage, read_write> blockSums: array<u32>;
|
|
116
|
+
|
|
117
|
+
override COUNT: u32;
|
|
118
|
+
|
|
119
|
+
var<workgroup> temp: array<u32, ${ITEMS_PER_WG * 2}>;
|
|
120
|
+
|
|
121
|
+
@compute @workgroup_size(${WG_X}, ${WG_Y}, 1)
|
|
122
|
+
fn scan(
|
|
123
|
+
@builtin(workgroup_id) wid: vec3<u32>,
|
|
124
|
+
@builtin(num_workgroups) wdim: vec3<u32>,
|
|
125
|
+
@builtin(local_invocation_index) tid: u32,
|
|
126
|
+
) {
|
|
127
|
+
let workgroup = wid.x + wid.y * wdim.x;
|
|
128
|
+
let base = workgroup * ${WG_SIZE}u;
|
|
129
|
+
let gid = base + tid;
|
|
130
|
+
let eid = gid * 2;
|
|
131
|
+
|
|
132
|
+
temp[tid * 2] = select(data[eid], 0u, eid >= COUNT);
|
|
133
|
+
temp[tid * 2 + 1] = select(data[eid + 1], 0u, eid + 1 >= COUNT);
|
|
134
|
+
|
|
135
|
+
var offset = 1u;
|
|
136
|
+
for (var d = ${ITEMS_PER_WG}u >> 1; d > 0; d >>= 1) {
|
|
137
|
+
workgroupBarrier();
|
|
138
|
+
if (tid < d) {
|
|
139
|
+
let ai = offset * (tid * 2 + 1) - 1;
|
|
140
|
+
let bi = offset * (tid * 2 + 2) - 1;
|
|
141
|
+
temp[bi] += temp[ai];
|
|
142
|
+
}
|
|
143
|
+
offset *= 2;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
if (tid == 0) {
|
|
147
|
+
blockSums[workgroup] = temp[${ITEMS_PER_WG}u - 1];
|
|
148
|
+
temp[${ITEMS_PER_WG}u - 1] = 0;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
for (var d = 1u; d < ${ITEMS_PER_WG}u; d *= 2) {
|
|
152
|
+
offset >>= 1;
|
|
153
|
+
workgroupBarrier();
|
|
154
|
+
if (tid < d) {
|
|
155
|
+
let ai = offset * (tid * 2 + 1) - 1;
|
|
156
|
+
let bi = offset * (tid * 2 + 2) - 1;
|
|
157
|
+
let t = temp[ai];
|
|
158
|
+
temp[ai] = temp[bi];
|
|
159
|
+
temp[bi] += t;
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
workgroupBarrier();
|
|
163
|
+
|
|
164
|
+
if (eid < COUNT) { data[eid] = temp[tid * 2]; }
|
|
165
|
+
if (eid + 1 < COUNT) { data[eid + 1] = temp[tid * 2 + 1]; }
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
@compute @workgroup_size(${WG_X}, ${WG_Y}, 1)
|
|
169
|
+
fn addBlocks(
|
|
170
|
+
@builtin(workgroup_id) wid: vec3<u32>,
|
|
171
|
+
@builtin(num_workgroups) wdim: vec3<u32>,
|
|
172
|
+
@builtin(local_invocation_index) tid: u32,
|
|
173
|
+
) {
|
|
174
|
+
let workgroup = wid.x + wid.y * wdim.x;
|
|
175
|
+
let eid = (workgroup * ${WG_SIZE}u + tid) * 2;
|
|
176
|
+
|
|
177
|
+
if (eid >= COUNT) { return; }
|
|
178
|
+
|
|
179
|
+
let sum = blockSums[workgroup];
|
|
180
|
+
data[eid] += sum;
|
|
181
|
+
if (eid + 1 < COUNT) { data[eid + 1] += sum; }
|
|
182
|
+
}
|
|
183
|
+
`;
|
|
184
|
+
|
|
185
|
+
function dispatchSize(device: GPUDevice, count: number): [number, number] {
|
|
186
|
+
const max = device.limits.maxComputeWorkgroupsPerDimension;
|
|
187
|
+
if (count <= max) return [count, 1];
|
|
188
|
+
const x = Math.ceil(Math.sqrt(count));
|
|
189
|
+
return [x, Math.ceil(count / x)];
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
interface PrefixPass {
|
|
193
|
+
pipeline: GPUComputePipeline;
|
|
194
|
+
bindGroup: GPUBindGroup;
|
|
195
|
+
dispatch: [number, number];
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
interface PrefixSumState {
|
|
199
|
+
passes: PrefixPass[];
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
async function buildPrefixPasses(
|
|
203
|
+
device: GPUDevice,
|
|
204
|
+
module: GPUShaderModule,
|
|
205
|
+
data: GPUBuffer,
|
|
206
|
+
count: number,
|
|
207
|
+
passes: PrefixPass[]
|
|
208
|
+
): Promise<void> {
|
|
209
|
+
const wgCount = Math.ceil(count / ITEMS_PER_WG);
|
|
210
|
+
const dispatch = dispatchSize(device, wgCount);
|
|
211
|
+
|
|
212
|
+
const blockSums = device.createBuffer({
|
|
213
|
+
size: Math.max(wgCount * 4, 4),
|
|
214
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
|
215
|
+
});
|
|
216
|
+
|
|
217
|
+
const layout = device.createBindGroupLayout({
|
|
218
|
+
entries: [
|
|
219
|
+
{ binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } },
|
|
220
|
+
{ binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } },
|
|
221
|
+
],
|
|
222
|
+
});
|
|
223
|
+
|
|
224
|
+
const bindGroup = device.createBindGroup({
|
|
225
|
+
layout,
|
|
226
|
+
entries: [
|
|
227
|
+
{ binding: 0, resource: { buffer: data } },
|
|
228
|
+
{ binding: 1, resource: { buffer: blockSums } },
|
|
229
|
+
],
|
|
230
|
+
});
|
|
231
|
+
|
|
232
|
+
const pipelineLayout = device.createPipelineLayout({ bindGroupLayouts: [layout] });
|
|
233
|
+
|
|
234
|
+
passes.push({
|
|
235
|
+
pipeline: await device.createComputePipelineAsync({
|
|
236
|
+
layout: pipelineLayout,
|
|
237
|
+
compute: { module, entryPoint: "scan", constants: { COUNT: count } },
|
|
238
|
+
}),
|
|
239
|
+
bindGroup,
|
|
240
|
+
dispatch,
|
|
241
|
+
});
|
|
242
|
+
|
|
243
|
+
if (wgCount > 1) {
|
|
244
|
+
await buildPrefixPasses(device, module, blockSums, wgCount, passes);
|
|
245
|
+
|
|
246
|
+
passes.push({
|
|
247
|
+
pipeline: await device.createComputePipelineAsync({
|
|
248
|
+
layout: pipelineLayout,
|
|
249
|
+
compute: { module, entryPoint: "addBlocks", constants: { COUNT: count } },
|
|
250
|
+
}),
|
|
251
|
+
bindGroup,
|
|
252
|
+
dispatch,
|
|
253
|
+
});
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
async function createPrefixSum(
|
|
258
|
+
device: GPUDevice,
|
|
259
|
+
data: GPUBuffer,
|
|
260
|
+
count: number
|
|
261
|
+
): Promise<PrefixSumState> {
|
|
262
|
+
const passes: PrefixPass[] = [];
|
|
263
|
+
const module = device.createShaderModule({ code: prefixSumShader });
|
|
264
|
+
await buildPrefixPasses(device, module, data, count, passes);
|
|
265
|
+
return { passes };
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
function dispatchPrefixSum(state: PrefixSumState, pass: GPUComputePassEncoder): void {
|
|
269
|
+
for (const p of state.passes) {
|
|
270
|
+
pass.setPipeline(p.pipeline);
|
|
271
|
+
pass.setBindGroup(0, p.bindGroup);
|
|
272
|
+
pass.dispatchWorkgroups(p.dispatch[0], p.dispatch[1], 1);
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
interface RadixPass {
|
|
277
|
+
blockSum: { pipeline: GPUComputePipeline; bindGroup: GPUBindGroup };
|
|
278
|
+
reorder: { pipeline: GPUComputePipeline; bindGroup: GPUBindGroup };
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
interface RadixSortState {
|
|
282
|
+
passes: RadixPass[];
|
|
283
|
+
prefixSum: PrefixSumState;
|
|
284
|
+
workgroups: [number, number];
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
export interface RadixSortConfig {
|
|
288
|
+
keys: GPUBuffer;
|
|
289
|
+
values: GPUBuffer;
|
|
290
|
+
count: number;
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
export async function createRadixSort(
|
|
294
|
+
device: GPUDevice,
|
|
295
|
+
config: RadixSortConfig
|
|
296
|
+
): Promise<RadixSortState> {
|
|
297
|
+
const { keys, values, count } = config;
|
|
298
|
+
const wgCount = Math.ceil(count / WG_SIZE);
|
|
299
|
+
const workgroups = dispatchSize(device, wgCount);
|
|
300
|
+
|
|
301
|
+
const tmpKeys = device.createBuffer({
|
|
302
|
+
size: count * 4,
|
|
303
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
|
304
|
+
});
|
|
305
|
+
const tmpVals = device.createBuffer({
|
|
306
|
+
size: count * 4,
|
|
307
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
|
308
|
+
});
|
|
309
|
+
const localSums = device.createBuffer({
|
|
310
|
+
size: count * 4,
|
|
311
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
|
312
|
+
});
|
|
313
|
+
const blockSums = device.createBuffer({
|
|
314
|
+
size: 4 * wgCount * 4,
|
|
315
|
+
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
|
316
|
+
});
|
|
317
|
+
|
|
318
|
+
const prefixSum = await createPrefixSum(device, blockSums, 4 * wgCount);
|
|
319
|
+
|
|
320
|
+
const blockSumModule = device.createShaderModule({ code: blockSumShader });
|
|
321
|
+
const reorderModule = device.createShaderModule({ code: reorderShader });
|
|
322
|
+
|
|
323
|
+
const blockSumLayout = device.createBindGroupLayout({
|
|
324
|
+
entries: [
|
|
325
|
+
{
|
|
326
|
+
binding: 0,
|
|
327
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
328
|
+
buffer: { type: "read-only-storage" },
|
|
329
|
+
},
|
|
330
|
+
{ binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } },
|
|
331
|
+
{ binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } },
|
|
332
|
+
],
|
|
333
|
+
});
|
|
334
|
+
|
|
335
|
+
const reorderLayout = device.createBindGroupLayout({
|
|
336
|
+
entries: [
|
|
337
|
+
{
|
|
338
|
+
binding: 0,
|
|
339
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
340
|
+
buffer: { type: "read-only-storage" },
|
|
341
|
+
},
|
|
342
|
+
{ binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } },
|
|
343
|
+
{
|
|
344
|
+
binding: 2,
|
|
345
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
346
|
+
buffer: { type: "read-only-storage" },
|
|
347
|
+
},
|
|
348
|
+
{
|
|
349
|
+
binding: 3,
|
|
350
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
351
|
+
buffer: { type: "read-only-storage" },
|
|
352
|
+
},
|
|
353
|
+
{
|
|
354
|
+
binding: 4,
|
|
355
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
356
|
+
buffer: { type: "read-only-storage" },
|
|
357
|
+
},
|
|
358
|
+
{ binding: 5, visibility: GPUShaderStage.COMPUTE, buffer: { type: "storage" } },
|
|
359
|
+
],
|
|
360
|
+
});
|
|
361
|
+
|
|
362
|
+
const pipelinePromises: Promise<{
|
|
363
|
+
blockSum: GPUComputePipeline;
|
|
364
|
+
reorder: GPUComputePipeline;
|
|
365
|
+
}>[] = [];
|
|
366
|
+
|
|
367
|
+
for (let bit = 0; bit < 32; bit += 2) {
|
|
368
|
+
pipelinePromises.push(
|
|
369
|
+
(async () => {
|
|
370
|
+
const [blockSumPipeline, reorderPipeline] = await Promise.all([
|
|
371
|
+
device.createComputePipelineAsync({
|
|
372
|
+
layout: device.createPipelineLayout({
|
|
373
|
+
bindGroupLayouts: [blockSumLayout],
|
|
374
|
+
}),
|
|
375
|
+
compute: {
|
|
376
|
+
module: blockSumModule,
|
|
377
|
+
entryPoint: "main",
|
|
378
|
+
constants: { WG_COUNT: wgCount, BIT: bit, COUNT: count },
|
|
379
|
+
},
|
|
380
|
+
}),
|
|
381
|
+
device.createComputePipelineAsync({
|
|
382
|
+
layout: device.createPipelineLayout({
|
|
383
|
+
bindGroupLayouts: [reorderLayout],
|
|
384
|
+
}),
|
|
385
|
+
compute: {
|
|
386
|
+
module: reorderModule,
|
|
387
|
+
entryPoint: "main",
|
|
388
|
+
constants: { WG_COUNT: wgCount, BIT: bit, COUNT: count },
|
|
389
|
+
},
|
|
390
|
+
}),
|
|
391
|
+
]);
|
|
392
|
+
return { blockSum: blockSumPipeline, reorder: reorderPipeline };
|
|
393
|
+
})()
|
|
394
|
+
);
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
const pipelines = await Promise.all(pipelinePromises);
|
|
398
|
+
const passes: RadixPass[] = [];
|
|
399
|
+
|
|
400
|
+
for (let i = 0; i < 16; i++) {
|
|
401
|
+
const bit = i * 2;
|
|
402
|
+
const even = bit % 4 === 0;
|
|
403
|
+
const inK = even ? keys : tmpKeys;
|
|
404
|
+
const inV = even ? values : tmpVals;
|
|
405
|
+
const outK = even ? tmpKeys : keys;
|
|
406
|
+
const outV = even ? tmpVals : values;
|
|
407
|
+
|
|
408
|
+
passes.push({
|
|
409
|
+
blockSum: {
|
|
410
|
+
pipeline: pipelines[i].blockSum,
|
|
411
|
+
bindGroup: device.createBindGroup({
|
|
412
|
+
layout: blockSumLayout,
|
|
413
|
+
entries: [
|
|
414
|
+
{ binding: 0, resource: { buffer: inK } },
|
|
415
|
+
{ binding: 1, resource: { buffer: localSums } },
|
|
416
|
+
{ binding: 2, resource: { buffer: blockSums } },
|
|
417
|
+
],
|
|
418
|
+
}),
|
|
419
|
+
},
|
|
420
|
+
reorder: {
|
|
421
|
+
pipeline: pipelines[i].reorder,
|
|
422
|
+
bindGroup: device.createBindGroup({
|
|
423
|
+
layout: reorderLayout,
|
|
424
|
+
entries: [
|
|
425
|
+
{ binding: 0, resource: { buffer: inK } },
|
|
426
|
+
{ binding: 1, resource: { buffer: outK } },
|
|
427
|
+
{ binding: 2, resource: { buffer: localSums } },
|
|
428
|
+
{ binding: 3, resource: { buffer: blockSums } },
|
|
429
|
+
{ binding: 4, resource: { buffer: inV } },
|
|
430
|
+
{ binding: 5, resource: { buffer: outV } },
|
|
431
|
+
],
|
|
432
|
+
}),
|
|
433
|
+
},
|
|
434
|
+
});
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
return { passes, prefixSum, workgroups };
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
export function dispatchRadixSort(state: RadixSortState, pass: GPUComputePassEncoder): void {
|
|
441
|
+
const [x, y] = state.workgroups;
|
|
442
|
+
for (const p of state.passes) {
|
|
443
|
+
pass.setPipeline(p.blockSum.pipeline);
|
|
444
|
+
pass.setBindGroup(0, p.blockSum.bindGroup);
|
|
445
|
+
pass.dispatchWorkgroups(x, y, 1);
|
|
446
|
+
|
|
447
|
+
dispatchPrefixSum(state.prefixSum, pass);
|
|
448
|
+
|
|
449
|
+
pass.setPipeline(p.reorder.pipeline);
|
|
450
|
+
pass.setBindGroup(0, p.reorder.bindGroup);
|
|
451
|
+
pass.dispatchWorkgroups(x, y, 1);
|
|
452
|
+
}
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
export function createRadixSortNode(config: RadixSortConfig): ComputeNode {
|
|
456
|
+
let sort: RadixSortState | null = null;
|
|
457
|
+
|
|
458
|
+
return {
|
|
459
|
+
id: "radix-sort",
|
|
460
|
+
inputs: [],
|
|
461
|
+
outputs: [],
|
|
462
|
+
|
|
463
|
+
async prepare(device: GPUDevice) {
|
|
464
|
+
sort = await createRadixSort(device, config);
|
|
465
|
+
},
|
|
466
|
+
|
|
467
|
+
execute(ctx: ExecutionContext) {
|
|
468
|
+
const pass = ctx.encoder.beginComputePass();
|
|
469
|
+
dispatchRadixSort(sort!, pass);
|
|
470
|
+
pass.end();
|
|
471
|
+
},
|
|
472
|
+
};
|
|
473
|
+
}
|