@jax-js/jax 0.1.5 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -43,6 +43,55 @@ way to get started on a blank HTML page.
43
43
  </script>
44
44
  ```
45
45
 
46
+ ## Feature comparison
47
+
48
+ Here's a quick, high-level comparison with other popular web ML runtimes:
49
+
50
+ | Feature | jax-js | TensorFlow.js | onnxruntime-web |
51
+ | ------------------------------- | ---------- | --------------- | ------------------ |
52
+ | **Overview** | | | |
53
+ | API style | JAX/NumPy | TensorFlow-like | Static ONNX graphs |
54
+ | Latest release | 2026 | ⚠️ 2024 | 2026 |
55
+ | Speed | Fastest | Fast | Fastest |
56
+ | Bundle size (gzip) | 80 KB | 269 KB | 90 KB + 24 MB Wasm |
57
+ | **Autodiff & JIT** | | | |
58
+ | Gradients | ✅ | ✅ | ❌ |
59
+ | Jacobian and Hessian | ✅ | ❌ | ❌ |
60
+ | `jvp()` forward differentiation | ✅ | ❌ | ❌ |
61
+ | `jit()` kernel fusion | ✅ | ❌ | ❌ |
62
+ | `vmap()` auto-vectorization | ✅ | ❌ | ❌ |
63
+ | Graph capture | ✅ | ❌ | ✅ |
64
+ | **Backends & Data** | | | |
65
+ | WebGPU backend | ✅ | 🟡 Preview | ✅ |
66
+ | WebGL backend | ✅ | ✅ | ✅ |
67
+ | Wasm (CPU) backend | ✅ | ✅ | ✅ |
68
+ | Eager array API | ✅ | ✅ | ❌ |
69
+ | Run ONNX models | 🟡 Partial | ❌ | ✅ |
70
+ | Read safetensors | ✅ | ❌ | ❌ |
71
+ | Float64 | ✅ | ❌ | ❌ |
72
+ | Float32 | ✅ | ✅ | ✅ |
73
+ | Float16 | ✅ | ❌ | ✅ |
74
+ | BFloat16 | ❌ | ❌ | ❌ |
75
+ | Packed Uint8 | ❌ | ❌ | 🟡 Partial |
76
+ | Mixed precision | ✅ | ❌ | ✅ |
77
+ | Mixed devices | ✅ | ❌ | ❌ |
78
+ | **Ops & Numerics** | | | |
79
+ | Arithmetic functions | ✅ | ✅ | ✅ |
80
+ | Matrix multiplication | ✅ | ✅ | ✅ |
81
+ | General einsum | ✅ | 🟡 Partial | 🟡 Partial |
82
+ | Sorting | ✅ | ❌ | ❌ |
83
+ | Activation functions | ✅ | ✅ | ✅ |
84
+ | NaN/Inf numerics | ✅ | ✅ | ✅ |
85
+ | Basic convolutions | ✅ | ✅ | ✅ |
86
+ | n-d convolutions | ✅ | ❌ | ✅ |
87
+ | Strided/dilated convolution | ✅ | ✅ | ✅ |
88
+ | Cholesky, Lstsq | ✅ | ❌ | ❌ |
89
+ | LU, Solve, Determinant | ✅ | ❌ | ❌ |
90
+ | SVD | ❌ | ❌ | ❌ |
91
+ | FFT | ✅ | ✅ | ✅ |
92
+ | Basic RNG (Uniform, Normal) | ✅ | ✅ | ✅ |
93
+ | Advanced RNG | ✅ | ❌ | ❌ |
94
+
46
95
  ## Tutorial
47
96
 
48
97
  Programming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),
@@ -271,15 +320,18 @@ self-contained way in other projects.
271
320
 
272
321
  ### Performance
273
322
 
274
- We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
275
- very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
276
- The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
277
- chip ([try it](https://jax-js.com/bench/matmul)).
323
+ The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual
324
+ browsers. Also, this library uniquely has the `jit()` feature that fuses operations together and
325
+ records an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an
326
+ Apple M4 Max chip ([try it](https://jax-js.com/bench/matmul)).
278
327
 
279
- For that example, it's around the same GFLOP/s as
328
+ For that example, it's significantly faster than both
280
329
  [TensorFlow.js](https://github.com/tensorflow/tfjs) and
281
330
  [ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
282
- libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
331
+ libraries of custom kernels.
332
+
333
+ It's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as
334
+ well as unique optimizations such as FlashAttention variants.
283
335
 
284
336
  ### API Reference
285
337
 
@@ -291,8 +343,9 @@ That's all for this short tutorial. Please see the generated
291
343
  If you make something cool with jax-js, don't be a stranger! We can feature it here.
292
344
 
293
345
  - [Training neural networks on MNIST](https://jax-js.com/mnist)
346
+ - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
294
347
  - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
295
- - [Object detection with DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
348
+ - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
296
349
  - [In-browser REPL](https://jax-js.com/repl)
297
350
  - [Matmul benchmark](https://jax-js.com/bench/matmul)
298
351
  - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
@@ -151,6 +151,19 @@ function normalizeAxis(axis, ndim) {
151
151
  return sorted(seen);
152
152
  }
153
153
  }
154
+ /** Check for an array of integers with no duplicates. */
155
+ function checkInts(indices) {
156
+ if (typeof indices === "number") {
157
+ if (!Number.isInteger(indices)) throw new TypeError(`Expected integer index, got ${indices}`);
158
+ } else {
159
+ const seen = /* @__PURE__ */ new Set();
160
+ for (const i of indices) {
161
+ if (!Number.isInteger(i)) throw new TypeError(`Expected integer indices, got ${i}`);
162
+ if (seen.has(i)) throw new Error(`Duplicate index ${i} passed to function`);
163
+ seen.add(i);
164
+ }
165
+ }
166
+ }
154
167
  function range(start, stop, step = 1) {
155
168
  if (stop === void 0) {
156
169
  stop = start;
@@ -897,7 +910,7 @@ var AluExp = class AluExp {
897
910
  return ret.simplify(cache);
898
911
  }
899
912
  }
900
- if (y.arg > 0) {
913
+ if (y.arg > 0 && x.min >= 0) {
901
914
  let [xNoConst, constVal] = [x, 0];
902
915
  if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
903
916
  const terms = [];
@@ -919,7 +932,7 @@ var AluExp = class AluExp {
919
932
  rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
920
933
  quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
921
934
  }
922
- if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
935
+ if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
923
936
  else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
924
937
  }
925
938
  }
@@ -2306,10 +2319,10 @@ function tuneWebgpu(kernel) {
2306
2319
  upcastedAxis.add(choices[0][2]);
2307
2320
  } else break;
2308
2321
  }
2309
- if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2322
+ if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2310
2323
  const s = dim.st.shape[dim.unroll - 1];
2311
2324
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2312
- else for (const splits of [4]) if (s % splits === 0) {
2325
+ else for (const splits of [8, 4]) if (s % splits === 0) {
2313
2326
  dim.applyUnroll(dim.unroll - 1, splits);
2314
2327
  break;
2315
2328
  }
@@ -4240,7 +4253,7 @@ async function createBackend(device) {
4240
4253
  if (!navigator.gpu) return null;
4241
4254
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4242
4255
  if (!adapter) return null;
4243
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Db2JrNBr.cjs"));
4256
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BykvF26B.cjs"));
4244
4257
  const importantLimits = [
4245
4258
  "maxBufferSize",
4246
4259
  "maxComputeInvocationsPerWorkgroup",
@@ -4278,7 +4291,7 @@ async function createBackend(device) {
4278
4291
  });
4279
4292
  if (!gl) return null;
4280
4293
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4281
- const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-ClIYb8jP.cjs"));
4294
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-DIIbKJ0G.cjs"));
4282
4295
  return new WebGLBackend(gl);
4283
4296
  } else throw new Error(`Backend not found: ${device}`);
4284
4297
  }
@@ -4446,6 +4459,12 @@ Object.defineProperty(exports, 'checkAxis', {
4446
4459
  return checkAxis;
4447
4460
  }
4448
4461
  });
4462
+ Object.defineProperty(exports, 'checkInts', {
4463
+ enumerable: true,
4464
+ get: function () {
4465
+ return checkInts;
4466
+ }
4467
+ });
4449
4468
  Object.defineProperty(exports, 'deepEqual', {
4450
4469
  enumerable: true,
4451
4470
  get: function () {
@@ -150,6 +150,19 @@ function normalizeAxis(axis, ndim) {
150
150
  return sorted(seen);
151
151
  }
152
152
  }
153
+ /** Check for an array of integers with no duplicates. */
154
+ function checkInts(indices) {
155
+ if (typeof indices === "number") {
156
+ if (!Number.isInteger(indices)) throw new TypeError(`Expected integer index, got ${indices}`);
157
+ } else {
158
+ const seen = /* @__PURE__ */ new Set();
159
+ for (const i of indices) {
160
+ if (!Number.isInteger(i)) throw new TypeError(`Expected integer indices, got ${i}`);
161
+ if (seen.has(i)) throw new Error(`Duplicate index ${i} passed to function`);
162
+ seen.add(i);
163
+ }
164
+ }
165
+ }
153
166
  function range(start, stop, step = 1) {
154
167
  if (stop === void 0) {
155
168
  stop = start;
@@ -896,7 +909,7 @@ var AluExp = class AluExp {
896
909
  return ret.simplify(cache);
897
910
  }
898
911
  }
899
- if (y.arg > 0) {
912
+ if (y.arg > 0 && x.min >= 0) {
900
913
  let [xNoConst, constVal] = [x, 0];
901
914
  if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
902
915
  const terms = [];
@@ -918,7 +931,7 @@ var AluExp = class AluExp {
918
931
  rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
919
932
  quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
920
933
  }
921
- if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
934
+ if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
922
935
  else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
923
936
  }
924
937
  }
@@ -2305,10 +2318,10 @@ function tuneWebgpu(kernel) {
2305
2318
  upcastedAxis.add(choices[0][2]);
2306
2319
  } else break;
2307
2320
  }
2308
- if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2321
+ if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2309
2322
  const s = dim.st.shape[dim.unroll - 1];
2310
2323
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2311
- else for (const splits of [4]) if (s % splits === 0) {
2324
+ else for (const splits of [8, 4]) if (s % splits === 0) {
2312
2325
  dim.applyUnroll(dim.unroll - 1, splits);
2313
2326
  break;
2314
2327
  }
@@ -4239,7 +4252,7 @@ async function createBackend(device) {
4239
4252
  if (!navigator.gpu) return null;
4240
4253
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4241
4254
  if (!adapter) return null;
4242
- const { WebGPUBackend } = await import("./webgpu-Dh7k9io0.js");
4255
+ const { WebGPUBackend } = await import("./webgpu-B96vzWGE.js");
4243
4256
  const importantLimits = [
4244
4257
  "maxBufferSize",
4245
4258
  "maxComputeInvocationsPerWorkgroup",
@@ -4277,7 +4290,7 @@ async function createBackend(device) {
4277
4290
  });
4278
4291
  if (!gl) return null;
4279
4292
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4280
- const { WebGLBackend } = await import("./webgl-RSuZKvgc.js");
4293
+ const { WebGLBackend } = await import("./webgl-DweKSWEm.js");
4281
4294
  return new WebGLBackend(gl);
4282
4295
  } else throw new Error(`Backend not found: ${device}`);
4283
4296
  }
@@ -4313,4 +4326,4 @@ var UnsupportedRoutineError = class extends Error {
4313
4326
  };
4314
4327
 
4315
4328
  //#endregion
4316
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
4329
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };