@jax-js/jax 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,522 @@
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DaqL-MNz.js";
2
+
3
+ //#region src/backend/webgl/builtins.ts
4
+ const threefrySrc = `
5
+ uvec2 threefry2x32(uvec2 key, uvec2 ctr) {
6
+ uint ks0 = key.x;
7
+ uint ks1 = key.y;
8
+ uint ks2 = ks0 ^ ks1 ^ 0x1BD11BDAu;
9
+
10
+ uint x0 = ctr.x + ks0;
11
+ uint x1 = ctr.y + ks1;
12
+
13
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
14
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
15
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
16
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
17
+ x0 += ks1;
18
+ x1 += ks2 + 1u;
19
+
20
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
21
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
22
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
23
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
24
+ x0 += ks2;
25
+ x1 += ks0 + 2u;
26
+
27
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
28
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
29
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
30
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
31
+ x0 += ks0;
32
+ x1 += ks1 + 3u;
33
+
34
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
35
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
36
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
37
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
38
+ x0 += ks1;
39
+ x1 += ks2 + 4u;
40
+
41
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
42
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
43
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
44
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
45
+ x0 += ks2;
46
+ x1 += ks0 + 5u;
47
+
48
+ return uvec2(x0, x1);
49
+ }`;
50
+ const erfSrc = `
51
+ const float _erf_p = 0.3275911;
52
+ const float _erf_a1 = 0.254829592;
53
+ const float _erf_a2 = -0.284496736;
54
+ const float _erf_a3 = 1.421413741;
55
+ const float _erf_a4 = -1.453152027;
56
+ const float _erf_a5 = 1.061405429;
57
+ float erf(float x) {
58
+ float t = 1.0 / (1.0 + _erf_p * abs(x));
59
+ float P_t = (((((_erf_a5 * t) + _erf_a4) * t + _erf_a3) * t + _erf_a2) * t + _erf_a1) * t;
60
+ return sign(x) * (1.0 - P_t * exp(-x * x));
61
+ }
62
+ float erfc(float x) {
63
+ float t = 1.0 / (1.0 + _erf_p * abs(x));
64
+ float P_t = (((((_erf_a5 * t) + _erf_a4) * t + _erf_a3) * t + _erf_a2) * t + _erf_a1) * t;
65
+ float E = P_t * exp(-x * x);
66
+ return x >= 0.0 ? E : 2.0 - E;
67
+ }`;
68
+
69
+ //#endregion
70
+ //#region src/backend/webgl.ts
71
+ /**
72
+ * No-frills backend that uses WebGL2 textures and shaders for compute.
73
+ *
74
+ * WebGL2 is available in almost all modern browsers, and it has options for
75
+ * floating-point numbers and integers in textures. However, it's still not a
76
+ * "real" compute API, and only float32 arithmetic is available.
77
+ *
78
+ * We make this backend available in case users want a fallback option compared
79
+ * to WebGPU, which is only on newer browsers and iOS 26+.
80
+ *
81
+ * Implementation notes:
82
+ * - All data is stored in typed RGBA32F textures, regardless of original dtype.
83
+ * They are converted to the correct data type when loaded in shaders.
84
+ * - Each texel holds 4 float32 values (128 bits).
85
+ * - Compute is done by rendering a full-screen quad with a fragment shader.
86
+ * - Output is rendered to a framebuffer-attached texture, then read back.
87
+ */
88
+ var WebGLBackend = class {
89
+ type = "webgl";
90
+ maxArgs = 8;
91
+ gl;
92
+ #fbo;
93
+ #buffers;
94
+ #programCache;
95
+ #nextSlot;
96
+ constructor(gl) {
97
+ this.gl = gl;
98
+ this.#fbo = gl.createFramebuffer();
99
+ this.#buffers = /* @__PURE__ */ new Map();
100
+ this.#programCache = /* @__PURE__ */ new Map();
101
+ this.#nextSlot = 1;
102
+ }
103
+ /**
104
+ * Allocate a buffer with a specific dtype.
105
+ *
106
+ * All buffers use RGBA32F texture format internally. Data is stored as raw
107
+ * bits and reinterpreted using floatBitsToInt/intBitsToFloat in shaders.
108
+ * This mirrors how WebGPU handles untyped byte buffers.
109
+ */
110
+ malloc(size, initialData) {
111
+ const gl = this.gl;
112
+ const numFloats = Math.ceil(size / 4) || 1;
113
+ const numTexels = Math.ceil(numFloats / 4) || 1;
114
+ const { width, height } = computeTextureDimensions(numTexels);
115
+ const texture = gl.createTexture();
116
+ if (!texture) throw new Error("Failed to create texture");
117
+ gl.bindTexture(gl.TEXTURE_2D, texture);
118
+ gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
119
+ gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
120
+ gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
121
+ gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
122
+ const totalFloats = width * height * 4;
123
+ let pixels = null;
124
+ if (initialData) {
125
+ pixels = new Float32Array(totalFloats);
126
+ new Uint8Array(pixels.buffer).set(initialData);
127
+ }
128
+ gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA32F, width, height, 0, gl.RGBA, gl.FLOAT, pixels);
129
+ gl.bindTexture(gl.TEXTURE_2D, null);
130
+ const slot = this.#nextSlot++;
131
+ this.#buffers.set(slot, {
132
+ ref: 1,
133
+ size,
134
+ texture,
135
+ width,
136
+ height
137
+ });
138
+ return slot;
139
+ }
140
+ incRef(slot) {
141
+ const buffer = this.#buffers.get(slot);
142
+ if (!buffer) throw new SlotError(slot);
143
+ buffer.ref++;
144
+ }
145
+ decRef(slot) {
146
+ const buffer = this.#buffers.get(slot);
147
+ if (!buffer) throw new SlotError(slot);
148
+ buffer.ref--;
149
+ if (buffer.ref === 0) {
150
+ this.gl.deleteTexture(buffer.texture);
151
+ this.#buffers.delete(slot);
152
+ }
153
+ }
154
+ async read(slot, start, count) {
155
+ const buffer = this.#buffers.get(slot);
156
+ if (!buffer) throw new SlotError(slot);
157
+ const gl = this.gl;
158
+ if (start === void 0) start = 0;
159
+ if (count === void 0) count = buffer.size - start;
160
+ gl.bindFramebuffer(gl.FRAMEBUFFER, this.#fbo);
161
+ gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, buffer.texture, 0);
162
+ const totalBytes = buffer.width * buffer.height * 4 * 4;
163
+ const floatData = new Float32Array(totalBytes / 4);
164
+ const pbo = gl.createBuffer();
165
+ if (!pbo) throw new Error("Failed to create PBO");
166
+ gl.bindBuffer(gl.PIXEL_PACK_BUFFER, pbo);
167
+ gl.bufferData(gl.PIXEL_PACK_BUFFER, totalBytes, gl.STREAM_READ);
168
+ gl.readPixels(0, 0, buffer.width, buffer.height, gl.RGBA, gl.FLOAT, 0);
169
+ const readError = gl.getError();
170
+ if (readError !== gl.NO_ERROR) {
171
+ gl.deleteBuffer(pbo);
172
+ throw new Error(`WebGL error after readPixels: ${readError}`);
173
+ }
174
+ const sync = gl.fenceSync(gl.SYNC_GPU_COMMANDS_COMPLETE, 0);
175
+ if (!sync) throw new Error("Failed to create sync object");
176
+ gl.flush();
177
+ gl.bindBuffer(gl.PIXEL_PACK_BUFFER, null);
178
+ gl.bindFramebuffer(gl.FRAMEBUFFER, null);
179
+ await new Promise((resolve, reject) => {
180
+ const poll = () => {
181
+ const status = gl.clientWaitSync(sync, 0, 0);
182
+ if (status === gl.TIMEOUT_EXPIRED) {
183
+ setTimeout(poll, 5);
184
+ return;
185
+ }
186
+ if (status === gl.WAIT_FAILED) {
187
+ gl.deleteSync(sync);
188
+ gl.deleteBuffer(pbo);
189
+ reject(/* @__PURE__ */ new Error("clientWaitSync failed"));
190
+ return;
191
+ }
192
+ resolve();
193
+ };
194
+ poll();
195
+ });
196
+ gl.deleteSync(sync);
197
+ gl.bindBuffer(gl.PIXEL_PACK_BUFFER, pbo);
198
+ gl.getBufferSubData(gl.PIXEL_PACK_BUFFER, 0, floatData);
199
+ gl.bindBuffer(gl.PIXEL_PACK_BUFFER, null);
200
+ gl.deleteBuffer(pbo);
201
+ const byteData = new Uint8Array(floatData.buffer);
202
+ return new Uint8Array(byteData.slice(start, start + count));
203
+ }
204
+ readSync(slot, start, count) {
205
+ const buffer = this.#buffers.get(slot);
206
+ if (!buffer) throw new SlotError(slot);
207
+ const gl = this.gl;
208
+ if (start === void 0) start = 0;
209
+ if (count === void 0) count = buffer.size - start;
210
+ gl.bindFramebuffer(gl.FRAMEBUFFER, this.#fbo);
211
+ gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, buffer.texture, 0);
212
+ const totalFloats = buffer.width * buffer.height * 4;
213
+ const floatData = new Float32Array(totalFloats);
214
+ gl.readPixels(0, 0, buffer.width, buffer.height, gl.RGBA, gl.FLOAT, floatData);
215
+ gl.bindFramebuffer(gl.FRAMEBUFFER, null);
216
+ const byteData = new Uint8Array(floatData.buffer);
217
+ return new Uint8Array(byteData.slice(start, start + count));
218
+ }
219
+ async prepareKernel(kernel) {
220
+ return this.prepareKernelSync(kernel);
221
+ }
222
+ prepareKernelSync(kernel) {
223
+ const shader = generateShader(kernel);
224
+ const cached = this.#programCache.get(shader.code);
225
+ if (cached) return new Executable(kernel, cached);
226
+ const dispatch = compileShader(this.gl, shader);
227
+ this.#programCache.set(shader.code, dispatch);
228
+ return new Executable(kernel, dispatch);
229
+ }
230
+ prepareRoutine(routine) {
231
+ throw new UnsupportedRoutineError(routine.name, "webgl");
232
+ }
233
+ prepareRoutineSync(routine) {
234
+ throw new UnsupportedRoutineError(routine.name, "webgl");
235
+ }
236
+ dispatch(exe, inputs, outputs) {
237
+ const gl = this.gl;
238
+ if (gl.isContextLost()) throw new Error("WebGL context lost - cannot dispatch");
239
+ const { program, inputLocations } = exe.data;
240
+ if (inputs.length !== exe.data.numInputs) throw new Error(`Expected ${exe.data.numInputs} inputs, got ${inputs.length}`);
241
+ if (outputs.length !== 1) throw new Error(`Expected 1 output, got ${outputs.length}`);
242
+ const outputBuffer = this.#buffers.get(outputs[0]);
243
+ if (!outputBuffer) throw new SlotError(outputs[0]);
244
+ gl.bindFramebuffer(gl.FRAMEBUFFER, this.#fbo);
245
+ gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, outputBuffer.texture, 0);
246
+ const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
247
+ if (status !== gl.FRAMEBUFFER_COMPLETE) throw new Error(`Framebuffer incomplete: ${status}`);
248
+ gl.viewport(0, 0, outputBuffer.width, outputBuffer.height);
249
+ gl.useProgram(program);
250
+ for (let i = 0; i < inputs.length; i++) {
251
+ const inputBuffer = this.#buffers.get(inputs[i]);
252
+ if (!inputBuffer) throw new SlotError(inputs[i]);
253
+ gl.activeTexture(gl.TEXTURE0 + i);
254
+ gl.bindTexture(gl.TEXTURE_2D, inputBuffer.texture);
255
+ if (inputLocations[i] !== null) gl.uniform1i(inputLocations[i], i);
256
+ }
257
+ gl.drawArrays(gl.TRIANGLES, 0, 3);
258
+ const error = gl.getError();
259
+ if (error !== gl.NO_ERROR) {
260
+ let errorName;
261
+ if (error === gl.INVALID_ENUM) errorName = "INVALID_ENUM";
262
+ else if (error === gl.INVALID_VALUE) errorName = "INVALID_VALUE";
263
+ else if (error === gl.INVALID_OPERATION) errorName = "INVALID_OPERATION";
264
+ else if (error === gl.INVALID_FRAMEBUFFER_OPERATION) errorName = "INVALID_FRAMEBUFFER_OPERATION";
265
+ else if (error === gl.OUT_OF_MEMORY) errorName = "OUT_OF_MEMORY";
266
+ else if (error === gl.CONTEXT_LOST_WEBGL) errorName = "CONTEXT_LOST_WEBGL";
267
+ else errorName = `UNKNOWN(${error})`;
268
+ throw new Error(`WebGL error after drawArrays: ${errorName}`);
269
+ }
270
+ gl.bindFramebuffer(gl.FRAMEBUFFER, null);
271
+ gl.useProgram(null);
272
+ }
273
+ };
274
+ function generateShader(kernel) {
275
+ const tune = tuneNullopt(kernel);
276
+ if (DEBUG >= 3) console.info(`webgl kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
277
+ const { nargs, reduction: re } = kernel;
278
+ const outputDtype = kernel.dtype;
279
+ const numTexels = Math.ceil(kernel.size / 4) || 1;
280
+ const outputSize = computeTextureDimensions(numTexels);
281
+ const inputDtypes = Array(nargs).fill(DType.Float32);
282
+ const builtins = {
283
+ erf: false,
284
+ threefry: false
285
+ };
286
+ const collectInfo = (exp) => {
287
+ if (exp.op === AluOp.GlobalIndex) inputDtypes[exp.arg[0]] = exp.dtype;
288
+ else if (exp.op === AluOp.Erf || exp.op === AluOp.Erfc) builtins.erf = true;
289
+ else if (exp.op === AluOp.Threefry2x32) builtins.threefry = true;
290
+ };
291
+ tune.exp.fold(collectInfo);
292
+ tune.epilogue?.fold(collectInfo);
293
+ const shader = [];
294
+ let indent = "";
295
+ const pushIndent = Symbol("pushIndent");
296
+ const popIndent = Symbol("popIndent");
297
+ const emit = (...lines) => {
298
+ for (const line of lines) if (line === pushIndent) indent += " ";
299
+ else if (line === popIndent) indent = indent.slice(0, -2);
300
+ else shader.push(line ? indent + line : line);
301
+ };
302
+ emit("#version 300 es", "precision highp float;", "precision highp int;", "");
303
+ const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
304
+ const resultType = glslType(outputDtype);
305
+ for (let i = 0; i < nargs; i++) emit(`uniform highp sampler2D ${args[i]};`);
306
+ emit("out vec4 out0;");
307
+ const fetchFunctions = /* @__PURE__ */ new Set();
308
+ for (const dtype of inputDtypes) fetchFunctions.add(dtype);
309
+ for (const dtype of fetchFunctions) emit(generateLoadFunction(dtype));
310
+ if (builtins.erf) emit(erfSrc);
311
+ if (builtins.threefry) emit(threefrySrc);
312
+ emit(`${resultType} compute(int gidx) {`, pushIndent, `${resultType} result = ${constToGlsl(outputDtype, 0)};`, `if (gidx < ${kernel.size}) {`, pushIndent);
313
+ if (!re) {
314
+ const code = generateExpression(tune.exp, args, inputDtypes);
315
+ emit(`result = ${strip1(code)};`);
316
+ } else {
317
+ const accType = glslType(re.dtype);
318
+ const accInit = constToGlsl(re.dtype, re.identity);
319
+ emit(`${accType} acc = ${accInit};`, `for (int ridx = 0; ridx < ${tune.size.reduce}; ridx++) {`, pushIndent);
320
+ const code = generateExpression(tune.exp, args, inputDtypes);
321
+ if (re.op === AluOp.Add) emit(`acc += ${strip1(code)};`);
322
+ else if (re.op === AluOp.Mul) emit(`acc *= ${strip1(code)};`);
323
+ else if (re.op === AluOp.Min) if (re.dtype !== DType.Bool) emit(`acc = min(acc, ${strip1(code)});`);
324
+ else emit(`acc = acc && ${code};`);
325
+ else if (re.op === AluOp.Max) if (re.dtype !== DType.Bool) emit(`acc = max(acc, ${strip1(code)});`);
326
+ else emit(`acc = acc || ${code};`);
327
+ else throw new Error(`Unsupported reduction op: ${re.op}`);
328
+ emit(popIndent, "}");
329
+ emit(`result = ${generateExpression(tune.epilogue, args, inputDtypes)};`);
330
+ }
331
+ emit(popIndent, "}", "return result;", popIndent, "}\n");
332
+ emit("void main() {", pushIndent, "ivec2 fragCoord = ivec2(gl_FragCoord.xy);", `int texelIdx = fragCoord.y * ${outputSize.width} + fragCoord.x;`, `${resultType} result0 = compute(texelIdx * 4);`, `${resultType} result1 = compute(texelIdx * 4 + 1);`, `${resultType} result2 = compute(texelIdx * 4 + 2);`, `${resultType} result3 = compute(texelIdx * 4 + 3);`, `out0 = vec4(${range(4).map((i) => toRGBA32F(outputDtype, `result${i}`)).join(", ")});`);
333
+ emit(popIndent, "}");
334
+ return {
335
+ code: shader.join("\n"),
336
+ numInputs: nargs,
337
+ outputSize: [outputSize.width, outputSize.height],
338
+ outputDtype
339
+ };
340
+ }
341
+ function compile(gl, type, src) {
342
+ const s = gl.createShader(type);
343
+ gl.shaderSource(s, src);
344
+ gl.compileShader(s);
345
+ if (!gl.getShaderParameter(s, gl.COMPILE_STATUS)) throw new Error(gl.getShaderInfoLog(s) ?? "Unknown shader compile error");
346
+ return s;
347
+ }
348
+ function link(gl, vsSrc, fsSrc) {
349
+ const p = gl.createProgram();
350
+ gl.attachShader(p, compile(gl, gl.VERTEX_SHADER, vsSrc));
351
+ gl.attachShader(p, compile(gl, gl.FRAGMENT_SHADER, fsSrc));
352
+ gl.linkProgram(p);
353
+ if (!gl.getProgramParameter(p, gl.LINK_STATUS)) throw new Error(gl.getProgramInfoLog(p) ?? "Unknown program link error");
354
+ return p;
355
+ }
356
+ const vertexShaderSource = `#version 300 es
357
+ precision highp float;
358
+ const vec2 pos[3] = vec2[](vec2(-1.0,-1.0), vec2(3.0,-1.0), vec2(-1.0,3.0));
359
+ void main() { gl_Position = vec4(pos[gl_VertexID], 0.0, 1.0); }
360
+ `;
361
+ function compileShader(gl, shader) {
362
+ if (DEBUG >= 1) console.info("=========== WebGL shader ===========\n" + shader.code);
363
+ const program = link(gl, vertexShaderSource, shader.code);
364
+ const inputLocations = [];
365
+ for (let i = 0; i < shader.numInputs; i++) inputLocations.push(gl.getUniformLocation(program, `in${i}`));
366
+ return {
367
+ ...shader,
368
+ program,
369
+ inputLocations
370
+ };
371
+ }
372
+ /** Compute 2D texture dimensions for a given number of texels. */
373
+ function computeTextureDimensions(numTexels) {
374
+ const maxDim = 16384;
375
+ let width = Math.min(Math.ceil(Math.sqrt(numTexels)), maxDim);
376
+ width = Math.min(1 << Math.ceil(Math.log2(width)), maxDim);
377
+ const height = Math.min(Math.ceil(numTexels / width), maxDim);
378
+ return {
379
+ width,
380
+ height
381
+ };
382
+ }
383
+ function glslType(dtype) {
384
+ switch (dtype) {
385
+ case DType.Float32: return "float";
386
+ case DType.Int32: return "int";
387
+ case DType.Uint32: return "uint";
388
+ case DType.Bool: return "bool";
389
+ default: throw new Error(`Unsupported dtype for WebGL: ${dtype}`);
390
+ }
391
+ }
392
+ function generateLoadFunction(dtype) {
393
+ const funcName = `load_${dtype}`;
394
+ const returnType = glslType(dtype);
395
+ let conversion;
396
+ if (isFloatDtype(dtype)) conversion = "val";
397
+ else if (dtype === DType.Int32) conversion = "floatBitsToInt(val)";
398
+ else if (dtype === DType.Uint32) conversion = "floatBitsToUint(val)";
399
+ else if (dtype === DType.Bool) conversion = "floatBitsToInt(val) != 0";
400
+ else throw new Error(`Unsupported dtype for WebGL fetch: ${dtype}`);
401
+ return `
402
+ ${returnType} ${funcName}(highp sampler2D tex, int idx) {
403
+ ivec2 texSize = textureSize(tex, 0);
404
+ int texel = idx / 4;
405
+ int component = idx - texel * 4;
406
+ ivec2 coord = ivec2(texel % texSize.x, texel / texSize.x);
407
+ vec4 texVal = texelFetch(tex, coord, 0);
408
+ float val;
409
+ if (component == 0) val = texVal.x;
410
+ else if (component == 1) val = texVal.y;
411
+ else if (component == 2) val = texVal.z;
412
+ else val = texVal.w;
413
+ return ${conversion};
414
+ }
415
+ `;
416
+ }
417
+ function toRGBA32F(dtype, source) {
418
+ switch (dtype) {
419
+ case DType.Float32: return source;
420
+ case DType.Int32: return `intBitsToFloat(${source})`;
421
+ case DType.Uint32: return `uintBitsToFloat(${source})`;
422
+ case DType.Bool: return `intBitsToFloat(${source} ? 1 : 0)`;
423
+ default: throw new Error(`Unsupported dtype for WebGL output: ${dtype}`);
424
+ }
425
+ }
426
+ function constToGlsl(dtype, value) {
427
+ switch (dtype) {
428
+ case DType.Bool: return value ? "true" : "false";
429
+ case DType.Int32: return value.toString();
430
+ case DType.Uint32: return value.toString() + "u";
431
+ case DType.Float32:
432
+ if (Number.isNaN(value)) return "uintBitsToFloat(0x7fc00000u)";
433
+ if (!Number.isFinite(value)) return value > 0 ? "uintBitsToFloat(0x7f800000u)" : "uintBitsToFloat(0xff800000u)";
434
+ return "float(" + value.toString() + ")";
435
+ default: throw new Error(`Unsupported dtype for WebGL constant: ${dtype}`);
436
+ }
437
+ }
438
+ /** Generate GLSL expression code from an AluExp. */
439
+ function generateExpression(exp, args, inputDtypes) {
440
+ const expContext = /* @__PURE__ */ new Map();
441
+ const gen = (e) => {
442
+ if (expContext.has(e)) return expContext.get(e);
443
+ const { op, src, dtype, arg } = e;
444
+ let source = "";
445
+ if (AluGroup.Binary.has(op)) {
446
+ const a = gen(src[0]);
447
+ const b = gen(src[1]);
448
+ if (op === AluOp.Add) if (dtype === DType.Bool) source = `(${a} || ${b})`;
449
+ else source = `(${a} + ${b})`;
450
+ else if (op === AluOp.Sub) source = `(${a} - ${b})`;
451
+ else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
452
+ else source = `(${a} * ${b})`;
453
+ else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) source = `trunc(${a} / ${b})`;
454
+ else source = `(${a} / ${b})`;
455
+ else if (op === AluOp.Mod) if (isFloatDtype(dtype)) source = `(${a} - ${b} * trunc(${a} / ${b}))`;
456
+ else source = `(${a} % ${b})`;
457
+ else if (op === AluOp.Min) if (dtype === DType.Bool) source = `(${a} && ${b})`;
458
+ else source = `min(${a}, ${b})`;
459
+ else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
460
+ else source = `max(${a}, ${b})`;
461
+ } else if (AluGroup.Compare.has(op)) {
462
+ const a = gen(src[0]);
463
+ const b = gen(src[1]);
464
+ if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
465
+ else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) source = `(${a} != ${b} || isnan(${a}) || isnan(${b}))`;
466
+ else source = `(${a} != ${b})`;
467
+ } else if (AluGroup.Unary.has(op)) {
468
+ const a = gen(src[0]);
469
+ if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
470
+ else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
471
+ else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
472
+ else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
473
+ else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
474
+ else if (op === AluOp.Log) source = `log(${strip1(a)})`;
475
+ else if (op === AluOp.Erf) source = `erf(${strip1(a)})`;
476
+ else if (op === AluOp.Erfc) source = `erfc(${strip1(a)})`;
477
+ else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
478
+ else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
479
+ else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
480
+ else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
481
+ else if (op === AluOp.Cast) source = `${glslType(dtype)}(${strip1(a)})`;
482
+ else if (op === AluOp.Bitcast) {
483
+ const dtype0 = src[0].dtype;
484
+ if (dtype === dtype0) source = a;
485
+ else if (dtype === DType.Float32) {
486
+ if (dtype0 === DType.Int32) source = `intBitsToFloat(${strip1(a)})`;
487
+ else if (dtype0 === DType.Uint32) source = `uintBitsToFloat(${strip1(a)})`;
488
+ } else if (dtype === DType.Int32) {
489
+ if (dtype0 === DType.Float32) source = `floatBitsToInt(${strip1(a)})`;
490
+ else if (dtype0 === DType.Uint32) source = `int(${strip1(a)})`;
491
+ } else if (dtype === DType.Uint32) {
492
+ if (dtype0 === DType.Float32) source = `floatBitsToUint(${strip1(a)})`;
493
+ else if (dtype0 === DType.Int32) source = `uint(${strip1(a)})`;
494
+ }
495
+ }
496
+ } else if (op === AluOp.Threefry2x32) {
497
+ const [k0, k1, c0, c1] = src.map((x) => strip1(gen(x)));
498
+ const mode = arg;
499
+ const call = `threefry2x32(uvec2(${k0}, ${k1}), uvec2(${c0}, ${c1}))`;
500
+ if (mode === "xor") source = `(${call}.x ^ ${call}.y)`;
501
+ else if (mode === 0) source = `${call}.x`;
502
+ else if (mode === 1) source = `${call}.y`;
503
+ } else if (op === AluOp.Where) {
504
+ const [cond, t, f] = src.map(gen);
505
+ source = `(${cond} ? ${t} : ${f})`;
506
+ } else if (op === AluOp.Const) source = constToGlsl(dtype, arg);
507
+ else if (op === AluOp.Special) source = arg[0];
508
+ else if (op === AluOp.Variable) source = arg;
509
+ else if (op === AluOp.GlobalIndex) {
510
+ const gid = arg[0];
511
+ const bufidx = gen(src[0]);
512
+ source = `load_${inputDtypes[gid]}(${args[gid]}, ${strip1(bufidx)})`;
513
+ }
514
+ if (!source) throw new UnsupportedOpError(op, dtype, "webgl", arg);
515
+ expContext.set(e, source);
516
+ return source;
517
+ };
518
+ return gen(exp);
519
+ }
520
+
521
+ //#endregion
522
+ export { WebGLBackend };