@jax-js/jax 0.1.9 → 0.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -313,6 +313,16 @@ function runWithCache(cache, key, thunk) {
313
313
  return value;
314
314
  }
315
315
  }
316
+ /** Async version of `runWithCache`. */
317
+ async function runWithCacheAsync(cache, key, thunk) {
318
+ const keyStr = JSON.stringify(key);
319
+ if (cache.has(keyStr)) return cache.get(keyStr);
320
+ else {
321
+ const value = await thunk();
322
+ cache.set(keyStr, value);
323
+ return value;
324
+ }
325
+ }
316
326
 
317
327
  //#endregion
318
328
  //#region src/alu.ts
@@ -415,8 +425,23 @@ var AluExp = class AluExp {
415
425
  this.src = src;
416
426
  this.arg = arg;
417
427
  if (AluGroup.RequiredFloat.has(op) && !isFloatDtype(dtype)) throw new TypeError(`Unsupported dtype for ${op}: ${dtype}`);
418
- if (op === AluOp.Bitcast && (dtype === DType.Bool || src[0].dtype === DType.Bool || byteWidth(dtype) !== byteWidth(src[0].dtype))) throw new TypeError(`Bitcast from ${src[0].dtype} -> ${dtype}`);
419
- if (op === AluOp.Threefry2x32 && (dtype !== DType.Uint32 || src.some((x) => x.dtype !== DType.Uint32))) throw new TypeError("Threefry2x32 requires uint32 types");
428
+ switch (op) {
429
+ case AluOp.Bitcast:
430
+ if (dtype === DType.Bool || src[0].dtype === DType.Bool || byteWidth(dtype) !== byteWidth(src[0].dtype)) throw new TypeError(`Bitcast from ${src[0].dtype} -> ${dtype}`);
431
+ break;
432
+ case AluOp.Threefry2x32:
433
+ if (dtype !== DType.Uint32 || src.some((x) => x.dtype !== DType.Uint32)) throw new TypeError("Threefry2x32 requires uint32 types");
434
+ break;
435
+ case AluOp.BitCombine:
436
+ if (src[0].dtype !== src[1].dtype || isFloatDtype(src[0].dtype)) throw new TypeError(`BitCombine[${arg}] requires matching integral dtype, got ${src[0].dtype} and ${src[1].dtype}`);
437
+ break;
438
+ case AluOp.BitShift:
439
+ if (src[0].dtype === DType.Bool || src[1].dtype === DType.Bool || isFloatDtype(src[0].dtype) || isFloatDtype(src[1].dtype)) throw new TypeError(`BitShift[${arg}] requires two integral, non-bool dtypes, got ${src[0].dtype} and ${src[1].dtype}`);
440
+ break;
441
+ case AluOp.BitInvert:
442
+ if (isFloatDtype(src[0].dtype)) throw new TypeError(`BitInvert requires an integral dtype, got ${src[0].dtype}`);
443
+ break;
444
+ }
420
445
  }
421
446
  static add(a, b) {
422
447
  return new AluExp(AluOp.Add, a.dtype, [a, b]);
@@ -493,6 +518,12 @@ var AluExp = class AluExp {
493
518
  c1
494
519
  ], mode);
495
520
  }
521
+ static bitCombine(a, b, mode) {
522
+ return new AluExp(AluOp.BitCombine, a.dtype, [a, b], mode);
523
+ }
524
+ static bitShift(a, b, mode) {
525
+ return new AluExp(AluOp.BitShift, a.dtype, [a, b], mode);
526
+ }
496
527
  static cmplt(a, b) {
497
528
  return new AluExp(AluOp.Cmplt, DType.Bool, [a, b]);
498
529
  }
@@ -965,6 +996,16 @@ var AluExp = class AluExp {
965
996
  case AluOp.Mod: return x % y;
966
997
  case AluOp.Min: return Math.min(x, y);
967
998
  case AluOp.Max: return Math.max(x, y);
999
+ case AluOp.BitCombine: {
1000
+ let r;
1001
+ if (this.arg === "and") r = x & y;
1002
+ else if (this.arg === "or") r = x | y;
1003
+ else r = x ^ y;
1004
+ return this.dtype === DType.Int32 ? r | 0 : r >>> 0;
1005
+ }
1006
+ case AluOp.BitShift:
1007
+ if (this.arg === "shl") return this.dtype === DType.Int32 ? x << y | 0 : x << y >>> 0;
1008
+ return x >>> y;
968
1009
  case AluOp.Cmplt: return Number(x < y);
969
1010
  case AluOp.Cmpne: return Number(x != y);
970
1011
  default: throw new Error(`Missing implemementation for ${this.op}`);
@@ -1086,6 +1127,18 @@ var AluExp = class AluExp {
1086
1127
  }
1087
1128
  if (BIN_SYM[node.op]) return `(${parts[0]} ${BIN_SYM[node.op]} ${parts[1]})`;
1088
1129
  if (CMP_SYM[node.op]) return `(${parts[0]} ${CMP_SYM[node.op]} ${parts[1]})`;
1130
+ if (node.op === AluOp.BitCombine) {
1131
+ const sym = {
1132
+ and: "&",
1133
+ or: "|",
1134
+ xor: "^"
1135
+ }[node.arg];
1136
+ return `(${parts[0]} ${sym} ${parts[1]})`;
1137
+ }
1138
+ if (node.op === AluOp.BitShift) {
1139
+ const sym = node.arg === "shl" ? "<<" : ">>";
1140
+ return `(${parts[0]} ${sym} ${parts[1]})`;
1141
+ }
1089
1142
  if (UNARY_SYM[node.op]) return `${UNARY_SYM[node.op]}${parts[0]}`;
1090
1143
  if (node.op === AluOp.Cast) return `Cast<${node.dtype}>(${strip1(parts[0])})`;
1091
1144
  if (node.op === AluOp.Bitcast) return `Bitcast<${node.dtype}>(${strip1(parts[0])})`;
@@ -1178,6 +1231,9 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1178
1231
  AluOp$1["Reciprocal"] = "Reciprocal";
1179
1232
  AluOp$1["Cast"] = "Cast";
1180
1233
  AluOp$1["Bitcast"] = "Bitcast";
1234
+ AluOp$1["BitCombine"] = "BitCombine";
1235
+ AluOp$1["BitInvert"] = "BitInvert";
1236
+ AluOp$1["BitShift"] = "BitShift";
1181
1237
  AluOp$1["Cmplt"] = "Cmplt";
1182
1238
  AluOp$1["Cmpne"] = "Cmpne";
1183
1239
  AluOp$1["Where"] = "Where";
@@ -1197,7 +1253,9 @@ const AluGroup = {
1197
1253
  AluOp.Idiv,
1198
1254
  AluOp.Mod,
1199
1255
  AluOp.Min,
1200
- AluOp.Max
1256
+ AluOp.Max,
1257
+ AluOp.BitCombine,
1258
+ AluOp.BitShift
1201
1259
  ]),
