@jax-js/jax 0.1.10 → 0.1.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -2
- package/dist/{backend-Ctqs8la1.js → backend-DI-V78Rk.js} +732 -21
- package/dist/{backend-DMauYnfl.cjs → backend-x-6vqzIM.cjs} +737 -20
- package/dist/index.cjs +372 -20
- package/dist/index.d.cts +172 -4
- package/dist/index.d.ts +172 -4
- package/dist/index.js +372 -21
- package/dist/{webgl-CvQ1QBX1.js → webgl-BhsnpeB0.js} +7 -1
- package/dist/{webgl-kvVt7-T7.cjs → webgl-CD3WK_Me.cjs} +7 -1
- package/dist/{webgpu-v_W_-oKw.js → webgpu-C2kLdkUh.js} +299 -149
- package/dist/{webgpu-DMSx7a6M.cjs → webgpu-C4S8Uq9e.cjs} +299 -149
- package/package.json +1 -1
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-DI-V78Rk.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -72,6 +72,45 @@ const headerWgsl = String.raw`
|
|
|
72
72
|
fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }
|
|
73
73
|
fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }
|
|
74
74
|
`.trim();
|
|
75
|
+
/** Builder class for simple program generation with WGSL. */
|
|
76
|
+
var WgslBuilder = class {
|
|
77
|
+
pushIndent = Symbol("pushIndent");
|
|
78
|
+
popIndent = Symbol("popIndent");
|
|
79
|
+
lines = [];
|
|
80
|
+
#indent = "";
|
|
81
|
+
emit(...lines) {
|
|
82
|
+
for (const line of lines) if (line === this.pushIndent) this.#indent += " ";
|
|
83
|
+
else if (line === this.popIndent) this.#indent = this.#indent.slice(0, -2);
|
|
84
|
+
else this.lines.push(line ? this.#indent + line : "");
|
|
85
|
+
}
|
|
86
|
+
emitPreamble(device, exps) {
|
|
87
|
+
let hasFloat16 = false;
|
|
88
|
+
let distinctOps = /* @__PURE__ */ new Map();
|
|
89
|
+
for (const exp of exps) {
|
|
90
|
+
if (exp == null) continue;
|
|
91
|
+
hasFloat16 ||= exp.some((e) => e.dtype === DType.Float16);
|
|
92
|
+
distinctOps = mapSetUnion(distinctOps, exp.distinctOps());
|
|
93
|
+
}
|
|
94
|
+
if (hasFloat16) {
|
|
95
|
+
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
96
|
+
this.emit("enable f16;");
|
|
97
|
+
}
|
|
98
|
+
this.emit(headerWgsl);
|
|
99
|
+
if (distinctOps.has(AluOp.Threefry2x32)) this.emit(threefrySrc);
|
|
100
|
+
if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) this.emit(erfSrc);
|
|
101
|
+
this.emit("");
|
|
102
|
+
}
|
|
103
|
+
/**
|
|
104
|
+
* Insert phony assignments, in case some inputs are not in use.
|
|
105
|
+
* <https://github.com/gpuweb/gpuweb/discussions/4582#discussioncomment-9146686>
|
|
106
|
+
*/
|
|
107
|
+
emitPhonyAssignments(args) {
|
|
108
|
+
if (args.length > 0) this.emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
109
|
+
}
|
|
110
|
+
toString() {
|
|
111
|
+
return this.lines.join("\n");
|
|
112
|
+
}
|
|
113
|
+
};
|
|
75
114
|
function dtypeToWgsl(dtype, storage = false) {
|
|
76
115
|
switch (dtype) {
|
|
77
116
|
case DType.Bool: return storage ? "i32" : "bool";
|
|
@@ -108,6 +147,136 @@ function constToWgsl(dtype, value) {
|
|
|
108
147
|
}
|
|
109
148
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
110
149
|
}
|
|
150
|
+
/** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
|
|
151
|
+
var WgslExpCodegen = class {
|
|
152
|
+
#gensymCount = 0;
|
|
153
|
+
#references = /* @__PURE__ */ new Map();
|
|
154
|
+
#seen = /* @__PURE__ */ new Set();
|
|
155
|
+
#context = /* @__PURE__ */ new Map();
|
|
156
|
+
constructor(wb, args) {
|
|
157
|
+
this.wb = wb;
|
|
158
|
+
this.args = args;
|
|
159
|
+
}
|
|
160
|
+
#gensym() {
|
|
161
|
+
return `alu${this.#gensymCount++}`;
|
|
162
|
+
}
|
|
163
|
+
#isGensym(text) {
|
|
164
|
+
return text.match(/^alu[0-9]+$/);
|
|
165
|
+
}
|
|
166
|
+
/**
|
|
167
|
+
* Count references for an expression.
|
|
168
|
+
*
|
|
169
|
+
* Used to get an ahead-of-time reference count for each node in the AluExp.
|
|
170
|
+
* Expressions with reference count greater than 1 are stored in temporary
|
|
171
|
+
* variables to avoid recomputation.
|
|
172
|
+
*/
|
|
173
|
+
countReferences(exp) {
|
|
174
|
+
this.#references.set(exp, (this.#references.get(exp) ?? 0) + 1);
|
|
175
|
+
if (!this.#seen.has(exp)) {
|
|
176
|
+
this.#seen.add(exp);
|
|
177
|
+
for (const src of exp.src) this.countReferences(src);
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
reset() {
|
|
181
|
+
this.#references.clear();
|
|
182
|
+
this.#seen.clear();
|
|
183
|
+
this.#context.clear();
|
|
184
|
+
}
|
|
185
|
+
/**
|
|
186
|
+
* Generate code for an expression.
|
|
187
|
+
*
|
|
188
|
+
* Calls itself recursively and eliminates common subexpressions by storing
|
|
189
|
+
* them in temporary variables, emitted to the current builder scope. This is
|
|
190
|
+
* a side-effect that leads to multiline code generation.
|
|
191
|
+
*/
|
|
192
|
+
run(exp) {
|
|
193
|
+
if (this.#context.has(exp)) return this.#context.get(exp);
|
|
194
|
+
const { op, src, dtype, arg } = exp;
|
|
195
|
+
let source = "";
|
|
196
|
+
if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
|
|
197
|
+
const a = this.run(src[0]);
|
|
198
|
+
const b = this.run(src[1]);
|
|
199
|
+
if (op === AluOp.Add) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
200
|
+
else source = `(${a} + ${b})`;
|
|
201
|
+
else if (op === AluOp.Sub) source = `(${a} - ${b})`;
|
|
202
|
+
else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
|
|
203
|
+
else source = `(${a} * ${b})`;
|
|
204
|
+
else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
205
|
+
else if (op === AluOp.Mod) source = `(${a} % ${b})`;
|
|
206
|
+
else if (op === AluOp.Min) if (dtype === DType.Bool) source = `(${a} && ${b})`;
|
|
207
|
+
else source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
208
|
+
else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
209
|
+
else source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
210
|
+
else if (op === AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
|
|
211
|
+
else if (arg === "or") source = `(${a} | ${b})`;
|
|
212
|
+
else source = dtype === DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
|
|
213
|
+
else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
214
|
+
else source = `(${a} >> ${b})`;
|
|
215
|
+
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
216
|
+
else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
|
|
217
|
+
const x = this.#isGensym(a) ? a : this.#gensym();
|
|
218
|
+
if (x !== a) this.wb.emit(`let ${x} = ${a};`);
|
|
219
|
+
source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
|
|
220
|
+
} else source = `(${a} != ${b})`;
|
|
221
|
+
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
|
|
222
|
+
const a = this.run(src[0].src[0]);
|
|
223
|
+
source = `inverseSqrt(${a})`;
|
|
224
|
+
} else {
|
|
225
|
+
const a = this.run(src[0]);
|
|
226
|
+
if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
|
|
227
|
+
else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
|
|
228
|
+
else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
|
|
229
|
+
else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
|
|
230
|
+
else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
|
|
231
|
+
else if (op === AluOp.Log) source = `log(${strip1(a)})`;
|
|
232
|
+
else if (op === AluOp.Erf || op === AluOp.Erfc) {
|
|
233
|
+
const funcName = op === AluOp.Erf ? "erf" : "erfc";
|
|
234
|
+
if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${strip1(a)})))`;
|
|
235
|
+
else source = `${funcName}(${strip1(a)})`;
|
|
236
|
+
} else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
|
|
237
|
+
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
238
|
+
else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
|
|
239
|
+
else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
|
|
240
|
+
else if (op === AluOp.Cast) {
|
|
241
|
+
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
242
|
+
const dstTy = dtypeToWgsl(dtype);
|
|
243
|
+
if (isFloatDtype(src[0].dtype) && !(isFloatDtype(dtype) || dtype === DType.Bool)) {
|
|
244
|
+
const maxVal = maxValueWgsl(dtype);
|
|
245
|
+
const x = this.#isGensym(a) ? a : this.#gensym();
|
|
246
|
+
if (x !== a) this.wb.emit(`let ${x}: ${srcTy} = ${strip1(a)};`);
|
|
247
|
+
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
248
|
+
} else source = `${dstTy}(${strip1(a)})`;
|
|
249
|
+
} else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
250
|
+
}
|
|
251
|
+
else if (op === AluOp.Where) source = `select(${strip1(this.run(src[2]))}, ${strip1(this.run(src[1]))}, ${strip1(this.run(src[0]))})`;
|
|
252
|
+
else if (op === AluOp.Threefry2x32) {
|
|
253
|
+
const x = this.#gensym();
|
|
254
|
+
const [k0, k1, c0, c1] = src.map((x$1) => strip1(this.run(x$1)));
|
|
255
|
+
this.wb.emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
|
|
256
|
+
if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
|
|
257
|
+
else if (arg === 0) source = `${x}.x`;
|
|
258
|
+
else if (arg === 1) source = `${x}.y`;
|
|
259
|
+
else throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
260
|
+
} else if (op === AluOp.Const) return constToWgsl(dtype, arg);
|
|
261
|
+
else if (op === AluOp.Special) return arg[0];
|
|
262
|
+
else if (op === AluOp.Variable) return arg;
|
|
263
|
+
else if (op === AluOp.GlobalIndex) {
|
|
264
|
+
source = `${this.args[arg[0]]}[${strip1(this.run(src[0]))}]`;
|
|
265
|
+
if (dtype === DType.Bool) source = `(${source} != 0)`;
|
|
266
|
+
}
|
|
267
|
+
if (!source) throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
268
|
+
const typeName = dtypeToWgsl(dtype);
|
|
269
|
+
if ((this.#references.get(exp) ?? 0) > 1) {
|
|
270
|
+
const name = this.#gensym();
|
|
271
|
+
this.#context.set(exp, name);
|
|
272
|
+
this.wb.emit(`let ${name}: ${typeName} = ${strip1(source)};`);
|
|
273
|
+
return name;
|
|
274
|
+
} else {
|
|
275
|
+
this.#context.set(exp, source);
|
|
276
|
+
return source;
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
};
|
|
111
280
|
const gridOffsetY = 16384;
|
|
112
281
|
function calculateGrid(gridSize) {
|
|
113
282
|
let gridX = gridSize;
|
|
@@ -119,6 +288,91 @@ function calculateGrid(gridSize) {
|
|
|
119
288
|
return [gridX, gridY];
|
|
120
289
|
}
|
|
121
290
|
|
|
291
|
+
//#endregion
|
|
292
|
+
//#region src/backend/webgpu/nullaryKernel.ts
|
|
293
|
+
function uniformDtype(dtype) {
|
|
294
|
+
if (dtype === DType.Float16) return DType.Float32;
|
|
295
|
+
if (dtype === DType.Bool) return DType.Int32;
|
|
296
|
+
return dtype;
|
|
297
|
+
}
|
|
298
|
+
function replacementFor({ name, dtype, uniformDtype: uniformDtype$1 }) {
|
|
299
|
+
const value = AluExp.variable(uniformDtype$1, `uniforms.${name}`);
|
|
300
|
+
if (dtype === DType.Float16) return AluExp.cast(DType.Float16, value);
|
|
301
|
+
if (dtype === DType.Bool) return AluExp.cmpne(value, AluExp.i32(0));
|
|
302
|
+
return value;
|
|
303
|
+
}
|
|
304
|
+
function writeUniform(view, offset, dtype, value) {
|
|
305
|
+
switch (dtype) {
|
|
306
|
+
case DType.Float32:
|
|
307
|
+
view.setFloat32(offset, value, true);
|
|
308
|
+
break;
|
|
309
|
+
case DType.Int32:
|
|
310
|
+
view.setInt32(offset, value, true);
|
|
311
|
+
break;
|
|
312
|
+
case DType.Uint32:
|
|
313
|
+
view.setUint32(offset, value, true);
|
|
314
|
+
break;
|
|
315
|
+
default: throw new Error(`Unsupported dtype for constant uniform: ${dtype}`);
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
function liftConstants(exp) {
|
|
319
|
+
const uniforms = [];
|
|
320
|
+
const lifted = exp.rewrite((node) => {
|
|
321
|
+
if (node.op !== AluOp.Const || node.arg === 0) return;
|
|
322
|
+
const uniform = {
|
|
323
|
+
name: `c${uniforms.length}`,
|
|
324
|
+
dtype: node.dtype,
|
|
325
|
+
uniformDtype: uniformDtype(node.dtype),
|
|
326
|
+
value: node.arg
|
|
327
|
+
};
|
|
328
|
+
uniforms.push(uniform);
|
|
329
|
+
return replacementFor(uniform);
|
|
330
|
+
});
|
|
331
|
+
return [lifted, uniforms];
|
|
332
|
+
}
|
|
333
|
+
function uniformsData(uniforms) {
|
|
334
|
+
const data = new Uint8Array(uniforms.length * 4);
|
|
335
|
+
const view = new DataView(data.buffer);
|
|
336
|
+
uniforms.forEach((u, i) => writeUniform(view, i * 4, u.uniformDtype, u.value));
|
|
337
|
+
return data;
|
|
338
|
+
}
|
|
339
|
+
function nullaryKernelSource(device, kernel) {
|
|
340
|
+
if (kernel.nargs !== 0 || kernel.reduction) return null;
|
|
341
|
+
let exp = kernel.exp.substitute({ gidx: AluExp.special(DType.Int32, "gidx", kernel.size) }).simplify();
|
|
342
|
+
let uniforms = [];
|
|
343
|
+
[exp, uniforms] = liftConstants(exp);
|
|
344
|
+
const wb = new WgslBuilder();
|
|
345
|
+
wb.emitPreamble(device, [exp]);
|
|
346
|
+
if (uniforms.length > 0) wb.emit("struct Uniforms {", wb.pushIndent, ...uniforms.map((u) => `${u.name}: ${dtypeToWgsl(u.uniformDtype)},`), wb.popIndent, "}\n");
|
|
347
|
+
const resultTy = dtypeToWgsl(kernel.dtype, true);
|
|
348
|
+
wb.emit(`@group(0) @binding(0) var<storage, read_write> result : array<${resultTy}>;`);
|
|
349
|
+
if (uniforms.length > 0) wb.emit(`@group(1) @binding(0) var<uniform> uniforms: Uniforms;`);
|
|
350
|
+
const workgroupSize = findPow2(kernel.size, 256);
|
|
351
|
+
const gridSize = Math.ceil(kernel.size / workgroupSize);
|
|
352
|
+
const [gridX, gridY] = calculateGrid(gridSize);
|
|
353
|
+
wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
|
|
354
|
+
if (gridY === 1) wb.emit(`if (id.x >= ${kernel.size}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
355
|
+
else {
|
|
356
|
+
const sizeX = gridX * workgroupSize;
|
|
357
|
+
wb.emit(`if (${sizeX} * id.y + id.x >= ${kernel.size}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
358
|
+
}
|
|
359
|
+
const gen = new WgslExpCodegen(wb, []);
|
|
360
|
+
gen.countReferences(exp);
|
|
361
|
+
let rhs = strip1(gen.run(exp));
|
|
362
|
+
if (resultTy !== dtypeToWgsl(exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
363
|
+
wb.emit(`result[gidx] = ${rhs};`, wb.popIndent, "}");
|
|
364
|
+
return {
|
|
365
|
+
code: wb.toString(),
|
|
366
|
+
numInputs: 0,
|
|
367
|
+
numOutputs: 1,
|
|
368
|
+
hasUniform: uniforms.length > 0,
|
|
369
|
+
passes: [{
|
|
370
|
+
grid: [gridX, gridY],
|
|
371
|
+
uniform: uniforms.length > 0 ? uniformsData(uniforms) : void 0
|
|
372
|
+
}]
|
|
373
|
+
};
|
|
374
|
+
}
|
|
375
|
+
|
|
122
376
|
//#endregion
|
|
123
377
|
//#region src/backend/webgpu/reader.ts
|
|
124
378
|
/**
|
|
@@ -1024,28 +1278,14 @@ var WebGPUBackend = class {
|
|
|
1024
1278
|
* and y axes, to run the kernel.
|
|
1025
1279
|
*/
|
|
1026
1280
|
function pipelineSource(device, kernel) {
|
|
1281
|
+
const nullaryKernel = nullaryKernelSource(device, kernel);
|
|
1282
|
+
if (nullaryKernel) return nullaryKernel;
|
|
1027
1283
|
const tune = tuneWebgpu(kernel);
|
|
1028
1284
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
1029
1285
|
const { nargs, reduction: re } = kernel;
|
|
1030
1286
|
const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
|
|
1031
|
-
const
|
|
1032
|
-
|
|
1033
|
-
const pushIndent = Symbol("pushIndent");
|
|
1034
|
-
const popIndent = Symbol("popIndent");
|
|
1035
|
-
const emit = (...lines) => {
|
|
1036
|
-
for (const line of lines) if (line === pushIndent) indent += " ";
|
|
1037
|
-
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
1038
|
-
else shader.push(line ? indent + line : line);
|
|
1039
|
-
};
|
|
1040
|
-
if (tune.exp.some((exp) => exp.dtype === DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === DType.Float16)) {
|
|
1041
|
-
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
1042
|
-
emit("enable f16;");
|
|
1043
|
-
}
|
|
1044
|
-
emit(headerWgsl);
|
|
1045
|
-
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
1046
|
-
if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
|
|
1047
|
-
if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
|
|
1048
|
-
emit("");
|
|
1287
|
+
const wb = new WgslBuilder();
|
|
1288
|
+
wb.emitPreamble(device, [tune.exp, tune.epilogue]);
|
|
1049
1289
|
const usedArgs = Array.from({ length: nargs }, () => null);
|
|
1050
1290
|
tune.exp.fold((exp) => {
|
|
1051
1291
|
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
@@ -1055,127 +1295,33 @@ function pipelineSource(device, kernel) {
|
|
|
1055
1295
|
});
|
|
1056
1296
|
for (let i = 0; i < nargs; i++) {
|
|
1057
1297
|
const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
|
|
1058
|
-
emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
1298
|
+
wb.emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
1059
1299
|
}
|
|
1060
1300
|
const resultTy = dtypeToWgsl(kernel.dtype, true);
|
|
1061
|
-
emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
1301
|
+
wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
1062
1302
|
const workgroupSize = findPow2(tune.threadCount, 256);
|
|
1063
1303
|
const gridSize = Math.ceil(tune.threadCount / workgroupSize);
|
|
1064
1304
|
const [gridX, gridY] = calculateGrid(gridSize);
|
|
1065
|
-
emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
|
|
1066
|
-
if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
1305
|
+
wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
|
|
1306
|
+
if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
1067
1307
|
else {
|
|
1068
1308
|
const sizeX = gridX * workgroupSize;
|
|
1069
|
-
emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
1309
|
+
wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
1070
1310
|
}
|
|
1071
|
-
|
|
1072
|
-
const
|
|
1073
|
-
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
1074
|
-
if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
1075
|
-
const references = /* @__PURE__ */ new Map();
|
|
1076
|
-
const seen = /* @__PURE__ */ new Set();
|
|
1077
|
-
const countReferences = (exp) => {
|
|
1078
|
-
references.set(exp, (references.get(exp) ?? 0) + 1);
|
|
1079
|
-
if (!seen.has(exp)) {
|
|
1080
|
-
seen.add(exp);
|
|
1081
|
-
for (const src of exp.src) countReferences(src);
|
|
1082
|
-
}
|
|
1083
|
-
};
|
|
1084
|
-
const expContext = /* @__PURE__ */ new Map();
|
|
1085
|
-
const gen = (exp) => {
|
|
1086
|
-
if (expContext.has(exp)) return expContext.get(exp);
|
|
1087
|
-
const { op, src, dtype, arg } = exp;
|
|
1088
|
-
let source = "";
|
|
1089
|
-
if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
|
|
1090
|
-
const a = gen(src[0]);
|
|
1091
|
-
const b = gen(src[1]);
|
|
1092
|
-
if (op === AluOp.Add) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
1093
|
-
else source = `(${a} + ${b})`;
|
|
1094
|
-
else if (op === AluOp.Sub) source = `(${a} - ${b})`;
|
|
1095
|
-
else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
|
|
1096
|
-
else source = `(${a} * ${b})`;
|
|
1097
|
-
else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
1098
|
-
else if (op === AluOp.Mod) source = `(${a} % ${b})`;
|
|
1099
|
-
else if (op === AluOp.Min) if (dtype === DType.Bool) source = `(${a} && ${b})`;
|
|
1100
|
-
else source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
1101
|
-
else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
|
|
1102
|
-
else source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
1103
|
-
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
1104
|
-
else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
|
|
1105
|
-
const x = isGensym(a) ? a : gensym();
|
|
1106
|
-
if (x !== a) emit(`let ${x} = ${a};`);
|
|
1107
|
-
source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
|
|
1108
|
-
} else source = `(${a} != ${b})`;
|
|
1109
|
-
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
|
|
1110
|
-
const a = gen(src[0].src[0]);
|
|
1111
|
-
source = `inverseSqrt(${a})`;
|
|
1112
|
-
} else {
|
|
1113
|
-
const a = gen(src[0]);
|
|
1114
|
-
if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
|
|
1115
|
-
else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
|
|
1116
|
-
else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
|
|
1117
|
-
else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
|
|
1118
|
-
else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
|
|
1119
|
-
else if (op === AluOp.Log) source = `log(${strip1(a)})`;
|
|
1120
|
-
else if (op === AluOp.Erf || op === AluOp.Erfc) {
|
|
1121
|
-
const funcName = op === AluOp.Erf ? "erf" : "erfc";
|
|
1122
|
-
if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${strip1(a)})))`;
|
|
1123
|
-
else source = `${funcName}(${strip1(a)})`;
|
|
1124
|
-
} else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
|
|
1125
|
-
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
1126
|
-
else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
|
|
1127
|
-
else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
|
|
1128
|
-
else if (op === AluOp.Cast) {
|
|
1129
|
-
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
1130
|
-
const dstTy = dtypeToWgsl(dtype);
|
|
1131
|
-
if (isFloatDtype(src[0].dtype) && !(isFloatDtype(dtype) || dtype === DType.Bool)) {
|
|
1132
|
-
const maxVal = maxValueWgsl(dtype);
|
|
1133
|
-
const x = isGensym(a) ? a : gensym();
|
|
1134
|
-
if (x !== a) emit(`let ${x}: ${srcTy} = ${strip1(a)};`);
|
|
1135
|
-
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
1136
|
-
} else source = `${dstTy}(${strip1(a)})`;
|
|
1137
|
-
} else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
1138
|
-
}
|
|
1139
|
-
else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
|
|
1140
|
-
else if (op === AluOp.Threefry2x32) {
|
|
1141
|
-
const x = gensym();
|
|
1142
|
-
const [k0, k1, c0, c1] = src.map((x$1) => strip1(gen(x$1)));
|
|
1143
|
-
emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
|
|
1144
|
-
if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
|
|
1145
|
-
else if (arg === 0) source = `${x}.x`;
|
|
1146
|
-
else if (arg === 1) source = `${x}.y`;
|
|
1147
|
-
else throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
1148
|
-
} else if (op === AluOp.Const) return constToWgsl(dtype, arg);
|
|
1149
|
-
else if (op === AluOp.Special) return arg[0];
|
|
1150
|
-
else if (op === AluOp.Variable) return arg;
|
|
1151
|
-
else if (op === AluOp.GlobalIndex) {
|
|
1152
|
-
source = `${args[arg[0]]}[${strip1(gen(src[0]))}]`;
|
|
1153
|
-
if (dtype === DType.Bool) source = `(${source} != 0)`;
|
|
1154
|
-
}
|
|
1155
|
-
if (!source) throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
1156
|
-
const typeName = dtypeToWgsl(dtype);
|
|
1157
|
-
if ((references.get(exp) ?? 0) > 1) {
|
|
1158
|
-
const name = gensym();
|
|
1159
|
-
expContext.set(exp, name);
|
|
1160
|
-
emit(`let ${name}: ${typeName} = ${strip1(source)};`);
|
|
1161
|
-
return name;
|
|
1162
|
-
} else {
|
|
1163
|
-
expContext.set(exp, source);
|
|
1164
|
-
return source;
|
|
1165
|
-
}
|
|
1166
|
-
};
|
|
1311
|
+
wb.emitPhonyAssignments(args);
|
|
1312
|
+
const gen = new WgslExpCodegen(wb, args);
|
|
1167
1313
|
if (!re) {
|
|
1168
|
-
countReferences(tune.exp);
|
|
1169
|
-
let rhs = strip1(gen(tune.exp));
|
|
1314
|
+
gen.countReferences(tune.exp);
|
|
1315
|
+
let rhs = strip1(gen.run(tune.exp));
|
|
1170
1316
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1171
|
-
emit(`result[gidx] = ${rhs};`);
|
|
1317
|
+
wb.emit(`result[gidx] = ${rhs};`);
|
|
1172
1318
|
} else {
|
|
1173
1319
|
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
1174
1320
|
const unroll = tune.size.unroll ?? 1;
|
|
1175
1321
|
const upcast = tune.size.upcast ?? 1;
|
|
1176
1322
|
const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
|
|
1177
|
-
for (let i = 0; i < upcast; i++) emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
|
|
1178
|
-
emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, pushIndent);
|
|
1323
|
+
for (let i = 0; i < upcast; i++) wb.emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
|
|
1324
|
+
wb.emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, wb.pushIndent);
|
|
1179
1325
|
const exps = [];
|
|
1180
1326
|
const cache = /* @__PURE__ */ new Map();
|
|
1181
1327
|
for (let up = 0; up < upcast; up++) {
|
|
@@ -1186,10 +1332,10 @@ function pipelineSource(device, kernel) {
|
|
|
1186
1332
|
unroll: AluExp.i32(un)
|
|
1187
1333
|
});
|
|
1188
1334
|
exps[up].push(exp.simplify(cache));
|
|
1189
|
-
countReferences(exps[up][un]);
|
|
1335
|
+
gen.countReferences(exps[up][un]);
|
|
1190
1336
|
}
|
|
1191
1337
|
}
|
|
1192
|
-
const items = exps.map((ar) => ar.map(gen).map(strip1));
|
|
1338
|
+
const items = exps.map((ar) => ar.map((x) => gen.run(x)).map(strip1));
|
|
1193
1339
|
for (let i = 0; i < upcast; i++) {
|
|
1194
1340
|
let rhs = items[i][0];
|
|
1195
1341
|
for (let j = 1; j < unroll; j++) if (re.op === AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
|
|
@@ -1197,40 +1343,38 @@ function pipelineSource(device, kernel) {
|
|
|
1197
1343
|
else if (re.op === AluOp.Min) rhs = re.dtype === DType.Bool ? `(${rhs} && ${items[i][j]})` : `min(${rhs}, ${items[i][j]})`;
|
|
1198
1344
|
else if (re.op === AluOp.Max) rhs = re.dtype === DType.Bool ? `(${rhs} || ${items[i][j]})` : `max(${rhs}, ${items[i][j]})`;
|
|
1199
1345
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1200
|
-
if (re.op === AluOp.Add) emit(`${acc[i]} += ${rhs};`);
|
|
1201
|
-
else if (re.op === AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
|
|
1202
|
-
else if (re.op === AluOp.Min) if (re.dtype === DType.Bool) emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
|
|
1203
|
-
else emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
|
|
1204
|
-
else if (re.op === AluOp.Max) if (re.dtype === DType.Bool) emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
|
|
1205
|
-
else emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
|
|
1346
|
+
if (re.op === AluOp.Add) wb.emit(`${acc[i]} += ${rhs};`);
|
|
1347
|
+
else if (re.op === AluOp.Mul) wb.emit(`${acc[i]} *= ${rhs};`);
|
|
1348
|
+
else if (re.op === AluOp.Min) if (re.dtype === DType.Bool) wb.emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
|
|
1349
|
+
else wb.emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
|
|
1350
|
+
else if (re.op === AluOp.Max) if (re.dtype === DType.Bool) wb.emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
|
|
1351
|
+
else wb.emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
|
|
1206
1352
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1207
1353
|
}
|
|
1208
|
-
emit(popIndent, "}");
|
|
1209
|
-
|
|
1210
|
-
references.clear();
|
|
1211
|
-
seen.clear();
|
|
1354
|
+
wb.emit(wb.popIndent, "}");
|
|
1355
|
+
gen.reset();
|
|
1212
1356
|
const outputIdxExps = [];
|
|
1213
1357
|
const fusionExps = [];
|
|
1214
1358
|
for (let i = 0; i < upcast; i++) {
|
|
1215
1359
|
const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
|
|
1216
1360
|
outputIdxExps.push(exp.simplify(cache));
|
|
1217
|
-
countReferences(outputIdxExps[i]);
|
|
1361
|
+
gen.countReferences(outputIdxExps[i]);
|
|
1218
1362
|
fusionExps.push(tune.epilogue.substitute({
|
|
1219
1363
|
acc: AluExp.variable(re.dtype, acc[i]),
|
|
1220
1364
|
upcast: AluExp.i32(i)
|
|
1221
1365
|
}).simplify(cache));
|
|
1222
|
-
countReferences(fusionExps[i]);
|
|
1366
|
+
gen.countReferences(fusionExps[i]);
|
|
1223
1367
|
}
|
|
1224
1368
|
for (let i = 0; i < upcast; i++) {
|
|
1225
|
-
const index = strip1(gen(outputIdxExps[i]));
|
|
1226
|
-
let rhs = strip1(gen(fusionExps[i]));
|
|
1369
|
+
const index = strip1(gen.run(outputIdxExps[i]));
|
|
1370
|
+
let rhs = strip1(gen.run(fusionExps[i]));
|
|
1227
1371
|
if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1228
|
-
emit(`result[${index}] = ${rhs};`);
|
|
1372
|
+
wb.emit(`result[${index}] = ${rhs};`);
|
|
1229
1373
|
}
|
|
1230
1374
|
}
|
|
1231
|
-
emit(popIndent, "}");
|
|
1375
|
+
wb.emit(wb.popIndent, "}");
|
|
1232
1376
|
return {
|
|
1233
|
-
code:
|
|
1377
|
+
code: wb.toString(),
|
|
1234
1378
|
numInputs: nargs,
|
|
1235
1379
|
numOutputs: 1,
|
|
1236
1380
|
hasUniform: false,
|
|
@@ -1274,11 +1418,17 @@ function pipelineSubmit(device, exe, inputs, outputs) {
|
|
|
1274
1418
|
}
|
|
1275
1419
|
for (let i = 0; i < filteredPasses.length; i++) {
|
|
1276
1420
|
const { grid } = filteredPasses[i];
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1421
|
+
let timestampWrites;
|
|
1422
|
+
if (slot) {
|
|
1423
|
+
const isFirst = i === 0;
|
|
1424
|
+
const isLast = i === filteredPasses.length - 1;
|
|
1425
|
+
if (isFirst || isLast) timestampWrites = {
|
|
1426
|
+
querySet: slot.batch.querySet,
|
|
1427
|
+
...isFirst ? { beginningOfPassWriteIndex: slot.beginIndex } : {},
|
|
1428
|
+
...isLast ? { endOfPassWriteIndex: slot.endIndex } : {}
|
|
1429
|
+
};
|
|
1430
|
+
}
|
|
1431
|
+
const passEncoder = commandEncoder.beginComputePass({ timestampWrites });
|
|
1282
1432
|
passEncoder.setPipeline(pipeline);
|
|
1283
1433
|
passEncoder.setBindGroup(0, bindGroup);
|
|
1284
1434
|
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|