@jax-js/jax 0.1.11 → 0.1.13
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +4 -1
- package/dist/{backend-DZvR7mZV.js → backend-DLEk-B3V.js} +5 -3
- package/dist/{backend-DlYlOYqN.cjs → backend-DMyuoWi2.cjs} +5 -3
- package/dist/index.cjs +233 -18
- package/dist/index.d.cts +106 -1
- package/dist/index.d.ts +106 -1
- package/dist/index.js +233 -18
- package/dist/{webgl-D8-14NzA.js → webgl-NsFtyIts.js} +1 -1
- package/dist/{webgl-Ovaaa-Qx.cjs → webgl-pbfUGDA6.cjs} +1 -1
- package/dist/{webgpu-uU9nnttc.cjs → webgpu-DDGCYtHa.cjs} +338 -176
- package/dist/{webgpu-Dg8FpYrH.js → webgpu-NkF1TZ0t.js} +338 -176
- package/package.json +1 -1
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-DMyuoWi2.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -72,6 +72,45 @@ const headerWgsl = String.raw`
|
|
|
72
72
|
fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }
|
|
73
73
|
fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }
|
|
74
74
|
`.trim();
|
|
75
|
+
/** Builder class for simple program generation with WGSL. */
|
|
76
|
+
var WgslBuilder = class {
|
|
77
|
+
pushIndent = Symbol("pushIndent");
|
|
78
|
+
popIndent = Symbol("popIndent");
|
|
79
|
+
lines = [];
|
|
80
|
+
#indent = "";
|
|
81
|
+
emit(...lines) {
|
|
82
|
+
for (const line of lines) if (line === this.pushIndent) this.#indent += " ";
|
|
83
|
+
else if (line === this.popIndent) this.#indent = this.#indent.slice(0, -2);
|
|
84
|
+
else this.lines.push(line ? this.#indent + line : "");
|
|
85
|
+
}
|
|
86
|
+
emitPreamble(device, exps) {
|
|
87
|
+
let hasFloat16 = false;
|
|
88
|
+
let distinctOps = /* @__PURE__ */ new Map();
|
|
89
|
+
for (const exp of exps) {
|
|
90
|
+
if (exp == null) continue;
|
|
91
|
+
hasFloat16 ||= exp.some((e) => e.dtype === require_backend.DType.Float16);
|
|
92
|
+
distinctOps = require_backend.mapSetUnion(distinctOps, exp.distinctOps());
|
|
93
|
+
}
|
|
94
|
+
if (hasFloat16) {
|
|
95
|
+
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
96
|
+
this.emit("enable f16;");
|
|
97
|
+
}
|
|
98
|
+
this.emit(headerWgsl);
|
|
99
|
+
if (distinctOps.has(require_backend.AluOp.Threefry2x32)) this.emit(threefrySrc);
|
|
100
|
+
if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) this.emit(erfSrc);
|
|
101
|
+
this.emit("");
|
|
102
|
+
}
|
|
103
|
+
/**
|
|
104
|
+
* Insert phony assignments, in case some inputs are not in use.
|
|
105
|
+
* <https://github.com/gpuweb/gpuweb/discussions/4582#discussioncomment-9146686>
|
|
106
|
+
*/
|
|
107
|
+
emitPhonyAssignments(args) {
|
|
108
|
+
if (args.length > 0) this.emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
109
|
+
}
|
|
110
|
+
toString() {
|
|
111
|
+
return this.lines.join("\n");
|
|
112
|
+
}
|
|
113
|
+
};
|
|
75
114
|
function dtypeToWgsl(dtype, storage = false) {
|
|
76
115
|
switch (dtype) {
|
|
77
116
|
case require_backend.DType.Bool: return storage ? "i32" : "bool";
|
|
@@ -108,6 +147,136 @@ function constToWgsl(dtype, value) {
|
|
|
108
147
|
}
|
|
109
148
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
110
149
|
}
|
|
150
|
+
/** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
|
|
151
|
+
var WgslExpCodegen = class {
|
|
152
|
+
#gensymCount = 0;
|
|
153
|
+
#references = /* @__PURE__ */ new Map();
|
|
154
|
+
#seen = /* @__PURE__ */ new Set();
|
|
155
|
+
#context = /* @__PURE__ */ new Map();
|
|
156
|
+
constructor(wb, args) {
|
|
157
|
+
this.wb = wb;
|
|
158
|
+
this.args = args;
|
|
159
|
+
}
|
|
160
|
+
#gensym() {
|
|
161
|
+
return `alu${this.#gensymCount++}`;
|
|
162
|
+
}
|
|
163
|
+
#isGensym(text) {
|
|
164
|
+
return text.match(/^alu[0-9]+$/);
|
|
165
|
+
}
|
|
166
|
+
/**
|
|
167
|
+
* Count references for an expression.
|
|
168
|
+
*
|
|
169
|
+
* Used to get an ahead-of-time reference count for each node in the AluExp.
|
|
170
|
+
* Expressions with reference count greater than 1 are stored in temporary
|
|
171
|
+
* variables to avoid recomputation.
|
|
172
|
+
*/
|
|
173
|
+
countReferences(exp) {
|
|
174
|
+
this.#references.set(exp, (this.#references.get(exp) ?? 0) + 1);
|
|
175
|
+
if (!this.#seen.has(exp)) {
|
|
176
|
+
this.#seen.add(exp);
|
|
177
|
+
for (const src of exp.src) this.countReferences(src);
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
reset() {
|
|
181
|
+
this.#references.clear();
|
|
182
|
+
this.#seen.clear();
|
|
183
|
+
this.#context.clear();
|
|
184
|
+
}
|
|
185
|
+
/**
|
|
186
|
+
* Generate code for an expression.
|
|
187
|
+
*
|
|
188
|
+
* Calls itself recursively and eliminates common subexpressions by storing
|
|
189
|
+
* them in temporary variables, emitted to the current builder scope. This is
|
|
190
|
+
* a side-effect that leads to multiline code generation.
|
|
191
|
+
*/
|
|
192
|
+
run(exp) {
|
|
193
|
+
if (this.#context.has(exp)) return this.#context.get(exp);
|
|
194
|
+
const { op, src, dtype, arg } = exp;
|
|
195
|
+
let source = "";
|
|
196
|
+
if (require_backend.AluGroup.Binary.has(op) || require_backend.AluGroup.Compare.has(op)) {
|
|
197
|
+
const a = this.run(src[0]);
|
|
198
|
+
const b = this.run(src[1]);
|
|
199
|
+
if (op === require_backend.AluOp.Add) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
200
|
+
else source = `(${a} + ${b})`;
|
|
201
|
+
else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
|
|
202
|
+
else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
|
|
203
|
+
else source = `(${a} * ${b})`;
|
|
204
|
+
else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
205
|
+
else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
|
|
206
|
+
else if (op === require_backend.AluOp.Min) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
|
|
207
|
+
else source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
208
|
+
else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
209
|
+
else source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
210
|
+
else if (op === require_backend.AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
|
|
211
|
+
else if (arg === "or") source = `(${a} | ${b})`;
|
|
212
|
+
else source = dtype === require_backend.DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
|
|
213
|
+
else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
214
|
+
else source = `(${a} >> ${b})`;
|
|
215
|
+
else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
216
|
+
else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
|
|
217
|
+
const x = this.#isGensym(a) ? a : this.#gensym();
|
|
218
|
+
if (x !== a) this.wb.emit(`let ${x} = ${a};`);
|
|
219
|
+
source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
|
|
220
|
+
} else source = `(${a} != ${b})`;
|
|
221
|
+
} else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
|
|
222
|
+
const a = this.run(src[0].src[0]);
|
|
223
|
+
source = `inverseSqrt(${a})`;
|
|
224
|
+
} else {
|
|
225
|
+
const a = this.run(src[0]);
|
|
226
|
+
if (op === require_backend.AluOp.Sin) source = `sin(${require_backend.strip1(a)})`;
|
|
227
|
+
else if (op === require_backend.AluOp.Cos) source = `cos(${require_backend.strip1(a)})`;
|
|
228
|
+
else if (op === require_backend.AluOp.Asin) source = `asin(${require_backend.strip1(a)})`;
|
|
229
|
+
else if (op === require_backend.AluOp.Atan) source = `atan(${require_backend.strip1(a)})`;
|
|
230
|
+
else if (op === require_backend.AluOp.Exp) source = `exp(${require_backend.strip1(a)})`;
|
|
231
|
+
else if (op === require_backend.AluOp.Log) source = `log(${require_backend.strip1(a)})`;
|
|
232
|
+
else if (op === require_backend.AluOp.Erf || op === require_backend.AluOp.Erfc) {
|
|
233
|
+
const funcName = op === require_backend.AluOp.Erf ? "erf" : "erfc";
|
|
234
|
+
if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${require_backend.strip1(a)})))`;
|
|
235
|
+
else source = `${funcName}(${require_backend.strip1(a)})`;
|
|
236
|
+
} else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${require_backend.strip1(a)})`;
|
|
237
|
+
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
238
|
+
else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
|
|
239
|
+
else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
|
|
240
|
+
else if (op === require_backend.AluOp.Cast) {
|
|
241
|
+
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
242
|
+
const dstTy = dtypeToWgsl(dtype);
|
|
243
|
+
if (require_backend.isFloatDtype(src[0].dtype) && !(require_backend.isFloatDtype(dtype) || dtype === require_backend.DType.Bool)) {
|
|
244
|
+
const maxVal = maxValueWgsl(dtype);
|
|
245
|
+
const x = this.#isGensym(a) ? a : this.#gensym();
|
|
246
|
+
if (x !== a) this.wb.emit(`let ${x}: ${srcTy} = ${require_backend.strip1(a)};`);
|
|
247
|
+
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
248
|
+
} else source = `${dstTy}(${require_backend.strip1(a)})`;
|
|
249
|
+
} else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
250
|
+
}
|
|
251
|
+
else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(this.run(src[2]))}, ${require_backend.strip1(this.run(src[1]))}, ${require_backend.strip1(this.run(src[0]))})`;
|
|
252
|
+
else if (op === require_backend.AluOp.Threefry2x32) {
|
|
253
|
+
const x = this.#gensym();
|
|
254
|
+
const [k0, k1, c0, c1] = src.map((x$1) => require_backend.strip1(this.run(x$1)));
|
|
255
|
+
this.wb.emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
|
|
256
|
+
if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
|
|
257
|
+
else if (arg === 0) source = `${x}.x`;
|
|
258
|
+
else if (arg === 1) source = `${x}.y`;
|
|
259
|
+
else throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
260
|
+
} else if (op === require_backend.AluOp.Const) return constToWgsl(dtype, arg);
|
|
261
|
+
else if (op === require_backend.AluOp.Special) return arg[0];
|
|
262
|
+
else if (op === require_backend.AluOp.Variable) return arg;
|
|
263
|
+
else if (op === require_backend.AluOp.GlobalIndex) {
|
|
264
|
+
source = `${this.args[arg[0]]}[${require_backend.strip1(this.run(src[0]))}]`;
|
|
265
|
+
if (dtype === require_backend.DType.Bool) source = `(${source} != 0)`;
|
|
266
|
+
}
|
|
267
|
+
if (!source) throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
268
|
+
const typeName = dtypeToWgsl(dtype);
|
|
269
|
+
if ((this.#references.get(exp) ?? 0) > 1) {
|
|
270
|
+
const name = this.#gensym();
|
|
271
|
+
this.#context.set(exp, name);
|
|
272
|
+
this.wb.emit(`let ${name}: ${typeName} = ${require_backend.strip1(source)};`);
|
|
273
|
+
return name;
|
|
274
|
+
} else {
|
|
275
|
+
this.#context.set(exp, source);
|
|
276
|
+
return source;
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
};
|
|
111
280
|
const gridOffsetY = 16384;
|
|
112
281
|
function calculateGrid(gridSize) {
|
|
113
282
|
let gridX = gridSize;
|
|
@@ -119,6 +288,91 @@ function calculateGrid(gridSize) {
|
|
|
119
288
|
return [gridX, gridY];
|
|
120
289
|
}
|
|
121
290
|
|
|
291
|
+
//#endregion
|
|
292
|
+
//#region src/backend/webgpu/nullaryKernel.ts
|
|
293
|
+
function uniformDtype(dtype) {
|
|
294
|
+
if (dtype === require_backend.DType.Float16) return require_backend.DType.Float32;
|
|
295
|
+
if (dtype === require_backend.DType.Bool) return require_backend.DType.Int32;
|
|
296
|
+
return dtype;
|
|
297
|
+
}
|
|
298
|
+
function replacementFor({ name, dtype, uniformDtype: uniformDtype$1 }) {
|
|
299
|
+
const value = require_backend.AluExp.variable(uniformDtype$1, `uniforms.${name}`);
|
|
300
|
+
if (dtype === require_backend.DType.Float16) return require_backend.AluExp.cast(require_backend.DType.Float16, value);
|
|
301
|
+
if (dtype === require_backend.DType.Bool) return require_backend.AluExp.cmpne(value, require_backend.AluExp.i32(0));
|
|
302
|
+
return value;
|
|
303
|
+
}
|
|
304
|
+
function writeUniform(view, offset, dtype, value) {
|
|
305
|
+
switch (dtype) {
|
|
306
|
+
case require_backend.DType.Float32:
|
|
307
|
+
view.setFloat32(offset, value, true);
|
|
308
|
+
break;
|
|
309
|
+
case require_backend.DType.Int32:
|
|
310
|
+
view.setInt32(offset, value, true);
|
|
311
|
+
break;
|
|
312
|
+
case require_backend.DType.Uint32:
|
|
313
|
+
view.setUint32(offset, value, true);
|
|
314
|
+
break;
|
|
315
|
+
default: throw new Error(`Unsupported dtype for constant uniform: ${dtype}`);
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
function liftConstants(exp) {
|
|
319
|
+
const uniforms = [];
|
|
320
|
+
const lifted = exp.rewrite((node) => {
|
|
321
|
+
if (node.op !== require_backend.AluOp.Const || node.arg === 0) return;
|
|
322
|
+
const uniform = {
|
|
323
|
+
name: `c${uniforms.length}`,
|
|
324
|
+
dtype: node.dtype,
|
|
325
|
+
uniformDtype: uniformDtype(node.dtype),
|
|
326
|
+
value: node.arg
|
|
327
|
+
};
|
|
328
|
+
uniforms.push(uniform);
|
|
329
|
+
return replacementFor(uniform);
|
|
330
|
+
});
|
|
331
|
+
return [lifted, uniforms];
|
|
332
|
+
}
|
|
333
|
+
function uniformsData(uniforms) {
|
|
334
|
+
const data = new Uint8Array(uniforms.length * 4);
|
|
335
|
+
const view = new DataView(data.buffer);
|
|
336
|
+
uniforms.forEach((u, i) => writeUniform(view, i * 4, u.uniformDtype, u.value));
|
|
337
|
+
return data;
|
|
338
|
+
}
|
|
339
|
+
function nullaryKernelSource(device, kernel) {
|
|
340
|
+
if (kernel.nargs !== 0 || kernel.reduction) return null;
|
|
341
|
+
let exp = kernel.exp.substitute({ gidx: require_backend.AluExp.special(require_backend.DType.Int32, "gidx", kernel.size) }).simplify();
|
|
342
|
+
let uniforms = [];
|
|
343
|
+
[exp, uniforms] = liftConstants(exp);
|
|
344
|
+
const wb = new WgslBuilder();
|
|
345
|
+
wb.emitPreamble(device, [exp]);
|
|
346
|
+
if (uniforms.length > 0) wb.emit("struct Uniforms {", wb.pushIndent, ...uniforms.map((u) => `${u.name}: ${dtypeToWgsl(u.uniformDtype)},`), wb.popIndent, "}\n");
|
|
347
|
+
const resultTy = dtypeToWgsl(kernel.dtype, true);
|
|
348
|
+
wb.emit(`@group(0) @binding(0) var<storage, read_write> result : array<${resultTy}>;`);
|
|
349
|
+
if (uniforms.length > 0) wb.emit(`@group(1) @binding(0) var<uniform> uniforms: Uniforms;`);
|
|
350
|
+
const workgroupSize = require_backend.findPow2(kernel.size, 256);
|
|
351
|
+
const gridSize = Math.ceil(kernel.size / workgroupSize);
|
|
352
|
+
const [gridX, gridY] = calculateGrid(gridSize);
|
|
353
|
+
wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
|
|
354
|
+
if (gridY === 1) wb.emit(`if (id.x >= ${kernel.size}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
355
|
+
else {
|
|
356
|
+
const sizeX = gridX * workgroupSize;
|
|
357
|
+
wb.emit(`if (${sizeX} * id.y + id.x >= ${kernel.size}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
358
|
+
}
|
|
359
|
+
const gen = new WgslExpCodegen(wb, []);
|
|
360
|
+
gen.countReferences(exp);
|
|
361
|
+
let rhs = require_backend.strip1(gen.run(exp));
|
|
362
|
+
if (resultTy !== dtypeToWgsl(exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
363
|
+
wb.emit(`result[gidx] = ${rhs};`, wb.popIndent, "}");
|
|
364
|
+
return {
|
|
365
|
+
code: wb.toString(),
|
|
366
|
+
numInputs: 0,
|
|
367
|
+
numOutputs: 1,
|
|
368
|
+
hasUniform: uniforms.length > 0,
|
|
369
|
+
passes: [{
|
|
370
|
+
grid: [gridX, gridY],
|
|
371
|
+
uniform: uniforms.length > 0 ? uniformsData(uniforms) : void 0
|
|
372
|
+
}]
|
|
373
|
+
};
|
|
374
|
+
}
|
|
375
|
+
|
|
122
376
|
//#endregion
|
|
123
377
|
//#region src/backend/webgpu/reader.ts
|
|
124
378
|
/**
|
|
@@ -845,6 +1099,8 @@ function flushTracingBatch(device, batch) {
|
|
|
845
1099
|
|
|
846
1100
|
//#endregion
|
|
847
1101
|
//#region src/backend/webgpu.ts
|
|
1102
|
+
const MAX_REUSABLE_BUFFER_BYTES = 64 * 1024 * 1024;
|
|
1103
|
+
const MAX_REUSABLE_BUFFERS_PER_SIZE = 64;
|
|
848
1104
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
849
1105
|
var WebGPUBackend = class {
|
|
850
1106
|
type = "webgpu";
|
|
@@ -855,6 +1111,7 @@ var WebGPUBackend = class {
|
|
|
855
1111
|
nextSlot;
|
|
856
1112
|
#cachedShaderMap = /* @__PURE__ */ new Map();
|
|
857
1113
|
#reusableZsb;
|
|
1114
|
+
#bufferPool = /* @__PURE__ */ new Map();
|
|
858
1115
|
constructor(device) {
|
|
859
1116
|
this.device = device;
|
|
860
1117
|
if (require_backend.DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
|
|
@@ -869,31 +1126,22 @@ var WebGPUBackend = class {
|
|
|
869
1126
|
});
|
|
870
1127
|
}
|
|
871
1128
|
malloc(size, initialData) {
|
|
872
|
-
|
|
873
|
-
const
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
884
|
-
else {
|
|
885
|
-
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
886
|
-
this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
887
|
-
const remainder = new Uint8Array(4);
|
|
888
|
-
remainder.set(initialData.subarray(aligned));
|
|
889
|
-
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
890
|
-
}
|
|
891
|
-
}
|
|
892
|
-
} else buffer = this.#createBuffer(paddedSize);
|
|
1129
|
+
if (initialData && initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
1130
|
+
const allocatedSize = Math.ceil(size / 4) * 4 || 4;
|
|
1131
|
+
const buffer = size === 0 ? this.#reusableZsb : this.#acquireBuffer(allocatedSize);
|
|
1132
|
+
if (initialData && size > 0) if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
1133
|
+
else {
|
|
1134
|
+
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
1135
|
+
if (aligned > 0) this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
1136
|
+
const remainder = new Uint8Array(4);
|
|
1137
|
+
remainder.set(initialData.subarray(aligned));
|
|
1138
|
+
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
1139
|
+
}
|
|
893
1140
|
const slot = this.nextSlot++;
|
|
894
1141
|
this.buffers.set(slot, {
|
|
895
1142
|
buffer,
|
|
896
1143
|
size,
|
|
1144
|
+
allocatedSize,
|
|
897
1145
|
ref: 1
|
|
898
1146
|
});
|
|
899
1147
|
return slot;
|
|
@@ -909,7 +1157,7 @@ var WebGPUBackend = class {
|
|
|
909
1157
|
buffer.ref--;
|
|
910
1158
|
if (buffer.ref === 0) {
|
|
911
1159
|
this.buffers.delete(slot);
|
|
912
|
-
if (buffer.buffer !== this.#reusableZsb) buffer.buffer.
|
|
1160
|
+
if (buffer.buffer !== this.#reusableZsb) this.#releaseBuffer(buffer.buffer, buffer.allocatedSize);
|
|
913
1161
|
}
|
|
914
1162
|
}
|
|
915
1163
|
async read(slot, start, count) {
|
|
@@ -997,6 +1245,29 @@ var WebGPUBackend = class {
|
|
|
997
1245
|
size: buffer.size
|
|
998
1246
|
};
|
|
999
1247
|
}
|
|
1248
|
+
#acquireBuffer(size) {
|
|
1249
|
+
if (size > MAX_REUSABLE_BUFFER_BYTES) return this.#createBuffer(size);
|
|
1250
|
+
const bucket = this.#bufferPool.get(size);
|
|
1251
|
+
const buffer = bucket?.pop();
|
|
1252
|
+
if (bucket && bucket.length === 0) this.#bufferPool.delete(size);
|
|
1253
|
+
return buffer ?? this.#createBuffer(size);
|
|
1254
|
+
}
|
|
1255
|
+
#releaseBuffer(buffer, size) {
|
|
1256
|
+
if (size > MAX_REUSABLE_BUFFER_BYTES) {
|
|
1257
|
+
buffer.destroy();
|
|
1258
|
+
return;
|
|
1259
|
+
}
|
|
1260
|
+
const bucket = this.#bufferPool.get(size);
|
|
1261
|
+
if (!bucket) {
|
|
1262
|
+
this.#bufferPool.set(size, [buffer]);
|
|
1263
|
+
return;
|
|
1264
|
+
}
|
|
1265
|
+
if (bucket.length >= MAX_REUSABLE_BUFFERS_PER_SIZE) {
|
|
1266
|
+
buffer.destroy();
|
|
1267
|
+
return;
|
|
1268
|
+
}
|
|
1269
|
+
bucket.push(buffer);
|
|
1270
|
+
}
|
|
1000
1271
|
/**
|
|
1001
1272
|
* Create a GPU buffer.
|
|
1002
1273
|
*
|
|
@@ -1024,28 +1295,14 @@ var WebGPUBackend = class {
|
|
|
1024
1295
|
* and y axes, to run the kernel.
|
|
1025
1296
|
*/
|
|
1026
1297
|
function pipelineSource(device, kernel) {
|
|
1298
|
+
const nullaryKernel = nullaryKernelSource(device, kernel);
|
|
1299
|
+
if (nullaryKernel) return nullaryKernel;
|
|
1027
1300
|
const tune = require_backend.tuneWebgpu(kernel);
|
|
1028
1301
|
if (require_backend.DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
1029
1302
|
const { nargs, reduction: re } = kernel;
|
|
1030
1303
|
const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
|
|
1031
|
-
const
|
|
1032
|
-
|
|
1033
|
-
const pushIndent = Symbol("pushIndent");
|
|
1034
|
-
const popIndent = Symbol("popIndent");
|
|
1035
|
-
const emit = (...lines) => {
|
|
1036
|
-
for (const line of lines) if (line === pushIndent) indent += " ";
|
|
1037
|
-
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
1038
|
-
else shader.push(line ? indent + line : line);
|
|
1039
|
-
};
|
|
1040
|
-
if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === require_backend.DType.Float16)) {
|
|
1041
|
-
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
1042
|
-
emit("enable f16;");
|
|
1043
|
-
}
|
|
1044
|
-
emit(headerWgsl);
|
|
1045
|
-
const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
1046
|
-
if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
|
|
1047
|
-
if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
|
|
1048
|
-
emit("");
|
|
1304
|
+
const wb = new WgslBuilder();
|
|
1305
|
+
wb.emitPreamble(device, [tune.exp, tune.epilogue]);
|
|
1049
1306
|
const usedArgs = Array.from({ length: nargs }, () => null);
|
|
1050
1307
|
tune.exp.fold((exp) => {
|
|
1051
1308
|
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
@@ -1055,132 +1312,33 @@ function pipelineSource(device, kernel) {
|
|
|
1055
1312
|
});
|
|
1056
1313
|
for (let i = 0; i < nargs; i++) {
|
|
1057
1314
|
const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
|
|
1058
|
-
emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
1315
|
+
wb.emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
|
|
1059
1316
|
}
|
|
1060
1317
|
const resultTy = dtypeToWgsl(kernel.dtype, true);
|
|
1061
|
-
emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
1318
|
+
wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
1062
1319
|
const workgroupSize = require_backend.findPow2(tune.threadCount, 256);
|
|
1063
1320
|
const gridSize = Math.ceil(tune.threadCount / workgroupSize);
|
|
1064
1321
|
const [gridX, gridY] = calculateGrid(gridSize);
|
|
1065
|
-
emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", pushIndent);
|
|
1066
|
-
if (gridY === 1) emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
1322
|
+
wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
|
|
1323
|
+
if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
1067
1324
|
else {
|
|
1068
1325
|
const sizeX = gridX * workgroupSize;
|
|
1069
|
-
emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
1326
|
+
wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
1070
1327
|
}
|
|
1071
|
-
|
|
1072
|
-
const
|
|
1073
|
-
const isGensym = (text) => text.match(/^alu[0-9]+$/);
|
|
1074
|
-
if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
|
|
1075
|
-
const references = /* @__PURE__ */ new Map();
|
|
1076
|
-
const seen = /* @__PURE__ */ new Set();
|
|
1077
|
-
const countReferences = (exp) => {
|
|
1078
|
-
references.set(exp, (references.get(exp) ?? 0) + 1);
|
|
1079
|
-
if (!seen.has(exp)) {
|
|
1080
|
-
seen.add(exp);
|
|
1081
|
-
for (const src of exp.src) countReferences(src);
|
|
1082
|
-
}
|
|
1083
|
-
};
|
|
1084
|
-
const expContext = /* @__PURE__ */ new Map();
|
|
1085
|
-
const gen = (exp) => {
|
|
1086
|
-
if (expContext.has(exp)) return expContext.get(exp);
|
|
1087
|
-
const { op, src, dtype, arg } = exp;
|
|
1088
|
-
let source = "";
|
|
1089
|
-
if (require_backend.AluGroup.Binary.has(op) || require_backend.AluGroup.Compare.has(op)) {
|
|
1090
|
-
const a = gen(src[0]);
|
|
1091
|
-
const b = gen(src[1]);
|
|
1092
|
-
if (op === require_backend.AluOp.Add) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
1093
|
-
else source = `(${a} + ${b})`;
|
|
1094
|
-
else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
|
|
1095
|
-
else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
|
|
1096
|
-
else source = `(${a} * ${b})`;
|
|
1097
|
-
else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
1098
|
-
else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
|
|
1099
|
-
else if (op === require_backend.AluOp.Min) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
|
|
1100
|
-
else source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
1101
|
-
else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
|
|
1102
|
-
else source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
1103
|
-
else if (op === require_backend.AluOp.BitCombine) if (arg === "and") source = `(${a} & ${b})`;
|
|
1104
|
-
else if (arg === "or") source = `(${a} | ${b})`;
|
|
1105
|
-
else source = dtype === require_backend.DType.Bool ? `(${a} != ${b})` : `(${a} ^ ${b})`;
|
|
1106
|
-
else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
|
|
1107
|
-
else source = `(${a} >> ${b})`;
|
|
1108
|
-
else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
1109
|
-
else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
|
|
1110
|
-
const x = isGensym(a) ? a : gensym();
|
|
1111
|
-
if (x !== a) emit(`let ${x} = ${a};`);
|
|
1112
|
-
source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
|
|
1113
|
-
} else source = `(${a} != ${b})`;
|
|
1114
|
-
} else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
|
|
1115
|
-
const a = gen(src[0].src[0]);
|
|
1116
|
-
source = `inverseSqrt(${a})`;
|
|
1117
|
-
} else {
|
|
1118
|
-
const a = gen(src[0]);
|
|
1119
|
-
if (op === require_backend.AluOp.Sin) source = `sin(${require_backend.strip1(a)})`;
|
|
1120
|
-
else if (op === require_backend.AluOp.Cos) source = `cos(${require_backend.strip1(a)})`;
|
|
1121
|
-
else if (op === require_backend.AluOp.Asin) source = `asin(${require_backend.strip1(a)})`;
|
|
1122
|
-
else if (op === require_backend.AluOp.Atan) source = `atan(${require_backend.strip1(a)})`;
|
|
1123
|
-
else if (op === require_backend.AluOp.Exp) source = `exp(${require_backend.strip1(a)})`;
|
|
1124
|
-
else if (op === require_backend.AluOp.Log) source = `log(${require_backend.strip1(a)})`;
|
|
1125
|
-
else if (op === require_backend.AluOp.Erf || op === require_backend.AluOp.Erfc) {
|
|
1126
|
-
const funcName = op === require_backend.AluOp.Erf ? "erf" : "erfc";
|
|
1127
|
-
if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${require_backend.strip1(a)})))`;
|
|
1128
|
-
else source = `${funcName}(${require_backend.strip1(a)})`;
|
|
1129
|
-
} else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${require_backend.strip1(a)})`;
|
|
1130
|
-
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
1131
|
-
else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
|
|
1132
|
-
else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
|
|
1133
|
-
else if (op === require_backend.AluOp.Cast) {
|
|
1134
|
-
const srcTy = dtypeToWgsl(src[0].dtype);
|
|
1135
|
-
const dstTy = dtypeToWgsl(dtype);
|
|
1136
|
-
if (require_backend.isFloatDtype(src[0].dtype) && !(require_backend.isFloatDtype(dtype) || dtype === require_backend.DType.Bool)) {
|
|
1137
|
-
const maxVal = maxValueWgsl(dtype);
|
|
1138
|
-
const x = isGensym(a) ? a : gensym();
|
|
1139
|
-
if (x !== a) emit(`let ${x}: ${srcTy} = ${require_backend.strip1(a)};`);
|
|
1140
|
-
source = `select(${dstTy}(${x}), ${maxVal}, ${x} >= ${srcTy}(${maxVal}))`;
|
|
1141
|
-
} else source = `${dstTy}(${require_backend.strip1(a)})`;
|
|
1142
|
-
} else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
1143
|
-
}
|
|
1144
|
-
else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
|
|
1145
|
-
else if (op === require_backend.AluOp.Threefry2x32) {
|
|
1146
|
-
const x = gensym();
|
|
1147
|
-
const [k0, k1, c0, c1] = src.map((x$1) => require_backend.strip1(gen(x$1)));
|
|
1148
|
-
emit(`let ${x} = threefry2x32(vec2(${k0}, ${k1}), vec2(${c0}, ${c1}));`);
|
|
1149
|
-
if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
|
|
1150
|
-
else if (arg === 0) source = `${x}.x`;
|
|
1151
|
-
else if (arg === 1) source = `${x}.y`;
|
|
1152
|
-
else throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
1153
|
-
} else if (op === require_backend.AluOp.Const) return constToWgsl(dtype, arg);
|
|
1154
|
-
else if (op === require_backend.AluOp.Special) return arg[0];
|
|
1155
|
-
else if (op === require_backend.AluOp.Variable) return arg;
|
|
1156
|
-
else if (op === require_backend.AluOp.GlobalIndex) {
|
|
1157
|
-
source = `${args[arg[0]]}[${require_backend.strip1(gen(src[0]))}]`;
|
|
1158
|
-
if (dtype === require_backend.DType.Bool) source = `(${source} != 0)`;
|
|
1159
|
-
}
|
|
1160
|
-
if (!source) throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
1161
|
-
const typeName = dtypeToWgsl(dtype);
|
|
1162
|
-
if ((references.get(exp) ?? 0) > 1) {
|
|
1163
|
-
const name = gensym();
|
|
1164
|
-
expContext.set(exp, name);
|
|
1165
|
-
emit(`let ${name}: ${typeName} = ${require_backend.strip1(source)};`);
|
|
1166
|
-
return name;
|
|
1167
|
-
} else {
|
|
1168
|
-
expContext.set(exp, source);
|
|
1169
|
-
return source;
|
|
1170
|
-
}
|
|
1171
|
-
};
|
|
1328
|
+
wb.emitPhonyAssignments(args);
|
|
1329
|
+
const gen = new WgslExpCodegen(wb, args);
|
|
1172
1330
|
if (!re) {
|
|
1173
|
-
countReferences(tune.exp);
|
|
1174
|
-
let rhs = require_backend.strip1(gen(tune.exp));
|
|
1331
|
+
gen.countReferences(tune.exp);
|
|
1332
|
+
let rhs = require_backend.strip1(gen.run(tune.exp));
|
|
1175
1333
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1176
|
-
emit(`result[gidx] = ${rhs};`);
|
|
1334
|
+
wb.emit(`result[gidx] = ${rhs};`);
|
|
1177
1335
|
} else {
|
|
1178
1336
|
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
1179
1337
|
const unroll = tune.size.unroll ?? 1;
|
|
1180
1338
|
const upcast = tune.size.upcast ?? 1;
|
|
1181
1339
|
const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
|
|
1182
|
-
for (let i = 0; i < upcast; i++) emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
|
|
1183
|
-
emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, pushIndent);
|
|
1340
|
+
for (let i = 0; i < upcast; i++) wb.emit(`var ${acc[i]}: ${dtypeToWgsl(re.dtype)} = ${constToWgsl(re.dtype, re.identity)};`);
|
|
1341
|
+
wb.emit(`for (var ridx: i32 = 0; ridx < ${tune.size.reduce}; ridx++) {`, wb.pushIndent);
|
|
1184
1342
|
const exps = [];
|
|
1185
1343
|
const cache = /* @__PURE__ */ new Map();
|
|
1186
1344
|
for (let up = 0; up < upcast; up++) {
|
|
@@ -1191,10 +1349,10 @@ function pipelineSource(device, kernel) {
|
|
|
1191
1349
|
unroll: require_backend.AluExp.i32(un)
|
|
1192
1350
|
});
|
|
1193
1351
|
exps[up].push(exp.simplify(cache));
|
|
1194
|
-
countReferences(exps[up][un]);
|
|
1352
|
+
gen.countReferences(exps[up][un]);
|
|
1195
1353
|
}
|
|
1196
1354
|
}
|
|
1197
|
-
const items = exps.map((ar) => ar.map(gen).map(require_backend.strip1));
|
|
1355
|
+
const items = exps.map((ar) => ar.map((x) => gen.run(x)).map(require_backend.strip1));
|
|
1198
1356
|
for (let i = 0; i < upcast; i++) {
|
|
1199
1357
|
let rhs = items[i][0];
|
|
1200
1358
|
for (let j = 1; j < unroll; j++) if (re.op === require_backend.AluOp.Add) rhs = `${rhs} + ${items[i][j]}`;
|
|
@@ -1202,40 +1360,38 @@ function pipelineSource(device, kernel) {
|
|
|
1202
1360
|
else if (re.op === require_backend.AluOp.Min) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} && ${items[i][j]})` : `min(${rhs}, ${items[i][j]})`;
|
|
1203
1361
|
else if (re.op === require_backend.AluOp.Max) rhs = re.dtype === require_backend.DType.Bool ? `(${rhs} || ${items[i][j]})` : `max(${rhs}, ${items[i][j]})`;
|
|
1204
1362
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1205
|
-
if (re.op === require_backend.AluOp.Add) emit(`${acc[i]} += ${rhs};`);
|
|
1206
|
-
else if (re.op === require_backend.AluOp.Mul) emit(`${acc[i]} *= ${rhs};`);
|
|
1207
|
-
else if (re.op === require_backend.AluOp.Min) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
|
|
1208
|
-
else emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
|
|
1209
|
-
else if (re.op === require_backend.AluOp.Max) if (re.dtype === require_backend.DType.Bool) emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
|
|
1210
|
-
else emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
|
|
1363
|
+
if (re.op === require_backend.AluOp.Add) wb.emit(`${acc[i]} += ${rhs};`);
|
|
1364
|
+
else if (re.op === require_backend.AluOp.Mul) wb.emit(`${acc[i]} *= ${rhs};`);
|
|
1365
|
+
else if (re.op === require_backend.AluOp.Min) if (re.dtype === require_backend.DType.Bool) wb.emit(`${acc[i]} = ${acc[i]} && ${rhs};`);
|
|
1366
|
+
else wb.emit(`${acc[i]} = min(${acc[i]}, ${rhs});`);
|
|
1367
|
+
else if (re.op === require_backend.AluOp.Max) if (re.dtype === require_backend.DType.Bool) wb.emit(`${acc[i]} = ${acc[i]} || ${rhs};`);
|
|
1368
|
+
else wb.emit(`${acc[i]} = max(${acc[i]}, ${rhs});`);
|
|
1211
1369
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1212
1370
|
}
|
|
1213
|
-
emit(popIndent, "}");
|
|
1214
|
-
|
|
1215
|
-
references.clear();
|
|
1216
|
-
seen.clear();
|
|
1371
|
+
wb.emit(wb.popIndent, "}");
|
|
1372
|
+
gen.reset();
|
|
1217
1373
|
const outputIdxExps = [];
|
|
1218
1374
|
const fusionExps = [];
|
|
1219
1375
|
for (let i = 0; i < upcast; i++) {
|
|
1220
1376
|
const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
|
|
1221
1377
|
outputIdxExps.push(exp.simplify(cache));
|
|
1222
|
-
countReferences(outputIdxExps[i]);
|
|
1378
|
+
gen.countReferences(outputIdxExps[i]);
|
|
1223
1379
|
fusionExps.push(tune.epilogue.substitute({
|
|
1224
1380
|
acc: require_backend.AluExp.variable(re.dtype, acc[i]),
|
|
1225
1381
|
upcast: require_backend.AluExp.i32(i)
|
|
1226
1382
|
}).simplify(cache));
|
|
1227
|
-
countReferences(fusionExps[i]);
|
|
1383
|
+
gen.countReferences(fusionExps[i]);
|
|
1228
1384
|
}
|
|
1229
1385
|
for (let i = 0; i < upcast; i++) {
|
|
1230
|
-
const index = require_backend.strip1(gen(outputIdxExps[i]));
|
|
1231
|
-
let rhs = require_backend.strip1(gen(fusionExps[i]));
|
|
1386
|
+
const index = require_backend.strip1(gen.run(outputIdxExps[i]));
|
|
1387
|
+
let rhs = require_backend.strip1(gen.run(fusionExps[i]));
|
|
1232
1388
|
if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1233
|
-
emit(`result[${index}] = ${rhs};`);
|
|
1389
|
+
wb.emit(`result[${index}] = ${rhs};`);
|
|
1234
1390
|
}
|
|
1235
1391
|
}
|
|
1236
|
-
emit(popIndent, "}");
|
|
1392
|
+
wb.emit(wb.popIndent, "}");
|
|
1237
1393
|
return {
|
|
1238
|
-
code:
|
|
1394
|
+
code: wb.toString(),
|
|
1239
1395
|
numInputs: nargs,
|
|
1240
1396
|
numOutputs: 1,
|
|
1241
1397
|
hasUniform: false,
|
|
@@ -1279,11 +1435,17 @@ function pipelineSubmit(device, exe, inputs, outputs) {
|
|
|
1279
1435
|
}
|
|
1280
1436
|
for (let i = 0; i < filteredPasses.length; i++) {
|
|
1281
1437
|
const { grid } = filteredPasses[i];
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1438
|
+
let timestampWrites;
|
|
1439
|
+
if (slot) {
|
|
1440
|
+
const isFirst = i === 0;
|
|
1441
|
+
const isLast = i === filteredPasses.length - 1;
|
|
1442
|
+
if (isFirst || isLast) timestampWrites = {
|
|
1443
|
+
querySet: slot.batch.querySet,
|
|
1444
|
+
...isFirst ? { beginningOfPassWriteIndex: slot.beginIndex } : {},
|
|
1445
|
+
...isLast ? { endOfPassWriteIndex: slot.endIndex } : {}
|
|
1446
|
+
};
|
|
1447
|
+
}
|
|
1448
|
+
const passEncoder = commandEncoder.beginComputePass({ timestampWrites });
|
|
1287
1449
|
passEncoder.setPipeline(pipeline);
|
|
1288
1450
|
passEncoder.setBindGroup(0, bindGroup);
|
|
1289
1451
|
if (uniformBindGroup) passEncoder.setBindGroup(1, uniformBindGroup, [i * uniformAlignment]);
|