@jax-js/jax 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +16 -34
- package/dist/{backend-DeVfWEFS.cjs → backend-Bu9GY6sK.cjs} +222 -36
- package/dist/{backend-BqymqzuU.js → backend-tngXtWe4.js} +204 -36
- package/dist/index.cjs +1798 -955
- package/dist/index.d.cts +383 -97
- package/dist/index.d.ts +383 -97
- package/dist/index.js +1791 -949
- package/dist/{webgpu-BGuG58KZ.js → webgpu-ChVgx3b6.js} +410 -97
- package/dist/{webgpu-CcGP160M.cjs → webgpu-Oj3Kd-kd.cjs} +410 -97
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
<p align="center"><strong>
|
|
4
4
|
<a href="https://jax-js.com">Website</a> |
|
|
5
5
|
<a href="https://jax-js.com/docs/">API Reference</a> |
|
|
6
|
-
<a href="./FEATURES.md">Compatibility Table</a>
|
|
6
|
+
<a href="./FEATURES.md">Compatibility Table</a> |
|
|
7
|
+
<a href="https://discord.gg/BW6YsCd4Tf">Discord</a>
|
|
7
8
|
</strong></p>
|
|
8
9
|
|
|
9
10
|
**jax-js** is a machine learning framework for the browser. It aims to bring
|
|
@@ -257,36 +258,12 @@ await devicePut(ar, "webgpu"); // Now device="webgpu"
|
|
|
257
258
|
There are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a
|
|
258
259
|
self-contained way in other projects.
|
|
259
260
|
|
|
260
|
-
**`@jax-js/
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
const solver = adam(1e-3);
|
|
268
|
-
let optState = solver.init(params.ref);
|
|
269
|
-
let updates: np.Array;
|
|
270
|
-
|
|
271
|
-
const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();
|
|
272
|
-
|
|
273
|
-
for (let i = 0; i < 100; i++) {
|
|
274
|
-
const paramsGrad = grad(f)(params.ref);
|
|
275
|
-
[updates, optState] = solver.update(paramsGrad, optState);
|
|
276
|
-
params = applyUpdates(params, updates);
|
|
277
|
-
}
|
|
278
|
-
```
|
|
279
|
-
|
|
280
|
-
**`@jax-js/loaders`** can load tensors from various formats like Safetensors, includes a fast and
|
|
281
|
-
compliant implementation of BPE, and caches HTTP requests for large assets like model weights in
|
|
282
|
-
OPFS.
|
|
283
|
-
|
|
284
|
-
```ts
|
|
285
|
-
import { tokenizers } from "@jax-js/loaders";
|
|
286
|
-
|
|
287
|
-
const enc = await tokenizers.getBpe("clip");
|
|
288
|
-
const tokens = enc.encode("Hello, world!"); // => [ 49406, 3306, 267, 1002, ... ]
|
|
289
|
-
```
|
|
261
|
+
- [**`@jax-js/loaders`**](packages/loaders) can load tensors from various formats like Safetensors,
|
|
262
|
+
includes a fast and compliant implementation of BPE, and caches HTTP requests for large assets
|
|
263
|
+
like model weights in OPFS.
|
|
264
|
+
- [**`@jax-js/onnx`**](packages/onnx) is a model loader from the [ONNX](https://onnx.ai/) format
|
|
265
|
+
into native jax-js functions.
|
|
266
|
+
- [**`@jax-js/optax`**](packages/optax) provides implementations of optimizers like Adam and SGD.
|
|
290
267
|
|
|
291
268
|
### Performance
|
|
292
269
|
|
|
@@ -311,6 +288,7 @@ If you make something cool with jax-js, don't be a stranger! We can feature it h
|
|
|
311
288
|
|
|
312
289
|
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
313
290
|
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
291
|
+
- [Object detection with DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
314
292
|
- [In-browser REPL](https://jax-js.com/repl)
|
|
315
293
|
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
316
294
|
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
@@ -346,12 +324,16 @@ pnpm -C website dev
|
|
|
346
324
|
|
|
347
325
|
## Future work / help wanted
|
|
348
326
|
|
|
349
|
-
Contributions are welcomed!
|
|
327
|
+
Contributions are welcomed! Some fruitful areas to look into:
|
|
350
328
|
|
|
351
329
|
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
352
330
|
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
353
331
|
and multithreading. (Even single-threaded Wasm could be ~20x faster.)
|
|
354
|
-
-
|
|
355
|
-
|
|
332
|
+
- Adding support for `jax.profiling`, in particular the start and end trace functions. We should be
|
|
333
|
+
able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
|
|
334
|
+
to help with model performance debugging.
|
|
335
|
+
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
356
336
|
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
357
337
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
338
|
+
|
|
339
|
+
You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
|
|
@@ -69,6 +69,9 @@ function zipn(...arrays) {
|
|
|
69
69
|
const minLength = Math.min(...arrays.map((x) => x.length));
|
|
70
70
|
return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
|
|
71
71
|
}
|
|
72
|
+
function sorted(arr) {
|
|
73
|
+
return [...arr].sort((a, b) => a - b);
|
|
74
|
+
}
|
|
72
75
|
function rep(length, value) {
|
|
73
76
|
if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
|
|
74
77
|
return new Array(length).fill(value);
|
|
@@ -145,7 +148,7 @@ function normalizeAxis(axis, ndim) {
|
|
|
145
148
|
if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
|
|
146
149
|
seen.add(ca);
|
|
147
150
|
}
|
|
148
|
-
return
|
|
151
|
+
return sorted(seen);
|
|
149
152
|
}
|
|
150
153
|
}
|
|
151
154
|
function range(start, stop, step = 1) {
|
|
@@ -558,16 +561,16 @@ var AluExp = class AluExp {
|
|
|
558
561
|
});
|
|
559
562
|
}
|
|
560
563
|
/** Reindex gid values in this expression as needed. */
|
|
561
|
-
reindexGids(
|
|
564
|
+
reindexGids(newGids) {
|
|
562
565
|
return this.rewrite((exp) => {
|
|
563
566
|
if (exp.op === AluOp.GlobalIndex) {
|
|
564
567
|
const [gid, len] = exp.arg;
|
|
565
|
-
const newGid =
|
|
566
|
-
if (newGid !==
|
|
568
|
+
const newGid = newGids[gid];
|
|
569
|
+
if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
|
|
567
570
|
} else if (exp.op === AluOp.GlobalView) {
|
|
568
571
|
const gid = exp.arg[0];
|
|
569
|
-
const newGid =
|
|
570
|
-
if (newGid !==
|
|
572
|
+
const newGid = newGids[gid];
|
|
573
|
+
if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
|
|
571
574
|
}
|
|
572
575
|
});
|
|
573
576
|
}
|
|
@@ -781,7 +784,7 @@ var AluExp = class AluExp {
|
|
|
781
784
|
if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
|
|
782
785
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
783
786
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
784
|
-
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
787
|
+
if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
|
|
785
788
|
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
786
789
|
}
|
|
787
790
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
@@ -1328,7 +1331,7 @@ var Reduction = class {
|
|
|
1328
1331
|
/** Evaluate this operation on CPU. */
|
|
1329
1332
|
evaluate(...values) {
|
|
1330
1333
|
if (this.dtype === DType.Bool) {
|
|
1331
|
-
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b,
|
|
1334
|
+
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, false);
|
|
1332
1335
|
else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
|
|
1333
1336
|
} else if (this.dtype === DType.Int32) {
|
|
1334
1337
|
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
|
|
@@ -1438,6 +1441,112 @@ function erfc(x) {
|
|
|
1438
1441
|
else return 2 - _erfapprox$1(-x);
|
|
1439
1442
|
}
|
|
1440
1443
|
|
|
1444
|
+
//#endregion
|
|
1445
|
+
//#region src/routine.ts
|
|
1446
|
+
/**
|
|
1447
|
+
* Advanced operations that don't fit into the `AluExp` compiler representation.
|
|
1448
|
+
*
|
|
1449
|
+
* Some routines like iterative matrix algorithms, FFTs, or sorting may not be
|
|
1450
|
+
* easy to express efficiently as a `Kernel` object. These also tend to be
|
|
1451
|
+
* somewhat expensive, so the benefit of kernel fusion and inlining is less
|
|
1452
|
+
* relevant.
|
|
1453
|
+
*
|
|
1454
|
+
* For these operations, we dispatch them as a custom operation on the backend,
|
|
1455
|
+
* which each backend implements in a specific way. These are listed in the
|
|
1456
|
+
* `Routines` enum below.
|
|
1457
|
+
*
|
|
1458
|
+
* Routines cannot be fused into other kernels and always operate on contiguous
|
|
1459
|
+
* arrays (default `ShapeTracker`).
|
|
1460
|
+
*/
|
|
1461
|
+
var Routine = class {
|
|
1462
|
+
constructor(name, type, params) {
|
|
1463
|
+
this.name = name;
|
|
1464
|
+
this.type = type;
|
|
1465
|
+
this.params = params;
|
|
1466
|
+
}
|
|
1467
|
+
};
|
|
1468
|
+
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1469
|
+
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1470
|
+
/** Stable sorting algorithm along the last axis. */
|
|
1471
|
+
Routines$1["Sort"] = "Sort";
|
|
1472
|
+
/** Returns `int32` indices of the stably sorted array. */
|
|
1473
|
+
Routines$1["Argsort"] = "Argsort";
|
|
1474
|
+
/** Solve a triangular system of questions. */
|
|
1475
|
+
Routines$1["TriangularSolve"] = "TriangularSolve";
|
|
1476
|
+
/** Cholesky decomposition of 2D positive semi-definite matrices. */
|
|
1477
|
+
Routines$1["Cholesky"] = "Cholesky";
|
|
1478
|
+
return Routines$1;
|
|
1479
|
+
}({});
|
|
1480
|
+
function runCpuRoutine(routine, inputs, outputs) {
|
|
1481
|
+
const { name, type } = routine;
|
|
1482
|
+
const inputAr = inputs.map((buf, i) => dtypedArray(type.inputDtypes[i], buf));
|
|
1483
|
+
const outputAr = outputs.map((buf, i) => dtypedArray(type.outputDtypes[i], buf));
|
|
1484
|
+
switch (name) {
|
|
1485
|
+
case Routines.Sort: return runSort(type, inputAr, outputAr);
|
|
1486
|
+
case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
|
|
1487
|
+
case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
|
|
1488
|
+
case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
|
|
1489
|
+
default:
|
|
1490
|
+
}
|
|
1491
|
+
}
|
|
1492
|
+
function runSort(type, [x], [y]) {
|
|
1493
|
+
const xs = type.inputShapes[0];
|
|
1494
|
+
if (xs.length === 0) throw new Error("sort: cannot sort a scalar");
|
|
1495
|
+
const n = xs[xs.length - 1];
|
|
1496
|
+
y.set(x);
|
|
1497
|
+
for (let i = 0; i < y.length; i += n) y.subarray(i, i + n).sort();
|
|
1498
|
+
}
|
|
1499
|
+
function runArgsort(type, [x], [y, yi]) {
|
|
1500
|
+
const xs = type.inputShapes[0];
|
|
1501
|
+
if (xs.length === 0) throw new Error("argsort: cannot sort a scalar");
|
|
1502
|
+
const n = xs[xs.length - 1];
|
|
1503
|
+
for (let offset = 0; offset < y.length; offset += n) {
|
|
1504
|
+
const ar = x.subarray(offset, offset + n);
|
|
1505
|
+
const out = y.subarray(offset, offset + n);
|
|
1506
|
+
const outi = yi.subarray(offset, offset + n);
|
|
1507
|
+
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1508
|
+
outi.sort((a, b) => ar[a] - ar[b]);
|
|
1509
|
+
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1510
|
+
}
|
|
1511
|
+
}
|
|
1512
|
+
function runTriangularSolve(type, [a, b], [x], { unitDiagonal }) {
|
|
1513
|
+
const as = type.inputShapes[0];
|
|
1514
|
+
const bs = type.inputShapes[1];
|
|
1515
|
+
if (as.length < 2) throw new Error(`triangular_solve: a must be at least 2D, got ${as}`);
|
|
1516
|
+
if (bs.length < 2) throw new Error(`triangular_solve: b must be at least 2D, got ${bs}`);
|
|
1517
|
+
const n = as[as.length - 2];
|
|
1518
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
1519
|
+
const batch = bs[bs.length - 2];
|
|
1520
|
+
for (let counter = 0; counter < a.length / (n * n); counter++) {
|
|
1521
|
+
const a1 = a.subarray(counter * n * n, (counter + 1) * n * n);
|
|
1522
|
+
for (let t = 0; t < batch; t++) {
|
|
1523
|
+
const b1 = b.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1524
|
+
const x1 = x.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1525
|
+
for (let i = n - 1; i >= 0; i--) {
|
|
1526
|
+
let sum = b1[i];
|
|
1527
|
+
for (let j = i + 1; j < n; j++) sum -= a1[i * n + j] * x1[j];
|
|
1528
|
+
x1[i] = unitDiagonal ? sum : sum / a1[i * n + i];
|
|
1529
|
+
}
|
|
1530
|
+
}
|
|
1531
|
+
}
|
|
1532
|
+
}
|
|
1533
|
+
function runCholesky(type, [x], [y]) {
|
|
1534
|
+
const xs = type.inputShapes[0];
|
|
1535
|
+
if (xs.length < 2) throw new Error("cholesky: input must be at least 2D");
|
|
1536
|
+
const n = xs[xs.length - 2];
|
|
1537
|
+
const m = xs[xs.length - 1];
|
|
1538
|
+
if (n !== m) throw new Error(`cholesky: input must be square, got [${n}, ${m}]`);
|
|
1539
|
+
for (let offset = 0; offset < y.length; offset += n * n) {
|
|
1540
|
+
const ar = x.subarray(offset, offset + n * n);
|
|
1541
|
+
const out = y.subarray(offset, offset + n * n);
|
|
1542
|
+
for (let i = 0; i < n; i++) for (let j = 0; j <= i; j++) {
|
|
1543
|
+
let sum = ar[i * n + j];
|
|
1544
|
+
for (let k = 0; k < j; k++) sum -= out[i * n + k] * out[j * n + k];
|
|
1545
|
+
out[i * n + j] = i === j ? Math.sqrt(sum) : sum / out[j * n + j];
|
|
1546
|
+
}
|
|
1547
|
+
}
|
|
1548
|
+
}
|
|
1549
|
+
|
|
1441
1550
|
//#endregion
|
|
1442
1551
|
//#region src/shape.ts
|
|
1443
1552
|
const jstr = JSON.stringify;
|
|
@@ -1908,7 +2017,7 @@ var ShapeTracker = class ShapeTracker {
|
|
|
1908
2017
|
let st = this;
|
|
1909
2018
|
if (axis.length > 0) {
|
|
1910
2019
|
const unsqueezed = [...st.shape];
|
|
1911
|
-
for (const i of axis
|
|
2020
|
+
for (const i of sorted(axis)) unsqueezed.splice(i, 0, 1);
|
|
1912
2021
|
st = st.reshape(unsqueezed);
|
|
1913
2022
|
}
|
|
1914
2023
|
return st.expand(newShape);
|
|
@@ -2067,7 +2176,8 @@ function tuneNullopt(kernel) {
|
|
|
2067
2176
|
if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2068
2177
|
return {
|
|
2069
2178
|
exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2070
|
-
|
|
2179
|
+
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2180
|
+
outputIdxExp: vars.gidx,
|
|
2071
2181
|
threadCount: kernel.size,
|
|
2072
2182
|
size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
|
|
2073
2183
|
};
|
|
@@ -2100,7 +2210,11 @@ function tuneWebgpu(kernel) {
|
|
|
2100
2210
|
while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
|
|
2101
2211
|
const choices = [];
|
|
2102
2212
|
const composedSts = sts.map((st) => st.compose(dim.st));
|
|
2103
|
-
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2213
|
+
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2214
|
+
3,
|
|
2215
|
+
4,
|
|
2216
|
+
5
|
|
2217
|
+
]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
|
|
2104
2218
|
let nonzeroStrides = 0;
|
|
2105
2219
|
let totalStrides = 0;
|
|
2106
2220
|
for (const st of composedSts) {
|
|
@@ -2128,7 +2242,7 @@ function tuneWebgpu(kernel) {
|
|
|
2128
2242
|
break;
|
|
2129
2243
|
}
|
|
2130
2244
|
}
|
|
2131
|
-
for (const ax of
|
|
2245
|
+
for (const ax of sorted(upcastedAxis)) {
|
|
2132
2246
|
const s = dim.st.shape[ax];
|
|
2133
2247
|
for (const amount of [8, 4]) if (s % amount === 0) {
|
|
2134
2248
|
dim.applyLocal(ax, amount);
|
|
@@ -2176,7 +2290,15 @@ function tuneWebgpu(kernel) {
|
|
|
2176
2290
|
});
|
|
2177
2291
|
const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
|
|
2178
2292
|
const outputUpcast = dim.outputSt.shape.slice(dim.groups);
|
|
2179
|
-
const
|
|
2293
|
+
const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
|
|
2294
|
+
const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
|
|
2295
|
+
const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
|
|
2296
|
+
if (exp$1.op === AluOp.GlobalView) {
|
|
2297
|
+
const gid = exp$1.arg[0];
|
|
2298
|
+
const st = exp$1.arg[1];
|
|
2299
|
+
return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
|
|
2300
|
+
}
|
|
2301
|
+
});
|
|
2180
2302
|
if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
|
|
2181
2303
|
const size = {
|
|
2182
2304
|
groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
|
|
@@ -2186,6 +2308,7 @@ function tuneWebgpu(kernel) {
|
|
|
2186
2308
|
};
|
|
2187
2309
|
return {
|
|
2188
2310
|
exp: newExp.simplify(),
|
|
2311
|
+
epilogue: newEpilogue.simplify(),
|
|
2189
2312
|
outputIdxExp: outputIdxExp.simplify(),
|
|
2190
2313
|
threadCount: kernel.size / size.upcast * size.groups,
|
|
2191
2314
|
size
|
|
@@ -2237,17 +2360,25 @@ var CpuBackend = class {
|
|
|
2237
2360
|
if (count === void 0) count = buffer.byteLength - start;
|
|
2238
2361
|
return buffer.slice(start, start + count);
|
|
2239
2362
|
}
|
|
2240
|
-
async
|
|
2241
|
-
return this.
|
|
2363
|
+
async prepareKernel(kernel) {
|
|
2364
|
+
return this.prepareKernelSync(kernel);
|
|
2242
2365
|
}
|
|
2243
|
-
|
|
2366
|
+
prepareKernelSync(kernel) {
|
|
2244
2367
|
return new Executable(kernel, void 0);
|
|
2245
2368
|
}
|
|
2246
|
-
|
|
2247
|
-
|
|
2369
|
+
async prepareRoutine(routine) {
|
|
2370
|
+
return this.prepareRoutineSync(routine);
|
|
2371
|
+
}
|
|
2372
|
+
prepareRoutineSync(routine) {
|
|
2373
|
+
return new Executable(routine, void 0);
|
|
2374
|
+
}
|
|
2375
|
+
dispatch(exe, inputs, outputs) {
|
|
2376
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
2377
|
+
const kernel = exe.source;
|
|
2378
|
+
const { exp, epilogue } = tuneNullopt(kernel);
|
|
2248
2379
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
2249
2380
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
2250
|
-
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2381
|
+
const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2251
2382
|
const inputArrays = inputBuffers.map((buf, i) => {
|
|
2252
2383
|
const dtype = usedArgs.get(i);
|
|
2253
2384
|
if (!dtype) return null;
|
|
@@ -2269,7 +2400,10 @@ var CpuBackend = class {
|
|
|
2269
2400
|
}, globals);
|
|
2270
2401
|
acc = kernel.reduction.evaluate(acc, item);
|
|
2271
2402
|
}
|
|
2272
|
-
outputArray[i] =
|
|
2403
|
+
outputArray[i] = epilogue.evaluate({
|
|
2404
|
+
acc,
|
|
2405
|
+
gidx: i
|
|
2406
|
+
}, globals);
|
|
2273
2407
|
}
|
|
2274
2408
|
}
|
|
2275
2409
|
#getBuffer(slot) {
|
|
@@ -2298,8 +2432,10 @@ var WasmAllocator = class {
|
|
|
2298
2432
|
const sizeClass = this.#findSizeClass(size);
|
|
2299
2433
|
const freeList = this.#freeLists.get(sizeClass);
|
|
2300
2434
|
let ptr;
|
|
2301
|
-
if (freeList && freeList.length > 0)
|
|
2302
|
-
|
|
2435
|
+
if (freeList && freeList.length > 0) {
|
|
2436
|
+
ptr = freeList.pop();
|
|
2437
|
+
new Uint8Array(this.#memory.buffer, ptr, sizeClass).fill(0);
|
|
2438
|
+
} else ptr = this.#bumpAlloc(sizeClass);
|
|
2303
2439
|
this.#allocatedBuffers.set(ptr, sizeClass);
|
|
2304
2440
|
return ptr;
|
|
2305
2441
|
}
|
|
@@ -2432,7 +2568,7 @@ function wasm_log(cg) {
|
|
|
2432
2568
|
const t2 = cg.local.declare(cg.f32);
|
|
2433
2569
|
cg.local.get(0);
|
|
2434
2570
|
cg.f32.const(0);
|
|
2435
|
-
cg.f32.
|
|
2571
|
+
cg.f32.lt();
|
|
2436
2572
|
cg.if(cg.void);
|
|
2437
2573
|
cg.f32.const(NaN);
|
|
2438
2574
|
cg.return();
|
|
@@ -2447,6 +2583,20 @@ function wasm_log(cg) {
|
|
|
2447
2583
|
cg.i32.const(127);
|
|
2448
2584
|
cg.i32.sub();
|
|
2449
2585
|
cg.local.set(e);
|
|
2586
|
+
cg.local.get(e);
|
|
2587
|
+
cg.i32.const(-127);
|
|
2588
|
+
cg.i32.eq();
|
|
2589
|
+
cg.if(cg.void);
|
|
2590
|
+
cg.f32.const(-Infinity);
|
|
2591
|
+
cg.return();
|
|
2592
|
+
cg.end();
|
|
2593
|
+
cg.local.get(e);
|
|
2594
|
+
cg.i32.const(128);
|
|
2595
|
+
cg.i32.eq();
|
|
2596
|
+
cg.if(cg.void);
|
|
2597
|
+
cg.local.get(0);
|
|
2598
|
+
cg.return();
|
|
2599
|
+
cg.end();
|
|
2450
2600
|
cg.local.get(bits);
|
|
2451
2601
|
cg.i32.const(8388607);
|
|
2452
2602
|
cg.i32.and();
|
|
@@ -2512,7 +2662,7 @@ function _sincos(cg) {
|
|
|
2512
2662
|
cg.f32.mul();
|
|
2513
2663
|
cg.f32.nearest();
|
|
2514
2664
|
cg.local.tee(qf);
|
|
2515
|
-
cg.i32.
|
|
2665
|
+
cg.i32.trunc_sat_f32_s();
|
|
2516
2666
|
cg.local.set(q);
|
|
2517
2667
|
cg.local.get(y);
|
|
2518
2668
|
cg.local.get(qf);
|
|
@@ -3599,6 +3749,7 @@ var F32x4 = class extends V128 {
|
|
|
3599
3749
|
|
|
3600
3750
|
//#endregion
|
|
3601
3751
|
//#region src/backend/wasm.ts
|
|
3752
|
+
const moduleCache = /* @__PURE__ */ new Map();
|
|
3602
3753
|
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3603
3754
|
var WasmBackend = class {
|
|
3604
3755
|
type = "wasm";
|
|
@@ -3650,15 +3801,25 @@ var WasmBackend = class {
|
|
|
3650
3801
|
if (count === void 0) count = buffer.byteLength - start;
|
|
3651
3802
|
return buffer.slice(start, start + count);
|
|
3652
3803
|
}
|
|
3653
|
-
async
|
|
3654
|
-
return this.
|
|
3804
|
+
async prepareKernel(kernel) {
|
|
3805
|
+
return this.prepareKernelSync(kernel);
|
|
3655
3806
|
}
|
|
3656
|
-
|
|
3657
|
-
const
|
|
3658
|
-
const module$1 =
|
|
3807
|
+
prepareKernelSync(kernel) {
|
|
3808
|
+
const kernelHash = FpHash.hash(kernel);
|
|
3809
|
+
const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => {
|
|
3810
|
+
const bytes = codegenWasm(kernel);
|
|
3811
|
+
return new WebAssembly.Module(bytes);
|
|
3812
|
+
});
|
|
3659
3813
|
return new Executable(kernel, { module: module$1 });
|
|
3660
3814
|
}
|
|
3815
|
+
async prepareRoutine(routine) {
|
|
3816
|
+
return this.prepareRoutineSync(routine);
|
|
3817
|
+
}
|
|
3818
|
+
prepareRoutineSync(routine) {
|
|
3819
|
+
return new Executable(routine, void 0);
|
|
3820
|
+
}
|
|
3661
3821
|
dispatch(exe, inputs, outputs) {
|
|
3822
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
3662
3823
|
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
3663
3824
|
const func = instance.exports.kernel;
|
|
3664
3825
|
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
@@ -3676,7 +3837,7 @@ function codegenWasm(kernel) {
|
|
|
3676
3837
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3677
3838
|
const cg = new CodeGenerator();
|
|
3678
3839
|
cg.memory.import("env", "memory");
|
|
3679
|
-
const distinctOps = mapSetUnion(tune.exp.distinctOps(),
|
|
3840
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
3680
3841
|
const funcs = {};
|
|
3681
3842
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3682
3843
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
@@ -3754,7 +3915,10 @@ function codegenWasm(kernel) {
|
|
|
3754
3915
|
cg.br(1);
|
|
3755
3916
|
cg.end();
|
|
3756
3917
|
cg.end();
|
|
3757
|
-
translateExp(cg, funcs,
|
|
3918
|
+
translateExp(cg, funcs, tune.epilogue, {
|
|
3919
|
+
acc,
|
|
3920
|
+
gidx
|
|
3921
|
+
});
|
|
3758
3922
|
} else translateExp(cg, funcs, tune.exp, { gidx });
|
|
3759
3923
|
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
3760
3924
|
cg.local.get(gidx);
|
|
@@ -4003,7 +4167,7 @@ async function createBackend(device) {
|
|
|
4003
4167
|
if (!navigator.gpu) return null;
|
|
4004
4168
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4005
4169
|
if (!adapter) return null;
|
|
4006
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4170
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Oj3Kd-kd.cjs"));
|
|
4007
4171
|
const importantLimits = [
|
|
4008
4172
|
"maxBufferSize",
|
|
4009
4173
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4037,8 +4201,8 @@ function getBackend(device) {
|
|
|
4037
4201
|
return backend;
|
|
4038
4202
|
}
|
|
4039
4203
|
var Executable = class {
|
|
4040
|
-
constructor(
|
|
4041
|
-
this.
|
|
4204
|
+
constructor(source, data) {
|
|
4205
|
+
this.source = source;
|
|
4042
4206
|
this.data = data;
|
|
4043
4207
|
}
|
|
4044
4208
|
};
|
|
@@ -4054,6 +4218,11 @@ var UnsupportedOpError = class extends Error {
|
|
|
4054
4218
|
super(msg);
|
|
4055
4219
|
}
|
|
4056
4220
|
};
|
|
4221
|
+
var UnsupportedRoutineError = class extends Error {
|
|
4222
|
+
constructor(name, device) {
|
|
4223
|
+
super(`routine '${name}' is not supported in ${device} backend`);
|
|
4224
|
+
}
|
|
4225
|
+
};
|
|
4057
4226
|
|
|
4058
4227
|
//#endregion
|
|
4059
4228
|
Object.defineProperty(exports, 'AluExp', {
|
|
@@ -4122,6 +4291,18 @@ Object.defineProperty(exports, 'Reduction', {
|
|
|
4122
4291
|
return Reduction;
|
|
4123
4292
|
}
|
|
4124
4293
|
});
|
|
4294
|
+
Object.defineProperty(exports, 'Routine', {
|
|
4295
|
+
enumerable: true,
|
|
4296
|
+
get: function () {
|
|
4297
|
+
return Routine;
|
|
4298
|
+
}
|
|
4299
|
+
});
|
|
4300
|
+
Object.defineProperty(exports, 'Routines', {
|
|
4301
|
+
enumerable: true,
|
|
4302
|
+
get: function () {
|
|
4303
|
+
return Routines;
|
|
4304
|
+
}
|
|
4305
|
+
});
|
|
4125
4306
|
Object.defineProperty(exports, 'ShapeTracker', {
|
|
4126
4307
|
enumerable: true,
|
|
4127
4308
|
get: function () {
|
|
@@ -4140,6 +4321,12 @@ Object.defineProperty(exports, 'UnsupportedOpError', {
|
|
|
4140
4321
|
return UnsupportedOpError;
|
|
4141
4322
|
}
|
|
4142
4323
|
});
|
|
4324
|
+
Object.defineProperty(exports, 'UnsupportedRoutineError', {
|
|
4325
|
+
enumerable: true,
|
|
4326
|
+
get: function () {
|
|
4327
|
+
return UnsupportedRoutineError;
|
|
4328
|
+
}
|
|
4329
|
+
});
|
|
4143
4330
|
Object.defineProperty(exports, 'accessorAluExp', {
|
|
4144
4331
|
enumerable: true,
|
|
4145
4332
|
get: function () {
|
|
@@ -4349,5 +4536,4 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4349
4536
|
get: function () {
|
|
4350
4537
|
return zipn;
|
|
4351
4538
|
}
|
|
4352
|
-
});
|
|
4353
|
-
//# sourceMappingURL=backend-DeVfWEFS.cjs.map
|
|
4539
|
+
});
|