@jax-js/jax 0.0.5 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,10 +1,14 @@
1
- # jax-js: JAX in pure JavaScript
1
+ <h1 align="center">jax-js: JAX in pure JavaScript</h1>
2
2
 
3
- [Website](https://www.ekzhang.com/jax-js/) | [API Reference](https://www.ekzhang.com/jax-js/docs/)
3
+ <p align="center"><strong>
4
+ <a href="https://jax-js.com">Website</a> |
5
+ <a href="https://jax-js.com/docs/">API Reference</a> |
6
+ <a href="./FEATURES.md">Compatibility Table</a>
7
+ </strong></p>
4
8
 
5
- **jax-js** is a machine learning framework for the browser. It aims to bring JAX-style,
6
- high-performance CPU and GPU kernels to JavaScript, so you can run numerical applications on the
7
- web.
9
+ **jax-js** is a machine learning framework for the browser. It aims to bring
10
+ [JAX](https://jax.dev)-style, high-performance CPU and GPU kernels to JavaScript, so you can run
11
+ numerical applications on the web.
8
12
 
9
13
  ```bash
10
14
  npm i @jax-js/jax
@@ -19,57 +23,269 @@ GPU ML framework, since it runs anywhere a browser can run.
19
23
 
20
24
  ## Quickstart
21
25
 
22
- You can use `jax-js` as an array API, just like NumPy.
23
-
24
26
  ```js
25
27
  import { numpy as np } from "@jax-js/jax";
26
28
 
27
- // Array operations, compatible with NumPy.
29
+ // Array operations, compatible with JAX/NumPy.
28
30
  const x = np.array([1, 2, 3]);
29
31
  const y = x.mul(4); // [4, 8, 12]
30
32
  ```
31
33
 
32
- It also lets you take derivatives with `grad` like in JAX (as well as `vmap`, `jit`).
34
+ ### Web usage (CDN)
33
35
 
34
- ```js
35
- import { grad, numpy as np } from "@jax-js/jax";
36
+ In vanilla JavaScript (without a bundler), just import from a module script tag. This is the easiest
37
+ way to get started on a blank HTML page.
36
38
 
37
- // Calculate derivatives with reverse-mode AD.
38
- const norm = (a) => a.ref.mul(a).sum();
39
+ ```html
40
+ <script type="module">
41
+ import { numpy as np } from "https://esm.sh/@jax-js/jax";
42
+ </script>
43
+ ```
39
44
 
40
- const x = np.array([1, 2, 3]);
41
- const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
42
- const xgrad = grad(norm)(x); // [2, 4, 6]
45
+ ## Tutorial
46
+
47
+ Programming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),
48
+ just in JavaScript.
49
+
50
+ ### Arrays
51
+
52
+ Create an array with `np.array()`:
53
+
54
+ ```ts
55
+ import { numpy as np } from "@jax-js/jax";
56
+
57
+ const ar = np.array([1, 2, 3]);
43
58
  ```
44
59
 
45
- The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu) including
46
- Chrome and iOS Safari, you can switch to GPU for better performance.
60
+ By default, this is a float32 array, but you can also specify a dtype explicitly:
47
61
 
48
- ```js
49
- import { defaultDevice, init, numpy as np } from "@jax-js/jax";
62
+ ```ts
63
+ const ar = np.array([1, 2, 3], { dtype: np.float32 });
64
+ ```
65
+
66
+ For more efficient construction, you can create an array from a JS `TypedArray` buffer:
67
+
68
+ ```ts
69
+ const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
70
+ const ar = np.array(buf).reshape([2, 3]);
71
+ ```
50
72
 
51
- // Initialize the GPU backend.
52
- await init("webgpu");
73
+ Once you're done with it, you can unwrap a `jax.Array` back into JavaScript. This will also apply
74
+ any pending operations or lazy updates:
53
75
 
54
- // Change the default backend to GPU.
55
- defaultDevice("webgpu");
76
+ ```ts
77
+ // 1) Returns a possibly nested JavaScript array.
78
+ ar.js();
79
+ await ar.jsAsync(); // Faster, non-blocking
56
80
 
57
- const x = np.ones([4096, 4096]);
58
- const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
81
+ // 2) Returns a flat TypedArray data buffer.
82
+ ar.dataSync();
83
+ await ar.data(); // Fastest, non-blocking
59
84
  ```
60
85
 
61
- Most common JAX APIs are supported. See the [compatibility table](./FEATURES.md) for a full
62
- breakdown of what features are available.
86
+ Arrays can have mathematical operations applied to them. For example:
63
87
 
64
- ### Web usage (CDN)
88
+ ```ts
89
+ import { numpy as np, scipySpecial as special } from "@jax-js/jax";
65
90
 
