@jax-js/jax 0.1.10 → 0.1.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -2
- package/dist/{backend-Ctqs8la1.js → backend-DI-V78Rk.js} +732 -21
- package/dist/{backend-DMauYnfl.cjs → backend-x-6vqzIM.cjs} +737 -20
- package/dist/index.cjs +372 -20
- package/dist/index.d.cts +172 -4
- package/dist/index.d.ts +172 -4
- package/dist/index.js +372 -21
- package/dist/{webgl-CvQ1QBX1.js → webgl-BhsnpeB0.js} +7 -1
- package/dist/{webgl-kvVt7-T7.cjs → webgl-CD3WK_Me.cjs} +7 -1
- package/dist/{webgpu-v_W_-oKw.js → webgpu-C2kLdkUh.js} +299 -149
- package/dist/{webgpu-DMSx7a6M.cjs → webgpu-C4S8Uq9e.cjs} +299 -149
- package/package.json +1 -1
|
@@ -313,6 +313,16 @@ function runWithCache(cache, key, thunk) {
|
|
|
313
313
|
return value;
|
|
314
314
|
}
|
|
315
315
|
}
|
|
316
|
+
/** Async version of `runWithCache`. */
|
|
317
|
+
async function runWithCacheAsync(cache, key, thunk) {
|
|
318
|
+
const keyStr = JSON.stringify(key);
|
|
319
|
+
if (cache.has(keyStr)) return cache.get(keyStr);
|
|
320
|
+
else {
|
|
321
|
+
const value = await thunk();
|
|
322
|
+
cache.set(keyStr, value);
|
|
323
|
+
return value;
|
|
324
|
+
}
|
|
325
|
+
}
|
|
316
326
|
|
|
317
327
|
//#endregion
|
|
318
328
|
//#region src/alu.ts
|
|
@@ -415,8 +425,23 @@ var AluExp = class AluExp {
|
|
|
415
425
|
this.src = src;
|
|
416
426
|
this.arg = arg;
|
|
417
427
|
if (AluGroup.RequiredFloat.has(op) && !isFloatDtype(dtype)) throw new TypeError(`Unsupported dtype for ${op}: ${dtype}`);
|
|
418
|
-
|
|
419
|
-
|
|
428
|
+
switch (op) {
|
|
429
|
+
case AluOp.Bitcast:
|
|
430
|
+
if (dtype === DType.Bool || src[0].dtype === DType.Bool || byteWidth(dtype) !== byteWidth(src[0].dtype)) throw new TypeError(`Bitcast from ${src[0].dtype} -> ${dtype}`);
|
|
431
|
+
break;
|
|
432
|
+
case AluOp.Threefry2x32:
|
|
433
|
+
if (dtype !== DType.Uint32 || src.some((x) => x.dtype !== DType.Uint32)) throw new TypeError("Threefry2x32 requires uint32 types");
|
|
434
|
+
break;
|
|
435
|
+
case AluOp.BitCombine:
|
|
436
|
+
if (src[0].dtype !== src[1].dtype || isFloatDtype(src[0].dtype)) throw new TypeError(`BitCombine[${arg}] requires matching integral dtype, got ${src[0].dtype} and ${src[1].dtype}`);
|
|
437
|
+
break;
|
|
438
|
+
case AluOp.BitShift:
|
|
439
|
+
if (src[0].dtype === DType.Bool || src[1].dtype === DType.Bool || isFloatDtype(src[0].dtype) || isFloatDtype(src[1].dtype)) throw new TypeError(`BitShift[${arg}] requires two integral, non-bool dtypes, got ${src[0].dtype} and ${src[1].dtype}`);
|
|
440
|
+
break;
|
|
441
|
+
case AluOp.BitInvert:
|
|
442
|
+
if (isFloatDtype(src[0].dtype)) throw new TypeError(`BitInvert requires an integral dtype, got ${src[0].dtype}`);
|
|
443
|
+
break;
|
|
444
|
+
}
|
|
420
445
|
}
|
|
421
446
|
static add(a, b) {
|
|
422
447
|
return new AluExp(AluOp.Add, a.dtype, [a, b]);
|
|
@@ -493,6 +518,12 @@ var AluExp = class AluExp {
|
|
|
493
518
|
c1
|
|
494
519
|
], mode);
|
|
495
520
|
}
|
|
521
|
+
static bitCombine(a, b, mode) {
|
|
522
|
+
return new AluExp(AluOp.BitCombine, a.dtype, [a, b], mode);
|
|
523
|
+
}
|
|
524
|
+
static bitShift(a, b, mode) {
|
|
525
|
+
return new AluExp(AluOp.BitShift, a.dtype, [a, b], mode);
|
|
526
|
+
}
|
|
496
527
|
static cmplt(a, b) {
|
|
497
528
|
return new AluExp(AluOp.Cmplt, DType.Bool, [a, b]);
|
|
498
529
|
}
|
|
@@ -965,6 +996,16 @@ var AluExp = class AluExp {
|
|
|
965
996
|
case AluOp.Mod: return x % y;
|
|
966
997
|
case AluOp.Min: return Math.min(x, y);
|
|
967
998
|
case AluOp.Max: return Math.max(x, y);
|
|
999
|
+
case AluOp.BitCombine: {
|
|
1000
|
+
let r;
|
|
1001
|
+
if (this.arg === "and") r = x & y;
|
|
1002
|
+
else if (this.arg === "or") r = x | y;
|
|
1003
|
+
else r = x ^ y;
|
|
1004
|
+
return this.dtype === DType.Int32 ? r | 0 : r >>> 0;
|
|
1005
|
+
}
|
|
1006
|
+
case AluOp.BitShift:
|
|
1007
|
+
if (this.arg === "shl") return this.dtype === DType.Int32 ? x << y | 0 : x << y >>> 0;
|
|
1008
|
+
return x >>> y;
|
|
968
1009
|
case AluOp.Cmplt: return Number(x < y);
|
|
969
1010
|
case AluOp.Cmpne: return Number(x != y);
|
|
970
1011
|
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
@@ -1086,6 +1127,18 @@ var AluExp = class AluExp {
|
|
|
1086
1127
|
}
|
|
1087
1128
|
if (BIN_SYM[node.op]) return `(${parts[0]} ${BIN_SYM[node.op]} ${parts[1]})`;
|
|
1088
1129
|
if (CMP_SYM[node.op]) return `(${parts[0]} ${CMP_SYM[node.op]} ${parts[1]})`;
|
|
1130
|
+
if (node.op === AluOp.BitCombine) {
|
|
1131
|
+
const sym = {
|
|
1132
|
+
and: "&",
|
|
1133
|
+
or: "|",
|
|
1134
|
+
xor: "^"
|
|
1135
|
+
}[node.arg];
|
|
1136
|
+
return `(${parts[0]} ${sym} ${parts[1]})`;
|
|
1137
|
+
}
|
|
1138
|
+
if (node.op === AluOp.BitShift) {
|
|
1139
|
+
const sym = node.arg === "shl" ? "<<" : ">>";
|
|
1140
|
+
return `(${parts[0]} ${sym} ${parts[1]})`;
|
|
1141
|
+
}
|
|
1089
1142
|
if (UNARY_SYM[node.op]) return `${UNARY_SYM[node.op]}${parts[0]}`;
|
|
1090
1143
|
if (node.op === AluOp.Cast) return `Cast<${node.dtype}>(${strip1(parts[0])})`;
|
|
1091
1144
|
if (node.op === AluOp.Bitcast) return `Bitcast<${node.dtype}>(${strip1(parts[0])})`;
|
|
@@ -1178,6 +1231,9 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1178
1231
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1179
1232
|
AluOp$1["Cast"] = "Cast";
|
|
1180
1233
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
1234
|
+
AluOp$1["BitCombine"] = "BitCombine";
|
|
1235
|
+
AluOp$1["BitInvert"] = "BitInvert";
|
|
1236
|
+
AluOp$1["BitShift"] = "BitShift";
|
|
1181
1237
|
AluOp$1["Cmplt"] = "Cmplt";
|
|
1182
1238
|
AluOp$1["Cmpne"] = "Cmpne";
|
|
1183
1239
|
AluOp$1["Where"] = "Where";
|
|
@@ -1197,7 +1253,9 @@ const AluGroup = {
|
|
|
1197
1253
|
AluOp.Idiv,
|
|
1198
1254
|
AluOp.Mod,
|
|
1199
1255
|
AluOp.Min,
|
|
1200
|
-
AluOp.Max
|
|
1256
|
+
AluOp.Max,
|
|
1257
|
+
AluOp.BitCombine,
|
|
1258
|
+
AluOp.BitShift
|
|
1201
1259
|
]),
|
|
1202
1260
|
Unary: new Set([
|
|
1203
1261
|
AluOp.Sin,
|
|
@@ -1372,11 +1430,13 @@ var Reduction = class {
|
|
|
1372
1430
|
function accessorGlobal(dtype, gid, st, indices) {
|
|
1373
1431
|
const [index, valid] = st.toAluExp(indices);
|
|
1374
1432
|
const [, len] = st.views[0].dataRange();
|
|
1433
|
+
if (valid.resolve()) return AluExp.globalIndex(dtype, gid, len, index);
|
|
1375
1434
|
return AluExp.where(valid, AluExp.globalIndex(dtype, gid, len, index), AluExp.const(dtype, 0));
|
|
1376
1435
|
}
|
|
1377
1436
|
/** Expression for accessing `indices` in an array recipe with variable "idx". */
|
|
1378
1437
|
function accessorAluExp(exp, st, indices) {
|
|
1379
1438
|
const [index, valid] = st.toAluExp(indices);
|
|
1439
|
+
if (valid.resolve()) return exp.substitute({ idx: index });
|
|
1380
1440
|
return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(exp.dtype, 0));
|
|
1381
1441
|
}
|
|
1382
1442
|
function threefry2x32(k0, k1, c0, c1) {
|
|
@@ -3158,6 +3218,147 @@ function wasm_threefry2x32(cg) {
|
|
|
3158
3218
|
});
|
|
3159
3219
|
}
|
|
3160
3220
|
|
|
3221
|
+
//#endregion
|
|
3222
|
+
//#region src/backend/wasm/parallel.ts
|
|
3223
|
+
/** Check if SharedArrayBuffer is available. */
|
|
3224
|
+
function hasSharedArrayBuffer() {
|
|
3225
|
+
return typeof SharedArrayBuffer !== "undefined" && typeof Worker !== "undefined";
|
|
3226
|
+
}
|
|
3227
|
+
const MIN_ELEMS_PER_THREAD = 256;
|
|
3228
|
+
const WORKER_SOURCE = `
|
|
3229
|
+
let memory = null;
|
|
3230
|
+
let cachedModule = null;
|
|
3231
|
+
let cachedFunc = null;
|
|
3232
|
+
|
|
3233
|
+
self.onmessage = (e) => {
|
|
3234
|
+
const msg = e.data;
|
|
3235
|
+
if (msg.type === "init") {
|
|
3236
|
+
memory = msg.memory;
|
|
3237
|
+
postMessage({ type: "ready" });
|
|
3238
|
+
return;
|
|
3239
|
+
}
|
|
3240
|
+
try {
|
|
3241
|
+
const { module, ptrs, begin, end } = msg;
|
|
3242
|
+
if (module !== cachedModule) {
|
|
3243
|
+
cachedModule = module;
|
|
3244
|
+
const instance = new WebAssembly.Instance(module, { env: { memory } });
|
|
3245
|
+
cachedFunc = instance.exports.kernel;
|
|
3246
|
+
}
|
|
3247
|
+
cachedFunc(...ptrs, begin, end);
|
|
3248
|
+
postMessage({ type: "done", ok: true });
|
|
3249
|
+
} catch (err) {
|
|
3250
|
+
postMessage({ type: "done", ok: false, error: String(err) });
|
|
3251
|
+
}
|
|
3252
|
+
};
|
|
3253
|
+
`;
|
|
3254
|
+
/** Pool of Web Workers for parallel WASM kernel dispatch. */
|
|
3255
|
+
var WasmWorkerPool = class {
|
|
3256
|
+
#memory;
|
|
3257
|
+
#numWorkers;
|
|
3258
|
+
#workers = [];
|
|
3259
|
+
#ready = Promise.resolve();
|
|
3260
|
+
/** Serializes dispatches so concurrent read() calls don't clobber onmessage. */
|
|
3261
|
+
#queue = Promise.resolve();
|
|
3262
|
+
#epoch = 0n;
|
|
3263
|
+
#epochEnd = 0n;
|
|
3264
|
+
#hooks = /* @__PURE__ */ new Map();
|
|
3265
|
+
constructor(memory, numWorkers) {
|
|
3266
|
+
if (numWorkers <= 0) throw new Error("numWorkers must be positive");
|
|
3267
|
+
this.#memory = memory;
|
|
3268
|
+
this.#numWorkers = numWorkers;
|
|
3269
|
+
}
|
|
3270
|
+
get epoch() {
|
|
3271
|
+
return this.#epoch;
|
|
3272
|
+
}
|
|
3273
|
+
waitForEpoch(target) {
|
|
3274
|
+
if (target <= this.#epoch) return Promise.resolve();
|
|
3275
|
+
return new Promise((resolve) => {
|
|
3276
|
+
if (target <= this.#epoch) return resolve();
|
|
3277
|
+
const hooks = this.#hooks.get(target);
|
|
3278
|
+
if (hooks) hooks.push(resolve);
|
|
3279
|
+
else this.#hooks.set(target, [resolve]);
|
|
3280
|
+
});
|
|
3281
|
+
}
|
|
3282
|
+
#ensureInit() {
|
|
3283
|
+
if (this.#workers.length > 0) return;
|
|
3284
|
+
const blob = new Blob([WORKER_SOURCE], { type: "application/javascript" });
|
|
3285
|
+
const url = URL.createObjectURL(blob);
|
|
3286
|
+
this.#workers = [];
|
|
3287
|
+
const readyPromises = [];
|
|
3288
|
+
for (let i = 0; i < this.#numWorkers; i++) {
|
|
3289
|
+
const worker = new Worker(url, { type: "module" });
|
|
3290
|
+
this.#workers.push(worker);
|
|
3291
|
+
readyPromises.push(new Promise((resolve, reject) => {
|
|
3292
|
+
worker.onmessage = () => resolve();
|
|
3293
|
+
worker.onerror = (e) => reject(new Error(e.message || "Worker failed to load"));
|
|
3294
|
+
}));
|
|
3295
|
+
worker.postMessage({
|
|
3296
|
+
type: "init",
|
|
3297
|
+
memory: this.#memory
|
|
3298
|
+
});
|
|
3299
|
+
}
|
|
3300
|
+
this.#ready = Promise.all(readyPromises).then(() => {
|
|
3301
|
+
URL.revokeObjectURL(url);
|
|
3302
|
+
});
|
|
3303
|
+
this.#queue = this.#ready;
|
|
3304
|
+
}
|
|
3305
|
+
/**
|
|
3306
|
+
* Dispatch a kernel across multiple workers.
|
|
3307
|
+
*
|
|
3308
|
+
* Returns an epoch that can be used to wait for the ongoing work to complete,
|
|
3309
|
+
* which is guaranteed to be monotonically increasing.
|
|
3310
|
+
*/
|
|
3311
|
+
dispatch(module, ptrs, size) {
|
|
3312
|
+
this.#ensureInit();
|
|
3313
|
+
this.#epochEnd++;
|
|
3314
|
+
const result = this.#queue.then(() => this.#dispatchNow(module, ptrs, size));
|
|
3315
|
+
this.#queue = result.then(() => {}, () => {}).then(() => {
|
|
3316
|
+
this.#epoch++;
|
|
3317
|
+
const hooks = this.#hooks.get(this.#epoch);
|
|
3318
|
+
if (hooks) {
|
|
3319
|
+
for (const hook of hooks) hook();
|
|
3320
|
+
this.#hooks.delete(this.#epoch);
|
|
3321
|
+
}
|
|
3322
|
+
});
|
|
3323
|
+
return this.#epochEnd;
|
|
3324
|
+
}
|
|
3325
|
+
async #dispatchNow(module, ptrs, size) {
|
|
3326
|
+
if (size === 0) return;
|
|
3327
|
+
const n = Math.min(this.#workers.length, Math.ceil(size / MIN_ELEMS_PER_THREAD));
|
|
3328
|
+
const chunkSize = Math.ceil(size / n / 16) * 16;
|
|
3329
|
+
const promises = [];
|
|
3330
|
+
for (let i = 0; i < n; i++) {
|
|
3331
|
+
const begin = i * chunkSize;
|
|
3332
|
+
const end = Math.min(begin + chunkSize, size);
|
|
3333
|
+
if (begin >= size) break;
|
|
3334
|
+
const worker = this.#workers[i];
|
|
3335
|
+
promises.push(new Promise((resolve, reject) => {
|
|
3336
|
+
worker.onmessage = (e) => {
|
|
3337
|
+
if (e.data.ok) resolve();
|
|
3338
|
+
else reject(/* @__PURE__ */ new Error(`Worker error: ${e.data.error}`));
|
|
3339
|
+
};
|
|
3340
|
+
worker.postMessage({
|
|
3341
|
+
module,
|
|
3342
|
+
ptrs,
|
|
3343
|
+
begin,
|
|
3344
|
+
end
|
|
3345
|
+
});
|
|
3346
|
+
}));
|
|
3347
|
+
}
|
|
3348
|
+
await Promise.all(promises);
|
|
3349
|
+
}
|
|
3350
|
+
};
|
|
3351
|
+
/** Try to create a worker pool. Returns null if workers are unavailable. */
|
|
3352
|
+
function createWorkerPool(memory) {
|
|
3353
|
+
if (!hasSharedArrayBuffer()) return null;
|
|
3354
|
+
try {
|
|
3355
|
+
const numWorkers = Math.max(1, typeof navigator !== "undefined" && navigator.hardwareConcurrency || 4);
|
|
3356
|
+
return new WasmWorkerPool(memory, numWorkers);
|
|
3357
|
+
} catch {
|
|
3358
|
+
return null;
|
|
3359
|
+
}
|
|
3360
|
+
}
|
|
3361
|
+
|
|
3161
3362
|
//#endregion
|
|
3162
3363
|
//#region src/backend/wasm/wasmblr.ts
|
|
3163
3364
|
/**
|
|
@@ -3495,7 +3696,7 @@ var CodeGenerator = class {
|
|
|
3495
3696
|
concat(importSectionBytes, encodeString(this.memory.aString));
|
|
3496
3697
|
concat(importSectionBytes, encodeString(this.memory.bString));
|
|
3497
3698
|
importSectionBytes.push(2);
|
|
3498
|
-
if (this.memory.
|
|
3699
|
+
if (this.memory.max) {
|
|
3499
3700
|
if (this.memory.isShared) importSectionBytes.push(3);
|
|
3500
3701
|
else importSectionBytes.push(1);
|
|
3501
3702
|
concat(importSectionBytes, encodeUnsigned(this.memory.min));
|
|
@@ -3902,6 +4103,8 @@ var I32x4 = class extends V128 {
|
|
|
3902
4103
|
min_u = VECTOR_OP("min_u", 183, ["v128", "v128"], "v128");
|
|
3903
4104
|
max_s = VECTOR_OP("max_s", 184, ["v128", "v128"], "v128");
|
|
3904
4105
|
max_u = VECTOR_OP("max_u", 185, ["v128", "v128"], "v128");
|
|
4106
|
+
trunc_sat_f32x4_s = VECTOR_OP("trunc_sat_f32x4_s", 248, ["v128"], "v128");
|
|
4107
|
+
trunc_sat_f32x4_u = VECTOR_OP("trunc_sat_f32x4_u", 249, ["v128"], "v128");
|
|
3905
4108
|
};
|
|
3906
4109
|
var F32x4 = class extends V128 {
|
|
3907
4110
|
splat = VECTOR_OP("splat", 19, ["f32"], "v128");
|
|
@@ -3928,10 +4131,333 @@ var F32x4 = class extends V128 {
|
|
|
3928
4131
|
max = VECTOR_OP("max", 233, ["v128", "v128"], "v128");
|
|
3929
4132
|
pmin = VECTOR_OP("pmin", 234, ["v128", "v128"], "v128");
|
|
3930
4133
|
pmax = VECTOR_OP("pmax", 235, ["v128", "v128"], "v128");
|
|
4134
|
+
convert_i32x4_s = VECTOR_OP("convert_i32x4_s", 250, ["v128"], "v128");
|
|
4135
|
+
convert_i32x4_u = VECTOR_OP("convert_i32x4_u", 251, ["v128"], "v128");
|
|
3931
4136
|
};
|
|
3932
4137
|
|
|
3933
4138
|
//#endregion
|
|
3934
4139
|
//#region src/backend/wasm.ts
|
|
4140
|
+
/**
|
|
4141
|
+
* SIMD version of translateExp: emits v128 (f32x4 or i32x4) instructions instead of scalar.
|
|
4142
|
+
* gidx always steps by 4. strideMap classifies each GlobalIndex as broadcast/contiguous/gather.
|
|
4143
|
+
*/
|
|
4144
|
+
function translateExpSimd(cg, funcs, exp, ctx, strideMap) {
|
|
4145
|
+
const references = /* @__PURE__ */ new Map();
|
|
4146
|
+
const seen = /* @__PURE__ */ new Set();
|
|
4147
|
+
const countReferences = (exp$1) => {
|
|
4148
|
+
references.set(exp$1, (references.get(exp$1) ?? 0) + 1);
|
|
4149
|
+
if (!seen.has(exp$1)) {
|
|
4150
|
+
seen.add(exp$1);
|
|
4151
|
+
for (const src of exp$1.src) countReferences(src);
|
|
4152
|
+
}
|
|
4153
|
+
};
|
|
4154
|
+
const expContext = /* @__PURE__ */ new Map();
|
|
4155
|
+
const gen = (exp$1) => {
|
|
4156
|
+
if (expContext.has(exp$1)) return cg.local.get(expContext.get(exp$1));
|
|
4157
|
+
const { op, src, arg, dtype } = exp$1;
|
|
4158
|
+
const isInt = dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool;
|
|
4159
|
+
const isSigned = dtype === DType.Int32;
|
|
4160
|
+
if (op === AluOp.Add) {
|
|
4161
|
+
gen(src[0]);
|
|
4162
|
+
gen(src[1]);
|
|
4163
|
+
if (isInt) cg.i32x4.add();
|
|
4164
|
+
else cg.f32x4.add();
|
|
4165
|
+
} else if (op === AluOp.Sub) {
|
|
4166
|
+
gen(src[0]);
|
|
4167
|
+
gen(src[1]);
|
|
4168
|
+
if (isInt) cg.i32x4.sub();
|
|
4169
|
+
else cg.f32x4.sub();
|
|
4170
|
+
} else if (op === AluOp.Mul) {
|
|
4171
|
+
gen(src[0]);
|
|
4172
|
+
gen(src[1]);
|
|
4173
|
+
if (isInt) cg.i32x4.mul();
|
|
4174
|
+
else cg.f32x4.mul();
|
|
4175
|
+
} else if (op === AluOp.Min) {
|
|
4176
|
+
gen(src[0]);
|
|
4177
|
+
gen(src[1]);
|
|
4178
|
+
if (isInt) if (isSigned) cg.i32x4.min_s();
|
|
4179
|
+
else cg.i32x4.min_u();
|
|
4180
|
+
else cg.f32x4.min();
|
|
4181
|
+
} else if (op === AluOp.Max) {
|
|
4182
|
+
gen(src[0]);
|
|
4183
|
+
gen(src[1]);
|
|
4184
|
+
if (isInt) if (isSigned) cg.i32x4.max_s();
|
|
4185
|
+
else cg.i32x4.max_u();
|
|
4186
|
+
else cg.f32x4.max();
|
|
4187
|
+
} else if (op === AluOp.Sqrt) {
|
|
4188
|
+
gen(src[0]);
|
|
4189
|
+
cg.f32x4.sqrt();
|
|
4190
|
+
} else if (op === AluOp.Floor) {
|
|
4191
|
+
gen(src[0]);
|
|
4192
|
+
cg.f32x4.floor();
|
|
4193
|
+
} else if (op === AluOp.Ceil) {
|
|
4194
|
+
gen(src[0]);
|
|
4195
|
+
cg.f32x4.ceil();
|
|
4196
|
+
} else if (op === AluOp.Const) if (isInt) {
|
|
4197
|
+
cg.i32.const(arg);
|
|
4198
|
+
cg.i32x4.splat();
|
|
4199
|
+
} else {
|
|
4200
|
+
cg.f32.const(arg);
|
|
4201
|
+
cg.f32x4.splat();
|
|
4202
|
+
}
|
|
4203
|
+
else if (op === AluOp.Cast) {
|
|
4204
|
+
gen(src[0]);
|
|
4205
|
+
const dtype0 = src[0].dtype;
|
|
4206
|
+
const src0IsInt = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
4207
|
+
if (isInt && !src0IsInt) if (isSigned) cg.i32x4.trunc_sat_f32x4_s();
|
|
4208
|
+
else cg.i32x4.trunc_sat_f32x4_u();
|
|
4209
|
+
else if (!isInt && src0IsInt) if (dtype0 === DType.Int32 || dtype0 === DType.Bool) cg.f32x4.convert_i32x4_s();
|
|
4210
|
+
else cg.f32x4.convert_i32x4_u();
|
|
4211
|
+
} else if (op === AluOp.Cmplt) {
|
|
4212
|
+
gen(src[0]);
|
|
4213
|
+
gen(src[1]);
|
|
4214
|
+
const srcDtype = src[0].dtype;
|
|
4215
|
+
if (srcDtype === DType.Float32) cg.f32x4.lt();
|
|
4216
|
+
else if (srcDtype === DType.Int32) cg.i32x4.lt_s();
|
|
4217
|
+
else if (srcDtype === DType.Uint32) cg.i32x4.lt_u();
|
|
4218
|
+
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
4219
|
+
cg.i32.const(1);
|
|
4220
|
+
cg.i32x4.splat();
|
|
4221
|
+
cg.v128.and();
|
|
4222
|
+
} else if (op === AluOp.Cmpne) {
|
|
4223
|
+
gen(src[0]);
|
|
4224
|
+
gen(src[1]);
|
|
4225
|
+
const srcDtype = src[0].dtype;
|
|
4226
|
+
if (srcDtype === DType.Float32) cg.f32x4.ne();
|
|
4227
|
+
else cg.i32x4.ne();
|
|
4228
|
+
cg.i32.const(1);
|
|
4229
|
+
cg.i32x4.splat();
|
|
4230
|
+
cg.v128.and();
|
|
4231
|
+
} else if (op === AluOp.Where) {
|
|
4232
|
+
gen(src[1]);
|
|
4233
|
+
gen(src[2]);
|
|
4234
|
+
gen(src[0]);
|
|
4235
|
+
cg.i32.const(0);
|
|
4236
|
+
cg.i32x4.splat();
|
|
4237
|
+
cg.i32x4.ne();
|
|
4238
|
+
cg.v128.bitselect();
|
|
4239
|
+
} else if (op === AluOp.Variable || op === AluOp.Special) throw new Error(`translateExpSimd: unexpected ${op}(${arg})`);
|
|
4240
|
+
else if (op === AluOp.GlobalIndex) {
|
|
4241
|
+
const [gid, len] = arg;
|
|
4242
|
+
const indexSubtree = src[0];
|
|
4243
|
+
const stride = strideMap.get(exp$1) ?? GATHER;
|
|
4244
|
+
if (stride.kind === "contiguous") {
|
|
4245
|
+
translateExp(cg, funcs, indexSubtree, ctx);
|
|
4246
|
+
{
|
|
4247
|
+
const maxIdx = Math.max(len - SIMD_LANES, 0);
|
|
4248
|
+
const wideIdx = cg.local.declare(cg.i32);
|
|
4249
|
+
cg.local.set(wideIdx);
|
|
4250
|
+
cg.local.get(wideIdx);
|
|
4251
|
+
cg.i32.const(maxIdx);
|
|
4252
|
+
cg.local.get(wideIdx);
|
|
4253
|
+
cg.i32.const(maxIdx);
|
|
4254
|
+
cg.i32.lt_u();
|
|
4255
|
+
cg.select();
|
|
4256
|
+
}
|
|
4257
|
+
cg.i32.const(byteWidth(dtype));
|
|
4258
|
+
cg.i32.mul();
|
|
4259
|
+
cg.local.get(gid);
|
|
4260
|
+
cg.i32.add();
|
|
4261
|
+
if (isInt) cg.i32x4.load(4);
|
|
4262
|
+
else cg.f32x4.load(4);
|
|
4263
|
+
} else if (stride.kind === "broadcast") {
|
|
4264
|
+
translateExp(cg, funcs, indexSubtree, ctx);
|
|
4265
|
+
const local = cg.local.declare(cg.i32);
|
|
4266
|
+
cg.local.tee(local);
|
|
4267
|
+
cg.i32.const(0);
|
|
4268
|
+
cg.local.get(local), cg.i32.const(len), cg.i32.lt_u();
|
|
4269
|
+
cg.select();
|
|
4270
|
+
cg.i32.const(byteWidth(dtype));
|
|
4271
|
+
cg.i32.mul();
|
|
4272
|
+
cg.local.get(gid);
|
|
4273
|
+
cg.i32.add();
|
|
4274
|
+
if (isInt) {
|
|
4275
|
+
cg.i32.load(2);
|
|
4276
|
+
cg.i32x4.splat();
|
|
4277
|
+
} else {
|
|
4278
|
+
cg.f32.load(2);
|
|
4279
|
+
cg.f32x4.splat();
|
|
4280
|
+
}
|
|
4281
|
+
} else {
|
|
4282
|
+
const steppingLocal = ctx["gidx"];
|
|
4283
|
+
const origValue = cg.local.declare(cg.i32);
|
|
4284
|
+
cg.local.get(steppingLocal);
|
|
4285
|
+
cg.local.set(origValue);
|
|
4286
|
+
if (isInt) {
|
|
4287
|
+
cg.i32.const(0);
|
|
4288
|
+
cg.i32x4.splat();
|
|
4289
|
+
} else {
|
|
4290
|
+
cg.f32.const(0);
|
|
4291
|
+
cg.f32x4.splat();
|
|
4292
|
+
}
|
|
4293
|
+
const vec = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
|
|
4294
|
+
cg.local.set(vec);
|
|
4295
|
+
const idx = cg.local.declare(cg.i32);
|
|
4296
|
+
const scalarVal = cg.local.declare(isInt ? cg.i32 : cg.f32);
|
|
4297
|
+
for (let lane = 0; lane < SIMD_LANES; lane++) {
|
|
4298
|
+
cg.local.get(origValue);
|
|
4299
|
+
if (lane > 0) {
|
|
4300
|
+
cg.i32.const(lane);
|
|
4301
|
+
cg.i32.add();
|
|
4302
|
+
}
|
|
4303
|
+
cg.local.set(steppingLocal);
|
|
4304
|
+
translateExp(cg, funcs, indexSubtree, ctx);
|
|
4305
|
+
cg.local.tee(idx);
|
|
4306
|
+
cg.i32.const(0);
|
|
4307
|
+
cg.local.get(idx), cg.i32.const(len), cg.i32.lt_u();
|
|
4308
|
+
cg.select();
|
|
4309
|
+
cg.i32.const(byteWidth(dtype));
|
|
4310
|
+
cg.i32.mul();
|
|
4311
|
+
cg.local.get(gid);
|
|
4312
|
+
cg.i32.add();
|
|
4313
|
+
if (isInt) cg.i32.load(2);
|
|
4314
|
+
else cg.f32.load(2);
|
|
4315
|
+
cg.local.set(scalarVal);
|
|
4316
|
+
cg.local.get(vec);
|
|
4317
|
+
cg.local.get(scalarVal);
|
|
4318
|
+
if (isInt) cg.i32x4.replace_lane(lane);
|
|
4319
|
+
else cg.f32x4.replace_lane(lane);
|
|
4320
|
+
cg.local.set(vec);
|
|
4321
|
+
}
|
|
4322
|
+
cg.local.get(origValue);
|
|
4323
|
+
cg.local.set(steppingLocal);
|
|
4324
|
+
cg.local.get(vec);
|
|
4325
|
+
}
|
|
4326
|
+
} else throw new Error(`translateExpSimd: unsupported op ${op}`);
|
|
4327
|
+
if ((references.get(exp$1) ?? 0) > 1) {
|
|
4328
|
+
const local = cg.local.declare(isInt ? cg.i32x4 : cg.f32x4);
|
|
4329
|
+
cg.local.tee(local);
|
|
4330
|
+
expContext.set(exp$1, local);
|
|
4331
|
+
}
|
|
4332
|
+
};
|
|
4333
|
+
countReferences(exp);
|
|
4334
|
+
gen(exp);
|
|
4335
|
+
}
|
|
4336
|
+
/** Number of SIMD lanes (f32x4 / i32x4 = 4 lanes). */
|
|
4337
|
+
const SIMD_LANES = 4;
|
|
4338
|
+
function referencesGidx(exp) {
|
|
4339
|
+
if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return true;
|
|
4340
|
+
return exp.src.some(referencesGidx);
|
|
4341
|
+
}
|
|
4342
|
+
/** When tileSize > N but doesn't divide evenly, the last group before the
|
|
4343
|
+
* inner reset is shorter than N — a SIMD group could straddle it. */
|
|
4344
|
+
function hasFragmentRisk(tileSize, N) {
|
|
4345
|
+
return isFinite(tileSize) && tileSize > N && tileSize % N !== 0;
|
|
4346
|
+
}
|
|
4347
|
+
const GATHER = { kind: "gather" };
|
|
4348
|
+
/**
|
|
4349
|
+
* Classify how a GlobalIndex's index expression behaves as gidx increments.
|
|
4350
|
+
*/
|
|
4351
|
+
function analyzeStride(exp) {
|
|
4352
|
+
if (!referencesGidx(exp)) return {
|
|
4353
|
+
kind: "broadcast",
|
|
4354
|
+
tileSize: Infinity
|
|
4355
|
+
};
|
|
4356
|
+
if (exp.op === AluOp.Special && exp.arg[0] === "gidx") return {
|
|
4357
|
+
kind: "contiguous",
|
|
4358
|
+
tileSize: Infinity
|
|
4359
|
+
};
|
|
4360
|
+
if (exp.op === AluOp.Idiv && exp.src[1].op === AluOp.Const) {
|
|
4361
|
+
const N = exp.src[1].arg;
|
|
4362
|
+
const inner = analyzeStride(exp.src[0]);
|
|
4363
|
+
if (inner.kind === "broadcast") return inner;
|
|
4364
|
+
if (inner.kind !== "contiguous") return GATHER;
|
|
4365
|
+
if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
|
|
4366
|
+
return {
|
|
4367
|
+
kind: "broadcast",
|
|
4368
|
+
tileSize: Math.min(inner.tileSize, N)
|
|
4369
|
+
};
|
|
4370
|
+
}
|
|
4371
|
+
if (exp.op === AluOp.Mod && exp.src[1].op === AluOp.Const) {
|
|
4372
|
+
const N = exp.src[1].arg;
|
|
4373
|
+
const inner = analyzeStride(exp.src[0]);
|
|
4374
|
+
if (inner.kind === "broadcast") return inner;
|
|
4375
|
+
if (inner.kind !== "contiguous") return GATHER;
|
|
4376
|
+
if (hasFragmentRisk(inner.tileSize, N)) return GATHER;
|
|
4377
|
+
return {
|
|
4378
|
+
kind: "contiguous",
|
|
4379
|
+
tileSize: Math.min(inner.tileSize, N)
|
|
4380
|
+
};
|
|
4381
|
+
}
|
|
4382
|
+
if (exp.op === AluOp.Mul) {
|
|
4383
|
+
for (let i = 0; i < 2; i++) if (exp.src[i].op === AluOp.Const) {
|
|
4384
|
+
const inner = analyzeStride(exp.src[1 - i]);
|
|
4385
|
+
if (inner.kind === "broadcast") return inner;
|
|
4386
|
+
return GATHER;
|
|
4387
|
+
}
|
|
4388
|
+
}
|
|
4389
|
+
if (exp.op === AluOp.Add) {
|
|
4390
|
+
const lhsHasGidx = referencesGidx(exp.src[0]);
|
|
4391
|
+
const rhsHasGidx = referencesGidx(exp.src[1]);
|
|
4392
|
+
if (lhsHasGidx && !rhsHasGidx) return analyzeStride(exp.src[0]);
|
|
4393
|
+
if (!lhsHasGidx && rhsHasGidx) return analyzeStride(exp.src[1]);
|
|
4394
|
+
}
|
|
4395
|
+
return GATHER;
|
|
4396
|
+
}
|
|
4397
|
+
/** Ops that have direct SIMD (f32x4) instruction variants. */
|
|
4398
|
+
const simdF32Ops = new Set([
|
|
4399
|
+
AluOp.Add,
|
|
4400
|
+
AluOp.Sub,
|
|
4401
|
+
AluOp.Mul,
|
|
4402
|
+
AluOp.Floor,
|
|
4403
|
+
AluOp.Ceil,
|
|
4404
|
+
AluOp.Min,
|
|
4405
|
+
AluOp.Max,
|
|
4406
|
+
AluOp.Sqrt,
|
|
4407
|
+
AluOp.Cast,
|
|
4408
|
+
AluOp.Where,
|
|
4409
|
+
AluOp.Const,
|
|
4410
|
+
AluOp.GlobalIndex
|
|
4411
|
+
]);
|
|
4412
|
+
/** Ops that have direct SIMD (i32x4) instruction variants. */
|
|
4413
|
+
const simdI32Ops = new Set([
|
|
4414
|
+
AluOp.Add,
|
|
4415
|
+
AluOp.Sub,
|
|
4416
|
+
AluOp.Mul,
|
|
4417
|
+
AluOp.Min,
|
|
4418
|
+
AluOp.Max,
|
|
4419
|
+
AluOp.Cast,
|
|
4420
|
+
AluOp.Where,
|
|
4421
|
+
AluOp.Const,
|
|
4422
|
+
AluOp.GlobalIndex
|
|
4423
|
+
]);
|
|
4424
|
+
/** Ops that produce Bool (i32x4 bitmask) in SIMD. */
|
|
4425
|
+
const simdBoolOps = new Set([
|
|
4426
|
+
AluOp.Cmplt,
|
|
4427
|
+
AluOp.Cmpne,
|
|
4428
|
+
AluOp.Const,
|
|
4429
|
+
AluOp.GlobalIndex
|
|
4430
|
+
]);
|
|
4431
|
+
/**
|
|
4432
|
+
* Check if a kernel is eligible for SIMD codegen.
|
|
4433
|
+
*
|
|
4434
|
+
* A kernel qualifies when:
|
|
4435
|
+
* - size >= 4 (need at least 4 elements for a SIMD group)
|
|
4436
|
+
* - For reductions: the reduction op has a SIMD variant for its dtype
|
|
4437
|
+
* - All nodes have a supported dtype (f32, i32, u32, bool) with SIMD variants
|
|
4438
|
+
*/
|
|
4439
|
+
function isSimdEligible(tunedExp, kernel) {
|
|
4440
|
+
if (kernel.size < SIMD_LANES) return false;
|
|
4441
|
+
if (kernel.reduction) {
|
|
4442
|
+
if (!simdSupportedOpsForDtype(kernel.reduction.dtype)?.has(kernel.reduction.op)) return false;
|
|
4443
|
+
}
|
|
4444
|
+
const check = (exp, visited) => {
|
|
4445
|
+
if (visited.has(exp)) return true;
|
|
4446
|
+
visited.add(exp);
|
|
4447
|
+
const supportedOps = simdSupportedOpsForDtype(exp.dtype);
|
|
4448
|
+
if (!supportedOps || !supportedOps.has(exp.op)) return false;
|
|
4449
|
+
if (exp.op === AluOp.GlobalIndex) return true;
|
|
4450
|
+
for (const child of exp.src) if (!check(child, visited)) return false;
|
|
4451
|
+
return true;
|
|
4452
|
+
};
|
|
4453
|
+
return check(tunedExp, /* @__PURE__ */ new Set());
|
|
4454
|
+
}
|
|
4455
|
+
function simdSupportedOpsForDtype(dtype) {
|
|
4456
|
+
if (dtype === DType.Float32) return simdF32Ops;
|
|
4457
|
+
if (dtype === DType.Int32 || dtype === DType.Uint32) return simdI32Ops;
|
|
4458
|
+
if (dtype === DType.Bool) return simdBoolOps;
|
|
4459
|
+
return null;
|
|
4460
|
+
}
|
|
3935
4461
|
const moduleCache = /* @__PURE__ */ new Map();
|
|
3936
4462
|
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3937
4463
|
var WasmBackend = class {
|
|
@@ -3941,11 +4467,18 @@ var WasmBackend = class {
|
|
|
3941
4467
|
#nextSlot;
|
|
3942
4468
|
#allocator;
|
|
3943
4469
|
#buffers;
|
|
4470
|
+
#workerPool;
|
|
4471
|
+
#pendingWork = /* @__PURE__ */ new Map();
|
|
3944
4472
|
constructor() {
|
|
3945
|
-
this.#memory = new WebAssembly.Memory({
|
|
4473
|
+
this.#memory = hasSharedArrayBuffer() ? new WebAssembly.Memory({
|
|
4474
|
+
initial: 0,
|
|
4475
|
+
maximum: 65536,
|
|
4476
|
+
shared: true
|
|
4477
|
+
}) : new WebAssembly.Memory({ initial: 0 });
|
|
3946
4478
|
this.#allocator = new WasmAllocator(this.#memory);
|
|
3947
4479
|
this.#nextSlot = 1;
|
|
3948
4480
|
this.#buffers = /* @__PURE__ */ new Map();
|
|
4481
|
+
this.#workerPool = createWorkerPool(this.#memory);
|
|
3949
4482
|
}
|
|
3950
4483
|
malloc(size, initialData) {
|
|
3951
4484
|
const ptr = this.#allocator.malloc(size);
|
|
@@ -3976,40 +4509,65 @@ var WasmBackend = class {
|
|
|
3976
4509
|
}
|
|
3977
4510
|
}
|
|
3978
4511
|
async read(slot, start, count) {
|
|
3979
|
-
|
|
4512
|
+
const epoch = this.#pendingWork.get(slot);
|
|
4513
|
+
if (epoch) await this.#workerPool.waitForEpoch(epoch);
|
|
4514
|
+
return this.#readData(slot, start, count);
|
|
3980
4515
|
}
|
|
3981
4516
|
readSync(slot, start, count) {
|
|
4517
|
+
const epoch = this.#pendingWork.get(slot);
|
|
4518
|
+
if (epoch && this.#workerPool.epoch < epoch) throw new Error("cannot read synchronously from a slot with async work");
|
|
4519
|
+
return this.#readData(slot, start, count);
|
|
4520
|
+
}
|
|
4521
|
+
#readData(slot, start, count) {
|
|
3982
4522
|
const buffer = this.#getBuffer(slot);
|
|
3983
4523
|
if (start === void 0) start = 0;
|
|
3984
4524
|
if (count === void 0) count = buffer.byteLength - start;
|
|
3985
|
-
return buffer.slice(start, start + count);
|
|
4525
|
+
if (buffer.buffer instanceof SharedArrayBuffer) return new Uint8Array(buffer.slice(start, start + count));
|
|
4526
|
+
else return buffer.slice(start, start + count);
|
|
3986
4527
|
}
|
|
3987
4528
|
async prepareKernel(kernel) {
|
|
3988
|
-
|
|
4529
|
+
const kernelHash = FpHash.hash(kernel);
|
|
4530
|
+
const module = await runWithCacheAsync(moduleCache, kernelHash.toString(), () => WebAssembly.compile(codegenWasm(kernel)));
|
|
4531
|
+
return new Executable(kernel, {
|
|
4532
|
+
module,
|
|
4533
|
+
parallel: this.#workerPool !== null
|
|
4534
|
+
});
|
|
3989
4535
|
}
|
|
3990
4536
|
prepareKernelSync(kernel) {
|
|
3991
4537
|
const kernelHash = FpHash.hash(kernel);
|
|
3992
|
-
const module = runWithCache(moduleCache, kernelHash.toString(), () =>
|
|
3993
|
-
|
|
3994
|
-
|
|
4538
|
+
const module = runWithCache(moduleCache, kernelHash.toString(), () => new WebAssembly.Module(codegenWasm(kernel)));
|
|
4539
|
+
return new Executable(kernel, {
|
|
4540
|
+
module,
|
|
4541
|
+
parallel: false
|
|
3995
4542
|
});
|
|
3996
|
-
return new Executable(kernel, { module });
|
|
3997
4543
|
}
|
|
3998
4544
|
async prepareRoutine(routine) {
|
|
3999
4545
|
return this.prepareRoutineSync(routine);
|
|
4000
4546
|
}
|
|
4001
4547
|
prepareRoutineSync(routine) {
|
|
4002
|
-
return new Executable(routine,
|
|
4548
|
+
return new Executable(routine, {
|
|
4549
|
+
module: void 0,
|
|
4550
|
+
parallel: false
|
|
4551
|
+
});
|
|
4003
4552
|
}
|
|
4004
4553
|
dispatch(exe, inputs, outputs) {
|
|
4005
4554
|
const tracing = isTracing();
|
|
4006
4555
|
const start = tracing ? performance.now() : 0;
|
|
4007
4556
|
if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
4008
4557
|
else {
|
|
4009
|
-
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
4010
|
-
const func = instance.exports.kernel;
|
|
4011
4558
|
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
4012
|
-
|
|
4559
|
+
if (exe.data.parallel && this.#workerPool) {
|
|
4560
|
+
const epoch = this.#workerPool.dispatch(exe.data.module, ptrs, exe.source.size);
|
|
4561
|
+
for (const slot of outputs) this.#pendingWork.set(slot, epoch);
|
|
4562
|
+
} else {
|
|
4563
|
+
if (inputs.some((slot) => {
|
|
4564
|
+
const epoch = this.#pendingWork.get(slot);
|
|
4565
|
+
return epoch && this.#workerPool.epoch < epoch;
|
|
4566
|
+
})) throw new Error("cannot dispatch synchronously with pending async work");
|
|
4567
|
+
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
4568
|
+
const func = instance.exports.kernel;
|
|
4569
|
+
func(...ptrs, 0, exe.source.size);
|
|
4570
|
+
}
|
|
4013
4571
|
}
|
|
4014
4572
|
if (tracing) {
|
|
4015
4573
|
const info = traceSourceInfo(exe.source);
|
|
@@ -4022,12 +4580,36 @@ var WasmBackend = class {
|
|
|
4022
4580
|
return new Uint8Array(this.#memory.buffer, buffer.ptr, buffer.size);
|
|
4023
4581
|
}
|
|
4024
4582
|
};
|
|
4583
|
+
/** Emit a runtime guard: enter the if-block only when [begin, end) is SIMD-aligned. */
|
|
4584
|
+
function emitAlignmentGuard(cg, paramBegin, paramEnd) {
|
|
4585
|
+
const mask = SIMD_LANES - 1;
|
|
4586
|
+
cg.local.get(paramEnd);
|
|
4587
|
+
cg.local.get(paramBegin);
|
|
4588
|
+
cg.i32.sub();
|
|
4589
|
+
cg.i32.const(mask);
|
|
4590
|
+
cg.i32.and();
|
|
4591
|
+
cg.i32.eqz();
|
|
4592
|
+
cg.local.get(paramBegin);
|
|
4593
|
+
cg.i32.const(mask);
|
|
4594
|
+
cg.i32.and();
|
|
4595
|
+
cg.i32.eqz();
|
|
4596
|
+
cg.i32.and();
|
|
4597
|
+
cg.if(cg.void);
|
|
4598
|
+
}
|
|
4025
4599
|
function codegenWasm(kernel) {
|
|
4026
4600
|
const tune = tuneNullopt(kernel);
|
|
4027
4601
|
const re = kernel.reduction;
|
|
4028
4602
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
4603
|
+
const useSimd = isSimdEligible(tune.exp, kernel);
|
|
4604
|
+
const bufferStrides = /* @__PURE__ */ new Map();
|
|
4605
|
+
if (useSimd) tune.exp.collect((e) => e.op === AluOp.GlobalIndex).forEach((gi) => {
|
|
4606
|
+
const result = analyzeStride(gi.src[0]);
|
|
4607
|
+
if (result.kind !== "gather" && (result.tileSize < SIMD_LANES || isFinite(result.tileSize) && result.tileSize % SIMD_LANES !== 0)) bufferStrides.set(gi, GATHER);
|
|
4608
|
+
else bufferStrides.set(gi, result);
|
|
4609
|
+
});
|
|
4029
4610
|
const cg = new CodeGenerator();
|
|
4030
4611
|
cg.memory.import("env", "memory");
|
|
4612
|
+
if (hasSharedArrayBuffer()) cg.memory.pages(0, 65536).shared(true);
|
|
4031
4613
|
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
4032
4614
|
const funcs = {};
|
|
4033
4615
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
@@ -4039,12 +4621,127 @@ function codegenWasm(kernel) {
|
|
|
4039
4621
|
if (distinctOps.has(AluOp.Erf)) funcs.erf = wasm_erf(cg, funcs.exp);
|
|
4040
4622
|
if (distinctOps.has(AluOp.Erfc)) funcs.erfc = wasm_erfc(cg, funcs.exp);
|
|
4041
4623
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
4042
|
-
const
|
|
4624
|
+
const paramBegin = kernel.nargs + 1;
|
|
4625
|
+
const paramEnd = kernel.nargs + 2;
|
|
4626
|
+
const kernelFunc = cg.function(rep(kernel.nargs + 3, cg.i32), [], () => {
|
|
4043
4627
|
const gidx = cg.local.declare(cg.i32);
|
|
4628
|
+
cg.local.get(paramBegin);
|
|
4629
|
+
cg.local.set(gidx);
|
|
4630
|
+
if (useSimd) {
|
|
4631
|
+
emitAlignmentGuard(cg, paramBegin, paramEnd);
|
|
4632
|
+
cg.loop(cg.void);
|
|
4633
|
+
if (!re) {
|
|
4634
|
+
cg.block(cg.void);
|
|
4635
|
+
cg.local.get(gidx);
|
|
4636
|
+
cg.local.get(paramEnd);
|
|
4637
|
+
cg.i32.ge_u();
|
|
4638
|
+
cg.br_if(0);
|
|
4639
|
+
cg.local.get(kernel.nargs);
|
|
4640
|
+
cg.local.get(gidx);
|
|
4641
|
+
cg.i32.const(byteWidth(kernel.dtype));
|
|
4642
|
+
cg.i32.mul();
|
|
4643
|
+
cg.i32.add();
|
|
4644
|
+
translateExpSimd(cg, funcs, tune.exp, { gidx }, bufferStrides);
|
|
4645
|
+
cg.v128.store(4);
|
|
4646
|
+
cg.local.get(gidx);
|
|
4647
|
+
cg.i32.const(SIMD_LANES);
|
|
4648
|
+
cg.i32.add();
|
|
4649
|
+
cg.local.set(gidx);
|
|
4650
|
+
cg.br(1);
|
|
4651
|
+
cg.end();
|
|
4652
|
+
} else {
|
|
4653
|
+
const reIsInt = kernel.exp.dtype === DType.Int32 || kernel.exp.dtype === DType.Uint32;
|
|
4654
|
+
cg.block(cg.void);
|
|
4655
|
+
cg.local.get(gidx);
|
|
4656
|
+
cg.local.get(paramEnd);
|
|
4657
|
+
cg.i32.ge_u();
|
|
4658
|
+
cg.br_if(0);
|
|
4659
|
+
const vecAcc = cg.local.declare(reIsInt ? cg.i32x4 : cg.f32x4);
|
|
4660
|
+
if (reIsInt) {
|
|
4661
|
+
cg.i32.const(re.identity);
|
|
4662
|
+
cg.i32x4.splat();
|
|
4663
|
+
} else {
|
|
4664
|
+
cg.f32.const(re.identity);
|
|
4665
|
+
cg.f32x4.splat();
|
|
4666
|
+
}
|
|
4667
|
+
cg.local.set(vecAcc);
|
|
4668
|
+
const ridx = cg.local.declare(cg.i32);
|
|
4669
|
+
cg.i32.const(0);
|
|
4670
|
+
cg.local.set(ridx);
|
|
4671
|
+
cg.loop(cg.void);
|
|
4672
|
+
cg.block(cg.void);
|
|
4673
|
+
cg.local.get(ridx);
|
|
4674
|
+
cg.i32.const(re.size);
|
|
4675
|
+
cg.i32.ge_u();
|
|
4676
|
+
cg.br_if(0);
|
|
4677
|
+
translateExpSimd(cg, funcs, tune.exp, {
|
|
4678
|
+
gidx,
|
|
4679
|
+
ridx
|
|
4680
|
+
}, bufferStrides);
|
|
4681
|
+
cg.local.get(vecAcc);
|
|
4682
|
+
if (reIsInt) if (re.op === AluOp.Add) cg.i32x4.add();
|
|
4683
|
+
else if (re.op === AluOp.Mul) cg.i32x4.mul();
|
|
4684
|
+
else if (re.op === AluOp.Min) if (re.dtype === DType.Int32) cg.i32x4.min_s();
|
|
4685
|
+
else cg.i32x4.min_u();
|
|
4686
|
+
else if (re.op === AluOp.Max) if (re.dtype === DType.Int32) cg.i32x4.max_s();
|
|
4687
|
+
else cg.i32x4.max_u();
|
|
4688
|
+
else throw new Error(`invalid SIMD reduction op: ${re.op}`);
|
|
4689
|
+
else if (re.op === AluOp.Add) cg.f32x4.add();
|
|
4690
|
+
else if (re.op === AluOp.Mul) cg.f32x4.mul();
|
|
4691
|
+
else if (re.op === AluOp.Min) cg.f32x4.min();
|
|
4692
|
+
else if (re.op === AluOp.Max) cg.f32x4.max();
|
|
4693
|
+
else throw new Error(`invalid SIMD reduction op: ${re.op}`);
|
|
4694
|
+
cg.local.set(vecAcc);
|
|
4695
|
+
cg.local.get(ridx);
|
|
4696
|
+
cg.i32.const(1);
|
|
4697
|
+
cg.i32.add();
|
|
4698
|
+
cg.local.set(ridx);
|
|
4699
|
+
cg.br(1);
|
|
4700
|
+
cg.end();
|
|
4701
|
+
cg.end();
|
|
4702
|
+
for (let lane = 0; lane < SIMD_LANES; lane++) {
|
|
4703
|
+
cg.local.get(kernel.nargs);
|
|
4704
|
+
cg.local.get(gidx);
|
|
4705
|
+
if (lane > 0) {
|
|
4706
|
+
cg.i32.const(lane);
|
|
4707
|
+
cg.i32.add();
|
|
4708
|
+
}
|
|
4709
|
+
cg.i32.const(byteWidth(kernel.dtype));
|
|
4710
|
+
cg.i32.mul();
|
|
4711
|
+
cg.i32.add();
|
|
4712
|
+
const acc = cg.local.declare(reIsInt ? cg.i32 : cg.f32);
|
|
4713
|
+
cg.local.get(vecAcc);
|
|
4714
|
+
if (reIsInt) cg.i32x4.extract_lane(lane);
|
|
4715
|
+
else cg.f32x4.extract_lane(lane);
|
|
4716
|
+
cg.local.set(acc);
|
|
4717
|
+
const laneGidx = cg.local.declare(cg.i32);
|
|
4718
|
+
cg.local.get(gidx);
|
|
4719
|
+
if (lane > 0) {
|
|
4720
|
+
cg.i32.const(lane);
|
|
4721
|
+
cg.i32.add();
|
|
4722
|
+
}
|
|
4723
|
+
cg.local.set(laneGidx);
|
|
4724
|
+
translateExp(cg, funcs, tune.epilogue, {
|
|
4725
|
+
acc,
|
|
4726
|
+
gidx: laneGidx
|
|
4727
|
+
});
|
|
4728
|
+
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
4729
|
+
}
|
|
4730
|
+
cg.local.get(gidx);
|
|
4731
|
+
cg.i32.const(SIMD_LANES);
|
|
4732
|
+
cg.i32.add();
|
|
4733
|
+
cg.local.set(gidx);
|
|
4734
|
+
cg.br(1);
|
|
4735
|
+
cg.end();
|
|
4736
|
+
}
|
|
4737
|
+
cg.end();
|
|
4738
|
+
cg.return();
|
|
4739
|
+
cg.end();
|
|
4740
|
+
}
|
|
4044
4741
|
cg.loop(cg.void);
|
|
4045
4742
|
cg.block(cg.void);
|
|
4046
4743
|
cg.local.get(gidx);
|
|
4047
|
-
cg.
|
|
4744
|
+
cg.local.get(paramEnd);
|
|
4048
4745
|
cg.i32.ge_u();
|
|
4049
4746
|
cg.br_if(0);
|
|
4050
4747
|
cg.local.get(kernel.nargs);
|
|
@@ -4183,6 +4880,11 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
4183
4880
|
else cg.i32.gt_u();
|
|
4184
4881
|
cg.select();
|
|
4185
4882
|
} else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
4883
|
+
else if (op === AluOp.BitCombine) if (arg === "and") cg.i32.and();
|
|
4884
|
+
else if (arg === "or") cg.i32.or();
|
|
4885
|
+
else cg.i32.xor();
|
|
4886
|
+
else if (op === AluOp.BitShift) if (arg === "shl") cg.i32.shl();
|
|
4887
|
+
else cg.i32.shr_u();
|
|
4186
4888
|
else if (op === AluOp.Cmplt) {
|
|
4187
4889
|
const srcDtype = src[0].dtype;
|
|
4188
4890
|
if (isFloatDtype(srcDtype)) dtyF(cg, op, srcDtype).lt();
|
|
@@ -4359,7 +5061,7 @@ async function createBackend(device) {
|
|
|
4359
5061
|
if (!navigator.gpu) return null;
|
|
4360
5062
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4361
5063
|
if (!adapter) return null;
|
|
4362
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
5064
|
+
const { WebGPUBackend } = await import("./webgpu-C2kLdkUh.js");
|
|
4363
5065
|
const importantLimits = [
|
|
4364
5066
|
"maxBufferSize",
|
|
4365
5067
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4397,7 +5099,7 @@ async function createBackend(device) {
|
|
|
4397
5099
|
});
|
|
4398
5100
|
if (!gl) return null;
|
|
4399
5101
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4400
|
-
const { WebGLBackend } = await import("./webgl-
|
|
5102
|
+
const { WebGLBackend } = await import("./webgl-BhsnpeB0.js");
|
|
4401
5103
|
return new WebGLBackend(gl);
|
|
4402
5104
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4403
5105
|
}
|
|
@@ -4431,6 +5133,15 @@ var UnsupportedRoutineError = class extends Error {
|
|
|
4431
5133
|
super(`routine '${name}' is not supported in ${device} backend`);
|
|
4432
5134
|
}
|
|
4433
5135
|
};
|
|
5136
|
+
/**
|
|
5137
|
+
* If the WebGPU backend has been initialized, return the `GPUDevice` that this
|
|
5138
|
+
* backend runs on. This is useful for sharing buffers.
|
|
5139
|
+
*/
|
|
5140
|
+
function getWebGPUDevice() {
|
|
5141
|
+
const backend = initializedBackends.get("webgpu");
|
|
5142
|
+
if (!backend) throw new Error("WebGPU backend not initialized, call init('webgpu') first");
|
|
5143
|
+
return backend.device;
|
|
5144
|
+
}
|
|
4434
5145
|
|
|
4435
5146
|
//#endregion
|
|
4436
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, emitTrace, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, isTracing, mapSetUnion, normalizeAxis, onFlushTrace, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, strip1, toposort, traceSourceInfo, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
5147
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, emitTrace, findPow2, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, isTracing, mapSetUnion, normalizeAxis, onFlushTrace, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, strip1, toposort, traceSourceInfo, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|