@jax-js/jax 0.0.4 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +67 -24
- package/dist/{backend-EBRGmEYw.js → backend-CdcTZEOF.js} +35 -6
- package/dist/{backend-Ss1Mev_-.cjs → backend-yEU0L_ig.cjs} +40 -5
- package/dist/index.cjs +324 -225
- package/dist/index.d.cts +71 -26
- package/dist/index.d.ts +71 -26
- package/dist/index.js +314 -215
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-CM-xNYzW.js} +1 -1
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-CNOpiO5T.cjs} +1 -1
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
[Website](https://www.ekzhang.com/jax-js/) | [API Reference](https://www.ekzhang.com/jax-js/docs/)
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
CPU and GPU kernels to JavaScript, so you can run numerical applications on the
|
|
5
|
+
**jax-js** is a machine learning framework for the browser. It aims to bring JAX-style,
|
|
6
|
+
high-performance CPU and GPU kernels to JavaScript, so you can run numerical applications on the
|
|
7
|
+
web.
|
|
7
8
|
|
|
8
9
|
```bash
|
|
9
10
|
npm i @jax-js/jax
|
|
@@ -12,6 +13,10 @@ npm i @jax-js/jax
|
|
|
12
13
|
Under the hood, it translates array operations into a compiler representation, then synthesizes
|
|
13
14
|
kernels in WebAssembly and WebGPU.
|
|
14
15
|
|
|
16
|
+
The library is written from scratch, with zero external dependencies. It maintains close API
|
|
17
|
+
compatibility with NumPy/JAX. Since everything runs client-side, jax-js is likely the most portable
|
|
18
|
+
GPU ML framework, since it runs anywhere a browser can run.
|
|
19
|
+
|
|
15
20
|
## Quickstart
|
|
16
21
|
|
|
17
22
|
You can use `jax-js` as an array API, just like NumPy.
|
|
@@ -24,7 +29,7 @@ const x = np.array([1, 2, 3]);
|
|
|
24
29
|
const y = x.mul(4); // [4, 8, 12]
|
|
25
30
|
```
|
|
26
31
|
|
|
27
|
-
It also lets you take derivatives like in JAX.
|
|
32
|
+
It also lets you take derivatives with `grad` like in JAX (as well as `vmap`, `jit`).
|
|
28
33
|
|
|
29
34
|
```js
|
|
30
35
|
import { grad, numpy as np } from "@jax-js/jax";
|
|
@@ -37,11 +42,14 @@ const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
|
|
|
37
42
|
const xgrad = grad(norm)(x); // [2, 4, 6]
|
|
38
43
|
```
|
|
39
44
|
|
|
40
|
-
The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu)
|
|
41
|
-
switch to GPU for better performance.
|
|
45
|
+
The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu) including
|
|
46
|
+
Chrome and iOS Safari, you can switch to GPU for better performance.
|
|
42
47
|
|
|
43
48
|
```js
|
|
44
|
-
import { defaultDevice, numpy as np } from "@jax-js/jax";
|
|
49
|
+
import { defaultDevice, init, numpy as np } from "@jax-js/jax";
|
|
50
|
+
|
|
51
|
+
// Initialize the GPU backend.
|
|
52
|
+
await init("webgpu");
|
|
45
53
|
|
|
46
54
|
// Change the default backend to GPU.
|
|
47
55
|
defaultDevice("webgpu");
|
|
@@ -53,8 +61,43 @@ const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
|
|
|
53
61
|
Most common JAX APIs are supported. See the [compatibility table](./FEATURES.md) for a full
|
|
54
62
|
breakdown of what features are available.
|
|
55
63
|
|
|
64
|
+
### Web usage (CDN)
|
|
65
|
+
|
|
66
|
+
If you want to use `jax-js` in vanilla JavaScript (without a bundler), just import from a module
|
|
67
|
+
script tag. This is the easiest way to get started on a blank HTML page.
|
|
68
|
+
|
|
69
|
+
```html
|
|
70
|
+
<script type="module">
|
|
71
|
+
import { numpy as np } from "https://esm.sh/@jax-js/jax";
|
|
72
|
+
</script>
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
### Performance
|
|
76
|
+
|
|
77
|
+
We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
|
|
78
|
+
very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
|
|
79
|
+
The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
|
|
80
|
+
chip ([try it](https://www.ekzhang.com/jax-js/bench/matmul)).
|
|
81
|
+
|
|
82
|
+
For that example, it's around the same GFLOP/s as
|
|
83
|
+
[TensorFlow.js](https://github.com/tensorflow/tfjs) and
|
|
84
|
+
[ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
|
|
85
|
+
libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
|
|
86
|
+
|
|
87
|
+
## Examples
|
|
88
|
+
|
|
89
|
+
If you make something cool with jax-js, don't be a stranger! We can feature it here.
|
|
90
|
+
|
|
91
|
+
- [In-browser REPL](https://www.ekzhang.com/jax-js/repl)
|
|
92
|
+
- [Interactive MNIST training](https://www.ekzhang.com/jax-js/mnist)
|
|
93
|
+
- [Matmul benchmark](https://www.ekzhang.com/jax-js/bench/matmul)
|
|
94
|
+
- [Conv2d benchmark](https://www.ekzhang.com/jax-js/bench/conv2d)
|
|
95
|
+
- [Mandelbrot set](https://www.ekzhang.com/jax-js/mandelbrot)
|
|
96
|
+
|
|
56
97
|
## Development
|
|
57
98
|
|
|
99
|
+
_The following technical details are for contributing to jax-js and modifying its internals._
|
|
100
|
+
|
|
58
101
|
This repository is managed by [`pnpm`](https://pnpm.io/). You can compile and build all packages in
|
|
59
102
|
watch mode with:
|
|
60
103
|
|
|
@@ -70,8 +113,8 @@ pnpm exec playwright install
|
|
|
70
113
|
pnpm test
|
|
71
114
|
```
|
|
72
115
|
|
|
73
|
-
|
|
74
|
-
newer versions
|
|
116
|
+
We are currently on an older version of Playwright that supports using WebGPU in headless mode;
|
|
117
|
+
newer versions skip the WebGPU tests.
|
|
75
118
|
|
|
76
119
|
To start a Vite dev server running the website, demos and REPL:
|
|
77
120
|
|
|
@@ -79,15 +122,26 @@ To start a Vite dev server running the website, demos and REPL:
|
|
|
79
122
|
pnpm -C website dev
|
|
80
123
|
```
|
|
81
124
|
|
|
125
|
+
## Future work / help wanted
|
|
126
|
+
|
|
127
|
+
Contributions are welcomed in the following areas:
|
|
128
|
+
|
|
129
|
+
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
130
|
+
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
131
|
+
and multithreading.
|
|
132
|
+
- Helping the JIT compiler to fuse operations in more cases.
|
|
133
|
+
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
134
|
+
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
135
|
+
- Ergonomics and API improvements.
|
|
136
|
+
|
|
82
137
|
## Next on Eric's mind
|
|
83
138
|
|
|
84
139
|
- Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
|
|
85
140
|
- Performance
|
|
86
|
-
- Improve perf of
|
|
87
|
-
- Optimize conv2d further (maybe blocks -> local dims?)
|
|
141
|
+
- Improve perf of MobileCLIP neural network
|
|
88
142
|
- Add fused epilogue to JIT
|
|
143
|
+
- Fix fusion of activation functions with branches like tanh
|
|
89
144
|
- Reduce kernel overhead of constants / inline expressions
|
|
90
|
-
- Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
|
|
91
145
|
- How many threads to create per workgroup, depends on hardware
|
|
92
146
|
|
|
93
147
|
## Milestones
|
|
@@ -120,20 +174,9 @@ pnpm -C website dev
|
|
|
120
174
|
- [ ] SIMD support for Wasm backend
|
|
121
175
|
- [ ] Async / multithreading Wasm support
|
|
122
176
|
- [ ] Full support of weak types and committed devices
|
|
123
|
-
- [
|
|
124
|
-
- [
|
|
177
|
+
- [x] High-level ops have automatic type promotion
|
|
178
|
+
- [x] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
|
|
125
179
|
- [ ] Committed devices -
|
|
126
180
|
[ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
|
|
127
181
|
- [ ] Device switching with `device_put()` between webgpu/cpu/wasm
|
|
128
182
|
- [x] numpy/jax API compatibility table
|
|
129
|
-
|
|
130
|
-
## Future work / help wanted
|
|
131
|
-
|
|
132
|
-
Contributions are welcomed in the following areas:
|
|
133
|
-
|
|
134
|
-
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
135
|
-
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, using SIMD and
|
|
136
|
-
multithreading.
|
|
137
|
-
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
138
|
-
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
139
|
-
- Ergonomics and API improvements.
|
|
@@ -206,6 +206,35 @@ function findPow2(hint, max) {
|
|
|
206
206
|
while (ret < hint && 2 * ret <= max) ret *= 2;
|
|
207
207
|
return ret;
|
|
208
208
|
}
|
|
209
|
+
/**
|
|
210
|
+
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
211
|
+
*
|
|
212
|
+
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
213
|
+
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
214
|
+
* Two dimensions are compatible when:
|
|
215
|
+
* 1. they are equal, or
|
|
216
|
+
* 2. one of them is 1."
|
|
217
|
+
*
|
|
218
|
+
* Throws a TypeError if the broadcast is not possible.
|
|
219
|
+
*
|
|
220
|
+
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
221
|
+
*/
|
|
222
|
+
function generalBroadcast(a, b) {
|
|
223
|
+
const out = [];
|
|
224
|
+
let i = a.length - 1;
|
|
225
|
+
let j = b.length - 1;
|
|
226
|
+
for (; i >= 0 && j >= 0; i--, j--) {
|
|
227
|
+
const x = a[i];
|
|
228
|
+
const y = b[j];
|
|
229
|
+
if (x === y) out.push(x);
|
|
230
|
+
else if (x === 1) out.push(y);
|
|
231
|
+
else if (y === 1) out.push(x);
|
|
232
|
+
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
233
|
+
}
|
|
234
|
+
for (; i >= 0; i--) out.push(a[i]);
|
|
235
|
+
for (; j >= 0; j--) out.push(b[j]);
|
|
236
|
+
return out.reverse();
|
|
237
|
+
}
|
|
209
238
|
function recursiveFlatten(ar) {
|
|
210
239
|
if (!Array.isArray(ar)) return [ar];
|
|
211
240
|
return ar.flat(Infinity);
|
|
@@ -294,12 +323,12 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
|
|
|
294
323
|
* **Type lattice:**
|
|
295
324
|
* ```text
|
|
296
325
|
* bool -> uint32 -> int32 -> float16 -> float32
|
|
297
|
-
*
|
|
326
|
+
* weakType --^
|
|
298
327
|
* ```
|
|
299
328
|
*
|
|
300
|
-
*
|
|
301
|
-
*
|
|
302
|
-
*
|
|
329
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
330
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
331
|
+
* they are first combined with, except `bool`.
|
|
303
332
|
*
|
|
304
333
|
* **Examples:**
|
|
305
334
|
* - `promoteTypes(bool, int32) → int32`
|
|
@@ -3760,7 +3789,7 @@ async function createBackend(device) {
|
|
|
3760
3789
|
if (!navigator.gpu) return null;
|
|
3761
3790
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3762
3791
|
if (!adapter) return null;
|
|
3763
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
3792
|
+
const { WebGPUBackend } = await import("./webgpu-CM-xNYzW.js");
|
|
3764
3793
|
const importantLimits = [
|
|
3765
3794
|
"maxBufferSize",
|
|
3766
3795
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3813,4 +3842,4 @@ var UnsupportedOpError = class extends Error {
|
|
|
3813
3842
|
};
|
|
3814
3843
|
|
|
3815
3844
|
//#endregion
|
|
3816
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
|
|
3845
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
|
|
@@ -207,6 +207,35 @@ function findPow2(hint, max) {
|
|
|
207
207
|
while (ret < hint && 2 * ret <= max) ret *= 2;
|
|
208
208
|
return ret;
|
|
209
209
|
}
|
|
210
|
+
/**
|
|
211
|
+
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
212
|
+
*
|
|
213
|
+
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
214
|
+
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
215
|
+
* Two dimensions are compatible when:
|
|
216
|
+
* 1. they are equal, or
|
|
217
|
+
* 2. one of them is 1."
|
|
218
|
+
*
|
|
219
|
+
* Throws a TypeError if the broadcast is not possible.
|
|
220
|
+
*
|
|
221
|
+
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
222
|
+
*/
|
|
223
|
+
function generalBroadcast(a, b) {
|
|
224
|
+
const out = [];
|
|
225
|
+
let i = a.length - 1;
|
|
226
|
+
let j = b.length - 1;
|
|
227
|
+
for (; i >= 0 && j >= 0; i--, j--) {
|
|
228
|
+
const x = a[i];
|
|
229
|
+
const y = b[j];
|
|
230
|
+
if (x === y) out.push(x);
|
|
231
|
+
else if (x === 1) out.push(y);
|
|
232
|
+
else if (y === 1) out.push(x);
|
|
233
|
+
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
234
|
+
}
|
|
235
|
+
for (; i >= 0; i--) out.push(a[i]);
|
|
236
|
+
for (; j >= 0; j--) out.push(b[j]);
|
|
237
|
+
return out.reverse();
|
|
238
|
+
}
|
|
210
239
|
function recursiveFlatten(ar) {
|
|
211
240
|
if (!Array.isArray(ar)) return [ar];
|
|
212
241
|
return ar.flat(Infinity);
|
|
@@ -295,12 +324,12 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
|
|
|
295
324
|
* **Type lattice:**
|
|
296
325
|
* ```text
|
|
297
326
|
* bool -> uint32 -> int32 -> float16 -> float32
|
|
298
|
-
*
|
|
327
|
+
* weakType --^
|
|
299
328
|
* ```
|
|
300
329
|
*
|
|
301
|
-
*
|
|
302
|
-
*
|
|
303
|
-
*
|
|
330
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
331
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
332
|
+
* they are first combined with, except `bool`.
|
|
304
333
|
*
|
|
305
334
|
* **Examples:**
|
|
306
335
|
* - `promoteTypes(bool, int32) → int32`
|
|
@@ -3761,7 +3790,7 @@ async function createBackend(device) {
|
|
|
3761
3790
|
if (!navigator.gpu) return null;
|
|
3762
3791
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3763
3792
|
if (!adapter) return null;
|
|
3764
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
3793
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CNOpiO5T.cjs"));
|
|
3765
3794
|
const importantLimits = [
|
|
3766
3795
|
"maxBufferSize",
|
|
3767
3796
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3958,6 +3987,12 @@ Object.defineProperty(exports, 'findPow2', {
|
|
|
3958
3987
|
return findPow2;
|
|
3959
3988
|
}
|
|
3960
3989
|
});
|
|
3990
|
+
Object.defineProperty(exports, 'generalBroadcast', {
|
|
3991
|
+
enumerable: true,
|
|
3992
|
+
get: function () {
|
|
3993
|
+
return generalBroadcast;
|
|
3994
|
+
}
|
|
3995
|
+
});
|
|
3961
3996
|
Object.defineProperty(exports, 'getBackend', {
|
|
3962
3997
|
enumerable: true,
|
|
3963
3998
|
get: function () {
|