@jax-js/jax 0.1.11 → 0.1.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-DZvR7mZV.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-DI-V78Rk.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -72,6 +72,45 @@ const headerWgsl = String.raw`
72
72
  fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }
73
73
  fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }
74
74
  `.trim();
75
+ /** Builder class for simple program generation with WGSL. */
76
+ var WgslBuilder = class {
77
+ pushIndent = Symbol("pushIndent");
78
+ popIndent = Symbol("popIndent");
79
+ lines = [];
80
+ #indent = "";
81
+ emit(...lines) {
82
+ for (const line of lines) if (line === this.pushIndent) this.#indent += " ";
83
+ else if (line === this.popIndent) this.#indent = this.#indent.slice(0, -2);
84
+ else this.lines.push(line ? this.#indent + line : "");
85
+ }
86
+ emitPreamble(device, exps) {
87
+ let hasFloat16 = false;
88
+ let distinctOps = /* @__PURE__ */ new Map();
89
+ for (const exp of exps) {
90
+ if (exp == null) continue;
91
+ hasFloat16 ||= exp.some((e) => e.dtype === DType.Float16);
92
+ distinctOps = mapSetUnion(distinctOps, exp.distinctOps());
93
+ }
94
+ if (hasFloat16) {
95
+ if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
96
+ this.emit("enable f16;");
97
+ }
98
+ this.emit(headerWgsl);
99
+ if (distinctOps.has(AluOp.Threefry2x32)) this.emit(threefrySrc);
100
+ if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) this.emit(erfSrc);
101
+ this.emit("");
102
+ }
103
+ /**
104
+ * Insert phony assignments, in case some inputs are not in use.
105
+ * <https://github.com/gpuweb/gpuweb/discussions/4582#discussioncomment-9146686>
106
+ */
107
+ emitPhonyAssignments(args) {
108
+ if (args.length > 0) this.emit(args.map((arg) => `_ = &${arg};`).join(" "));
109
+ }
110
+ toString() {
111
+ return this.lines.join("\n");
112
+ }
113
+ };
75
114
  function dtypeToWgsl(dtype, storage = false) {
76
115
  switch (dtype) {
77
116
  case DType.Bool: return storage ? "i32" : "bool";
@@ -108,6 +147,136 @@ function constToWgsl(dtype, value) {
108
147
  }
109
148
  throw new Error(`Unsupported const dtype: ${dtype}`);
110
149
  }
150
+ /** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
151
+ var WgslExpCodegen = class {
152
+ #gensymCount = 0;
153
+ #references = /* @__PURE__ */ new Map();
154
+ #seen = /* @__PURE__ */ new Set();
155
+ #context = /* @__PURE__ */ new Map();
156
+ constructor(wb, args) {
157
+ this.wb = wb;
158
+ this.args = args;
159
+ }
160
+ #gensym() {
161
+ return `alu${this.#gensymCount++}`;
162
+ }
163
+ #isGensym(text) {
164
+ return text.match(/^alu[0-9]+$/);
165
+ }
166
+ /**
167
+ * Count references for an expression.
168
+ *
169
+ * Used to get an ahead-of-time reference count for each node in the AluExp.
170
+ * Expressions with reference count greater than 1 are stored in temporary
171
+ * variables to avoid recomputation.
172
+ */
173
+ countReferences(exp) {
174
+ this.#references.set(exp, (this.#references.get(exp) ?? 0) + 1);
175
+ if (!this.#seen.has(exp)) {
176
+ this.#seen.add(exp);
177
+ for (const src of exp.src) this.countReferences(src);
178
+ }
179
+ }
180
+ reset() {
181
+ this.#references.clear();
182
+ this.#seen.clear();
183
+ this.#context.clear();
184
+ }
185
+ /**
186
+ * Generate code for an expression.
187
+ *
188
+ * Calls itself recursively and eliminates common subexpressions by storing
189
+ * them in temporary variables, emitted to the current builder scope. This is
190
+ * a side-effect that leads to multiline code generation.
191
+ */
192
+ run(exp) {
193
+ if (this.#context.has(exp)) return this.#context.get(exp);
194
+ const { op, src, dtype, arg } = exp;
195
+ let source = "";
196
+ if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
197
+ const a = this.run(src[0]);
198
+ const b = this.run(src[1]);
199
+ if (op === AluOp.Add) if (dtype === DType.Bool) source = `(${a} || ${b})`;
200
+ else source = `(${a} + ${b})`;
201
+ else if (op === AluOp.Sub) source = `(${a} - ${b})`;
202
+ else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
203
+ else source = `(${a} * ${b})`;
204
+ else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
205
+ else if (op === AluOp.Mod) source = `(${a} % ${b})`;
206
+ else if (op === AluOp.Min) if (dtype === DType.Bool) source = `(${a} && ${b})`;
207
+ else source = `min(${strip1(a)}, ${strip1(b)})`;
208
+ else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
209
+ else source = `max(${strip1(a)}, ${strip1(b)})`;
210
+ else if (op === AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
211
+ else if (arg === "or") source = `(${a} | ${b})`;
212
+ else source = dtype === DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
213
+ else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
214
+ else source = `(${a} >> ${b})`;
215
+ else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
216
+ else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
217
+ const x = this.#isGensym(a) ? a : this.#gensym();
218
+ if (x !== a) this.wb.emit(`let ${x} = ${a};`);
219
+ source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
220
+ } else source = `(${a} != ${b})`;
221
+ } else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
222
+ const a = this.run(src[0].src[0]);
223
+ source = `inverseSqrt(${a})`;
224
+ } else {
225
+ const a = this.run(src[0]);
226
+ if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
227
+ else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
228
+ else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
229
+ else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
230
+ else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
231
+ else if (op === AluOp.Log) source = `log(${strip1(a)})`;
232
+ else if (op === AluOp.Erf || op === AluOp.Erfc) {
233
+ const funcName = op === AluOp.Erf ? "erf" : "erfc";
234
+ if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${strip1(a)})))`;
235
+ else source = `${funcName}(${strip1(a)})`;
236
+ } else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
237
+ else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
238
+ else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
239
+ else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
240
+ else if (op === AluOp.Cast) {
241
+ const srcTy = dtypeToWgsl(src[0].dtype);
242
+ const dstTy = dtypeToWgsl(dtype);
243
+ if (isFloatDtype(src[0].dtype) && !(isFloatDtype(dtype) || dtype === DType.Bool)) {
244
+ const maxVal = maxValueWgsl(dtype);
245
+ const x = this.#isGensym(a) ? a : this.#gensym();
246
+ if (x !== a) this.wb.emit(`let ${x}: ${srcTy} = ${strip1(a)};`);
247
+ source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
248
+ } else source = `${dstTy}(${strip1(a)})`;
249
+ } else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
250
+ }
251
+ else if (op === AluOp.Where) source = `select(${strip1(this.run(src[2]))}, ${strip1(this.run(src[1]))}, ${strip1(this.run(src[0]))})`;
252
+ else if (op === AluOp.Threefry2x32) {
253
+ const x = this.#gensym();
254
+ const [k0, k1, c0, c1] = src.map((x$1) => strip1(this.run(x$1)));
255
+ this.wb.emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
256
+ if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
257
+ else if (arg === 0) source = `${x}.x`;
258
+ else if (arg === 1) source = `${x}.y`;
259
+ else throw new UnsupportedOpError(op, dtype, "webgpu", arg);
260
+ } else if (op === AluOp.Const) return constToWgsl(dtype, arg);
261
+ else if (op === AluOp.Special) return arg[0];
262
+ else if (op === AluOp.Variable) return arg;
263
+ else if (op === AluOp.GlobalIndex) {
264
+ source = `${this.args[arg[0]]}[${strip1(this.run(src[0]))}]`;
265
+ if (dtype === DType.Bool) source = `(${source} != 0)`;
266
+ }
267
+ if (!source) throw new UnsupportedOpError(op, dtype, "webgpu", arg);
268
+ const typeName = dtypeToWgsl(dtype);
269
+ if ((this.#references.get(exp) ?? 0) > 1) {
270
+ const name = this.#gensym();
271
+ this.#context.set(exp, name);
272
+ this.wb.emit(`let ${name}: ${typeName} = ${strip1(source)};`);
273
+ return name;
274
+ } else {
275
+ this.#context.set(exp, source);
276
+ return source;
277
+ }
278
+ }
279
+ };
111
280
  const gridOffsetY = 16384;
112
281
  function calculateGrid(gridSize) {
113
282
  let gridX = gridSize;
@@ -119,6 +288,91 @@ function calculateGrid(gridSize) {
119
288
  return [gridX, gridY];
120
289
  }
121
290
 
291
+ //#endregion
292
+ //#region src/backend/webgpu/nullaryKernel.ts
293
+ function uniformDtype(dtype) {
294
+ if (dtype === DType.Float16) return DType.Float32;
295
+ if (dtype === DType.Bool) return DType.Int32;
296
+ return dtype;
297
+ }
298
+ function replacementFor({ name, dtype, uniformDtype: uniformDtype$1 }) {
299
+ const value = AluExp.variable(uniformDtype$1, `uniforms.${name}`);
300
+ if (dtype === DType.Float16) return AluExp.cast(DType.Float16, value);
301
+ if (dtype === DType.Bool) return AluExp.cmpne(value, AluExp.i32(0));
302
+ return value;
303
+ }
304
+ function writeUniform(view, offset, dtype, value) {
305
+ switch (dtype) {
306
+ case DType.Float32:
307
+ view.setFloat32(offset, value, true);
308
+ break;
309
+ case DType.Int32:
310
+ view.setInt32(offset, value, true);
311
+ break;
312
+ case DType.Uint32:
313
+ view.setUint32(offset, value, true);
314
+ break;
315
+ default: throw new Error(`Unsupported dtype for constant uniform: ${dtype}`);
316
+ }
317
+ }
318
+ function liftConstants(exp) {
319
+ const uniforms = [];
320
+ const lifted = exp.rewrite((node) => {
321
+ if (node.op !== AluOp.Const || node.arg === 0) return;
322
+ const uniform = {
323
+ name: `c${uniforms.length}`,
324
+ dtype: node.dtype,
325
+ uniformDtype: uniformDtype(node.dtype),
326
+ value: node.arg
327
+ };
328
+ uniforms.push(uniform);
329
+ return replacementFor(uniform);
330
+ });
331
+ return [lifted, uniforms];
332
+ }
333
+ function uniformsData(uniforms) {
334
+ const data = new Uint8Array(uniforms.length * 4);
335
+ const view = new DataView(data.buffer);
336
+ uniforms.forEach((u, i) => writeUniform(view, i * 4, u.uniformDtype, u.value));
337
+ return data;
338
+ }
339
+ function nullaryKernelSource(device, kernel) {
340
+ if (kernel.nargs !== 0 || kernel.reduction) return null;
341
+ let exp = kernel.exp.substitute({ gidx: AluExp.special(DType.Int32, "gidx", kernel.size) }).simplify();
342
+ let uniforms = [];
343
+ [exp, uniforms] = liftConstants(exp);
344
+ const wb = new WgslBuilder();
345
+ wb.emitPreamble(device, [exp]);
346
+ if (uniforms.length > 0) wb.emit("struct Uniforms {", wb.pushIndent, ...uniforms.map((u) => `${u.name}: ${dtypeToWgsl(u.uniformDtype)},`), wb.popIndent, "}\n");
347
+ const resultTy = dtypeToWgsl(kernel.dtype, true);
348
+ wb.emit(`@group(0) @binding(0) var<storage, read_write> result : array<${resultTy}>;`);
349
+ if (uniforms.length > 0) wb.emit(`@group(1) @binding(0) var<uniform> uniforms: Uniforms;`);
350
+ const workgroupSize = findPow2(kernel.size, 256);
351
+ const gridSize = Math.ceil(kernel.size / workgroupSize);
352
+ const [gridX, gridY] = calculateGrid(gridSize);
353
+ wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
354
+ if (gridY === 1) wb.emit(`if (id.x >= ${kernel.size}) { return; }`, "let gidx: i32 = i32(id.x);");
355
+ else {
356
+ const sizeX = gridX * workgroupSize;
357
+ wb.emit(`if (${sizeX} * id.y + id.x >= ${kernel.size}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
358
+ }
359
+ const gen = new WgslExpCodegen(wb, []);
360
+ gen.countReferences(exp);
361
+ let rhs = strip1(gen.run(exp));
362
+ if (resultTy !== dtypeToWgsl(exp.dtype)) rhs = `${resultTy}(${rhs})`;
363
+ wb.emit(`result[gidx] = ${rhs};`, wb.popIndent, "}");
364
+ return {
365
+ code: wb.toString(),
366
+ numInputs: 0,
367
+ numOutputs: 1,
368
+ hasUniform: uniforms.length > 0,
369
+ passes: [{
370
+ grid: [gridX, gridY],
371
+ uniform: uniforms.length > 0 ? uniformsData(uniforms) : void 0
372
+ }]
373
+ };
374
+ }
375
+
122
376
  //#endregion
