@jax-js/jax 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +16 -34
- package/dist/{backend-DeVfWEFS.cjs → backend-Bu9GY6sK.cjs} +222 -36
- package/dist/{backend-BqymqzuU.js → backend-tngXtWe4.js} +204 -36
- package/dist/index.cjs +1798 -955
- package/dist/index.d.cts +383 -97
- package/dist/index.d.ts +383 -97
- package/dist/index.js +1791 -949
- package/dist/{webgpu-BGuG58KZ.js → webgpu-ChVgx3b6.js} +410 -97
- package/dist/{webgpu-CcGP160M.cjs → webgpu-Oj3Kd-kd.cjs} +410 -97
- package/package.json +1 -1
|
@@ -68,6 +68,9 @@ function zipn(...arrays) {
|
|
|
68
68
|
const minLength = Math.min(...arrays.map((x) => x.length));
|
|
69
69
|
return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
|
|
70
70
|
}
|
|
71
|
+
function sorted(arr) {
|
|
72
|
+
return [...arr].sort((a, b) => a - b);
|
|
73
|
+
}
|
|
71
74
|
function rep(length, value) {
|
|
72
75
|
if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
|
|
73
76
|
return new Array(length).fill(value);
|
|
@@ -144,7 +147,7 @@ function normalizeAxis(axis, ndim) {
|
|
|
144
147
|
if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
|
|
145
148
|
seen.add(ca);
|
|
146
149
|
}
|
|
147
|
-
return
|
|
150
|
+
return sorted(seen);
|
|
148
151
|
}
|
|
149
152
|
}
|
|
150
153
|
function range(start, stop, step = 1) {
|
|
@@ -557,16 +560,16 @@ var AluExp = class AluExp {
|
|
|
557
560
|
});
|
|
558
561
|
}
|
|
559
562
|
/** Reindex gid values in this expression as needed. */
|
|
560
|
-
reindexGids(
|
|
563
|
+
reindexGids(newGids) {
|
|
561
564
|
return this.rewrite((exp) => {
|
|
562
565
|
if (exp.op === AluOp.GlobalIndex) {
|
|
563
566
|
const [gid, len] = exp.arg;
|
|
564
|
-
const newGid =
|
|
565
|
-
if (newGid !==
|
|
567
|
+
const newGid = newGids[gid];
|
|
568
|
+
if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
|
|
566
569
|
} else if (exp.op === AluOp.GlobalView) {
|
|
567
570
|
const gid = exp.arg[0];
|
|
568
|
-
const newGid =
|
|
569
|
-
if (newGid !==
|
|
571
|
+
const newGid = newGids[gid];
|
|
572
|
+
if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
|
|
570
573
|
}
|
|
571
574
|
});
|
|
572
575
|
}
|
|
@@ -780,7 +783,7 @@ var AluExp = class AluExp {
|
|
|
780
783
|
if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
|
|
781
784
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
782
785
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
783
|
-
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
786
|
+
if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
|
|
784
787
|
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
785
788
|
}
|
|
786
789
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
@@ -1327,7 +1330,7 @@ var Reduction = class {
|
|
|
1327
1330
|
/** Evaluate this operation on CPU. */
|
|
1328
1331
|
evaluate(...values) {
|
|
1329
1332
|
if (this.dtype === DType.Bool) {
|
|
1330
|
-
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b,
|
|
1333
|
+
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, false);
|
|
1331
1334
|
else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
|
|
1332
1335
|
} else if (this.dtype === DType.Int32) {
|
|
1333
1336
|
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
|
|
@@ -1437,6 +1440,112 @@ function erfc(x) {
|
|
|
1437
1440
|
else return 2 - _erfapprox$1(-x);
|
|
1438
1441
|
}
|
|
1439
1442
|
|
|
1443
|
+
//#endregion
|
|
1444
|
+
//#region src/routine.ts
|
|
1445
|
+
/**
|
|
1446
|
+
* Advanced operations that don't fit into the `AluExp` compiler representation.
|
|
1447
|
+
*
|
|
1448
|
+
* Some routines like iterative matrix algorithms, FFTs, or sorting may not be
|
|
1449
|
+
* easy to express efficiently as a `Kernel` object. These also tend to be
|
|
1450
|
+
* somewhat expensive, so the benefit of kernel fusion and inlining is less
|
|
1451
|
+
* relevant.
|
|
1452
|
+
*
|
|
1453
|
+
* For these operations, we dispatch them as a custom operation on the backend,
|
|
1454
|
+
* which each backend implements in a specific way. These are listed in the
|
|
1455
|
+
* `Routines` enum below.
|
|
1456
|
+
*
|
|
1457
|
+
* Routines cannot be fused into other kernels and always operate on contiguous
|
|
1458
|
+
* arrays (default `ShapeTracker`).
|
|
1459
|
+
*/
|
|
1460
|
+
var Routine = class {
|
|
1461
|
+
constructor(name, type, params) {
|
|
1462
|
+
this.name = name;
|
|
1463
|
+
this.type = type;
|
|
1464
|
+
this.params = params;
|
|
1465
|
+
}
|
|
1466
|
+
};
|
|
1467
|
+
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1468
|
+
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1469
|
+
/** Stable sorting algorithm along the last axis. */
|
|
1470
|
+
Routines$1["Sort"] = "Sort";
|
|
1471
|
+
/** Returns `int32` indices of the stably sorted array. */
|
|
1472
|
+
Routines$1["Argsort"] = "Argsort";
|
|
1473
|
+
/** Solve a triangular system of questions. */
|
|
1474
|
+
Routines$1["TriangularSolve"] = "TriangularSolve";
|
|
1475
|
+
/** Cholesky decomposition of 2D positive semi-definite matrices. */
|
|
1476
|
+
Routines$1["Cholesky"] = "Cholesky";
|
|
1477
|
+
return Routines$1;
|
|
1478
|
+
}({});
|
|
1479
|
+
function runCpuRoutine(routine, inputs, outputs) {
|
|
1480
|
+
const { name, type } = routine;
|
|
1481
|
+
const inputAr = inputs.map((buf, i) => dtypedArray(type.inputDtypes[i], buf));
|
|
1482
|
+
const outputAr = outputs.map((buf, i) => dtypedArray(type.outputDtypes[i], buf));
|
|
1483
|
+
switch (name) {
|
|
1484
|
+
case Routines.Sort: return runSort(type, inputAr, outputAr);
|
|
1485
|
+
case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
|
|
1486
|
+
case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
|
|
1487
|
+
case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
|
|
1488
|
+
default:
|
|
1489
|
+
}
|
|
1490
|
+
}
|
|
1491
|
+
function runSort(type, [x], [y]) {
|
|
1492
|
+
const xs = type.inputShapes[0];
|
|
1493
|
+
if (xs.length === 0) throw new Error("sort: cannot sort a scalar");
|
|
1494
|
+
const n = xs[xs.length - 1];
|
|
1495
|
+
y.set(x);
|
|
1496
|
+
for (let i = 0; i < y.length; i += n) y.subarray(i, i + n).sort();
|
|
1497
|
+
}
|
|
1498
|
+
function runArgsort(type, [x], [y, yi]) {
|
|
1499
|
+
const xs = type.inputShapes[0];
|
|
1500
|
+
if (xs.length === 0) throw new Error("argsort: cannot sort a scalar");
|
|
1501
|
+
const n = xs[xs.length - 1];
|
|
1502
|
+
for (let offset = 0; offset < y.length; offset += n) {
|
|
1503
|
+
const ar = x.subarray(offset, offset + n);
|
|
1504
|
+
const out = y.subarray(offset, offset + n);
|
|
1505
|
+
const outi = yi.subarray(offset, offset + n);
|
|
1506
|
+
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1507
|
+
outi.sort((a, b) => ar[a] - ar[b]);
|
|
1508
|
+
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1509
|
+
}
|
|
1510
|
+
}
|
|
1511
|
+
function runTriangularSolve(type, [a, b], [x], { unitDiagonal }) {
|
|
1512
|
+
const as = type.inputShapes[0];
|
|
1513
|
+
const bs = type.inputShapes[1];
|
|
1514
|
+
if (as.length < 2) throw new Error(`triangular_solve: a must be at least 2D, got ${as}`);
|
|
1515
|
+
if (bs.length < 2) throw new Error(`triangular_solve: b must be at least 2D, got ${bs}`);
|
|
1516
|
+
const n = as[as.length - 2];
|
|
1517
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
1518
|
+
const batch = bs[bs.length - 2];
|
|
1519
|
+
for (let counter = 0; counter < a.length / (n * n); counter++) {
|
|
1520
|
+
const a1 = a.subarray(counter * n * n, (counter + 1) * n * n);
|
|
1521
|
+
for (let t = 0; t < batch; t++) {
|
|
1522
|
+
const b1 = b.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1523
|
+
const x1 = x.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1524
|
+
for (let i = n - 1; i >= 0; i--) {
|
|
1525
|
+
let sum = b1[i];
|
|
1526
|
+
for (let j = i + 1; j < n; j++) sum -= a1[i * n + j] * x1[j];
|
|
1527
|
+
x1[i] = unitDiagonal ? sum : sum / a1[i * n + i];
|
|
1528
|
+
}
|
|
1529
|
+
}
|
|
1530
|
+
}
|
|
1531
|
+
}
|
|
1532
|
+
function runCholesky(type, [x], [y]) {
|
|
1533
|
+
const xs = type.inputShapes[0];
|
|
1534
|
+
if (xs.length < 2) throw new Error("cholesky: input must be at least 2D");
|
|
1535
|
+
const n = xs[xs.length - 2];
|
|
1536
|
+
const m = xs[xs.length - 1];
|
|
1537
|
+
if (n !== m) throw new Error(`cholesky: input must be square, got [${n}, ${m}]`);
|
|
1538
|
+
for (let offset = 0; offset < y.length; offset += n * n) {
|
|
1539
|
+
const ar = x.subarray(offset, offset + n * n);
|
|
1540
|
+
const out = y.subarray(offset, offset + n * n);
|
|
1541
|
+
for (let i = 0; i < n; i++) for (let j = 0; j <= i; j++) {
|
|
1542
|
+
let sum = ar[i * n + j];
|
|
1543
|
+
for (let k = 0; k < j; k++) sum -= out[i * n + k] * out[j * n + k];
|
|
1544
|
+
out[i * n + j] = i === j ? Math.sqrt(sum) : sum / out[j * n + j];
|
|
1545
|
+
}
|
|
1546
|
+
}
|
|
1547
|
+
}
|
|
1548
|
+
|
|
1440
1549
|
//#endregion
|
|
1441
1550
|
//#region src/shape.ts
|
|
1442
1551
|
const jstr = JSON.stringify;
|
|
@@ -1907,7 +2016,7 @@ var ShapeTracker = class ShapeTracker {
|
|
|
1907
2016
|
let st = this;
|
|
1908
2017
|
if (axis.length > 0) {
|
|
1909
2018
|
const unsqueezed = [...st.shape];
|
|
1910
|
-
for (const i of axis
|
|
2019
|
+
for (const i of sorted(axis)) unsqueezed.splice(i, 0, 1);
|
|
1911
2020
|
st = st.reshape(unsqueezed);
|
|
1912
2021
|
}
|
|
1913
2022
|
return st.expand(newShape);
|
|
@@ -2066,7 +2175,8 @@ function tuneNullopt(kernel) {
|
|
|
2066
2175
|
if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2067
2176
|
return {
|
|
2068
2177
|
exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2069
|
-
|
|
2178
|
+
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2179
|
+
outputIdxExp: vars.gidx,
|
|
2070
2180
|
threadCount: kernel.size,
|
|
2071
2181
|
size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
|
|
2072
2182
|
};
|
|
@@ -2099,7 +2209,11 @@ function tuneWebgpu(kernel) {
|
|
|
2099
2209
|
while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
|
|
2100
2210
|
const choices = [];
|
|
2101
2211
|
const composedSts = sts.map((st) => st.compose(dim.st));
|
|
2102
|
-
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2212
|
+
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2213
|
+
3,
|
|
2214
|
+
4,
|
|
2215
|
+
5
|
|
2216
|
+
]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
|
|
2103
2217
|
let nonzeroStrides = 0;
|
|
2104
2218
|
let totalStrides = 0;
|
|
2105
2219
|
for (const st of composedSts) {
|
|
@@ -2127,7 +2241,7 @@ function tuneWebgpu(kernel) {
|
|
|
2127
2241
|
break;
|
|
2128
2242
|
}
|
|
2129
2243
|
}
|
|
2130
|
-
for (const ax of
|
|
2244
|
+
for (const ax of sorted(upcastedAxis)) {
|
|
2131
2245
|
const s = dim.st.shape[ax];
|
|
2132
2246
|
for (const amount of [8, 4]) if (s % amount === 0) {
|
|
2133
2247
|
dim.applyLocal(ax, amount);
|
|
@@ -2175,7 +2289,15 @@ function tuneWebgpu(kernel) {
|
|
|
2175
2289
|
});
|
|
2176
2290
|
const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
|
|
2177
2291
|
const outputUpcast = dim.outputSt.shape.slice(dim.groups);
|
|
2178
|
-
const
|
|
2292
|
+
const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
|
|
2293
|
+
const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
|
|
2294
|
+
const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
|
|
2295
|
+
if (exp$1.op === AluOp.GlobalView) {
|
|
2296
|
+
const gid = exp$1.arg[0];
|
|
2297
|
+
const st = exp$1.arg[1];
|
|
2298
|
+
return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
|
|
2299
|
+
}
|
|
2300
|
+
});
|
|
2179
2301
|
if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
|
|
2180
2302
|
const size = {
|
|
2181
2303
|
groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
|
|
@@ -2185,6 +2307,7 @@ function tuneWebgpu(kernel) {
|
|
|
2185
2307
|
};
|
|
2186
2308
|
return {
|
|
2187
2309
|
exp: newExp.simplify(),
|
|
2310
|
+
epilogue: newEpilogue.simplify(),
|
|
2188
2311
|
outputIdxExp: outputIdxExp.simplify(),
|
|
2189
2312
|
threadCount: kernel.size / size.upcast * size.groups,
|
|
2190
2313
|
size
|
|
@@ -2236,17 +2359,25 @@ var CpuBackend = class {
|
|
|
2236
2359
|
if (count === void 0) count = buffer.byteLength - start;
|
|
2237
2360
|
return buffer.slice(start, start + count);
|
|
2238
2361
|
}
|
|
2239
|
-
async
|
|
2240
|
-
return this.
|
|
2362
|
+
async prepareKernel(kernel) {
|
|
2363
|
+
return this.prepareKernelSync(kernel);
|
|
2241
2364
|
}
|
|
2242
|
-
|
|
2365
|
+
prepareKernelSync(kernel) {
|
|
2243
2366
|
return new Executable(kernel, void 0);
|
|
2244
2367
|
}
|
|
2245
|
-
|
|
2246
|
-
|
|
2368
|
+
async prepareRoutine(routine) {
|
|
2369
|
+
return this.prepareRoutineSync(routine);
|
|
2370
|
+
}
|
|
2371
|
+
prepareRoutineSync(routine) {
|
|
2372
|
+
return new Executable(routine, void 0);
|
|
2373
|
+
}
|
|
2374
|
+
dispatch(exe, inputs, outputs) {
|
|
2375
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
2376
|
+
const kernel = exe.source;
|
|
2377
|
+
const { exp, epilogue } = tuneNullopt(kernel);
|
|
2247
2378
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
2248
2379
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
2249
|
-
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2380
|
+
const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2250
2381
|
const inputArrays = inputBuffers.map((buf, i) => {
|
|
2251
2382
|
const dtype = usedArgs.get(i);
|
|
2252
2383
|
if (!dtype) return null;
|
|
@@ -2268,7 +2399,10 @@ var CpuBackend = class {
|
|
|
2268
2399
|
}, globals);
|
|
2269
2400
|
acc = kernel.reduction.evaluate(acc, item);
|
|
2270
2401
|
}
|
|
2271
|
-
outputArray[i] =
|
|
2402
|
+
outputArray[i] = epilogue.evaluate({
|
|
2403
|
+
acc,
|
|
2404
|
+
gidx: i
|
|
2405
|
+
}, globals);
|
|
2272
2406
|
}
|
|
2273
2407
|
}
|
|
2274
2408
|
#getBuffer(slot) {
|
|
@@ -2297,8 +2431,10 @@ var WasmAllocator = class {
|
|
|
2297
2431
|
const sizeClass = this.#findSizeClass(size);
|
|
2298
2432
|
const freeList = this.#freeLists.get(sizeClass);
|
|
2299
2433
|
let ptr;
|
|
2300
|
-
if (freeList && freeList.length > 0)
|
|
2301
|
-
|
|
2434
|
+
if (freeList && freeList.length > 0) {
|
|
2435
|
+
ptr = freeList.pop();
|
|
2436
|
+
new Uint8Array(this.#memory.buffer, ptr, sizeClass).fill(0);
|
|
2437
|
+
} else ptr = this.#bumpAlloc(sizeClass);
|
|
2302
2438
|
this.#allocatedBuffers.set(ptr, sizeClass);
|
|
2303
2439
|
return ptr;
|
|
2304
2440
|
}
|
|
@@ -2431,7 +2567,7 @@ function wasm_log(cg) {
|
|
|
2431
2567
|
const t2 = cg.local.declare(cg.f32);
|
|
2432
2568
|
cg.local.get(0);
|
|
2433
2569
|
cg.f32.const(0);
|
|
2434
|
-
cg.f32.
|
|
2570
|
+
cg.f32.lt();
|
|
2435
2571
|
cg.if(cg.void);
|
|
2436
2572
|
cg.f32.const(NaN);
|
|
2437
2573
|
cg.return();
|
|
@@ -2446,6 +2582,20 @@ function wasm_log(cg) {
|
|
|
2446
2582
|
cg.i32.const(127);
|
|
2447
2583
|
cg.i32.sub();
|
|
2448
2584
|
cg.local.set(e);
|
|
2585
|
+
cg.local.get(e);
|
|
2586
|
+
cg.i32.const(-127);
|
|
2587
|
+
cg.i32.eq();
|
|
2588
|
+
cg.if(cg.void);
|
|
2589
|
+
cg.f32.const(-Infinity);
|
|
2590
|
+
cg.return();
|
|
2591
|
+
cg.end();
|
|
2592
|
+
cg.local.get(e);
|
|
2593
|
+
cg.i32.const(128);
|
|
2594
|
+
cg.i32.eq();
|
|
2595
|
+
cg.if(cg.void);
|
|
2596
|
+
cg.local.get(0);
|
|
2597
|
+
cg.return();
|
|
2598
|
+
cg.end();
|
|
2449
2599
|
cg.local.get(bits);
|
|
2450
2600
|
cg.i32.const(8388607);
|
|
2451
2601
|
cg.i32.and();
|
|
@@ -2511,7 +2661,7 @@ function _sincos(cg) {
|
|
|
2511
2661
|
cg.f32.mul();
|
|
2512
2662
|
cg.f32.nearest();
|
|
2513
2663
|
cg.local.tee(qf);
|
|
2514
|
-
cg.i32.
|
|
2664
|
+
cg.i32.trunc_sat_f32_s();
|
|
2515
2665
|
cg.local.set(q);
|
|
2516
2666
|
cg.local.get(y);
|
|
2517
2667
|
cg.local.get(qf);
|
|
@@ -3598,6 +3748,7 @@ var F32x4 = class extends V128 {
|
|
|
3598
3748
|
|
|
3599
3749
|
//#endregion
|
|
3600
3750
|
//#region src/backend/wasm.ts
|
|
3751
|
+
const moduleCache = /* @__PURE__ */ new Map();
|
|
3601
3752
|
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3602
3753
|
var WasmBackend = class {
|
|
3603
3754
|
type = "wasm";
|
|
@@ -3649,15 +3800,25 @@ var WasmBackend = class {
|
|
|
3649
3800
|
if (count === void 0) count = buffer.byteLength - start;
|
|
3650
3801
|
return buffer.slice(start, start + count);
|
|
3651
3802
|
}
|
|
3652
|
-
async
|
|
3653
|
-
return this.
|
|
3803
|
+
async prepareKernel(kernel) {
|
|
3804
|
+
return this.prepareKernelSync(kernel);
|
|
3654
3805
|
}
|
|
3655
|
-
|
|
3656
|
-
const
|
|
3657
|
-
const module =
|
|
3806
|
+
prepareKernelSync(kernel) {
|
|
3807
|
+
const kernelHash = FpHash.hash(kernel);
|
|
3808
|
+
const module = runWithCache(moduleCache, kernelHash.toString(), () => {
|
|
3809
|
+
const bytes = codegenWasm(kernel);
|
|
3810
|
+
return new WebAssembly.Module(bytes);
|
|
3811
|
+
});
|
|
3658
3812
|
return new Executable(kernel, { module });
|
|
3659
3813
|
}
|
|
3814
|
+
async prepareRoutine(routine) {
|
|
3815
|
+
return this.prepareRoutineSync(routine);
|
|
3816
|
+
}
|
|
3817
|
+
prepareRoutineSync(routine) {
|
|
3818
|
+
return new Executable(routine, void 0);
|
|
3819
|
+
}
|
|
3660
3820
|
dispatch(exe, inputs, outputs) {
|
|
3821
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
3661
3822
|
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
3662
3823
|
const func = instance.exports.kernel;
|
|
3663
3824
|
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
@@ -3675,7 +3836,7 @@ function codegenWasm(kernel) {
|
|
|
3675
3836
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3676
3837
|
const cg = new CodeGenerator();
|
|
3677
3838
|
cg.memory.import("env", "memory");
|
|
3678
|
-
const distinctOps = mapSetUnion(tune.exp.distinctOps(),
|
|
3839
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
3679
3840
|
const funcs = {};
|
|
3680
3841
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3681
3842
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
@@ -3753,7 +3914,10 @@ function codegenWasm(kernel) {
|
|
|
3753
3914
|
cg.br(1);
|
|
3754
3915
|
cg.end();
|
|
3755
3916
|
cg.end();
|
|
3756
|
-
translateExp(cg, funcs,
|
|
3917
|
+
translateExp(cg, funcs, tune.epilogue, {
|
|
3918
|
+
acc,
|
|
3919
|
+
gidx
|
|
3920
|
+
});
|
|
3757
3921
|
} else translateExp(cg, funcs, tune.exp, { gidx });
|
|
3758
3922
|
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
3759
3923
|
cg.local.get(gidx);
|
|
@@ -4002,7 +4166,7 @@ async function createBackend(device) {
|
|
|
4002
4166
|
if (!navigator.gpu) return null;
|
|
4003
4167
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4004
4168
|
if (!adapter) return null;
|
|
4005
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4169
|
+
const { WebGPUBackend } = await import("./webgpu-ChVgx3b6.js");
|
|
4006
4170
|
const importantLimits = [
|
|
4007
4171
|
"maxBufferSize",
|
|
4008
4172
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4036,8 +4200,8 @@ function getBackend(device) {
|
|
|
4036
4200
|
return backend;
|
|
4037
4201
|
}
|
|
4038
4202
|
var Executable = class {
|
|
4039
|
-
constructor(
|
|
4040
|
-
this.
|
|
4203
|
+
constructor(source, data) {
|
|
4204
|
+
this.source = source;
|
|
4041
4205
|
this.data = data;
|
|
4042
4206
|
}
|
|
4043
4207
|
};
|
|
@@ -4053,7 +4217,11 @@ var UnsupportedOpError = class extends Error {
|
|
|
4053
4217
|
super(msg);
|
|
4054
4218
|
}
|
|
4055
4219
|
};
|
|
4220
|
+
var UnsupportedRoutineError = class extends Error {
|
|
4221
|
+
constructor(name, device) {
|
|
4222
|
+
super(`routine '${name}' is not supported in ${device} backend`);
|
|
4223
|
+
}
|
|
4224
|
+
};
|
|
4056
4225
|
|
|
4057
4226
|
//#endregion
|
|
4058
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4059
|
-
//# sourceMappingURL=backend-BqymqzuU.js.map
|
|
4227
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|