@jax-js/jax 0.1.12 → 0.1.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2327,6 +2327,25 @@ var TuneDims = class {
2327
2327
  this.upcast++;
2328
2328
  }
2329
2329
  }
2330
+ applyGroup(axis, amount) {
2331
+ if (axis < this.reduce || axis >= this.unroll) throw new Error("Can only group reduction axes");
2332
+ const length = this.st.shape[axis];
2333
+ if (length % amount !== 0) throw new Error(`Group by ${amount} on axis length ${length}`);
2334
+ this.st = this.st.reshape([
2335
+ ...this.st.shape.slice(0, axis),
2336
+ length / amount,
2337
+ amount,
2338
+ ...this.st.shape.slice(axis + 1)
2339
+ ]).permute([
2340
+ ...range(this.reduce),
2341
+ axis + 1,
2342
+ ...range(this.reduce, axis + 1),
2343
+ ...range(axis + 2, this.st.shape.length + 1)
2344
+ ]);
2345
+ this.reduce++;
2346
+ this.unroll++;
2347
+ this.upcast++;
2348
+ }
2330
2349
  };
2331
2350
  /** Tuning step that does not apply any optimization. */
2332
2351
  function tuneNullopt(kernel) {
@@ -2398,6 +2417,17 @@ function tuneWebgpu(kernel) {
2398
2417
  upcastedAxis.add(choices[0][2]);
2399
2418
  } else break;
2400
2419
  }
2420
+ const groupCandidateSts = sts.map((st) => st.compose(dim.st));
2421
+ const groupCandidateHasContiguousInputs = groupCandidateSts.every((st) => {
2422
+ const hasOutputStride = st.lastStrides.slice(0, dim.groups).some((stride) => stride !== 0);
2423
+ return !hasOutputStride || Math.abs(st.lastStrides[dim.reduce]) <= 1;
2424
+ });
2425
+ if (reduction.op === AluOp.Add && reduction.dtype === DType.Float32 && groupCandidateHasContiguousInputs && prod(dim.st.shape.slice(0, dim.groups)) < 4096 && prod(dim.st.shape.slice(dim.reduce, dim.unroll)) >= 512) {
2426
+ const axis = dim.reduce;
2427
+ let amount = 16;
2428
+ while (amount > 1 && dim.st.shape[axis] % amount !== 0) amount /= 2;
2429
+ if (amount > 1) dim.applyGroup(axis, amount);
2430
+ }
2401
2431
  if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2402
2432
  const s = dim.st.shape[dim.unroll - 1];
2403
2433
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
@@ -3218,6 +3248,30 @@ function wasm_threefry2x32(cg) {
3218
3248
  });
3219
3249
  }
3220
3250
 
3251
+ //#endregion
3252
+ //#region src/backend/wasm/featureProbe.ts
3253
+ const featureProbes = { "relaxed-madd": "0061736d0100000001080160037b7b7b017b030201000a0d010b00200020012002fd85020b" };
3254
+ const featureSupportCache = /* @__PURE__ */ new Map();
3255
+ /** Detects whether this environment supports a probed Wasm CPU feature. */
3256
+ function hasWasmFeature(feature) {
3257
+ const cached = featureSupportCache.get(feature);
3258
+ if (cached !== void 0) return cached;
3259
+ const testHex = featureProbes[feature];
3260
+ let supported = false;
3261
+ try {
3262
+ supported = typeof WebAssembly !== "undefined" && WebAssembly.validate(decodeHex(testHex));
3263
+ } catch {
3264
+ supported = false;
3265
+ }
3266
+ featureSupportCache.set(feature, supported);
3267
+ return supported;
3268
+ }
3269
+ function decodeHex(hex) {
3270
+ const bytes = new Uint8Array(/* @__PURE__ */ new ArrayBuffer(hex.length / 2));
3271
+ for (let i = 0; i < bytes.length; i++) bytes[i] = Number.parseInt(hex.slice(i * 2, i * 2 + 2), 16);
3272
+ return bytes;
3273
+ }
3274
+
3221
3275
  //#endregion
3222
3276
  //#region src/backend/wasm/parallel.ts
3223
3277
  /** Check if SharedArrayBuffer is available. */