123
377
  //#region src/backend/webgpu/reader.ts
124
378
  /**
@@ -1024,28 +1278,14 @@ var WebGPUBackend = class {
1024
1278
  * and y axes, to run the kernel.
1025
1279
  */
1026
1280
  function pipelineSource(device, kernel) {
1281
+ const nullaryKernel = nullaryKernelSource(device, kernel);
1282
+ if (nullaryKernel) return nullaryKernel;
1027
1283
  const tune = tuneWebgpu(kernel);
1028
1284
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
1029
1285
  const { nargs, reduction: re } = kernel;
1030
1286
  const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
1031
- const shader = [];
1032
- let indent = "";
1033
- const pushIndent = Symbol("pushIndent");
1034
- const popIndent = Symbol("popIndent");
1035
- const emit = (...lines) => {
1036
- for (const line of lines) if (line === pushIndent) indent += " ";
1037
- else if (line === popIndent) indent = indent.slice(0, -2);
1038
- else shader.push(line ? indent + line : line);
1039
- };
1040
- if (tune.exp.some((exp) => exp.dtype === DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === DType.Float16)) {
1041
- if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
1042
- emit("enable f16;");
1043
- }
1044
- emit(headerWgsl);
1045
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
1046
- if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
1047
- if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
1048
- emit("");
1287
+ const wb = new WgslBuilder();
1288
+ wb.emitPreamble(device, [tune.exp, tune.epilogue]);
1049
1289
  const usedArgs = Array.from({ length: nargs }, () => null);
1050
1290
  tune.exp.fold((exp) => {
1051
1291
  if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
@@ -1055,132 +1295,33 @@ function pipelineSource(device, kernel) {
1055
1295
  });
1056
1296
  for (let i = 0; i < nargs; i++) {
1057
1297
  const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
1058
- emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
1298
+ wb.emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
1059
1299
  }
1060
1300
  const resultTy = dtypeToWgsl(kernel.dtype, true);
1061
- emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
1301
+ wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
1062
1302
  const workgroupSize = findPow2(tune.threadCount, 256);
1063
1303
  const gridSize = Math.ceil(tune.threadCount / workgroupSize);
1064
1304
  const [gridX, gridY] = calculateGrid(gridSize);
1065
- emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
1066
- if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1305
+ wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
1306
+ if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1067
1307
  else {
1068
1308
  const sizeX = gridX * workgroupSize;
1069
- emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1309
+ wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1070
1310
  }
1071
- let gensymCount = 0;
1072
- const gensym = () => `alu${gensymCount++}`;
1073
- const isGensym = (text) => text.match(/^alu[0-9]+$/);
1074
- if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
1075
- const references = /* @__PURE__ */ new Map();
1076
- const seen = /* @__PURE__ */ new Set();
1077
- const countReferences = (exp) => {
1078
- references.set(exp, (references.get(exp) ?? 0) + 1);
1079
- if (!seen.has(exp)) {
1080
- seen.add(exp);
1081
- for (const src of exp.src) countReferences(src);
1082
- }
1083
- };
1084
- const expContext = /* @__PURE__ */ new Map();
1085
- const gen = (exp) => {
1086
- if (expContext.has(exp)) return expContext.get(exp);
1087
- const { op, src, dtype, arg } = exp;
1088
- let source = "";
1089
- if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
1090
- const a = gen(src[0]);
1091
- const b = gen(src[1]);
1092
- if (op === AluOp.Add) if (dtype === DType.Bool) source = `(${a} || ${b})`;
1093
- else source = `(${a} + ${b})`;
1094
- else if (op === AluOp.Sub) source = `(${a} - ${b})`;
1095
- else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
1096
- else source = `(${a} * ${b})`;
1097
- else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
1098
- else if (op === AluOp.Mod) source = `(${a} % ${b})`;
1099
- else if (op === AluOp.Min) if (dtype === DType.Bool) source = `(${a} && ${b})`;
1100
- else source = `min(${strip1(a)}, ${strip1(b)})`;
1101
- else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
1102
- else source = `max(${strip1(a)}, ${strip1(b)})`;
1103
- else if (op === AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
1104
- else if (arg === "or") source = `(${a} | ${b})`;
1105
- else source = dtype === DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
1106
- else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
1107
- else source = `(${a} >> ${b})`;
1108
- else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
1109
- else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
1110
- const x = isGensym(a) ? a : gensym();
1111
- if (x !== a) emit(`let ${x} = ${a};`);
1112
- source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
1113
- } else source = `(${a} != ${b})`;
1114
- } else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
1115
- const a = gen(src[0].src[0]);
1116
- source = `inverseSqrt(${a})`;
1117
- } else {
1118
- const a = gen(src[0]);
1119
- if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
1120
- else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
1121
- else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
1122
- else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
1123
- else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
1124
- else if (op === AluOp.Log) source = `log(${strip1(a)})`;
1125
- else if (op === AluOp.Erf || op === AluOp.Erfc) {
1126
- const funcName = op === AluOp.Erf ? "erf" : "erfc";
1127
- if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${strip1(a)})))`;
1128
- else source = `${funcName}(${strip1(a)})`;
1129
- } else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
1130
- else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
1131
- else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
1132
- else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
1133
- else if (op === AluOp.Cast) {
1134
- const srcTy = dtypeToWgsl(src[0].dtype);
1135
- const dstTy = dtypeToWgsl(dtype);
1136
- if (isFloatDtype(src[0].dtype) && !(isFloatDtype(dtype) || dtype === DType.Bool)) {
1137
- const maxVal = maxValueWgsl(dtype);
1138
- const x = isGensym(a) ? a : gensym();
1139
- if (x !== a) emit(`let ${x}: ${srcTy} = ${strip1(a)};`);
1140
- source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
1141
- } else source = `${dstTy}(${strip1(a)})`;
1142
- } else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
1143
- }
1144
- else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
1145
- else if (op === AluOp.Threefry2x32) {
1146
- const x = gensym();
1147
- const [k0, k1, c0, c1] = src.map((x$1) => strip1(gen(x$1)));
1148
- emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
1149
- if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
1150
- else if (arg === 0) source = `${x}.x`;
1151
- else if (arg === 1) source = `${x}.y`;
1152
- else throw new UnsupportedOpError(op, dtype, "webgpu", arg);
1153
- } else if (op === AluOp.Const) return constToWgsl(dtype, arg);
1154
- else if (op === AluOp.Special) return arg[0];
1155
- else if (op === AluOp.Variable) return arg;
1156
- else if (op === AluOp.GlobalIndex) {
1157
- source = `${args[arg[0]]}[${strip1(gen(src[0]))}]`;
1158
- if (dtype === DType.Bool) source = `(${source} != 0)`;
1159
- }
1160
- if (!source) throw new UnsupportedOpError(op, dtype, "webgpu", arg);
1161
- const typeName = dtypeToWgsl(dtype);
1162
- if ((references.get(exp) ?? 0) > 1) {
1163
- const name = gensym();
1164
- expContext.set(exp, name);
1165
- emit(`let ${name}: ${typeName} = ${strip1(source)};`);
1166
- return name;
1167
- } else {
1168
- expContext.set(exp, source);
1169
- return source;
1170
- }
1171
- };
1311
+ wb.emitPhonyAssignments(args);
1312
+ const gen = new WgslExpCodegen(wb, args);
1172
1313
  if (!re) {
1173
- countReferences(tune.exp);
1174
- let rhs = strip1(gen(tune.exp));
1314
+ gen.countReferences(tune.exp);
1315
+ let rhs = strip1(gen.run(tune.exp));
1175
1316
  if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
1176
- emit(`result[gidx] = ${rhs};`);
1317
+ wb.emit(`result[gidx] = ${rhs};`);
1177
1318
  } else {
1178
1319
  if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
1179
1320
  const unroll = tune.size.unroll ?? 1;
1180
1321
  const upcast = tune.size.upcast ?? 1;
1181
1322
  const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
1182
- for (let i = 0; i < upcast; i++) emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
1183
- emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, pushIndent);
1323
+ for (let i = 0; i < upcast; i++) wb.emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
1324
+ wb.emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, wb.pushIndent);
1184
1325
  const exps = [];
1185
1326
  const cache = /* @__PURE__ */ new Map();
1186
1327
  for (let up = 0; up < upcast; up++) {
@@ -1191,10 +1332,10 @@ function pipelineSource(device, kernel) {
1191
1332
  unroll: AluExp.i32(un)
1192
1333
  });
1193
1334
  exps[up].push(exp.simplify(cache));
1194
- countReferences(exps[up][un]);
1335
+ gen.countReferences(exps[up][un]);
1195
1336
  }
1196
1337
  }
1197
- const items = exps.map((ar) => ar.map(gen).map(strip1));
1338
+ const items = exps.map((ar) => ar.map((x) => gen.run(x)).map(strip1));
1198
1339
  for (let i = 0; i < upcast; i++) {
1199
1340
  let rhs = items[i][0];
1200
1341
  for (let j = 1; j < unroll; j++) if (re.op === AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
@@ -1202,40 +1343,38 @@ function pipelineSource(device, kernel) {
1202
1343
  else if (re.op === AluOp.Min) rhs = re.dtype === DType.Bool ? `(${rhs} && ${items[i][j]})` : `min(${rhs}, ${items[i][j]})`;
1203
1344
  else if (re.op === AluOp.Max) rhs = re.dtype === DType.Bool ? `(${rhs} || ${items[i][j]})` : `max(${rhs}, ${items[i][j]})`;
1204
1345
  else throw new Error(`Unsupported reduction op: ${re.op}`);
1205
- if (re.op === AluOp.Add) emit(`${acc[i]} += ${rhs};`);
1206
- else if (re.op === AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
1207
- else if (re.op === AluOp.Min) if (re.dtype === DType.Bool) emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
1208
- else emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
1209
- else if (re.op === AluOp.Max) if (re.dtype === DType.Bool) emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
1210
- else emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
1346
+ if (re.op === AluOp.Add) wb.emit(`${acc[i]} += ${rhs};`);
1347
+ else if (re.op === AluOp.Mul) wb.emit(`${acc[i]} *= ${rhs};`);
1348
+ else if (re.op === AluOp.Min) if (re.dtype === DType.Bool) wb.emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
1349
+ else wb.emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
1350
+ else if (re.op === AluOp.Max) if (re.dtype === DType.Bool) wb.emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
1351
+ else wb.emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
1211
1352
  else throw new Error(`Unsupported reduction op: ${re.op}`);
1212
1353
  }
1213
- emit(popIndent, "}");
1214
- expContext.clear();
1215
- references.clear();
1216
- seen.clear();
1354
+ wb.emit(wb.popIndent, "}");
1355
+ gen.reset();
1217
1356
  const outputIdxExps = [];
1218
1357
  const fusionExps = [];
1219
1358
  for (let i = 0; i < upcast; i++) {
1220
1359
  const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
1221
1360
  outputIdxExps.push(exp.simplify(cache));
1222
- countReferences(outputIdxExps[i]);
1361
+ gen.countReferences(outputIdxExps[i]);
1223
1362
  fusionExps.push(tune.epilogue.substitute({
1224
1363
  acc: AluExp.variable(re.dtype, acc[i]),
1225
1364
  upcast: AluExp.i32(i)
1226
1365
  }).simplify(cache));
1227
- countReferences(fusionExps[i]);
1366
+ gen.countReferences(fusionExps[i]);
1228
1367
  }
1229
1368
  for (let i = 0; i < upcast; i++) {
1230
- const index = strip1(gen(outputIdxExps[i]));
1231
- let rhs = strip1(gen(fusionExps[i]));
1369
+ const index = strip1(gen.run(outputIdxExps[i]));
1370
+ let rhs = strip1(gen.run(fusionExps[i]));
1232
1371
  if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
1233
- emit(`result[${index}] = ${rhs};`);
1372
+ wb.emit(`result[${index}] = ${rhs};`);
1234
1373
  }
1235
1374
  }
1236
- emit(popIndent, "}");
1375
+ wb.emit(wb.popIndent, "}");
1237
1376
  return {
1238
- code: shader.join("\n"),
1377
+ code: wb.toString(),
1239
1378
  numInputs: nargs,
1240
1379
  numOutputs: 1,
1241
1380
  hasUniform: false,
@@ -1279,11 +1418,17 @@ function pipelineSubmit(device, exe, inputs, outputs) {
1279
1418
  }
1280
1419
  for (let i = 0; i < filteredPasses.length; i++) {
1281
1420
  const { grid } = filteredPasses[i];
1282
- const passEncoder = commandEncoder.beginComputePass({ timestampWrites: slot ? {
1283
- querySet: slot.batch.querySet,
1284
- beginningOfPassWriteIndex: i === 0 ? slot.beginIndex : void 0,
1285
- endOfPassWriteIndex: i === filteredPasses.length - 1 ? slot.endIndex : void 0
1286
- } : void 0 });
1421
+ let timestampWrites;
1422
+ if (slot) {
1423
+ const isFirst = i === 0;
1424
+ const isLast = i === filteredPasses.length - 1;
1425
+ if (isFirst || isLast) timestampWrites = {
1426
+ querySet: slot.batch.querySet,
1427
+ ...isFirst ? { beginningOfPassWriteIndex: slot.beginIndex } : {},
1428
+ ...isLast ? { endOfPassWriteIndex: slot.endIndex } : {}
1429
+ };
1430
+ }
1431
+ const passEncoder = commandEncoder.beginComputePass({ timestampWrites });
1287
1432
  passEncoder.setPipeline(pipeline);
1288
1433
  passEncoder.setBindGroup(0, bindGroup);
1289
1434
  if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);