@jax-js/jax 0.0.1 → 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,591 @@
1
+ const require_backend = require('./backend-BK21PBVP.cjs');
2
+
3
+ //#region src/backend/webgpu.ts
4
+ /** Implementation of `Backend` that uses WebGPU in browsers. */
5
+ var WebGPUBackend = class {
6
+ type = "webgpu";
7
+ maxArgs;
8
+ pipelines;
9
+ syncReader;
10
+ buffers;
11
+ nextSlot;
12
+ #cachedShaderMap = /* @__PURE__ */ new Map();
13
+ constructor(device) {
14
+ this.device = device;
15
+ if (require_backend.DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
16
+ this.maxArgs = this.device.limits.maxStorageBuffersPerShaderStage - 1;
17
+ this.pipelines = new ShaderPipelineCache(device);
18
+ this.syncReader = new SyncReader(device);
19
+ this.buffers = /* @__PURE__ */ new Map();
20
+ this.nextSlot = 1;
21
+ }
22
+ malloc(size, initialData) {
23
+ let buffer;
24
+ if (initialData) {
25
+ if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
26
+ if (initialData.byteLength < 4096) {
27
+ buffer = this.#createBuffer(size, { mapped: true });
28
+ new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(initialData));
29
+ buffer.unmap();
30
+ } else {
31
+ buffer = this.#createBuffer(size);
32
+ this.device.queue.writeBuffer(buffer, 0, initialData);
33
+ }
34
+ } else buffer = this.#createBuffer(size);
35
+ const slot = this.nextSlot++;
36
+ this.buffers.set(slot, {
37
+ buffer,
38
+ ref: 1
39
+ });
40
+ return slot;
41
+ }
42
+ incRef(slot) {
43
+ const buffer = this.buffers.get(slot);
44
+ if (!buffer) throw new require_backend.SlotError(slot);
45
+ buffer.ref++;
46
+ }
47
+ decRef(slot) {
48
+ const buffer = this.buffers.get(slot);
49
+ if (!buffer) throw new require_backend.SlotError(slot);
50
+ buffer.ref--;
51
+ if (buffer.ref === 0) {
52
+ this.buffers.delete(slot);
53
+ buffer.buffer.destroy();
54
+ }
55
+ }
56
+ async read(slot, start, count) {
57
+ const buffer = this.#getBuffer(slot);
58
+ if (start === void 0) start = 0;
59
+ if (count === void 0) count = buffer.size - start;
60
+ const staging = this.#createBuffer(count, { read: true });
61
+ try {
62
+ const commandEncoder = this.device.createCommandEncoder();
63
+ commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, count);
64
+ this.device.queue.submit([commandEncoder.finish()]);
65
+ await staging.mapAsync(GPUMapMode.READ);
66
+ const arrayBuffer = staging.getMappedRange();
67
+ return arrayBuffer.slice();
68
+ } finally {
69
+ staging.destroy();
70
+ }
71
+ }
72
+ readSync(slot, start, count) {
73
+ const buffer = this.#getBuffer(slot);
74
+ if (start === void 0) start = 0;
75
+ if (count === void 0) count = buffer.size - start;
76
+ return this.syncReader.read(buffer, start, count);
77
+ }
78
+ #cachedShader(kernel) {
79
+ const cacheKey = require_backend.FpHash.hash(kernel);
80
+ let result = this.#cachedShaderMap.get(cacheKey);
81
+ if (!result) {
82
+ result = pipelineSource(this.device, kernel);
83
+ this.#cachedShaderMap.set(cacheKey, result);
84
+ }
85
+ return result;
86
+ }
87
+ async prepare(kernel) {
88
+ const { shader, grid } = this.#cachedShader(kernel);
89
+ const pipeline = await this.pipelines.prepare(shader);
90
+ return new require_backend.Executable(kernel, {
91
+ shader,
92
+ grid,
93
+ pipeline
94
+ });
95
+ }
96
+ prepareSync(kernel) {
97
+ const { shader, grid } = this.#cachedShader(kernel);
98
+ const pipeline = this.pipelines.prepareSync(shader);
99
+ return new require_backend.Executable(kernel, {
100
+ shader,
101
+ grid,
102
+ pipeline
103
+ });
104
+ }
105
+ dispatch(exe, inputs, outputs) {
106
+ const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
107
+ const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
108
+ pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
109
+ }
110
+ #getBuffer(slot) {
111
+ const buffer = this.buffers.get(slot);
112
+ if (!buffer) throw new require_backend.SlotError(slot);
113
+ return buffer.buffer;
114
+ }
115
+ /**
116
+ * Create a GPU buffer.
117
+ *
118
+ * By default, this creates a general-purpose buffer with the given size.
119
+ *
120
+ * - If `mapped` is true, initialize the buffer in mapped mode so that it can
121
+ * be populated with data from the CPU. (Call `.unmap()` later.)
122
+ * - If `read` is true, create a staging buffer for returning data to CPU.
123
+ * (Call `.mapAsync()` later.)
124
+ */
125
+ #createBuffer(size, { mapped = false, read = false } = {}) {
126
+ if (read && mapped) throw new Error("mapped and read cannot both be true");
127
+ const buffer = this.device.createBuffer({
128
+ size,
129
+ usage: read ? GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST : GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
130
+ mappedAtCreation: mapped
131
+ });
132
+ return buffer;
133
+ }
134
+ };
135
+ function dtypeToWgsl(dtype, storage = false) {
136
+ switch (dtype) {
137
+ case require_backend.DType.Bool: return storage ? "i32" : "bool";
138
+ case require_backend.DType.Int32: return "i32";
139
+ case require_backend.DType.Uint32: return "u32";
140
+ case require_backend.DType.Float32: return "f32";
141
+ default: throw new Error(`Unsupported dtype: ${dtype}`);
142
+ }
143
+ }
144
+ function constToWgsl(dtype, value) {
145
+ if (dtype === require_backend.DType.Bool) return value ? "true" : "false";
146
+ if (dtype === require_backend.DType.Int32) return value.toString();
147
+ if (dtype === require_backend.DType.Uint32) return value.toString() + "u";
148
+ if (dtype === require_backend.DType.Float32) {
149
+ if (Number.isNaN(value)) return "nan()";
150
+ if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
151
+ let s = value.toString();
152
+ if (!s.includes(".")) s += ".0";
153
+ return s;
154
+ }
155
+ throw new Error(`Unsupported const dtype: ${dtype}`);
156
+ }
157
+ /**
158
+ * Compiles an expression into WebGPU shader source code.
159
+ *
160
+ * Returns the shader source and the number of workgroups to dispatch along x
161
+ * and y axes, to run the kernel.
162
+ */
163
+ function pipelineSource(device, kernel) {
164
+ const tune = require_backend.tuneWebgpu(kernel);
165
+ if (require_backend.DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
166
+ const { nargs } = kernel;
167
+ const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
168
+ const shader = [];
169
+ let indent = "";
170
+ const pushIndent = Symbol("pushIndent");
171
+ const popIndent = Symbol("popIndent");
172
+ const emit = (...lines) => {
173
+ for (const line of lines) if (line === pushIndent) indent += " ";
174
+ else if (line === popIndent) indent = indent.slice(0, -2);
175
+ else shader.push(line ? indent + line : line);
176
+ };
177
+ emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
178
+ if (tune.exp.collect((exp) => exp.op === require_backend.AluOp.Threefry2x32).length > 0) emit(threefrySrc);
179
+ emit("");
180
+ const usedArgs = Array.from({ length: nargs }, () => null);
181
+ tune.exp.fold((exp) => {
182
+ if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg] = exp.dtype;
183
+ });
184
+ for (let i = 0; i < nargs; i++) {
185
+ const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
186
+ emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
187
+ }
188
+ const resultTy = dtypeToWgsl(kernel.dtype, true);
189
+ emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
190
+ const workgroupSize = require_backend.findPow2(tune.threadCount, 256);
191
+ const gridSize = Math.ceil(tune.threadCount / workgroupSize);
192
+ let gridX = gridSize;
193
+ let gridY = 1;
194
+ if (gridSize > device.limits.maxComputeWorkgroupsPerDimension) {
195
+ gridX = 16384;
196
+ gridY = Math.ceil(gridSize / gridX);
197
+ }
198
+ emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
199
+ if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
200
+ else {
201
+ const sizeX = gridX * workgroupSize;
202
+ emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
203
+ }
204
+ let gensymCount = 0;
205
+ const gensym = () => `alu${gensymCount++}`;
206
+ for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
207
+ const references = /* @__PURE__ */ new Map();
208
+ const seen = /* @__PURE__ */ new Set();
209
+ const countReferences = (exp) => {
210
+ references.set(exp, (references.get(exp) ?? 0) + 1);
211
+ if (!seen.has(exp)) {
212
+ seen.add(exp);
213
+ for (const src of exp.src) countReferences(src);
214
+ }
215
+ };
216
+ const expContext = /* @__PURE__ */ new Map();
217
+ const gen = (exp) => {
218
+ if (expContext.has(exp)) return expContext.get(exp);
219
+ const { op, src, dtype, arg } = exp;
220
+ let source = "";
221
+ if (require_backend.AluGroup.Binary.has(op) || require_backend.AluGroup.Compare.has(op)) {
222
+ const a = gen(src[0]);
223
+ const b = gen(src[1]);
224
+ if (op === require_backend.AluOp.Add) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
225
+ else source = `(${a} + ${b})`;
226
+ else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
227
+ else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
228
+ else source = `(${a} * ${b})`;
229
+ else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `floor(${a} / ${b})` : `(${a} / ${b})`;
230
+ else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
231
+ else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
232
+ else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
233
+ else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
234
+ else if (op === require_backend.AluOp.Cmpne) source = `(${a} != ${b})`;
235
+ } else if (require_backend.AluGroup.Unary.has(op)) {
236
+ const a = gen(src[0]);
237
+ if (op === require_backend.AluOp.Sin) source = `sin(${a})`;
238
+ else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
239
+ else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
240
+ else if (op === require_backend.AluOp.Log) source = `log(${a})`;
241
+ else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
242
+ else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
243
+ else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
244
+ } else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
245
+ else if (op === require_backend.AluOp.Threefry2x32) {
246
+ const x = gensym();
247
+ const [k0, k1, c0, c1] = src.map((x$1) => require_backend.strip1(gen(x$1)));
248
+ emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
249
+ if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
250
+ else if (arg === 0) source = `${x}.x`;
251
+ else if (arg === 1) source = `${x}.y`;
252
+ else throw new Error("Invalid Threefry2x32 mode: " + arg);
253
+ } else if (op === require_backend.AluOp.Const) return constToWgsl(dtype, arg);
254
+ else if (op === require_backend.AluOp.Special) return arg[0];
255
+ else if (op === require_backend.AluOp.Variable) return arg;
256
+ else if (op === require_backend.AluOp.GlobalIndex) {
257
+ source = `${args[arg]}[${require_backend.strip1(gen(src[0]))}]`;
258
+ if (dtype === require_backend.DType.Bool) source = `(${source} != 0)`;
259
+ }
260
+ if (!source) throw new Error(`Missing impl for op: ${op}`);
261
+ const typeName = dtypeToWgsl(dtype);
262
+ if ((references.get(exp) ?? 0) > 1) {
263
+ const name = gensym();
264
+ expContext.set(exp, name);
265
+ emit(`let ${name}: ${typeName} = ${require_backend.strip1(source)};`);
266
+ return name;
267
+ } else {
268
+ expContext.set(exp, source);
269
+ return source;
270
+ }
271
+ };
272
+ if (!kernel.reduction) {
273
+ countReferences(tune.exp);
274
+ let rhs = require_backend.strip1(gen(tune.exp));
275
+ if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
276
+ emit(`result[gidx] = ${rhs};`);
277
+ } else {
278
+ const re = kernel.reduction;
279
+ if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
280
+ const unroll = tune.size.unroll ?? 1;
281
+ const upcast = tune.size.upcast ?? 1;
282
+ const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
283
+ for (let i = 0; i < upcast; i++) emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
284
+ emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, pushIndent);
285
+ const exps = [];
286
+ const cache = /* @__PURE__ */ new Map();
287
+ for (let up = 0; up < upcast; up++) {
288
+ exps.push([]);
289
+ for (let un = 0; un < unroll; un++) {
290
+ const exp = tune.exp.substitute({
291
+ upcast: require_backend.AluExp.i32(up),
292
+ unroll: require_backend.AluExp.i32(un)
293
+ });
294
+ exps[up].push(exp.simplify(cache));
295
+ countReferences(exps[up][un]);
296
+ }
297
+ }
298
+ const items = exps.map((ar) => ar.map(gen).map(require_backend.strip1));
299
+ for (let i = 0; i < upcast; i++) {
300
+ let rhs = items[i][0];
301
+ for (let j = 1; j < unroll; j++) if (re.op === require_backend.AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
302
+ else if (re.op === require_backend.AluOp.Mul) rhs = `${rhs} * ${items[i][j]}`;
303
+ else if (re.op === require_backend.AluOp.Min) rhs = `min(${rhs}, ${items[i][j]})`;
304
+ else if (re.op === require_backend.AluOp.Max) rhs = `max(${rhs}, ${items[i][j]})`;
305
+ else throw new Error(`Unsupported reduction op: ${re.op}`);
306
+ if (re.op === require_backend.AluOp.Add) emit(`${acc[i]} += ${rhs};`);
307
+ else if (re.op === require_backend.AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
308
+ else if (re.op === require_backend.AluOp.Min) emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
309
+ else if (re.op === require_backend.AluOp.Max) emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
310
+ else throw new Error(`Unsupported reduction op: ${re.op}`);
311
+ }
312
+ emit(popIndent, "}");
313
+ expContext.clear();
314
+ references.clear();
315
+ seen.clear();
316
+ const outputIdxExps = [];
317
+ const fusionExps = [];
318
+ for (let i = 0; i < upcast; i++) {
319
+ const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
320
+ outputIdxExps.push(exp.simplify(cache));
321
+ countReferences(outputIdxExps[i]);
322
+ fusionExps.push(re.fusion.substitute({ acc: require_backend.AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
323
+ countReferences(fusionExps[i]);
324
+ }
325
+ for (let i = 0; i < upcast; i++) {
326
+ const index = require_backend.strip1(gen(outputIdxExps[i]));
327
+ let rhs = require_backend.strip1(gen(fusionExps[i]));
328
+ if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
329
+ emit(`result[${index}] = ${rhs};`);
330
+ }
331
+ }
332
+ emit(popIndent, "}");
333
+ return {
334
+ shader: shader.join("\n"),
335
+ grid: [gridX, gridY]
336
+ };
337
+ }
338
+ function pipelineSubmit(device, { pipeline, grid }, inputs, outputs) {
339
+ if (inputs.length + outputs.length > device.limits.maxStorageBuffersPerShaderStage) {
340
+ const actual = inputs.length + outputs.length;
341
+ const max = device.limits.maxStorageBuffersPerShaderStage;
342
+ throw new Error(`Too many buffers (${actual}) for WebGPU pipeline (max: ${max})`);
343
+ }
344
+ const bindGroup = device.createBindGroup({
345
+ layout: pipeline.getBindGroupLayout(0),
346
+ entries: [...inputs.map((buffer, i) => {
347
+ return {
348
+ binding: i,
349
+ resource: { buffer }
350
+ };
351
+ }), {
352
+ binding: inputs.length,
353
+ resource: { buffer: outputs[0] }
354
+ }]
355
+ });
356
+ const commandEncoder = device.createCommandEncoder();
357
+ const passEncoder = commandEncoder.beginComputePass();
358
+ passEncoder.setPipeline(pipeline);
359
+ passEncoder.setBindGroup(0, bindGroup);
360
+ passEncoder.dispatchWorkgroups(grid[0], grid[1]);
361
+ passEncoder.end();
362
+ device.queue.submit([commandEncoder.finish()]);
363
+ }
364
+ /**
365
+ * A cache for compiled GPU compute pipelines, keyed by the shader source.
366
+ *
367
+ * This supports both async compilation (recommended) and a synchronous variant.
368
+ * If the pipeline is not in the cache, it will be compiled and added. For async
369
+ * compilation, only one compilation will be in progress at a time for a given
370
+ * shader source.
371
+ */
372
+ var ShaderPipelineCache = class {
373
+ cache;
374
+ inProgress;
375
+ constructor(device) {
376
+ this.device = device;
377
+ this.cache = /* @__PURE__ */ new Map();
378
+ this.inProgress = /* @__PURE__ */ new Map();
379
+ }
380
+ async prepare(code) {
381
+ const existingPipeline = this.cache.get(code);
382
+ if (existingPipeline) return existingPipeline;
383
+ const existingPromise = this.inProgress.get(code);
384
+ if (existingPromise) return await existingPromise;
385
+ if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
386
+ const shaderModule = this.device.createShaderModule({ code });
387
+ const promise = (async () => {
388
+ this.device.pushErrorScope("validation");
389
+ try {
390
+ const pipeline$1 = await this.device.createComputePipelineAsync({
391
+ layout: "auto",
392
+ compute: {
393
+ module: shaderModule,
394
+ entryPoint: "main"
395
+ }
396
+ });
397
+ await this.device.popErrorScope();
398
+ return pipeline$1;
399
+ } catch (_error) {
400
+ const scope = await this.device.popErrorScope();
401
+ const emsg = await compileError(shaderModule, scope, code);
402
+ throw new Error(emsg);
403
+ }
404
+ })();
405
+ this.inProgress.set(code, promise);
406
+ const pipeline = await promise;
407
+ this.cache.set(code, pipeline);
408
+ return pipeline;
409
+ }
410
+ prepareSync(code) {
411
+ const existingPipeline = this.cache.get(code);
412
+ if (existingPipeline) return existingPipeline;
413
+ if (require_backend.DEBUG >= 2) console.info("=========== WebGPU shader ===========\n" + code);
414
+ const shaderModule = this.device.createShaderModule({ code });
415
+ this.device.pushErrorScope("validation");
416
+ const pipeline = this.device.createComputePipeline({
417
+ layout: "auto",
418
+ compute: {
419
+ module: shaderModule,
420
+ entryPoint: "main"
421
+ }
422
+ });
423
+ this.device.popErrorScope().then(async (scope) => {
424
+ if (scope !== null) {
425
+ const emsg = await compileError(shaderModule, scope, code);
426
+ console.error(emsg);
427
+ }
428
+ });
429
+ this.cache.set(code, pipeline);
430
+ return pipeline;
431
+ }
432
+ };
433
+ /** Gather information about a compilation error and format it. */
434
+ async function compileError(shaderModule, scope, code) {
435
+ let message = `Failed to compile shader: ${scope ? scope.message : "(no error scope)"}`;
436
+ const info = await shaderModule.getCompilationInfo();
437
+ for (const msg of info.messages) message += `\n [${msg.type} at ${msg.lineNum}:${msg.linePos}] ${msg.message}`;
438
+ if (code) message += `\n\n${code}`;
439
+ return message;
440
+ }
441
+ /**
442
+ * Graphics state used to synchronously read data from WebGPU buffers.
443
+ *
444
+ * This trick is borrowed from TensorFlow.js. Basically, the idea is to create
445
+ * an offscreen canvas with one pixel for every 4 bytes ("device storage"), then
446
+ * configure it with a WebGPU context. Copy the buffer to a texture, then draw
447
+ * the canvas onto another offscreen canvas with '2d' context ("host storage").
448
+ *
449
+ * Once it's on host storage, we can use `getImageData()` to read the pixels
450
+ * from the image directly.
451
+ *
452
+ * We use 256x256 canvases here (256 KiB). The performance of this is bad
453
+ * because it involves multiple data copies, but it still works. We also
454
+ * actually need to copy the image twice: once in "opaque" mode for the RGB
455
+ * values, and once in "premultiplied" mode for the alpha channel.
456
+ *
457
+ * https://github.com/tensorflow/tfjs/blob/tfjs-v4.22.0/tfjs-backend-webgpu/src/backend_webgpu.ts#L379
458
+ */
459
+ var SyncReader = class SyncReader {
460
+ static alphaModes = ["opaque", "premultiplied"];
461
+ static width = 256;
462
+ static height = 256;
463
+ initialized = false;
464
+ deviceStorage;
465
+ deviceContexts;
466
+ hostStorage;
467
+ hostContext;
468
+ constructor(device) {
469
+ this.device = device;
470
+ }
471
+ #init() {
472
+ const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
473
+ this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
474
+ this.deviceContexts = this.deviceStorage.map((canvas, i) => {
475
+ const context = canvas.getContext("webgpu");
476
+ context.configure({
477
+ device: this.device,
478
+ format: "bgra8unorm",
479
+ usage: GPUTextureUsage.COPY_DST,
480
+ alphaMode: SyncReader.alphaModes[i]
481
+ });
482
+ return context;
483
+ });
484
+ this.hostStorage = makeCanvas();
485
+ this.hostContext = this.hostStorage.getContext("2d", { willReadFrequently: true });
486
+ this.initialized = true;
487
+ }
488
+ read(buffer, start, count) {
489
+ if (!this.initialized) this.#init();
490
+ if (count % 4 !== 0) throw new Error("Read size must be a multiple of 4 bytes");
491
+ const deviceStorage = this.deviceStorage;
492
+ const deviceContexts = this.deviceContexts;
493
+ const hostContext = this.hostContext;
494
+ const pixelsSize = count / 4;
495
+ const bytesPerRow = SyncReader.width * 4;
496
+ const valsGPU = new ArrayBuffer(count);
497
+ for (let i = 0; i < deviceContexts.length; i++) {
498
+ const texture = deviceContexts[i].getCurrentTexture();
499
+ const readData = (width, height, offset$1) => {
500
+ const encoder = this.device.createCommandEncoder();
501
+ encoder.copyBufferToTexture({
502
+ buffer,
503
+ bytesPerRow,
504
+ offset: offset$1 + start
505
+ }, { texture }, {
506
+ width,
507
+ height,
508
+ depthOrArrayLayers: 1
509
+ });
510
+ const commandBuffer = encoder.finish();
511
+ this.device.queue.submit([commandBuffer]);
512
+ hostContext.clearRect(0, 0, width, height);
513
+ hostContext.drawImage(deviceStorage[i], 0, 0);
514
+ const values = hostContext.getImageData(0, 0, width, height).data;
515
+ const span = new Uint8ClampedArray(valsGPU, offset$1, 4 * width * height);
516
+ const alphaMode = SyncReader.alphaModes[i];
517
+ for (let k = 0; k < span.length; k += 4) if (alphaMode === "premultiplied") span[k + 3] = values[k + 3];
518
+ else {
519
+ span[k] = values[k + 2];
520
+ span[k + 1] = values[k + 1];
521
+ span[k + 2] = values[k];
522
+ }
523
+ };
524
+ const pixelsPerCanvas = SyncReader.width * SyncReader.height;
525
+ const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
526
+ let remainder = pixelsSize % pixelsPerCanvas;
527
+ const remainderRows = Math.floor(remainder / SyncReader.width);
528
+ remainder = remainder % SyncReader.width;
529
+ let offset = 0;
530
+ for (let j = 0; j < wholeChunks; j++) {
531
+ readData(SyncReader.width, SyncReader.height, offset);
532
+ offset += pixelsPerCanvas * 4;
533
+ }
534
+ if (remainderRows > 0) {
535
+ readData(SyncReader.width, remainderRows, offset);
536
+ offset += remainderRows * SyncReader.width * 4;
537
+ }
538
+ if (remainder > 0) readData(remainder, 1, offset);
539
+ }
540
+ return valsGPU;
541
+ }
542
+ };
543
+ const threefrySrc = `
544
+ fn threefry2x32(key: vec2<u32>, ctr: vec2<u32>) -> vec2<u32> {
545
+ let ks0: u32 = key.x;
546
+ let ks1: u32 = key.y;
547
+ let ks2: u32 = ks0 ^ ks1 ^ 0x1BD11BDAu;
548
+
549
+ var x0: u32 = ctr.x + ks0;
550
+ var x1: u32 = ctr.y + ks1;
551
+
552
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
553
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
554
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
555
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
556
+ x0 += ks1;
557
+ x1 += ks2 + 1u;
558
+
559
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
560
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
561
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
562
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
563
+ x0 += ks2;
564
+ x1 += ks0 + 2u;
565
+
566
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
567
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
568
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
569
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
570
+ x0 += ks0;
571
+ x1 += ks1 + 3u;
572
+
573
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
574
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
575
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
576
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
577
+ x0 += ks1;
578
+ x1 += ks2 + 4u;
579
+
580
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
581
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
582
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
583
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
584
+ x0 += ks2;
585
+ x1 += ks0 + 5u;
586
+
587
+ return vec2<u32>(x0, x1);
588
+ }`;
589
+
590
+ //#endregion
591
+ exports.WebGPUBackend = WebGPUBackend;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.0.1",
3
+ "version": "0.0.2",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -35,35 +35,40 @@
35
35
  "types": "dist/index.d.ts",
36
36
  "module": "dist/index.js",
37
37
  "license": "MIT",
38
- "scripts": {
39
- "build": "tsup",
40
- "build:watch": "tsup --watch src",
41
- "format": "prettier --write .",
42
- "format:check": "prettier --check .",
43
- "lint": "eslint",
44
- "test": "vitest",
45
- "prepublishOnly": "npm run build",
46
- "postpublish": "git tag jax/v$npm_package_version && git push --tags"
47
- },
48
38
  "devDependencies": {
49
- "@eslint/js": "^9.25.0",
39
+ "@eslint/js": "^9.31.0",
50
40
  "@types/debug": "^4.1.12",
51
- "@vitest/browser": "^3.0.5",
52
- "@webgpu/types": "^0.1.54",
53
- "eslint": "^9.25.0",
54
- "eslint-plugin-import": "^2.31.0",
41
+ "@vitest/browser": "^3.2.4",
42
+ "@webgpu/types": "^0.1.64",
43
+ "eslint": "^9.31.0",
44
+ "eslint-plugin-import": "^2.32.0",
55
45
  "globals": "^16.0.0",
56
- "playwright": "^1.50.1",
57
- "prettier": "^3.5.2",
58
- "prettier-plugin-svelte": "^3.3.3",
59
- "tsup": "^8.3.6",
60
- "typescript": "~5.7.3",
61
- "typescript-eslint": "^8.30.1",
62
- "vitest": "^3.0.5"
46
+ "playwright": "~1.50.1",
47
+ "prettier": "^3.6.2",
48
+ "prettier-plugin-svelte": "^3.4.0",
49
+ "tsdown": "^0.13.0",
50
+ "tsx": "^4.20.3",
51
+ "typedoc": "^0.28.7",
52
+ "typescript": "~5.8.3",
53
+ "typescript-eslint": "^8.38.0",
54
+ "vitest": "^3.2.4"
55
+ },
56
+ "engines": {
57
+ "pnpm": ">=10.0.0"
63
58
  },
64
59
  "prettier": {
65
60
  "plugins": [
66
61
  "prettier-plugin-svelte"
67
62
  ]
63
+ },
64
+ "scripts": {
65
+ "build": "tsdown",
66
+ "build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
67
+ "check": "tsc",
68
+ "docs": "tsx typedoc.ts",
69
+ "format": "prettier --write .",
70
+ "format:check": "prettier --check .",
71
+ "lint": "eslint",
72
+ "test": "vitest"
68
73
  }
69
- }
74
+ }