@jax-js/jax 0.1.5 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +60 -7
- package/dist/{backend-DziQSaoQ.cjs → backend-B3foXiV_.cjs} +25 -6
- package/dist/{backend-DaqL-MNz.js → backend-nEolvdLv.js} +20 -7
- package/dist/index.cjs +450 -129
- package/dist/index.d.cts +1669 -1467
- package/dist/index.d.ts +1669 -1467
- package/dist/index.js +450 -130
- package/dist/{webgl-ClIYb8jP.cjs → webgl-DIIbKJ0G.cjs} +1 -1
- package/dist/{webgl-RSuZKvgc.js → webgl-DweKSWEm.js} +1 -1
- package/dist/{webgpu-Dh7k9io0.js → webgpu-B96vzWGE.js} +1 -1
- package/dist/{webgpu-Db2JrNBr.cjs → webgpu-BykvF26B.cjs} +1 -1
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -43,6 +43,55 @@ way to get started on a blank HTML page.
|
|
|
43
43
|
</script>
|
|
44
44
|
```
|
|
45
45
|
|
|
46
|
+
## Feature comparison
|
|
47
|
+
|
|
48
|
+
Here's a quick, high-level comparison with other popular web ML runtimes:
|
|
49
|
+
|
|
50
|
+
| Feature | jax-js | TensorFlow.js | onnxruntime-web |
|
|
51
|
+
| ------------------------------- | ---------- | --------------- | ------------------ |
|
|
52
|
+
| **Overview** | | | |
|
|
53
|
+
| API style | JAX/NumPy | TensorFlow-like | Static ONNX graphs |
|
|
54
|
+
| Latest release | 2026 | ⚠️ 2024 | 2026 |
|
|
55
|
+
| Speed | Fastest | Fast | Fastest |
|
|
56
|
+
| Bundle size (gzip) | 80 KB | 269 KB | 90 KB + 24 MB Wasm |
|
|
57
|
+
| **Autodiff & JIT** | | | |
|
|
58
|
+
| Gradients | ✅ | ✅ | ❌ |
|
|
59
|
+
| Jacobian and Hessian | ✅ | ❌ | ❌ |
|
|
60
|
+
| `jvp()` forward differentiation | ✅ | ❌ | ❌ |
|
|
61
|
+
| `jit()` kernel fusion | ✅ | ❌ | ❌ |
|
|
62
|
+
| `vmap()` auto-vectorization | ✅ | ❌ | ❌ |
|
|
63
|
+
| Graph capture | ✅ | ❌ | ✅ |
|
|
64
|
+
| **Backends & Data** | | | |
|
|
65
|
+
| WebGPU backend | ✅ | 🟡 Preview | ✅ |
|
|
66
|
+
| WebGL backend | ✅ | ✅ | ✅ |
|
|
67
|
+
| Wasm (CPU) backend | ✅ | ✅ | ✅ |
|
|
68
|
+
| Eager array API | ✅ | ✅ | ❌ |
|
|
69
|
+
| Run ONNX models | 🟡 Partial | ❌ | ✅ |
|
|
70
|
+
| Read safetensors | ✅ | ❌ | ❌ |
|
|
71
|
+
| Float64 | ✅ | ❌ | ❌ |
|
|
72
|
+
| Float32 | ✅ | ✅ | ✅ |
|
|
73
|
+
| Float16 | ✅ | ❌ | ✅ |
|
|
74
|
+
| BFloat16 | ❌ | ❌ | ❌ |
|
|
75
|
+
| Packed Uint8 | ❌ | ❌ | 🟡 Partial |
|
|
76
|
+
| Mixed precision | ✅ | ❌ | ✅ |
|
|
77
|
+
| Mixed devices | ✅ | ❌ | ❌ |
|
|
78
|
+
| **Ops & Numerics** | | | |
|
|
79
|
+
| Arithmetic functions | ✅ | ✅ | ✅ |
|
|
80
|
+
| Matrix multiplication | ✅ | ✅ | ✅ |
|
|
81
|
+
| General einsum | ✅ | 🟡 Partial | 🟡 Partial |
|
|
82
|
+
| Sorting | ✅ | ❌ | ❌ |
|
|
83
|
+
| Activation functions | ✅ | ✅ | ✅ |
|
|
84
|
+
| NaN/Inf numerics | ✅ | ✅ | ✅ |
|
|
85
|
+
| Basic convolutions | ✅ | ✅ | ✅ |
|
|
86
|
+
| n-d convolutions | ✅ | ❌ | ✅ |
|
|
87
|
+
| Strided/dilated convolution | ✅ | ✅ | ✅ |
|
|
88
|
+
| Cholesky, Lstsq | ✅ | ❌ | ❌ |
|
|
89
|
+
| LU, Solve, Determinant | ✅ | ❌ | ❌ |
|
|
90
|
+
| SVD | ❌ | ❌ | ❌ |
|
|
91
|
+
| FFT | ✅ | ✅ | ✅ |
|
|
92
|
+
| Basic RNG (Uniform, Normal) | ✅ | ✅ | ✅ |
|
|
93
|
+
| Advanced RNG | ✅ | ❌ | ❌ |
|
|
94
|
+
|
|
46
95
|
## Tutorial
|
|
47
96
|
|
|
48
97
|
Programming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),
|
|
@@ -271,15 +320,18 @@ self-contained way in other projects.
|
|
|
271
320
|
|
|
272
321
|
### Performance
|
|
273
322
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
chip ([try it](https://jax-js.com/bench/matmul)).
|
|
323
|
+
The WebGPU runtime includes an ML compiler with tile-aware optimizations, tuned for indiidual
|
|
324
|
+
browsers. Also, this library uniquely has the `jit()` feature that fuses operations together and
|
|
325
|
+
records an execution graph. jax-js achieves **over 7000 GFLOP/s** for matrix multiplication on an
|
|
326
|
+
Apple M4 Max chip ([try it](https://jax-js.com/bench/matmul)).
|
|
278
327
|
|
|
279
|
-
For that example, it's
|
|
328
|
+
For that example, it's significantly faster than both
|
|
280
329
|
[TensorFlow.js](https://github.com/tensorflow/tfjs) and
|
|
281
330
|
[ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
|
|
282
|
-
libraries of custom kernels
|
|
331
|
+
libraries of custom kernels.
|
|
332
|
+
|
|
333
|
+
It's still early though. There's a lot of low-hanging fruit to continue optimizing the library, as
|
|
334
|
+
well as unique optimizations such as FlashAttention variants.
|
|
283
335
|
|
|
284
336
|
### API Reference
|
|
285
337
|
|
|
@@ -291,8 +343,9 @@ That's all for this short tutorial. Please see the generated
|
|
|
291
343
|
If you make something cool with jax-js, don't be a stranger! We can feature it here.
|
|
292
344
|
|
|
293
345
|
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
346
|
+
- [Voice cloning: Kyutai Pocket TTS](https://jax-js.com/tts)
|
|
294
347
|
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
295
|
-
- [Object detection
|
|
348
|
+
- [Object detection: DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
|
|
296
349
|
- [In-browser REPL](https://jax-js.com/repl)
|
|
297
350
|
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
298
351
|
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
@@ -151,6 +151,19 @@ function normalizeAxis(axis, ndim) {
|
|
|
151
151
|
return sorted(seen);
|
|
152
152
|
}
|
|
153
153
|
}
|
|
154
|
+
/** Check for an array of integers with no duplicates. */
|
|
155
|
+
function checkInts(indices) {
|
|
156
|
+
if (typeof indices === "number") {
|
|
157
|
+
if (!Number.isInteger(indices)) throw new TypeError(`Expected integer index, got ${indices}`);
|
|
158
|
+
} else {
|
|
159
|
+
const seen = /* @__PURE__ */ new Set();
|
|
160
|
+
for (const i of indices) {
|
|
161
|
+
if (!Number.isInteger(i)) throw new TypeError(`Expected integer indices, got ${i}`);
|
|
162
|
+
if (seen.has(i)) throw new Error(`Duplicate index ${i} passed to function`);
|
|
163
|
+
seen.add(i);
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
}
|
|
154
167
|
function range(start, stop, step = 1) {
|
|
155
168
|
if (stop === void 0) {
|
|
156
169
|
stop = start;
|
|
@@ -897,7 +910,7 @@ var AluExp = class AluExp {
|
|
|
897
910
|
return ret.simplify(cache);
|
|
898
911
|
}
|
|
899
912
|
}
|
|
900
|
-
if (y.arg > 0) {
|
|
913
|
+
if (y.arg > 0 && x.min >= 0) {
|
|
901
914
|
let [xNoConst, constVal] = [x, 0];
|
|
902
915
|
if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
|
|
903
916
|
const terms = [];
|
|
@@ -919,7 +932,7 @@ var AluExp = class AluExp {
|
|
|
919
932
|
rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
|
|
920
933
|
quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
|
|
921
934
|
}
|
|
922
|
-
if (
|
|
935
|
+
if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
|
|
923
936
|
else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
|
|
924
937
|
}
|
|
925
938
|
}
|
|
@@ -2306,10 +2319,10 @@ function tuneWebgpu(kernel) {
|
|
|
2306
2319
|
upcastedAxis.add(choices[0][2]);
|
|
2307
2320
|
} else break;
|
|
2308
2321
|
}
|
|
2309
|
-
if (/
|
|
2322
|
+
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2310
2323
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2311
2324
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2312
|
-
else for (const splits of [4]) if (s % splits === 0) {
|
|
2325
|
+
else for (const splits of [8, 4]) if (s % splits === 0) {
|
|
2313
2326
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2314
2327
|
break;
|
|
2315
2328
|
}
|
|
@@ -4240,7 +4253,7 @@ async function createBackend(device) {
|
|
|
4240
4253
|
if (!navigator.gpu) return null;
|
|
4241
4254
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4242
4255
|
if (!adapter) return null;
|
|
4243
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4256
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BykvF26B.cjs"));
|
|
4244
4257
|
const importantLimits = [
|
|
4245
4258
|
"maxBufferSize",
|
|
4246
4259
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4278,7 +4291,7 @@ async function createBackend(device) {
|
|
|
4278
4291
|
});
|
|
4279
4292
|
if (!gl) return null;
|
|
4280
4293
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4281
|
-
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-
|
|
4294
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-DIIbKJ0G.cjs"));
|
|
4282
4295
|
return new WebGLBackend(gl);
|
|
4283
4296
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4284
4297
|
}
|
|
@@ -4446,6 +4459,12 @@ Object.defineProperty(exports, 'checkAxis', {
|
|
|
4446
4459
|
return checkAxis;
|
|
4447
4460
|
}
|
|
4448
4461
|
});
|
|
4462
|
+
Object.defineProperty(exports, 'checkInts', {
|
|
4463
|
+
enumerable: true,
|
|
4464
|
+
get: function () {
|
|
4465
|
+
return checkInts;
|
|
4466
|
+
}
|
|
4467
|
+
});
|
|
4449
4468
|
Object.defineProperty(exports, 'deepEqual', {
|
|
4450
4469
|
enumerable: true,
|
|
4451
4470
|
get: function () {
|
|
@@ -150,6 +150,19 @@ function normalizeAxis(axis, ndim) {
|
|
|
150
150
|
return sorted(seen);
|
|
151
151
|
}
|
|
152
152
|
}
|
|
153
|
+
/** Check for an array of integers with no duplicates. */
|
|
154
|
+
function checkInts(indices) {
|
|
155
|
+
if (typeof indices === "number") {
|
|
156
|
+
if (!Number.isInteger(indices)) throw new TypeError(`Expected integer index, got ${indices}`);
|
|
157
|
+
} else {
|
|
158
|
+
const seen = /* @__PURE__ */ new Set();
|
|
159
|
+
for (const i of indices) {
|
|
160
|
+
if (!Number.isInteger(i)) throw new TypeError(`Expected integer indices, got ${i}`);
|
|
161
|
+
if (seen.has(i)) throw new Error(`Duplicate index ${i} passed to function`);
|
|
162
|
+
seen.add(i);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
}
|
|
153
166
|
function range(start, stop, step = 1) {
|
|
154
167
|
if (stop === void 0) {
|
|
155
168
|
stop = start;
|
|
@@ -896,7 +909,7 @@ var AluExp = class AluExp {
|
|
|
896
909
|
return ret.simplify(cache);
|
|
897
910
|
}
|
|
898
911
|
}
|
|
899
|
-
if (y.arg > 0) {
|
|
912
|
+
if (y.arg > 0 && x.min >= 0) {
|
|
900
913
|
let [xNoConst, constVal] = [x, 0];
|
|
901
914
|
if (x.op === AluOp.Add && x.src[1].op === AluOp.Const) [xNoConst, constVal] = [x.src[0], x.src[1].arg];
|
|
902
915
|
const terms = [];
|
|
@@ -918,7 +931,7 @@ var AluExp = class AluExp {
|
|
|
918
931
|
rem = AluExp.add(rem, AluExp.mul(AluExp.const(x.dtype, Math.floor(remainders[i] / gcdVal)), terms[i]));
|
|
919
932
|
quo = AluExp.add(quo, AluExp.mul(AluExp.const(x.dtype, quotients[i]), terms[i]));
|
|
920
933
|
}
|
|
921
|
-
if (
|
|
934
|
+
if (rem.min >= 0) if (op === AluOp.Mod) return AluExp.add(AluExp.mul(AluExp.const(x.dtype, gcdVal), AluExp.mod(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal)))), AluExp.const(x.dtype, constVal % gcdVal)).simplify(cache);
|
|
922
935
|
else return AluExp.add(AluExp.idiv(rem, AluExp.const(x.dtype, Math.floor(y.arg / gcdVal))), quo).simplify(cache);
|
|
923
936
|
}
|
|
924
937
|
}
|
|
@@ -2305,10 +2318,10 @@ function tuneWebgpu(kernel) {
|
|
|
2305
2318
|
upcastedAxis.add(choices[0][2]);
|
|
2306
2319
|
} else break;
|
|
2307
2320
|
}
|
|
2308
|
-
if (/
|
|
2321
|
+
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2309
2322
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2310
2323
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2311
|
-
else for (const splits of [4]) if (s % splits === 0) {
|
|
2324
|
+
else for (const splits of [8, 4]) if (s % splits === 0) {
|
|
2312
2325
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2313
2326
|
break;
|
|
2314
2327
|
}
|
|
@@ -4239,7 +4252,7 @@ async function createBackend(device) {
|
|
|
4239
4252
|
if (!navigator.gpu) return null;
|
|
4240
4253
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4241
4254
|
if (!adapter) return null;
|
|
4242
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4255
|
+
const { WebGPUBackend } = await import("./webgpu-B96vzWGE.js");
|
|
4243
4256
|
const importantLimits = [
|
|
4244
4257
|
"maxBufferSize",
|
|
4245
4258
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4277,7 +4290,7 @@ async function createBackend(device) {
|
|
|
4277
4290
|
});
|
|
4278
4291
|
if (!gl) return null;
|
|
4279
4292
|
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4280
|
-
const { WebGLBackend } = await import("./webgl-
|
|
4293
|
+
const { WebGLBackend } = await import("./webgl-DweKSWEm.js");
|
|
4281
4294
|
return new WebGLBackend(gl);
|
|
4282
4295
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4283
4296
|
}
|
|
@@ -4313,4 +4326,4 @@ var UnsupportedRoutineError = class extends Error {
|
|
|
4313
4326
|
};
|
|
4314
4327
|
|
|
4315
4328
|
//#endregion
|
|
4316
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4329
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|