@jax-js/jax 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -7
- package/dist/{backend-tngXtWe4.js → backend-DaqL-MNz.js} +96 -7
- package/dist/{backend-Bu9GY6sK.cjs → backend-DziQSaoQ.cjs} +101 -6
- package/dist/index.cjs +737 -141
- package/dist/index.d.cts +238 -9
- package/dist/index.d.ts +238 -9
- package/dist/index.js +737 -141
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/{webgpu-Oj3Kd-kd.cjs → webgpu-Db2JrNBr.cjs} +296 -3
- package/dist/{webgpu-ChVgx3b6.js → webgpu-Dh7k9io0.js} +296 -3
- package/package.json +1 -1
|
@@ -0,0 +1,522 @@
|
|
|
1
|
+
const require_backend = require('./backend-DziQSaoQ.cjs');
|
|
2
|
+
|
|
3
|
+
//#region src/backend/webgl/builtins.ts
|
|
4
|
+
const threefrySrc = `
|
|
5
|
+
uvec2 threefry2x32(uvec2 key, uvec2 ctr) {
|
|
6
|
+
uint ks0 = key.x;
|
|
7
|
+
uint ks1 = key.y;
|
|
8
|
+
uint ks2 = ks0 ^ ks1 ^ 0x1BD11BDAu;
|
|
9
|
+
|
|
10
|
+
uint x0 = ctr.x + ks0;
|
|
11
|
+
uint x1 = ctr.y + ks1;
|
|
12
|
+
|
|
13
|
+
x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
|
|
14
|
+
x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
|
|
15
|
+
x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
|
|
16
|
+
x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
|
|
17
|
+
x0 += ks1;
|
|
18
|
+
x1 += ks2 + 1u;
|
|
19
|
+
|
|
20
|
+
x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
|
|
21
|
+
x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
|
|
22
|
+
x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
|
|
23
|
+
x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
|
|
24
|
+
x0 += ks2;
|
|
25
|
+
x1 += ks0 + 2u;
|
|
26
|
+
|
|
27
|
+
x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
|
|
28
|
+
x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
|
|
29
|
+
x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
|
|
30
|
+
x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
|
|
31
|
+
x0 += ks0;
|
|
32
|
+
x1 += ks1 + 3u;
|
|
33
|
+
|
|
34
|
+
x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
|
|
35
|
+
x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
|
|
36
|
+
x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
|
|
37
|
+
x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
|
|
38
|
+
x0 += ks1;
|
|
39
|
+
x1 += ks2 + 4u;
|
|
40
|
+
|
|
41
|
+
x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
|
|
42
|
+
x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
|
|
43
|
+
x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
|
|
44
|
+
x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
|
|
45
|
+
x0 += ks2;
|
|
46
|
+
x1 += ks0 + 5u;
|
|
47
|
+
|
|
48
|
+
return uvec2(x0, x1);
|
|
49
|
+
}`;
|
|
50
|
+
const erfSrc = `
|
|
51
|
+
const float _erf_p = 0.3275911;
|
|
52
|
+
const float _erf_a1 = 0.254829592;
|
|
53
|
+
const float _erf_a2 = -0.284496736;
|
|
54
|
+
const float _erf_a3 = 1.421413741;
|
|
55
|
+
const float _erf_a4 = -1.453152027;
|
|
56
|
+
const float _erf_a5 = 1.061405429;
|
|
57
|
+
float erf(float x) {
|
|
58
|
+
float t = 1.0 / (1.0 + _erf_p * abs(x));
|
|
59
|
+
float P_t = (((((_erf_a5 * t) + _erf_a4) * t + _erf_a3) * t + _erf_a2) * t + _erf_a1) * t;
|
|
60
|
+
return sign(x) * (1.0 - P_t * exp(-x * x));
|
|
61
|
+
}
|
|
62
|
+
float erfc(float x) {
|
|
63
|
+
float t = 1.0 / (1.0 + _erf_p * abs(x));
|
|
64
|
+
float P_t = (((((_erf_a5 * t) + _erf_a4) * t + _erf_a3) * t + _erf_a2) * t + _erf_a1) * t;
|
|
65
|
+
float E = P_t * exp(-x * x);
|
|
66
|
+
return x >= 0.0 ? E : 2.0 - E;
|
|
67
|
+
}`;
|
|
68
|
+
|
|
69
|
+
//#endregion
|
|
70
|
+
//#region src/backend/webgl.ts
|
|
71
|
+
/**
|
|
72
|
+
* No-frills backend that uses WebGL2 textures and shaders for compute.
|
|
73
|
+
*
|
|
74
|
+
* WebGL2 is available in almost all modern browsers, and it has options for
|
|
75
|
+
* floating-point numbers and integers in textures. However, it's still not a
|
|
76
|
+
* "real" compute API, and only float32 arithmetic is available.
|
|
77
|
+
*
|
|
78
|
+
* We make this backend available in case users want a fallback option compared
|
|
79
|
+
* to WebGPU, which is only on newer browsers and iOS 26+.
|
|
80
|
+
*
|
|
81
|
+
* Implementation notes:
|
|
82
|
+
* - All data is stored in typed RGBA32F textures, regardless of original dtype.
|
|
83
|
+
* They are converted to the correct data type when loaded in shaders.
|
|
84
|
+
* - Each texel holds 4 float32 values (128 bits).
|
|
85
|
+
* - Compute is done by rendering a full-screen quad with a fragment shader.
|
|
86
|
+
* - Output is rendered to a framebuffer-attached texture, then read back.
|
|
87
|
+
*/
|
|
88
|
+
var WebGLBackend = class {
|
|
89
|
+
type = "webgl";
|
|
90
|
+
maxArgs = 8;
|
|
91
|
+
gl;
|
|
92
|
+
#fbo;
|
|
93
|
+
#buffers;
|
|
94
|
+
#programCache;
|
|
95
|
+
#nextSlot;
|
|
96
|
+
constructor(gl) {
|
|
97
|
+
this.gl = gl;
|
|
98
|
+
this.#fbo = gl.createFramebuffer();
|
|
99
|
+
this.#buffers = /* @__PURE__ */ new Map();
|
|
100
|
+
this.#programCache = /* @__PURE__ */ new Map();
|
|
101
|
+
this.#nextSlot = 1;
|
|
102
|
+
}
|
|
103
|
+
/**
|
|
104
|
+
* Allocate a buffer with a specific dtype.
|
|
105
|
+
*
|
|
106
|
+
* All buffers use RGBA32F texture format internally. Data is stored as raw
|
|
107
|
+
* bits and reinterpreted using floatBitsToInt/intBitsToFloat in shaders.
|
|
108
|
+
* This mirrors how WebGPU handles untyped byte buffers.
|
|
109
|
+
*/
|
|
110
|
+
malloc(size, initialData) {
|
|
111
|
+
const gl = this.gl;
|
|
112
|
+
const numFloats = Math.ceil(size / 4) || 1;
|
|
113
|
+
const numTexels = Math.ceil(numFloats / 4) || 1;
|
|
114
|
+
const { width, height } = computeTextureDimensions(numTexels);
|
|
115
|
+
const texture = gl.createTexture();
|
|
116
|
+
if (!texture) throw new Error("Failed to create texture");
|
|
117
|
+
gl.bindTexture(gl.TEXTURE_2D, texture);
|
|
118
|
+
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
|
|
119
|
+
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
|
|
120
|
+
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
|
121
|
+
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
|
122
|
+
const totalFloats = width * height * 4;
|
|
123
|
+
let pixels = null;
|
|
124
|
+
if (initialData) {
|
|
125
|
+
pixels = new Float32Array(totalFloats);
|
|
126
|
+
new Uint8Array(pixels.buffer).set(initialData);
|
|
127
|
+
}
|
|
128
|
+
gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA32F, width, height, 0, gl.RGBA, gl.FLOAT, pixels);
|
|
129
|
+
gl.bindTexture(gl.TEXTURE_2D, null);
|
|
130
|
+
const slot = this.#nextSlot++;
|
|
131
|
+
this.#buffers.set(slot, {
|
|
132
|
+
ref: 1,
|
|
133
|
+
size,
|
|
134
|
+
texture,
|
|
135
|
+
width,
|
|
136
|
+
height
|
|
137
|
+
});
|
|
138
|
+
return slot;
|
|
139
|
+
}
|
|
140
|
+
incRef(slot) {
|
|
141
|
+
const buffer = this.#buffers.get(slot);
|
|
142
|
+
if (!buffer) throw new require_backend.SlotError(slot);
|
|
143
|
+
buffer.ref++;
|
|
144
|
+
}
|
|
145
|
+
decRef(slot) {
|
|
146
|
+
const buffer = this.#buffers.get(slot);
|
|
147
|
+
if (!buffer) throw new require_backend.SlotError(slot);
|
|
148
|
+
buffer.ref--;
|
|
149
|
+
if (buffer.ref === 0) {
|
|
150
|
+
this.gl.deleteTexture(buffer.texture);
|
|
151
|
+
this.#buffers.delete(slot);
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
async read(slot, start, count) {
|
|
155
|
+
const buffer = this.#buffers.get(slot);
|
|
156
|
+
if (!buffer) throw new require_backend.SlotError(slot);
|
|
157
|
+
const gl = this.gl;
|
|
158
|
+
if (start === void 0) start = 0;
|
|
159
|
+
if (count === void 0) count = buffer.size - start;
|
|
160
|
+
gl.bindFramebuffer(gl.FRAMEBUFFER, this.#fbo);
|
|
161
|
+
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, buffer.texture, 0);
|
|
162
|
+
const totalBytes = buffer.width * buffer.height * 4 * 4;
|
|
163
|
+
const floatData = new Float32Array(totalBytes / 4);
|
|
164
|
+
const pbo = gl.createBuffer();
|
|
165
|
+
if (!pbo) throw new Error("Failed to create PBO");
|
|
166
|
+
gl.bindBuffer(gl.PIXEL_PACK_BUFFER, pbo);
|
|
167
|
+
gl.bufferData(gl.PIXEL_PACK_BUFFER, totalBytes, gl.STREAM_READ);
|
|
168
|
+
gl.readPixels(0, 0, buffer.width, buffer.height, gl.RGBA, gl.FLOAT, 0);
|
|
169
|
+
const readError = gl.getError();
|
|
170
|
+
if (readError !== gl.NO_ERROR) {
|
|
171
|
+
gl.deleteBuffer(pbo);
|
|
172
|
+
throw new Error(`WebGL error after readPixels: ${readError}`);
|
|
173
|
+
}
|
|
174
|
+
const sync = gl.fenceSync(gl.SYNC_GPU_COMMANDS_COMPLETE, 0);
|
|
175
|
+
if (!sync) throw new Error("Failed to create sync object");
|
|
176
|
+
gl.flush();
|
|
177
|
+
gl.bindBuffer(gl.PIXEL_PACK_BUFFER, null);
|
|
178
|
+
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
|
179
|
+
await new Promise((resolve, reject) => {
|
|
180
|
+
const poll = () => {
|
|
181
|
+
const status = gl.clientWaitSync(sync, 0, 0);
|
|
182
|
+
if (status === gl.TIMEOUT_EXPIRED) {
|
|
183
|
+
setTimeout(poll, 5);
|
|
184
|
+
return;
|
|
185
|
+
}
|
|
186
|
+
if (status === gl.WAIT_FAILED) {
|
|
187
|
+
gl.deleteSync(sync);
|
|
188
|
+
gl.deleteBuffer(pbo);
|
|
189
|
+
reject(/* @__PURE__ */ new Error("clientWaitSync failed"));
|
|
190
|
+
return;
|
|
191
|
+
}
|
|
192
|
+
resolve();
|
|
193
|
+
};
|
|
194
|
+
poll();
|
|
195
|
+
});
|
|
196
|
+
gl.deleteSync(sync);
|
|
197
|
+
gl.bindBuffer(gl.PIXEL_PACK_BUFFER, pbo);
|
|
198
|
+
gl.getBufferSubData(gl.PIXEL_PACK_BUFFER, 0, floatData);
|
|
199
|
+
gl.bindBuffer(gl.PIXEL_PACK_BUFFER, null);
|
|
200
|
+
gl.deleteBuffer(pbo);
|
|
201
|
+
const byteData = new Uint8Array(floatData.buffer);
|
|
202
|
+
return new Uint8Array(byteData.slice(start, start + count));
|
|
203
|
+
}
|
|
204
|
+
readSync(slot, start, count) {
|
|
205
|
+
const buffer = this.#buffers.get(slot);
|
|
206
|
+
if (!buffer) throw new require_backend.SlotError(slot);
|
|
207
|
+
const gl = this.gl;
|
|
208
|
+
if (start === void 0) start = 0;
|
|
209
|
+
if (count === void 0) count = buffer.size - start;
|
|
210
|
+
gl.bindFramebuffer(gl.FRAMEBUFFER, this.#fbo);
|
|
211
|
+
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, buffer.texture, 0);
|
|
212
|
+
const totalFloats = buffer.width * buffer.height * 4;
|
|
213
|
+
const floatData = new Float32Array(totalFloats);
|
|
214
|
+
gl.readPixels(0, 0, buffer.width, buffer.height, gl.RGBA, gl.FLOAT, floatData);
|
|
215
|
+
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
|
216
|
+
const byteData = new Uint8Array(floatData.buffer);
|
|
217
|
+
return new Uint8Array(byteData.slice(start, start + count));
|
|
218
|
+
}
|
|
219
|
+
async prepareKernel(kernel) {
|
|
220
|
+
return this.prepareKernelSync(kernel);
|
|
221
|
+
}
|
|
222
|
+
prepareKernelSync(kernel) {
|
|
223
|
+
const shader = generateShader(kernel);
|
|
224
|
+
const cached = this.#programCache.get(shader.code);
|
|
225
|
+
if (cached) return new require_backend.Executable(kernel, cached);
|
|
226
|
+
const dispatch = compileShader(this.gl, shader);
|
|
227
|
+
this.#programCache.set(shader.code, dispatch);
|
|
228
|
+
return new require_backend.Executable(kernel, dispatch);
|
|
229
|
+
}
|
|
230
|
+
prepareRoutine(routine) {
|
|
231
|
+
throw new require_backend.UnsupportedRoutineError(routine.name, "webgl");
|
|
232
|
+
}
|
|
233
|
+
prepareRoutineSync(routine) {
|
|
234
|
+
throw new require_backend.UnsupportedRoutineError(routine.name, "webgl");
|
|
235
|
+
}
|
|
236
|
+
dispatch(exe, inputs, outputs) {
|
|
237
|
+
const gl = this.gl;
|
|
238
|
+
if (gl.isContextLost()) throw new Error("WebGL context lost - cannot dispatch");
|
|
239
|
+
const { program, inputLocations } = exe.data;
|
|
240
|
+
if (inputs.length !== exe.data.numInputs) throw new Error(`Expected ${exe.data.numInputs} inputs, got ${inputs.length}`);
|
|
241
|
+
if (outputs.length !== 1) throw new Error(`Expected 1 output, got ${outputs.length}`);
|
|
242
|
+
const outputBuffer = this.#buffers.get(outputs[0]);
|
|
243
|
+
if (!outputBuffer) throw new require_backend.SlotError(outputs[0]);
|
|
244
|
+
gl.bindFramebuffer(gl.FRAMEBUFFER, this.#fbo);
|
|
245
|
+
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, outputBuffer.texture, 0);
|
|
246
|
+
const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
|
|
247
|
+
if (status !== gl.FRAMEBUFFER_COMPLETE) throw new Error(`Framebuffer incomplete: ${status}`);
|
|
248
|
+
gl.viewport(0, 0, outputBuffer.width, outputBuffer.height);
|
|
249
|
+
gl.useProgram(program);
|
|
250
|
+
for (let i = 0; i < inputs.length; i++) {
|
|
251
|
+
const inputBuffer = this.#buffers.get(inputs[i]);
|
|
252
|
+
if (!inputBuffer) throw new require_backend.SlotError(inputs[i]);
|
|
253
|
+
gl.activeTexture(gl.TEXTURE0 + i);
|
|
254
|
+
gl.bindTexture(gl.TEXTURE_2D, inputBuffer.texture);
|
|
255
|
+
if (inputLocations[i] !== null) gl.uniform1i(inputLocations[i], i);
|
|
256
|
+
}
|
|
257
|
+
gl.drawArrays(gl.TRIANGLES, 0, 3);
|
|
258
|
+
const error = gl.getError();
|
|
259
|
+
if (error !== gl.NO_ERROR) {
|
|
260
|
+
let errorName;
|
|
261
|
+
if (error === gl.INVALID_ENUM) errorName = "INVALID_ENUM";
|
|
262
|
+
else if (error === gl.INVALID_VALUE) errorName = "INVALID_VALUE";
|
|
263
|
+
else if (error === gl.INVALID_OPERATION) errorName = "INVALID_OPERATION";
|
|
264
|
+
else if (error === gl.INVALID_FRAMEBUFFER_OPERATION) errorName = "INVALID_FRAMEBUFFER_OPERATION";
|
|
265
|
+
else if (error === gl.OUT_OF_MEMORY) errorName = "OUT_OF_MEMORY";
|
|
266
|
+
else if (error === gl.CONTEXT_LOST_WEBGL) errorName = "CONTEXT_LOST_WEBGL";
|
|
267
|
+
else errorName = `UNKNOWN(${error})`;
|
|
268
|
+
throw new Error(`WebGL error after drawArrays: ${errorName}`);
|
|
269
|
+
}
|
|
270
|
+
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
|
271
|
+
gl.useProgram(null);
|
|
272
|
+
}
|
|
273
|
+
};
|
|
274
|
+
function generateShader(kernel) {
|
|
275
|
+
const tune = require_backend.tuneNullopt(kernel);
|
|
276
|
+
if (require_backend.DEBUG >= 3) console.info(`webgl kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
277
|
+
const { nargs, reduction: re } = kernel;
|
|
278
|
+
const outputDtype = kernel.dtype;
|
|
279
|
+
const numTexels = Math.ceil(kernel.size / 4) || 1;
|
|
280
|
+
const outputSize = computeTextureDimensions(numTexels);
|
|
281
|
+
const inputDtypes = Array(nargs).fill(require_backend.DType.Float32);
|
|
282
|
+
const builtins = {
|
|
283
|
+
erf: false,
|
|
284
|
+
threefry: false
|
|
285
|
+
};
|
|
286
|
+
const collectInfo = (exp) => {
|
|
287
|
+
if (exp.op === require_backend.AluOp.GlobalIndex) inputDtypes[exp.arg[0]] = exp.dtype;
|
|
288
|
+
else if (exp.op === require_backend.AluOp.Erf || exp.op === require_backend.AluOp.Erfc) builtins.erf = true;
|
|
289
|
+
else if (exp.op === require_backend.AluOp.Threefry2x32) builtins.threefry = true;
|
|
290
|
+
};
|
|
291
|
+
tune.exp.fold(collectInfo);
|
|
292
|
+
tune.epilogue?.fold(collectInfo);
|
|
293
|
+
const shader = [];
|
|
294
|
+
let indent = "";
|
|
295
|
+
const pushIndent = Symbol("pushIndent");
|
|
296
|
+
const popIndent = Symbol("popIndent");
|
|
297
|
+
const emit = (...lines) => {
|
|
298
|
+
for (const line of lines) if (line === pushIndent) indent += " ";
|
|
299
|
+
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
300
|
+
else shader.push(line ? indent + line : line);
|
|
301
|
+
};
|
|
302
|
+
emit("#version 300 es", "precision highp float;", "precision highp int;", "");
|
|
303
|
+
const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
|
|
304
|
+
const resultType = glslType(outputDtype);
|
|
305
|
+
for (let i = 0; i < nargs; i++) emit(`uniform highp sampler2D ${args[i]};`);
|
|
306
|
+
emit("out vec4 out0;");
|
|
307
|
+
const fetchFunctions = /* @__PURE__ */ new Set();
|
|
308
|
+
for (const dtype of inputDtypes) fetchFunctions.add(dtype);
|
|
309
|
+
for (const dtype of fetchFunctions) emit(generateLoadFunction(dtype));
|
|
310
|
+
if (builtins.erf) emit(erfSrc);
|
|
311
|
+
if (builtins.threefry) emit(threefrySrc);
|
|
312
|
+
emit(`${resultType} compute(int gidx) {`, pushIndent, `${resultType} result = ${constToGlsl(outputDtype, 0)};`, `if (gidx < ${kernel.size}) {`, pushIndent);
|
|
313
|
+
if (!re) {
|
|
314
|
+
const code = generateExpression(tune.exp, args, inputDtypes);
|
|
315
|
+
emit(`result = ${require_backend.strip1(code)};`);
|
|
316
|
+
} else {
|
|
317
|
+
const accType = glslType(re.dtype);
|
|
318
|
+
const accInit = constToGlsl(re.dtype, re.identity);
|
|
319
|
+
emit(`${accType} acc = ${accInit};`, `for (int ridx = 0; ridx < ${tune.size.reduce}; ridx++) {`, pushIndent);
|
|
320
|
+
const code = generateExpression(tune.exp, args, inputDtypes);
|
|
321
|
+
if (re.op === require_backend.AluOp.Add) emit(`acc += ${require_backend.strip1(code)};`);
|
|
322
|
+
else if (re.op === require_backend.AluOp.Mul) emit(`acc *= ${require_backend.strip1(code)};`);
|
|
323
|
+
else if (re.op === require_backend.AluOp.Min) if (re.dtype !== require_backend.DType.Bool) emit(`acc = min(acc, ${require_backend.strip1(code)});`);
|
|
324
|
+
else emit(`acc = acc && ${code};`);
|
|
325
|
+
else if (re.op === require_backend.AluOp.Max) if (re.dtype !== require_backend.DType.Bool) emit(`acc = max(acc, ${require_backend.strip1(code)});`);
|
|
326
|
+
else emit(`acc = acc || ${code};`);
|
|
327
|
+
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
328
|
+
emit(popIndent, "}");
|
|
329
|
+
emit(`result = ${generateExpression(tune.epilogue, args, inputDtypes)};`);
|
|
330
|
+
}
|
|
331
|
+
emit(popIndent, "}", "return result;", popIndent, "}\n");
|
|
332
|
+
emit("void main() {", pushIndent, "ivec2 fragCoord = ivec2(gl_FragCoord.xy);", `int texelIdx = fragCoord.y * ${outputSize.width} + fragCoord.x;`, `${resultType} result0 = compute(texelIdx * 4);`, `${resultType} result1 = compute(texelIdx * 4 + 1);`, `${resultType} result2 = compute(texelIdx * 4 + 2);`, `${resultType} result3 = compute(texelIdx * 4 + 3);`, `out0 = vec4(${require_backend.range(4).map((i) => toRGBA32F(outputDtype, `result${i}`)).join(", ")});`);
|
|
333
|
+
emit(popIndent, "}");
|
|
334
|
+
return {
|
|
335
|
+
code: shader.join("\n"),
|
|
336
|
+
numInputs: nargs,
|
|
337
|
+
outputSize: [outputSize.width, outputSize.height],
|
|
338
|
+
outputDtype
|
|
339
|
+
};
|
|
340
|
+
}
|
|
341
|
+
function compile(gl, type, src) {
|
|
342
|
+
const s = gl.createShader(type);
|
|
343
|
+
gl.shaderSource(s, src);
|
|
344
|
+
gl.compileShader(s);
|
|
345
|
+
if (!gl.getShaderParameter(s, gl.COMPILE_STATUS)) throw new Error(gl.getShaderInfoLog(s) ?? "Unknown shader compile error");
|
|
346
|
+
return s;
|
|
347
|
+
}
|
|
348
|
+
function link(gl, vsSrc, fsSrc) {
|
|
349
|
+
const p = gl.createProgram();
|
|
350
|
+
gl.attachShader(p, compile(gl, gl.VERTEX_SHADER, vsSrc));
|
|
351
|
+
gl.attachShader(p, compile(gl, gl.FRAGMENT_SHADER, fsSrc));
|
|
352
|
+
gl.linkProgram(p);
|
|
353
|
+
if (!gl.getProgramParameter(p, gl.LINK_STATUS)) throw new Error(gl.getProgramInfoLog(p) ?? "Unknown program link error");
|
|
354
|
+
return p;
|
|
355
|
+
}
|
|
356
|
+
const vertexShaderSource = `#version 300 es
|
|
357
|
+
precision highp float;
|
|
358
|
+
const vec2 pos[3] = vec2[](vec2(-1.0,-1.0), vec2(3.0,-1.0), vec2(-1.0,3.0));
|
|
359
|
+
void main() { gl_Position = vec4(pos[gl_VertexID], 0.0, 1.0); }
|
|
360
|
+
`;
|
|
361
|
+
function compileShader(gl, shader) {
|
|
362
|
+
if (require_backend.DEBUG >= 1) console.info("=========== WebGL shader ===========\n" + shader.code);
|
|
363
|
+
const program = link(gl, vertexShaderSource, shader.code);
|
|
364
|
+
const inputLocations = [];
|
|
365
|
+
for (let i = 0; i < shader.numInputs; i++) inputLocations.push(gl.getUniformLocation(program, `in${i}`));
|
|
366
|
+
return {
|
|
367
|
+
...shader,
|
|
368
|
+
program,
|
|
369
|
+
inputLocations
|
|
370
|
+
};
|
|
371
|
+
}
|
|
372
|
+
/** Compute 2D texture dimensions for a given number of texels. */
|
|
373
|
+
function computeTextureDimensions(numTexels) {
|
|
374
|
+
const maxDim = 16384;
|
|
375
|
+
let width = Math.min(Math.ceil(Math.sqrt(numTexels)), maxDim);
|
|
376
|
+
width = Math.min(1 << Math.ceil(Math.log2(width)), maxDim);
|
|
377
|
+
const height = Math.min(Math.ceil(numTexels / width), maxDim);
|
|
378
|
+
return {
|
|
379
|
+
width,
|
|
380
|
+
height
|
|
381
|
+
};
|
|
382
|
+
}
|
|
383
|
+
function glslType(dtype) {
|
|
384
|
+
switch (dtype) {
|
|
385
|
+
case require_backend.DType.Float32: return "float";
|
|
386
|
+
case require_backend.DType.Int32: return "int";
|
|
387
|
+
case require_backend.DType.Uint32: return "uint";
|
|
388
|
+
case require_backend.DType.Bool: return "bool";
|
|
389
|
+
default: throw new Error(`Unsupported dtype for WebGL: ${dtype}`);
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
function generateLoadFunction(dtype) {
|
|
393
|
+
const funcName = `load_${dtype}`;
|
|
394
|
+
const returnType = glslType(dtype);
|
|
395
|
+
let conversion;
|
|
396
|
+
if (require_backend.isFloatDtype(dtype)) conversion = "val";
|
|
397
|
+
else if (dtype === require_backend.DType.Int32) conversion = "floatBitsToInt(val)";
|
|
398
|
+
else if (dtype === require_backend.DType.Uint32) conversion = "floatBitsToUint(val)";
|
|
399
|
+
else if (dtype === require_backend.DType.Bool) conversion = "floatBitsToInt(val) != 0";
|
|
400
|
+
else throw new Error(`Unsupported dtype for WebGL fetch: ${dtype}`);
|
|
401
|
+
return `
|
|
402
|
+
${returnType} ${funcName}(highp sampler2D tex, int idx) {
|
|
403
|
+
ivec2 texSize = textureSize(tex, 0);
|
|
404
|
+
int texel = idx / 4;
|
|
405
|
+
int component = idx - texel * 4;
|
|
406
|
+
ivec2 coord = ivec2(texel % texSize.x, texel / texSize.x);
|
|
407
|
+
vec4 texVal = texelFetch(tex, coord, 0);
|
|
408
|
+
float val;
|
|
409
|
+
if (component == 0) val = texVal.x;
|
|
410
|
+
else if (component == 1) val = texVal.y;
|
|
411
|
+
else if (component == 2) val = texVal.z;
|
|
412
|
+
else val = texVal.w;
|
|
413
|
+
return ${conversion};
|
|
414
|
+
}
|
|
415
|
+
`;
|
|
416
|
+
}
|
|
417
|
+
function toRGBA32F(dtype, source) {
|
|
418
|
+
switch (dtype) {
|
|
419
|
+
case require_backend.DType.Float32: return source;
|
|
420
|
+
case require_backend.DType.Int32: return `intBitsToFloat(${source})`;
|
|
421
|
+
case require_backend.DType.Uint32: return `uintBitsToFloat(${source})`;
|
|
422
|
+
case require_backend.DType.Bool: return `intBitsToFloat(${source} ? 1 : 0)`;
|
|
423
|
+
default: throw new Error(`Unsupported dtype for WebGL output: ${dtype}`);
|
|
424
|
+
}
|
|
425
|
+
}
|
|
426
|
+
function constToGlsl(dtype, value) {
|
|
427
|
+
switch (dtype) {
|
|
428
|
+
case require_backend.DType.Bool: return value ? "true" : "false";
|
|
429
|
+
case require_backend.DType.Int32: return value.toString();
|
|
430
|
+
case require_backend.DType.Uint32: return value.toString() + "u";
|
|
431
|
+
case require_backend.DType.Float32:
|
|
432
|
+
if (Number.isNaN(value)) return "uintBitsToFloat(0x7fc00000u)";
|
|
433
|
+
if (!Number.isFinite(value)) return value > 0 ? "uintBitsToFloat(0x7f800000u)" : "uintBitsToFloat(0xff800000u)";
|
|
434
|
+
return "float(" + value.toString() + ")";
|
|
435
|
+
default: throw new Error(`Unsupported dtype for WebGL constant: ${dtype}`);
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
/** Generate GLSL expression code from an AluExp. */
|
|
439
|
+
function generateExpression(exp, args, inputDtypes) {
|
|
440
|
+
const expContext = /* @__PURE__ */ new Map();
|
|
441
|
+
const gen = (e) => {
|
|
442
|
+
if (expContext.has(e)) return expContext.get(e);
|
|
443
|
+
const { op, src, dtype, arg } = e;
|
|
444
|
+
let source = "";
|
|
445
|
+
if (require_backend.AluGroup.Binary.has(op)) {
|
|
446
|
+
const a = gen(src[0]);
|
|
447
|
+
const b = gen(src[1]);
|
|
448
|
+
if (op === require_backend.AluOp.Add) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
449
|
+
else source = `(${a} + ${b})`;
|
|
450
|
+
else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
|
|
451
|
+
else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
|
|
452
|
+
else source = `(${a} * ${b})`;
|
|
453
|
+
else if (op === require_backend.AluOp.Idiv) if (require_backend.isFloatDtype(dtype)) source = `trunc(${a} / ${b})`;
|
|
454
|
+
else source = `(${a} / ${b})`;
|
|
455
|
+
else if (op === require_backend.AluOp.Mod) if (require_backend.isFloatDtype(dtype)) source = `(${a} - ${b} * trunc(${a} / ${b}))`;
|
|
456
|
+
else source = `(${a} % ${b})`;
|
|
457
|
+
else if (op === require_backend.AluOp.Min) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
|
|
458
|
+
else source = `min(${a}, ${b})`;
|
|
459
|
+
else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
460
|
+
else source = `max(${a}, ${b})`;
|
|
461
|
+
} else if (require_backend.AluGroup.Compare.has(op)) {
|
|
462
|
+
const a = gen(src[0]);
|
|
463
|
+
const b = gen(src[1]);
|
|
464
|
+
if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
465
|
+
else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) source = `(${a} != ${b} || isnan(${a}) || isnan(${b}))`;
|
|
466
|
+
else source = `(${a} != ${b})`;
|
|
467
|
+
} else if (require_backend.AluGroup.Unary.has(op)) {
|
|
468
|
+
const a = gen(src[0]);
|
|
469
|
+
if (op === require_backend.AluOp.Sin) source = `sin(${require_backend.strip1(a)})`;
|
|
470
|
+
else if (op === require_backend.AluOp.Cos) source = `cos(${require_backend.strip1(a)})`;
|
|
471
|
+
else if (op === require_backend.AluOp.Asin) source = `asin(${require_backend.strip1(a)})`;
|
|
472
|
+
else if (op === require_backend.AluOp.Atan) source = `atan(${require_backend.strip1(a)})`;
|
|
473
|
+
else if (op === require_backend.AluOp.Exp) source = `exp(${require_backend.strip1(a)})`;
|
|
474
|
+
else if (op === require_backend.AluOp.Log) source = `log(${require_backend.strip1(a)})`;
|
|
475
|
+
else if (op === require_backend.AluOp.Erf) source = `erf(${require_backend.strip1(a)})`;
|
|
476
|
+
else if (op === require_backend.AluOp.Erfc) source = `erfc(${require_backend.strip1(a)})`;
|
|
477
|
+
else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${require_backend.strip1(a)})`;
|
|
478
|
+
else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
|
|
479
|
+
else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
|
|
480
|
+
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
481
|
+
else if (op === require_backend.AluOp.Cast) source = `${glslType(dtype)}(${require_backend.strip1(a)})`;
|
|
482
|
+
else if (op === require_backend.AluOp.Bitcast) {
|
|
483
|
+
const dtype0 = src[0].dtype;
|
|
484
|
+
if (dtype === dtype0) source = a;
|
|
485
|
+
else if (dtype === require_backend.DType.Float32) {
|
|
486
|
+
if (dtype0 === require_backend.DType.Int32) source = `intBitsToFloat(${require_backend.strip1(a)})`;
|
|
487
|
+
else if (dtype0 === require_backend.DType.Uint32) source = `uintBitsToFloat(${require_backend.strip1(a)})`;
|
|
488
|
+
} else if (dtype === require_backend.DType.Int32) {
|
|
489
|
+
if (dtype0 === require_backend.DType.Float32) source = `floatBitsToInt(${require_backend.strip1(a)})`;
|
|
490
|
+
else if (dtype0 === require_backend.DType.Uint32) source = `int(${require_backend.strip1(a)})`;
|
|
491
|
+
} else if (dtype === require_backend.DType.Uint32) {
|
|
492
|
+
if (dtype0 === require_backend.DType.Float32) source = `floatBitsToUint(${require_backend.strip1(a)})`;
|
|
493
|
+
else if (dtype0 === require_backend.DType.Int32) source = `uint(${require_backend.strip1(a)})`;
|
|
494
|
+
}
|
|
495
|
+
}
|
|
496
|
+
} else if (op === require_backend.AluOp.Threefry2x32) {
|
|
497
|
+
const [k0, k1, c0, c1] = src.map((x) => require_backend.strip1(gen(x)));
|
|
498
|
+
const mode = arg;
|
|
499
|
+
const call = `threefry2x32(uvec2(${k0}, ${k1}), uvec2(${c0}, ${c1}))`;
|
|
500
|
+
if (mode === "xor") source = `(${call}.x ^ ${call}.y)`;
|
|
501
|
+
else if (mode === 0) source = `${call}.x`;
|
|
502
|
+
else if (mode === 1) source = `${call}.y`;
|
|
503
|
+
} else if (op === require_backend.AluOp.Where) {
|
|
504
|
+
const [cond, t, f] = src.map(gen);
|
|
505
|
+
source = `(${cond} ? ${t} : ${f})`;
|
|
506
|
+
} else if (op === require_backend.AluOp.Const) source = constToGlsl(dtype, arg);
|
|
507
|
+
else if (op === require_backend.AluOp.Special) source = arg[0];
|
|
508
|
+
else if (op === require_backend.AluOp.Variable) source = arg;
|
|
509
|
+
else if (op === require_backend.AluOp.GlobalIndex) {
|
|
510
|
+
const gid = arg[0];
|
|
511
|
+
const bufidx = gen(src[0]);
|
|
512
|
+
source = `load_${inputDtypes[gid]}(${args[gid]}, ${require_backend.strip1(bufidx)})`;
|
|
513
|
+
}
|
|
514
|
+
if (!source) throw new require_backend.UnsupportedOpError(op, dtype, "webgl", arg);
|
|
515
|
+
expContext.set(e, source);
|
|
516
|
+
return source;
|
|
517
|
+
};
|
|
518
|
+
return gen(exp);
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
//#endregion
|
|
522
|
+
exports.WebGLBackend = WebGLBackend;
|