@jax-js/jax 0.1.1 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -32
- package/dist/{backend-CoVtc9dx.js → backend-BY8wlLEl.js} +88 -25
- package/dist/{backend-BbrKEB18.cjs → backend-CmaidnkQ.cjs} +88 -25
- package/dist/index.cjs +2901 -2252
- package/dist/index.d.cts +1101 -979
- package/dist/index.d.ts +1101 -979
- package/dist/index.js +2892 -2243
- package/dist/{webgpu-DGYNVHma.cjs → webgpu-BVns4DbI.cjs} +25 -15
- package/dist/{webgpu-B3UVme6n.js → webgpu-C9iAP5h5.js} +25 -15
- package/package.json +13 -21
package/README.md
CHANGED
|
@@ -257,36 +257,12 @@ await devicePut(ar, "webgpu"); // Now device="webgpu"
|
|
|
257
257
|
There are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a
|
|
258
258
|
self-contained way in other projects.
|
|
259
259
|
|
|
260
|
-
**`@jax-js/
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
const solver = adam(1e-3);
|
|
268
|
-
let optState = solver.init(params.ref);
|
|
269
|
-
let updates: np.Array;
|
|
270
|
-
|
|
271
|
-
const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();
|
|
272
|
-
|
|
273
|
-
for (let i = 0; i < 100; i++) {
|
|
274
|
-
const paramsGrad = grad(f)(params.ref);
|
|
275
|
-
[updates, optState] = solver.update(paramsGrad, optState);
|
|
276
|
-
params = applyUpdates(params, updates);
|
|
277
|
-
}
|
|
278
|
-
```
|
|
279
|
-
|
|
280
|
-
**`@jax-js/loaders`** can load tensors from various formats like Safetensors, includes a fast and
|
|
281
|
-
compliant implementation of BPE, and caches HTTP requests for large assets like model weights in
|
|
282
|
-
OPFS.
|
|
283
|
-
|
|
284
|
-
```ts
|
|
285
|
-
import { tokenizers } from "@jax-js/loaders";
|
|
286
|
-
|
|
287
|
-
const enc = await tokenizers.getBpe("clip");
|
|
288
|
-
const tokens = enc.encode("Hello, world!"); // => [ 49406, 3306, 267, 1002, ... ]
|
|
289
|
-
```
|
|
260
|
+
- [**`@jax-js/loaders`**](packages/loaders) can load tensors from various formats like Safetensors,
|
|
261
|
+
includes a fast and compliant implementation of BPE, and caches HTTP requests for large assets
|
|
262
|
+
like model weights in OPFS.
|
|
263
|
+
- [**`@jax-js/onnx`**](packages/onnx) is a model loader from the [ONNX](https://onnx.ai/) format
|
|
264
|
+
into native jax-js functions.
|
|
265
|
+
- [**`@jax-js/optax`**](packages/optax) provides implementations of optimizers like Adam and SGD.
|
|
290
266
|
|
|
291
267
|
### Performance
|
|
292
268
|
|
|
@@ -311,6 +287,7 @@ If you make something cool with jax-js, don't be a stranger! We can feature it h
|
|
|
311
287
|
|
|
312
288
|
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
313
289
|
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
290
|
+
- [Object detection with DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
314
291
|
- [In-browser REPL](https://jax-js.com/repl)
|
|
315
292
|
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
316
293
|
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
@@ -351,7 +328,9 @@ Contributions are welcomed! Especially in:
|
|
|
351
328
|
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
352
329
|
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
353
330
|
and multithreading. (Even single-threaded Wasm could be ~20x faster.)
|
|
354
|
-
-
|
|
355
|
-
|
|
331
|
+
- Adding support for `jax.profiling`, in particular the start and end trace functions. We should be
|
|
332
|
+
able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
|
|
333
|
+
to help with model performance debugging.
|
|
334
|
+
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
356
335
|
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
357
336
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
@@ -289,10 +289,11 @@ var FpHash = class FpHash {
|
|
|
289
289
|
};
|
|
290
290
|
/** Run a function while caching it inline inside a `Map`. */
|
|
291
291
|
function runWithCache(cache, key, thunk) {
|
|
292
|
-
|
|
292
|
+
const keyStr = JSON.stringify(key);
|
|
293
|
+
if (cache.has(keyStr)) return cache.get(keyStr);
|
|
293
294
|
else {
|
|
294
295
|
const value = thunk();
|
|
295
|
-
cache.set(
|
|
296
|
+
cache.set(keyStr, value);
|
|
296
297
|
return value;
|
|
297
298
|
}
|
|
298
299
|
}
|
|
@@ -449,6 +450,14 @@ var AluExp = class AluExp {
|
|
|
449
450
|
static sqrt(a) {
|
|
450
451
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
451
452
|
}
|
|
453
|
+
static floor(a) {
|
|
454
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
455
|
+
return new AluExp(AluOp.Floor, a.dtype, [a]);
|
|
456
|
+
}
|
|
457
|
+
static ceil(a) {
|
|
458
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
459
|
+
return new AluExp(AluOp.Ceil, a.dtype, [a]);
|
|
460
|
+
}
|
|
452
461
|
static reciprocal(a) {
|
|
453
462
|
return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
|
|
454
463
|
}
|
|
@@ -548,16 +557,16 @@ var AluExp = class AluExp {
|
|
|
548
557
|
});
|
|
549
558
|
}
|
|
550
559
|
/** Reindex gid values in this expression as needed. */
|
|
551
|
-
reindexGids(
|
|
560
|
+
reindexGids(newGids) {
|
|
552
561
|
return this.rewrite((exp) => {
|
|
553
562
|
if (exp.op === AluOp.GlobalIndex) {
|
|
554
563
|
const [gid, len] = exp.arg;
|
|
555
|
-
const newGid =
|
|
556
|
-
if (newGid !==
|
|
564
|
+
const newGid = newGids[gid];
|
|
565
|
+
if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
|
|
557
566
|
} else if (exp.op === AluOp.GlobalView) {
|
|
558
567
|
const gid = exp.arg[0];
|
|
559
|
-
const newGid =
|
|
560
|
-
if (newGid !==
|
|
568
|
+
const newGid = newGids[gid];
|
|
569
|
+
if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
|
|
561
570
|
}
|
|
562
571
|
});
|
|
563
572
|
}
|
|
@@ -629,6 +638,12 @@ var AluExp = class AluExp {
|
|
|
629
638
|
case AluOp.Sqrt:
|
|
630
639
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
631
640
|
break;
|
|
641
|
+
case AluOp.Floor:
|
|
642
|
+
ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
|
|
643
|
+
break;
|
|
644
|
+
case AluOp.Ceil:
|
|
645
|
+
ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
|
|
646
|
+
break;
|
|
632
647
|
case AluOp.Reciprocal:
|
|
633
648
|
if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
|
|
634
649
|
ret = [1 / src[0].max, 1 / src[0].min];
|
|
@@ -765,7 +780,7 @@ var AluExp = class AluExp {
|
|
|
765
780
|
if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
|
|
766
781
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
767
782
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
768
|
-
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
783
|
+
if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
|
|
769
784
|
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
770
785
|
}
|
|
771
786
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
@@ -861,7 +876,7 @@ var AluExp = class AluExp {
|
|
|
861
876
|
else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
|
|
862
877
|
if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
|
|
863
878
|
}
|
|
864
|
-
if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
|
|
879
|
+
if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
|
|
865
880
|
const [x, y] = src;
|
|
866
881
|
{
|
|
867
882
|
const factors = [];
|
|
@@ -951,6 +966,8 @@ var AluExp = class AluExp {
|
|
|
951
966
|
case AluOp.Erf: return erf(x);
|
|
952
967
|
case AluOp.Erfc: return erfc(x);
|
|
953
968
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
969
|
+
case AluOp.Floor: return Math.floor(x);
|
|
970
|
+
case AluOp.Ceil: return Math.ceil(x);
|
|
954
971
|
case AluOp.Reciprocal: return 1 / x;
|
|
955
972
|
case AluOp.Cast: {
|
|
956
973
|
const wasFloat = isFloatDtype(this.src[0].dtype);
|
|
@@ -1140,6 +1157,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1140
1157
|
AluOp$1["Erf"] = "Erf";
|
|
1141
1158
|
AluOp$1["Erfc"] = "Erfc";
|
|
1142
1159
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1160
|
+
AluOp$1["Floor"] = "Floor";
|
|
1161
|
+
AluOp$1["Ceil"] = "Ceil";
|
|
1143
1162
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1144
1163
|
AluOp$1["Cast"] = "Cast";
|
|
1145
1164
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
@@ -1174,6 +1193,8 @@ const AluGroup = {
|
|
|
1174
1193
|
AluOp.Erf,
|
|
1175
1194
|
AluOp.Erfc,
|
|
1176
1195
|
AluOp.Sqrt,
|
|
1196
|
+
AluOp.Floor,
|
|
1197
|
+
AluOp.Ceil,
|
|
1177
1198
|
AluOp.Reciprocal,
|
|
1178
1199
|
AluOp.Cast,
|
|
1179
1200
|
AluOp.Bitcast
|
|
@@ -1201,7 +1222,9 @@ const AluGroup = {
|
|
|
1201
1222
|
AluOp.Erf,
|
|
1202
1223
|
AluOp.Erfc,
|
|
1203
1224
|
AluOp.Sqrt,
|
|
1204
|
-
AluOp.Reciprocal
|
|
1225
|
+
AluOp.Reciprocal,
|
|
1226
|
+
AluOp.Floor,
|
|
1227
|
+
AluOp.Ceil
|
|
1205
1228
|
])
|
|
1206
1229
|
};
|
|
1207
1230
|
/** Common variables that can be substituted in expressions. */
|
|
@@ -2043,7 +2066,8 @@ function tuneNullopt(kernel) {
|
|
|
2043
2066
|
if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2044
2067
|
return {
|
|
2045
2068
|
exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2046
|
-
|
|
2069
|
+
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2070
|
+
outputIdxExp: vars.gidx,
|
|
2047
2071
|
threadCount: kernel.size,
|
|
2048
2072
|
size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
|
|
2049
2073
|
};
|
|
@@ -2076,7 +2100,11 @@ function tuneWebgpu(kernel) {
|
|
|
2076
2100
|
while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
|
|
2077
2101
|
const choices = [];
|
|
2078
2102
|
const composedSts = sts.map((st) => st.compose(dim.st));
|
|
2079
|
-
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2103
|
+
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2104
|
+
3,
|
|
2105
|
+
4,
|
|
2106
|
+
5
|
|
2107
|
+
]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
|
|
2080
2108
|
let nonzeroStrides = 0;
|
|
2081
2109
|
let totalStrides = 0;
|
|
2082
2110
|
for (const st of composedSts) {
|
|
@@ -2152,7 +2180,15 @@ function tuneWebgpu(kernel) {
|
|
|
2152
2180
|
});
|
|
2153
2181
|
const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
|
|
2154
2182
|
const outputUpcast = dim.outputSt.shape.slice(dim.groups);
|
|
2155
|
-
const
|
|
2183
|
+
const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
|
|
2184
|
+
const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
|
|
2185
|
+
const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
|
|
2186
|
+
if (exp$1.op === AluOp.GlobalView) {
|
|
2187
|
+
const gid = exp$1.arg[0];
|
|
2188
|
+
const st = exp$1.arg[1];
|
|
2189
|
+
return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
|
|
2190
|
+
}
|
|
2191
|
+
});
|
|
2156
2192
|
if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
|
|
2157
2193
|
const size = {
|
|
2158
2194
|
groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
|
|
@@ -2162,6 +2198,7 @@ function tuneWebgpu(kernel) {
|
|
|
2162
2198
|
};
|
|
2163
2199
|
return {
|
|
2164
2200
|
exp: newExp.simplify(),
|
|
2201
|
+
epilogue: newEpilogue.simplify(),
|
|
2165
2202
|
outputIdxExp: outputIdxExp.simplify(),
|
|
2166
2203
|
threadCount: kernel.size / size.upcast * size.groups,
|
|
2167
2204
|
size
|
|
@@ -2220,10 +2257,10 @@ var CpuBackend = class {
|
|
|
2220
2257
|
return new Executable(kernel, void 0);
|
|
2221
2258
|
}
|
|
2222
2259
|
dispatch({ kernel }, inputs, outputs) {
|
|
2223
|
-
const { exp } = tuneNullopt(kernel);
|
|
2260
|
+
const { exp, epilogue } = tuneNullopt(kernel);
|
|
2224
2261
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
2225
2262
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
2226
|
-
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2263
|
+
const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2227
2264
|
const inputArrays = inputBuffers.map((buf, i) => {
|
|
2228
2265
|
const dtype = usedArgs.get(i);
|
|
2229
2266
|
if (!dtype) return null;
|
|
@@ -2245,7 +2282,10 @@ var CpuBackend = class {
|
|
|
2245
2282
|
}, globals);
|
|
2246
2283
|
acc = kernel.reduction.evaluate(acc, item);
|
|
2247
2284
|
}
|
|
2248
|
-
outputArray[i] =
|
|
2285
|
+
outputArray[i] = epilogue.evaluate({
|
|
2286
|
+
acc,
|
|
2287
|
+
gidx: i
|
|
2288
|
+
}, globals);
|
|
2249
2289
|
}
|
|
2250
2290
|
}
|
|
2251
2291
|
#getBuffer(slot) {
|
|
@@ -2408,7 +2448,7 @@ function wasm_log(cg) {
|
|
|
2408
2448
|
const t2 = cg.local.declare(cg.f32);
|
|
2409
2449
|
cg.local.get(0);
|
|
2410
2450
|
cg.f32.const(0);
|
|
2411
|
-
cg.f32.
|
|
2451
|
+
cg.f32.lt();
|
|
2412
2452
|
cg.if(cg.void);
|
|
2413
2453
|
cg.f32.const(NaN);
|
|
2414
2454
|
cg.return();
|
|
@@ -2423,6 +2463,20 @@ function wasm_log(cg) {
|
|
|
2423
2463
|
cg.i32.const(127);
|
|
2424
2464
|
cg.i32.sub();
|
|
2425
2465
|
cg.local.set(e);
|
|
2466
|
+
cg.local.get(e);
|
|
2467
|
+
cg.i32.const(-127);
|
|
2468
|
+
cg.i32.eq();
|
|
2469
|
+
cg.if(cg.void);
|
|
2470
|
+
cg.f32.const(-Infinity);
|
|
2471
|
+
cg.return();
|
|
2472
|
+
cg.end();
|
|
2473
|
+
cg.local.get(e);
|
|
2474
|
+
cg.i32.const(128);
|
|
2475
|
+
cg.i32.eq();
|
|
2476
|
+
cg.if(cg.void);
|
|
2477
|
+
cg.local.get(0);
|
|
2478
|
+
cg.return();
|
|
2479
|
+
cg.end();
|
|
2426
2480
|
cg.local.get(bits);
|
|
2427
2481
|
cg.i32.const(8388607);
|
|
2428
2482
|
cg.i32.and();
|
|
@@ -2488,7 +2542,7 @@ function _sincos(cg) {
|
|
|
2488
2542
|
cg.f32.mul();
|
|
2489
2543
|
cg.f32.nearest();
|
|
2490
2544
|
cg.local.tee(qf);
|
|
2491
|
-
cg.i32.
|
|
2545
|
+
cg.i32.trunc_sat_f32_s();
|
|
2492
2546
|
cg.local.set(q);
|
|
2493
2547
|
cg.local.get(y);
|
|
2494
2548
|
cg.local.get(qf);
|
|
@@ -3575,6 +3629,7 @@ var F32x4 = class extends V128 {
|
|
|
3575
3629
|
|
|
3576
3630
|
//#endregion
|
|
3577
3631
|
//#region src/backend/wasm.ts
|
|
3632
|
+
const moduleCache = /* @__PURE__ */ new Map();
|
|
3578
3633
|
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3579
3634
|
var WasmBackend = class {
|
|
3580
3635
|
type = "wasm";
|
|
@@ -3630,8 +3685,11 @@ var WasmBackend = class {
|
|
|
3630
3685
|
return this.prepareSync(kernel);
|
|
3631
3686
|
}
|
|
3632
3687
|
prepareSync(kernel) {
|
|
3633
|
-
const
|
|
3634
|
-
const module =
|
|
3688
|
+
const kernelHash = FpHash.hash(kernel);
|
|
3689
|
+
const module = runWithCache(moduleCache, kernelHash.toString(), () => {
|
|
3690
|
+
const bytes = codegenWasm(kernel);
|
|
3691
|
+
return new WebAssembly.Module(bytes);
|
|
3692
|
+
});
|
|
3635
3693
|
return new Executable(kernel, { module });
|
|
3636
3694
|
}
|
|
3637
3695
|
dispatch(exe, inputs, outputs) {
|
|
@@ -3652,7 +3710,7 @@ function codegenWasm(kernel) {
|
|
|
3652
3710
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3653
3711
|
const cg = new CodeGenerator();
|
|
3654
3712
|
cg.memory.import("env", "memory");
|
|
3655
|
-
const distinctOps = mapSetUnion(tune.exp.distinctOps(),
|
|
3713
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
3656
3714
|
const funcs = {};
|
|
3657
3715
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3658
3716
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
@@ -3730,7 +3788,10 @@ function codegenWasm(kernel) {
|
|
|
3730
3788
|
cg.br(1);
|
|
3731
3789
|
cg.end();
|
|
3732
3790
|
cg.end();
|
|
3733
|
-
translateExp(cg, funcs,
|
|
3791
|
+
translateExp(cg, funcs, tune.epilogue, {
|
|
3792
|
+
acc,
|
|
3793
|
+
gidx
|
|
3794
|
+
});
|
|
3734
3795
|
} else translateExp(cg, funcs, tune.exp, { gidx });
|
|
3735
3796
|
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
3736
3797
|
cg.local.get(gidx);
|
|
@@ -3831,7 +3892,9 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3831
3892
|
else if (op === AluOp.Reciprocal) {
|
|
3832
3893
|
const dt = dtyF(cg, op, dtype);
|
|
3833
3894
|
dt.const(1), gen(src[0]), dt.div();
|
|
3834
|
-
} else if (op === AluOp.
|
|
3895
|
+
} else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
|
|
3896
|
+
else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
|
|
3897
|
+
else if (op === AluOp.Cast) {
|
|
3835
3898
|
gen(src[0]);
|
|
3836
3899
|
const dtype0 = src[0].dtype;
|
|
3837
3900
|
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
@@ -3977,7 +4040,7 @@ async function createBackend(device) {
|
|
|
3977
4040
|
if (!navigator.gpu) return null;
|
|
3978
4041
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3979
4042
|
if (!adapter) return null;
|
|
3980
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4043
|
+
const { WebGPUBackend } = await import("./webgpu-C9iAP5h5.js");
|
|
3981
4044
|
const importantLimits = [
|
|
3982
4045
|
"maxBufferSize",
|
|
3983
4046
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4031,4 +4094,4 @@ var UnsupportedOpError = class extends Error {
|
|
|
4031
4094
|
|
|
4032
4095
|
//#endregion
|
|
4033
4096
|
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4034
|
-
//# sourceMappingURL=backend-
|
|
4097
|
+
//# sourceMappingURL=backend-BY8wlLEl.js.map
|
|
@@ -290,10 +290,11 @@ var FpHash = class FpHash {
|
|
|
290
290
|
};
|
|
291
291
|
/** Run a function while caching it inline inside a `Map`. */
|
|
292
292
|
function runWithCache(cache, key, thunk) {
|
|
293
|
-
|
|
293
|
+
const keyStr = JSON.stringify(key);
|
|
294
|
+
if (cache.has(keyStr)) return cache.get(keyStr);
|
|
294
295
|
else {
|
|
295
296
|
const value = thunk();
|
|
296
|
-
cache.set(
|
|
297
|
+
cache.set(keyStr, value);
|
|
297
298
|
return value;
|
|
298
299
|
}
|
|
299
300
|
}
|
|
@@ -450,6 +451,14 @@ var AluExp = class AluExp {
|
|
|
450
451
|
static sqrt(a) {
|
|
451
452
|
return new AluExp(AluOp.Sqrt, a.dtype, [a]);
|
|
452
453
|
}
|
|
454
|
+
static floor(a) {
|
|
455
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
456
|
+
return new AluExp(AluOp.Floor, a.dtype, [a]);
|
|
457
|
+
}
|
|
458
|
+
static ceil(a) {
|
|
459
|
+
if (!isFloatDtype(a.dtype)) return a;
|
|
460
|
+
return new AluExp(AluOp.Ceil, a.dtype, [a]);
|
|
461
|
+
}
|
|
453
462
|
static reciprocal(a) {
|
|
454
463
|
return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
|
|
455
464
|
}
|
|
@@ -549,16 +558,16 @@ var AluExp = class AluExp {
|
|
|
549
558
|
});
|
|
550
559
|
}
|
|
551
560
|
/** Reindex gid values in this expression as needed. */
|
|
552
|
-
reindexGids(
|
|
561
|
+
reindexGids(newGids) {
|
|
553
562
|
return this.rewrite((exp) => {
|
|
554
563
|
if (exp.op === AluOp.GlobalIndex) {
|
|
555
564
|
const [gid, len] = exp.arg;
|
|
556
|
-
const newGid =
|
|
557
|
-
if (newGid !==
|
|
565
|
+
const newGid = newGids[gid];
|
|
566
|
+
if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
|
|
558
567
|
} else if (exp.op === AluOp.GlobalView) {
|
|
559
568
|
const gid = exp.arg[0];
|
|
560
|
-
const newGid =
|
|
561
|
-
if (newGid !==
|
|
569
|
+
const newGid = newGids[gid];
|
|
570
|
+
if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
|
|
562
571
|
}
|
|
563
572
|
});
|
|
564
573
|
}
|
|
@@ -630,6 +639,12 @@ var AluExp = class AluExp {
|
|
|
630
639
|
case AluOp.Sqrt:
|
|
631
640
|
ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
|
|
632
641
|
break;
|
|
642
|
+
case AluOp.Floor:
|
|
643
|
+
ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
|
|
644
|
+
break;
|
|
645
|
+
case AluOp.Ceil:
|
|
646
|
+
ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
|
|
647
|
+
break;
|
|
633
648
|
case AluOp.Reciprocal:
|
|
634
649
|
if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
|
|
635
650
|
ret = [1 / src[0].max, 1 / src[0].min];
|
|
@@ -766,7 +781,7 @@ var AluExp = class AluExp {
|
|
|
766
781
|
if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
|
|
767
782
|
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
768
783
|
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
769
|
-
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
784
|
+
if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
|
|
770
785
|
if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
|
|
771
786
|
}
|
|
772
787
|
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
@@ -862,7 +877,7 @@ var AluExp = class AluExp {
|
|
|
862
877
|
else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
|
|
863
878
|
if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
|
|
864
879
|
}
|
|
865
|
-
if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
|
|
880
|
+
if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
|
|
866
881
|
const [x, y] = src;
|
|
867
882
|
{
|
|
868
883
|
const factors = [];
|
|
@@ -952,6 +967,8 @@ var AluExp = class AluExp {
|
|
|
952
967
|
case AluOp.Erf: return erf(x);
|
|
953
968
|
case AluOp.Erfc: return erfc(x);
|
|
954
969
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
970
|
+
case AluOp.Floor: return Math.floor(x);
|
|
971
|
+
case AluOp.Ceil: return Math.ceil(x);
|
|
955
972
|
case AluOp.Reciprocal: return 1 / x;
|
|
956
973
|
case AluOp.Cast: {
|
|
957
974
|
const wasFloat = isFloatDtype(this.src[0].dtype);
|
|
@@ -1141,6 +1158,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
1141
1158
|
AluOp$1["Erf"] = "Erf";
|
|
1142
1159
|
AluOp$1["Erfc"] = "Erfc";
|
|
1143
1160
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
1161
|
+
AluOp$1["Floor"] = "Floor";
|
|
1162
|
+
AluOp$1["Ceil"] = "Ceil";
|
|
1144
1163
|
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
1145
1164
|
AluOp$1["Cast"] = "Cast";
|
|
1146
1165
|
AluOp$1["Bitcast"] = "Bitcast";
|
|
@@ -1175,6 +1194,8 @@ const AluGroup = {
|
|
|
1175
1194
|
AluOp.Erf,
|
|
1176
1195
|
AluOp.Erfc,
|
|
1177
1196
|
AluOp.Sqrt,
|
|
1197
|
+
AluOp.Floor,
|
|
1198
|
+
AluOp.Ceil,
|
|
1178
1199
|
AluOp.Reciprocal,
|
|
1179
1200
|
AluOp.Cast,
|
|
1180
1201
|
AluOp.Bitcast
|
|
@@ -1202,7 +1223,9 @@ const AluGroup = {
|
|
|
1202
1223
|
AluOp.Erf,
|
|
1203
1224
|
AluOp.Erfc,
|
|
1204
1225
|
AluOp.Sqrt,
|
|
1205
|
-
AluOp.Reciprocal
|
|
1226
|
+
AluOp.Reciprocal,
|
|
1227
|
+
AluOp.Floor,
|
|
1228
|
+
AluOp.Ceil
|
|
1206
1229
|
])
|
|
1207
1230
|
};
|
|
1208
1231
|
/** Common variables that can be substituted in expressions. */
|
|
@@ -2044,7 +2067,8 @@ function tuneNullopt(kernel) {
|
|
|
2044
2067
|
if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2045
2068
|
return {
|
|
2046
2069
|
exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2047
|
-
|
|
2070
|
+
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2071
|
+
outputIdxExp: vars.gidx,
|
|
2048
2072
|
threadCount: kernel.size,
|
|
2049
2073
|
size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
|
|
2050
2074
|
};
|
|
@@ -2077,7 +2101,11 @@ function tuneWebgpu(kernel) {
|
|
|
2077
2101
|
while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
|
|
2078
2102
|
const choices = [];
|
|
2079
2103
|
const composedSts = sts.map((st) => st.compose(dim.st));
|
|
2080
|
-
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2104
|
+
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
|
|
2105
|
+
3,
|
|
2106
|
+
4,
|
|
2107
|
+
5
|
|
2108
|
+
]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
|
|
2081
2109
|
let nonzeroStrides = 0;
|
|
2082
2110
|
let totalStrides = 0;
|
|
2083
2111
|
for (const st of composedSts) {
|
|
@@ -2153,7 +2181,15 @@ function tuneWebgpu(kernel) {
|
|
|
2153
2181
|
});
|
|
2154
2182
|
const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
|
|
2155
2183
|
const outputUpcast = dim.outputSt.shape.slice(dim.groups);
|
|
2156
|
-
const
|
|
2184
|
+
const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
|
|
2185
|
+
const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
|
|
2186
|
+
const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
|
|
2187
|
+
if (exp$1.op === AluOp.GlobalView) {
|
|
2188
|
+
const gid = exp$1.arg[0];
|
|
2189
|
+
const st = exp$1.arg[1];
|
|
2190
|
+
return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
|
|
2191
|
+
}
|
|
2192
|
+
});
|
|
2157
2193
|
if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
|
|
2158
2194
|
const size = {
|
|
2159
2195
|
groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
|
|
@@ -2163,6 +2199,7 @@ function tuneWebgpu(kernel) {
|
|
|
2163
2199
|
};
|
|
2164
2200
|
return {
|
|
2165
2201
|
exp: newExp.simplify(),
|
|
2202
|
+
epilogue: newEpilogue.simplify(),
|
|
2166
2203
|
outputIdxExp: outputIdxExp.simplify(),
|
|
2167
2204
|
threadCount: kernel.size / size.upcast * size.groups,
|
|
2168
2205
|
size
|
|
@@ -2221,10 +2258,10 @@ var CpuBackend = class {
|
|
|
2221
2258
|
return new Executable(kernel, void 0);
|
|
2222
2259
|
}
|
|
2223
2260
|
dispatch({ kernel }, inputs, outputs) {
|
|
2224
|
-
const { exp } = tuneNullopt(kernel);
|
|
2261
|
+
const { exp, epilogue } = tuneNullopt(kernel);
|
|
2225
2262
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
2226
2263
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
2227
|
-
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2264
|
+
const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
|
|
2228
2265
|
const inputArrays = inputBuffers.map((buf, i) => {
|
|
2229
2266
|
const dtype = usedArgs.get(i);
|
|
2230
2267
|
if (!dtype) return null;
|
|
@@ -2246,7 +2283,10 @@ var CpuBackend = class {
|
|
|
2246
2283
|
}, globals);
|
|
2247
2284
|
acc = kernel.reduction.evaluate(acc, item);
|
|
2248
2285
|
}
|
|
2249
|
-
outputArray[i] =
|
|
2286
|
+
outputArray[i] = epilogue.evaluate({
|
|
2287
|
+
acc,
|
|
2288
|
+
gidx: i
|
|
2289
|
+
}, globals);
|
|
2250
2290
|
}
|
|
2251
2291
|
}
|
|
2252
2292
|
#getBuffer(slot) {
|
|
@@ -2409,7 +2449,7 @@ function wasm_log(cg) {
|
|
|
2409
2449
|
const t2 = cg.local.declare(cg.f32);
|
|
2410
2450
|
cg.local.get(0);
|
|
2411
2451
|
cg.f32.const(0);
|
|
2412
|
-
cg.f32.
|
|
2452
|
+
cg.f32.lt();
|
|
2413
2453
|
cg.if(cg.void);
|
|
2414
2454
|
cg.f32.const(NaN);
|
|
2415
2455
|
cg.return();
|
|
@@ -2424,6 +2464,20 @@ function wasm_log(cg) {
|
|
|
2424
2464
|
cg.i32.const(127);
|
|
2425
2465
|
cg.i32.sub();
|
|
2426
2466
|
cg.local.set(e);
|
|
2467
|
+
cg.local.get(e);
|
|
2468
|
+
cg.i32.const(-127);
|
|
2469
|
+
cg.i32.eq();
|
|
2470
|
+
cg.if(cg.void);
|
|
2471
|
+
cg.f32.const(-Infinity);
|
|
2472
|
+
cg.return();
|
|
2473
|
+
cg.end();
|
|
2474
|
+
cg.local.get(e);
|
|
2475
|
+
cg.i32.const(128);
|
|
2476
|
+
cg.i32.eq();
|
|
2477
|
+
cg.if(cg.void);
|
|
2478
|
+
cg.local.get(0);
|
|
2479
|
+
cg.return();
|
|
2480
|
+
cg.end();
|
|
2427
2481
|
cg.local.get(bits);
|
|
2428
2482
|
cg.i32.const(8388607);
|
|
2429
2483
|
cg.i32.and();
|
|
@@ -2489,7 +2543,7 @@ function _sincos(cg) {
|
|
|
2489
2543
|
cg.f32.mul();
|
|
2490
2544
|
cg.f32.nearest();
|
|
2491
2545
|
cg.local.tee(qf);
|
|
2492
|
-
cg.i32.
|
|
2546
|
+
cg.i32.trunc_sat_f32_s();
|
|
2493
2547
|
cg.local.set(q);
|
|
2494
2548
|
cg.local.get(y);
|
|
2495
2549
|
cg.local.get(qf);
|
|
@@ -3576,6 +3630,7 @@ var F32x4 = class extends V128 {
|
|
|
3576
3630
|
|
|
3577
3631
|
//#endregion
|
|
3578
3632
|
//#region src/backend/wasm.ts
|
|
3633
|
+
const moduleCache = /* @__PURE__ */ new Map();
|
|
3579
3634
|
/** Backend that compiles into WebAssembly bytecode for immediate execution. */
|
|
3580
3635
|
var WasmBackend = class {
|
|
3581
3636
|
type = "wasm";
|
|
@@ -3631,8 +3686,11 @@ var WasmBackend = class {
|
|
|
3631
3686
|
return this.prepareSync(kernel);
|
|
3632
3687
|
}
|
|
3633
3688
|
prepareSync(kernel) {
|
|
3634
|
-
const
|
|
3635
|
-
const module$1 =
|
|
3689
|
+
const kernelHash = FpHash.hash(kernel);
|
|
3690
|
+
const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => {
|
|
3691
|
+
const bytes = codegenWasm(kernel);
|
|
3692
|
+
return new WebAssembly.Module(bytes);
|
|
3693
|
+
});
|
|
3636
3694
|
return new Executable(kernel, { module: module$1 });
|
|
3637
3695
|
}
|
|
3638
3696
|
dispatch(exe, inputs, outputs) {
|
|
@@ -3653,7 +3711,7 @@ function codegenWasm(kernel) {
|
|
|
3653
3711
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
3654
3712
|
const cg = new CodeGenerator();
|
|
3655
3713
|
cg.memory.import("env", "memory");
|
|
3656
|
-
const distinctOps = mapSetUnion(tune.exp.distinctOps(),
|
|
3714
|
+
const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
|
|
3657
3715
|
const funcs = {};
|
|
3658
3716
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3659
3717
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
@@ -3731,7 +3789,10 @@ function codegenWasm(kernel) {
|
|
|
3731
3789
|
cg.br(1);
|
|
3732
3790
|
cg.end();
|
|
3733
3791
|
cg.end();
|
|
3734
|
-
translateExp(cg, funcs,
|
|
3792
|
+
translateExp(cg, funcs, tune.epilogue, {
|
|
3793
|
+
acc,
|
|
3794
|
+
gidx
|
|
3795
|
+
});
|
|
3735
3796
|
} else translateExp(cg, funcs, tune.exp, { gidx });
|
|
3736
3797
|
dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
|
|
3737
3798
|
cg.local.get(gidx);
|
|
@@ -3832,7 +3893,9 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3832
3893
|
else if (op === AluOp.Reciprocal) {
|
|
3833
3894
|
const dt = dtyF(cg, op, dtype);
|
|
3834
3895
|
dt.const(1), gen(src[0]), dt.div();
|
|
3835
|
-
} else if (op === AluOp.
|
|
3896
|
+
} else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
|
|
3897
|
+
else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
|
|
3898
|
+
else if (op === AluOp.Cast) {
|
|
3836
3899
|
gen(src[0]);
|
|
3837
3900
|
const dtype0 = src[0].dtype;
|
|
3838
3901
|
const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
|
|
@@ -3978,7 +4041,7 @@ async function createBackend(device) {
|
|
|
3978
4041
|
if (!navigator.gpu) return null;
|
|
3979
4042
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3980
4043
|
if (!adapter) return null;
|
|
3981
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4044
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVns4DbI.cjs"));
|
|
3982
4045
|
const importantLimits = [
|
|
3983
4046
|
"maxBufferSize",
|
|
3984
4047
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4325,4 +4388,4 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4325
4388
|
return zipn;
|
|
4326
4389
|
}
|
|
4327
4390
|
});
|
|
4328
|
-
//# sourceMappingURL=backend-
|
|
4391
|
+
//# sourceMappingURL=backend-CmaidnkQ.cjs.map
|