@jax-js/jax 0.1.9 → 0.1.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +31 -18
- package/dist/{backend-BId79r5b.js → backend-Ctqs8la1.js} +107 -11
- package/dist/{backend-DpI0riom.cjs → backend-DMauYnfl.cjs} +142 -10
- package/dist/index.cjs +225 -18
- package/dist/index.d.cts +112 -11
- package/dist/index.d.ts +112 -11
- package/dist/index.js +225 -19
- package/dist/{webgl-DnGrclTz.js → webgl-CvQ1QBX1.js} +1 -1
- package/dist/{webgl-C5NjXc1p.cjs → webgl-kvVt7-T7.cjs} +1 -1
- package/dist/{webgpu-CdjiJSa7.cjs → webgpu-DMSx7a6M.cjs} +136 -6
- package/dist/{webgpu-AN0cG_nB.js → webgpu-v_W_-oKw.js} +136 -6
- package/package.json +5 -16
package/README.md
CHANGED
|
@@ -43,13 +43,31 @@ way to get started on a blank HTML page.
|
|
|
43
43
|
</script>
|
|
44
44
|
```
|
|
45
45
|
|
|
46
|
+
### Platforms
|
|
47
|
+
|
|
48
|
+
This table refers to latest versions of each browser. WebGPU has gained wide support in browsers as
|
|
49
|
+
of late 2025.
|
|
50
|
+
|
|
51
|
+
| Platform | CPU (Wasm) | GPU (WebGPU) | GPU (WebGL) |
|
|
52
|
+
| ------------------- | ---------- | -------------- | ----------- |
|
|
53
|
+
| Chrome / Edge | ✅ | ✅ | ✅ |
|
|
54
|
+
| Firefox | ✅ | ✅ - macOS 26+ | ✅ |
|
|
55
|
+
| Safari | ✅ | ✅ - macOS 26+ | ✅ |
|
|
56
|
+
| iOS | ✅ | ✅ - iOS 26+ | ✅ |
|
|
57
|
+
| Chrome for Android | ✅ | ✅ | ✅ |
|
|
58
|
+
| Firefox for Android | ✅ | ❌ | ✅ |
|
|
59
|
+
| Node.js | ✅ | ❌ | ❌ |
|
|
60
|
+
| Deno | ✅ | ✅ - async | ❌ |
|
|
61
|
+
|
|
46
62
|
## Examples
|
|
47
63
|
|
|
48
|
-
|
|
64
|
+
Community usage:
|
|
49
65
|
|
|
66
|
+
- [**autoresearch-webgpu**: autoresesarch, in the browser](https://autoresearch.lucasgelfond.online/)
|
|
50
67
|
- [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)
|
|
68
|
+
- [**jax-js-bayes**: Declarative Bayesian modeling library](https://github.com/StefanSko/jax-js-bayes)
|
|
51
69
|
|
|
52
|
-
|
|
70
|
+
Demos on the jax-js website:
|
|
53
71
|
|
|
54
72
|
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
55
73
|
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
@@ -337,6 +355,8 @@ self-contained way in other projects.
|
|
|
337
355
|
|
|
338
356
|
### Performance
|
|
339
357
|
|
|
358
|
+
To see per-kernel traces in browser development tools, call `jax.profiler.startTrace()`.
|
|
359
|
+
|
|
340
360
|
The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual
|
|
341
361
|
browsers. Also, this library uniquely has the `jit()` feature that fuses operations together and
|
|
342
362
|
records an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an
|
|
@@ -367,19 +387,6 @@ pnpm install
|
|
|
367
387
|
pnpm run build:watch
|
|
368
388
|
```
|
|
369
389
|
|
|
370
|
-
The `pnpm install` command automatically sets up Git hooks via
|
|
371
|
-
[Husky](https://typicode.github.io/husky/). Pre-commit hooks will run ESLint and Prettier on staged
|
|
372
|
-
files to ensure code quality.
|
|
373
|
-
|
|
374
|
-
You can also run linting and formatting manually:
|
|
375
|
-
|
|
376
|
-
```bash
|
|
377
|
-
pnpm lint # Run ESLint
|
|
378
|
-
pnpm format # Format all files with Prettier
|
|
379
|
-
pnpm format:check # Check formatting without writing
|
|
380
|
-
pnpm check # Run TypeScript type checking
|
|
381
|
-
```
|
|
382
|
-
|
|
383
390
|
Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
|
|
384
391
|
|
|
385
392
|
```bash
|
|
@@ -396,6 +403,15 @@ To start a Vite dev server running the website, demos and REPL:
|
|
|
396
403
|
pnpm -C website dev
|
|
397
404
|
```
|
|
398
405
|
|
|
406
|
+
You can run the linter, code formatter, and type checker with:
|
|
407
|
+
|
|
408
|
+
```bash
|
|
409
|
+
pnpm lint # Run ESLint
|
|
410
|
+
pnpm format # Format all files with Prettier
|
|
411
|
+
pnpm format:check # Check formatting without writing
|
|
412
|
+
pnpm check # Run TypeScript type checking
|
|
413
|
+
```
|
|
414
|
+
|
|
399
415
|
## Future work / help wanted
|
|
400
416
|
|
|
401
417
|
Contributions are welcomed! Some fruitful areas to look into:
|
|
@@ -403,9 +419,6 @@ Contributions are welcomed! Some fruitful areas to look into:
|
|
|
403
419
|
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
404
420
|
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
405
421
|
and multithreading. (Even single-threaded Wasm could be ~20x faster.)
|
|
406
|
-
- Adding support for `jax.profiling`, in particular the start and end trace functions. We should be
|
|
407
|
-
able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
|
|
408
|
-
to help with model performance debugging.
|
|
409
422
|
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
410
423
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
411
424
|
|
|
@@ -1312,6 +1312,10 @@ var Reduction = class {
|
|
|
1312
1312
|
this.epilogue = epilogue;
|
|
1313
1313
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
|
|
1314
1314
|
this.epilogue = epilogue.simplify();
|
|
1315
|
+
if (this.dtype === DType.Float16 && this.op === AluOp.Add) {
|
|
1316
|
+
this.epilogue = this.epilogue.substitute({ acc: AluExp.cast(this.dtype, AluVar.acc(DType.Float32)) });
|
|
1317
|
+
this.dtype = DType.Float32;
|
|
1318
|
+
}
|
|
1315
1319
|
}
|
|
1316
1320
|
hash(state) {
|
|
1317
1321
|
state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
|
|
@@ -2266,11 +2270,15 @@ var TuneDims = class {
|
|
|
2266
2270
|
};
|
|
2267
2271
|
/** Tuning step that does not apply any optimization. */
|
|
2268
2272
|
function tuneNullopt(kernel) {
|
|
2273
|
+
let exp = kernel.exp;
|
|
2269
2274
|
const vars = {};
|
|
2270
2275
|
vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
|
|
2271
|
-
if (kernel.reduction)
|
|
2276
|
+
if (kernel.reduction) {
|
|
2277
|
+
vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2278
|
+
if (exp.dtype !== kernel.reduction.dtype) exp = AluExp.cast(kernel.reduction.dtype, exp);
|
|
2279
|
+
}
|
|
2272
2280
|
return {
|
|
2273
|
-
exp:
|
|
2281
|
+
exp: exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2274
2282
|
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2275
2283
|
outputIdxExp: vars.gidx,
|
|
2276
2284
|
threadCount: kernel.size,
|
|
@@ -2279,8 +2287,9 @@ function tuneNullopt(kernel) {
|
|
|
2279
2287
|
}
|
|
2280
2288
|
/** Tuning for WebGPU kernels. */
|
|
2281
2289
|
function tuneWebgpu(kernel) {
|
|
2282
|
-
const
|
|
2290
|
+
const reduction = kernel.reduction;
|
|
2283
2291
|
if (!reduction) return tuneNullopt(kernel);
|
|
2292
|
+
const exp = AluExp.cast(reduction.dtype, kernel.exp);
|
|
2284
2293
|
const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
|
|
2285
2294
|
if (globalIndexes.length > 0) {
|
|
2286
2295
|
if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
|
|
@@ -2508,6 +2517,85 @@ var CpuBackend = class {
|
|
|
2508
2517
|
}
|
|
2509
2518
|
};
|
|
2510
2519
|
|
|
2520
|
+
//#endregion
|
|
2521
|
+
//#region src/tracing.ts
|
|
2522
|
+
let traceEnabled = false;
|
|
2523
|
+
const flushCallbacks = [];
|
|
2524
|
+
/**
|
|
2525
|
+
* Start collecting kernel traces.
|
|
2526
|
+
*
|
|
2527
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2528
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2529
|
+
*/
|
|
2530
|
+
function startTrace() {
|
|
2531
|
+
traceEnabled = true;
|
|
2532
|
+
}
|
|
2533
|
+
/**
|
|
2534
|
+
* Stop collecting kernel traces.
|
|
2535
|
+
*
|
|
2536
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2537
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2538
|
+
*/
|
|
2539
|
+
function stopTrace() {
|
|
2540
|
+
traceEnabled = false;
|
|
2541
|
+
for (const cb of flushCallbacks) cb();
|
|
2542
|
+
}
|
|
2543
|
+
/** Check if tracing is currently enabled. */
|
|
2544
|
+
function isTracing() {
|
|
2545
|
+
return traceEnabled;
|
|
2546
|
+
}
|
|
2547
|
+
/** Register a callback to flush pending trace data when tracing stops. */
|
|
2548
|
+
function onFlushTrace(cb) {
|
|
2549
|
+
flushCallbacks.push(cb);
|
|
2550
|
+
}
|
|
2551
|
+
function humanSize(n) {
|
|
2552
|
+
if (n >= 1e9) return `${(n / 1e9).toPrecision(3)}B`;
|
|
2553
|
+
if (n >= 1e6) return `${(n / 1e6).toPrecision(3)}M`;
|
|
2554
|
+
if (n >= 1e3) return `${(n / 1e3).toPrecision(3)}K`;
|
|
2555
|
+
return `${n}`;
|
|
2556
|
+
}
|
|
2557
|
+
/** Build a trace label, properties, and color from a kernel or routine source. */
|
|
2558
|
+
function traceSourceInfo(source) {
|
|
2559
|
+
const properties = [];
|
|
2560
|
+
let label;
|
|
2561
|
+
let color;
|
|
2562
|
+
if (source instanceof Kernel) {
|
|
2563
|
+
label = `Kernel[${humanSize(source.size)}]`;
|
|
2564
|
+
properties.push(["exp", `${source.exp}`]);
|
|
2565
|
+
properties.push(["size", `${source.size}`]);
|
|
2566
|
+
properties.push(["nargs", `${source.nargs}`]);
|
|
2567
|
+
if (!source.reduction) color = "primary";
|
|
2568
|
+
else {
|
|
2569
|
+
color = "secondary";
|
|
2570
|
+
properties.push(["reduction", `${source.reduction.op}:${source.reduction.size}`]);
|
|
2571
|
+
}
|
|
2572
|
+
} else {
|
|
2573
|
+
color = "tertiary";
|
|
2574
|
+
label = source.name;
|
|
2575
|
+
properties.push(["inputShapes", source.type.inputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2576
|
+
properties.push(["outputShapes", source.type.outputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2577
|
+
properties.push(["dtype", source.type.inputDtypes.join(", ")]);
|
|
2578
|
+
}
|
|
2579
|
+
return {
|
|
2580
|
+
label,
|
|
2581
|
+
color,
|
|
2582
|
+
properties
|
|
2583
|
+
};
|
|
2584
|
+
}
|
|
2585
|
+
/** Emit a trace entry as a `performance.measure` with devtools metadata. */
|
|
2586
|
+
function emitTrace(track, info, start, end) {
|
|
2587
|
+
performance.measure(info.label, {
|
|
2588
|
+
detail: { devtools: {
|
|
2589
|
+
trackGroup: "JAX Profiler",
|
|
2590
|
+
track,
|
|
2591
|
+
color: info.color,
|
|
2592
|
+
properties: info.properties
|
|
2593
|
+
} },
|
|
2594
|
+
start,
|
|
2595
|
+
end
|
|
2596
|
+
});
|
|
2597
|
+
}
|
|
2598
|
+
|
|
2511
2599
|
//#endregion
|
|
2512
2600
|
//#region src/backend/wasm/allocator.ts
|
|
2513
2601
|
/** Simple tensor memory allocator for WebAssembly linear memory. */
|
|
@@ -3914,11 +4002,19 @@ var WasmBackend = class {
|
|
|
3914
4002
|
return new Executable(routine, void 0);
|
|
3915
4003
|
}
|
|
3916
4004
|
dispatch(exe, inputs, outputs) {
|
|
3917
|
-
|
|
3918
|
-
const
|
|
3919
|
-
|
|
3920
|
-
|
|
3921
|
-
|
|
4005
|
+
const tracing = isTracing();
|
|
4006
|
+
const start = tracing ? performance.now() : 0;
|
|
4007
|
+
if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
4008
|
+
else {
|
|
4009
|
+
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
4010
|
+
const func = instance.exports.kernel;
|
|
4011
|
+
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
4012
|
+
func(...ptrs);
|
|
4013
|
+
}
|
|
4014
|
+
if (tracing) {
|
|
4015
|
+
const info = traceSourceInfo(exe.source);
|
|
4016
|
+
emitTrace("wasm", info, start, performance.now());
|
|
4017
|
+
}
|
|
3922
4018
|
}
|
|
3923
4019
|
#getBuffer(slot) {
|
|
3924
4020
|
const buffer = this.#buffers.get(slot);
|
|
@@ -4263,7 +4359,7 @@ async function createBackend(device) {
|
|
|
4263
4359
|
if (!navigator.gpu) return null;
|
|
4264
4360
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4265
4361
|
if (!adapter) return null;
|
|
4266
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4362
|
+
const { WebGPUBackend } = await import("./webgpu-v_W_-oKw.js");
|
|
4267
4363
|
const importantLimits = [
|
|
4268
4364
|
"maxBufferSize",
|
|
4269
4365
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4301,7 +4397,7 @@ async function createBackend(device) {
|
|
|
4301
4397
|
});
|
|
4302
4398
|
if (!gl) return null;
|
|
4303
4399
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4304
|
-
const { WebGLBackend } = await import("./webgl-
|
|
4400
|
+
const { WebGLBackend } = await import("./webgl-CvQ1QBX1.js");
|
|
4305
4401
|
return new WebGLBackend(gl);
|
|
4306
4402
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4307
4403
|
}
|
|
@@ -4337,4 +4433,4 @@ var UnsupportedRoutineError = class extends Error {
|
|
|
4337
4433
|
};
|
|
4338
4434
|
|
|
4339
4435
|
//#endregion
|
|
4340
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4436
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, emitTrace, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, isTracing, mapSetUnion, normalizeAxis, onFlushTrace, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, strip1, toposort, traceSourceInfo, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
@@ -1313,6 +1313,10 @@ var Reduction = class {
|
|
|
1313
1313
|
this.epilogue = epilogue;
|
|
1314
1314
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
|
|
1315
1315
|
this.epilogue = epilogue.simplify();
|
|
1316
|
+
if (this.dtype === DType.Float16 && this.op === AluOp.Add) {
|
|
1317
|
+
this.epilogue = this.epilogue.substitute({ acc: AluExp.cast(this.dtype, AluVar.acc(DType.Float32)) });
|
|
1318
|
+
this.dtype = DType.Float32;
|
|
1319
|
+
}
|
|
1316
1320
|
}
|
|
1317
1321
|
hash(state) {
|
|
1318
1322
|
state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
|
|
@@ -2267,11 +2271,15 @@ var TuneDims = class {
|
|
|
2267
2271
|
};
|
|
2268
2272
|
/** Tuning step that does not apply any optimization. */
|
|
2269
2273
|
function tuneNullopt(kernel) {
|
|
2274
|
+
let exp = kernel.exp;
|
|
2270
2275
|
const vars = {};
|
|
2271
2276
|
vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
|
|
2272
|
-
if (kernel.reduction)
|
|
2277
|
+
if (kernel.reduction) {
|
|
2278
|
+
vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
2279
|
+
if (exp.dtype !== kernel.reduction.dtype) exp = AluExp.cast(kernel.reduction.dtype, exp);
|
|
2280
|
+
}
|
|
2273
2281
|
return {
|
|
2274
|
-
exp:
|
|
2282
|
+
exp: exp.substitute(vars).rewriteGlobalViews().simplify(),
|
|
2275
2283
|
epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
|
|
2276
2284
|
outputIdxExp: vars.gidx,
|
|
2277
2285
|
threadCount: kernel.size,
|
|
@@ -2280,8 +2288,9 @@ function tuneNullopt(kernel) {
|
|
|
2280
2288
|
}
|
|
2281
2289
|
/** Tuning for WebGPU kernels. */
|
|
2282
2290
|
function tuneWebgpu(kernel) {
|
|
2283
|
-
const
|
|
2291
|
+
const reduction = kernel.reduction;
|
|
2284
2292
|
if (!reduction) return tuneNullopt(kernel);
|
|
2293
|
+
const exp = AluExp.cast(reduction.dtype, kernel.exp);
|
|
2285
2294
|
const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
|
|
2286
2295
|
if (globalIndexes.length > 0) {
|
|
2287
2296
|
if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
|
|
@@ -2509,6 +2518,85 @@ var CpuBackend = class {
|
|
|
2509
2518
|
}
|
|
2510
2519
|
};
|
|
2511
2520
|
|
|
2521
|
+
//#endregion
|
|
2522
|
+
//#region src/tracing.ts
|
|
2523
|
+
let traceEnabled = false;
|
|
2524
|
+
const flushCallbacks = [];
|
|
2525
|
+
/**
|
|
2526
|
+
* Start collecting kernel traces.
|
|
2527
|
+
*
|
|
2528
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2529
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2530
|
+
*/
|
|
2531
|
+
function startTrace() {
|
|
2532
|
+
traceEnabled = true;
|
|
2533
|
+
}
|
|
2534
|
+
/**
|
|
2535
|
+
* Stop collecting kernel traces.
|
|
2536
|
+
*
|
|
2537
|
+
* Traces appear in developer tools under the "Performance" tab, and they are
|
|
2538
|
+
* useful for measuring fine-grained kernel execution time.
|
|
2539
|
+
*/
|
|
2540
|
+
function stopTrace() {
|
|
2541
|
+
traceEnabled = false;
|
|
2542
|
+
for (const cb of flushCallbacks) cb();
|
|
2543
|
+
}
|
|
2544
|
+
/** Check if tracing is currently enabled. */
|
|
2545
|
+
function isTracing() {
|
|
2546
|
+
return traceEnabled;
|
|
2547
|
+
}
|
|
2548
|
+
/** Register a callback to flush pending trace data when tracing stops. */
|
|
2549
|
+
function onFlushTrace(cb) {
|
|
2550
|
+
flushCallbacks.push(cb);
|
|
2551
|
+
}
|
|
2552
|
+
function humanSize(n) {
|
|
2553
|
+
if (n >= 1e9) return `${(n / 1e9).toPrecision(3)}B`;
|
|
2554
|
+
if (n >= 1e6) return `${(n / 1e6).toPrecision(3)}M`;
|
|
2555
|
+
if (n >= 1e3) return `${(n / 1e3).toPrecision(3)}K`;
|
|
2556
|
+
return `${n}`;
|
|
2557
|
+
}
|
|
2558
|
+
/** Build a trace label, properties, and color from a kernel or routine source. */
|
|
2559
|
+
function traceSourceInfo(source) {
|
|
2560
|
+
const properties = [];
|
|
2561
|
+
let label;
|
|
2562
|
+
let color;
|
|
2563
|
+
if (source instanceof Kernel) {
|
|
2564
|
+
label = `Kernel[${humanSize(source.size)}]`;
|
|
2565
|
+
properties.push(["exp", `${source.exp}`]);
|
|
2566
|
+
properties.push(["size", `${source.size}`]);
|
|
2567
|
+
properties.push(["nargs", `${source.nargs}`]);
|
|
2568
|
+
if (!source.reduction) color = "primary";
|
|
2569
|
+
else {
|
|
2570
|
+
color = "secondary";
|
|
2571
|
+
properties.push(["reduction", `${source.reduction.op}:${source.reduction.size}`]);
|
|
2572
|
+
}
|
|
2573
|
+
} else {
|
|
2574
|
+
color = "tertiary";
|
|
2575
|
+
label = source.name;
|
|
2576
|
+
properties.push(["inputShapes", source.type.inputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2577
|
+
properties.push(["outputShapes", source.type.outputShapes.map((s) => `[${s}]`).join(", ")]);
|
|
2578
|
+
properties.push(["dtype", source.type.inputDtypes.join(", ")]);
|
|
2579
|
+
}
|
|
2580
|
+
return {
|
|
2581
|
+
label,
|
|
2582
|
+
color,
|
|
2583
|
+
properties
|
|
2584
|
+
};
|
|
2585
|
+
}
|
|
2586
|
+
/** Emit a trace entry as a `performance.measure` with devtools metadata. */
|
|
2587
|
+
function emitTrace(track, info, start, end) {
|
|
2588
|
+
performance.measure(info.label, {
|
|
2589
|
+
detail: { devtools: {
|
|
2590
|
+
trackGroup: "JAX Profiler",
|
|
2591
|
+
track,
|
|
2592
|
+
color: info.color,
|
|
2593
|
+
properties: info.properties
|
|
2594
|
+
} },
|
|
2595
|
+
start,
|
|
2596
|
+
end
|
|
2597
|
+
});
|
|
2598
|
+
}
|
|
2599
|
+
|
|
2512
2600
|
//#endregion
|
|
2513
2601
|
//#region src/backend/wasm/allocator.ts
|
|
2514
2602
|
/** Simple tensor memory allocator for WebAssembly linear memory. */
|
|
@@ -3915,11 +4003,19 @@ var WasmBackend = class {
|
|
|
3915
4003
|
return new Executable(routine, void 0);
|
|
3916
4004
|
}
|
|
3917
4005
|
dispatch(exe, inputs, outputs) {
|
|
3918
|
-
|
|
3919
|
-
const
|
|
3920
|
-
|
|
3921
|
-
|
|
3922
|
-
|
|
4006
|
+
const tracing = isTracing();
|
|
4007
|
+
const start = tracing ? performance.now() : 0;
|
|
4008
|
+
if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
4009
|
+
else {
|
|
4010
|
+
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
4011
|
+
const func = instance.exports.kernel;
|
|
4012
|
+
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
4013
|
+
func(...ptrs);
|
|
4014
|
+
}
|
|
4015
|
+
if (tracing) {
|
|
4016
|
+
const info = traceSourceInfo(exe.source);
|
|
4017
|
+
emitTrace("wasm", info, start, performance.now());
|
|
4018
|
+
}
|
|
3923
4019
|
}
|
|
3924
4020
|
#getBuffer(slot) {
|
|
3925
4021
|
const buffer = this.#buffers.get(slot);
|
|
@@ -4264,7 +4360,7 @@ async function createBackend(device) {
|
|
|
4264
4360
|
if (!navigator.gpu) return null;
|
|
4265
4361
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4266
4362
|
if (!adapter) return null;
|
|
4267
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4363
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-DMSx7a6M.cjs"));
|
|
4268
4364
|
const importantLimits = [
|
|
4269
4365
|
"maxBufferSize",
|
|
4270
4366
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4302,7 +4398,7 @@ async function createBackend(device) {
|
|
|
4302
4398
|
});
|
|
4303
4399
|
if (!gl) return null;
|
|
4304
4400
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4305
|
-
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-
|
|
4401
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-kvVt7-T7.cjs"));
|
|
4306
4402
|
return new WebGLBackend(gl);
|
|
4307
4403
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4308
4404
|
}
|
|
@@ -4506,6 +4602,12 @@ Object.defineProperty(exports, 'dtypedJsArray', {
|
|
|
4506
4602
|
return dtypedJsArray;
|
|
4507
4603
|
}
|
|
4508
4604
|
});
|
|
4605
|
+
Object.defineProperty(exports, 'emitTrace', {
|
|
4606
|
+
enumerable: true,
|
|
4607
|
+
get: function () {
|
|
4608
|
+
return emitTrace;
|
|
4609
|
+
}
|
|
4610
|
+
});
|
|
4509
4611
|
Object.defineProperty(exports, 'findPow2', {
|
|
4510
4612
|
enumerable: true,
|
|
4511
4613
|
get: function () {
|
|
@@ -4554,6 +4656,12 @@ Object.defineProperty(exports, 'isPermutation', {
|
|
|
4554
4656
|
return isPermutation;
|
|
4555
4657
|
}
|
|
4556
4658
|
});
|
|
4659
|
+
Object.defineProperty(exports, 'isTracing', {
|
|
4660
|
+
enumerable: true,
|
|
4661
|
+
get: function () {
|
|
4662
|
+
return isTracing;
|
|
4663
|
+
}
|
|
4664
|
+
});
|
|
4557
4665
|
Object.defineProperty(exports, 'mapSetUnion', {
|
|
4558
4666
|
enumerable: true,
|
|
4559
4667
|
get: function () {
|
|
@@ -4566,6 +4674,12 @@ Object.defineProperty(exports, 'normalizeAxis', {
|
|
|
4566
4674
|
return normalizeAxis;
|
|
4567
4675
|
}
|
|
4568
4676
|
});
|
|
4677
|
+
Object.defineProperty(exports, 'onFlushTrace', {
|
|
4678
|
+
enumerable: true,
|
|
4679
|
+
get: function () {
|
|
4680
|
+
return onFlushTrace;
|
|
4681
|
+
}
|
|
4682
|
+
});
|
|
4569
4683
|
Object.defineProperty(exports, 'partitionList', {
|
|
4570
4684
|
enumerable: true,
|
|
4571
4685
|
get: function () {
|
|
@@ -4614,6 +4728,18 @@ Object.defineProperty(exports, 'setDebug', {
|
|
|
4614
4728
|
return setDebug;
|
|
4615
4729
|
}
|
|
4616
4730
|
});
|
|
4731
|
+
Object.defineProperty(exports, 'startTrace', {
|
|
4732
|
+
enumerable: true,
|
|
4733
|
+
get: function () {
|
|
4734
|
+
return startTrace;
|
|
4735
|
+
}
|
|
4736
|
+
});
|
|
4737
|
+
Object.defineProperty(exports, 'stopTrace', {
|
|
4738
|
+
enumerable: true,
|
|
4739
|
+
get: function () {
|
|
4740
|
+
return stopTrace;
|
|
4741
|
+
}
|
|
4742
|
+
});
|
|
4617
4743
|
Object.defineProperty(exports, 'strip1', {
|
|
4618
4744
|
enumerable: true,
|
|
4619
4745
|
get: function () {
|
|
@@ -4626,6 +4752,12 @@ Object.defineProperty(exports, 'toposort', {
|
|
|
4626
4752
|
return toposort;
|
|
4627
4753
|
}
|
|
4628
4754
|
});
|
|
4755
|
+
Object.defineProperty(exports, 'traceSourceInfo', {
|
|
4756
|
+
enumerable: true,
|
|
4757
|
+
get: function () {
|
|
4758
|
+
return traceSourceInfo;
|
|
4759
|
+
}
|
|
4760
|
+
});
|
|
4629
4761
|
Object.defineProperty(exports, 'tuneNullopt', {
|
|
4630
4762
|
enumerable: true,
|
|
4631
4763
|
get: function () {
|