@jax-js/jax 0.1.8 → 0.1.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -43,6 +43,41 @@ way to get started on a blank HTML page.
43
43
  </script>
44
44
  ```
45
45
 
46
+ ### Platforms
47
+
48
+ This table refers to latest versions of each browser. WebGPU has gained wide support in browsers as
49
+ of late 2025.
50
+
51
+ | Platform | CPU (Wasm) | GPU (WebGPU) | GPU (WebGL) |
52
+ | ------------------- | ---------- | -------------- | ----------- |
53
+ | Chrome / Edge | ✅ | ✅ | ✅ |
54
+ | Firefox | ✅ | ✅ - macOS 26+ | ✅ |
55
+ | Safari | ✅ | ✅ - macOS 26+ | ✅ |
56
+ | iOS | ✅ | ✅ - iOS 26+ | ✅ |
57
+ | Chrome for Android | ✅ | ✅ | ✅ |
58
+ | Firefox for Android | ✅ | ❌ | ✅ |
59
+ | Node.js | ✅ | ❌ | ❌ |
60
+ | Deno | ✅ | ✅ - async | ❌ |
61
+
62
+ ## Examples
63
+
64
+ Community usage:
65
+
66
+ - [**autoresearch-webgpu**: autoresesarch, in the browser](https://autoresearch.lucasgelfond.online/)
67
+ - [**tanh.xyz**: Interactive ML visualizations](https://tanh.xyz/)
68
+ - [**jax-js-bayes**: Declarative Bayesian modeling library](https://github.com/StefanSko/jax-js-bayes)
69
+
70
+ Demos on the jax-js website:
71
+
72
+ - [Training neural networks on MNIST](https://jax-js.com/mnist)
73
+ - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
74
+ - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
75
+ - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
76
+ - [In-browser REPL](https://jax-js.com/repl)
77
+ - [Matmul benchmark](https://jax-js.com/bench/matmul)
78
+ - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
79
+ - [Mandelbrot set](https://jax-js.com/mandelbrot)
80
+
46
81
  ## Feature comparison
47
82
 
48
83
  Here's a quick, high-level comparison with other popular web ML runtimes:
@@ -320,6 +355,8 @@ self-contained way in other projects.
320
355
 
321
356
  ### Performance
322
357
 
358
+ To see per-kernel traces in browser development tools, call `jax.profiler.startTrace()`.
359
+
323
360
  The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual
324
361
  browsers. Also, this library uniquely has the `jit()` feature that fuses operations together and
325
362
  records an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an
@@ -338,19 +375,6 @@ well as unique optimizations such as FlashAttention variants.
338
375
  That's all for this short tutorial. Please see the generated
339
376
  [API reference](https://jax-js.com/docs) for detailed documentation.
340
377
 
341
- ## Examples
342
-
343
- If you make something cool with jax-js, don't be a stranger! We can feature it here.
344
-
345
- - [Training neural networks on MNIST](https://jax-js.com/mnist)
346
- - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
347
- - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
348
- - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
349
- - [In-browser REPL](https://jax-js.com/repl)
350
- - [Matmul benchmark](https://jax-js.com/bench/matmul)
351
- - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
352
- - [Mandelbrot set](https://jax-js.com/mandelbrot)
353
-
354
378
  ## Development
355
379
 
356
380
  _The following technical details are for contributing to jax-js and modifying its internals._
@@ -363,19 +387,6 @@ pnpm install
363
387
  pnpm run build:watch
364
388
  ```
365
389
 
