@jax-js/jax 0.1.8 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +46 -29
- package/dist/{backend-nEolvdLv.js → backend-Ctqs8la1.js} +122 -15
- package/dist/{backend-B3foXiV_.cjs → backend-DMauYnfl.cjs} +157 -14
- package/dist/index.cjs +331 -46
- package/dist/index.d.cts +175 -31
- package/dist/index.d.ts +175 -31
- package/dist/index.js +331 -47
- package/dist/{webgl-DweKSWEm.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-DIIbKJ0G.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-BykvF26B.cjs → webgpu-DMSx7a6M.cjs} +160 -15
- package/dist/{webgpu-B96vzWGE.js → webgpu-v_W_-oKw.js} +160 -15
- package/package.json +5 -16
package/README.md
CHANGED
|
@@ -43,6 +43,41 @@ way to get started on a blank HTML page.
|
|
|
43
43
|
</script>
|
|
44
44
|
```
|
|
45
45
|
|
|
46
|
+
### Platforms
|
|
47
|
+
|
|
48
|
+
This table refers to latest versions of each browser. WebGPU has gained wide support in browsers as
|
|
49
|
+
of late 2025.
|
|
50
|
+
|
|
51
|
+
| Platform | CPU (Wasm) | GPU (WebGPU) | GPU (WebGL) |
|
|
52
|
+
| ------------------- | ---------- | -------------- | ----------- |
|
|
53
|
+
| Chrome / Edge | ✅ | ✅ | ✅ |
|
|
54
|
+
| Firefox | ✅ | ✅ - macOS 26+ | ✅ |
|
|
55
|
+
| Safari | ✅ | ✅ - macOS 26+ | ✅ |
|
|
56
|
+
| iOS | ✅ | ✅ - iOS 26+ | ✅ |
|
|
57
|
+
| Chrome for Android | ✅ | ✅ | ✅ |
|
|
58
|
+
| Firefox for Android | ✅ | ❌ | ✅ |
|
|
59
|
+
| Node.js | ✅ | ❌ | ❌ |
|
|
60
|
+
| Deno | ✅ | ✅ - async | ❌ |
|
|
61
|
+
|
|
62
|
+
## Examples
|
|
63
|
+
|
|
64
|
+
Community usage:
|
|
65
|
+
|
|
66
|
+
- [**autoresearch-webgpu**: autoresesarch, in the browser](https://autoresearch.lucasgelfond.online/)
|
|
67
|
+
- [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)
|
|
68
|
+
- [**jax-js-bayes**: Declarative Bayesian modeling library](https://github.com/StefanSko/jax-js-bayes)
|
|
69
|
+
|
|
70
|
+
Demos on the jax-js website:
|
|
71
|
+
|
|
72
|
+
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
73
|
+
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
74
|
+
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
75
|
+
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
76
|
+
- [In-browser REPL](https://jax-js.com/repl)
|
|
77
|
+
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
78
|
+
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
79
|
+
- [Mandelbrot set](https://jax-js.com/mandelbrot)
|
|
80
|
+
|
|
46
81
|
## Feature comparison
|
|
47
82
|
|
|
48
83
|
Here's a quick, high-level comparison with other popular web ML runtimes:
|
|
@@ -320,6 +355,8 @@ self-contained way in other projects.
|
|
|
320
355
|
|
|
321
356
|
### Performance
|
|
322
357
|
|
|
358
|
+
To see per-kernel traces in browser development tools, call `jax.profiler.startTrace()`.
|
|
359
|
+
|
|
323
360
|
The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual
|
|
324
361
|
browsers. Also, this library uniquely has the `jit()` feature that fuses operations together and
|
|
325
362
|
records an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an
|
|
@@ -338,19 +375,6 @@ well as unique optimizations such as FlashAttention variants.
|
|
|
338
375
|
That's all for this short tutorial. Please see the generated
|
|
339
376
|
[API reference](https://jax-js.com/docs) for detailed documentation.
|
|
340
377
|
|
|
341
|
-
## Examples
|
|
342
|
-
|
|
343
|
-
If you make something cool with jax-js, don't be a stranger! We can feature it here.
|
|
344
|
-
|
|
345
|
-
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
346
|
-
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
347
|
-
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
348
|
-
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
349
|
-
- [In-browser REPL](https://jax-js.com/repl)
|
|
350
|
-
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
351
|
-
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
352
|
-
- [Mandelbrot set](https://jax-js.com/mandelbrot)
|
|
353
|
-
|
|
354
378
|
## Development
|
|
355
379
|
|
|
356
380
|
_The following technical details are for contributing to jax-js and modifying its internals._
|
|
@@ -363,19 +387,6 @@ pnpm install
|
|
|
363
387
|
pnpm run build:watch
|
|
364
388
|
```
|
|
365
389
|
|
|
366
|
-
The `pnpm install` command automatically sets up Git hooks via
|
|
367
|
-
[Husky](https://typicode.github.io/husky/). Pre-commit hooks will run ESLint and Prettier on staged
|
|
368
|
-
files to ensure code quality.
|
|
369
|
-
|
|
370
|
-
You can also run linting and formatting manually:
|
|
371
|
-
|
|
372
|
-
```bash
|
|
373
|
-
pnpm lint # Run ESLint
|
|
374
|
-
pnpm format # Format all files with Prettier
|
|
375
|
-
pnpm format:check # Check formatting without writing
|
|
376
|
-
pnpm check # Run TypeScript type checking
|
|
377
|
-
```
|
|
378
|
-
|
|
379
390
|
Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
|
|
380
391
|
|
|
381
392
|
```bash
|
|
@@ -392,6 +403,15 @@ To start a Vite dev server running the website, demos and REPL:
|
|
|
392
403
|
pnpm -C website dev
|
|
393
404
|
```
|
|
394
405
|
|
|
406
|
+
You can run the linter, code formatter, and type checker with:
|
|
407
|
+
|
|
408
|
+
```bash
|
|
409
|
+
pnpm lint # Run ESLint
|
|
410
|
+
pnpm format # Format all files with Prettier
|
|
411
|
+
pnpm format:check # Check formatting without writing
|
|
412
|
+
pnpm check # Run TypeScript type checking
|
|
413
|
+
```
|
|
414
|
+
|
|
395
415
|
## Future work / help wanted
|
|
396
416
|
|
|
397
417
|
Contributions are welcomed! Some fruitful areas to look into:
|
|
@@ -399,9 +419,6 @@ Contributions are welcomed! Some fruitful areas to look into:
|
|
|
399
419
|
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
400
420
|
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
401
421
|
and multithreading. (Even single-threaded Wasm could be ~20x faster.)
|
|
402
|
-
- Adding support for `jax.profiling`, in particular the start and end trace functions. We should be
|
|
403
|
-
able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
|
|
404
|
-
to help with model performance debugging.
|
|
405
422
|
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
406
423
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
407
424
|
|
|
@@ -1312,6 +1312,10 @@ var Reduction = class {
|
|
|
1312
1312
|
this.epilogue = epilogue;
|
|
1313
1313
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
|
|
1314
1314
|
this.epilogue = epilogue.simplify();
|
|
1315
|
+
if (this.dtype === DType.Float16 && this.op === AluOp.Add) {
|
|
1316
|
+
this.epilogue = this.epilogue.substitute({ acc: AluExp.cast(this.dtype, AluVar.acc(DType.Float32)) });
|
|
1317
|
+
this.dtype = DType.Float32;
|
|
1318
|
+
}
|
|
1315
1319
|
}
|
|
1316
1320
|
hash(state) {
|
|
1317
1321
|
state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
|
|
@@ -1479,9 +1483,14 @@ var Routine = class {
|
|
|
1479
1483
|
};
|
|
1480
1484
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1481
1485
|
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1482
|
-
/**
|
|
1486
|
+
/**
|
|
1487
|
+
* Sort along the last axis.
|
|
1488
|
+
*
|
|
1489
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
1490
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
1491
|
+
*/
|
|
1483
1492
|
Routines$1["Sort"] = "Sort";
|
|
1484
|
-
/**
|
|
1493
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
1485
1494
|
Routines$1["Argsort"] = "Argsort";
|
|
1486
1495
|
/**
|
|
1487
1496
|
* Solve a triangular system of equations.
|
|
@@ -1545,7 +1554,13 @@ function runArgsort(type, [x], [y, yi]) {
|
|
|
1545
1554
|
const out = y.subarray(offset, offset + n);
|
|
1546
1555
|
const outi = yi.subarray(offset, offset + n);
|
|
1547
1556
|
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1548
|
-
outi.sort((a, b) =>
|
|
1557
|
+
outi.sort((a, b) => {
|
|
1558
|
+
const x$1 = ar[a];
|
|
1559
|
+
const y$1 = ar[b];
|
|
1560
|
+
if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
|
|
1561
|
+
if (isNaN(y$1)) return -1;
|
|
1562
|
+
return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
|
|
1563
|
+
});
|
|
1549
1564
|
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1550
1565
|
}
|
|
1551
1566
|
}
|
|
@@ -2255,11 +2270,15 @@ var TuneDims = class {
|
|
|
2255
2270
|
};
|
|
2256
2271
|
/** Tuning step that does not apply any optimization. */
|
|
2257
2272
|
function tuneNullopt(kernel) {
|
|
2273
|
+
let exp = kernel.exp;
|
|
2258
2274
|
const vars = {};
|
|
2259
2275
|
vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
|
|
2260
|
-
if (kernel.reduction)
|
|
2276
|
+
if (kernel.reduction) {
|
|
2277
|
+
vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2278
|
+
if (exp.dtype !== kernel.reduction.dtype) exp = AluExp.cast(kernel.reduction.dtype, exp);
|
|
2279
|
+
}
|
|
2261
2280
|
return {
|
|
2262
|
-
exp:
|
|
2281
|
+
exp: exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2263
2282
|
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2264
2283
|
outputIdxExp: vars.gidx,
|
|
2265
2284
|
threadCount: kernel.size,
|
|
@@ -2268,8 +2287,9 @@ function tuneNullopt(kernel) {
|
|
|
2268
2287
|
}
|
|
2269
2288
|
/** Tuning for WebGPU kernels. */
|
|
2270
2289
|
function tuneWebgpu(kernel) {
|
|
2271
|
-
const
|
|
2290
|
+
const reduction = kernel.reduction;
|
|
2272
2291
|
if (!reduction) return tuneNullopt(kernel);
|
|
2292
|
+
const exp = AluExp.cast(reduction.dtype, kernel.exp);
|
|
2273
2293
|
const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
|
|
2274
2294
|
if (globalIndexes.length > 0) {
|
|
2275
2295
|
if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
|
|
@@ -2321,7 +2341,7 @@ function tuneWebgpu(kernel) {
|
|
|
2321
2341
|
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2322
2342
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2323
2343
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2324
|
-
else for (const splits of [
|
|
2344
|
+
else for (const splits of [4, 2]) if (s % splits === 0) {
|
|
2325
2345
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2326
2346
|
break;
|
|
2327
2347
|
}
|
|
@@ -2497,6 +2517,85 @@ var CpuBackend = class {
|
|
|
2497
2517
|
}
|
|
2498
2518
|
};
|
|
2499
2519
|
|
|
2520
|
+
//#endregion
|
|
2521
|
+
//#region src/tracing.ts
|
|
2522
|
+
let traceEnabled = false;
|
|
2523
|
+
const flushCallbacks = [];
|
|
2524
|
+
/**
|
|
2525
|
+
* Start collecting kernel traces.
|
|
2526
|
+
*
|
|
2527
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2528
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2529
|
+
*/
|
|
2530
|
+
function startTrace() {
|
|
2531
|
+
traceEnabled = true;
|
|
2532
|
+
}
|
|
2533
|
+
/**
|
|
2534
|
+
* Stop collecting kernel traces.
|
|
2535
|
+
*
|
|
2536
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2537
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2538
|
+
*/
|
|
2539
|
+
function stopTrace() {
|
|
2540
|
+
traceEnabled = false;
|
|
2541
|
+
for (const cb of flushCallbacks) cb();
|
|
2542
|
+
}
|
|
2543
|
+
/** Check if tracing is currently enabled. */
|
|
2544
|
+
function isTracing() {
|
|
2545
|
+
return traceEnabled;
|
|
2546
|
+
}
|
|
2547
|
+
/** Register a callback to flush pending trace data when tracing stops. */
|
|
2548
|
+
function onFlushTrace(cb) {
|
|
2549
|
+
flushCallbacks.push(cb);
|
|
2550
|
+
}
|
|
2551
|
+
function humanSize(n) {
|
|
2552
|
+
if (n >= 1e9) return `${(n / 1e9).toPrecision(3)}B`;
|
|
2553
|
+
if (n >= 1e6) return `${(n / 1e6).toPrecision(3)}M`;
|
|
2554
|
+
if (n >= 1e3) return `${(n / 1e3).toPrecision(3)}K`;
|
|
2555
|
+
return `${n}`;
|
|
2556
|
+
}
|
|
2557
|
+
/** Build a trace label, properties, and color from a kernel or routine source. */
|
|
2558
|
+
function traceSourceInfo(source) {
|
|
2559
|
+
const properties = [];
|
|
2560
|
+
let label;
|
|
2561
|
+
let color;
|
|
2562
|
+
if (source instanceof Kernel) {
|
|
2563
|
+
label = `Kernel[${humanSize(source.size)}]`;
|
|
2564
|
+
properties.push(["exp", `${source.exp}`]);
|
|
2565
|
+
properties.push(["size", `${source.size}`]);
|
|
2566
|
+
properties.push(["nargs", `${source.nargs}`]);
|
|
2567
|
+
if (!source.reduction) color = "primary";
|
|
2568
|
+
else {
|
|
2569
|
+
color = "secondary";
|
|
2570
|
+
properties.push(["reduction", `${source.reduction.op}:${source.reduction.size}`]);
|
|
2571
|
+
}
|
|
2572
|
+
} else {
|
|
2573
|
+
color = "tertiary";
|
|
2574
|
+
label = source.name;
|
|
2575
|
+
properties.push(["inputShapes", source.type.inputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2576
|
+
properties.push(["outputShapes", source.type.outputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2577
|
+
properties.push(["dtype", source.type.inputDtypes.join(", ")]);
|
|
2578
|
+
}
|
|
2579
|
+
return {
|
|
2580
|
+
label,
|
|
2581
|
+
color,
|
|
2582
|
+
properties
|
|
2583
|
+
};
|
|
2584
|
+
}
|
|
2585
|
+
/** Emit a trace entry as a `performance.measure` with devtools metadata. */
|
|
2586
|
+
function emitTrace(track, info, start, end) {
|
|
2587
|
+
performance.measure(info.label, {
|
|
2588
|
+
detail: { devtools: {
|
|
2589
|
+
trackGroup: "JAX Profiler",
|
|
2590
|
+
track,
|
|
2591
|
+
color: info.color,
|
|
2592
|
+
properties: info.properties
|
|
2593
|
+
} },
|
|
2594
|
+
start,
|
|
2595
|
+
end
|
|
2596
|
+
});
|
|
2597
|
+
}
|
|
2598
|
+
|
|
2500
2599
|
//#endregion
|
|
2501
2600
|
//#region src/backend/wasm/allocator.ts
|
|
2502
2601
|
/** Simple tensor memory allocator for WebAssembly linear memory. */
|
|
@@ -3903,11 +4002,19 @@ var WasmBackend = class {
|
|
|
3903
4002
|
return new Executable(routine, void 0);
|
|
3904
4003
|
}
|
|
3905
4004
|
dispatch(exe, inputs, outputs) {
|
|
3906
|
-
|
|
3907
|
-
const
|
|
3908
|
-
|
|
3909
|
-
|
|
3910
|
-
|
|
4005
|
+
const tracing = isTracing();
|
|
4006
|
+
const start = tracing ? performance.now() : 0;
|
|
4007
|
+
if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
4008
|
+
else {
|
|
4009
|
+
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
4010
|
+
const func = instance.exports.kernel;
|
|
4011
|
+
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
4012
|
+
func(...ptrs);
|
|
4013
|
+
}
|
|
4014
|
+
if (tracing) {
|
|
4015
|
+
const info = traceSourceInfo(exe.source);
|
|
4016
|
+
emitTrace("wasm", info, start, performance.now());
|
|
4017
|
+
}
|
|
3911
4018
|
}
|
|
3912
4019
|
#getBuffer(slot) {
|
|
3913
4020
|
const buffer = this.#buffers.get(slot);
|
|
@@ -4252,7 +4359,7 @@ async function createBackend(device) {
|
|
|
4252
4359
|
if (!navigator.gpu) return null;
|
|
4253
4360
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4254
4361
|
if (!adapter) return null;
|
|
4255
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4362
|
+
const { WebGPUBackend } = await import("./webgpu-v_W_-oKw.js");
|
|
4256
4363
|
const importantLimits = [
|
|
4257
4364
|
"maxBufferSize",
|
|
4258
4365
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4290,7 +4397,7 @@ async function createBackend(device) {
|
|
|
4290
4397
|
});
|
|
4291
4398
|
if (!gl) return null;
|
|
4292
4399
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4293
|
-
const { WebGLBackend } = await import("./webgl-
|
|
4400
|
+
const { WebGLBackend } = await import("./webgl-CvQ1QBX1.js");
|
|
4294
4401
|
return new WebGLBackend(gl);
|
|
4295
4402
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4296
4403
|
}
|
|
@@ -4326,4 +4433,4 @@ var UnsupportedRoutineError = class extends Error {
|
|
|
4326
4433
|
};
|
|
4327
4434
|
|
|
4328
4435
|
//#endregion
|
|
4329
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4436
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, emitTrace, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, isTracing, mapSetUnion, normalizeAxis, onFlushTrace, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, strip1, toposort, traceSourceInfo, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
@@ -1313,6 +1313,10 @@ var Reduction = class {
|
|
|
1313
1313
|
this.epilogue = epilogue;
|
|
1314
1314
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
|
|
1315
1315
|
this.epilogue = epilogue.simplify();
|
|
1316
|
+
if (this.dtype === DType.Float16 && this.op === AluOp.Add) {
|
|
1317
|
+
this.epilogue = this.epilogue.substitute({ acc: AluExp.cast(this.dtype, AluVar.acc(DType.Float32)) });
|
|
1318
|
+
this.dtype = DType.Float32;
|
|
1319
|
+
}
|
|
1316
1320
|
}
|
|
1317
1321
|
hash(state) {
|
|
1318
1322
|
state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
|
|
@@ -1480,9 +1484,14 @@ var Routine = class {
|
|
|
1480
1484
|
};
|
|
1481
1485
|
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1482
1486
|
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1483
|
-
/**
|
|
1487
|
+
/**
|
|
1488
|
+
* Sort along the last axis.
|
|
1489
|
+
*
|
|
1490
|
+
* This may be _unstable_ but it often doesn't matter, sorting numbers is
|
|
1491
|
+
* bitwise unique up to signed zeros and NaNs.
|
|
1492
|
+
*/
|
|
1484
1493
|
Routines$1["Sort"] = "Sort";
|
|
1485
|
-
/**
|
|
1494
|
+
/** Stable sorting, returns `int32` indices and values of the sorted array. */
|
|
1486
1495
|
Routines$1["Argsort"] = "Argsort";
|
|
1487
1496
|
/**
|
|
1488
1497
|
* Solve a triangular system of equations.
|
|
@@ -1546,7 +1555,13 @@ function runArgsort(type, [x], [y, yi]) {
|
|
|
1546
1555
|
const out = y.subarray(offset, offset + n);
|
|
1547
1556
|
const outi = yi.subarray(offset, offset + n);
|
|
1548
1557
|
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1549
|
-
outi.sort((a, b) =>
|
|
1558
|
+
outi.sort((a, b) => {
|
|
1559
|
+
const x$1 = ar[a];
|
|
1560
|
+
const y$1 = ar[b];
|
|
1561
|
+
if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
|
|
1562
|
+
if (isNaN(y$1)) return -1;
|
|
1563
|
+
return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
|
|
1564
|
+
});
|
|
1550
1565
|
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1551
1566
|
}
|
|
1552
1567
|
}
|
|
@@ -2256,11 +2271,15 @@ var TuneDims = class {
|
|
|
2256
2271
|
};
|
|
2257
2272
|
/** Tuning step that does not apply any optimization. */
|
|
2258
2273
|
function tuneNullopt(kernel) {
|
|
2274
|
+
let exp = kernel.exp;
|
|
2259
2275
|
const vars = {};
|
|
2260
2276
|
vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
|
|
2261
|
-
if (kernel.reduction)
|
|
2277
|
+
if (kernel.reduction) {
|
|
2278
|
+
vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2279
|
+
if (exp.dtype !== kernel.reduction.dtype) exp = AluExp.cast(kernel.reduction.dtype, exp);
|
|
2280
|
+
}
|
|
2262
2281
|
return {
|
|
2263
|
-
exp:
|
|
2282
|
+
exp: exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2264
2283
|
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2265
2284
|
outputIdxExp: vars.gidx,
|
|
2266
2285
|
threadCount: kernel.size,
|
|
@@ -2269,8 +2288,9 @@ function tuneNullopt(kernel) {
|
|
|
2269
2288
|
}
|
|
2270
2289
|
/** Tuning for WebGPU kernels. */
|
|
2271
2290
|
function tuneWebgpu(kernel) {
|
|
2272
|
-
const
|
|
2291
|
+
const reduction = kernel.reduction;
|
|
2273
2292
|
if (!reduction) return tuneNullopt(kernel);
|
|
2293
|
+
const exp = AluExp.cast(reduction.dtype, kernel.exp);
|
|
2274
2294
|
const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
|
|
2275
2295
|
if (globalIndexes.length > 0) {
|
|
2276
2296
|
if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
|
|
@@ -2322,7 +2342,7 @@ function tuneWebgpu(kernel) {
|
|
|
2322
2342
|
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2323
2343
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2324
2344
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2325
|
-
else for (const splits of [
|
|
2345
|
+
else for (const splits of [4, 2]) if (s % splits === 0) {
|
|
2326
2346
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2327
2347
|
break;
|
|
2328
2348
|
}
|
|
@@ -2498,6 +2518,85 @@ var CpuBackend = class {
|
|
|
2498
2518
|
}
|
|
2499
2519
|
};
|
|
2500
2520
|
|
|
2521
|
+
//#endregion
|
|
2522
|
+
//#region src/tracing.ts
|
|
2523
|
+
let traceEnabled = false;
|
|
2524
|
+
const flushCallbacks = [];
|
|
2525
|
+
/**
|
|
2526
|
+
* Start collecting kernel traces.
|
|
2527
|
+
*
|
|
2528
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2529
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2530
|
+
*/
|
|
2531
|
+
function startTrace() {
|
|
2532
|
+
traceEnabled = true;
|
|
2533
|
+
}
|
|
2534
|
+
/**
|
|
2535
|
+
* Stop collecting kernel traces.
|
|
2536
|
+
*
|
|
2537
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2538
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2539
|
+
*/
|
|
2540
|
+
function stopTrace() {
|
|
2541
|
+
traceEnabled = false;
|
|
2542
|
+
for (const cb of flushCallbacks) cb();
|
|
2543
|
+
}
|
|
2544
|
+
/** Check if tracing is currently enabled. */
|
|
2545
|
+
function isTracing() {
|
|
2546
|
+
return traceEnabled;
|
|
2547
|
+
}
|
|
2548
|
+
/** Register a callback to flush pending trace data when tracing stops. */
|
|
2549
|
+
function onFlushTrace(cb) {
|
|
2550
|
+
flushCallbacks.push(cb);
|
|
2551
|
+
}
|
|
2552
|
+
function humanSize(n) {
|
|
2553
|
+
if (n >= 1e9) return `${(n / 1e9).toPrecision(3)}B`;
|
|
2554
|
+
if (n >= 1e6) return `${(n / 1e6).toPrecision(3)}M`;
|
|
2555
|
+
if (n >= 1e3) return `${(n / 1e3).toPrecision(3)}K`;
|
|
2556
|
+
return `${n}`;
|
|
2557
|
+
}
|
|
2558
|
+
/** Build a trace label, properties, and color from a kernel or routine source. */
|
|
2559
|
+
function traceSourceInfo(source) {
|
|
2560
|
+
const properties = [];
|
|
2561
|
+
let label;
|
|
2562
|
+
let color;
|
|
2563
|
+
if (source instanceof Kernel) {
|
|
2564
|
+
label = `Kernel[${humanSize(source.size)}]`;
|
|
2565
|
+
properties.push(["exp", `${source.exp}`]);
|
|
2566
|
+
properties.push(["size", `${source.size}`]);
|
|
2567
|
+
properties.push(["nargs", `${source.nargs}`]);
|
|
2568
|
+
if (!source.reduction) color = "primary";
|
|
2569
|
+
else {
|
|
2570
|
+
color = "secondary";
|
|
2571
|
+
properties.push(["reduction", `${source.reduction.op}:${source.reduction.size}`]);
|
|
2572
|
+
}
|
|
2573
|
+
} else {
|
|
2574
|
+
color = "tertiary";
|
|
2575
|
+
label = source.name;
|
|
2576
|
+
properties.push(["inputShapes", source.type.inputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2577
|
+
properties.push(["outputShapes", source.type.outputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2578
|
+
properties.push(["dtype", source.type.inputDtypes.join(", ")]);
|
|
2579
|
+
}
|
|
2580
|
+
return {
|
|
2581
|
+
label,
|
|
2582
|
+
color,
|
|
2583
|
+
properties
|
|
2584
|
+
};
|
|
2585
|
+
}
|
|
2586
|
+
/** Emit a trace entry as a `performance.measure` with devtools metadata. */
|
|
2587
|
+
function emitTrace(track, info, start, end) {
|
|
2588
|
+
performance.measure(info.label, {
|
|
2589
|
+
detail: { devtools: {
|
|
2590
|
+
trackGroup: "JAX Profiler",
|
|
2591
|
+
track,
|
|
2592
|
+
color: info.color,
|
|
2593
|
+
properties: info.properties
|
|
2594
|
+
} },
|
|
2595
|
+
start,
|
|
2596
|
+
end
|
|
2597
|
+
});
|
|
2598
|
+
}
|
|
2599
|
+
|
|
2501
2600
|
//#endregion
|
|
2502
2601
|
//#region src/backend/wasm/allocator.ts
|
|
2503
2602
|
/** Simple tensor memory allocator for WebAssembly linear memory. */
|
|
@@ -3904,11 +4003,19 @@ var WasmBackend = class {
|
|
|
3904
4003
|
return new Executable(routine, void 0);
|
|
3905
4004
|
}
|
|
3906
4005
|
dispatch(exe, inputs, outputs) {
|
|
3907
|
-
|
|
3908
|
-
const
|
|
3909
|
-
|
|
3910
|
-
|
|
3911
|
-
|
|
4006
|
+
const tracing = isTracing();
|
|
4007
|
+
const start = tracing ? performance.now() : 0;
|
|
4008
|
+
if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
4009
|
+
else {
|
|
4010
|
+
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
4011
|
+
const func = instance.exports.kernel;
|
|
4012
|
+
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
4013
|
+
func(...ptrs);
|
|
4014
|
+
}
|
|
4015
|
+
if (tracing) {
|
|
4016
|
+
const info = traceSourceInfo(exe.source);
|
|
4017
|
+
emitTrace("wasm", info, start, performance.now());
|
|
4018
|
+
}
|
|
3912
4019
|
}
|
|
3913
4020
|
#getBuffer(slot) {
|
|
3914
4021
|
const buffer = this.#buffers.get(slot);
|
|
@@ -4253,7 +4360,7 @@ async function createBackend(device) {
|
|
|
4253
4360
|
if (!navigator.gpu) return null;
|
|
4254
4361
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4255
4362
|
if (!adapter) return null;
|
|
4256
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4363
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-DMSx7a6M.cjs"));
|
|
4257
4364
|
const importantLimits = [
|
|
4258
4365
|
"maxBufferSize",
|
|
4259
4366
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4291,7 +4398,7 @@ async function createBackend(device) {
|
|
|
4291
4398
|
});
|
|
4292
4399
|
if (!gl) return null;
|
|
4293
4400
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4294
|
-
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-
|
|
4401
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-kvVt7-T7.cjs"));
|
|
4295
4402
|
return new WebGLBackend(gl);
|
|
4296
4403
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4297
4404
|
}
|
|
@@ -4495,6 +4602,12 @@ Object.defineProperty(exports, 'dtypedJsArray', {
|
|
|
4495
4602
|
return dtypedJsArray;
|
|
4496
4603
|
}
|
|
4497
4604
|
});
|
|
4605
|
+
Object.defineProperty(exports, 'emitTrace', {
|
|
4606
|
+
enumerable: true,
|
|
4607
|
+
get: function () {
|
|
4608
|
+
return emitTrace;
|
|
4609
|
+
}
|
|
4610
|
+
});
|
|
4498
4611
|
Object.defineProperty(exports, 'findPow2', {
|
|
4499
4612
|
enumerable: true,
|
|
4500
4613
|
get: function () {
|
|
@@ -4543,6 +4656,12 @@ Object.defineProperty(exports, 'isPermutation', {
|
|
|
4543
4656
|
return isPermutation;
|
|
4544
4657
|
}
|
|
4545
4658
|
});
|
|
4659
|
+
Object.defineProperty(exports, 'isTracing', {
|
|
4660
|
+
enumerable: true,
|
|
4661
|
+
get: function () {
|
|
4662
|
+
return isTracing;
|
|
4663
|
+
}
|
|
4664
|
+
});
|
|
4546
4665
|
Object.defineProperty(exports, 'mapSetUnion', {
|
|
4547
4666
|
enumerable: true,
|
|
4548
4667
|
get: function () {
|
|
@@ -4555,6 +4674,12 @@ Object.defineProperty(exports, 'normalizeAxis', {
|
|
|
4555
4674
|
return normalizeAxis;
|
|
4556
4675
|
}
|
|
4557
4676
|
});
|
|
4677
|
+
Object.defineProperty(exports, 'onFlushTrace', {
|
|
4678
|
+
enumerable: true,
|
|
4679
|
+
get: function () {
|
|
4680
|
+
return onFlushTrace;
|
|
4681
|
+
}
|
|
4682
|
+
});
|
|
4558
4683
|
Object.defineProperty(exports, 'partitionList', {
|
|
4559
4684
|
enumerable: true,
|
|
4560
4685
|
get: function () {
|
|
@@ -4603,6 +4728,18 @@ Object.defineProperty(exports, 'setDebug', {
|
|
|
4603
4728
|
return setDebug;
|
|
4604
4729
|
}
|
|
4605
4730
|
});
|
|
4731
|
+
Object.defineProperty(exports, 'startTrace', {
|
|
4732
|
+
enumerable: true,
|
|
4733
|
+
get: function () {
|
|
4734
|
+
return startTrace;
|
|
4735
|
+
}
|
|
4736
|
+
});
|
|
4737
|
+
Object.defineProperty(exports, 'stopTrace', {
|
|
4738
|
+
enumerable: true,
|
|
4739
|
+
get: function () {
|
|
4740
|
+
return stopTrace;
|
|
4741
|
+
}
|
|
4742
|
+
});
|
|
4606
4743
|
Object.defineProperty(exports, 'strip1', {
|
|
4607
4744
|
enumerable: true,
|
|
4608
4745
|
get: function () {
|
|
@@ -4615,6 +4752,12 @@ Object.defineProperty(exports, 'toposort', {
|
|
|
4615
4752
|
return toposort;
|
|
4616
4753
|
}
|
|
4617
4754
|
});
|
|
4755
|
+
Object.defineProperty(exports, 'traceSourceInfo', {
|
|
4756
|
+
enumerable: true,
|
|
4757
|
+
get: function () {
|
|
4758
|
+
return traceSourceInfo;
|
|
4759
|
+
}
|
|
4760
|
+
});
|
|
4618
4761
|
Object.defineProperty(exports, 'tuneNullopt', {
|
|
4619
4762
|
enumerable: true,
|
|
4620
4763
|
get: function () {
|