@jax-js/jax 0.1.13 → 0.1.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2328,6 +2328,25 @@ var TuneDims = class {
2328
2328
  this.upcast++;
2329
2329
  }
2330
2330
  }
2331
+ applyGroup(axis, amount) {
2332
+ if (axis < this.reduce || axis >= this.unroll) throw new Error("Can only group reduction axes");
2333
+ const length = this.st.shape[axis];
2334
+ if (length % amount !== 0) throw new Error(`Group by ${amount} on axis length ${length}`);
2335
+ this.st = this.st.reshape([
2336
+ ...this.st.shape.slice(0, axis),
2337
+ length / amount,
2338
+ amount,
2339
+ ...this.st.shape.slice(axis + 1)
2340
+ ]).permute([
2341
+ ...range(this.reduce),
2342
+ axis + 1,
2343
+ ...range(this.reduce, axis + 1),
2344
+ ...range(axis + 2, this.st.shape.length + 1)
2345
+ ]);
2346
+ this.reduce++;
2347
+ this.unroll++;
2348
+ this.upcast++;
2349
+ }
2331
2350
  };
2332
2351
  /** Tuning step that does not apply any optimization. */
2333
2352
  function tuneNullopt(kernel) {
@@ -2399,6 +2418,17 @@ function tuneWebgpu(kernel) {
2399
2418
  upcastedAxis.add(choices[0][2]);
2400
2419
  } else break;
2401
2420
  }
2421
+ const groupCandidateSts = sts.map((st) => st.compose(dim.st));
2422
+ const groupCandidateHasContiguousInputs = groupCandidateSts.every((st) => {
2423
+ const hasOutputStride = st.lastStrides.slice(0, dim.groups).some((stride) => stride !== 0);
2424
+ return !hasOutputStride || Math.abs(st.lastStrides[dim.reduce]) <= 1;
2425
+ });
2426
+ if (reduction.op === AluOp.Add && reduction.dtype === DType.Float32 && groupCandidateHasContiguousInputs && prod(dim.st.shape.slice(0, dim.groups)) < 4096 && prod(dim.st.shape.slice(dim.reduce, dim.unroll)) >= 512) {
2427
+ const axis = dim.reduce;
2428
+ let amount = 16;
2429
+ while (amount > 1 && dim.st.shape[axis] % amount !== 0) amount /= 2;
2430
+ if (amount > 1) dim.applyGroup(axis, amount);
2431
+ }
2402
2432
  if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2403
2433
  const s = dim.st.shape[dim.unroll - 1];
2404
2434
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
@@ -3219,6 +3249,30 @@ function wasm_threefry2x32(cg) {
3219
3249
  });
3220
3250
  }
3221
3251
 
3252
+ //#endregion
3253
+ //#region src/backend/wasm/featureProbe.ts
3254
+ const featureProbes = { "relaxed-madd": "0061736d0100000001080160037b7b7b017b030201000a0d010b00200020012002fd85020b" };
3255
+ const featureSupportCache = /* @__PURE__ */ new Map();
3256
+ /** Detects whether this environment supports a probed Wasm CPU feature. */
3257
+ function hasWasmFeature(feature) {
3258
+ const cached = featureSupportCache.get(feature);
3259
+ if (cached !== void 0) return cached;
3260
+ const testHex = featureProbes[feature];
3261
+ let supported = false;
3262
+ try {
3263
+ supported = typeof WebAssembly !== "undefined" && WebAssembly.validate(decodeHex(testHex));
3264
+ } catch {
3265
+ supported = false;
3266
+ }
3267
+ featureSupportCache.set(feature, supported);
3268
+ return supported;
3269
+ }
3270
+ function decodeHex(hex) {
3271
+ const bytes = new Uint8Array(/* @__PURE__ */ new ArrayBuffer(hex.length / 2));
3272
+ for (let i = 0; i < bytes.length; i++) bytes[i] = Number.parseInt(hex.slice(i * 2, i * 2 + 2), 16);
3273
+ return bytes;
3274
+ }
3275
+
3222
3276
  //#endregion
3223
3277
  //#region src/backend/wasm/parallel.ts
3224
3278
  /** Check if SharedArrayBuffer is available. */