366
- The `pnpm install` command automatically sets up Git hooks via
367
- [Husky](https://typicode.github.io/husky/). Pre-commit hooks will run ESLint and Prettier on staged
368
- files to ensure code quality.
369
-
370
- You can also run linting and formatting manually:
371
-
372
- ```bash
373
- pnpm lint # Run ESLint
374
- pnpm format # Format all files with Prettier
375
- pnpm format:check # Check formatting without writing
376
- pnpm check # Run TypeScript type checking
377
- ```
378
-
379
390
  Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
380
391
 
381
392
  ```bash
@@ -392,6 +403,15 @@ To start a Vite dev server running the website, demos and REPL:
392
403
  pnpm -C website dev
393
404
  ```
394
405
 
406
+ You can run the linter, code formatter, and type checker with:
407
+
408
+ ```bash
409
+ pnpm lint # Run ESLint
410
+ pnpm format # Format all files with Prettier
411
+ pnpm format:check # Check formatting without writing
412
+ pnpm check # Run TypeScript type checking
413
+ ```
414
+
395
415
  ## Future work / help wanted
396
416
 
397
417
  Contributions are welcomed! Some fruitful areas to look into:
@@ -399,9 +419,6 @@ Contributions are welcomed! Some fruitful areas to look into:
399
419
  - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
400
420
  - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
401
421
  and multithreading. (Even single-threaded Wasm could be ~20x faster.)
402
- - Adding support for `jax.profiling`, in particular the start and end trace functions. We should be
403
- able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
404
- to help with model performance debugging.
405
422
  - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
406
423
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
407
424
 
@@ -1312,6 +1312,10 @@ var Reduction = class {
1312
1312
  this.epilogue = epilogue;
1313
1313
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
1314
1314
  this.epilogue = epilogue.simplify();
1315
+ if (this.dtype === DType.Float16 && this.op === AluOp.Add) {
1316
+ this.epilogue = this.epilogue.substitute({ acc: AluExp.cast(this.dtype, AluVar.acc(DType.Float32)) });
1317
+ this.dtype = DType.Float32;
1318
+ }
1315
1319
  }
1316
1320
  hash(state) {
1317
1321
  state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
@@ -1479,9 +1483,14 @@ var Routine = class {
1479
1483
  };
1480
1484
  /** One of the valid `Routine` that can be dispatched to backend. */
1481
1485
  let Routines = /* @__PURE__ */ function(Routines$1) {
1482
- /** Stable sorting algorithm along the last axis. */
1486
+ /**
1487
+ * Sort along the last axis.
1488
+ *
1489
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
1490
+ * bitwise unique up to signed zeros and NaNs.
1491
+ */
1483
1492
  Routines$1["Sort"] = "Sort";
1484
- /** Returns `int32` indices of the stably sorted array. */
1493
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
1485
1494
  Routines$1["Argsort"] = "Argsort";
1486
1495
  /**
1487
1496
  * Solve a triangular system of equations.
@@ -1545,7 +1554,13 @@ function runArgsort(type, [x], [y, yi]) {
1545
1554
  const out = y.subarray(offset, offset + n);
1546
1555
  const outi = yi.subarray(offset, offset + n);
1547
1556
  for (let i = 0; i < n; i++) outi[i] = i;
1548
- outi.sort((a, b) => ar[a] - ar[b]);
1557
+ outi.sort((a, b) => {
1558
+ const x$1 = ar[a];
1559
+ const y$1 = ar[b];
1560
+ if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
1561
+ if (isNaN(y$1)) return -1;
1562
+ return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
1563
+ });
1549
1564
  for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1550
1565
  }
1551
1566
  }
@@ -2255,11 +2270,15 @@ var TuneDims = class {
2255
2270
  };
2256
2271
  /** Tuning step that does not apply any optimization. */
2257
2272
  function tuneNullopt(kernel) {
2273
+ let exp = kernel.exp;
2258
2274
  const vars = {};
2259
2275
  vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
2260
- if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2276
+ if (kernel.reduction) {
2277
+ vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2278
+ if (exp.dtype !== kernel.reduction.dtype) exp = AluExp.cast(kernel.reduction.dtype, exp);
2279
+ }
2261
2280
  return {
2262
- exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
2281
+ exp: exp.substitute(vars).rewriteGlobalViews().simplify(),
2263
2282
  epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
2264
2283
  outputIdxExp: vars.gidx,
2265
2284
  threadCount: kernel.size,
@@ -2268,8 +2287,9 @@ function tuneNullopt(kernel) {
2268
2287
  }
2269
2288
  /** Tuning for WebGPU kernels. */
2270
2289
  function tuneWebgpu(kernel) {
2271
- const { exp, reduction } = kernel;
2290
+ const reduction = kernel.reduction;
2272
2291
  if (!reduction) return tuneNullopt(kernel);
2292
+ const exp = AluExp.cast(reduction.dtype, kernel.exp);
2273
2293
  const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
2274
2294
  if (globalIndexes.length > 0) {
2275
2295
  if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
@@ -2321,7 +2341,7 @@ function tuneWebgpu(kernel) {
2321
2341
  if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2322
2342
  const s = dim.st.shape[dim.unroll - 1];
2323
2343
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2324
- else for (const splits of [8, 4]) if (s % splits === 0) {
2344
+ else for (const splits of [4, 2]) if (s % splits === 0) {
2325
2345
  dim.applyUnroll(dim.unroll - 1, splits);
2326
2346
  break;
2327
2347
  }
@@ -2497,6 +2517,85 @@ var CpuBackend = class {
2497
2517
  }
2498
2518
  };
2499
2519
 
2520
+ //#endregion
2521
+ //#region src/tracing.ts
2522
+ let traceEnabled = false;
2523
+ const flushCallbacks = [];
2524
+ /**
2525
+ * Start collecting kernel traces.
2526
+ *
2527
+ * Traces appear in developer tools under the "Performance" tab, and they are
2528
+ * useful for measuring fine-grained kernel execution time.
2529
+ */
2530
+ function startTrace() {
2531
+ traceEnabled = true;
2532
+ }
2533
+ /**
2534
+ * Stop collecting kernel traces.
2535
+ *
2536
+ * Traces appear in developer tools under the "Performance" tab, and they are
2537
+ * useful for measuring fine-grained kernel execution time.
2538
+ */
2539
+ function stopTrace() {
2540
+ traceEnabled = false;
2541
+ for (const cb of flushCallbacks) cb();
2542
+ }
2543
+ /** Check if tracing is currently enabled. */
2544
+ function isTracing() {
2545
+ return traceEnabled;
2546
+ }
2547
+ /** Register a callback to flush pending trace data when tracing stops. */
2548
+ function onFlushTrace(cb) {
2549
+ flushCallbacks.push(cb);
2550
+ }
2551
+ function humanSize(n) {
2552
+ if (n >= 1e9) return `${(n / 1e9).toPrecision(3)}B`;
2553
+ if (n >= 1e6) return `${(n / 1e6).toPrecision(3)}M`;
2554
+ if (n >= 1e3) return `${(n / 1e3).toPrecision(3)}K`;
2555
+ return `${n}`;
2556
+ }
2557
+ /** Build a trace label, properties, and color from a kernel or routine source. */
2558
+ function traceSourceInfo(source) {
2559
+ const properties = [];
2560
+ let label;
2561
+ let color;
2562
+ if (source instanceof Kernel) {
2563
+ label = `Kernel[${humanSize(source.size)}]`;
2564
+ properties.push(["exp", `${source.exp}`]);
2565
+ properties.push(["size", `${source.size}`]);
2566
+ properties.push(["nargs", `${source.nargs}`]);
2567
+ if (!source.reduction) color = "primary";
2568
+ else {
2569
+ color = "secondary";
2570
+ properties.push(["reduction", `${source.reduction.op}:${source.reduction.size}`]);
2571
+ }
2572
+ } else {
2573
+ color = "tertiary";
2574
+ label = source.name;
2575
+ properties.push(["inputShapes", source.type.inputShapes.map((s) => `[${s}]`).join(", ")]);
2576
+ properties.push(["outputShapes", source.type.outputShapes.map((s) => `[${s}]`).join(", ")]);
2577
+ properties.push(["dtype", source.type.inputDtypes.join(", ")]);
2578
+ }
2579
+ return {
2580
+ label,
2581
+ color,
2582
+ properties
2583
+ };
2584
+ }
2585
+ /** Emit a trace entry as a `performance.measure` with devtools metadata. */
2586
+ function emitTrace(track, info, start, end) {
2587
+ performance.measure(info.label, {
2588
+ detail: { devtools: {
2589
+ trackGroup: "JAX Profiler",
2590
+ track,
2591
+ color: info.color,
2592
+ properties: info.properties
2593
+ } },
2594
+ start,
2595
+ end
2596
+ });
2597
+ }
2598
+
2500
2599
  //#endregion
2501
2600
  //#region src/backend/wasm/allocator.ts
2502
2601
  /** Simple tensor memory allocator for WebAssembly linear memory. */
@@ -3903,11 +4002,19 @@ var WasmBackend = class {
3903
4002
  return new Executable(routine, void 0);
3904
4003
  }
3905
4004
  dispatch(exe, inputs, outputs) {
3906
- if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
3907
- const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
3908
- const func = instance.exports.kernel;
3909
- const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
3910
- func(...ptrs);
4005
+ const tracing = isTracing();
4006
+ const start = tracing ? performance.now() : 0;
4007
+ if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
4008
+ else {
4009
+ const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
4010
+ const func = instance.exports.kernel;
4011
+ const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
4012
+ func(...ptrs);
4013
+ }
4014
+ if (tracing) {
4015
+ const info = traceSourceInfo(exe.source);
4016
+ emitTrace("wasm", info, start, performance.now());
4017
+ }
3911
4018
  }
3912
4019
  #getBuffer(slot) {
3913
4020
  const buffer = this.#buffers.get(slot);
@@ -4252,7 +4359,7 @@ async function createBackend(device) {
4252
4359
  if (!navigator.gpu) return null;
4253
4360
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4254
4361
  if (!adapter) return null;
4255
- const { WebGPUBackend } = await import("./webgpu-B96vzWGE.js");
4362
+ const { WebGPUBackend } = await import("./webgpu-v_W_-oKw.js");
4256
4363
  const importantLimits = [
4257
4364
  "maxBufferSize",
4258
4365
  "maxComputeInvocationsPerWorkgroup",
@@ -4290,7 +4397,7 @@ async function createBackend(device) {
4290
4397
  });
4291
4398
  if (!gl) return null;
4292
4399
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4293
- const { WebGLBackend } = await import("./webgl-DweKSWEm.js");
4400
+ const { WebGLBackend } = await import("./webgl-CvQ1QBX1.js");
4294
4401
  return new WebGLBackend(gl);
4295
4402
  } else throw new Error(`Backend not found: ${device}`);
4296
4403
  }
@@ -4326,4 +4433,4 @@ var UnsupportedRoutineError = class extends Error {
4326
4433
  };
4327
4434
 
4328
4435
  //#endregion
4329
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
4436
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, emitTrace, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, isTracing, mapSetUnion, normalizeAxis, onFlushTrace, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, strip1, toposort, traceSourceInfo, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
@@ -1313,6 +1313,10 @@ var Reduction = class {
1313
1313
  this.epilogue = epilogue;
1314
1314
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
1315
1315
  this.epilogue = epilogue.simplify();
1316
+ if (this.dtype === DType.Float16 && this.op === AluOp.Add) {
1317
+ this.epilogue = this.epilogue.substitute({ acc: AluExp.cast(this.dtype, AluVar.acc(DType.Float32)) });
1318
+ this.dtype = DType.Float32;
1319
+ }
1316
1320
  }
1317
1321
  hash(state) {
1318
1322
  state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
@@ -1480,9 +1484,14 @@ var Routine = class {
1480
1484
  };
1481
1485
  /** One of the valid `Routine` that can be dispatched to backend. */
1482
1486
  let Routines = /* @__PURE__ */ function(Routines$1) {
1483
- /** Stable sorting algorithm along the last axis. */
1487
+ /**
1488
+ * Sort along the last axis.
1489
+ *
1490
+ * This may be _unstable_ but it often doesn't matter, sorting numbers is
1491
+ * bitwise unique up to signed zeros and NaNs.
1492
+ */
1484
1493
  Routines$1["Sort"] = "Sort";
1485
- /** Returns `int32` indices of the stably sorted array. */
1494
+ /** Stable sorting, returns `int32` indices and values of the sorted array. */
1486
1495
  Routines$1["Argsort"] = "Argsort";
1487
1496
  /**
1488
1497
  * Solve a triangular system of equations.
@@ -1546,7 +1555,13 @@ function runArgsort(type, [x], [y, yi]) {
1546
1555
  const out = y.subarray(offset, offset + n);
1547
1556
  const outi = yi.subarray(offset, offset + n);
1548
1557
  for (let i = 0; i < n; i++) outi[i] = i;
1549
- outi.sort((a, b) => ar[a] - ar[b]);
1558
+ outi.sort((a, b) => {
1559
+ const x$1 = ar[a];
1560
+ const y$1 = ar[b];
1561
+ if (isNaN(x$1)) return isNaN(y$1) ? 0 : 1;
1562
+ if (isNaN(y$1)) return -1;
1563
+ return x$1 === y$1 ? 0 : x$1 < y$1 ? -1 : 1;
1564
+ });
1550
1565
  for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1551
1566
  }
1552
1567
  }
@@ -2256,11 +2271,15 @@ var TuneDims = class {
2256
2271
  };
2257
2272
  /** Tuning step that does not apply any optimization. */
2258
2273
  function tuneNullopt(kernel) {
2274
+ let exp = kernel.exp;
2259
2275
  const vars = {};
2260
2276
  vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
2261
- if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2277
+ if (kernel.reduction) {
2278
+ vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2279
+ if (exp.dtype !== kernel.reduction.dtype) exp = AluExp.cast(kernel.reduction.dtype, exp);
2280
+ }
2262
2281
  return {
2263
- exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
2282
+ exp: exp.substitute(vars).rewriteGlobalViews().simplify(),
2264
2283
  epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
2265
2284
  outputIdxExp: vars.gidx,
2266
2285
  threadCount: kernel.size,
@@ -2269,8 +2288,9 @@ function tuneNullopt(kernel) {
2269
2288
  }
2270
2289
  /** Tuning for WebGPU kernels. */
2271
2290
  function tuneWebgpu(kernel) {
2272
- const { exp, reduction } = kernel;
2291
+ const reduction = kernel.reduction;
2273
2292
  if (!reduction) return tuneNullopt(kernel);
2293
+ const exp = AluExp.cast(reduction.dtype, kernel.exp);
2274
2294
  const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
2275
2295
  if (globalIndexes.length > 0) {
2276
2296
  if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
@@ -2322,7 +2342,7 @@ function tuneWebgpu(kernel) {
2322
2342
  if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2323
2343
  const s = dim.st.shape[dim.unroll - 1];
2324
2344
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2325
- else for (const splits of [8, 4]) if (s % splits === 0) {
2345
+ else for (const splits of [4, 2]) if (s % splits === 0) {
2326
2346
  dim.applyUnroll(dim.unroll - 1, splits);
2327
2347
  break;
2328
2348
  }
@@ -2498,6 +2518,85 @@ var CpuBackend = class {
2498
2518
  }
2499
2519
  };
2500
2520
 
2521
+ //#endregion
2522
+ //#region src/tracing.ts
2523
+ let traceEnabled = false;
2524
+ const flushCallbacks = [];
2525
+ /**
2526
+ * Start collecting kernel traces.
2527
+ *
2528
+ * Traces appear in developer tools under the "Performance" tab, and they are
2529
+ * useful for measuring fine-grained kernel execution time.
2530
+ */
2531
+ function startTrace() {
2532
+ traceEnabled = true;
2533
+ }
2534
+ /**
2535
+ * Stop collecting kernel traces.
2536
+ *
2537
+ * Traces appear in developer tools under the "Performance" tab, and they are
2538
+ * useful for measuring fine-grained kernel execution time.
2539
+ */
2540
+ function stopTrace() {
2541
+ traceEnabled = false;
2542
+ for (const cb of flushCallbacks) cb();
2543
+ }
2544
+ /** Check if tracing is currently enabled. */
2545
+ function isTracing() {
2546
+ return traceEnabled;
2547
+ }
2548
+ /** Register a callback to flush pending trace data when tracing stops. */
2549
+ function onFlushTrace(cb) {
2550
+ flushCallbacks.push(cb);
2551
+ }
2552
+ function humanSize(n) {
2553
+ if (n >= 1e9) return `${(n / 1e9).toPrecision(3)}B`;
2554
+ if (n >= 1e6) return `${(n / 1e6).toPrecision(3)}M`;
2555
+ if (n >= 1e3) return `${(n / 1e3).toPrecision(3)}K`;
2556
+ return `${n}`;
2557
+ }
2558
+ /** Build a trace label, properties, and color from a kernel or routine source. */
2559
+ function traceSourceInfo(source) {
2560
+ const properties = [];
2561
+ let label;
2562
+ let color;
2563
+ if (source instanceof Kernel) {
2564
+ label = `Kernel[${humanSize(source.size)}]`;
2565
+ properties.push(["exp", `${source.exp}`]);
2566
+ properties.push(["size", `${source.size}`]);
2567
+ properties.push(["nargs", `${source.nargs}`]);
2568
+ if (!source.reduction) color = "primary";
2569
+ else {
2570
+ color = "secondary";
2571
+ properties.push(["reduction", `${source.reduction.op}:${source.reduction.size}`]);
2572
+ }
2573
+ } else {
2574
+ color = "tertiary";
2575
+ label = source.name;
2576
+ properties.push(["inputShapes", source.type.inputShapes.map((s) => `[${s}]`).join(", ")]);
2577
+ properties.push(["outputShapes", source.type.outputShapes.map((s) => `[${s}]`).join(", ")]);
2578
+ properties.push(["dtype", source.type.inputDtypes.join(", ")]);
2579
+ }
2580
+ return {
2581
+ label,
2582
+ color,
2583
+ properties
2584
+ };
2585
+ }
2586
+ /** Emit a trace entry as a `performance.measure` with devtools metadata. */
2587
+ function emitTrace(track, info, start, end) {
2588
+ performance.measure(info.label, {
2589
+ detail: { devtools: {
2590
+ trackGroup: "JAX Profiler",
2591
+ track,
2592
+ color: info.color,
2593
+ properties: info.properties
2594
+ } },
2595
+ start,
2596
+ end
2597
+ });
2598
+ }
2599
+
2501
2600
  //#endregion
2502
2601
  //#region src/backend/wasm/allocator.ts
2503
2602
  /** Simple tensor memory allocator for WebAssembly linear memory. */
@@ -3904,11 +4003,19 @@ var WasmBackend = class {
3904
4003
  return new Executable(routine, void 0);
3905
4004
  }
3906
4005
  dispatch(exe, inputs, outputs) {
3907
- if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
3908
- const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
3909
- const func = instance.exports.kernel;
3910
- const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
3911
- func(...ptrs);
4006
+ const tracing = isTracing();
4007
+ const start = tracing ? performance.now() : 0;
4008
+ if (exe.source instanceof Routine) runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
4009
+ else {
4010
+ const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
4011
+ const func = instance.exports.kernel;
4012
+ const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
4013
+ func(...ptrs);
4014
+ }
4015
+ if (tracing) {
4016
+ const info = traceSourceInfo(exe.source);
4017
+ emitTrace("wasm", info, start, performance.now());
4018
+ }
3912
4019
  }
3913
4020
  #getBuffer(slot) {
3914
4021
  const buffer = this.#buffers.get(slot);
@@ -4253,7 +4360,7 @@ async function createBackend(device) {
4253
4360
  if (!navigator.gpu) return null;
4254
4361
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4255
4362
  if (!adapter) return null;
4256
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BykvF26B.cjs"));
4363
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-DMSx7a6M.cjs"));
4257
4364
  const importantLimits = [
4258
4365
  "maxBufferSize",
4259
4366
  "maxComputeInvocationsPerWorkgroup",
@@ -4291,7 +4398,7 @@ async function createBackend(device) {
4291
4398
  });
4292
4399
  if (!gl) return null;
4293
4400
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4294
- const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-DIIbKJ0G.cjs"));
4401
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-kvVt7-T7.cjs"));
4295
4402
  return new WebGLBackend(gl);
4296
4403
  } else throw new Error(`Backend not found: ${device}`);
4297
4404
  }
@@ -4495,6 +4602,12 @@ Object.defineProperty(exports, 'dtypedJsArray', {
4495
4602
  return dtypedJsArray;
4496
4603
  }
4497
4604
  });
4605
+ Object.defineProperty(exports, 'emitTrace', {
4606
+ enumerable: true,
4607
+ get: function () {
4608
+ return emitTrace;
4609
+ }
4610
+ });
4498
4611
  Object.defineProperty(exports, 'findPow2', {
4499
4612
  enumerable: true,
4500
4613
  get: function () {
@@ -4543,6 +4656,12 @@ Object.defineProperty(exports, 'isPermutation', {
4543
4656
  return isPermutation;
4544
4657
  }
4545
4658
  });
4659
+ Object.defineProperty(exports, 'isTracing', {
4660
+ enumerable: true,
4661
+ get: function () {
4662
+ return isTracing;
4663
+ }
4664
+ });
4546
4665
  Object.defineProperty(exports, 'mapSetUnion', {
4547
4666
  enumerable: true,
4548
4667
  get: function () {
@@ -4555,6 +4674,12 @@ Object.defineProperty(exports, 'normalizeAxis', {
4555
4674
  return normalizeAxis;
4556
4675
  }
4557
4676
  });
4677
+ Object.defineProperty(exports, 'onFlushTrace', {
4678
+ enumerable: true,
4679
+ get: function () {
4680
+ return onFlushTrace;
4681
+ }
4682
+ });
4558
4683
  Object.defineProperty(exports, 'partitionList', {
4559
4684
  enumerable: true,
4560
4685
  get: function () {
@@ -4603,6 +4728,18 @@ Object.defineProperty(exports, 'setDebug', {
4603
4728
  return setDebug;
4604
4729
  }
4605
4730
  });
4731
+ Object.defineProperty(exports, 'startTrace', {
4732
+ enumerable: true,
4733
+ get: function () {
4734
+ return startTrace;
4735
+ }
4736
+ });
4737
+ Object.defineProperty(exports, 'stopTrace', {
4738
+ enumerable: true,
4739
+ get: function () {
4740
+ return stopTrace;
4741
+ }
4742
+ });
4606
4743
  Object.defineProperty(exports, 'strip1', {
4607
4744
  enumerable: true,
4608
4745
  get: function () {
@@ -4615,6 +4752,12 @@ Object.defineProperty(exports, 'toposort', {
4615
4752
  return toposort;
4616
4753
  }
4617
4754
  });
4755
+ Object.defineProperty(exports, 'traceSourceInfo', {
4756
+ enumerable: true,
4757
+ get: function () {
4758
+ return traceSourceInfo;
4759
+ }
4760
+ });
4618
4761
  Object.defineProperty(exports, 'tuneNullopt', {
4619
4762
  enumerable: true,
4620
4763
  get: function () {