@jax-js/jax 0.1.6 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -43,6 +43,55 @@ way to get started on a blank HTML page.
43
43
  </script>
44
44
  ```
45
45
 
46
+ ## Feature comparison
47
+
48
+ Here's a quick, high-level comparison with other popular web ML runtimes:
49
+
50
+ | Feature | jax-js | TensorFlow.js | onnxruntime-web |
51
+ | ------------------------------- | ---------- | --------------- | ------------------ |
52
+ | **Overview** | | | |
53
+ | API style | JAX/NumPy | TensorFlow-like | Static ONNX graphs |
54
+ | Latest release | 2026 | ⚠️ 2024 | 2026 |
55
+ | Speed | Fastest | Fast | Fastest |
56
+ | Bundle size (gzip) | 80 KB | 269 KB | 90 KB + 24 MB Wasm |
57
+ | **Autodiff & JIT** | | | |
58
+ | Gradients | ✅ | ✅ | ❌ |
59
+ | Jacobian and Hessian | ✅ | ❌ | ❌ |
60
+ | `jvp()` forward differentiation | ✅ | ❌ | ❌ |
61
+ | `jit()` kernel fusion | ✅ | ❌ | ❌ |
62
+ | `vmap()` auto-vectorization | ✅ | ❌ | ❌ |
63
+ | Graph capture | ✅ | ❌ | ✅ |
64
+ | **Backends & Data** | | | |
65
+ | WebGPU backend | ✅ | 🟡 Preview | ✅ |
66
+ | WebGL backend | ✅ | ✅ | ✅ |
67
+ | Wasm (CPU) backend | ✅ | ✅ | ✅ |
68
+ | Eager array API | ✅ | ✅ | ❌ |
69
+ | Run ONNX models | 🟡 Partial | ❌ | ✅ |
70
+ | Read safetensors | ✅ | ❌ | ❌ |
71
+ | Float64 | ✅ | ❌ | ❌ |
72
+ | Float32 | ✅ | ✅ | ✅ |
73
+ | Float16 | ✅ | ❌ | ✅ |
74
+ | BFloat16 | ❌ | ❌ | ❌ |
75
+ | Packed Uint8 | ❌ | ❌ | 🟡 Partial |
76
+ | Mixed precision | ✅ | ❌ | ✅ |
77
+ | Mixed devices | ✅ | ❌ | ❌ |
78
+ | **Ops & Numerics** | | | |
79
+ | Arithmetic functions | ✅ | ✅ | ✅ |
80
+ | Matrix multiplication | ✅ | ✅ | ✅ |
81
+ | General einsum | ✅ | 🟡 Partial | 🟡 Partial |
82
+ | Sorting | ✅ | ❌ | ❌ |
83
+ | Activation functions | ✅ | ✅ | ✅ |
84
+ | NaN/Inf numerics | ✅ | ✅ | ✅ |
85
+ | Basic convolutions | ✅ | ✅ | ✅ |
86
+ | n-d convolutions | ✅ | ❌ | ✅ |
87
+ | Strided/dilated convolution | ✅ | ✅ | ✅ |
88
+ | Cholesky, Lstsq | ✅ | ❌ | ❌ |
89
+ | LU, Solve, Determinant | ✅ | ❌ | ❌ |
90
+ | SVD | ❌ | ❌ | ❌ |
91
+ | FFT | ✅ | ✅ | ✅ |
92
+ | Basic RNG (Uniform, Normal) | ✅ | ✅ | ✅ |
93
+ | Advanced RNG | ✅ | ❌ | ❌ |
94
+
46
95
  ## Tutorial
47
96
 
48
97
  Programming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),
@@ -271,15 +320,18 @@ self-contained way in other projects.
271
320
 
272
321
  ### Performance
273
322
 
274
- We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
275
- very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
276
- The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
277
- chip ([try it](https://jax-js.com/bench/matmul)).
323
+ The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual
324
+ browsers. Also, this library uniquely has the `jit()` feature that fuses operations together and
325
+ records an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an
326
+ Apple M4 Max chip ([try it](https://jax-js.com/bench/matmul)).
278
327
 
279
- For that example, it's around the same GFLOP/s as
328
+ For that example, it's significantly faster than both
280
329
  [TensorFlow.js](https://github.com/tensorflow/tfjs) and
281
330
  [ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
282
- libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
331
+ libraries of custom kernels.
332
+
333
+ It's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as
334
+ well as unique optimizations such as FlashAttention variants.
283
335
 
284
336
  ### API Reference
285
337
 
@@ -291,8 +343,9 @@ That's all for this short tutorial. Please see the generated
291
343
  If you make something cool with jax-js, don't be a stranger! We can feature it here.
292
344
 
293
345
  - [Training neural networks on MNIST](https://jax-js.com/mnist)
346
+ - [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
294
347
  - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
295
- - [Object detection with DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
348
+ - [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
296
349
  - [In-browser REPL](https://jax-js.com/repl)
297
350
  - [Matmul benchmark](https://jax-js.com/bench/matmul)
298
351
  - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
@@ -910,7 +910,7 @@ var AluExp = class AluExp {
910
910
  return ret.simplify(cache);
911
911
  }
912
912
  }
913
- if (y.arg > 0) {
913
+ if (y.arg > 0 && x.min >= 0) {
914
914
  let [xNoConst, constVal] = [x, 0];
915
915
  if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
916
916
  const terms = [];
@@ -932,7 +932,7 @@ var AluExp = class AluExp {
932
932
  rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
933
933
  quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
934
934
  }
935
- if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
935
+ if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
936
936
  else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
937
937
  }
938
938
  }
@@ -4253,7 +4253,7 @@ async function createBackend(device) {
4253
4253
  if (!navigator.gpu) return null;
4254
4254
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4255
4255
  if (!adapter) return null;
4256
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-rraa6dfz.cjs"));
4256
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BykvF26B.cjs"));
4257
4257
  const importantLimits = [
4258
4258
  "maxBufferSize",
4259
4259
  "maxComputeInvocationsPerWorkgroup",
@@ -4291,7 +4291,7 @@ async function createBackend(device) {
4291
4291
  });
4292
4292
  if (!gl) return null;
4293
4293
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4294
- const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-CyfzNW8T.cjs"));
4294
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-DIIbKJ0G.cjs"));
4295
4295
  return new WebGLBackend(gl);
4296
4296
  } else throw new Error(`Backend not found: ${device}`);
4297
4297
  }
@@ -909,7 +909,7 @@ var AluExp = class AluExp {
909
909
  return ret.simplify(cache);
910
910
  }
911
911
  }
912
- if (y.arg > 0) {
912
+ if (y.arg > 0 && x.min >= 0) {
913
913
  let [xNoConst, constVal] = [x, 0];
914
914
  if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
915
915
  const terms = [];
@@ -931,7 +931,7 @@ var AluExp = class AluExp {
931
931
  rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
932
932
  quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
933
933
  }
934
- if (!((x.min < 0 || rem.min < 0) && remainders.some((r) => r !== 0))) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
934
+ if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
935
935
  else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
936
936
  }
937
937
  }
@@ -4252,7 +4252,7 @@ async function createBackend(device) {
4252
4252
  if (!navigator.gpu) return null;
4253
4253
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4254
4254
  if (!adapter) return null;
4255
- const { WebGPUBackend } = await import("./webgpu-C-VfevQW.js");
4255
+ const { WebGPUBackend } = await import("./webgpu-B96vzWGE.js");
4256
4256
  const importantLimits = [
4257
4257
  "maxBufferSize",
4258
4258
  "maxComputeInvocationsPerWorkgroup",
@@ -4290,7 +4290,7 @@ async function createBackend(device) {
4290
4290
  });
4291
4291
  if (!gl) return null;
4292
4292
  if (!gl.getExtension("EXT_color_buffer_float")) return null;
4293
- const { WebGLBackend } = await import("./webgl-CLLvzJlO.js");
4293
+ const { WebGLBackend } = await import("./webgl-DweKSWEm.js");
4294
4294
  return new WebGLBackend(gl);
4295
4295
  } else throw new Error(`Backend not found: ${device}`);
4296
4296
  }
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-D7s-Retx.cjs');
33
+ const require_backend = require('./backend-B3foXiV_.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -337,11 +337,11 @@ function map(fn, tree, ...rest) {
337
337
  }
338
338
  /** Take a reference of every array in a tree. */
339
339
  function ref(tree) {
340
- return map((x) => x.ref, tree);
340
+ return map((x) => x instanceof Tracer ? x.ref : x, tree);
341
341
  }
342
342
  /** Dispose every array in a tree. */
343
343
  function dispose(tree) {
344
- if (tree) map((x) => x.dispose(), tree);
344
+ if (tree) map((x) => x instanceof Tracer ? x.dispose() : void 0, tree);
345
345
  }
346
346
 
347
347
  //#endregion
@@ -615,14 +615,20 @@ function shrink(x, slice) {
615
615
  }
616
616
  function pad$1(x, width) {
617
617
  const nd = ndim$1(x);
618
- if (typeof width === "number") width = [[width, width]];
619
- else if (require_backend.isNumberPair(width)) width = [width];
620
- else if (!Array.isArray(width) || !width.every(require_backend.isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
621
- if (width.length === 1) {
622
- const [w0, w1] = width[0];
623
- width = require_backend.rep(nd, () => [w0, w1]);
624
- } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
625
- return bind1(Primitive.Pad, [x], { width });
618
+ let w;
619
+ if (typeof width === "number") w = [[width, width]];
620
+ else if (require_backend.isNumberPair(width)) w = [width];
621
+ else if (!Array.isArray(width)) {
622
+ const indicesAndPairs = Object.entries(width);
623
+ w = require_backend.rep(nd, [0, 0]);
624
+ for (const [k, v] of indicesAndPairs) w[require_backend.checkAxis(parseInt(k), nd)] = v;
625
+ } else if (!width.every(require_backend.isNumberPair)) throw new TypeError(`Invalid pad() type: ${JSON.stringify(width)}`);
626
+ else w = width;
627
+ if (w.length === 1) {
628
+ const [w0, w1] = w[0];
629
+ w = require_backend.rep(nd, () => [w0, w1]);
630
+ } else if (w.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${w.length}`);
631
+ return bind1(Primitive.Pad, [x], { width: w });
626
632
  }