@@ -3309,10 +3363,10 @@ var WasmWorkerPool = class {
3309
3363
  * Returns an epoch that can be used to wait for the ongoing work to complete,
3310
3364
  * which is guaranteed to be monotonically increasing.
3311
3365
  */
3312
- dispatch(module$1, ptrs, size) {
3366
+ dispatch(module$1, ptrs, size, chunkAlignment = 16, minWorkPerWorker = MIN_ELEMS_PER_THREAD) {
3313
3367
  this.#ensureInit();
3314
3368
  this.#epochEnd++;
3315
- const result = this.#queue.then(() => this.#dispatchNow(module$1, ptrs, size));
3369
+ const result = this.#queue.then(() => this.#dispatchNow(module$1, ptrs, size, chunkAlignment, minWorkPerWorker));
3316
3370
  this.#queue = result.then(() => {}, () => {}).then(() => {
3317
3371
  this.#epoch++;
3318
3372
  const hooks = this.#hooks.get(this.#epoch);
@@ -3323,10 +3377,10 @@ var WasmWorkerPool = class {
3323
3377
  });
3324
3378
  return this.#epochEnd;
3325
3379
  }
3326
- async #dispatchNow(module$1, ptrs, size) {
3380
+ async #dispatchNow(module$1, ptrs, size, chunkAlignment, minWorkPerWorker) {
3327
3381
  if (size === 0) return;
3328
- const n = Math.min(this.#workers.length, Math.ceil(size / MIN_ELEMS_PER_THREAD));
3329
- const chunkSize = Math.ceil(size / n / 16) * 16;
3382
+ const n = Math.min(this.#workers.length, Math.ceil(size / minWorkPerWorker));
3383
+ const chunkSize = Math.ceil(size / n / chunkAlignment) * chunkAlignment;
3330
3384
  const promises = [];
3331
3385
  for (let i = 0; i < n; i++) {
3332
3386
  const begin = i * chunkSize;
@@ -3361,1421 +3415,1529 @@ function createWorkerPool(memory) {
3361
3415
  }
3362
3416
 
3363
3417
  //#endregion
3364
- //#region src/backend/wasm/wasmblr.ts
3365
- /**
3366
- * @file Minimalist WebAssembly assembler. This allows you to emit WebAssembly
3367
- * bytecode directly from the browser.
3368
- *
3369
- * Self-contained port of https://github.com/bwasti/wasmblr to TypeScript.
3370
- * Some operation names in this module are written in `snake_case` to match
3371
- * their names in the Wasm specification.
3372
- *
3373
- * Reference: https://pengowray.github.io/wasm-ops/.
3374
- */
3375
- const magicModuleHeader = [
3376
- 0,
3377
- 97,
3378
- 115,
3379
- 109
3380
- ];
3381
- const moduleVersion = [
3382
- 1,
3383
- 0,
3384
- 0,
3385
- 0
3386
- ];
3387
- function assert(condition, message) {
3388
- if (!condition) throw new Error(message || "Assertion failed");
3418
+ //#region src/backend/wasm/tilePlan.ts
3419
+ const simdLanes = 4;
3420
+ const TILED_SIMD_ROWS = 128;
3421
+ const TILED_SIMD_COLUMNS = 128;
3422
+ const TILED_SIMD_K = 64;
3423
+ const TILED_SIMD_MICRO_ROWS = 4;
3424
+ const TILED_SIMD_MICRO_VECTORS = 4;
3425
+ const K_SIMD_MICRO_ROWS = 4;
3426
+ const K_SIMD_MICRO_COLS = 4;
3427
+ const K_SIMD_UNROLL = 4;
3428
+ const TILE_AXIS_PARTS = 8;
3429
+ function isSymbol(exp, name) {
3430
+ return exp.op === AluOp.Variable && exp.arg === name || exp.op === AluOp.Special && exp.arg[0] === name;
3389
3431
  }
3390
- function encodeSigned(n) {
3391
- const out = [];
3392
- let more = true;
3393
- while (more) {
3394
- let byte = n & 127;
3395
- n >>= 7;
3396
- if (n === 0 && (byte & 64) === 0 || n === -1 && (byte & 64) !== 0) more = false;
3397
- else byte |= 128;
3398
- out.push(byte);
3432
+ function referencesSymbol(exp, name) {
3433
+ return isSymbol(exp, name) || exp.src.some((src) => referencesSymbol(src, name));
3434
+ }
3435
+ function referencesGidx(exp) {
3436
+ return referencesSymbol(exp, "gidx");
3437
+ }
3438
+ function hasFragmentRisk(tileSize, N) {
3439
+ return isFinite(tileSize) && tileSize > N && tileSize % N !== 0;
3440
+ }
3441
+ function constInt(exp) {
3442
+ if (exp.op !== AluOp.Const) return null;
3443
+ const value = exp.arg;
3444
+ return Number.isInteger(value) ? value : null;
3445
+ }
3446
+ function coefficientOfSymbol(exp, name) {
3447
+ if (!referencesSymbol(exp, name)) return 0;
3448
+ if (isSymbol(exp, name)) return 1;
3449
+ if (exp.op === AluOp.Add || exp.op === AluOp.Sub) {
3450
+ const a = coefficientOfSymbol(exp.src[0], name);
3451
+ const b = coefficientOfSymbol(exp.src[1], name);
3452
+ if (a === null || b === null) return null;
3453
+ return exp.op === AluOp.Add ? a + b : a - b;
3399
3454
  }
3400
- return out;
3455
+ if (exp.op === AluOp.Mul) {
3456
+ const lhs = constInt(exp.src[0]);
3457
+ if (lhs !== null) {
3458
+ const rhsCoeff = coefficientOfSymbol(exp.src[1], name);
3459
+ return rhsCoeff === null ? null : lhs * rhsCoeff;
3460
+ }
3461
+ const rhs = constInt(exp.src[1]);
3462
+ if (rhs !== null) {
3463
+ const lhsCoeff = coefficientOfSymbol(exp.src[0], name);
3464
+ return lhsCoeff === null ? null : rhs * lhsCoeff;
3465
+ }
3466
+ }
3467
+ return null;
3401
3468
  }
3402
- function encodeUnsigned(n) {
3403
- const out = [];
3404
- do {
3405
- let byte = n & 127;
3406
- n = n >>> 7;
3407
- if (n !== 0) byte |= 128;
3408
- out.push(byte);
3409
- } while (n !== 0);
3410
- return out;
3469
+ function rewriteSymbol(exp, name, rewrite) {
3470
+ return exp.rewrite((node) => isSymbol(node, name) ? rewrite(node) : void 0).simplify();
3411
3471
  }
3412
- function encodeString(s) {
3413
- const bytes = new TextEncoder().encode(s);
3414
- return [bytes.length, ...bytes];
3472
+ function repeatsAcrossGidxTile(exp, tileSize) {
3473
+ if (!isFinite(tileSize)) return false;
3474
+ const shifted = rewriteSymbol(exp, "gidx", (node) => AluExp.add(node, AluExp.i32(tileSize)));
3475
+ return shifted.getHash() === exp.getHash();
3415
3476
  }
3416
- function encodeBlocktype(type) {
3417
- assert(type.length > 0, "blocktype must have at least one type");
3418
- if (type.length === 1) return [type[0].typeId];
3419
- return [
3420
- 96,
3421
- ...encodeUnsigned(0),
3422
- ...encodeUnsigned(type.length),
3423
- ...type.map((t) => t.typeId)
3424
- ];
3477
+ function divisorAtMost(value, limit) {
3478
+ for (let i = Math.min(value, limit); i > 1; i--) if (value % i === 0) return i;
3479
+ return 1;
3425
3480
  }
3426
- function encodeOpcode(opcode) {
3427
- if (typeof opcode === "number") return [opcode];
3428
- return [opcode[0], ...encodeUnsigned(opcode[1])];
3481
+ function tileAxisLimit(axisSize, maxTileSize) {
3482
+ return Math.min(maxTileSize, Math.max(1, Math.floor(axisSize / TILE_AXIS_PARTS)));
3429
3483
  }
3430
- function concat(out, inp) {
3431
- out.push(...inp);
3484
+ function commonTileSize(kernelSize, strideMap, minTileSize, unconstrainedTileSize = null) {
3485
+ const tileSizes = [];
3486
+ for (const stride of strideMap.values()) if (stride.kind !== "gather" && isFinite(stride.tileSize)) tileSizes.push(stride.tileSize);
3487
+ if (tileSizes.length === 0) return unconstrainedTileSize;
3488
+ const tileSize = Math.min(...tileSizes);
3489
+ if (tileSize < minTileSize || kernelSize % tileSize !== 0 || tileSizes.some((size) => size % tileSize !== 0)) return null;
3490
+ return tileSize;
3432
3491
  }
3433
- var Function_ = class {
3434
- inputTypes;
3435
- outputTypes;
3436
- body;
3437
- locals = [];
3438
- constructor(inputTypes, outputTypes, body) {
3439
- this.inputTypes = inputTypes;
3440
- this.outputTypes = outputTypes;
3441
- this.body = body || (() => {});
3442
- }
3443
- emit() {
3444
- this.locals = [];
3445
- this.body();
3446
- }
3447
- };
3448
- var Memory = class {
3449
- min = 0;
3450
- max = 0;
3451
- isShared = false;
3452
- aString = "";
3453
- bString = "";
3454
- constructor(cg) {
3455
- this.cg = cg;
3456
- }
3457
- /** Declare the size of the memory. Each page is 64 KiB. */
3458
- pages(min, max = 0) {
3459
- assert(this.min === 0 && this.max === 0);
3460
- this.min = min;
3461
- this.max = max;
3462
- return this;
3463
- }
3464
- export(a) {
3465
- assert(!this.isImport && !this.isExport, "already set");
3466
- this.aString = a;
3467
- return this;
3468
- }
3469
- shared(isShared) {
3470
- this.isShared = isShared;
3471
- return this;
3472
- }
3473
- import(a, b) {
3474
- assert(!this.isImport && !this.isExport, "already set");
3475
- this.aString = a;
3476
- this.bString = b;
3477
- return this;
3478
- }
3479
- size() {
3480
- this.cg._emit(63);
3481
- this.cg._emit(0);
3482
- }
3483
- grow() {
3484
- this.cg._emit(64);
3485
- this.cg._emit(0);
3486
- }
3487
- get isImport() {
3488
- return this.aString.length > 0 && this.bString.length > 0;
3489
- }
3490
- get isExport() {
3491
- return this.aString.length > 0 && this.bString.length === 0;
3492
- }
3493
- };
3494
- /** Public API of WebAssembly assembler. */
3495
- var CodeGenerator = class {
3496
- local;
3497
- i32;
3498
- f32;
3499
- f64;
3500
- v128;
3501
- i32x4;
3502
- f32x4;
3503
- memory;
3504
- void = {
3505
- typeId: 64,
3506
- name: "void"
3492
+ function tiledRows(kernelSize, tileSize) {
3493
+ const rowCount = kernelSize / tileSize;
3494
+ return divisorAtMost(rowCount, tileAxisLimit(rowCount, TILED_SIMD_ROWS));
3495
+ }
3496
+ function tiledColumns(tileSize, laneWidth = 1) {
3497
+ return divisorAtMost(tileSize / laneWidth, Math.max(1, Math.floor(tileAxisLimit(tileSize, TILED_SIMD_COLUMNS) / laneWidth)));
3498
+ }
3499
+ function periodicStride(exp, kind) {
3500
+ if (exp.src[1].op !== AluOp.Const) return null;
3501
+ const N = exp.src[1].arg;
3502
+ const inner = analyzeStride(exp.src[0]);
3503
+ if (inner.kind === "broadcast") return inner;
3504
+ if (inner.kind !== "contiguous" || hasFragmentRisk(inner.tileSize, N)) return { kind: "gather" };
3505
+ return {
3506
+ kind,
3507
+ tileSize: Math.min(inner.tileSize, N)
3507
3508
  };
3508
- #functions = [];
3509
- #importedFunctions = [];
3510
- #exportedFunctions = /* @__PURE__ */ new Map();
3511
- #curFunction = null;
3512
- #curBytes = [];
3513
- #typeStack = [];
3514
- #blockFrames = [];
3515
- constructor() {
3516
- this.local = new Local(this);
3517
- this.i32 = new I32(this);
3518
- this.f32 = new F32(this);
3519
- this.f64 = new F64(this);
3520
- this.v128 = new V128(this);
3521
- this.i32x4 = new I32x4(this);
3522
- this.f32x4 = new F32x4(this);
3523
- this.memory = new Memory(this);
3524
- }
3525
- unreachable() {
3526
- this._emit(0);
3527
- }
3528
- nop() {
3529
- this._emit(1);
3530
- }
3531
- block(...type) {
3532
- this.#blockFrames.push({
3533
- idx: this.#typeStack.length,
3534
- ty: type
3535
- });
3536
- this._emit(2);
3537
- this._emit(encodeBlocktype(type));
3509
+ }
3510
+ function addStrides(lhs, rhs) {
3511
+ if (lhs.kind === "gather" || rhs.kind === "gather") return { kind: "gather" };
3512
+ const tileSize = Math.min(lhs.tileSize, rhs.tileSize);
3513
+ if (lhs.kind === "broadcast") return {
3514
+ kind: rhs.kind,
3515
+ tileSize
3516
+ };
3517
+ if (rhs.kind === "broadcast") return {
3518
+ kind: lhs.kind,
3519
+ tileSize
3520
+ };
3521
+ return { kind: "gather" };
3522
+ }
3523
+ function analyzeStride(exp) {
3524
+ if (!referencesGidx(exp)) return {
3525
+ kind: "broadcast",
3526
+ tileSize: Infinity
3527
+ };
3528
+ if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return {
3529
+ kind: "contiguous",
3530
+ tileSize: Infinity
3531
+ };
3532
+ if (exp.op === AluOp.Idiv || exp.op === AluOp.Mod) {
3533
+ const stride = periodicStride(exp, exp.op === AluOp.Idiv ? "broadcast" : "contiguous");
3534
+ if (stride) return stride;
3538
3535
  }
3539
- loop(...type) {
3540
- this.#blockFrames.push({
3541
- idx: this.#typeStack.length,
3542
- ty: type
3543
- });
3544
- this._emit(3);
3545
- this._emit(encodeBlocktype(type));
3536
+ if (exp.op === AluOp.Mul) {
3537
+ for (let i = 0; i < 2; i++) if (exp.src[i].op === AluOp.Const) {
3538
+ const inner = analyzeStride(exp.src[1 - i]);
3539
+ if (inner.kind === "broadcast") return inner;
3540
+ return { kind: "gather" };
3541
+ }
3546
3542
  }
3547
- if(...type) {
3548
- assert(this._pop().typeId === this.i32.typeId, "if_: expected i32");
3549
- this.#blockFrames.push({
3550
- idx: this.#typeStack.length,
3551
- ty: type
3543
+ if (exp.op === AluOp.Add) return addStrides(analyzeStride(exp.src[0]), analyzeStride(exp.src[1]));
3544
+ return { kind: "gather" };
3545
+ }
3546
+ function simdStrideResult(globalIndex) {
3547
+ const index = globalIndex.src[0];
3548
+ const result = analyzeStride(index);
3549
+ const [_, len] = globalIndex.arg;
3550
+ if (result.kind !== "gather" && (result.tileSize < simdLanes || isFinite(result.tileSize) && result.tileSize % simdLanes !== 0)) return { kind: "gather" };
3551
+ if (result.kind === "contiguous" && (index.min < 0 || index.max >= len)) return { kind: "gather" };
3552
+ return result;
3553
+ }
3554
+ function collectSimdStrides(exp) {
3555
+ return new Map((exp?.collect((node) => node.op === AluOp.GlobalIndex) ?? []).map((gi) => [gi, simdStrideResult(gi)]));
3556
+ }
3557
+ function reductionPointerCandidates(exp, strideMap) {
3558
+ const candidates = [];
3559
+ for (const gi of exp.collect((node) => node.op === AluOp.GlobalIndex)) {
3560
+ const stride = strideMap?.get(gi) ?? {
3561
+ kind: "broadcast",
3562
+ tileSize: Infinity
3563
+ };
3564
+ if (strideMap && stride.kind === "gather") continue;
3565
+ const [gid, len] = gi.arg;
3566
+ const index = gi.src[0];
3567
+ const strideElems = coefficientOfSymbol(index, "ridx");
3568
+ if (strideElems === null || !Number.isInteger(strideElems)) continue;
3569
+ if (index.min < 0 || index.max >= len) continue;
3570
+ candidates.push({
3571
+ exp: gi,
3572
+ gid,
3573
+ dtype: gi.dtype,
3574
+ stride,
3575
+ baseIndex: rewriteSymbol(index, "ridx", () => AluExp.i32(0)),
3576
+ strideBytes: strideElems * byteWidth(gi.dtype)
3552
3577
  });
3553
- this._emit(4);
3554
- this._emit(encodeBlocktype(type));
3555
- }
3556
- else() {
3557
- assert(this.#blockFrames.length > 0, "else: no block to else");
3558
- const frame = this.#blockFrames[this.#blockFrames.length - 1];
3559
- this.#typeStack = this.#typeStack.slice(0, frame.idx);
3560
- this._emit(5);
3561
3578
  }
3562
- /** End a block (`block`, `if`/`else`, `loop`, or function). */
3563
- end() {
3564
- const frame = this.#blockFrames.pop();
3565
- assert(frame !== void 0, "end: no block to end");
3566
- this.#typeStack = this.#typeStack.slice(0, frame.idx);
3567
- for (const ty of frame.ty) if (ty.typeId !== this.void.typeId) this._push(ty);
3568
- this._emit(11);
3569
- }
3570
- /** Branch to a block a certain depth outward on the stack. */
3571
- br(depth) {
3572
- this._emit(12);
3573
- this._emit(encodeUnsigned(depth));
3574
- }
3575
- /** Conditional branch to a block a certain depth outward on the stack. */
3576
- br_if(depth) {
3577
- assert(this._pop().typeId === this.i32.typeId, "br_if: expected i32");
3578
- this._emit(13);
3579
- this._emit(encodeUnsigned(depth));
3580
- }
3581
- /** Jump table that indexes into a label vector (like switch). */
3582
- br_table(...depths) {
3583
- assert(this._pop().typeId === this.i32.typeId, "br_table: expected i32");
3584
- assert(depths.length > 0, "br_table: expected at least one default depth");
3585
- this._emit(14);
3586
- this._emit(encodeUnsigned(depths.length - 1));
3587
- for (const d of depths) this._emit(encodeUnsigned(d));
3588
- }
3589
- /** Return from a function, branching out of the outermost block. */
3590
- return() {
3591
- this._emit(15);
3592
- }
3593
- /** Call a function with the given ID. */
3594
- call(fn) {
3595
- const totalFunctions = this.#importedFunctions.length + this.#functions.length;
3596
- assert(fn < totalFunctions, "function index does not exist");
3597
- const func = fn < this.#importedFunctions.length ? this.#importedFunctions[fn] : this.#functions[fn - this.#importedFunctions.length];
3598
- for (let i = func.inputTypes.length - 1; i >= 0; i--) {
3599
- const argType = this._pop();
3600
- assert(argType.typeId === func.inputTypes[i].typeId, `call: argument ${i} type mismatch, expected ${func.inputTypes[i].name} got ${argType.name}`);
3579
+ return candidates;
3580
+ }
3581
+ function reductionTilePlan(kernel, strideMap) {
3582
+ if (!kernel.reduction) return null;
3583
+ const tileSize = commonTileSize(kernel.size, strideMap, simdLanes);
3584
+ if (tileSize === null) return null;
3585
+ const tileRows = tiledRows(kernel.size, tileSize);
3586
+ const tileVectors = tiledColumns(tileSize, simdLanes);
3587
+ return {
3588
+ tileSize,
3589
+ tileRows,
3590
+ tileVectors,
3591
+ tileK: divisorAtMost(kernel.reduction.size, TILED_SIMD_K),
3592
+ microRows: divisorAtMost(tileRows, TILED_SIMD_MICRO_ROWS),
3593
+ microVectors: divisorAtMost(tileVectors, TILED_SIMD_MICRO_VECTORS)
3594
+ };
3595
+ }
3596
+ function reductionKTilePlan(kernel, strideMap) {
3597
+ if (!kernel.reduction) return null;
3598
+ if (kernel.reduction.size % (simdLanes * K_SIMD_UNROLL) !== 0) return null;
3599
+ const tileSize = commonTileSize(kernel.size, strideMap, 1, 1);
3600
+ if (tileSize === null) return null;
3601
+ const tileRows = tiledRows(kernel.size, tileSize);
3602
+ const tileCols = tiledColumns(tileSize);
3603
+ return {
3604
+ tileSize,
3605
+ tileRows,
3606
+ tileCols,
3607
+ microRows: divisorAtMost(tileRows, K_SIMD_MICRO_ROWS),
3608
+ microCols: divisorAtMost(tileCols, K_SIMD_MICRO_COLS),
3609
+ kUnroll: K_SIMD_UNROLL
3610
+ };
3611
+ }
3612
+ function pointerShareKey(candidate, row, vector, groupIndex) {
3613
+ const hash = candidate.exp.getHash().toString();
3614
+ const { stride } = candidate;
3615
+ if (stride.kind === "broadcast") return isFinite(stride.tileSize) ? `${hash}:row${row}` : `${hash}:all`;
3616
+ if (stride.kind === "contiguous" && isFinite(stride.tileSize)) return repeatsAcrossGidxTile(candidate.baseIndex, stride.tileSize) ? `${hash}:vec${vector}` : `${hash}:row${row}:vec${vector}`;
3617
+ return `${hash}:g${groupIndex}`;
3618
+ }
3619
+ function kReductionPointerShareKey(candidate, outputStride, tileSize, row, col, groupIndex) {
3620
+ const hash = candidate.exp.getHash().toString();
3621
+ if (outputStride.kind === "broadcast" && isFinite(outputStride.tileSize)) return `${hash}:row${row}`;
3622
+ if (repeatsAcrossGidxTile(candidate.baseIndex, tileSize)) return `${hash}:col${col}`;
3623
+ return `${hash}:g${groupIndex}`;
3624
+ }
3625
+
3626
+ //#endregion
3627
+ //#region src/backend/wasm/translation.ts
3628
+ function translateExp(cg, funcs, exp, ctx, pointerMap = /* @__PURE__ */ new Map()) {
3629
+ const references = /* @__PURE__ */ new Map();
3630
+ const seen = /* @__PURE__ */ new Set();
3631
+ const countReferences = (exp$1) => {
3632
+ references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
3633
+ if (!seen.has(exp$1)) {
3634
+ seen.add(exp$1);
3635
+ for (const src of exp$1.src) countReferences(src);
3601
3636
  }
3602
- for (const outputType of func.outputTypes) this._push(outputType);
3603
- this._emit(16);
3604
- this._emit(encodeUnsigned(fn));
3605
- }
3606
- /** Throw away an operand on the stack. */
3607
- drop() {
3608
- this._pop();
3609
- this._emit(26);
3610
- }
3611
- /** Select one of the first two operands (T, F) based on the third operand (i32)'s value. */
3612
- select() {
3613
- assert(this._pop().typeId === this.i32.typeId, "select: expected i32 condition");
3614
- const [b, a] = [this._pop(), this._pop()];
3615
- assert(a.typeId === b.typeId, "select: expected same type for both operands");
3616
- this._push(a);
3617
- this._emit(27);
3618
- }
3619
- /** Import a JavaScript function; returns its index. */
3620
- importFunction(module$1, name, inputTypes, outputTypes) {
3621
- if (this.#functions.length > 0) throw new Error("function imports must precede defining functions");
3622
- const idx = this.#importedFunctions.length;
3623
- this.#importedFunctions.push({
3624
- module: module$1,
3625
- name,
3626
- inputTypes,
3627
- outputTypes
3628
- });
3629
- return idx;
3630
- }
3631
- /** Export a function. */
3632
- export(fn, name) {
3633
- this.#exportedFunctions.set(fn, name);
3634
- }
3635
- /** Declare a new function; returns its index. */
3636
- function(inputTypes, outputTypes, body) {
3637
- const idx = this.#importedFunctions.length + this.#functions.length;
3638
- this.#functions.push(new Function_(inputTypes, outputTypes, body));
3639
- return idx;
3640
- }
3641
- _declareLocal(type) {
3642
- assert(this.#curFunction !== null, "No current function");
3643
- const idx = this.#curFunction.locals.length + this.#curFunction.inputTypes.length;
3644
- this.#curFunction.locals.push(type);
3645
- return idx;
3646
- }
3647
- _inputTypes() {
3648
- assert(this.#curFunction !== null, "No current function");
3649
- return this.#curFunction.inputTypes;
3650
- }
3651
- _locals() {
3652
- assert(this.#curFunction !== null, "No current function");
3653
- return this.#curFunction.locals;
3637
+ };
3638
+ const expContext = /* @__PURE__ */ new Map();
3639
+ const gen = (exp$1) => {
3640
+ if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
3641
+ const { op, src, dtype, arg } = exp$1;
3642
+ if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
3643
+ gen(src[0]), gen(src[1]);
3644
+ if (op === AluOp.Add) if (dtype === DType.Bool) cg.i32.or();
3645
+ else dty(cg, op, dtype).add();
3646
+ else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
3647
+ else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
3648
+ else dty(cg, op, dtype).mul();
3649
+ else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
3650
+ dtyF(cg, op, dtype).div();
3651
+ dtyF(cg, op, dtype).trunc();
3652
+ } else if (dtype === DType.Uint32) cg.i32.div_u();
3653
+ else if (dtype === DType.Int32) cg.i32.div_s();
3654
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3655
+ else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
3656
+ const dt = dtyF(cg, op, dtype);
3657
+ const a = cg.local.declare(dt);
3658
+ const b = cg.local.declare(dt);
3659
+ cg.local.set(b);
3660
+ cg.local.tee(a);
3661
+ cg.local.get(a);
3662
+ cg.local.get(b);
3663
+ dt.div();
3664
+ dt.trunc();
3665
+ cg.local.get(b);
3666
+ dt.mul();
3667
+ dt.sub();
3668
+ } else if (dtype === DType.Uint32) cg.i32.rem_u();
3669
+ else if (dtype === DType.Int32) cg.i32.rem_s();
3670
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3671
+ else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3672
+ else dtyF(cg, op, dtype).max();
3673
+ else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
3674
+ const a = cg.local.declare(cg.i32);
3675
+ const b = cg.local.declare(cg.i32);
3676
+ cg.local.set(b);
3677
+ cg.local.tee(a);
3678
+ cg.local.get(b);
3679
+ cg.local.get(a);
3680
+ cg.local.get(b);
3681
+ if (dtype === DType.Int32) if (op === AluOp.Min) cg.i32.lt_s();
3682
+ else cg.i32.gt_s();
3683
+ else if (op === AluOp.Min) cg.i32.lt_u();
3684
+ else cg.i32.gt_u();
3685
+ cg.select();
3686
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3687
+ else if (op === AluOp.BitCombine) if (arg === "and") cg.i32.and();
3688
+ else if (arg === "or") cg.i32.or();
3689
+ else cg.i32.xor();
3690
+ else if (op === AluOp.BitShift) if (arg === "shl") cg.i32.shl();
3691
+ else cg.i32.shr_u();
3692
+ else if (op === AluOp.Cmplt) {
3693
+ const srcDtype = src[0].dtype;
3694
+ if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
3695
+ else if (srcDtype === DType.Int32) cg.i32.lt_s();
3696
+ else if (srcDtype === DType.Uint32) cg.i32.lt_u();
3697
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3698
+ } else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
3699
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3700
+ } else if (AluGroup.Unary.has(op)) {
3701
+ const callFuncF32 = (func) => {
3702
+ if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
3703
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3704
+ cg.call(func);
3705
+ if (dtype === DType.Float64) cg.f64.promote_f32();
3706
+ };
3707
+ if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
3708
+ else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
3709
+ else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
3710
+ else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
3711
+ else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
3712
+ else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
3713
+ else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
3714
+ else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
3715
+ else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
3716
+ else if (op === AluOp.Reciprocal) {
3717
+ const dt = dtyF(cg, op, dtype);
3718
+ dt.const(1), gen(src[0]), dt.div();
3719
+ } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
3720
+ else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
3721
+ else if (op === AluOp.Cast) {
3722
+ gen(src[0]);
3723
+ const dtype0 = src[0].dtype;
3724
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3725
+ if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
3726
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
3727
+ else if (i32repr);
3728
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3729
+ else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
3730
+ else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
3731
+ else if (i32repr);
3732
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3733
+ else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
3734
+ else if (dtype0 === DType.Float64) cg.f32.demote_f64();
3735
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
3736
+ else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
3737
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3738
+ else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
3739
+ else if (dtype0 === DType.Float64);
3740
+ else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
3741
+ else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
3742
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3743
+ else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
3744
+ else if (i32repr) cg.i32.const(0), cg.i32.ne();
3745
+ else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
3746
+ else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
3747
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3748
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3749
+ } else if (op === AluOp.Bitcast) {
3750
+ gen(src[0]);
3751
+ const dtype0 = src[0].dtype;
3752
+ if (dtype !== dtype0) {
3753
+ const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
3754
+ if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
3755
+ else if (i32repr);
3756
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3757
+ else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
3758
+ else if (dtype0 === DType.Float32);
3759
+ else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
3760
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3761
+ }
3762
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3763
+ } else if (op === AluOp.Where) {
3764
+ gen(src[1]);
3765
+ gen(src[2]);
3766
+ gen(src[0]);
3767
+ cg.select();
3768
+ } else if (op === AluOp.Threefry2x32) {
3769
+ for (let i = 0; i < 4; i++) gen(src[i]);
3770
+ cg.call(funcs.threefry2x32);
3771
+ if (arg === "xor") cg.i32.xor();
3772
+ else if (arg === 0) cg.drop();
3773
+ else if (arg === 1) {
3774
+ const local = cg.local.declare(cg.i32);
3775
+ cg.local.set(local);
3776
+ cg.drop();
3777
+ cg.local.get(local);
3778
+ } else throw new UnsupportedOpError(op, dtype, "wasm", arg);
3779
+ } else if (op === AluOp.Const) return dty(cg, op, dtype).const(arg);
3780
+ else if (op === AluOp.Special) return cg.local.get(ctx[arg[0]]);
3781
+ else if (op === AluOp.Variable) return cg.local.get(ctx[arg]);
3782
+ else if (op === AluOp.GlobalIndex) {
3783
+ const [gid, len] = arg;
3784
+ const pointer = pointerMap.get(exp$1);
3785
+ if (pointer) cg.local.get(pointer.ptr);
3786
+ else {
3787
+ gen(src[0]);
3788
+ const local = cg.local.declare(cg.i32);
3789
+ cg.local.tee(local);
3790
+ cg.i32.const(0);
3791
+ cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
3792
+ cg.select();
3793
+ cg.i32.const(byteWidth(dtype));
3794
+ cg.i32.mul();
3795
+ cg.local.get(gid);
3796
+ cg.i32.add();
3797
+ }
3798
+ dty(cg, op, dtype).load(Math.log2(byteWidth(dtype)));
3799
+ } else throw new UnsupportedOpError(op, dtype, "wasm");
3800
+ if ((references.get(exp$1) ?? 0) > 1) {
3801
+ const local = cg.local.declare(dty(cg, op, dtype));
3802
+ cg.local.tee(local);
3803
+ expContext.set(exp$1, local);
3804
+ }
3805
+ };
3806
+ countReferences(exp);
3807
+ gen(exp);
3808
+ }
3809
+ /**
3810
+ * SIMD version of translateExp. Emits one v128 value for `exp`, interpreting
3811
+ * the current `gidx` local as the first lane and the following lanes as
3812
+ * `gidx + 1`, `gidx + 2`, etc.
3813
+ *
3814
+ * GlobalIndex loads are the only places where per-lane address behavior
3815
+ * matters:
3816
+ * - `strideMap` classifies each GlobalIndex as contiguous, broadcast, or
3817
+ * gather. Contiguous loads become one v128.load, broadcast loads become a
3818
+ * scalar load plus splat, and gather loads fall back to four scalar loads.
3819
+ * - `pointerMap` optionally supplies precomputed reduction pointers for
3820
+ * GlobalIndex nodes whose address can be advanced by the surrounding
3821
+ * reduction loop. This avoids re-emitting scalar index math in the hot path.
3822
+ * - `pointerValueCache` lets pointer plans with the same `valueKey` share a
3823
+ * single loaded vector within one emitted reduction step.
3824
+ */
3825
+ function translateExpSimd(cg, funcs, exp, ctx, strideMap, pointerMap = /* @__PURE__ */ new Map(), pointerValueCache = /* @__PURE__ */ new Map()) {
3826
+ const references = /* @__PURE__ */ new Map();
3827
+ const seen = /* @__PURE__ */ new Set();
3828
+ const countReferences = (exp$1) => {
3829
+ references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
3830
+ if (!seen.has(exp$1)) {
3831
+ seen.add(exp$1);
3832
+ for (const src of exp$1.src) countReferences(src);
3833
+ }
3834
+ };
3835
+ const expContext = /* @__PURE__ */ new Map();
3836
+ const gen = (exp$1) => {
3837
+ if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
3838
+ const { op, src, arg, dtype } = exp$1;
3839
+ const isInt = dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool;
3840
+ const isSigned = dtype === DType.Int32;
3841
+ if (op === AluOp.Add) {
3842
+ gen(src[0]), gen(src[1]);
3843
+ if (dtype === DType.Bool) cg.v128.or();
3844
+ else if (isInt) cg.i32x4.add();
3845
+ else cg.f32x4.add();
3846
+ } else if (op === AluOp.Sub) {
3847
+ gen(src[0]), gen(src[1]);
3848
+ if (isInt) cg.i32x4.sub();
3849
+ else cg.f32x4.sub();
3850
+ } else if (op === AluOp.Mul) {
3851
+ gen(src[0]), gen(src[1]);
3852
+ if (dtype === DType.Bool) cg.v128.and();
3853
+ else if (isInt) cg.i32x4.mul();
3854
+ else cg.f32x4.mul();
3855
+ } else if (op === AluOp.Min) {
3856
+ gen(src[0]), gen(src[1]);
3857
+ if (isInt) if (isSigned) cg.i32x4.min_s();
3858
+ else cg.i32x4.min_u();
3859
+ else cg.f32x4.min();
3860
+ } else if (op === AluOp.Max) {
3861
+ gen(src[0]), gen(src[1]);
3862
+ if (isInt) if (isSigned) cg.i32x4.max_s();
3863
+ else cg.i32x4.max_u();
3864
+ else cg.f32x4.max();
3865
+ } else if (op === AluOp.Sqrt) {
3866
+ gen(src[0]);
3867
+ cg.f32x4.sqrt();
3868
+ } else if (op === AluOp.Floor) {
3869
+ gen(src[0]);
3870
+ cg.f32x4.floor();
3871
+ } else if (op === AluOp.Ceil) {
3872
+ gen(src[0]);
3873
+ cg.f32x4.ceil();
3874
+ } else if (op === AluOp.Const) if (isInt) {
3875
+ cg.i32.const(arg);
3876
+ cg.i32x4.splat();
3877
+ } else {
3878
+ cg.f32.const(arg);
3879
+ cg.f32x4.splat();
3880
+ }
3881
+ else if (op === AluOp.Cast) {
3882
+ gen(src[0]);
3883
+ const dtype0 = src[0].dtype;
3884
+ const src0IsInt = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
3885
+ if (isInt && !src0IsInt) if (isSigned) cg.i32x4.trunc_sat_f32x4_s();
3886
+ else cg.i32x4.trunc_sat_f32x4_u();
3887
+ else if (!isInt && src0IsInt) if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32x4.convert_i32x4_s();
3888
+ else cg.f32x4.convert_i32x4_u();
3889
+ } else if (op === AluOp.Cmplt) {
3890
+ gen(src[0]), gen(src[1]);
3891
+ const srcDtype = src[0].dtype;
3892
+ if (srcDtype === DType.Float32) cg.f32x4.lt();
3893
+ else if (srcDtype === DType.Int32) cg.i32x4.lt_s();
3894
+ else if (srcDtype === DType.Uint32) cg.i32x4.lt_u();
3895
+ else throw new UnsupportedOpError(op, dtype, "wasm");
3896
+ cg.i32.const(1);
3897
+ cg.i32x4.splat();
3898
+ cg.v128.and();
3899
+ } else if (op === AluOp.Cmpne) {
3900
+ gen(src[0]), gen(src[1]);
3901
+ const srcDtype = src[0].dtype;
3902
+ if (srcDtype === DType.Float32) cg.f32x4.ne();
3903
+ else cg.i32x4.ne();
3904
+ cg.i32.const(1);
3905
+ cg.i32x4.splat();
3906
+ cg.v128.and();
3907
+ } else if (op === AluOp.Where) {
3908
+ gen(src[1]);
3909
+ gen(src[2]);
3910
+ gen(src[0]);
3911
+ cg.i32.const(0);
3912
+ cg.i32x4.splat();
3913
+ cg.i32x4.ne();
3914
+ cg.v128.bitselect();
3915
+ } else if (op === AluOp.Variable || op === AluOp.Special) throw new Error(`translateExpSimd: unexpected ${op}(${arg})`);
3916
+ else if (op === AluOp.GlobalIndex) {
3917
+ const [gid, len] = arg;
3918
+ const indexSubtree = src[0];
3919
+ const pointer = pointerMap.get(exp$1);
3920
+ const stride = pointer?.stride ?? strideMap.get(exp$1) ?? { kind: "gather" };
3921
+ if (pointer) {
3922
+ const cached = pointer.valueKey ? pointerValueCache.get(pointer.valueKey) : void 0;
3923
+ if (cached !== void 0) cg.local.get(cached);
3924
+ else {
3925
+ cg.local.get(pointer.ptr);
3926
+ if (stride.kind === "contiguous") if (isInt) cg.i32x4.load(4);
3927
+ else cg.f32x4.load(4);
3928
+ else if (stride.kind === "broadcast") if (isInt) {
3929
+ cg.i32.load(2);
3930
+ cg.i32x4.splat();
3931
+ } else {
3932
+ cg.f32.load(2);
3933
+ cg.f32x4.splat();
3934
+ }
3935
+ else throw new Error("reduction pointer plan cannot use gather loads");
3936
+ if (pointer.valueKey) {
3937
+ const local = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
3938
+ cg.local.tee(local);
3939
+ pointerValueCache.set(pointer.valueKey, local);
3940
+ }
3941
+ }
3942
+ } else if (stride.kind === "contiguous") {
3943
+ translateExp(cg, funcs, indexSubtree, ctx);
3944
+ {
3945
+ const maxIdx = Math.max(len - simdLanes, 0);
3946
+ const wideIdx = cg.local.declare(cg.i32);
3947
+ cg.local.set(wideIdx);
3948
+ cg.local.get(wideIdx);
3949
+ cg.i32.const(maxIdx);
3950
+ cg.local.get(wideIdx);
3951
+ cg.i32.const(maxIdx);
3952
+ cg.i32.lt_u();
3953
+ cg.select();
3954
+ }
3955
+ cg.i32.const(byteWidth(dtype));
3956
+ cg.i32.mul();
3957
+ cg.local.get(gid);
3958
+ cg.i32.add();
3959
+ if (isInt) cg.i32x4.load(4);
3960
+ else cg.f32x4.load(4);
3961
+ } else if (stride.kind === "broadcast") {
3962
+ translateExp(cg, funcs, indexSubtree, ctx);
3963
+ const local = cg.local.declare(cg.i32);
3964
+ cg.local.tee(local);
3965
+ cg.i32.const(0);
3966
+ cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
3967
+ cg.select();
3968
+ cg.i32.const(byteWidth(dtype));
3969
+ cg.i32.mul();
3970
+ cg.local.get(gid);
3971
+ cg.i32.add();
3972
+ if (isInt) {
3973
+ cg.i32.load(2);
3974
+ cg.i32x4.splat();
3975
+ } else {
3976
+ cg.f32.load(2);
3977
+ cg.f32x4.splat();
3978
+ }
3979
+ } else {
3980
+ const steppingLocal = ctx["gidx"];
3981
+ const origValue = cg.local.declare(cg.i32);
3982
+ cg.local.get(steppingLocal);
3983
+ cg.local.set(origValue);
3984
+ if (isInt) {
3985
+ cg.i32.const(0);
3986
+ cg.i32x4.splat();
3987
+ } else {
3988
+ cg.f32.const(0);
3989
+ cg.f32x4.splat();
3990
+ }
3991
+ const vec = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
3992
+ cg.local.set(vec);
3993
+ const idx = cg.local.declare(cg.i32);
3994
+ const scalarVal = cg.local.declare(isInt ? cg.i32 : cg.f32);
3995
+ for (let lane = 0; lane < simdLanes; lane++) {
3996
+ cg.local.get(origValue);
3997
+ if (lane > 0) {
3998
+ cg.i32.const(lane);
3999
+ cg.i32.add();
4000
+ }
4001
+ cg.local.set(steppingLocal);
4002
+ translateExp(cg, funcs, indexSubtree, ctx);
4003
+ cg.local.tee(idx);
4004
+ cg.i32.const(0);
4005
+ cg.local.get(idx), cg.i32.const(len), cg.i32.lt_u();
4006
+ cg.select();
4007
+ cg.i32.const(byteWidth(dtype));
4008
+ cg.i32.mul();
4009
+ cg.local.get(gid);
4010
+ cg.i32.add();
4011
+ if (isInt) cg.i32.load(2);
4012
+ else cg.f32.load(2);
4013
+ cg.local.set(scalarVal);
4014
+ cg.local.get(vec);
4015
+ cg.local.get(scalarVal);
4016
+ if (isInt) cg.i32x4.replace_lane(lane);
4017
+ else cg.f32x4.replace_lane(lane);
4018
+ cg.local.set(vec);
4019
+ }
4020
+ cg.local.get(origValue);
4021
+ cg.local.set(steppingLocal);
4022
+ cg.local.get(vec);
4023
+ }
4024
+ } else throw new Error(`translateExpSimd: unsupported op ${op}`);
4025
+ if ((references.get(exp$1) ?? 0) > 1) {
4026
+ const local = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
4027
+ cg.local.tee(local);
4028
+ expContext.set(exp$1, local);
4029
+ }
4030
+ };
4031
+ countReferences(exp);
4032
+ gen(exp);
4033
+ }
4034
+ function dty(cg, op, dtype) {
4035
+ switch (dtype) {
4036
+ case DType.Float32: return cg.f32;
4037
+ case DType.Float64: return cg.f64;
4038
+ case DType.Int32:
4039
+ case DType.Uint32:
4040
+ case DType.Bool: return cg.i32;
4041
+ default: throw new UnsupportedOpError(op, dtype, "wasm");
3654
4042
  }
3655
- _push(type) {
3656
- if (!type) throw new Error(`pushing type ${type}`);
3657
- this.#typeStack.push(type);
4043
+ }
4044
+ function dtyF(cg, op, dtype) {
4045
+ switch (dtype) {
4046
+ case DType.Float32: return cg.f32;
4047
+ case DType.Float64: return cg.f64;
4048
+ default: throw new UnsupportedOpError(op, dtype, "wasm");
3658
4049
  }
3659
- _pop() {
3660
- assert(this.#typeStack.length > 0, "popping empty stack");
3661
- return this.#typeStack.pop();
4050
+ }
4051
+ /** Subset of operations supported in SIMD compilation mode. */
4052
+ const simdSupportedOps = /* @__PURE__ */ new Map();
4053
+ simdSupportedOps.set(DType.Float32, new Set([
4054
+ AluOp.Add,
4055
+ AluOp.Sub,
4056
+ AluOp.Mul,
4057
+ AluOp.Floor,
4058
+ AluOp.Ceil,
4059
+ AluOp.Min,
4060
+ AluOp.Max,
4061
+ AluOp.Sqrt,
4062
+ AluOp.Cast,
4063
+ AluOp.Where,
4064
+ AluOp.Const,
4065
+ AluOp.GlobalIndex
4066
+ ]));
4067
+ simdSupportedOps.set(DType.Int32, new Set([
4068
+ AluOp.Add,
4069
+ AluOp.Sub,
4070
+ AluOp.Mul,
4071
+ AluOp.Min,
4072
+ AluOp.Max,
4073
+ AluOp.Cast,
4074
+ AluOp.Where,
4075
+ AluOp.Const,
4076
+ AluOp.GlobalIndex
4077
+ ]));
4078
+ simdSupportedOps.set(DType.Uint32, simdSupportedOps.get(DType.Int32));
4079
+ simdSupportedOps.set(DType.Bool, new Set([
4080
+ AluOp.Add,
4081
+ AluOp.Mul,
4082
+ AluOp.Min,
4083
+ AluOp.Max,
4084
+ AluOp.Cmplt,
4085
+ AluOp.Cmpne,
4086
+ AluOp.Const,
4087
+ AluOp.GlobalIndex
4088
+ ]));
4089
+
4090
+ //#endregion
4091
+ //#region src/backend/wasm/wasmblr.ts
4092
+ /**
4093
+ * @file Minimalist WebAssembly assembler. This allows you to emit WebAssembly
4094
+ * bytecode directly from the browser.
4095
+ *
4096
+ * Self-contained port of https://github.com/bwasti/wasmblr to TypeScript.
4097
+ * Some operation names in this module are written in `snake_case` to match
4098
+ * their names in the Wasm specification.
4099
+ *
4100
+ * Reference: https://pengowray.github.io/wasm-ops/.
4101
+ */
4102
+ const magicModuleHeader = [
4103
+ 0,
4104
+ 97,
4105
+ 115,
4106
+ 109
4107
+ ];
4108
+ const moduleVersion = [
4109
+ 1,
4110
+ 0,
4111
+ 0,
4112
+ 0
4113
+ ];
4114
+ function assert(condition, message) {
4115
+ if (!condition) throw new Error(message || "Assertion failed");
4116
+ }
4117
+ function encodeSigned(n) {
4118
+ const out = [];
4119
+ let more = true;
4120
+ while (more) {
4121
+ let byte = n & 127;
4122
+ n >>= 7;
4123
+ if (n === 0 && (byte & 64) === 0 || n === -1 && (byte & 64) !== 0) more = false;
4124
+ else byte |= 128;
4125
+ out.push(byte);
3662
4126
  }
3663
- _emit(bytes) {
3664
- if (typeof bytes === "number") this.#curBytes.push(bytes);
3665
- else this.#curBytes.push(...bytes);
4127
+ return out;
4128
+ }
4129
+ function encodeUnsigned(n) {
4130
+ const out = [];
4131
+ do {
4132
+ let byte = n & 127;
4133
+ n = n >>> 7;
4134
+ if (n !== 0) byte |= 128;
4135
+ out.push(byte);
4136
+ } while (n !== 0);
4137
+ return out;
4138
+ }
4139
+ function encodeString(s) {
4140
+ const bytes = new TextEncoder().encode(s);
4141
+ return [bytes.length, ...bytes];
4142
+ }
4143
+ function encodeBlocktype(type) {
4144
+ assert(type.length > 0, "blocktype must have at least one type");
4145
+ if (type.length === 1) return [type[0].typeId];
4146
+ return [
4147
+ 96,
4148
+ ...encodeUnsigned(0),
4149
+ ...encodeUnsigned(type.length),
4150
+ ...type.map((t) => t.typeId)
4151
+ ];
4152
+ }
4153
+ function encodeOpcode(opcode) {
4154
+ if (typeof opcode === "number") return [opcode];
4155
+ return [opcode[0], ...encodeUnsigned(opcode[1])];
4156
+ }
4157
+ function appendLengthEncodedBlock(out, inp) {
4158
+ out.push(...encodeUnsigned(inp.length));
4159
+ for (const b of inp) out.push(b);
4160
+ }
4161
+ var Function_ = class {
4162
+ inputTypes;
4163
+ outputTypes;
4164
+ body;
4165
+ locals = [];
4166
+ constructor(inputTypes, outputTypes, body) {
4167
+ this.inputTypes = inputTypes;
4168
+ this.outputTypes = outputTypes;
4169
+ this.body = body || (() => {});
3666
4170
  }
3667
- finish() {
3668
- this.#curBytes = [];
3669
- const emittedBytes = [];
3670
- concat(emittedBytes, magicModuleHeader);
3671
- concat(emittedBytes, moduleVersion);
3672
- const typeSectionBytes = [];
3673
- const totalFunctionTypes = this.#importedFunctions.length + this.#functions.length;
3674
- concat(typeSectionBytes, encodeUnsigned(totalFunctionTypes));
3675
- for (const f of [...this.#importedFunctions, ...this.#functions]) {
3676
- typeSectionBytes.push(96);
3677
- concat(typeSectionBytes, encodeUnsigned(f.inputTypes.length));
3678
- for (const t of f.inputTypes) typeSectionBytes.push(t.typeId);
3679
- concat(typeSectionBytes, encodeUnsigned(f.outputTypes.length));
3680
- for (const t of f.outputTypes) typeSectionBytes.push(t.typeId);
3681
- }
3682
- emittedBytes.push(1);
3683
- concat(emittedBytes, encodeUnsigned(typeSectionBytes.length));
3684
- concat(emittedBytes, typeSectionBytes);
3685
- const importSectionBytes = [];
3686
- const numImports = this.#importedFunctions.length + (this.memory.isImport ? 1 : 0);
3687
- if (numImports > 0) {
3688
- concat(importSectionBytes, encodeUnsigned(numImports));
3689
- for (let i = 0; i < this.#importedFunctions.length; i++) {
3690
- const f = this.#importedFunctions[i];
3691
- concat(importSectionBytes, encodeString(f.module));
3692
- concat(importSectionBytes, encodeString(f.name));
3693
- importSectionBytes.push(0);
3694
- concat(importSectionBytes, encodeUnsigned(i));
3695
- }
3696
- if (this.memory.isImport) {
3697
- concat(importSectionBytes, encodeString(this.memory.aString));
3698
- concat(importSectionBytes, encodeString(this.memory.bString));
3699
- importSectionBytes.push(2);
3700
- if (this.memory.max) {
3701
- if (this.memory.isShared) importSectionBytes.push(3);
3702
- else importSectionBytes.push(1);
3703
- concat(importSectionBytes, encodeUnsigned(this.memory.min));
3704
- concat(importSectionBytes, encodeUnsigned(this.memory.max));
3705
- } else {
3706
- assert(!this.memory.isShared, "shared memory must have a max size");
3707
- importSectionBytes.push(0);
3708
- concat(importSectionBytes, encodeUnsigned(this.memory.min));
3709
- }
3710
- }
3711
- emittedBytes.push(2);
3712
- concat(emittedBytes, encodeUnsigned(importSectionBytes.length));
3713
- concat(emittedBytes, importSectionBytes);
3714
- }
3715
- const functionSectionBytes = [];
3716
- concat(functionSectionBytes, encodeUnsigned(this.#functions.length));
3717
- for (let i = 0; i < this.#functions.length; i++) {
3718
- const typeIndex = this.#importedFunctions.length + i;
3719
- concat(functionSectionBytes, encodeUnsigned(typeIndex));
3720
- }
3721
- emittedBytes.push(3);
3722
- concat(emittedBytes, encodeUnsigned(functionSectionBytes.length));
3723
- concat(emittedBytes, functionSectionBytes);
3724
- const memorySectionBytes = [];
3725
- if (!this.memory.isImport && (this.memory.min || this.memory.max)) {
3726
- memorySectionBytes.push(1);
3727
- if (this.memory.min && this.memory.max) {
3728
- if (this.memory.isShared) memorySectionBytes.push(3);
3729
- else memorySectionBytes.push(1);
3730
- concat(memorySectionBytes, encodeUnsigned(this.memory.min));
3731
- concat(memorySectionBytes, encodeUnsigned(this.memory.max));
3732
- } else {
3733
- assert(!this.memory.isShared, "shared memory must have a max size");
3734
- memorySectionBytes.push(0);
3735
- concat(memorySectionBytes, encodeUnsigned(this.memory.min));
3736
- }
3737
- emittedBytes.push(5);
3738
- concat(emittedBytes, encodeUnsigned(memorySectionBytes.length));
3739
- concat(emittedBytes, memorySectionBytes);
3740
- }
3741
- const exportSectionBytes = [];
3742
- const numExports = this.#exportedFunctions.size + (this.memory.isExport ? 1 : 0);
3743
- concat(exportSectionBytes, encodeUnsigned(numExports));
3744
- if (this.memory.isExport) {
3745
- concat(exportSectionBytes, encodeString(this.memory.aString));
3746
- exportSectionBytes.push(2);
3747
- exportSectionBytes.push(0);
3748
- }
3749
- for (const [key, name] of this.#exportedFunctions.entries()) {
3750
- concat(exportSectionBytes, encodeString(name));
3751
- exportSectionBytes.push(0);
3752
- concat(exportSectionBytes, encodeUnsigned(key));
3753
- }
3754
- emittedBytes.push(7);
3755
- concat(emittedBytes, encodeUnsigned(exportSectionBytes.length));
3756
- concat(emittedBytes, exportSectionBytes);
3757
- const codeSectionBytes = [];
3758
- concat(codeSectionBytes, encodeUnsigned(this.#functions.length));
3759
- for (const f of this.#functions) {
3760
- this.#typeStack = [];
3761
- this.#blockFrames = [{
3762
- idx: 0,
3763
- ty: f.outputTypes
3764
- }];
3765
- this.#curFunction = f;
3766
- this.#curBytes = [];
3767
- f.emit();
3768
- this.end();
3769
- const bodyBytes = [...this.#curBytes];
3770
- this.#curBytes = [];
3771
- concat(this.#curBytes, encodeUnsigned(f.locals.length));
3772
- for (const l of f.locals) {
3773
- this._emit(1);
3774
- this._emit(l.typeId);
3775
- }
3776
- const headerBytes = [...this.#curBytes];
3777
- const fnSize = headerBytes.length + bodyBytes.length;
3778
- concat(codeSectionBytes, encodeUnsigned(fnSize));
3779
- concat(codeSectionBytes, headerBytes);
3780
- concat(codeSectionBytes, bodyBytes);
3781
- }
3782
- this.#curFunction = null;
3783
- emittedBytes.push(10);
3784
- concat(emittedBytes, encodeUnsigned(codeSectionBytes.length));
3785
- concat(emittedBytes, codeSectionBytes);
3786
- return new Uint8Array(emittedBytes);
4171
+ emit() {
4172
+ this.locals = [];
4173
+ this.body();
3787
4174
  }
3788
4175
  };
3789
- var Local = class {
4176
+ var Memory = class {
4177
+ min = 0;
4178
+ max = 0;
4179
+ isShared = false;
4180
+ aString = "";
4181
+ bString = "";
3790
4182
  constructor(cg) {
3791
4183
  this.cg = cg;
3792
4184
  }
3793
- declare(type) {
3794
- return this.cg._declareLocal(type);
4185
+ /** Declare the size of the memory. Each page is 64 KiB. */
4186
+ pages(min, max = 0) {
4187
+ assert(this.min === 0 && this.max === 0);
4188
+ this.min = min;
4189
+ this.max = max;
4190
+ return this;
3795
4191
  }
3796
- get(idx) {
3797
- assert(Number.isInteger(idx), "getting non-integer local");
3798
- const inputTypes = this.cg._inputTypes();
3799
- if (idx < inputTypes.length) this.cg._push(inputTypes[idx]);
3800
- else this.cg._push(this.cg._locals()[idx - inputTypes.length]);
3801
- this.cg._emit(32);
3802
- this.cg._emit(encodeUnsigned(idx));
4192
+ export(a) {
4193
+ assert(!this.isImport && !this.isExport, "already set");
4194
+ this.aString = a;
4195
+ return this;
3803
4196
  }
3804
- set(idx) {
3805
- const t = this.cg._pop();
3806
- const inputTypes = this.cg._inputTypes();
3807
- const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
3808
- assert(expectedType.typeId === t.typeId, "can't set local to this value (wrong type)");
3809
- this.cg._emit(33);
3810
- this.cg._emit(encodeUnsigned(idx));
4197
+ shared(isShared) {
4198
+ this.isShared = isShared;
4199
+ return this;
4200
+ }
4201
+ import(a, b) {
4202
+ assert(!this.isImport && !this.isExport, "already set");
4203
+ this.aString = a;
4204
+ this.bString = b;
4205
+ return this;
4206
+ }
4207
+ size() {
4208
+ this.cg._emit(63);
4209
+ this.cg._emit(0);
4210
+ }
4211
+ grow() {
4212
+ this.cg._emit(64);
4213
+ this.cg._emit(0);
4214
+ }
4215
+ get isImport() {
4216
+ return this.aString.length > 0 && this.bString.length > 0;
3811
4217
  }
3812
- tee(idx) {
3813
- const t = this.cg._pop();
3814
- const inputTypes = this.cg._inputTypes();
3815
- const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
3816
- assert(expectedType.typeId === t.typeId, "can't tee local to this value (wrong type)");
3817
- this.cg._emit(34);
3818
- this.cg._emit(encodeUnsigned(idx));
3819
- this.cg._push(expectedType);
4218
+ get isExport() {
4219
+ return this.aString.length > 0 && this.bString.length === 0;
3820
4220
  }
3821
4221
  };
3822
- function UNARY_OP(op, opcode, inType, outType) {
3823
- return function() {
3824
- const t = this.cg._pop();
3825
- assert(t.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inType} -> ${outType})`);
3826
- this.cg._emit(encodeOpcode(opcode));
3827
- this.cg._push(this.cg[outType]);
3828
- };
3829
- }
3830
- function BINARY_OP(op, opcode, typeA, typeB, outType) {
3831
- return function() {
3832
- const b = this.cg._pop();
3833
- const a = this.cg._pop();
3834
- assert(a.typeId === this.cg[typeA].typeId && b.typeId === this.cg[typeB].typeId, `invalid type for ${op} (${typeA}, ${typeB} -> ${outType})`);
3835
- this.cg._emit(encodeOpcode(opcode));
3836
- this.cg._push(this.cg[outType]);
3837
- };
3838
- }
3839
- function LOAD_OP(op, opcode, outType) {
3840
- return function(align = 0, offset = 0) {
3841
- const idxType = this.cg._pop();
3842
- assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3843
- this.cg._emit(encodeOpcode(opcode));
3844
- this.cg._emit(encodeUnsigned(align));
3845
- this.cg._emit(encodeUnsigned(offset));
3846
- this.cg._push(this.cg[outType]);
3847
- };
3848
- }
3849
- function STORE_OP(op, opcode, inType) {
3850
- return function(align = 0, offset = 0) {
3851
- const valType = this.cg._pop();
3852
- const idxType = this.cg._pop();
3853
- assert(valType.typeId === this.cg[inType].typeId, `invalid value type for ${op} (${inType})`);
3854
- assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
3855
- this.cg._emit(encodeOpcode(opcode));
3856
- this.cg._emit(encodeUnsigned(align));
3857
- this.cg._emit(encodeUnsigned(offset));
4222
+ /** Public API of WebAssembly assembler. */
4223
+ var CodeGenerator = class {
4224
+ local;
4225
+ i32;
4226
+ f32;
4227
+ f64;
4228
+ v128;
4229
+ i32x4;
4230
+ f32x4;
4231
+ memory;
4232
+ void = {
4233
+ typeId: 64,
4234
+ name: "void"
3858
4235
  };
3859
- }
3860
- var I32 = class {
3861
- constructor(cg) {
3862
- this.cg = cg;
4236
+ #functions = [];
4237
+ #importedFunctions = [];
4238
+ #exportedFunctions = /* @__PURE__ */ new Map();
4239
+ #curFunction = null;
4240
+ #curBytes = [];
4241
+ #typeStack = [];
4242
+ #blockFrames = [];
4243
+ constructor() {
4244
+ this.local = new Local(this);
4245
+ this.i32 = new I32(this);
4246
+ this.f32 = new F32(this);
4247
+ this.f64 = new F64(this);
4248
+ this.v128 = new V128(this);
4249
+ this.i32x4 = new I32x4(this);
4250
+ this.f32x4 = new F32x4(this);
4251
+ this.memory = new Memory(this);
3863
4252
  }
3864
- get typeId() {
3865
- return 127;
4253
+ unreachable() {
4254
+ this._emit(0);
3866
4255
  }
3867
- get name() {
3868
- return "i32";
4256
+ nop() {
4257
+ this._emit(1);
3869
4258
  }
3870
- const(i) {
3871
- this.cg._emit(65);
3872
- this.cg._emit(encodeSigned(i));
3873
- this.cg._push(this);
4259
+ block(...type) {
4260
+ this.#blockFrames.push({
4261
+ idx: this.#typeStack.length,
4262
+ ty: type
4263
+ });
4264
+ this._emit(2);
4265
+ this._emit(encodeBlocktype(type));
3874
4266
  }
3875
- clz = UNARY_OP("clz", 103, "i32", "i32");
3876
- ctz = UNARY_OP("ctz", 104, "i32", "i32");
3877
- popcnt = UNARY_OP("popcnt", 105, "i32", "i32");
3878
- lt_s = BINARY_OP("lt_s", 72, "i32", "i32", "i32");
3879
- lt_u = BINARY_OP("lt_u", 73, "i32", "i32", "i32");
3880
- gt_s = BINARY_OP("gt_s", 74, "i32", "i32", "i32");
3881
- gt_u = BINARY_OP("gt_u", 75, "i32", "i32", "i32");
3882
- le_s = BINARY_OP("le_s", 76, "i32", "i32", "i32");
3883
- le_u = BINARY_OP("le_u", 77, "i32", "i32", "i32");
3884
- ge_s = BINARY_OP("ge_s", 78, "i32", "i32", "i32");
3885
- ge_u = BINARY_OP("ge_u", 79, "i32", "i32", "i32");
3886
- add = BINARY_OP("add", 106, "i32", "i32", "i32");
3887
- sub = BINARY_OP("sub", 107, "i32", "i32", "i32");
3888
- mul = BINARY_OP("mul", 108, "i32", "i32", "i32");
3889
- div_s = BINARY_OP("div_s", 109, "i32", "i32", "i32");
3890
- div_u = BINARY_OP("div_u", 110, "i32", "i32", "i32");
3891
- rem_s = BINARY_OP("rem_s", 111, "i32", "i32", "i32");
3892
- rem_u = BINARY_OP("rem_u", 112, "i32", "i32", "i32");
3893
- and = BINARY_OP("and", 113, "i32", "i32", "i32");
3894
- or = BINARY_OP("or", 114, "i32", "i32", "i32");
3895
- xor = BINARY_OP("xor", 115, "i32", "i32", "i32");
3896
- shl = BINARY_OP("shl", 116, "i32", "i32", "i32");
3897
- shr_s = BINARY_OP("shr_s", 117, "i32", "i32", "i32");
3898
- shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
3899
- rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
3900
- rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
3901
- eqz = UNARY_OP("eqz", 69, "i32", "i32");
3902
- eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
3903
- ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3904
- trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
3905
- trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
3906
- trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
3907
- trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
3908
- load = LOAD_OP("load", 40, "i32");
3909
- load8_s = LOAD_OP("load8_s", 44, "i32");
3910
- load8_u = LOAD_OP("load8_u", 45, "i32");
3911
- load16_s = LOAD_OP("load16_s", 46, "i32");
3912
- load16_u = LOAD_OP("load16_u", 47, "i32");
3913
- store = STORE_OP("store", 54, "i32");
3914
- store8 = STORE_OP("store8", 58, "i32");
3915
- store16 = STORE_OP("store16", 59, "i32");
3916
- reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
3917
- trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
3918
- trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
3919
- trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
3920
- trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
3921
- };
3922
- var F32 = class {
3923
- constructor(cg) {
3924
- this.cg = cg;
4267
+ loop(...type) {
4268
+ this.#blockFrames.push({
4269
+ idx: this.#typeStack.length,
4270
+ ty: type
4271
+ });
4272
+ this._emit(3);
4273
+ this._emit(encodeBlocktype(type));
3925
4274
  }
3926
- get typeId() {
3927
- return 125;
4275
+ if(...type) {
4276
+ assert(this._pop().typeId === this.i32.typeId, "if_: expected i32");
4277
+ this.#blockFrames.push({
4278
+ idx: this.#typeStack.length,
4279
+ ty: type
4280
+ });
4281
+ this._emit(4);
4282
+ this._emit(encodeBlocktype(type));
3928
4283
  }
3929
- get name() {
3930
- return "f32";
4284
+ else() {
4285
+ assert(this.#blockFrames.length > 0, "else: no block to else");
4286
+ const frame = this.#blockFrames[this.#blockFrames.length - 1];
4287
+ this.#typeStack = this.#typeStack.slice(0, frame.idx);
4288
+ this._emit(5);
3931
4289
  }
3932
- const(f) {
3933
- this.cg._emit(67);
3934
- const buffer = /* @__PURE__ */ new ArrayBuffer(4);
3935
- new DataView(buffer).setFloat32(0, f, true);
3936
- const bytes = new Uint8Array(buffer);
3937
- for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
3938
- this.cg._push(this);
4290
+ /** End a block (`block`, `if`/`else`, `loop`, or function). */
4291
+ end() {
4292
+ const frame = this.#blockFrames.pop();
4293
+ assert(frame !== void 0, "end: no block to end");
4294
+ this.#typeStack = this.#typeStack.slice(0, frame.idx);
4295
+ for (const ty of frame.ty) if (ty.typeId !== this.void.typeId) this._push(ty);
4296
+ this._emit(11);
3939
4297
  }
3940
- load = LOAD_OP("load", 42, "f32");
3941
- store = STORE_OP("store", 56, "f32");
3942
- eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
3943
- ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
3944
- lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
3945
- gt = BINARY_OP("gt", 94, "f32", "f32", "i32");
3946
- le = BINARY_OP("le", 95, "f32", "f32", "i32");
3947
- ge = BINARY_OP("ge", 96, "f32", "f32", "i32");
3948
- abs = UNARY_OP("abs", 139, "f32", "f32");
3949
- neg = UNARY_OP("neg", 140, "f32", "f32");
3950
- ceil = UNARY_OP("ceil", 141, "f32", "f32");
3951
- floor = UNARY_OP("floor", 142, "f32", "f32");
3952
- trunc = UNARY_OP("trunc", 143, "f32", "f32");
3953
- nearest = UNARY_OP("nearest", 144, "f32", "f32");
3954
- sqrt = UNARY_OP("sqrt", 145, "f32", "f32");
3955
- add = BINARY_OP("add", 146, "f32", "f32", "f32");
3956
- sub = BINARY_OP("sub", 147, "f32", "f32", "f32");
3957
- mul = BINARY_OP("mul", 148, "f32", "f32", "f32");
3958
- div = BINARY_OP("div", 149, "f32", "f32", "f32");
3959
- min = BINARY_OP("min", 150, "f32", "f32", "f32");
3960
- max = BINARY_OP("max", 151, "f32", "f32", "f32");
3961
- copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
3962
- convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
3963
- convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
3964
- demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
3965
- reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
3966
- };
3967
- var F64 = class {
3968
- constructor(cg) {
3969
- this.cg = cg;
4298
+ /** Branch to a block a certain depth outward on the stack. */
4299
+ br(depth) {
4300
+ this._emit(12);
4301
+ this._emit(encodeUnsigned(depth));
3970
4302
  }
3971
- get typeId() {
3972
- return 124;
4303
+ /** Conditional branch to a block a certain depth outward on the stack. */
4304
+ br_if(depth) {
4305
+ assert(this._pop().typeId === this.i32.typeId, "br_if: expected i32");
4306
+ this._emit(13);
4307
+ this._emit(encodeUnsigned(depth));
4308
+ }
4309
+ /** Jump table that indexes into a label vector (like switch). */
4310
+ br_table(...depths) {
4311
+ assert(this._pop().typeId === this.i32.typeId, "br_table: expected i32");
4312
+ assert(depths.length > 0, "br_table: expected at least one default depth");
4313
+ this._emit(14);
4314
+ this._emit(encodeUnsigned(depths.length - 1));
4315
+ for (const d of depths) this._emit(encodeUnsigned(d));
4316
+ }
4317
+ /** Return from a function, branching out of the outermost block. */
4318
+ return() {
4319
+ this._emit(15);
4320
+ }
4321
+ /** Call a function with the given ID. */
4322
+ call(fn) {
4323
+ const totalFunctions = this.#importedFunctions.length + this.#functions.length;
4324
+ assert(fn < totalFunctions, "function index does not exist");
4325
+ const func = fn < this.#importedFunctions.length ? this.#importedFunctions[fn] : this.#functions[fn - this.#importedFunctions.length];
4326
+ for (let i = func.inputTypes.length - 1; i >= 0; i--) {
4327
+ const argType = this._pop();
4328
+ assert(argType.typeId === func.inputTypes[i].typeId, `call: argument ${i} type mismatch, expected ${func.inputTypes[i].name} got ${argType.name}`);
4329
+ }
4330
+ for (const outputType of func.outputTypes) this._push(outputType);
4331
+ this._emit(16);
4332
+ this._emit(encodeUnsigned(fn));
4333
+ }
4334
+ /** Throw away an operand on the stack. */
4335
+ drop() {
4336
+ this._pop();
4337
+ this._emit(26);
3973
4338
  }
3974
- get name() {
3975
- return "f64";
4339
+ /** Select one of the first two operands (T, F) based on the third operand (i32)'s value. */
4340
+ select() {
4341
+ assert(this._pop().typeId === this.i32.typeId, "select: expected i32 condition");
4342
+ const [b, a] = [this._pop(), this._pop()];
4343
+ assert(a.typeId === b.typeId, "select: expected same type for both operands");
4344
+ this._push(a);
4345
+ this._emit(27);
3976
4346
  }
3977
- const(f) {
3978
- this.cg._emit(68);
3979
- const buffer = /* @__PURE__ */ new ArrayBuffer(8);
3980
- new DataView(buffer).setFloat64(0, f, true);
3981
- const bytes = new Uint8Array(buffer);
3982
- for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
3983
- this.cg._push(this);
4347
+ /** Import a JavaScript function; returns its index. */
4348
+ importFunction(module$1, name, inputTypes, outputTypes) {
4349
+ if (this.#functions.length > 0) throw new Error("function imports must precede defining functions");
4350
+ const idx = this.#importedFunctions.length;
4351
+ this.#importedFunctions.push({
4352
+ module: module$1,
4353
+ name,
4354
+ inputTypes,
4355
+ outputTypes
4356
+ });
4357
+ return idx;
3984
4358
  }
3985
- load = LOAD_OP("load", 43, "f64");
3986
- store = STORE_OP("store", 57, "f64");
3987
- eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
3988
- ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
3989
- lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
3990
- gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
3991
- le = BINARY_OP("le", 101, "f64", "f64", "i32");
3992
- ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
3993
- abs = UNARY_OP("abs", 153, "f64", "f64");
3994
- neg = UNARY_OP("neg", 154, "f64", "f64");
3995
- ceil = UNARY_OP("ceil", 155, "f64", "f64");
3996
- floor = UNARY_OP("floor", 156, "f64", "f64");
3997
- trunc = UNARY_OP("trunc", 157, "f64", "f64");
3998
- nearest = UNARY_OP("nearest", 158, "f64", "f64");
3999
- sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
4000
- add = BINARY_OP("add", 160, "f64", "f64", "f64");
4001
- sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
4002
- mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
4003
- div = BINARY_OP("div", 163, "f64", "f64", "f64");
4004
- min = BINARY_OP("min", 164, "f64", "f64", "f64");
4005
- max = BINARY_OP("max", 165, "f64", "f64", "f64");
4006
- copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
4007
- convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
4008
- convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
4009
- promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
4010
- };
4011
- function VECTOR_OP(op, vopcode, inTypes, outType) {
4012
- return function() {
4013
- for (const inType of inTypes.toReversed()) {
4014
- const actualType = this.cg._pop();
4015
- assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes.join(", ")} -> ${outType})`);
4016
- }
4017
- this.cg._emit(encodeOpcode([253, vopcode]));
4018
- this.cg._push(this.cg[outType]);
4019
- };
4020
- }
4021
- function VECTOR_OPL(op, vopcode, inTypes, outType) {
4022
- return function(lane) {
4023
- for (const inType of inTypes.toReversed()) {
4024
- const actualType = this.cg._pop();
4025
- assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes} -> ${outType})`);
4026
- }
4027
- this.cg._emit(encodeOpcode([253, vopcode]));
4028
- this.cg._emit(lane);
4029
- this.cg._push(this.cg[outType]);
4030
- };
4031
- }
4032
- function VECTOR_LOAD_OP(op, vopcode) {
4033
- return function(align = 0, offset = 0) {
4034
- const idxType = this.cg._pop();
4035
- assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
4036
- this.cg._emit(encodeOpcode([253, vopcode]));
4037
- this.cg._emit(encodeUnsigned(align));
4038
- this.cg._emit(encodeUnsigned(offset));
4039
- this.cg._push(this.cg.v128);
4040
- };
4041
- }
4042
- var V128 = class {
4043
- constructor(cg) {
4044
- this.cg = cg;
4359
+ /** Export a function. */
4360
+ export(fn, name) {
4361
+ this.#exportedFunctions.set(fn, name);
4045
4362
  }
4046
- get typeId() {
4047
- return 123;
4363
+ /** Declare a new function; returns its index. */
4364
+ function(inputTypes, outputTypes, body) {
4365
+ const idx = this.#importedFunctions.length + this.#functions.length;
4366
+ this.#functions.push(new Function_(inputTypes, outputTypes, body));
4367
+ return idx;
4048
4368
  }
4049
- get name() {
4050
- return "v128";
4369
+ _declareLocal(type) {
4370
+ assert(this.#curFunction !== null, "No current function");
4371
+ const idx = this.#curFunction.locals.length + this.#curFunction.inputTypes.length;
4372
+ this.#curFunction.locals.push(type);
4373
+ return idx;
4051
4374
  }
4052
- load = VECTOR_LOAD_OP("load", 0);
4053
- load32x2_s = VECTOR_LOAD_OP("load32x2_s", 5);
4054
- load32x2_u = VECTOR_LOAD_OP("load32x2_u", 6);
4055
- load32_splat = VECTOR_LOAD_OP("load32_splat", 9);
4056
- load32_zero = VECTOR_LOAD_OP("load32_zero", 92);
4057
- store(align = 0, offset = 0) {
4058
- const valType = this.cg._pop();
4059
- assert(valType.typeId === this.cg.v128.typeId, `invalid type for store`);
4060
- const idxType = this.cg._pop();
4061
- assert(idxType.typeId === this.cg.i32.typeId, `invalid type for store`);
4062
- this.cg._emit(253);
4063
- this.cg._emit(encodeUnsigned(11));
4064
- this.cg._emit(encodeUnsigned(align));
4065
- this.cg._emit(encodeUnsigned(offset));
4375
+ _inputTypes() {
4376
+ assert(this.#curFunction !== null, "No current function");
4377
+ return this.#curFunction.inputTypes;
4066
4378
  }
4067
- not = VECTOR_OP("not", 77, ["v128"], "v128");
4068
- and = VECTOR_OP("and", 78, ["v128", "v128"], "v128");
4069
- andnot = VECTOR_OP("andnot", 79, ["v128", "v128"], "v128");
4070
- or = VECTOR_OP("or", 80, ["v128", "v128"], "v128");
4071
- xor = VECTOR_OP("xor", 81, ["v128", "v128"], "v128");
4072
- bitselect = VECTOR_OP("bitselect", 82, [
4073
- "v128",
4074
- "v128",
4075
- "v128"
4076
- ], "v128");
4077
- any_true = VECTOR_OP("any_true", 83, ["v128"], "i32");
4078
- };
4079
- var I32x4 = class extends V128 {
4080
- splat = VECTOR_OP("splat", 17, ["i32"], "v128");
4081
- extract_lane = VECTOR_OPL("extract_lane", 27, ["v128"], "i32");
4082
- replace_lane = VECTOR_OPL("replace_lane", 28, ["v128", "i32"], "v128");
4083
- eq = VECTOR_OP("eq", 55, ["v128", "v128"], "v128");
4084
- ne = VECTOR_OP("ne", 56, ["v128", "v128"], "v128");
4085
- lt_s = VECTOR_OP("lt_s", 57, ["v128", "v128"], "v128");
4086
- lt_u = VECTOR_OP("lt_u", 58, ["v128", "v128"], "v128");
4087
- gt_s = VECTOR_OP("gt_s", 59, ["v128", "v128"], "v128");
4088
- gt_u = VECTOR_OP("gt_u", 60, ["v128", "v128"], "v128");
4089
- le_s = VECTOR_OP("le_s", 61, ["v128", "v128"], "v128");
4090
- le_u = VECTOR_OP("le_u", 62, ["v128", "v128"], "v128");
4091
- ge_s = VECTOR_OP("ge_s", 63, ["v128", "v128"], "v128");
4092
- ge_u = VECTOR_OP("ge_u", 64, ["v128", "v128"], "v128");
4093
- abs = VECTOR_OP("abs", 160, ["v128"], "v128");
4094
- neg = VECTOR_OP("neg", 161, ["v128"], "v128");
4095
- all_true = VECTOR_OP("all_true", 163, ["v128"], "i32");
4096
- bitmask = VECTOR_OP("bitmask", 164, ["v128"], "i32");
4097
- shl = VECTOR_OP("shl", 171, ["v128", "i32"], "v128");
4098
- shr_s = VECTOR_OP("shr_s", 172, ["v128", "i32"], "v128");
4099
- shr_u = VECTOR_OP("shr_u", 173, ["v128", "i32"], "v128");
4100
- add = VECTOR_OP("add", 174, ["v128", "v128"], "v128");
4101
- sub = VECTOR_OP("sub", 177, ["v128", "v128"], "v128");
4102
- mul = VECTOR_OP("mul", 181, ["v128", "v128"], "v128");
4103
- min_s = VECTOR_OP("min_s", 182, ["v128", "v128"], "v128");
4104
- min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
4105
- max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
4106
- max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
4107
- trunc_sat_f32x4_s = VECTOR_OP("trunc_sat_f32x4_s", 248, ["v128"], "v128");
4108
- trunc_sat_f32x4_u = VECTOR_OP("trunc_sat_f32x4_u", 249, ["v128"], "v128");
4109
- };
4110
- var F32x4 = class extends V128 {
4111
- splat = VECTOR_OP("splat", 19, ["f32"], "v128");
4112
- extract_lane = VECTOR_OPL("extract_lane", 31, ["v128"], "f32");
4113
- replace_lane = VECTOR_OPL("replace_lane", 32, ["v128", "f32"], "v128");
4114
- eq = VECTOR_OP("eq", 65, ["v128", "v128"], "v128");
4115
- ne = VECTOR_OP("ne", 66, ["v128", "v128"], "v128");
4116
- lt = VECTOR_OP("lt", 67, ["v128", "v128"], "v128");
4117
- gt = VECTOR_OP("gt", 68, ["v128", "v128"], "v128");
4118
- le = VECTOR_OP("le", 69, ["v128", "v128"], "v128");
4119
- ge = VECTOR_OP("ge", 70, ["v128", "v128"], "v128");
4120
- ceil = VECTOR_OP("ceil", 103, ["v128"], "v128");
4121
- floor = VECTOR_OP("floor", 104, ["v128"], "v128");
4122
- trunc = VECTOR_OP("trunc", 105, ["v128"], "v128");
4123
- nearest = VECTOR_OP("nearest", 106, ["v128"], "v128");
4124
- abs = VECTOR_OP("abs", 224, ["v128"], "v128");
4125
- neg = VECTOR_OP("neg", 225, ["v128"], "v128");
4126
- sqrt = VECTOR_OP("sqrt", 227, ["v128"], "v128");
4127
- add = VECTOR_OP("add", 228, ["v128", "v128"], "v128");
4128
- sub = VECTOR_OP("sub", 229, ["v128", "v128"], "v128");
4129
- mul = VECTOR_OP("mul", 230, ["v128", "v128"], "v128");
4130
- div = VECTOR_OP("div", 231, ["v128", "v128"], "v128");
4131
- min = VECTOR_OP("min", 232, ["v128", "v128"], "v128");
4132
- max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
4133
- pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
4134
- pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
4135
- convert_i32x4_s = VECTOR_OP("convert_i32x4_s", 250, ["v128"], "v128");
4136
- convert_i32x4_u = VECTOR_OP("convert_i32x4_u", 251, ["v128"], "v128");
4137
- };
4138
-
4139
- //#endregion
4140
- //#region src/backend/wasm.ts
4141
- /**
4142
- * SIMD version of translateExp: emits v128 (f32x4 or i32x4) instructions instead of scalar.
4143
- * gidx always steps by 4. strideMap classifies each GlobalIndex as broadcast/contiguous/gather.
4144
- */
4145
- function translateExpSimd(cg, funcs, exp, ctx, strideMap) {
4146
- const references = /* @__PURE__ */ new Map();
4147
- const seen = /* @__PURE__ */ new Set();
4148
- const countReferences = (exp$1) => {
4149
- references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
4150
- if (!seen.has(exp$1)) {
4151
- seen.add(exp$1);
4152
- for (const src of exp$1.src) countReferences(src);
4153
- }
4154
- };
4155
- const expContext = /* @__PURE__ */ new Map();
4156
- const gen = (exp$1) => {
4157
- if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
4158
- const { op, src, arg, dtype } = exp$1;
4159
- const isInt = dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool;
4160
- const isSigned = dtype === DType.Int32;
4161
- if (op === AluOp.Add) {
4162
- gen(src[0]);
4163
- gen(src[1]);
4164
- if (isInt) cg.i32x4.add();
4165
- else cg.f32x4.add();
4166
- } else if (op === AluOp.Sub) {
4167
- gen(src[0]);
4168
- gen(src[1]);
4169
- if (isInt) cg.i32x4.sub();
4170
- else cg.f32x4.sub();
4171
- } else if (op === AluOp.Mul) {
4172
- gen(src[0]);
4173
- gen(src[1]);
4174
- if (isInt) cg.i32x4.mul();
4175
- else cg.f32x4.mul();
4176
- } else if (op === AluOp.Min) {
4177
- gen(src[0]);
4178
- gen(src[1]);
4179
- if (isInt) if (isSigned) cg.i32x4.min_s();
4180
- else cg.i32x4.min_u();
4181
- else cg.f32x4.min();
4182
- } else if (op === AluOp.Max) {
4183
- gen(src[0]);
4184
- gen(src[1]);
4185
- if (isInt) if (isSigned) cg.i32x4.max_s();
4186
- else cg.i32x4.max_u();
4187
- else cg.f32x4.max();
4188
- } else if (op === AluOp.Sqrt) {
4189
- gen(src[0]);
4190
- cg.f32x4.sqrt();
4191
- } else if (op === AluOp.Floor) {
4192
- gen(src[0]);
4193
- cg.f32x4.floor();
4194
- } else if (op === AluOp.Ceil) {
4195
- gen(src[0]);
4196
- cg.f32x4.ceil();
4197
- } else if (op === AluOp.Const) if (isInt) {
4198
- cg.i32.const(arg);
4199
- cg.i32x4.splat();
4200
- } else {
4201
- cg.f32.const(arg);
4202
- cg.f32x4.splat();
4379
+ _locals() {
4380
+ assert(this.#curFunction !== null, "No current function");
4381
+ return this.#curFunction.locals;
4382
+ }
4383
+ _push(type) {
4384
+ if (!type) throw new Error(`pushing type ${type}`);
4385
+ this.#typeStack.push(type);
4386
+ }
4387
+ _pop() {
4388
+ assert(this.#typeStack.length > 0, "popping empty stack");
4389
+ return this.#typeStack.pop();
4390
+ }
4391
+ _emit(bytes) {
4392
+ if (typeof bytes === "number") this.#curBytes.push(bytes);
4393
+ else this.#curBytes.push(...bytes);
4394
+ }
4395
+ finish() {
4396
+ this.#curBytes = [];
4397
+ const emittedBytes = [];
4398
+ emittedBytes.push(...magicModuleHeader);
4399
+ emittedBytes.push(...moduleVersion);
4400
+ const typeSectionBytes = [];
4401
+ const totalFunctionTypes = this.#importedFunctions.length + this.#functions.length;
4402
+ typeSectionBytes.push(...encodeUnsigned(totalFunctionTypes));
4403
+ for (const f of [...this.#importedFunctions, ...this.#functions]) {
4404
+ typeSectionBytes.push(96);
4405
+ typeSectionBytes.push(...encodeUnsigned(f.inputTypes.length));
4406
+ for (const t of f.inputTypes) typeSectionBytes.push(t.typeId);
4407
+ typeSectionBytes.push(...encodeUnsigned(f.outputTypes.length));
4408
+ for (const t of f.outputTypes) typeSectionBytes.push(t.typeId);
4203
4409
  }
4204
- else if (op === AluOp.Cast) {
4205
- gen(src[0]);
4206
- const dtype0 = src[0].dtype;
4207
- const src0IsInt = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
4208
- if (isInt && !src0IsInt) if (isSigned) cg.i32x4.trunc_sat_f32x4_s();
4209
- else cg.i32x4.trunc_sat_f32x4_u();
4210
- else if (!isInt && src0IsInt) if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32x4.convert_i32x4_s();
4211
- else cg.f32x4.convert_i32x4_u();
4212
- } else if (op === AluOp.Cmplt) {
4213
- gen(src[0]);
4214
- gen(src[1]);
4215
- const srcDtype = src[0].dtype;
4216
- if (srcDtype === DType.Float32) cg.f32x4.lt();
4217
- else if (srcDtype === DType.Int32) cg.i32x4.lt_s();
4218
- else if (srcDtype === DType.Uint32) cg.i32x4.lt_u();
4219
- else throw new UnsupportedOpError(op, dtype, "wasm");
4220
- cg.i32.const(1);
4221
- cg.i32x4.splat();
4222
- cg.v128.and();
4223
- } else if (op === AluOp.Cmpne) {
4224
- gen(src[0]);
4225
- gen(src[1]);
4226
- const srcDtype = src[0].dtype;
4227
- if (srcDtype === DType.Float32) cg.f32x4.ne();
4228
- else cg.i32x4.ne();
4229
- cg.i32.const(1);
4230
- cg.i32x4.splat();
4231
- cg.v128.and();
4232
- } else if (op === AluOp.Where) {
4233
- gen(src[1]);
4234
- gen(src[2]);
4235
- gen(src[0]);
4236
- cg.i32.const(0);
4237
- cg.i32x4.splat();
4238
- cg.i32x4.ne();
4239
- cg.v128.bitselect();
4240
- } else if (op === AluOp.Variable || op === AluOp.Special) throw new Error(`translateExpSimd: unexpected ${op}(${arg})`);
4241
- else if (op === AluOp.GlobalIndex) {
4242
- const [gid, len] = arg;
4243
- const indexSubtree = src[0];
4244
- const stride = strideMap.get(exp$1) ?? GATHER;
4245
- if (stride.kind === "contiguous") {
4246
- translateExp(cg, funcs, indexSubtree, ctx);
4247
- {
4248
- const maxIdx = Math.max(len - SIMD_LANES, 0);
4249
- const wideIdx = cg.local.declare(cg.i32);
4250
- cg.local.set(wideIdx);
4251
- cg.local.get(wideIdx);
4252
- cg.i32.const(maxIdx);
4253
- cg.local.get(wideIdx);
4254
- cg.i32.const(maxIdx);
4255
- cg.i32.lt_u();
4256
- cg.select();
4257
- }
4258
- cg.i32.const(byteWidth(dtype));
4259
- cg.i32.mul();
4260
- cg.local.get(gid);
4261
- cg.i32.add();
4262
- if (isInt) cg.i32x4.load(4);
4263
- else cg.f32x4.load(4);
4264
- } else if (stride.kind === "broadcast") {
4265
- translateExp(cg, funcs, indexSubtree, ctx);
4266
- const local = cg.local.declare(cg.i32);
4267
- cg.local.tee(local);
4268
- cg.i32.const(0);
4269
- cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
4270
- cg.select();
4271
- cg.i32.const(byteWidth(dtype));
4272
- cg.i32.mul();
4273
- cg.local.get(gid);
4274
- cg.i32.add();
4275
- if (isInt) {
4276
- cg.i32.load(2);
4277
- cg.i32x4.splat();
4410
+ emittedBytes.push(1);
4411
+ appendLengthEncodedBlock(emittedBytes, typeSectionBytes);
4412
+ const importSectionBytes = [];
4413
+ const numImports = this.#importedFunctions.length + (this.memory.isImport ? 1 : 0);
4414
+ if (numImports > 0) {
4415
+ importSectionBytes.push(...encodeUnsigned(numImports));
4416
+ for (let i = 0; i < this.#importedFunctions.length; i++) {
4417
+ const f = this.#importedFunctions[i];
4418
+ importSectionBytes.push(...encodeString(f.module));
4419
+ importSectionBytes.push(...encodeString(f.name));
4420
+ importSectionBytes.push(0);
4421
+ importSectionBytes.push(...encodeUnsigned(i));
4422
+ }
4423
+ if (this.memory.isImport) {
4424
+ importSectionBytes.push(...encodeString(this.memory.aString));
4425
+ importSectionBytes.push(...encodeString(this.memory.bString));
4426
+ importSectionBytes.push(2);
4427
+ if (this.memory.max) {
4428
+ if (this.memory.isShared) importSectionBytes.push(3);
4429
+ else importSectionBytes.push(1);
4430
+ importSectionBytes.push(...encodeUnsigned(this.memory.min));
4431
+ importSectionBytes.push(...encodeUnsigned(this.memory.max));
4278
4432
  } else {
4279
- cg.f32.load(2);
4280
- cg.f32x4.splat();
4433
+ assert(!this.memory.isShared, "shared memory must have a max size");
4434
+ importSectionBytes.push(0);
4435
+ importSectionBytes.push(...encodeUnsigned(this.memory.min));
4281
4436
  }
4437
+ }
4438
+ emittedBytes.push(2);
4439
+ appendLengthEncodedBlock(emittedBytes, importSectionBytes);
4440
+ }
4441
+ const functionSectionBytes = [];
4442
+ functionSectionBytes.push(...encodeUnsigned(this.#functions.length));
4443
+ for (let i = 0; i < this.#functions.length; i++) {
4444
+ const typeIndex = this.#importedFunctions.length + i;
4445
+ functionSectionBytes.push(...encodeUnsigned(typeIndex));
4446
+ }
4447
+ emittedBytes.push(3);
4448
+ appendLengthEncodedBlock(emittedBytes, functionSectionBytes);
4449
+ const memorySectionBytes = [];
4450
+ if (!this.memory.isImport && (this.memory.min || this.memory.max)) {
4451
+ memorySectionBytes.push(1);
4452
+ if (this.memory.min && this.memory.max) {
4453
+ if (this.memory.isShared) memorySectionBytes.push(3);
4454
+ else memorySectionBytes.push(1);
4455
+ memorySectionBytes.push(...encodeUnsigned(this.memory.min));
4456
+ memorySectionBytes.push(...encodeUnsigned(this.memory.max));
4282
4457
  } else {
4283
- const steppingLocal = ctx["gidx"];
4284
- const origValue = cg.local.declare(cg.i32);
4285
- cg.local.get(steppingLocal);
4286
- cg.local.set(origValue);
4287
- if (isInt) {
4288
- cg.i32.const(0);
4289
- cg.i32x4.splat();
4290
- } else {
4291
- cg.f32.const(0);
4292
- cg.f32x4.splat();
4293
- }
4294
- const vec = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
4295
- cg.local.set(vec);
4296
- const idx = cg.local.declare(cg.i32);
4297
- const scalarVal = cg.local.declare(isInt ? cg.i32 : cg.f32);
4298
- for (let lane = 0; lane < SIMD_LANES; lane++) {
4299
- cg.local.get(origValue);
4300
- if (lane > 0) {
4301
- cg.i32.const(lane);
4302
- cg.i32.add();
4303
- }
4304
- cg.local.set(steppingLocal);
4305
- translateExp(cg, funcs, indexSubtree, ctx);
4306
- cg.local.tee(idx);
4307
- cg.i32.const(0);
4308
- cg.local.get(idx), cg.i32.const(len), cg.i32.lt_u();
4309
- cg.select();
4310
- cg.i32.const(byteWidth(dtype));
4311
- cg.i32.mul();
4312
- cg.local.get(gid);
4313
- cg.i32.add();
4314
- if (isInt) cg.i32.load(2);
4315
- else cg.f32.load(2);
4316
- cg.local.set(scalarVal);
4317
- cg.local.get(vec);
4318
- cg.local.get(scalarVal);
4319
- if (isInt) cg.i32x4.replace_lane(lane);
4320
- else cg.f32x4.replace_lane(lane);
4321
- cg.local.set(vec);
4322
- }
4323
- cg.local.get(origValue);
4324
- cg.local.set(steppingLocal);
4325
- cg.local.get(vec);
4458
+ assert(!this.memory.isShared, "shared memory must have a max size");
4459
+ memorySectionBytes.push(0);
4460
+ memorySectionBytes.push(...encodeUnsigned(this.memory.min));
4326
4461
  }
4327
- } else throw new Error(`translateExpSimd: unsupported op ${op}`);
4328
- if ((references.get(exp$1) ?? 0) > 1) {
4329
- const local = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
4330
- cg.local.tee(local);
4331
- expContext.set(exp$1, local);
4462
+ emittedBytes.push(5);
4463
+ appendLengthEncodedBlock(emittedBytes, memorySectionBytes);
4332
4464
  }
4333
- };
4334
- countReferences(exp);
4335
- gen(exp);
4336
- }
4337
- /** Number of SIMD lanes (f32x4 / i32x4 = 4 lanes). */
4338
- const SIMD_LANES = 4;
4339
- function referencesGidx(exp) {
4340
- if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return true;
4341
- return exp.src.some(referencesGidx);
4342
- }
4343
- /** When tileSize > N but doesn't divide evenly, the last group before the
4344
- * inner reset is shorter than N — a SIMD group could straddle it. */
4345
- function hasFragmentRisk(tileSize, N) {
4346
- return isFinite(tileSize) && tileSize > N && tileSize % N !== 0;
4347
- }
4348
- const GATHER = { kind: "gather" };
4349
- /**
4350
- * Classify how a GlobalIndex's index expression behaves as gidx increments.
4351
- */
4352
- function analyzeStride(exp) {
4353
- if (!referencesGidx(exp)) return {
4354
- kind: "broadcast",
4355
- tileSize: Infinity
4356
- };
4357
- if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return {
4358
- kind: "contiguous",
4359
- tileSize: Infinity
4360
- };
4361
- if (exp.op === AluOp.Idiv && exp.src[1].op === AluOp.Const) {
4362
- const N = exp.src[1].arg;
4363
- const inner = analyzeStride(exp.src[0]);
4364
- if (inner.kind === "broadcast") return inner;
4365
- if (inner.kind !== "contiguous") return GATHER;
4366
- if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
4367
- return {
4368
- kind: "broadcast",
4369
- tileSize: Math.min(inner.tileSize, N)
4370
- };
4465
+ const exportSectionBytes = [];
4466
+ const numExports = this.#exportedFunctions.size + (this.memory.isExport ? 1 : 0);
4467
+ exportSectionBytes.push(...encodeUnsigned(numExports));
4468
+ if (this.memory.isExport) {
4469
+ exportSectionBytes.push(...encodeString(this.memory.aString));
4470
+ exportSectionBytes.push(2);
4471
+ exportSectionBytes.push(0);
4472
+ }
4473
+ for (const [key, name] of this.#exportedFunctions.entries()) {
4474
+ exportSectionBytes.push(...encodeString(name));
4475
+ exportSectionBytes.push(0);
4476
+ exportSectionBytes.push(...encodeUnsigned(key));
4477
+ }
4478
+ emittedBytes.push(7);
4479
+ appendLengthEncodedBlock(emittedBytes, exportSectionBytes);
4480
+ const codeSectionBytes = [];
4481
+ codeSectionBytes.push(...encodeUnsigned(this.#functions.length));
4482
+ for (const f of this.#functions) {
4483
+ this.#typeStack = [];
4484
+ this.#blockFrames = [{
4485
+ idx: 0,
4486
+ ty: f.outputTypes
4487
+ }];
4488
+ this.#curFunction = f;
4489
+ this.#curBytes = [];
4490
+ f.emit();
4491
+ this.end();
4492
+ const bodyBytes = this.#curBytes;
4493
+ this.#curBytes = [];
4494
+ this.#curBytes.push(...encodeUnsigned(f.locals.length));
4495
+ for (const l of f.locals) {
4496
+ this._emit(1);
4497
+ this._emit(l.typeId);
4498
+ }
4499
+ const fnBytes = this.#curBytes.concat(bodyBytes);
4500
+ appendLengthEncodedBlock(codeSectionBytes, fnBytes);
4501
+ }
4502
+ this.#curFunction = null;
4503
+ emittedBytes.push(10);
4504
+ appendLengthEncodedBlock(emittedBytes, codeSectionBytes);
4505
+ return new Uint8Array(emittedBytes);
4371
4506
  }
4372
- if (exp.op === AluOp.Mod && exp.src[1].op === AluOp.Const) {
4373
- const N = exp.src[1].arg;
4374
- const inner = analyzeStride(exp.src[0]);
4375
- if (inner.kind === "broadcast") return inner;
4376
- if (inner.kind !== "contiguous") return GATHER;
4377
- if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
4378
- return {
4379
- kind: "contiguous",
4380
- tileSize: Math.min(inner.tileSize, N)
4381
- };
4507
+ };
4508
+ var Local = class {
4509
+ constructor(cg) {
4510
+ this.cg = cg;
4382
4511
  }
4383
- if (exp.op === AluOp.Mul) {
4384
- for (let i = 0; i < 2; i++) if (exp.src[i].op === AluOp.Const) {
4385
- const inner = analyzeStride(exp.src[1 - i]);
4386
- if (inner.kind === "broadcast") return inner;
4387
- return GATHER;
4388
- }
4512
+ declare(type) {
4513
+ return this.cg._declareLocal(type);
4514
+ }
4515
+ get(idx) {
4516
+ assert(Number.isInteger(idx), "getting non-integer local");
4517
+ const inputTypes = this.cg._inputTypes();
4518
+ if (idx < inputTypes.length) this.cg._push(inputTypes[idx]);
4519
+ else this.cg._push(this.cg._locals()[idx - inputTypes.length]);
4520
+ this.cg._emit(32);
4521
+ this.cg._emit(encodeUnsigned(idx));
4522
+ }
4523
+ set(idx) {
4524
+ const t = this.cg._pop();
4525
+ const inputTypes = this.cg._inputTypes();
4526
+ const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
4527
+ assert(expectedType.typeId === t.typeId, "can't set local to this value (wrong type)");
4528
+ this.cg._emit(33);
4529
+ this.cg._emit(encodeUnsigned(idx));
4389
4530
  }
4390
- if (exp.op === AluOp.Add) {
4391
- const lhsHasGidx = referencesGidx(exp.src[0]);
4392
- const rhsHasGidx = referencesGidx(exp.src[1]);
4393
- if (lhsHasGidx && !rhsHasGidx) return analyzeStride(exp.src[0]);
4394
- if (!lhsHasGidx && rhsHasGidx) return analyzeStride(exp.src[1]);
4531
+ tee(idx) {
4532
+ const t = this.cg._pop();
4533
+ const inputTypes = this.cg._inputTypes();
4534
+ const expectedType = idx < inputTypes.length ? inputTypes[idx] : this.cg._locals()[idx - inputTypes.length];
4535
+ assert(expectedType.typeId === t.typeId, "can't tee local to this value (wrong type)");
4536
+ this.cg._emit(34);
4537
+ this.cg._emit(encodeUnsigned(idx));
4538
+ this.cg._push(expectedType);
4395
4539
  }
4396
- return GATHER;
4540
+ };
4541
+ function UNARY_OP(op, opcode, inType, outType) {
4542
+ return function() {
4543
+ const t = this.cg._pop();
4544
+ assert(t.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inType} -> ${outType})`);
4545
+ this.cg._emit(encodeOpcode(opcode));
4546
+ this.cg._push(this.cg[outType]);
4547
+ };
4397
4548
  }
4398
- /** Ops that have direct SIMD (f32x4) instruction variants. */
4399
- const simdF32Ops = new Set([
4400
- AluOp.Add,
4401
- AluOp.Sub,
4402
- AluOp.Mul,
4403
- AluOp.Floor,
4404
- AluOp.Ceil,
4405
- AluOp.Min,
4406
- AluOp.Max,
4407
- AluOp.Sqrt,
4408
- AluOp.Cast,
4409
- AluOp.Where,
4410
- AluOp.Const,
4411
- AluOp.GlobalIndex
4412
- ]);
4413
- /** Ops that have direct SIMD (i32x4) instruction variants. */
4414
- const simdI32Ops = new Set([
4415
- AluOp.Add,
4416
- AluOp.Sub,
4417
- AluOp.Mul,
4418
- AluOp.Min,
4419
- AluOp.Max,
4420
- AluOp.Cast,
4421
- AluOp.Where,
4422
- AluOp.Const,
4423
- AluOp.GlobalIndex
4424
- ]);
4425
- /** Ops that produce Bool (i32x4 bitmask) in SIMD. */
4426
- const simdBoolOps = new Set([
4427
- AluOp.Cmplt,
4428
- AluOp.Cmpne,
4429
- AluOp.Const,
4430
- AluOp.GlobalIndex
4431
- ]);
4432
- /**
4433
- * Check if a kernel is eligible for SIMD codegen.
4434
- *
4435
- * A kernel qualifies when:
4436
- * - size >= 4 (need at least 4 elements for a SIMD group)
4437
- * - For reductions: the reduction op has a SIMD variant for its dtype
4438
- * - All nodes have a supported dtype (f32, i32, u32, bool) with SIMD variants
4439
- */
4440
- function isSimdEligible(tunedExp, kernel) {
4441
- if (kernel.size < SIMD_LANES) return false;
4442
- if (kernel.reduction) {
4443
- if (!simdSupportedOpsForDtype(kernel.reduction.dtype)?.has(kernel.reduction.op)) return false;
4444
- }
4445
- const check = (exp, visited) => {
4446
- if (visited.has(exp)) return true;
4447
- visited.add(exp);
4448
- const supportedOps = simdSupportedOpsForDtype(exp.dtype);
4449
- if (!supportedOps || !supportedOps.has(exp.op)) return false;
4450
- if (exp.op === AluOp.GlobalIndex) return true;
4451
- for (const child of exp.src) if (!check(child, visited)) return false;
4452
- return true;
4549
+ function BINARY_OP(op, opcode, typeA, typeB, outType) {
4550
+ return function() {
4551
+ const b = this.cg._pop();
4552
+ const a = this.cg._pop();
4553
+ assert(a.typeId === this.cg[typeA].typeId && b.typeId === this.cg[typeB].typeId, `invalid type for ${op} (${typeA}, ${typeB} -> ${outType})`);
4554
+ this.cg._emit(encodeOpcode(opcode));
4555
+ this.cg._push(this.cg[outType]);
4453
4556
  };
4454
- return check(tunedExp, /* @__PURE__ */ new Set());
4455
4557
  }
4456
- function simdSupportedOpsForDtype(dtype) {
4457
- if (dtype === DType.Float32) return simdF32Ops;
4458
- if (dtype === DType.Int32 || dtype === DType.Uint32) return simdI32Ops;
4459
- if (dtype === DType.Bool) return simdBoolOps;
4460
- return null;
4558
+ function LOAD_OP(op, opcode, outType) {
4559
+ return function(align = 0, offset = 0) {
4560
+ const idxType = this.cg._pop();
4561
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
4562
+ this.cg._emit(encodeOpcode(opcode));
4563
+ this.cg._emit(encodeUnsigned(align));
4564
+ this.cg._emit(encodeUnsigned(offset));
4565
+ this.cg._push(this.cg[outType]);
4566
+ };
4461
4567
  }
4462
- const moduleCache = /* @__PURE__ */ new Map();
4463
- /** Backend that compiles into WebAssembly bytecode for immediate execution. */
4464
- var WasmBackend = class {
4465
- type = "wasm";
4466
- maxArgs = 64;
4467
- #memory;
4468
- #nextSlot;
4469
- #allocator;
4470
- #buffers;
4471
- #workerPool;
4472
- #pendingWork = /* @__PURE__ */ new Map();
4473
- constructor() {
4474
- this.#memory = hasSharedArrayBuffer() ? new WebAssembly.Memory({
4475
- initial: 0,
4476
- maximum: 65536,
4477
- shared: true
4478
- }) : new WebAssembly.Memory({ initial: 0 });
4479
- this.#allocator = new WasmAllocator(this.#memory);
4480
- this.#nextSlot = 1;
4481
- this.#buffers = /* @__PURE__ */ new Map();
4482
- this.#workerPool = createWorkerPool(this.#memory);
4568
+ function STORE_OP(op, opcode, inType) {
4569
+ return function(align = 0, offset = 0) {
4570
+ const valType = this.cg._pop();
4571
+ const idxType = this.cg._pop();
4572
+ assert(valType.typeId === this.cg[inType].typeId, `invalid value type for ${op} (${inType})`);
4573
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
4574
+ this.cg._emit(encodeOpcode(opcode));
4575
+ this.cg._emit(encodeUnsigned(align));
4576
+ this.cg._emit(encodeUnsigned(offset));
4577
+ };
4578
+ }
4579
+ var I32 = class {
4580
+ constructor(cg) {
4581
+ this.cg = cg;
4483
4582
  }
4484
- malloc(size, initialData) {
4485
- const ptr = this.#allocator.malloc(size);
4486
- if (initialData) {
4487
- if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
4488
- new Uint8Array(this.#memory.buffer, ptr, size).set(initialData);
4489
- }
4490
- const slot = this.#nextSlot++;
4491
- this.#buffers.set(slot, {
4492
- ptr,
4493
- size,
4494
- ref: 1
4495
- });
4496
- return slot;
4583
+ get typeId() {
4584
+ return 127;
4497
4585
  }
4498
- incRef(slot) {
4499
- const buffer = this.#buffers.get(slot);
4500
- if (!buffer) throw new SlotError(slot);
4501
- buffer.ref++;
4586
+ get name() {
4587
+ return "i32";
4502
4588
  }
4503
- decRef(slot) {
4504
- const buffer = this.#buffers.get(slot);
4505
- if (!buffer) throw new SlotError(slot);
4506
- buffer.ref--;
4507
- if (buffer.ref === 0) {
4508
- this.#allocator.free(buffer.ptr);
4509
- this.#buffers.delete(slot);
4510
- }
4589
+ const(i) {
4590
+ this.cg._emit(65);
4591
+ this.cg._emit(encodeSigned(i));
4592
+ this.cg._push(this);
4511
4593
  }
4512
- async read(slot, start, count) {
4513
- const epoch = this.#pendingWork.get(slot);
4514
- if (epoch) await this.#workerPool.waitForEpoch(epoch);
4515
- return this.#readData(slot, start, count);
4594
+ clz = UNARY_OP("clz", 103, "i32", "i32");
4595
+ ctz = UNARY_OP("ctz", 104, "i32", "i32");
4596
+ popcnt = UNARY_OP("popcnt", 105, "i32", "i32");
4597
+ lt_s = BINARY_OP("lt_s", 72, "i32", "i32", "i32");
4598
+ lt_u = BINARY_OP("lt_u", 73, "i32", "i32", "i32");
4599
+ gt_s = BINARY_OP("gt_s", 74, "i32", "i32", "i32");
4600
+ gt_u = BINARY_OP("gt_u", 75, "i32", "i32", "i32");
4601
+ le_s = BINARY_OP("le_s", 76, "i32", "i32", "i32");
4602
+ le_u = BINARY_OP("le_u", 77, "i32", "i32", "i32");
4603
+ ge_s = BINARY_OP("ge_s", 78, "i32", "i32", "i32");
4604
+ ge_u = BINARY_OP("ge_u", 79, "i32", "i32", "i32");
4605
+ add = BINARY_OP("add", 106, "i32", "i32", "i32");
4606
+ sub = BINARY_OP("sub", 107, "i32", "i32", "i32");
4607
+ mul = BINARY_OP("mul", 108, "i32", "i32", "i32");
4608
+ div_s = BINARY_OP("div_s", 109, "i32", "i32", "i32");
4609
+ div_u = BINARY_OP("div_u", 110, "i32", "i32", "i32");
4610
+ rem_s = BINARY_OP("rem_s", 111, "i32", "i32", "i32");
4611
+ rem_u = BINARY_OP("rem_u", 112, "i32", "i32", "i32");
4612
+ and = BINARY_OP("and", 113, "i32", "i32", "i32");
4613
+ or = BINARY_OP("or", 114, "i32", "i32", "i32");
4614
+ xor = BINARY_OP("xor", 115, "i32", "i32", "i32");
4615
+ shl = BINARY_OP("shl", 116, "i32", "i32", "i32");
4616
+ shr_s = BINARY_OP("shr_s", 117, "i32", "i32", "i32");
4617
+ shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
4618
+ rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
4619
+ rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
4620
+ eqz = UNARY_OP("eqz", 69, "i32", "i32");
4621
+ eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
4622
+ ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
4623
+ trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
4624
+ trunc_f32_u = UNARY_OP("trunc_f32_u", 169, "f32", "i32");
4625
+ trunc_f64_s = UNARY_OP("trunc_f64_s", 170, "f64", "i32");
4626
+ trunc_f64_u = UNARY_OP("trunc_f64_u", 171, "f64", "i32");
4627
+ load = LOAD_OP("load", 40, "i32");
4628
+ load8_s = LOAD_OP("load8_s", 44, "i32");
4629
+ load8_u = LOAD_OP("load8_u", 45, "i32");
4630
+ load16_s = LOAD_OP("load16_s", 46, "i32");
4631
+ load16_u = LOAD_OP("load16_u", 47, "i32");
4632
+ store = STORE_OP("store", 54, "i32");
4633
+ store8 = STORE_OP("store8", 58, "i32");
4634
+ store16 = STORE_OP("store16", 59, "i32");
4635
+ reinterpret_f32 = UNARY_OP("reinterpret_f32", 188, "f32", "i32");
4636
+ trunc_sat_f32_s = UNARY_OP("trunc_sat_f32_s", [252, 0], "f32", "i32");
4637
+ trunc_sat_f32_u = UNARY_OP("trunc_sat_f32_u", [252, 1], "f32", "i32");
4638
+ trunc_sat_f64_s = UNARY_OP("trunc_sat_f64_s", [252, 2], "f64", "i32");
4639
+ trunc_sat_f64_u = UNARY_OP("trunc_sat_f64_u", [252, 3], "f64", "i32");
4640
+ };
4641
+ var F32 = class {
4642
+ constructor(cg) {
4643
+ this.cg = cg;
4516
4644
  }
4517
- readSync(slot, start, count) {
4518
- const epoch = this.#pendingWork.get(slot);
4519
- if (epoch && this.#workerPool.epoch < epoch) throw new Error("cannot read synchronously from a slot with async work");
4520
- return this.#readData(slot, start, count);
4645
+ get typeId() {
4646
+ return 125;
4521
4647
  }
4522
- #readData(slot, start, count) {
4523
- const buffer = this.#getBuffer(slot);
4524
- if (start === void 0) start = 0;
4525
- if (count === void 0) count = buffer.byteLength - start;
4526
- if (hasSharedArrayBuffer() && buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
4527
- else return buffer.slice(start, start + count);
4648
+ get name() {
4649
+ return "f32";
4528
4650
  }
4529
- async prepareKernel(kernel) {
4530
- const kernelHash = FpHash.hash(kernel);
4531
- const module$1 = await runWithCacheAsync(moduleCache, kernelHash.toString(), () => WebAssembly.compile(codegenWasm(kernel)));
4532
- return new Executable(kernel, {
4533
- module: module$1,
4534
- parallel: this.#workerPool !== null
4535
- });
4651
+ const(f) {
4652
+ this.cg._emit(67);
4653
+ const buffer = /* @__PURE__ */ new ArrayBuffer(4);
4654
+ new DataView(buffer).setFloat32(0, f, true);
4655
+ const bytes = new Uint8Array(buffer);
4656
+ for (let i = 0; i < 4; i++) this.cg._emit(bytes[i]);
4657
+ this.cg._push(this);
4658
+ }
4659
+ load = LOAD_OP("load", 42, "f32");
4660
+ store = STORE_OP("store", 56, "f32");
4661
+ eq = BINARY_OP("eq", 91, "f32", "f32", "i32");
4662
+ ne = BINARY_OP("ne", 92, "f32", "f32", "i32");
4663
+ lt = BINARY_OP("lt", 93, "f32", "f32", "i32");
4664
+ gt = BINARY_OP("gt", 94, "f32", "f32", "i32");
4665
+ le = BINARY_OP("le", 95, "f32", "f32", "i32");
4666
+ ge = BINARY_OP("ge", 96, "f32", "f32", "i32");
4667
+ abs = UNARY_OP("abs", 139, "f32", "f32");
4668
+ neg = UNARY_OP("neg", 140, "f32", "f32");
4669
+ ceil = UNARY_OP("ceil", 141, "f32", "f32");
4670
+ floor = UNARY_OP("floor", 142, "f32", "f32");
4671
+ trunc = UNARY_OP("trunc", 143, "f32", "f32");
4672
+ nearest = UNARY_OP("nearest", 144, "f32", "f32");
4673
+ sqrt = UNARY_OP("sqrt", 145, "f32", "f32");
4674
+ add = BINARY_OP("add", 146, "f32", "f32", "f32");
4675
+ sub = BINARY_OP("sub", 147, "f32", "f32", "f32");
4676
+ mul = BINARY_OP("mul", 148, "f32", "f32", "f32");
4677
+ div = BINARY_OP("div", 149, "f32", "f32", "f32");
4678
+ min = BINARY_OP("min", 150, "f32", "f32", "f32");
4679
+ max = BINARY_OP("max", 151, "f32", "f32", "f32");
4680
+ copysign = BINARY_OP("copysign", 152, "f32", "f32", "f32");
4681
+ convert_i32_s = UNARY_OP("convert_i32_s", 178, "i32", "f32");
4682
+ convert_i32_u = UNARY_OP("convert_i32_u", 179, "i32", "f32");
4683
+ demote_f64 = UNARY_OP("demote_f64", 182, "f64", "f32");
4684
+ reinterpret_i32 = UNARY_OP("reinterpret_i32", 190, "i32", "f32");
4685
+ };
4686
+ var F64 = class {
4687
+ constructor(cg) {
4688
+ this.cg = cg;
4536
4689
  }
4537
- prepareKernelSync(kernel) {
4538
- const kernelHash = FpHash.hash(kernel);
4539
- const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => new WebAssembly.Module(codegenWasm(kernel)));
4540
- return new Executable(kernel, {
4541
- module: module$1,
4542
- parallel: false
4543
- });
4690
+ get typeId() {
4691
+ return 124;
4544
4692
  }
4545
- async prepareRoutine(routine) {
4546
- return this.prepareRoutineSync(routine);
4693
+ get name() {
4694
+ return "f64";
4547
4695
  }
4548
- prepareRoutineSync(routine) {
4549
- return new Executable(routine, {
4550
- module: void 0,
4551
- parallel: false
4552
- });
4696
+ const(f) {
4697
+ this.cg._emit(68);
4698
+ const buffer = /* @__PURE__ */ new ArrayBuffer(8);
4699
+ new DataView(buffer).setFloat64(0, f, true);
4700
+ const bytes = new Uint8Array(buffer);
4701
+ for (let i = 0; i < 8; i++) this.cg._emit(bytes[i]);
4702
+ this.cg._push(this);
4553
4703
  }
4554
- dispatch(exe, inputs, outputs) {
4555
- const tracing = isTracing();
4556
- const start = tracing ? performance.now() : 0;
4557
- if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
4558
- else {
4559
- const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
4560
- if (exe.data.parallel && this.#workerPool) {
4561
- const epoch = this.#workerPool.dispatch(exe.data.module, ptrs, exe.source.size);
4562
- for (const slot of outputs) this.#pendingWork.set(slot, epoch);
4563
- } else {
4564
- if (inputs.some((slot) => {
4565
- const epoch = this.#pendingWork.get(slot);
4566
- return epoch && this.#workerPool.epoch < epoch;
4567
- })) throw new Error("cannot dispatch synchronously with pending async work");
4568
- const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
4569
- const func = instance.exports.kernel;
4570
- func(...ptrs, 0, exe.source.size);
4571
- }
4704
+ load = LOAD_OP("load", 43, "f64");
4705
+ store = STORE_OP("store", 57, "f64");
4706
+ eq = BINARY_OP("eq", 97, "f64", "f64", "i32");
4707
+ ne = BINARY_OP("ne", 98, "f64", "f64", "i32");
4708
+ lt = BINARY_OP("lt", 99, "f64", "f64", "i32");
4709
+ gt = BINARY_OP("gt", 100, "f64", "f64", "i32");
4710
+ le = BINARY_OP("le", 101, "f64", "f64", "i32");
4711
+ ge = BINARY_OP("ge", 102, "f64", "f64", "i32");
4712
+ abs = UNARY_OP("abs", 153, "f64", "f64");
4713
+ neg = UNARY_OP("neg", 154, "f64", "f64");
4714
+ ceil = UNARY_OP("ceil", 155, "f64", "f64");
4715
+ floor = UNARY_OP("floor", 156, "f64", "f64");
4716
+ trunc = UNARY_OP("trunc", 157, "f64", "f64");
4717
+ nearest = UNARY_OP("nearest", 158, "f64", "f64");
4718
+ sqrt = UNARY_OP("sqrt", 159, "f64", "f64");
4719
+ add = BINARY_OP("add", 160, "f64", "f64", "f64");
4720
+ sub = BINARY_OP("sub", 161, "f64", "f64", "f64");
4721
+ mul = BINARY_OP("mul", 162, "f64", "f64", "f64");
4722
+ div = BINARY_OP("div", 163, "f64", "f64", "f64");
4723
+ min = BINARY_OP("min", 164, "f64", "f64", "f64");
4724
+ max = BINARY_OP("max", 165, "f64", "f64", "f64");
4725
+ copysign = BINARY_OP("copysign", 166, "f64", "f64", "f64");
4726
+ convert_i32_s = UNARY_OP("convert_i32_s", 183, "i32", "f64");
4727
+ convert_i32_u = UNARY_OP("convert_i32_u", 184, "i32", "f64");
4728
+ promote_f32 = UNARY_OP("promote_f32", 187, "f32", "f64");
4729
+ };
4730
+ function VECTOR_OP(op, vopcode, inTypes, outType) {
4731
+ return function() {
4732
+ for (const inType of inTypes.toReversed()) {
4733
+ const actualType = this.cg._pop();
4734
+ assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes.join(", ")} -> ${outType})`);
4572
4735
  }
4573
- if (tracing) {
4574
- const info = traceSourceInfo(exe.source);
4575
- emitTrace("wasm", info, start, performance.now());
4736
+ this.cg._emit(encodeOpcode([253, vopcode]));
4737
+ this.cg._push(this.cg[outType]);
4738
+ };
4739
+ }
4740
+ function VECTOR_OPL(op, vopcode, inTypes, outType) {
4741
+ return function(lane) {
4742
+ for (const inType of inTypes.toReversed()) {
4743
+ const actualType = this.cg._pop();
4744
+ assert(actualType.typeId === this.cg[inType].typeId, `invalid type for ${op} (${inTypes} -> ${outType})`);
4576
4745
  }
4746
+ this.cg._emit(encodeOpcode([253, vopcode]));
4747
+ this.cg._emit(lane);
4748
+ this.cg._push(this.cg[outType]);
4749
+ };
4750
+ }
4751
+ function VECTOR_LOAD_OP(op, vopcode) {
4752
+ return function(align = 0, offset = 0) {
4753
+ const idxType = this.cg._pop();
4754
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for ${op}`);
4755
+ this.cg._emit(encodeOpcode([253, vopcode]));
4756
+ this.cg._emit(encodeUnsigned(align));
4757
+ this.cg._emit(encodeUnsigned(offset));
4758
+ this.cg._push(this.cg.v128);
4759
+ };
4760
+ }
4761
+ var V128 = class {
4762
+ constructor(cg) {
4763
+ this.cg = cg;
4577
4764
  }
4578
- #getBuffer(slot) {
4579
- const buffer = this.#buffers.get(slot);
4580
- if (!buffer) throw new SlotError(slot);
4581
- return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
4765
+ get typeId() {
4766
+ return 123;
4767
+ }
4768
+ get name() {
4769
+ return "v128";
4770
+ }
4771
+ load = VECTOR_LOAD_OP("load", 0);
4772
+ load32x2_s = VECTOR_LOAD_OP("load32x2_s", 5);
4773
+ load32x2_u = VECTOR_LOAD_OP("load32x2_u", 6);
4774
+ load32_splat = VECTOR_LOAD_OP("load32_splat", 9);
4775
+ load32_zero = VECTOR_LOAD_OP("load32_zero", 92);
4776
+ store(align = 0, offset = 0) {
4777
+ const valType = this.cg._pop();
4778
+ assert(valType.typeId === this.cg.v128.typeId, `invalid type for store`);
4779
+ const idxType = this.cg._pop();
4780
+ assert(idxType.typeId === this.cg.i32.typeId, `invalid type for store`);
4781
+ this.cg._emit(253);
4782
+ this.cg._emit(encodeUnsigned(11));
4783
+ this.cg._emit(encodeUnsigned(align));
4784
+ this.cg._emit(encodeUnsigned(offset));
4582
4785
  }
4786
+ not = VECTOR_OP("not", 77, ["v128"], "v128");
4787
+ and = VECTOR_OP("and", 78, ["v128", "v128"], "v128");
4788
+ andnot = VECTOR_OP("andnot", 79, ["v128", "v128"], "v128");
4789
+ or = VECTOR_OP("or", 80, ["v128", "v128"], "v128");
4790
+ xor = VECTOR_OP("xor", 81, ["v128", "v128"], "v128");
4791
+ bitselect = VECTOR_OP("bitselect", 82, [
4792
+ "v128",
4793
+ "v128",
4794
+ "v128"
4795
+ ], "v128");
4796
+ any_true = VECTOR_OP("any_true", 83, ["v128"], "i32");
4583
4797
  };
4584
- /** Emit a runtime guard: enter the if-block only when [begin, end) is SIMD-aligned. */
4585
- function emitAlignmentGuard(cg, paramBegin, paramEnd) {
4586
- const mask = SIMD_LANES - 1;
4587
- cg.local.get(paramEnd);
4588
- cg.local.get(paramBegin);
4589
- cg.i32.sub();
4590
- cg.i32.const(mask);
4591
- cg.i32.and();
4592
- cg.i32.eqz();
4593
- cg.local.get(paramBegin);
4594
- cg.i32.const(mask);
4595
- cg.i32.and();
4596
- cg.i32.eqz();
4597
- cg.i32.and();
4598
- cg.if(cg.void);
4599
- }
4600
- function codegenWasm(kernel) {
4601
- const tune = tuneNullopt(kernel);
4602
- const re = kernel.reduction;
4603
- if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
4604
- const useSimd = isSimdEligible(tune.exp, kernel);
4605
- const bufferStrides = /* @__PURE__ */ new Map();
4606
- if (useSimd) tune.exp.collect((e) => e.op === AluOp.GlobalIndex).forEach((gi) => {
4607
- const result = analyzeStride(gi.src[0]);
4608
- if (result.kind !== "gather" && (result.tileSize < SIMD_LANES || isFinite(result.tileSize) && result.tileSize % SIMD_LANES !== 0)) bufferStrides.set(gi, GATHER);
4609
- else bufferStrides.set(gi, result);
4610
- });
4611
- const cg = new CodeGenerator();
4612
- cg.memory.import("env", "memory");
4613
- if (hasSharedArrayBuffer()) cg.memory.pages(0, 65536).shared(true);
4614
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
4615
- const funcs = {};
4616
- if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
4617
- if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
4618
- if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
4619
- if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
4620
- if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
4621
- if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
4622
- if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
4623
- if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
4624
- if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
4625
- const paramBegin = kernel.nargs + 1;
4626
- const paramEnd = kernel.nargs + 2;
4627
- const kernelFunc = cg.function(rep(kernel.nargs + 3, cg.i32), [], () => {
4628
- const gidx = cg.local.declare(cg.i32);
4629
- cg.local.get(paramBegin);
4630
- cg.local.set(gidx);
4631
- if (useSimd) {
4632
- emitAlignmentGuard(cg, paramBegin, paramEnd);
4633
- cg.loop(cg.void);
4634
- if (!re) {
4635
- cg.block(cg.void);
4636
- cg.local.get(gidx);
4637
- cg.local.get(paramEnd);
4638
- cg.i32.ge_u();
4639
- cg.br_if(0);
4640
- cg.local.get(kernel.nargs);
4641
- cg.local.get(gidx);
4642
- cg.i32.const(byteWidth(kernel.dtype));
4643
- cg.i32.mul();
4644
- cg.i32.add();
4645
- translateExpSimd(cg, funcs, tune.exp, { gidx }, bufferStrides);
4646
- cg.v128.store(4);
4647
- cg.local.get(gidx);
4648
- cg.i32.const(SIMD_LANES);
4649
- cg.i32.add();
4650
- cg.local.set(gidx);
4651
- cg.br(1);
4652
- cg.end();
4653
- } else {
4654
- const reIsInt = kernel.exp.dtype === DType.Int32 || kernel.exp.dtype === DType.Uint32;
4655
- cg.block(cg.void);
4656
- cg.local.get(gidx);
4657
- cg.local.get(paramEnd);
4658
- cg.i32.ge_u();
4659
- cg.br_if(0);
4660
- const vecAcc = cg.local.declare(reIsInt ? cg.i32x4 : cg.f32x4);
4661
- if (reIsInt) {
4662
- cg.i32.const(re.identity);
4663
- cg.i32x4.splat();
4664
- } else {
4665
- cg.f32.const(re.identity);
4666
- cg.f32x4.splat();
4667
- }
4668
- cg.local.set(vecAcc);
4669
- const ridx = cg.local.declare(cg.i32);
4670
- cg.i32.const(0);
4671
- cg.local.set(ridx);
4672
- cg.loop(cg.void);
4673
- cg.block(cg.void);
4674
- cg.local.get(ridx);
4675
- cg.i32.const(re.size);
4676
- cg.i32.ge_u();
4677
- cg.br_if(0);
4678
- translateExpSimd(cg, funcs, tune.exp, {
4679
- gidx,
4680
- ridx
4681
- }, bufferStrides);
4682
- cg.local.get(vecAcc);
4683
- if (reIsInt) if (re.op === AluOp.Add) cg.i32x4.add();
4684
- else if (re.op === AluOp.Mul) cg.i32x4.mul();
4685
- else if (re.op === AluOp.Min) if (re.dtype === DType.Int32) cg.i32x4.min_s();
4686
- else cg.i32x4.min_u();
4687
- else if (re.op === AluOp.Max) if (re.dtype === DType.Int32) cg.i32x4.max_s();
4688
- else cg.i32x4.max_u();
4689
- else throw new Error(`invalid SIMD reduction op: ${re.op}`);
4690
- else if (re.op === AluOp.Add) cg.f32x4.add();
4691
- else if (re.op === AluOp.Mul) cg.f32x4.mul();
4692
- else if (re.op === AluOp.Min) cg.f32x4.min();
4693
- else if (re.op === AluOp.Max) cg.f32x4.max();
4694
- else throw new Error(`invalid SIMD reduction op: ${re.op}`);
4695
- cg.local.set(vecAcc);
4696
- cg.local.get(ridx);
4697
- cg.i32.const(1);
4698
- cg.i32.add();
4699
- cg.local.set(ridx);
4700
- cg.br(1);
4701
- cg.end();
4702
- cg.end();
4703
- for (let lane = 0; lane < SIMD_LANES; lane++) {
4704
- cg.local.get(kernel.nargs);
4705
- cg.local.get(gidx);
4706
- if (lane > 0) {
4707
- cg.i32.const(lane);
4708
- cg.i32.add();
4709
- }
4710
- cg.i32.const(byteWidth(kernel.dtype));
4711
- cg.i32.mul();
4712
- cg.i32.add();
4713
- const acc = cg.local.declare(reIsInt ? cg.i32 : cg.f32);
4714
- cg.local.get(vecAcc);
4715
- if (reIsInt) cg.i32x4.extract_lane(lane);
4716
- else cg.f32x4.extract_lane(lane);
4717
- cg.local.set(acc);
4718
- const laneGidx = cg.local.declare(cg.i32);
4719
- cg.local.get(gidx);
4720
- if (lane > 0) {
4721
- cg.i32.const(lane);
4722
- cg.i32.add();
4723
- }
4724
- cg.local.set(laneGidx);
4725
- translateExp(cg, funcs, tune.epilogue, {
4726
- acc,
4727
- gidx: laneGidx
4728
- });
4729
- dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
4730
- }
4731
- cg.local.get(gidx);
4732
- cg.i32.const(SIMD_LANES);
4733
- cg.i32.add();
4734
- cg.local.set(gidx);
4735
- cg.br(1);
4736
- cg.end();
4737
- }
4738
- cg.end();
4739
- cg.return();
4740
- cg.end();
4741
- }
4742
- cg.loop(cg.void);
4743
- cg.block(cg.void);
4744
- cg.local.get(gidx);
4745
- cg.local.get(paramEnd);
4746
- cg.i32.ge_u();
4747
- cg.br_if(0);
4748
- cg.local.get(kernel.nargs);
4749
- cg.local.get(gidx);
4750
- cg.i32.const(byteWidth(kernel.dtype));
4798
+ var I32x4 = class extends V128 {
4799
+ splat = VECTOR_OP("splat", 17, ["i32"], "v128");
4800
+ extract_lane = VECTOR_OPL("extract_lane", 27, ["v128"], "i32");
4801
+ replace_lane = VECTOR_OPL("replace_lane", 28, ["v128", "i32"], "v128");
4802
+ eq = VECTOR_OP("eq", 55, ["v128", "v128"], "v128");
4803
+ ne = VECTOR_OP("ne", 56, ["v128", "v128"], "v128");
4804
+ lt_s = VECTOR_OP("lt_s", 57, ["v128", "v128"], "v128");
4805
+ lt_u = VECTOR_OP("lt_u", 58, ["v128", "v128"], "v128");
4806
+ gt_s = VECTOR_OP("gt_s", 59, ["v128", "v128"], "v128");
4807
+ gt_u = VECTOR_OP("gt_u", 60, ["v128", "v128"], "v128");
4808
+ le_s = VECTOR_OP("le_s", 61, ["v128", "v128"], "v128");
4809
+ le_u = VECTOR_OP("le_u", 62, ["v128", "v128"], "v128");
4810
+ ge_s = VECTOR_OP("ge_s", 63, ["v128", "v128"], "v128");
4811
+ ge_u = VECTOR_OP("ge_u", 64, ["v128", "v128"], "v128");
4812
+ abs = VECTOR_OP("abs", 160, ["v128"], "v128");
4813
+ neg = VECTOR_OP("neg", 161, ["v128"], "v128");
4814
+ all_true = VECTOR_OP("all_true", 163, ["v128"], "i32");
4815
+ bitmask = VECTOR_OP("bitmask", 164, ["v128"], "i32");
4816
+ shl = VECTOR_OP("shl", 171, ["v128", "i32"], "v128");
4817
+ shr_s = VECTOR_OP("shr_s", 172, ["v128", "i32"], "v128");
4818
+ shr_u = VECTOR_OP("shr_u", 173, ["v128", "i32"], "v128");
4819
+ add = VECTOR_OP("add", 174, ["v128", "v128"], "v128");
4820
+ sub = VECTOR_OP("sub", 177, ["v128", "v128"], "v128");
4821
+ mul = VECTOR_OP("mul", 181, ["v128", "v128"], "v128");
4822
+ min_s = VECTOR_OP("min_s", 182, ["v128", "v128"], "v128");
4823
+ min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
4824
+ max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
4825
+ max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
4826
+ trunc_sat_f32x4_s = VECTOR_OP("trunc_sat_f32x4_s", 248, ["v128"], "v128");
4827
+ trunc_sat_f32x4_u = VECTOR_OP("trunc_sat_f32x4_u", 249, ["v128"], "v128");
4828
+ };
4829
+ var F32x4 = class extends V128 {
4830
+ splat = VECTOR_OP("splat", 19, ["f32"], "v128");
4831
+ extract_lane = VECTOR_OPL("extract_lane", 31, ["v128"], "f32");
4832
+ replace_lane = VECTOR_OPL("replace_lane", 32, ["v128", "f32"], "v128");
4833
+ eq = VECTOR_OP("eq", 65, ["v128", "v128"], "v128");
4834
+ ne = VECTOR_OP("ne", 66, ["v128", "v128"], "v128");
4835
+ lt = VECTOR_OP("lt", 67, ["v128", "v128"], "v128");
4836
+ gt = VECTOR_OP("gt", 68, ["v128", "v128"], "v128");
4837
+ le = VECTOR_OP("le", 69, ["v128", "v128"], "v128");
4838
+ ge = VECTOR_OP("ge", 70, ["v128", "v128"], "v128");
4839
+ ceil = VECTOR_OP("ceil", 103, ["v128"], "v128");
4840
+ floor = VECTOR_OP("floor", 104, ["v128"], "v128");
4841
+ trunc = VECTOR_OP("trunc", 105, ["v128"], "v128");
4842
+ nearest = VECTOR_OP("nearest", 106, ["v128"], "v128");
4843
+ abs = VECTOR_OP("abs", 224, ["v128"], "v128");
4844
+ neg = VECTOR_OP("neg", 225, ["v128"], "v128");
4845
+ sqrt = VECTOR_OP("sqrt", 227, ["v128"], "v128");
4846
+ add = VECTOR_OP("add", 228, ["v128", "v128"], "v128");
4847
+ sub = VECTOR_OP("sub", 229, ["v128", "v128"], "v128");
4848
+ mul = VECTOR_OP("mul", 230, ["v128", "v128"], "v128");
4849
+ div = VECTOR_OP("div", 231, ["v128", "v128"], "v128");
4850
+ min = VECTOR_OP("min", 232, ["v128", "v128"], "v128");
4851
+ max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
4852
+ pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
4853
+ pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
4854
+ convert_i32x4_s = VECTOR_OP("convert_i32x4_s", 250, ["v128"], "v128");
4855
+ convert_i32x4_u = VECTOR_OP("convert_i32x4_u", 251, ["v128"], "v128");
4856
+ relaxed_madd = VECTOR_OP("relaxed_madd", 261, [
4857
+ "v128",
4858
+ "v128",
4859
+ "v128"
4860
+ ], "v128");
4861
+ relaxed_nmadd = VECTOR_OP("relaxed_nmadd", 262, [
4862
+ "v128",
4863
+ "v128",
4864
+ "v128"
4865
+ ], "v128");
4866
+ };
4867
+
4868
+ //#endregion
4869
+ //#region src/backend/wasm/codegen.ts
4870
+ function isIdentityEpilogue(exp) {
4871
+ return exp?.op === AluOp.Variable && exp.arg === "acc";
4872
+ }
4873
+ function initializeReductionPointer(cg, funcs, candidate, ctx, valueKey, ridxOffset) {
4874
+ const ptr = cg.local.declare(cg.i32);
4875
+ translateExp(cg, funcs, candidate.baseIndex, ctx);
4876
+ cg.i32.const(byteWidth(candidate.dtype));
4877
+ cg.i32.mul();
4878
+ cg.local.get(candidate.gid);
4879
+ cg.i32.add();
4880
+ if (ridxOffset !== void 0 && candidate.strideBytes !== 0) {
4881
+ cg.local.get(ridxOffset);
4882
+ cg.i32.const(candidate.strideBytes);
4751
4883
  cg.i32.mul();
4752
4884
  cg.i32.add();
4753
- if (re) {
4754
- const acc = cg.local.declare(dty(cg, null, kernel.exp.dtype));
4755
- dty(cg, null, kernel.exp.dtype).const(re.identity);
4756
- cg.local.set(acc);
4757
- const ridx = cg.local.declare(cg.i32);
4758
- cg.i32.const(0);
4759
- cg.local.set(ridx);
4760
- cg.loop(cg.void);
4761
- cg.block(cg.void);
4762
- cg.local.get(ridx);
4763
- cg.i32.const(re.size);
4764
- cg.i32.ge_u();
4765
- cg.br_if(0);
4766
- translateExp(cg, funcs, tune.exp, {
4767
- gidx,
4768
- ridx
4769
- });
4770
- if (re.op === AluOp.Add) {
4771
- cg.local.get(acc);
4772
- if (re.dtype === DType.Bool) cg.i32.or();
4773
- else dty(cg, re.op, re.dtype).add();
4774
- } else if (re.op === AluOp.Mul) {
4775
- cg.local.get(acc);
4776
- if (re.dtype === DType.Bool) cg.i32.and();
4777
- else dty(cg, re.op, re.dtype).mul();
4778
- } else if (re.op === AluOp.Min || re.op === AluOp.Max) if (isFloatDtype(re.dtype)) {
4885
+ }
4886
+ cg.local.set(ptr);
4887
+ return {
4888
+ ...candidate,
4889
+ ptr,
4890
+ valueKey
4891
+ };
4892
+ }
4893
+ function incrementReductionPointers(cg, pointers, multiplier = 1) {
4894
+ for (const pointer of pointers) {
4895
+ if (pointer.strideBytes === 0) continue;
4896
+ cg.local.get(pointer.ptr);
4897
+ cg.i32.const(pointer.strideBytes * multiplier);
4898
+ cg.i32.add();
4899
+ cg.local.set(pointer.ptr);
4900
+ }
4901
+ }
4902
+ function emitSimdReductionOp(cg, re, reIsInt, valueAlreadyAccumulated) {
4903
+ if (!reIsInt && valueAlreadyAccumulated) return;
4904
+ switch (re.op) {
4905
+ case AluOp.Add:
4906
+ if (reIsInt) cg.i32x4.add();
4907
+ else cg.f32x4.add();
4908
+ return;
4909
+ case AluOp.Mul:
4910
+ if (reIsInt) cg.i32x4.mul();
4911
+ else cg.f32x4.mul();
4912
+ return;
4913
+ case AluOp.Min:
4914
+ if (reIsInt) if (re.dtype === DType.Int32) cg.i32x4.min_s();
4915
+ else cg.i32x4.min_u();
4916
+ else cg.f32x4.min();
4917
+ return;
4918
+ case AluOp.Max:
4919
+ if (reIsInt) if (re.dtype === DType.Int32) cg.i32x4.max_s();
4920
+ else cg.i32x4.max_u();
4921
+ else cg.f32x4.max();
4922
+ return;
4923
+ default: throw new Error(`invalid SIMD reduction op: ${re.op}`);
4924
+ }
4925
+ }
4926
+ function emitScalarReductionOp(cg, re, acc) {
4927
+ switch (re.op) {
4928
+ case AluOp.Add:
4929
+ cg.local.get(acc);
4930
+ if (re.dtype === DType.Bool) cg.i32.or();
4931
+ else dty(cg, re.op, re.dtype).add();
4932
+ return;
4933
+ case AluOp.Mul:
4934
+ cg.local.get(acc);
4935
+ if (re.dtype === DType.Bool) cg.i32.and();
4936
+ else dty(cg, re.op, re.dtype).mul();
4937
+ return;
4938
+ case AluOp.Min:
4939
+ case AluOp.Max:
4940
+ if (isFloatDtype(re.dtype)) {
4779
4941
  cg.local.get(acc);
4780
4942
  if (re.op === AluOp.Min) dtyF(cg, re.op, re.dtype).min();
4781
4943
  else dtyF(cg, re.op, re.dtype).max();
@@ -4795,227 +4957,619 @@ function codegenWasm(kernel) {
4795
4957
  else cg.i32.gt_u();
4796
4958
  cg.select();
4797
4959
  } else throw new Error(`invalid reduction min/max over ${re.dtype}`);
4798
- else throw new Error(`invalid wasm reduction op: ${re.op}`);
4799
- cg.local.set(acc);
4800
- cg.local.get(ridx);
4801
- cg.i32.const(1);
4802
- cg.i32.add();
4803
- cg.local.set(ridx);
4804
- cg.br(1);
4805
- cg.end();
4806
- cg.end();
4807
- translateExp(cg, funcs, tune.epilogue, {
4808
- acc,
4809
- gidx
4810
- });
4811
- } else translateExp(cg, funcs, tune.exp, { gidx });
4812
- dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
4813
- cg.local.get(gidx);
4814
- cg.i32.const(1);
4815
- cg.i32.add();
4816
- cg.local.set(gidx);
4817
- cg.br(1);
4818
- cg.end();
4819
- cg.end();
4820
- });
4821
- cg.export(kernelFunc, "kernel");
4822
- return cg.finish();
4960
+ return;
4961
+ default: throw new Error(`invalid wasm reduction op: ${re.op}`);
4962
+ }
4823
4963
  }
4824
- function translateExp(cg, funcs, exp, ctx) {
4825
- const references = /* @__PURE__ */ new Map();
4826
- const seen = /* @__PURE__ */ new Set();
4827
- const countReferences = (exp$1) => {
4828
- references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
4829
- if (!seen.has(exp$1)) {
4830
- seen.add(exp$1);
4831
- for (const src of exp$1.src) countReferences(src);
4832
- }
4964
+ /**
4965
+ * Check if a kernel is eligible for SIMD codegen.
4966
+ *
4967
+ * A kernel qualifies when:
4968
+ * - size >= 4 (need at least 4 elements for a SIMD group)
4969
+ * - For reductions: the reduction op has a SIMD variant for its dtype
4970
+ * - All nodes have a supported dtype (f32, i32, u32, bool) with SIMD variants
4971
+ */
4972
+ function isSimdEligible(tunedExp, kernel) {
4973
+ if (kernel.size < simdLanes) return false;
4974
+ if (kernel.reduction) {
4975
+ if (!simdSupportedOps.get(kernel.reduction.dtype)?.has(kernel.reduction.op)) return false;
4976
+ }
4977
+ const check = (exp, visited) => {
4978
+ if (visited.has(exp)) return true;
4979
+ visited.add(exp);
4980
+ if (!simdSupportedOps.get(exp.dtype)?.has(exp.op)) return false;
4981
+ if (exp.op === AluOp.GlobalIndex) return true;
4982
+ for (const child of exp.src) if (!check(child, visited)) return false;
4983
+ return true;
4833
4984
  };
4834
- const expContext = /* @__PURE__ */ new Map();
4835
- const gen = (exp$1) => {
4836
- if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
4837
- const { op, src, dtype, arg } = exp$1;
4838
- if (AluGroup.Binary.has(op) || AluGroup.Compare.has(op)) {
4839
- gen(src[0]);
4840
- gen(src[1]);
4841
- if (op === AluOp.Add) if (dtype === DType.Bool) cg.i32.or();
4842
- else dty(cg, op, dtype).add();
4843
- else if (op === AluOp.Sub) dty(cg, op, dtype).sub();
4844
- else if (op === AluOp.Mul) if (dtype === DType.Bool) cg.i32.and();
4845
- else dty(cg, op, dtype).mul();
4846
- else if (op === AluOp.Idiv) if (isFloatDtype(dtype)) {
4847
- dtyF(cg, op, dtype).div();
4848
- dtyF(cg, op, dtype).trunc();
4849
- } else if (dtype === DType.Uint32) cg.i32.div_u();
4850
- else if (dtype === DType.Int32) cg.i32.div_s();
4851
- else throw new UnsupportedOpError(op, dtype, "wasm");
4852
- else if (op === AluOp.Mod) if (isFloatDtype(dtype)) {
4853
- const dt = dtyF(cg, op, dtype);
4854
- const a = cg.local.declare(dt);
4855
- const b = cg.local.declare(dt);
4856
- cg.local.set(b);
4857
- cg.local.tee(a);
4858
- cg.local.get(a);
4859
- cg.local.get(b);
4860
- dt.div();
4861
- dt.trunc();
4862
- cg.local.get(b);
4863
- dt.mul();
4864
- dt.sub();
4865
- } else if (dtype === DType.Uint32) cg.i32.rem_u();
4866
- else if (dtype === DType.Int32) cg.i32.rem_s();
4867
- else throw new UnsupportedOpError(op, dtype, "wasm");
4868
- else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
4869
- else dtyF(cg, op, dtype).max();
4870
- else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
4871
- const a = cg.local.declare(cg.i32);
4872
- const b = cg.local.declare(cg.i32);
4873
- cg.local.set(b);
4874
- cg.local.tee(a);
4875
- cg.local.get(b);
4876
- cg.local.get(a);
4877
- cg.local.get(b);
4878
- if (dtype === DType.Int32) if (op === AluOp.Min) cg.i32.lt_s();
4879
- else cg.i32.gt_s();
4880
- else if (op === AluOp.Min) cg.i32.lt_u();
4881
- else cg.i32.gt_u();
4882
- cg.select();
4883
- } else throw new UnsupportedOpError(op, dtype, "wasm");
4884
- else if (op === AluOp.BitCombine) if (arg === "and") cg.i32.and();
4885
- else if (arg === "or") cg.i32.or();
4886
- else cg.i32.xor();
4887
- else if (op === AluOp.BitShift) if (arg === "shl") cg.i32.shl();
4888
- else cg.i32.shr_u();
4889
- else if (op === AluOp.Cmplt) {
4890
- const srcDtype = src[0].dtype;
4891
- if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
4892
- else if (srcDtype === DType.Int32) cg.i32.lt_s();
4893
- else if (srcDtype === DType.Uint32) cg.i32.lt_u();
4894
- else throw new UnsupportedOpError(op, dtype, "wasm");
4895
- } else if (op === AluOp.Cmpne) dty(cg, op, src[0].dtype).ne();
4896
- else throw new UnsupportedOpError(op, dtype, "wasm");
4897
- } else if (AluGroup.Unary.has(op)) {
4898
- const callFuncF32 = (func) => {
4899
- if (dtype !== DType.Float32) if (dtype === DType.Float64) cg.f32.demote_f64();
4900
- else throw new UnsupportedOpError(op, dtype, "wasm");
4901
- cg.call(func);
4902
- if (dtype === DType.Float64) cg.f64.promote_f32();
4903
- };
4904
- if (op === AluOp.Sin) gen(src[0]), callFuncF32(funcs.sin);
4905
- else if (op === AluOp.Cos) gen(src[0]), callFuncF32(funcs.cos);
4906
- else if (op === AluOp.Asin) gen(src[0]), callFuncF32(funcs.asin);
4907
- else if (op === AluOp.Atan) gen(src[0]), callFuncF32(funcs.atan);
4908
- else if (op === AluOp.Exp) gen(src[0]), callFuncF32(funcs.exp);
4909
- else if (op === AluOp.Log) gen(src[0]), callFuncF32(funcs.log);
4910
- else if (op === AluOp.Erf) gen(src[0]), callFuncF32(funcs.erf);
4911
- else if (op === AluOp.Erfc) gen(src[0]), callFuncF32(funcs.erfc);
4912
- else if (op === AluOp.Sqrt) gen(src[0]), dtyF(cg, op, dtype).sqrt();
4913
- else if (op === AluOp.Reciprocal) {
4914
- const dt = dtyF(cg, op, dtype);
4915
- dt.const(1), gen(src[0]), dt.div();
4916
- } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
4917
- else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
4918
- else if (op === AluOp.Cast) {
4919
- gen(src[0]);
4920
- const dtype0 = src[0].dtype;
4921
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
4922
- if (dtype === DType.Int32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_s();
4923
- else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_s();
4924
- else if (i32repr);
4925
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4926
- else if (dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.trunc_sat_f32_u();
4927
- else if (dtype0 === DType.Float64) cg.i32.trunc_sat_f64_u();
4928
- else if (i32repr);
4929
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4930
- else if (dtype === DType.Float32) if (dtype0 === DType.Float32);
4931
- else if (dtype0 === DType.Float64) cg.f32.demote_f64();
4932
- else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32.convert_i32_s();
4933
- else if (dtype0 === DType.Uint32) cg.f32.convert_i32_u();
4934
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4935
- else if (dtype === DType.Float64) if (dtype0 === DType.Float32) cg.f64.promote_f32();
4936
- else if (dtype0 === DType.Float64);
4937
- else if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f64.convert_i32_s();
4938
- else if (dtype0 === DType.Uint32) cg.f64.convert_i32_u();
4939
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4940
- else if (dtype === DType.Bool) if (dtype0 === DType.Bool);
4941
- else if (i32repr) cg.i32.const(0), cg.i32.ne();
4942
- else if (dtype0 === DType.Float32) cg.f32.const(0), cg.f32.ne();
4943
- else if (dtype0 === DType.Float64) cg.f64.const(0), cg.f64.ne();
4944
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4945
- else throw new UnsupportedOpError(op, dtype, "wasm");
4946
- } else if (op === AluOp.Bitcast) {
4947
- gen(src[0]);
4948
- const dtype0 = src[0].dtype;
4949
- if (dtype !== dtype0) {
4950
- const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32;
4951
- if (dtype === DType.Int32 || dtype === DType.Uint32) if (dtype0 === DType.Float32) cg.i32.reinterpret_f32();
4952
- else if (i32repr);
4953
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4954
- else if (dtype === DType.Float32) if (i32repr) cg.f32.reinterpret_i32();
4955
- else if (dtype0 === DType.Float32);
4956
- else throw new UnsupportedOpError(op, dtype, "wasm", dtype0);
4957
- else throw new UnsupportedOpError(op, dtype, "wasm");
4958
- }
4959
- } else throw new UnsupportedOpError(op, dtype, "wasm");
4960
- } else if (op === AluOp.Where) {
4961
- gen(src[1]);
4962
- gen(src[2]);
4963
- gen(src[0]);
4964
- cg.select();
4965
- } else if (op === AluOp.Threefry2x32) {
4966
- for (let i = 0; i < 4; i++) gen(src[i]);
4967
- cg.call(funcs.threefry2x32);
4968
- if (arg === "xor") cg.i32.xor();
4969
- else if (arg === 0) cg.drop();
4970
- else if (arg === 1) {
4971
- const local = cg.local.declare(cg.i32);
4972
- cg.local.set(local);
4973
- cg.drop();
4974
- cg.local.get(local);
4975
- } else throw new UnsupportedOpError(op, dtype, "wasm", arg);
4976
- } else if (op === AluOp.Const) return dty(cg, op, dtype).const(arg);
4977
- else if (op === AluOp.Special) return cg.local.get(ctx[arg[0]]);
4978
- else if (op === AluOp.Variable) return cg.local.get(ctx[arg]);
4979
- else if (op === AluOp.GlobalIndex) {
4980
- const [gid, len] = arg;
4981
- gen(src[0]);
4982
- const local = cg.local.declare(cg.i32);
4983
- cg.local.tee(local);
4984
- cg.i32.const(0);
4985
- cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
4986
- cg.select();
4987
- cg.i32.const(byteWidth(dtype));
4985
+ return check(tunedExp, /* @__PURE__ */ new Set());
4986
+ }
4987
+ /** Checks if SIMD over the reduced-k dimension is workable. */
4988
+ function canUseKSimdReduction(exp, re, pointers) {
4989
+ if (re.op !== AluOp.Add) return false;
4990
+ const globalIndexCount = exp.collect((node) => node.op === AluOp.GlobalIndex).length;
4991
+ return globalIndexCount === pointers.length && pointers.every((candidate) => candidate.dtype === DType.Float32 && (candidate.strideBytes === 0 || candidate.strideBytes === byteWidth(candidate.dtype)));
4992
+ }
4993
+ /** Emit a runtime guard: enter the if-block only when [begin, end) is SIMD-aligned. */
4994
+ function emitAlignmentGuard(cg, paramBegin, paramEnd, alignment = simdLanes) {
4995
+ cg.local.get(paramEnd);
4996
+ cg.local.get(paramBegin);
4997
+ cg.i32.sub();
4998
+ cg.i32.const(alignment);
4999
+ cg.i32.rem_u();
5000
+ cg.i32.eqz();
5001
+ cg.local.get(paramBegin);
5002
+ cg.i32.const(alignment);
5003
+ cg.i32.rem_u();
5004
+ cg.i32.eqz();
5005
+ cg.i32.and();
5006
+ cg.if(cg.void);
5007
+ }
5008
+ function codegenWasm(kernel) {
5009
+ const tune = tuneNullopt(kernel);
5010
+ const re = kernel.reduction;
5011
+ if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
5012
+ const simdEligible = isSimdEligible(tune.exp, kernel);
5013
+ const hasIdentityEpilogue = isIdentityEpilogue(tune.epilogue);
5014
+ const expStrides = simdEligible ? collectSimdStrides(tune.exp) : /* @__PURE__ */ new Map();
5015
+ const reductionHasLaneGather = re && [...expStrides.values()].some((stride) => stride.kind === "gather");
5016
+ const useSimd = simdEligible && !reductionHasLaneGather;
5017
+ const simdReductionPointerCandidates = useSimd && re ? reductionPointerCandidates(tune.exp, expStrides) : [];
5018
+ const reductionPointers = re ? reductionPointerCandidates(tune.exp) : [];
5019
+ const useKSimdReduction = simdEligible && re && canUseKSimdReduction(tune.exp, re, reductionPointers);
5020
+ const kSimdReductionPointerCandidates = useKSimdReduction && re ? reductionPointers.map((candidate) => ({
5021
+ ...candidate,
5022
+ stride: candidate.strideBytes === 0 ? {
5023
+ kind: "broadcast",
5024
+ tileSize: Infinity
5025
+ } : {
5026
+ kind: "contiguous",
5027
+ tileSize: Infinity
5028
+ }
5029
+ })) : [];
5030
+ const simdTilePlan = useSimd && re && hasIdentityEpilogue ? reductionTilePlan(kernel, expStrides) : null;
5031
+ const kSimdTilePlan = reductionHasLaneGather && useKSimdReduction ? reductionKTilePlan(kernel, expStrides) : null;
5032
+ const useRelaxedMadd = hasWasmFeature("relaxed-madd") && re?.op === AluOp.Add && tune.exp.dtype === DType.Float32 && tune.exp.op === AluOp.Mul;
5033
+ const cg = new CodeGenerator();
5034
+ cg.memory.import("env", "memory");
5035
+ if (hasSharedArrayBuffer()) cg.memory.pages(0, 65536).shared(true);
5036
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
5037
+ const funcs = {};
5038
+ if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
5039
+ if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
5040
+ if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
5041
+ if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
5042
+ if (distinctOps.has(AluOp.Exp) || distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) funcs.exp = wasm_exp(cg);
5043
+ if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
5044
+ if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
5045
+ if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
5046
+ if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
5047
+ const paramBegin = kernel.nargs + 1;
5048
+ const paramEnd = kernel.nargs + 2;
5049
+ const kernelFunc = cg.function(rep(kernel.nargs + 3, cg.i32), [], () => {
5050
+ const gidx = cg.local.declare(cg.i32);
5051
+ cg.local.get(paramBegin);
5052
+ cg.local.set(gidx);
5053
+ const emitLocalPlusConst = (local, amount) => {
5054
+ cg.local.get(local);
5055
+ cg.i32.const(amount);
5056
+ cg.i32.add();
5057
+ };
5058
+ const bumpLocal = (local, amount) => {
5059
+ emitLocalPlusConst(local, amount);
5060
+ cg.local.set(local);
5061
+ };
5062
+ const setLocalConst = (local, value) => {
5063
+ cg.i32.const(value);
5064
+ cg.local.set(local);
5065
+ };
5066
+ const copyLocal = (target, source) => {
5067
+ cg.local.get(source);
5068
+ cg.local.set(target);
5069
+ };
5070
+ const setRowBase = (target, rowTileBase, rowOffset, tileSize) => {
5071
+ cg.local.get(rowTileBase);
5072
+ cg.local.get(rowOffset);
5073
+ cg.i32.const(tileSize);
4988
5074
  cg.i32.mul();
4989
- cg.local.get(gid);
4990
5075
  cg.i32.add();
4991
- dty(cg, op, dtype).load(Math.log2(byteWidth(dtype)));
4992
- } else throw new UnsupportedOpError(op, dtype, "wasm");
4993
- if ((references.get(exp$1) ?? 0) > 1) {
4994
- const local = cg.local.declare(dty(cg, op, dtype));
4995
- cg.local.tee(local);
4996
- expContext.set(exp$1, local);
5076
+ cg.local.set(target);
5077
+ };
5078
+ const emitOutputAddress = (index) => {
5079
+ cg.local.get(kernel.nargs);
5080
+ cg.local.get(index);
5081
+ cg.i32.const(byteWidth(kernel.dtype));
5082
+ cg.i32.mul();
5083
+ cg.i32.add();
5084
+ };
5085
+ const declareTileGidx = (rowBase, col, rowOffset, colOffset) => {
5086
+ const local = cg.local.declare(cg.i32);
5087
+ cg.local.get(rowBase);
5088
+ if (rowOffset !== 0) {
5089
+ cg.i32.const(rowOffset);
5090
+ cg.i32.add();
5091
+ }
5092
+ cg.local.get(col);
5093
+ cg.i32.add();
5094
+ if (colOffset !== 0) {
5095
+ cg.i32.const(colOffset);
5096
+ cg.i32.add();
5097
+ }
5098
+ cg.local.set(local);
5099
+ return local;
5100
+ };
5101
+ const emitLoopWhileLt = (index, emitBound, emitBody) => {
5102
+ cg.loop(cg.void);
5103
+ cg.block(cg.void);
5104
+ cg.local.get(index);
5105
+ emitBound();
5106
+ cg.i32.ge_u();
5107
+ cg.br_if(0);
5108
+ emitBody();
5109
+ cg.br(1);
5110
+ cg.end();
5111
+ cg.end();
5112
+ };
5113
+ const emitLoopWhileLocalLt = (index, bound, emitBody) => emitLoopWhileLt(index, () => cg.local.get(bound), emitBody);
5114
+ const emitLoopWhileConstLt = (index, bound, emitBody) => emitLoopWhileLt(index, () => cg.i32.const(bound), emitBody);
5115
+ const emitSimdExpWithAccumulator = (ctx, pointerMap, pointerValueCache, acc) => {
5116
+ if (useRelaxedMadd) {
5117
+ translateExpSimd(cg, funcs, tune.exp.src[0], ctx, expStrides, pointerMap, pointerValueCache);
5118
+ translateExpSimd(cg, funcs, tune.exp.src[1], ctx, expStrides, pointerMap, pointerValueCache);
5119
+ cg.local.get(acc);
5120
+ cg.f32x4.relaxed_madd();
5121
+ return true;
5122
+ }
5123
+ translateExpSimd(cg, funcs, tune.exp, ctx, expStrides, pointerMap, pointerValueCache);
5124
+ cg.local.get(acc);
5125
+ return false;
5126
+ };
5127
+ const emitSimdReductionForGidxs = (gidxs, pointerMaps, uniquePointers, ridxStart, ridxEnd) => {
5128
+ if (!re) throw new Error("internal: missing reduction");
5129
+ const reIsInt = kernel.exp.dtype === DType.Int32 || kernel.exp.dtype === DType.Uint32;
5130
+ const vecAccs = gidxs.map(() => cg.local.declare(reIsInt ? cg.i32x4 : cg.f32x4));
5131
+ const initializeIdentityAccumulators = () => {
5132
+ for (const acc of vecAccs) {
5133
+ if (reIsInt) {
5134
+ cg.i32.const(re.identity);
5135
+ cg.i32x4.splat();
5136
+ } else {
5137
+ cg.f32.const(re.identity);
5138
+ cg.f32x4.splat();
5139
+ }
5140
+ cg.local.set(acc);
5141
+ }
5142
+ };
5143
+ const loadPartialAccumulators = () => {
5144
+ for (let i = 0; i < gidxs.length; i++) {
5145
+ emitOutputAddress(gidxs[i]);
5146
+ if (reIsInt) cg.i32x4.load(4);
5147
+ else cg.f32x4.load(4);
5148
+ cg.local.set(vecAccs[i]);
5149
+ }
5150
+ };
5151
+ if (ridxStart === void 0) initializeIdentityAccumulators();
5152
+ else {
5153
+ cg.local.get(ridxStart);
5154
+ cg.i32.eqz();
5155
+ cg.if(cg.void);
5156
+ initializeIdentityAccumulators();
5157
+ cg.else();
5158
+ loadPartialAccumulators();
5159
+ cg.end();
5160
+ }
5161
+ const ridx = cg.local.declare(cg.i32);
5162
+ const emitReductionStep = () => {
5163
+ const pointerValueCache = /* @__PURE__ */ new Map();
5164
+ for (let i = 0; i < gidxs.length; i++) {
5165
+ const valueAlreadyAccumulated = emitSimdExpWithAccumulator({
5166
+ gidx: gidxs[i],
5167
+ ridx
5168
+ }, pointerMaps[i], pointerValueCache, vecAccs[i]);
5169
+ emitSimdReductionOp(cg, re, reIsInt, valueAlreadyAccumulated);
5170
+ cg.local.set(vecAccs[i]);
5171
+ }
5172
+ incrementReductionPointers(cg, uniquePointers);
5173
+ };
5174
+ if (ridxStart === void 0) setLocalConst(ridx, 0);
5175
+ else copyLocal(ridx, ridxStart);
5176
+ const emitReductionLoopBody = () => {
5177
+ emitReductionStep();
5178
+ bumpLocal(ridx, 1);
5179
+ };
5180
+ if (ridxEnd === void 0) emitLoopWhileConstLt(ridx, re.size, emitReductionLoopBody);
5181
+ else emitLoopWhileLocalLt(ridx, ridxEnd, emitReductionLoopBody);
5182
+ if (hasIdentityEpilogue) {
5183
+ for (let i = 0; i < gidxs.length; i++) {
5184
+ emitOutputAddress(gidxs[i]);
5185
+ cg.local.get(vecAccs[i]);
5186
+ cg.v128.store(4);
5187
+ }
5188
+ return;
5189
+ }
5190
+ const laneGidx = cg.local.declare(cg.i32);
5191
+ const laneAcc = cg.local.declare(reIsInt ? cg.i32 : cg.f32);
5192
+ for (let i = 0; i < gidxs.length; i++) for (let lane = 0; lane < simdLanes; lane++) {
5193
+ cg.local.get(kernel.nargs);
5194
+ cg.local.get(gidxs[i]);
5195
+ if (lane > 0) {
5196
+ cg.i32.const(lane);
5197
+ cg.i32.add();
5198
+ }
5199
+ cg.local.tee(laneGidx);
5200
+ cg.i32.const(byteWidth(kernel.dtype));
5201
+ cg.i32.mul();
5202
+ cg.i32.add();
5203
+ cg.local.get(vecAccs[i]);
5204
+ if (reIsInt) cg.i32x4.extract_lane(lane);
5205
+ else cg.f32x4.extract_lane(lane);
5206
+ cg.local.set(laneAcc);
5207
+ translateExp(cg, funcs, tune.epilogue, {
5208
+ acc: laneAcc,
5209
+ gidx: laneGidx
5210
+ });
5211
+ dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
5212
+ }
5213
+ };
5214
+ const initializePointerMaps = (groups, candidates, keyFor, ctxFor, ridxOffset) => {
5215
+ const pointerMaps = groups.map(() => /* @__PURE__ */ new Map());
5216
+ const sharedPointers = /* @__PURE__ */ new Map();
5217
+ const uniquePointers = [];
5218
+ for (let i = 0; i < groups.length; i++) {
5219
+ const group = groups[i];
5220
+ for (const candidate of candidates) {
5221
+ const key = keyFor(candidate, group, i);
5222
+ let plan = sharedPointers.get(key);
5223
+ if (!plan) {
5224
+ plan = initializeReductionPointer(cg, funcs, candidate, ctxFor(group), key, ridxOffset);
5225
+ sharedPointers.set(key, plan);
5226
+ uniquePointers.push(plan);
5227
+ }
5228
+ pointerMaps[i].set(candidate.exp, plan);
5229
+ }
5230
+ }
5231
+ return {
5232
+ pointerMaps,
5233
+ uniquePointers
5234
+ };
5235
+ };
5236
+ const emitSimdReductionStep = () => {
5237
+ const groups = [{
5238
+ gidx,
5239
+ row: 0,
5240
+ vector: 0
5241
+ }];
5242
+ const { pointerMaps, uniquePointers } = initializePointerMaps(groups, simdReductionPointerCandidates, (candidate, group, i) => pointerShareKey(candidate, group.row, group.vector, i), (group) => ({ gidx: group.gidx }));
5243
+ emitSimdReductionForGidxs([gidx], pointerMaps, uniquePointers);
5244
+ bumpLocal(gidx, simdLanes);
5245
+ };
5246
+ const emitElementwiseSimdStep = () => {
5247
+ emitOutputAddress(gidx);
5248
+ translateExpSimd(cg, funcs, tune.exp, { gidx }, expStrides);
5249
+ cg.v128.store(4);
5250
+ bumpLocal(gidx, simdLanes);
5251
+ };
5252
+ const emitKSimdReductionForGroups = (groups, pointerMaps, uniquePointers) => {
5253
+ if (!re) throw new Error("internal: missing reduction");
5254
+ if (!kSimdTilePlan) throw new Error("internal: missing K SIMD plan");
5255
+ const vecAccs = groups.map(() => cg.local.declare(cg.f32x4));
5256
+ for (const acc of vecAccs) {
5257
+ cg.f32.const(re.identity);
5258
+ cg.f32x4.splat();
5259
+ cg.local.set(acc);
5260
+ }
5261
+ const ridx = cg.local.declare(cg.i32);
5262
+ setLocalConst(ridx, 0);
5263
+ emitLoopWhileConstLt(ridx, re.size, () => {
5264
+ for (let u = 0; u < kSimdTilePlan.kUnroll; u++) {
5265
+ const pointerValueCache = /* @__PURE__ */ new Map();
5266
+ for (let i = 0; i < groups.length; i++) {
5267
+ const valueAlreadyAccumulated = emitSimdExpWithAccumulator({
5268
+ gidx: groups[i].gidx,
5269
+ ridx
5270
+ }, pointerMaps[i], pointerValueCache, vecAccs[i]);
5271
+ if (!valueAlreadyAccumulated) cg.f32x4.add();
5272
+ cg.local.set(vecAccs[i]);
5273
+ }
5274
+ incrementReductionPointers(cg, uniquePointers, simdLanes);
5275
+ }
5276
+ bumpLocal(ridx, simdLanes * kSimdTilePlan.kUnroll);
5277
+ });
5278
+ for (let i = 0; i < groups.length; i++) {
5279
+ const acc = cg.local.declare(cg.f32);
5280
+ for (let lane = 0; lane < simdLanes; lane++) {
5281
+ cg.local.get(vecAccs[i]);
5282
+ cg.f32x4.extract_lane(lane);
5283
+ if (lane > 0) cg.f32.add();
5284
+ }
5285
+ cg.local.set(acc);
5286
+ emitOutputAddress(groups[i].gidx);
5287
+ translateExp(cg, funcs, tune.epilogue, {
5288
+ acc,
5289
+ gidx: groups[i].gidx
5290
+ });
5291
+ dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
5292
+ }
5293
+ };
5294
+ const emitKSimdReductionLoop = (plan) => {
5295
+ const colTile = cg.local.declare(cg.i32);
5296
+ const col = cg.local.declare(cg.i32);
5297
+ const rowTileBase = cg.local.declare(cg.i32);
5298
+ const rowOffset = cg.local.declare(cg.i32);
5299
+ const rowBase = cg.local.declare(cg.i32);
5300
+ const emitBlock = () => {
5301
+ const groups = [];
5302
+ for (let row = 0; row < plan.microRows; row++) for (let colOffset = 0; colOffset < plan.microCols; colOffset++) groups.push({
5303
+ gidx: declareTileGidx(rowBase, col, row * plan.tileSize, colOffset),
5304
+ row,
5305
+ col: colOffset
5306
+ });
5307
+ if (!kSimdTilePlan) throw new Error("internal: missing K SIMD plan");
5308
+ const { pointerMaps, uniquePointers } = initializePointerMaps(groups, kSimdReductionPointerCandidates, (candidate, group, i) => kReductionPointerShareKey(candidate, expStrides.get(candidate.exp) ?? { kind: "gather" }, kSimdTilePlan.tileSize, group.row, group.col, i), (group) => ({ gidx: group.gidx }));
5309
+ emitKSimdReductionForGroups(groups, pointerMaps, uniquePointers);
5310
+ };
5311
+ emitLoopWhileLocalLt(gidx, paramEnd, () => {
5312
+ copyLocal(rowTileBase, gidx);
5313
+ setLocalConst(colTile, 0);
5314
+ emitLoopWhileConstLt(colTile, plan.tileSize, () => {
5315
+ setLocalConst(rowOffset, 0);
5316
+ emitLoopWhileConstLt(rowOffset, plan.tileRows, () => {
5317
+ copyLocal(col, colTile);
5318
+ emitLoopWhileLt(col, () => emitLocalPlusConst(colTile, plan.tileCols), () => {
5319
+ setRowBase(rowBase, rowTileBase, rowOffset, plan.tileSize);
5320
+ emitBlock();
5321
+ bumpLocal(col, plan.microCols);
5322
+ });
5323
+ bumpLocal(rowOffset, plan.microRows);
5324
+ });
5325
+ bumpLocal(colTile, plan.tileCols);
5326
+ });
5327
+ bumpLocal(gidx, plan.tileRows * plan.tileSize);
5328
+ });
5329
+ };
5330
+ const emitTiledSimdReductionLoop = (plan) => {
5331
+ const colTile = cg.local.declare(cg.i32);
5332
+ const col = cg.local.declare(cg.i32);
5333
+ const kTile = cg.local.declare(cg.i32);
5334
+ const tileEnd = cg.local.declare(cg.i32);
5335
+ const rowTileBase = cg.local.declare(cg.i32);
5336
+ const rowOffset = cg.local.declare(cg.i32);
5337
+ const rowBase = cg.local.declare(cg.i32);
5338
+ const emitRowBlock = () => {
5339
+ const groups = [];
5340
+ for (let row = 0; row < plan.microRows; row++) for (let vector = 0; vector < plan.microVectors; vector++) groups.push({
5341
+ gidx: declareTileGidx(rowBase, col, row * plan.tileSize, vector * simdLanes),
5342
+ row,
5343
+ vector
5344
+ });
5345
+ const { pointerMaps, uniquePointers } = initializePointerMaps(groups, simdReductionPointerCandidates, (candidate, group, i) => pointerShareKey(candidate, group.row, group.vector, i), (group) => ({ gidx: group.gidx }), kTile);
5346
+ emitSimdReductionForGidxs(groups.map((group) => group.gidx), pointerMaps, uniquePointers, kTile, tileEnd);
5347
+ };
5348
+ emitLoopWhileLocalLt(gidx, paramEnd, () => {
5349
+ copyLocal(rowTileBase, gidx);
5350
+ setLocalConst(colTile, 0);
5351
+ emitLoopWhileConstLt(colTile, plan.tileSize, () => {
5352
+ setLocalConst(kTile, 0);
5353
+ emitLoopWhileConstLt(kTile, re.size, () => {
5354
+ emitLocalPlusConst(kTile, plan.tileK);
5355
+ cg.local.set(tileEnd);
5356
+ setLocalConst(rowOffset, 0);
5357
+ emitLoopWhileConstLt(rowOffset, plan.tileRows, () => {
5358
+ copyLocal(col, colTile);
5359
+ emitLoopWhileLt(col, () => emitLocalPlusConst(colTile, plan.tileVectors * simdLanes), () => {
5360
+ setRowBase(rowBase, rowTileBase, rowOffset, plan.tileSize);
5361
+ emitRowBlock();
5362
+ bumpLocal(col, plan.microVectors * simdLanes);
5363
+ });
5364
+ bumpLocal(rowOffset, plan.microRows);
5365
+ });
5366
+ bumpLocal(kTile, plan.tileK);
5367
+ });
5368
+ bumpLocal(colTile, plan.tileVectors * simdLanes);
5369
+ });
5370
+ bumpLocal(gidx, plan.tileRows * plan.tileSize);
5371
+ });
5372
+ };
5373
+ const emitGuardedFastPath = (alignment, emit) => {
5374
+ emitAlignmentGuard(cg, paramBegin, paramEnd, alignment);
5375
+ emit();
5376
+ cg.return();
5377
+ cg.end();
5378
+ };
5379
+ if (kSimdTilePlan) emitGuardedFastPath(kSimdTilePlan.tileRows * kSimdTilePlan.tileSize, () => emitKSimdReductionLoop(kSimdTilePlan));
5380
+ if (useSimd) {
5381
+ if (simdTilePlan) emitGuardedFastPath(simdTilePlan.tileRows * simdTilePlan.tileSize, () => emitTiledSimdReductionLoop(simdTilePlan));
5382
+ emitGuardedFastPath(simdLanes, () => {
5383
+ emitLoopWhileLocalLt(gidx, paramEnd, re ? emitSimdReductionStep : emitElementwiseSimdStep);
5384
+ });
4997
5385
  }
5386
+ emitLoopWhileLocalLt(gidx, paramEnd, () => {
5387
+ emitOutputAddress(gidx);
5388
+ if (re) {
5389
+ const acc = cg.local.declare(dty(cg, null, kernel.exp.dtype));
5390
+ dty(cg, null, kernel.exp.dtype).const(re.identity);
5391
+ cg.local.set(acc);
5392
+ const ridx = cg.local.declare(cg.i32);
5393
+ const emitReductionStep = () => {
5394
+ translateExp(cg, funcs, tune.exp, {
5395
+ gidx,
5396
+ ridx
5397
+ });
5398
+ emitScalarReductionOp(cg, re, acc);
5399
+ cg.local.set(acc);
5400
+ };
5401
+ setLocalConst(ridx, 0);
5402
+ emitLoopWhileConstLt(ridx, re.size, () => {
5403
+ emitReductionStep();
5404
+ bumpLocal(ridx, 1);
5405
+ });
5406
+ translateExp(cg, funcs, tune.epilogue, {
5407
+ acc,
5408
+ gidx
5409
+ });
5410
+ } else translateExp(cg, funcs, tune.exp, { gidx });
5411
+ dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
5412
+ bumpLocal(gidx, 1);
5413
+ });
5414
+ });
5415
+ cg.export(kernelFunc, "kernel");
5416
+ const tiledPlan = simdTilePlan ?? kSimdTilePlan;
5417
+ return {
5418
+ bytes: cg.finish(),
5419
+ workSize: kernel.size,
5420
+ chunkAlignment: tiledPlan ? tiledPlan.tileRows * tiledPlan.tileSize : 16,
5421
+ minWorkPerWorker: tiledPlan && kernel.size / tiledPlan.tileSize >= 1024 ? tiledPlan.tileSize * 32 : 256
4998
5422
  };
4999
- countReferences(exp);
5000
- gen(exp);
5001
5423
  }
5002
- function dty(cg, op, dtype) {
5003
- switch (dtype) {
5004
- case DType.Float32: return cg.f32;
5005
- case DType.Float64: return cg.f64;
5006
- case DType.Int32:
5007
- case DType.Uint32:
5008
- case DType.Bool: return cg.i32;
5009
- default: throw new UnsupportedOpError(op, dtype, "wasm");
5424
+
5425
+ //#endregion
5426
+ //#region src/backend/wasm.ts
5427
+ const compiledProgramCache = /* @__PURE__ */ new Map();
5428
+ /** Backend that compiles into WebAssembly bytecode for immediate execution. */
5429
+ var WasmBackend = class {
5430
+ type = "wasm";
5431
+ maxArgs = 64;
5432
+ #memory;
5433
+ #nextSlot;
5434
+ #allocator;
5435
+ #buffers;
5436
+ #workerPool;
5437
+ #pendingWork = /* @__PURE__ */ new Map();
5438
+ constructor() {
5439
+ this.#memory = hasSharedArrayBuffer() ? new WebAssembly.Memory({
5440
+ initial: 0,
5441
+ maximum: 65536,
5442
+ shared: true
5443
+ }) : new WebAssembly.Memory({ initial: 0 });
5444
+ this.#allocator = new WasmAllocator(this.#memory);
5445
+ this.#nextSlot = 1;
5446
+ this.#buffers = /* @__PURE__ */ new Map();
5447
+ this.#workerPool = createWorkerPool(this.#memory);
5010
5448
  }
5011
- }
5012
- function dtyF(cg, op, dtype) {
5013
- switch (dtype) {
5014
- case DType.Float32: return cg.f32;
5015
- case DType.Float64: return cg.f64;
5016
- default: throw new UnsupportedOpError(op, dtype, "wasm");
5449
+ malloc(size, initialData) {
5450
+ const ptr = this.#allocator.malloc(size);
5451
+ if (initialData) {
5452
+ if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
5453
+ new Uint8Array(this.#memory.buffer, ptr, size).set(initialData);
5454
+ }
5455
+ const slot = this.#nextSlot++;
5456
+ this.#buffers.set(slot, {
5457
+ ptr,
5458
+ size,
5459
+ ref: 1
5460
+ });
5461
+ return slot;
5017
5462
  }
5018
- }
5463
+ incRef(slot) {
5464
+ const buffer = this.#buffers.get(slot);
5465
+ if (!buffer) throw new SlotError(slot);
5466
+ buffer.ref++;
5467
+ }
5468
+ decRef(slot) {
5469
+ const buffer = this.#buffers.get(slot);
5470
+ if (!buffer) throw new SlotError(slot);
5471
+ buffer.ref--;
5472
+ if (buffer.ref === 0) {
5473
+ this.#allocator.free(buffer.ptr);
5474
+ this.#buffers.delete(slot);
5475
+ this.#pendingWork.delete(slot);
5476
+ }
5477
+ }
5478
+ async read(slot, start, count) {
5479
+ const epoch = this.#pendingWork.get(slot);
5480
+ if (epoch) await this.#workerPool.waitForEpoch(epoch);
5481
+ return this.#readData(slot, start, count);
5482
+ }
5483
+ readSync(slot, start, count) {
5484
+ const epoch = this.#pendingWork.get(slot);
5485
+ if (epoch && this.#workerPool.epoch < epoch) throw new Error("cannot read synchronously from a slot with async work");
5486
+ return this.#readData(slot, start, count);
5487
+ }
5488
+ #readData(slot, start, count) {
5489
+ const buffer = this.#getBuffer(slot);
5490
+ if (start === void 0) start = 0;
5491
+ if (count === void 0) count = buffer.byteLength - start;
5492
+ if (hasSharedArrayBuffer() && buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
5493
+ else return buffer.slice(start, start + count);
5494
+ }
5495
+ async prepareKernel(kernel) {
5496
+ const kernelHash = FpHash.hash(kernel);
5497
+ const hashKey = kernelHash.toString();
5498
+ const program = await runWithCacheAsync(compiledProgramCache, hashKey, async () => {
5499
+ const { bytes,...metadata } = codegenWasm(kernel);
5500
+ const module$1 = await WebAssembly.compile(bytes);
5501
+ return {
5502
+ module: module$1,
5503
+ ...metadata
5504
+ };
5505
+ });
5506
+ return new Executable(kernel, {
5507
+ program,
5508
+ sync: false
5509
+ });
5510
+ }
5511
+ prepareKernelSync(kernel) {
5512
+ const kernelHash = FpHash.hash(kernel);
5513
+ const hashKey = kernelHash.toString();
5514
+ const compiled = runWithCache(compiledProgramCache, hashKey, () => {
5515
+ const { bytes,...metadata } = codegenWasm(kernel);
5516
+ const module$1 = new WebAssembly.Module(bytes);
5517
+ return {
5518
+ module: module$1,
5519
+ ...metadata
5520
+ };
5521
+ });
5522
+ return new Executable(kernel, {
5523
+ program: compiled,
5524
+ sync: true
5525
+ });
5526
+ }
5527
+ async prepareRoutine(routine) {
5528
+ return this.prepareRoutineSync(routine);
5529
+ }
5530
+ prepareRoutineSync(routine) {
5531
+ return new Executable(routine, {
5532
+ program: void 0,
5533
+ sync: true
5534
+ });
5535
+ }
5536
+ dispatch(exe, inputs, outputs) {
5537
+ const tracing = isTracing();
5538
+ const start = tracing ? performance.now() : 0;
5539
+ if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
5540
+ else {
5541
+ const { program, sync } = exe.data;
5542
+ const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
5543
+ if (this.#workerPool && !sync) {
5544
+ const retainedSlots = [...inputs, ...outputs];
5545
+ for (const slot of retainedSlots) this.incRef(slot);
5546
+ const epoch = this.#workerPool.dispatch(program.module, ptrs, program.workSize, program.chunkAlignment, program.minWorkPerWorker);
5547
+ for (const slot of outputs) this.#pendingWork.set(slot, epoch);
5548
+ this.#workerPool.waitForEpoch(epoch).then(() => {
5549
+ for (const slot of outputs) if (this.#pendingWork.get(slot) === epoch) this.#pendingWork.delete(slot);
5550
+ for (const slot of retainedSlots) this.decRef(slot);
5551
+ });
5552
+ } else {
5553
+ if (inputs.some((slot) => {
5554
+ const epoch = this.#pendingWork.get(slot);
5555
+ return epoch && this.#workerPool.epoch < epoch;
5556
+ })) throw new Error("cannot dispatch synchronously with pending async work");
5557
+ const instance = new WebAssembly.Instance(program.module, { env: { memory: this.#memory } });
5558
+ const func = instance.exports.kernel;
5559
+ func(...ptrs, 0, program.workSize);
5560
+ }
5561
+ }
5562
+ if (tracing) {
5563
+ const info = traceSourceInfo(exe.source);
5564
+ emitTrace("wasm", info, start, performance.now());
5565
+ }
5566
+ }
5567
+ #getBuffer(slot) {
5568
+ const buffer = this.#buffers.get(slot);
5569
+ if (!buffer) throw new SlotError(slot);
5570
+ return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
5571
+ }
5572
+ };
5019
5573
 
5020
5574
  //#endregion
5021
5575
  //#region src/backend.ts
@@ -5062,7 +5616,7 @@ async function createBackend(device) {
5062
5616
  if (!navigator.gpu) return null;
5063
5617
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
5064
5618
  if (!adapter) return null;
5065
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-DDGCYtHa.cjs"));
5619
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-pWnE96Xc.cjs"));
5066
5620
  const importantLimits = [
5067
5621
  "maxBufferSize",
5068
5622
  "maxComputeInvocationsPerWorkgroup",
@@ -5100,7 +5654,7 @@ async function createBackend(device) {
5100
5654
  });
5101
5655
  if (!gl) return null;
5102
5656
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
5103
- const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-pbfUGDA6.cjs"));
5657
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-C6rCbloA.cjs"));
5104
5658
  return new WebGLBackend(gl);
5105
5659
  } else throw new Error(`Backend not found: ${device}`);
5106
5660
  }