@jax-js/jax 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +5 -2
- package/dist/{backend-CmaidnkQ.cjs → backend-Bu9GY6sK.cjs} +166 -18
- package/dist/{backend-BY8wlLEl.js → backend-tngXtWe4.js} +148 -18
- package/dist/index.cjs +1683 -1004
- package/dist/index.d.cts +365 -95
- package/dist/index.d.ts +365 -95
- package/dist/index.js +1675 -997
- package/dist/{webgpu-C9iAP5h5.js → webgpu-ChVgx3b6.js} +400 -95
- package/dist/{webgpu-BVns4DbI.cjs → webgpu-Oj3Kd-kd.cjs} +400 -95
- package/package.json +1 -1
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-Bu9GY6sK.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -66,6 +66,59 @@ fn erfc(x: f32) -> f32 {
|
|
|
66
66
|
return select(2.0 - E, E, x >= 0.0);
|
|
67
67
|
}`;
|
|
68
68
|
|
|
69
|
+
//#endregion
|
|
70
|
+
//#region src/backend/webgpu/codegen.ts
|
|
71
|
+
const headerWgsl = String.raw`
|
|
72
|
+
fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }
|
|
73
|
+
fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }
|
|
74
|
+
`.trim();
|
|
75
|
+
function dtypeToWgsl(dtype, storage = false) {
|
|
76
|
+
switch (dtype) {
|
|
77
|
+
case require_backend.DType.Bool: return storage ? "i32" : "bool";
|
|
78
|
+
case require_backend.DType.Int32: return "i32";
|
|
79
|
+
case require_backend.DType.Uint32: return "u32";
|
|
80
|
+
case require_backend.DType.Float32: return "f32";
|
|
81
|
+
case require_backend.DType.Float16: return "f16";
|
|
82
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
function maxValueWgsl(dtype) {
|
|
86
|
+
switch (dtype) {
|
|
87
|
+
case require_backend.DType.Bool: return "1";
|
|
88
|
+
case require_backend.DType.Int32: return "2147483647";
|
|
89
|
+
case require_backend.DType.Uint32: return "4294967295u";
|
|
90
|
+
case require_backend.DType.Float32: return "inf()";
|
|
91
|
+
case require_backend.DType.Float16: return "f16(inf())";
|
|
92
|
+
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
function constToWgsl(dtype, value) {
|
|
96
|
+
if (dtype === require_backend.DType.Bool) return value ? "true" : "false";
|
|
97
|
+
if (dtype === require_backend.DType.Int32) return value.toString();
|
|
98
|
+
if (dtype === require_backend.DType.Uint32) return value.toString() + "u";
|
|
99
|
+
if (dtype === require_backend.DType.Float32) {
|
|
100
|
+
if (Number.isNaN(value)) return "nan()";
|
|
101
|
+
if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
|
|
102
|
+
return "f32(" + value.toString() + ")";
|
|
103
|
+
}
|
|
104
|
+
if (dtype === require_backend.DType.Float16) {
|
|
105
|
+
if (Number.isNaN(value)) return "f16(nan())";
|
|
106
|
+
if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
|
|
107
|
+
return "f16(" + value.toString() + ")";
|
|
108
|
+
}
|
|
109
|
+
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
110
|
+
}
|
|
111
|
+
const gridOffsetY = 16384;
|
|
112
|
+
function calculateGrid(gridSize) {
|
|
113
|
+
let gridX = gridSize;
|
|
114
|
+
let gridY = 1;
|
|
115
|
+
if (gridSize > 65535) {
|
|
116
|
+
gridX = gridOffsetY;
|
|
117
|
+
gridY = Math.ceil(gridSize / gridOffsetY);
|
|
118
|
+
}
|
|
119
|
+
return [gridX, gridY];
|
|
120
|
+
}
|
|
121
|
+
|
|
69
122
|
//#endregion
|
|
70
123
|
//#region src/backend/webgpu/reader.ts
|
|
71
124
|
/**
|
|
@@ -170,6 +223,205 @@ var SyncReader = class SyncReader {
|
|
|
170
223
|
}
|
|
171
224
|
};
|
|
172
225
|
|
|
226
|
+
//#endregion
|
|
227
|
+
//#region src/backend/webgpu/routines.ts
|
|
228
|
+
function bitonicSortUniform(pass) {
|
|
229
|
+
const ar = new Uint32Array(3);
|
|
230
|
+
ar[0] = pass.kind === "sort" ? 0 : 1;
|
|
231
|
+
ar[1] = pass.mergeStep ?? 0;
|
|
232
|
+
ar[2] = pass.mergeStage ?? 0;
|
|
233
|
+
return new Uint8Array(ar.buffer);
|
|
234
|
+
}
|
|
235
|
+
/**
|
|
236
|
+
* Generate a bitonic sort shader.
|
|
237
|
+
*
|
|
238
|
+
* We implement a variant of bitonic sort that [only has forward comparators](
|
|
239
|
+
* <https://sortingalgos.miraheze.org/wiki/Bitonic_Sort#Bitonic_Sort_using_Forward_Comparators>),
|
|
240
|
+
* so we don't need to allocate memory for power-of-two padding.
|
|
241
|
+
*
|
|
242
|
+
* This uses workgroup shared memory up to `2*workgroupSize` elements, for each
|
|
243
|
+
* array in `batches`. For larger arrays, multiple passes are done:
|
|
244
|
+
*
|
|
245
|
+
* - Initial "sort" pass: each workgroup sorts its `2*workgroupSize` elements.
|
|
246
|
+
* - Subsequent "merge" passes: each pass merges sorted sequences of size
|
|
247
|
+
* `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
|
|
248
|
+
*
|
|
249
|
+
* The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
|
|
250
|
+
*/
|
|
251
|
+
function bitonicSortShader(device, dtype, n, batches, outputIndices) {
|
|
252
|
+
const ty = dtypeToWgsl(dtype, true);
|
|
253
|
+
const paddedN = 1 << Math.ceil(Math.log2(n || 1));
|
|
254
|
+
const numThreads = Math.ceil(paddedN / 2);
|
|
255
|
+
const workgroupSize = require_backend.findPow2(numThreads, device.limits.maxComputeWorkgroupSizeX);
|
|
256
|
+
const workgroupsPerBatch = numThreads / workgroupSize;
|
|
257
|
+
const numStages = Math.log2(paddedN);
|
|
258
|
+
const numLocalStages = Math.min(numStages, Math.log2(workgroupSize * 2));
|
|
259
|
+
const needsF16 = dtype === require_backend.DType.Float16;
|
|
260
|
+
const padValue = require_backend.isFloatDtype(dtype) ? `${ty}(nan())` : maxValueWgsl(dtype);
|
|
261
|
+
const code = `
|
|
262
|
+
${needsF16 ? "enable f16;" : ""}
|
|
263
|
+
${headerWgsl}
|
|
264
|
+
|
|
265
|
+
struct Uniforms {
|
|
266
|
+
kind: u32, // 0 = sort, 1 = merge
|
|
267
|
+
merge_step: u32, // half_block = 2^step
|
|
268
|
+
merge_stage: u32, // only used for merge
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
@group(0) @binding(0) var<storage, read> input: array<${ty}>;
|
|
272
|
+
@group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
|
|
273
|
+
${outputIndices ? `@group(0) @binding(2) var<storage, read_write> output_idx: array<i32>;` : ""}
|
|
274
|
+
|
|
275
|
+
@group(1) @binding(0) var<uniform> uniforms: Uniforms;
|
|
276
|
+
|
|
277
|
+
var<workgroup> shared_vals: array<${ty}, ${workgroupSize * 2}>;
|
|
278
|
+
${outputIndices ? `var<workgroup> shared_idx: array<i32, ${workgroupSize * 2}>;` : ""}
|
|
279
|
+
|
|
280
|
+
fn compare(a: ${ty}, b: ${ty}) -> bool {
|
|
281
|
+
${require_backend.isFloatDtype(dtype) ? `
|
|
282
|
+
let min_value = min(a, b);
|
|
283
|
+
return a == min_value && b != min_value;` : " return a < b;"}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
fn compare_and_swap(i: u32, j: u32) {
|
|
287
|
+
let val_i = shared_vals[i];
|
|
288
|
+
let val_j = shared_vals[j];
|
|
289
|
+
if (compare(val_j, val_i)) {
|
|
290
|
+
shared_vals[i] = val_j;
|
|
291
|
+
shared_vals[j] = val_i;
|
|
292
|
+
${outputIndices ? `
|
|
293
|
+
let tmp_idx = shared_idx[i];
|
|
294
|
+
shared_idx[i] = shared_idx[j];
|
|
295
|
+
shared_idx[j] = tmp_idx;` : ""}
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
@compute @workgroup_size(${workgroupSize})
|
|
300
|
+
fn main(
|
|
301
|
+
@builtin(workgroup_id) wg_id: vec3<u32>,
|
|
302
|
+
@builtin(local_invocation_id) local_id: vec3<u32>,
|
|
303
|
+
) {
|
|
304
|
+
let blockid = wg_id.x + wg_id.y * ${gridOffsetY}u;
|
|
305
|
+
let batch = blockid / ${workgroupsPerBatch}u;
|
|
306
|
+
let wg_in_batch = blockid % ${workgroupsPerBatch}u;
|
|
307
|
+
|
|
308
|
+
let tid = local_id.x;
|
|
309
|
+
let base = batch * ${n}u;
|
|
310
|
+
|
|
311
|
+
if (uniforms.kind == 0u || (uniforms.kind == 1u && uniforms.merge_step == ${numLocalStages - 1}u)) {
|
|
312
|
+
let wg_base = wg_in_batch * ${workgroupSize * 2}u;
|
|
313
|
+
|
|
314
|
+
// Load data into shared memory (2 elements per thread)
|
|
315
|
+
let idx0 = tid * 2u;
|
|
316
|
+
let idx1 = tid * 2u + 1u;
|
|
317
|
+
// Load from input for initial 'sort' pass, then from output (read-write) for 'merge' passes.
|
|
318
|
+
if (uniforms.kind == 0u) {
|
|
319
|
+
shared_vals[idx0] = select(${padValue}, input[base + wg_base + idx0], wg_base + idx0 < ${n}u);
|
|
320
|
+
shared_vals[idx1] = select(${padValue}, input[base + wg_base + idx1], wg_base + idx1 < ${n}u);
|
|
321
|
+
${outputIndices ? `
|
|
322
|
+
shared_idx[idx0] = i32(wg_base + idx0);
|
|
323
|
+
shared_idx[idx1] = i32(wg_base + idx1);` : ""}
|
|
324
|
+
} else {
|
|
325
|
+
shared_vals[idx0] = select(${padValue}, output[base + wg_base + idx0], wg_base + idx0 < ${n}u);
|
|
326
|
+
shared_vals[idx1] = select(${padValue}, output[base + wg_base + idx1], wg_base + idx1 < ${n}u);
|
|
327
|
+
${outputIndices ? `
|
|
328
|
+
shared_idx[idx0] = select(${n}, output_idx[base + wg_base + idx0], wg_base + idx0 < ${n}u);
|
|
329
|
+
shared_idx[idx1] = select(${n}, output_idx[base + wg_base + idx1], wg_base + idx1 < ${n}u);` : ""}
|
|
330
|
+
}
|
|
331
|
+
workgroupBarrier();
|
|
332
|
+
|
|
333
|
+
let initial_stage = select(0u, ${numLocalStages - 1}u, uniforms.kind != 0u);
|
|
334
|
+
for (var stage = initial_stage; stage < ${numLocalStages}u; stage++) {
|
|
335
|
+
for (var step1 = stage + 1u; step1 > 0u; step1--) {
|
|
336
|
+
let step = step1 - 1u;
|
|
337
|
+
let half_block = 1u << step;
|
|
338
|
+
let is_first_step = uniforms.kind == 0u && step == stage;
|
|
339
|
+
|
|
340
|
+
let block_offset = (tid / half_block) * half_block;
|
|
341
|
+
let local_offset = tid % half_block;
|
|
342
|
+
let i = block_offset * 2u + local_offset;
|
|
343
|
+
let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
|
|
344
|
+
compare_and_swap(i, j);
|
|
345
|
+
|
|
346
|
+
workgroupBarrier();
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
if (wg_base + idx0 < ${n}u) {
|
|
351
|
+
output[base + wg_base + idx0] = shared_vals[idx0];
|
|
352
|
+
${outputIndices ? `output_idx[base + wg_base + idx0] = shared_idx[idx0];` : ""}
|
|
353
|
+
}
|
|
354
|
+
if (wg_base + idx1 < ${n}u) {
|
|
355
|
+
output[base + wg_base + idx1] = shared_vals[idx1];
|
|
356
|
+
${outputIndices ? `output_idx[base + wg_base + idx1] = shared_idx[idx1];` : ""}
|
|
357
|
+
}
|
|
358
|
+
} else {
|
|
359
|
+
// Execute single merge pass for a step >= numLocalStages.
|
|
360
|
+
let half_block = 1u << uniforms.merge_step; // half_block >= workgroupSize * 2
|
|
361
|
+
let thread_in_batch = wg_in_batch * ${workgroupSize} + tid;
|
|
362
|
+
let is_first_step = uniforms.merge_step == uniforms.merge_stage;
|
|
363
|
+
|
|
364
|
+
let block_offset = (thread_in_batch / half_block) * half_block;
|
|
365
|
+
let local_offset = thread_in_batch % half_block;
|
|
366
|
+
let i = block_offset * 2u + local_offset;
|
|
367
|
+
let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
|
|
368
|
+
|
|
369
|
+
// Global version of compare_and_swap()
|
|
370
|
+
if (j < ${n}u) {
|
|
371
|
+
let val_i = output[base + i];
|
|
372
|
+
let val_j = output[base + j];
|
|
373
|
+
if (compare(val_j, val_i)) {
|
|
374
|
+
output[base + i] = val_j;
|
|
375
|
+
output[base + j] = val_i;
|
|
376
|
+
${outputIndices ? `
|
|
377
|
+
let tmp_idx = output_idx[base + i];
|
|
378
|
+
output_idx[base + i] = output_idx[base + j];
|
|
379
|
+
output_idx[base + j] = tmp_idx;` : ""}
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
`.trim();
|
|
385
|
+
const grid = calculateGrid(batches * workgroupsPerBatch);
|
|
386
|
+
const passes = [{ kind: "sort" }];
|
|
387
|
+
for (let mergeStage = numLocalStages; mergeStage < numStages; mergeStage++) for (let mergeStep = mergeStage; mergeStep >= numLocalStages - 1; mergeStep--) passes.push({
|
|
388
|
+
kind: "merge",
|
|
389
|
+
mergeStep,
|
|
390
|
+
mergeStage
|
|
391
|
+
});
|
|
392
|
+
return [{
|
|
393
|
+
code,
|
|
394
|
+
numInputs: 1,
|
|
395
|
+
numOutputs: outputIndices ? 2 : 1,
|
|
396
|
+
hasUniform: true,
|
|
397
|
+
passes: passes.map((pass) => ({
|
|
398
|
+
grid,
|
|
399
|
+
uniform: bitonicSortUniform(pass)
|
|
400
|
+
}))
|
|
401
|
+
}];
|
|
402
|
+
}
|
|
403
|
+
function createSort(device, type) {
|
|
404
|
+
const dtype = type.inputDtypes[0];
|
|
405
|
+
const shape = type.inputShapes[0];
|
|
406
|
+
const n = shape[shape.length - 1];
|
|
407
|
+
const batches = require_backend.prod(shape.slice(0, -1));
|
|
408
|
+
return bitonicSortShader(device, dtype, n, batches, false);
|
|
409
|
+
}
|
|
410
|
+
function createArgsort(device, type) {
|
|
411
|
+
const dtype = type.inputDtypes[0];
|
|
412
|
+
const shape = type.inputShapes[0];
|
|
413
|
+
const n = shape[shape.length - 1];
|
|
414
|
+
const batches = require_backend.prod(shape.slice(0, -1));
|
|
415
|
+
return bitonicSortShader(device, dtype, n, batches, true);
|
|
416
|
+
}
|
|
417
|
+
function createRoutineShader(device, routine) {
|
|
418
|
+
switch (routine.name) {
|
|
419
|
+
case require_backend.Routines.Sort: return createSort(device, routine.type);
|
|
420
|
+
case require_backend.Routines.Argsort: return createArgsort(device, routine.type);
|
|
421
|
+
default: throw new require_backend.UnsupportedRoutineError(routine.name, "webgpu");
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
|
|
173
425
|
//#endregion
|
|
174
426
|
//#region src/backend/webgpu.ts
|
|
175
427
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -181,6 +433,7 @@ var WebGPUBackend = class {
|
|
|
181
433
|
buffers;
|
|
182
434
|
nextSlot;
|
|
183
435
|
#cachedShaderMap = /* @__PURE__ */ new Map();
|
|
436
|
+
#reusableZsb;
|
|
184
437
|
constructor(device) {
|
|
185
438
|
this.device = device;
|
|
186
439
|
if (require_backend.DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
|
|
@@ -189,11 +442,16 @@ var WebGPUBackend = class {
|
|
|
189
442
|
this.syncReader = new SyncReader(device);
|
|
190
443
|
this.buffers = /* @__PURE__ */ new Map();
|
|
191
444
|
this.nextSlot = 1;
|
|
445
|
+
this.#reusableZsb = this.#createBuffer(4);
|
|
446
|
+
device.addEventListener("uncapturederror", (event) => {
|
|
447
|
+
console.error("Uncaptured error in WebGPU backend:", event.error.message);
|
|
448
|
+
});
|
|
192
449
|
}
|
|
193
450
|
malloc(size, initialData) {
|
|
194
451
|
let buffer;
|
|
195
452
|
const paddedSize = Math.ceil(size / 4) * 4;
|
|
196
|
-
if (
|
|
453
|
+
if (size === 0) buffer = this.#reusableZsb;
|
|
454
|
+
else if (initialData) {
|
|
197
455
|
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
198
456
|
if (initialData.byteLength < 4096) {
|
|
199
457
|
buffer = this.#createBuffer(paddedSize, { mapped: true });
|
|
@@ -230,11 +488,12 @@ var WebGPUBackend = class {
|
|
|
230
488
|
buffer.ref--;
|
|
231
489
|
if (buffer.ref === 0) {
|
|
232
490
|
this.buffers.delete(slot);
|
|
233
|
-
buffer.buffer.destroy();
|
|
491
|
+
if (buffer.buffer !== this.#reusableZsb) buffer.buffer.destroy();
|
|
234
492
|
}
|
|
235
493
|
}
|
|
236
494
|
async read(slot, start, count) {
|
|
237
495
|
const { buffer, size } = this.#getBuffer(slot);
|
|
496
|
+
if (buffer === this.#reusableZsb) return new Uint8Array();
|
|
238
497
|
if (start === void 0) start = 0;
|
|
239
498
|
if (count === void 0) count = size - start;
|
|
240
499
|
const paddedSize = Math.ceil(count / 4) * 4;
|
|
@@ -252,6 +511,7 @@ var WebGPUBackend = class {
|
|
|
252
511
|
}
|
|
253
512
|
readSync(slot, start, count) {
|
|
254
513
|
const { buffer, size } = this.#getBuffer(slot);
|
|
514
|
+
if (buffer === this.#reusableZsb) return new Uint8Array();
|
|
255
515
|
if (start === void 0) start = 0;
|
|
256
516
|
if (count === void 0) count = size - start;
|
|
257
517
|
return this.syncReader.read(buffer, start, count);
|
|
@@ -265,27 +525,45 @@ var WebGPUBackend = class {
|
|
|
265
525
|
}
|
|
266
526
|
return result;
|
|
267
527
|
}
|
|
268
|
-
async
|
|
269
|
-
const
|
|
528
|
+
async prepareKernel(kernel) {
|
|
529
|
+
const shader = this.#cachedShader(kernel);
|
|
270
530
|
const pipeline = await this.pipelines.prepare(shader);
|
|
271
|
-
return new require_backend.Executable(kernel, {
|
|
272
|
-
shader,
|
|
273
|
-
grid,
|
|
531
|
+
return new require_backend.Executable(kernel, [{
|
|
532
|
+
...shader,
|
|
274
533
|
pipeline
|
|
275
|
-
});
|
|
534
|
+
}]);
|
|
276
535
|
}
|
|
277
|
-
|
|
278
|
-
const
|
|
536
|
+
prepareKernelSync(kernel) {
|
|
537
|
+
const shader = this.#cachedShader(kernel);
|
|
279
538
|
const pipeline = this.pipelines.prepareSync(shader);
|
|
280
|
-
return new require_backend.Executable(kernel, {
|
|
281
|
-
shader,
|
|
282
|
-
grid,
|
|
539
|
+
return new require_backend.Executable(kernel, [{
|
|
540
|
+
...shader,
|
|
283
541
|
pipeline
|
|
542
|
+
}]);
|
|
543
|
+
}
|
|
544
|
+
async prepareRoutine(routine) {
|
|
545
|
+
const shaders = createRoutineShader(this.device, routine);
|
|
546
|
+
const dispatches = await Promise.all(shaders.map(async (shader) => {
|
|
547
|
+
const pipeline = await this.pipelines.prepare(shader);
|
|
548
|
+
return {
|
|
549
|
+
...shader,
|
|
550
|
+
pipeline
|
|
551
|
+
};
|
|
552
|
+
}));
|
|
553
|
+
return new require_backend.Executable(routine, dispatches);
|
|
554
|
+
}
|
|
555
|
+
prepareRoutineSync(routine) {
|
|
556
|
+
const shaders = createRoutineShader(this.device, routine);
|
|
557
|
+
const dispatches = shaders.map((shader) => {
|
|
558
|
+
const pipeline = this.pipelines.prepareSync(shader);
|
|
559
|
+
return {
|
|
560
|
+
...shader,
|
|
561
|
+
pipeline
|
|
562
|
+
};
|
|
284
563
|
});
|
|
564
|
+
return new require_backend.Executable(routine, dispatches);
|
|
285
565
|
}
|
|
286
566
|
dispatch(exe, inputs, outputs) {
|
|
287
|
-
if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
|
|
288
|
-
if (exe.kernel.size === 0) return;
|
|
289
567
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
290
568
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
291
569
|
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
@@ -318,32 +596,6 @@ var WebGPUBackend = class {
|
|
|
318
596
|
return buffer;
|
|
319
597
|
}
|
|
320
598
|
};
|
|
321
|
-
function dtypeToWgsl(dtype, storage = false) {
|
|
322
|
-
switch (dtype) {
|
|
323
|
-
case require_backend.DType.Bool: return storage ? "i32" : "bool";
|
|
324
|
-
case require_backend.DType.Int32: return "i32";
|
|
325
|
-
case require_backend.DType.Uint32: return "u32";
|
|
326
|
-
case require_backend.DType.Float32: return "f32";
|
|
327
|
-
case require_backend.DType.Float16: return "f16";
|
|
328
|
-
default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
|
|
329
|
-
}
|
|
330
|
-
}
|
|
331
|
-
function constToWgsl(dtype, value) {
|
|
332
|
-
if (dtype === require_backend.DType.Bool) return value ? "true" : "false";
|
|
333
|
-
if (dtype === require_backend.DType.Int32) return value.toString();
|
|
334
|
-
if (dtype === require_backend.DType.Uint32) return value.toString() + "u";
|
|
335
|
-
if (dtype === require_backend.DType.Float32) {
|
|
336
|
-
if (Number.isNaN(value)) return "nan()";
|
|
337
|
-
if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
|
|
338
|
-
return "f32(" + value.toString() + ")";
|
|
339
|
-
}
|
|
340
|
-
if (dtype === require_backend.DType.Float16) {
|
|
341
|
-
if (Number.isNaN(value)) return "f16(nan())";
|
|
342
|
-
if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
|
|
343
|
-
return "f16(" + value.toString() + ")";
|
|
344
|
-
}
|
|
345
|
-
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
346
|
-
}
|
|
347
599
|
/**
|
|
348
600
|
* Compiles an expression into WebGPU shader source code.
|
|
349
601
|
*
|
|
@@ -368,7 +620,7 @@ function pipelineSource(device, kernel) {
|
|
|
368
620
|
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
369
621
|
emit("enable f16;");
|
|
370
622
|
}
|
|
371
|
-
emit(
|
|
623
|
+
emit(headerWgsl);
|
|
372
624
|
const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
373
625
|
if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
|
|
374
626
|
if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
|
|
@@ -388,12 +640,7 @@ function pipelineSource(device, kernel) {
|
|
|
388
640
|
emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
389
641
|
const workgroupSize = require_backend.findPow2(tune.threadCount, 256);
|
|
390
642
|
const gridSize = Math.ceil(tune.threadCount / workgroupSize);
|
|
391
|
-
|
|
392
|
-
let gridY = 1;
|
|
393
|
-
if (gridSize > device.limits.maxComputeWorkgroupsPerDimension) {
|
|
394
|
-
gridX = 16384;
|
|
395
|
-
gridY = Math.ceil(gridSize / gridX);
|
|
396
|
-
}
|
|
643
|
+
const [gridX, gridY] = calculateGrid(gridSize);
|
|
397
644
|
emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
|
|
398
645
|
if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
399
646
|
else {
|
|
@@ -516,13 +763,15 @@ function pipelineSource(device, kernel) {
|
|
|
516
763
|
let rhs = items[i][0];
|
|
517
764
|
for (let j = 1; j < unroll; j++) if (re.op === require_backend.AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
|
|
518
765
|
else if (re.op === require_backend.AluOp.Mul) rhs = `${rhs} * ${items[i][j]}`;
|
|
519
|
-
else if (re.op === require_backend.AluOp.Min) rhs = `min(${rhs}, ${items[i][j]})`;
|
|
520
|
-
else if (re.op === require_backend.AluOp.Max) rhs = `max(${rhs}, ${items[i][j]})`;
|
|
766
|
+
else if (re.op === require_backend.AluOp.Min) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} && ${items[i][j]})` : `min(${rhs}, ${items[i][j]})`;
|
|
767
|
+
else if (re.op === require_backend.AluOp.Max) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} || ${items[i][j]})` : `max(${rhs}, ${items[i][j]})`;
|
|
521
768
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
522
769
|
if (re.op === require_backend.AluOp.Add) emit(`${acc[i]} += ${rhs};`);
|
|
523
770
|
else if (re.op === require_backend.AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
|
|
524
|
-
else if (re.op === require_backend.AluOp.Min) emit(`${acc[i]} =
|
|
525
|
-
else
|
|
771
|
+
else if (re.op === require_backend.AluOp.Min) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
|
|
772
|
+
else emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
|
|
773
|
+
else if (re.op === require_backend.AluOp.Max) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
|
|
774
|
+
else emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
|
|
526
775
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
527
776
|
}
|
|
528
777
|
emit(popIndent, "}");
|
|
@@ -550,36 +799,72 @@ function pipelineSource(device, kernel) {
|
|
|
550
799
|
}
|
|
551
800
|
emit(popIndent, "}");
|
|
552
801
|
return {
|
|
553
|
-
|
|
554
|
-
|
|
802
|
+
code: shader.join("\n"),
|
|
803
|
+
numInputs: nargs,
|
|
804
|
+
numOutputs: 1,
|
|
805
|
+
hasUniform: false,
|
|
806
|
+
passes: [{ grid: [gridX, gridY] }]
|
|
555
807
|
};
|
|
556
808
|
}
|
|
557
|
-
function pipelineSubmit(device,
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
return {
|
|
809
|
+
function pipelineSubmit(device, pipelines, inputs, outputs) {
|
|
810
|
+
const commandEncoder = device.createCommandEncoder();
|
|
811
|
+
for (const { pipeline,...shader } of pipelines) {
|
|
812
|
+
if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
|
|
813
|
+
const filteredPasses = shader.passes.filter(({ grid }) => require_backend.prod(grid) > 0);
|
|
814
|
+
if (filteredPasses.length === 0) continue;
|
|
815
|
+
const bindGroup = device.createBindGroup({
|
|
816
|
+
layout: pipeline.getBindGroupLayout(0),
|
|
817
|
+
entries: [...inputs.map((buffer, i) => ({
|
|
567
818
|
binding: i,
|
|
568
819
|
resource: { buffer }
|
|
569
|
-
}
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
}
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
820
|
+
})), ...outputs.map((buffer, i) => ({
|
|
821
|
+
binding: inputs.length + i,
|
|
822
|
+
resource: { buffer }
|
|
823
|
+
}))]
|
|
824
|
+
});
|
|
825
|
+
let uniformBindGroup = null;
|
|
826
|
+
let uniformAlignment = 0;
|
|
827
|
+
if (shader.hasUniform) {
|
|
828
|
+
const uniforms = filteredPasses.map(({ uniform }) => uniform);
|
|
829
|
+
const [uniformBuffer, alignment] = combineUniforms(device, uniforms);
|
|
830
|
+
uniformAlignment = alignment;
|
|
831
|
+
uniformBindGroup = device.createBindGroup({
|
|
832
|
+
layout: pipeline.getBindGroupLayout(1),
|
|
833
|
+
entries: [{
|
|
834
|
+
binding: 0,
|
|
835
|
+
resource: {
|
|
836
|
+
buffer: uniformBuffer,
|
|
837
|
+
size: alignment
|
|
838
|
+
}
|
|
839
|
+
}]
|
|
840
|
+
});
|
|
841
|
+
}
|
|
842
|
+
for (let i = 0; i < filteredPasses.length; i++) {
|
|
843
|
+
const { grid } = filteredPasses[i];
|
|
844
|
+
const passEncoder = commandEncoder.beginComputePass();
|
|
845
|
+
passEncoder.setPipeline(pipeline);
|
|
846
|
+
passEncoder.setBindGroup(0, bindGroup);
|
|
847
|
+
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|
|
848
|
+
passEncoder.dispatchWorkgroups(grid[0], grid[1]);
|
|
849
|
+
passEncoder.end();
|
|
850
|
+
}
|
|
851
|
+
}
|
|
581
852
|
device.queue.submit([commandEncoder.finish()]);
|
|
582
853
|
}
|
|
854
|
+
function combineUniforms(device, uniforms) {
|
|
855
|
+
for (const buf of uniforms) if (!buf || buf.byteLength === 0 || buf.byteLength !== uniforms[0].byteLength) throw new Error("webgpu: Uniform mismatch between shader passes");
|
|
856
|
+
const minAlign = device.limits.minUniformBufferOffsetAlignment;
|
|
857
|
+
const alignment = Math.ceil(uniforms[0].byteLength / minAlign) * minAlign;
|
|
858
|
+
const buffer = device.createBuffer({
|
|
859
|
+
size: alignment * uniforms.length,
|
|
860
|
+
usage: GPUBufferUsage.UNIFORM,
|
|
861
|
+
mappedAtCreation: true
|
|
862
|
+
});
|
|
863
|
+
const bufferMapped = new Uint8Array(buffer.getMappedRange());
|
|
864
|
+
for (let i = 0; i < uniforms.length; i++) bufferMapped.set(uniforms[i], i * alignment);
|
|
865
|
+
buffer.unmap();
|
|
866
|
+
return [buffer, alignment];
|
|
867
|
+
}
|
|
583
868
|
/**
|
|
584
869
|
* A cache for compiled GPU compute pipelines, keyed by the shader source.
|
|
585
870
|
*
|
|
@@ -596,18 +881,39 @@ var ShaderPipelineCache = class {
|
|
|
596
881
|
this.cache = /* @__PURE__ */ new Map();
|
|
597
882
|
this.inProgress = /* @__PURE__ */ new Map();
|
|
598
883
|
}
|
|
599
|
-
|
|
600
|
-
|
|
884
|
+
#getLayout(shader) {
|
|
885
|
+
if (shader.numInputs + shader.numOutputs > this.device.limits.maxStorageBuffersPerShaderStage) {
|
|
886
|
+
const actual = shader.numInputs + shader.numOutputs;
|
|
887
|
+
const max = this.device.limits.maxStorageBuffersPerShaderStage;
|
|
888
|
+
throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
|
|
889
|
+
}
|
|
890
|
+
const bindGroupLayouts = [this.device.createBindGroupLayout({ entries: require_backend.range(shader.numInputs + shader.numOutputs).map((i) => ({
|
|
891
|
+
binding: i,
|
|
892
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
893
|
+
buffer: { type: i < shader.numInputs ? "read-only-storage" : "storage" }
|
|
894
|
+
})) })];
|
|
895
|
+
if (shader.hasUniform) bindGroupLayouts.push(this.device.createBindGroupLayout({ entries: [{
|
|
896
|
+
binding: 0,
|
|
897
|
+
visibility: GPUShaderStage.COMPUTE,
|
|
898
|
+
buffer: {
|
|
899
|
+
type: "uniform",
|
|
900
|
+
hasDynamicOffset: true
|
|
901
|
+
}
|
|
902
|
+
}] }));
|
|
903
|
+
return this.device.createPipelineLayout({ bindGroupLayouts });
|
|
904
|
+
}
|
|
905
|
+
async prepare(shader) {
|
|
906
|
+
const existingPipeline = this.cache.get(shader.code);
|
|
601
907
|
if (existingPipeline) return existingPipeline;
|
|
602
|
-
const existingPromise = this.inProgress.get(code);
|
|
908
|
+
const existingPromise = this.inProgress.get(shader.code);
|
|
603
909
|
if (existingPromise) return await existingPromise;
|
|
604
|
-
if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
|
|
605
|
-
const shaderModule = this.device.createShaderModule({ code });
|
|
910
|
+
if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
|
|
911
|
+
const shaderModule = this.device.createShaderModule({ code: shader.code });
|
|
606
912
|
const promise = (async () => {
|
|
607
913
|
this.device.pushErrorScope("validation");
|
|
608
914
|
try {
|
|
609
915
|
const pipeline$1 = await this.device.createComputePipelineAsync({
|
|
610
|
-
layout:
|
|
916
|
+
layout: this.#getLayout(shader),
|
|
611
917
|
compute: {
|
|
612
918
|
module: shaderModule,
|
|
613
919
|
entryPoint: "main"
|
|
@@ -617,23 +923,23 @@ var ShaderPipelineCache = class {
|
|
|
617
923
|
return pipeline$1;
|
|
618
924
|
} catch (_error) {
|
|
619
925
|
const scope = await this.device.popErrorScope();
|
|
620
|
-
const emsg = await compileError(shaderModule, scope, code);
|
|
926
|
+
const emsg = await compileError(shaderModule, scope, shader.code);
|
|
621
927
|
throw new Error(emsg);
|
|
622
928
|
}
|
|
623
929
|
})();
|
|
624
|
-
this.inProgress.set(code, promise);
|
|
930
|
+
this.inProgress.set(shader.code, promise);
|
|
625
931
|
const pipeline = await promise;
|
|
626
|
-
this.cache.set(code, pipeline);
|
|
932
|
+
this.cache.set(shader.code, pipeline);
|
|
627
933
|
return pipeline;
|
|
628
934
|
}
|
|
629
|
-
prepareSync(
|
|
630
|
-
const existingPipeline = this.cache.get(code);
|
|
935
|
+
prepareSync(shader) {
|
|
936
|
+
const existingPipeline = this.cache.get(shader.code);
|
|
631
937
|
if (existingPipeline) return existingPipeline;
|
|
632
|
-
if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
|
|
633
|
-
const shaderModule = this.device.createShaderModule({ code });
|
|
938
|
+
if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
|
|
939
|
+
const shaderModule = this.device.createShaderModule({ code: shader.code });
|
|
634
940
|
this.device.pushErrorScope("validation");
|
|
635
941
|
const pipeline = this.device.createComputePipeline({
|
|
636
|
-
layout:
|
|
942
|
+
layout: this.#getLayout(shader),
|
|
637
943
|
compute: {
|
|
638
944
|
module: shaderModule,
|
|
639
945
|
entryPoint: "main"
|
|
@@ -641,11 +947,11 @@ var ShaderPipelineCache = class {
|
|
|
641
947
|
});
|
|
642
948
|
this.device.popErrorScope().then(async (scope) => {
|
|
643
949
|
if (scope !== null) {
|
|
644
|
-
const emsg = await compileError(shaderModule, scope, code);
|
|
950
|
+
const emsg = await compileError(shaderModule, scope, shader.code);
|
|
645
951
|
console.error(emsg);
|
|
646
952
|
}
|
|
647
953
|
});
|
|
648
|
-
this.cache.set(code, pipeline);
|
|
954
|
+
this.cache.set(shader.code, pipeline);
|
|
649
955
|
return pipeline;
|
|
650
956
|
}
|
|
651
957
|
};
|
|
@@ -659,5 +965,4 @@ async function compileError(shaderModule, scope, code) {
|
|
|
659
965
|
}
|
|
660
966
|
|
|
661
967
|
//#endregion
|
|
662
|
-
exports.WebGPUBackend = WebGPUBackend;
|
|
663
|
-
//# sourceMappingURL=webgpu-BVns4DbI.cjs.map
|
|
968
|
+
exports.WebGPUBackend = WebGPUBackend;
|