@jax-js/jax 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -32
- package/dist/{backend-BqymqzuU.js → backend-BY8wlLEl.js} +58 -20
- package/dist/{backend-DeVfWEFS.cjs → backend-CmaidnkQ.cjs} +58 -20
- package/dist/index.cjs +298 -134
- package/dist/index.d.cts +21 -5
- package/dist/index.d.ts +21 -5
- package/dist/index.js +298 -134
- package/dist/{webgpu-CcGP160M.cjs → webgpu-BVns4DbI.cjs} +14 -6
- package/dist/{webgpu-BGuG58KZ.js → webgpu-C9iAP5h5.js} +14 -6
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -257,36 +257,12 @@ await devicePut(ar, "webgpu"); // Now device="webgpu"
|
|
|
257
257
|
There are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a
|
|
258
258
|
self-contained way in other projects.
|
|
259
259
|
|
|
260
|
-
**`@jax-js/
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
const solver = adam(1e-3);
|
|
268
|
-
let optState = solver.init(params.ref);
|
|
269
|
-
let updates: np.Array;
|
|
270
|
-
|
|
271
|
-
const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();
|
|
272
|
-
|
|
273
|
-
for (let i = 0; i < 100; i++) {
|
|
274
|
-
const paramsGrad = grad(f)(params.ref);
|
|
275
|
-
[updates, optState] = solver.update(paramsGrad, optState);
|
|
276
|
-
params = applyUpdates(params, updates);
|
|
277
|
-
}
|
|
278
|
-
```
|
|
279
|
-
|
|
280
|
-
**`@jax-js/loaders`** can load tensors from various formats like Safetensors, includes a fast and
|
|
281
|
-
compliant implementation of BPE, and caches HTTP requests for large assets like model weights in
|
|
282
|
-
OPFS.
|
|
283
|
-
|
|
284
|
-
```ts
|
|
285
|
-
import { tokenizers } from "@jax-js/loaders";
|
|
286
|
-
|
|
287
|
-
const enc = await tokenizers.getBpe("clip");
|
|
288
|
-
const tokens = enc.encode("Hello, world!"); // => [ 49406, 3306, 267, 1002, ... ]
|
|
289
|
-
```
|
|
260
|
+
- [**`@jax-js/loaders`**](packages/loaders) can load tensors from various formats like Safetensors,
|
|
261
|
+
includes a fast and compliant implementation of BPE, and caches HTTP requests for large assets
|
|
262
|
+
like model weights in OPFS.
|
|
263
|
+
- [**`@jax-js/onnx`**](packages/onnx) is a model loader from the [ONNX](https://onnx.ai/) format
|
|
264
|
+
into native jax-js functions.
|
|
265
|
+
- [**`@jax-js/optax`**](packages/optax) provides implementations of optimizers like Adam and SGD.
|
|
290
266
|
|
|
291
267
|
### Performance
|
|
292
268
|
|
|
@@ -311,6 +287,7 @@ If you make something cool with jax-js, don't be a stranger! We can feature it h
|
|
|
311
287
|
|
|
312
288
|
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
313
289
|
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
290
|
+
- [Object detection with DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
314
291
|
- [In-browser REPL](https://jax-js.com/repl)
|
|
315
292
|
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
316
293
|
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
@@ -351,7 +328,9 @@ Contributions are welcomed! Especially in:
|
|
|
351
328
|
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
352
329
|
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
353
330
|
and multithreading. (Even single-threaded Wasm could be ~20x faster.)
|
|
354
|
-
-
|
|
355
|
-
|
|
331
|
+
- Adding support for `jax.profiling`, in particular the start and end trace functions. We should be
|
|
332
|
+
able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
|
|
333
|
+
to help with model performance debugging.
|
|
334
|
+
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
356
335
|
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
357
336
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
@@ -557,16 +557,16 @@ var AluExp = class AluExp {
|
|
|
557
557
|
});
|
|
558
558
|
}
|
|
559
559
|
/** Reindex gid values in this expression as needed. */
|
|
560
|
-
reindexGids(
|
|
560
|
+
reindexGids(newGids) {
|
|
561
561
|
return this.rewrite((exp) => {
|
|
562
562
|
if (exp.op === AluOp.GlobalIndex) {
|
|
563
563
|
const [gid, len] = exp.arg;
|
|
564
|
-
const newGid =
|
|
565
|
-
if (newGid !==
|
|
564
|
+
const newGid = newGids[gid];
|
|
565
|
+
if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
|
|
566
566
|
} else if (exp.op === AluOp.GlobalView) {
|
|
567
567
|
const gid = exp.arg[0];
|
|
568
|
-
const newGid =
|
|
569
|
-
if (newGid !==
|
|
568
|
+
const newGid = newGids[gid];
|
|
569
|
+
if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
|
|
570
570
|
}
|
|
571
571
|
});
|
|
572
572
|
}
|
|
@@ -780,7 +780,7 @@ var AluExp = class AluExp {
|
|
|
780
780
|
if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
|
|
781
781
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
782
782
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
783
|
-
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
783
|
+
if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
|
|
784
784
|
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
785
785
|
}
|
|
786
786
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
@@ -2066,7 +2066,8 @@ function tuneNullopt(kernel) {
|
|
|
2066
2066
|
if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2067
2067
|
return {
|
|
2068
2068
|
exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2069
|
-
|
|
2069
|
+
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2070
|
+
outputIdxExp: vars.gidx,
|
|
2070
2071
|
threadCount: kernel.size,
|
|
2071
2072
|
size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
|
|
2072
2073
|
};
|
|
@@ -2099,7 +2100,11 @@ function tuneWebgpu(kernel) {
|
|
|
2099
2100
|
while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
|
|
2100
2101
|
const choices = [];
|
|
2101
2102
|
const composedSts = sts.map((st) => st.compose(dim.st));
|
|
2102
|
-
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2103
|
+
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2104
|
+
3,
|
|
2105
|
+
4,
|
|
2106
|
+
5
|
|
2107
|
+
]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
|
|
2103
2108
|
let nonzeroStrides = 0;
|
|
2104
2109
|
let totalStrides = 0;
|
|
2105
2110
|
for (const st of composedSts) {
|
|
@@ -2175,7 +2180,15 @@ function tuneWebgpu(kernel) {
|
|
|
2175
2180
|
});
|
|
2176
2181
|
const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
|
|
2177
2182
|
const outputUpcast = dim.outputSt.shape.slice(dim.groups);
|
|
2178
|
-
const
|
|
2183
|
+
const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
|
|
2184
|
+
const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
|
|
2185
|
+
const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
|
|
2186
|
+
if (exp$1.op === AluOp.GlobalView) {
|
|
2187
|
+
const gid = exp$1.arg[0];
|
|
2188
|
+
const st = exp$1.arg[1];
|
|
2189
|
+
return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
|
|
2190
|
+
}
|
|
2191
|
+
});
|
|
2179
2192
|
if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
|
|
2180
2193
|
const size = {
|
|
2181
2194
|
groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
|
|
@@ -2185,6 +2198,7 @@ function tuneWebgpu(kernel) {
|
|
|
2185
2198
|
};
|
|
2186
2199
|
return {
|
|
2187
2200
|
exp: newExp.simplify(),
|
|
2201
|
+
epilogue: newEpilogue.simplify(),
|
|
2188
2202
|
outputIdxExp: outputIdxExp.simplify(),
|
|
2189
2203
|
threadCount: kernel.size / size.upcast * size.groups,
|
|
2190
2204
|
size
|
|
@@ -2243,10 +2257,10 @@ var CpuBackend = class {
|
|
|
2243
2257
|
return new Executable(kernel, void 0);
|
|
2244
2258
|
}
|
|
2245
2259
|
dispatch({ kernel }, inputs, outputs) {
|
|
2246
|
-
const { exp } = tuneNullopt(kernel);
|
|
2260
|
+
const { exp, epilogue } = tuneNullopt(kernel);
|
|
2247
2261
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
2248
2262
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
2249
|
-
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2263
|
+
const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2250
2264
|
const inputArrays = inputBuffers.map((buf, i) => {
|
|
2251
2265
|
const dtype = usedArgs.get(i);
|
|
2252
2266
|
if (!dtype) return null;
|
|
@@ -2268,7 +2282,10 @@ var CpuBackend = class {
|
|
|
2268
2282
|
}, globals);
|
|
2269
2283
|
acc = kernel.reduction.evaluate(acc, item);
|
|
2270
2284
|
}
|
|
2271
|
-
outputArray[i] =
|
|
2285
|
+
outputArray[i] = epilogue.evaluate({
|
|
2286
|
+
acc,
|
|
2287
|
+
gidx: i
|
|
2288
|
+
}, globals);
|
|
2272
2289
|
}
|
|
2273
2290
|
}
|
|
2274
2291
|
#getBuffer(slot) {
|
|
@@ -2431,7 +2448,7 @@ function wasm_log(cg) {
|
|
|
2431
2448
|
const t2 = cg.local.declare(cg.f32);
|
|
2432
2449
|
cg.local.get(0);
|
|
2433
2450
|
cg.f32.const(0);
|
|
2434
|
-
cg.f32.
|
|
2451
|
+
cg.f32.lt();
|
|
2435
2452
|
cg.if(cg.void);
|
|
2436
2453
|
cg.f32.const(NaN);
|
|
2437
2454
|
cg.return();
|
|
@@ -2446,6 +2463,20 @@ function wasm_log(cg) {
|
|
|
2446
2463
|
cg.i32.const(127);
|
|
2447
2464
|
cg.i32.sub();
|
|
2448
2465
|
cg.local.set(e);
|
|
2466
|
+
cg.local.get(e);
|
|
2467
|
+
cg.i32.const(-127);
|
|
2468
|
+
cg.i32.eq();
|
|
2469
|
+
cg.if(cg.void);
|
|
2470
|
+
cg.f32.const(-Infinity);
|
|
2471
|
+
cg.return();
|
|
2472
|
+
cg.end();
|
|
2473
|
+
cg.local.get(e);
|
|
2474
|
+
cg.i32.const(128);
|
|
2475
|
+
cg.i32.eq();
|
|
2476
|
+
cg.if(cg.void);
|
|
2477
|
+
cg.local.get(0);
|
|
2478
|
+
cg.return();
|
|
2479
|
+
cg.end();
|
|
2449
2480
|
cg.local.get(bits);
|
|
2450
2481
|
cg.i32.const(8388607);
|
|
2451
2482
|
cg.i32.and();
|
|
@@ -2511,7 +2542,7 @@ function _sincos(cg) {
|
|
|
2511
2542
|
cg.f32.mul();
|
|
2512
2543
|
cg.f32.nearest();
|
|
2513
2544
|
cg.local.tee(qf);
|
|
2514
|
-
cg.i32.
|
|
2545
|
+
cg.i32.trunc_sat_f32_s();
|
|
2515
2546
|
cg.local.set(q);
|
|
2516
2547
|
cg.local.get(y);
|
|
2517
2548
|
cg.local.get(qf);
|
|
@@ -3598,6 +3629,7 @@ var F32x4 = class extends V128 {
|
|
|
3598
3629
|
|
|
3599
3630
|
//#endregion
|
|
3600
3631
|
//#region src/backend/wasm.ts
|
|
3632
|
+
const moduleCache = /* @__PURE__ */ new Map();
|
|
3601
3633
|
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3602
3634
|
var WasmBackend = class {
|
|
3603
3635
|
type = "wasm";
|
|
@@ -3653,8 +3685,11 @@ var WasmBackend = class {
|
|
|
3653
3685
|
return this.prepareSync(kernel);
|
|
3654
3686
|
}
|
|
3655
3687
|
prepareSync(kernel) {
|
|
3656
|
-
const
|
|
3657
|
-
const module =
|
|
3688
|
+
const kernelHash = FpHash.hash(kernel);
|
|
3689
|
+
const module = runWithCache(moduleCache, kernelHash.toString(), () => {
|
|
3690
|
+
const bytes = codegenWasm(kernel);
|
|
3691
|
+
return new WebAssembly.Module(bytes);
|
|
3692
|
+
});
|
|
3658
3693
|
return new Executable(kernel, { module });
|
|
3659
3694
|
}
|
|
3660
3695
|
dispatch(exe, inputs, outputs) {
|
|
@@ -3675,7 +3710,7 @@ function codegenWasm(kernel) {
|
|
|
3675
3710
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3676
3711
|
const cg = new CodeGenerator();
|
|
3677
3712
|
cg.memory.import("env", "memory");
|
|
3678
|
-
const distinctOps = mapSetUnion(tune.exp.distinctOps(),
|
|
3713
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
3679
3714
|
const funcs = {};
|
|
3680
3715
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3681
3716
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
@@ -3753,7 +3788,10 @@ function codegenWasm(kernel) {
|
|
|
3753
3788
|
cg.br(1);
|
|
3754
3789
|
cg.end();
|
|
3755
3790
|
cg.end();
|
|
3756
|
-
translateExp(cg, funcs,
|
|
3791
|
+
translateExp(cg, funcs, tune.epilogue, {
|
|
3792
|
+
acc,
|
|
3793
|
+
gidx
|
|
3794
|
+
});
|
|
3757
3795
|
} else translateExp(cg, funcs, tune.exp, { gidx });
|
|
3758
3796
|
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
3759
3797
|
cg.local.get(gidx);
|
|
@@ -4002,7 +4040,7 @@ async function createBackend(device) {
|
|
|
4002
4040
|
if (!navigator.gpu) return null;
|
|
4003
4041
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4004
4042
|
if (!adapter) return null;
|
|
4005
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4043
|
+
const { WebGPUBackend } = await import("./webgpu-C9iAP5h5.js");
|
|
4006
4044
|
const importantLimits = [
|
|
4007
4045
|
"maxBufferSize",
|
|
4008
4046
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4056,4 +4094,4 @@ var UnsupportedOpError = class extends Error {
|
|
|
4056
4094
|
|
|
4057
4095
|
//#endregion
|
|
4058
4096
|
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4059
|
-
//# sourceMappingURL=backend-
|
|
4097
|
+
//# sourceMappingURL=backend-BY8wlLEl.js.map
|
|
@@ -558,16 +558,16 @@ var AluExp = class AluExp {
|
|
|
558
558
|
});
|
|
559
559
|
}
|
|
560
560
|
/** Reindex gid values in this expression as needed. */
|
|
561
|
-
reindexGids(
|
|
561
|
+
reindexGids(newGids) {
|
|
562
562
|
return this.rewrite((exp) => {
|
|
563
563
|
if (exp.op === AluOp.GlobalIndex) {
|
|
564
564
|
const [gid, len] = exp.arg;
|
|
565
|
-
const newGid =
|
|
566
|
-
if (newGid !==
|
|
565
|
+
const newGid = newGids[gid];
|
|
566
|
+
if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
|
|
567
567
|
} else if (exp.op === AluOp.GlobalView) {
|
|
568
568
|
const gid = exp.arg[0];
|
|
569
|
-
const newGid =
|
|
570
|
-
if (newGid !==
|
|
569
|
+
const newGid = newGids[gid];
|
|
570
|
+
if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
|
|
571
571
|
}
|
|
572
572
|
});
|
|
573
573
|
}
|
|
@@ -781,7 +781,7 @@ var AluExp = class AluExp {
|
|
|
781
781
|
if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
|
|
782
782
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
783
783
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
784
|
-
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
784
|
+
if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
|
|
785
785
|
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
786
786
|
}
|
|
787
787
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
@@ -2067,7 +2067,8 @@ function tuneNullopt(kernel) {
|
|
|
2067
2067
|
if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2068
2068
|
return {
|
|
2069
2069
|
exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2070
|
-
|
|
2070
|
+
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2071
|
+
outputIdxExp: vars.gidx,
|
|
2071
2072
|
threadCount: kernel.size,
|
|
2072
2073
|
size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
|
|
2073
2074
|
};
|
|
@@ -2100,7 +2101,11 @@ function tuneWebgpu(kernel) {
|
|
|
2100
2101
|
while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
|
|
2101
2102
|
const choices = [];
|
|
2102
2103
|
const composedSts = sts.map((st) => st.compose(dim.st));
|
|
2103
|
-
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2104
|
+
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2105
|
+
3,
|
|
2106
|
+
4,
|
|
2107
|
+
5
|
|
2108
|
+
]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
|
|
2104
2109
|
let nonzeroStrides = 0;
|
|
2105
2110
|
let totalStrides = 0;
|
|
2106
2111
|
for (const st of composedSts) {
|
|
@@ -2176,7 +2181,15 @@ function tuneWebgpu(kernel) {
|
|
|
2176
2181
|
});
|
|
2177
2182
|
const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
|
|
2178
2183
|
const outputUpcast = dim.outputSt.shape.slice(dim.groups);
|
|
2179
|
-
const
|
|
2184
|
+
const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
|
|
2185
|
+
const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
|
|
2186
|
+
const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
|
|
2187
|
+
if (exp$1.op === AluOp.GlobalView) {
|
|
2188
|
+
const gid = exp$1.arg[0];
|
|
2189
|
+
const st = exp$1.arg[1];
|
|
2190
|
+
return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
|
|
2191
|
+
}
|
|
2192
|
+
});
|
|
2180
2193
|
if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
|
|
2181
2194
|
const size = {
|
|
2182
2195
|
groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
|
|
@@ -2186,6 +2199,7 @@ function tuneWebgpu(kernel) {
|
|
|
2186
2199
|
};
|
|
2187
2200
|
return {
|
|
2188
2201
|
exp: newExp.simplify(),
|
|
2202
|
+
epilogue: newEpilogue.simplify(),
|
|
2189
2203
|
outputIdxExp: outputIdxExp.simplify(),
|
|
2190
2204
|
threadCount: kernel.size / size.upcast * size.groups,
|
|
2191
2205
|
size
|
|
@@ -2244,10 +2258,10 @@ var CpuBackend = class {
|
|
|
2244
2258
|
return new Executable(kernel, void 0);
|
|
2245
2259
|
}
|
|
2246
2260
|
dispatch({ kernel }, inputs, outputs) {
|
|
2247
|
-
const { exp } = tuneNullopt(kernel);
|
|
2261
|
+
const { exp, epilogue } = tuneNullopt(kernel);
|
|
2248
2262
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
2249
2263
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
2250
|
-
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2264
|
+
const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2251
2265
|
const inputArrays = inputBuffers.map((buf, i) => {
|
|
2252
2266
|
const dtype = usedArgs.get(i);
|
|
2253
2267
|
if (!dtype) return null;
|
|
@@ -2269,7 +2283,10 @@ var CpuBackend = class {
|
|
|
2269
2283
|
}, globals);
|
|
2270
2284
|
acc = kernel.reduction.evaluate(acc, item);
|
|
2271
2285
|
}
|
|
2272
|
-
outputArray[i] =
|
|
2286
|
+
outputArray[i] = epilogue.evaluate({
|
|
2287
|
+
acc,
|
|
2288
|
+
gidx: i
|
|
2289
|
+
}, globals);
|
|
2273
2290
|
}
|
|
2274
2291
|
}
|
|
2275
2292
|
#getBuffer(slot) {
|
|
@@ -2432,7 +2449,7 @@ function wasm_log(cg) {
|
|
|
2432
2449
|
const t2 = cg.local.declare(cg.f32);
|
|
2433
2450
|
cg.local.get(0);
|
|
2434
2451
|
cg.f32.const(0);
|
|
2435
|
-
cg.f32.
|
|
2452
|
+
cg.f32.lt();
|
|
2436
2453
|
cg.if(cg.void);
|
|
2437
2454
|
cg.f32.const(NaN);
|
|
2438
2455
|
cg.return();
|
|
@@ -2447,6 +2464,20 @@ function wasm_log(cg) {
|
|
|
2447
2464
|
cg.i32.const(127);
|
|
2448
2465
|
cg.i32.sub();
|
|
2449
2466
|
cg.local.set(e);
|
|
2467
|
+
cg.local.get(e);
|
|
2468
|
+
cg.i32.const(-127);
|
|
2469
|
+
cg.i32.eq();
|
|
2470
|
+
cg.if(cg.void);
|
|
2471
|
+
cg.f32.const(-Infinity);
|
|
2472
|
+
cg.return();
|
|
2473
|
+
cg.end();
|
|
2474
|
+
cg.local.get(e);
|
|
2475
|
+
cg.i32.const(128);
|
|
2476
|
+
cg.i32.eq();
|
|
2477
|
+
cg.if(cg.void);
|
|
2478
|
+
cg.local.get(0);
|
|
2479
|
+
cg.return();
|
|
2480
|
+
cg.end();
|
|
2450
2481
|
cg.local.get(bits);
|
|
2451
2482
|
cg.i32.const(8388607);
|
|
2452
2483
|
cg.i32.and();
|
|
@@ -2512,7 +2543,7 @@ function _sincos(cg) {
|
|
|
2512
2543
|
cg.f32.mul();
|
|
2513
2544
|
cg.f32.nearest();
|
|
2514
2545
|
cg.local.tee(qf);
|
|
2515
|
-
cg.i32.
|
|
2546
|
+
cg.i32.trunc_sat_f32_s();
|
|
2516
2547
|
cg.local.set(q);
|
|
2517
2548
|
cg.local.get(y);
|
|
2518
2549
|
cg.local.get(qf);
|
|
@@ -3599,6 +3630,7 @@ var F32x4 = class extends V128 {
|
|
|
3599
3630
|
|
|
3600
3631
|
//#endregion
|
|
3601
3632
|
//#region src/backend/wasm.ts
|
|
3633
|
+
const moduleCache = /* @__PURE__ */ new Map();
|
|
3602
3634
|
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3603
3635
|
var WasmBackend = class {
|
|
3604
3636
|
type = "wasm";
|
|
@@ -3654,8 +3686,11 @@ var WasmBackend = class {
|
|
|
3654
3686
|
return this.prepareSync(kernel);
|
|
3655
3687
|
}
|
|
3656
3688
|
prepareSync(kernel) {
|
|
3657
|
-
const
|
|
3658
|
-
const module$1 =
|
|
3689
|
+
const kernelHash = FpHash.hash(kernel);
|
|
3690
|
+
const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => {
|
|
3691
|
+
const bytes = codegenWasm(kernel);
|
|
3692
|
+
return new WebAssembly.Module(bytes);
|
|
3693
|
+
});
|
|
3659
3694
|
return new Executable(kernel, { module: module$1 });
|
|
3660
3695
|
}
|
|
3661
3696
|
dispatch(exe, inputs, outputs) {
|
|
@@ -3676,7 +3711,7 @@ function codegenWasm(kernel) {
|
|
|
3676
3711
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3677
3712
|
const cg = new CodeGenerator();
|
|
3678
3713
|
cg.memory.import("env", "memory");
|
|
3679
|
-
const distinctOps = mapSetUnion(tune.exp.distinctOps(),
|
|
3714
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
3680
3715
|
const funcs = {};
|
|
3681
3716
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3682
3717
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
@@ -3754,7 +3789,10 @@ function codegenWasm(kernel) {
|
|
|
3754
3789
|
cg.br(1);
|
|
3755
3790
|
cg.end();
|
|
3756
3791
|
cg.end();
|
|
3757
|
-
translateExp(cg, funcs,
|
|
3792
|
+
translateExp(cg, funcs, tune.epilogue, {
|
|
3793
|
+
acc,
|
|
3794
|
+
gidx
|
|
3795
|
+
});
|
|
3758
3796
|
} else translateExp(cg, funcs, tune.exp, { gidx });
|
|
3759
3797
|
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
3760
3798
|
cg.local.get(gidx);
|
|
@@ -4003,7 +4041,7 @@ async function createBackend(device) {
|
|
|
4003
4041
|
if (!navigator.gpu) return null;
|
|
4004
4042
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4005
4043
|
if (!adapter) return null;
|
|
4006
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4044
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVns4DbI.cjs"));
|
|
4007
4045
|
const importantLimits = [
|
|
4008
4046
|
"maxBufferSize",
|
|
4009
4047
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4350,4 +4388,4 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4350
4388
|
return zipn;
|
|
4351
4389
|
}
|
|
4352
4390
|
});
|
|
4353
|
-
//# sourceMappingURL=backend-
|
|
4391
|
+
//# sourceMappingURL=backend-CmaidnkQ.cjs.map
|