@jax-js/jax 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +15 -9
- package/dist/{backend-BY8wlLEl.js → backend-DaqL-MNz.js} +240 -21
- package/dist/{backend-CmaidnkQ.cjs → backend-DziQSaoQ.cjs} +264 -21
- package/dist/index.cjs +2407 -1132
- package/dist/index.d.cts +596 -97
- package/dist/index.d.ts +596 -97
- package/dist/index.js +2400 -1126
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/webgpu-Db2JrNBr.cjs +1261 -0
- package/dist/webgpu-Dh7k9io0.js +1261 -0
- package/package.json +1 -1
- package/dist/webgpu-BVns4DbI.cjs +0 -663
- package/dist/webgpu-C9iAP5h5.js +0 -663
|
@@ -0,0 +1,1261 @@
|
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-DaqL-MNz.js";
|
|
2
|
+
|
|
3
|
+
//#region src/backend/webgpu/builtins.ts
|
|
4
|
+
const threefrySrc = `
|
|
5
|
+
fn threefry2x32(key: vec2<u32>, ctr: vec2<u32>) -> vec2<u32> {
|
|
6
|
+
let ks0: u32 = key.x;
|
|
7
|
+
let ks1: u32 = key.y;
|
|
8
|
+
let ks2: u32 = ks0 ^ ks1 ^ 0x1BD11BDAu;
|
|
9
|
+
|
|
10
|
+
var x0: u32 = ctr.x + ks0;
|
|
11
|
+
var x1: u32 = ctr.y + ks1;
|
|
12
|
+
|
|
13
|
+
x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
|
|
14
|
+
x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
|
|
15
|
+
x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
|
|
16
|
+
x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
|
|
17
|
+
x0 += ks1;
|
|
18
|
+
x1 += ks2 + 1u;
|
|
19
|
+
|
|
20
|
+
x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
|
|
21
|
+
x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
|
|
22
|
+
x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
|
|
23
|
+
x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
|
|
24
|
+
x0 += ks2;
|
|
25
|
+
x1 += ks0 + 2u;
|
|
26
|
+
|
|
27
|
+
x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
|
|
28
|
+
x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
|
|
29
|
+
x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
|
|
30
|
+
x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
|
|
31
|
+
x0 += ks0;
|
|
32
|
+
x1 += ks1 + 3u;
|
|
33
|
+
|
|
34
|
+
x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
|
|
35
|
+
x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
|
|
36
|
+
x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
|
|
37
|
+
x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
|
|
38
|
+
x0 += ks1;
|
|
39
|
+
x1 += ks2 + 4u;
|
|
40
|
+
|
|
41
|
+
x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
|
|
42
|
+
x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
|
|
43
|
+
x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
|
|
44
|
+
x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
|
|
45
|
+
x0 += ks2;
|
|
46
|
+
x1 += ks0 + 5u;
|
|
47
|
+
|
|
48
|
+
return vec2<u32>(x0, x1);
|
|
49
|
+
}`;
|
|
50
|
+
const erfSrc = `
|
|
51
|
+
const _erf_p: f32 = 0.3275911;
|
|
52
|
+
const _erf_a1: f32 = 0.254829592;
|
|
53
|
+
const _erf_a2: f32 = -0.284496736;
|
|
54
|
+
const _erf_a3: f32 = 1.421413741;
|
|
55
|
+
const _erf_a4: f32 = -1.453152027;
|
|
56
|
+
const _erf_a5: f32 = 1.061405429;
|
|
57
|
+
fn erf(x: f32) -> f32 {
|
|
58
|
+
let t = 1.0 / (1.0 + _erf_p * abs(x));
|
|
59
|
+
let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
|
|
60
|
+
return sign(x) * (1.0 - P_t * exp(-x * x));
|
|
61
|
+
}
|
|
62
|
+
fn erfc(x: f32) -> f32 {
|
|
63
|
+
let t = 1.0 / (1.0 + _erf_p * abs(x));
|
|
64
|
+
let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
|
|
65
|
+
let E = P_t * exp(-x * x);
|
|
66
|
+
return select(2.0 - E, E, x >= 0.0);
|
|
67
|
+
}`;
|
|
68
|
+
|
|
69
|
+
//#endregion
|
|
70
|
+
//#region src/backend/webgpu/codegen.ts
|
|
71
|
+
const headerWgsl = String.raw`
|
|
72
|
+
fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }
|
|
73
|
+
fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }
|
|
74
|
+
`.trim();
|
|
75
|
+
function dtypeToWgsl(dtype, storage = false) {
|
|
76
|
+
switch (dtype) {
|
|
77
|
+
case DType.Bool: return storage ? "i32" : "bool";
|
|
78
|
+
case DType.Int32: return "i32";
|
|
79
|
+
case DType.Uint32: return "u32";
|
|
80
|
+
case DType.Float32: return "f32";
|
|
81
|
+
case DType.Float16: return "f16";
|
|
82
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
function maxValueWgsl(dtype) {
|
|
86
|
+
switch (dtype) {
|
|
87
|
+
case DType.Bool: return "1";
|
|
88
|
+
case DType.Int32: return "2147483647";
|
|
89
|
+
case DType.Uint32: return "4294967295u";
|
|
90
|
+
case DType.Float32: return "inf()";
|
|
91
|
+
case DType.Float16: return "f16(inf())";
|
|
92
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
function constToWgsl(dtype, value) {
|
|
96
|
+
if (dtype === DType.Bool) return value ? "true" : "false";
|
|
97
|
+
if (dtype === DType.Int32) return value.toString();
|
|
98
|
+
if (dtype === DType.Uint32) return value.toString() + "u";
|
|
99
|
+
if (dtype === DType.Float32) {
|
|
100
|
+
if (Number.isNaN(value)) return "nan()";
|
|
101
|
+
if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
|
|
102
|
+
return "f32(" + value.toString() + ")";
|
|
103
|
+
}
|
|
104
|
+
if (dtype === DType.Float16) {
|
|
105
|
+
if (Number.isNaN(value)) return "f16(nan())";
|
|
106
|
+
if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
|
|
107
|
+
return "f16(" + value.toString() + ")";
|
|
108
|
+
}
|
|
109
|
+
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
110
|
+
}
|
|
111
|
+
const gridOffsetY = 16384;
|
|
112
|
+
function calculateGrid(gridSize) {
|
|
113
|
+
let gridX = gridSize;
|
|
114
|
+
let gridY = 1;
|
|
115
|
+
if (gridSize > 65535) {
|
|
116
|
+
gridX = gridOffsetY;
|
|
117
|
+
gridY = Math.ceil(gridSize / gridOffsetY);
|
|
118
|
+
}
|
|
119
|
+
return [gridX, gridY];
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
//#endregion
|
|
123
|
+
//#region src/backend/webgpu/reader.ts
|
|
124
|
+
/**
|
|
125
|
+
* Graphics state used to synchronously read data from WebGPU buffers.
|
|
126
|
+
*
|
|
127
|
+
* This trick is borrowed from TensorFlow.js. Basically, the idea is to create
|
|
128
|
+
* an offscreen canvas with one pixel for every 4 bytes ("device storage"), then
|
|
129
|
+
* configure it with a WebGPU context. Copy the buffer to a texture, then draw
|
|
130
|
+
* the canvas onto another offscreen canvas with '2d' context ("host storage").
|
|
131
|
+
*
|
|
132
|
+
* Once it's on host storage, we can use `getImageData()` to read the pixels
|
|
133
|
+
* from the image directly.
|
|
134
|
+
*
|
|
135
|
+
* We use 256x256 canvases here (256 KiB). The performance of this is bad
|
|
136
|
+
* because it involves multiple data copies, but it still works. We also
|
|
137
|
+
* actually need to copy the image twice: once in "opaque" mode for the RGB
|
|
138
|
+
* values, and once in "premultiplied" mode for the alpha channel.
|
|
139
|
+
*
|
|
140
|
+
* https://github.com/tensorflow/tfjs/blob/tfjs-v4.22.0/tfjs-backend-webgpu/src/backend_webgpu.ts#L379
|
|
141
|
+
*/
|
|
142
|
+
var SyncReader = class SyncReader {
|
|
143
|
+
static alphaModes = ["opaque", "premultiplied"];
|
|
144
|
+
static width = 256;
|
|
145
|
+
static height = 256;
|
|
146
|
+
initialized = false;
|
|
147
|
+
deviceStorage;
|
|
148
|
+
deviceContexts;
|
|
149
|
+
hostStorage;
|
|
150
|
+
hostContext;
|
|
151
|
+
constructor(device) {
|
|
152
|
+
this.device = device;
|
|
153
|
+
}
|
|
154
|
+
#init() {
|
|
155
|
+
const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
|
|
156
|
+
this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
|
|
157
|
+
this.deviceContexts = this.deviceStorage.map((canvas, i) => {
|
|
158
|
+
const context = canvas.getContext("webgpu");
|
|
159
|
+
context.configure({
|
|
160
|
+
device: this.device,
|
|
161
|
+
format: "bgra8unorm",
|
|
162
|
+
usage: GPUTextureUsage.COPY_DST,
|
|
163
|
+
alphaMode: SyncReader.alphaModes[i]
|
|
164
|
+
});
|
|
165
|
+
return context;
|
|
166
|
+
});
|
|
167
|
+
this.hostStorage = makeCanvas();
|
|
168
|
+
this.hostContext = this.hostStorage.getContext("2d", { willReadFrequently: true });
|
|
169
|
+
this.initialized = true;
|
|
170
|
+
}
|
|
171
|
+
read(buffer, start, count) {
|
|
172
|
+
if (!this.initialized) this.#init();
|
|
173
|
+
const deviceStorage = this.deviceStorage;
|
|
174
|
+
const deviceContexts = this.deviceContexts;
|
|
175
|
+
const hostContext = this.hostContext;
|
|
176
|
+
const pixelsSize = Math.ceil(count / 4);
|
|
177
|
+
const bytesPerRow = SyncReader.width * 4;
|
|
178
|
+
const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
|
|
179
|
+
for (let i = 0; i < deviceContexts.length; i++) {
|
|
180
|
+
const texture = deviceContexts[i].getCurrentTexture();
|
|
181
|
+
const readData = (width, height, offset$1) => {
|
|
182
|
+
const encoder = this.device.createCommandEncoder();
|
|
183
|
+
encoder.copyBufferToTexture({
|
|
184
|
+
buffer,
|
|
185
|
+
bytesPerRow,
|
|
186
|
+
offset: offset$1 + start
|
|
187
|
+
}, { texture }, {
|
|
188
|
+
width,
|
|
189
|
+
height,
|
|
190
|
+
depthOrArrayLayers: 1
|
|
191
|
+
});
|
|
192
|
+
const commandBuffer = encoder.finish();
|
|
193
|
+
this.device.queue.submit([commandBuffer]);
|
|
194
|
+
hostContext.clearRect(0, 0, width, height);
|
|
195
|
+
hostContext.drawImage(deviceStorage[i], 0, 0);
|
|
196
|
+
const values = hostContext.getImageData(0, 0, width, height).data;
|
|
197
|
+
const span = new Uint8ClampedArray(valsGPU, offset$1, 4 * width * height);
|
|
198
|
+
const alphaMode = SyncReader.alphaModes[i];
|
|
199
|
+
for (let k = 0; k < span.length; k += 4) if (alphaMode === "premultiplied") span[k + 3] = values[k + 3];
|
|
200
|
+
else {
|
|
201
|
+
span[k] = values[k + 2];
|
|
202
|
+
span[k + 1] = values[k + 1];
|
|
203
|
+
span[k + 2] = values[k];
|
|
204
|
+
}
|
|
205
|
+
};
|
|
206
|
+
const pixelsPerCanvas = SyncReader.width * SyncReader.height;
|
|
207
|
+
const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
|
|
208
|
+
let remainder = pixelsSize % pixelsPerCanvas;
|
|
209
|
+
const remainderRows = Math.floor(remainder / SyncReader.width);
|
|
210
|
+
remainder = remainder % SyncReader.width;
|
|
211
|
+
let offset = 0;
|
|
212
|
+
for (let j = 0; j < wholeChunks; j++) {
|
|
213
|
+
readData(SyncReader.width, SyncReader.height, offset);
|
|
214
|
+
offset += pixelsPerCanvas * 4;
|
|
215
|
+
}
|
|
216
|
+
if (remainderRows > 0) {
|
|
217
|
+
readData(SyncReader.width, remainderRows, offset);
|
|
218
|
+
offset += remainderRows * SyncReader.width * 4;
|
|
219
|
+
}
|
|
220
|
+
if (remainder > 0) readData(remainder, 1, offset);
|
|
221
|
+
}
|
|
222
|
+
return new Uint8Array(valsGPU, 0, count);
|
|
223
|
+
}
|
|
224
|
+
};
|
|
225
|
+
|
|
226
|
+
//#endregion
|
|
227
|
+
//#region src/backend/webgpu/routines.ts
|
|
228
|
+
function bitonicSortUniform(pass) {
|
|
229
|
+
const ar = new Uint32Array(3);
|
|
230
|
+
ar[0] = pass.kind === "sort" ? 0 : 1;
|
|
231
|
+
ar[1] = pass.mergeStep ?? 0;
|
|
232
|
+
ar[2] = pass.mergeStage ?? 0;
|
|
233
|
+
return new Uint8Array(ar.buffer);
|
|
234
|
+
}
|
|
235
|
+
/**
|
|
236
|
+
* Generate a bitonic sort shader.
|
|
237
|
+
*
|
|
238
|
+
* We implement a variant of bitonic sort that [only has forward comparators](
|
|
239
|
+
* <https://sortingalgos.miraheze.org/wiki/Bitonic_Sort#Bitonic_Sort_using_Forward_Comparators>),
|
|
240
|
+
* so we don't need to allocate memory for power-of-two padding.
|
|
241
|
+
*
|
|
242
|
+
* This uses workgroup shared memory up to `2*workgroupSize` elements, for each
|
|
243
|
+
* array in `batches`. For larger arrays, multiple passes are done:
|
|
244
|
+
*
|
|
245
|
+
* - Initial "sort" pass: each workgroup sorts its `2*workgroupSize` elements.
|
|
246
|
+
* - Subsequent "merge" passes: each pass merges sorted sequences of size
|
|
247
|
+
* `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
|
|
248
|
+
*
|
|
249
|
+
* The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
|
|
250
|
+
*/
|
|
251
|
+
function bitonicSortShader(device, dtype, n, batches, outputIndices) {
|
|
252
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
253
|
+
const paddedN = 1 << Math.ceil(Math.log2(n || 1));
|
|
254
|
+
const numThreads = Math.ceil(paddedN / 2);
|
|
255
|
+
const workgroupSize = findPow2(numThreads, device.limits.maxComputeWorkgroupSizeX);
|
|
256
|
+
const workgroupsPerBatch = numThreads / workgroupSize;
|
|
257
|
+
const numStages = Math.log2(paddedN);
|
|
258
|
+
const numLocalStages = Math.min(numStages, Math.log2(workgroupSize * 2));
|
|
259
|
+
const needsF16 = dtype === DType.Float16;
|
|
260
|
+
const padValue = isFloatDtype(dtype) ? `${ty}(nan())` : maxValueWgsl(dtype);
|
|
261
|
+
const code = `
|
|
262
|
+
${needsF16 ? "enable f16;" : ""}
|
|
263
|
+
${headerWgsl}
|
|
264
|
+
|
|
265
|
+
struct Uniforms {
|
|
266
|
+
kind: u32, // 0 = sort, 1 = merge
|
|
267
|
+
merge_step: u32, // half_block = 2^step
|
|
268
|
+
merge_stage: u32, // only used for merge
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
@group(0) @binding(0) var<storage, read> input: array<${ty}>;
|
|
272
|
+
@group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
|
|
273
|
+
${outputIndices ? `@group(0) @binding(2) var<storage, read_write> output_idx: array<i32>;` : ""}
|
|
274
|
+
|
|
275
|
+
@group(1) @binding(0) var<uniform> uniforms: Uniforms;
|
|
276
|
+
|
|
277
|
+
var<workgroup> shared_vals: array<${ty}, ${workgroupSize * 2}>;
|
|
278
|
+
${outputIndices ? `var<workgroup> shared_idx: array<i32, ${workgroupSize * 2}>;` : ""}
|
|
279
|
+
|
|
280
|
+
fn compare(a: ${ty}, b: ${ty}) -> bool {
|
|
281
|
+
${isFloatDtype(dtype) ? `
|
|
282
|
+
let min_value = min(a, b);
|
|
283
|
+
return a == min_value && b != min_value;` : " return a < b;"}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
fn compare_and_swap(i: u32, j: u32) {
|
|
287
|
+
let val_i = shared_vals[i];
|
|
288
|
+
let val_j = shared_vals[j];
|
|
289
|
+
if (compare(val_j, val_i)) {
|
|
290
|
+
shared_vals[i] = val_j;
|
|
291
|
+
shared_vals[j] = val_i;
|
|
292
|
+
${outputIndices ? `
|
|
293
|
+
let tmp_idx = shared_idx[i];
|
|
294
|
+
shared_idx[i] = shared_idx[j];
|
|
295
|
+
shared_idx[j] = tmp_idx;` : ""}
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
@compute @workgroup_size(${workgroupSize})
|
|
300
|
+
fn main(
|
|
301
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
302
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
303
|
+
) {
|
|
304
|
+
let blockid = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
305
|
+
let batch = blockid / ${workgroupsPerBatch}u;
|
|
306
|
+
let wg_in_batch = blockid % ${workgroupsPerBatch}u;
|
|
307
|
+
|
|
308
|
+
let tid = local_id.x;
|
|
309
|
+
let base = batch * ${n}u;
|
|
310
|
+
|
|
311
|
+
if (uniforms.kind == 0u || (uniforms.kind == 1u && uniforms.merge_step == ${numLocalStages - 1}u)) {
|
|
312
|
+
let wg_base = wg_in_batch * ${workgroupSize * 2}u;
|
|
313
|
+
|
|
314
|
+
// Load data into shared memory (2 elements per thread)
|
|
315
|
+
let idx0 = tid * 2u;
|
|
316
|
+
let idx1 = tid * 2u + 1u;
|
|
317
|
+
// Load from input for initial 'sort' pass, then from output (read-write) for 'merge' passes.
|
|
318
|
+
if (uniforms.kind == 0u) {
|
|
319
|
+
shared_vals[idx0] = select(${padValue}, input[base + wg_base + idx0], wg_base + idx0 < ${n}u);
|
|
320
|
+
shared_vals[idx1] = select(${padValue}, input[base + wg_base + idx1], wg_base + idx1 < ${n}u);
|
|
321
|
+
${outputIndices ? `
|
|
322
|
+
shared_idx[idx0] = i32(wg_base + idx0);
|
|
323
|
+
shared_idx[idx1] = i32(wg_base + idx1);` : ""}
|
|
324
|
+
} else {
|
|
325
|
+
shared_vals[idx0] = select(${padValue}, output[base + wg_base + idx0], wg_base + idx0 < ${n}u);
|
|
326
|
+
shared_vals[idx1] = select(${padValue}, output[base + wg_base + idx1], wg_base + idx1 < ${n}u);
|
|
327
|
+
${outputIndices ? `
|
|
328
|
+
shared_idx[idx0] = select(${n}, output_idx[base + wg_base + idx0], wg_base + idx0 < ${n}u);
|
|
329
|
+
shared_idx[idx1] = select(${n}, output_idx[base + wg_base + idx1], wg_base + idx1 < ${n}u);` : ""}
|
|
330
|
+
}
|
|
331
|
+
workgroupBarrier();
|
|
332
|
+
|
|
333
|
+
let initial_stage = select(0u, ${numLocalStages - 1}u, uniforms.kind != 0u);
|
|
334
|
+
for (var stage = initial_stage; stage < ${numLocalStages}u; stage++) {
|
|
335
|
+
for (var step1 = stage + 1u; step1 > 0u; step1--) {
|
|
336
|
+
let step = step1 - 1u;
|
|
337
|
+
let half_block = 1u << step;
|
|
338
|
+
let is_first_step = uniforms.kind == 0u && step == stage;
|
|
339
|
+
|
|
340
|
+
let block_offset = (tid / half_block) * half_block;
|
|
341
|
+
let local_offset = tid % half_block;
|
|
342
|
+
let i = block_offset * 2u + local_offset;
|
|
343
|
+
let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
|
|
344
|
+
compare_and_swap(i, j);
|
|
345
|
+
|
|
346
|
+
workgroupBarrier();
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
if (wg_base + idx0 < ${n}u) {
|
|
351
|
+
output[base + wg_base + idx0] = shared_vals[idx0];
|
|
352
|
+
${outputIndices ? `output_idx[base + wg_base + idx0] = shared_idx[idx0];` : ""}
|
|
353
|
+
}
|
|
354
|
+
if (wg_base + idx1 < ${n}u) {
|
|
355
|
+
output[base + wg_base + idx1] = shared_vals[idx1];
|
|
356
|
+
${outputIndices ? `output_idx[base + wg_base + idx1] = shared_idx[idx1];` : ""}
|
|
357
|
+
}
|
|
358
|
+
} else {
|
|
359
|
+
// Execute single merge pass for a step >= numLocalStages.
|
|
360
|
+
let half_block = 1u << uniforms.merge_step; // half_block >= workgroupSize * 2
|
|
361
|
+
let thread_in_batch = wg_in_batch * ${workgroupSize} + tid;
|
|
362
|
+
let is_first_step = uniforms.merge_step == uniforms.merge_stage;
|
|
363
|
+
|
|
364
|
+
let block_offset = (thread_in_batch / half_block) * half_block;
|
|
365
|
+
let local_offset = thread_in_batch % half_block;
|
|
366
|
+
let i = block_offset * 2u + local_offset;
|
|
367
|
+
let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
|
|
368
|
+
|
|
369
|
+
// Global version of compare_and_swap()
|
|
370
|
+
if (j < ${n}u) {
|
|
371
|
+
let val_i = output[base + i];
|
|
372
|
+
let val_j = output[base + j];
|
|
373
|
+
if (compare(val_j, val_i)) {
|
|
374
|
+
output[base + i] = val_j;
|
|
375
|
+
output[base + j] = val_i;
|
|
376
|
+
${outputIndices ? `
|
|
377
|
+
let tmp_idx = output_idx[base + i];
|
|
378
|
+
output_idx[base + i] = output_idx[base + j];
|
|
379
|
+
output_idx[base + j] = tmp_idx;` : ""}
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
`.trim();
|
|
385
|
+
const grid = calculateGrid(batches * workgroupsPerBatch);
|
|
386
|
+
const passes = [{ kind: "sort" }];
|
|
387
|
+
for (let mergeStage = numLocalStages; mergeStage < numStages; mergeStage++) for (let mergeStep = mergeStage; mergeStep >= numLocalStages - 1; mergeStep--) passes.push({
|
|
388
|
+
kind: "merge",
|
|
389
|
+
mergeStep,
|
|
390
|
+
mergeStage
|
|
391
|
+
});
|
|
392
|
+
return [{
|
|
393
|
+
code,
|
|
394
|
+
numInputs: 1,
|
|
395
|
+
numOutputs: outputIndices ? 2 : 1,
|
|
396
|
+
hasUniform: true,
|
|
397
|
+
passes: passes.map((pass) => ({
|
|
398
|
+
grid,
|
|
399
|
+
uniform: bitonicSortUniform(pass)
|
|
400
|
+
}))
|
|
401
|
+
}];
|
|
402
|
+
}
|
|
403
|
+
function createSort(device, type) {
|
|
404
|
+
const dtype = type.inputDtypes[0];
|
|
405
|
+
const shape = type.inputShapes[0];
|
|
406
|
+
const n = shape[shape.length - 1];
|
|
407
|
+
const batches = prod(shape.slice(0, -1));
|
|
408
|
+
return bitonicSortShader(device, dtype, n, batches, false);
|
|
409
|
+
}
|
|
410
|
+
function createArgsort(device, type) {
|
|
411
|
+
const dtype = type.inputDtypes[0];
|
|
412
|
+
const shape = type.inputShapes[0];
|
|
413
|
+
const n = shape[shape.length - 1];
|
|
414
|
+
const batches = prod(shape.slice(0, -1));
|
|
415
|
+
return bitonicSortShader(device, dtype, n, batches, true);
|
|
416
|
+
}
|
|
417
|
+
/**
|
|
418
|
+
* Generate a triangular solve shader.
|
|
419
|
+
*
|
|
420
|
+
* Solves A @ X.T = B.T for X, where A is upper-triangular.
|
|
421
|
+
* Uses a parallelized back-substitution:
|
|
422
|
+
* 1. Copy b to x
|
|
423
|
+
* 2. For j = n-1 down to 0:
|
|
424
|
+
* - Divide x[j] by a[j,j] (single thread)
|
|
425
|
+
* - All threads subtract x[j] * a[i,j] from x[i] for i < j in parallel
|
|
426
|
+
*/
|
|
427
|
+
function createTriangularSolve(device, type, params) {
|
|
428
|
+
const dtype = type.inputDtypes[0];
|
|
429
|
+
const aShape = type.inputShapes[0];
|
|
430
|
+
const bShape = type.inputShapes[1];
|
|
431
|
+
const n = aShape[aShape.length - 1];
|
|
432
|
+
const numRhs = bShape[bShape.length - 2];
|
|
433
|
+
const numMatrices = prod(aShape.slice(0, -2));
|
|
434
|
+
const needsF16 = dtype === DType.Float16;
|
|
435
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
436
|
+
const workgroupSize = findPow2(n, device.limits.maxComputeWorkgroupSizeX);
|
|
437
|
+
const code = `
|
|
438
|
+
${needsF16 ? "enable f16;" : ""}
|
|
439
|
+
${headerWgsl}
|
|
440
|
+
|
|
441
|
+
@group(0) @binding(0) var<storage, read> a: array<${ty}>;
|
|
442
|
+
@group(0) @binding(1) var<storage, read> b: array<${ty}>;
|
|
443
|
+
@group(0) @binding(2) var<storage, read_write> x: array<${ty}>;
|
|
444
|
+
|
|
445
|
+
// Shared memory for the current pivot value x[j]
|
|
446
|
+
var<workgroup> x_j: ${ty};
|
|
447
|
+
|
|
448
|
+
@compute @workgroup_size(${workgroupSize})
|
|
449
|
+
fn main(
|
|
450
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
451
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
452
|
+
) {
|
|
453
|
+
let wg_idx = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
454
|
+
let mat_idx = wg_idx / ${numRhs}u;
|
|
455
|
+
let rhs_idx = wg_idx % ${numRhs}u;
|
|
456
|
+
|
|
457
|
+
if (mat_idx >= ${numMatrices}u) {
|
|
458
|
+
return;
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
let a_base = mat_idx * ${n * n}u;
|
|
462
|
+
let bx_base = (mat_idx * ${numRhs}u + rhs_idx) * ${n}u;
|
|
463
|
+
let tid = local_id.x;
|
|
464
|
+
|
|
465
|
+
// Step 1: Copy b to x (threads collaborate)
|
|
466
|
+
for (var idx = tid; idx < ${n}u; idx += ${workgroupSize}u) {
|
|
467
|
+
x[bx_base + idx] = b[bx_base + idx];
|
|
468
|
+
}
|
|
469
|
+
storageBarrier();
|
|
470
|
+
|
|
471
|
+
// Step 2: Back-substitution from j = n-1 down to 0
|
|
472
|
+
for (var jj = 0u; jj < ${n}u; jj++) {
|
|
473
|
+
let j = ${n - 1}u - jj;
|
|
474
|
+
|
|
475
|
+
// Thread 0 computes x[j] = x[j] / a[j,j]
|
|
476
|
+
if (tid == 0u) {
|
|
477
|
+
${params.unitDiagonal ? `x_j = x[bx_base + j];` : `x_j = x[bx_base + j] / a[a_base + j * ${n}u + j];`}
|
|
478
|
+
x[bx_base + j] = x_j;
|
|
479
|
+
}
|
|
480
|
+
workgroupBarrier(); // Sync shared memory x_j
|
|
481
|
+
|
|
482
|
+
// All threads subtract x[j] * a[i,j] from x[i] for i < j
|
|
483
|
+
for (var i = tid; i < j; i += ${workgroupSize}u) {
|
|
484
|
+
x[bx_base + i] -= x_j * a[a_base + i * ${n}u + j];
|
|
485
|
+
}
|
|
486
|
+
workgroupBarrier();
|
|
487
|
+
storageBarrier();
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
`.trim();
|
|
491
|
+
const totalWorkgroups = numMatrices * numRhs;
|
|
492
|
+
const grid = calculateGrid(totalWorkgroups);
|
|
493
|
+
return [{
|
|
494
|
+
code,
|
|
495
|
+
numInputs: 2,
|
|
496
|
+
numOutputs: 1,
|
|
497
|
+
hasUniform: false,
|
|
498
|
+
passes: [{ grid }]
|
|
499
|
+
}];
|
|
500
|
+
}
|
|
501
|
+
/**
|
|
502
|
+
* Generate a Cholesky decomposition shader.
|
|
503
|
+
*
|
|
504
|
+
* Computes the lower triangular matrix L such that A = L * L^T for each
|
|
505
|
+
* positive semi-definite matrix in the batch. Uses the Cholesky-Crout
|
|
506
|
+
* algorithm which processes column-by-column.
|
|
507
|
+
*
|
|
508
|
+
* For each column j:
|
|
509
|
+
* 1. All threads compute their row's sum in parallel and store to output
|
|
510
|
+
* 2. Thread 0 computes L[j][j] = sqrt(output[j][j]) and stores to shared memory
|
|
511
|
+
* 3. All threads divide their output[i][j] by L[j][j] in parallel
|
|
512
|
+
*/
|
|
513
|
+
function createCholesky(device, type) {
|
|
514
|
+
const dtype = type.inputDtypes[0];
|
|
515
|
+
const shape = type.inputShapes[0];
|
|
516
|
+
const n = shape[shape.length - 1];
|
|
517
|
+
const batches = prod(shape.slice(0, -2));
|
|
518
|
+
const needsF16 = dtype === DType.Float16;
|
|
519
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
520
|
+
const workgroupSize = findPow2(n, device.limits.maxComputeWorkgroupSizeX);
|
|
521
|
+
const code = `
|
|
522
|
+
${needsF16 ? "enable f16;" : ""}
|
|
523
|
+
${headerWgsl}
|
|
524
|
+
|
|
525
|
+
@group(0) @binding(0) var<storage, read> input: array<${ty}>;
|
|
526
|
+
@group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
|
|
527
|
+
|
|
528
|
+
// Shared memory for the diagonal element
|
|
529
|
+
var<workgroup> L_jj: ${ty};
|
|
530
|
+
|
|
531
|
+
@compute @workgroup_size(${workgroupSize})
|
|
532
|
+
fn main(
|
|
533
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
534
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
535
|
+
) {
|
|
536
|
+
let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
537
|
+
if (batch >= ${batches}u) {
|
|
538
|
+
return;
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
let base = batch * ${n * n}u;
|
|
542
|
+
let tid = local_id.x;
|
|
543
|
+
|
|
544
|
+
// Zero out output and copy lower triangle from input (threads collaborate)
|
|
545
|
+
for (var idx = tid; idx < ${n * n}u; idx += ${workgroupSize}u) {
|
|
546
|
+
let row = idx / ${n}u;
|
|
547
|
+
let col = idx % ${n}u;
|
|
548
|
+
output[base + idx] = select(0, input[base + idx], col <= row);
|
|
549
|
+
}
|
|
550
|
+
storageBarrier();
|
|
551
|
+
|
|
552
|
+
// Cholesky-Crout algorithm: process column by column
|
|
553
|
+
for (var j = 0u; j < ${n}u; j++) {
|
|
554
|
+
// Step 1: All threads compute sum for their rows i >= j in parallel
|
|
555
|
+
// sum = A[i][j] - sum(L[i][k] * L[j][k] for k < j)
|
|
556
|
+
for (var i = j + tid; i < ${n}u; i += ${workgroupSize}u) {
|
|
557
|
+
var sum = output[base + i * ${n}u + j];
|
|
558
|
+
for (var k = 0u; k < j; k++) {
|
|
559
|
+
sum -= output[base + i * ${n}u + k] * output[base + j * ${n}u + k];
|
|
560
|
+
}
|
|
561
|
+
output[base + i * ${n}u + j] = sum;
|
|
562
|
+
}
|
|
563
|
+
storageBarrier();
|
|
564
|
+
|
|
565
|
+
// Step 2: Thread 0 computes L[j][j] = sqrt(output[j][j])
|
|
566
|
+
if (tid == 0u) {
|
|
567
|
+
L_jj = sqrt(output[base + j * ${n}u + j]);
|
|
568
|
+
output[base + j * ${n}u + j] = L_jj;
|
|
569
|
+
}
|
|
570
|
+
workgroupBarrier();
|
|
571
|
+
|
|
572
|
+
// Step 3: All threads divide output[i][j] by L[j][j] for i > j
|
|
573
|
+
for (var i = j + 1u + tid; i < ${n}u; i += ${workgroupSize}u) {
|
|
574
|
+
output[base + i * ${n}u + j] /= L_jj;
|
|
575
|
+
}
|
|
576
|
+
storageBarrier();
|
|
577
|
+
}
|
|
578
|
+
}
|
|
579
|
+
`.trim();
|
|
580
|
+
const grid = calculateGrid(batches);
|
|
581
|
+
return [{
|
|
582
|
+
code,
|
|
583
|
+
numInputs: 1,
|
|
584
|
+
numOutputs: 1,
|
|
585
|
+
hasUniform: false,
|
|
586
|
+
passes: [{ grid }]
|
|
587
|
+
}];
|
|
588
|
+
}
|
|
589
|
+
/**
|
|
590
|
+
* Generate an LU decomposition shader with partial pivoting.
|
|
591
|
+
*
|
|
592
|
+
* Computes PA = LU where P is a permutation matrix, L is lower triangular
|
|
593
|
+
* with unit diagonal, and U is upper triangular.
|
|
594
|
+
*
|
|
595
|
+
* For each column j:
|
|
596
|
+
* 1. Find pivot row (max absolute value in column j, rows >= j)
|
|
597
|
+
* 2. Swap rows j and pivot row
|
|
598
|
+
* 3. Compute L[i][j] = A[i][j] / A[j][j] for i > j
|
|
599
|
+
* 4. Update submatrix: A[i][k] -= L[i][j] * A[j][k] for i > j, k > j
|
|
600
|
+
*/
|
|
601
|
+
function createLU(device, type) {
|
|
602
|
+
const dtype = type.inputDtypes[0];
|
|
603
|
+
const shape = type.inputShapes[0];
|
|
604
|
+
const m = shape[shape.length - 2];
|
|
605
|
+
const n = shape[shape.length - 1];
|
|
606
|
+
const r = Math.min(m, n);
|
|
607
|
+
const batches = prod(shape.slice(0, -2));
|
|
608
|
+
const needsF16 = dtype === DType.Float16;
|
|
609
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
610
|
+
const workgroupSize = findPow2(Math.max(m, n), device.limits.maxComputeWorkgroupSizeX);
|
|
611
|
+
const code = `
|
|
612
|
+
${needsF16 ? "enable f16;" : ""}
|
|
613
|
+
${headerWgsl}
|
|
614
|
+
|
|
615
|
+
@group(0) @binding(0) var<storage, read> input: array<${ty}>;
|
|
616
|
+
@group(0) @binding(1) var<storage, read_write> lu: array<${ty}>;
|
|
617
|
+
@group(0) @binding(2) var<storage, read_write> pivots: array<i32>;
|
|
618
|
+
@group(0) @binding(3) var<storage, read_write> perm: array<i32>;
|
|
619
|
+
|
|
620
|
+
var<workgroup> pivot_row: u32;
|
|
621
|
+
var<workgroup> pivot_val: ${ty};
|
|
622
|
+
|
|
623
|
+
@compute @workgroup_size(${workgroupSize})
|
|
624
|
+
fn main(
|
|
625
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
626
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
627
|
+
) {
|
|
628
|
+
let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
629
|
+
if (batch >= ${batches}u) {
|
|
630
|
+
return;
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
let lu_base = batch * ${m * n}u;
|
|
634
|
+
let piv_base = batch * ${r}u;
|
|
635
|
+
let perm_base = batch * ${m}u;
|
|
636
|
+
let tid = local_id.x;
|
|
637
|
+
|
|
638
|
+
// Copy input to lu
|
|
639
|
+
for (var idx = tid; idx < ${m * n}u; idx += ${workgroupSize}u) {
|
|
640
|
+
lu[lu_base + idx] = input[lu_base + idx];
|
|
641
|
+
}
|
|
642
|
+
// Initialize permutation
|
|
643
|
+
for (var idx = tid; idx < ${m}u; idx += ${workgroupSize}u) {
|
|
644
|
+
perm[perm_base + idx] = i32(idx);
|
|
645
|
+
}
|
|
646
|
+
storageBarrier();
|
|
647
|
+
|
|
648
|
+
// LU decomposition with partial pivoting
|
|
649
|
+
for (var j = 0u; j < ${r}u; j++) {
|
|
650
|
+
// Step 1: Thread 0 finds pivot (max abs value in column j, rows >= j)
|
|
651
|
+
if (tid == 0u) {
|
|
652
|
+
var max_val = abs(lu[lu_base + j * ${n}u + j]);
|
|
653
|
+
var max_row = j;
|
|
654
|
+
for (var i = j + 1u; i < ${m}u; i++) {
|
|
655
|
+
let val = abs(lu[lu_base + i * ${n}u + j]);
|
|
656
|
+
if (val > max_val) {
|
|
657
|
+
max_val = val;
|
|
658
|
+
max_row = i;
|
|
659
|
+
}
|
|
660
|
+
}
|
|
661
|
+
pivot_row = max_row;
|
|
662
|
+
pivot_val = lu[lu_base + max_row * ${n}u + j];
|
|
663
|
+
pivots[piv_base + j] = i32(max_row);
|
|
664
|
+
}
|
|
665
|
+
workgroupBarrier();
|
|
666
|
+
|
|
667
|
+
// Step 2: Swap rows j and pivot_row (threads collaborate)
|
|
668
|
+
let pr = pivot_row;
|
|
669
|
+
if (pr != j) {
|
|
670
|
+
for (var col = tid; col < ${n}u; col += ${workgroupSize}u) {
|
|
671
|
+
let tmp = lu[lu_base + j * ${n}u + col];
|
|
672
|
+
lu[lu_base + j * ${n}u + col] = lu[lu_base + pr * ${n}u + col];
|
|
673
|
+
lu[lu_base + pr * ${n}u + col] = tmp;
|
|
674
|
+
}
|
|
675
|
+
if (tid == 0u) {
|
|
676
|
+
let tmp_p = perm[perm_base + j];
|
|
677
|
+
perm[perm_base + j] = perm[perm_base + pr];
|
|
678
|
+
perm[perm_base + pr] = tmp_p;
|
|
679
|
+
}
|
|
680
|
+
}
|
|
681
|
+
storageBarrier();
|
|
682
|
+
|
|
683
|
+
// Step 3: Compute L[i][j] and update submatrix
|
|
684
|
+
// Each thread handles one row i > j
|
|
685
|
+
for (var i = j + 1u + tid; i < ${m}u; i += ${workgroupSize}u) {
|
|
686
|
+
let factor = lu[lu_base + i * ${n}u + j] / pivot_val;
|
|
687
|
+
lu[lu_base + i * ${n}u + j] = factor; // L[i][j]
|
|
688
|
+
for (var k = j + 1u; k < ${n}u; k++) {
|
|
689
|
+
lu[lu_base + i * ${n}u + k] -= factor * lu[lu_base + j * ${n}u + k];
|
|
690
|
+
}
|
|
691
|
+
}
|
|
692
|
+
storageBarrier();
|
|
693
|
+
}
|
|
694
|
+
}
|
|
695
|
+
`.trim();
|
|
696
|
+
const grid = calculateGrid(batches);
|
|
697
|
+
return [{
|
|
698
|
+
code,
|
|
699
|
+
numInputs: 1,
|
|
700
|
+
numOutputs: 3,
|
|
701
|
+
hasUniform: false,
|
|
702
|
+
passes: [{ grid }]
|
|
703
|
+
}];
|
|
704
|
+
}
|
|
705
|
+
function createRoutineShader(device, routine) {
|
|
706
|
+
switch (routine.name) {
|
|
707
|
+
case Routines.Sort: return createSort(device, routine.type);
|
|
708
|
+
case Routines.Argsort: return createArgsort(device, routine.type);
|
|
709
|
+
case Routines.TriangularSolve: return createTriangularSolve(device, routine.type, routine.params);
|
|
710
|
+
case Routines.Cholesky: return createCholesky(device, routine.type);
|
|
711
|
+
case Routines.LU: return createLU(device, routine.type);
|
|
712
|
+
default: throw new UnsupportedRoutineError(routine.name, "webgpu");
|
|
713
|
+
}
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
//#endregion
|
|
717
|
+
//#region src/backend/webgpu.ts
|
|
718
|
+
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
719
|
+
var WebGPUBackend = class {
|
|
720
|
+
type = "webgpu";
|
|
721
|
+
maxArgs;
|
|
722
|
+
pipelines;
|
|
723
|
+
syncReader;
|
|
724
|
+
buffers;
|
|
725
|
+
nextSlot;
|
|
726
|
+
#cachedShaderMap = /* @__PURE__ */ new Map();
|
|
727
|
+
#reusableZsb;
|
|
728
|
+
constructor(device) {
|
|
729
|
+
this.device = device;
|
|
730
|
+
if (DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
|
|
731
|
+
this.maxArgs = this.device.limits.maxStorageBuffersPerShaderStage - 1;
|
|
732
|
+
this.pipelines = new ShaderPipelineCache(device);
|
|
733
|
+
this.syncReader = new SyncReader(device);
|
|
734
|
+
this.buffers = /* @__PURE__ */ new Map();
|
|
735
|
+
this.nextSlot = 1;
|
|
736
|
+
this.#reusableZsb = this.#createBuffer(4);
|
|
737
|
+
device.addEventListener("uncapturederror", (event) => {
|
|
738
|
+
console.error("Uncaptured error in WebGPU backend:", event.error.message);
|
|
739
|
+
});
|
|
740
|
+
}
|
|
741
|
+
malloc(size, initialData) {
|
|
742
|
+
let buffer;
|
|
743
|
+
const paddedSize = Math.ceil(size / 4) * 4;
|
|
744
|
+
if (size === 0) buffer = this.#reusableZsb;
|
|
745
|
+
else if (initialData) {
|
|
746
|
+
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
747
|
+
if (initialData.byteLength < 4096) {
|
|
748
|
+
buffer = this.#createBuffer(paddedSize, { mapped: true });
|
|
749
|
+
new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
|
|
750
|
+
buffer.unmap();
|
|
751
|
+
} else {
|
|
752
|
+
buffer = this.#createBuffer(paddedSize);
|
|
753
|
+
if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
754
|
+
else {
|
|
755
|
+
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
756
|
+
this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
757
|
+
const remainder = new Uint8Array(4);
|
|
758
|
+
remainder.set(initialData.subarray(aligned));
|
|
759
|
+
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
760
|
+
}
|
|
761
|
+
}
|
|
762
|
+
} else buffer = this.#createBuffer(paddedSize);
|
|
763
|
+
const slot = this.nextSlot++;
|
|
764
|
+
this.buffers.set(slot, {
|
|
765
|
+
buffer,
|
|
766
|
+
size,
|
|
767
|
+
ref: 1
|
|
768
|
+
});
|
|
769
|
+
return slot;
|
|
770
|
+
}
|
|
771
|
+
incRef(slot) {
|
|
772
|
+
const buffer = this.buffers.get(slot);
|
|
773
|
+
if (!buffer) throw new SlotError(slot);
|
|
774
|
+
buffer.ref++;
|
|
775
|
+
}
|
|
776
|
+
decRef(slot) {
|
|
777
|
+
const buffer = this.buffers.get(slot);
|
|
778
|
+
if (!buffer) throw new SlotError(slot);
|
|
779
|
+
buffer.ref--;
|
|
780
|
+
if (buffer.ref === 0) {
|
|
781
|
+
this.buffers.delete(slot);
|
|
782
|
+
if (buffer.buffer !== this.#reusableZsb) buffer.buffer.destroy();
|
|
783
|
+
}
|
|
784
|
+
}
|
|
785
|
+
async read(slot, start, count) {
|
|
786
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
787
|
+
if (buffer === this.#reusableZsb) return new Uint8Array();
|
|
788
|
+
if (start === void 0) start = 0;
|
|
789
|
+
if (count === void 0) count = size - start;
|
|
790
|
+
const paddedSize = Math.ceil(count / 4) * 4;
|
|
791
|
+
const staging = this.#createBuffer(paddedSize, { read: true });
|
|
792
|
+
try {
|
|
793
|
+
const commandEncoder = this.device.createCommandEncoder();
|
|
794
|
+
commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
|
|
795
|
+
this.device.queue.submit([commandEncoder.finish()]);
|
|
796
|
+
await staging.mapAsync(GPUMapMode.READ);
|
|
797
|
+
const arrayBuffer = staging.getMappedRange();
|
|
798
|
+
return new Uint8Array(arrayBuffer.slice(), 0, count);
|
|
799
|
+
} finally {
|
|
800
|
+
staging.destroy();
|
|
801
|
+
}
|
|
802
|
+
}
|
|
803
|
+
readSync(slot, start, count) {
|
|
804
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
805
|
+
if (buffer === this.#reusableZsb) return new Uint8Array();
|
|
806
|
+
if (start === void 0) start = 0;
|
|
807
|
+
if (count === void 0) count = size - start;
|
|
808
|
+
return this.syncReader.read(buffer, start, count);
|
|
809
|
+
}
|
|
810
|
+
#cachedShader(kernel) {
|
|
811
|
+
const cacheKey = FpHash.hash(kernel);
|
|
812
|
+
let result = this.#cachedShaderMap.get(cacheKey);
|
|
813
|
+
if (!result) {
|
|
814
|
+
result = pipelineSource(this.device, kernel);
|
|
815
|
+
this.#cachedShaderMap.set(cacheKey, result);
|
|
816
|
+
}
|
|
817
|
+
return result;
|
|
818
|
+
}
|
|
819
|
+
async prepareKernel(kernel) {
|
|
820
|
+
const shader = this.#cachedShader(kernel);
|
|
821
|
+
const pipeline = await this.pipelines.prepare(shader);
|
|
822
|
+
return new Executable(kernel, [{
|
|
823
|
+
...shader,
|
|
824
|
+
pipeline
|
|
825
|
+
}]);
|
|
826
|
+
}
|
|
827
|
+
prepareKernelSync(kernel) {
|
|
828
|
+
const shader = this.#cachedShader(kernel);
|
|
829
|
+
const pipeline = this.pipelines.prepareSync(shader);
|
|
830
|
+
return new Executable(kernel, [{
|
|
831
|
+
...shader,
|
|
832
|
+
pipeline
|
|
833
|
+
}]);
|
|
834
|
+
}
|
|
835
|
+
async prepareRoutine(routine) {
|
|
836
|
+
const shaders = createRoutineShader(this.device, routine);
|
|
837
|
+
const dispatches = await Promise.all(shaders.map(async (shader) => {
|
|
838
|
+
const pipeline = await this.pipelines.prepare(shader);
|
|
839
|
+
return {
|
|
840
|
+
...shader,
|
|
841
|
+
pipeline
|
|
842
|
+
};
|
|
843
|
+
}));
|
|
844
|
+
return new Executable(routine, dispatches);
|
|
845
|
+
}
|
|
846
|
+
prepareRoutineSync(routine) {
|
|
847
|
+
const shaders = createRoutineShader(this.device, routine);
|
|
848
|
+
const dispatches = shaders.map((shader) => {
|
|
849
|
+
const pipeline = this.pipelines.prepareSync(shader);
|
|
850
|
+
return {
|
|
851
|
+
...shader,
|
|
852
|
+
pipeline
|
|
853
|
+
};
|
|
854
|
+
});
|
|
855
|
+
return new Executable(routine, dispatches);
|
|
856
|
+
}
|
|
857
|
+
dispatch(exe, inputs, outputs) {
|
|
858
|
+
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
859
|
+
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
860
|
+
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
861
|
+
}
|
|
862
|
+
#getBuffer(slot) {
|
|
863
|
+
const buffer = this.buffers.get(slot);
|
|
864
|
+
if (!buffer) throw new SlotError(slot);
|
|
865
|
+
return {
|
|
866
|
+
buffer: buffer.buffer,
|
|
867
|
+
size: buffer.size
|
|
868
|
+
};
|
|
869
|
+
}
|
|
870
|
+
/**
|
|
871
|
+
* Create a GPU buffer.
|
|
872
|
+
*
|
|
873
|
+
* By default, this creates a general-purpose buffer with the given size.
|
|
874
|
+
*
|
|
875
|
+
* - If `mapped` is true, initialize the buffer in mapped mode so that it can
|
|
876
|
+
* be populated with data from the CPU. (Call `.unmap()` later.)
|
|
877
|
+
* - If `read` is true, create a staging buffer for returning data to CPU.
|
|
878
|
+
* (Call `.mapAsync()` later.)
|
|
879
|
+
*/
|
|
880
|
+
#createBuffer(size, { mapped = false, read = false } = {}) {
|
|
881
|
+
if (read && mapped) throw new Error("mapped and read cannot both be true");
|
|
882
|
+
const buffer = this.device.createBuffer({
|
|
883
|
+
size,
|
|
884
|
+
usage: read ? GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST : GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
|
|
885
|
+
mappedAtCreation: mapped
|
|
886
|
+
});
|
|
887
|
+
return buffer;
|
|
888
|
+
}
|
|
889
|
+
};
|
|
890
|
+
/**
|
|
891
|
+
* Compiles an expression into WebGPU shader source code.
|
|
892
|
+
*
|
|
893
|
+
* Returns the shader source and the number of workgroups to dispatch along x
|
|
894
|
+
* and y axes, to run the kernel.
|
|
895
|
+
*/
|
|
896
|
+
function pipelineSource(device, kernel) {
|
|
897
|
+
const tune = tuneWebgpu(kernel);
|
|
898
|
+
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
899
|
+
const { nargs, reduction: re } = kernel;
|
|
900
|
+
const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
|
|
901
|
+
const shader = [];
|
|
902
|
+
let indent = "";
|
|
903
|
+
const pushIndent = Symbol("pushIndent");
|
|
904
|
+
const popIndent = Symbol("popIndent");
|
|
905
|
+
const emit = (...lines) => {
|
|
906
|
+
for (const line of lines) if (line === pushIndent) indent += " ";
|
|
907
|
+
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
908
|
+
else shader.push(line ? indent + line : line);
|
|
909
|
+
};
|
|
910
|
+
if (tune.exp.some((exp) => exp.dtype === DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === DType.Float16)) {
|
|
911
|
+
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
912
|
+
emit("enable f16;");
|
|
913
|
+
}
|
|
914
|
+
emit(headerWgsl);
|
|
915
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
916
|
+
if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
|
|
917
|
+
if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
|
|
918
|
+
emit("");
|
|
919
|
+
const usedArgs = Array.from({ length: nargs }, () => null);
|
|
920
|
+
tune.exp.fold((exp) => {
|
|
921
|
+
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
922
|
+
});
|
|
923
|
+
tune.epilogue?.fold((exp) => {
|
|
924
|
+
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
925
|
+
});
|
|
926
|
+
for (let i = 0; i < nargs; i++) {
|
|
927
|
+
const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
|
|
928
|
+
emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
929
|
+
}
|
|
930
|
+
const resultTy = dtypeToWgsl(kernel.dtype, true);
|
|
931
|
+
emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
932
|
+
const workgroupSize = findPow2(tune.threadCount, 256);
|
|
933
|
+
const gridSize = Math.ceil(tune.threadCount / workgroupSize);
|
|
934
|
+
const [gridX, gridY] = calculateGrid(gridSize);
|
|
935
|
+
emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
|
|
936
|
+
if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
937
|
+
else {
|
|
938
|
+
const sizeX = gridX * workgroupSize;
|
|
939
|
+
emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
940
|
+
}
|
|
941
|
+
let gensymCount = 0;
|
|
942
|
+
const gensym = () => `alu${gensymCount++}`;
|
|
943
|
+
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
944
|
+
if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
945
|
+
const references = /* @__PURE__ */ new Map();
|
|
946
|
+
const seen = /* @__PURE__ */ new Set();
|
|
947
|
+
const countReferences = (exp) => {
|
|
948
|
+
references.set(exp, (references.get(exp) ?? 0) + 1);
|
|
949
|
+
if (!seen.has(exp)) {
|
|
950
|
+
seen.add(exp);
|
|
951
|
+
for (const src of exp.src) countReferences(src);
|
|
952
|
+
}
|
|
953
|
+
};
|
|
954
|
+
const expContext = /* @__PURE__ */ new Map();
|
|
955
|
+
const gen = (exp) => {
|
|
956
|
+
if (expContext.has(exp)) return expContext.get(exp);
|
|
957
|
+
const { op, src, dtype, arg } = exp;
|
|
958
|
+
let source = "";
|
|
959
|
+
if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
|
|
960
|
+
const a = gen(src[0]);
|
|
961
|
+
const b = gen(src[1]);
|
|
962
|
+
if (op === AluOp.Add) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
963
|
+
else source = `(${a} + ${b})`;
|
|
964
|
+
else if (op === AluOp.Sub) source = `(${a} - ${b})`;
|
|
965
|
+
else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
|
|
966
|
+
else source = `(${a} * ${b})`;
|
|
967
|
+
else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
968
|
+
else if (op === AluOp.Mod) source = `(${a} % ${b})`;
|
|
969
|
+
else if (op === AluOp.Min) if (dtype === DType.Bool) source = `(${a} && ${b})`;
|
|
970
|
+
else source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
971
|
+
else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
972
|
+
else source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
973
|
+
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
974
|
+
else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
|
|
975
|
+
const x = isGensym(a) ? a : gensym();
|
|
976
|
+
if (x !== a) emit(`let ${x} = ${a};`);
|
|
977
|
+
source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
|
|
978
|
+
} else source = `(${a} != ${b})`;
|
|
979
|
+
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
|
|
980
|
+
const a = gen(src[0].src[0]);
|
|
981
|
+
source = `inverseSqrt(${a})`;
|
|
982
|
+
} else {
|
|
983
|
+
const a = gen(src[0]);
|
|
984
|
+
if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
|
|
985
|
+
else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
|
|
986
|
+
else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
|
|
987
|
+
else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
|
|
988
|
+
else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
|
|
989
|
+
else if (op === AluOp.Log) source = `log(${strip1(a)})`;
|
|
990
|
+
else if (op === AluOp.Erf || op === AluOp.Erfc) {
|
|
991
|
+
const funcName = op === AluOp.Erf ? "erf" : "erfc";
|
|
992
|
+
if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${strip1(a)})))`;
|
|
993
|
+
else source = `${funcName}(${strip1(a)})`;
|
|
994
|
+
} else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
|
|
995
|
+
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
996
|
+
else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
|
|
997
|
+
else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
|
|
998
|
+
else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
|
|
999
|
+
else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
1000
|
+
}
|
|
1001
|
+
else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
|
|
1002
|
+
else if (op === AluOp.Threefry2x32) {
|
|
1003
|
+
const x = gensym();
|
|
1004
|
+
const [k0, k1, c0, c1] = src.map((x$1) => strip1(gen(x$1)));
|
|
1005
|
+
emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
|
|
1006
|
+
if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
|
|
1007
|
+
else if (arg === 0) source = `${x}.x`;
|
|
1008
|
+
else if (arg === 1) source = `${x}.y`;
|
|
1009
|
+
else throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
1010
|
+
} else if (op === AluOp.Const) return constToWgsl(dtype, arg);
|
|
1011
|
+
else if (op === AluOp.Special) return arg[0];
|
|
1012
|
+
else if (op === AluOp.Variable) return arg;
|
|
1013
|
+
else if (op === AluOp.GlobalIndex) {
|
|
1014
|
+
source = `${args[arg[0]]}[${strip1(gen(src[0]))}]`;
|
|
1015
|
+
if (dtype === DType.Bool) source = `(${source} != 0)`;
|
|
1016
|
+
}
|
|
1017
|
+
if (!source) throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
1018
|
+
const typeName = dtypeToWgsl(dtype);
|
|
1019
|
+
if ((references.get(exp) ?? 0) > 1) {
|
|
1020
|
+
const name = gensym();
|
|
1021
|
+
expContext.set(exp, name);
|
|
1022
|
+
emit(`let ${name}: ${typeName} = ${strip1(source)};`);
|
|
1023
|
+
return name;
|
|
1024
|
+
} else {
|
|
1025
|
+
expContext.set(exp, source);
|
|
1026
|
+
return source;
|
|
1027
|
+
}
|
|
1028
|
+
};
|
|
1029
|
+
if (!re) {
|
|
1030
|
+
countReferences(tune.exp);
|
|
1031
|
+
let rhs = strip1(gen(tune.exp));
|
|
1032
|
+
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1033
|
+
emit(`result[gidx] = ${rhs};`);
|
|
1034
|
+
} else {
|
|
1035
|
+
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
1036
|
+
const unroll = tune.size.unroll ?? 1;
|
|
1037
|
+
const upcast = tune.size.upcast ?? 1;
|
|
1038
|
+
const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
|
|
1039
|
+
for (let i = 0; i < upcast; i++) emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
|
|
1040
|
+
emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, pushIndent);
|
|
1041
|
+
const exps = [];
|
|
1042
|
+
const cache = /* @__PURE__ */ new Map();
|
|
1043
|
+
for (let up = 0; up < upcast; up++) {
|
|
1044
|
+
exps.push([]);
|
|
1045
|
+
for (let un = 0; un < unroll; un++) {
|
|
1046
|
+
const exp = tune.exp.substitute({
|
|
1047
|
+
upcast: AluExp.i32(up),
|
|
1048
|
+
unroll: AluExp.i32(un)
|
|
1049
|
+
});
|
|
1050
|
+
exps[up].push(exp.simplify(cache));
|
|
1051
|
+
countReferences(exps[up][un]);
|
|
1052
|
+
}
|
|
1053
|
+
}
|
|
1054
|
+
const items = exps.map((ar) => ar.map(gen).map(strip1));
|
|
1055
|
+
for (let i = 0; i < upcast; i++) {
|
|
1056
|
+
let rhs = items[i][0];
|
|
1057
|
+
for (let j = 1; j < unroll; j++) if (re.op === AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
|
|
1058
|
+
else if (re.op === AluOp.Mul) rhs = `${rhs} * ${items[i][j]}`;
|
|
1059
|
+
else if (re.op === AluOp.Min) rhs = re.dtype === DType.Bool ? `(${rhs} && ${items[i][j]})` : `min(${rhs}, ${items[i][j]})`;
|
|
1060
|
+
else if (re.op === AluOp.Max) rhs = re.dtype === DType.Bool ? `(${rhs} || ${items[i][j]})` : `max(${rhs}, ${items[i][j]})`;
|
|
1061
|
+
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1062
|
+
if (re.op === AluOp.Add) emit(`${acc[i]} += ${rhs};`);
|
|
1063
|
+
else if (re.op === AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
|
|
1064
|
+
else if (re.op === AluOp.Min) if (re.dtype === DType.Bool) emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
|
|
1065
|
+
else emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
|
|
1066
|
+
else if (re.op === AluOp.Max) if (re.dtype === DType.Bool) emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
|
|
1067
|
+
else emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
|
|
1068
|
+
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1069
|
+
}
|
|
1070
|
+
emit(popIndent, "}");
|
|
1071
|
+
expContext.clear();
|
|
1072
|
+
references.clear();
|
|
1073
|
+
seen.clear();
|
|
1074
|
+
const outputIdxExps = [];
|
|
1075
|
+
const fusionExps = [];
|
|
1076
|
+
for (let i = 0; i < upcast; i++) {
|
|
1077
|
+
const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
|
|
1078
|
+
outputIdxExps.push(exp.simplify(cache));
|
|
1079
|
+
countReferences(outputIdxExps[i]);
|
|
1080
|
+
fusionExps.push(tune.epilogue.substitute({
|
|
1081
|
+
acc: AluExp.variable(re.dtype, acc[i]),
|
|
1082
|
+
upcast: AluExp.i32(i)
|
|
1083
|
+
}).simplify(cache));
|
|
1084
|
+
countReferences(fusionExps[i]);
|
|
1085
|
+
}
|
|
1086
|
+
for (let i = 0; i < upcast; i++) {
|
|
1087
|
+
const index = strip1(gen(outputIdxExps[i]));
|
|
1088
|
+
let rhs = strip1(gen(fusionExps[i]));
|
|
1089
|
+
if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1090
|
+
emit(`result[${index}] = ${rhs};`);
|
|
1091
|
+
}
|
|
1092
|
+
}
|
|
1093
|
+
emit(popIndent, "}");
|
|
1094
|
+
return {
|
|
1095
|
+
code: shader.join("\n"),
|
|
1096
|
+
numInputs: nargs,
|
|
1097
|
+
numOutputs: 1,
|
|
1098
|
+
hasUniform: false,
|
|
1099
|
+
passes: [{ grid: [gridX, gridY] }]
|
|
1100
|
+
};
|
|
1101
|
+
}
|
|
1102
|
+
function pipelineSubmit(device, pipelines, inputs, outputs) {
|
|
1103
|
+
const commandEncoder = device.createCommandEncoder();
|
|
1104
|
+
for (const { pipeline,...shader } of pipelines) {
|
|
1105
|
+
if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
|
|
1106
|
+
const filteredPasses = shader.passes.filter(({ grid }) => prod(grid) > 0);
|
|
1107
|
+
if (filteredPasses.length === 0) continue;
|
|
1108
|
+
const bindGroup = device.createBindGroup({
|
|
1109
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
1110
|
+
entries: [...inputs.map((buffer, i) => ({
|
|
1111
|
+
binding: i,
|
|
1112
|
+
resource: { buffer }
|
|
1113
|
+
})), ...outputs.map((buffer, i) => ({
|
|
1114
|
+
binding: inputs.length + i,
|
|
1115
|
+
resource: { buffer }
|
|
1116
|
+
}))]
|
|
1117
|
+
});
|
|
1118
|
+
let uniformBindGroup = null;
|
|
1119
|
+
let uniformAlignment = 0;
|
|
1120
|
+
if (shader.hasUniform) {
|
|
1121
|
+
const uniforms = filteredPasses.map(({ uniform }) => uniform);
|
|
1122
|
+
const [uniformBuffer, alignment] = combineUniforms(device, uniforms);
|
|
1123
|
+
uniformAlignment = alignment;
|
|
1124
|
+
uniformBindGroup = device.createBindGroup({
|
|
1125
|
+
layout: pipeline.getBindGroupLayout(1),
|
|
1126
|
+
entries: [{
|
|
1127
|
+
binding: 0,
|
|
1128
|
+
resource: {
|
|
1129
|
+
buffer: uniformBuffer,
|
|
1130
|
+
size: alignment
|
|
1131
|
+
}
|
|
1132
|
+
}]
|
|
1133
|
+
});
|
|
1134
|
+
}
|
|
1135
|
+
for (let i = 0; i < filteredPasses.length; i++) {
|
|
1136
|
+
const { grid } = filteredPasses[i];
|
|
1137
|
+
const passEncoder = commandEncoder.beginComputePass();
|
|
1138
|
+
passEncoder.setPipeline(pipeline);
|
|
1139
|
+
passEncoder.setBindGroup(0, bindGroup);
|
|
1140
|
+
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|
|
1141
|
+
passEncoder.dispatchWorkgroups(grid[0], grid[1]);
|
|
1142
|
+
passEncoder.end();
|
|
1143
|
+
}
|
|
1144
|
+
}
|
|
1145
|
+
device.queue.submit([commandEncoder.finish()]);
|
|
1146
|
+
}
|
|
1147
|
+
function combineUniforms(device, uniforms) {
|
|
1148
|
+
for (const buf of uniforms) if (!buf || buf.byteLength === 0 || buf.byteLength !== uniforms[0].byteLength) throw new Error("webgpu: Uniform mismatch between shader passes");
|
|
1149
|
+
const minAlign = device.limits.minUniformBufferOffsetAlignment;
|
|
1150
|
+
const alignment = Math.ceil(uniforms[0].byteLength / minAlign) * minAlign;
|
|
1151
|
+
const buffer = device.createBuffer({
|
|
1152
|
+
size: alignment * uniforms.length,
|
|
1153
|
+
usage: GPUBufferUsage.UNIFORM,
|
|
1154
|
+
mappedAtCreation: true
|
|
1155
|
+
});
|
|
1156
|
+
const bufferMapped = new Uint8Array(buffer.getMappedRange());
|
|
1157
|
+
for (let i = 0; i < uniforms.length; i++) bufferMapped.set(uniforms[i], i * alignment);
|
|
1158
|
+
buffer.unmap();
|
|
1159
|
+
return [buffer, alignment];
|
|
1160
|
+
}
|
|
1161
|
+
/**
|
|
1162
|
+
* A cache for compiled GPU compute pipelines, keyed by the shader source.
|
|
1163
|
+
*
|
|
1164
|
+
* This supports both async compilation (recommended) and a synchronous variant.
|
|
1165
|
+
* If the pipeline is not in the cache, it will be compiled and added. For async
|
|
1166
|
+
* compilation, only one compilation will be in progress at a time for a given
|
|
1167
|
+
* shader source.
|
|
1168
|
+
*/
|
|
1169
|
+
var ShaderPipelineCache = class {
|
|
1170
|
+
cache;
|
|
1171
|
+
inProgress;
|
|
1172
|
+
constructor(device) {
|
|
1173
|
+
this.device = device;
|
|
1174
|
+
this.cache = /* @__PURE__ */ new Map();
|
|
1175
|
+
this.inProgress = /* @__PURE__ */ new Map();
|
|
1176
|
+
}
|
|
1177
|
+
#getLayout(shader) {
|
|
1178
|
+
if (shader.numInputs + shader.numOutputs > this.device.limits.maxStorageBuffersPerShaderStage) {
|
|
1179
|
+
const actual = shader.numInputs + shader.numOutputs;
|
|
1180
|
+
const max = this.device.limits.maxStorageBuffersPerShaderStage;
|
|
1181
|
+
throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
|
|
1182
|
+
}
|
|
1183
|
+
const bindGroupLayouts = [this.device.createBindGroupLayout({ entries: range(shader.numInputs + shader.numOutputs).map((i) => ({
|
|
1184
|
+
binding: i,
|
|
1185
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
1186
|
+
buffer: { type: i < shader.numInputs ? "read-only-storage" : "storage" }
|
|
1187
|
+
})) })];
|
|
1188
|
+
if (shader.hasUniform) bindGroupLayouts.push(this.device.createBindGroupLayout({ entries: [{
|
|
1189
|
+
binding: 0,
|
|
1190
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
1191
|
+
buffer: {
|
|
1192
|
+
type: "uniform",
|
|
1193
|
+
hasDynamicOffset: true
|
|
1194
|
+
}
|
|
1195
|
+
}] }));
|
|
1196
|
+
return this.device.createPipelineLayout({ bindGroupLayouts });
|
|
1197
|
+
}
|
|
1198
|
+
async prepare(shader) {
|
|
1199
|
+
const existingPipeline = this.cache.get(shader.code);
|
|
1200
|
+
if (existingPipeline) return existingPipeline;
|
|
1201
|
+
const existingPromise = this.inProgress.get(shader.code);
|
|
1202
|
+
if (existingPromise) return await existingPromise;
|
|
1203
|
+
if (DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
|
|
1204
|
+
const shaderModule = this.device.createShaderModule({ code: shader.code });
|
|
1205
|
+
const promise = (async () => {
|
|
1206
|
+
this.device.pushErrorScope("validation");
|
|
1207
|
+
try {
|
|
1208
|
+
const pipeline$1 = await this.device.createComputePipelineAsync({
|
|
1209
|
+
layout: this.#getLayout(shader),
|
|
1210
|
+
compute: {
|
|
1211
|
+
module: shaderModule,
|
|
1212
|
+
entryPoint: "main"
|
|
1213
|
+
}
|
|
1214
|
+
});
|
|
1215
|
+
await this.device.popErrorScope();
|
|
1216
|
+
return pipeline$1;
|
|
1217
|
+
} catch (_error) {
|
|
1218
|
+
const scope = await this.device.popErrorScope();
|
|
1219
|
+
const emsg = await compileError(shaderModule, scope, shader.code);
|
|
1220
|
+
throw new Error(emsg);
|
|
1221
|
+
}
|
|
1222
|
+
})();
|
|
1223
|
+
this.inProgress.set(shader.code, promise);
|
|
1224
|
+
const pipeline = await promise;
|
|
1225
|
+
this.cache.set(shader.code, pipeline);
|
|
1226
|
+
return pipeline;
|
|
1227
|
+
}
|
|
1228
|
+
prepareSync(shader) {
|
|
1229
|
+
const existingPipeline = this.cache.get(shader.code);
|
|
1230
|
+
if (existingPipeline) return existingPipeline;
|
|
1231
|
+
if (DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
|
|
1232
|
+
const shaderModule = this.device.createShaderModule({ code: shader.code });
|
|
1233
|
+
this.device.pushErrorScope("validation");
|
|
1234
|
+
const pipeline = this.device.createComputePipeline({
|
|
1235
|
+
layout: this.#getLayout(shader),
|
|
1236
|
+
compute: {
|
|
1237
|
+
module: shaderModule,
|
|
1238
|
+
entryPoint: "main"
|
|
1239
|
+
}
|
|
1240
|
+
});
|
|
1241
|
+
this.device.popErrorScope().then(async (scope) => {
|
|
1242
|
+
if (scope !== null) {
|
|
1243
|
+
const emsg = await compileError(shaderModule, scope, shader.code);
|
|
1244
|
+
console.error(emsg);
|
|
1245
|
+
}
|
|
1246
|
+
});
|
|
1247
|
+
this.cache.set(shader.code, pipeline);
|
|
1248
|
+
return pipeline;
|
|
1249
|
+
}
|
|
1250
|
+
};
|
|
1251
|
+
/** Gather information about a compilation error and format it. */
|
|
1252
|
+
async function compileError(shaderModule, scope, code) {
|
|
1253
|
+
let message = `Failed to compile shader: ${scope ? scope.message : "(no error scope)"}`;
|
|
1254
|
+
const info = await shaderModule.getCompilationInfo();
|
|
1255
|
+
for (const msg of info.messages) message += `\n [${msg.type} at ${msg.lineNum}:${msg.linePos}] ${msg.message}`;
|
|
1256
|
+
if (code) message += `\n\n${code}`;
|
|
1257
|
+
return message;
|
|
1258
|
+
}
|
|
1259
|
+
|
|
1260
|
+
//#endregion
|
|
1261
|
+
export { WebGPUBackend };
|