@jax-js/jax 0.0.1 → 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,559 +0,0 @@
1
- import {
2
- AluExp,
3
- AluGroup,
4
- DEBUG,
5
- Executable,
6
- FpHash,
7
- SlotError,
8
- findPow2,
9
- strip1,
10
- tuneWebgpu
11
- } from "./chunk-B2GFURUN.js";
12
-
13
- // src/backend/webgpu.ts
14
- var WebGPUBackend = class {
15
- constructor(device) {
16
- this.device = device;
17
- if (DEBUG >= 3 && device.adapterInfo) {
18
- console.info(
19
- "webgpu adapter:",
20
- device.adapterInfo.vendor,
21
- device.adapterInfo.architecture
22
- );
23
- }
24
- this.maxArgs = this.device.limits.maxStorageBuffersPerShaderStage - 1;
25
- this.pipelines = new ShaderPipelineCache(device);
26
- this.syncReader = new SyncReader(device);
27
- this.buffers = /* @__PURE__ */ new Map();
28
- this.nextSlot = 1;
29
- }
30
- type = "webgpu";
31
- maxArgs;
32
- pipelines;
33
- syncReader;
34
- buffers;
35
- nextSlot;
36
- #cachedShaderMap = /* @__PURE__ */ new Map();
37
- malloc(size, initialData) {
38
- let buffer;
39
- if (initialData) {
40
- if (initialData.byteLength !== size) {
41
- throw new Error("initialData size does not match buffer size");
42
- }
43
- if (initialData.byteLength < 4096) {
44
- buffer = this.#createBuffer(size, { mapped: true });
45
- new Uint8Array(buffer.getMappedRange()).set(
46
- new Uint8Array(initialData)
47
- );
48
- buffer.unmap();
49
- } else {
50
- buffer = this.#createBuffer(size);
51
- this.device.queue.writeBuffer(buffer, 0, initialData);
52
- }
53
- } else {
54
- buffer = this.#createBuffer(size);
55
- }
56
- const slot = this.nextSlot++;
57
- this.buffers.set(slot, { buffer, ref: 1 });
58
- return slot;
59
- }
60
- incRef(slot) {
61
- const buffer = this.buffers.get(slot);
62
- if (!buffer) throw new SlotError(slot);
63
- buffer.ref++;
64
- }
65
- decRef(slot) {
66
- const buffer = this.buffers.get(slot);
67
- if (!buffer) throw new SlotError(slot);
68
- buffer.ref--;
69
- if (buffer.ref === 0) {
70
- this.buffers.delete(slot);
71
- buffer.buffer.destroy();
72
- }
73
- }
74
- async read(slot, start, count) {
75
- const buffer = this.#getBuffer(slot);
76
- if (start === void 0) start = 0;
77
- if (count === void 0) count = buffer.size - start;
78
- const staging = this.#createBuffer(count, { read: true });
79
- try {
80
- const commandEncoder = this.device.createCommandEncoder();
81
- commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, count);
82
- this.device.queue.submit([commandEncoder.finish()]);
83
- await staging.mapAsync(GPUMapMode.READ);
84
- const arrayBuffer = staging.getMappedRange();
85
- return arrayBuffer.slice();
86
- } finally {
87
- staging.destroy();
88
- }
89
- }
90
- readSync(slot, start, count) {
91
- const buffer = this.#getBuffer(slot);
92
- if (start === void 0) start = 0;
93
- if (count === void 0) count = buffer.size - start;
94
- return this.syncReader.read(buffer, start, count);
95
- }
96
- #cachedShader(kernel) {
97
- const cacheKey = FpHash.hash(kernel);
98
- let result = this.#cachedShaderMap.get(cacheKey);
99
- if (!result) {
100
- result = pipelineSource(this.device, kernel);
101
- this.#cachedShaderMap.set(cacheKey, result);
102
- }
103
- return result;
104
- }
105
- async prepare(kernel) {
106
- const { shader, grid } = this.#cachedShader(kernel);
107
- const pipeline = await this.pipelines.prepare(shader);
108
- return new Executable(kernel, { shader, grid, pipeline });
109
- }
110
- prepareSync(kernel) {
111
- const { shader, grid } = this.#cachedShader(kernel);
112
- const pipeline = this.pipelines.prepareSync(shader);
113
- return new Executable(kernel, { shader, grid, pipeline });
114
- }
115
- dispatch(exe, inputs, outputs) {
116
- const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
117
- const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
118
- pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
119
- }
120
- #getBuffer(slot) {
121
- const buffer = this.buffers.get(slot);
122
- if (!buffer) throw new SlotError(slot);
123
- return buffer.buffer;
124
- }
125
- /**
126
- * Create a GPU buffer.
127
- *
128
- * By default, this creates a general-purpose buffer with the given size.
129
- *
130
- * - If `mapped` is true, initialize the buffer in mapped mode so that it can
131
- * be populated with data from the CPU. (Call `.unmap()` later.)
132
- * - If `read` is true, create a staging buffer for returning data to CPU.
133
- * (Call `.mapAsync()` later.)
134
- */
135
- #createBuffer(size, { mapped = false, read = false } = {}) {
136
- if (read && mapped) {
137
- throw new Error("mapped and read cannot both be true");
138
- }
139
- const buffer = this.device.createBuffer({
140
- size,
141
- usage: read ? GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST : GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
142
- mappedAtCreation: mapped
143
- });
144
- return buffer;
145
- }
146
- };
147
- function dtypeToWgsl(dtype, storage = false) {
148
- switch (dtype) {
149
- case "bool" /* Bool */:
150
- return storage ? "i32" : "bool";
151
- // WebGPU does not support bools in buffers.
152
- case "int32" /* Int32 */:
153
- return "i32";
154
- case "float32" /* Float32 */:
155
- return "f32";
156
- default:
157
- throw new Error(`Unsupported dtype: ${dtype}`);
158
- }
159
- }
160
- function constToWgsl(dtype, value) {
161
- if (dtype === "bool" /* Bool */) return value ? "true" : "false";
162
- if (dtype === "int32" /* Int32 */) return value.toString();
163
- if (dtype === "float32" /* Float32 */) {
164
- let s = value.toString();
165
- if (!s.includes(".")) s += ".0";
166
- return s;
167
- }
168
- throw new Error(`Unsupported const dtype: ${dtype}`);
169
- }
170
- function pipelineSource(device, kernel) {
171
- const tune = tuneWebgpu(kernel);
172
- if (DEBUG >= 3) {
173
- console.info(`kernel.exp: ${kernel.exp}
174
- tune.exp: ${tune.exp}`);
175
- }
176
- const { nargs, reduction: re } = kernel;
177
- const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
178
- const shader = [];
179
- let indent = "";
180
- const pushIndent = Symbol("pushIndent");
181
- const popIndent = Symbol("popIndent");
182
- const emit = (...lines) => {
183
- for (const line of lines) {
184
- if (line === pushIndent) indent += " ";
185
- else if (line === popIndent) indent = indent.slice(0, -2);
186
- else shader.push(line ? indent + line : line);
187
- }
188
- };
189
- const usedArgs = Array.from({ length: nargs }, () => null);
190
- tune.exp.fold((exp) => {
191
- if (exp.op === "GlobalIndex" /* GlobalIndex */) usedArgs[exp.arg] = exp.dtype;
192
- });
193
- for (let i = 0; i < nargs; i++) {
194
- const ty = dtypeToWgsl(usedArgs[i] ?? "float32" /* Float32 */, true);
195
- emit(
196
- `@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`
197
- );
198
- }
199
- const resultTy = dtypeToWgsl(re?.dtype ?? tune.exp.dtype, true);
200
- emit(
201
- `@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`
202
- );
203
- const workgroupSize = findPow2(tune.threadCount, 256);
204
- emit(
205
- "",
206
- `@compute @workgroup_size(${workgroupSize})`,
207
- "fn main(@builtin(global_invocation_id) id : vec3<u32>) {",
208
- pushIndent,
209
- `if (id.x >= ${tune.threadCount}) { return; }`,
210
- "let gidx: i32 = i32(id.x);"
211
- );
212
- let gensymCount = 0;
213
- const gensym = () => `alu${gensymCount++}`;
214
- for (let i = 0; i < args.length; i++) {
215
- if (!usedArgs[i]) emit(`_ = &${args[i]};`);
216
- }
217
- const references = /* @__PURE__ */ new Map();
218
- const seen = /* @__PURE__ */ new Set();
219
- const countReferences = (exp) => {
220
- references.set(exp, (references.get(exp) ?? 0) + 1);
221
- if (!seen.has(exp)) {
222
- seen.add(exp);
223
- for (const src of exp.src) countReferences(src);
224
- }
225
- };
226
- const expContext = /* @__PURE__ */ new Map();
227
- const gen = (exp) => {
228
- if (expContext.has(exp)) return expContext.get(exp);
229
- const { op, src, dtype, arg } = exp;
230
- let source = "";
231
- if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
232
- const a = gen(src[0]);
233
- const b = gen(src[1]);
234
- if (op === "Add" /* Add */) {
235
- if (dtype === "bool" /* Bool */) source = `(${a} || ${b})`;
236
- else source = `(${a} + ${b})`;
237
- } else if (op === "Sub" /* Sub */) source = `(${a} - ${b})`;
238
- else if (op === "Mul" /* Mul */) {
239
- if (dtype === "bool" /* Bool */) source = `(${a} && ${b})`;
240
- else source = `(${a} * ${b})`;
241
- } else if (op === "Idiv" /* Idiv */)
242
- source = dtype === "int32" /* Int32 */ ? `(${a} / ${b})` : `floor(${a} / ${b})`;
243
- else if (op === "Mod" /* Mod */) source = `(${a} % ${b})`;
244
- else if (op === "Min" /* Min */) source = `min(${strip1(a)}, ${strip1(b)})`;
245
- else if (op === "Max" /* Max */) source = `max(${strip1(a)}, ${strip1(b)})`;
246
- else if (op === "Cmplt" /* Cmplt */) source = `(${a} < ${b})`;
247
- else if (op === "Cmpne" /* Cmpne */) source = `(${a} != ${b})`;
248
- } else if (AluGroup.Unary.has(op)) {
249
- const a = gen(src[0]);
250
- if (op === "Sin" /* Sin */) source = `sin(${a})`;
251
- else if (op === "Cos" /* Cos */) source = `cos(${a})`;
252
- else if (op === "Exp" /* Exp */) source = `exp(${a})`;
253
- else if (op === "Log" /* Log */) source = `log(${a})`;
254
- else if (op === "Reciprocal" /* Reciprocal */) source = `(1.0 / ${a})`;
255
- else if (op === "Cast" /* Cast */)
256
- source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
257
- } else if (op === "Where" /* Where */) {
258
- source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
259
- } else if (op === "Const" /* Const */) {
260
- return constToWgsl(dtype, arg);
261
- } else if (op === "Special" /* Special */) {
262
- return arg[0];
263
- } else if (op === "Variable" /* Variable */) {
264
- return arg;
265
- } else if (op === "GlobalIndex" /* GlobalIndex */) {
266
- source = `${args[arg]}[${strip1(gen(src[0]))}]`;
267
- if (dtype === "bool" /* Bool */) source = `(${source} != 0)`;
268
- }
269
- if (!source) throw new Error(`Missing impl for op: ${op}`);
270
- const typeName = dtypeToWgsl(dtype);
271
- if ((references.get(exp) ?? 0) > 1) {
272
- const name = gensym();
273
- expContext.set(exp, name);
274
- emit(`let ${name}: ${typeName} = ${strip1(source)};`);
275
- return name;
276
- } else {
277
- expContext.set(exp, source);
278
- return source;
279
- }
280
- };
281
- if (!kernel.reduction) {
282
- countReferences(tune.exp);
283
- let rhs = strip1(gen(tune.exp));
284
- if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
285
- emit(`result[gidx] = ${rhs};`);
286
- } else {
287
- const re2 = kernel.reduction;
288
- if ((tune.size.groups ?? 1) > 1) {
289
- throw new Error("WebGPU backend does not support group optimization yet");
290
- }
291
- const unroll = tune.size.unroll ?? 1;
292
- const upcast = tune.size.upcast ?? 1;
293
- const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
294
- for (let i = 0; i < upcast; i++) {
295
- emit(
296
- `var ${acc[i]}: ${dtypeToWgsl(tune.exp.dtype)} = ${constToWgsl(re2.dtype, re2.identity)};`
297
- );
298
- }
299
- emit(
300
- `for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`,
301
- pushIndent
302
- );
303
- const exps = [];
304
- const cache = /* @__PURE__ */ new Map();
305
- for (let up = 0; up < upcast; up++) {
306
- exps.push([]);
307
- for (let un = 0; un < unroll; un++) {
308
- const exp = tune.exp.substitute({
309
- upcast: AluExp.i32(up),
310
- unroll: AluExp.i32(un)
311
- });
312
- exps[up].push(exp.simplify(cache));
313
- countReferences(exps[up][un]);
314
- }
315
- }
316
- const items = exps.map((ar) => ar.map(gen).map(strip1));
317
- for (let i = 0; i < upcast; i++) {
318
- let rhs = items[i][0];
319
- for (let j = 1; j < unroll; j++) {
320
- if (re2.op === "Add" /* Add */) rhs = `${rhs} + ${items[i][j]}`;
321
- else if (re2.op === "Mul" /* Mul */) rhs = `${rhs} * ${items[i][j]}`;
322
- else if (re2.op === "Min" /* Min */) rhs = `min(${rhs}, ${items[i][j]})`;
323
- else if (re2.op === "Max" /* Max */) rhs = `max(${rhs}, ${items[i][j]})`;
324
- else throw new Error(`Unsupported reduction op: ${re2.op}`);
325
- }
326
- if (re2.op === "Add" /* Add */) emit(`${acc[i]} += ${rhs};`);
327
- else if (re2.op === "Mul" /* Mul */) emit(`${acc[i]} *= ${rhs};`);
328
- else if (re2.op === "Min" /* Min */) emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
329
- else if (re2.op === "Max" /* Max */) emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
330
- else throw new Error(`Unsupported reduction op: ${re2.op}`);
331
- }
332
- emit(popIndent, "}");
333
- expContext.clear();
334
- references.clear();
335
- seen.clear();
336
- const outputIdxExps = [];
337
- const fusionExps = [];
338
- for (let i = 0; i < upcast; i++) {
339
- const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
340
- outputIdxExps.push(exp.simplify(cache));
341
- countReferences(outputIdxExps[i]);
342
- fusionExps.push(
343
- re2.fusion.substitute({ acc: AluExp.variable(re2.dtype, acc[i]) }).simplify(cache)
344
- );
345
- countReferences(fusionExps[i]);
346
- }
347
- for (let i = 0; i < upcast; i++) {
348
- const index = strip1(gen(outputIdxExps[i]));
349
- let rhs = strip1(gen(fusionExps[i]));
350
- if (resultTy !== dtypeToWgsl(fusionExps[i].dtype))
351
- rhs = `${resultTy}(${rhs})`;
352
- emit(`result[${index}] = ${rhs};`);
353
- }
354
- }
355
- emit(popIndent, "}");
356
- return {
357
- shader: shader.join("\n"),
358
- grid: [Math.ceil(tune.threadCount / workgroupSize), 1]
359
- };
360
- }
361
- function pipelineSubmit(device, { pipeline, grid }, inputs, outputs) {
362
- if (inputs.length + outputs.length > device.limits.maxStorageBuffersPerShaderStage) {
363
- const actual = inputs.length + outputs.length;
364
- const max = device.limits.maxStorageBuffersPerShaderStage;
365
- throw new Error(
366
- `Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`
367
- );
368
- }
369
- const bindGroup = device.createBindGroup({
370
- layout: pipeline.getBindGroupLayout(0),
371
- entries: [
372
- ...inputs.map((buffer, i) => {
373
- return { binding: i, resource: { buffer } };
374
- }),
375
- { binding: inputs.length, resource: { buffer: outputs[0] } }
376
- ]
377
- });
378
- const commandEncoder = device.createCommandEncoder();
379
- const passEncoder = commandEncoder.beginComputePass();
380
- passEncoder.setPipeline(pipeline);
381
- passEncoder.setBindGroup(0, bindGroup);
382
- passEncoder.dispatchWorkgroups(grid[0], grid[1]);
383
- passEncoder.end();
384
- device.queue.submit([commandEncoder.finish()]);
385
- }
386
- var ShaderPipelineCache = class {
387
- constructor(device) {
388
- this.device = device;
389
- this.cache = /* @__PURE__ */ new Map();
390
- this.inProgress = /* @__PURE__ */ new Map();
391
- }
392
- cache;
393
- inProgress;
394
- async prepare(code) {
395
- const existingPipeline = this.cache.get(code);
396
- if (existingPipeline) return existingPipeline;
397
- const existingPromise = this.inProgress.get(code);
398
- if (existingPromise) return await existingPromise;
399
- if (DEBUG >= 2) {
400
- console.info("=========== WebGPU shader ===========\n" + code);
401
- }
402
- const shaderModule = this.device.createShaderModule({ code });
403
- const promise = (async () => {
404
- this.device.pushErrorScope("validation");
405
- try {
406
- const pipeline2 = await this.device.createComputePipelineAsync({
407
- layout: "auto",
408
- compute: {
409
- module: shaderModule,
410
- entryPoint: "main"
411
- }
412
- });
413
- await this.device.popErrorScope();
414
- return pipeline2;
415
- } catch (_error) {
416
- const scope = await this.device.popErrorScope();
417
- const emsg = await compileError(shaderModule, scope, code);
418
- throw new Error(emsg);
419
- }
420
- })();
421
- this.inProgress.set(code, promise);
422
- const pipeline = await promise;
423
- this.cache.set(code, pipeline);
424
- return pipeline;
425
- }
426
- prepareSync(code) {
427
- const existingPipeline = this.cache.get(code);
428
- if (existingPipeline) return existingPipeline;
429
- if (DEBUG >= 2) {
430
- console.info("=========== WebGPU shader ===========\n" + code);
431
- }
432
- const shaderModule = this.device.createShaderModule({ code });
433
- this.device.pushErrorScope("validation");
434
- const pipeline = this.device.createComputePipeline({
435
- layout: "auto",
436
- compute: {
437
- module: shaderModule,
438
- entryPoint: "main"
439
- }
440
- });
441
- this.device.popErrorScope().then(async (scope) => {
442
- if (scope !== null) {
443
- const emsg = await compileError(shaderModule, scope, code);
444
- console.error(emsg);
445
- }
446
- });
447
- this.cache.set(code, pipeline);
448
- return pipeline;
449
- }
450
- };
451
- async function compileError(shaderModule, scope, code) {
452
- let message = `Failed to compile shader: ${scope ? scope.message : "(no error scope)"}`;
453
- const info = await shaderModule.getCompilationInfo();
454
- for (const msg of info.messages) {
455
- message += `
456
- [${msg.type} at ${msg.lineNum}:${msg.linePos}] ${msg.message}`;
457
- }
458
- if (code) {
459
- message += `
460
-
461
- ${code}`;
462
- }
463
- return message;
464
- }
465
- var SyncReader = class _SyncReader {
466
- constructor(device) {
467
- this.device = device;
468
- }
469
- static alphaModes = [
470
- "opaque",
471
- "premultiplied"
472
- ];
473
- static width = 256;
474
- static height = 256;
475
- initialized = false;
476
- deviceStorage;
477
- deviceContexts;
478
- hostStorage;
479
- hostContext;
480
- #init() {
481
- const makeCanvas = () => new OffscreenCanvas(_SyncReader.width, _SyncReader.height);
482
- this.deviceStorage = _SyncReader.alphaModes.map(makeCanvas);
483
- this.deviceContexts = this.deviceStorage.map((canvas, i) => {
484
- const context = canvas.getContext("webgpu");
485
- context.configure({
486
- device: this.device,
487
- // rgba8unorm is not supported on Chrome for macOS.
488
- // https://bugs.chromium.org/p/chromium/issues/detail?id=1298618
489
- format: "bgra8unorm",
490
- usage: GPUTextureUsage.COPY_DST,
491
- alphaMode: _SyncReader.alphaModes[i]
492
- });
493
- return context;
494
- });
495
- this.hostStorage = makeCanvas();
496
- this.hostContext = this.hostStorage.getContext("2d", {
497
- willReadFrequently: true
498
- });
499
- this.initialized = true;
500
- }
501
- read(buffer, start, count) {
502
- if (!this.initialized) this.#init();
503
- if (count % 4 !== 0) {
504
- throw new Error("Read size must be a multiple of 4 bytes");
505
- }
506
- const deviceStorage = this.deviceStorage;
507
- const deviceContexts = this.deviceContexts;
508
- const hostContext = this.hostContext;
509
- const pixelsSize = count / 4;
510
- const bytesPerRow = _SyncReader.width * 4;
511
- const valsGPU = new ArrayBuffer(count);
512
- for (let i = 0; i < deviceContexts.length; i++) {
513
- const texture = deviceContexts[i].getCurrentTexture();
514
- const readData = (width, height, offset2) => {
515
- const encoder = this.device.createCommandEncoder();
516
- encoder.copyBufferToTexture(
517
- { buffer, bytesPerRow, offset: offset2 + start },
518
- { texture },
519
- { width, height, depthOrArrayLayers: 1 }
520
- );
521
- const commandBuffer = encoder.finish();
522
- this.device.queue.submit([commandBuffer]);
523
- hostContext.clearRect(0, 0, width, height);
524
- hostContext.drawImage(deviceStorage[i], 0, 0);
525
- const values = hostContext.getImageData(0, 0, width, height).data;
526
- const span = new Uint8ClampedArray(valsGPU, offset2, 4 * width * height);
527
- const alphaMode = _SyncReader.alphaModes[i];
528
- for (let k = 0; k < span.length; k += 4) {
529
- if (alphaMode === "premultiplied") {
530
- span[k + 3] = values[k + 3];
531
- } else {
532
- span[k] = values[k + 2];
533
- span[k + 1] = values[k + 1];
534
- span[k + 2] = values[k];
535
- }
536
- }
537
- };
538
- const pixelsPerCanvas = _SyncReader.width * _SyncReader.height;
539
- const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
540
- let remainder = pixelsSize % pixelsPerCanvas;
541
- const remainderRows = Math.floor(remainder / _SyncReader.width);
542
- remainder = remainder % _SyncReader.width;
543
- let offset = 0;
544
- for (let j = 0; j < wholeChunks; j++) {
545
- readData(_SyncReader.width, _SyncReader.height, offset);
546
- offset += pixelsPerCanvas * 4;
547
- }
548
- if (remainderRows > 0) {
549
- readData(_SyncReader.width, remainderRows, offset);
550
- offset += remainderRows * _SyncReader.width * 4;
551
- }
552
- if (remainder > 0) readData(remainder, 1, offset);
553
- }
554
- return valsGPU;
555
- }
556
- };
557
- export {
558
- WebGPUBackend
559
- };