@jax-js/jax 0.1.2 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-BqymqzuU.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, findPow2, isFloatDtype, mapSetUnion, prod, range, strip1, tuneWebgpu } from "./backend-tngXtWe4.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -66,6 +66,59 @@ fn erfc(x: f32) -> f32 {
66
66
  return select(2.0 - E, E, x >= 0.0);
67
67
  }`;
68
68
 
69
+ //#endregion
70
+ //#region src/backend/webgpu/codegen.ts
71
+ const headerWgsl = String.raw`
72
+ fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }
73
+ fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }
74
+ `.trim();
75
+ function dtypeToWgsl(dtype, storage = false) {
76
+ switch (dtype) {
77
+ case DType.Bool: return storage ? "i32" : "bool";
78
+ case DType.Int32: return "i32";
79
+ case DType.Uint32: return "u32";
80
+ case DType.Float32: return "f32";
81
+ case DType.Float16: return "f16";
82
+ default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
83
+ }
84
+ }
85
+ function maxValueWgsl(dtype) {
86
+ switch (dtype) {
87
+ case DType.Bool: return "1";
88
+ case DType.Int32: return "2147483647";
89
+ case DType.Uint32: return "4294967295u";
90
+ case DType.Float32: return "inf()";
91
+ case DType.Float16: return "f16(inf())";
92
+ default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
93
+ }
94
+ }
95
+ function constToWgsl(dtype, value) {
96
+ if (dtype === DType.Bool) return value ? "true" : "false";
97
+ if (dtype === DType.Int32) return value.toString();
98
+ if (dtype === DType.Uint32) return value.toString() + "u";
99
+ if (dtype === DType.Float32) {
100
+ if (Number.isNaN(value)) return "nan()";
101
+ if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
102
+ return "f32(" + value.toString() + ")";
103
+ }
104
+ if (dtype === DType.Float16) {
105
+ if (Number.isNaN(value)) return "f16(nan())";
106
+ if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
107
+ return "f16(" + value.toString() + ")";
108
+ }
109
+ throw new Error(`Unsupported const dtype: ${dtype}`);
110
+ }
111
+ const gridOffsetY = 16384;
112
+ function calculateGrid(gridSize) {
113
+ let gridX = gridSize;
114
+ let gridY = 1;
115
+ if (gridSize > 65535) {
116
+ gridX = gridOffsetY;
117
+ gridY = Math.ceil(gridSize / gridOffsetY);
118
+ }
119
+ return [gridX, gridY];
120
+ }
121
+
69
122
  //#endregion
70
123
  //#region src/backend/webgpu/reader.ts
71
124
  /**
@@ -170,6 +223,205 @@ var SyncReader = class SyncReader {
170
223
  }
171
224
  };
172
225
 
226
+ //#endregion
227
+ //#region src/backend/webgpu/routines.ts
228
+ function bitonicSortUniform(pass) {
229
+ const ar = new Uint32Array(3);
230
+ ar[0] = pass.kind === "sort" ? 0 : 1;
231
+ ar[1] = pass.mergeStep ?? 0;
232
+ ar[2] = pass.mergeStage ?? 0;
233
+ return new Uint8Array(ar.buffer);
234
+ }
235
+ /**
236
+ * Generate a bitonic sort shader.
237
+ *
238
+ * We implement a variant of bitonic sort that [only has forward comparators](
239
+ * <https://sortingalgos.miraheze.org/wiki/Bitonic_Sort#Bitonic_Sort_using_Forward_Comparators>),
240
+ * so we don't need to allocate memory for power-of-two padding.
241
+ *
242
+ * This uses workgroup shared memory up to `2*workgroupSize` elements, for each
243
+ * array in `batches`. For larger arrays, multiple passes are done:
244
+ *
245
+ * - Initial "sort" pass: each workgroup sorts its `2*workgroupSize` elements.
246
+ * - Subsequent "merge" passes: each pass merges sorted sequences of size
247
+ * `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
248
+ *
249
+ * The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
250
+ */
251
+ function bitonicSortShader(device, dtype, n, batches, outputIndices) {
252
+ const ty = dtypeToWgsl(dtype, true);
253
+ const paddedN = 1 << Math.ceil(Math.log2(n || 1));
254
+ const numThreads = Math.ceil(paddedN / 2);
255
+ const workgroupSize = findPow2(numThreads, device.limits.maxComputeWorkgroupSizeX);
256
+ const workgroupsPerBatch = numThreads / workgroupSize;
257
+ const numStages = Math.log2(paddedN);
258
+ const numLocalStages = Math.min(numStages, Math.log2(workgroupSize * 2));
259
+ const needsF16 = dtype === DType.Float16;
260
+ const padValue = isFloatDtype(dtype) ? `${ty}(nan())` : maxValueWgsl(dtype);
261
+ const code = `
262
+ ${needsF16 ? "enable f16;" : ""}
263
+ ${headerWgsl}
264
+
265
+ struct Uniforms {
266
+ kind: u32, // 0 = sort, 1 = merge
267
+ merge_step: u32, // half_block = 2^step
268
+ merge_stage: u32, // only used for merge
269
+ }
270
+
271
+ @group(0) @binding(0) var<storage, read> input: array<${ty}>;
272
+ @group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
273
+ ${outputIndices ? `@group(0) @binding(2) var<storage, read_write> output_idx: array<i32>;` : ""}
274
+
275
+ @group(1) @binding(0) var<uniform> uniforms: Uniforms;
276
+
277
+ var<workgroup> shared_vals: array<${ty}, ${workgroupSize * 2}>;
278
+ ${outputIndices ? `var<workgroup> shared_idx: array<i32, ${workgroupSize * 2}>;` : ""}
279
+
280
+ fn compare(a: ${ty}, b: ${ty}) -> bool {
281
+ ${isFloatDtype(dtype) ? `
282
+ let min_value = min(a, b);
283
+ return a == min_value && b != min_value;` : " return a < b;"}
284
+ }
285
+
286
+ fn compare_and_swap(i: u32, j: u32) {
287
+ let val_i = shared_vals[i];
288
+ let val_j = shared_vals[j];
289
+ if (compare(val_j, val_i)) {
290
+ shared_vals[i] = val_j;
291
+ shared_vals[j] = val_i;
292
+ ${outputIndices ? `
293
+ let tmp_idx = shared_idx[i];
294
+ shared_idx[i] = shared_idx[j];
295
+ shared_idx[j] = tmp_idx;` : ""}
296
+ }
297
+ }
298
+
299
+ @compute @workgroup_size(${workgroupSize})
300
+ fn main(
301
+ @builtin(workgroup_id) wg_id: vec3<u32>,
302
+ @builtin(local_invocation_id) local_id: vec3<u32>,
303
+ ) {
304
+ let blockid = wg_id.x + wg_id.y * ${gridOffsetY}u;
305
+ let batch = blockid / ${workgroupsPerBatch}u;
306
+ let wg_in_batch = blockid % ${workgroupsPerBatch}u;
307
+
308
+ let tid = local_id.x;
309
+ let base = batch * ${n}u;
310
+
311
+ if (uniforms.kind == 0u || (uniforms.kind == 1u && uniforms.merge_step == ${numLocalStages - 1}u)) {
312
+ let wg_base = wg_in_batch * ${workgroupSize * 2}u;
313
+
314
+ // Load data into shared memory (2 elements per thread)
315
+ let idx0 = tid * 2u;
316
+ let idx1 = tid * 2u + 1u;
317
+ // Load from input for initial 'sort' pass, then from output (read-write) for 'merge' passes.
318
+ if (uniforms.kind == 0u) {
319
+ shared_vals[idx0] = select(${padValue}, input[base + wg_base + idx0], wg_base + idx0 < ${n}u);
320
+ shared_vals[idx1] = select(${padValue}, input[base + wg_base + idx1], wg_base + idx1 < ${n}u);
321
+ ${outputIndices ? `
322
+ shared_idx[idx0] = i32(wg_base + idx0);
323
+ shared_idx[idx1] = i32(wg_base + idx1);` : ""}
324
+ } else {
325
+ shared_vals[idx0] = select(${padValue}, output[base + wg_base + idx0], wg_base + idx0 < ${n}u);
326
+ shared_vals[idx1] = select(${padValue}, output[base + wg_base + idx1], wg_base + idx1 < ${n}u);
327
+ ${outputIndices ? `
328
+ shared_idx[idx0] = select(${n}, output_idx[base + wg_base + idx0], wg_base + idx0 < ${n}u);
329
+ shared_idx[idx1] = select(${n}, output_idx[base + wg_base + idx1], wg_base + idx1 < ${n}u);` : ""}
330
+ }
331
+ workgroupBarrier();
332
+
333
+ let initial_stage = select(0u, ${numLocalStages - 1}u, uniforms.kind != 0u);
334
+ for (var stage = initial_stage; stage < ${numLocalStages}u; stage++) {
335
+ for (var step1 = stage + 1u; step1 > 0u; step1--) {
336
+ let step = step1 - 1u;
337
+ let half_block = 1u << step;
338
+ let is_first_step = uniforms.kind == 0u && step == stage;
339
+
340
+ let block_offset = (tid / half_block) * half_block;
341
+ let local_offset = tid % half_block;
342
+ let i = block_offset * 2u + local_offset;
343
+ let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
344
+ compare_and_swap(i, j);
345
+
346
+ workgroupBarrier();
347
+ }
348
+ }
349
+
350
+ if (wg_base + idx0 < ${n}u) {
351
+ output[base + wg_base + idx0] = shared_vals[idx0];
352
+ ${outputIndices ? `output_idx[base + wg_base + idx0] = shared_idx[idx0];` : ""}
353
+ }
354
+ if (wg_base + idx1 < ${n}u) {
355
+ output[base + wg_base + idx1] = shared_vals[idx1];
356
+ ${outputIndices ? `output_idx[base + wg_base + idx1] = shared_idx[idx1];` : ""}
357
+ }
358
+ } else {
359
+ // Execute single merge pass for a step >= numLocalStages.
360
+ let half_block = 1u << uniforms.merge_step; // half_block >= workgroupSize * 2
361
+ let thread_in_batch = wg_in_batch * ${workgroupSize} + tid;
362
+ let is_first_step = uniforms.merge_step == uniforms.merge_stage;
363
+
364
+ let block_offset = (thread_in_batch / half_block) * half_block;
365
+ let local_offset = thread_in_batch % half_block;
366
+ let i = block_offset * 2u + local_offset;
367
+ let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
368
+
369
+ // Global version of compare_and_swap()
370
+ if (j < ${n}u) {
371
+ let val_i = output[base + i];
372
+ let val_j = output[base + j];
373
+ if (compare(val_j, val_i)) {
374
+ output[base + i] = val_j;
375
+ output[base + j] = val_i;
376
+ ${outputIndices ? `
377
+ let tmp_idx = output_idx[base + i];
378
+ output_idx[base + i] = output_idx[base + j];
379
+ output_idx[base + j] = tmp_idx;` : ""}
380
+ }
381
+ }
382
+ }
383
+ }
384
+ `.trim();
385
+ const grid = calculateGrid(batches * workgroupsPerBatch);
386
+ const passes = [{ kind: "sort" }];
387
+ for (let mergeStage = numLocalStages; mergeStage < numStages; mergeStage++) for (let mergeStep = mergeStage; mergeStep >= numLocalStages - 1; mergeStep--) passes.push({
388
+ kind: "merge",
389
+ mergeStep,
390
+ mergeStage
391
+ });
392
+ return [{
393
+ code,
394
+ numInputs: 1,
395
+ numOutputs: outputIndices ? 2 : 1,
396
+ hasUniform: true,
397
+ passes: passes.map((pass) => ({
398
+ grid,
399
+ uniform: bitonicSortUniform(pass)
400
+ }))
401
+ }];
402
+ }
403
+ function createSort(device, type) {
404
+ const dtype = type.inputDtypes[0];
405
+ const shape = type.inputShapes[0];
406
+ const n = shape[shape.length - 1];
407
+ const batches = prod(shape.slice(0, -1));
408
+ return bitonicSortShader(device, dtype, n, batches, false);
409
+ }
410
+ function createArgsort(device, type) {
411
+ const dtype = type.inputDtypes[0];
412
+ const shape = type.inputShapes[0];
413
+ const n = shape[shape.length - 1];
414
+ const batches = prod(shape.slice(0, -1));
415
+ return bitonicSortShader(device, dtype, n, batches, true);
416
+ }
417
+ function createRoutineShader(device, routine) {
418
+ switch (routine.name) {
419
+ case Routines.Sort: return createSort(device, routine.type);
420
+ case Routines.Argsort: return createArgsort(device, routine.type);
421
+ default: throw new UnsupportedRoutineError(routine.name, "webgpu");
422
+ }
423
+ }
424
+
173
425
  //#endregion