@@ -3308,10 +3362,10 @@ var WasmWorkerPool = class {
3308
3362
  * Returns an epoch that can be used to wait for the ongoing work to complete,
3309
3363
  * which is guaranteed to be monotonically increasing.
3310
3364
  */
3311
- dispatch(module, ptrs, size) {
3365
+ dispatch(module, ptrs, size, chunkAlignment = 16, minWorkPerWorker = MIN_ELEMS_PER_THREAD) {
3312
3366
  this.#ensureInit();
3313
3367
  this.#epochEnd++;
3314
- const result = this.#queue.then(() => this.#dispatchNow(module, ptrs, size));
3368
+ const result = this.#queue.then(() => this.#dispatchNow(module, ptrs, size, chunkAlignment, minWorkPerWorker));
3315
3369
  this.#queue = result.then(() => {}, () => {}).then(() => {
3316
3370
  this.#epoch++;
3317
3371
  const hooks = this.#hooks.get(this.#epoch);
@@ -3322,10 +3376,10 @@ var WasmWorkerPool = class {
3322
3376
  });
3323
3377
  return this.#epochEnd;
3324
3378
  }
3325
- async #dispatchNow(module, ptrs, size) {
3379
+ async #dispatchNow(module, ptrs, size, chunkAlignment, minWorkPerWorker) {
3326
3380
  if (size === 0) return;
3327
- const n = Math.min(this.#workers.length, Math.ceil(size / MIN_ELEMS_PER_THREAD));
3328
- const chunkSize = Math.ceil(size / n / 16) * 16;
3381
+ const n = Math.min(this.#workers.length, Math.ceil(size / minWorkPerWorker));
3382
+ const chunkSize = Math.ceil(size / n / chunkAlignment) * chunkAlignment;
3329
3383
  const promises = [];
3330
3384
  for (let i = 0; i < n; i++) {
3331
3385
  const begin = i * chunkSize;
@@ -3360,1421 +3414,1529 @@ function createWorkerPool(memory) {
3360
3414
  }
3361
3415
 
3362
3416
  //#endregion
3363
- //#region src/backend/wasm/wasmblr.ts
3364
- /**
3365
- * @file Minimalist WebAssembly assembler. This allows you to emit WebAssembly
3366
- * bytecode directly from the browser.
3367
- *
3368
- * Self-contained port of https://github.com/bwasti/wasmblr to TypeScript.
3369
- * Some operation names in this module are written in `snake_case` to match
3370
- * their names in the Wasm specification.
3371
- *
3372
- * Reference: https://pengowray.github.io/wasm-ops/.
3373
- */
3374
- const magicModuleHeader = [
3375
- 0,
3376
- 97,
3377
- 115,
3378
- 109
3379
- ];
3380
- const moduleVersion = [
3381
- 1,
3382
- 0,
3383
- 0,
3384
- 0
3385
- ];
3386
- function assert(condition, message) {
3387
- if (!condition) throw new Error(message || "Assertion failed");
3417
+ //#region src/backend/wasm/tilePlan.ts
3418
+ const simdLanes = 4;
3419
+ const TILED_SIMD_ROWS = 128;
3420
+ const TILED_SIMD_COLUMNS = 128;
3421
+ const TILED_SIMD_K = 64;
3422
+ const TILED_SIMD_MICRO_ROWS = 4;
3423
+ const TILED_SIMD_MICRO_VECTORS = 4;
3424
+ const K_SIMD_MICRO_ROWS = 4;
3425
+ const K_SIMD_MICRO_COLS = 4;
3426
+ const K_SIMD_UNROLL = 4;
3427
+ const TILE_AXIS_PARTS = 8;
3428
+ function isSymbol(exp, name) {
3429
+ return exp.op === AluOp.Variable && exp.arg === name || exp.op === AluOp.Special && exp.arg[0] === name;
3388
3430
  }
3389
- function encodeSigned(n) {
3390
- const out = [];
3391
- let more = true;
3392
- while (more) {
3393
- let byte = n & 127;
3394
- n >>= 7;
3395
- if (n === 0 && (byte & 64) === 0 || n === -1 && (byte & 64) !== 0) more = false;
3396
- else byte |= 128;
3397
- out.push(byte);
3431
+ function referencesSymbol(exp, name) {
3432
+ return isSymbol(exp, name) || exp.src.some((src) => referencesSymbol(src, name));
3433
+ }
3434
+ function referencesGidx(exp) {
3435
+ return referencesSymbol(exp, "gidx");
3436
+ }
3437
+ function hasFragmentRisk(tileSize, N) {
3438
+ return isFinite(tileSize) && tileSize > N && tileSize % N !== 0;
3439
+ }
3440
+ function constInt(exp) {
3441
+ if (exp.op !== AluOp.Const) return null;
3442
+ const value = exp.arg;
3443
+ return Number.isInteger(value) ? value : null;
3444
+ }
3445
+ function coefficientOfSymbol(exp, name) {
3446
+ if (!referencesSymbol(exp, name)) return 0;
3447
+ if (isSymbol(exp, name)) return 1;
3448
+ if (exp.op === AluOp.Add || exp.op === AluOp.Sub) {
3449
+ const a = coefficientOfSymbol(exp.src[0], name);
3450
+ const b = coefficientOfSymbol(exp.src[1], name);
3451
+ if (a === null || b === null) return null;
3452
+ return exp.op === AluOp.Add ? a + b : a - b;
3398
3453
  }
3399
- return out;
3454
+ if (exp.op === AluOp.Mul) {
3455
+ const lhs = constInt(exp.src[0]);
3456
+ if (lhs !== null) {
3457
+ const rhsCoeff = coefficientOfSymbol(exp.src[1], name);
3458
+ return rhsCoeff === null ? null : lhs * rhsCoeff;
3459
+ }
3460
+ const rhs = constInt(exp.src[1]);
3461
+ if (rhs !== null) {
3462
+ const lhsCoeff = coefficientOfSymbol(exp.src[0], name);
3463
+ return lhsCoeff === null ? null : rhs * lhsCoeff;
3464
+ }
3465
+ }
3466
+ return null;
3400
3467
  }
3401
- function encodeUnsigned(n) {
3402
- const out = [];
3403
- do {
3404
- let byte = n & 127;
3405
- n = n >>> 7;
3406
- if (n !== 0) byte |= 128;
3407
- out.push(byte);
3408
- } while (n !== 0);
3409
- return out;
3468
+ function rewriteSymbol(exp, name, rewrite) {
3469
+ return exp.rewrite((node) => isSymbol(node, name) ? rewrite(node) : void 0).simplify();
3410
3470
  }
3411
- function encodeString(s) {
3412
- const bytes = new TextEncoder().encode(s);
3413
- return [bytes.length, ...bytes];
3471
+ function repeatsAcrossGidxTile(exp, tileSize) {
3472
+ if (!isFinite(tileSize)) return false;
3473
+ const shifted = rewriteSymbol(exp, "gidx", (node) => AluExp.add(node, AluExp.i32(tileSize)));
3474
+ return shifted.getHash() === exp.getHash();
3414
3475
  }
3415
- function encodeBlocktype(type) {
3416
- assert(type.length > 0, "blocktype must have at least one type");
3417
- if (type.length === 1) return [type[0].typeId];
3418
- return [
3419
- 96,
3420
- ...encodeUnsigned(0),
3421
- ...encodeUnsigned(type.length),
3422
- ...type.map((t) => t.typeId)
3423
- ];
3476
+ function divisorAtMost(value, limit) {
3477
+ for (let i = Math.min(value, limit); i > 1; i--) if (value % i === 0) return i;
3478
+ return 1;
3424
3479
  }
3425
- function encodeOpcode(opcode) {
3426
- if (typeof opcode === "number") return [opcode];
3427
- return [opcode[0], ...encodeUnsigned(opcode[1])];
3480
+ function tileAxisLimit(axisSize, maxTileSize) {
3481
+ return Math.min(maxTileSize, Math.max(1, Math.floor(axisSize / TILE_AXIS_PARTS)));
3428
3482
  }
3429
- function concat(out, inp) {
3430
- out.push(...inp);
3483
+ function commonTileSize(kernelSize, strideMap, minTileSize, unconstrainedTileSize = null) {
3484
+ const tileSizes = [];
3485
+ for (const stride of strideMap.values()) if (stride.kind !== "gather" && isFinite(stride.tileSize)) tileSizes.push(stride.tileSize);
3486
+ if (tileSizes.length === 0) return unconstrainedTileSize;
3487
+ const tileSize = Math.min(...tileSizes);
3488
+ if (tileSize < minTileSize || kernelSize % tileSize !== 0 || tileSizes.some((size) => size % tileSize !== 0)) return null;
3489
+ return tileSize;
3431
3490
  }
3432
- var Function_ = class {
3433
- inputTypes;
3434
- outputTypes;
3435
- body;
3436
- locals = [];
3437
- constructor(inputTypes, outputTypes, body) {
3438
- this.inputTypes = inputTypes;
3439
- this.outputTypes = outputTypes;
3440
- this.body = body || (() => {});
3441
- }
3442
- emit() {
3443
- this.locals = [];
3444
- this.body();
3445
- }
3446
- };
3447
- var Memory = class {
3448
- min = 0;
3449
- max = 0;
3450
- isShared = false;
3451
- aString = "";
3452
- bString = "";
3453
- constructor(cg) {
3454
- this.cg = cg;
3455
- }
3456
- /** Declare the size of the memory. Each page is 64 KiB. */
3457
- pages(min, max = 0) {
3458
- assert(this.min === 0 && this.max === 0);
3459
- this.min = min;
3460
- this.max = max;
3461
- return this;
3462
- }
3463
- export(a) {
3464
- assert(!this.isImport && !this.isExport, "already set");
3465
- this.aString = a;
3466
- return this;
3467
- }
3468
- shared(isShared) {
3469
- this.isShared = isShared;
3470
- return this;
3471
- }
3472
- import(a, b) {
3473
- assert(!this.isImport && !this.isExport, "already set");
3474
- this.aString = a;
3475
- this.bString = b;
3476
- return this;
3477
- }
3478
- size() {
3479
- this.cg._emit(63);
3480
- this.cg._emit(0);
3481
- }
3482
- grow() {
3483
- this.cg._emit(64);
3484
- this.cg._emit(0);
3485
- }
3486
- get isImport() {
3487
- return this.aString.length > 0 && this.bString.length > 0;
3488
- }
3489
- get isExport() {
3490
- return this.aString.length > 0 && this.bString.length === 0;
3491
- }
3492
- };
3493
- /** Public API of WebAssembly assembler. */
3494
- var CodeGenerator = class {
3495
- local;
3496
- i32;
3497
- f32;
3498
- f64;
3499
- v128;
3500
- i32x4;
3501
- f32x4;
3502
- memory;
3503
- void = {
3504
- typeId: 64,
3505
- name: "void"
3491
+ function tiledRows(kernelSize, tileSize) {
3492
+ const rowCount = kernelSize / tileSize;
3493
+ return divisorAtMost(rowCount, tileAxisLimit(rowCount, TILED_SIMD_ROWS));
3494
+ }
3495
+ function tiledColumns(tileSize, laneWidth = 1) {
3496
+ return divisorAtMost(tileSize / laneWidth, Math.max(1, Math.floor(tileAxisLimit(tileSize, TILED_SIMD_COLUMNS) / laneWidth)));
3497
+ }
3498
+ function periodicStride(exp, kind) {
3499
+ if (exp.src[1].op !== AluOp.Const) return null;
3500
+ const N = exp.src[1].arg;
3501
+ const inner = analyzeStride(exp.src[0]);
3502
+ if (inner.kind === "broadcast") return inner;
3503
+ if (inner.kind !== "contiguous" || hasFragmentRisk(inner.tileSize, N)) return { kind: "gather" };
3504
+ return {
3505
+ kind,
3506
+ tileSize: Math.min(inner.tileSize, N)
3506
3507
  };
3507
- #functions = [];
3508
- #importedFunctions = [];
3509
- #exportedFunctions = /* @__PURE__ */ new Map();
3510
- #curFunction = null;
3511
- #curBytes = [];
3512
- #typeStack = [];
3513
- #blockFrames = [];
3514
- constructor() {
3515
- this.local = new Local(this);
3516
- this.i32 = new I32(this);
3517
- this.f32 = new F32(this);
3518
- this.f64 = new F64(this);
3519
- this.v128 = new V128(this);
3520
- this.i32x4 = new I32x4(this);
3521
- this.f32x4 = new F32x4(this);
3522
- this.memory = new Memory(this);
3523
- }
3524
- unreachable() {
3525
- this._emit(0);
3526
- }
3527
- nop() {
3528
- this._emit(1);
3529
- }
3530
- block(...type) {
3531
- this.#blockFrames.push({
3532
- idx: this.#typeStack.length,
3533
- ty: type
3534
- });
3535
- this._emit(2);
3536
- this._emit(encodeBlocktype(type));
3508
+ }
3509
+ function addStrides(lhs, rhs) {
3510
+ if (lhs.kind === "gather" || rhs.kind === "gather") return { kind: "gather" };
3511
+ const tileSize = Math.min(lhs.tileSize, rhs.tileSize);
3512
+ if (lhs.kind === "broadcast") return {
3513
+ kind: rhs.kind,
3514
+ tileSize
3515
+ };
3516
+ if (rhs.kind === "broadcast") return {
3517
+ kind: lhs.kind,
3518
+ tileSize
3519
+ };
3520
+ return { kind: "gather" };
3521
+ }
3522
+ function analyzeStride(exp) {
3523
+ if (!referencesGidx(exp)) return {
3524
+ kind: "broadcast",
3525
+ tileSize: Infinity
3526
+ };
3527
+ if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return {
3528
+ kind: "contiguous",
3529
+ tileSize: Infinity
3530
+ };
3531
+ if (exp.op === AluOp.Idiv || exp.op === AluOp.Mod) {
3532
+ const stride = periodicStride(exp, exp.op === AluOp.Idiv ? "broadcast" : "contiguous");
3533
+ if (stride) return stride;
3537
3534
  }
3538
- loop(...type) {
3539
- this.#blockFrames.push({
3540
- idx: this.#typeStack.length,
3541
- ty: type
3542
- });
3543
- this._emit(3);
3544
- this._emit(encodeBlocktype(type));
3535
+ if (exp.op === AluOp.Mul) {
3536
+ for (let i = 0; i < 2; i++) if (exp.src[i].op === AluOp.Const) {
3537
+ const inner = analyzeStride(exp.src[1 - i]);
3538
+ if (inner.kind === "broadcast") return inner;
3539
+ return { kind: "gather" };
3540
+ }
3545
3541
  }
3546
- if(...type) {
3547
- assert(this._pop().typeId === this.i32.typeId, "if_: expected i32");
3548
- this.#blockFrames.push({
3549
- idx: this.#typeStack.length,
3550
- ty: type
3542
+ if (exp.op === AluOp.Add) return addStrides(analyzeStride(exp.src[0]), analyzeStride(exp.src[1]));
3543
+ return { kind: "gather" };
3544
+ }
3545
+ function simdStrideResult(globalIndex) {
3546
+ const index = globalIndex.src[0];
3547
+ const result = analyzeStride(index);
3548
+ const [_, len] = globalIndex.arg;
3549
+ if (result.kind !== "gather" && (result.tileSize < simdLanes || isFinite(result.tileSize) && result.tileSize % simdLanes !== 0)) return { kind: "gather" };
3550
+ if (result.kind === "contiguous" && (index.min < 0 || index.max >= len)) return { kind: "gather" };
3551
+ return result;
3552
+ }
3553
+ function collectSimdStrides(exp) {
3554
+ return new Map((exp?.collect((node) => node.op === AluOp.GlobalIndex) ?? []).map((gi) => [gi, simdStrideResult(gi)]));
3555
+ }
3556
+ function reductionPointerCandidates(exp, strideMap) {
3557
+ const candidates = [];
3558
+ for (const gi of exp.collect((node) => node.op === AluOp.GlobalIndex)) {
3559
+ const stride = strideMap?.get(gi) ?? {
3560
+ kind: "broadcast",
3561
+ tileSize: Infinity
3562
+ };
3563
+ if (strideMap && stride.kind === "gather") continue;
3564
+ const [gid, len] = gi.arg;
3565
+ const index = gi.src[0];
3566
+ const strideElems = coefficientOfSymbol(index, "ridx");
3567
+ if (strideElems === null || !Number.isInteger(strideElems)) continue;
3568
+ if (index.min < 0 || index.max >= len) continue;
3569
+ candidates.push({
3570
+ exp: gi,
3571
+ gid,
3572
+ dtype: gi.dtype,
3573
+ stride,
3574
+ baseIndex: rewriteSymbol(index, "ridx", () => AluExp.i32(0)),
3575
+ strideBytes: strideElems * byteWidth(gi.dtype)
3551
3576
  });
3552
- this._emit(4);
3553
- this._emit(encodeBlocktype(type));
3554
- }
3555
- else() {
3556
- assert(this.#blockFrames.length > 0, "else: no block to else");
3557
- const frame = this.#blockFrames[this.#blockFrames.length - 1];
3558
- this.#typeStack = this.#typeStack.slice(0, frame.idx);
3559
- this._emit(5);
3560
3577
  }
3561
- /** End a block (`block`, `if`/`else`, `loop`, or function). */
3562
- end() {
3563
- const frame = this.#blockFrames.pop();
3564
- assert(frame !== void 0, "end: no block to end");
3565
- this.#typeStack = this.#typeStack.slice(0, frame.idx);
3566
- for (const ty of frame.ty) if (ty.typeId !== this.void.typeId) this._push(ty);
3567
- this._emit(11);
3568
- }
3569
- /** Branch to a block a certain depth outward on the stack. */
3570
- br(depth) {
3571
- this._emit(12);
3572
- this._emit(encodeUnsigned(depth));
3573
- }
3574
- /** Conditional branch to a block a certain depth outward on the stack. */
3575
- br_if(depth) {
3576
- assert(this._pop().typeId === this.i32.typeId, "br_if: expected i32");
3577
- this._emit(13);
3578
- this._emit(encodeUnsigned(depth));
3579
- }
3580
- /** Jump table that indexes into a label vector (like switch). */
3581
- br_table(...depths) {
3582
- assert(this._pop().typeId === this.i32.typeId, "br_table: expected i32");
3583
- assert(depths.length > 0, "br_table: expected at least one default depth");
3584
- this._emit(14);
3585
- this._emit(encodeUnsigned(depths.length - 1));
3586
- for (const d of depths) this._emit(encodeUnsigned(d));
3587
- }
3588
- /** Return from a function, branching out of the outermost block. */
3589
- return() {
3590
- this._emit(15);
3591
- }
3592
- /** Call a function with the given ID. */
3593
- call(fn) {
3594
- const totalFunctions = this.#importedFunctions.length + this.#functions.length;
3595
- assert(fn < totalFunctions, "function index does not exist");
3596
- const func = fn < this.#importedFunctions.length ? this.#importedFunctions[fn] : this.#functions[fn - this.#importedFunctions.length];
3597
- for (let i = func.inputTypes.length - 1; i >= 0; i--) {
3598
- const argType = this._pop();
3599
- assert(argType.typeId === func.inputTypes[i].typeId, `call: argument ${i} type mismatch, expected ${func.inputTypes[i].name} got ${argType.name}`);
3578
+ return candidates;
3579
+ }
3580
+ function reductionTilePlan(kernel, strideMap) {
3581
+ if (!kernel.reduction) return null;
3582
+ const tileSize = commonTileSize(kernel.size, strideMap, simdLanes);
3583
+ if (tileSize === null) return null;
3584
+ const tileRows = tiledRows(kernel.size, tileSize);
3585
+ const tileVectors = tiledColumns(tileSize, simdLanes);
3586
+ return {
3587
+ tileSize,
3588
+ tileRows,
3589
+ tileVectors,
3590
+ tileK: divisorAtMost(kernel.reduction.size, TILED_SIMD_K),
3591
+ microRows: divisorAtMost(tileRows, TILED_SIMD_MICRO_ROWS),
3592
+ microVectors: divisorAtMost(tileVectors, TILED_SIMD_MICRO_VECTORS)
3593
+ };
3594
+ }
3595
+ function reductionKTilePlan(kernel, strideMap) {
3596
+ if (!kernel.reduction) return null;
3597
+ if (kernel.reduction.size % (simdLanes * K_SIMD_UNROLL) !== 0) return null;
3598
+ const tileSize = commonTileSize(kernel.size, strideMap, 1, 1);
3599
+ if (tileSize === null) return null;
3600
+ const tileRows = tiledRows(kernel.size, tileSize);
3601
+ const tileCols = tiledColumns(tileSize);
3602
+ return {
3603
+ tileSize,
3604
+ tileRows,
3605
+ tileCols,
3606
+ microRows: divisorAtMost(tileRows, K_SIMD_MICRO_ROWS),
3607
+ microCols: divisorAtMost(tileCols, K_SIMD_MICRO_COLS),
3608
+ kUnroll: K_SIMD_UNROLL
3609
+ };
3610
+ }
3611
+ function pointerShareKey(candidate, row, vector, groupIndex) {
3612
+ const hash = candidate.exp.getHash().toString();
3613
+ const { stride } = candidate;
3614
+ if (stride.kind === "broadcast") return isFinite(stride.tileSize) ? `${hash}:row${row}` : `${hash}:all`;
3615
+ if (stride.kind === "contiguous" && isFinite(stride.tileSize)) return repeatsAcrossGidxTile(candidate.baseIndex, stride.tileSize) ? `${hash}:vec${vector}` : `${hash}:row${row}:vec${vector}`;
3616
+ return `${hash}:g${groupIndex}`;
3617
+ }
3618
+ function kReductionPointerShareKey(candidate, outputStride, tileSize, row, col, groupIndex) {
3619
+ const hash = candidate.exp.getHash().toString();
3620
+ if (outputStride.kind === "broadcast" && isFinite(outputStride.tileSize)) return `${hash}:row${row}`;
3621
+ if (repeatsAcrossGidxTile(candidate.baseIndex, tileSize)) return `${hash}:col${col}`;
3622
+ return `${hash}:g${groupIndex}`;
3623
+ }
3624
+
3625
+ //#endregion
3626
+ //#region src/backend/wasm/translation.ts
3627
+ function translateExp(cg, funcs, exp, ctx, pointerMap = /* @__PURE__ */ new Map()) {
3628
+ const references = /* @__PURE__ */ new Map();
3629
+ const seen = /* @__PURE__ */ new Set();
3630
+ const countReferences = (exp$1) => {
3631
+ references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
3632
+ if (!seen.has(exp$1)) {
3633
+ seen.add(exp$1);
3634
+ for (const src of exp$1.src) countReferences(src);
3600
3635
  }
3601
- for (const outputType of func.outputTypes) this._push(outputType);
3602
- this._emit(16);
3603
- this._emit(encodeUnsigned(fn));
3604
- }
3605
- /** Throw away an operand on the stack. */
3606
- drop() {
3607
- this._pop();
3608
- this._emit(26);
3609
- }
3610
- /** Select one of the first two operands (T, F) based on the third operand (i32)'s value. */
3611
- select() {
3612
- assert(this._pop().typeId === this.i32.typeId, "select: expected i32 condition");
3613
- const [b, a] = [this._pop(), this._pop()];
3614
- assert(a.typeId === b.typeId, "select: expected same type for both operands");
3615
- this._push(a);
3616
- this._emit(27);
3617
- }
3618
- /** Import a JavaScript function; returns its index. */
3619
- importFunction(module, name, inputTypes, outputTypes) {
3620
- if (this.#functions.length > 0) throw new Error("function imports must precede defining functions");
3621
- const idx = this.#importedFunctions.length;
3622
- this.#importedFunctions.push({
3623
- module,
3624
- name,
3625
- inputTypes,
3626
- outputTypes
3627
- });
3628
- return idx;
3629
- }
3630
- /** Export a function. */
3631
- export(fn, name) {
3632
- this.#exportedFunctions.set(fn, name);
3633
- }
3634
- /** Declare a new function; returns its index. */
3635
- function(inputTypes, outputTypes, body) {
3636
- const idx = this.#importedFunctions.length + this.#functions.length;
3637
- this.#functions.push(new Function_(inputTypes, outputTypes, body));
3638
- return idx;
3639
- }
3640
- _declareLocal(type) {
3641
- assert(this.#curFunction !== null, "No current function");
3642
- const idx = this.#curFunction.locals.length + this.#curFunction.inputTypes.length;
3643
- this.#curFunction.locals.push(type);
3644
- return idx;
3645
- }
3646
- _inputTypes() {
3647
- assert(this.#curFunction !== null, "No current function");
3648
- return this.#curFunction.inputTypes;
3649
- }
3650
- _locals() {
3651
- assert(this.#curFunction !== null, "No current function");
3652
- return this.#curFunction.locals;
3636
+ };
3637
+ const expContext = /* @__PURE__ */ new Map();
3638
+ const gen = (exp$1) => {
3639
+ if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
3640
+ const { op, src, dtype, arg } = exp$1;
3641
+ if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
3642
+ gen(src[0]), gen(src[1]);
3643
+ if (op === AluOp.Add) if (dtype === DType.Bool) cg.i32.or();
3644
+ else dty(cg, op, dtype).add();
3645
+ else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
3646
+ else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
3647
+ else dty(cg, op, dtype).mul();
3648
+ else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
3649
+ dtyF(cg, op, dtype).div();
3650
+ dtyF(cg, op, dtype).trunc();
3651
+ } else if (dtype === DType.Uint32) cg.i32.div_u();
3652
+ else if (dtype === DType.Int32) cg.i32.div_s();
3653
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3654
+ else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
3655
+ const dt = dtyF(cg, op, dtype);
3656
+ const a = cg.local.declare(dt);
3657
+ const b = cg.local.declare(dt);
3658
+ cg.local.set(b);
3659
+ cg.local.tee(a);
3660
+ cg.local.get(a);
3661
+ cg.local.get(b);
3662
+ dt.div();
3663
+ dt.trunc();
3664
+ cg.local.get(b);
3665
+ dt.mul();
3666
+ dt.sub();
3667
+ } else if (dtype === DType.Uint32) cg.i32.rem_u();
3668
+ else if (dtype === DType.Int32) cg.i32.rem_s();
3669
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3670
+ else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3671
+ else dtyF(cg, op, dtype).max();
3672
+ else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
3673
+ const a = cg.local.declare(cg.i32);
3674
+ const b = cg.local.declare(cg.i32);
3675
+ cg.local.set(b);
3676
+ cg.local.tee(a);
3677
+ cg.local.get(b);
3678
+ cg.local.get(a);
3679
+ cg.local.get(b);
3680
+ if (dtype === DType.Int32) if (op === AluOp.Min) cg.i32.lt_s();
3681
+ else cg.i32.gt_s();
3682
+ else if (op === AluOp.Min) cg.i32.lt_u();
3683
+ else cg.i32.gt_u();
3684
+ cg.select();
3685
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3686
+ else if (op === AluOp.BitCombine) if (arg === "and") cg.i32.and();
3687
+ else if (arg === "or") cg.i32.or();
3688
+ else cg.i32.xor();
3689
+ else if (op === AluOp.BitShift) if (arg === "shl") cg.i32.shl();
3690
+ else cg.i32.shr_u();
3691
+ else if (op === AluOp.Cmplt) {
3692
+ const srcDtype = src[0].dtype;
3693
+ if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
3694
+ else if (srcDtype === DType.Int32) cg.i32.lt_s();
3695
+ else if (srcDtype === DType.Uint32) cg.i32.lt_u();
3696
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3697
+ } else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
3698
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3699
+ } else if (AluGroup.Unary.has(op)) {
3700
+ const callFuncF32 = (func) => {
3701
+ if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
3702
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3703
+ cg.call(func);
3704
+ if (dtype === DType.Float64) cg.f64.promote_f32();
3705
+ };
3706
+ if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
3707
+ else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
3708
+ else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
3709
+ else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
3710
+ else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
3711
+ else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
3712
+ else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
3713
+ else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
3714
+ else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
3715
+ else if (op === AluOp.Reciprocal) {
3716
+ const dt = dtyF(cg, op, dtype);
3717
+ dt.const(1), gen(src[0]), dt.div();
3718
+ } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
3719
+ else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
3720
+ else if (op === AluOp.Cast) {
3721
+ gen(src[0]);
3722
+ const dtype0 = src[0].dtype;
3723
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3724
+ if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3725
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
3726
+ else if (i32repr);
3727
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3728
+ else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3729
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
3730
+ else if (i32repr);
3731
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3732
+ else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3733
+ else if (dtype0 === DType.Float64) cg.f32.demote_f64();
3734
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3735
+ else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3736
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3737
+ else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
3738
+ else if (dtype0 === DType.Float64);
3739
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
3740
+ else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
3741
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3742
+ else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3743
+ else if (i32repr) cg.i32.const(0), cg.i32.ne();
3744
+ else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3745
+ else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
3746
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3747
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3748
+ } else if (op === AluOp.Bitcast) {
3749
+ gen(src[0]);
3750
+ const dtype0 = src[0].dtype;
3751
+ if (dtype !== dtype0) {
3752
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3753
+ if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3754
+ else if (i32repr);
3755
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3756
+ else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3757
+ else if (dtype0 === DType.Float32);
3758
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3759
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3760
+ }
3761
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3762
+ } else if (op === AluOp.Where) {
3763
+ gen(src[1]);
3764
+ gen(src[2]);
3765
+ gen(src[0]);
3766
+ cg.select();
3767
+ } else if (op === AluOp.Threefry2x32) {
3768
+ for (let i = 0; i < 4; i++) gen(src[i]);
3769
+ cg.call(funcs.threefry2x32);
3770
+ if (arg === "xor") cg.i32.xor();
3771
+ else if (arg === 0) cg.drop();
3772
+ else if (arg === 1) {
3773
+ const local = cg.local.declare(cg.i32);
3774
+ cg.local.set(local);
3775
+ cg.drop();
3776
+ cg.local.get(local);
3777
+ } else throw new UnsupportedOpError(op, dtype, "wasm", arg);
3778
+ } else if (op === AluOp.Const) return dty(cg, op, dtype).const(arg);
3779
+ else if (op === AluOp.Special) return cg.local.get(ctx[arg[0]]);
3780
+ else if (op === AluOp.Variable) return cg.local.get(ctx[arg]);
3781
+ else if (op === AluOp.GlobalIndex) {
3782
+ const [gid, len] = arg;
3783
+ const pointer = pointerMap.get(exp$1);
3784
+ if (pointer) cg.local.get(pointer.ptr);
3785
+ else {
3786
+ gen(src[0]);
3787
+ const local = cg.local.declare(cg.i32);
3788
+ cg.local.tee(local);
3789
+ cg.i32.const(0);
3790
+ cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
3791
+ cg.select();
3792
+ cg.i32.const(byteWidth(dtype));
3793
+ cg.i32.mul();
3794
+ cg.local.get(gid);
3795
+ cg.i32.add();
3796
+ }
3797
+ dty(cg, op, dtype).load(Math.log2(byteWidth(dtype)));
3798
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3799
+ if ((references.get(exp$1) ?? 0) > 1) {
3800
+ const local = cg.local.declare(dty(cg, op, dtype));
3801
+ cg.local.tee(local);
3802
+ expContext.set(exp$1, local);
3803
+ }
3804
+ };
3805
+ countReferences(exp);
3806
+ gen(exp);
3807
+ }
3808
+ /**
3809
+ * SIMD version of translateExp. Emits one v128 value for `exp`, interpreting
3810
+ * the current `gidx` local as the first lane and the following lanes as
3811
+ * `gidx + 1`, `gidx + 2`, etc.
3812
+ *
3813
+ * GlobalIndex loads are the only places where per-lane address behavior
3814
+ * matters:
3815
+ * - `strideMap` classifies each GlobalIndex as contiguous, broadcast, or
3816
+ * gather. Contiguous loads become one v128.load, broadcast loads become a
3817
+ * scalar load plus splat, and gather loads fall back to four scalar loads.
3818
+ * - `pointerMap` optionally supplies precomputed reduction pointers for
3819
+ * GlobalIndex nodes whose address can be advanced by the surrounding
3820
+ * reduction loop. This avoids re-emitting scalar index math in the hot path.
3821
+ * - `pointerValueCache` lets pointer plans with the same `valueKey` share a
3822
+ * single loaded vector within one emitted reduction step.
3823
+ */
3824
+ function translateExpSimd(cg, funcs, exp, ctx, strideMap, pointerMap = /* @__PURE__ */ new Map(), pointerValueCache = /* @__PURE__ */ new Map()) {
3825
+ const references = /* @__PURE__ */ new Map();
3826
+ const seen = /* @__PURE__ */ new Set();
3827
+ const countReferences = (exp$1) => {
3828
+ references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
3829
+ if (!seen.has(exp$1)) {
3830
+ seen.add(exp$1);
3831
+ for (const src of exp$1.src) countReferences(src);
3832
+ }
3833
+ };
3834
+ const expContext = /* @__PURE__ */ new Map();
3835
+ const gen = (exp$1) => {
3836
+ if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
3837
+ const { op, src, arg, dtype } = exp$1;
3838
+ const isInt = dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool;
3839
+ const isSigned = dtype === DType.Int32;
3840
+ if (op === AluOp.Add) {
3841
+ gen(src[0]), gen(src[1]);
3842
+ if (dtype === DType.Bool) cg.v128.or();
3843
+ else if (isInt) cg.i32x4.add();
3844
+ else cg.f32x4.add();
3845
+ } else if (op === AluOp.Sub) {
3846
+ gen(src[0]), gen(src[1]);
3847
+ if (isInt) cg.i32x4.sub();
3848
+ else cg.f32x4.sub();
3849
+ } else if (op === AluOp.Mul) {
3850
+ gen(src[0]), gen(src[1]);
3851
+ if (dtype === DType.Bool) cg.v128.and();
3852
+ else if (isInt) cg.i32x4.mul();
3853
+ else cg.f32x4.mul();
3854
+ } else if (op === AluOp.Min) {
3855
+ gen(src[0]), gen(src[1]);
3856
+ if (isInt) if (isSigned) cg.i32x4.min_s();
3857
+ else cg.i32x4.min_u();
3858
+ else cg.f32x4.min();
3859
+ } else if (op === AluOp.Max) {
3860
+ gen(src[0]), gen(src[1]);
3861
+ if (isInt) if (isSigned) cg.i32x4.max_s();
3862
+ else cg.i32x4.max_u();
3863
+ else cg.f32x4.max();
3864
+ } else if (op === AluOp.Sqrt) {
3865
+ gen(src[0]);
3866
+ cg.f32x4.sqrt();
3867
+ } else if (op === AluOp.Floor) {
3868
+ gen(src[0]);
3869
+ cg.f32x4.floor();
3870
+ } else if (op === AluOp.Ceil) {
3871
+ gen(src[0]);
3872
+ cg.f32x4.ceil();
3873
+ } else if (op === AluOp.Const) if (isInt) {
3874
+ cg.i32.const(arg);
3875
+ cg.i32x4.splat();
3876
+ } else {
3877
+ cg.f32.const(arg);
3878
+ cg.f32x4.splat();
3879
+ }
3880
+ else if (op === AluOp.Cast) {
3881
+ gen(src[0]);
3882
+ const dtype0 = src[0].dtype;
3883
+ const src0IsInt = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3884
+ if (isInt && !src0IsInt) if (isSigned) cg.i32x4.trunc_sat_f32x4_s();
3885
+ else cg.i32x4.trunc_sat_f32x4_u();
3886
+ else if (!isInt && src0IsInt) if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32x4.convert_i32x4_s();
3887
+ else cg.f32x4.convert_i32x4_u();
3888
+ } else if (op === AluOp.Cmplt) {
3889
+ gen(src[0]), gen(src[1]);
3890
+ const srcDtype = src[0].dtype;
3891
+ if (srcDtype === DType.Float32) cg.f32x4.lt();
3892
+ else if (srcDtype === DType.Int32) cg.i32x4.lt_s();
3893
+ else if (srcDtype === DType.Uint32) cg.i32x4.lt_u();
3894
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3895
+ cg.i32.const(1);
3896
+ cg.i32x4.splat();
3897
+ cg.v128.and();
3898
+ } else if (op === AluOp.Cmpne) {
3899
+ gen(src[0]), gen(src[1]);
3900
+ const srcDtype = src[0].dtype;
3901
+ if (srcDtype === DType.Float32) cg.f32x4.ne();
3902
+ else cg.i32x4.ne();
3903
+ cg.i32.const(1);
3904
+ cg.i32x4.splat();
3905
+ cg.v128.and();
3906
+ } else if (op === AluOp.Where) {
3907
+ gen(src[1]);
3908
+ gen(src[2]);
3909
+ gen(src[0]);
3910
+ cg.i32.const(0);
3911
+ cg.i32x4.splat();
3912
+ cg.i32x4.ne();
3913
+ cg.v128.bitselect();
3914
+ } else if (op === AluOp.Variable || op === AluOp.Special) throw new Error(`translateExpSimd: unexpected ${op}(${arg})`);
3915
+ else if (op === AluOp.GlobalIndex) {
3916
+ const [gid, len] = arg;
3917
+ const indexSubtree = src[0];
3918
+ const pointer = pointerMap.get(exp$1);
3919
+ const stride = pointer?.stride ?? strideMap.get(exp$1) ?? { kind: "gather" };
3920
+ if (pointer) {
3921
+ const cached = pointer.valueKey ? pointerValueCache.get(pointer.valueKey) : void 0;
3922
+ if (cached !== void 0) cg.local.get(cached);
3923
+ else {
3924
+ cg.local.get(pointer.ptr);
3925
+ if (stride.kind === "contiguous") if (isInt) cg.i32x4.load(4);
3926
+ else cg.f32x4.load(4);
3927
+ else if (stride.kind === "broadcast") if (isInt) {
3928
+ cg.i32.load(2);
3929
+ cg.i32x4.splat();
3930
+ } else {
3931
+ cg.f32.load(2);
3932
+ cg.f32x4.splat();
3933
+ }
3934
+ else throw new Error("reduction pointer plan cannot use gather loads");
3935
+ if (pointer.valueKey) {
3936
+ const local = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
3937
+ cg.local.tee(local);
3938
+ pointerValueCache.set(pointer.valueKey, local);
3939
+ }
3940
+ }
3941
+ } else if (stride.kind === "contiguous") {
3942
+ translateExp(cg, funcs, indexSubtree, ctx);
3943
+ {
3944
+ const maxIdx = Math.max(len - simdLanes, 0);
3945
+ const wideIdx = cg.local.declare(cg.i32);
3946
+ cg.local.set(wideIdx);
3947
+ cg.local.get(wideIdx);
3948
+ cg.i32.const(maxIdx);
3949
+ cg.local.get(wideIdx);
3950
+ cg.i32.const(maxIdx);
3951
+ cg.i32.lt_u();
3952
+ cg.select();
3953
+ }
3954
+ cg.i32.const(byteWidth(dtype));
3955
+ cg.i32.mul();
3956
+ cg.local.get(gid);
3957
+ cg.i32.add();
3958
+ if (isInt) cg.i32x4.load(4);
3959
+ else cg.f32x4.load(4);
3960
+ } else if (stride.kind === "broadcast") {
3961
+ translateExp(cg, funcs, indexSubtree, ctx);
3962
+ const local = cg.local.declare(cg.i32);
3963
+ cg.local.tee(local);
3964
+ cg.i32.const(0);
3965
+ cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
3966
+ cg.select();
3967
+ cg.i32.const(byteWidth(dtype));
3968
+ cg.i32.mul();
3969
+ cg.local.get(gid);
3970
+ cg.i32.add();
3971
+ if (isInt) {
3972
+ cg.i32.load(2);
3973
+ cg.i32x4.splat();
3974
+ } else {
3975
+ cg.f32.load(2);
3976
+ cg.f32x4.splat();
3977
+ }
3978
+ } else {
3979
+ const steppingLocal = ctx["gidx"];
3980
+ const origValue = cg.local.declare(cg.i32);
3981
+ cg.local.get(steppingLocal);
3982
+ cg.local.set(origValue);
3983
+ if (isInt) {
3984
+ cg.i32.const(0);
3985
+ cg.i32x4.splat();
3986
+ } else {
3987
+ cg.f32.const(0);
3988
+ cg.f32x4.splat();
3989
+ }
3990
+ const vec = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
3991
+ cg.local.set(vec);
3992
+ const idx = cg.local.declare(cg.i32);
3993
+ const scalarVal = cg.local.declare(isInt ? cg.i32 : cg.f32);
3994
+ for (let lane = 0; lane < simdLanes; lane++) {
3995
+ cg.local.get(origValue);
3996
+ if (lane > 0) {
3997
+ cg.i32.const(lane);
3998
+ cg.i32.add();
3999
+ }
4000
+ cg.local.set(steppingLocal);
4001
+ translateExp(cg, funcs, indexSubtree, ctx);
4002
+ cg.local.tee(idx);
4003
+ cg.i32.const(0);
4004
+ cg.local.get(idx), cg.i32.const(len), cg.i32.lt_u();
4005
+ cg.select();
4006
+ cg.i32.const(byteWidth(dtype));
4007
+ cg.i32.mul();
4008
+ cg.local.get(gid);
4009
+ cg.i32.add();
4010
+ if (isInt) cg.i32.load(2);
4011
+ else cg.f32.load(2);
4012
+ cg.local.set(scalarVal);
4013
+ cg.local.get(vec);
4014
+ cg.local.get(scalarVal);
4015
+ if (isInt) cg.i32x4.replace_lane(lane);
4016
+ else cg.f32x4.replace_lane(lane);
4017
+ cg.local.set(vec);
4018
+ }
4019
+ cg.local.get(origValue);
4020
+ cg.local.set(steppingLocal);
4021
+ cg.local.get(vec);
4022
+ }
4023
+ } else throw new Error(`translateExpSimd: unsupported op ${op}`);
4024
+ if ((references.get(exp$1) ?? 0) > 1) {
4025
+ const local = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
4026
+ cg.local.tee(local);
4027
+ expContext.set(exp$1, local);
4028
+ }
4029
+ };
4030
+ countReferences(exp);
4031
+ gen(exp);
4032
+ }
4033
+ function dty(cg, op, dtype) {
4034
+ switch (dtype) {
4035
+ case DType.Float32: return cg.f32;
4036
+ case DType.Float64: return cg.f64;
4037
+ case DType.Int32:
4038
+ case DType.Uint32:
4039
+ case DType.Bool: return cg.i32;
4040
+ default: throw new UnsupportedOpError(op, dtype, "wasm");
3653
4041
  }
3654
- _push(type) {
3655
- if (!type) throw new Error(`pushing type ${type}`);
3656
- this.#typeStack.push(type);
4042
+ }
4043
+ function dtyF(cg, op, dtype) {
4044
+ switch (dtype) {
4045
+ case DType.Float32: return cg.f32;
4046
+ case DType.Float64: return cg.f64;
4047
+ default: throw new UnsupportedOpError(op, dtype, "wasm");
3657
4048
  }
3658
- _pop() {
3659
- assert(this.#typeStack.length > 0, "popping empty stack");
3660
- return this.#typeStack.pop();
4049
+ }
4050
+ /** Subset of operations supported in SIMD compilation mode. */
4051
+ const simdSupportedOps = /* @__PURE__ */ new Map();
4052
+ simdSupportedOps.set(DType.Float32, new Set([
4053
+ AluOp.Add,
4054
+ AluOp.Sub,
4055
+ AluOp.Mul,
4056
+ AluOp.Floor,
4057
+ AluOp.Ceil,
4058
+ AluOp.Min,
4059
+ AluOp.Max,
4060
+ AluOp.Sqrt,
4061
+ AluOp.Cast,
4062
+ AluOp.Where,
4063
+ AluOp.Const,
4064
+ AluOp.GlobalIndex
4065
+ ]));
4066
+ simdSupportedOps.set(DType.Int32, new Set([
4067
+ AluOp.Add,
4068
+ AluOp.Sub,
4069
+ AluOp.Mul,
4070
+ AluOp.Min,
4071
+ AluOp.Max,
4072
+ AluOp.Cast,
4073
+ AluOp.Where,
4074
+ AluOp.Const,
4075
+ AluOp.GlobalIndex
4076
+ ]));
4077
+ simdSupportedOps.set(DType.Uint32, simdSupportedOps.get(DType.Int32));
4078
+ simdSupportedOps.set(DType.Bool, new Set([
4079
+ AluOp.Add,
4080
+ AluOp.Mul,
4081
+ AluOp.Min,
4082
+ AluOp.Max,
4083
+ AluOp.Cmplt,
4084
+ AluOp.Cmpne,
4085
+ AluOp.Const,
4086
+ AluOp.GlobalIndex
4087
+ ]));
4088
+
4089
+ //#endregion
4090
+ //#region src/backend/wasm/wasmblr.ts
4091
+ /**
4092
+ * @file Minimalist WebAssembly assembler. This allows you to emit WebAssembly
4093
+ * bytecode directly from the browser.
4094
+ *
4095
+ * Self-contained port of https://github.com/bwasti/wasmblr to TypeScript.
4096
+ * Some operation names in this module are written in `snake_case` to match
4097
+ * their names in the Wasm specification.
4098
+ *
4099
+ * Reference: https://pengowray.github.io/wasm-ops/.
4100
+ */
4101
+ const magicModuleHeader = [
4102
+ 0,
4103
+ 97,
4104
+ 115,
4105
+ 109
4106
+ ];
4107
+ const moduleVersion = [
4108
+ 1,
4109
+ 0,
4110
+ 0,
4111
+ 0
4112
+ ];
4113
+ function assert(condition, message) {
4114
+ if (!condition) throw new Error(message || "Assertion failed");
4115
+ }
4116
+ function encodeSigned(n) {
4117
+ const out = [];
4118
+ let more = true;
4119
+ while (more) {
4120
+ let byte = n & 127;
4121
+ n >>= 7;
4122
+ if (n === 0 && (byte & 64) === 0 || n === -1 && (byte & 64) !== 0) more = false;
4123
+ else byte |= 128;
4124
+ out.push(byte);
3661
4125
  }
3662
- _emit(bytes) {
3663
- if (typeof bytes === "number") this.#curBytes.push(bytes);
3664
- else this.#curBytes.push(...bytes);
4126
+ return out;
4127
+ }
4128
+ function encodeUnsigned(n) {
4129
+ const out = [];
4130
+ do {
4131
+ let byte = n & 127;
4132
+ n = n >>> 7;
4133
+ if (n !== 0) byte |= 128;
4134
+ out.push(byte);
4135
+ } while (n !== 0);
4136
+ return out;
4137
+ }
4138
+ function encodeString(s) {
4139
+ const bytes = new TextEncoder().encode(s);
4140
+ return [bytes.length, ...bytes];
4141
+ }
4142
+ function encodeBlocktype(type) {
4143
+ assert(type.length > 0, "blocktype must have at least one type");
4144
+ if (type.length === 1) return [type[0].typeId];
4145
+ return [
4146
+ 96,
4147
+ ...encodeUnsigned(0),
4148
+ ...encodeUnsigned(type.length),
4149
+ ...type.map((t) => t.typeId)
4150
+ ];
4151
+ }
4152
+ function encodeOpcode(opcode) {
4153
+ if (typeof opcode === "number") return [opcode];
4154
+ return [opcode[0], ...encodeUnsigned(opcode[1])];
4155
+ }
4156
+ function appendLengthEncodedBlock(out, inp) {
4157
+ out.push(...encodeUnsigned(inp.length));
4158
+ for (const b of inp) out.push(b);
4159
+ }
4160
+ var Function_ = class {
4161
+ inputTypes;
4162
+ outputTypes;
4163
+ body;
4164
+ locals = [];
4165
+ constructor(inputTypes, outputTypes, body) {
4166
+ this.inputTypes = inputTypes;
4167
+ this.outputTypes = outputTypes;
4168
+ this.body = body || (() => {});
3665
4169
  }
3666
- finish() {
3667
- this.#curBytes = [];
3668
- const emittedBytes = [];
3669
- concat(emittedBytes, magicModuleHeader);
3670
- concat(emittedBytes, moduleVersion);
3671
- const typeSectionBytes = [];
3672
- const totalFunctionTypes = this.#importedFunctions.length + this.#functions.length;
3673
- concat(typeSectionBytes, encodeUnsigned(totalFunctionTypes));
3674
- for (const f of [...this.#importedFunctions, ...this.#functions]) {
3675
- typeSectionBytes.push(96);
3676
- concat(typeSectionBytes, encodeUnsigned(f.inputTypes.length));
3677
- for (const t of f.inputTypes) typeSectionBytes.push(t.typeId);
3678
- concat(typeSectionBytes, encodeUnsigned(f.outputTypes.length));
3679
- for (const t of f.outputTypes) typeSectionBytes.push(t.typeId);
3680
- }
3681
- emittedBytes.push(1);
3682
- concat(emittedBytes, encodeUnsigned(typeSectionBytes.length));
3683
- concat(emittedBytes, typeSectionBytes);
3684
- const importSectionBytes = [];
3685
- const numImports = this.#importedFunctions.length + (this.memory.isImport ? 1 : 0);
3686
- if (numImports > 0) {
3687
- concat(importSectionBytes, encodeUnsigned(numImports));
3688
- for (let i = 0; i < this.#importedFunctions.length; i++) {
3689
- const f = this.#importedFunctions[i];
3690
- concat(importSectionBytes, encodeString(f.module));
3691
- concat(importSectionBytes, encodeString(f.name));
3692
- importSectionBytes.push(0);
3693
- concat(importSectionBytes, encodeUnsigned(i));
3694
- }
3695
- if (this.memory.isImport) {
3696
- concat(importSectionBytes, encodeString(this.memory.aString));
3697
- concat(importSectionBytes, encodeString(this.memory.bString));
3698
- importSectionBytes.push(2);
3699
- if (this.memory.max) {
3700
- if (this.memory.isShared) importSectionBytes.push(3);
3701
- else importSectionBytes.push(1);
3702
- concat(importSectionBytes, encodeUnsigned(this.memory.min));
3703
- concat(importSectionBytes, encodeUnsigned(this.memory.max));
3704
- } else {
3705
- assert(!this.memory.isShared, "shared memory must have a max size");
3706
- importSectionBytes.push(0);
3707
- concat(importSectionBytes, encodeUnsigned(this.memory.min));
3708
- }
3709
- }
3710
- emittedBytes.push(2);
3711
- concat(emittedBytes, encodeUnsigned(importSectionBytes.length));
3712
- concat(emittedBytes, importSectionBytes);
3713
- }
3714
- const functionSectionBytes = [];
3715
- concat(functionSectionBytes, encodeUnsigned(this.#functions.length));
3716
- for (let i = 0; i < this.#functions.length; i++) {
3717
- const typeIndex = this.#importedFunctions.length + i;
3718
- concat(functionSectionBytes, encodeUnsigned(typeIndex));
3719
- }
3720
- emittedBytes.push(3);
3721
- concat(emittedBytes, encodeUnsigned(functionSectionBytes.length));
3722
- concat(emittedBytes, functionSectionBytes);
3723
- const memorySectionBytes = [];
3724
- if (!this.memory.isImport && (this.memory.min || this.memory.max)) {
3725
- memorySectionBytes.push(1);
3726
- if (this.memory.min && this.memory.max) {
3727
- if (this.memory.isShared) memorySectionBytes.push(3);
3728
- else memorySectionBytes.push(1);
3729
- concat(memorySectionBytes, encodeUnsigned(this.memory.min));
3730
- concat(memorySectionBytes, encodeUnsigned(this.memory.max));
3731
- } else {
3732
- assert(!this.memory.isShared, "shared memory must have a max size");
3733
- memorySectionBytes.push(0);
3734
- concat(memorySectionBytes, encodeUnsigned(this.memory.min));
3735
- }
3736
- emittedBytes.push(5);
3737
- concat(emittedBytes, encodeUnsigned(memorySectionBytes.length));
3738
- concat(emittedBytes, memorySectionBytes);
3739
- }
3740
- const exportSectionBytes = [];
3741
- const numExports = this.#exportedFunctions.size + (this.memory.isExport ? 1 : 0);
3742
- concat(exportSectionBytes, encodeUnsigned(numExports));
3743
- if (this.memory.isExport) {
3744
- concat(exportSectionBytes, encodeString(this.memory.aString));
3745
- exportSectionBytes.push(2);
3746
- exportSectionBytes.push(0);
3747
- }
3748
- for (const [key, name] of this.#exportedFunctions.entries()) {
3749
- concat(exportSectionBytes, encodeString(name));
3750
- exportSectionBytes.push(0);
3751
- concat(exportSectionBytes, encodeUnsigned(key));
3752
- }
3753
- emittedBytes.push(7);
3754
- concat(emittedBytes, encodeUnsigned(exportSectionBytes.length));
3755
- concat(emittedBytes, exportSectionBytes);
3756
- const codeSectionBytes = [];
3757
- concat(codeSectionBytes, encodeUnsigned(this.#functions.length));
3758
- for (const f of this.#functions) {
3759
- this.#typeStack = [];
3760
- this.#blockFrames = [{
3761
- idx: 0,
3762
- ty: f.outputTypes
3763
- }];
3764
- this.#curFunction = f;
3765
- this.#curBytes = [];
3766
- f.emit();
3767
- this.end();
3768
- const bodyBytes = [...this.#curBytes];
3769
- this.#curBytes = [];
3770
- concat(this.#curBytes, encodeUnsigned(f.locals.length));
3771
- for (const l of f.locals) {
3772
- this._emit(1);
3773
- this._emit(l.typeId);
3774
- }
3775
- const headerBytes = [...this.#curBytes];
3776
- const fnSize = headerBytes.length + bodyBytes.length;
3777
- concat(codeSectionBytes, encodeUnsigned(fnSize));
3778
- concat(codeSectionBytes, headerBytes);
3779
- concat(codeSectionBytes, bodyBytes);
3780
- }
3781
- this.#curFunction = null;
3782
- emittedBytes.push(10);
3783
- concat(emittedBytes, encodeUnsigned(codeSectionBytes.length));
3784
- concat(emittedBytes, codeSectionBytes);
3785
- return new Uint8Array(emittedBytes);
4170
+ emit() {
4171
+ this.locals = [];
4172
+ this.body();
3786
4173
  }
3787
4174
  };
3788
- var Local = class {
4175
+ var Memory = class {
4176
+ min = 0;
4177
+ max = 0;
4178
+ isShared = false;
4179
+ aString = "";
4180
+ bString = "";
3789
4181
  constructor(cg) {
3790
4182
  this.cg = cg;
3791
4183
  }
3792
- declare(type) {
3793
- return this.cg._declareLocal(type);
4184
+ /** Declare the size of the memory. Each page is 64 KiB. */
4185
+ pages(min, max = 0) {
4186
+ assert(this.min === 0 && this.max === 0);
4187
+ this.min = min;
4188
+ this.max = max;
4189
+ return this;
3794
4190
  }
3795
- get(idx) {
3796
- assert(Number.isInteger(idx), "getting non-integer local");
3797
- const inputTypes = this.cg._inputTypes();
3798
- if (idx < inputTypes.length) this.cg._push(inputTypes[idx]);
3799
- else this.cg._push(this.cg._locals()[idx - inputTypes.length]);
3800
- this.cg._emit(32);
3801
- this.cg._emit(encodeUnsigned(idx));
4191
+ export(a) {
4192
+ assert(!this.isImport && !this.isExport, "already set");
4193
+ this.aString = a;
4194
+ return this;
3802
4195
  }
3803
- set(idx) {
3804
- const t = this.cg._pop();
3805
- const inputTypes = this.cg._inputTypes();
3806
- const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
3807
- assert(expectedType.typeId === t.typeId, "can't set local to this value (wrong type)");
3808
- this.cg._emit(33);
3809
- this.cg._emit(encodeUnsigned(idx));
4196
+ shared(isShared) {
4197
+ this.isShared = isShared;
4198
+ return this;
4199
+ }
4200
+ import(a, b) {
4201
+ assert(!this.isImport && !this.isExport, "already set");
4202
+ this.aString = a;
4203
+ this.bString = b;
4204
+ return this;
4205
+ }
4206
+ size() {
4207
+ this.cg._emit(63);
4208
+ this.cg._emit(0);
4209
+ }
4210
+ grow() {
4211
+ this.cg._emit(64);
4212
+ this.cg._emit(0);
3810
4213
  }
3811
- tee(idx) {
3812
- const t = this.cg._pop();
3813
- const inputTypes = this.cg._inputTypes();
3814
- const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
3815
- assert(expectedType.typeId === t.typeId, "can't tee local to this value (wrong type)");
3816
- this.cg._emit(34);
3817
- this.cg._emit(encodeUnsigned(idx));
3818
- this.cg._push(expectedType);
4214
+ get isImport() {
4215
+ return this.aString.length > 0 && this.bString.length > 0;
4216
+ }
4217
+ get isExport() {
4218
+ return this.aString.length > 0 && this.bString.length === 0;
3819
4219
  }
3820
4220
  };
3821
- function UNARY_OP(op, opcode, inType, outType) {
3822
- return function() {
3823
- const t = this.cg._pop();
3824
- assert(t.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inType} -> ${outType})`);
3825
- this.cg._emit(encodeOpcode(opcode));
3826
- this.cg._push(this.cg[outType]);
3827
- };
3828
- }
3829
- function BINARY_OP(op, opcode, typeA, typeB, outType) {
3830
- return function() {
3831
- const b = this.cg._pop();
3832
- const a = this.cg._pop();
3833
- assert(a.typeId === this.cg[typeA].typeId && b.typeId === this.cg[typeB].typeId, `invalid type for ${op} (${typeA}, ${typeB} -> ${outType})`);
3834
- this.cg._emit(encodeOpcode(opcode));
3835
- this.cg._push(this.cg[outType]);
3836
- };
3837
- }
3838
- function LOAD_OP(op, opcode, outType) {
3839
- return function(align = 0, offset = 0) {
3840
- const idxType = this.cg._pop();
3841
- assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3842
- this.cg._emit(encodeOpcode(opcode));
3843
- this.cg._emit(encodeUnsigned(align));
3844
- this.cg._emit(encodeUnsigned(offset));
3845
- this.cg._push(this.cg[outType]);
3846
- };
3847
- }
3848
- function STORE_OP(op, opcode, inType) {
3849
- return function(align = 0, offset = 0) {
3850
- const valType = this.cg._pop();
3851
- const idxType = this.cg._pop();
3852
- assert(valType.typeId === this.cg[inType].typeId, `invalid value type for ${op} (${inType})`);
3853
- assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3854
- this.cg._emit(encodeOpcode(opcode));
3855
- this.cg._emit(encodeUnsigned(align));
3856
- this.cg._emit(encodeUnsigned(offset));
4221
+ /** Public API of WebAssembly assembler. */
4222
+ var CodeGenerator = class {
4223
+ local;
4224
+ i32;
4225
+ f32;
4226
+ f64;
4227
+ v128;
4228
+ i32x4;
4229
+ f32x4;
4230
+ memory;
4231
+ void = {
4232
+ typeId: 64,
4233
+ name: "void"
3857
4234
  };
3858
- }
3859
- var I32 = class {
3860
- constructor(cg) {
3861
- this.cg = cg;
4235
+ #functions = [];
4236
+ #importedFunctions = [];
4237
+ #exportedFunctions = /* @__PURE__ */ new Map();
4238
+ #curFunction = null;
4239
+ #curBytes = [];
4240
+ #typeStack = [];
4241
+ #blockFrames = [];
4242
+ constructor() {
4243
+ this.local = new Local(this);
4244
+ this.i32 = new I32(this);
4245
+ this.f32 = new F32(this);
4246
+ this.f64 = new F64(this);
4247
+ this.v128 = new V128(this);
4248
+ this.i32x4 = new I32x4(this);
4249
+ this.f32x4 = new F32x4(this);
4250
+ this.memory = new Memory(this);
3862
4251
  }
3863
- get typeId() {
3864
- return 127;
4252
+ unreachable() {
4253
+ this._emit(0);
3865
4254
  }
3866
- get name() {
3867
- return "i32";
4255
+ nop() {
4256
+ this._emit(1);
3868
4257
  }
3869
- const(i) {
3870
- this.cg._emit(65);
3871
- this.cg._emit(encodeSigned(i));
3872
- this.cg._push(this);
4258
+ block(...type) {
4259
+ this.#blockFrames.push({
4260
+ idx: this.#typeStack.length,
4261
+ ty: type
4262
+ });
4263
+ this._emit(2);
4264
+ this._emit(encodeBlocktype(type));
3873
4265
  }
3874
- clz = UNARY_OP("clz", 103, "i32", "i32");
3875
- ctz = UNARY_OP("ctz", 104, "i32", "i32");
3876
- popcnt = UNARY_OP("popcnt", 105, "i32", "i32");
3877
- lt_s = BINARY_OP("lt_s", 72, "i32", "i32", "i32");
3878
- lt_u = BINARY_OP("lt_u", 73, "i32", "i32", "i32");
3879
- gt_s = BINARY_OP("gt_s", 74, "i32", "i32", "i32");
3880
- gt_u = BINARY_OP("gt_u", 75, "i32", "i32", "i32");
3881
- le_s = BINARY_OP("le_s", 76, "i32", "i32", "i32");
3882
- le_u = BINARY_OP("le_u", 77, "i32", "i32", "i32");
3883
- ge_s = BINARY_OP("ge_s", 78, "i32", "i32", "i32");
3884
- ge_u = BINARY_OP("ge_u", 79, "i32", "i32", "i32");
3885
- add = BINARY_OP("add", 106, "i32", "i32", "i32");
3886
- sub = BINARY_OP("sub", 107, "i32", "i32", "i32");
3887
- mul = BINARY_OP("mul", 108, "i32", "i32", "i32");
3888
- div_s = BINARY_OP("div_s", 109, "i32", "i32", "i32");
3889
- div_u = BINARY_OP("div_u", 110, "i32", "i32", "i32");
3890
- rem_s = BINARY_OP("rem_s", 111, "i32", "i32", "i32");
3891
- rem_u = BINARY_OP("rem_u", 112, "i32", "i32", "i32");
3892
- and = BINARY_OP("and", 113, "i32", "i32", "i32");
3893
- or = BINARY_OP("or", 114, "i32", "i32", "i32");
3894
- xor = BINARY_OP("xor", 115, "i32", "i32", "i32");
3895
- shl = BINARY_OP("shl", 116, "i32", "i32", "i32");
3896
- shr_s = BINARY_OP("shr_s", 117, "i32", "i32", "i32");
3897
- shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
3898
- rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
3899
- rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
3900
- eqz = UNARY_OP("eqz", 69, "i32", "i32");
3901
- eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
3902
- ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3903
- trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
3904
- trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
3905
- trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
3906
- trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
3907
- load = LOAD_OP("load", 40, "i32");
3908
- load8_s = LOAD_OP("load8_s", 44, "i32");
3909
- load8_u = LOAD_OP("load8_u", 45, "i32");
3910
- load16_s = LOAD_OP("load16_s", 46, "i32");
3911
- load16_u = LOAD_OP("load16_u", 47, "i32");
3912
- store = STORE_OP("store", 54, "i32");
3913
- store8 = STORE_OP("store8", 58, "i32");
3914
- store16 = STORE_OP("store16", 59, "i32");
3915
- reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
3916
- trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
3917
- trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
3918
- trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
3919
- trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
3920
- };
3921
- var F32 = class {
3922
- constructor(cg) {
3923
- this.cg = cg;
4266
+ loop(...type) {
4267
+ this.#blockFrames.push({
4268
+ idx: this.#typeStack.length,
4269
+ ty: type
4270
+ });
4271
+ this._emit(3);
4272
+ this._emit(encodeBlocktype(type));
3924
4273
  }
3925
- get typeId() {
3926
- return 125;
4274
+ if(...type) {
4275
+ assert(this._pop().typeId === this.i32.typeId, "if_: expected i32");
4276
+ this.#blockFrames.push({
4277
+ idx: this.#typeStack.length,
4278
+ ty: type
4279
+ });
4280
+ this._emit(4);
4281
+ this._emit(encodeBlocktype(type));
3927
4282
  }
3928
- get name() {
3929
- return "f32";
4283
+ else() {
4284
+ assert(this.#blockFrames.length > 0, "else: no block to else");
4285
+ const frame = this.#blockFrames[this.#blockFrames.length - 1];
4286
+ this.#typeStack = this.#typeStack.slice(0, frame.idx);
4287
+ this._emit(5);
3930
4288
  }
3931
- const(f) {
3932
- this.cg._emit(67);
3933
- const buffer = /* @__PURE__ */ new ArrayBuffer(4);
3934
- new DataView(buffer).setFloat32(0, f, true);
3935
- const bytes = new Uint8Array(buffer);
3936
- for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
3937
- this.cg._push(this);
4289
+ /** End a block (`block`, `if`/`else`, `loop`, or function). */
4290
+ end() {
4291
+ const frame = this.#blockFrames.pop();
4292
+ assert(frame !== void 0, "end: no block to end");
4293
+ this.#typeStack = this.#typeStack.slice(0, frame.idx);
4294
+ for (const ty of frame.ty) if (ty.typeId !== this.void.typeId) this._push(ty);
4295
+ this._emit(11);
3938
4296
  }
3939
- load = LOAD_OP("load", 42, "f32");
3940
- store = STORE_OP("store", 56, "f32");
3941
- eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
3942
- ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
3943
- lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
3944
- gt = BINARY_OP("gt", 94, "f32", "f32", "i32");
3945
- le = BINARY_OP("le", 95, "f32", "f32", "i32");
3946
- ge = BINARY_OP("ge", 96, "f32", "f32", "i32");
3947
- abs = UNARY_OP("abs", 139, "f32", "f32");
3948
- neg = UNARY_OP("neg", 140, "f32", "f32");
3949
- ceil = UNARY_OP("ceil", 141, "f32", "f32");
3950
- floor = UNARY_OP("floor", 142, "f32", "f32");
3951
- trunc = UNARY_OP("trunc", 143, "f32", "f32");
3952
- nearest = UNARY_OP("nearest", 144, "f32", "f32");
3953
- sqrt = UNARY_OP("sqrt", 145, "f32", "f32");
3954
- add = BINARY_OP("add", 146, "f32", "f32", "f32");
3955
- sub = BINARY_OP("sub", 147, "f32", "f32", "f32");
3956
- mul = BINARY_OP("mul", 148, "f32", "f32", "f32");
3957
- div = BINARY_OP("div", 149, "f32", "f32", "f32");
3958
- min = BINARY_OP("min", 150, "f32", "f32", "f32");
3959
- max = BINARY_OP("max", 151, "f32", "f32", "f32");
3960
- copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
3961
- convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
3962
- convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
3963
- demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
3964
- reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
3965
- };
3966
- var F64 = class {
3967
- constructor(cg) {
3968
- this.cg = cg;
4297
+ /** Branch to a block a certain depth outward on the stack. */
4298
+ br(depth) {
4299
+ this._emit(12);
4300
+ this._emit(encodeUnsigned(depth));
3969
4301
  }
3970
- get typeId() {
3971
- return 124;
4302
+ /** Conditional branch to a block a certain depth outward on the stack. */
4303
+ br_if(depth) {
4304
+ assert(this._pop().typeId === this.i32.typeId, "br_if: expected i32");
4305
+ this._emit(13);
4306
+ this._emit(encodeUnsigned(depth));
4307
+ }
4308
+ /** Jump table that indexes into a label vector (like switch). */
4309
+ br_table(...depths) {
4310
+ assert(this._pop().typeId === this.i32.typeId, "br_table: expected i32");
4311
+ assert(depths.length > 0, "br_table: expected at least one default depth");
4312
+ this._emit(14);
4313
+ this._emit(encodeUnsigned(depths.length - 1));
4314
+ for (const d of depths) this._emit(encodeUnsigned(d));
4315
+ }
4316
+ /** Return from a function, branching out of the outermost block. */
4317
+ return() {
4318
+ this._emit(15);
4319
+ }
4320
+ /** Call a function with the given ID. */
4321
+ call(fn) {
4322
+ const totalFunctions = this.#importedFunctions.length + this.#functions.length;
4323
+ assert(fn < totalFunctions, "function index does not exist");
4324
+ const func = fn < this.#importedFunctions.length ? this.#importedFunctions[fn] : this.#functions[fn - this.#importedFunctions.length];
4325
+ for (let i = func.inputTypes.length - 1; i >= 0; i--) {
4326
+ const argType = this._pop();
4327
+ assert(argType.typeId === func.inputTypes[i].typeId, `call: argument ${i} type mismatch, expected ${func.inputTypes[i].name} got ${argType.name}`);
4328
+ }
4329
+ for (const outputType of func.outputTypes) this._push(outputType);
4330
+ this._emit(16);
4331
+ this._emit(encodeUnsigned(fn));
3972
4332
  }
3973
- get name() {
3974
- return "f64";
4333
+ /** Throw away an operand on the stack. */
4334
+ drop() {
4335
+ this._pop();
4336
+ this._emit(26);
3975
4337
  }
3976
- const(f) {
3977
- this.cg._emit(68);
3978
- const buffer = /* @__PURE__ */ new ArrayBuffer(8);
3979
- new DataView(buffer).setFloat64(0, f, true);
3980
- const bytes = new Uint8Array(buffer);
3981
- for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
3982
- this.cg._push(this);
4338
+ /** Select one of the first two operands (T, F) based on the third operand (i32)'s value. */
4339
+ select() {
4340
+ assert(this._pop().typeId === this.i32.typeId, "select: expected i32 condition");
4341
+ const [b, a] = [this._pop(), this._pop()];
4342
+ assert(a.typeId === b.typeId, "select: expected same type for both operands");
4343
+ this._push(a);
4344
+ this._emit(27);
3983
4345
  }
3984
- load = LOAD_OP("load", 43, "f64");
3985
- store = STORE_OP("store", 57, "f64");
3986
- eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
3987
- ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
3988
- lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
3989
- gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
3990
- le = BINARY_OP("le", 101, "f64", "f64", "i32");
3991
- ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
3992
- abs = UNARY_OP("abs", 153, "f64", "f64");
3993
- neg = UNARY_OP("neg", 154, "f64", "f64");
3994
- ceil = UNARY_OP("ceil", 155, "f64", "f64");
3995
- floor = UNARY_OP("floor", 156, "f64", "f64");
3996
- trunc = UNARY_OP("trunc", 157, "f64", "f64");
3997
- nearest = UNARY_OP("nearest", 158, "f64", "f64");
3998
- sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
3999
- add = BINARY_OP("add", 160, "f64", "f64", "f64");
4000
- sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
4001
- mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
4002
- div = BINARY_OP("div", 163, "f64", "f64", "f64");
4003
- min = BINARY_OP("min", 164, "f64", "f64", "f64");
4004
- max = BINARY_OP("max", 165, "f64", "f64", "f64");
4005
- copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
4006
- convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
4007
- convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
4008
- promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
4009
- };
4010
- function VECTOR_OP(op, vopcode, inTypes, outType) {
4011
- return function() {
4012
- for (const inType of inTypes.toReversed()) {
4013
- const actualType = this.cg._pop();
4014
- assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes.join(", ")} -> ${outType})`);
4015
- }
4016
- this.cg._emit(encodeOpcode([253, vopcode]));
4017
- this.cg._push(this.cg[outType]);
4018
- };
4019
- }
4020
- function VECTOR_OPL(op, vopcode, inTypes, outType) {
4021
- return function(lane) {
4022
- for (const inType of inTypes.toReversed()) {
4023
- const actualType = this.cg._pop();
4024
- assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes} -> ${outType})`);
4025
- }
4026
- this.cg._emit(encodeOpcode([253, vopcode]));
4027
- this.cg._emit(lane);
4028
- this.cg._push(this.cg[outType]);
4029
- };
4030
- }
4031
- function VECTOR_LOAD_OP(op, vopcode) {
4032
- return function(align = 0, offset = 0) {
4033
- const idxType = this.cg._pop();
4034
- assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
4035
- this.cg._emit(encodeOpcode([253, vopcode]));
4036
- this.cg._emit(encodeUnsigned(align));
4037
- this.cg._emit(encodeUnsigned(offset));
4038
- this.cg._push(this.cg.v128);
4039
- };
4040
- }
4041
- var V128 = class {
4042
- constructor(cg) {
4043
- this.cg = cg;
4346
+ /** Import a JavaScript function; returns its index. */
4347
+ importFunction(module, name, inputTypes, outputTypes) {
4348
+ if (this.#functions.length > 0) throw new Error("function imports must precede defining functions");
4349
+ const idx = this.#importedFunctions.length;
4350
+ this.#importedFunctions.push({
4351
+ module,
4352
+ name,
4353
+ inputTypes,
4354
+ outputTypes
4355
+ });
4356
+ return idx;
4044
4357
  }
4045
- get typeId() {
4046
- return 123;
4358
+ /** Export a function. */
4359
+ export(fn, name) {
4360
+ this.#exportedFunctions.set(fn, name);
4047
4361
  }
4048
- get name() {
4049
- return "v128";
4362
+ /** Declare a new function; returns its index. */
4363
+ function(inputTypes, outputTypes, body) {
4364
+ const idx = this.#importedFunctions.length + this.#functions.length;
4365
+ this.#functions.push(new Function_(inputTypes, outputTypes, body));
4366
+ return idx;
4050
4367
  }
4051
- load = VECTOR_LOAD_OP("load", 0);
4052
- load32x2_s = VECTOR_LOAD_OP("load32x2_s", 5);
4053
- load32x2_u = VECTOR_LOAD_OP("load32x2_u", 6);
4054
- load32_splat = VECTOR_LOAD_OP("load32_splat", 9);
4055
- load32_zero = VECTOR_LOAD_OP("load32_zero", 92);
4056
- store(align = 0, offset = 0) {
4057
- const valType = this.cg._pop();
4058
- assert(valType.typeId === this.cg.v128.typeId, `invalid type for store`);
4059
- const idxType = this.cg._pop();
4060
- assert(idxType.typeId === this.cg.i32.typeId, `invalid type for store`);
4061
- this.cg._emit(253);
4062
- this.cg._emit(encodeUnsigned(11));
4063
- this.cg._emit(encodeUnsigned(align));
4064
- this.cg._emit(encodeUnsigned(offset));
4368
+ _declareLocal(type) {
4369
+ assert(this.#curFunction !== null, "No current function");
4370
+ const idx = this.#curFunction.locals.length + this.#curFunction.inputTypes.length;
4371
+ this.#curFunction.locals.push(type);
4372
+ return idx;
4065
4373
  }
4066
- not = VECTOR_OP("not", 77, ["v128"], "v128");
4067
- and = VECTOR_OP("and", 78, ["v128", "v128"], "v128");
4068
- andnot = VECTOR_OP("andnot", 79, ["v128", "v128"], "v128");
4069
- or = VECTOR_OP("or", 80, ["v128", "v128"], "v128");
4070
- xor = VECTOR_OP("xor", 81, ["v128", "v128"], "v128");
4071
- bitselect = VECTOR_OP("bitselect", 82, [
4072
- "v128",
4073
- "v128",
4074
- "v128"
4075
- ], "v128");
4076
- any_true = VECTOR_OP("any_true", 83, ["v128"], "i32");
4077
- };
4078
- var I32x4 = class extends V128 {
4079
- splat = VECTOR_OP("splat", 17, ["i32"], "v128");
4080
- extract_lane = VECTOR_OPL("extract_lane", 27, ["v128"], "i32");
4081
- replace_lane = VECTOR_OPL("replace_lane", 28, ["v128", "i32"], "v128");
4082
- eq = VECTOR_OP("eq", 55, ["v128", "v128"], "v128");
4083
- ne = VECTOR_OP("ne", 56, ["v128", "v128"], "v128");
4084
- lt_s = VECTOR_OP("lt_s", 57, ["v128", "v128"], "v128");
4085
- lt_u = VECTOR_OP("lt_u", 58, ["v128", "v128"], "v128");
4086
- gt_s = VECTOR_OP("gt_s", 59, ["v128", "v128"], "v128");
4087
- gt_u = VECTOR_OP("gt_u", 60, ["v128", "v128"], "v128");
4088
- le_s = VECTOR_OP("le_s", 61, ["v128", "v128"], "v128");
4089
- le_u = VECTOR_OP("le_u", 62, ["v128", "v128"], "v128");
4090
- ge_s = VECTOR_OP("ge_s", 63, ["v128", "v128"], "v128");
4091
- ge_u = VECTOR_OP("ge_u", 64, ["v128", "v128"], "v128");
4092
- abs = VECTOR_OP("abs", 160, ["v128"], "v128");
4093
- neg = VECTOR_OP("neg", 161, ["v128"], "v128");
4094
- all_true = VECTOR_OP("all_true", 163, ["v128"], "i32");
4095
- bitmask = VECTOR_OP("bitmask", 164, ["v128"], "i32");
4096
- shl = VECTOR_OP("shl", 171, ["v128", "i32"], "v128");
4097
- shr_s = VECTOR_OP("shr_s", 172, ["v128", "i32"], "v128");
4098
- shr_u = VECTOR_OP("shr_u", 173, ["v128", "i32"], "v128");
4099
- add = VECTOR_OP("add", 174, ["v128", "v128"], "v128");
4100
- sub = VECTOR_OP("sub", 177, ["v128", "v128"], "v128");
4101
- mul = VECTOR_OP("mul", 181, ["v128", "v128"], "v128");
4102
- min_s = VECTOR_OP("min_s", 182, ["v128", "v128"], "v128");
4103
- min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
4104
- max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
4105
- max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
4106
- trunc_sat_f32x4_s = VECTOR_OP("trunc_sat_f32x4_s", 248, ["v128"], "v128");
4107
- trunc_sat_f32x4_u = VECTOR_OP("trunc_sat_f32x4_u", 249, ["v128"], "v128");
4108
- };
4109
- var F32x4 = class extends V128 {
4110
- splat = VECTOR_OP("splat", 19, ["f32"], "v128");
4111
- extract_lane = VECTOR_OPL("extract_lane", 31, ["v128"], "f32");
4112
- replace_lane = VECTOR_OPL("replace_lane", 32, ["v128", "f32"], "v128");
4113
- eq = VECTOR_OP("eq", 65, ["v128", "v128"], "v128");
4114
- ne = VECTOR_OP("ne", 66, ["v128", "v128"], "v128");
4115
- lt = VECTOR_OP("lt", 67, ["v128", "v128"], "v128");
4116
- gt = VECTOR_OP("gt", 68, ["v128", "v128"], "v128");
4117
- le = VECTOR_OP("le", 69, ["v128", "v128"], "v128");
4118
- ge = VECTOR_OP("ge", 70, ["v128", "v128"], "v128");
4119
- ceil = VECTOR_OP("ceil", 103, ["v128"], "v128");
4120
- floor = VECTOR_OP("floor", 104, ["v128"], "v128");
4121
- trunc = VECTOR_OP("trunc", 105, ["v128"], "v128");
4122
- nearest = VECTOR_OP("nearest", 106, ["v128"], "v128");
4123
- abs = VECTOR_OP("abs", 224, ["v128"], "v128");
4124
- neg = VECTOR_OP("neg", 225, ["v128"], "v128");
4125
- sqrt = VECTOR_OP("sqrt", 227, ["v128"], "v128");
4126
- add = VECTOR_OP("add", 228, ["v128", "v128"], "v128");
4127
- sub = VECTOR_OP("sub", 229, ["v128", "v128"], "v128");
4128
- mul = VECTOR_OP("mul", 230, ["v128", "v128"], "v128");
4129
- div = VECTOR_OP("div", 231, ["v128", "v128"], "v128");
4130
- min = VECTOR_OP("min", 232, ["v128", "v128"], "v128");
4131
- max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
4132
- pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
4133
- pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
4134
- convert_i32x4_s = VECTOR_OP("convert_i32x4_s", 250, ["v128"], "v128");
4135
- convert_i32x4_u = VECTOR_OP("convert_i32x4_u", 251, ["v128"], "v128");
4136
- };
4137
-
4138
- //#endregion
4139
- //#region src/backend/wasm.ts
4140
- /**
4141
- * SIMD version of translateExp: emits v128 (f32x4 or i32x4) instructions instead of scalar.
4142
- * gidx always steps by 4. strideMap classifies each GlobalIndex as broadcast/contiguous/gather.
4143
- */
4144
- function translateExpSimd(cg, funcs, exp, ctx, strideMap) {
4145
- const references = /* @__PURE__ */ new Map();
4146
- const seen = /* @__PURE__ */ new Set();
4147
- const countReferences = (exp$1) => {
4148
- references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
4149
- if (!seen.has(exp$1)) {
4150
- seen.add(exp$1);
4151
- for (const src of exp$1.src) countReferences(src);
4152
- }
4153
- };
4154
- const expContext = /* @__PURE__ */ new Map();
4155
- const gen = (exp$1) => {
4156
- if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
4157
- const { op, src, arg, dtype } = exp$1;
4158
- const isInt = dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool;
4159
- const isSigned = dtype === DType.Int32;
4160
- if (op === AluOp.Add) {
4161
- gen(src[0]);
4162
- gen(src[1]);
4163
- if (isInt) cg.i32x4.add();
4164
- else cg.f32x4.add();
4165
- } else if (op === AluOp.Sub) {
4166
- gen(src[0]);
4167
- gen(src[1]);
4168
- if (isInt) cg.i32x4.sub();
4169
- else cg.f32x4.sub();
4170
- } else if (op === AluOp.Mul) {
4171
- gen(src[0]);
4172
- gen(src[1]);
4173
- if (isInt) cg.i32x4.mul();
4174
- else cg.f32x4.mul();
4175
- } else if (op === AluOp.Min) {
4176
- gen(src[0]);
4177
- gen(src[1]);
4178
- if (isInt) if (isSigned) cg.i32x4.min_s();
4179
- else cg.i32x4.min_u();
4180
- else cg.f32x4.min();
4181
- } else if (op === AluOp.Max) {
4182
- gen(src[0]);
4183
- gen(src[1]);
4184
- if (isInt) if (isSigned) cg.i32x4.max_s();
4185
- else cg.i32x4.max_u();
4186
- else cg.f32x4.max();
4187
- } else if (op === AluOp.Sqrt) {
4188
- gen(src[0]);
4189
- cg.f32x4.sqrt();
4190
- } else if (op === AluOp.Floor) {
4191
- gen(src[0]);
4192
- cg.f32x4.floor();
4193
- } else if (op === AluOp.Ceil) {
4194
- gen(src[0]);
4195
- cg.f32x4.ceil();
4196
- } else if (op === AluOp.Const) if (isInt) {
4197
- cg.i32.const(arg);
4198
- cg.i32x4.splat();
4199
- } else {
4200
- cg.f32.const(arg);
4201
- cg.f32x4.splat();
4374
+ _inputTypes() {
4375
+ assert(this.#curFunction !== null, "No current function");
4376
+ return this.#curFunction.inputTypes;
4377
+ }
4378
+ _locals() {
4379
+ assert(this.#curFunction !== null, "No current function");
4380
+ return this.#curFunction.locals;
4381
+ }
4382
+ _push(type) {
4383
+ if (!type) throw new Error(`pushing type ${type}`);
4384
+ this.#typeStack.push(type);
4385
+ }
4386
+ _pop() {
4387
+ assert(this.#typeStack.length > 0, "popping empty stack");
4388
+ return this.#typeStack.pop();
4389
+ }
4390
+ _emit(bytes) {
4391
+ if (typeof bytes === "number") this.#curBytes.push(bytes);
4392
+ else this.#curBytes.push(...bytes);
4393
+ }
4394
+ finish() {
4395
+ this.#curBytes = [];
4396
+ const emittedBytes = [];
4397
+ emittedBytes.push(...magicModuleHeader);
4398
+ emittedBytes.push(...moduleVersion);
4399
+ const typeSectionBytes = [];
4400
+ const totalFunctionTypes = this.#importedFunctions.length + this.#functions.length;
4401
+ typeSectionBytes.push(...encodeUnsigned(totalFunctionTypes));
4402
+ for (const f of [...this.#importedFunctions, ...this.#functions]) {
4403
+ typeSectionBytes.push(96);
4404
+ typeSectionBytes.push(...encodeUnsigned(f.inputTypes.length));
4405
+ for (const t of f.inputTypes) typeSectionBytes.push(t.typeId);
4406
+ typeSectionBytes.push(...encodeUnsigned(f.outputTypes.length));
4407
+ for (const t of f.outputTypes) typeSectionBytes.push(t.typeId);
4202
4408
  }
4203
- else if (op === AluOp.Cast) {
4204
- gen(src[0]);
4205
- const dtype0 = src[0].dtype;
4206
- const src0IsInt = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
4207
- if (isInt && !src0IsInt) if (isSigned) cg.i32x4.trunc_sat_f32x4_s();
4208
- else cg.i32x4.trunc_sat_f32x4_u();
4209
- else if (!isInt && src0IsInt) if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32x4.convert_i32x4_s();
4210
- else cg.f32x4.convert_i32x4_u();
4211
- } else if (op === AluOp.Cmplt) {
4212
- gen(src[0]);
4213
- gen(src[1]);
4214
- const srcDtype = src[0].dtype;
4215
- if (srcDtype === DType.Float32) cg.f32x4.lt();
4216
- else if (srcDtype === DType.Int32) cg.i32x4.lt_s();
4217
- else if (srcDtype === DType.Uint32) cg.i32x4.lt_u();
4218
- else throw new UnsupportedOpError(op, dtype, "wasm");
4219
- cg.i32.const(1);
4220
- cg.i32x4.splat();
4221
- cg.v128.and();
4222
- } else if (op === AluOp.Cmpne) {
4223
- gen(src[0]);
4224
- gen(src[1]);
4225
- const srcDtype = src[0].dtype;
4226
- if (srcDtype === DType.Float32) cg.f32x4.ne();
4227
- else cg.i32x4.ne();
4228
- cg.i32.const(1);
4229
- cg.i32x4.splat();
4230
- cg.v128.and();
4231
- } else if (op === AluOp.Where) {
4232
- gen(src[1]);
4233
- gen(src[2]);
4234
- gen(src[0]);
4235
- cg.i32.const(0);
4236
- cg.i32x4.splat();
4237
- cg.i32x4.ne();
4238
- cg.v128.bitselect();
4239
- } else if (op === AluOp.Variable || op === AluOp.Special) throw new Error(`translateExpSimd: unexpected ${op}(${arg})`);
4240
- else if (op === AluOp.GlobalIndex) {
4241
- const [gid, len] = arg;
4242
- const indexSubtree = src[0];
4243
- const stride = strideMap.get(exp$1) ?? GATHER;
4244
- if (stride.kind === "contiguous") {
4245
- translateExp(cg, funcs, indexSubtree, ctx);
4246
- {
4247
- const maxIdx = Math.max(len - SIMD_LANES, 0);
4248
- const wideIdx = cg.local.declare(cg.i32);
4249
- cg.local.set(wideIdx);
4250
- cg.local.get(wideIdx);
4251
- cg.i32.const(maxIdx);
4252
- cg.local.get(wideIdx);
4253
- cg.i32.const(maxIdx);
4254
- cg.i32.lt_u();
4255
- cg.select();
4256
- }
4257
- cg.i32.const(byteWidth(dtype));
4258
- cg.i32.mul();
4259
- cg.local.get(gid);
4260
- cg.i32.add();
4261
- if (isInt) cg.i32x4.load(4);
4262
- else cg.f32x4.load(4);
4263
- } else if (stride.kind === "broadcast") {
4264
- translateExp(cg, funcs, indexSubtree, ctx);
4265
- const local = cg.local.declare(cg.i32);
4266
- cg.local.tee(local);
4267
- cg.i32.const(0);
4268
- cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
4269
- cg.select();
4270
- cg.i32.const(byteWidth(dtype));
4271
- cg.i32.mul();
4272
- cg.local.get(gid);
4273
- cg.i32.add();
4274
- if (isInt) {
4275
- cg.i32.load(2);
4276
- cg.i32x4.splat();
4409
+ emittedBytes.push(1);
4410
+ appendLengthEncodedBlock(emittedBytes, typeSectionBytes);
4411
+ const importSectionBytes = [];
4412
+ const numImports = this.#importedFunctions.length + (this.memory.isImport ? 1 : 0);
4413
+ if (numImports > 0) {
4414
+ importSectionBytes.push(...encodeUnsigned(numImports));
4415
+ for (let i = 0; i < this.#importedFunctions.length; i++) {
4416
+ const f = this.#importedFunctions[i];
4417
+ importSectionBytes.push(...encodeString(f.module));
4418
+ importSectionBytes.push(...encodeString(f.name));
4419
+ importSectionBytes.push(0);
4420
+ importSectionBytes.push(...encodeUnsigned(i));
4421
+ }
4422
+ if (this.memory.isImport) {
4423
+ importSectionBytes.push(...encodeString(this.memory.aString));
4424
+ importSectionBytes.push(...encodeString(this.memory.bString));
4425
+ importSectionBytes.push(2);
4426
+ if (this.memory.max) {
4427
+ if (this.memory.isShared) importSectionBytes.push(3);
4428
+ else importSectionBytes.push(1);
4429
+ importSectionBytes.push(...encodeUnsigned(this.memory.min));
4430
+ importSectionBytes.push(...encodeUnsigned(this.memory.max));
4277
4431
  } else {
4278
- cg.f32.load(2);
4279
- cg.f32x4.splat();
4432
+ assert(!this.memory.isShared, "shared memory must have a max size");
4433
+ importSectionBytes.push(0);
4434
+ importSectionBytes.push(...encodeUnsigned(this.memory.min));
4280
4435
  }
4436
+ }
4437
+ emittedBytes.push(2);
4438
+ appendLengthEncodedBlock(emittedBytes, importSectionBytes);
4439
+ }
4440
+ const functionSectionBytes = [];
4441
+ functionSectionBytes.push(...encodeUnsigned(this.#functions.length));
4442
+ for (let i = 0; i < this.#functions.length; i++) {
4443
+ const typeIndex = this.#importedFunctions.length + i;
4444
+ functionSectionBytes.push(...encodeUnsigned(typeIndex));
4445
+ }
4446
+ emittedBytes.push(3);
4447
+ appendLengthEncodedBlock(emittedBytes, functionSectionBytes);
4448
+ const memorySectionBytes = [];
4449
+ if (!this.memory.isImport && (this.memory.min || this.memory.max)) {
4450
+ memorySectionBytes.push(1);
4451
+ if (this.memory.min && this.memory.max) {
4452
+ if (this.memory.isShared) memorySectionBytes.push(3);
4453
+ else memorySectionBytes.push(1);
4454
+ memorySectionBytes.push(...encodeUnsigned(this.memory.min));
4455
+ memorySectionBytes.push(...encodeUnsigned(this.memory.max));
4281
4456
  } else {
4282
- const steppingLocal = ctx["gidx"];
4283
- const origValue = cg.local.declare(cg.i32);
4284
- cg.local.get(steppingLocal);
4285
- cg.local.set(origValue);
4286
- if (isInt) {
4287
- cg.i32.const(0);
4288
- cg.i32x4.splat();
4289
- } else {
4290
- cg.f32.const(0);
4291
- cg.f32x4.splat();
4292
- }
4293
- const vec = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
4294
- cg.local.set(vec);
4295
- const idx = cg.local.declare(cg.i32);
4296
- const scalarVal = cg.local.declare(isInt ? cg.i32 : cg.f32);
4297
- for (let lane = 0; lane < SIMD_LANES; lane++) {
4298
- cg.local.get(origValue);
4299
- if (lane > 0) {
4300
- cg.i32.const(lane);
4301
- cg.i32.add();
4302
- }
4303
- cg.local.set(steppingLocal);
4304
- translateExp(cg, funcs, indexSubtree, ctx);
4305
- cg.local.tee(idx);
4306
- cg.i32.const(0);
4307
- cg.local.get(idx), cg.i32.const(len), cg.i32.lt_u();
4308
- cg.select();
4309
- cg.i32.const(byteWidth(dtype));
4310
- cg.i32.mul();
4311
- cg.local.get(gid);
4312
- cg.i32.add();
4313
- if (isInt) cg.i32.load(2);
4314
- else cg.f32.load(2);
4315
- cg.local.set(scalarVal);
4316
- cg.local.get(vec);
4317
- cg.local.get(scalarVal);
4318
- if (isInt) cg.i32x4.replace_lane(lane);
4319
- else cg.f32x4.replace_lane(lane);
4320
- cg.local.set(vec);
4321
- }
4322
- cg.local.get(origValue);
4323
- cg.local.set(steppingLocal);
4324
- cg.local.get(vec);
4457
+ assert(!this.memory.isShared, "shared memory must have a max size");
4458
+ memorySectionBytes.push(0);
4459
+ memorySectionBytes.push(...encodeUnsigned(this.memory.min));
4325
4460
  }
4326
- } else throw new Error(`translateExpSimd: unsupported op ${op}`);
4327
- if ((references.get(exp$1) ?? 0) > 1) {
4328
- const local = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
4329
- cg.local.tee(local);
4330
- expContext.set(exp$1, local);
4461
+ emittedBytes.push(5);
4462
+ appendLengthEncodedBlock(emittedBytes, memorySectionBytes);
4331
4463
  }
4332
- };
4333
- countReferences(exp);
4334
- gen(exp);
4335
- }
4336
- /** Number of SIMD lanes (f32x4 / i32x4 = 4 lanes). */
4337
- const SIMD_LANES = 4;
4338
- function referencesGidx(exp) {
4339
- if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return true;
4340
- return exp.src.some(referencesGidx);
4341
- }
4342
- /** When tileSize > N but doesn't divide evenly, the last group before the
4343
- * inner reset is shorter than N — a SIMD group could straddle it. */
4344
- function hasFragmentRisk(tileSize, N) {
4345
- return isFinite(tileSize) && tileSize > N && tileSize % N !== 0;
4346
- }
4347
- const GATHER = { kind: "gather" };
4348
- /**
4349
- * Classify how a GlobalIndex's index expression behaves as gidx increments.
4350
- */
4351
- function analyzeStride(exp) {
4352
- if (!referencesGidx(exp)) return {
4353
- kind: "broadcast",
4354
- tileSize: Infinity
4355
- };
4356
- if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return {
4357
- kind: "contiguous",
4358
- tileSize: Infinity
4359
- };
4360
- if (exp.op === AluOp.Idiv && exp.src[1].op === AluOp.Const) {
4361
- const N = exp.src[1].arg;
4362
- const inner = analyzeStride(exp.src[0]);
4363
- if (inner.kind === "broadcast") return inner;
4364
- if (inner.kind !== "contiguous") return GATHER;
4365
- if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
4366
- return {
4367
- kind: "broadcast",
4368
- tileSize: Math.min(inner.tileSize, N)
4369
- };
4464
+ const exportSectionBytes = [];
4465
+ const numExports = this.#exportedFunctions.size + (this.memory.isExport ? 1 : 0);
4466
+ exportSectionBytes.push(...encodeUnsigned(numExports));
4467
+ if (this.memory.isExport) {
4468
+ exportSectionBytes.push(...encodeString(this.memory.aString));
4469
+ exportSectionBytes.push(2);
4470
+ exportSectionBytes.push(0);
4471
+ }
4472
+ for (const [key, name] of this.#exportedFunctions.entries()) {
4473
+ exportSectionBytes.push(...encodeString(name));
4474
+ exportSectionBytes.push(0);
4475
+ exportSectionBytes.push(...encodeUnsigned(key));
4476
+ }
4477
+ emittedBytes.push(7);
4478
+ appendLengthEncodedBlock(emittedBytes, exportSectionBytes);
4479
+ const codeSectionBytes = [];
4480
+ codeSectionBytes.push(...encodeUnsigned(this.#functions.length));
4481
+ for (const f of this.#functions) {
4482
+ this.#typeStack = [];
4483
+ this.#blockFrames = [{
4484
+ idx: 0,
4485
+ ty: f.outputTypes
4486
+ }];
4487
+ this.#curFunction = f;
4488
+ this.#curBytes = [];
4489
+ f.emit();
4490
+ this.end();
4491
+ const bodyBytes = this.#curBytes;
4492
+ this.#curBytes = [];
4493
+ this.#curBytes.push(...encodeUnsigned(f.locals.length));
4494
+ for (const l of f.locals) {
4495
+ this._emit(1);
4496
+ this._emit(l.typeId);
4497
+ }
4498
+ const fnBytes = this.#curBytes.concat(bodyBytes);
4499
+ appendLengthEncodedBlock(codeSectionBytes, fnBytes);
4500
+ }
4501
+ this.#curFunction = null;
4502
+ emittedBytes.push(10);
4503
+ appendLengthEncodedBlock(emittedBytes, codeSectionBytes);
4504
+ return new Uint8Array(emittedBytes);
4370
4505
  }
4371
- if (exp.op === AluOp.Mod && exp.src[1].op === AluOp.Const) {
4372
- const N = exp.src[1].arg;
4373
- const inner = analyzeStride(exp.src[0]);
4374
- if (inner.kind === "broadcast") return inner;
4375
- if (inner.kind !== "contiguous") return GATHER;
4376
- if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
4377
- return {
4378
- kind: "contiguous",
4379
- tileSize: Math.min(inner.tileSize, N)
4380
- };
4506
+ };
4507
+ var Local = class {
4508
+ constructor(cg) {
4509
+ this.cg = cg;
4381
4510
  }
4382
- if (exp.op === AluOp.Mul) {
4383
- for (let i = 0; i < 2; i++) if (exp.src[i].op === AluOp.Const) {
4384
- const inner = analyzeStride(exp.src[1 - i]);
4385
- if (inner.kind === "broadcast") return inner;
4386
- return GATHER;
4387
- }
4511
+ declare(type) {
4512
+ return this.cg._declareLocal(type);
4513
+ }
4514
+ get(idx) {
4515
+ assert(Number.isInteger(idx), "getting non-integer local");
4516
+ const inputTypes = this.cg._inputTypes();
4517
+ if (idx < inputTypes.length) this.cg._push(inputTypes[idx]);
4518
+ else this.cg._push(this.cg._locals()[idx - inputTypes.length]);
4519
+ this.cg._emit(32);
4520
+ this.cg._emit(encodeUnsigned(idx));
4521
+ }
4522
+ set(idx) {
4523
+ const t = this.cg._pop();
4524
+ const inputTypes = this.cg._inputTypes();
4525
+ const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
4526
+ assert(expectedType.typeId === t.typeId, "can't set local to this value (wrong type)");
4527
+ this.cg._emit(33);
4528
+ this.cg._emit(encodeUnsigned(idx));
4388
4529
  }
4389
- if (exp.op === AluOp.Add) {
4390
- const lhsHasGidx = referencesGidx(exp.src[0]);
4391
- const rhsHasGidx = referencesGidx(exp.src[1]);
4392
- if (lhsHasGidx && !rhsHasGidx) return analyzeStride(exp.src[0]);
4393
- if (!lhsHasGidx && rhsHasGidx) return analyzeStride(exp.src[1]);
4530
+ tee(idx) {
4531
+ const t = this.cg._pop();
4532
+ const inputTypes = this.cg._inputTypes();
4533
+ const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
4534
+ assert(expectedType.typeId === t.typeId, "can't tee local to this value (wrong type)");
4535
+ this.cg._emit(34);
4536
+ this.cg._emit(encodeUnsigned(idx));
4537
+ this.cg._push(expectedType);
4394
4538
  }
4395
- return GATHER;
4539
+ };
4540
+ function UNARY_OP(op, opcode, inType, outType) {
4541
+ return function() {
4542
+ const t = this.cg._pop();
4543
+ assert(t.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inType} -> ${outType})`);
4544
+ this.cg._emit(encodeOpcode(opcode));
4545
+ this.cg._push(this.cg[outType]);
4546
+ };
4396
4547
  }
4397
- /** Ops that have direct SIMD (f32x4) instruction variants. */
4398
- const simdF32Ops = new Set([
4399
- AluOp.Add,
4400
- AluOp.Sub,
4401
- AluOp.Mul,
4402
- AluOp.Floor,
4403
- AluOp.Ceil,
4404
- AluOp.Min,
4405
- AluOp.Max,
4406
- AluOp.Sqrt,
4407
- AluOp.Cast,
4408
- AluOp.Where,
4409
- AluOp.Const,
4410
- AluOp.GlobalIndex
4411
- ]);
4412
- /** Ops that have direct SIMD (i32x4) instruction variants. */
4413
- const simdI32Ops = new Set([
4414
- AluOp.Add,
4415
- AluOp.Sub,
4416
- AluOp.Mul,
4417
- AluOp.Min,
4418
- AluOp.Max,
4419
- AluOp.Cast,
4420
- AluOp.Where,
4421
- AluOp.Const,
4422
- AluOp.GlobalIndex
4423
- ]);
4424
- /** Ops that produce Bool (i32x4 bitmask) in SIMD. */
4425
- const simdBoolOps = new Set([
4426
- AluOp.Cmplt,
4427
- AluOp.Cmpne,
4428
- AluOp.Const,
4429
- AluOp.GlobalIndex
4430
- ]);
4431
- /**
4432
- * Check if a kernel is eligible for SIMD codegen.
4433
- *
4434
- * A kernel qualifies when:
4435
- * - size >= 4 (need at least 4 elements for a SIMD group)
4436
- * - For reductions: the reduction op has a SIMD variant for its dtype
4437
- * - All nodes have a supported dtype (f32, i32, u32, bool) with SIMD variants
4438
- */
4439
- function isSimdEligible(tunedExp, kernel) {
4440
- if (kernel.size < SIMD_LANES) return false;
4441
- if (kernel.reduction) {
4442
- if (!simdSupportedOpsForDtype(kernel.reduction.dtype)?.has(kernel.reduction.op)) return false;
4443
- }
4444
- const check = (exp, visited) => {
4445
- if (visited.has(exp)) return true;
4446
- visited.add(exp);
4447
- const supportedOps = simdSupportedOpsForDtype(exp.dtype);
4448
- if (!supportedOps || !supportedOps.has(exp.op)) return false;
4449
- if (exp.op === AluOp.GlobalIndex) return true;
4450
- for (const child of exp.src) if (!check(child, visited)) return false;
4451
- return true;
4548
+ function BINARY_OP(op, opcode, typeA, typeB, outType) {
4549
+ return function() {
4550
+ const b = this.cg._pop();
4551
+ const a = this.cg._pop();
4552
+ assert(a.typeId === this.cg[typeA].typeId && b.typeId === this.cg[typeB].typeId, `invalid type for ${op} (${typeA}, ${typeB} -> ${outType})`);
4553
+ this.cg._emit(encodeOpcode(opcode));
4554
+ this.cg._push(this.cg[outType]);
4452
4555
  };
4453
- return check(tunedExp, /* @__PURE__ */ new Set());
4454
4556
  }
4455
- function simdSupportedOpsForDtype(dtype) {
4456
- if (dtype === DType.Float32) return simdF32Ops;
4457
- if (dtype === DType.Int32 || dtype === DType.Uint32) return simdI32Ops;
4458
- if (dtype === DType.Bool) return simdBoolOps;
4459
- return null;
4557
+ function LOAD_OP(op, opcode, outType) {
4558
+ return function(align = 0, offset = 0) {
4559
+ const idxType = this.cg._pop();
4560
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
4561
+ this.cg._emit(encodeOpcode(opcode));
4562
+ this.cg._emit(encodeUnsigned(align));
4563
+ this.cg._emit(encodeUnsigned(offset));
4564
+ this.cg._push(this.cg[outType]);
4565
+ };
4460
4566
  }
4461
- const moduleCache = /* @__PURE__ */ new Map();
4462
- /** Backend that compiles into WebAssembly bytecode for immediate execution. */
4463
- var WasmBackend = class {
4464
- type = "wasm";
4465
- maxArgs = 64;
4466
- #memory;
4467
- #nextSlot;
4468
- #allocator;
4469
- #buffers;
4470
- #workerPool;
4471
- #pendingWork = /* @__PURE__ */ new Map();
4472
- constructor() {
4473
- this.#memory = hasSharedArrayBuffer() ? new WebAssembly.Memory({
4474
- initial: 0,
4475
- maximum: 65536,
4476
- shared: true
4477
- }) : new WebAssembly.Memory({ initial: 0 });
4478
- this.#allocator = new WasmAllocator(this.#memory);
4479
- this.#nextSlot = 1;
4480
- this.#buffers = /* @__PURE__ */ new Map();
4481
- this.#workerPool = createWorkerPool(this.#memory);
4567
+ function STORE_OP(op, opcode, inType) {
4568
+ return function(align = 0, offset = 0) {
4569
+ const valType = this.cg._pop();
4570
+ const idxType = this.cg._pop();
4571
+ assert(valType.typeId === this.cg[inType].typeId, `invalid value type for ${op} (${inType})`);
4572
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
4573
+ this.cg._emit(encodeOpcode(opcode));
4574
+ this.cg._emit(encodeUnsigned(align));
4575
+ this.cg._emit(encodeUnsigned(offset));
4576
+ };
4577
+ }
4578
+ var I32 = class {
4579
+ constructor(cg) {
4580
+ this.cg = cg;
4482
4581
  }
4483
- malloc(size, initialData) {
4484
- const ptr = this.#allocator.malloc(size);
4485
- if (initialData) {
4486
- if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
4487
- new Uint8Array(this.#memory.buffer, ptr, size).set(initialData);
4488
- }
4489
- const slot = this.#nextSlot++;
4490
- this.#buffers.set(slot, {
4491
- ptr,
4492
- size,
4493
- ref: 1
4494
- });
4495
- return slot;
4582
+ get typeId() {
4583
+ return 127;
4496
4584
  }
4497
- incRef(slot) {
4498
- const buffer = this.#buffers.get(slot);
4499
- if (!buffer) throw new SlotError(slot);
4500
- buffer.ref++;
4585
+ get name() {
4586
+ return "i32";
4501
4587
  }
4502
- decRef(slot) {
4503
- const buffer = this.#buffers.get(slot);
4504
- if (!buffer) throw new SlotError(slot);
4505
- buffer.ref--;
4506
- if (buffer.ref === 0) {
4507
- this.#allocator.free(buffer.ptr);
4508
- this.#buffers.delete(slot);
4509
- }
4588
+ const(i) {
4589
+ this.cg._emit(65);
4590
+ this.cg._emit(encodeSigned(i));
4591
+ this.cg._push(this);
4510
4592
  }
4511
- async read(slot, start, count) {
4512
- const epoch = this.#pendingWork.get(slot);
4513
- if (epoch) await this.#workerPool.waitForEpoch(epoch);
4514
- return this.#readData(slot, start, count);
4593
+ clz = UNARY_OP("clz", 103, "i32", "i32");
4594
+ ctz = UNARY_OP("ctz", 104, "i32", "i32");
4595
+ popcnt = UNARY_OP("popcnt", 105, "i32", "i32");
4596
+ lt_s = BINARY_OP("lt_s", 72, "i32", "i32", "i32");
4597
+ lt_u = BINARY_OP("lt_u", 73, "i32", "i32", "i32");
4598
+ gt_s = BINARY_OP("gt_s", 74, "i32", "i32", "i32");
4599
+ gt_u = BINARY_OP("gt_u", 75, "i32", "i32", "i32");
4600
+ le_s = BINARY_OP("le_s", 76, "i32", "i32", "i32");
4601
+ le_u = BINARY_OP("le_u", 77, "i32", "i32", "i32");
4602
+ ge_s = BINARY_OP("ge_s", 78, "i32", "i32", "i32");
4603
+ ge_u = BINARY_OP("ge_u", 79, "i32", "i32", "i32");
4604
+ add = BINARY_OP("add", 106, "i32", "i32", "i32");
4605
+ sub = BINARY_OP("sub", 107, "i32", "i32", "i32");
4606
+ mul = BINARY_OP("mul", 108, "i32", "i32", "i32");
4607
+ div_s = BINARY_OP("div_s", 109, "i32", "i32", "i32");
4608
+ div_u = BINARY_OP("div_u", 110, "i32", "i32", "i32");
4609
+ rem_s = BINARY_OP("rem_s", 111, "i32", "i32", "i32");
4610
+ rem_u = BINARY_OP("rem_u", 112, "i32", "i32", "i32");
4611
+ and = BINARY_OP("and", 113, "i32", "i32", "i32");
4612
+ or = BINARY_OP("or", 114, "i32", "i32", "i32");
4613
+ xor = BINARY_OP("xor", 115, "i32", "i32", "i32");
4614
+ shl = BINARY_OP("shl", 116, "i32", "i32", "i32");
4615
+ shr_s = BINARY_OP("shr_s", 117, "i32", "i32", "i32");
4616
+ shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
4617
+ rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
4618
+ rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
4619
+ eqz = UNARY_OP("eqz", 69, "i32", "i32");
4620
+ eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
4621
+ ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
4622
+ trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
4623
+ trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
4624
+ trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
4625
+ trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
4626
+ load = LOAD_OP("load", 40, "i32");
4627
+ load8_s = LOAD_OP("load8_s", 44, "i32");
4628
+ load8_u = LOAD_OP("load8_u", 45, "i32");
4629
+ load16_s = LOAD_OP("load16_s", 46, "i32");
4630
+ load16_u = LOAD_OP("load16_u", 47, "i32");
4631
+ store = STORE_OP("store", 54, "i32");
4632
+ store8 = STORE_OP("store8", 58, "i32");
4633
+ store16 = STORE_OP("store16", 59, "i32");
4634
+ reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
4635
+ trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
4636
+ trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
4637
+ trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
4638
+ trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
4639
+ };
4640
+ var F32 = class {
4641
+ constructor(cg) {
4642
+ this.cg = cg;
4515
4643
  }
4516
- readSync(slot, start, count) {
4517
- const epoch = this.#pendingWork.get(slot);
4518
- if (epoch && this.#workerPool.epoch < epoch) throw new Error("cannot read synchronously from a slot with async work");
4519
- return this.#readData(slot, start, count);
4644
+ get typeId() {
4645
+ return 125;
4520
4646
  }
4521
- #readData(slot, start, count) {
4522
- const buffer = this.#getBuffer(slot);
4523
- if (start === void 0) start = 0;
4524
- if (count === void 0) count = buffer.byteLength - start;
4525
- if (buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
4526
- else return buffer.slice(start, start + count);
4647
+ get name() {
4648
+ return "f32";
4527
4649
  }
4528
- async prepareKernel(kernel) {
4529
- const kernelHash = FpHash.hash(kernel);
4530
- const module = await runWithCacheAsync(moduleCache, kernelHash.toString(), () => WebAssembly.compile(codegenWasm(kernel)));
4531
- return new Executable(kernel, {
4532
- module,
4533
- parallel: this.#workerPool !== null
4534
- });
4650
+ const(f) {
4651
+ this.cg._emit(67);
4652
+ const buffer = /* @__PURE__ */ new ArrayBuffer(4);
4653
+ new DataView(buffer).setFloat32(0, f, true);
4654
+ const bytes = new Uint8Array(buffer);
4655
+ for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
4656
+ this.cg._push(this);
4657
+ }
4658
+ load = LOAD_OP("load", 42, "f32");
4659
+ store = STORE_OP("store", 56, "f32");
4660
+ eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
4661
+ ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
4662
+ lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
4663
+ gt = BINARY_OP("gt", 94, "f32", "f32", "i32");
4664
+ le = BINARY_OP("le", 95, "f32", "f32", "i32");
4665
+ ge = BINARY_OP("ge", 96, "f32", "f32", "i32");
4666
+ abs = UNARY_OP("abs", 139, "f32", "f32");
4667
+ neg = UNARY_OP("neg", 140, "f32", "f32");
4668
+ ceil = UNARY_OP("ceil", 141, "f32", "f32");
4669
+ floor = UNARY_OP("floor", 142, "f32", "f32");
4670
+ trunc = UNARY_OP("trunc", 143, "f32", "f32");
4671
+ nearest = UNARY_OP("nearest", 144, "f32", "f32");
4672
+ sqrt = UNARY_OP("sqrt", 145, "f32", "f32");
4673
+ add = BINARY_OP("add", 146, "f32", "f32", "f32");
4674
+ sub = BINARY_OP("sub", 147, "f32", "f32", "f32");
4675
+ mul = BINARY_OP("mul", 148, "f32", "f32", "f32");
4676
+ div = BINARY_OP("div", 149, "f32", "f32", "f32");
4677
+ min = BINARY_OP("min", 150, "f32", "f32", "f32");
4678
+ max = BINARY_OP("max", 151, "f32", "f32", "f32");
4679
+ copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
4680
+ convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
4681
+ convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
4682
+ demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
4683
+ reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
4684
+ };
4685
+ var F64 = class {
4686
+ constructor(cg) {
4687
+ this.cg = cg;
4535
4688
  }
4536
- prepareKernelSync(kernel) {
4537
- const kernelHash = FpHash.hash(kernel);
4538
- const module = runWithCache(moduleCache, kernelHash.toString(), () => new WebAssembly.Module(codegenWasm(kernel)));
4539
- return new Executable(kernel, {
4540
- module,
4541
- parallel: false
4542
- });
4689
+ get typeId() {
4690
+ return 124;
4543
4691
  }
4544
- async prepareRoutine(routine) {
4545
- return this.prepareRoutineSync(routine);
4692
+ get name() {
4693
+ return "f64";
4546
4694
  }
4547
- prepareRoutineSync(routine) {
4548
- return new Executable(routine, {
4549
- module: void 0,
4550
- parallel: false
4551
- });
4695
+ const(f) {
4696
+ this.cg._emit(68);
4697
+ const buffer = /* @__PURE__ */ new ArrayBuffer(8);
4698
+ new DataView(buffer).setFloat64(0, f, true);
4699
+ const bytes = new Uint8Array(buffer);
4700
+ for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
4701
+ this.cg._push(this);
4552
4702
  }
4553
- dispatch(exe, inputs, outputs) {
4554
- const tracing = isTracing();
4555
- const start = tracing ? performance.now() : 0;
4556
- if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
4557
- else {
4558
- const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
4559
- if (exe.data.parallel && this.#workerPool) {
4560
- const epoch = this.#workerPool.dispatch(exe.data.module, ptrs, exe.source.size);
4561
- for (const slot of outputs) this.#pendingWork.set(slot, epoch);
4562
- } else {
4563
- if (inputs.some((slot) => {
4564
- const epoch = this.#pendingWork.get(slot);
4565
- return epoch && this.#workerPool.epoch < epoch;
4566
- })) throw new Error("cannot dispatch synchronously with pending async work");
4567
- const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
4568
- const func = instance.exports.kernel;
4569
- func(...ptrs, 0, exe.source.size);
4570
- }
4703
+ load = LOAD_OP("load", 43, "f64");
4704
+ store = STORE_OP("store", 57, "f64");
4705
+ eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
4706
+ ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
4707
+ lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
4708
+ gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
4709
+ le = BINARY_OP("le", 101, "f64", "f64", "i32");
4710
+ ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
4711
+ abs = UNARY_OP("abs", 153, "f64", "f64");
4712
+ neg = UNARY_OP("neg", 154, "f64", "f64");
4713
+ ceil = UNARY_OP("ceil", 155, "f64", "f64");
4714
+ floor = UNARY_OP("floor", 156, "f64", "f64");
4715
+ trunc = UNARY_OP("trunc", 157, "f64", "f64");
4716
+ nearest = UNARY_OP("nearest", 158, "f64", "f64");
4717
+ sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
4718
+ add = BINARY_OP("add", 160, "f64", "f64", "f64");
4719
+ sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
4720
+ mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
4721
+ div = BINARY_OP("div", 163, "f64", "f64", "f64");
4722
+ min = BINARY_OP("min", 164, "f64", "f64", "f64");
4723
+ max = BINARY_OP("max", 165, "f64", "f64", "f64");
4724
+ copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
4725
+ convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
4726
+ convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
4727
+ promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
4728
+ };
4729
+ function VECTOR_OP(op, vopcode, inTypes, outType) {
4730
+ return function() {
4731
+ for (const inType of inTypes.toReversed()) {
4732
+ const actualType = this.cg._pop();
4733
+ assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes.join(", ")} -> ${outType})`);
4571
4734
  }
4572
- if (tracing) {
4573
- const info = traceSourceInfo(exe.source);
4574
- emitTrace("wasm", info, start, performance.now());
4735
+ this.cg._emit(encodeOpcode([253, vopcode]));
4736
+ this.cg._push(this.cg[outType]);
4737
+ };
4738
+ }
4739
+ function VECTOR_OPL(op, vopcode, inTypes, outType) {
4740
+ return function(lane) {
4741
+ for (const inType of inTypes.toReversed()) {
4742
+ const actualType = this.cg._pop();
4743
+ assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes} -> ${outType})`);
4575
4744
  }
4745
+ this.cg._emit(encodeOpcode([253, vopcode]));
4746
+ this.cg._emit(lane);
4747
+ this.cg._push(this.cg[outType]);
4748
+ };
4749
+ }
4750
+ function VECTOR_LOAD_OP(op, vopcode) {
4751
+ return function(align = 0, offset = 0) {
4752
+ const idxType = this.cg._pop();
4753
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
4754
+ this.cg._emit(encodeOpcode([253, vopcode]));
4755
+ this.cg._emit(encodeUnsigned(align));
4756
+ this.cg._emit(encodeUnsigned(offset));
4757
+ this.cg._push(this.cg.v128);
4758
+ };
4759
+ }
4760
+ var V128 = class {
4761
+ constructor(cg) {
4762
+ this.cg = cg;
4576
4763
  }
4577
- #getBuffer(slot) {
4578
- const buffer = this.#buffers.get(slot);
4579
- if (!buffer) throw new SlotError(slot);
4580
- return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
4764
+ get typeId() {
4765
+ return 123;
4766
+ }
4767
+ get name() {
4768
+ return "v128";
4769
+ }
4770
+ load = VECTOR_LOAD_OP("load", 0);
4771
+ load32x2_s = VECTOR_LOAD_OP("load32x2_s", 5);
4772
+ load32x2_u = VECTOR_LOAD_OP("load32x2_u", 6);
4773
+ load32_splat = VECTOR_LOAD_OP("load32_splat", 9);
4774
+ load32_zero = VECTOR_LOAD_OP("load32_zero", 92);
4775
+ store(align = 0, offset = 0) {
4776
+ const valType = this.cg._pop();
4777
+ assert(valType.typeId === this.cg.v128.typeId, `invalid type for store`);
4778
+ const idxType = this.cg._pop();
4779
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for store`);
4780
+ this.cg._emit(253);
4781
+ this.cg._emit(encodeUnsigned(11));
4782
+ this.cg._emit(encodeUnsigned(align));
4783
+ this.cg._emit(encodeUnsigned(offset));
4581
4784
  }
4785
+ not = VECTOR_OP("not", 77, ["v128"], "v128");
4786
+ and = VECTOR_OP("and", 78, ["v128", "v128"], "v128");
4787
+ andnot = VECTOR_OP("andnot", 79, ["v128", "v128"], "v128");
4788
+ or = VECTOR_OP("or", 80, ["v128", "v128"], "v128");
4789
+ xor = VECTOR_OP("xor", 81, ["v128", "v128"], "v128");
4790
+ bitselect = VECTOR_OP("bitselect", 82, [
4791
+ "v128",
4792
+ "v128",
4793
+ "v128"
4794
+ ], "v128");
4795
+ any_true = VECTOR_OP("any_true", 83, ["v128"], "i32");
4582
4796
  };
4583
- /** Emit a runtime guard: enter the if-block only when [begin, end) is SIMD-aligned. */
4584
- function emitAlignmentGuard(cg, paramBegin, paramEnd) {
4585
- const mask = SIMD_LANES - 1;
4586
- cg.local.get(paramEnd);
4587
- cg.local.get(paramBegin);
4588
- cg.i32.sub();
4589
- cg.i32.const(mask);
4590
- cg.i32.and();
4591
- cg.i32.eqz();
4592
- cg.local.get(paramBegin);
4593
- cg.i32.const(mask);
4594
- cg.i32.and();
4595
- cg.i32.eqz();
4596
- cg.i32.and();
4597
- cg.if(cg.void);
4598
- }
4599
- function codegenWasm(kernel) {
4600
- const tune = tuneNullopt(kernel);
4601
- const re = kernel.reduction;
4602
- if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
4603
- const useSimd = isSimdEligible(tune.exp, kernel);
4604
- const bufferStrides = /* @__PURE__ */ new Map();
4605
- if (useSimd) tune.exp.collect((e) => e.op === AluOp.GlobalIndex).forEach((gi) => {
4606
- const result = analyzeStride(gi.src[0]);
4607
- if (result.kind !== "gather" && (result.tileSize < SIMD_LANES || isFinite(result.tileSize) && result.tileSize % SIMD_LANES !== 0)) bufferStrides.set(gi, GATHER);
4608
- else bufferStrides.set(gi, result);
4609
- });
4610
- const cg = new CodeGenerator();
4611
- cg.memory.import("env", "memory");
4612
- if (hasSharedArrayBuffer()) cg.memory.pages(0, 65536).shared(true);
4613
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
4614
- const funcs = {};
4615
- if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
4616
- if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
4617
- if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
4618
- if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
4619
- if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
4620
- if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
4621
- if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
4622
- if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
4623
- if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
4624
- const paramBegin = kernel.nargs + 1;
4625
- const paramEnd = kernel.nargs + 2;
4626
- const kernelFunc = cg.function(rep(kernel.nargs + 3, cg.i32), [], () => {
4627
- const gidx = cg.local.declare(cg.i32);
4628
- cg.local.get(paramBegin);
4629
- cg.local.set(gidx);
4630
- if (useSimd) {
4631
- emitAlignmentGuard(cg, paramBegin, paramEnd);
4632
- cg.loop(cg.void);
4633
- if (!re) {
4634
- cg.block(cg.void);
4635
- cg.local.get(gidx);
4636
- cg.local.get(paramEnd);
4637
- cg.i32.ge_u();
4638
- cg.br_if(0);
4639
- cg.local.get(kernel.nargs);
4640
- cg.local.get(gidx);
4641
- cg.i32.const(byteWidth(kernel.dtype));
4642
- cg.i32.mul();
4643
- cg.i32.add();
4644
- translateExpSimd(cg, funcs, tune.exp, { gidx }, bufferStrides);
4645
- cg.v128.store(4);
4646
- cg.local.get(gidx);
4647
- cg.i32.const(SIMD_LANES);
4648
- cg.i32.add();
4649
- cg.local.set(gidx);
4650
- cg.br(1);
4651
- cg.end();
4652
- } else {
4653
- const reIsInt = kernel.exp.dtype === DType.Int32 || kernel.exp.dtype === DType.Uint32;
4654
- cg.block(cg.void);
4655
- cg.local.get(gidx);
4656
- cg.local.get(paramEnd);
4657
- cg.i32.ge_u();
4658
- cg.br_if(0);
4659
- const vecAcc = cg.local.declare(reIsInt ? cg.i32x4 : cg.f32x4);
4660
- if (reIsInt) {
4661
- cg.i32.const(re.identity);
4662
- cg.i32x4.splat();
4663
- } else {
4664
- cg.f32.const(re.identity);
4665
- cg.f32x4.splat();
4666
- }
4667
- cg.local.set(vecAcc);
4668
- const ridx = cg.local.declare(cg.i32);
4669
- cg.i32.const(0);
4670
- cg.local.set(ridx);
4671
- cg.loop(cg.void);
4672
- cg.block(cg.void);
4673
- cg.local.get(ridx);
4674
- cg.i32.const(re.size);
4675
- cg.i32.ge_u();
4676
- cg.br_if(0);
4677
- translateExpSimd(cg, funcs, tune.exp, {
4678
- gidx,
4679
- ridx
4680
- }, bufferStrides);
4681
- cg.local.get(vecAcc);
4682
- if (reIsInt) if (re.op === AluOp.Add) cg.i32x4.add();
4683
- else if (re.op === AluOp.Mul) cg.i32x4.mul();
4684
- else if (re.op === AluOp.Min) if (re.dtype === DType.Int32) cg.i32x4.min_s();
4685
- else cg.i32x4.min_u();
4686
- else if (re.op === AluOp.Max) if (re.dtype === DType.Int32) cg.i32x4.max_s();
4687
- else cg.i32x4.max_u();
4688
- else throw new Error(`invalid SIMD reduction op: ${re.op}`);
4689
- else if (re.op === AluOp.Add) cg.f32x4.add();
4690
- else if (re.op === AluOp.Mul) cg.f32x4.mul();
4691
- else if (re.op === AluOp.Min) cg.f32x4.min();
4692
- else if (re.op === AluOp.Max) cg.f32x4.max();
4693
- else throw new Error(`invalid SIMD reduction op: ${re.op}`);
4694
- cg.local.set(vecAcc);
4695
- cg.local.get(ridx);
4696
- cg.i32.const(1);
4697
- cg.i32.add();
4698
- cg.local.set(ridx);
4699
- cg.br(1);
4700
- cg.end();
4701
- cg.end();
4702
- for (let lane = 0; lane < SIMD_LANES; lane++) {
4703
- cg.local.get(kernel.nargs);
4704
- cg.local.get(gidx);
4705
- if (lane > 0) {
4706
- cg.i32.const(lane);
4707
- cg.i32.add();
4708
- }
4709
- cg.i32.const(byteWidth(kernel.dtype));
4710
- cg.i32.mul();
4711
- cg.i32.add();
4712
- const acc = cg.local.declare(reIsInt ? cg.i32 : cg.f32);
4713
- cg.local.get(vecAcc);
4714
- if (reIsInt) cg.i32x4.extract_lane(lane);
4715
- else cg.f32x4.extract_lane(lane);
4716
- cg.local.set(acc);
4717
- const laneGidx = cg.local.declare(cg.i32);
4718
- cg.local.get(gidx);
4719
- if (lane > 0) {
4720
- cg.i32.const(lane);
4721
- cg.i32.add();
4722
- }
4723
- cg.local.set(laneGidx);
4724
- translateExp(cg, funcs, tune.epilogue, {
4725
- acc,
4726
- gidx: laneGidx
4727
- });
4728
- dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
4729
- }
4730
- cg.local.get(gidx);
4731
- cg.i32.const(SIMD_LANES);
4732
- cg.i32.add();
4733
- cg.local.set(gidx);
4734
- cg.br(1);
4735
- cg.end();
4736
- }
4737
- cg.end();
4738
- cg.return();
4739
- cg.end();
4740
- }
4741
- cg.loop(cg.void);
4742
- cg.block(cg.void);
4743
- cg.local.get(gidx);
4744
- cg.local.get(paramEnd);
4745
- cg.i32.ge_u();
4746
- cg.br_if(0);
4747
- cg.local.get(kernel.nargs);
4748
- cg.local.get(gidx);
4749
- cg.i32.const(byteWidth(kernel.dtype));
4797
+ var I32x4 = class extends V128 {
4798
+ splat = VECTOR_OP("splat", 17, ["i32"], "v128");
4799
+ extract_lane = VECTOR_OPL("extract_lane", 27, ["v128"], "i32");
4800
+ replace_lane = VECTOR_OPL("replace_lane", 28, ["v128", "i32"], "v128");
4801
+ eq = VECTOR_OP("eq", 55, ["v128", "v128"], "v128");
4802
+ ne = VECTOR_OP("ne", 56, ["v128", "v128"], "v128");
4803
+ lt_s = VECTOR_OP("lt_s", 57, ["v128", "v128"], "v128");
4804
+ lt_u = VECTOR_OP("lt_u", 58, ["v128", "v128"], "v128");
4805
+ gt_s = VECTOR_OP("gt_s", 59, ["v128", "v128"], "v128");
4806
+ gt_u = VECTOR_OP("gt_u", 60, ["v128", "v128"], "v128");
4807
+ le_s = VECTOR_OP("le_s", 61, ["v128", "v128"], "v128");
4808
+ le_u = VECTOR_OP("le_u", 62, ["v128", "v128"], "v128");
4809
+ ge_s = VECTOR_OP("ge_s", 63, ["v128", "v128"], "v128");
4810
+ ge_u = VECTOR_OP("ge_u", 64, ["v128", "v128"], "v128");
4811
+ abs = VECTOR_OP("abs", 160, ["v128"], "v128");
4812
+ neg = VECTOR_OP("neg", 161, ["v128"], "v128");
4813
+ all_true = VECTOR_OP("all_true", 163, ["v128"], "i32");
4814
+ bitmask = VECTOR_OP("bitmask", 164, ["v128"], "i32");
4815
+ shl = VECTOR_OP("shl", 171, ["v128", "i32"], "v128");
4816
+ shr_s = VECTOR_OP("shr_s", 172, ["v128", "i32"], "v128");
4817
+ shr_u = VECTOR_OP("shr_u", 173, ["v128", "i32"], "v128");
4818
+ add = VECTOR_OP("add", 174, ["v128", "v128"], "v128");
4819
+ sub = VECTOR_OP("sub", 177, ["v128", "v128"], "v128");
4820
+ mul = VECTOR_OP("mul", 181, ["v128", "v128"], "v128");
4821
+ min_s = VECTOR_OP("min_s", 182, ["v128", "v128"], "v128");
4822
+ min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
4823
+ max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
4824
+ max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
4825
+ trunc_sat_f32x4_s = VECTOR_OP("trunc_sat_f32x4_s", 248, ["v128"], "v128");
4826
+ trunc_sat_f32x4_u = VECTOR_OP("trunc_sat_f32x4_u", 249, ["v128"], "v128");
4827
+ };
4828
+ var F32x4 = class extends V128 {
4829
+ splat = VECTOR_OP("splat", 19, ["f32"], "v128");
4830
+ extract_lane = VECTOR_OPL("extract_lane", 31, ["v128"], "f32");
4831
+ replace_lane = VECTOR_OPL("replace_lane", 32, ["v128", "f32"], "v128");
4832
+ eq = VECTOR_OP("eq", 65, ["v128", "v128"], "v128");
4833
+ ne = VECTOR_OP("ne", 66, ["v128", "v128"], "v128");
4834
+ lt = VECTOR_OP("lt", 67, ["v128", "v128"], "v128");
4835
+ gt = VECTOR_OP("gt", 68, ["v128", "v128"], "v128");
4836
+ le = VECTOR_OP("le", 69, ["v128", "v128"], "v128");
4837
+ ge = VECTOR_OP("ge", 70, ["v128", "v128"], "v128");
4838
+ ceil = VECTOR_OP("ceil", 103, ["v128"], "v128");
4839
+ floor = VECTOR_OP("floor", 104, ["v128"], "v128");
4840
+ trunc = VECTOR_OP("trunc", 105, ["v128"], "v128");
4841
+ nearest = VECTOR_OP("nearest", 106, ["v128"], "v128");
4842
+ abs = VECTOR_OP("abs", 224, ["v128"], "v128");
4843
+ neg = VECTOR_OP("neg", 225, ["v128"], "v128");
4844
+ sqrt = VECTOR_OP("sqrt", 227, ["v128"], "v128");
4845
+ add = VECTOR_OP("add", 228, ["v128", "v128"], "v128");
4846
+ sub = VECTOR_OP("sub", 229, ["v128", "v128"], "v128");
4847
+ mul = VECTOR_OP("mul", 230, ["v128", "v128"], "v128");
4848
+ div = VECTOR_OP("div", 231, ["v128", "v128"], "v128");
4849
+ min = VECTOR_OP("min", 232, ["v128", "v128"], "v128");
4850
+ max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
4851
+ pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
4852
+ pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
4853
+ convert_i32x4_s = VECTOR_OP("convert_i32x4_s", 250, ["v128"], "v128");
4854
+ convert_i32x4_u = VECTOR_OP("convert_i32x4_u", 251, ["v128"], "v128");
4855
+ relaxed_madd = VECTOR_OP("relaxed_madd", 261, [
4856
+ "v128",
4857
+ "v128",
4858
+ "v128"
4859
+ ], "v128");
4860
+ relaxed_nmadd = VECTOR_OP("relaxed_nmadd", 262, [
4861
+ "v128",
4862
+ "v128",
4863
+ "v128"
4864
+ ], "v128");
4865
+ };
4866
+
4867
+ //#endregion
4868
+ //#region src/backend/wasm/codegen.ts
4869
+ function isIdentityEpilogue(exp) {
4870
+ return exp?.op === AluOp.Variable && exp.arg === "acc";
4871
+ }
4872
+ function initializeReductionPointer(cg, funcs, candidate, ctx, valueKey, ridxOffset) {
4873
+ const ptr = cg.local.declare(cg.i32);
4874
+ translateExp(cg, funcs, candidate.baseIndex, ctx);
4875
+ cg.i32.const(byteWidth(candidate.dtype));
4876
+ cg.i32.mul();
4877
+ cg.local.get(candidate.gid);
4878
+ cg.i32.add();
4879
+ if (ridxOffset !== void 0 && candidate.strideBytes !== 0) {
4880
+ cg.local.get(ridxOffset);
4881
+ cg.i32.const(candidate.strideBytes);
4750
4882
  cg.i32.mul();
4751
4883
  cg.i32.add();
4752
- if (re) {
4753
- const acc = cg.local.declare(dty(cg, null, kernel.exp.dtype));
4754
- dty(cg, null, kernel.exp.dtype).const(re.identity);
4755
- cg.local.set(acc);
4756
- const ridx = cg.local.declare(cg.i32);
4757
- cg.i32.const(0);
4758
- cg.local.set(ridx);
4759
- cg.loop(cg.void);
4760
- cg.block(cg.void);
4761
- cg.local.get(ridx);
4762
- cg.i32.const(re.size);
4763
- cg.i32.ge_u();
4764
- cg.br_if(0);
4765
- translateExp(cg, funcs, tune.exp, {
4766
- gidx,
4767
- ridx
4768
- });
4769
- if (re.op === AluOp.Add) {
4770
- cg.local.get(acc);
4771
- if (re.dtype === DType.Bool) cg.i32.or();
4772
- else dty(cg, re.op, re.dtype).add();
4773
- } else if (re.op === AluOp.Mul) {
4774
- cg.local.get(acc);
4775
- if (re.dtype === DType.Bool) cg.i32.and();
4776
- else dty(cg, re.op, re.dtype).mul();
4777
- } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
4884
+ }
4885
+ cg.local.set(ptr);
4886
+ return {
4887
+ ...candidate,
4888
+ ptr,
4889
+ valueKey
4890
+ };
4891
+ }
4892
+ function incrementReductionPointers(cg, pointers, multiplier = 1) {
4893
+ for (const pointer of pointers) {
4894
+ if (pointer.strideBytes === 0) continue;
4895
+ cg.local.get(pointer.ptr);
4896
+ cg.i32.const(pointer.strideBytes * multiplier);
4897
+ cg.i32.add();
4898
+ cg.local.set(pointer.ptr);
4899
+ }
4900
+ }
4901
+ function emitSimdReductionOp(cg, re, reIsInt, valueAlreadyAccumulated) {
4902
+ if (!reIsInt && valueAlreadyAccumulated) return;
4903
+ switch (re.op) {
4904
+ case AluOp.Add:
4905
+ if (reIsInt) cg.i32x4.add();
4906
+ else cg.f32x4.add();
4907
+ return;
4908
+ case AluOp.Mul:
4909
+ if (reIsInt) cg.i32x4.mul();
4910
+ else cg.f32x4.mul();
4911
+ return;
4912
+ case AluOp.Min:
4913
+ if (reIsInt) if (re.dtype === DType.Int32) cg.i32x4.min_s();
4914
+ else cg.i32x4.min_u();
4915
+ else cg.f32x4.min();
4916
+ return;
4917
+ case AluOp.Max:
4918
+ if (reIsInt) if (re.dtype === DType.Int32) cg.i32x4.max_s();
4919
+ else cg.i32x4.max_u();
4920
+ else cg.f32x4.max();
4921
+ return;
4922
+ default: throw new Error(`invalid SIMD reduction op: ${re.op}`);
4923
+ }
4924
+ }
4925
+ function emitScalarReductionOp(cg, re, acc) {
4926
+ switch (re.op) {
4927
+ case AluOp.Add:
4928
+ cg.local.get(acc);
4929
+ if (re.dtype === DType.Bool) cg.i32.or();
4930
+ else dty(cg, re.op, re.dtype).add();
4931
+ return;
4932
+ case AluOp.Mul:
4933
+ cg.local.get(acc);
4934
+ if (re.dtype === DType.Bool) cg.i32.and();
4935
+ else dty(cg, re.op, re.dtype).mul();
4936
+ return;
4937
+ case AluOp.Min:
4938
+ case AluOp.Max:
4939
+ if (isFloatDtype(re.dtype)) {
4778
4940
  cg.local.get(acc);
4779
4941
  if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
4780
4942
  else dtyF(cg, re.op, re.dtype).max();
@@ -4794,227 +4956,619 @@ function codegenWasm(kernel) {
4794
4956
  else cg.i32.gt_u();
4795
4957
  cg.select();
4796
4958
  } else throw new Error(`invalid reduction min/max over ${re.dtype}`);
4797
- else throw new Error(`invalid wasm reduction op: ${re.op}`);
4798
- cg.local.set(acc);
4799
- cg.local.get(ridx);
4800
- cg.i32.const(1);
4801
- cg.i32.add();
4802
- cg.local.set(ridx);
4803
- cg.br(1);
4804
- cg.end();
4805
- cg.end();
4806
- translateExp(cg, funcs, tune.epilogue, {
4807
- acc,
4808
- gidx
4809
- });
4810
- } else translateExp(cg, funcs, tune.exp, { gidx });
4811
- dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
4812
- cg.local.get(gidx);
4813
- cg.i32.const(1);
4814
- cg.i32.add();
4815
- cg.local.set(gidx);
4816
- cg.br(1);
4817
- cg.end();
4818
- cg.end();
4819
- });
4820
- cg.export(kernelFunc, "kernel");
4821
- return cg.finish();
4959
+ return;
4960
+ default: throw new Error(`invalid wasm reduction op: ${re.op}`);
4961
+ }
4822
4962
  }
4823
- function translateExp(cg, funcs, exp, ctx) {
4824
- const references = /* @__PURE__ */ new Map();
4825
- const seen = /* @__PURE__ */ new Set();
4826
- const countReferences = (exp$1) => {
4827
- references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
4828
- if (!seen.has(exp$1)) {
4829
- seen.add(exp$1);
4830
- for (const src of exp$1.src) countReferences(src);
4831
- }
4963
+ /**
4964
+ * Check if a kernel is eligible for SIMD codegen.
4965
+ *
4966
+ * A kernel qualifies when:
4967
+ * - size >= 4 (need at least 4 elements for a SIMD group)
4968
+ * - For reductions: the reduction op has a SIMD variant for its dtype
4969
+ * - All nodes have a supported dtype (f32, i32, u32, bool) with SIMD variants
4970
+ */
4971
+ function isSimdEligible(tunedExp, kernel) {
4972
+ if (kernel.size < simdLanes) return false;
4973
+ if (kernel.reduction) {
4974
+ if (!simdSupportedOps.get(kernel.reduction.dtype)?.has(kernel.reduction.op)) return false;
4975
+ }
4976
+ const check = (exp, visited) => {
4977
+ if (visited.has(exp)) return true;
4978
+ visited.add(exp);
4979
+ if (!simdSupportedOps.get(exp.dtype)?.has(exp.op)) return false;
4980
+ if (exp.op === AluOp.GlobalIndex) return true;
4981
+ for (const child of exp.src) if (!check(child, visited)) return false;
4982
+ return true;
4832
4983
  };
4833
- const expContext = /* @__PURE__ */ new Map();
4834
- const gen = (exp$1) => {
4835
- if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
4836
- const { op, src, dtype, arg } = exp$1;
4837
- if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
4838
- gen(src[0]);
4839
- gen(src[1]);
4840
- if (op === AluOp.Add) if (dtype === DType.Bool) cg.i32.or();
4841
- else dty(cg, op, dtype).add();
4842
- else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
4843
- else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
4844
- else dty(cg, op, dtype).mul();
4845
- else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
4846
- dtyF(cg, op, dtype).div();
4847
- dtyF(cg, op, dtype).trunc();
4848
- } else if (dtype === DType.Uint32) cg.i32.div_u();
4849
- else if (dtype === DType.Int32) cg.i32.div_s();
4850
- else throw new UnsupportedOpError(op, dtype, "wasm");
4851
- else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
4852
- const dt = dtyF(cg, op, dtype);
4853
- const a = cg.local.declare(dt);
4854
- const b = cg.local.declare(dt);
4855
- cg.local.set(b);
4856
- cg.local.tee(a);
4857
- cg.local.get(a);
4858
- cg.local.get(b);
4859
- dt.div();
4860
- dt.trunc();
4861
- cg.local.get(b);
4862
- dt.mul();
4863
- dt.sub();
4864
- } else if (dtype === DType.Uint32) cg.i32.rem_u();
4865
- else if (dtype === DType.Int32) cg.i32.rem_s();
4866
- else throw new UnsupportedOpError(op, dtype, "wasm");
4867
- else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
4868
- else dtyF(cg, op, dtype).max();
4869
- else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
4870
- const a = cg.local.declare(cg.i32);
4871
- const b = cg.local.declare(cg.i32);
4872
- cg.local.set(b);
4873
- cg.local.tee(a);
4874
- cg.local.get(b);
4875
- cg.local.get(a);
4876
- cg.local.get(b);
4877
- if (dtype === DType.Int32) if (op === AluOp.Min) cg.i32.lt_s();
4878
- else cg.i32.gt_s();
4879
- else if (op === AluOp.Min) cg.i32.lt_u();
4880
- else cg.i32.gt_u();
4881
- cg.select();
4882
- } else throw new UnsupportedOpError(op, dtype, "wasm");
4883
- else if (op === AluOp.BitCombine) if (arg === "and") cg.i32.and();
4884
- else if (arg === "or") cg.i32.or();
4885
- else cg.i32.xor();
4886
- else if (op === AluOp.BitShift) if (arg === "shl") cg.i32.shl();
4887
- else cg.i32.shr_u();
4888
- else if (op === AluOp.Cmplt) {
4889
- const srcDtype = src[0].dtype;
4890
- if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
4891
- else if (srcDtype === DType.Int32) cg.i32.lt_s();
4892
- else if (srcDtype === DType.Uint32) cg.i32.lt_u();
4893
- else throw new UnsupportedOpError(op, dtype, "wasm");
4894
- } else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
4895
- else throw new UnsupportedOpError(op, dtype, "wasm");
4896
- } else if (AluGroup.Unary.has(op)) {
4897
- const callFuncF32 = (func) => {
4898
- if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
4899
- else throw new UnsupportedOpError(op, dtype, "wasm");
4900
- cg.call(func);
4901
- if (dtype === DType.Float64) cg.f64.promote_f32();
4902
- };
4903
- if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
4904
- else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
4905
- else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
4906
- else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
4907
- else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
4908
- else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
4909
- else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
4910
- else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
4911
- else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
4912
- else if (op === AluOp.Reciprocal) {
4913
- const dt = dtyF(cg, op, dtype);
4914
- dt.const(1), gen(src[0]), dt.div();
4915
- } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
4916
- else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
4917
- else if (op === AluOp.Cast) {
4918
- gen(src[0]);
4919
- const dtype0 = src[0].dtype;
4920
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
4921
- if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
4922
- else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
4923
- else if (i32repr);
4924
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4925
- else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
4926
- else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
4927
- else if (i32repr);
4928
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4929
- else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
4930
- else if (dtype0 === DType.Float64) cg.f32.demote_f64();
4931
- else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
4932
- else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
4933
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4934
- else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
4935
- else if (dtype0 === DType.Float64);
4936
- else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
4937
- else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
4938
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4939
- else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
4940
- else if (i32repr) cg.i32.const(0), cg.i32.ne();
4941
- else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
4942
- else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
4943
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4944
- else throw new UnsupportedOpError(op, dtype, "wasm");
4945
- } else if (op === AluOp.Bitcast) {
4946
- gen(src[0]);
4947
- const dtype0 = src[0].dtype;
4948
- if (dtype !== dtype0) {
4949
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
4950
- if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
4951
- else if (i32repr);
4952
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4953
- else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
4954
- else if (dtype0 === DType.Float32);
4955
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4956
- else throw new UnsupportedOpError(op, dtype, "wasm");
4957
- }
4958
- } else throw new UnsupportedOpError(op, dtype, "wasm");
4959
- } else if (op === AluOp.Where) {
4960
- gen(src[1]);
4961
- gen(src[2]);
4962
- gen(src[0]);
4963
- cg.select();
4964
- } else if (op === AluOp.Threefry2x32) {
4965
- for (let i = 0; i < 4; i++) gen(src[i]);
4966
- cg.call(funcs.threefry2x32);
4967
- if (arg === "xor") cg.i32.xor();
4968
- else if (arg === 0) cg.drop();
4969
- else if (arg === 1) {
4970
- const local = cg.local.declare(cg.i32);
4971
- cg.local.set(local);
4972
- cg.drop();
4973
- cg.local.get(local);
4974
- } else throw new UnsupportedOpError(op, dtype, "wasm", arg);
4975
- } else if (op === AluOp.Const) return dty(cg, op, dtype).const(arg);
4976
- else if (op === AluOp.Special) return cg.local.get(ctx[arg[0]]);
4977
- else if (op === AluOp.Variable) return cg.local.get(ctx[arg]);
4978
- else if (op === AluOp.GlobalIndex) {
4979
- const [gid, len] = arg;
4980
- gen(src[0]);
4981
- const local = cg.local.declare(cg.i32);
4982
- cg.local.tee(local);
4983
- cg.i32.const(0);
4984
- cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
4985
- cg.select();
4986
- cg.i32.const(byteWidth(dtype));
4984
+ return check(tunedExp, /* @__PURE__ */ new Set());
4985
+ }
4986
+ /** Checks if SIMD over the reduced-k dimension is workable. */
4987
+ function canUseKSimdReduction(exp, re, pointers) {
4988
+ if (re.op !== AluOp.Add) return false;
4989
+ const globalIndexCount = exp.collect((node) => node.op === AluOp.GlobalIndex).length;
4990
+ return globalIndexCount === pointers.length && pointers.every((candidate) => candidate.dtype === DType.Float32 && (candidate.strideBytes === 0 || candidate.strideBytes === byteWidth(candidate.dtype)));
4991
+ }
4992
+ /** Emit a runtime guard: enter the if-block only when [begin, end) is SIMD-aligned. */
4993
+ function emitAlignmentGuard(cg, paramBegin, paramEnd, alignment = simdLanes) {
4994
+ cg.local.get(paramEnd);
4995
+ cg.local.get(paramBegin);
4996
+ cg.i32.sub();
4997
+ cg.i32.const(alignment);
4998
+ cg.i32.rem_u();
4999
+ cg.i32.eqz();
5000
+ cg.local.get(paramBegin);
5001
+ cg.i32.const(alignment);
5002
+ cg.i32.rem_u();
5003
+ cg.i32.eqz();
5004
+ cg.i32.and();
5005
+ cg.if(cg.void);
5006
+ }
5007
+ function codegenWasm(kernel) {
5008
+ const tune = tuneNullopt(kernel);
5009
+ const re = kernel.reduction;
5010
+ if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
5011
+ const simdEligible = isSimdEligible(tune.exp, kernel);
5012
+ const hasIdentityEpilogue = isIdentityEpilogue(tune.epilogue);
5013
+ const expStrides = simdEligible ? collectSimdStrides(tune.exp) : /* @__PURE__ */ new Map();
5014
+ const reductionHasLaneGather = re && [...expStrides.values()].some((stride) => stride.kind === "gather");
5015
+ const useSimd = simdEligible && !reductionHasLaneGather;
5016
+ const simdReductionPointerCandidates = useSimd && re ? reductionPointerCandidates(tune.exp, expStrides) : [];
5017
+ const reductionPointers = re ? reductionPointerCandidates(tune.exp) : [];
5018
+ const useKSimdReduction = simdEligible && re && canUseKSimdReduction(tune.exp, re, reductionPointers);
5019
+ const kSimdReductionPointerCandidates = useKSimdReduction && re ? reductionPointers.map((candidate) => ({
5020
+ ...candidate,
5021
+ stride: candidate.strideBytes === 0 ? {
5022
+ kind: "broadcast",
5023
+ tileSize: Infinity
5024
+ } : {
5025
+ kind: "contiguous",
5026
+ tileSize: Infinity
5027
+ }
5028
+ })) : [];
5029
+ const simdTilePlan = useSimd && re && hasIdentityEpilogue ? reductionTilePlan(kernel, expStrides) : null;
5030
+ const kSimdTilePlan = reductionHasLaneGather && useKSimdReduction ? reductionKTilePlan(kernel, expStrides) : null;
5031
+ const useRelaxedMadd = hasWasmFeature("relaxed-madd") && re?.op === AluOp.Add && tune.exp.dtype === DType.Float32 && tune.exp.op === AluOp.Mul;
5032
+ const cg = new CodeGenerator();
5033
+ cg.memory.import("env", "memory");
5034
+ if (hasSharedArrayBuffer()) cg.memory.pages(0, 65536).shared(true);
5035
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
5036
+ const funcs = {};
5037
+ if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
5038
+ if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
5039
+ if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
5040
+ if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
5041
+ if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
5042
+ if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
5043
+ if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
5044
+ if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
5045
+ if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
5046
+ const paramBegin = kernel.nargs + 1;
5047
+ const paramEnd = kernel.nargs + 2;
5048
+ const kernelFunc = cg.function(rep(kernel.nargs + 3, cg.i32), [], () => {
5049
+ const gidx = cg.local.declare(cg.i32);
5050
+ cg.local.get(paramBegin);
5051
+ cg.local.set(gidx);
5052
+ const emitLocalPlusConst = (local, amount) => {
5053
+ cg.local.get(local);
5054
+ cg.i32.const(amount);
5055
+ cg.i32.add();
5056
+ };
5057
+ const bumpLocal = (local, amount) => {
5058
+ emitLocalPlusConst(local, amount);
5059
+ cg.local.set(local);
5060
+ };
5061
+ const setLocalConst = (local, value) => {
5062
+ cg.i32.const(value);
5063
+ cg.local.set(local);
5064
+ };
5065
+ const copyLocal = (target, source) => {
5066
+ cg.local.get(source);
5067
+ cg.local.set(target);
5068
+ };
5069
+ const setRowBase = (target, rowTileBase, rowOffset, tileSize) => {
5070
+ cg.local.get(rowTileBase);
5071
+ cg.local.get(rowOffset);
5072
+ cg.i32.const(tileSize);
4987
5073
  cg.i32.mul();
4988
- cg.local.get(gid);
4989
5074
  cg.i32.add();
4990
- dty(cg, op, dtype).load(Math.log2(byteWidth(dtype)));
4991
- } else throw new UnsupportedOpError(op, dtype, "wasm");
4992
- if ((references.get(exp$1) ?? 0) > 1) {
4993
- const local = cg.local.declare(dty(cg, op, dtype));
4994
- cg.local.tee(local);
4995
- expContext.set(exp$1, local);
5075
+ cg.local.set(target);
5076
+ };
5077
+ const emitOutputAddress = (index) => {
5078
+ cg.local.get(kernel.nargs);
5079
+ cg.local.get(index);
5080
+ cg.i32.const(byteWidth(kernel.dtype));
5081
+ cg.i32.mul();
5082
+ cg.i32.add();
5083
+ };
5084
+ const declareTileGidx = (rowBase, col, rowOffset, colOffset) => {
5085
+ const local = cg.local.declare(cg.i32);
5086
+ cg.local.get(rowBase);
5087
+ if (rowOffset !== 0) {
5088
+ cg.i32.const(rowOffset);
5089
+ cg.i32.add();
5090
+ }
5091
+ cg.local.get(col);
5092
+ cg.i32.add();
5093
+ if (colOffset !== 0) {
5094
+ cg.i32.const(colOffset);
5095
+ cg.i32.add();
5096
+ }
5097
+ cg.local.set(local);
5098
+ return local;
5099
+ };
5100
+ const emitLoopWhileLt = (index, emitBound, emitBody) => {
5101
+ cg.loop(cg.void);
5102
+ cg.block(cg.void);
5103
+ cg.local.get(index);
5104
+ emitBound();
5105
+ cg.i32.ge_u();
5106
+ cg.br_if(0);
5107
+ emitBody();
5108
+ cg.br(1);
5109
+ cg.end();
5110
+ cg.end();
5111
+ };
5112
+ const emitLoopWhileLocalLt = (index, bound, emitBody) => emitLoopWhileLt(index, () => cg.local.get(bound), emitBody);
5113
+ const emitLoopWhileConstLt = (index, bound, emitBody) => emitLoopWhileLt(index, () => cg.i32.const(bound), emitBody);
5114
+ const emitSimdExpWithAccumulator = (ctx, pointerMap, pointerValueCache, acc) => {
5115
+ if (useRelaxedMadd) {
5116
+ translateExpSimd(cg, funcs, tune.exp.src[0], ctx, expStrides, pointerMap, pointerValueCache);
5117
+ translateExpSimd(cg, funcs, tune.exp.src[1], ctx, expStrides, pointerMap, pointerValueCache);
5118
+ cg.local.get(acc);
5119
+ cg.f32x4.relaxed_madd();
5120
+ return true;
5121
+ }
5122
+ translateExpSimd(cg, funcs, tune.exp, ctx, expStrides, pointerMap, pointerValueCache);
5123
+ cg.local.get(acc);
5124
+ return false;
5125
+ };
5126
+ const emitSimdReductionForGidxs = (gidxs, pointerMaps, uniquePointers, ridxStart, ridxEnd) => {
5127
+ if (!re) throw new Error("internal: missing reduction");
5128
+ const reIsInt = kernel.exp.dtype === DType.Int32 || kernel.exp.dtype === DType.Uint32;
5129
+ const vecAccs = gidxs.map(() => cg.local.declare(reIsInt ? cg.i32x4 : cg.f32x4));
5130
+ const initializeIdentityAccumulators = () => {
5131
+ for (const acc of vecAccs) {
5132
+ if (reIsInt) {
5133
+ cg.i32.const(re.identity);
5134
+ cg.i32x4.splat();
5135
+ } else {
5136
+ cg.f32.const(re.identity);
5137
+ cg.f32x4.splat();
5138
+ }
5139
+ cg.local.set(acc);
5140
+ }
5141
+ };
5142
+ const loadPartialAccumulators = () => {
5143
+ for (let i = 0; i < gidxs.length; i++) {
5144
+ emitOutputAddress(gidxs[i]);
5145
+ if (reIsInt) cg.i32x4.load(4);
5146
+ else cg.f32x4.load(4);
5147
+ cg.local.set(vecAccs[i]);
5148
+ }
5149
+ };
5150
+ if (ridxStart === void 0) initializeIdentityAccumulators();
5151
+ else {
5152
+ cg.local.get(ridxStart);
5153
+ cg.i32.eqz();
5154
+ cg.if(cg.void);
5155
+ initializeIdentityAccumulators();
5156
+ cg.else();
5157
+ loadPartialAccumulators();
5158
+ cg.end();
5159
+ }
5160
+ const ridx = cg.local.declare(cg.i32);
5161
+ const emitReductionStep = () => {
5162
+ const pointerValueCache = /* @__PURE__ */ new Map();
5163
+ for (let i = 0; i < gidxs.length; i++) {
5164
+ const valueAlreadyAccumulated = emitSimdExpWithAccumulator({
5165
+ gidx: gidxs[i],
5166
+ ridx
5167
+ }, pointerMaps[i], pointerValueCache, vecAccs[i]);
5168
+ emitSimdReductionOp(cg, re, reIsInt, valueAlreadyAccumulated);
5169
+ cg.local.set(vecAccs[i]);
5170
+ }
5171
+ incrementReductionPointers(cg, uniquePointers);
5172
+ };
5173
+ if (ridxStart === void 0) setLocalConst(ridx, 0);
5174
+ else copyLocal(ridx, ridxStart);
5175
+ const emitReductionLoopBody = () => {
5176
+ emitReductionStep();
5177
+ bumpLocal(ridx, 1);
5178
+ };
5179
+ if (ridxEnd === void 0) emitLoopWhileConstLt(ridx, re.size, emitReductionLoopBody);
5180
+ else emitLoopWhileLocalLt(ridx, ridxEnd, emitReductionLoopBody);
5181
+ if (hasIdentityEpilogue) {
5182
+ for (let i = 0; i < gidxs.length; i++) {
5183
+ emitOutputAddress(gidxs[i]);
5184
+ cg.local.get(vecAccs[i]);
5185
+ cg.v128.store(4);
5186
+ }
5187
+ return;
5188
+ }
5189
+ const laneGidx = cg.local.declare(cg.i32);
5190
+ const laneAcc = cg.local.declare(reIsInt ? cg.i32 : cg.f32);
5191
+ for (let i = 0; i < gidxs.length; i++) for (let lane = 0; lane < simdLanes; lane++) {
5192
+ cg.local.get(kernel.nargs);
5193
+ cg.local.get(gidxs[i]);
5194
+ if (lane > 0) {
5195
+ cg.i32.const(lane);
5196
+ cg.i32.add();
5197
+ }
5198
+ cg.local.tee(laneGidx);
5199
+ cg.i32.const(byteWidth(kernel.dtype));
5200
+ cg.i32.mul();
5201
+ cg.i32.add();
5202
+ cg.local.get(vecAccs[i]);
5203
+ if (reIsInt) cg.i32x4.extract_lane(lane);
5204
+ else cg.f32x4.extract_lane(lane);
5205
+ cg.local.set(laneAcc);
5206
+ translateExp(cg, funcs, tune.epilogue, {
5207
+ acc: laneAcc,
5208
+ gidx: laneGidx
5209
+ });
5210
+ dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
5211
+ }
5212
+ };
5213
+ const initializePointerMaps = (groups, candidates, keyFor, ctxFor, ridxOffset) => {
5214
+ const pointerMaps = groups.map(() => /* @__PURE__ */ new Map());
5215
+ const sharedPointers = /* @__PURE__ */ new Map();
5216
+ const uniquePointers = [];
5217
+ for (let i = 0; i < groups.length; i++) {
5218
+ const group = groups[i];
5219
+ for (const candidate of candidates) {
5220
+ const key = keyFor(candidate, group, i);
5221
+ let plan = sharedPointers.get(key);
5222
+ if (!plan) {
5223
+ plan = initializeReductionPointer(cg, funcs, candidate, ctxFor(group), key, ridxOffset);
5224
+ sharedPointers.set(key, plan);
5225
+ uniquePointers.push(plan);
5226
+ }
5227
+ pointerMaps[i].set(candidate.exp, plan);
5228
+ }
5229
+ }
5230
+ return {
5231
+ pointerMaps,
5232
+ uniquePointers
5233
+ };
5234
+ };
5235
+ const emitSimdReductionStep = () => {
5236
+ const groups = [{
5237
+ gidx,
5238
+ row: 0,
5239
+ vector: 0
5240
+ }];
5241
+ const { pointerMaps, uniquePointers } = initializePointerMaps(groups, simdReductionPointerCandidates, (candidate, group, i) => pointerShareKey(candidate, group.row, group.vector, i), (group) => ({ gidx: group.gidx }));
5242
+ emitSimdReductionForGidxs([gidx], pointerMaps, uniquePointers);
5243
+ bumpLocal(gidx, simdLanes);
5244
+ };
5245
+ const emitElementwiseSimdStep = () => {
5246
+ emitOutputAddress(gidx);
5247
+ translateExpSimd(cg, funcs, tune.exp, { gidx }, expStrides);
5248
+ cg.v128.store(4);
5249
+ bumpLocal(gidx, simdLanes);
5250
+ };
5251
+ const emitKSimdReductionForGroups = (groups, pointerMaps, uniquePointers) => {
5252
+ if (!re) throw new Error("internal: missing reduction");
5253
+ if (!kSimdTilePlan) throw new Error("internal: missing K SIMD plan");
5254
+ const vecAccs = groups.map(() => cg.local.declare(cg.f32x4));
5255
+ for (const acc of vecAccs) {
5256
+ cg.f32.const(re.identity);
5257
+ cg.f32x4.splat();
5258
+ cg.local.set(acc);
5259
+ }
5260
+ const ridx = cg.local.declare(cg.i32);
5261
+ setLocalConst(ridx, 0);
5262
+ emitLoopWhileConstLt(ridx, re.size, () => {
5263
+ for (let u = 0; u < kSimdTilePlan.kUnroll; u++) {
5264
+ const pointerValueCache = /* @__PURE__ */ new Map();
5265
+ for (let i = 0; i < groups.length; i++) {
5266
+ const valueAlreadyAccumulated = emitSimdExpWithAccumulator({
5267
+ gidx: groups[i].gidx,
5268
+ ridx
5269
+ }, pointerMaps[i], pointerValueCache, vecAccs[i]);
5270
+ if (!valueAlreadyAccumulated) cg.f32x4.add();
5271
+ cg.local.set(vecAccs[i]);
5272
+ }
5273
+ incrementReductionPointers(cg, uniquePointers, simdLanes);
5274
+ }
5275
+ bumpLocal(ridx, simdLanes * kSimdTilePlan.kUnroll);
5276
+ });
5277
+ for (let i = 0; i < groups.length; i++) {
5278
+ const acc = cg.local.declare(cg.f32);
5279
+ for (let lane = 0; lane < simdLanes; lane++) {
5280
+ cg.local.get(vecAccs[i]);
5281
+ cg.f32x4.extract_lane(lane);
5282
+ if (lane > 0) cg.f32.add();
5283
+ }
5284
+ cg.local.set(acc);
5285
+ emitOutputAddress(groups[i].gidx);
5286
+ translateExp(cg, funcs, tune.epilogue, {
5287
+ acc,
5288
+ gidx: groups[i].gidx
5289
+ });
5290
+ dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
5291
+ }
5292
+ };
5293
+ const emitKSimdReductionLoop = (plan) => {
5294
+ const colTile = cg.local.declare(cg.i32);
5295
+ const col = cg.local.declare(cg.i32);
5296
+ const rowTileBase = cg.local.declare(cg.i32);
5297
+ const rowOffset = cg.local.declare(cg.i32);
5298
+ const rowBase = cg.local.declare(cg.i32);
5299
+ const emitBlock = () => {
5300
+ const groups = [];
5301
+ for (let row = 0; row < plan.microRows; row++) for (let colOffset = 0; colOffset < plan.microCols; colOffset++) groups.push({
5302
+ gidx: declareTileGidx(rowBase, col, row * plan.tileSize, colOffset),
5303
+ row,
5304
+ col: colOffset
5305
+ });
5306
+ if (!kSimdTilePlan) throw new Error("internal: missing K SIMD plan");
5307
+ const { pointerMaps, uniquePointers } = initializePointerMaps(groups, kSimdReductionPointerCandidates, (candidate, group, i) => kReductionPointerShareKey(candidate, expStrides.get(candidate.exp) ?? { kind: "gather" }, kSimdTilePlan.tileSize, group.row, group.col, i), (group) => ({ gidx: group.gidx }));
5308
+ emitKSimdReductionForGroups(groups, pointerMaps, uniquePointers);
5309
+ };
5310
+ emitLoopWhileLocalLt(gidx, paramEnd, () => {
5311
+ copyLocal(rowTileBase, gidx);
5312
+ setLocalConst(colTile, 0);
5313
+ emitLoopWhileConstLt(colTile, plan.tileSize, () => {
5314
+ setLocalConst(rowOffset, 0);
5315
+ emitLoopWhileConstLt(rowOffset, plan.tileRows, () => {
5316
+ copyLocal(col, colTile);
5317
+ emitLoopWhileLt(col, () => emitLocalPlusConst(colTile, plan.tileCols), () => {
5318
+ setRowBase(rowBase, rowTileBase, rowOffset, plan.tileSize);
5319
+ emitBlock();
5320
+ bumpLocal(col, plan.microCols);
5321
+ });
5322
+ bumpLocal(rowOffset, plan.microRows);
5323
+ });
5324
+ bumpLocal(colTile, plan.tileCols);
5325
+ });
5326
+ bumpLocal(gidx, plan.tileRows * plan.tileSize);
5327
+ });
5328
+ };
5329
+ const emitTiledSimdReductionLoop = (plan) => {
5330
+ const colTile = cg.local.declare(cg.i32);
5331
+ const col = cg.local.declare(cg.i32);
5332
+ const kTile = cg.local.declare(cg.i32);
5333
+ const tileEnd = cg.local.declare(cg.i32);
5334
+ const rowTileBase = cg.local.declare(cg.i32);
5335
+ const rowOffset = cg.local.declare(cg.i32);
5336
+ const rowBase = cg.local.declare(cg.i32);
5337
+ const emitRowBlock = () => {
5338
+ const groups = [];
5339
+ for (let row = 0; row < plan.microRows; row++) for (let vector = 0; vector < plan.microVectors; vector++) groups.push({
5340
+ gidx: declareTileGidx(rowBase, col, row * plan.tileSize, vector * simdLanes),
5341
+ row,
5342
+ vector
5343
+ });
5344
+ const { pointerMaps, uniquePointers } = initializePointerMaps(groups, simdReductionPointerCandidates, (candidate, group, i) => pointerShareKey(candidate, group.row, group.vector, i), (group) => ({ gidx: group.gidx }), kTile);
5345
+ emitSimdReductionForGidxs(groups.map((group) => group.gidx), pointerMaps, uniquePointers, kTile, tileEnd);
5346
+ };
5347
+ emitLoopWhileLocalLt(gidx, paramEnd, () => {
5348
+ copyLocal(rowTileBase, gidx);
5349
+ setLocalConst(colTile, 0);
5350
+ emitLoopWhileConstLt(colTile, plan.tileSize, () => {
5351
+ setLocalConst(kTile, 0);
5352
+ emitLoopWhileConstLt(kTile, re.size, () => {
5353
+ emitLocalPlusConst(kTile, plan.tileK);
5354
+ cg.local.set(tileEnd);
5355
+ setLocalConst(rowOffset, 0);
5356
+ emitLoopWhileConstLt(rowOffset, plan.tileRows, () => {
5357
+ copyLocal(col, colTile);
5358
+ emitLoopWhileLt(col, () => emitLocalPlusConst(colTile, plan.tileVectors * simdLanes), () => {
5359
+ setRowBase(rowBase, rowTileBase, rowOffset, plan.tileSize);
5360
+ emitRowBlock();
5361
+ bumpLocal(col, plan.microVectors * simdLanes);
5362
+ });
5363
+ bumpLocal(rowOffset, plan.microRows);
5364
+ });
5365
+ bumpLocal(kTile, plan.tileK);
5366
+ });
5367
+ bumpLocal(colTile, plan.tileVectors * simdLanes);
5368
+ });
5369
+ bumpLocal(gidx, plan.tileRows * plan.tileSize);
5370
+ });
5371
+ };
5372
+ const emitGuardedFastPath = (alignment, emit) => {
5373
+ emitAlignmentGuard(cg, paramBegin, paramEnd, alignment);
5374
+ emit();
5375
+ cg.return();
5376
+ cg.end();
5377
+ };
5378
+ if (kSimdTilePlan) emitGuardedFastPath(kSimdTilePlan.tileRows * kSimdTilePlan.tileSize, () => emitKSimdReductionLoop(kSimdTilePlan));
5379
+ if (useSimd) {
5380
+ if (simdTilePlan) emitGuardedFastPath(simdTilePlan.tileRows * simdTilePlan.tileSize, () => emitTiledSimdReductionLoop(simdTilePlan));
5381
+ emitGuardedFastPath(simdLanes, () => {
5382
+ emitLoopWhileLocalLt(gidx, paramEnd, re ? emitSimdReductionStep : emitElementwiseSimdStep);
5383
+ });
4996
5384
  }
5385
+ emitLoopWhileLocalLt(gidx, paramEnd, () => {
5386
+ emitOutputAddress(gidx);
5387
+ if (re) {
5388
+ const acc = cg.local.declare(dty(cg, null, kernel.exp.dtype));
5389
+ dty(cg, null, kernel.exp.dtype).const(re.identity);
5390
+ cg.local.set(acc);
5391
+ const ridx = cg.local.declare(cg.i32);
5392
+ const emitReductionStep = () => {
5393
+ translateExp(cg, funcs, tune.exp, {
5394
+ gidx,
5395
+ ridx
5396
+ });
5397
+ emitScalarReductionOp(cg, re, acc);
5398
+ cg.local.set(acc);
5399
+ };
5400
+ setLocalConst(ridx, 0);
5401
+ emitLoopWhileConstLt(ridx, re.size, () => {
5402
+ emitReductionStep();
5403
+ bumpLocal(ridx, 1);
5404
+ });
5405
+ translateExp(cg, funcs, tune.epilogue, {
5406
+ acc,
5407
+ gidx
5408
+ });
5409
+ } else translateExp(cg, funcs, tune.exp, { gidx });
5410
+ dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
5411
+ bumpLocal(gidx, 1);
5412
+ });
5413
+ });
5414
+ cg.export(kernelFunc, "kernel");
5415
+ const tiledPlan = simdTilePlan ?? kSimdTilePlan;
5416
+ return {
5417
+ bytes: cg.finish(),
5418
+ workSize: kernel.size,
5419
+ chunkAlignment: tiledPlan ? tiledPlan.tileRows * tiledPlan.tileSize : 16,
5420
+ minWorkPerWorker: tiledPlan && kernel.size / tiledPlan.tileSize >= 1024 ? tiledPlan.tileSize * 32 : 256
4997
5421
  };
4998
- countReferences(exp);
4999
- gen(exp);
5000
5422
  }
5001
- function dty(cg, op, dtype) {
5002
- switch (dtype) {
5003
- case DType.Float32: return cg.f32;
5004
- case DType.Float64: return cg.f64;
5005
- case DType.Int32:
5006
- case DType.Uint32:
5007
- case DType.Bool: return cg.i32;
5008
- default: throw new UnsupportedOpError(op, dtype, "wasm");
5423
+
5424
+ //#endregion
5425
+ //#region src/backend/wasm.ts
5426
+ const compiledProgramCache = /* @__PURE__ */ new Map();
5427
+ /** Backend that compiles into WebAssembly bytecode for immediate execution. */
5428
+ var WasmBackend = class {
5429
+ type = "wasm";
5430
+ maxArgs = 64;
5431
+ #memory;
5432
+ #nextSlot;
5433
+ #allocator;
5434
+ #buffers;
5435
+ #workerPool;
5436
+ #pendingWork = /* @__PURE__ */ new Map();
5437
+ constructor() {
5438
+ this.#memory = hasSharedArrayBuffer() ? new WebAssembly.Memory({
5439
+ initial: 0,
5440
+ maximum: 65536,
5441
+ shared: true
5442
+ }) : new WebAssembly.Memory({ initial: 0 });
5443
+ this.#allocator = new WasmAllocator(this.#memory);
5444
+ this.#nextSlot = 1;
5445
+ this.#buffers = /* @__PURE__ */ new Map();
5446
+ this.#workerPool = createWorkerPool(this.#memory);
5009
5447
  }
5010
- }
5011
- function dtyF(cg, op, dtype) {
5012
- switch (dtype) {
5013
- case DType.Float32: return cg.f32;
5014
- case DType.Float64: return cg.f64;
5015
- default: throw new UnsupportedOpError(op, dtype, "wasm");
5448
+ malloc(size, initialData) {
5449
+ const ptr = this.#allocator.malloc(size);
5450
+ if (initialData) {
5451
+ if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
5452
+ new Uint8Array(this.#memory.buffer, ptr, size).set(initialData);
5453
+ }
5454
+ const slot = this.#nextSlot++;
5455
+ this.#buffers.set(slot, {
5456
+ ptr,
5457
+ size,
5458
+ ref: 1
5459
+ });
5460
+ return slot;
5016
5461
  }
5017
- }
5462
+ incRef(slot) {
5463
+ const buffer = this.#buffers.get(slot);
5464
+ if (!buffer) throw new SlotError(slot);
5465
+ buffer.ref++;
5466
+ }
5467
+ decRef(slot) {
5468
+ const buffer = this.#buffers.get(slot);
5469
+ if (!buffer) throw new SlotError(slot);
5470
+ buffer.ref--;
5471
+ if (buffer.ref === 0) {
5472
+ this.#allocator.free(buffer.ptr);
5473
+ this.#buffers.delete(slot);
5474
+ this.#pendingWork.delete(slot);
5475
+ }
5476
+ }
5477
+ async read(slot, start, count) {
5478
+ const epoch = this.#pendingWork.get(slot);
5479
+ if (epoch) await this.#workerPool.waitForEpoch(epoch);
5480
+ return this.#readData(slot, start, count);
5481
+ }
5482
+ readSync(slot, start, count) {
5483
+ const epoch = this.#pendingWork.get(slot);
5484
+ if (epoch && this.#workerPool.epoch < epoch) throw new Error("cannot read synchronously from a slot with async work");
5485
+ return this.#readData(slot, start, count);
5486
+ }
5487
+ #readData(slot, start, count) {
5488
+ const buffer = this.#getBuffer(slot);
5489
+ if (start === void 0) start = 0;
5490
+ if (count === void 0) count = buffer.byteLength - start;
5491
+ if (hasSharedArrayBuffer() && buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
5492
+ else return buffer.slice(start, start + count);
5493
+ }
5494
+ async prepareKernel(kernel) {
5495
+ const kernelHash = FpHash.hash(kernel);
5496
+ const hashKey = kernelHash.toString();
5497
+ const program = await runWithCacheAsync(compiledProgramCache, hashKey, async () => {
5498
+ const { bytes,...metadata } = codegenWasm(kernel);
5499
+ const module = await WebAssembly.compile(bytes);
5500
+ return {
5501
+ module,
5502
+ ...metadata
5503
+ };
5504
+ });
5505
+ return new Executable(kernel, {
5506
+ program,
5507
+ sync: false
5508
+ });
5509
+ }
5510
+ prepareKernelSync(kernel) {
5511
+ const kernelHash = FpHash.hash(kernel);
5512
+ const hashKey = kernelHash.toString();
5513
+ const compiled = runWithCache(compiledProgramCache, hashKey, () => {
5514
+ const { bytes,...metadata } = codegenWasm(kernel);
5515
+ const module = new WebAssembly.Module(bytes);
5516
+ return {
5517
+ module,
5518
+ ...metadata
5519
+ };
5520
+ });
5521
+ return new Executable(kernel, {
5522
+ program: compiled,
5523
+ sync: true
5524
+ });
5525
+ }
5526
+ async prepareRoutine(routine) {
5527
+ return this.prepareRoutineSync(routine);
5528
+ }
5529
+ prepareRoutineSync(routine) {
5530
+ return new Executable(routine, {
5531
+ program: void 0,
5532
+ sync: true
5533
+ });
5534
+ }
5535
+ dispatch(exe, inputs, outputs) {
5536
+ const tracing = isTracing();
5537
+ const start = tracing ? performance.now() : 0;
5538
+ if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
5539
+ else {
5540
+ const { program, sync } = exe.data;
5541
+ const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
5542
+ if (this.#workerPool && !sync) {
5543
+ const retainedSlots = [...inputs, ...outputs];
5544
+ for (const slot of retainedSlots) this.incRef(slot);
5545
+ const epoch = this.#workerPool.dispatch(program.module, ptrs, program.workSize, program.chunkAlignment, program.minWorkPerWorker);
5546
+ for (const slot of outputs) this.#pendingWork.set(slot, epoch);
5547
+ this.#workerPool.waitForEpoch(epoch).then(() => {
5548
+ for (const slot of outputs) if (this.#pendingWork.get(slot) === epoch) this.#pendingWork.delete(slot);
5549
+ for (const slot of retainedSlots) this.decRef(slot);
5550
+ });
5551
+ } else {
5552
+ if (inputs.some((slot) => {
5553
+ const epoch = this.#pendingWork.get(slot);
5554
+ return epoch && this.#workerPool.epoch < epoch;
5555
+ })) throw new Error("cannot dispatch synchronously with pending async work");
5556
+ const instance = new WebAssembly.Instance(program.module, { env: { memory: this.#memory } });
5557
+ const func = instance.exports.kernel;
5558
+ func(...ptrs, 0, program.workSize);
5559
+ }
5560
+ }
5561
+ if (tracing) {
5562
+ const info = traceSourceInfo(exe.source);
5563
+ emitTrace("wasm", info, start, performance.now());
5564
+ }
5565
+ }
5566
+ #getBuffer(slot) {
5567
+ const buffer = this.#buffers.get(slot);
5568
+ if (!buffer) throw new SlotError(slot);
5569
+ return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
5570
+ }
5571
+ };
5018
5572
 
5019
5573
  //#endregion
5020
5574
  //#region src/backend.ts
@@ -5061,7 +5615,7 @@ async function createBackend(device) {
5061
5615
  if (!navigator.gpu) return null;
5062
5616
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
5063
5617
  if (!adapter) return null;
5064
- const { WebGPUBackend } = await import("./webgpu-C2kLdkUh.js");
5618
+ const { WebGPUBackend } = await import("./webgpu-BRv5r9Sl.js");
5065
5619
  const importantLimits = [
5066
5620
  "maxBufferSize",
5067
5621
  "maxComputeInvocationsPerWorkgroup",
@@ -5099,7 +5653,7 @@ async function createBackend(device) {
5099
5653
  });
5100
5654
  if (!gl) return null;
5101
5655
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
5102
- const { WebGLBackend } = await import("./webgl-BhsnpeB0.js");
5656
+ const { WebGLBackend } = await import("./webgl-Hh0FX6oV.js");
5103
5657
  return new WebGLBackend(gl);
5104
5658
  } else throw new Error(`Backend not found: ${device}`);
5105
5659
  }