1202
1260
  Unary: new Set([
1203
1261
  AluOp.Sin,
@@ -1312,6 +1370,10 @@ var Reduction = class {
1312
1370
  this.epilogue = epilogue;
1313
1371
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
1314
1372
  this.epilogue = epilogue.simplify();
1373
+ if (this.dtype === DType.Float16 && this.op === AluOp.Add) {
1374
+ this.epilogue = this.epilogue.substitute({ acc: AluExp.cast(this.dtype, AluVar.acc(DType.Float32)) });
1375
+ this.dtype = DType.Float32;
1376
+ }
1315
1377
  }
1316
1378
  hash(state) {
1317
1379
  state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
@@ -2266,11 +2328,15 @@ var TuneDims = class {
2266
2328
  };
2267
2329
  /** Tuning step that does not apply any optimization. */
2268
2330
  function tuneNullopt(kernel) {
2331
+ let exp = kernel.exp;
2269
2332
  const vars = {};
2270
2333
  vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
2271
- if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2334
+ if (kernel.reduction) {
2335
+ vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2336
+ if (exp.dtype !== kernel.reduction.dtype) exp = AluExp.cast(kernel.reduction.dtype, exp);
2337
+ }
2272
2338
  return {
2273
- exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
2339
+ exp: exp.substitute(vars).rewriteGlobalViews().simplify(),
2274
2340
  epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
2275
2341
  outputIdxExp: vars.gidx,
2276
2342
  threadCount: kernel.size,
@@ -2279,8 +2345,9 @@ function tuneNullopt(kernel) {
2279
2345
  }
2280
2346
  /** Tuning for WebGPU kernels. */
2281
2347
  function tuneWebgpu(kernel) {
2282
- const { exp, reduction } = kernel;
2348
+ const reduction = kernel.reduction;
2283
2349
  if (!reduction) return tuneNullopt(kernel);
2350
+ const exp = AluExp.cast(reduction.dtype, kernel.exp);
2284
2351
  const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
2285
2352
  if (globalIndexes.length > 0) {
2286
2353
  if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
@@ -2508,6 +2575,85 @@ var CpuBackend = class {
2508
2575
  }
2509
2576
  };
2510
2577
 
2578
+ //#endregion
2579
+ //#region src/tracing.ts
2580
+ let traceEnabled = false;
2581
+ const flushCallbacks = [];
2582
+ /**
2583
+ * Start collecting kernel traces.
2584
+ *
2585
+ * Traces appear in developer tools under the "Performance" tab, and they are
2586
+ * useful for measuring fine-grained kernel execution time.
2587
+ */
2588
+ function startTrace() {
2589
+ traceEnabled = true;
2590
+ }
2591
+ /**
2592
+ * Stop collecting kernel traces.
2593
+ *
2594
+ * Traces appear in developer tools under the "Performance" tab, and they are
2595
+ * useful for measuring fine-grained kernel execution time.
2596
+ */
2597
+ function stopTrace() {
2598
+ traceEnabled = false;
2599
+ for (const cb of flushCallbacks) cb();
2600
+ }
2601
+ /** Check if tracing is currently enabled. */
2602
+ function isTracing() {
2603
+ return traceEnabled;
2604
+ }
2605
+ /** Register a callback to flush pending trace data when tracing stops. */
2606
+ function onFlushTrace(cb) {
2607
+ flushCallbacks.push(cb);
2608
+ }
2609
+ function humanSize(n) {
2610
+ if (n >= 1e9) return `${(n / 1e9).toPrecision(3)}B`;
2611
+ if (n >= 1e6) return `${(n / 1e6).toPrecision(3)}M`;
2612
+ if (n >= 1e3) return `${(n / 1e3).toPrecision(3)}K`;
2613
+ return `${n}`;
2614
+ }
2615
+ /** Build a trace label, properties, and color from a kernel or routine source. */
2616
+ function traceSourceInfo(source) {
2617
+ const properties = [];
2618
+ let label;
2619
+ let color;
2620
+ if (source instanceof Kernel) {
2621
+ label = `Kernel[${humanSize(source.size)}]`;
2622
+ properties.push(["exp", `${source.exp}`]);
2623
+ properties.push(["size", `${source.size}`]);
2624
+ properties.push(["nargs", `${source.nargs}`]);
2625
+ if (!source.reduction) color = "primary";
2626
+ else {
2627
+ color = "secondary";
2628
+ properties.push(["reduction", `${source.reduction.op}:${source.reduction.size}`]);
2629
+ }
2630
+ } else {
2631
+ color = "tertiary";
2632
+ label = source.name;
2633
+ properties.push(["inputShapes", source.type.inputShapes.map((s) => `[${s}]`).join(", ")]);
2634
+ properties.push(["outputShapes", source.type.outputShapes.map((s) => `[${s}]`).join(", ")]);
2635
+ properties.push(["dtype", source.type.inputDtypes.join(", ")]);
2636
+ }
2637
+ return {
2638
+ label,
2639
+ color,
2640
+ properties
2641
+ };
2642
+ }
2643
+ /** Emit a trace entry as a `performance.measure` with devtools metadata. */
2644
+ function emitTrace(track, info, start, end) {
2645
+ performance.measure(info.label, {
2646
+ detail: { devtools: {
2647
+ trackGroup: "JAX Profiler",
2648
+ track,
2649
+ color: info.color,
2650
+ properties: info.properties
2651
+ } },
2652
+ start,
2653
+ end
2654
+ });
2655
+ }
2656
+
2511
2657
  //#endregion
2512
2658
  //#region src/backend/wasm/allocator.ts
2513
2659
  /** Simple tensor memory allocator for WebAssembly linear memory. */
@@ -3070,6 +3216,147 @@ function wasm_threefry2x32(cg) {
3070
3216
  });
3071
3217
  }
3072
3218
 
3219
+ //#endregion
3220
+ //#region src/backend/wasm/parallel.ts
3221
+ /** Check if SharedArrayBuffer is available. */
3222
+ function hasSharedArrayBuffer() {
3223
+ return typeof SharedArrayBuffer !== "undefined" && typeof Worker !== "undefined";
3224
+ }
3225
+ const MIN_ELEMS_PER_THREAD = 256;
3226
+ const WORKER_SOURCE = `
3227
+ let memory = null;
3228
+ let cachedModule = null;
3229
+ let cachedFunc = null;
3230
+
3231
+ self.onmessage = (e) => {
3232
+ const msg = e.data;
3233
+ if (msg.type === "init") {
3234
+ memory = msg.memory;
3235
+ postMessage({ type: "ready" });
3236
+ return;
3237
+ }
3238
+ try {
3239
+ const { module, ptrs, begin, end } = msg;
3240
+ if (module !== cachedModule) {
3241
+ cachedModule = module;
3242
+ const instance = new WebAssembly.Instance(module, { env: { memory } });
3243
+ cachedFunc = instance.exports.kernel;
3244
+ }
3245
+ cachedFunc(...ptrs, begin, end);
3246
+ postMessage({ type: "done", ok: true });
3247
+ } catch (err) {
3248
+ postMessage({ type: "done", ok: false, error: String(err) });
3249
+ }
3250
+ };
3251
+ `;
3252
+ /** Pool of Web Workers for parallel WASM kernel dispatch. */
3253
+ var WasmWorkerPool = class {
3254
+ #memory;
3255
+ #numWorkers;
3256
+ #workers = [];
3257
+ #ready = Promise.resolve();
3258
+ /** Serializes dispatches so concurrent read() calls don't clobber onmessage. */
3259
+ #queue = Promise.resolve();
3260
+ #epoch = 0n;
3261
+ #epochEnd = 0n;
3262
+ #hooks = /* @__PURE__ */ new Map();
3263
+ constructor(memory, numWorkers) {
3264
+ if (numWorkers <= 0) throw new Error("numWorkers must be positive");
3265
+ this.#memory = memory;
3266
+ this.#numWorkers = numWorkers;
3267
+ }
3268
+ get epoch() {
3269
+ return this.#epoch;
3270
+ }
3271
+ waitForEpoch(target) {
3272
+ if (target <= this.#epoch) return Promise.resolve();
3273
+ return new Promise((resolve) => {
3274
+ if (target <= this.#epoch) return resolve();
3275
+ const hooks = this.#hooks.get(target);
3276
+ if (hooks) hooks.push(resolve);
3277
+ else this.#hooks.set(target, [resolve]);
3278
+ });
3279
+ }
3280
+ #ensureInit() {
3281
+ if (this.#workers.length > 0) return;
3282
+ const blob = new Blob([WORKER_SOURCE], { type: "application/javascript" });
3283
+ const url = URL.createObjectURL(blob);
3284
+ this.#workers = [];
3285
+ const readyPromises = [];
3286
+ for (let i = 0; i < this.#numWorkers; i++) {
3287
+ const worker = new Worker(url, { type: "module" });
3288
+ this.#workers.push(worker);
3289
+ readyPromises.push(new Promise((resolve, reject) => {
3290
+ worker.onmessage = () => resolve();
3291
+ worker.onerror = (e) => reject(new Error(e.message || "Worker failed to load"));
3292
+ }));
3293
+ worker.postMessage({
3294
+ type: "init",
3295
+ memory: this.#memory
3296
+ });
3297
+ }
3298
+ this.#ready = Promise.all(readyPromises).then(() => {
3299
+ URL.revokeObjectURL(url);
3300
+ });
3301
+ this.#queue = this.#ready;
3302
+ }
3303
+ /**
3304
+ * Dispatch a kernel across multiple workers.
3305
+ *
3306
+ * Returns an epoch that can be used to wait for the ongoing work to complete,
3307
+ * which is guaranteed to be monotonically increasing.
3308
+ */
3309
+ dispatch(module, ptrs, size) {
3310
+ this.#ensureInit();
3311
+ this.#epochEnd++;
3312
+ const result = this.#queue.then(() => this.#dispatchNow(module, ptrs, size));
3313
+ this.#queue = result.then(() => {}, () => {}).then(() => {
3314
+ this.#epoch++;
3315
+ const hooks = this.#hooks.get(this.#epoch);
3316
+ if (hooks) {
3317
+ for (const hook of hooks) hook();
3318
+ this.#hooks.delete(this.#epoch);
3319
+ }
3320
+ });
3321
+ return this.#epochEnd;
3322
+ }
3323
+ async #dispatchNow(module, ptrs, size) {
3324
+ if (size === 0) return;
3325
+ const n = Math.min(this.#workers.length, Math.ceil(size / MIN_ELEMS_PER_THREAD));
3326
+ const chunkSize = Math.ceil(size / n / 16) * 16;
3327
+ const promises = [];
3328
+ for (let i = 0; i < n; i++) {
3329
+ const begin = i * chunkSize;
3330
+ const end = Math.min(begin + chunkSize, size);
3331
+ if (begin >= size) break;
3332
+ const worker = this.#workers[i];
3333
+ promises.push(new Promise((resolve, reject) => {
3334
+ worker.onmessage = (e) => {
3335
+ if (e.data.ok) resolve();
3336
+ else reject(/* @__PURE__ */ new Error(`Worker error: ${e.data.error}`));
3337
+ };
3338
+ worker.postMessage({
3339
+ module,
3340
+ ptrs,
3341
+ begin,
3342
+ end
3343
+ });
3344
+ }));
3345
+ }
3346
+ await Promise.all(promises);
3347
+ }
3348
+ };
3349
+ /** Try to create a worker pool. Returns null if workers are unavailable. */
3350
+ function createWorkerPool(memory) {
3351
+ if (!hasSharedArrayBuffer()) return null;
3352
+ try {
3353
+ const numWorkers = Math.max(1, typeof navigator !== "undefined" && navigator.hardwareConcurrency || 4);
3354
+ return new WasmWorkerPool(memory, numWorkers);
3355
+ } catch {
3356
+ return null;
3357
+ }
3358
+ }
3359
+
3073
3360
  //#endregion
3074
3361
  //#region src/backend/wasm/wasmblr.ts
3075
3362
  /**
@@ -3407,7 +3694,7 @@ var CodeGenerator = class {
3407
3694
  concat(importSectionBytes, encodeString(this.memory.aString));
3408
3695
  concat(importSectionBytes, encodeString(this.memory.bString));
3409
3696
  importSectionBytes.push(2);
3410
- if (this.memory.min && this.memory.max) {
3697
+ if (this.memory.max) {
3411
3698
  if (this.memory.isShared) importSectionBytes.push(3);
3412
3699
  else importSectionBytes.push(1);
3413
3700
  concat(importSectionBytes, encodeUnsigned(this.memory.min));
@@ -3814,6 +4101,8 @@ var I32x4 = class extends V128 {
3814
4101
  min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
3815
4102
  max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
3816
4103
  max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
4104
+ trunc_sat_f32x4_s = VECTOR_OP("trunc_sat_f32x4_s", 248, ["v128"], "v128");
4105
+ trunc_sat_f32x4_u = VECTOR_OP("trunc_sat_f32x4_u", 249, ["v128"], "v128");
3817
4106
  };
3818
4107
  var F32x4 = class extends V128 {
3819
4108
  splat = VECTOR_OP("splat", 19, ["f32"], "v128");
@@ -3840,10 +4129,333 @@ var F32x4 = class extends V128 {
3840
4129
  max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
3841
4130
  pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
3842
4131
  pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
4132
+ convert_i32x4_s = VECTOR_OP("convert_i32x4_s", 250, ["v128"], "v128");
4133
+ convert_i32x4_u = VECTOR_OP("convert_i32x4_u", 251, ["v128"], "v128");
3843
4134
  };
3844
4135
 
3845
4136
  //#endregion
3846
4137
  //#region src/backend/wasm.ts
4138
+ /**
4139
+ * SIMD version of translateExp: emits v128 (f32x4 or i32x4) instructions instead of scalar.
4140
+ * gidx always steps by 4. strideMap classifies each GlobalIndex as broadcast/contiguous/gather.
4141
+ */
4142
+ function translateExpSimd(cg, funcs, exp, ctx, strideMap) {
4143
+ const references = /* @__PURE__ */ new Map();
4144
+ const seen = /* @__PURE__ */ new Set();
4145
+ const countReferences = (exp$1) => {
4146
+ references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
4147
+ if (!seen.has(exp$1)) {
4148
+ seen.add(exp$1);
4149
+ for (const src of exp$1.src) countReferences(src);
4150
+ }
4151
+ };
4152
+ const expContext = /* @__PURE__ */ new Map();
4153
+ const gen = (exp$1) => {
4154
+ if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
4155
+ const { op, src, arg, dtype } = exp$1;
4156
+ const isInt = dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool;
4157
+ const isSigned = dtype === DType.Int32;
4158
+ if (op === AluOp.Add) {
4159
+ gen(src[0]);
4160
+ gen(src[1]);
4161
+ if (isInt) cg.i32x4.add();
4162
+ else cg.f32x4.add();
4163
+ } else if (op === AluOp.Sub) {
4164
+ gen(src[0]);
4165
+ gen(src[1]);
4166
+ if (isInt) cg.i32x4.sub();
4167
+ else cg.f32x4.sub();
4168
+ } else if (op === AluOp.Mul) {
4169
+ gen(src[0]);
4170
+ gen(src[1]);
4171
+ if (isInt) cg.i32x4.mul();
4172
+ else cg.f32x4.mul();
4173
+ } else if (op === AluOp.Min) {
4174
+ gen(src[0]);
4175
+ gen(src[1]);
4176
+ if (isInt) if (isSigned) cg.i32x4.min_s();
4177
+ else cg.i32x4.min_u();
4178
+ else cg.f32x4.min();
4179
+ } else if (op === AluOp.Max) {
4180
+ gen(src[0]);
4181
+ gen(src[1]);
4182
+ if (isInt) if (isSigned) cg.i32x4.max_s();
4183
+ else cg.i32x4.max_u();
4184
+ else cg.f32x4.max();
4185
+ } else if (op === AluOp.Sqrt) {
4186
+ gen(src[0]);
4187
+ cg.f32x4.sqrt();
4188
+ } else if (op === AluOp.Floor) {
4189
+ gen(src[0]);
4190
+ cg.f32x4.floor();
4191
+ } else if (op === AluOp.Ceil) {
4192
+ gen(src[0]);
4193
+ cg.f32x4.ceil();
4194
+ } else if (op === AluOp.Const) if (isInt) {
4195
+ cg.i32.const(arg);
4196
+ cg.i32x4.splat();
4197
+ } else {
4198
+ cg.f32.const(arg);
4199
+ cg.f32x4.splat();
4200
+ }
4201
+ else if (op === AluOp.Cast) {
4202
+ gen(src[0]);
4203
+ const dtype0 = src[0].dtype;
4204
+ const src0IsInt = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
4205
+ if (isInt && !src0IsInt) if (isSigned) cg.i32x4.trunc_sat_f32x4_s();
4206
+ else cg.i32x4.trunc_sat_f32x4_u();
4207
+ else if (!isInt && src0IsInt) if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32x4.convert_i32x4_s();
4208
+ else cg.f32x4.convert_i32x4_u();
4209
+ } else if (op === AluOp.Cmplt) {
4210
+ gen(src[0]);
4211
+ gen(src[1]);
4212
+ const srcDtype = src[0].dtype;
4213
+ if (srcDtype === DType.Float32) cg.f32x4.lt();
4214
+ else if (srcDtype === DType.Int32) cg.i32x4.lt_s();
4215
+ else if (srcDtype === DType.Uint32) cg.i32x4.lt_u();
4216
+ else throw new UnsupportedOpError(op, dtype, "wasm");
4217
+ cg.i32.const(1);
4218
+ cg.i32x4.splat();
4219
+ cg.v128.and();
4220
+ } else if (op === AluOp.Cmpne) {
4221
+ gen(src[0]);
4222
+ gen(src[1]);
4223
+ const srcDtype = src[0].dtype;
4224
+ if (srcDtype === DType.Float32) cg.f32x4.ne();
4225
+ else cg.i32x4.ne();
4226
+ cg.i32.const(1);
4227
+ cg.i32x4.splat();
4228
+ cg.v128.and();
4229
+ } else if (op === AluOp.Where) {
4230
+ gen(src[1]);
4231
+ gen(src[2]);
4232
+ gen(src[0]);
4233
+ cg.i32.const(0);
4234
+ cg.i32x4.splat();
4235
+ cg.i32x4.ne();
4236
+ cg.v128.bitselect();
4237
+ } else if (op === AluOp.Variable || op === AluOp.Special) throw new Error(`translateExpSimd: unexpected ${op}(${arg})`);
4238
+ else if (op === AluOp.GlobalIndex) {
4239
+ const [gid, len] = arg;
4240
+ const indexSubtree = src[0];
4241
+ const stride = strideMap.get(exp$1) ?? GATHER;
4242
+ if (stride.kind === "contiguous") {
4243
+ translateExp(cg, funcs, indexSubtree, ctx);
4244
+ {
4245
+ const maxIdx = Math.max(len - SIMD_LANES, 0);
4246
+ const wideIdx = cg.local.declare(cg.i32);
4247
+ cg.local.set(wideIdx);
4248
+ cg.local.get(wideIdx);
4249
+ cg.i32.const(maxIdx);
4250
+ cg.local.get(wideIdx);
4251
+ cg.i32.const(maxIdx);
4252
+ cg.i32.lt_u();
4253
+ cg.select();
4254
+ }
4255
+ cg.i32.const(byteWidth(dtype));
4256
+ cg.i32.mul();
4257
+ cg.local.get(gid);
4258
+ cg.i32.add();
4259
+ if (isInt) cg.i32x4.load(4);
4260
+ else cg.f32x4.load(4);
4261
+ } else if (stride.kind === "broadcast") {
4262
+ translateExp(cg, funcs, indexSubtree, ctx);
4263
+ const local = cg.local.declare(cg.i32);
4264
+ cg.local.tee(local);
4265
+ cg.i32.const(0);
4266
+ cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
4267
+ cg.select();
4268
+ cg.i32.const(byteWidth(dtype));
4269
+ cg.i32.mul();
4270
+ cg.local.get(gid);
4271
+ cg.i32.add();
4272
+ if (isInt) {
4273
+ cg.i32.load(2);
4274
+ cg.i32x4.splat();
4275
+ } else {
4276
+ cg.f32.load(2);
4277
+ cg.f32x4.splat();
4278
+ }
4279
+ } else {
4280
+ const steppingLocal = ctx["gidx"];
4281
+ const origValue = cg.local.declare(cg.i32);
4282
+ cg.local.get(steppingLocal);
4283
+ cg.local.set(origValue);
4284
+ if (isInt) {
4285
+ cg.i32.const(0);
4286
+ cg.i32x4.splat();
4287
+ } else {
4288
+ cg.f32.const(0);
4289
+ cg.f32x4.splat();
4290
+ }
4291
+ const vec = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
4292
+ cg.local.set(vec);
4293
+ const idx = cg.local.declare(cg.i32);
4294
+ const scalarVal = cg.local.declare(isInt ? cg.i32 : cg.f32);
4295
+ for (let lane = 0; lane < SIMD_LANES; lane++) {
4296
+ cg.local.get(origValue);
4297
+ if (lane > 0) {
4298
+ cg.i32.const(lane);
4299
+ cg.i32.add();
4300
+ }
4301
+ cg.local.set(steppingLocal);
4302
+ translateExp(cg, funcs, indexSubtree, ctx);
4303
+ cg.local.tee(idx);
4304
+ cg.i32.const(0);
4305
+ cg.local.get(idx), cg.i32.const(len), cg.i32.lt_u();
4306
+ cg.select();
4307
+ cg.i32.const(byteWidth(dtype));
4308
+ cg.i32.mul();
4309
+ cg.local.get(gid);
4310
+ cg.i32.add();
4311
+ if (isInt) cg.i32.load(2);
4312
+ else cg.f32.load(2);
4313
+ cg.local.set(scalarVal);
4314
+ cg.local.get(vec);
4315
+ cg.local.get(scalarVal);
4316
+ if (isInt) cg.i32x4.replace_lane(lane);
4317
+ else cg.f32x4.replace_lane(lane);
4318
+ cg.local.set(vec);
4319
+ }
4320
+ cg.local.get(origValue);
4321
+ cg.local.set(steppingLocal);
4322
+ cg.local.get(vec);
4323
+ }
4324
+ } else throw new Error(`translateExpSimd: unsupported op ${op}`);
4325
+ if ((references.get(exp$1) ?? 0) > 1) {
4326
+ const local = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
4327
+ cg.local.tee(local);
4328
+ expContext.set(exp$1, local);
4329
+ }
4330
+ };
4331
+ countReferences(exp);
4332
+ gen(exp);
4333
+ }
4334
+ /** Number of SIMD lanes (f32x4 / i32x4 = 4 lanes). */
4335
+ const SIMD_LANES = 4;
4336
+ function referencesGidx(exp) {
4337
+ if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return true;
4338
+ return exp.src.some(referencesGidx);
4339
+ }
4340
+ /** When tileSize > N but doesn't divide evenly, the last group before the
4341
+ * inner reset is shorter than N — a SIMD group could straddle it. */
4342
+ function hasFragmentRisk(tileSize, N) {
4343
+ return isFinite(tileSize) && tileSize > N && tileSize % N !== 0;
4344
+ }
4345
+ const GATHER = { kind: "gather" };
4346
+ /**
4347
+ * Classify how a GlobalIndex's index expression behaves as gidx increments.
4348
+ */
4349
+ function analyzeStride(exp) {
4350
+ if (!referencesGidx(exp)) return {
4351
+ kind: "broadcast",
4352
+ tileSize: Infinity
4353
+ };
4354
+ if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return {
4355
+ kind: "contiguous",
4356
+ tileSize: Infinity
4357
+ };
4358
+ if (exp.op === AluOp.Idiv && exp.src[1].op === AluOp.Const) {
4359
+ const N = exp.src[1].arg;
4360
+ const inner = analyzeStride(exp.src[0]);
4361
+ if (inner.kind === "broadcast") return inner;
4362
+ if (inner.kind !== "contiguous") return GATHER;
4363
+ if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
4364
+ return {
4365
+ kind: "broadcast",
4366
+ tileSize: Math.min(inner.tileSize, N)
4367
+ };
4368
+ }
4369
+ if (exp.op === AluOp.Mod && exp.src[1].op === AluOp.Const) {
4370
+ const N = exp.src[1].arg;
4371
+ const inner = analyzeStride(exp.src[0]);
4372
+ if (inner.kind === "broadcast") return inner;
4373
+ if (inner.kind !== "contiguous") return GATHER;
4374
+ if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
4375
+ return {
4376
+ kind: "contiguous",
4377
+ tileSize: Math.min(inner.tileSize, N)
4378
+ };
4379
+ }
4380
+ if (exp.op === AluOp.Mul) {
4381
+ for (let i = 0; i < 2; i++) if (exp.src[i].op === AluOp.Const) {
4382
+ const inner = analyzeStride(exp.src[1 - i]);
4383
+ if (inner.kind === "broadcast") return inner;
4384
+ return GATHER;
4385
+ }
4386
+ }
4387
+ if (exp.op === AluOp.Add) {
4388
+ const lhsHasGidx = referencesGidx(exp.src[0]);
4389
+ const rhsHasGidx = referencesGidx(exp.src[1]);
4390
+ if (lhsHasGidx && !rhsHasGidx) return analyzeStride(exp.src[0]);
4391
+ if (!lhsHasGidx && rhsHasGidx) return analyzeStride(exp.src[1]);
4392
+ }
4393
+ return GATHER;
4394
+ }
4395
+ /** Ops that have direct SIMD (f32x4) instruction variants. */
4396
+ const simdF32Ops = new Set([
4397
+ AluOp.Add,
4398
+ AluOp.Sub,
4399
+ AluOp.Mul,
4400
+ AluOp.Floor,
4401
+ AluOp.Ceil,
4402
+ AluOp.Min,
4403
+ AluOp.Max,
4404
+ AluOp.Sqrt,
4405
+ AluOp.Cast,
4406
+ AluOp.Where,
4407
+ AluOp.Const,
4408
+ AluOp.GlobalIndex
4409
+ ]);
4410
+ /** Ops that have direct SIMD (i32x4) instruction variants. */
4411
+ const simdI32Ops = new Set([
4412
+ AluOp.Add,
4413
+ AluOp.Sub,
4414
+ AluOp.Mul,
4415
+ AluOp.Min,
4416
+ AluOp.Max,
4417
+ AluOp.Cast,
4418
+ AluOp.Where,
4419
+ AluOp.Const,
4420
+ AluOp.GlobalIndex
4421
+ ]);
4422
+ /** Ops that produce Bool (i32x4 bitmask) in SIMD. */
4423
+ const simdBoolOps = new Set([
4424
+ AluOp.Cmplt,
4425
+ AluOp.Cmpne,
4426
+ AluOp.Const,
4427
+ AluOp.GlobalIndex
4428
+ ]);
4429
+ /**
4430
+ * Check if a kernel is eligible for SIMD codegen.
4431
+ *
4432
+ * A kernel qualifies when:
4433
+ * - size >= 4 (need at least 4 elements for a SIMD group)
4434
+ * - For reductions: the reduction op has a SIMD variant for its dtype
4435
+ * - All nodes have a supported dtype (f32, i32, u32, bool) with SIMD variants
4436
+ */
4437
+ function isSimdEligible(tunedExp, kernel) {
4438
+ if (kernel.size < SIMD_LANES) return false;
4439
+ if (kernel.reduction) {
4440
+ if (!simdSupportedOpsForDtype(kernel.reduction.dtype)?.has(kernel.reduction.op)) return false;
4441
+ }
4442
+ const check = (exp, visited) => {
4443
+ if (visited.has(exp)) return true;
4444
+ visited.add(exp);
4445
+ const supportedOps = simdSupportedOpsForDtype(exp.dtype);
4446
+ if (!supportedOps || !supportedOps.has(exp.op)) return false;
4447
+ if (exp.op === AluOp.GlobalIndex) return true;
4448
+ for (const child of exp.src) if (!check(child, visited)) return false;
4449
+ return true;
4450
+ };
4451
+ return check(tunedExp, /* @__PURE__ */ new Set());
4452
+ }
4453
+ function simdSupportedOpsForDtype(dtype) {
4454
+ if (dtype === DType.Float32) return simdF32Ops;
4455
+ if (dtype === DType.Int32 || dtype === DType.Uint32) return simdI32Ops;
4456
+ if (dtype === DType.Bool) return simdBoolOps;
4457
+ return null;
4458
+ }
3847
4459
  const moduleCache = /* @__PURE__ */ new Map();
3848
4460
  /** Backend that compiles into WebAssembly bytecode for immediate execution. */
3849
4461
  var WasmBackend = class {
@@ -3853,11 +4465,18 @@ var WasmBackend = class {
3853
4465
  #nextSlot;
3854
4466
  #allocator;
3855
4467
  #buffers;
4468
+ #workerPool;
4469
+ #pendingWork = /* @__PURE__ */ new Map();
3856
4470
  constructor() {
3857
- this.#memory = new WebAssembly.Memory({ initial: 0 });
4471
+ this.#memory = hasSharedArrayBuffer() ? new WebAssembly.Memory({
4472
+ initial: 0,
4473
+ maximum: 65536,
4474
+ shared: true
4475
+ }) : new WebAssembly.Memory({ initial: 0 });
3858
4476
  this.#allocator = new WasmAllocator(this.#memory);
3859
4477
  this.#nextSlot = 1;
3860
4478
  this.#buffers = /* @__PURE__ */ new Map();
4479
+ this.#workerPool = createWorkerPool(this.#memory);
3861
4480
  }
3862
4481
  malloc(size, initialData) {
3863
4482
  const ptr = this.#allocator.malloc(size);
@@ -3888,37 +4507,70 @@ var WasmBackend = class {
3888
4507
  }
3889
4508
  }
3890
4509
  async read(slot, start, count) {
3891
- return this.readSync(slot, start, count);
4510
+ const epoch = this.#pendingWork.get(slot);
4511
+ if (epoch) await this.#workerPool.waitForEpoch(epoch);
4512
+ return this.#readData(slot, start, count);
3892
4513
  }
3893
4514
  readSync(slot, start, count) {
4515
+ const epoch = this.#pendingWork.get(slot);
4516
+ if (epoch && this.#workerPool.epoch < epoch) throw new Error("cannot read synchronously from a slot with async work");
4517
+ return this.#readData(slot, start, count);
4518
+ }
4519
+ #readData(slot, start, count) {
3894
4520
  const buffer = this.#getBuffer(slot);
3895
4521
  if (start === void 0) start = 0;
3896
4522
  if (count === void 0) count = buffer.byteLength - start;
3897
- return buffer.slice(start, start + count);
4523
+ if (buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
4524
+ else return buffer.slice(start, start + count);
3898
4525
  }
3899
4526
  async prepareKernel(kernel) {
3900
- return this.prepareKernelSync(kernel);
4527
+ const kernelHash = FpHash.hash(kernel);
4528
+ const module = await runWithCacheAsync(moduleCache, kernelHash.toString(), () => WebAssembly.compile(codegenWasm(kernel)));
4529
+ return new Executable(kernel, {
4530
+ module,
4531
+ parallel: this.#workerPool !== null
4532
+ });
3901
4533
  }
3902
4534
  prepareKernelSync(kernel) {
3903
4535
  const kernelHash = FpHash.hash(kernel);
3904
- const module = runWithCache(moduleCache, kernelHash.toString(), () => {
3905
- const bytes = codegenWasm(kernel);
3906
- return new WebAssembly.Module(bytes);
4536
+ const module = runWithCache(moduleCache, kernelHash.toString(), () => new WebAssembly.Module(codegenWasm(kernel)));
4537
+ return new Executable(kernel, {
4538
+ module,
4539
+ parallel: false
3907
4540
  });
3908
- return new Executable(kernel, { module });
3909
4541
  }
3910
4542
  async prepareRoutine(routine) {
3911
4543
  return this.prepareRoutineSync(routine);
3912
4544
  }
3913
4545
  prepareRoutineSync(routine) {
3914
- return new Executable(routine, void 0);
4546
+ return new Executable(routine, {
4547
+ module: void 0,
4548
+ parallel: false
4549
+ });
3915
4550
  }
3916
4551
  dispatch(exe, inputs, outputs) {
3917
- if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
3918
- const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
3919
- const func = instance.exports.kernel;
3920
- const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
3921
- func(...ptrs);
4552
+ const tracing = isTracing();
4553
+ const start = tracing ? performance.now() : 0;
4554
+ if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
4555
+ else {
4556
+ const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
4557
+ if (exe.data.parallel && this.#workerPool) {
4558
+ const epoch = this.#workerPool.dispatch(exe.data.module, ptrs, exe.source.size);
4559
+ for (const slot of outputs) this.#pendingWork.set(slot, epoch);
4560
+ } else {
4561
+ if (inputs.some((slot) => {
4562
+ const epoch = this.#pendingWork.get(slot);
4563
+ return epoch && this.#workerPool.epoch < epoch;
4564
+ })) throw new Error("cannot dispatch synchronously with pending async work");
4565
+ const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
4566
+ const func = instance.exports.kernel;
4567
+ func(...ptrs, 0, exe.source.size);
4568
+ }
4569
+ }
4570
+ if (tracing) {
4571
+ const info = traceSourceInfo(exe.source);
4572
+ emitTrace("wasm", info, start, performance.now());
4573
+ }
3922
4574
  }
3923
4575
  #getBuffer(slot) {
3924
4576
  const buffer = this.#buffers.get(slot);
@@ -3926,12 +4578,36 @@ var WasmBackend = class {
3926
4578
  return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
3927
4579
  }
3928
4580
  };
4581
+ /** Emit a runtime guard: enter the if-block only when [begin, end) is SIMD-aligned. */
4582
+ function emitAlignmentGuard(cg, paramBegin, paramEnd) {
4583
+ const mask = SIMD_LANES - 1;
4584
+ cg.local.get(paramEnd);
4585
+ cg.local.get(paramBegin);
4586
+ cg.i32.sub();
4587
+ cg.i32.const(mask);
4588
+ cg.i32.and();
4589
+ cg.i32.eqz();
4590
+ cg.local.get(paramBegin);
4591
+ cg.i32.const(mask);
4592
+ cg.i32.and();
4593
+ cg.i32.eqz();
4594
+ cg.i32.and();
4595
+ cg.if(cg.void);
4596
+ }
3929
4597
  function codegenWasm(kernel) {
3930
4598
  const tune = tuneNullopt(kernel);
3931
4599
  const re = kernel.reduction;
3932
4600
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
4601
+ const useSimd = isSimdEligible(tune.exp, kernel);
4602
+ const bufferStrides = /* @__PURE__ */ new Map();
4603
+ if (useSimd) tune.exp.collect((e) => e.op === AluOp.GlobalIndex).forEach((gi) => {
4604
+ const result = analyzeStride(gi.src[0]);
4605
+ if (result.kind !== "gather" && (result.tileSize < SIMD_LANES || isFinite(result.tileSize) && result.tileSize % SIMD_LANES !== 0)) bufferStrides.set(gi, GATHER);
4606
+ else bufferStrides.set(gi, result);
4607
+ });
3933
4608
  const cg = new CodeGenerator();
3934
4609
  cg.memory.import("env", "memory");
4610
+ if (hasSharedArrayBuffer()) cg.memory.pages(0, 65536).shared(true);
3935
4611
  const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
3936
4612
  const funcs = {};
3937
4613
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
@@ -3943,12 +4619,127 @@ function codegenWasm(kernel) {
3943
4619
  if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
3944
4620
  if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
3945
4621
  if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
3946
- const kernelFunc = cg.function(rep(kernel.nargs + 1, cg.i32), [], () => {
4622
+ const paramBegin = kernel.nargs + 1;
4623
+ const paramEnd = kernel.nargs + 2;
4624
+ const kernelFunc = cg.function(rep(kernel.nargs + 3, cg.i32), [], () => {
3947
4625
  const gidx = cg.local.declare(cg.i32);
4626
+ cg.local.get(paramBegin);
4627
+ cg.local.set(gidx);
4628
+ if (useSimd) {
4629
+ emitAlignmentGuard(cg, paramBegin, paramEnd);
4630
+ cg.loop(cg.void);
4631
+ if (!re) {
4632
+ cg.block(cg.void);
4633
+ cg.local.get(gidx);
4634
+ cg.local.get(paramEnd);
4635
+ cg.i32.ge_u();
4636
+ cg.br_if(0);
4637
+ cg.local.get(kernel.nargs);
4638
+ cg.local.get(gidx);
4639
+ cg.i32.const(byteWidth(kernel.dtype));
4640
+ cg.i32.mul();
4641
+ cg.i32.add();
4642
+ translateExpSimd(cg, funcs, tune.exp, { gidx }, bufferStrides);
4643
+ cg.v128.store(4);
4644
+ cg.local.get(gidx);
4645
+ cg.i32.const(SIMD_LANES);
4646
+ cg.i32.add();
4647
+ cg.local.set(gidx);
4648
+ cg.br(1);
4649
+ cg.end();
4650
+ } else {
4651
+ const reIsInt = kernel.exp.dtype === DType.Int32 || kernel.exp.dtype === DType.Uint32;
4652
+ cg.block(cg.void);
4653
+ cg.local.get(gidx);
4654
+ cg.local.get(paramEnd);
4655
+ cg.i32.ge_u();
4656
+ cg.br_if(0);
4657
+ const vecAcc = cg.local.declare(reIsInt ? cg.i32x4 : cg.f32x4);
4658
+ if (reIsInt) {
4659
+ cg.i32.const(re.identity);
4660
+ cg.i32x4.splat();
4661
+ } else {
4662
+ cg.f32.const(re.identity);
4663
+ cg.f32x4.splat();
4664
+ }
4665
+ cg.local.set(vecAcc);
4666
+ const ridx = cg.local.declare(cg.i32);
4667
+ cg.i32.const(0);
4668
+ cg.local.set(ridx);
4669
+ cg.loop(cg.void);
4670
+ cg.block(cg.void);
4671
+ cg.local.get(ridx);
4672
+ cg.i32.const(re.size);
4673
+ cg.i32.ge_u();
4674
+ cg.br_if(0);
4675
+ translateExpSimd(cg, funcs, tune.exp, {
4676
+ gidx,
4677
+ ridx
4678
+ }, bufferStrides);
4679
+ cg.local.get(vecAcc);
4680
+ if (reIsInt) if (re.op === AluOp.Add) cg.i32x4.add();
4681
+ else if (re.op === AluOp.Mul) cg.i32x4.mul();
4682
+ else if (re.op === AluOp.Min) if (re.dtype === DType.Int32) cg.i32x4.min_s();
4683
+ else cg.i32x4.min_u();
4684
+ else if (re.op === AluOp.Max) if (re.dtype === DType.Int32) cg.i32x4.max_s();
4685
+ else cg.i32x4.max_u();
4686
+ else throw new Error(`invalid SIMD reduction op: ${re.op}`);
4687
+ else if (re.op === AluOp.Add) cg.f32x4.add();
4688
+ else if (re.op === AluOp.Mul) cg.f32x4.mul();
4689
+ else if (re.op === AluOp.Min) cg.f32x4.min();
4690
+ else if (re.op === AluOp.Max) cg.f32x4.max();
4691
+ else throw new Error(`invalid SIMD reduction op: ${re.op}`);
4692
+ cg.local.set(vecAcc);
4693
+ cg.local.get(ridx);
4694
+ cg.i32.const(1);
4695
+ cg.i32.add();
4696
+ cg.local.set(ridx);
4697
+ cg.br(1);
4698
+ cg.end();
4699
+ cg.end();
4700
+ for (let lane = 0; lane < SIMD_LANES; lane++) {
4701
+ cg.local.get(kernel.nargs);
4702
+ cg.local.get(gidx);
4703
+ if (lane > 0) {
4704
+ cg.i32.const(lane);
4705
+ cg.i32.add();
4706
+ }
4707
+ cg.i32.const(byteWidth(kernel.dtype));
4708
+ cg.i32.mul();
4709
+ cg.i32.add();
4710
+ const acc = cg.local.declare(reIsInt ? cg.i32 : cg.f32);
4711
+ cg.local.get(vecAcc);
4712
+ if (reIsInt) cg.i32x4.extract_lane(lane);
4713
+ else cg.f32x4.extract_lane(lane);
4714
+ cg.local.set(acc);
4715
+ const laneGidx = cg.local.declare(cg.i32);
4716
+ cg.local.get(gidx);
4717
+ if (lane > 0) {
4718
+ cg.i32.const(lane);
4719
+ cg.i32.add();
4720
+ }
4721
+ cg.local.set(laneGidx);
4722
+ translateExp(cg, funcs, tune.epilogue, {
4723
+ acc,
4724
+ gidx: laneGidx
4725
+ });
4726
+ dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
4727
+ }
4728
+ cg.local.get(gidx);
4729
+ cg.i32.const(SIMD_LANES);
4730
+ cg.i32.add();
4731
+ cg.local.set(gidx);
4732
+ cg.br(1);
4733
+ cg.end();
4734
+ }
4735
+ cg.end();
4736
+ cg.return();
4737
+ cg.end();
4738
+ }
3948
4739
  cg.loop(cg.void);
3949
4740
  cg.block(cg.void);
3950
4741
  cg.local.get(gidx);
3951
- cg.i32.const(kernel.size);
4742
+ cg.local.get(paramEnd);
3952
4743
  cg.i32.ge_u();
3953
4744
  cg.br_if(0);
3954
4745
  cg.local.get(kernel.nargs);
@@ -4087,6 +4878,11 @@ function translateExp(cg, funcs, exp, ctx) {
4087
4878
  else cg.i32.gt_u();
4088
4879
  cg.select();
4089
4880
  } else throw new UnsupportedOpError(op, dtype, "wasm");
4881
+ else if (op === AluOp.BitCombine) if (arg === "and") cg.i32.and();
4882
+ else if (arg === "or") cg.i32.or();
4883
+ else cg.i32.xor();
4884
+ else if (op === AluOp.BitShift) if (arg === "shl") cg.i32.shl();
4885
+ else cg.i32.shr_u();
4090
4886
  else if (op === AluOp.Cmplt) {
4091
4887
  const srcDtype = src[0].dtype;
4092
4888
  if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
@@ -4263,7 +5059,7 @@ async function createBackend(device) {
4263
5059
  if (!navigator.gpu) return null;
4264
5060
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4265
5061
  if (!adapter) return null;
4266
- const { WebGPUBackend } = await import("./webgpu-AN0cG_nB.js");
5062
+ const { WebGPUBackend } = await import("./webgpu-Dg8FpYrH.js");
4267
5063
  const importantLimits = [
4268
5064
  "maxBufferSize",
4269
5065
  "maxComputeInvocationsPerWorkgroup",
@@ -4301,7 +5097,7 @@ async function createBackend(device) {
4301
5097
  });
4302
5098
  if (!gl) return null;
4303
5099
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4304
- const { WebGLBackend } = await import("./webgl-DnGrclTz.js");
5100
+ const { WebGLBackend } = await import("./webgl-D8-14NzA.js");
4305
5101
  return new WebGLBackend(gl);
4306
5102
  } else throw new Error(`Backend not found: ${device}`);
4307
5103
  }
@@ -4335,6 +5131,15 @@ var UnsupportedRoutineError = class extends Error {
4335
5131
  super(`routine '${name}' is not supported in ${device} backend`);
4336
5132
  }
4337
5133
  };
5134
+ /**
5135
+ * If the WebGPU backend has been initialized, return the `GPUDevice` that this
5136
+ * backend runs on. This is useful for sharing buffers.
5137
+ */
5138
+ function getWebGPUDevice() {
5139
+ const backend = initializedBackends.get("webgpu");
5140
+ if (!backend) throw new Error("WebGPU backend not initialized, call init('webgpu') first");
5141
+ return backend.device;
5142
+ }
4338
5143
 
4339
5144
  //#endregion
4340
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
5145
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, emitTrace, findPow2, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, isTracing, mapSetUnion, normalizeAxis, onFlushTrace, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, strip1, toposort, traceSourceInfo, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };