@jax-js/jax 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -22
- package/dist/{backend-BqDtPGaR.js → backend-CdcTZEOF.js} +325 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-yEU0L_ig.cjs} +350 -154
- package/dist/index.cjs +977 -354
- package/dist/index.d.cts +479 -88
- package/dist/index.d.ts +479 -88
- package/dist/index.js +964 -345
- package/dist/{webgpu-CNg9JGva.js → webgpu-CM-xNYzW.js} +9 -3
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-CNOpiO5T.cjs} +9 -3
- package/package.json +15 -4
package/README.md
CHANGED
|
@@ -2,8 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
[Website](https://www.ekzhang.com/jax-js/) | [API Reference](https://www.ekzhang.com/jax-js/docs/)
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
CPU and GPU kernels to JavaScript, so you can run numerical applications on the
|
|
5
|
+
**jax-js** is a machine learning framework for the browser. It aims to bring JAX-style,
|
|
6
|
+
high-performance CPU and GPU kernels to JavaScript, so you can run numerical applications on the
|
|
7
|
+
web.
|
|
7
8
|
|
|
8
9
|
```bash
|
|
9
10
|
npm i @jax-js/jax
|
|
@@ -12,6 +13,10 @@ npm i @jax-js/jax
|
|
|
12
13
|
Under the hood, it translates array operations into a compiler representation, then synthesizes
|
|
13
14
|
kernels in WebAssembly and WebGPU.
|
|
14
15
|
|
|
16
|
+
The library is written from scratch, with zero external dependencies. It maintains close API
|
|
17
|
+
compatibility with NumPy/JAX. Since everything runs client-side, jax-js is likely the most portable
|
|
18
|
+
GPU ML framework, since it runs anywhere a browser can run.
|
|
19
|
+
|
|
15
20
|
## Quickstart
|
|
16
21
|
|
|
17
22
|
You can use `jax-js` as an array API, just like NumPy.
|
|
@@ -24,7 +29,7 @@ const x = np.array([1, 2, 3]);
|
|
|
24
29
|
const y = x.mul(4); // [4, 8, 12]
|
|
25
30
|
```
|
|
26
31
|
|
|
27
|
-
It also lets you take derivatives like in JAX.
|
|
32
|
+
It also lets you take derivatives with `grad` like in JAX (as well as `vmap`, `jit`).
|
|
28
33
|
|
|
29
34
|
```js
|
|
30
35
|
import { grad, numpy as np } from "@jax-js/jax";
|
|
@@ -37,42 +42,107 @@ const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
|
|
|
37
42
|
const xgrad = grad(norm)(x); // [2, 4, 6]
|
|
38
43
|
```
|
|
39
44
|
|
|
40
|
-
The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu)
|
|
41
|
-
you can switch to GPU for
|
|
45
|
+
The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu) including
|
|
46
|
+
Chrome and iOS Safari, you can switch to GPU for better performance.
|
|
42
47
|
|
|
43
48
|
```js
|
|
44
|
-
import { numpy as np
|
|
49
|
+
import { defaultDevice, init, numpy as np } from "@jax-js/jax";
|
|
50
|
+
|
|
51
|
+
// Initialize the GPU backend.
|
|
52
|
+
await init("webgpu");
|
|
45
53
|
|
|
46
54
|
// Change the default backend to GPU.
|
|
47
|
-
|
|
55
|
+
defaultDevice("webgpu");
|
|
48
56
|
|
|
49
57
|
const x = np.ones([4096, 4096]);
|
|
50
58
|
const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
|
|
51
59
|
```
|
|
52
60
|
|
|
61
|
+
Most common JAX APIs are supported. See the [compatibility table](./FEATURES.md) for a full
|
|
62
|
+
breakdown of what features are available.
|
|
63
|
+
|
|
64
|
+
### Web usage (CDN)
|
|
65
|
+
|
|
66
|
+
If you want to use `jax-js` in vanilla JavaScript (without a bundler), just import from a module
|
|
67
|
+
script tag. This is the easiest way to get started on a blank HTML page.
|
|
68
|
+
|
|
69
|
+
```html
|
|
70
|
+
<script type="module">
|
|
71
|
+
import { numpy as np } from "https://esm.sh/@jax-js/jax";
|
|
72
|
+
</script>
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
### Performance
|
|
76
|
+
|
|
77
|
+
We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
|
|
78
|
+
very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
|
|
79
|
+
The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
|
|
80
|
+
chip ([try it](https://www.ekzhang.com/jax-js/bench/matmul)).
|
|
81
|
+
|
|
82
|
+
For that example, it's around the same GFLOP/s as
|
|
83
|
+
[TensorFlow.js](https://github.com/tensorflow/tfjs) and
|
|
84
|
+
[ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
|
|
85
|
+
libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
|
|
86
|
+
|
|
87
|
+
## Examples
|
|
88
|
+
|
|
89
|
+
If you make something cool with jax-js, don't be a stranger! We can feature it here.
|
|
90
|
+
|
|
91
|
+
- [In-browser REPL](https://www.ekzhang.com/jax-js/repl)
|
|
92
|
+
- [Interactive MNIST training](https://www.ekzhang.com/jax-js/mnist)
|
|
93
|
+
- [Matmul benchmark](https://www.ekzhang.com/jax-js/bench/matmul)
|
|
94
|
+
- [Conv2d benchmark](https://www.ekzhang.com/jax-js/bench/conv2d)
|
|
95
|
+
- [Mandelbrot set](https://www.ekzhang.com/jax-js/mandelbrot)
|
|
96
|
+
|
|
53
97
|
## Development
|
|
54
98
|
|
|
55
|
-
|
|
99
|
+
_The following technical details are for contributing to jax-js and modifying its internals._
|
|
100
|
+
|
|
101
|
+
This repository is managed by [`pnpm`](https://pnpm.io/). You can compile and build all packages in
|
|
102
|
+
watch mode with:
|
|
56
103
|
|
|
57
104
|
```bash
|
|
58
105
|
pnpm install
|
|
59
106
|
pnpm run build:watch
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
|
|
60
110
|
|
|
61
|
-
|
|
111
|
+
```bash
|
|
62
112
|
pnpm exec playwright install
|
|
63
113
|
pnpm test
|
|
64
114
|
```
|
|
65
115
|
|
|
116
|
+
We are currently on an older version of Playwright that supports using WebGPU in headless mode;
|
|
117
|
+
newer versions skip the WebGPU tests.
|
|
118
|
+
|
|
119
|
+
To start a Vite dev server running the website, demos and REPL:
|
|
120
|
+
|
|
121
|
+
```bash
|
|
122
|
+
pnpm -C website dev
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
## Future work / help wanted
|
|
126
|
+
|
|
127
|
+
Contributions are welcomed in the following areas:
|
|
128
|
+
|
|
129
|
+
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
130
|
+
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
131
|
+
and multithreading.
|
|
132
|
+
- Helping the JIT compiler to fuse operations in more cases.
|
|
133
|
+
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
134
|
+
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
135
|
+
- Ergonomics and API improvements.
|
|
136
|
+
|
|
66
137
|
## Next on Eric's mind
|
|
67
138
|
|
|
68
139
|
- Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
|
|
69
|
-
-
|
|
70
|
-
- Improve perf of
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
-
|
|
75
|
-
- How many threads to create per workgroup, depends on hardware
|
|
140
|
+
- Performance
|
|
141
|
+
- Improve perf of MobileCLIP neural network
|
|
142
|
+
- Add fused epilogue to JIT
|
|
143
|
+
- Fix fusion of activation functions with branches like tanh
|
|
144
|
+
- Reduce kernel overhead of constants / inline expressions
|
|
145
|
+
- How many threads to create per workgroup, depends on hardware
|
|
76
146
|
|
|
77
147
|
## Milestones
|
|
78
148
|
|
|
@@ -91,9 +161,9 @@ pnpm test
|
|
|
91
161
|
- [x] Other dtypes like int32 and bool
|
|
92
162
|
- [x] `jit()` support via Jaxprs and kernel fusion
|
|
93
163
|
- [x] We figure out the `dispose()` / refcount / linear types stuff
|
|
94
|
-
- [
|
|
95
|
-
- [
|
|
96
|
-
- [
|
|
164
|
+
- [x] `dispose()` for saved "const" tracers in Jaxprs
|
|
165
|
+
- [x] Garbage collection for JIT programs
|
|
166
|
+
- [x] Debug grad-grad-jit test producing a UseAfterFreeError
|
|
97
167
|
- [ ] Demos: Navier-Stokes, neural networks, statistics
|
|
98
168
|
- [x] Features for neural networks
|
|
99
169
|
- [x] Convolution
|
|
@@ -103,6 +173,10 @@ pnpm test
|
|
|
103
173
|
- [x] Better memory allocation that frees buffers
|
|
104
174
|
- [ ] SIMD support for Wasm backend
|
|
105
175
|
- [ ] Async / multithreading Wasm support
|
|
106
|
-
- [ ]
|
|
107
|
-
- [
|
|
108
|
-
- [
|
|
176
|
+
- [ ] Full support of weak types and committed devices
|
|
177
|
+
- [x] High-level ops have automatic type promotion
|
|
178
|
+
- [x] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
|
|
179
|
+
- [ ] Committed devices -
|
|
180
|
+
[ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
|
|
181
|
+
- [ ] Device switching with `device_put()` between webgpu/cpu/wasm
|
|
182
|
+
- [x] numpy/jax API compatibility table
|