66
- If you want to use `jax-js` in vanilla JavaScript (without a bundler), just import from a module
67
- script tag. This is the easiest way to get started on a blank HTML page.
91
+ const x = np.arange(100).astype(np.float32); // array of integers [0..99]
68
92
 
69
- ```html
70
- <script type="module">
71
- import { numpy as np } from "https://esm.sh/@jax-js/jax";
72
- </script>
93
+ const y1 = x.ref.add(x.ref); // x + x
94
+ const y2 = np.sin(x.ref); // sin(x)
95
+ const y3 = np.tanh(x.ref).mul(5); // 5 * tanh(x)
96
+ const y4 = special.erfc(x.ref); // erfc(x)
97
+ ```
98
+
99
+ Notice that in the above code, we used `x.ref`. This is because of the memory model, jax-js uses
100
+ reference-counted _ownership_ to track when the memory of an Array can be freed. More on this below.
101
+
102
+ ### Reference counting
103
+
104
+ Big Arrays take up a lot of memory. Python ML libraries override the `__del__()` method to free
105
+ memory, but JavaScript has no such API for running object destructors
106
+ ([cf.](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry)).
107
+ This means that you have to track references manually. jax-js tries to make this as ergonomic as
108
+ possible, so you don't accidentally leak memory in a loop.
109
+
110
+ Every `jax.Array` has a reference count. This satisfies the following rules:
111
+
112
+ - Whenever you create an Array, its reference count starts at `1`.
113
+ - When an Array's reference count reaches `0`, it is freed and can no longer be used.
114
+ - Given an Array `a`:
115
+ - Accessing `a.ref` returns `a` and changes its reference count by `+1`.
116
+ - Passing `a` into any function as argument changes its reference count by `-1`.
117
+ - Calling `a.dispose()` also changes its reference count by `-1`.
118
+
119
+ What this means is that all functions in jax-js must _take ownership_ of their arguments as
120
+ references. Whenever you would like to pass an Array as argument, you can pass it directly to
121
+ dispose of it, or use `.ref` if you'd like to use it again later.
122
+
123
+ **You must follow these rules on your own functions as well!** All combinators like `jvp`, `grad`,
124
+ `jit` assume that you are following these conventions on how arguments are passed, and they will
125
+ respect them as well.
126
+
127
+ ```ts
128
+ // Bad: Uses `x` twice, decrementing its reference count twice.
129
+ function foo_bad(x: np.Array, y: np.Array) {
130
+ return x.add(x.mul(y));
131
+ }
132
+
133
+ // Good: The first usage of `x` is `x.ref`, adding +1 to refcount.
134
+ function foo_good(x: np.Array, y: np.Array) {
135
+ return x.ref.add(x.mul(y));
136
+ }
137
+ ```
138
+
139
+ Here's another example:
140
+
141
+ ```ts
142
+ // Bad: Doesn't consume `x` in the `if`-branch.
143
+ function bar_bad(x: np.Array, skip: boolean) {
144
+ if (skip) return np.zeros(x.shape);
145
+ return x;
146
+ }
147
+
148
+ // Good: Consumes `x` the one time in each branch.
149
+ function bar_good(x: np.Array, skip: boolean) {
150
+ if (skip) {
151
+ const ret = np.zeros(x.shape);
152
+ x.dispose();
153
+ return ret;
154
+ }
155
+ return x;
156
+ }
157
+ ```
158
+
159
+ You can assume that every function in jax-js takes ownership properly, except with a couple of very
160
+ rare exceptions that are documented.
161
+
162
+ ### grad(), vmap() and jit()
163
+
164
+ JAX's signature composable transformations are also supported in jax-js. Here is a simple example of
165
+ using `grad` and `vmap` to compute the derivaive of a function:
166
+
167
+ ```ts
168
+ import { numpy as np, grad, vmap } from "@jax-js/jax";
169
+
170
+ const x = np.linspace(-10, 10, 1000);
171
+
172
+ const y1 = vmap(grad(np.sin))(x.ref); // d/dx sin(x) = cos(x)
173
+ const y2 = np.cos(x);
174
+
175
+ np.allclose(y1, y2); // => true
176
+ ```
177
+
178
+ The `jit` function is especially useful when doing long sequences of primitives on GPU, since it
179
+ fuses operations together into a single kernel dispatch. This
180
+ [improves memory bandwidth usage](https://substack.com/home/post/p-163548742) on hardware
181
+ accelerators, which is the bottleneck on GPU rather than raw FLOPs. For instance:
182
+
183
+ ```ts
184
+ export const hypot = jit(function hypot(x1: np.Array, x2: np.Array) {
185
+ return np.sqrt(np.square(x1).add(np.square(x2)));
186
+ });
187
+ ```
188
+
189
+ Without JIT, the `hypot()` function would require four kernel dispatches: two multiplies, one add,
190
+ and one sqrt. JIT fuses these together into a single kernel that does it all at once.
191
+
192
+ All functional transformations can take typed `JsTree` of inputs and outputs. These are similar to
193
+ [JAX's pytrees](https://docs.jax.dev/en/latest/pytrees.html), and it's basically just a structure of
194
+ nested JavaScript objects and arrays. For instance:
195
+
196
+ ```ts
197
+ import { grad, numpy as np } from "@jax-js/jax";
198
+
199
+ type Params = {
200
+ foo: np.Array;
201
+ bar: np.Array[];
202
+ };
203
+
204
+ function getSums(p: Params) {
205
+ const fooSum = p.foo.sum();
206
+ const barSum = p.bar.map((x) => x.sum()).reduce(np.add);
207
+ return fooSum.add(barSum);
208
+ }
209
+
210
+ grad(getSums)({
211
+ foo: np.array([1, 2, 3]),
212
+ bar: [np.array([10]), np.array([11, 12])],
213
+ });
214
+ // => { foo: [1, 1, 1], bar: [[1], [1, 1]] }
215
+ ```
216
+
217
+ Note that you need to use `type` alias syntax rather than `interface` to define fine-grained
218
+ `JsTree` types.
219
+
220
+ ### Devices
221
+
222
+ Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory
223
+ and determines how to execute compiled operations on them.
224
+
225
+ There are currently 3 devices in jax-js:
226
+
227
+ - `cpu`: Slow, mostly for debugging purposes.
228
+ - `wasm`: [WebAssembly](https://webassembly.org/), currently single-threaded and blocking.
229
+ - `webgpu`: [WebGPU](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API), available on
230
+ [supported browsers](https://caniuse.com/webgpu) (Chrome, Firefox, Safari, iOS).
231
+
232
+ The default device is `wasm`, but you can change this at startup time:
233
+
234
+ ```ts
235
+ import { defaultDevice, init } from "@jax-js/jax";
236
+
237
+ const devices = await init(); // Starts all available backends.
238
+
239
+ if (devices.includes("webgpu")) {
240
+ defaultDevice("webgpu");
241
+ } else {
242
+ console.warn("WebGPU is not supported, falling back to Wasm.");
243
+ }
244
+ ```
245
+
246
+ You can also place individual arrays on specific devices:
247
+
248
+ ```ts
249
+ import { devicePut, numpy as np } from "@jax-js/jax";
250
+
251
+ const ar = np.array([1, 2, 3]); // Starts with device="wasm"
252
+ await devicePut(ar, "webgpu"); // Now device="webgpu"
253
+ ```
254
+
255
+ ### Helper libraries
256
+
257
+ There are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a
258
+ self-contained way in other projects.
259
+
260
+ **`@jax-js/optax`** provides implementations of optimizers like Adam and SGD.
261
+
262
+ ```ts
263
+ import { adam } from "@jax-js/optax";
264
+
265
+ let params = np.array([1.0, 2.0, 3.0]);
266
+
267
+ const solver = adam(1e-3);
268
+ let optState = solver.init(params.ref);
269
+ let updates: np.Array;
270
+
271
+ const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();
272
+
273
+ for (let i = 0; i < 100; i++) {
274
+ const paramsGrad = grad(f)(params.ref);
275
+ [updates, optState] = solver.update(paramsGrad, optState);
276
+ params = applyUpdates(params, updates);
277
+ }
278
+ ```
279
+
280
+ **`@jax-js/loaders`** can load tensors from various formats like Safetensors, includes a fast and
281
+ compliant implementation of BPE, and caches HTTP requests for large assets like model weights in
282
+ OPFS.
283
+
284
+ ```ts
285
+ import { tokenizers } from "@jax-js/loaders";
286
+
287
+ const enc = await tokenizers.getBpe("clip");
288
+ const tokens = enc.encode("Hello, world!"); // => [ 49406, 3306, 267, 1002, ... ]
73
289
  ```
74
290
 
75
291
  ### Performance
@@ -77,22 +293,28 @@ script tag. This is the easiest way to get started on a blank HTML page.
77
293
  We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
78
294
  very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
79
295
  The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
80
- chip ([try it](https://www.ekzhang.com/jax-js/bench/matmul)).
296
+ chip ([try it](https://jax-js.com/bench/matmul)).
81
297
 
82
298
  For that example, it's around the same GFLOP/s as
83
299
  [TensorFlow.js](https://github.com/tensorflow/tfjs) and
84
300
  [ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
85
301
  libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
86
302
 
303
+ ### API Reference
304
+
305
+ That's all for this short tutorial. Please see the generated
306
+ [API reference](https://jax-js.com/docs) for detailed documentation.
307
+
87
308
  ## Examples
88
309
 
89
310
  If you make something cool with jax-js, don't be a stranger! We can feature it here.
90
311
 
91
- - [In-browser REPL](https://www.ekzhang.com/jax-js/repl)
92
- - [Interactive MNIST training](https://www.ekzhang.com/jax-js/mnist)
93
- - [Matmul benchmark](https://www.ekzhang.com/jax-js/bench/matmul)
94
- - [Conv2d benchmark](https://www.ekzhang.com/jax-js/bench/conv2d)
95
- - [Mandelbrot set](https://www.ekzhang.com/jax-js/mandelbrot)
312
+ - [Training neural networks on MNIST](https://jax-js.com/mnist)
313
+ - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
314
+ - [In-browser REPL](https://jax-js.com/repl)
315
+ - [Matmul benchmark](https://jax-js.com/bench/matmul)
316
+ - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
317
+ - [Mandelbrot set](https://jax-js.com/mandelbrot)
96
318
 
97
319
  ## Development
98
320
 
@@ -124,59 +346,12 @@ pnpm -C website dev
124
346
 
125
347
  ## Future work / help wanted
126
348
 
127
- Contributions are welcomed in the following areas:
349
+ Contributions are welcomed! Especially in:
128
350
 
129
351
  - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
130
352
  - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
131
- and multithreading.
132
- - Helping the JIT compiler to fuse operations in more cases.
353
+ and multithreading. (Even single-threaded Wasm could be ~20x faster.)
354
+ - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches and adding
355
+ epilogue to reductions.
133
356
  - Adding WebGL runtime for older browsers that don't support WebGPU.
134
357
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
135
- - Ergonomics and API improvements.
136
-
137
- ## Next on Eric's mind
138
-
139
- - Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
140
- - Performance
141
- - Improve perf of MobileCLIP neural network
142
- - Add fused epilogue to JIT
143
- - Fix fusion of activation functions with branches like tanh
144
- - Reduce kernel overhead of constants / inline expressions
145
- - How many threads to create per workgroup, depends on hardware
146
-
147
- ## Milestones
148
-
149
- - [x] It works!
150
- - [x] Demos: Browser REPL / editor
151
- - [x] First custom kernel
152
- - [x] Custom WebGPU backend, removing tfjs dependency
153
- - [x] Low-level operations
154
- - [x] Create `class Array {}` wrappers
155
- - [x] Reduction operations
156
- - [ ] Kernel tuning (see `tuner.ts`)
157
- - [x] "Upcast" optimizations (compute a tile per thread, e.g., matmul)
158
- - [x] "Unroll" optimizations (multiple loop iters per thread, e.g., matmul)
159
- - [ ] "Group" optimizations (multiple threads per value, e.g., matvec)
160
- - [ ] Blocks respect local dimensions
161
- - [x] Other dtypes like int32 and bool
162
- - [x] `jit()` support via Jaxprs and kernel fusion
163
- - [x] We figure out the `dispose()` / refcount / linear types stuff
164
- - [x] `dispose()` for saved "const" tracers in Jaxprs
165
- - [x] Garbage collection for JIT programs
166
- - [x] Debug grad-grad-jit test producing a UseAfterFreeError
167
- - [ ] Demos: Navier-Stokes, neural networks, statistics
168
- - [x] Features for neural networks
169
- - [x] Convolution
170
- - [x] Random and initializers
171
- - [x] Optimizers (optax package?)
172
- - [x] Wasm backend (needs malloc)
173
- - [x] Better memory allocation that frees buffers
174
- - [ ] SIMD support for Wasm backend
175
- - [ ] Async / multithreading Wasm support
176
- - [ ] Full support of weak types and committed devices
177
- - [x] High-level ops have automatic type promotion
178
- - [x] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
179
- - [ ] Committed devices -
180
- [ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
181
- - [ ] Device switching with `device_put()` between webgpu/cpu/wasm
182
- - [x] numpy/jax API compatibility table