@jax-js/jax 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-CmaidnkQ.cjs');
1
+ const require_backend = require('./backend-Bu9GY6sK.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -66,6 +66,59 @@ fn erfc(x: f32) -> f32 {
66
66
  return select(2.0 - E, E, x >= 0.0);
67
67
  }`;
68
68
 
69
+ //#endregion
70
+ //#region src/backend/webgpu/codegen.ts
71
+ const headerWgsl = String.raw`
72
+ fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }
73
+ fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }
74
+ `.trim();
75
+ function dtypeToWgsl(dtype, storage = false) {
76
+ switch (dtype) {
77
+ case require_backend.DType.Bool: return storage ? "i32" : "bool";
78
+ case require_backend.DType.Int32: return "i32";
79
+ case require_backend.DType.Uint32: return "u32";
80
+ case require_backend.DType.Float32: return "f32";
81
+ case require_backend.DType.Float16: return "f16";
82
+ default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
83
+ }
84
+ }
85
+ function maxValueWgsl(dtype) {
86
+ switch (dtype) {
87
+ case require_backend.DType.Bool: return "1";
88
+ case require_backend.DType.Int32: return "2147483647";
89
+ case require_backend.DType.Uint32: return "4294967295u";
90
+ case require_backend.DType.Float32: return "inf()";
91
+ case require_backend.DType.Float16: return "f16(inf())";
92
+ default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
93
+ }
94
+ }
95
+ function constToWgsl(dtype, value) {
96
+ if (dtype === require_backend.DType.Bool) return value ? "true" : "false";
97
+ if (dtype === require_backend.DType.Int32) return value.toString();
98
+ if (dtype === require_backend.DType.Uint32) return value.toString() + "u";
99
+ if (dtype === require_backend.DType.Float32) {
100
+ if (Number.isNaN(value)) return "nan()";
101
+ if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
102
+ return "f32(" + value.toString() + ")";
103
+ }
104
+ if (dtype === require_backend.DType.Float16) {
105
+ if (Number.isNaN(value)) return "f16(nan())";
106
+ if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
107
+ return "f16(" + value.toString() + ")";
108
+ }
109
+ throw new Error(`Unsupported const dtype: ${dtype}`);
110
+ }
111
+ const gridOffsetY = 16384;
112
+ function calculateGrid(gridSize) {
113
+ let gridX = gridSize;
114
+ let gridY = 1;
115
+ if (gridSize > 65535) {
116
+ gridX = gridOffsetY;
117
+ gridY = Math.ceil(gridSize / gridOffsetY);
118
+ }
119
+ return [gridX, gridY];
120
+ }
121
+
69
122
  //#endregion
70
123
  //#region src/backend/webgpu/reader.ts
71
124
  /**
@@ -170,6 +223,205 @@ var SyncReader = class SyncReader {
170
223
  }
171
224
  };
172
225
 
226
+ //#endregion
227
+ //#region src/backend/webgpu/routines.ts
228
+ function bitonicSortUniform(pass) {
229
+ const ar = new Uint32Array(3);
230
+ ar[0] = pass.kind === "sort" ? 0 : 1;
231
+ ar[1] = pass.mergeStep ?? 0;
232
+ ar[2] = pass.mergeStage ?? 0;
233
+ return new Uint8Array(ar.buffer);
234
+ }
235
+ /**
236
+ * Generate a bitonic sort shader.
237
+ *
238
+ * We implement a variant of bitonic sort that [only has forward comparators](
239
+ * <https://sortingalgos.miraheze.org/wiki/Bitonic_Sort#Bitonic_Sort_using_Forward_Comparators>),
240
+ * so we don't need to allocate memory for power-of-two padding.
241
+ *
242
+ * This uses workgroup shared memory up to `2*workgroupSize` elements, for each
243
+ * array in `batches`. For larger arrays, multiple passes are done:
244
+ *
245
+ * - Initial "sort" pass: each workgroup sorts its `2*workgroupSize` elements.
246
+ * - Subsequent "merge" passes: each pass merges sorted sequences of size
247
+ * `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
248
+ *
249
+ * The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
250
+ */
251
+ function bitonicSortShader(device, dtype, n, batches, outputIndices) {
252
+ const ty = dtypeToWgsl(dtype, true);
253
+ const paddedN = 1 << Math.ceil(Math.log2(n || 1));
254
+ const numThreads = Math.ceil(paddedN / 2);
255
+ const workgroupSize = require_backend.findPow2(numThreads, device.limits.maxComputeWorkgroupSizeX);
256
+ const workgroupsPerBatch = numThreads / workgroupSize;
257
+ const numStages = Math.log2(paddedN);
258
+ const numLocalStages = Math.min(numStages, Math.log2(workgroupSize * 2));
259
+ const needsF16 = dtype === require_backend.DType.Float16;
260
+ const padValue = require_backend.isFloatDtype(dtype) ? `${ty}(nan())` : maxValueWgsl(dtype);
261
+ const code = `
262
+ ${needsF16 ? "enable f16;" : ""}
263
+ ${headerWgsl}
264
+
265
+ struct Uniforms {
266
+ kind: u32, // 0 = sort, 1 = merge
267
+ merge_step: u32, // half_block = 2^step
268
+ merge_stage: u32, // only used for merge
269
+ }
270
+
271
+ @group(0) @binding(0) var<storage, read> input: array<${ty}>;
272
+ @group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
273
+ ${outputIndices ? `@group(0) @binding(2) var<storage, read_write> output_idx: array<i32>;` : ""}
274
+
275
+ @group(1) @binding(0) var<uniform> uniforms: Uniforms;
276
+
277
+ var<workgroup> shared_vals: array<${ty}, ${workgroupSize * 2}>;
278
+ ${outputIndices ? `var<workgroup> shared_idx: array<i32, ${workgroupSize * 2}>;` : ""}
279
+
280
+ fn compare(a: ${ty}, b: ${ty}) -> bool {
281
+ ${require_backend.isFloatDtype(dtype) ? `
282
+ let min_value = min(a, b);
283
+ return a == min_value && b != min_value;` : " return a < b;"}
284
+ }
285
+
286
+ fn compare_and_swap(i: u32, j: u32) {
287
+ let val_i = shared_vals[i];
288
+ let val_j = shared_vals[j];
289
+ if (compare(val_j, val_i)) {
290
+ shared_vals[i] = val_j;
291
+ shared_vals[j] = val_i;
292
+ ${outputIndices ? `
293
+ let tmp_idx = shared_idx[i];
294
+ shared_idx[i] = shared_idx[j];
295
+ shared_idx[j] = tmp_idx;` : ""}
296
+ }
297
+ }
298
+
299
+ @compute @workgroup_size(${workgroupSize})
300
+ fn main(
301
+ @builtin(workgroup_id) wg_id: vec3<u32>,
302
+ @builtin(local_invocation_id) local_id: vec3<u32>,
303
+ ) {
304
+ let blockid = wg_id.x + wg_id.y * ${gridOffsetY}u;
305
+ let batch = blockid / ${workgroupsPerBatch}u;
306
+ let wg_in_batch = blockid % ${workgroupsPerBatch}u;
307
+
308
+ let tid = local_id.x;
309
+ let base = batch * ${n}u;
310
+
311
+ if (uniforms.kind == 0u || (uniforms.kind == 1u && uniforms.merge_step == ${numLocalStages - 1}u)) {
312
+ let wg_base = wg_in_batch * ${workgroupSize * 2}u;
313
+
314
+ // Load data into shared memory (2 elements per thread)
315
+ let idx0 = tid * 2u;
316
+ let idx1 = tid * 2u + 1u;
317
+ // Load from input for initial 'sort' pass, then from output (read-write) for 'merge' passes.
318
+ if (uniforms.kind == 0u) {
319
+ shared_vals[idx0] = select(${padValue}, input[base + wg_base + idx0], wg_base + idx0 < ${n}u);
320
+ shared_vals[idx1] = select(${padValue}, input[base + wg_base + idx1], wg_base + idx1 < ${n}u);
321
+ ${outputIndices ? `
322
+ shared_idx[idx0] = i32(wg_base + idx0);
323
+ shared_idx[idx1] = i32(wg_base + idx1);` : ""}
324
+ } else {
325
+ shared_vals[idx0] = select(${padValue}, output[base + wg_base + idx0], wg_base + idx0 < ${n}u);
326
+ shared_vals[idx1] = select(${padValue}, output[base + wg_base + idx1], wg_base + idx1 < ${n}u);
327
+ ${outputIndices ? `
328
+ shared_idx[idx0] = select(${n}, output_idx[base + wg_base + idx0], wg_base + idx0 < ${n}u);
329
+ shared_idx[idx1] = select(${n}, output_idx[base + wg_base + idx1], wg_base + idx1 < ${n}u);` : ""}
330
+ }
331
+ workgroupBarrier();
332
+
333
+ let initial_stage = select(0u, ${numLocalStages - 1}u, uniforms.kind != 0u);
334
+ for (var stage = initial_stage; stage < ${numLocalStages}u; stage++) {
335
+ for (var step1 = stage + 1u; step1 > 0u; step1--) {
336
+ let step = step1 - 1u;
337
+ let half_block = 1u << step;
338
+ let is_first_step = uniforms.kind == 0u && step == stage;
339
+
340
+ let block_offset = (tid / half_block) * half_block;
341
+ let local_offset = tid % half_block;
342
+ let i = block_offset * 2u + local_offset;
343
+ let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
344
+ compare_and_swap(i, j);
345
+
346
+ workgroupBarrier();
347
+ }
348
+ }
349
+
350
+ if (wg_base + idx0 < ${n}u) {
351
+ output[base + wg_base + idx0] = shared_vals[idx0];
352
+ ${outputIndices ? `output_idx[base + wg_base + idx0] = shared_idx[idx0];` : ""}
353
+ }
354
+ if (wg_base + idx1 < ${n}u) {
355
+ output[base + wg_base + idx1] = shared_vals[idx1];
356
+ ${outputIndices ? `output_idx[base + wg_base + idx1] = shared_idx[idx1];` : ""}
357
+ }
358
+ } else {
359
+ // Execute single merge pass for a step >= numLocalStages.
360
+ let half_block = 1u << uniforms.merge_step; // half_block >= workgroupSize * 2
361
+ let thread_in_batch = wg_in_batch * ${workgroupSize} + tid;
362
+ let is_first_step = uniforms.merge_step == uniforms.merge_stage;
363
+
364
+ let block_offset = (thread_in_batch / half_block) * half_block;
365
+ let local_offset = thread_in_batch % half_block;
366
+ let i = block_offset * 2u + local_offset;
367
+ let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
368
+
369
+ // Global version of compare_and_swap()
370
+ if (j < ${n}u) {
371
+ let val_i = output[base + i];
372
+ let val_j = output[base + j];
373
+ if (compare(val_j, val_i)) {
374
+ output[base + i] = val_j;
375
+ output[base + j] = val_i;
376
+ ${outputIndices ? `
377
+ let tmp_idx = output_idx[base + i];
378
+ output_idx[base + i] = output_idx[base + j];
379
+ output_idx[base + j] = tmp_idx;` : ""}
380
+ }
381
+ }
382
+ }
383
+ }
384
+ `.trim();
385
+ const grid = calculateGrid(batches * workgroupsPerBatch);
386
+ const passes = [{ kind: "sort" }];
387
+ for (let mergeStage = numLocalStages; mergeStage < numStages; mergeStage++) for (let mergeStep = mergeStage; mergeStep >= numLocalStages - 1; mergeStep--) passes.push({
388
+ kind: "merge",
389
+ mergeStep,
390
+ mergeStage
391
+ });
392
+ return [{
393
+ code,
394
+ numInputs: 1,
395
+ numOutputs: outputIndices ? 2 : 1,
396
+ hasUniform: true,
397
+ passes: passes.map((pass) => ({
398
+ grid,
399
+ uniform: bitonicSortUniform(pass)
400
+ }))
401
+ }];
402
+ }
403
+ function createSort(device, type) {
404
+ const dtype = type.inputDtypes[0];
405
+ const shape = type.inputShapes[0];
406
+ const n = shape[shape.length - 1];
407
+ const batches = require_backend.prod(shape.slice(0, -1));
408
+ return bitonicSortShader(device, dtype, n, batches, false);
409
+ }
410
+ function createArgsort(device, type) {
411
+ const dtype = type.inputDtypes[0];
412
+ const shape = type.inputShapes[0];
413
+ const n = shape[shape.length - 1];
414
+ const batches = require_backend.prod(shape.slice(0, -1));
415
+ return bitonicSortShader(device, dtype, n, batches, true);
416
+ }
417
+ function createRoutineShader(device, routine) {
418
+ switch (routine.name) {
419
+ case require_backend.Routines.Sort: return createSort(device, routine.type);
420
+ case require_backend.Routines.Argsort: return createArgsort(device, routine.type);
421
+ default: throw new require_backend.UnsupportedRoutineError(routine.name, "webgpu");
422
+ }
423
+ }
424
+
173
425
  //#endregion
