@jax-js/jax 0.1.9 → 0.1.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -19
- package/dist/{backend-BId79r5b.js → backend-DZvR7mZV.js} +831 -26
- package/dist/{backend-DpI0riom.cjs → backend-DlYlOYqN.cjs} +872 -25
- package/dist/index.cjs +364 -20
- package/dist/index.d.cts +175 -11
- package/dist/index.d.ts +175 -11
- package/dist/index.js +363 -21
- package/dist/{webgl-DnGrclTz.js → webgl-D8-14NzA.js} +7 -1
- package/dist/{webgl-C5NjXc1p.cjs → webgl-Ovaaa-Qx.cjs} +7 -1
- package/dist/{webgpu-AN0cG_nB.js → webgpu-Dg8FpYrH.js} +141 -6
- package/dist/{webgpu-CdjiJSa7.cjs → webgpu-uU9nnttc.cjs} +141 -6
- package/package.json +5 -16
|
@@ -314,6 +314,16 @@ function runWithCache(cache, key, thunk) {
|
|
|
314
314
|
return value;
|
|
315
315
|
}
|
|
316
316
|
}
|
|
317
|
+
/** Async version of `runWithCache`. */
|
|
318
|
+
async function runWithCacheAsync(cache, key, thunk) {
|
|
319
|
+
const keyStr = JSON.stringify(key);
|
|
320
|
+
if (cache.has(keyStr)) return cache.get(keyStr);
|
|
321
|
+
else {
|
|
322
|
+
const value = await thunk();
|
|
323
|
+
cache.set(keyStr, value);
|
|
324
|
+
return value;
|
|
325
|
+
}
|
|
326
|
+
}
|
|
317
327
|
|
|
318
328
|
//#endregion
|
|
319
329
|
//#region src/alu.ts
|
|
@@ -416,8 +426,23 @@ var AluExp = class AluExp {
|
|
|
416
426
|
this.src = src;
|
|
417
427
|
this.arg = arg;
|
|
418
428
|
if (AluGroup.RequiredFloat.has(op) && !isFloatDtype(dtype)) throw new TypeError(`Unsupported dtype for ${op}: ${dtype}`);
|
|
419
|
-
|
|
420
|
-
|
|
429
|
+
switch (op) {
|
|
430
|
+
case AluOp.Bitcast:
|
|
431
|
+
if (dtype === DType.Bool || src[0].dtype === DType.Bool || byteWidth(dtype) !== byteWidth(src[0].dtype)) throw new TypeError(`Bitcast from ${src[0].dtype} -> ${dtype}`);
|
|
432
|
+
break;
|
|
433
|
+
case AluOp.Threefry2x32:
|
|
434
|
+
if (dtype !== DType.Uint32 || src.some((x) => x.dtype !== DType.Uint32)) throw new TypeError("Threefry2x32 requires uint32 types");
|
|
435
|
+
break;
|
|
436
|
+
case AluOp.BitCombine:
|
|
437
|
+
if (src[0].dtype !== src[1].dtype || isFloatDtype(src[0].dtype)) throw new TypeError(`BitCombine[${arg}] requires matching integral dtype, got ${src[0].dtype} and ${src[1].dtype}`);
|
|
438
|
+
break;
|
|
439
|
+
case AluOp.BitShift:
|
|
440
|
+
if (src[0].dtype === DType.Bool || src[1].dtype === DType.Bool || isFloatDtype(src[0].dtype) || isFloatDtype(src[1].dtype)) throw new TypeError(`BitShift[${arg}] requires two integral, non-bool dtypes, got ${src[0].dtype} and ${src[1].dtype}`);
|
|
441
|
+
break;
|
|
442
|
+
case AluOp.BitInvert:
|
|
443
|
+
if (isFloatDtype(src[0].dtype)) throw new TypeError(`BitInvert requires an integral dtype, got ${src[0].dtype}`);
|
|
444
|
+
break;
|
|
445
|
+
}
|
|
421
446
|
}
|
|
422
447
|
static add(a, b) {
|
|
423
448
|
return new AluExp(AluOp.Add, a.dtype, [a, b]);
|
|
@@ -494,6 +519,12 @@ var AluExp = class AluExp {
|
|
|
494
519
|
c1
|
|
495
520
|
], mode);
|
|
496
521
|
}
|
|
522
|
+
static bitCombine(a, b, mode) {
|
|
523
|
+
return new AluExp(AluOp.BitCombine, a.dtype, [a, b], mode);
|
|
524
|
+
}
|
|
525
|
+
static bitShift(a, b, mode) {
|
|
526
|
+
return new AluExp(AluOp.BitShift, a.dtype, [a, b], mode);
|
|
527
|
+
}
|
|
497
528
|
static cmplt(a, b) {
|
|
498
529
|
return new AluExp(AluOp.Cmplt, DType.Bool, [a, b]);
|
|
499
530
|
}
|
|
@@ -966,6 +997,16 @@ var AluExp = class AluExp {
|
|
|
966
997
|
case AluOp.Mod: return x % y;
|
|
967
998
|
case AluOp.Min: return Math.min(x, y);
|
|
968
999
|
case AluOp.Max: return Math.max(x, y);
|
|
1000
|
+
case AluOp.BitCombine: {
|
|
1001
|
+
let r;
|
|
1002
|
+
if (this.arg === "and") r = x & y;
|
|
1003
|
+
else if (this.arg === "or") r = x | y;
|
|
1004
|
+
else r = x ^ y;
|
|
1005
|
+
return this.dtype === DType.Int32 ? r | 0 : r >>> 0;
|
|
1006
|
+
}
|
|
1007
|
+
case AluOp.BitShift:
|
|
1008
|
+
if (this.arg === "shl") return this.dtype === DType.Int32 ? x << y | 0 : x << y >>> 0;
|
|
1009
|
+
return x >>> y;
|
|
969
1010
|
case AluOp.Cmplt: return Number(x < y);
|
|
970
1011
|
case AluOp.Cmpne: return Number(x != y);
|
|
971
1012
|
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
@@ -1087,6 +1128,18 @@ var AluExp = class AluExp {
|
|
|
1087
1128
|
}
|
|
1088
1129
|
if (BIN_SYM[node.op]) return `(${parts[0]} ${BIN_SYM[node.op]} ${parts[1]})`;
|
|
1089
1130
|
if (CMP_SYM[node.op]) return `(${parts[0]} ${CMP_SYM[node.op]} ${parts[1]})`;
|
|
1131
|
+
if (node.op === AluOp.BitCombine) {
|
|
1132
|
+
const sym = {
|
|
1133
|
+
and: "&",
|
|
1134
|
+
or: "|",
|
|
1135
|
+
xor: "^"
|
|
1136
|
+
}[node.arg];
|
|
1137
|
+
return `(${parts[0]} ${sym} ${parts[1]})`;
|
|
1138
|
+
}
|
|
1139
|
+
if (node.op === AluOp.BitShift) {
|
|
1140
|
+
const sym = node.arg === "shl" ? "<<" : ">>";
|
|
1141
|
+
return `(${parts[0]} ${sym} ${parts[1]})`;
|
|
1142
|
+
}
|
|
1090
1143
|
if (UNARY_SYM[node.op]) return `${UNARY_SYM[node.op]}${parts[0]}`;
|
|
1091
1144
|
if (node.op === AluOp.Cast) return `Cast<${node.dtype}>(${strip1(parts[0])})`;
|
|
1092
1145
|
if (node.op === AluOp.Bitcast) return `Bitcast<${node.dtype}>(${strip1(parts[0])})`;
|
|
@@ -1179,6 +1232,9 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1179
1232
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1180
1233
|
AluOp$1["Cast"] = "Cast";
|
|
1181
1234
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
1235
|
+
AluOp$1["BitCombine"] = "BitCombine";
|
|
1236
|
+
AluOp$1["BitInvert"] = "BitInvert";
|
|
1237
|
+
AluOp$1["BitShift"] = "BitShift";
|
|
1182
1238
|
AluOp$1["Cmplt"] = "Cmplt";
|
|
1183
1239
|
AluOp$1["Cmpne"] = "Cmpne";
|
|
1184
1240
|
AluOp$1["Where"] = "Where";
|
|
@@ -1198,7 +1254,9 @@ const AluGroup = {
|
|
|
1198
1254
|
AluOp.Idiv,
|
|
1199
1255
|
AluOp.Mod,
|
|
1200
1256
|
AluOp.Min,
|
|
1201
|
-
AluOp.Max
|
|
1257
|
+
AluOp.Max,
|
|
1258
|
+
AluOp.BitCombine,
|
|
1259
|
+
AluOp.BitShift
|
|
1202
1260
|
]),
|
|
1203
1261
|
Unary: new Set([
|
|
1204
1262
|
AluOp.Sin,
|
|
@@ -1313,6 +1371,10 @@ var Reduction = class {
|
|
|
1313
1371
|
this.epilogue = epilogue;
|
|
1314
1372
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
|
|
1315
1373
|
this.epilogue = epilogue.simplify();
|
|
1374
|
+
if (this.dtype === DType.Float16 && this.op === AluOp.Add) {
|
|
1375
|
+
this.epilogue = this.epilogue.substitute({ acc: AluExp.cast(this.dtype, AluVar.acc(DType.Float32)) });
|
|
1376
|
+
this.dtype = DType.Float32;
|
|
1377
|
+
}
|
|
1316
1378
|
}
|
|
1317
1379
|
hash(state) {
|
|
1318
1380
|
state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
|
|
@@ -2267,11 +2329,15 @@ var TuneDims = class {
|
|
|
2267
2329
|
};
|
|
2268
2330
|
/** Tuning step that does not apply any optimization. */
|
|
2269
2331
|
function tuneNullopt(kernel) {
|
|
2332
|
+
let exp = kernel.exp;
|
|
2270
2333
|
const vars = {};
|
|
2271
2334
|
vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
|
|
2272
|
-
if (kernel.reduction)
|
|
2335
|
+
if (kernel.reduction) {
|
|
2336
|
+
vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2337
|
+
if (exp.dtype !== kernel.reduction.dtype) exp = AluExp.cast(kernel.reduction.dtype, exp);
|
|
2338
|
+
}
|
|
2273
2339
|
return {
|
|
2274
|
-
exp:
|
|
2340
|
+
exp: exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2275
2341
|
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2276
2342
|
outputIdxExp: vars.gidx,
|
|
2277
2343
|
threadCount: kernel.size,
|
|
@@ -2280,8 +2346,9 @@ function tuneNullopt(kernel) {
|
|
|
2280
2346
|
}
|
|
2281
2347
|
/** Tuning for WebGPU kernels. */
|
|
2282
2348
|
function tuneWebgpu(kernel) {
|
|
2283
|
-
const
|
|
2349
|
+
const reduction = kernel.reduction;
|
|
2284
2350
|
if (!reduction) return tuneNullopt(kernel);
|
|
2351
|
+
const exp = AluExp.cast(reduction.dtype, kernel.exp);
|
|
2285
2352
|
const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
|
|
2286
2353
|
if (globalIndexes.length > 0) {
|
|
2287
2354
|
if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
|
|
@@ -2509,6 +2576,85 @@ var CpuBackend = class {
|
|
|
2509
2576
|
}
|
|
2510
2577
|
};
|
|
2511
2578
|
|
|
2579
|
+
//#endregion
|
|
2580
|
+
//#region src/tracing.ts
|
|
2581
|
+
let traceEnabled = false;
|
|
2582
|
+
const flushCallbacks = [];
|
|
2583
|
+
/**
|
|
2584
|
+
* Start collecting kernel traces.
|
|
2585
|
+
*
|
|
2586
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2587
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2588
|
+
*/
|
|
2589
|
+
function startTrace() {
|
|
2590
|
+
traceEnabled = true;
|
|
2591
|
+
}
|
|
2592
|
+
/**
|
|
2593
|
+
* Stop collecting kernel traces.
|
|
2594
|
+
*
|
|
2595
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2596
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2597
|
+
*/
|
|
2598
|
+
function stopTrace() {
|
|
2599
|
+
traceEnabled = false;
|
|
2600
|
+
for (const cb of flushCallbacks) cb();
|
|
2601
|
+
}
|
|
2602
|
+
/** Check if tracing is currently enabled. */
|
|
2603
|
+
function isTracing() {
|
|
2604
|
+
return traceEnabled;
|
|
2605
|
+
}
|
|
2606
|
+
/** Register a callback to flush pending trace data when tracing stops. */
|
|
2607
|
+
function onFlushTrace(cb) {
|
|
2608
|
+
flushCallbacks.push(cb);
|
|
2609
|
+
}
|
|
2610
|
+
function humanSize(n) {
|
|
2611
|
+
if (n >= 1e9) return `${(n / 1e9).toPrecision(3)}B`;
|
|
2612
|
+
if (n >= 1e6) return `${(n / 1e6).toPrecision(3)}M`;
|
|
2613
|
+
if (n >= 1e3) return `${(n / 1e3).toPrecision(3)}K`;
|
|
2614
|
+
return `${n}`;
|
|
2615
|
+
}
|
|
2616
|
+
/** Build a trace label, properties, and color from a kernel or routine source. */
|
|
2617
|
+
function traceSourceInfo(source) {
|
|
2618
|
+
const properties = [];
|
|
2619
|
+
let label;
|
|
2620
|
+
let color;
|
|
2621
|
+
if (source instanceof Kernel) {
|
|
2622
|
+
label = `Kernel[${humanSize(source.size)}]`;
|
|
2623
|
+
properties.push(["exp", `${source.exp}`]);
|
|
2624
|
+
properties.push(["size", `${source.size}`]);
|
|
2625
|
+
properties.push(["nargs", `${source.nargs}`]);
|
|
2626
|
+
if (!source.reduction) color = "primary";
|
|
2627
|
+
else {
|
|
2628
|
+
color = "secondary";
|
|
2629
|
+
properties.push(["reduction", `${source.reduction.op}:${source.reduction.size}`]);
|
|
2630
|
+
}
|
|
2631
|
+
} else {
|
|
2632
|
+
color = "tertiary";
|
|
2633
|
+
label = source.name;
|
|
2634
|
+
properties.push(["inputShapes", source.type.inputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2635
|
+
properties.push(["outputShapes", source.type.outputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2636
|
+
properties.push(["dtype", source.type.inputDtypes.join(", ")]);
|
|
2637
|
+
}
|
|
2638
|
+
return {
|
|
2639
|
+
label,
|
|
2640
|
+
color,
|
|
2641
|
+
properties
|
|
2642
|
+
};
|
|
2643
|
+
}
|
|
2644
|
+
/** Emit a trace entry as a `performance.measure` with devtools metadata. */
|
|
2645
|
+
function emitTrace(track, info, start, end) {
|
|
2646
|
+
performance.measure(info.label, {
|
|
2647
|
+
detail: { devtools: {
|
|
2648
|
+
trackGroup: "JAX Profiler",
|
|
2649
|
+
track,
|
|
2650
|
+
color: info.color,
|
|
2651
|
+
properties: info.properties
|
|
2652
|
+
} },
|
|
2653
|
+
start,
|
|
2654
|
+
end
|
|
2655
|
+
});
|
|
2656
|
+
}
|
|
2657
|
+
|
|
2512
2658
|
//#endregion
|
|
2513
2659
|
//#region src/backend/wasm/allocator.ts
|
|
2514
2660
|
/** Simple tensor memory allocator for WebAssembly linear memory. */
|
|
@@ -3071,6 +3217,147 @@ function wasm_threefry2x32(cg) {
|
|
|
3071
3217
|
});
|
|
3072
3218
|
}
|
|
3073
3219
|
|
|
3220
|
+
//#endregion
|
|
3221
|
+
//#region src/backend/wasm/parallel.ts
|
|
3222
|
+
/** Check if SharedArrayBuffer is available. */
|
|
3223
|
+
function hasSharedArrayBuffer() {
|
|
3224
|
+
return typeof SharedArrayBuffer !== "undefined" && typeof Worker !== "undefined";
|
|
3225
|
+
}
|
|
3226
|
+
const MIN_ELEMS_PER_THREAD = 256;
|
|
3227
|
+
const WORKER_SOURCE = `
|
|
3228
|
+
let memory = null;
|
|
3229
|
+
let cachedModule = null;
|
|
3230
|
+
let cachedFunc = null;
|
|
3231
|
+
|
|
3232
|
+
self.onmessage = (e) => {
|
|
3233
|
+
const msg = e.data;
|
|
3234
|
+
if (msg.type === "init") {
|
|
3235
|
+
memory = msg.memory;
|
|
3236
|
+
postMessage({ type: "ready" });
|
|
3237
|
+
return;
|
|
3238
|
+
}
|
|
3239
|
+
try {
|
|
3240
|
+
const { module, ptrs, begin, end } = msg;
|
|
3241
|
+
if (module !== cachedModule) {
|
|
3242
|
+
cachedModule = module;
|
|
3243
|
+
const instance = new WebAssembly.Instance(module, { env: { memory } });
|
|
3244
|
+
cachedFunc = instance.exports.kernel;
|
|
3245
|
+
}
|
|
3246
|
+
cachedFunc(...ptrs, begin, end);
|
|
3247
|
+
postMessage({ type: "done", ok: true });
|
|
3248
|
+
} catch (err) {
|
|
3249
|
+
postMessage({ type: "done", ok: false, error: String(err) });
|
|
3250
|
+
}
|
|
3251
|
+
};
|
|
3252
|
+
`;
|
|
3253
|
+
/** Pool of Web Workers for parallel WASM kernel dispatch. */
|
|
3254
|
+
var WasmWorkerPool = class {
|
|
3255
|
+
#memory;
|
|
3256
|
+
#numWorkers;
|
|
3257
|
+
#workers = [];
|
|
3258
|
+
#ready = Promise.resolve();
|
|
3259
|
+
/** Serializes dispatches so concurrent read() calls don't clobber onmessage. */
|
|
3260
|
+
#queue = Promise.resolve();
|
|
3261
|
+
#epoch = 0n;
|
|
3262
|
+
#epochEnd = 0n;
|
|
3263
|
+
#hooks = /* @__PURE__ */ new Map();
|
|
3264
|
+
constructor(memory, numWorkers) {
|
|
3265
|
+
if (numWorkers <= 0) throw new Error("numWorkers must be positive");
|
|
3266
|
+
this.#memory = memory;
|
|
3267
|
+
this.#numWorkers = numWorkers;
|
|
3268
|
+
}
|
|
3269
|
+
get epoch() {
|
|
3270
|
+
return this.#epoch;
|
|
3271
|
+
}
|
|
3272
|
+
waitForEpoch(target) {
|
|
3273
|
+
if (target <= this.#epoch) return Promise.resolve();
|
|
3274
|
+
return new Promise((resolve) => {
|
|
3275
|
+
if (target <= this.#epoch) return resolve();
|
|
3276
|
+
const hooks = this.#hooks.get(target);
|
|
3277
|
+
if (hooks) hooks.push(resolve);
|
|
3278
|
+
else this.#hooks.set(target, [resolve]);
|
|
3279
|
+
});
|
|
3280
|
+
}
|
|
3281
|
+
#ensureInit() {
|
|
3282
|
+
if (this.#workers.length > 0) return;
|
|
3283
|
+
const blob = new Blob([WORKER_SOURCE], { type: "application/javascript" });
|
|
3284
|
+
const url = URL.createObjectURL(blob);
|
|
3285
|
+
this.#workers = [];
|
|
3286
|
+
const readyPromises = [];
|
|
3287
|
+
for (let i = 0; i < this.#numWorkers; i++) {
|
|
3288
|
+
const worker = new Worker(url, { type: "module" });
|
|
3289
|
+
this.#workers.push(worker);
|
|
3290
|
+
readyPromises.push(new Promise((resolve, reject) => {
|
|
3291
|
+
worker.onmessage = () => resolve();
|
|
3292
|
+
worker.onerror = (e) => reject(new Error(e.message || "Worker failed to load"));
|
|
3293
|
+
}));
|
|
3294
|
+
worker.postMessage({
|
|
3295
|
+
type: "init",
|
|
3296
|
+
memory: this.#memory
|
|
3297
|
+
});
|
|
3298
|
+
}
|
|
3299
|
+
this.#ready = Promise.all(readyPromises).then(() => {
|
|
3300
|
+
URL.revokeObjectURL(url);
|
|
3301
|
+
});
|
|
3302
|
+
this.#queue = this.#ready;
|
|
3303
|
+
}
|
|
3304
|
+
/**
|
|
3305
|
+
* Dispatch a kernel across multiple workers.
|
|
3306
|
+
*
|
|
3307
|
+
* Returns an epoch that can be used to wait for the ongoing work to complete,
|
|
3308
|
+
* which is guaranteed to be monotonically increasing.
|
|
3309
|
+
*/
|
|
3310
|
+
dispatch(module$1, ptrs, size) {
|
|
3311
|
+
this.#ensureInit();
|
|
3312
|
+
this.#epochEnd++;
|
|
3313
|
+
const result = this.#queue.then(() => this.#dispatchNow(module$1, ptrs, size));
|
|
3314
|
+
this.#queue = result.then(() => {}, () => {}).then(() => {
|
|
3315
|
+
this.#epoch++;
|
|
3316
|
+
const hooks = this.#hooks.get(this.#epoch);
|
|
3317
|
+
if (hooks) {
|
|
3318
|
+
for (const hook of hooks) hook();
|
|
3319
|
+
this.#hooks.delete(this.#epoch);
|
|
3320
|
+
}
|
|
3321
|
+
});
|
|
3322
|
+
return this.#epochEnd;
|
|
3323
|
+
}
|
|
3324
|
+
async #dispatchNow(module$1, ptrs, size) {
|
|
3325
|
+
if (size === 0) return;
|
|
3326
|
+
const n = Math.min(this.#workers.length, Math.ceil(size / MIN_ELEMS_PER_THREAD));
|
|
3327
|
+
const chunkSize = Math.ceil(size / n / 16) * 16;
|
|
3328
|
+
const promises = [];
|
|
3329
|
+
for (let i = 0; i < n; i++) {
|
|
3330
|
+
const begin = i * chunkSize;
|
|
3331
|
+
const end = Math.min(begin + chunkSize, size);
|
|
3332
|
+
if (begin >= size) break;
|
|
3333
|
+
const worker = this.#workers[i];
|
|
3334
|
+
promises.push(new Promise((resolve, reject) => {
|
|
3335
|
+
worker.onmessage = (e) => {
|
|
3336
|
+
if (e.data.ok) resolve();
|
|
3337
|
+
else reject(/* @__PURE__ */ new Error(`Worker error: ${e.data.error}`));
|
|
3338
|
+
};
|
|
3339
|
+
worker.postMessage({
|
|
3340
|
+
module: module$1,
|
|
3341
|
+
ptrs,
|
|
3342
|
+
begin,
|
|
3343
|
+
end
|
|
3344
|
+
});
|
|
3345
|
+
}));
|
|
3346
|
+
}
|
|
3347
|
+
await Promise.all(promises);
|
|
3348
|
+
}
|
|
3349
|
+
};
|
|
3350
|
+
/** Try to create a worker pool. Returns null if workers are unavailable. */
|
|
3351
|
+
function createWorkerPool(memory) {
|
|
3352
|
+
if (!hasSharedArrayBuffer()) return null;
|
|
3353
|
+
try {
|
|
3354
|
+
const numWorkers = Math.max(1, typeof navigator !== "undefined" && navigator.hardwareConcurrency || 4);
|
|
3355
|
+
return new WasmWorkerPool(memory, numWorkers);
|
|
3356
|
+
} catch {
|
|
3357
|
+
return null;
|
|
3358
|
+
}
|
|
3359
|
+
}
|
|
3360
|
+
|
|
3074
3361
|
//#endregion
|
|
3075
3362
|
//#region src/backend/wasm/wasmblr.ts
|
|
3076
3363
|
/**
|
|
@@ -3408,7 +3695,7 @@ var CodeGenerator = class {
|
|
|
3408
3695
|
concat(importSectionBytes, encodeString(this.memory.aString));
|
|
3409
3696
|
concat(importSectionBytes, encodeString(this.memory.bString));
|
|
3410
3697
|
importSectionBytes.push(2);
|
|
3411
|
-
if (this.memory.
|
|
3698
|
+
if (this.memory.max) {
|
|
3412
3699
|
if (this.memory.isShared) importSectionBytes.push(3);
|
|
3413
3700
|
else importSectionBytes.push(1);
|
|
3414
3701
|
concat(importSectionBytes, encodeUnsigned(this.memory.min));
|
|
@@ -3815,6 +4102,8 @@ var I32x4 = class extends V128 {
|
|
|
3815
4102
|
min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
|
|
3816
4103
|
max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
|
|
3817
4104
|
max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
|
|
4105
|
+
trunc_sat_f32x4_s = VECTOR_OP("trunc_sat_f32x4_s", 248, ["v128"], "v128");
|
|
4106
|
+
trunc_sat_f32x4_u = VECTOR_OP("trunc_sat_f32x4_u", 249, ["v128"], "v128");
|
|
3818
4107
|
};
|
|
3819
4108
|
var F32x4 = class extends V128 {
|
|
3820
4109
|
splat = VECTOR_OP("splat", 19, ["f32"], "v128");
|
|
@@ -3841,10 +4130,333 @@ var F32x4 = class extends V128 {
|
|
|
3841
4130
|
max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
|
|
3842
4131
|
pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
|
|
3843
4132
|
pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
|
|
4133
|
+
convert_i32x4_s = VECTOR_OP("convert_i32x4_s", 250, ["v128"], "v128");
|
|
4134
|
+
convert_i32x4_u = VECTOR_OP("convert_i32x4_u", 251, ["v128"], "v128");
|
|
3844
4135
|
};
|
|
3845
4136
|
|
|
3846
4137
|
//#endregion
|
|
3847
4138
|
//#region src/backend/wasm.ts
|
|
4139
|
+
/**
|
|
4140
|
+
* SIMD version of translateExp: emits v128 (f32x4 or i32x4) instructions instead of scalar.
|
|
4141
|
+
* gidx always steps by 4. strideMap classifies each GlobalIndex as broadcast/contiguous/gather.
|
|
4142
|
+
*/
|
|
4143
|
+
function translateExpSimd(cg, funcs, exp, ctx, strideMap) {
|
|
4144
|
+
const references = /* @__PURE__ */ new Map();
|
|
4145
|
+
const seen = /* @__PURE__ */ new Set();
|
|
4146
|
+
const countReferences = (exp$1) => {
|
|
4147
|
+
references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
|
|
4148
|
+
if (!seen.has(exp$1)) {
|
|
4149
|
+
seen.add(exp$1);
|
|
4150
|
+
for (const src of exp$1.src) countReferences(src);
|
|
4151
|
+
}
|
|
4152
|
+
};
|
|
4153
|
+
const expContext = /* @__PURE__ */ new Map();
|
|
4154
|
+
const gen = (exp$1) => {
|
|
4155
|
+
if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
|
|
4156
|
+
const { op, src, arg, dtype } = exp$1;
|
|
4157
|
+
const isInt = dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool;
|
|
4158
|
+
const isSigned = dtype === DType.Int32;
|
|
4159
|
+
if (op === AluOp.Add) {
|
|
4160
|
+
gen(src[0]);
|
|
4161
|
+
gen(src[1]);
|
|
4162
|
+
if (isInt) cg.i32x4.add();
|
|
4163
|
+
else cg.f32x4.add();
|
|
4164
|
+
} else if (op === AluOp.Sub) {
|
|
4165
|
+
gen(src[0]);
|
|
4166
|
+
gen(src[1]);
|
|
4167
|
+
if (isInt) cg.i32x4.sub();
|
|
4168
|
+
else cg.f32x4.sub();
|
|
4169
|
+
} else if (op === AluOp.Mul) {
|
|
4170
|
+
gen(src[0]);
|
|
4171
|
+
gen(src[1]);
|
|
4172
|
+
if (isInt) cg.i32x4.mul();
|
|
4173
|
+
else cg.f32x4.mul();
|
|
4174
|
+
} else if (op === AluOp.Min) {
|
|
4175
|
+
gen(src[0]);
|
|
4176
|
+
gen(src[1]);
|
|
4177
|
+
if (isInt) if (isSigned) cg.i32x4.min_s();
|
|
4178
|
+
else cg.i32x4.min_u();
|
|
4179
|
+
else cg.f32x4.min();
|
|
4180
|
+
} else if (op === AluOp.Max) {
|
|
4181
|
+
gen(src[0]);
|
|
4182
|
+
gen(src[1]);
|
|
4183
|
+
if (isInt) if (isSigned) cg.i32x4.max_s();
|
|
4184
|
+
else cg.i32x4.max_u();
|
|
4185
|
+
else cg.f32x4.max();
|
|
4186
|
+
} else if (op === AluOp.Sqrt) {
|
|
4187
|
+
gen(src[0]);
|
|
4188
|
+
cg.f32x4.sqrt();
|
|
4189
|
+
} else if (op === AluOp.Floor) {
|
|
4190
|
+
gen(src[0]);
|
|
4191
|
+
cg.f32x4.floor();
|
|
4192
|
+
} else if (op === AluOp.Ceil) {
|
|
4193
|
+
gen(src[0]);
|
|
4194
|
+
cg.f32x4.ceil();
|
|
4195
|
+
} else if (op === AluOp.Const) if (isInt) {
|
|
4196
|
+
cg.i32.const(arg);
|
|
4197
|
+
cg.i32x4.splat();
|
|
4198
|
+
} else {
|
|
4199
|
+
cg.f32.const(arg);
|
|
4200
|
+
cg.f32x4.splat();
|
|
4201
|
+
}
|
|
4202
|
+
else if (op === AluOp.Cast) {
|
|
4203
|
+
gen(src[0]);
|
|
4204
|
+
const dtype0 = src[0].dtype;
|
|
4205
|
+
const src0IsInt = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
4206
|
+
if (isInt && !src0IsInt) if (isSigned) cg.i32x4.trunc_sat_f32x4_s();
|
|
4207
|
+
else cg.i32x4.trunc_sat_f32x4_u();
|
|
4208
|
+
else if (!isInt && src0IsInt) if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32x4.convert_i32x4_s();
|
|
4209
|
+
else cg.f32x4.convert_i32x4_u();
|
|
4210
|
+
} else if (op === AluOp.Cmplt) {
|
|
4211
|
+
gen(src[0]);
|
|
4212
|
+
gen(src[1]);
|
|
4213
|
+
const srcDtype = src[0].dtype;
|
|
4214
|
+
if (srcDtype === DType.Float32) cg.f32x4.lt();
|
|
4215
|
+
else if (srcDtype === DType.Int32) cg.i32x4.lt_s();
|
|
4216
|
+
else if (srcDtype === DType.Uint32) cg.i32x4.lt_u();
|
|
4217
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
4218
|
+
cg.i32.const(1);
|
|
4219
|
+
cg.i32x4.splat();
|
|
4220
|
+
cg.v128.and();
|
|
4221
|
+
} else if (op === AluOp.Cmpne) {
|
|
4222
|
+
gen(src[0]);
|
|
4223
|
+
gen(src[1]);
|
|
4224
|
+
const srcDtype = src[0].dtype;
|
|
4225
|
+
if (srcDtype === DType.Float32) cg.f32x4.ne();
|
|
4226
|
+
else cg.i32x4.ne();
|
|
4227
|
+
cg.i32.const(1);
|
|
4228
|
+
cg.i32x4.splat();
|
|
4229
|
+
cg.v128.and();
|
|
4230
|
+
} else if (op === AluOp.Where) {
|
|
4231
|
+
gen(src[1]);
|
|
4232
|
+
gen(src[2]);
|
|
4233
|
+
gen(src[0]);
|
|
4234
|
+
cg.i32.const(0);
|
|
4235
|
+
cg.i32x4.splat();
|
|
4236
|
+
cg.i32x4.ne();
|
|
4237
|
+
cg.v128.bitselect();
|
|
4238
|
+
} else if (op === AluOp.Variable || op === AluOp.Special) throw new Error(`translateExpSimd: unexpected ${op}(${arg})`);
|
|
4239
|
+
else if (op === AluOp.GlobalIndex) {
|
|
4240
|
+
const [gid, len] = arg;
|
|
4241
|
+
const indexSubtree = src[0];
|
|
4242
|
+
const stride = strideMap.get(exp$1) ?? GATHER;
|
|
4243
|
+
if (stride.kind === "contiguous") {
|
|
4244
|
+
translateExp(cg, funcs, indexSubtree, ctx);
|
|
4245
|
+
{
|
|
4246
|
+
const maxIdx = Math.max(len - SIMD_LANES, 0);
|
|
4247
|
+
const wideIdx = cg.local.declare(cg.i32);
|
|
4248
|
+
cg.local.set(wideIdx);
|
|
4249
|
+
cg.local.get(wideIdx);
|
|
4250
|
+
cg.i32.const(maxIdx);
|
|
4251
|
+
cg.local.get(wideIdx);
|
|
4252
|
+
cg.i32.const(maxIdx);
|
|
4253
|
+
cg.i32.lt_u();
|
|
4254
|
+
cg.select();
|
|
4255
|
+
}
|
|
4256
|
+
cg.i32.const(byteWidth(dtype));
|
|
4257
|
+
cg.i32.mul();
|
|
4258
|
+
cg.local.get(gid);
|
|
4259
|
+
cg.i32.add();
|
|
4260
|
+
if (isInt) cg.i32x4.load(4);
|
|
4261
|
+
else cg.f32x4.load(4);
|
|
4262
|
+
} else if (stride.kind === "broadcast") {
|
|
4263
|
+
translateExp(cg, funcs, indexSubtree, ctx);
|
|
4264
|
+
const local = cg.local.declare(cg.i32);
|
|
4265
|
+
cg.local.tee(local);
|
|
4266
|
+
cg.i32.const(0);
|
|
4267
|
+
cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
|
|
4268
|
+
cg.select();
|
|
4269
|
+
cg.i32.const(byteWidth(dtype));
|
|
4270
|
+
cg.i32.mul();
|
|
4271
|
+
cg.local.get(gid);
|
|
4272
|
+
cg.i32.add();
|
|
4273
|
+
if (isInt) {
|
|
4274
|
+
cg.i32.load(2);
|
|
4275
|
+
cg.i32x4.splat();
|
|
4276
|
+
} else {
|
|
4277
|
+
cg.f32.load(2);
|
|
4278
|
+
cg.f32x4.splat();
|
|
4279
|
+
}
|
|
4280
|
+
} else {
|
|
4281
|
+
const steppingLocal = ctx["gidx"];
|
|
4282
|
+
const origValue = cg.local.declare(cg.i32);
|
|
4283
|
+
cg.local.get(steppingLocal);
|
|
4284
|
+
cg.local.set(origValue);
|
|
4285
|
+
if (isInt) {
|
|
4286
|
+
cg.i32.const(0);
|
|
4287
|
+
cg.i32x4.splat();
|
|
4288
|
+
} else {
|
|
4289
|
+
cg.f32.const(0);
|
|
4290
|
+
cg.f32x4.splat();
|
|
4291
|
+
}
|
|
4292
|
+
const vec = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
|
|
4293
|
+
cg.local.set(vec);
|
|
4294
|
+
const idx = cg.local.declare(cg.i32);
|
|
4295
|
+
const scalarVal = cg.local.declare(isInt ? cg.i32 : cg.f32);
|
|
4296
|
+
for (let lane = 0; lane < SIMD_LANES; lane++) {
|
|
4297
|
+
cg.local.get(origValue);
|
|
4298
|
+
if (lane > 0) {
|
|
4299
|
+
cg.i32.const(lane);
|
|
4300
|
+
cg.i32.add();
|
|
4301
|
+
}
|
|
4302
|
+
cg.local.set(steppingLocal);
|
|
4303
|
+
translateExp(cg, funcs, indexSubtree, ctx);
|
|
4304
|
+
cg.local.tee(idx);
|
|
4305
|
+
cg.i32.const(0);
|
|
4306
|
+
cg.local.get(idx), cg.i32.const(len), cg.i32.lt_u();
|
|
4307
|
+
cg.select();
|
|
4308
|
+
cg.i32.const(byteWidth(dtype));
|
|
4309
|
+
cg.i32.mul();
|
|
4310
|
+
cg.local.get(gid);
|
|
4311
|
+
cg.i32.add();
|
|
4312
|
+
if (isInt) cg.i32.load(2);
|
|
4313
|
+
else cg.f32.load(2);
|
|
4314
|
+
cg.local.set(scalarVal);
|
|
4315
|
+
cg.local.get(vec);
|
|
4316
|
+
cg.local.get(scalarVal);
|
|
4317
|
+
if (isInt) cg.i32x4.replace_lane(lane);
|
|
4318
|
+
else cg.f32x4.replace_lane(lane);
|
|
4319
|
+
cg.local.set(vec);
|
|
4320
|
+
}
|
|
4321
|
+
cg.local.get(origValue);
|
|
4322
|
+
cg.local.set(steppingLocal);
|
|
4323
|
+
cg.local.get(vec);
|
|
4324
|
+
}
|
|
4325
|
+
} else throw new Error(`translateExpSimd: unsupported op ${op}`);
|
|
4326
|
+
if ((references.get(exp$1) ?? 0) > 1) {
|
|
4327
|
+
const local = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
|
|
4328
|
+
cg.local.tee(local);
|
|
4329
|
+
expContext.set(exp$1, local);
|
|
4330
|
+
}
|
|
4331
|
+
};
|
|
4332
|
+
countReferences(exp);
|
|
4333
|
+
gen(exp);
|
|
4334
|
+
}
|
|
4335
|
+
/** Number of SIMD lanes (f32x4 / i32x4 = 4 lanes). */
|
|
4336
|
+
const SIMD_LANES = 4;
|
|
4337
|
+
function referencesGidx(exp) {
|
|
4338
|
+
if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return true;
|
|
4339
|
+
return exp.src.some(referencesGidx);
|
|
4340
|
+
}
|
|
4341
|
+
/** When tileSize > N but doesn't divide evenly, the last group before the
|
|
4342
|
+
* inner reset is shorter than N — a SIMD group could straddle it. */
|
|
4343
|
+
function hasFragmentRisk(tileSize, N) {
|
|
4344
|
+
return isFinite(tileSize) && tileSize > N && tileSize % N !== 0;
|
|
4345
|
+
}
|
|
4346
|
+
const GATHER = { kind: "gather" };
|
|
4347
|
+
/**
|
|
4348
|
+
* Classify how a GlobalIndex's index expression behaves as gidx increments.
|
|
4349
|
+
*/
|
|
4350
|
+
function analyzeStride(exp) {
|
|
4351
|
+
if (!referencesGidx(exp)) return {
|
|
4352
|
+
kind: "broadcast",
|
|
4353
|
+
tileSize: Infinity
|
|
4354
|
+
};
|
|
4355
|
+
if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return {
|
|
4356
|
+
kind: "contiguous",
|
|
4357
|
+
tileSize: Infinity
|
|
4358
|
+
};
|
|
4359
|
+
if (exp.op === AluOp.Idiv && exp.src[1].op === AluOp.Const) {
|
|
4360
|
+
const N = exp.src[1].arg;
|
|
4361
|
+
const inner = analyzeStride(exp.src[0]);
|
|
4362
|
+
if (inner.kind === "broadcast") return inner;
|
|
4363
|
+
if (inner.kind !== "contiguous") return GATHER;
|
|
4364
|
+
if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
|
|
4365
|
+
return {
|
|
4366
|
+
kind: "broadcast",
|
|
4367
|
+
tileSize: Math.min(inner.tileSize, N)
|
|
4368
|
+
};
|
|
4369
|
+
}
|
|
4370
|
+
if (exp.op === AluOp.Mod && exp.src[1].op === AluOp.Const) {
|
|
4371
|
+
const N = exp.src[1].arg;
|
|
4372
|
+
const inner = analyzeStride(exp.src[0]);
|
|
4373
|
+
if (inner.kind === "broadcast") return inner;
|
|
4374
|
+
if (inner.kind !== "contiguous") return GATHER;
|
|
4375
|
+
if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
|
|
4376
|
+
return {
|
|
4377
|
+
kind: "contiguous",
|
|
4378
|
+
tileSize: Math.min(inner.tileSize, N)
|
|
4379
|
+
};
|
|
4380
|
+
}
|
|
4381
|
+
if (exp.op === AluOp.Mul) {
|
|
4382
|
+
for (let i = 0; i < 2; i++) if (exp.src[i].op === AluOp.Const) {
|
|
4383
|
+
const inner = analyzeStride(exp.src[1 - i]);
|
|
4384
|
+
if (inner.kind === "broadcast") return inner;
|
|
4385
|
+
return GATHER;
|
|
4386
|
+
}
|
|
4387
|
+
}
|
|
4388
|
+
if (exp.op === AluOp.Add) {
|
|
4389
|
+
const lhsHasGidx = referencesGidx(exp.src[0]);
|
|
4390
|
+
const rhsHasGidx = referencesGidx(exp.src[1]);
|
|
4391
|
+
if (lhsHasGidx && !rhsHasGidx) return analyzeStride(exp.src[0]);
|
|
4392
|
+
if (!lhsHasGidx && rhsHasGidx) return analyzeStride(exp.src[1]);
|
|
4393
|
+
}
|
|
4394
|
+
return GATHER;
|
|
4395
|
+
}
|
|
4396
|
+
/** Ops that have direct SIMD (f32x4) instruction variants. */
|
|
4397
|
+
const simdF32Ops = new Set([
|
|
4398
|
+
AluOp.Add,
|
|
4399
|
+
AluOp.Sub,
|
|
4400
|
+
AluOp.Mul,
|
|
4401
|
+
AluOp.Floor,
|
|
4402
|
+
AluOp.Ceil,
|
|
4403
|
+
AluOp.Min,
|
|
4404
|
+
AluOp.Max,
|
|
4405
|
+
AluOp.Sqrt,
|
|
4406
|
+
AluOp.Cast,
|
|
4407
|
+
AluOp.Where,
|
|
4408
|
+
AluOp.Const,
|
|
4409
|
+
AluOp.GlobalIndex
|
|
4410
|
+
]);
|
|
4411
|
+
/** Ops that have direct SIMD (i32x4) instruction variants. */
|
|
4412
|
+
const simdI32Ops = new Set([
|
|
4413
|
+
AluOp.Add,
|
|
4414
|
+
AluOp.Sub,
|
|
4415
|
+
AluOp.Mul,
|
|
4416
|
+
AluOp.Min,
|
|
4417
|
+
AluOp.Max,
|
|
4418
|
+
AluOp.Cast,
|
|
4419
|
+
AluOp.Where,
|
|
4420
|
+
AluOp.Const,
|
|
4421
|
+
AluOp.GlobalIndex
|
|
4422
|
+
]);
|
|
4423
|
+
/** Ops that produce Bool (i32x4 bitmask) in SIMD. */
|
|
4424
|
+
const simdBoolOps = new Set([
|
|
4425
|
+
AluOp.Cmplt,
|
|
4426
|
+
AluOp.Cmpne,
|
|
4427
|
+
AluOp.Const,
|
|
4428
|
+
AluOp.GlobalIndex
|
|
4429
|
+
]);
|
|
4430
|
+
/**
|
|
4431
|
+
* Check if a kernel is eligible for SIMD codegen.
|
|
4432
|
+
*
|
|
4433
|
+
* A kernel qualifies when:
|
|
4434
|
+
* - size >= 4 (need at least 4 elements for a SIMD group)
|
|
4435
|
+
* - For reductions: the reduction op has a SIMD variant for its dtype
|
|
4436
|
+
* - All nodes have a supported dtype (f32, i32, u32, bool) with SIMD variants
|
|
4437
|
+
*/
|
|
4438
|
+
function isSimdEligible(tunedExp, kernel) {
|
|
4439
|
+
if (kernel.size < SIMD_LANES) return false;
|
|
4440
|
+
if (kernel.reduction) {
|
|
4441
|
+
if (!simdSupportedOpsForDtype(kernel.reduction.dtype)?.has(kernel.reduction.op)) return false;
|
|
4442
|
+
}
|
|
4443
|
+
const check = (exp, visited) => {
|
|
4444
|
+
if (visited.has(exp)) return true;
|
|
4445
|
+
visited.add(exp);
|
|
4446
|
+
const supportedOps = simdSupportedOpsForDtype(exp.dtype);
|
|
4447
|
+
if (!supportedOps || !supportedOps.has(exp.op)) return false;
|
|
4448
|
+
if (exp.op === AluOp.GlobalIndex) return true;
|
|
4449
|
+
for (const child of exp.src) if (!check(child, visited)) return false;
|
|
4450
|
+
return true;
|
|
4451
|
+
};
|
|
4452
|
+
return check(tunedExp, /* @__PURE__ */ new Set());
|
|
4453
|
+
}
|
|
4454
|
+
function simdSupportedOpsForDtype(dtype) {
|
|
4455
|
+
if (dtype === DType.Float32) return simdF32Ops;
|
|
4456
|
+
if (dtype === DType.Int32 || dtype === DType.Uint32) return simdI32Ops;
|
|
4457
|
+
if (dtype === DType.Bool) return simdBoolOps;
|
|
4458
|
+
return null;
|
|
4459
|
+
}
|
|
3848
4460
|
const moduleCache = /* @__PURE__ */ new Map();
|
|
3849
4461
|
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3850
4462
|
var WasmBackend = class {
|
|
@@ -3854,11 +4466,18 @@ var WasmBackend = class {
|
|
|
3854
4466
|
#nextSlot;
|
|
3855
4467
|
#allocator;
|
|
3856
4468
|
#buffers;
|
|
4469
|
+
#workerPool;
|
|
4470
|
+
#pendingWork = /* @__PURE__ */ new Map();
|
|
3857
4471
|
constructor() {
|
|
3858
|
-
this.#memory = new WebAssembly.Memory({
|
|
4472
|
+
this.#memory = hasSharedArrayBuffer() ? new WebAssembly.Memory({
|
|
4473
|
+
initial: 0,
|
|
4474
|
+
maximum: 65536,
|
|
4475
|
+
shared: true
|
|
4476
|
+
}) : new WebAssembly.Memory({ initial: 0 });
|
|
3859
4477
|
this.#allocator = new WasmAllocator(this.#memory);
|
|
3860
4478
|
this.#nextSlot = 1;
|
|
3861
4479
|
this.#buffers = /* @__PURE__ */ new Map();
|
|
4480
|
+
this.#workerPool = createWorkerPool(this.#memory);
|
|
3862
4481
|
}
|
|
3863
4482
|
malloc(size, initialData) {
|
|
3864
4483
|
const ptr = this.#allocator.malloc(size);
|
|
@@ -3889,37 +4508,70 @@ var WasmBackend = class {
|
|
|
3889
4508
|
}
|
|
3890
4509
|
}
|
|
3891
4510
|
async read(slot, start, count) {
|
|
3892
|
-
|
|
4511
|
+
const epoch = this.#pendingWork.get(slot);
|
|
4512
|
+
if (epoch) await this.#workerPool.waitForEpoch(epoch);
|
|
4513
|
+
return this.#readData(slot, start, count);
|
|
3893
4514
|
}
|
|
3894
4515
|
readSync(slot, start, count) {
|
|
4516
|
+
const epoch = this.#pendingWork.get(slot);
|
|
4517
|
+
if (epoch && this.#workerPool.epoch < epoch) throw new Error("cannot read synchronously from a slot with async work");
|
|
4518
|
+
return this.#readData(slot, start, count);
|
|
4519
|
+
}
|
|
4520
|
+
#readData(slot, start, count) {
|
|
3895
4521
|
const buffer = this.#getBuffer(slot);
|
|
3896
4522
|
if (start === void 0) start = 0;
|
|
3897
4523
|
if (count === void 0) count = buffer.byteLength - start;
|
|
3898
|
-
return buffer.slice(start, start + count);
|
|
4524
|
+
if (buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
|
|
4525
|
+
else return buffer.slice(start, start + count);
|
|
3899
4526
|
}
|
|
3900
4527
|
async prepareKernel(kernel) {
|
|
3901
|
-
|
|
4528
|
+
const kernelHash = FpHash.hash(kernel);
|
|
4529
|
+
const module$1 = await runWithCacheAsync(moduleCache, kernelHash.toString(), () => WebAssembly.compile(codegenWasm(kernel)));
|
|
4530
|
+
return new Executable(kernel, {
|
|
4531
|
+
module: module$1,
|
|
4532
|
+
parallel: this.#workerPool !== null
|
|
4533
|
+
});
|
|
3902
4534
|
}
|
|
3903
4535
|
prepareKernelSync(kernel) {
|
|
3904
4536
|
const kernelHash = FpHash.hash(kernel);
|
|
3905
|
-
const module$1 = runWithCache(moduleCache, kernelHash.toString(), () =>
|
|
3906
|
-
|
|
3907
|
-
|
|
4537
|
+
const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => new WebAssembly.Module(codegenWasm(kernel)));
|
|
4538
|
+
return new Executable(kernel, {
|
|
4539
|
+
module: module$1,
|
|
4540
|
+
parallel: false
|
|
3908
4541
|
});
|
|
3909
|
-
return new Executable(kernel, { module: module$1 });
|
|
3910
4542
|
}
|
|
3911
4543
|
async prepareRoutine(routine) {
|
|
3912
4544
|
return this.prepareRoutineSync(routine);
|
|
3913
4545
|
}
|
|
3914
4546
|
prepareRoutineSync(routine) {
|
|
3915
|
-
return new Executable(routine,
|
|
4547
|
+
return new Executable(routine, {
|
|
4548
|
+
module: void 0,
|
|
4549
|
+
parallel: false
|
|
4550
|
+
});
|
|
3916
4551
|
}
|
|
3917
4552
|
dispatch(exe, inputs, outputs) {
|
|
3918
|
-
|
|
3919
|
-
const
|
|
3920
|
-
|
|
3921
|
-
|
|
3922
|
-
|
|
4553
|
+
const tracing = isTracing();
|
|
4554
|
+
const start = tracing ? performance.now() : 0;
|
|
4555
|
+
if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
4556
|
+
else {
|
|
4557
|
+
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
4558
|
+
if (exe.data.parallel && this.#workerPool) {
|
|
4559
|
+
const epoch = this.#workerPool.dispatch(exe.data.module, ptrs, exe.source.size);
|
|
4560
|
+
for (const slot of outputs) this.#pendingWork.set(slot, epoch);
|
|
4561
|
+
} else {
|
|
4562
|
+
if (inputs.some((slot) => {
|
|
4563
|
+
const epoch = this.#pendingWork.get(slot);
|
|
4564
|
+
return epoch && this.#workerPool.epoch < epoch;
|
|
4565
|
+
})) throw new Error("cannot dispatch synchronously with pending async work");
|
|
4566
|
+
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
4567
|
+
const func = instance.exports.kernel;
|
|
4568
|
+
func(...ptrs, 0, exe.source.size);
|
|
4569
|
+
}
|
|
4570
|
+
}
|
|
4571
|
+
if (tracing) {
|
|
4572
|
+
const info = traceSourceInfo(exe.source);
|
|
4573
|
+
emitTrace("wasm", info, start, performance.now());
|
|
4574
|
+
}
|
|
3923
4575
|
}
|
|
3924
4576
|
#getBuffer(slot) {
|
|
3925
4577
|
const buffer = this.#buffers.get(slot);
|
|
@@ -3927,12 +4579,36 @@ var WasmBackend = class {
|
|
|
3927
4579
|
return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
|
|
3928
4580
|
}
|
|
3929
4581
|
};
|
|
4582
|
+
/** Emit a runtime guard: enter the if-block only when [begin, end) is SIMD-aligned. */
|
|
4583
|
+
function emitAlignmentGuard(cg, paramBegin, paramEnd) {
|
|
4584
|
+
const mask = SIMD_LANES - 1;
|
|
4585
|
+
cg.local.get(paramEnd);
|
|
4586
|
+
cg.local.get(paramBegin);
|
|
4587
|
+
cg.i32.sub();
|
|
4588
|
+
cg.i32.const(mask);
|
|
4589
|
+
cg.i32.and();
|
|
4590
|
+
cg.i32.eqz();
|
|
4591
|
+
cg.local.get(paramBegin);
|
|
4592
|
+
cg.i32.const(mask);
|
|
4593
|
+
cg.i32.and();
|
|
4594
|
+
cg.i32.eqz();
|
|
4595
|
+
cg.i32.and();
|
|
4596
|
+
cg.if(cg.void);
|
|
4597
|
+
}
|
|
3930
4598
|
function codegenWasm(kernel) {
|
|
3931
4599
|
const tune = tuneNullopt(kernel);
|
|
3932
4600
|
const re = kernel.reduction;
|
|
3933
4601
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
4602
|
+
const useSimd = isSimdEligible(tune.exp, kernel);
|
|
4603
|
+
const bufferStrides = /* @__PURE__ */ new Map();
|
|
4604
|
+
if (useSimd) tune.exp.collect((e) => e.op === AluOp.GlobalIndex).forEach((gi) => {
|
|
4605
|
+
const result = analyzeStride(gi.src[0]);
|
|
4606
|
+
if (result.kind !== "gather" && (result.tileSize < SIMD_LANES || isFinite(result.tileSize) && result.tileSize % SIMD_LANES !== 0)) bufferStrides.set(gi, GATHER);
|
|
4607
|
+
else bufferStrides.set(gi, result);
|
|
4608
|
+
});
|
|
3934
4609
|
const cg = new CodeGenerator();
|
|
3935
4610
|
cg.memory.import("env", "memory");
|
|
4611
|
+
if (hasSharedArrayBuffer()) cg.memory.pages(0, 65536).shared(true);
|
|
3936
4612
|
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
3937
4613
|
const funcs = {};
|
|
3938
4614
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
@@ -3944,12 +4620,127 @@ function codegenWasm(kernel) {
|
|
|
3944
4620
|
if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
|
|
3945
4621
|
if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
|
|
3946
4622
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
3947
|
-
const
|
|
4623
|
+
const paramBegin = kernel.nargs + 1;
|
|
4624
|
+
const paramEnd = kernel.nargs + 2;
|
|
4625
|
+
const kernelFunc = cg.function(rep(kernel.nargs + 3, cg.i32), [], () => {
|
|
3948
4626
|
const gidx = cg.local.declare(cg.i32);
|
|
4627
|
+
cg.local.get(paramBegin);
|
|
4628
|
+
cg.local.set(gidx);
|
|
4629
|
+
if (useSimd) {
|
|
4630
|
+
emitAlignmentGuard(cg, paramBegin, paramEnd);
|
|
4631
|
+
cg.loop(cg.void);
|
|
4632
|
+
if (!re) {
|
|
4633
|
+
cg.block(cg.void);
|
|
4634
|
+
cg.local.get(gidx);
|
|
4635
|
+
cg.local.get(paramEnd);
|
|
4636
|
+
cg.i32.ge_u();
|
|
4637
|
+
cg.br_if(0);
|
|
4638
|
+
cg.local.get(kernel.nargs);
|
|
4639
|
+
cg.local.get(gidx);
|
|
4640
|
+
cg.i32.const(byteWidth(kernel.dtype));
|
|
4641
|
+
cg.i32.mul();
|
|
4642
|
+
cg.i32.add();
|
|
4643
|
+
translateExpSimd(cg, funcs, tune.exp, { gidx }, bufferStrides);
|
|
4644
|
+
cg.v128.store(4);
|
|
4645
|
+
cg.local.get(gidx);
|
|
4646
|
+
cg.i32.const(SIMD_LANES);
|
|
4647
|
+
cg.i32.add();
|
|
4648
|
+
cg.local.set(gidx);
|
|
4649
|
+
cg.br(1);
|
|
4650
|
+
cg.end();
|
|
4651
|
+
} else {
|
|
4652
|
+
const reIsInt = kernel.exp.dtype === DType.Int32 || kernel.exp.dtype === DType.Uint32;
|
|
4653
|
+
cg.block(cg.void);
|
|
4654
|
+
cg.local.get(gidx);
|
|
4655
|
+
cg.local.get(paramEnd);
|
|
4656
|
+
cg.i32.ge_u();
|
|
4657
|
+
cg.br_if(0);
|
|
4658
|
+
const vecAcc = cg.local.declare(reIsInt ? cg.i32x4 : cg.f32x4);
|
|
4659
|
+
if (reIsInt) {
|
|
4660
|
+
cg.i32.const(re.identity);
|
|
4661
|
+
cg.i32x4.splat();
|
|
4662
|
+
} else {
|
|
4663
|
+
cg.f32.const(re.identity);
|
|
4664
|
+
cg.f32x4.splat();
|
|
4665
|
+
}
|
|
4666
|
+
cg.local.set(vecAcc);
|
|
4667
|
+
const ridx = cg.local.declare(cg.i32);
|
|
4668
|
+
cg.i32.const(0);
|
|
4669
|
+
cg.local.set(ridx);
|
|
4670
|
+
cg.loop(cg.void);
|
|
4671
|
+
cg.block(cg.void);
|
|
4672
|
+
cg.local.get(ridx);
|
|
4673
|
+
cg.i32.const(re.size);
|
|
4674
|
+
cg.i32.ge_u();
|
|
4675
|
+
cg.br_if(0);
|
|
4676
|
+
translateExpSimd(cg, funcs, tune.exp, {
|
|
4677
|
+
gidx,
|
|
4678
|
+
ridx
|
|
4679
|
+
}, bufferStrides);
|
|
4680
|
+
cg.local.get(vecAcc);
|
|
4681
|
+
if (reIsInt) if (re.op === AluOp.Add) cg.i32x4.add();
|
|
4682
|
+
else if (re.op === AluOp.Mul) cg.i32x4.mul();
|
|
4683
|
+
else if (re.op === AluOp.Min) if (re.dtype === DType.Int32) cg.i32x4.min_s();
|
|
4684
|
+
else cg.i32x4.min_u();
|
|
4685
|
+
else if (re.op === AluOp.Max) if (re.dtype === DType.Int32) cg.i32x4.max_s();
|
|
4686
|
+
else cg.i32x4.max_u();
|
|
4687
|
+
else throw new Error(`invalid SIMD reduction op: ${re.op}`);
|
|
4688
|
+
else if (re.op === AluOp.Add) cg.f32x4.add();
|
|
4689
|
+
else if (re.op === AluOp.Mul) cg.f32x4.mul();
|
|
4690
|
+
else if (re.op === AluOp.Min) cg.f32x4.min();
|
|
4691
|
+
else if (re.op === AluOp.Max) cg.f32x4.max();
|
|
4692
|
+
else throw new Error(`invalid SIMD reduction op: ${re.op}`);
|
|
4693
|
+
cg.local.set(vecAcc);
|
|
4694
|
+
cg.local.get(ridx);
|
|
4695
|
+
cg.i32.const(1);
|
|
4696
|
+
cg.i32.add();
|
|
4697
|
+
cg.local.set(ridx);
|
|
4698
|
+
cg.br(1);
|
|
4699
|
+
cg.end();
|
|
4700
|
+
cg.end();
|
|
4701
|
+
for (let lane = 0; lane < SIMD_LANES; lane++) {
|
|
4702
|
+
cg.local.get(kernel.nargs);
|
|
4703
|
+
cg.local.get(gidx);
|
|
4704
|
+
if (lane > 0) {
|
|
4705
|
+
cg.i32.const(lane);
|
|
4706
|
+
cg.i32.add();
|
|
4707
|
+
}
|
|
4708
|
+
cg.i32.const(byteWidth(kernel.dtype));
|
|
4709
|
+
cg.i32.mul();
|
|
4710
|
+
cg.i32.add();
|
|
4711
|
+
const acc = cg.local.declare(reIsInt ? cg.i32 : cg.f32);
|
|
4712
|
+
cg.local.get(vecAcc);
|
|
4713
|
+
if (reIsInt) cg.i32x4.extract_lane(lane);
|
|
4714
|
+
else cg.f32x4.extract_lane(lane);
|
|
4715
|
+
cg.local.set(acc);
|
|
4716
|
+
const laneGidx = cg.local.declare(cg.i32);
|
|
4717
|
+
cg.local.get(gidx);
|
|
4718
|
+
if (lane > 0) {
|
|
4719
|
+
cg.i32.const(lane);
|
|
4720
|
+
cg.i32.add();
|
|
4721
|
+
}
|
|
4722
|
+
cg.local.set(laneGidx);
|
|
4723
|
+
translateExp(cg, funcs, tune.epilogue, {
|
|
4724
|
+
acc,
|
|
4725
|
+
gidx: laneGidx
|
|
4726
|
+
});
|
|
4727
|
+
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
4728
|
+
}
|
|
4729
|
+
cg.local.get(gidx);
|
|
4730
|
+
cg.i32.const(SIMD_LANES);
|
|
4731
|
+
cg.i32.add();
|
|
4732
|
+
cg.local.set(gidx);
|
|
4733
|
+
cg.br(1);
|
|
4734
|
+
cg.end();
|
|
4735
|
+
}
|
|
4736
|
+
cg.end();
|
|
4737
|
+
cg.return();
|
|
4738
|
+
cg.end();
|
|
4739
|
+
}
|
|
3949
4740
|
cg.loop(cg.void);
|
|
3950
4741
|
cg.block(cg.void);
|
|
3951
4742
|
cg.local.get(gidx);
|
|
3952
|
-
cg.
|
|
4743
|
+
cg.local.get(paramEnd);
|
|
3953
4744
|
cg.i32.ge_u();
|
|
3954
4745
|
cg.br_if(0);
|
|
3955
4746
|
cg.local.get(kernel.nargs);
|
|
@@ -4088,6 +4879,11 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
4088
4879
|
else cg.i32.gt_u();
|
|
4089
4880
|
cg.select();
|
|
4090
4881
|
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
4882
|
+
else if (op === AluOp.BitCombine) if (arg === "and") cg.i32.and();
|
|
4883
|
+
else if (arg === "or") cg.i32.or();
|
|
4884
|
+
else cg.i32.xor();
|
|
4885
|
+
else if (op === AluOp.BitShift) if (arg === "shl") cg.i32.shl();
|
|
4886
|
+
else cg.i32.shr_u();
|
|
4091
4887
|
else if (op === AluOp.Cmplt) {
|
|
4092
4888
|
const srcDtype = src[0].dtype;
|
|
4093
4889
|
if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
|
|
@@ -4264,7 +5060,7 @@ async function createBackend(device) {
|
|
|
4264
5060
|
if (!navigator.gpu) return null;
|
|
4265
5061
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4266
5062
|
if (!adapter) return null;
|
|
4267
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
5063
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-uU9nnttc.cjs"));
|
|
4268
5064
|
const importantLimits = [
|
|
4269
5065
|
"maxBufferSize",
|
|
4270
5066
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4302,7 +5098,7 @@ async function createBackend(device) {
|
|
|
4302
5098
|
});
|
|
4303
5099
|
if (!gl) return null;
|
|
4304
5100
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4305
|
-
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-
|
|
5101
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-Ovaaa-Qx.cjs"));
|
|
4306
5102
|
return new WebGLBackend(gl);
|
|
4307
5103
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4308
5104
|
}
|
|
@@ -4336,6 +5132,15 @@ var UnsupportedRoutineError = class extends Error {
|
|
|
4336
5132
|
super(`routine '${name}' is not supported in ${device} backend`);
|
|
4337
5133
|
}
|
|
4338
5134
|
};
|
|
5135
|
+
/**
|
|
5136
|
+
* If the WebGPU backend has been initialized, return the `GPUDevice` that this
|
|
5137
|
+
* backend runs on. This is useful for sharing buffers.
|
|
5138
|
+
*/
|
|
5139
|
+
function getWebGPUDevice() {
|
|
5140
|
+
const backend = initializedBackends.get("webgpu");
|
|
5141
|
+
if (!backend) throw new Error("WebGPU backend not initialized, call init('webgpu') first");
|
|
5142
|
+
return backend.device;
|
|
5143
|
+
}
|
|
4339
5144
|
|
|
4340
5145
|
//#endregion
|
|
4341
5146
|
Object.defineProperty(exports, 'AluExp', {
|
|
@@ -4506,6 +5311,12 @@ Object.defineProperty(exports, 'dtypedJsArray', {
|
|
|
4506
5311
|
return dtypedJsArray;
|
|
4507
5312
|
}
|
|
4508
5313
|
});
|
|
5314
|
+
Object.defineProperty(exports, 'emitTrace', {
|
|
5315
|
+
enumerable: true,
|
|
5316
|
+
get: function () {
|
|
5317
|
+
return emitTrace;
|
|
5318
|
+
}
|
|
5319
|
+
});
|
|
4509
5320
|
Object.defineProperty(exports, 'findPow2', {
|
|
4510
5321
|
enumerable: true,
|
|
4511
5322
|
get: function () {
|
|
@@ -4524,6 +5335,12 @@ Object.defineProperty(exports, 'getBackend', {
|
|
|
4524
5335
|
return getBackend;
|
|
4525
5336
|
}
|
|
4526
5337
|
});
|
|
5338
|
+
Object.defineProperty(exports, 'getWebGPUDevice', {
|
|
5339
|
+
enumerable: true,
|
|
5340
|
+
get: function () {
|
|
5341
|
+
return getWebGPUDevice;
|
|
5342
|
+
}
|
|
5343
|
+
});
|
|
4527
5344
|
Object.defineProperty(exports, 'init', {
|
|
4528
5345
|
enumerable: true,
|
|
4529
5346
|
get: function () {
|
|
@@ -4554,6 +5371,12 @@ Object.defineProperty(exports, 'isPermutation', {
|
|
|
4554
5371
|
return isPermutation;
|
|
4555
5372
|
}
|
|
4556
5373
|
});
|
|
5374
|
+
Object.defineProperty(exports, 'isTracing', {
|
|
5375
|
+
enumerable: true,
|
|
5376
|
+
get: function () {
|
|
5377
|
+
return isTracing;
|
|
5378
|
+
}
|
|
5379
|
+
});
|
|
4557
5380
|
Object.defineProperty(exports, 'mapSetUnion', {
|
|
4558
5381
|
enumerable: true,
|
|
4559
5382
|
get: function () {
|
|
@@ -4566,6 +5389,12 @@ Object.defineProperty(exports, 'normalizeAxis', {
|
|
|
4566
5389
|
return normalizeAxis;
|
|
4567
5390
|
}
|
|
4568
5391
|
});
|
|
5392
|
+
Object.defineProperty(exports, 'onFlushTrace', {
|
|
5393
|
+
enumerable: true,
|
|
5394
|
+
get: function () {
|
|
5395
|
+
return onFlushTrace;
|
|
5396
|
+
}
|
|
5397
|
+
});
|
|
4569
5398
|
Object.defineProperty(exports, 'partitionList', {
|
|
4570
5399
|
enumerable: true,
|
|
4571
5400
|
get: function () {
|
|
@@ -4614,6 +5443,18 @@ Object.defineProperty(exports, 'setDebug', {
|
|
|
4614
5443
|
return setDebug;
|
|
4615
5444
|
}
|
|
4616
5445
|
});
|
|
5446
|
+
Object.defineProperty(exports, 'startTrace', {
|
|
5447
|
+
enumerable: true,
|
|
5448
|
+
get: function () {
|
|
5449
|
+
return startTrace;
|
|
5450
|
+
}
|
|
5451
|
+
});
|
|
5452
|
+
Object.defineProperty(exports, 'stopTrace', {
|
|
5453
|
+
enumerable: true,
|
|
5454
|
+
get: function () {
|
|
5455
|
+
return stopTrace;
|
|
5456
|
+
}
|
|
5457
|
+
});
|
|
4617
5458
|
Object.defineProperty(exports, 'strip1', {
|
|
4618
5459
|
enumerable: true,
|
|
4619
5460
|
get: function () {
|
|
@@ -4626,6 +5467,12 @@ Object.defineProperty(exports, 'toposort', {
|
|
|
4626
5467
|
return toposort;
|
|
4627
5468
|
}
|
|
4628
5469
|
});
|
|
5470
|
+
Object.defineProperty(exports, 'traceSourceInfo', {
|
|
5471
|
+
enumerable: true,
|
|
5472
|
+
get: function () {
|
|
5473
|
+
return traceSourceInfo;
|
|
5474
|
+
}
|
|
5475
|
+
});
|
|
4629
5476
|
Object.defineProperty(exports, 'tuneNullopt', {
|
|
4630
5477
|
enumerable: true,
|
|
4631
5478
|
get: function () {
|