@jax-js/jax 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1261 @@
1
+ const require_backend = require('./backend-DziQSaoQ.cjs');
2
+
3
+ //#region src/backend/webgpu/builtins.ts
4
+ const threefrySrc = `
5
+ fn threefry2x32(key: vec2<u32>, ctr: vec2<u32>) -> vec2<u32> {
6
+ let ks0: u32 = key.x;
7
+ let ks1: u32 = key.y;
8
+ let ks2: u32 = ks0 ^ ks1 ^ 0x1BD11BDAu;
9
+
10
+ var x0: u32 = ctr.x + ks0;
11
+ var x1: u32 = ctr.y + ks1;
12
+
13
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
14
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
15
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
16
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
17
+ x0 += ks1;
18
+ x1 += ks2 + 1u;
19
+
20
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
21
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
22
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
23
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
24
+ x0 += ks2;
25
+ x1 += ks0 + 2u;
26
+
27
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
28
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
29
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
30
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
31
+ x0 += ks0;
32
+ x1 += ks1 + 3u;
33
+
34
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
35
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
36
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
37
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
38
+ x0 += ks1;
39
+ x1 += ks2 + 4u;
40
+
41
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
42
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
43
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
44
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
45
+ x0 += ks2;
46
+ x1 += ks0 + 5u;
47
+
48
+ return vec2<u32>(x0, x1);
49
+ }`;
50
+ const erfSrc = `
51
+ const _erf_p: f32 = 0.3275911;
52
+ const _erf_a1: f32 = 0.254829592;
53
+ const _erf_a2: f32 = -0.284496736;
54
+ const _erf_a3: f32 = 1.421413741;
55
+ const _erf_a4: f32 = -1.453152027;
56
+ const _erf_a5: f32 = 1.061405429;
57
+ fn erf(x: f32) -> f32 {
58
+ let t = 1.0 / (1.0 + _erf_p * abs(x));
59
+ let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
60
+ return sign(x) * (1.0 - P_t * exp(-x * x));
61
+ }
62
+ fn erfc(x: f32) -> f32 {
63
+ let t = 1.0 / (1.0 + _erf_p * abs(x));
64
+ let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
65
+ let E = P_t * exp(-x * x);
66
+ return select(2.0 - E, E, x >= 0.0);
67
+ }`;
68
+
69
+ //#endregion
70
+ //#region src/backend/webgpu/codegen.ts
71
+ const headerWgsl = String.raw`
72
+ fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }
73
+ fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }
74
+ `.trim();
75
+ function dtypeToWgsl(dtype, storage = false) {
76
+ switch (dtype) {
77
+ case require_backend.DType.Bool: return storage ? "i32" : "bool";
78
+ case require_backend.DType.Int32: return "i32";
79
+ case require_backend.DType.Uint32: return "u32";
80
+ case require_backend.DType.Float32: return "f32";
81
+ case require_backend.DType.Float16: return "f16";
82
+ default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
83
+ }
84
+ }
85
+ function maxValueWgsl(dtype) {
86
+ switch (dtype) {
87
+ case require_backend.DType.Bool: return "1";
88
+ case require_backend.DType.Int32: return "2147483647";
89
+ case require_backend.DType.Uint32: return "4294967295u";
90
+ case require_backend.DType.Float32: return "inf()";
91
+ case require_backend.DType.Float16: return "f16(inf())";
92
+ default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
93
+ }
94
+ }
95
+ function constToWgsl(dtype, value) {
96
+ if (dtype === require_backend.DType.Bool) return value ? "true" : "false";
97
+ if (dtype === require_backend.DType.Int32) return value.toString();
98
+ if (dtype === require_backend.DType.Uint32) return value.toString() + "u";
99
+ if (dtype === require_backend.DType.Float32) {
100
+ if (Number.isNaN(value)) return "nan()";
101
+ if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
102
+ return "f32(" + value.toString() + ")";
103
+ }
104
+ if (dtype === require_backend.DType.Float16) {
105
+ if (Number.isNaN(value)) return "f16(nan())";
106
+ if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
107
+ return "f16(" + value.toString() + ")";
108
+ }
109
+ throw new Error(`Unsupported const dtype: ${dtype}`);
110
+ }
111
+ const gridOffsetY = 16384;
112
+ function calculateGrid(gridSize) {
113
+ let gridX = gridSize;
114
+ let gridY = 1;
115
+ if (gridSize > 65535) {
116
+ gridX = gridOffsetY;
117
+ gridY = Math.ceil(gridSize / gridOffsetY);
118
+ }
119
+ return [gridX, gridY];
120
+ }
121
+
122
+ //#endregion
123
+ //#region src/backend/webgpu/reader.ts
124
+ /**
125
+ * Graphics state used to synchronously read data from WebGPU buffers.
126
+ *
127
+ * This trick is borrowed from TensorFlow.js. Basically, the idea is to create
128
+ * an offscreen canvas with one pixel for every 4 bytes ("device storage"), then
129
+ * configure it with a WebGPU context. Copy the buffer to a texture, then draw
130
+ * the canvas onto another offscreen canvas with '2d' context ("host storage").
131
+ *
132
+ * Once it's on host storage, we can use `getImageData()` to read the pixels
133
+ * from the image directly.
134
+ *
135
+ * We use 256x256 canvases here (256 KiB). The performance of this is bad
136
+ * because it involves multiple data copies, but it still works. We also
137
+ * actually need to copy the image twice: once in "opaque" mode for the RGB
138
+ * values, and once in "premultiplied" mode for the alpha channel.
139
+ *
140
+ * https://github.com/tensorflow/tfjs/blob/tfjs-v4.22.0/tfjs-backend-webgpu/src/backend_webgpu.ts#L379
141
+ */
142
+ var SyncReader = class SyncReader {
143
+ static alphaModes = ["opaque", "premultiplied"];
144
+ static width = 256;
145
+ static height = 256;
146
+ initialized = false;
147
+ deviceStorage;
148
+ deviceContexts;
149
+ hostStorage;
150
+ hostContext;
151
+ constructor(device) {
152
+ this.device = device;
153
+ }
154
+ #init() {
155
+ const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
156
+ this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
157
+ this.deviceContexts = this.deviceStorage.map((canvas, i) => {
158
+ const context = canvas.getContext("webgpu");
159
+ context.configure({
160
+ device: this.device,
161
+ format: "bgra8unorm",
162
+ usage: GPUTextureUsage.COPY_DST,
163
+ alphaMode: SyncReader.alphaModes[i]
164
+ });
165
+ return context;
166
+ });
167
+ this.hostStorage = makeCanvas();
168
+ this.hostContext = this.hostStorage.getContext("2d", { willReadFrequently: true });
169
+ this.initialized = true;
170
+ }
171
+ read(buffer, start, count) {
172
+ if (!this.initialized) this.#init();
173
+ const deviceStorage = this.deviceStorage;
174
+ const deviceContexts = this.deviceContexts;
175
+ const hostContext = this.hostContext;
176
+ const pixelsSize = Math.ceil(count / 4);
177
+ const bytesPerRow = SyncReader.width * 4;
178
+ const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
179
+ for (let i = 0; i < deviceContexts.length; i++) {
180
+ const texture = deviceContexts[i].getCurrentTexture();
181
+ const readData = (width, height, offset$1) => {
182
+ const encoder = this.device.createCommandEncoder();
183
+ encoder.copyBufferToTexture({
184
+ buffer,
185
+ bytesPerRow,
186
+ offset: offset$1 + start
187
+ }, { texture }, {
188
+ width,
189
+ height,
190
+ depthOrArrayLayers: 1
191
+ });
192
+ const commandBuffer = encoder.finish();
193
+ this.device.queue.submit([commandBuffer]);
194
+ hostContext.clearRect(0, 0, width, height);
195
+ hostContext.drawImage(deviceStorage[i], 0, 0);
196
+ const values = hostContext.getImageData(0, 0, width, height).data;
197
+ const span = new Uint8ClampedArray(valsGPU, offset$1, 4 * width * height);
198
+ const alphaMode = SyncReader.alphaModes[i];
199
+ for (let k = 0; k < span.length; k += 4) if (alphaMode === "premultiplied") span[k + 3] = values[k + 3];
200
+ else {
201
+ span[k] = values[k + 2];
202
+ span[k + 1] = values[k + 1];
203
+ span[k + 2] = values[k];
204
+ }
205
+ };
206
+ const pixelsPerCanvas = SyncReader.width * SyncReader.height;
207
+ const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
208
+ let remainder = pixelsSize % pixelsPerCanvas;
209
+ const remainderRows = Math.floor(remainder / SyncReader.width);
210
+ remainder = remainder % SyncReader.width;
211
+ let offset = 0;
212
+ for (let j = 0; j < wholeChunks; j++) {
213
+ readData(SyncReader.width, SyncReader.height, offset);
214
+ offset += pixelsPerCanvas * 4;
215
+ }
216
+ if (remainderRows > 0) {
217
+ readData(SyncReader.width, remainderRows, offset);
218
+ offset += remainderRows * SyncReader.width * 4;
219
+ }
220
+ if (remainder > 0) readData(remainder, 1, offset);
221
+ }
222
+ return new Uint8Array(valsGPU, 0, count);
223
+ }
224
+ };
225
+
226
+ //#endregion
227
+ //#region src/backend/webgpu/routines.ts
228
+ function bitonicSortUniform(pass) {
229
+ const ar = new Uint32Array(3);
230
+ ar[0] = pass.kind === "sort" ? 0 : 1;
231
+ ar[1] = pass.mergeStep ?? 0;
232
+ ar[2] = pass.mergeStage ?? 0;
233
+ return new Uint8Array(ar.buffer);
234
+ }
235
+ /**
236
+ * Generate a bitonic sort shader.
237
+ *
238
+ * We implement a variant of bitonic sort that [only has forward comparators](
239
+ * <https://sortingalgos.miraheze.org/wiki/Bitonic_Sort#Bitonic_Sort_using_Forward_Comparators>),
240
+ * so we don't need to allocate memory for power-of-two padding.
241
+ *
242
+ * This uses workgroup shared memory up to `2*workgroupSize` elements, for each
243
+ * array in `batches`. For larger arrays, multiple passes are done:
244
+ *
245
+ * - Initial "sort" pass: each workgroup sorts its `2*workgroupSize` elements.
246
+ * - Subsequent "merge" passes: each pass merges sorted sequences of size
247
+ * `2^(step+1)` with multiple workgroups. This doesn't use shared memory.
248
+ *
249
+ * The total number of passes is roughly `log2(n / workgroupSize)^2 / 2`.
250
+ */
251
+ function bitonicSortShader(device, dtype, n, batches, outputIndices) {
252
+ const ty = dtypeToWgsl(dtype, true);
253
+ const paddedN = 1 << Math.ceil(Math.log2(n || 1));
254
+ const numThreads = Math.ceil(paddedN / 2);
255
+ const workgroupSize = require_backend.findPow2(numThreads, device.limits.maxComputeWorkgroupSizeX);
256
+ const workgroupsPerBatch = numThreads / workgroupSize;
257
+ const numStages = Math.log2(paddedN);
258
+ const numLocalStages = Math.min(numStages, Math.log2(workgroupSize * 2));
259
+ const needsF16 = dtype === require_backend.DType.Float16;
260
+ const padValue = require_backend.isFloatDtype(dtype) ? `${ty}(nan())` : maxValueWgsl(dtype);
261
+ const code = `
262
+ ${needsF16 ? "enable f16;" : ""}
263
+ ${headerWgsl}
264
+
265
+ struct Uniforms {
266
+ kind: u32, // 0 = sort, 1 = merge
267
+ merge_step: u32, // half_block = 2^step
268
+ merge_stage: u32, // only used for merge
269
+ }
270
+
271
+ @group(0) @binding(0) var<storage, read> input: array<${ty}>;
272
+ @group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
273
+ ${outputIndices ? `@group(0) @binding(2) var<storage, read_write> output_idx: array<i32>;` : ""}
274
+
275
+ @group(1) @binding(0) var<uniform> uniforms: Uniforms;
276
+
277
+ var<workgroup> shared_vals: array<${ty}, ${workgroupSize * 2}>;
278
+ ${outputIndices ? `var<workgroup> shared_idx: array<i32, ${workgroupSize * 2}>;` : ""}
279
+
280
+ fn compare(a: ${ty}, b: ${ty}) -> bool {
281
+ ${require_backend.isFloatDtype(dtype) ? `
282
+ let min_value = min(a, b);
283
+ return a == min_value && b != min_value;` : " return a < b;"}
284
+ }
285
+
286
+ fn compare_and_swap(i: u32, j: u32) {
287
+ let val_i = shared_vals[i];
288
+ let val_j = shared_vals[j];
289
+ if (compare(val_j, val_i)) {
290
+ shared_vals[i] = val_j;
291
+ shared_vals[j] = val_i;
292
+ ${outputIndices ? `
293
+ let tmp_idx = shared_idx[i];
294
+ shared_idx[i] = shared_idx[j];
295
+ shared_idx[j] = tmp_idx;` : ""}
296
+ }
297
+ }
298
+
299
+ @compute @workgroup_size(${workgroupSize})
300
+ fn main(
301
+ @builtin(workgroup_id) wg_id: vec3<u32>,
302
+ @builtin(local_invocation_id) local_id: vec3<u32>,
303
+ ) {
304
+ let blockid = wg_id.x + wg_id.y * ${gridOffsetY}u;
305
+ let batch = blockid / ${workgroupsPerBatch}u;
306
+ let wg_in_batch = blockid % ${workgroupsPerBatch}u;
307
+
308
+ let tid = local_id.x;
309
+ let base = batch * ${n}u;
310
+
311
+ if (uniforms.kind == 0u || (uniforms.kind == 1u && uniforms.merge_step == ${numLocalStages - 1}u)) {
312
+ let wg_base = wg_in_batch * ${workgroupSize * 2}u;
313
+
314
+ // Load data into shared memory (2 elements per thread)
315
+ let idx0 = tid * 2u;
316
+ let idx1 = tid * 2u + 1u;
317
+ // Load from input for initial 'sort' pass, then from output (read-write) for 'merge' passes.
318
+ if (uniforms.kind == 0u) {
319
+ shared_vals[idx0] = select(${padValue}, input[base + wg_base + idx0], wg_base + idx0 < ${n}u);
320
+ shared_vals[idx1] = select(${padValue}, input[base + wg_base + idx1], wg_base + idx1 < ${n}u);
321
+ ${outputIndices ? `
322
+ shared_idx[idx0] = i32(wg_base + idx0);
323
+ shared_idx[idx1] = i32(wg_base + idx1);` : ""}
324
+ } else {
325
+ shared_vals[idx0] = select(${padValue}, output[base + wg_base + idx0], wg_base + idx0 < ${n}u);
326
+ shared_vals[idx1] = select(${padValue}, output[base + wg_base + idx1], wg_base + idx1 < ${n}u);
327
+ ${outputIndices ? `
328
+ shared_idx[idx0] = select(${n}, output_idx[base + wg_base + idx0], wg_base + idx0 < ${n}u);
329
+ shared_idx[idx1] = select(${n}, output_idx[base + wg_base + idx1], wg_base + idx1 < ${n}u);` : ""}
330
+ }
331
+ workgroupBarrier();
332
+
333
+ let initial_stage = select(0u, ${numLocalStages - 1}u, uniforms.kind != 0u);
334
+ for (var stage = initial_stage; stage < ${numLocalStages}u; stage++) {
335
+ for (var step1 = stage + 1u; step1 > 0u; step1--) {
336
+ let step = step1 - 1u;
337
+ let half_block = 1u << step;
338
+ let is_first_step = uniforms.kind == 0u && step == stage;
339
+
340
+ let block_offset = (tid / half_block) * half_block;
341
+ let local_offset = tid % half_block;
342
+ let i = block_offset * 2u + local_offset;
343
+ let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
344
+ compare_and_swap(i, j);
345
+
346
+ workgroupBarrier();
347
+ }
348
+ }
349
+
350
+ if (wg_base + idx0 < ${n}u) {
351
+ output[base + wg_base + idx0] = shared_vals[idx0];
352
+ ${outputIndices ? `output_idx[base + wg_base + idx0] = shared_idx[idx0];` : ""}
353
+ }
354
+ if (wg_base + idx1 < ${n}u) {
355
+ output[base + wg_base + idx1] = shared_vals[idx1];
356
+ ${outputIndices ? `output_idx[base + wg_base + idx1] = shared_idx[idx1];` : ""}
357
+ }
358
+ } else {
359
+ // Execute single merge pass for a step >= numLocalStages.
360
+ let half_block = 1u << uniforms.merge_step; // half_block >= workgroupSize * 2
361
+ let thread_in_batch = wg_in_batch * ${workgroupSize} + tid;
362
+ let is_first_step = uniforms.merge_step == uniforms.merge_stage;
363
+
364
+ let block_offset = (thread_in_batch / half_block) * half_block;
365
+ let local_offset = thread_in_batch % half_block;
366
+ let i = block_offset * 2u + local_offset;
367
+ let j = select(i + half_block, i ^ (half_block * 2u - 1u), is_first_step);
368
+
369
+ // Global version of compare_and_swap()
370
+ if (j < ${n}u) {
371
+ let val_i = output[base + i];
372
+ let val_j = output[base + j];
373
+ if (compare(val_j, val_i)) {
374
+ output[base + i] = val_j;
375
+ output[base + j] = val_i;
376
+ ${outputIndices ? `
377
+ let tmp_idx = output_idx[base + i];
378
+ output_idx[base + i] = output_idx[base + j];
379
+ output_idx[base + j] = tmp_idx;` : ""}
380
+ }
381
+ }
382
+ }
383
+ }
384
+ `.trim();
385
+ const grid = calculateGrid(batches * workgroupsPerBatch);
386
+ const passes = [{ kind: "sort" }];
387
+ for (let mergeStage = numLocalStages; mergeStage < numStages; mergeStage++) for (let mergeStep = mergeStage; mergeStep >= numLocalStages - 1; mergeStep--) passes.push({
388
+ kind: "merge",
389
+ mergeStep,
390
+ mergeStage
391
+ });
392
+ return [{
393
+ code,
394
+ numInputs: 1,
395
+ numOutputs: outputIndices ? 2 : 1,
396
+ hasUniform: true,
397
+ passes: passes.map((pass) => ({
398
+ grid,
399
+ uniform: bitonicSortUniform(pass)
400
+ }))
401
+ }];
402
+ }
403
+ function createSort(device, type) {
404
+ const dtype = type.inputDtypes[0];
405
+ const shape = type.inputShapes[0];
406
+ const n = shape[shape.length - 1];
407
+ const batches = require_backend.prod(shape.slice(0, -1));
408
+ return bitonicSortShader(device, dtype, n, batches, false);
409
+ }
410
+ function createArgsort(device, type) {
411
+ const dtype = type.inputDtypes[0];
412
+ const shape = type.inputShapes[0];
413
+ const n = shape[shape.length - 1];
414
+ const batches = require_backend.prod(shape.slice(0, -1));
415
+ return bitonicSortShader(device, dtype, n, batches, true);
416
+ }
417
+ /**
418
+ * Generate a triangular solve shader.
419
+ *
420
+ * Solves A @ X.T = B.T for X, where A is upper-triangular.
421
+ * Uses a parallelized back-substitution:
422
+ * 1. Copy b to x
423
+ * 2. For j = n-1 down to 0:
424
+ * - Divide x[j] by a[j,j] (single thread)
425
+ * - All threads subtract x[j] * a[i,j] from x[i] for i < j in parallel
426
+ */
427
+ function createTriangularSolve(device, type, params) {
428
+ const dtype = type.inputDtypes[0];
429
+ const aShape = type.inputShapes[0];
430
+ const bShape = type.inputShapes[1];
431
+ const n = aShape[aShape.length - 1];
432
+ const numRhs = bShape[bShape.length - 2];
433
+ const numMatrices = require_backend.prod(aShape.slice(0, -2));
434
+ const needsF16 = dtype === require_backend.DType.Float16;
435
+ const ty = dtypeToWgsl(dtype, true);
436
+ const workgroupSize = require_backend.findPow2(n, device.limits.maxComputeWorkgroupSizeX);
437
+ const code = `
438
+ ${needsF16 ? "enable f16;" : ""}
439
+ ${headerWgsl}
440
+
441
+ @group(0) @binding(0) var<storage, read> a: array<${ty}>;
442
+ @group(0) @binding(1) var<storage, read> b: array<${ty}>;
443
+ @group(0) @binding(2) var<storage, read_write> x: array<${ty}>;
444
+
445
+ // Shared memory for the current pivot value x[j]
446
+ var<workgroup> x_j: ${ty};
447
+
448
+ @compute @workgroup_size(${workgroupSize})
449
+ fn main(
450
+ @builtin(workgroup_id) wg_id: vec3<u32>,
451
+ @builtin(local_invocation_id) local_id: vec3<u32>,
452
+ ) {
453
+ let wg_idx = wg_id.x + wg_id.y * ${gridOffsetY}u;
454
+ let mat_idx = wg_idx / ${numRhs}u;
455
+ let rhs_idx = wg_idx % ${numRhs}u;
456
+
457
+ if (mat_idx >= ${numMatrices}u) {
458
+ return;
459
+ }
460
+
461
+ let a_base = mat_idx * ${n * n}u;
462
+ let bx_base = (mat_idx * ${numRhs}u + rhs_idx) * ${n}u;
463
+ let tid = local_id.x;
464
+
465
+ // Step 1: Copy b to x (threads collaborate)
466
+ for (var idx = tid; idx < ${n}u; idx += ${workgroupSize}u) {
467
+ x[bx_base + idx] = b[bx_base + idx];
468
+ }
469
+ storageBarrier();
470
+
471
+ // Step 2: Back-substitution from j = n-1 down to 0
472
+ for (var jj = 0u; jj < ${n}u; jj++) {
473
+ let j = ${n - 1}u - jj;
474
+
475
+ // Thread 0 computes x[j] = x[j] / a[j,j]
476
+ if (tid == 0u) {
477
+ ${params.unitDiagonal ? `x_j = x[bx_base + j];` : `x_j = x[bx_base + j] / a[a_base + j * ${n}u + j];`}
478
+ x[bx_base + j] = x_j;
479
+ }
480
+ workgroupBarrier(); // Sync shared memory x_j
481
+
482
+ // All threads subtract x[j] * a[i,j] from x[i] for i < j
483
+ for (var i = tid; i < j; i += ${workgroupSize}u) {
484
+ x[bx_base + i] -= x_j * a[a_base + i * ${n}u + j];
485
+ }
486
+ workgroupBarrier();
487
+ storageBarrier();
488
+ }
489
+ }
490
+ `.trim();
491
+ const totalWorkgroups = numMatrices * numRhs;
492
+ const grid = calculateGrid(totalWorkgroups);
493
+ return [{
494
+ code,
495
+ numInputs: 2,
496
+ numOutputs: 1,
497
+ hasUniform: false,
498
+ passes: [{ grid }]
499
+ }];
500
+ }
501
+ /**
502
+ * Generate a Cholesky decomposition shader.
503
+ *
504
+ * Computes the lower triangular matrix L such that A = L * L^T for each
505
+ * positive semi-definite matrix in the batch. Uses the Cholesky-Crout
506
+ * algorithm which processes column-by-column.
507
+ *
508
+ * For each column j:
509
+ * 1. All threads compute their row's sum in parallel and store to output
510
+ * 2. Thread 0 computes L[j][j] = sqrt(output[j][j]) and stores to shared memory
511
+ * 3. All threads divide their output[i][j] by L[j][j] in parallel
512
+ */
513
+ function createCholesky(device, type) {
514
+ const dtype = type.inputDtypes[0];
515
+ const shape = type.inputShapes[0];
516
+ const n = shape[shape.length - 1];
517
+ const batches = require_backend.prod(shape.slice(0, -2));
518
+ const needsF16 = dtype === require_backend.DType.Float16;
519
+ const ty = dtypeToWgsl(dtype, true);
520
+ const workgroupSize = require_backend.findPow2(n, device.limits.maxComputeWorkgroupSizeX);
521
+ const code = `
522
+ ${needsF16 ? "enable f16;" : ""}
523
+ ${headerWgsl}
524
+
525
+ @group(0) @binding(0) var<storage, read> input: array<${ty}>;
526
+ @group(0) @binding(1) var<storage, read_write> output: array<${ty}>;
527
+
528
+ // Shared memory for the diagonal element
529
+ var<workgroup> L_jj: ${ty};
530
+
531
+ @compute @workgroup_size(${workgroupSize})
532
+ fn main(
533
+ @builtin(workgroup_id) wg_id: vec3<u32>,
534
+ @builtin(local_invocation_id) local_id: vec3<u32>,
535
+ ) {
536
+ let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
537
+ if (batch >= ${batches}u) {
538
+ return;
539
+ }
540
+
541
+ let base = batch * ${n * n}u;
542
+ let tid = local_id.x;
543
+
544
+ // Zero out output and copy lower triangle from input (threads collaborate)
545
+ for (var idx = tid; idx < ${n * n}u; idx += ${workgroupSize}u) {
546
+ let row = idx / ${n}u;
547
+ let col = idx % ${n}u;
548
+ output[base + idx] = select(0, input[base + idx], col <= row);
549
+ }
550
+ storageBarrier();
551
+
552
+ // Cholesky-Crout algorithm: process column by column
553
+ for (var j = 0u; j < ${n}u; j++) {
554
+ // Step 1: All threads compute sum for their rows i >= j in parallel
555
+ // sum = A[i][j] - sum(L[i][k] * L[j][k] for k < j)
556
+ for (var i = j + tid; i < ${n}u; i += ${workgroupSize}u) {
557
+ var sum = output[base + i * ${n}u + j];
558
+ for (var k = 0u; k < j; k++) {
559
+ sum -= output[base + i * ${n}u + k] * output[base + j * ${n}u + k];
560
+ }
561
+ output[base + i * ${n}u + j] = sum;
562
+ }
563
+ storageBarrier();
564
+
565
+ // Step 2: Thread 0 computes L[j][j] = sqrt(output[j][j])
566
+ if (tid == 0u) {
567
+ L_jj = sqrt(output[base + j * ${n}u + j]);
568
+ output[base + j * ${n}u + j] = L_jj;
569
+ }
570
+ workgroupBarrier();
571
+
572
+ // Step 3: All threads divide output[i][j] by L[j][j] for i > j
573
+ for (var i = j + 1u + tid; i < ${n}u; i += ${workgroupSize}u) {
574
+ output[base + i * ${n}u + j] /= L_jj;
575
+ }
576
+ storageBarrier();
577
+ }
578
+ }
579
+ `.trim();
580
+ const grid = calculateGrid(batches);
581
+ return [{
582
+ code,
583
+ numInputs: 1,
584
+ numOutputs: 1,
585
+ hasUniform: false,
586
+ passes: [{ grid }]
587
+ }];
588
+ }
589
+ /**
590
+ * Generate an LU decomposition shader with partial pivoting.
591
+ *
592
+ * Computes PA = LU where P is a permutation matrix, L is lower triangular
593
+ * with unit diagonal, and U is upper triangular.
594
+ *
595
+ * For each column j:
596
+ * 1. Find pivot row (max absolute value in column j, rows >= j)
597
+ * 2. Swap rows j and pivot row
598
+ * 3. Compute L[i][j] = A[i][j] / A[j][j] for i > j
599
+ * 4. Update submatrix: A[i][k] -= L[i][j] * A[j][k] for i > j, k > j
600
+ */
601
+ function createLU(device, type) {
602
+ const dtype = type.inputDtypes[0];
603
+ const shape = type.inputShapes[0];
604
+ const m = shape[shape.length - 2];
605
+ const n = shape[shape.length - 1];
606
+ const r = Math.min(m, n);
607
+ const batches = require_backend.prod(shape.slice(0, -2));
608
+ const needsF16 = dtype === require_backend.DType.Float16;
609
+ const ty = dtypeToWgsl(dtype, true);
610
+ const workgroupSize = require_backend.findPow2(Math.max(m, n), device.limits.maxComputeWorkgroupSizeX);
611
+ const code = `
612
+ ${needsF16 ? "enable f16;" : ""}
613
+ ${headerWgsl}
614
+
615
+ @group(0) @binding(0) var<storage, read> input: array<${ty}>;
616
+ @group(0) @binding(1) var<storage, read_write> lu: array<${ty}>;
617
+ @group(0) @binding(2) var<storage, read_write> pivots: array<i32>;
618
+ @group(0) @binding(3) var<storage, read_write> perm: array<i32>;
619
+
620
+ var<workgroup> pivot_row: u32;
621
+ var<workgroup> pivot_val: ${ty};
622
+
623
+ @compute @workgroup_size(${workgroupSize})
624
+ fn main(
625
+ @builtin(workgroup_id) wg_id: vec3<u32>,
626
+ @builtin(local_invocation_id) local_id: vec3<u32>,
627
+ ) {
628
+ let batch = wg_id.x + wg_id.y * ${gridOffsetY}u;
629
+ if (batch >= ${batches}u) {
630
+ return;
631
+ }
632
+
633
+ let lu_base = batch * ${m * n}u;
634
+ let piv_base = batch * ${r}u;
635
+ let perm_base = batch * ${m}u;
636
+ let tid = local_id.x;
637
+
638
+ // Copy input to lu
639
+ for (var idx = tid; idx < ${m * n}u; idx += ${workgroupSize}u) {
640
+ lu[lu_base + idx] = input[lu_base + idx];
641
+ }
642
+ // Initialize permutation
643
+ for (var idx = tid; idx < ${m}u; idx += ${workgroupSize}u) {
644
+ perm[perm_base + idx] = i32(idx);
645
+ }
646
+ storageBarrier();
647
+
648
+ // LU decomposition with partial pivoting
649
+ for (var j = 0u; j < ${r}u; j++) {
650
+ // Step 1: Thread 0 finds pivot (max abs value in column j, rows >= j)
651
+ if (tid == 0u) {
652
+ var max_val = abs(lu[lu_base + j * ${n}u + j]);
653
+ var max_row = j;
654
+ for (var i = j + 1u; i < ${m}u; i++) {
655
+ let val = abs(lu[lu_base + i * ${n}u + j]);
656
+ if (val > max_val) {
657
+ max_val = val;
658
+ max_row = i;
659
+ }
660
+ }
661
+ pivot_row = max_row;
662
+ pivot_val = lu[lu_base + max_row * ${n}u + j];
663
+ pivots[piv_base + j] = i32(max_row);
664
+ }
665
+ workgroupBarrier();
666
+
667
+ // Step 2: Swap rows j and pivot_row (threads collaborate)
668
+ let pr = pivot_row;
669
+ if (pr != j) {
670
+ for (var col = tid; col < ${n}u; col += ${workgroupSize}u) {
671
+ let tmp = lu[lu_base + j * ${n}u + col];
672
+ lu[lu_base + j * ${n}u + col] = lu[lu_base + pr * ${n}u + col];
673
+ lu[lu_base + pr * ${n}u + col] = tmp;
674
+ }
675
+ if (tid == 0u) {
676
+ let tmp_p = perm[perm_base + j];
677
+ perm[perm_base + j] = perm[perm_base + pr];
678
+ perm[perm_base + pr] = tmp_p;
679
+ }
680
+ }
681
+ storageBarrier();
682
+
683
+ // Step 3: Compute L[i][j] and update submatrix
684
+ // Each thread handles one row i > j
685
+ for (var i = j + 1u + tid; i < ${m}u; i += ${workgroupSize}u) {
686
+ let factor = lu[lu_base + i * ${n}u + j] / pivot_val;
687
+ lu[lu_base + i * ${n}u + j] = factor; // L[i][j]
688
+ for (var k = j + 1u; k < ${n}u; k++) {
689
+ lu[lu_base + i * ${n}u + k] -= factor * lu[lu_base + j * ${n}u + k];
690
+ }
691
+ }
692
+ storageBarrier();
693
+ }
694
+ }
695
+ `.trim();
696
+ const grid = calculateGrid(batches);
697
+ return [{
698
+ code,
699
+ numInputs: 1,
700
+ numOutputs: 3,
701
+ hasUniform: false,
702
+ passes: [{ grid }]
703
+ }];
704
+ }
705
+ function createRoutineShader(device, routine) {
706
+ switch (routine.name) {
707
+ case require_backend.Routines.Sort: return createSort(device, routine.type);
708
+ case require_backend.Routines.Argsort: return createArgsort(device, routine.type);
709
+ case require_backend.Routines.TriangularSolve: return createTriangularSolve(device, routine.type, routine.params);
710
+ case require_backend.Routines.Cholesky: return createCholesky(device, routine.type);
711
+ case require_backend.Routines.LU: return createLU(device, routine.type);
712
+ default: throw new require_backend.UnsupportedRoutineError(routine.name, "webgpu");
713
+ }
714
+ }
715
+
716
+ //#endregion
717
+ //#region src/backend/webgpu.ts
718
+ /** Implementation of `Backend` that uses WebGPU in browsers. */
719
+ var WebGPUBackend = class {
720
+ type = "webgpu";
721
+ maxArgs;
722
+ pipelines;
723
+ syncReader;
724
+ buffers;
725
+ nextSlot;
726
+ #cachedShaderMap = /* @__PURE__ */ new Map();
727
+ #reusableZsb;
728
+ constructor(device) {
729
+ this.device = device;
730
+ if (require_backend.DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
731
+ this.maxArgs = this.device.limits.maxStorageBuffersPerShaderStage - 1;
732
+ this.pipelines = new ShaderPipelineCache(device);
733
+ this.syncReader = new SyncReader(device);
734
+ this.buffers = /* @__PURE__ */ new Map();
735
+ this.nextSlot = 1;
736
+ this.#reusableZsb = this.#createBuffer(4);
737
+ device.addEventListener("uncapturederror", (event) => {
738
+ console.error("Uncaptured error in WebGPU backend:", event.error.message);
739
+ });
740
+ }
741
+ malloc(size, initialData) {
742
+ let buffer;
743
+ const paddedSize = Math.ceil(size / 4) * 4;
744
+ if (size === 0) buffer = this.#reusableZsb;
745
+ else if (initialData) {
746
+ if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
747
+ if (initialData.byteLength < 4096) {
748
+ buffer = this.#createBuffer(paddedSize, { mapped: true });
749
+ new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
750
+ buffer.unmap();
751
+ } else {
752
+ buffer = this.#createBuffer(paddedSize);
753
+ if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
754
+ else {
755
+ const aligned = initialData.byteLength - initialData.byteLength % 4;
756
+ this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
757
+ const remainder = new Uint8Array(4);
758
+ remainder.set(initialData.subarray(aligned));
759
+ this.device.queue.writeBuffer(buffer, aligned, remainder);
760
+ }
761
+ }
762
+ } else buffer = this.#createBuffer(paddedSize);
763
+ const slot = this.nextSlot++;
764
+ this.buffers.set(slot, {
765
+ buffer,
766
+ size,
767
+ ref: 1
768
+ });
769
+ return slot;
770
+ }
771
+ incRef(slot) {
772
+ const buffer = this.buffers.get(slot);
773
+ if (!buffer) throw new require_backend.SlotError(slot);
774
+ buffer.ref++;
775
+ }
776
+ decRef(slot) {
777
+ const buffer = this.buffers.get(slot);
778
+ if (!buffer) throw new require_backend.SlotError(slot);
779
+ buffer.ref--;
780
+ if (buffer.ref === 0) {
781
+ this.buffers.delete(slot);
782
+ if (buffer.buffer !== this.#reusableZsb) buffer.buffer.destroy();
783
+ }
784
+ }
785
+ async read(slot, start, count) {
786
+ const { buffer, size } = this.#getBuffer(slot);
787
+ if (buffer === this.#reusableZsb) return new Uint8Array();
788
+ if (start === void 0) start = 0;
789
+ if (count === void 0) count = size - start;
790
+ const paddedSize = Math.ceil(count / 4) * 4;
791
+ const staging = this.#createBuffer(paddedSize, { read: true });
792
+ try {
793
+ const commandEncoder = this.device.createCommandEncoder();
794
+ commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
795
+ this.device.queue.submit([commandEncoder.finish()]);
796
+ await staging.mapAsync(GPUMapMode.READ);
797
+ const arrayBuffer = staging.getMappedRange();
798
+ return new Uint8Array(arrayBuffer.slice(), 0, count);
799
+ } finally {
800
+ staging.destroy();
801
+ }
802
+ }
803
+ readSync(slot, start, count) {
804
+ const { buffer, size } = this.#getBuffer(slot);
805
+ if (buffer === this.#reusableZsb) return new Uint8Array();
806
+ if (start === void 0) start = 0;
807
+ if (count === void 0) count = size - start;
808
+ return this.syncReader.read(buffer, start, count);
809
+ }
810
+ #cachedShader(kernel) {
811
+ const cacheKey = require_backend.FpHash.hash(kernel);
812
+ let result = this.#cachedShaderMap.get(cacheKey);
813
+ if (!result) {
814
+ result = pipelineSource(this.device, kernel);
815
+ this.#cachedShaderMap.set(cacheKey, result);
816
+ }
817
+ return result;
818
+ }
819
+ async prepareKernel(kernel) {
820
+ const shader = this.#cachedShader(kernel);
821
+ const pipeline = await this.pipelines.prepare(shader);
822
+ return new require_backend.Executable(kernel, [{
823
+ ...shader,
824
+ pipeline
825
+ }]);
826
+ }
827
+ prepareKernelSync(kernel) {
828
+ const shader = this.#cachedShader(kernel);
829
+ const pipeline = this.pipelines.prepareSync(shader);
830
+ return new require_backend.Executable(kernel, [{
831
+ ...shader,
832
+ pipeline
833
+ }]);
834
+ }
835
+ async prepareRoutine(routine) {
836
+ const shaders = createRoutineShader(this.device, routine);
837
+ const dispatches = await Promise.all(shaders.map(async (shader) => {
838
+ const pipeline = await this.pipelines.prepare(shader);
839
+ return {
840
+ ...shader,
841
+ pipeline
842
+ };
843
+ }));
844
+ return new require_backend.Executable(routine, dispatches);
845
+ }
846
+ prepareRoutineSync(routine) {
847
+ const shaders = createRoutineShader(this.device, routine);
848
+ const dispatches = shaders.map((shader) => {
849
+ const pipeline = this.pipelines.prepareSync(shader);
850
+ return {
851
+ ...shader,
852
+ pipeline
853
+ };
854
+ });
855
+ return new require_backend.Executable(routine, dispatches);
856
+ }
857
+ dispatch(exe, inputs, outputs) {
858
+ const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
859
+ const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
860
+ pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
861
+ }
862
+ #getBuffer(slot) {
863
+ const buffer = this.buffers.get(slot);
864
+ if (!buffer) throw new require_backend.SlotError(slot);
865
+ return {
866
+ buffer: buffer.buffer,
867
+ size: buffer.size
868
+ };
869
+ }
870
+ /**
871
+ * Create a GPU buffer.
872
+ *
873
+ * By default, this creates a general-purpose buffer with the given size.
874
+ *
875
+ * - If `mapped` is true, initialize the buffer in mapped mode so that it can
876
+ * be populated with data from the CPU. (Call `.unmap()` later.)
877
+ * - If `read` is true, create a staging buffer for returning data to CPU.
878
+ * (Call `.mapAsync()` later.)
879
+ */
880
+ #createBuffer(size, { mapped = false, read = false } = {}) {
881
+ if (read && mapped) throw new Error("mapped and read cannot both be true");
882
+ const buffer = this.device.createBuffer({
883
+ size,
884
+ usage: read ? GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST : GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
885
+ mappedAtCreation: mapped
886
+ });
887
+ return buffer;
888
+ }
889
+ };
890
+ /**
891
+ * Compiles an expression into WebGPU shader source code.
892
+ *
893
+ * Returns the shader source and the number of workgroups to dispatch along x
894
+ * and y axes, to run the kernel.
895
+ */
896
+ function pipelineSource(device, kernel) {
897
+ const tune = require_backend.tuneWebgpu(kernel);
898
+ if (require_backend.DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
899
+ const { nargs, reduction: re } = kernel;
900
+ const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
901
+ const shader = [];
902
+ let indent = "";
903
+ const pushIndent = Symbol("pushIndent");
904
+ const popIndent = Symbol("popIndent");
905
+ const emit = (...lines) => {
906
+ for (const line of lines) if (line === pushIndent) indent += " ";
907
+ else if (line === popIndent) indent = indent.slice(0, -2);
908
+ else shader.push(line ? indent + line : line);
909
+ };
910
+ if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === require_backend.DType.Float16)) {
911
+ if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
912
+ emit("enable f16;");
913
+ }
914
+ emit(headerWgsl);
915
+ const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
916
+ if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
917
+ if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
918
+ emit("");
919
+ const usedArgs = Array.from({ length: nargs }, () => null);
920
+ tune.exp.fold((exp) => {
921
+ if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
922
+ });
923
+ tune.epilogue?.fold((exp) => {
924
+ if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
925
+ });
926
+ for (let i = 0; i < nargs; i++) {
927
+ const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
928
+ emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
929
+ }
930
+ const resultTy = dtypeToWgsl(kernel.dtype, true);
931
+ emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
932
+ const workgroupSize = require_backend.findPow2(tune.threadCount, 256);
933
+ const gridSize = Math.ceil(tune.threadCount / workgroupSize);
934
+ const [gridX, gridY] = calculateGrid(gridSize);
935
+ emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
936
+ if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
937
+ else {
938
+ const sizeX = gridX * workgroupSize;
939
+ emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
940
+ }
941
+ let gensymCount = 0;
942
+ const gensym = () => `alu${gensymCount++}`;
943
+ const isGensym = (text) => text.match(/^alu[0-9]+$/);
944
+ if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
945
+ const references = /* @__PURE__ */ new Map();
946
+ const seen = /* @__PURE__ */ new Set();
947
+ const countReferences = (exp) => {
948
+ references.set(exp, (references.get(exp) ?? 0) + 1);
949
+ if (!seen.has(exp)) {
950
+ seen.add(exp);
951
+ for (const src of exp.src) countReferences(src);
952
+ }
953
+ };
954
+ const expContext = /* @__PURE__ */ new Map();
955
+ const gen = (exp) => {
956
+ if (expContext.has(exp)) return expContext.get(exp);
957
+ const { op, src, dtype, arg } = exp;
958
+ let source = "";
959
+ if (require_backend.AluGroup.Binary.has(op) || require_backend.AluGroup.Compare.has(op)) {
960
+ const a = gen(src[0]);
961
+ const b = gen(src[1]);
962
+ if (op === require_backend.AluOp.Add) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
963
+ else source = `(${a} + ${b})`;
964
+ else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
965
+ else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
966
+ else source = `(${a} * ${b})`;
967
+ else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
968
+ else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
969
+ else if (op === require_backend.AluOp.Min) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
970
+ else source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
971
+ else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
972
+ else source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
973
+ else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
974
+ else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
975
+ const x = isGensym(a) ? a : gensym();
976
+ if (x !== a) emit(`let ${x} = ${a};`);
977
+ source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
978
+ } else source = `(${a} != ${b})`;
979
+ } else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
980
+ const a = gen(src[0].src[0]);
981
+ source = `inverseSqrt(${a})`;
982
+ } else {
983
+ const a = gen(src[0]);
984
+ if (op === require_backend.AluOp.Sin) source = `sin(${require_backend.strip1(a)})`;
985
+ else if (op === require_backend.AluOp.Cos) source = `cos(${require_backend.strip1(a)})`;
986
+ else if (op === require_backend.AluOp.Asin) source = `asin(${require_backend.strip1(a)})`;
987
+ else if (op === require_backend.AluOp.Atan) source = `atan(${require_backend.strip1(a)})`;
988
+ else if (op === require_backend.AluOp.Exp) source = `exp(${require_backend.strip1(a)})`;
989
+ else if (op === require_backend.AluOp.Log) source = `log(${require_backend.strip1(a)})`;
990
+ else if (op === require_backend.AluOp.Erf || op === require_backend.AluOp.Erfc) {
991
+ const funcName = op === require_backend.AluOp.Erf ? "erf" : "erfc";
992
+ if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${require_backend.strip1(a)})))`;
993
+ else source = `${funcName}(${require_backend.strip1(a)})`;
994
+ } else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${require_backend.strip1(a)})`;
995
+ else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
996
+ else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
997
+ else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
998
+ else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
999
+ else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
1000
+ }
1001
+ else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
1002
+ else if (op === require_backend.AluOp.Threefry2x32) {
1003
+ const x = gensym();
1004
+ const [k0, k1, c0, c1] = src.map((x$1) => require_backend.strip1(gen(x$1)));
1005
+ emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
1006
+ if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
1007
+ else if (arg === 0) source = `${x}.x`;
1008
+ else if (arg === 1) source = `${x}.y`;
1009
+ else throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
1010
+ } else if (op === require_backend.AluOp.Const) return constToWgsl(dtype, arg);
1011
+ else if (op === require_backend.AluOp.Special) return arg[0];
1012
+ else if (op === require_backend.AluOp.Variable) return arg;
1013
+ else if (op === require_backend.AluOp.GlobalIndex) {
1014
+ source = `${args[arg[0]]}[${require_backend.strip1(gen(src[0]))}]`;
1015
+ if (dtype === require_backend.DType.Bool) source = `(${source} != 0)`;
1016
+ }
1017
+ if (!source) throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
1018
+ const typeName = dtypeToWgsl(dtype);
1019
+ if ((references.get(exp) ?? 0) > 1) {
1020
+ const name = gensym();
1021
+ expContext.set(exp, name);
1022
+ emit(`let ${name}: ${typeName} = ${require_backend.strip1(source)};`);
1023
+ return name;
1024
+ } else {
1025
+ expContext.set(exp, source);
1026
+ return source;
1027
+ }
1028
+ };
1029
+ if (!re) {
1030
+ countReferences(tune.exp);
1031
+ let rhs = require_backend.strip1(gen(tune.exp));
1032
+ if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
1033
+ emit(`result[gidx] = ${rhs};`);
1034
+ } else {
1035
+ if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
1036
+ const unroll = tune.size.unroll ?? 1;
1037
+ const upcast = tune.size.upcast ?? 1;
1038
+ const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
1039
+ for (let i = 0; i < upcast; i++) emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
1040
+ emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, pushIndent);
1041
+ const exps = [];
1042
+ const cache = /* @__PURE__ */ new Map();
1043
+ for (let up = 0; up < upcast; up++) {
1044
+ exps.push([]);
1045
+ for (let un = 0; un < unroll; un++) {
1046
+ const exp = tune.exp.substitute({
1047
+ upcast: require_backend.AluExp.i32(up),
1048
+ unroll: require_backend.AluExp.i32(un)
1049
+ });
1050
+ exps[up].push(exp.simplify(cache));
1051
+ countReferences(exps[up][un]);
1052
+ }
1053
+ }
1054
+ const items = exps.map((ar) => ar.map(gen).map(require_backend.strip1));
1055
+ for (let i = 0; i < upcast; i++) {
1056
+ let rhs = items[i][0];
1057
+ for (let j = 1; j < unroll; j++) if (re.op === require_backend.AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
1058
+ else if (re.op === require_backend.AluOp.Mul) rhs = `${rhs} * ${items[i][j]}`;
1059
+ else if (re.op === require_backend.AluOp.Min) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} && ${items[i][j]})` : `min(${rhs}, ${items[i][j]})`;
1060
+ else if (re.op === require_backend.AluOp.Max) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} || ${items[i][j]})` : `max(${rhs}, ${items[i][j]})`;
1061
+ else throw new Error(`Unsupported reduction op: ${re.op}`);
1062
+ if (re.op === require_backend.AluOp.Add) emit(`${acc[i]} += ${rhs};`);
1063
+ else if (re.op === require_backend.AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
1064
+ else if (re.op === require_backend.AluOp.Min) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
1065
+ else emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
1066
+ else if (re.op === require_backend.AluOp.Max) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
1067
+ else emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
1068
+ else throw new Error(`Unsupported reduction op: ${re.op}`);
1069
+ }
1070
+ emit(popIndent, "}");
1071
+ expContext.clear();
1072
+ references.clear();
1073
+ seen.clear();
1074
+ const outputIdxExps = [];
1075
+ const fusionExps = [];
1076
+ for (let i = 0; i < upcast; i++) {
1077
+ const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
1078
+ outputIdxExps.push(exp.simplify(cache));
1079
+ countReferences(outputIdxExps[i]);
1080
+ fusionExps.push(tune.epilogue.substitute({
1081
+ acc: require_backend.AluExp.variable(re.dtype, acc[i]),
1082
+ upcast: require_backend.AluExp.i32(i)
1083
+ }).simplify(cache));
1084
+ countReferences(fusionExps[i]);
1085
+ }
1086
+ for (let i = 0; i < upcast; i++) {
1087
+ const index = require_backend.strip1(gen(outputIdxExps[i]));
1088
+ let rhs = require_backend.strip1(gen(fusionExps[i]));
1089
+ if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
1090
+ emit(`result[${index}] = ${rhs};`);
1091
+ }
1092
+ }
1093
+ emit(popIndent, "}");
1094
+ return {
1095
+ code: shader.join("\n"),
1096
+ numInputs: nargs,
1097
+ numOutputs: 1,
1098
+ hasUniform: false,
1099
+ passes: [{ grid: [gridX, gridY] }]
1100
+ };
1101
+ }
1102
+ function pipelineSubmit(device, pipelines, inputs, outputs) {
1103
+ const commandEncoder = device.createCommandEncoder();
1104
+ for (const { pipeline,...shader } of pipelines) {
1105
+ if (inputs.length !== shader.numInputs || outputs.length !== shader.numOutputs) throw new Error(`webgpu: expected ${shader.numInputs} inputs and ${shader.numOutputs} outputs, got ${inputs.length} inputs and ${outputs.length} outputs`);
1106
+ const filteredPasses = shader.passes.filter(({ grid }) => require_backend.prod(grid) > 0);
1107
+ if (filteredPasses.length === 0) continue;
1108
+ const bindGroup = device.createBindGroup({
1109
+ layout: pipeline.getBindGroupLayout(0),
1110
+ entries: [...inputs.map((buffer, i) => ({
1111
+ binding: i,
1112
+ resource: { buffer }
1113
+ })), ...outputs.map((buffer, i) => ({
1114
+ binding: inputs.length + i,
1115
+ resource: { buffer }
1116
+ }))]
1117
+ });
1118
+ let uniformBindGroup = null;
1119
+ let uniformAlignment = 0;
1120
+ if (shader.hasUniform) {
1121
+ const uniforms = filteredPasses.map(({ uniform }) => uniform);
1122
+ const [uniformBuffer, alignment] = combineUniforms(device, uniforms);
1123
+ uniformAlignment = alignment;
1124
+ uniformBindGroup = device.createBindGroup({
1125
+ layout: pipeline.getBindGroupLayout(1),
1126
+ entries: [{
1127
+ binding: 0,
1128
+ resource: {
1129
+ buffer: uniformBuffer,
1130
+ size: alignment
1131
+ }
1132
+ }]
1133
+ });
1134
+ }
1135
+ for (let i = 0; i < filteredPasses.length; i++) {
1136
+ const { grid } = filteredPasses[i];
1137
+ const passEncoder = commandEncoder.beginComputePass();
1138
+ passEncoder.setPipeline(pipeline);
1139
+ passEncoder.setBindGroup(0, bindGroup);
1140
+ if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
1141
+ passEncoder.dispatchWorkgroups(grid[0], grid[1]);
1142
+ passEncoder.end();
1143
+ }
1144
+ }
1145
+ device.queue.submit([commandEncoder.finish()]);
1146
+ }
1147
+ function combineUniforms(device, uniforms) {
1148
+ for (const buf of uniforms) if (!buf || buf.byteLength === 0 || buf.byteLength !== uniforms[0].byteLength) throw new Error("webgpu: Uniform mismatch between shader passes");
1149
+ const minAlign = device.limits.minUniformBufferOffsetAlignment;
1150
+ const alignment = Math.ceil(uniforms[0].byteLength / minAlign) * minAlign;
1151
+ const buffer = device.createBuffer({
1152
+ size: alignment * uniforms.length,
1153
+ usage: GPUBufferUsage.UNIFORM,
1154
+ mappedAtCreation: true
1155
+ });
1156
+ const bufferMapped = new Uint8Array(buffer.getMappedRange());
1157
+ for (let i = 0; i < uniforms.length; i++) bufferMapped.set(uniforms[i], i * alignment);
1158
+ buffer.unmap();
1159
+ return [buffer, alignment];
1160
+ }
1161
+ /**
1162
+ * A cache for compiled GPU compute pipelines, keyed by the shader source.
1163
+ *
1164
+ * This supports both async compilation (recommended) and a synchronous variant.
1165
+ * If the pipeline is not in the cache, it will be compiled and added. For async
1166
+ * compilation, only one compilation will be in progress at a time for a given
1167
+ * shader source.
1168
+ */
1169
+ var ShaderPipelineCache = class {
1170
+ cache;
1171
+ inProgress;
1172
+ constructor(device) {
1173
+ this.device = device;
1174
+ this.cache = /* @__PURE__ */ new Map();
1175
+ this.inProgress = /* @__PURE__ */ new Map();
1176
+ }
1177
+ #getLayout(shader) {
1178
+ if (shader.numInputs + shader.numOutputs > this.device.limits.maxStorageBuffersPerShaderStage) {
1179
+ const actual = shader.numInputs + shader.numOutputs;
1180
+ const max = this.device.limits.maxStorageBuffersPerShaderStage;
1181
+ throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
1182
+ }
1183
+ const bindGroupLayouts = [this.device.createBindGroupLayout({ entries: require_backend.range(shader.numInputs + shader.numOutputs).map((i) => ({
1184
+ binding: i,
1185
+ visibility: GPUShaderStage.COMPUTE,
1186
+ buffer: { type: i < shader.numInputs ? "read-only-storage" : "storage" }
1187
+ })) })];
1188
+ if (shader.hasUniform) bindGroupLayouts.push(this.device.createBindGroupLayout({ entries: [{
1189
+ binding: 0,
1190
+ visibility: GPUShaderStage.COMPUTE,
1191
+ buffer: {
1192
+ type: "uniform",
1193
+ hasDynamicOffset: true
1194
+ }
1195
+ }] }));
1196
+ return this.device.createPipelineLayout({ bindGroupLayouts });
1197
+ }
1198
+ async prepare(shader) {
1199
+ const existingPipeline = this.cache.get(shader.code);
1200
+ if (existingPipeline) return existingPipeline;
1201
+ const existingPromise = this.inProgress.get(shader.code);
1202
+ if (existingPromise) return await existingPromise;
1203
+ if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
1204
+ const shaderModule = this.device.createShaderModule({ code: shader.code });
1205
+ const promise = (async () => {
1206
+ this.device.pushErrorScope("validation");
1207
+ try {
1208
+ const pipeline$1 = await this.device.createComputePipelineAsync({
1209
+ layout: this.#getLayout(shader),
1210
+ compute: {
1211
+ module: shaderModule,
1212
+ entryPoint: "main"
1213
+ }
1214
+ });
1215
+ await this.device.popErrorScope();
1216
+ return pipeline$1;
1217
+ } catch (_error) {
1218
+ const scope = await this.device.popErrorScope();
1219
+ const emsg = await compileError(shaderModule, scope, shader.code);
1220
+ throw new Error(emsg);
1221
+ }
1222
+ })();
1223
+ this.inProgress.set(shader.code, promise);
1224
+ const pipeline = await promise;
1225
+ this.cache.set(shader.code, pipeline);
1226
+ return pipeline;
1227
+ }
1228
+ prepareSync(shader) {
1229
+ const existingPipeline = this.cache.get(shader.code);
1230
+ if (existingPipeline) return existingPipeline;
1231
+ if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + shader.code);
1232
+ const shaderModule = this.device.createShaderModule({ code: shader.code });
1233
+ this.device.pushErrorScope("validation");
1234
+ const pipeline = this.device.createComputePipeline({
1235
+ layout: this.#getLayout(shader),
1236
+ compute: {
1237
+ module: shaderModule,
1238
+ entryPoint: "main"
1239
+ }
1240
+ });
1241
+ this.device.popErrorScope().then(async (scope) => {
1242
+ if (scope !== null) {
1243
+ const emsg = await compileError(shaderModule, scope, shader.code);
1244
+ console.error(emsg);
1245
+ }
1246
+ });
1247
+ this.cache.set(shader.code, pipeline);
1248
+ return pipeline;
1249
+ }
1250
+ };
1251
+ /** Gather information about a compilation error and format it. */
1252
+ async function compileError(shaderModule, scope, code) {
1253
+ let message = `Failed to compile shader: ${scope ? scope.message : "(no error scope)"}`;
1254
+ const info = await shaderModule.getCompilationInfo();
1255
+ for (const msg of info.messages) message += `\n [${msg.type} at ${msg.lineNum}:${msg.linePos}] ${msg.message}`;
1256
+ if (code) message += `\n\n${code}`;
1257
+ return message;
1258
+ }
1259
+
1260
+ //#endregion
1261
+ exports.WebGPUBackend = WebGPUBackend;