174
426
  //#region src/backend/webgpu.ts
175
427
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -181,6 +433,7 @@ var WebGPUBackend = class {
181
433
  buffers;
182
434
  nextSlot;
183
435
  #cachedShaderMap = /* @__PURE__ */ new Map();
436
+ #reusableZsb;
184
437
  constructor(device) {
185
438
  this.device = device;
186
439
  if (require_backend.DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
@@ -189,11 +442,16 @@ var WebGPUBackend = class {
189
442
  this.syncReader = new SyncReader(device);
190
443
  this.buffers = /* @__PURE__ */ new Map();
191
444
  this.nextSlot = 1;
445
+ this.#reusableZsb = this.#createBuffer(4);
446
+ device.addEventListener("uncapturederror", (event) => {
447
+ console.error("Uncaptured error in WebGPU backend:", event.error.message);
448
+ });
192
449
  }
193
450
  malloc(size, initialData) {
194
451
  let buffer;
195
452
  const paddedSize = Math.ceil(size / 4) * 4;
196
- if (initialData) {
453
+ if (size === 0) buffer = this.#reusableZsb;
454
+ else if (initialData) {
197
455
  if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
198
456
  if (initialData.byteLength < 4096) {
199
457
  buffer = this.#createBuffer(paddedSize, { mapped: true });
@@ -230,11 +488,12 @@ var WebGPUBackend = class {
230
488
  buffer.ref--;
231
489
  if (buffer.ref === 0) {
232
490
  this.buffers.delete(slot);
233
- buffer.buffer.destroy();
491
+ if (buffer.buffer !== this.#reusableZsb) buffer.buffer.destroy();
234
492
  }
235
493
  }
236
494
  async read(slot, start, count) {
237
495
  const { buffer, size } = this.#getBuffer(slot);
496
+ if (buffer === this.#reusableZsb) return new Uint8Array();
238
497
  if (start === void 0) start = 0;
239
498
  if (count === void 0) count = size - start;
240
499
  const paddedSize = Math.ceil(count / 4) * 4;
@@ -252,6 +511,7 @@ var WebGPUBackend = class {
252
511
  }
253
512
  readSync(slot, start, count) {
254
513
  const { buffer, size } = this.#getBuffer(slot);
514
+ if (buffer === this.#reusableZsb) return new Uint8Array();
255
515
  if (start === void 0) start = 0;
256
516
  if (count === void 0) count = size - start;
257
517
  return this.syncReader.read(buffer, start, count);
@@ -265,27 +525,45 @@ var WebGPUBackend = class {
265
525
  }
266
526
  return result;
267
527
  }
268
- async prepare(kernel) {
269
- const { shader, grid } = this.#cachedShader(kernel);
528
+ async prepareKernel(kernel) {
529
+ const shader = this.#cachedShader(kernel);
270
530
  const pipeline = await this.pipelines.prepare(shader);
271
- return new require_backend.Executable(kernel, {
272
- shader,
273
- grid,
531
+ return new require_backend.Executable(kernel, [{
532
+ ...shader,
274
533
  pipeline
275
- });
534
+ }]);
276
535
  }
277
- prepareSync(kernel) {
278
- const { shader, grid } = this.#cachedShader(kernel);
536
+ prepareKernelSync(kernel) {
537
+ const shader = this.#cachedShader(kernel);
279
538
  const pipeline = this.pipelines.prepareSync(shader);
280
- return new require_backend.Executable(kernel, {
281
- shader,
282
- grid,
539
+ return new require_backend.Executable(kernel, [{
540
+ ...shader,
283
541
  pipeline
542
+ }]);
543
+ }
544
+ async prepareRoutine(routine) {
545
+ const shaders = createRoutineShader(this.device, routine);
546
+ const dispatches = await Promise.all(shaders.map(async (shader) => {
547
+ const pipeline = await this.pipelines.prepare(shader);
548
+ return {
549
+ ...shader,
550
+ pipeline
551
+ };
552
+ }));
553
+ return new require_backend.Executable(routine, dispatches);
554
+ }
555
+ prepareRoutineSync(routine) {
556
+ const shaders = createRoutineShader(this.device, routine);
557
+ const dispatches = shaders.map((shader) => {
558
+ const pipeline = this.pipelines.prepareSync(shader);
559
+ return {
560
+ ...shader,
561
+ pipeline
562
+ };
284
563
  });
564
+ return new require_backend.Executable(routine, dispatches);
285
565
  }
286
566
  dispatch(exe, inputs, outputs) {
287
- if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
288
- if (exe.kernel.size === 0) return;
289
567
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
290
568
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
291
569
  pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
@@ -318,32 +596,6 @@ var WebGPUBackend = class {
318
596
  return buffer;
319
597
  }
320
598
  };
321
- function dtypeToWgsl(dtype, storage = false) {
322
- switch (dtype) {
323
- case require_backend.DType.Bool: return storage ? "i32" : "bool";
324
- case require_backend.DType.Int32: return "i32";
325
- case require_backend.DType.Uint32: return "u32";
326
- case require_backend.DType.Float32: return "f32";
327
- case require_backend.DType.Float16: return "f16";
328
- default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
329
- }
330
- }
331
- function constToWgsl(dtype, value) {
332
- if (dtype === require_backend.DType.Bool) return value ? "true" : "false";
333
- if (dtype === require_backend.DType.Int32) return value.toString();
334
- if (dtype === require_backend.DType.Uint32) return value.toString() + "u";
335
- if (dtype === require_backend.DType.Float32) {
336
- if (Number.isNaN(value)) return "nan()";
337
- if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
338
- return "f32(" + value.toString() + ")";
339
- }
340
- if (dtype === require_backend.DType.Float16) {
341
- if (Number.isNaN(value)) return "f16(nan())";
342
- if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
343
- return "f16(" + value.toString() + ")";
344
- }
345
- throw new Error(`Unsupported const dtype: ${dtype}`);
346
- }
347
599
  /**
348
600
  * Compiles an expression into WebGPU shader source code.
349
601
  *
@@ -368,7 +620,7 @@ function pipelineSource(device, kernel) {
368
620
  if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
369
621
  emit("enable f16;");
370
622
  }
371
- emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
623
+ emit(headerWgsl);
372
624
  const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
373
625
  if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
374
626
  if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
@@ -388,12 +640,7 @@ function pipelineSource(device, kernel) {
388
640
  emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
389
641
  const workgroupSize = require_backend.findPow2(tune.threadCount, 256);
390
642
  const gridSize = Math.ceil(tune.threadCount / workgroupSize);
391
- let gridX = gridSize;
392
- let gridY = 1;
393
- if (gridSize > device.limits.maxComputeWorkgroupsPerDimension) {
394
- gridX = 16384;
395
- gridY = Math.ceil(gridSize / gridX);
396
- }
643
+ const [gridX, gridY] = calculateGrid(gridSize);
397
644
  emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
398
645
  if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
399
646
  else {
@@ -516,13 +763,15 @@ function pipelineSource(device, kernel) {
516
763
  let rhs = items[i][0];
517
764
  for (let j = 1; j < unroll; j++) if (re.op === require_backend.AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
518
765
  else if (re.op === require_backend.AluOp.Mul) rhs = `${rhs} * ${items[i][j]}`;
519
- else if (re.op === require_backend.AluOp.Min) rhs = `min(${rhs}, ${items[i][j]})`;
520
- else if (re.op === require_backend.AluOp.Max) rhs = `max(${rhs}, ${items[i][j]})`;
766
+ else if (re.op === require_backend.AluOp.Min) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} && ${items[i][j]})` : `min(${rhs}, ${items[i][j]})`;
767
+ else if (re.op === require_backend.AluOp.Max) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} || ${items[i][j]})` : `max(${rhs}, ${items[i][j]})`;
521
768
  else throw new Error(`Unsupported reduction op: ${re.op}`);
522
769
  if (re.op === require_backend.AluOp.Add) emit(`${acc[i]} += ${rhs};`);
523
770
  else if (re.op === require_backend.AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
524
- else if (re.op === require_backend.AluOp.Min) emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
525
- else if (re.op === require_backend.AluOp.Max) emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
771
+ else if (re.op === require_backend.AluOp.Min) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
772
+ else emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
773
+ else if (re.op === require_backend.AluOp.Max) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
774
+ else emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
526
775
  else throw new Error(`Unsupported reduction op: ${re.op}`);
527
776
  }
528
777
  emit(popIndent, "}");
@@ -550,36 +799,72 @@ function pipelineSource(device, kernel) {
550
799
  }
551
800
  emit(popIndent, "}");
552
801
  return {
553
- shader: shader.join("\n"),
554
- grid: [gridX, gridY]
802
+ code: shader.join("\n"),
803
+ numInputs: nargs,
804
+ numOutputs: 1,
805
+ hasUniform: false,
806
+ passes: [{ grid: [gridX, gridY] }]
555
807
  };
556
808
  }
557
- function pipelineSubmit(device, { pipeline, grid }, inputs, outputs) {
558
- if (inputs.length + outputs.length > device.limits.maxStorageBuffersPerShaderStage) {
559
- const actual = inputs.length + outputs.length;
560
- const max = device.limits.maxStorageBuffersPerShaderStage;
561
- throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
562
- }
563
- const bindGroup = device.createBindGroup({
564
- layout: pipeline.getBindGroupLayout(0),
565
- entries: [...inputs.map((buffer, i) => {
566
- return {
809
+ function pipelineSubmit(device, pipelines, inputs, outputs) {
810
+ const commandEncoder = device.createCommandEncoder();
811
+ for (const { pipeline,...shader } of pipelines) {
812
+ if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
813
+ const filteredPasses = shader.passes.filter(({ grid }) => require_backend.prod(grid) > 0);
814
+ if (filteredPasses.length === 0) continue;
815
+ const bindGroup = device.createBindGroup({
816
+ layout: pipeline.getBindGroupLayout(0),
817
+ entries: [...inputs.map((buffer, i) => ({
567
818
  binding: i,
568
819
  resource: { buffer }
569
- };
570
- }), {
571
- binding: inputs.length,
572
- resource: { buffer: outputs[0] }
573
- }]
574
- });
575
- const commandEncoder = device.createCommandEncoder();
576
- const passEncoder = commandEncoder.beginComputePass();
577
- passEncoder.setPipeline(pipeline);
578
- passEncoder.setBindGroup(0, bindGroup);
579
- passEncoder.dispatchWorkgroups(grid[0], grid[1]);
580
- passEncoder.end();
820
+ })), ...outputs.map((buffer, i) => ({
821
+ binding: inputs.length + i,
822
+ resource: { buffer }
823
+ }))]
824
+ });
825
+ let uniformBindGroup = null;
826
+ let uniformAlignment = 0;
827
+ if (shader.hasUniform) {
828
+ const uniforms = filteredPasses.map(({ uniform }) => uniform);
829
+ const [uniformBuffer, alignment] = combineUniforms(device, uniforms);
830
+ uniformAlignment = alignment;
831
+ uniformBindGroup = device.createBindGroup({
832
+ layout: pipeline.getBindGroupLayout(1),
833
+ entries: [{
834
+ binding: 0,
835
+ resource: {
836
+ buffer: uniformBuffer,
837
+ size: alignment
838
+ }
839
+ }]
840
+ });
841
+ }
842
+ for (let i = 0; i < filteredPasses.length; i++) {
843
+ const { grid } = filteredPasses[i];
844
+ const passEncoder = commandEncoder.beginComputePass();
845
+ passEncoder.setPipeline(pipeline);
846
+ passEncoder.setBindGroup(0, bindGroup);
847
+ if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
848
+ passEncoder.dispatchWorkgroups(grid[0], grid[1]);
849
+ passEncoder.end();
850
+ }
851
+ }
581
852
  device.queue.submit([commandEncoder.finish()]);
582
853
  }
854
+ function combineUniforms(device, uniforms) {
855
+ for (const buf of uniforms) if (!buf || buf.byteLength === 0 || buf.byteLength !== uniforms[0].byteLength) throw new Error("webgpu: Uniform mismatch between shader passes");
856
+ const minAlign = device.limits.minUniformBufferOffsetAlignment;
857
+ const alignment = Math.ceil(uniforms[0].byteLength / minAlign) * minAlign;
858
+ const buffer = device.createBuffer({
859
+ size: alignment * uniforms.length,
860
+ usage: GPUBufferUsage.UNIFORM,
861
+ mappedAtCreation: true
862
+ });
863
+ const bufferMapped = new Uint8Array(buffer.getMappedRange());
864
+ for (let i = 0; i < uniforms.length; i++) bufferMapped.set(uniforms[i], i * alignment);
865
+ buffer.unmap();
866
+ return [buffer, alignment];
867
+ }
583
868
  /**
584
869
  * A cache for compiled GPU compute pipelines, keyed by the shader source.
585
870
  *
@@ -596,18 +881,39 @@ var ShaderPipelineCache = class {
596
881
  this.cache = /* @__PURE__ */ new Map();
597
882
  this.inProgress = /* @__PURE__ */ new Map();
598
883
  }
599
- async prepare(code) {
600
- const existingPipeline = this.cache.get(code);
884
+ #getLayout(shader) {
885
+ if (shader.numInputs + shader.numOutputs > this.device.limits.maxStorageBuffersPerShaderStage) {
886
+ const actual = shader.numInputs + shader.numOutputs;
887
+ const max = this.device.limits.maxStorageBuffersPerShaderStage;
888
+ throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
889
+ }
890
+ const bindGroupLayouts = [this.device.createBindGroupLayout({ entries: require_backend.range(shader.numInputs + shader.numOutputs).map((i) => ({
891
+ binding: i,
892
+ visibility: GPUShaderStage.COMPUTE,
893
+ buffer: { type: i < shader.numInputs ? "read-only-storage" : "storage" }
894
+ })) })];
895
+ if (shader.hasUniform) bindGroupLayouts.push(this.device.createBindGroupLayout({ entries: [{
896
+ binding: 0,
897
+ visibility: GPUShaderStage.COMPUTE,
898
+ buffer: {
899
+ type: "uniform",
900
+ hasDynamicOffset: true
901
+ }
902
+ }] }));
903
+ return this.device.createPipelineLayout({ bindGroupLayouts });
904
+ }
905
+ async prepare(shader) {
906
+ const existingPipeline = this.cache.get(shader.code);
601
907
  if (existingPipeline) return existingPipeline;
602
- const existingPromise = this.inProgress.get(code);
908
+ const existingPromise = this.inProgress.get(shader.code);
603
909
  if (existingPromise) return await existingPromise;
604
- if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
605
- const shaderModule = this.device.createShaderModule({ code });
910
+ if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
911
+ const shaderModule = this.device.createShaderModule({ code: shader.code });
606
912
  const promise = (async () => {
607
913
  this.device.pushErrorScope("validation");
608
914
  try {
609
915
  const pipeline$1 = await this.device.createComputePipelineAsync({
610
- layout: "auto",
916
+ layout: this.#getLayout(shader),
611
917
  compute: {
612
918
  module: shaderModule,
613
919
  entryPoint: "main"
@@ -617,23 +923,23 @@ var ShaderPipelineCache = class {
617
923
  return pipeline$1;
618
924
  } catch (_error) {
619
925
  const scope = await this.device.popErrorScope();
620
- const emsg = await compileError(shaderModule, scope, code);
926
+ const emsg = await compileError(shaderModule, scope, shader.code);
621
927
  throw new Error(emsg);
622
928
  }
623
929
  })();
624
- this.inProgress.set(code, promise);
930
+ this.inProgress.set(shader.code, promise);
625
931
  const pipeline = await promise;
626
- this.cache.set(code, pipeline);
932
+ this.cache.set(shader.code, pipeline);
627
933
  return pipeline;
628
934
  }
629
- prepareSync(code) {
630
- const existingPipeline = this.cache.get(code);
935
+ prepareSync(shader) {
936
+ const existingPipeline = this.cache.get(shader.code);
631
937
  if (existingPipeline) return existingPipeline;
632
- if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
633
- const shaderModule = this.device.createShaderModule({ code });
938
+ if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
939
+ const shaderModule = this.device.createShaderModule({ code: shader.code });
634
940
  this.device.pushErrorScope("validation");
635
941
  const pipeline = this.device.createComputePipeline({
636
- layout: "auto",
942
+ layout: this.#getLayout(shader),
637
943
  compute: {
638
944
  module: shaderModule,
639
945
  entryPoint: "main"
@@ -641,11 +947,11 @@ var ShaderPipelineCache = class {
641
947
  });
642
948
  this.device.popErrorScope().then(async (scope) => {
643
949
  if (scope !== null) {
644
- const emsg = await compileError(shaderModule, scope, code);
950
+ const emsg = await compileError(shaderModule, scope, shader.code);
645
951
  console.error(emsg);
646
952
  }
647
953
  });
648
- this.cache.set(code, pipeline);
954
+ this.cache.set(shader.code, pipeline);
649
955
  return pipeline;
650
956
  }
651
957
  };
@@ -659,5 +965,4 @@ async function compileError(shaderModule, scope, code) {
659
965
  }
660
966
 
661
967
  //#endregion
662
- exports.WebGPUBackend = WebGPUBackend;
663
- //# sourceMappingURL=webgpu-BVns4DbI.cjs.map
968
+ exports.WebGPUBackend = WebGPUBackend;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.3",
3
+ "version": "0.1.4",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",