@jax-js/jax 0.0.4 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,9 +1,14 @@
1
- # jax-js: JAX in pure JavaScript
1
+ <h1 align="center">jax-js: JAX in pure JavaScript</h1>
2
2
 
3
- [Website](https://www.ekzhang.com/jax-js/) | [API Reference](https://www.ekzhang.com/jax-js/docs/)
3
+ <p align="center"><strong>
4
+ <a href="https://jax-js.com">Website</a> |
5
+ <a href="https://jax-js.com/docs/">API Reference</a> |
6
+ <a href="./FEATURES.md">Compatibility Table</a>
7
+ </strong></p>
4
8
 
5
- This is a machine learning framework for the browser. It aims to bring JAX-style, high-performance
6
- CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.
9
+ **jax-js** is a machine learning framework for the browser. It aims to bring
10
+ [JAX](https://jax.dev)-style, high-performance CPU and GPU kernels to JavaScript, so you can run
11
+ numerical applications on the web.
7
12
 
8
13
  ```bash
9
14
  npm i @jax-js/jax
@@ -12,49 +17,309 @@ npm i @jax-js/jax
12
17
  Under the hood, it translates array operations into a compiler representation, then synthesizes
13
18
  kernels in WebAssembly and WebGPU.
14
19
 
15
- ## Quickstart
20
+ The library is written from scratch, with zero external dependencies. It maintains close API
21
+ compatibility with NumPy/JAX. Since everything runs client-side, jax-js is likely the most portable
22
+ GPU ML framework, since it runs anywhere a browser can run.
16
23
 
17
- You can use `jax-js` as an array API, just like NumPy.
24
+ ## Quickstart
18
25
 
19
26
  ```js
20
27
  import { numpy as np } from "@jax-js/jax";
21
28
 
22
- // Array operations, compatible with NumPy.
29
+ // Array operations, compatible with JAX/NumPy.
23
30
  const x = np.array([1, 2, 3]);
24
31
  const y = x.mul(4); // [4, 8, 12]
25
32
  ```
26
33
 
27
- It also lets you take derivatives like in JAX.
34
+ ### Web usage (CDN)
28
35
 
29
- ```js
36
+ In vanilla JavaScript (without a bundler), just import from a module script tag. This is the easiest
37
+ way to get started on a blank HTML page.
38
+
39
+ ```html
40
+ <script type="module">
41
+ import { numpy as np } from "https://esm.sh/@jax-js/jax";
42
+ </script>
43
+ ```
44
+
45
+ ## Tutorial
46
+
47
+ Programming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),
48
+ just in JavaScript.
49
+
50
+ ### Arrays
51
+
52
+ Create an array with `np.array()`:
53
+
54
+ ```ts
55
+ import { numpy as np } from "@jax-js/jax";
56
+
57
+ const ar = np.array([1, 2, 3]);
58
+ ```
59
+
60
+ By default, this is a float32 array, but you can also specify a dtype explicitly:
61
+
62
+ ```ts
63
+ const ar = np.array([1, 2, 3], { dtype: np.float32 });
64
+ ```
65
+
66
+ For more efficient construction, you can create an array from a JS `TypedArray` buffer:
67
+
68
+ ```ts
69
+ const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
70
+ const ar = np.array(buf).reshape([2, 3]);
71
+ ```
72
+
73
+ Once you're done with it, you can unwrap a `jax.Array` back into JavaScript. This will also apply
74
+ any pending operations or lazy updates:
75
+
76
+ ```ts
77
+ // 1) Returns a possibly nested JavaScript array.
78
+ ar.js();
79
+ await ar.jsAsync(); // Faster, non-blocking
80
+
81
+ // 2) Returns a flat TypedArray data buffer.
82
+ ar.dataSync();
83
+ await ar.data(); // Fastest, non-blocking
84
+ ```
85
+
86
+ Arrays can have mathematical operations applied to them. For example:
87
+
88
+ ```ts
89
+ import { numpy as np, scipySpecial as special } from "@jax-js/jax";
90
+
91
+ const x = np.arange(100).astype(np.float32); // array of integers [0..99]
92
+
93
+ const y1 = x.ref.add(x.ref); // x + x
94
+ const y2 = np.sin(x.ref); // sin(x)
95
+ const y3 = np.tanh(x.ref).mul(5); // 5 * tanh(x)
96
+ const y4 = special.erfc(x.ref); // erfc(x)
97
+ ```
98
+
99
+ Notice that in the above code, we used `x.ref`. This is because of the memory model, jax-js uses
100
+ reference-counted _ownership_ to track when the memory of an Array can be freed. More on this below.
101
+
102
+ ### Reference counting
103
+
104
+ Big Arrays take up a lot of memory. Python ML libraries override the `__del__()` method to free
105
+ memory, but JavaScript has no such API for running object destructors
106
+ ([cf.](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry)).
107
+ This means that you have to track references manually. jax-js tries to make this as ergonomic as
108
+ possible, so you don't accidentally leak memory in a loop.
109
+
110
+ Every `jax.Array` has a reference count. This satisfies the following rules:
111
+
112
+ - Whenever you create an Array, its reference count starts at `1`.
113
+ - When an Array's reference count reaches `0`, it is freed and can no longer be used.
114
+ - Given an Array `a`:
115
+ - Accessing `a.ref` returns `a` and changes its reference count by `+1`.
116
+ - Passing `a` into any function as argument changes its reference count by `-1`.
117
+ - Calling `a.dispose()` also changes its reference count by `-1`.
118
+
119
+ What this means is that all functions in jax-js must _take ownership_ of their arguments as
120
+ references. Whenever you would like to pass an Array as argument, you can pass it directly to
121
+ dispose of it, or use `.ref` if you'd like to use it again later.
122
+
123
+ **You must follow these rules on your own functions as well!** All combinators like `jvp`, `grad`,
124
+ `jit` assume that you are following these conventions on how arguments are passed, and they will
125
+ respect them as well.
126
+
127
+ ```ts
128
+ // Bad: Uses `x` twice, decrementing its reference count twice.
129
+ function foo_bad(x: np.Array, y: np.Array) {
130
+ return x.add(x.mul(y));
131
+ }
132
+
133
+ // Good: The first usage of `x` is `x.ref`, adding +1 to refcount.
134
+ function foo_good(x: np.Array, y: np.Array) {
135
+ return x.ref.add(x.mul(y));
136
+ }
137
+ ```
138
+
139
+ Here's another example:
140
+
141
+ ```ts
142
+ // Bad: Doesn't consume `x` in the `if`-branch.
143
+ function bar_bad(x: np.Array, skip: boolean) {
144
+ if (skip) return np.zeros(x.shape);
145
+ return x;
146
+ }
147
+
148
+ // Good: Consumes `x` the one time in each branch.
149
+ function bar_good(x: np.Array, skip: boolean) {
150
+ if (skip) {
151
+ const ret = np.zeros(x.shape);
152
+ x.dispose();
153
+ return ret;
154
+ }
155
+ return x;
156
+ }
157
+ ```
158
+
159
+ You can assume that every function in jax-js takes ownership properly, except with a couple of very
160
+ rare exceptions that are documented.
161
+
162
+ ### grad(), vmap() and jit()
163
+
164
+ JAX's signature composable transformations are also supported in jax-js. Here is a simple example of
165
+ using `grad` and `vmap` to compute the derivaive of a function:
166
+
167
+ ```ts
168
+ import { numpy as np, grad, vmap } from "@jax-js/jax";
169
+
170
+ const x = np.linspace(-10, 10, 1000);
171
+
172
+ const y1 = vmap(grad(np.sin))(x.ref); // d/dx sin(x) = cos(x)
173
+ const y2 = np.cos(x);
174
+
175
+ np.allclose(y1, y2); // => true
176
+ ```
177
+
178
+ The `jit` function is especially useful when doing long sequences of primitives on GPU, since it
179
+ fuses operations together into a single kernel dispatch. This
180
+ [improves memory bandwidth usage](https://substack.com/home/post/p-163548742) on hardware
181
+ accelerators, which is the bottleneck on GPU rather than raw FLOPs. For instance:
182
+
183
+ ```ts
184
+ export const hypot = jit(function hypot(x1: np.Array, x2: np.Array) {
185
+ return np.sqrt(np.square(x1).add(np.square(x2)));
186
+ });
187
+ ```
188
+
189
+ Without JIT, the `hypot()` function would require four kernel dispatches: two multiplies, one add,
190
+ and one sqrt. JIT fuses these together into a single kernel that does it all at once.
191
+
192
+ All functional transformations can take typed `JsTree` of inputs and outputs. These are similar to
193
+ [JAX's pytrees](https://docs.jax.dev/en/latest/pytrees.html), and it's basically just a structure of
194
+ nested JavaScript objects and arrays. For instance:
195
+
196
+ ```ts
30
197
  import { grad, numpy as np } from "@jax-js/jax";
31
198
 
32
- // Calculate derivatives with reverse-mode AD.
33
- const norm = (a) => a.ref.mul(a).sum();
199
+ type Params = {
200
+ foo: np.Array;
201
+ bar: np.Array[];
202
+ };
203
+
204
+ function getSums(p: Params) {
205
+ const fooSum = p.foo.sum();
206
+ const barSum = p.bar.map((x) => x.sum()).reduce(np.add);
207
+ return fooSum.add(barSum);
208
+ }
209
+
210
+ grad(getSums)({
211
+ foo: np.array([1, 2, 3]),
212
+ bar: [np.array([10]), np.array([11, 12])],
213
+ });
214
+ // => { foo: [1, 1, 1], bar: [[1], [1, 1]] }
215
+ ```
34
216
 
35
- const x = np.array([1, 2, 3]);
36
- const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
37
- const xgrad = grad(norm)(x); // [2, 4, 6]
217
+ Note that you need to use `type` alias syntax rather than `interface` to define fine-grained
218
+ `JsTree` types.
219
+
220
+ ### Devices
221
+
222
+ Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory
223
+ and determines how to execute compiled operations on them.
224
+
225
+ There are currently 3 devices in jax-js:
226
+
227
+ - `cpu`: Slow, mostly for debugging purposes.
228
+ - `wasm`: [WebAssembly](https://webassembly.org/), currently single-threaded and blocking.
229
+ - `webgpu`: [WebGPU](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API), available on
230
+ [supported browsers](https://caniuse.com/webgpu) (Chrome, Firefox, Safari, iOS).
231
+
232
+ The default device is `wasm`, but you can change this at startup time:
233
+
234
+ ```ts
235
+ import { defaultDevice, init } from "@jax-js/jax";
236
+
237
+ const devices = await init(); // Starts all available backends.
238
+
239
+ if (devices.includes("webgpu")) {
240
+ defaultDevice("webgpu");
241
+ } else {
242
+ console.warn("WebGPU is not supported, falling back to Wasm.");
243
+ }
38
244
  ```
39
245
 
40
- The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu), you can
41
- switch to GPU for better performance.
246
+ You can also place individual arrays on specific devices:
42
247
 
43
- ```js
44
- import { defaultDevice, numpy as np } from "@jax-js/jax";
248
+ ```ts
249
+ import { devicePut, numpy as np } from "@jax-js/jax";
45
250
 
46
- // Change the default backend to GPU.
47
- defaultDevice("webgpu");
251
+ const ar = np.array([1, 2, 3]); // Starts with device="wasm"
252
+ await devicePut(ar, "webgpu"); // Now device="webgpu"
253
+ ```
254
+
255
+ ### Helper libraries
256
+
257
+ There are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a
258
+ self-contained way in other projects.
259
+
260
+ **`@jax-js/optax`** provides implementations of optimizers like Adam and SGD.
261
+
262
+ ```ts
263
+ import { adam } from "@jax-js/optax";
264
+
265
+ let params = np.array([1.0, 2.0, 3.0]);
48
266
 
49
- const x = np.ones([4096, 4096]);
50
- const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
267
+ const solver = adam(1e-3);
268
+ let optState = solver.init(params.ref);
269
+ let updates: np.Array;
270
+
271
+ const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();
272
+
273
+ for (let i = 0; i < 100; i++) {
274
+ const paramsGrad = grad(f)(params.ref);
275
+ [updates, optState] = solver.update(paramsGrad, optState);
276
+ params = applyUpdates(params, updates);
277
+ }
51
278
  ```
52
279
 
53
- Most common JAX APIs are supported. See the [compatibility table](./FEATURES.md) for a full
54
- breakdown of what features are available.
280
+ **`@jax-js/loaders`** can load tensors from various formats like Safetensors, includes a fast and
281
+ compliant implementation of BPE, and caches HTTP requests for large assets like model weights in
282
+ OPFS.
283
+
284
+ ```ts
285
+ import { tokenizers } from "@jax-js/loaders";
286
+
287
+ const enc = await tokenizers.getBpe("clip");
288
+ const tokens = enc.encode("Hello, world!"); // => [ 49406, 3306, 267, 1002, ... ]
289
+ ```
290
+
291
+ ### Performance
292
+
293
+ We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
294
+ very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
295
+ The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
296
+ chip ([try it](https://jax-js.com/bench/matmul)).
297
+
298
+ For that example, it's around the same GFLOP/s as
299
+ [TensorFlow.js](https://github.com/tensorflow/tfjs) and
300
+ [ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
301
+ libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
302
+
303
+ ### API Reference
304
+
305
+ That's all for this short tutorial. Please see the generated
306
+ [API reference](https://jax-js.com/docs) for detailed documentation.
307
+
308
+ ## Examples
309
+
310
+ If you make something cool with jax-js, don't be a stranger! We can feature it here.
311
+
312
+ - [Training neural networks on MNIST](https://jax-js.com/mnist)
313
+ - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
314
+ - [In-browser REPL](https://jax-js.com/repl)
315
+ - [Matmul benchmark](https://jax-js.com/bench/matmul)
316
+ - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
317
+ - [Mandelbrot set](https://jax-js.com/mandelbrot)
55
318
 
56
319
  ## Development
57
320
 
321
+ _The following technical details are for contributing to jax-js and modifying its internals._
322
+
58
323
  This repository is managed by [`pnpm`](https://pnpm.io/). You can compile and build all packages in
59
324
  watch mode with:
60
325
 
@@ -70,8 +335,8 @@ pnpm exec playwright install
70
335
  pnpm test
71
336
  ```
72
337
 
73
- _We are currently on an older version of Playwright that supports using WebGPU in headless mode;
74
- newer versions seem to skip the WebGPU tests._
338
+ We are currently on an older version of Playwright that supports using WebGPU in headless mode;
339
+ newer versions skip the WebGPU tests.
75
340
 
76
341
  To start a Vite dev server running the website, demos and REPL:
77
342
 
@@ -79,61 +344,14 @@ To start a Vite dev server running the website, demos and REPL:
79
344
  pnpm -C website dev
80
345
  ```
81
346
 
82
- ## Next on Eric's mind
83
-
84
- - Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
85
- - Performance
86
- - Improve perf of MNIST neural network
87
- - Optimize conv2d further (maybe blocks -> local dims?)
88
- - Add fused epilogue to JIT
89
- - Reduce kernel overhead of constants / inline expressions
90
- - Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
91
- - How many threads to create per workgroup, depends on hardware
92
-
93
- ## Milestones
94
-
95
- - [x] It works!
96
- - [x] Demos: Browser REPL / editor
97
- - [x] First custom kernel
98
- - [x] Custom WebGPU backend, removing tfjs dependency
99
- - [x] Low-level operations
100
- - [x] Create `class Array {}` wrappers
101
- - [x] Reduction operations
102
- - [ ] Kernel tuning (see `tuner.ts`)
103
- - [x] "Upcast" optimizations (compute a tile per thread, e.g., matmul)
104
- - [x] "Unroll" optimizations (multiple loop iters per thread, e.g., matmul)
105
- - [ ] "Group" optimizations (multiple threads per value, e.g., matvec)
106
- - [ ] Blocks respect local dimensions
107
- - [x] Other dtypes like int32 and bool
108
- - [x] `jit()` support via Jaxprs and kernel fusion
109
- - [x] We figure out the `dispose()` / refcount / linear types stuff
110
- - [x] `dispose()` for saved "const" tracers in Jaxprs
111
- - [x] Garbage collection for JIT programs
112
- - [x] Debug grad-grad-jit test producing a UseAfterFreeError
113
- - [ ] Demos: Navier-Stokes, neural networks, statistics
114
- - [x] Features for neural networks
115
- - [x] Convolution
116
- - [x] Random and initializers
117
- - [x] Optimizers (optax package?)
118
- - [x] Wasm backend (needs malloc)
119
- - [x] Better memory allocation that frees buffers
120
- - [ ] SIMD support for Wasm backend
121
- - [ ] Async / multithreading Wasm support
122
- - [ ] Full support of weak types and committed devices
123
- - [ ] High-level ops have automatic type promotion
124
- - [ ] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
125
- - [ ] Committed devices -
126
- [ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
127
- - [ ] Device switching with `device_put()` between webgpu/cpu/wasm
128
- - [x] numpy/jax API compatibility table
129
-
130
347
  ## Future work / help wanted
131
348
 
132
- Contributions are welcomed in the following areas:
349
+ Contributions are welcomed! Especially in:
133
350
 
134
351
  - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
135
- - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, using SIMD and
136
- multithreading.
352
+ - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
353
+ and multithreading. (Even single-threaded Wasm could be ~20x faster.)
354
+ - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches and adding
355
+ epilogue to reductions.
137
356
  - Adding WebGL runtime for older browsers that don't support WebGPU.
138
357
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
139
- - Ergonomics and API improvements.