@jax-js/jax 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -2,8 +2,9 @@
2
2
 
3
3
  [Website](https://www.ekzhang.com/jax-js/) | [API Reference](https://www.ekzhang.com/jax-js/docs/)
4
4
 
5
- This is a machine learning framework for the browser. It aims to bring JAX-style, high-performance
6
- CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.
5
+ **jax-js** is a machine learning framework for the browser. It aims to bring JAX-style,
6
+ high-performance CPU and GPU kernels to JavaScript, so you can run numerical applications on the
7
+ web.
7
8
 
8
9
  ```bash
9
10
  npm i @jax-js/jax
@@ -12,6 +13,10 @@ npm i @jax-js/jax
12
13
  Under the hood, it translates array operations into a compiler representation, then synthesizes
13
14
  kernels in WebAssembly and WebGPU.
14
15
 
16
+ The library is written from scratch, with zero external dependencies. It maintains close API
17
+ compatibility with NumPy/JAX. Since everything runs client-side, jax-js is likely the most portable
18
+ GPU ML framework, since it runs anywhere a browser can run.
19
+
15
20
  ## Quickstart
16
21
 
17
22
  You can use `jax-js` as an array API, just like NumPy.
@@ -24,7 +29,7 @@ const x = np.array([1, 2, 3]);
24
29
  const y = x.mul(4); // [4, 8, 12]
25
30
  ```
26
31
 
27
- It also lets you take derivatives like in JAX.
32
+ It also lets you take derivatives with `grad` like in JAX (as well as `vmap`, `jit`).
28
33
 
29
34
  ```js
30
35
  import { grad, numpy as np } from "@jax-js/jax";
@@ -37,42 +42,107 @@ const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
37
42
  const xgrad = grad(norm)(x); // [2, 4, 6]
38
43
  ```
39
44
 
40
- The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu),
41
- you can switch to GPU for maximum performance.
45
+ The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu) including
46
+ Chrome and iOS Safari, you can switch to GPU for better performance.
42
47
 
43
48
  ```js
44
- import { numpy as np, setDevice } from "@jax-js/jax";
49
+ import { defaultDevice, init, numpy as np } from "@jax-js/jax";
50
+
51
+ // Initialize the GPU backend.
52
+ await init("webgpu");
45
53
 
46
54
  // Change the default backend to GPU.
47
- setDevice("webgpu");
55
+ defaultDevice("webgpu");
48
56
 
49
57
  const x = np.ones([4096, 4096]);
50
58
  const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
51
59
  ```
52
60
 
