@jax-js/jax 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +16 -34
- package/dist/{backend-DeVfWEFS.cjs → backend-Bu9GY6sK.cjs} +222 -36
- package/dist/{backend-BqymqzuU.js → backend-tngXtWe4.js} +204 -36
- package/dist/index.cjs +1798 -955
- package/dist/index.d.cts +383 -97
- package/dist/index.d.ts +383 -97
- package/dist/index.js +1791 -949
- package/dist/{webgpu-BGuG58KZ.js → webgpu-ChVgx3b6.js} +410 -97
- package/dist/{webgpu-CcGP160M.cjs → webgpu-Oj3Kd-kd.cjs} +410 -97
- package/package.json +1 -1
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-Bu9GY6sK.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -66,6 +66,59 @@ fn erfc(x: f32) -> f32 {
|
|
|
66
66
|
return select(2.0 - E, E, x >= 0.0);
|
|
67
67
|
}`;
|
|
68
68
|
|
|
69
|
+
//#endregion
|
|
70
|
+
//#region src/backend/webgpu/codegen.ts
|
|
71
|
+
const headerWgsl = String.raw`
|
|
72
|
+
fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }
|
|
73
|
+
fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }
|
|
74
|
+
`.trim();
|
|
75
|
+
function dtypeToWgsl(dtype, storage = false) {
|
|
76
|
+
switch (dtype) {
|
|
77
|
+
case require_backend.DType.Bool: return storage ? "i32" : "bool";
|
|
78
|
+
case require_backend.DType.Int32: return "i32";
|
|
79
|
+
case require_backend.DType.Uint32: return "u32";
|
|
80
|
+
case require_backend.DType.Float32: return "f32";
|
|
81
|
+
case require_backend.DType.Float16: return "f16";
|
|
82
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
function maxValueWgsl(dtype) {
|
|
86
|
+
switch (dtype) {
|
|
87
|
+
case require_backend.DType.Bool: return "1";
|
|
88
|
+
case require_backend.DType.Int32: return "2147483647";
|
|
89
|
+
case require_backend.DType.Uint32: return "4294967295u";
|
|
90
|
+
case require_backend.DType.Float32: return "inf()";
|
|
91
|
+
case require_backend.DType.Float16: return "f16(inf())";
|
|
92
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
function constToWgsl(dtype, value) {
|
|
96
|
+
if (dtype === require_backend.DType.Bool) return value ? "true" : "false";
|
|
97
|
+
if (dtype === require_backend.DType.Int32) return value.toString();
|
|
98
|
+
if (dtype === require_backend.DType.Uint32) return value.toString() + "u";
|
|
99
|
+
if (dtype === require_backend.DType.Float32) {
|
|
100
|
+
if (Number.isNaN(value)) return "nan()";
|
|
101
|
+
if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
|
|
102
|
+
return "f32(" + value.toString() + ")";
|
|
103
|
+
}
|
|
104
|
+
if (dtype === require_backend.DType.Float16) {
|
|
105
|
+
if (Number.isNaN(value)) return "f16(nan())";
|
|
106
|
+
if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
|
|
107
|
+
return "f16(" + value.toString() + ")";
|
|
108
|
+
}
|
|
109
|
+
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
110
|
+
}
|
|
111
|
+
const gridOffsetY = 16384;
|
|
112
|
+
function calculateGrid(gridSize) {
|
|
113
|
+
let gridX = gridSize;
|
|
114
|
+
let gridY = 1;
|
|
115
|
+
if (gridSize > 65535) {
|
|
116
|
+
gridX = gridOffsetY;
|
|
117
|
+
gridY = Math.ceil(gridSize / gridOffsetY);
|
|
118
|
+
}
|
|
119
|
+
return [gridX, gridY];
|
|
120
|
+
}
|
|
121
|
+
|
|
69
122
|
//#endregion
|
|
70
123
|
//#region src/backend/webgpu/reader.ts
|
|
71
124
|
/**
|
|
@@ -170,6 +223,205 @@ var SyncReader = class SyncReader {
|
|
|
170
223
|
}
|
|
171
224
|
};
|
|
172
225
|
|
|
226
|
+
//#endregion
|
|
227
|
+
//#region src/backend/webgpu/routines.ts
|
|
228
|
+
function bitonicSortUniform(pass) {
|
|
229
|
+
const ar = new Uint32Array(3);
|
|
230
|
+
ar[0] = pass.kind === "sort" ? 0 : 1;
|
|
231
|
+
ar[1] = pass.mergeStep ?? 0;
|
|
232
|
+
ar[2] = pass.mergeStage ?? 0;
|
|
233
|
+
return new Uint8Array(ar.buffer);
|
|
234
|
+
}
|
|
235
|
+
/**
|
|
236
|
+
* Generate a bitonic sort shader.
|
|
237
|
+
*
|
|
238
|
+
* We implement a variant of bitonic sort that [only has forward comparators](
|
|
239
|
+
* <https://sortingalgos.miraheze.org/wiki/Bitonic_Sort#Bitonic_Sort_using_Forward_Comparators>),
|
|
240
|
+
* so we don't need to allocate memory for power-of-two padding.
|
|
241
|
+
*
|
|
242
|
+
* This uses workgroup shared memory up to `2*workgroupSize` elements, for each
|
|
243
|
+
* array in `batches`. For larger arrays, multiple passes are done:
|
|
244
|
+
*
|
|
245
|
+
* - Initial "sort" pass: each workgroup sorts its `2*workgroupSize` elements.
|
|
246
|
+
* - Subsequent "merge" passes: each pass merges sorted sequences of size
|
|
247
|
+
* `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
|
|
248
|
+
*
|
|
249
|
+
* The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
|
|
250
|
+
*/
|
|
251
|
+
function bitonicSortShader(device, dtype, n, batches, outputIndices) {
|
|
252
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
253
|
+
const paddedN = 1 << Math.ceil(Math.log2(n || 1));
|
|
254
|
+
const numThreads = Math.ceil(paddedN / 2);
|
|
255
|
+
const workgroupSize = require_backend.findPow2(numThreads, device.limits.maxComputeWorkgroupSizeX);
|
|
256
|
+
const workgroupsPerBatch = numThreads / workgroupSize;
|
|
257
|
+
const numStages = Math.log2(paddedN);
|
|
258
|
+
const numLocalStages = Math.min(numStages, Math.log2(workgroupSize * 2));
|
|
259
|
+
const needsF16 = dtype === require_backend.DType.Float16;
|
|
260
|
+
const padValue = require_backend.isFloatDtype(dtype) ? `${ty}(nan())` : maxValueWgsl(dtype);
|
|
261
|
+
const code = `
|
|
262
|
+
${needsF16 ? "enable f16;" : ""}
|
|
263
|
+
${headerWgsl}
|
|
264
|
+
|
|
265
|
+
struct Uniforms {
|
|
266
|
+
kind: u32, // 0 = sort, 1 = merge
|
|
267
|
+
merge_step: u32, // half_block = 2^step
|
|
268
|
+
merge_stage: u32, // only used for merge
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
@group(0) @binding(0) var<storage, read> input: array<${ty}>;
|
|
272
|
+
@group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
|
|
273
|
+
${outputIndices ? `@group(0) @binding(2) var<storage, read_write> output_idx: array<i32>;` : ""}
|
|
274
|
+
|
|
275
|
+
@group(1) @binding(0) var<uniform> uniforms: Uniforms;
|
|
276
|
+
|
|
277
|
+
var<workgroup> shared_vals: array<${ty}, ${workgroupSize * 2}>;
|
|
278
|
+
${outputIndices ? `var<workgroup> shared_idx: array<i32, ${workgroupSize * 2}>;` : ""}
|
|
279
|
+
|
|
280
|
+
fn compare(a: ${ty}, b: ${ty}) -> bool {
|
|
281
|
+
${require_backend.isFloatDtype(dtype) ? `
|
|
282
|
+
let min_value = min(a, b);
|
|
283
|
+
return a == min_value && b != min_value;` : " return a < b;"}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
fn compare_and_swap(i: u32, j: u32) {
|
|
287
|
+
let val_i = shared_vals[i];
|
|
288
|
+
let val_j = shared_vals[j];
|
|
289
|
+
if (compare(val_j, val_i)) {
|
|
290
|
+
shared_vals[i] = val_j;
|
|
291
|
+
shared_vals[j] = val_i;
|
|
292
|
+
${outputIndices ? `
|
|
293
|
+
let tmp_idx = shared_idx[i];
|
|
294
|
+
shared_idx[i] = shared_idx[j];
|
|
295
|
+
shared_idx[j] = tmp_idx;` : ""}
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
@compute @workgroup_size(${workgroupSize})
|
|
300
|
+
fn main(
|
|
301
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
302
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
303
|
+
) {
|
|
304
|
+
let blockid = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
305
|
+
let batch = blockid / ${workgroupsPerBatch}u;
|
|
306
|
+
let wg_in_batch = blockid % ${workgroupsPerBatch}u;
|
|
307
|
+
|
|
308
|
+
let tid = local_id.x;
|
|
309
|
+
let base = batch * ${n}u;
|
|
310
|
+
|
|
311
|
+
if (uniforms.kind == 0u || (uniforms.kind == 1u && uniforms.merge_step == ${numLocalStages - 1}u)) {
|
|
312
|
+
let wg_base = wg_in_batch * ${workgroupSize * 2}u;
|
|
313
|
+
|
|
314
|
+
// Load data into shared memory (2 elements per thread)
|
|
315
|
+
let idx0 = tid * 2u;
|
|
316
|
+
let idx1 = tid * 2u + 1u;
|
|
317
|
+
// Load from input for initial 'sort' pass, then from output (read-write) for 'merge' passes.
|
|
318
|
+
if (uniforms.kind == 0u) {
|
|
319
|
+
shared_vals[idx0] = select(${padValue}, input[base + wg_base + idx0], wg_base + idx0 < ${n}u);
|
|
320
|
+
shared_vals[idx1] = select(${padValue}, input[base + wg_base + idx1], wg_base + idx1 < ${n}u);
|
|
321
|
+
${outputIndices ? `
|
|
322
|
+
shared_idx[idx0] = i32(wg_base + idx0);
|
|
323
|
+
shared_idx[idx1] = i32(wg_base + idx1);` : ""}
|
|
324
|
+
} else {
|
|
325
|
+
shared_vals[idx0] = select(${padValue}, output[base + wg_base + idx0], wg_base + idx0 < ${n}u);
|
|
326
|
+
shared_vals[idx1] = select(${padValue}, output[base + wg_base + idx1], wg_base + idx1 < ${n}u);
|
|
327
|
+
${outputIndices ? `
|
|
328
|
+
shared_idx[idx0] = select(${n}, output_idx[base + wg_base + idx0], wg_base + idx0 < ${n}u);
|
|
329
|
+
shared_idx[idx1] = select(${n}, output_idx[base + wg_base + idx1], wg_base + idx1 < ${n}u);` : ""}
|
|
330
|
+
}
|
|
331
|
+
workgroupBarrier();
|
|
332
|
+
|
|
333
|
+
let initial_stage = select(0u, ${numLocalStages - 1}u, uniforms.kind != 0u);
|
|
334
|
+
for (var stage = initial_stage; stage < ${numLocalStages}u; stage++) {
|
|
335
|
+
for (var step1 = stage + 1u; step1 > 0u; step1--) {
|
|
336
|
+
let step = step1 - 1u;
|
|
337
|
+
let half_block = 1u << step;
|
|
338
|
+
let is_first_step = uniforms.kind == 0u && step == stage;
|
|
339
|
+
|
|
340
|
+
let block_offset = (tid / half_block) * half_block;
|
|
341
|
+
let local_offset = tid % half_block;
|
|
342
|
+
let i = block_offset * 2u + local_offset;
|
|
343
|
+
let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
|
|
344
|
+
compare_and_swap(i, j);
|
|
345
|
+
|
|
346
|
+
workgroupBarrier();
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
if (wg_base + idx0 < ${n}u) {
|
|
351
|
+
output[base + wg_base + idx0] = shared_vals[idx0];
|
|
352
|
+
${outputIndices ? `output_idx[base + wg_base + idx0] = shared_idx[idx0];` : ""}
|
|
353
|
+
}
|
|
354
|
+
if (wg_base + idx1 < ${n}u) {
|
|
355
|
+
output[base + wg_base + idx1] = shared_vals[idx1];
|
|
356
|
+
${outputIndices ? `output_idx[base + wg_base + idx1] = shared_idx[idx1];` : ""}
|
|
357
|
+
}
|
|
358
|
+
} else {
|
|
359
|
+
// Execute single merge pass for a step >= numLocalStages.
|
|
360
|
+
let half_block = 1u << uniforms.merge_step; // half_block >= workgroupSize * 2
|
|
361
|
+
let thread_in_batch = wg_in_batch * ${workgroupSize} + tid;
|
|
362
|
+
let is_first_step = uniforms.merge_step == uniforms.merge_stage;
|
|
363
|
+
|
|
364
|
+
let block_offset = (thread_in_batch / half_block) * half_block;
|
|
365
|
+
let local_offset = thread_in_batch % half_block;
|
|
366
|
+
let i = block_offset * 2u + local_offset;
|
|
367
|
+
let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
|
|
368
|
+
|
|
369
|
+
// Global version of compare_and_swap()
|
|
370
|
+
if (j < ${n}u) {
|
|
371
|
+
let val_i = output[base + i];
|
|
372
|
+
let val_j = output[base + j];
|
|
373
|
+
if (compare(val_j, val_i)) {
|
|
374
|
+
output[base + i] = val_j;
|
|
375
|
+
output[base + j] = val_i;
|
|
376
|
+
${outputIndices ? `
|
|
377
|
+
let tmp_idx = output_idx[base + i];
|
|
378
|
+
output_idx[base + i] = output_idx[base + j];
|
|
379
|
+
output_idx[base + j] = tmp_idx;` : ""}
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
`.trim();
|
|
385
|
+
const grid = calculateGrid(batches * workgroupsPerBatch);
|
|
386
|
+
const passes = [{ kind: "sort" }];
|
|
387
|
+
for (let mergeStage = numLocalStages; mergeStage < numStages; mergeStage++) for (let mergeStep = mergeStage; mergeStep >= numLocalStages - 1; mergeStep--) passes.push({
|
|
388
|
+
kind: "merge",
|
|
389
|
+
mergeStep,
|
|
390
|
+
mergeStage
|
|
391
|
+
});
|
|
392
|
+
return [{
|
|
393
|
+
code,
|
|
394
|
+
numInputs: 1,
|
|
395
|
+
numOutputs: outputIndices ? 2 : 1,
|
|
396
|
+
hasUniform: true,
|
|
397
|
+
passes: passes.map((pass) => ({
|
|
398
|
+
grid,
|
|
399
|
+
uniform: bitonicSortUniform(pass)
|
|
400
|
+
}))
|
|
401
|
+
}];
|
|
402
|
+
}
|
|
403
|
+
function createSort(device, type) {
|
|
404
|
+
const dtype = type.inputDtypes[0];
|
|
405
|
+
const shape = type.inputShapes[0];
|
|
406
|
+
const n = shape[shape.length - 1];
|
|
407
|
+
const batches = require_backend.prod(shape.slice(0, -1));
|
|
408
|
+
return bitonicSortShader(device, dtype, n, batches, false);
|
|
409
|
+
}
|
|
410
|
+
function createArgsort(device, type) {
|
|
411
|
+
const dtype = type.inputDtypes[0];
|
|
412
|
+
const shape = type.inputShapes[0];
|
|
413
|
+
const n = shape[shape.length - 1];
|
|
414
|
+
const batches = require_backend.prod(shape.slice(0, -1));
|
|
415
|
+
return bitonicSortShader(device, dtype, n, batches, true);
|
|
416
|
+
}
|
|
417
|
+
function createRoutineShader(device, routine) {
|
|
418
|
+
switch (routine.name) {
|
|
419
|
+
case require_backend.Routines.Sort: return createSort(device, routine.type);
|
|
420
|
+
case require_backend.Routines.Argsort: return createArgsort(device, routine.type);
|
|
421
|
+
default: throw new require_backend.UnsupportedRoutineError(routine.name, "webgpu");
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
|
|
173
425
|
//#endregion
|
|
174
426
|
//#region src/backend/webgpu.ts
|
|
175
427
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -181,6 +433,7 @@ var WebGPUBackend = class {
|
|
|
181
433
|
buffers;
|
|
182
434
|
nextSlot;
|
|
183
435
|
#cachedShaderMap = /* @__PURE__ */ new Map();
|
|
436
|
+
#reusableZsb;
|
|
184
437
|
constructor(device) {
|
|
185
438
|
this.device = device;
|
|
186
439
|
if (require_backend.DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
|
|
@@ -189,11 +442,16 @@ var WebGPUBackend = class {
|
|
|
189
442
|
this.syncReader = new SyncReader(device);
|
|
190
443
|
this.buffers = /* @__PURE__ */ new Map();
|
|
191
444
|
this.nextSlot = 1;
|
|
445
|
+
this.#reusableZsb = this.#createBuffer(4);
|
|
446
|
+
device.addEventListener("uncapturederror", (event) => {
|
|
447
|
+
console.error("Uncaptured error in WebGPU backend:", event.error.message);
|
|
448
|
+
});
|
|
192
449
|
}
|
|
193
450
|
malloc(size, initialData) {
|
|
194
451
|
let buffer;
|
|
195
452
|
const paddedSize = Math.ceil(size / 4) * 4;
|
|
196
|
-
if (
|
|
453
|
+
if (size === 0) buffer = this.#reusableZsb;
|
|
454
|
+
else if (initialData) {
|
|
197
455
|
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
198
456
|
if (initialData.byteLength < 4096) {
|
|
199
457
|
buffer = this.#createBuffer(paddedSize, { mapped: true });
|
|
@@ -230,11 +488,12 @@ var WebGPUBackend = class {
|
|
|
230
488
|
buffer.ref--;
|
|
231
489
|
if (buffer.ref === 0) {
|
|
232
490
|
this.buffers.delete(slot);
|
|
233
|
-
buffer.buffer.destroy();
|
|
491
|
+
if (buffer.buffer !== this.#reusableZsb) buffer.buffer.destroy();
|
|
234
492
|
}
|
|
235
493
|
}
|
|
236
494
|
async read(slot, start, count) {
|
|
237
495
|
const { buffer, size } = this.#getBuffer(slot);
|
|
496
|
+
if (buffer === this.#reusableZsb) return new Uint8Array();
|
|
238
497
|
if (start === void 0) start = 0;
|
|
239
498
|
if (count === void 0) count = size - start;
|
|
240
499
|
const paddedSize = Math.ceil(count / 4) * 4;
|
|
@@ -252,6 +511,7 @@ var WebGPUBackend = class {
|
|
|
252
511
|
}
|
|
253
512
|
readSync(slot, start, count) {
|
|
254
513
|
const { buffer, size } = this.#getBuffer(slot);
|
|
514
|
+
if (buffer === this.#reusableZsb) return new Uint8Array();
|
|
255
515
|
if (start === void 0) start = 0;
|
|
256
516
|
if (count === void 0) count = size - start;
|
|
257
517
|
return this.syncReader.read(buffer, start, count);
|
|
@@ -265,23 +525,43 @@ var WebGPUBackend = class {
|
|
|
265
525
|
}
|
|
266
526
|
return result;
|
|
267
527
|
}
|
|
268
|
-
async
|
|
269
|
-
const
|
|
528
|
+
async prepareKernel(kernel) {
|
|
529
|
+
const shader = this.#cachedShader(kernel);
|
|
270
530
|
const pipeline = await this.pipelines.prepare(shader);
|
|
271
|
-
return new require_backend.Executable(kernel, {
|
|
272
|
-
shader,
|
|
273
|
-
grid,
|
|
531
|
+
return new require_backend.Executable(kernel, [{
|
|
532
|
+
...shader,
|
|
274
533
|
pipeline
|
|
275
|
-
});
|
|
534
|
+
}]);
|
|
276
535
|
}
|
|
277
|
-
|
|
278
|
-
const
|
|
536
|
+
prepareKernelSync(kernel) {
|
|
537
|
+
const shader = this.#cachedShader(kernel);
|
|
279
538
|
const pipeline = this.pipelines.prepareSync(shader);
|
|
280
|
-
return new require_backend.Executable(kernel, {
|
|
281
|
-
shader,
|
|
282
|
-
grid,
|
|
539
|
+
return new require_backend.Executable(kernel, [{
|
|
540
|
+
...shader,
|
|
283
541
|
pipeline
|
|
542
|
+
}]);
|
|
543
|
+
}
|
|
544
|
+
async prepareRoutine(routine) {
|
|
545
|
+
const shaders = createRoutineShader(this.device, routine);
|
|
546
|
+
const dispatches = await Promise.all(shaders.map(async (shader) => {
|
|
547
|
+
const pipeline = await this.pipelines.prepare(shader);
|
|
548
|
+
return {
|
|
549
|
+
...shader,
|
|
550
|
+
pipeline
|
|
551
|
+
};
|
|
552
|
+
}));
|
|
553
|
+
return new require_backend.Executable(routine, dispatches);
|
|
554
|
+
}
|
|
555
|
+
prepareRoutineSync(routine) {
|
|
556
|
+
const shaders = createRoutineShader(this.device, routine);
|
|
557
|
+
const dispatches = shaders.map((shader) => {
|
|
558
|
+
const pipeline = this.pipelines.prepareSync(shader);
|
|
559
|
+
return {
|
|
560
|
+
...shader,
|
|
561
|
+
pipeline
|
|
562
|
+
};
|
|
284
563
|
});
|
|
564
|
+
return new require_backend.Executable(routine, dispatches);
|
|
285
565
|
}
|
|
286
566
|
dispatch(exe, inputs, outputs) {
|
|
287
567
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
@@ -316,32 +596,6 @@ var WebGPUBackend = class {
|
|
|
316
596
|
return buffer;
|
|
317
597
|
}
|
|
318
598
|
};
|
|
319
|
-
function dtypeToWgsl(dtype, storage = false) {
|
|
320
|
-
switch (dtype) {
|
|
321
|
-
case require_backend.DType.Bool: return storage ? "i32" : "bool";
|
|
322
|
-
case require_backend.DType.Int32: return "i32";
|
|
323
|
-
case require_backend.DType.Uint32: return "u32";
|
|
324
|
-
case require_backend.DType.Float32: return "f32";
|
|
325
|
-
case require_backend.DType.Float16: return "f16";
|
|
326
|
-
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
327
|
-
}
|
|
328
|
-
}
|
|
329
|
-
function constToWgsl(dtype, value) {
|
|
330
|
-
if (dtype === require_backend.DType.Bool) return value ? "true" : "false";
|
|
331
|
-
if (dtype === require_backend.DType.Int32) return value.toString();
|
|
332
|
-
if (dtype === require_backend.DType.Uint32) return value.toString() + "u";
|
|
333
|
-
if (dtype === require_backend.DType.Float32) {
|
|
334
|
-
if (Number.isNaN(value)) return "nan()";
|
|
335
|
-
if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
|
|
336
|
-
return "f32(" + value.toString() + ")";
|
|
337
|
-
}
|
|
338
|
-
if (dtype === require_backend.DType.Float16) {
|
|
339
|
-
if (Number.isNaN(value)) return "f16(nan())";
|
|
340
|
-
if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
|
|
341
|
-
return "f16(" + value.toString() + ")";
|
|
342
|
-
}
|
|
343
|
-
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
344
|
-
}
|
|
345
599
|
/**
|
|
346
600
|
* Compiles an expression into WebGPU shader source code.
|
|
347
601
|
*
|
|
@@ -362,12 +616,12 @@ function pipelineSource(device, kernel) {
|
|
|
362
616
|
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
363
617
|
else shader.push(line ? indent + line : line);
|
|
364
618
|
};
|
|
365
|
-
if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) ||
|
|
619
|
+
if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === require_backend.DType.Float16)) {
|
|
366
620
|
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
367
621
|
emit("enable f16;");
|
|
368
622
|
}
|
|
369
|
-
emit(
|
|
370
|
-
const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(),
|
|
623
|
+
emit(headerWgsl);
|
|
624
|
+
const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
371
625
|
if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
|
|
372
626
|
if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
|
|
373
627
|
emit("");
|
|
@@ -375,6 +629,9 @@ function pipelineSource(device, kernel) {
|
|
|
375
629
|
tune.exp.fold((exp) => {
|
|
376
630
|
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
377
631
|
});
|
|
632
|
+
tune.epilogue?.fold((exp) => {
|
|
633
|
+
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
634
|
+
});
|
|
378
635
|
for (let i = 0; i < nargs; i++) {
|
|
379
636
|
const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
|
|
380
637
|
emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
@@ -383,12 +640,7 @@ function pipelineSource(device, kernel) {
|
|
|
383
640
|
emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
384
641
|
const workgroupSize = require_backend.findPow2(tune.threadCount, 256);
|
|
385
642
|
const gridSize = Math.ceil(tune.threadCount / workgroupSize);
|
|
386
|
-
|
|
387
|
-
let gridY = 1;
|
|
388
|
-
if (gridSize > device.limits.maxComputeWorkgroupsPerDimension) {
|
|
389
|
-
gridX = 16384;
|
|
390
|
-
gridY = Math.ceil(gridSize / gridX);
|
|
391
|
-
}
|
|
643
|
+
const [gridX, gridY] = calculateGrid(gridSize);
|
|
392
644
|
emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
|
|
393
645
|
if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
394
646
|
else {
|
|
@@ -398,7 +650,7 @@ function pipelineSource(device, kernel) {
|
|
|
398
650
|
let gensymCount = 0;
|
|
399
651
|
const gensym = () => `alu${gensymCount++}`;
|
|
400
652
|
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
401
|
-
|
|
653
|
+
if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
402
654
|
const references = /* @__PURE__ */ new Map();
|
|
403
655
|
const seen = /* @__PURE__ */ new Set();
|
|
404
656
|
const countReferences = (exp) => {
|
|
@@ -511,13 +763,15 @@ function pipelineSource(device, kernel) {
|
|
|
511
763
|
let rhs = items[i][0];
|
|
512
764
|
for (let j = 1; j < unroll; j++) if (re.op === require_backend.AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
|
|
513
765
|
else if (re.op === require_backend.AluOp.Mul) rhs = `${rhs} * ${items[i][j]}`;
|
|
514
|
-
else if (re.op === require_backend.AluOp.Min) rhs = `min(${rhs}, ${items[i][j]})`;
|
|
515
|
-
else if (re.op === require_backend.AluOp.Max) rhs = `max(${rhs}, ${items[i][j]})`;
|
|
766
|
+
else if (re.op === require_backend.AluOp.Min) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} && ${items[i][j]})` : `min(${rhs}, ${items[i][j]})`;
|
|
767
|
+
else if (re.op === require_backend.AluOp.Max) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} || ${items[i][j]})` : `max(${rhs}, ${items[i][j]})`;
|
|
516
768
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
517
769
|
if (re.op === require_backend.AluOp.Add) emit(`${acc[i]} += ${rhs};`);
|
|
518
770
|
else if (re.op === require_backend.AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
|
|
519
|
-
else if (re.op === require_backend.AluOp.Min) emit(`${acc[i]} =
|
|
520
|
-
else
|
|
771
|
+
else if (re.op === require_backend.AluOp.Min) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
|
|
772
|
+
else emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
|
|
773
|
+
else if (re.op === require_backend.AluOp.Max) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
|
|
774
|
+
else emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
|
|
521
775
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
522
776
|
}
|
|
523
777
|
emit(popIndent, "}");
|
|
@@ -530,7 +784,10 @@ function pipelineSource(device, kernel) {
|
|
|
530
784
|
const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
|
|
531
785
|
outputIdxExps.push(exp.simplify(cache));
|
|
532
786
|
countReferences(outputIdxExps[i]);
|
|
533
|
-
fusionExps.push(
|
|
787
|
+
fusionExps.push(tune.epilogue.substitute({
|
|
788
|
+
acc: require_backend.AluExp.variable(re.dtype, acc[i]),
|
|
789
|
+
upcast: require_backend.AluExp.i32(i)
|
|
790
|
+
}).simplify(cache));
|
|
534
791
|
countReferences(fusionExps[i]);
|
|
535
792
|
}
|
|
536
793
|
for (let i = 0; i < upcast; i++) {
|
|
@@ -542,36 +799,72 @@ function pipelineSource(device, kernel) {
|
|
|
542
799
|
}
|
|
543
800
|
emit(popIndent, "}");
|
|
544
801
|
return {
|
|
545
|
-
|
|
546
|
-
|
|
802
|
+
code: shader.join("\n"),
|
|
803
|
+
numInputs: nargs,
|
|
804
|
+
numOutputs: 1,
|
|
805
|
+
hasUniform: false,
|
|
806
|
+
passes: [{ grid: [gridX, gridY] }]
|
|
547
807
|
};
|
|
548
808
|
}
|
|
549
|
-
function pipelineSubmit(device,
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
return {
|
|
809
|
+
function pipelineSubmit(device, pipelines, inputs, outputs) {
|
|
810
|
+
const commandEncoder = device.createCommandEncoder();
|
|
811
|
+
for (const { pipeline,...shader } of pipelines) {
|
|
812
|
+
if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
|
|
813
|
+
const filteredPasses = shader.passes.filter(({ grid }) => require_backend.prod(grid) > 0);
|
|
814
|
+
if (filteredPasses.length === 0) continue;
|
|
815
|
+
const bindGroup = device.createBindGroup({
|
|
816
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
817
|
+
entries: [...inputs.map((buffer, i) => ({
|
|
559
818
|
binding: i,
|
|
560
819
|
resource: { buffer }
|
|
561
|
-
}
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
}
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
820
|
+
})), ...outputs.map((buffer, i) => ({
|
|
821
|
+
binding: inputs.length + i,
|
|
822
|
+
resource: { buffer }
|
|
823
|
+
}))]
|
|
824
|
+
});
|
|
825
|
+
let uniformBindGroup = null;
|
|
826
|
+
let uniformAlignment = 0;
|
|
827
|
+
if (shader.hasUniform) {
|
|
828
|
+
const uniforms = filteredPasses.map(({ uniform }) => uniform);
|
|
829
|
+
const [uniformBuffer, alignment] = combineUniforms(device, uniforms);
|
|
830
|
+
uniformAlignment = alignment;
|
|
831
|
+
uniformBindGroup = device.createBindGroup({
|
|
832
|
+
layout: pipeline.getBindGroupLayout(1),
|
|
833
|
+
entries: [{
|
|
834
|
+
binding: 0,
|
|
835
|
+
resource: {
|
|
836
|
+
buffer: uniformBuffer,
|
|
837
|
+
size: alignment
|
|
838
|
+
}
|
|
839
|
+
}]
|
|
840
|
+
});
|
|
841
|
+
}
|
|
842
|
+
for (let i = 0; i < filteredPasses.length; i++) {
|
|
843
|
+
const { grid } = filteredPasses[i];
|
|
844
|
+
const passEncoder = commandEncoder.beginComputePass();
|
|
845
|
+
passEncoder.setPipeline(pipeline);
|
|
846
|
+
passEncoder.setBindGroup(0, bindGroup);
|
|
847
|
+
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|
|
848
|
+
passEncoder.dispatchWorkgroups(grid[0], grid[1]);
|
|
849
|
+
passEncoder.end();
|
|
850
|
+
}
|
|
851
|
+
}
|
|
573
852
|
device.queue.submit([commandEncoder.finish()]);
|
|
574
853
|
}
|
|
854
|
+
function combineUniforms(device, uniforms) {
|
|
855
|
+
for (const buf of uniforms) if (!buf || buf.byteLength === 0 || buf.byteLength !== uniforms[0].byteLength) throw new Error("webgpu: Uniform mismatch between shader passes");
|
|
856
|
+
const minAlign = device.limits.minUniformBufferOffsetAlignment;
|
|
857
|
+
const alignment = Math.ceil(uniforms[0].byteLength / minAlign) * minAlign;
|
|
858
|
+
const buffer = device.createBuffer({
|
|
859
|
+
size: alignment * uniforms.length,
|
|
860
|
+
usage: GPUBufferUsage.UNIFORM,
|
|
861
|
+
mappedAtCreation: true
|
|
862
|
+
});
|
|
863
|
+
const bufferMapped = new Uint8Array(buffer.getMappedRange());
|
|
864
|
+
for (let i = 0; i < uniforms.length; i++) bufferMapped.set(uniforms[i], i * alignment);
|
|
865
|
+
buffer.unmap();
|
|
866
|
+
return [buffer, alignment];
|
|
867
|
+
}
|
|
575
868
|
/**
|
|
576
869
|
* A cache for compiled GPU compute pipelines, keyed by the shader source.
|
|
577
870
|
*
|
|
@@ -588,18 +881,39 @@ var ShaderPipelineCache = class {
|
|
|
588
881
|
this.cache = /* @__PURE__ */ new Map();
|
|
589
882
|
this.inProgress = /* @__PURE__ */ new Map();
|
|
590
883
|
}
|
|
591
|
-
|
|
592
|
-
|
|
884
|
+
#getLayout(shader) {
|
|
885
|
+
if (shader.numInputs + shader.numOutputs > this.device.limits.maxStorageBuffersPerShaderStage) {
|
|
886
|
+
const actual = shader.numInputs + shader.numOutputs;
|
|
887
|
+
const max = this.device.limits.maxStorageBuffersPerShaderStage;
|
|
888
|
+
throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
|
|
889
|
+
}
|
|
890
|
+
const bindGroupLayouts = [this.device.createBindGroupLayout({ entries: require_backend.range(shader.numInputs + shader.numOutputs).map((i) => ({
|
|
891
|
+
binding: i,
|
|
892
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
893
|
+
buffer: { type: i < shader.numInputs ? "read-only-storage" : "storage" }
|
|
894
|
+
})) })];
|
|
895
|
+
if (shader.hasUniform) bindGroupLayouts.push(this.device.createBindGroupLayout({ entries: [{
|
|
896
|
+
binding: 0,
|
|
897
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
898
|
+
buffer: {
|
|
899
|
+
type: "uniform",
|
|
900
|
+
hasDynamicOffset: true
|
|
901
|
+
}
|
|
902
|
+
}] }));
|
|
903
|
+
return this.device.createPipelineLayout({ bindGroupLayouts });
|
|
904
|
+
}
|
|
905
|
+
async prepare(shader) {
|
|
906
|
+
const existingPipeline = this.cache.get(shader.code);
|
|
593
907
|
if (existingPipeline) return existingPipeline;
|
|
594
|
-
const existingPromise = this.inProgress.get(code);
|
|
908
|
+
const existingPromise = this.inProgress.get(shader.code);
|
|
595
909
|
if (existingPromise) return await existingPromise;
|
|
596
|
-
if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
|
|
597
|
-
const shaderModule = this.device.createShaderModule({ code });
|
|
910
|
+
if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
|
|
911
|
+
const shaderModule = this.device.createShaderModule({ code: shader.code });
|
|
598
912
|
const promise = (async () => {
|
|
599
913
|
this.device.pushErrorScope("validation");
|
|
600
914
|
try {
|
|
601
915
|
const pipeline$1 = await this.device.createComputePipelineAsync({
|
|
602
|
-
layout:
|
|
916
|
+
layout: this.#getLayout(shader),
|
|
603
917
|
compute: {
|
|
604
918
|
module: shaderModule,
|
|
605
919
|
entryPoint: "main"
|
|
@@ -609,23 +923,23 @@ var ShaderPipelineCache = class {
|
|
|
609
923
|
return pipeline$1;
|
|
610
924
|
} catch (_error) {
|
|
611
925
|
const scope = await this.device.popErrorScope();
|
|
612
|
-
const emsg = await compileError(shaderModule, scope, code);
|
|
926
|
+
const emsg = await compileError(shaderModule, scope, shader.code);
|
|
613
927
|
throw new Error(emsg);
|
|
614
928
|
}
|
|
615
929
|
})();
|
|
616
|
-
this.inProgress.set(code, promise);
|
|
930
|
+
this.inProgress.set(shader.code, promise);
|
|
617
931
|
const pipeline = await promise;
|
|
618
|
-
this.cache.set(code, pipeline);
|
|
932
|
+
this.cache.set(shader.code, pipeline);
|
|
619
933
|
return pipeline;
|
|
620
934
|
}
|
|
621
|
-
prepareSync(
|
|
622
|
-
const existingPipeline = this.cache.get(code);
|
|
935
|
+
prepareSync(shader) {
|
|
936
|
+
const existingPipeline = this.cache.get(shader.code);
|
|
623
937
|
if (existingPipeline) return existingPipeline;
|
|
624
|
-
if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
|
|
625
|
-
const shaderModule = this.device.createShaderModule({ code });
|
|
938
|
+
if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
|
|
939
|
+
const shaderModule = this.device.createShaderModule({ code: shader.code });
|
|
626
940
|
this.device.pushErrorScope("validation");
|
|
627
941
|
const pipeline = this.device.createComputePipeline({
|
|
628
|
-
layout:
|
|
942
|
+
layout: this.#getLayout(shader),
|
|
629
943
|
compute: {
|
|
630
944
|
module: shaderModule,
|
|
631
945
|
entryPoint: "main"
|
|
@@ -633,11 +947,11 @@ var ShaderPipelineCache = class {
|
|
|
633
947
|
});
|
|
634
948
|
this.device.popErrorScope().then(async (scope) => {
|
|
635
949
|
if (scope !== null) {
|
|
636
|
-
const emsg = await compileError(shaderModule, scope, code);
|
|
950
|
+
const emsg = await compileError(shaderModule, scope, shader.code);
|
|
637
951
|
console.error(emsg);
|
|
638
952
|
}
|
|
639
953
|
});
|
|
640
|
-
this.cache.set(code, pipeline);
|
|
954
|
+
this.cache.set(shader.code, pipeline);
|
|
641
955
|
return pipeline;
|
|
642
956
|
}
|
|
643
957
|
};
|
|
@@ -651,5 +965,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
651
965
|
}
|
|
652
966
|
|
|
653
967
|
//#endregion
|
|
654
|
-
exports.WebGPUBackend = WebGPUBackend;
|
|
655
|
-
//# sourceMappingURL=webgpu-CcGP160M.cjs.map
|
|
968
|
+
exports.WebGPUBackend = WebGPUBackend;
|