627
633
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
628
634
  const as = getShape(a);
@@ -798,6 +804,22 @@ var Tracer = class Tracer {
798
804
  const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
799
805
  return result.mul(1 / n).astype(originalDtype);
800
806
  }
807
+ /** Minimum of the elements of the array along a given axis. */
808
+ min(axis = null, opts) {
809
+ return reduce(this, require_backend.AluOp.Min, axis, opts);
810
+ }
811
+ /** Maximum of the elements of the array along a given axis. */
812
+ max(axis = null, opts) {
813
+ return reduce(this, require_backend.AluOp.Max, axis, opts);
814
+ }
815
+ /** Test whether all array elements along a given axis evaluate to true. */
816
+ all(axis = null, opts) {
817
+ return this.astype(require_backend.DType.Bool).min(axis, opts);
818
+ }
819
+ /** Test whether any array element along a given axis evaluates to true. */
820
+ any(axis = null, opts) {
821
+ return this.astype(require_backend.DType.Bool).max(axis, opts);
822
+ }
801
823
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
802
824
  transpose(perm) {
803
825
  return transpose$1(this, perm);
@@ -5607,6 +5629,7 @@ __export(numpy_exports, {
5607
5629
  moveaxis: () => moveaxis$1,
5608
5630
  multiply: () => multiply,
5609
5631
  nan: () => nan,
5632
+ nanToNum: () => nanToNum,
5610
5633
  ndim: () => ndim,
5611
5634
  negative: () => negative,
5612
5635
  notEqual: () => notEqual,
@@ -5804,24 +5827,22 @@ function max(a, axis = null, opts) {
5804
5827
  return reduce(a, require_backend.AluOp.Max, axis, opts);
5805
5828
  }
5806
5829
  /**
5807
- * Test whether all array elements along a given axis evaluate to True.
5830
+ * Test whether any array element along a given axis evaluates to True.
5808
5831
  *
5809
5832
  * Returns a boolean array with the same shape as `a` with the specified axis
5810
5833
  * removed. If axis is None, returns a scalar.
5811
5834
  */
5812
- function all(a, axis = null, opts) {
5813
- a = fudgeArray(a).astype(require_backend.DType.Bool);
5814
- return min(a, axis, opts);
5835
+ function any(a, axis = null, opts) {
5836
+ return fudgeArray(a).any(axis, opts);
5815
5837
  }
5816
5838
  /**
5817
- * Test whether any array element along a given axis evaluates to True.
5839
+ * Test whether all array elements along a given axis evaluate to True.
5818
5840
  *
5819
5841
  * Returns a boolean array with the same shape as `a` with the specified axis
5820
5842
  * removed. If axis is None, returns a scalar.
5821
5843
  */
5822
- function any(a, axis = null, opts) {
5823
- a = fudgeArray(a).astype(require_backend.DType.Bool);
5824
- return max(a, axis, opts);
5844
+ function all(a, axis = null, opts) {
5845
+ return fudgeArray(a).all(axis, opts);
5825
5846
  }
5826
5847
  /** Return the peak-to-peak range along a given axis (`max - min`). */
5827
5848
  function ptp(a, axis = null, opts) {
@@ -5922,7 +5943,7 @@ function split$1(a, indicesOrSections, axis = 0) {
5922
5943
  const partSize = size$1 / indicesOrSections;
5923
5944
  sizes = require_backend.rep(indicesOrSections, partSize);
5924
5945
  } else {
5925
- const indices = indicesOrSections;
5946
+ const indices = indicesOrSections.map((i) => i < 0 ? i + size$1 : i);
5926
5947
  sizes = [indices[0]];
5927
5948
  for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5928
5949
  sizes.push(size$1 - indices[indices.length - 1]);
@@ -6870,6 +6891,21 @@ function isposinf(x) {
6870
6891
  return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
6871
6892
  }
6872
6893
  /**
6894
+ * Replace NaN and infinite entries in an array.
6895
+ *
6896
+ * By default, NaNs are replaced with `0.0`, and infinities are are substituted
6897
+ * with the corresponding maximum or minimum finite values.
6898
+ */
6899
+ function nanToNum(x, { nan: nan$1 = 0, posinf = null, neginf = null } = {}) {
6900
+ x = fudgeArray(x);
6901
+ x = where(isnan(x.ref), nan$1, x);
6902
+ posinf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).max : iinfo(x.dtype).max;
6903
+ neginf ??= require_backend.isFloatDtype(x.dtype) ? finfo(x.dtype).min : iinfo(x.dtype).min;
6904
+ x = where(isposinf(x.ref), posinf, x);
6905
+ x = where(isneginf(x.ref), neginf, x);
6906
+ return x;
6907
+ }
6908
+ /**
6873
6909
  * @function
6874
6910
  * Test element-wise for finite values (not infinity or NaN).
6875
6911
  */
@@ -7578,8 +7614,6 @@ function oneHot(x, numClasses) {
7578
7614
  * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7579
7615
  */
7580
7616
  function dotProductAttention(query, key$1, value, opts = {}) {
7581
- if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
7582
- if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
7583
7617
  query = fudgeArray(query);
7584
7618
  key$1 = fudgeArray(key$1);
7585
7619
  value = fudgeArray(value);
@@ -7617,6 +7651,38 @@ function dotProductAttention(query, key$1, value, opts = {}) {
7617
7651
  const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
7618
7652
  scores = where(causalMask, scores, -Infinity);
7619
7653
  }
7654
+ if (opts.localWindowSize !== void 0) {
7655
+ const [before, after] = typeof opts.localWindowSize === "number" ? [opts.localWindowSize, opts.localWindowSize] : opts.localWindowSize;
7656
+ if (before < 0 || after < 0 || !Number.isInteger(before) || !Number.isInteger(after)) throw new Error(`dotProductAttention: localWindowSize values must be non-negative, got ${opts.localWindowSize}`);
7657
+ const localMask = tri(L, S, after, { dtype: require_backend.DType.Bool }).mul(tri(L, S, -before - 1, { dtype: require_backend.DType.Bool }).notEqual(true));
7658
+ scores = where(localMask, scores, -Infinity);
7659
+ }
7660
+ if (opts.querySeqLengths !== void 0) {
7661
+ const sl = expandDims(opts.querySeqLengths, [
7662
+ -1,
7663
+ -2,
7664
+ -3
7665
+ ]);
7666
+ scores = where(arange(L).reshape([
7667
+ 1,
7668
+ 1,
7669
+ L,
7670
+ 1
7671
+ ]).less(sl), scores, -Infinity);
7672
+ }
7673
+ if (opts.keyValueSeqLengths !== void 0) {
7674
+ const sl = expandDims(opts.keyValueSeqLengths, [
7675
+ -1,
7676
+ -2,
7677
+ -3
7678
+ ]);
7679
+ scores = where(arange(S).reshape([
7680
+ 1,
7681
+ 1,
7682
+ 1,
7683
+ S
7684
+ ]).less(sl), scores, -Infinity);
7685
+ }
7620
7686
  const attn = softmax(scores, -1);
7621
7687
  const out = einsum("BNLS,BSNH->BLNH", attn, value);
7622
7688
  return isRank3 ? out.reshape([