61
+ Most common JAX APIs are supported. See the [compatibility table](./FEATURES.md) for a full
62
+ breakdown of what features are available.
63
+
64
+ ### Web usage (CDN)
65
+
66
+ If you want to use `jax-js` in vanilla JavaScript (without a bundler), just import from a module
67
+ script tag. This is the easiest way to get started on a blank HTML page.
68
+
69
+ ```html
70
+ <script type="module">
71
+ import { numpy as np } from "https://esm.sh/@jax-js/jax";
72
+ </script>
73
+ ```
74
+
75
+ ### Performance
76
+
77
+ We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
78
+ very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
79
+ The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
80
+ chip ([try it](https://www.ekzhang.com/jax-js/bench/matmul)).
81
+
82
+ For that example, it's around the same GFLOP/s as
83
+ [TensorFlow.js](https://github.com/tensorflow/tfjs) and
84
+ [ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
85
+ libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
86
+
87
+ ## Examples
88
+
89
+ If you make something cool with jax-js, don't be a stranger! We can feature it here.
90
+
91
+ - [In-browser REPL](https://www.ekzhang.com/jax-js/repl)
92
+ - [Interactive MNIST training](https://www.ekzhang.com/jax-js/mnist)
93
+ - [Matmul benchmark](https://www.ekzhang.com/jax-js/bench/matmul)
94
+ - [Conv2d benchmark](https://www.ekzhang.com/jax-js/bench/conv2d)
95
+ - [Mandelbrot set](https://www.ekzhang.com/jax-js/mandelbrot)
96
+
53
97
  ## Development
54
98
 
55
- Under construction.
99
+ _The following technical details are for contributing to jax-js and modifying its internals._
100
+
101
+ This repository is managed by [`pnpm`](https://pnpm.io/). You can compile and build all packages in
102
+ watch mode with:
56
103
 
57
104
  ```bash
58
105
  pnpm install
59
106
  pnpm run build:watch
107
+ ```
108
+
109
+ Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
60
110
 
61
- # Run tests
111
+ ```bash
62
112
  pnpm exec playwright install
63
113
  pnpm test
64
114
  ```
65
115
 
116
+ We are currently on an older version of Playwright that supports using WebGPU in headless mode;
117
+ newer versions skip the WebGPU tests.
118
+
119
+ To start a Vite dev server running the website, demos and REPL:
120
+
121
+ ```bash
122
+ pnpm -C website dev
123
+ ```
124
+
125
+ ## Future work / help wanted
126
+
127
+ Contributions are welcomed in the following areas:
128
+
129
+ - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
130
+ - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
131
+ and multithreading.
132
+ - Helping the JIT compiler to fuse operations in more cases.
133
+ - Adding WebGL runtime for older browsers that don't support WebGPU.
134
+ - Making a fast transformer inference engine, comparing against onnxruntime-web.
135
+ - Ergonomics and API improvements.
136
+
66
137
  ## Next on Eric's mind
67
138
 
68
139
  - Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
69
- - Fix jit-of-grad returning very incorrect result
70
- - Improve perf of MNIST neural network
71
- - Optimize conv2d further (maybe blocks -> local dims?)
72
- - Add fused epilogue to JIT
73
- - Reduce kernel overhead of constants / inline expressions
74
- - Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
75
- - How many threads to create per workgroup, depends on hardware
140
+ - Performance
141
+ - Improve perf of MobileCLIP neural network
142
+ - Add fused epilogue to JIT
143
+ - Fix fusion of activation functions with branches like tanh
144
+ - Reduce kernel overhead of constants / inline expressions
145
+ - How many threads to create per workgroup, depends on hardware
76
146
 
77
147
  ## Milestones
78
148
 
@@ -91,9 +161,9 @@ pnpm test
91
161
  - [x] Other dtypes like int32 and bool
92
162
  - [x] `jit()` support via Jaxprs and kernel fusion
93
163
  - [x] We figure out the `dispose()` / refcount / linear types stuff
94
- - [ ] `dispose()` for saved "const" tracers in Jaxprs
95
- - [ ] Garbage collection for JIT programs
96
- - [ ] Memory scheduling, buffer allocation (can be tricky)
164
+ - [x] `dispose()` for saved "const" tracers in Jaxprs
165
+ - [x] Garbage collection for JIT programs
166
+ - [x] Debug grad-grad-jit test producing a UseAfterFreeError
97
167
  - [ ] Demos: Navier-Stokes, neural networks, statistics
98
168
  - [x] Features for neural networks
99
169
  - [x] Convolution
@@ -103,6 +173,10 @@ pnpm test
103
173
  - [x] Better memory allocation that frees buffers
104
174
  - [ ] SIMD support for Wasm backend
105
175
  - [ ] Async / multithreading Wasm support
106
- - [ ] Device switching with `.to()` between webgpu/cpu/wasm
107
- - [ ] numpy/jax API compatibility table
108
- - [ ] Import tfjs models
176
+ - [ ] Full support of weak types and committed devices
177
+ - [x] High-level ops have automatic type promotion
178
+ - [x] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
179
+ - [ ] Committed devices -
180
+ [ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
181
+ - [ ] Device switching with `device_put()` between webgpu/cpu/wasm
182
+ - [x] numpy/jax API compatibility table