@jax-js/jax 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.3",
3
+ "version": "0.1.5",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -1,663 +0,0 @@
1
- const require_backend = require('./backend-CmaidnkQ.cjs');
2
-
3
- //#region src/backend/webgpu/builtins.ts
4
- const threefrySrc = `
5
- fn threefry2x32(key: vec2<u32>, ctr: vec2<u32>) -> vec2<u32> {
6
- let ks0: u32 = key.x;
7
- let ks1: u32 = key.y;
8
- let ks2: u32 = ks0 ^ ks1 ^ 0x1BD11BDAu;
9
-
10
- var x0: u32 = ctr.x + ks0;
11
- var x1: u32 = ctr.y + ks1;
12
-
13
- x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
14
- x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
15
- x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
16
- x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
17
- x0 += ks1;
18
- x1 += ks2 + 1u;
19
-
20
- x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
21
- x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
22
- x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
23
- x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
24
- x0 += ks2;
25
- x1 += ks0 + 2u;
26
-
27
- x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
28
- x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
29
- x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
30
- x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
31
- x0 += ks0;
32
- x1 += ks1 + 3u;
33
-
34
- x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
35
- x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
36
- x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
37
- x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
38
- x0 += ks1;
39
- x1 += ks2 + 4u;
40
-
41
- x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
42
- x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
43
- x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
44
- x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
45
- x0 += ks2;
46
- x1 += ks0 + 5u;
47
-
48
- return vec2<u32>(x0, x1);
49
- }`;
50
- const erfSrc = `
51
- const _erf_p: f32 = 0.3275911;
52
- const _erf_a1: f32 = 0.254829592;
53
- const _erf_a2: f32 = -0.284496736;
54
- const _erf_a3: f32 = 1.421413741;
55
- const _erf_a4: f32 = -1.453152027;
56
- const _erf_a5: f32 = 1.061405429;
57
- fn erf(x: f32) -> f32 {
58
- let t = 1.0 / (1.0 + _erf_p * abs(x));
59
- let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
60
- return sign(x) * (1.0 - P_t * exp(-x * x));
61
- }
62
- fn erfc(x: f32) -> f32 {
63
- let t = 1.0 / (1.0 + _erf_p * abs(x));
64
- let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
65
- let E = P_t * exp(-x * x);
66
- return select(2.0 - E, E, x >= 0.0);
67
- }`;
68
-
69
- //#endregion
70
- //#region src/backend/webgpu/reader.ts
71
- /**
72
- * Graphics state used to synchronously read data from WebGPU buffers.
73
- *
74
- * This trick is borrowed from TensorFlow.js. Basically, the idea is to create
75
- * an offscreen canvas with one pixel for every 4 bytes ("device storage"), then
76
- * configure it with a WebGPU context. Copy the buffer to a texture, then draw
77
- * the canvas onto another offscreen canvas with '2d' context ("host storage").
78
- *
79
- * Once it's on host storage, we can use `getImageData()` to read the pixels
80
- * from the image directly.
81
- *
82
- * We use 256x256 canvases here (256 KiB). The performance of this is bad
83
- * because it involves multiple data copies, but it still works. We also
84
- * actually need to copy the image twice: once in "opaque" mode for the RGB
85
- * values, and once in "premultiplied" mode for the alpha channel.
86
- *
87
- * https://github.com/tensorflow/tfjs/blob/tfjs-v4.22.0/tfjs-backend-webgpu/src/backend_webgpu.ts#L379
88
- */
89
- var SyncReader = class SyncReader {
90
- static alphaModes = ["opaque", "premultiplied"];
91
- static width = 256;
92
- static height = 256;
93
- initialized = false;
94
- deviceStorage;
95
- deviceContexts;
96
- hostStorage;
97
- hostContext;
98
- constructor(device) {
99
- this.device = device;
100
- }
101
- #init() {
102
- const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
103
- this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
104
- this.deviceContexts = this.deviceStorage.map((canvas, i) => {
105
- const context = canvas.getContext("webgpu");
106
- context.configure({
107
- device: this.device,
108
- format: "bgra8unorm",
109
- usage: GPUTextureUsage.COPY_DST,
110
- alphaMode: SyncReader.alphaModes[i]
111
- });
112
- return context;
113
- });
114
- this.hostStorage = makeCanvas();
115
- this.hostContext = this.hostStorage.getContext("2d", { willReadFrequently: true });
116
- this.initialized = true;
117
- }
118
- read(buffer, start, count) {
119
- if (!this.initialized) this.#init();
120
- const deviceStorage = this.deviceStorage;
121
- const deviceContexts = this.deviceContexts;
122
- const hostContext = this.hostContext;
123
- const pixelsSize = Math.ceil(count / 4);
124
- const bytesPerRow = SyncReader.width * 4;
125
- const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
126
- for (let i = 0; i < deviceContexts.length; i++) {
127
- const texture = deviceContexts[i].getCurrentTexture();
128
- const readData = (width, height, offset$1) => {
129
- const encoder = this.device.createCommandEncoder();
130
- encoder.copyBufferToTexture({
131
- buffer,
132
- bytesPerRow,
133
- offset: offset$1 + start
134
- }, { texture }, {
135
- width,
136
- height,
137
- depthOrArrayLayers: 1
138
- });
139
- const commandBuffer = encoder.finish();
140
- this.device.queue.submit([commandBuffer]);
141
- hostContext.clearRect(0, 0, width, height);
142
- hostContext.drawImage(deviceStorage[i], 0, 0);
143
- const values = hostContext.getImageData(0, 0, width, height).data;
144
- const span = new Uint8ClampedArray(valsGPU, offset$1, 4 * width * height);
145
- const alphaMode = SyncReader.alphaModes[i];
146
- for (let k = 0; k < span.length; k += 4) if (alphaMode === "premultiplied") span[k + 3] = values[k + 3];
147
- else {
148
- span[k] = values[k + 2];
149
- span[k + 1] = values[k + 1];
150
- span[k + 2] = values[k];
151
- }
152
- };
153
- const pixelsPerCanvas = SyncReader.width * SyncReader.height;
154
- const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
155
- let remainder = pixelsSize % pixelsPerCanvas;
156
- const remainderRows = Math.floor(remainder / SyncReader.width);
157
- remainder = remainder % SyncReader.width;
158
- let offset = 0;
159
- for (let j = 0; j < wholeChunks; j++) {
160
- readData(SyncReader.width, SyncReader.height, offset);
161
- offset += pixelsPerCanvas * 4;
162
- }
163
- if (remainderRows > 0) {
164
- readData(SyncReader.width, remainderRows, offset);
165
- offset += remainderRows * SyncReader.width * 4;
166
- }
167
- if (remainder > 0) readData(remainder, 1, offset);
168
- }
169
- return new Uint8Array(valsGPU, 0, count);
170
- }
171
- };
172
-
173
- //#endregion
174
- //#region src/backend/webgpu.ts
175
- /** Implementation of `Backend` that uses WebGPU in browsers. */
176
- var WebGPUBackend = class {
177
- type = "webgpu";
178
- maxArgs;
179
- pipelines;
180
- syncReader;
181
- buffers;
182
- nextSlot;
183
- #cachedShaderMap = /* @__PURE__ */ new Map();
184
- constructor(device) {
185
- this.device = device;
186
- if (require_backend.DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
187
- this.maxArgs = this.device.limits.maxStorageBuffersPerShaderStage - 1;
188
- this.pipelines = new ShaderPipelineCache(device);
189
- this.syncReader = new SyncReader(device);
190
- this.buffers = /* @__PURE__ */ new Map();
191
- this.nextSlot = 1;
192
- }
193
- malloc(size, initialData) {
194
- let buffer;
195
- const paddedSize = Math.ceil(size / 4) * 4;
196
- if (initialData) {
197
- if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
198
- if (initialData.byteLength < 4096) {
199
- buffer = this.#createBuffer(paddedSize, { mapped: true });
200
- new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
201
- buffer.unmap();
202
- } else {
203
- buffer = this.#createBuffer(paddedSize);
204
- if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
205
- else {
206
- const aligned = initialData.byteLength - initialData.byteLength % 4;
207
- this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
208
- const remainder = new Uint8Array(4);
209
- remainder.set(initialData.subarray(aligned));
210
- this.device.queue.writeBuffer(buffer, aligned, remainder);
211
- }
212
- }
213
- } else buffer = this.#createBuffer(paddedSize);
214
- const slot = this.nextSlot++;
215
- this.buffers.set(slot, {
216
- buffer,
217
- size,
218
- ref: 1
219
- });
220
- return slot;
221
- }
222
- incRef(slot) {
223
- const buffer = this.buffers.get(slot);
224
- if (!buffer) throw new require_backend.SlotError(slot);
225
- buffer.ref++;
226
- }
227
- decRef(slot) {
228
- const buffer = this.buffers.get(slot);
229
- if (!buffer) throw new require_backend.SlotError(slot);
230
- buffer.ref--;
231
- if (buffer.ref === 0) {
232
- this.buffers.delete(slot);
233
- buffer.buffer.destroy();
234
- }
235
- }
236
- async read(slot, start, count) {
237
- const { buffer, size } = this.#getBuffer(slot);
238
- if (start === void 0) start = 0;
239
- if (count === void 0) count = size - start;
240
- const paddedSize = Math.ceil(count / 4) * 4;
241
- const staging = this.#createBuffer(paddedSize, { read: true });
242
- try {
243
- const commandEncoder = this.device.createCommandEncoder();
244
- commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
245
- this.device.queue.submit([commandEncoder.finish()]);
246
- await staging.mapAsync(GPUMapMode.READ);
247
- const arrayBuffer = staging.getMappedRange();
248
- return new Uint8Array(arrayBuffer.slice(), 0, count);
249
- } finally {
250
- staging.destroy();
251
- }
252
- }
253
- readSync(slot, start, count) {
254
- const { buffer, size } = this.#getBuffer(slot);
255
- if (start === void 0) start = 0;
256
- if (count === void 0) count = size - start;
257
- return this.syncReader.read(buffer, start, count);
258
- }
259
- #cachedShader(kernel) {
260
- const cacheKey = require_backend.FpHash.hash(kernel);
261
- let result = this.#cachedShaderMap.get(cacheKey);
262
- if (!result) {
263
- result = pipelineSource(this.device, kernel);
264
- this.#cachedShaderMap.set(cacheKey, result);
265
- }
266
- return result;
267
- }
268
- async prepare(kernel) {
269
- const { shader, grid } = this.#cachedShader(kernel);
270
- const pipeline = await this.pipelines.prepare(shader);
271
- return new require_backend.Executable(kernel, {
272
- shader,
273
- grid,
274
- pipeline
275
- });
276
- }
277
- prepareSync(kernel) {
278
- const { shader, grid } = this.#cachedShader(kernel);
279
- const pipeline = this.pipelines.prepareSync(shader);
280
- return new require_backend.Executable(kernel, {
281
- shader,
282
- grid,
283
- pipeline
284
- });
285
- }
286
- dispatch(exe, inputs, outputs) {
287
- if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
288
- if (exe.kernel.size === 0) return;
289
- const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
290
- const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
291
- pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
292
- }
293
- #getBuffer(slot) {
294
- const buffer = this.buffers.get(slot);
295
- if (!buffer) throw new require_backend.SlotError(slot);
296
- return {
297
- buffer: buffer.buffer,
298
- size: buffer.size
299
- };
300
- }
301
- /**
302
- * Create a GPU buffer.
303
- *
304
- * By default, this creates a general-purpose buffer with the given size.
305
- *
306
- * - If `mapped` is true, initialize the buffer in mapped mode so that it can
307
- * be populated with data from the CPU. (Call `.unmap()` later.)
308
- * - If `read` is true, create a staging buffer for returning data to CPU.
309
- * (Call `.mapAsync()` later.)
310
- */
311
- #createBuffer(size, { mapped = false, read = false } = {}) {
312
- if (read && mapped) throw new Error("mapped and read cannot both be true");
313
- const buffer = this.device.createBuffer({
314
- size,
315
- usage: read ? GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST : GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
316
- mappedAtCreation: mapped
317
- });
318
- return buffer;
319
- }
320
- };
321
- function dtypeToWgsl(dtype, storage = false) {
322
- switch (dtype) {
323
- case require_backend.DType.Bool: return storage ? "i32" : "bool";
324
- case require_backend.DType.Int32: return "i32";
325
- case require_backend.DType.Uint32: return "u32";
326
- case require_backend.DType.Float32: return "f32";
327
- case require_backend.DType.Float16: return "f16";
328
- default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
329
- }
330
- }
331
- function constToWgsl(dtype, value) {
332
- if (dtype === require_backend.DType.Bool) return value ? "true" : "false";
333
- if (dtype === require_backend.DType.Int32) return value.toString();
334
- if (dtype === require_backend.DType.Uint32) return value.toString() + "u";
335
- if (dtype === require_backend.DType.Float32) {
336
- if (Number.isNaN(value)) return "nan()";
337
- if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
338
- return "f32(" + value.toString() + ")";
339
- }
340
- if (dtype === require_backend.DType.Float16) {
341
- if (Number.isNaN(value)) return "f16(nan())";
342
- if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
343
- return "f16(" + value.toString() + ")";
344
- }
345
- throw new Error(`Unsupported const dtype: ${dtype}`);
346
- }
347
- /**
348
- * Compiles an expression into WebGPU shader source code.
349
- *
350
- * Returns the shader source and the number of workgroups to dispatch along x
351
- * and y axes, to run the kernel.
352
- */
353
- function pipelineSource(device, kernel) {
354
- const tune = require_backend.tuneWebgpu(kernel);
355
- if (require_backend.DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
356
- const { nargs, reduction: re } = kernel;
357
- const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
358
- const shader = [];
359
- let indent = "";
360
- const pushIndent = Symbol("pushIndent");
361
- const popIndent = Symbol("popIndent");
362
- const emit = (...lines) => {
363
- for (const line of lines) if (line === pushIndent) indent += " ";
364
- else if (line === popIndent) indent = indent.slice(0, -2);
365
- else shader.push(line ? indent + line : line);
366
- };
367
- if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === require_backend.DType.Float16)) {
368
- if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
369
- emit("enable f16;");
370
- }
371
- emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
372
- const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
373
- if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
374
- if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
375
- emit("");
376
- const usedArgs = Array.from({ length: nargs }, () => null);
377
- tune.exp.fold((exp) => {
378
- if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
379
- });
380
- tune.epilogue?.fold((exp) => {
381
- if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
382
- });
383
- for (let i = 0; i < nargs; i++) {
384
- const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
385
- emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
386
- }
387
- const resultTy = dtypeToWgsl(kernel.dtype, true);
388
- emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
389
- const workgroupSize = require_backend.findPow2(tune.threadCount, 256);
390
- const gridSize = Math.ceil(tune.threadCount / workgroupSize);
391
- let gridX = gridSize;
392
- let gridY = 1;
393
- if (gridSize > device.limits.maxComputeWorkgroupsPerDimension) {
394
- gridX = 16384;
395
- gridY = Math.ceil(gridSize / gridX);
396
- }
397
- emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
398
- if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
399
- else {
400
- const sizeX = gridX * workgroupSize;
401
- emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
402
- }
403
- let gensymCount = 0;
404
- const gensym = () => `alu${gensymCount++}`;
405
- const isGensym = (text) => text.match(/^alu[0-9]+$/);
406
- if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
407
- const references = /* @__PURE__ */ new Map();
408
- const seen = /* @__PURE__ */ new Set();
409
- const countReferences = (exp) => {
410
- references.set(exp, (references.get(exp) ?? 0) + 1);
411
- if (!seen.has(exp)) {
412
- seen.add(exp);
413
- for (const src of exp.src) countReferences(src);
414
- }
415
- };
416
- const expContext = /* @__PURE__ */ new Map();
417
- const gen = (exp) => {
418
- if (expContext.has(exp)) return expContext.get(exp);
419
- const { op, src, dtype, arg } = exp;
420
- let source = "";
421
- if (require_backend.AluGroup.Binary.has(op) || require_backend.AluGroup.Compare.has(op)) {
422
- const a = gen(src[0]);
423
- const b = gen(src[1]);
424
- if (op === require_backend.AluOp.Add) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
425
- else source = `(${a} + ${b})`;
426
- else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
427
- else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
428
- else source = `(${a} * ${b})`;
429
- else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
430
- else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
431
- else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
432
- else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
433
- else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
434
- else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
435
- const x = isGensym(a) ? a : gensym();
436
- if (x !== a) emit(`let ${x} = ${a};`);
437
- source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
438
- } else source = `(${a} != ${b})`;
439
- } else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
440
- const a = gen(src[0].src[0]);
441
- source = `inverseSqrt(${a})`;
442
- } else {
443
- const a = gen(src[0]);
444
- if (op === require_backend.AluOp.Sin) source = `sin(${require_backend.strip1(a)})`;
445
- else if (op === require_backend.AluOp.Cos) source = `cos(${require_backend.strip1(a)})`;
446
- else if (op === require_backend.AluOp.Asin) source = `asin(${require_backend.strip1(a)})`;
447
- else if (op === require_backend.AluOp.Atan) source = `atan(${require_backend.strip1(a)})`;
448
- else if (op === require_backend.AluOp.Exp) source = `exp(${require_backend.strip1(a)})`;
449
- else if (op === require_backend.AluOp.Log) source = `log(${require_backend.strip1(a)})`;
450
- else if (op === require_backend.AluOp.Erf || op === require_backend.AluOp.Erfc) {
451
- const funcName = op === require_backend.AluOp.Erf ? "erf" : "erfc";
452
- if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${require_backend.strip1(a)})))`;
453
- else source = `${funcName}(${require_backend.strip1(a)})`;
454
- } else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${require_backend.strip1(a)})`;
455
- else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
456
- else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
457
- else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
458
- else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
459
- else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
460
- }
461
- else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
462
- else if (op === require_backend.AluOp.Threefry2x32) {
463
- const x = gensym();
464
- const [k0, k1, c0, c1] = src.map((x$1) => require_backend.strip1(gen(x$1)));
465
- emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
466
- if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
467
- else if (arg === 0) source = `${x}.x`;
468
- else if (arg === 1) source = `${x}.y`;
469
- else throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
470
- } else if (op === require_backend.AluOp.Const) return constToWgsl(dtype, arg);
471
- else if (op === require_backend.AluOp.Special) return arg[0];
472
- else if (op === require_backend.AluOp.Variable) return arg;
473
- else if (op === require_backend.AluOp.GlobalIndex) {
474
- source = `${args[arg[0]]}[${require_backend.strip1(gen(src[0]))}]`;
475
- if (dtype === require_backend.DType.Bool) source = `(${source} != 0)`;
476
- }
477
- if (!source) throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
478
- const typeName = dtypeToWgsl(dtype);
479
- if ((references.get(exp) ?? 0) > 1) {
480
- const name = gensym();
481
- expContext.set(exp, name);
482
- emit(`let ${name}: ${typeName} = ${require_backend.strip1(source)};`);
483
- return name;
484
- } else {
485
- expContext.set(exp, source);
486
- return source;
487
- }
488
- };
489
- if (!re) {
490
- countReferences(tune.exp);
491
- let rhs = require_backend.strip1(gen(tune.exp));
492
- if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
493
- emit(`result[gidx] = ${rhs};`);
494
- } else {
495
- if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
496
- const unroll = tune.size.unroll ?? 1;
497
- const upcast = tune.size.upcast ?? 1;
498
- const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
499
- for (let i = 0; i < upcast; i++) emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
500
- emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, pushIndent);
501
- const exps = [];
502
- const cache = /* @__PURE__ */ new Map();
503
- for (let up = 0; up < upcast; up++) {
504
- exps.push([]);
505
- for (let un = 0; un < unroll; un++) {
506
- const exp = tune.exp.substitute({
507
- upcast: require_backend.AluExp.i32(up),
508
- unroll: require_backend.AluExp.i32(un)
509
- });
510
- exps[up].push(exp.simplify(cache));
511
- countReferences(exps[up][un]);
512
- }
513
- }
514
- const items = exps.map((ar) => ar.map(gen).map(require_backend.strip1));
515
- for (let i = 0; i < upcast; i++) {
516
- let rhs = items[i][0];
517
- for (let j = 1; j < unroll; j++) if (re.op === require_backend.AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
518
- else if (re.op === require_backend.AluOp.Mul) rhs = `${rhs} * ${items[i][j]}`;
519
- else if (re.op === require_backend.AluOp.Min) rhs = `min(${rhs}, ${items[i][j]})`;
520
- else if (re.op === require_backend.AluOp.Max) rhs = `max(${rhs}, ${items[i][j]})`;
521
- else throw new Error(`Unsupported reduction op: ${re.op}`);
522
- if (re.op === require_backend.AluOp.Add) emit(`${acc[i]} += ${rhs};`);
523
- else if (re.op === require_backend.AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
524
- else if (re.op === require_backend.AluOp.Min) emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
525
- else if (re.op === require_backend.AluOp.Max) emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
526
- else throw new Error(`Unsupported reduction op: ${re.op}`);
527
- }
528
- emit(popIndent, "}");
529
- expContext.clear();
530
- references.clear();
531
- seen.clear();
532
- const outputIdxExps = [];
533
- const fusionExps = [];
534
- for (let i = 0; i < upcast; i++) {
535
- const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
536
- outputIdxExps.push(exp.simplify(cache));
537
- countReferences(outputIdxExps[i]);
538
- fusionExps.push(tune.epilogue.substitute({
539
- acc: require_backend.AluExp.variable(re.dtype, acc[i]),
540
- upcast: require_backend.AluExp.i32(i)
541
- }).simplify(cache));
542
- countReferences(fusionExps[i]);
543
- }
544
- for (let i = 0; i < upcast; i++) {
545
- const index = require_backend.strip1(gen(outputIdxExps[i]));
546
- let rhs = require_backend.strip1(gen(fusionExps[i]));
547
- if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
548
- emit(`result[${index}] = ${rhs};`);
549
- }
550
- }
551
- emit(popIndent, "}");
552
- return {
553
- shader: shader.join("\n"),
554
- grid: [gridX, gridY]
555
- };
556
- }
557
- function pipelineSubmit(device, { pipeline, grid }, inputs, outputs) {
558
- if (inputs.length + outputs.length > device.limits.maxStorageBuffersPerShaderStage) {
559
- const actual = inputs.length + outputs.length;
560
- const max = device.limits.maxStorageBuffersPerShaderStage;
561
- throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
562
- }
563
- const bindGroup = device.createBindGroup({
564
- layout: pipeline.getBindGroupLayout(0),
565
- entries: [...inputs.map((buffer, i) => {
566
- return {
567
- binding: i,
568
- resource: { buffer }
569
- };
570
- }), {
571
- binding: inputs.length,
572
- resource: { buffer: outputs[0] }
573
- }]
574
- });
575
- const commandEncoder = device.createCommandEncoder();
576
- const passEncoder = commandEncoder.beginComputePass();
577
- passEncoder.setPipeline(pipeline);
578
- passEncoder.setBindGroup(0, bindGroup);
579
- passEncoder.dispatchWorkgroups(grid[0], grid[1]);
580
- passEncoder.end();
581
- device.queue.submit([commandEncoder.finish()]);
582
- }
583
- /**
584
- * A cache for compiled GPU compute pipelines, keyed by the shader source.
585
- *
586
- * This supports both async compilation (recommended) and a synchronous variant.
587
- * If the pipeline is not in the cache, it will be compiled and added. For async
588
- * compilation, only one compilation will be in progress at a time for a given
589
- * shader source.
590
- */
591
- var ShaderPipelineCache = class {
592
- cache;
593
- inProgress;
594
- constructor(device) {
595
- this.device = device;
596
- this.cache = /* @__PURE__ */ new Map();
597
- this.inProgress = /* @__PURE__ */ new Map();
598
- }
599
- async prepare(code) {
600
- const existingPipeline = this.cache.get(code);
601
- if (existingPipeline) return existingPipeline;
602
- const existingPromise = this.inProgress.get(code);
603
- if (existingPromise) return await existingPromise;
604
- if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
605
- const shaderModule = this.device.createShaderModule({ code });
606
- const promise = (async () => {
607
- this.device.pushErrorScope("validation");
608
- try {
609
- const pipeline$1 = await this.device.createComputePipelineAsync({
610
- layout: "auto",
611
- compute: {
612
- module: shaderModule,
613
- entryPoint: "main"
614
- }
615
- });
616
- await this.device.popErrorScope();
617
- return pipeline$1;
618
- } catch (_error) {
619
- const scope = await this.device.popErrorScope();
620
- const emsg = await compileError(shaderModule, scope, code);
621
- throw new Error(emsg);
622
- }
623
- })();
624
- this.inProgress.set(code, promise);
625
- const pipeline = await promise;
626
- this.cache.set(code, pipeline);
627
- return pipeline;
628
- }
629
- prepareSync(code) {
630
- const existingPipeline = this.cache.get(code);
631
- if (existingPipeline) return existingPipeline;
632
- if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
633
- const shaderModule = this.device.createShaderModule({ code });
634
- this.device.pushErrorScope("validation");
635
- const pipeline = this.device.createComputePipeline({
636
- layout: "auto",
637
- compute: {
638
- module: shaderModule,
639
- entryPoint: "main"
640
- }
641
- });
642
- this.device.popErrorScope().then(async (scope) => {
643
- if (scope !== null) {
644
- const emsg = await compileError(shaderModule, scope, code);
645
- console.error(emsg);
646
- }
647
- });
648
- this.cache.set(code, pipeline);
649
- return pipeline;
650
- }
651
- };
652
- /** Gather information about a compilation error and format it. */
653
- async function compileError(shaderModule, scope, code) {
654
- let message = `Failed to compile shader: ${scope ? scope.message : "(no error scope)"}`;
655
- const info = await shaderModule.getCompilationInfo();
656
- for (const msg of info.messages) message += `\n [${msg.type} at ${msg.lineNum}:${msg.linePos}] ${msg.message}`;
657
- if (code) message += `\n\n${code}`;
658
- return message;
659
- }
660
-
661
- //#endregion
662
- exports.WebGPUBackend = WebGPUBackend;
663
- //# sourceMappingURL=webgpu-BVns4DbI.cjs.map