174
426
  //#region src/backend/webgpu.ts
175
427
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -181,6 +433,7 @@ var WebGPUBackend = class {
181
433
  buffers;
182
434
  nextSlot;
183
435
  #cachedShaderMap = /* @__PURE__ */ new Map();
436
+ #reusableZsb;
184
437
  constructor(device) {
185
438
  this.device = device;
186
439
  if (DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
@@ -189,11 +442,16 @@ var WebGPUBackend = class {
189
442
  this.syncReader = new SyncReader(device);
190
443
  this.buffers = /* @__PURE__ */ new Map();
191
444
  this.nextSlot = 1;
445
+ this.#reusableZsb = this.#createBuffer(4);
446
+ device.addEventListener("uncapturederror", (event) => {
447
+ console.error("Uncaptured error in WebGPU backend:", event.error.message);
448
+ });
192
449
  }
193
450
  malloc(size, initialData) {
194
451
  let buffer;
195
452
  const paddedSize = Math.ceil(size / 4) * 4;
196
- if (initialData) {
453
+ if (size === 0) buffer = this.#reusableZsb;
454
+ else if (initialData) {
197
455
  if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
198
456
  if (initialData.byteLength < 4096) {
199
457
  buffer = this.#createBuffer(paddedSize, { mapped: true });
@@ -230,11 +488,12 @@ var WebGPUBackend = class {
230
488
  buffer.ref--;
231
489
  if (buffer.ref === 0) {
232
490
  this.buffers.delete(slot);
233
- buffer.buffer.destroy();
491
+ if (buffer.buffer !== this.#reusableZsb) buffer.buffer.destroy();
234
492
  }
235
493
  }
236
494
  async read(slot, start, count) {
237
495
  const { buffer, size } = this.#getBuffer(slot);
496
+ if (buffer === this.#reusableZsb) return new Uint8Array();
238
497
  if (start === void 0) start = 0;
239
498
  if (count === void 0) count = size - start;
240
499
  const paddedSize = Math.ceil(count / 4) * 4;
@@ -252,6 +511,7 @@ var WebGPUBackend = class {
252
511
  }
253
512
  readSync(slot, start, count) {
254
513
  const { buffer, size } = this.#getBuffer(slot);
514
+ if (buffer === this.#reusableZsb) return new Uint8Array();
255
515
  if (start === void 0) start = 0;
256
516
  if (count === void 0) count = size - start;
257
517
  return this.syncReader.read(buffer, start, count);
@@ -265,23 +525,43 @@ var WebGPUBackend = class {
265
525
  }
266
526
  return result;
267
527
  }
268
- async prepare(kernel) {
269
- const { shader, grid } = this.#cachedShader(kernel);
528
+ async prepareKernel(kernel) {
529
+ const shader = this.#cachedShader(kernel);
270
530
  const pipeline = await this.pipelines.prepare(shader);
271
- return new Executable(kernel, {
272
- shader,
273
- grid,
531
+ return new Executable(kernel, [{
532
+ ...shader,
274
533
  pipeline
275
- });
534
+ }]);
276
535
  }
277
- prepareSync(kernel) {
278
- const { shader, grid } = this.#cachedShader(kernel);
536
+ prepareKernelSync(kernel) {
537
+ const shader = this.#cachedShader(kernel);
279
538
  const pipeline = this.pipelines.prepareSync(shader);
280
- return new Executable(kernel, {
281
- shader,
282
- grid,
539
+ return new Executable(kernel, [{
540
+ ...shader,
283
541
  pipeline
542
+ }]);
543
+ }
544
+ async prepareRoutine(routine) {
545
+ const shaders = createRoutineShader(this.device, routine);
546
+ const dispatches = await Promise.all(shaders.map(async (shader) => {
547
+ const pipeline = await this.pipelines.prepare(shader);
548
+ return {
549
+ ...shader,
550
+ pipeline
551
+ };
552
+ }));
553
+ return new Executable(routine, dispatches);
554
+ }
555
+ prepareRoutineSync(routine) {
556
+ const shaders = createRoutineShader(this.device, routine);
557
+ const dispatches = shaders.map((shader) => {
558
+ const pipeline = this.pipelines.prepareSync(shader);
559
+ return {
560
+ ...shader,
561
+ pipeline
562
+ };
284
563
  });
564
+ return new Executable(routine, dispatches);
285
565
  }
286
566
  dispatch(exe, inputs, outputs) {
287
567
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
@@ -316,32 +596,6 @@ var WebGPUBackend = class {
316
596
  return buffer;
317
597
  }
318
598
  };
319
- function dtypeToWgsl(dtype, storage = false) {
320
- switch (dtype) {
321
- case DType.Bool: return storage ? "i32" : "bool";
322
- case DType.Int32: return "i32";
323
- case DType.Uint32: return "u32";
324
- case DType.Float32: return "f32";
325
- case DType.Float16: return "f16";
326
- default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
327
- }
328
- }
329
- function constToWgsl(dtype, value) {
330
- if (dtype === DType.Bool) return value ? "true" : "false";
331
- if (dtype === DType.Int32) return value.toString();
332
- if (dtype === DType.Uint32) return value.toString() + "u";
333
- if (dtype === DType.Float32) {
334
- if (Number.isNaN(value)) return "nan()";
335
- if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
336
- return "f32(" + value.toString() + ")";
337
- }
338
- if (dtype === DType.Float16) {
339
- if (Number.isNaN(value)) return "f16(nan())";
340
- if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
341
- return "f16(" + value.toString() + ")";
342
- }
343
- throw new Error(`Unsupported const dtype: ${dtype}`);
344
- }
345
599
  /**
346
600
  * Compiles an expression into WebGPU shader source code.
347
601
  *
@@ -362,12 +616,12 @@ function pipelineSource(device, kernel) {
362
616
  else if (line === popIndent) indent = indent.slice(0, -2);
363
617
  else shader.push(line ? indent + line : line);
364
618
  };
365
- if (tune.exp.some((exp) => exp.dtype === DType.Float16) || re?.epilogue.some((exp) => exp.dtype === DType.Float16)) {
619
+ if (tune.exp.some((exp) => exp.dtype === DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === DType.Float16)) {
366
620
  if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
367
621
  emit("enable f16;");
368
622
  }
369
- emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
370
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
623
+ emit(headerWgsl);
624
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
371
625
  if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
372
626
  if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
373
627
  emit("");
@@ -375,6 +629,9 @@ function pipelineSource(device, kernel) {
375
629
  tune.exp.fold((exp) => {
376
630
  if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
377
631
  });
632
+ tune.epilogue?.fold((exp) => {
633
+ if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
634
+ });
378
635
  for (let i = 0; i < nargs; i++) {
379
636
  const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
380
637
  emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
@@ -383,12 +640,7 @@ function pipelineSource(device, kernel) {
383
640
  emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
384
641
  const workgroupSize = findPow2(tune.threadCount, 256);
385
642
  const gridSize = Math.ceil(tune.threadCount / workgroupSize);
386
- let gridX = gridSize;
387
- let gridY = 1;
388
- if (gridSize > device.limits.maxComputeWorkgroupsPerDimension) {
389
- gridX = 16384;
390
- gridY = Math.ceil(gridSize / gridX);
391
- }
643
+ const [gridX, gridY] = calculateGrid(gridSize);
392
644
  emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
393
645
  if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
394
646
  else {
@@ -398,7 +650,7 @@ function pipelineSource(device, kernel) {
398
650
  let gensymCount = 0;
399
651
  const gensym = () => `alu${gensymCount++}`;
400
652
  const isGensym = (text) => text.match(/^alu[0-9]+$/);
401
- for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
653
+ if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
402
654
  const references = /* @__PURE__ */ new Map();
403
655
  const seen = /* @__PURE__ */ new Set();
404
656
  const countReferences = (exp) => {
@@ -511,13 +763,15 @@ function pipelineSource(device, kernel) {
511
763
  let rhs = items[i][0];
512
764
  for (let j = 1; j < unroll; j++) if (re.op === AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
513
765
  else if (re.op === AluOp.Mul) rhs = `${rhs} * ${items[i][j]}`;
514
- else if (re.op === AluOp.Min) rhs = `min(${rhs}, ${items[i][j]})`;
515
- else if (re.op === AluOp.Max) rhs = `max(${rhs}, ${items[i][j]})`;
766
+ else if (re.op === AluOp.Min) rhs = re.dtype === DType.Bool ? `(${rhs} && ${items[i][j]})` : `min(${rhs}, ${items[i][j]})`;
767
+ else if (re.op === AluOp.Max) rhs = re.dtype === DType.Bool ? `(${rhs} || ${items[i][j]})` : `max(${rhs}, ${items[i][j]})`;
516
768
  else throw new Error(`Unsupported reduction op: ${re.op}`);
517
769
  if (re.op === AluOp.Add) emit(`${acc[i]} += ${rhs};`);
518
770
  else if (re.op === AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
519
- else if (re.op === AluOp.Min) emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
520
- else if (re.op === AluOp.Max) emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
771
+ else if (re.op === AluOp.Min) if (re.dtype === DType.Bool) emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
772
+ else emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
773
+ else if (re.op === AluOp.Max) if (re.dtype === DType.Bool) emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
774
+ else emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
521
775
  else throw new Error(`Unsupported reduction op: ${re.op}`);
522
776
  }
523
777
  emit(popIndent, "}");
@@ -530,7 +784,10 @@ function pipelineSource(device, kernel) {
530
784
  const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
531
785
  outputIdxExps.push(exp.simplify(cache));
532
786
  countReferences(outputIdxExps[i]);
533
- fusionExps.push(re.epilogue.substitute({ acc: AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
787
+ fusionExps.push(tune.epilogue.substitute({
788
+ acc: AluExp.variable(re.dtype, acc[i]),
789
+ upcast: AluExp.i32(i)
790
+ }).simplify(cache));
534
791
  countReferences(fusionExps[i]);
535
792
  }
536
793
  for (let i = 0; i < upcast; i++) {
@@ -542,36 +799,72 @@ function pipelineSource(device, kernel) {
542
799
  }
543
800
  emit(popIndent, "}");
544
801
  return {
545
- shader: shader.join("\n"),
546
- grid: [gridX, gridY]
802
+ code: shader.join("\n"),
803
+ numInputs: nargs,
804
+ numOutputs: 1,
805
+ hasUniform: false,
806
+ passes: [{ grid: [gridX, gridY] }]
547
807
  };
548
808
  }
549
- function pipelineSubmit(device, { pipeline, grid }, inputs, outputs) {
550
- if (inputs.length + outputs.length > device.limits.maxStorageBuffersPerShaderStage) {
551
- const actual = inputs.length + outputs.length;
552
- const max = device.limits.maxStorageBuffersPerShaderStage;
553
- throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
554
- }
555
- const bindGroup = device.createBindGroup({
556
- layout: pipeline.getBindGroupLayout(0),
557
- entries: [...inputs.map((buffer, i) => {
558
- return {
809
+ function pipelineSubmit(device, pipelines, inputs, outputs) {
810
+ const commandEncoder = device.createCommandEncoder();
811
+ for (const { pipeline,...shader } of pipelines) {
812
+ if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
813
+ const filteredPasses = shader.passes.filter(({ grid }) => prod(grid) > 0);
814
+ if (filteredPasses.length === 0) continue;
815
+ const bindGroup = device.createBindGroup({
816
+ layout: pipeline.getBindGroupLayout(0),
817
+ entries: [...inputs.map((buffer, i) => ({
559
818
  binding: i,
560
819
  resource: { buffer }
561
- };
562
- }), {
563
- binding: inputs.length,
564
- resource: { buffer: outputs[0] }
565
- }]
566
- });
567
- const commandEncoder = device.createCommandEncoder();
568
- const passEncoder = commandEncoder.beginComputePass();
569
- passEncoder.setPipeline(pipeline);
570
- passEncoder.setBindGroup(0, bindGroup);
571
- passEncoder.dispatchWorkgroups(grid[0], grid[1]);
572
- passEncoder.end();
820
+ })), ...outputs.map((buffer, i) => ({
821
+ binding: inputs.length + i,
822
+ resource: { buffer }
823
+ }))]
824
+ });
825
+ let uniformBindGroup = null;
826
+ let uniformAlignment = 0;
827
+ if (shader.hasUniform) {
828
+ const uniforms = filteredPasses.map(({ uniform }) => uniform);
829
+ const [uniformBuffer, alignment] = combineUniforms(device, uniforms);
830
+ uniformAlignment = alignment;
831
+ uniformBindGroup = device.createBindGroup({
832
+ layout: pipeline.getBindGroupLayout(1),
833
+ entries: [{
834
+ binding: 0,
835
+ resource: {
836
+ buffer: uniformBuffer,
837
+ size: alignment
838
+ }
839
+ }]
840
+ });
841
+ }
842
+ for (let i = 0; i < filteredPasses.length; i++) {
843
+ const { grid } = filteredPasses[i];
844
+ const passEncoder = commandEncoder.beginComputePass();
845
+ passEncoder.setPipeline(pipeline);
846
+ passEncoder.setBindGroup(0, bindGroup);
847
+ if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
848
+ passEncoder.dispatchWorkgroups(grid[0], grid[1]);
849
+ passEncoder.end();
850
+ }
851
+ }
573
852
  device.queue.submit([commandEncoder.finish()]);
574
853
  }
854
+ function combineUniforms(device, uniforms) {
855
+ for (const buf of uniforms) if (!buf || buf.byteLength === 0 || buf.byteLength !== uniforms[0].byteLength) throw new Error("webgpu: Uniform mismatch between shader passes");
856
+ const minAlign = device.limits.minUniformBufferOffsetAlignment;
857
+ const alignment = Math.ceil(uniforms[0].byteLength / minAlign) * minAlign;
858
+ const buffer = device.createBuffer({
859
+ size: alignment * uniforms.length,
860
+ usage: GPUBufferUsage.UNIFORM,
861
+ mappedAtCreation: true
862
+ });
863
+ const bufferMapped = new Uint8Array(buffer.getMappedRange());
864
+ for (let i = 0; i < uniforms.length; i++) bufferMapped.set(uniforms[i], i * alignment);
865
+ buffer.unmap();
866
+ return [buffer, alignment];
867
+ }
575
868
  /**
576
869
  * A cache for compiled GPU compute pipelines, keyed by the shader source.
577
870
  *
@@ -588,18 +881,39 @@ var ShaderPipelineCache = class {
588
881
  this.cache = /* @__PURE__ */ new Map();
589
882
  this.inProgress = /* @__PURE__ */ new Map();
590
883
  }
591
- async prepare(code) {
592
- const existingPipeline = this.cache.get(code);
884
+ #getLayout(shader) {
885
+ if (shader.numInputs + shader.numOutputs > this.device.limits.maxStorageBuffersPerShaderStage) {
886
+ const actual = shader.numInputs + shader.numOutputs;
887
+ const max = this.device.limits.maxStorageBuffersPerShaderStage;
888
+ throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
889
+ }
890
+ const bindGroupLayouts = [this.device.createBindGroupLayout({ entries: range(shader.numInputs + shader.numOutputs).map((i) => ({
891
+ binding: i,
892
+ visibility: GPUShaderStage.COMPUTE,
893
+ buffer: { type: i < shader.numInputs ? "read-only-storage" : "storage" }
894
+ })) })];
895
+ if (shader.hasUniform) bindGroupLayouts.push(this.device.createBindGroupLayout({ entries: [{
896
+ binding: 0,
897
+ visibility: GPUShaderStage.COMPUTE,
898
+ buffer: {
899
+ type: "uniform",
900
+ hasDynamicOffset: true
901
+ }
902
+ }] }));
903
+ return this.device.createPipelineLayout({ bindGroupLayouts });
904
+ }
905
+ async prepare(shader) {
906
+ const existingPipeline = this.cache.get(shader.code);
593
907
  if (existingPipeline) return existingPipeline;
594
- const existingPromise = this.inProgress.get(code);
908
+ const existingPromise = this.inProgress.get(shader.code);
595
909
  if (existingPromise) return await existingPromise;
596
- if (DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
597
- const shaderModule = this.device.createShaderModule({ code });
910
+ if (DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
911
+ const shaderModule = this.device.createShaderModule({ code: shader.code });
598
912
  const promise = (async () => {
599
913
  this.device.pushErrorScope("validation");
600
914
  try {
601
915
  const pipeline$1 = await this.device.createComputePipelineAsync({
602
- layout: "auto",
916
+ layout: this.#getLayout(shader),
603
917
  compute: {
604
918
  module: shaderModule,
605
919
  entryPoint: "main"
@@ -609,23 +923,23 @@ var ShaderPipelineCache = class {
609
923
  return pipeline$1;
610
924
  } catch (_error) {
611
925
  const scope = await this.device.popErrorScope();
612
- const emsg = await compileError(shaderModule, scope, code);
926
+ const emsg = await compileError(shaderModule, scope, shader.code);
613
927
  throw new Error(emsg);
614
928
  }
615
929
  })();
616
- this.inProgress.set(code, promise);
930
+ this.inProgress.set(shader.code, promise);
617
931
  const pipeline = await promise;
618
- this.cache.set(code, pipeline);
932
+ this.cache.set(shader.code, pipeline);
619
933
  return pipeline;
620
934
  }
621
- prepareSync(code) {
622
- const existingPipeline = this.cache.get(code);
935
+ prepareSync(shader) {
936
+ const existingPipeline = this.cache.get(shader.code);
623
937
  if (existingPipeline) return existingPipeline;
624
- if (DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
625
- const shaderModule = this.device.createShaderModule({ code });
938
+ if (DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
939
+ const shaderModule = this.device.createShaderModule({ code: shader.code });
626
940
  this.device.pushErrorScope("validation");
627
941
  const pipeline = this.device.createComputePipeline({
628
- layout: "auto",
942
+ layout: this.#getLayout(shader),
629
943
  compute: {
630
944
  module: shaderModule,
631
945
  entryPoint: "main"
@@ -633,11 +947,11 @@ var ShaderPipelineCache = class {
633
947
  });
634
948
  this.device.popErrorScope().then(async (scope) => {
635
949
  if (scope !== null) {
636
- const emsg = await compileError(shaderModule, scope, code);
950
+ const emsg = await compileError(shaderModule, scope, shader.code);
637
951
  console.error(emsg);
638
952
  }
639
953
  });
640
- this.cache.set(code, pipeline);
954
+ this.cache.set(shader.code, pipeline);
641
955
  return pipeline;
642
956
  }
643
957
  };
@@ -651,5 +965,4 @@ async function compileError(shaderModule, scope, code) {
651
965
  }
652
966
 
653
967
  //#endregion
654
- export { WebGPUBackend };
655
- //# sourceMappingURL=webgpu-BGuG58KZ.js.map
968
+ export { WebGPUBackend };