@jax-js/jax 0.0.5 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-yEU0L_ig.cjs → backend-BbrKEB18.cjs} +378 -183
- package/dist/{backend-CdcTZEOF.js → backend-CoVtc9dx.js} +366 -177
- package/dist/index.cjs +385 -74
- package/dist/index.d.cts +115 -23
- package/dist/index.d.ts +115 -23
- package/dist/index.js +378 -74
- package/dist/{webgpu-CM-xNYzW.js → webgpu-B3UVme6n.js} +188 -153
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-DGYNVHma.cjs} +188 -153
- package/package.json +25 -15
package/README.md
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
|
-
|
|
1
|
+
<h1 align="center">jax-js: JAX in pure JavaScript</h1>
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
<p align="center"><strong>
|
|
4
|
+
<a href="https://jax-js.com">Website</a> |
|
|
5
|
+
<a href="https://jax-js.com/docs/">API Reference</a> |
|
|
6
|
+
<a href="./FEATURES.md">Compatibility Table</a>
|
|
7
|
+
</strong></p>
|
|
4
8
|
|
|
5
|
-
**jax-js** is a machine learning framework for the browser. It aims to bring
|
|
6
|
-
high-performance CPU and GPU kernels to JavaScript, so you can run
|
|
7
|
-
web.
|
|
9
|
+
**jax-js** is a machine learning framework for the browser. It aims to bring
|
|
10
|
+
[JAX](https://jax.dev)-style, high-performance CPU and GPU kernels to JavaScript, so you can run
|
|
11
|
+
numerical applications on the web.
|
|
8
12
|
|
|
9
13
|
```bash
|
|
10
14
|
npm i @jax-js/jax
|
|
@@ -19,57 +23,269 @@ GPU ML framework, since it runs anywhere a browser can run.
|
|
|
19
23
|
|
|
20
24
|
## Quickstart
|
|
21
25
|
|
|
22
|
-
You can use `jax-js` as an array API, just like NumPy.
|
|
23
|
-
|
|
24
26
|
```js
|
|
25
27
|
import { numpy as np } from "@jax-js/jax";
|
|
26
28
|
|
|
27
|
-
// Array operations, compatible with NumPy.
|
|
29
|
+
// Array operations, compatible with JAX/NumPy.
|
|
28
30
|
const x = np.array([1, 2, 3]);
|
|
29
31
|
const y = x.mul(4); // [4, 8, 12]
|
|
30
32
|
```
|
|
31
33
|
|
|
32
|
-
|
|
34
|
+
### Web usage (CDN)
|
|
33
35
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
+
In vanilla JavaScript (without a bundler), just import from a module script tag. This is the easiest
|
|
37
|
+
way to get started on a blank HTML page.
|
|
36
38
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
+
```html
|
|
40
|
+
<script type="module">
|
|
41
|
+
import { numpy as np } from "https://esm.sh/@jax-js/jax";
|
|
42
|
+
</script>
|
|
43
|
+
```
|
|
39
44
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
45
|
+
## Tutorial
|
|
46
|
+
|
|
47
|
+
Programming in `jax-js` looks [very similar to JAX](https://docs.jax.dev/en/latest/jax-101.html),
|
|
48
|
+
just in JavaScript.
|
|
49
|
+
|
|
50
|
+
### Arrays
|
|
51
|
+
|
|
52
|
+
Create an array with `np.array()`:
|
|
53
|
+
|
|
54
|
+
```ts
|
|
55
|
+
import { numpy as np } from "@jax-js/jax";
|
|
56
|
+
|
|
57
|
+
const ar = np.array([1, 2, 3]);
|
|
43
58
|
```
|
|
44
59
|
|
|
45
|
-
|
|
46
|
-
Chrome and iOS Safari, you can switch to GPU for better performance.
|
|
60
|
+
By default, this is a float32 array, but you can also specify a dtype explicitly:
|
|
47
61
|
|
|
48
|
-
```
|
|
49
|
-
|
|
62
|
+
```ts
|
|
63
|
+
const ar = np.array([1, 2, 3], { dtype: np.float32 });
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
For more efficient construction, you can create an array from a JS `TypedArray` buffer:
|
|
67
|
+
|
|
68
|
+
```ts
|
|
69
|
+
const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
|
|
70
|
+
const ar = np.array(buf).reshape([2, 3]);
|
|
71
|
+
```
|
|
50
72
|
|
|
51
|
-
|
|
52
|
-
|
|
73
|
+
Once you're done with it, you can unwrap a `jax.Array` back into JavaScript. This will also apply
|
|
74
|
+
any pending operations or lazy updates:
|
|
53
75
|
|
|
54
|
-
|
|
55
|
-
|
|
76
|
+
```ts
|
|
77
|
+
// 1) Returns a possibly nested JavaScript array.
|
|
78
|
+
ar.js();
|
|
79
|
+
await ar.jsAsync(); // Faster, non-blocking
|
|
56
80
|
|
|
57
|
-
|
|
58
|
-
|
|
81
|
+
// 2) Returns a flat TypedArray data buffer.
|
|
82
|
+
ar.dataSync();
|
|
83
|
+
await ar.data(); // Fastest, non-blocking
|
|
59
84
|
```
|
|
60
85
|
|
|
61
|
-
|
|
62
|
-
breakdown of what features are available.
|
|
86
|
+
Arrays can have mathematical operations applied to them. For example:
|
|
63
87
|
|
|
64
|
-
|
|
88
|
+
```ts
|
|
89
|
+
import { numpy as np, scipySpecial as special } from "@jax-js/jax";
|
|
65
90
|
|
|
66
|
-
|
|
67
|
-
script tag. This is the easiest way to get started on a blank HTML page.
|
|
91
|
+
const x = np.arange(100).astype(np.float32); // array of integers [0..99]
|
|
68
92
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
93
|
+
const y1 = x.ref.add(x.ref); // x + x
|
|
94
|
+
const y2 = np.sin(x.ref); // sin(x)
|
|
95
|
+
const y3 = np.tanh(x.ref).mul(5); // 5 * tanh(x)
|
|
96
|
+
const y4 = special.erfc(x.ref); // erfc(x)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
Notice that in the above code, we used `x.ref`. This is because of the memory model, jax-js uses
|
|
100
|
+
reference-counted _ownership_ to track when the memory of an Array can be freed. More on this below.
|
|
101
|
+
|
|
102
|
+
### Reference counting
|
|
103
|
+
|
|
104
|
+
Big Arrays take up a lot of memory. Python ML libraries override the `__del__()` method to free
|
|
105
|
+
memory, but JavaScript has no such API for running object destructors
|
|
106
|
+
([cf.](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry)).
|
|
107
|
+
This means that you have to track references manually. jax-js tries to make this as ergonomic as
|
|
108
|
+
possible, so you don't accidentally leak memory in a loop.
|
|
109
|
+
|
|
110
|
+
Every `jax.Array` has a reference count. This satisfies the following rules:
|
|
111
|
+
|
|
112
|
+
- Whenever you create an Array, its reference count starts at `1`.
|
|
113
|
+
- When an Array's reference count reaches `0`, it is freed and can no longer be used.
|
|
114
|
+
- Given an Array `a`:
|
|
115
|
+
- Accessing `a.ref` returns `a` and changes its reference count by `+1`.
|
|
116
|
+
- Passing `a` into any function as argument changes its reference count by `-1`.
|
|
117
|
+
- Calling `a.dispose()` also changes its reference count by `-1`.
|
|
118
|
+
|
|
119
|
+
What this means is that all functions in jax-js must _take ownership_ of their arguments as
|
|
120
|
+
references. Whenever you would like to pass an Array as argument, you can pass it directly to
|
|
121
|
+
dispose of it, or use `.ref` if you'd like to use it again later.
|
|
122
|
+
|
|
123
|
+
**You must follow these rules on your own functions as well!** All combinators like `jvp`, `grad`,
|
|
124
|
+
`jit` assume that you are following these conventions on how arguments are passed, and they will
|
|
125
|
+
respect them as well.
|
|
126
|
+
|
|
127
|
+
```ts
|
|
128
|
+
// Bad: Uses `x` twice, decrementing its reference count twice.
|
|
129
|
+
function foo_bad(x: np.Array, y: np.Array) {
|
|
130
|
+
return x.add(x.mul(y));
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
// Good: The first usage of `x` is `x.ref`, adding +1 to refcount.
|
|
134
|
+
function foo_good(x: np.Array, y: np.Array) {
|
|
135
|
+
return x.ref.add(x.mul(y));
|
|
136
|
+
}
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
Here's another example:
|
|
140
|
+
|
|
141
|
+
```ts
|
|
142
|
+
// Bad: Doesn't consume `x` in the `if`-branch.
|
|
143
|
+
function bar_bad(x: np.Array, skip: boolean) {
|
|
144
|
+
if (skip) return np.zeros(x.shape);
|
|
145
|
+
return x;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// Good: Consumes `x` the one time in each branch.
|
|
149
|
+
function bar_good(x: np.Array, skip: boolean) {
|
|
150
|
+
if (skip) {
|
|
151
|
+
const ret = np.zeros(x.shape);
|
|
152
|
+
x.dispose();
|
|
153
|
+
return ret;
|
|
154
|
+
}
|
|
155
|
+
return x;
|
|
156
|
+
}
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
You can assume that every function in jax-js takes ownership properly, except with a couple of very
|
|
160
|
+
rare exceptions that are documented.
|
|
161
|
+
|
|
162
|
+
### grad(), vmap() and jit()
|
|
163
|
+
|
|
164
|
+
JAX's signature composable transformations are also supported in jax-js. Here is a simple example of
|
|
165
|
+
using `grad` and `vmap` to compute the derivaive of a function:
|
|
166
|
+
|
|
167
|
+
```ts
|
|
168
|
+
import { numpy as np, grad, vmap } from "@jax-js/jax";
|
|
169
|
+
|
|
170
|
+
const x = np.linspace(-10, 10, 1000);
|
|
171
|
+
|
|
172
|
+
const y1 = vmap(grad(np.sin))(x.ref); // d/dx sin(x) = cos(x)
|
|
173
|
+
const y2 = np.cos(x);
|
|
174
|
+
|
|
175
|
+
np.allclose(y1, y2); // => true
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
The `jit` function is especially useful when doing long sequences of primitives on GPU, since it
|
|
179
|
+
fuses operations together into a single kernel dispatch. This
|
|
180
|
+
[improves memory bandwidth usage](https://substack.com/home/post/p-163548742) on hardware
|
|
181
|
+
accelerators, which is the bottleneck on GPU rather than raw FLOPs. For instance:
|
|
182
|
+
|
|
183
|
+
```ts
|
|
184
|
+
export const hypot = jit(function hypot(x1: np.Array, x2: np.Array) {
|
|
185
|
+
return np.sqrt(np.square(x1).add(np.square(x2)));
|
|
186
|
+
});
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
Without JIT, the `hypot()` function would require four kernel dispatches: two multiplies, one add,
|
|
190
|
+
and one sqrt. JIT fuses these together into a single kernel that does it all at once.
|
|
191
|
+
|
|
192
|
+
All functional transformations can take typed `JsTree` of inputs and outputs. These are similar to
|
|
193
|
+
[JAX's pytrees](https://docs.jax.dev/en/latest/pytrees.html), and it's basically just a structure of
|
|
194
|
+
nested JavaScript objects and arrays. For instance:
|
|
195
|
+
|
|
196
|
+
```ts
|
|
197
|
+
import { grad, numpy as np } from "@jax-js/jax";
|
|
198
|
+
|
|
199
|
+
type Params = {
|
|
200
|
+
foo: np.Array;
|
|
201
|
+
bar: np.Array[];
|
|
202
|
+
};
|
|
203
|
+
|
|
204
|
+
function getSums(p: Params) {
|
|
205
|
+
const fooSum = p.foo.sum();
|
|
206
|
+
const barSum = p.bar.map((x) => x.sum()).reduce(np.add);
|
|
207
|
+
return fooSum.add(barSum);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
grad(getSums)({
|
|
211
|
+
foo: np.array([1, 2, 3]),
|
|
212
|
+
bar: [np.array([10]), np.array([11, 12])],
|
|
213
|
+
});
|
|
214
|
+
// => { foo: [1, 1, 1], bar: [[1], [1, 1]] }
|
|
215
|
+
```
|
|
216
|
+
|
|
217
|
+
Note that you need to use `type` alias syntax rather than `interface` to define fine-grained
|
|
218
|
+
`JsTree` types.
|
|
219
|
+
|
|
220
|
+
### Devices
|
|
221
|
+
|
|
222
|
+
Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory
|
|
223
|
+
and determines how to execute compiled operations on them.
|
|
224
|
+
|
|
225
|
+
There are currently 3 devices in jax-js:
|
|
226
|
+
|
|
227
|
+
- `cpu`: Slow, mostly for debugging purposes.
|
|
228
|
+
- `wasm`: [WebAssembly](https://webassembly.org/), currently single-threaded and blocking.
|
|
229
|
+
- `webgpu`: [WebGPU](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API), available on
|
|
230
|
+
[supported browsers](https://caniuse.com/webgpu) (Chrome, Firefox, Safari, iOS).
|
|
231
|
+
|
|
232
|
+
The default device is `wasm`, but you can change this at startup time:
|
|
233
|
+
|
|
234
|
+
```ts
|
|
235
|
+
import { defaultDevice, init } from "@jax-js/jax";
|
|
236
|
+
|
|
237
|
+
const devices = await init(); // Starts all available backends.
|
|
238
|
+
|
|
239
|
+
if (devices.includes("webgpu")) {
|
|
240
|
+
defaultDevice("webgpu");
|
|
241
|
+
} else {
|
|
242
|
+
console.warn("WebGPU is not supported, falling back to Wasm.");
|
|
243
|
+
}
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
You can also place individual arrays on specific devices:
|
|
247
|
+
|
|
248
|
+
```ts
|
|
249
|
+
import { devicePut, numpy as np } from "@jax-js/jax";
|
|
250
|
+
|
|
251
|
+
const ar = np.array([1, 2, 3]); // Starts with device="wasm"
|
|
252
|
+
await devicePut(ar, "webgpu"); // Now device="webgpu"
|
|
253
|
+
```
|
|
254
|
+
|
|
255
|
+
### Helper libraries
|
|
256
|
+
|
|
257
|
+
There are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a
|
|
258
|
+
self-contained way in other projects.
|
|
259
|
+
|
|
260
|
+
**`@jax-js/optax`** provides implementations of optimizers like Adam and SGD.
|
|
261
|
+
|
|
262
|
+
```ts
|
|
263
|
+
import { adam } from "@jax-js/optax";
|
|
264
|
+
|
|
265
|
+
let params = np.array([1.0, 2.0, 3.0]);
|
|
266
|
+
|
|
267
|
+
const solver = adam(1e-3);
|
|
268
|
+
let optState = solver.init(params.ref);
|
|
269
|
+
let updates: np.Array;
|
|
270
|
+
|
|
271
|
+
const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();
|
|
272
|
+
|
|
273
|
+
for (let i = 0; i < 100; i++) {
|
|
274
|
+
const paramsGrad = grad(f)(params.ref);
|
|
275
|
+
[updates, optState] = solver.update(paramsGrad, optState);
|
|
276
|
+
params = applyUpdates(params, updates);
|
|
277
|
+
}
|
|
278
|
+
```
|
|
279
|
+
|
|
280
|
+
**`@jax-js/loaders`** can load tensors from various formats like Safetensors, includes a fast and
|
|
281
|
+
compliant implementation of BPE, and caches HTTP requests for large assets like model weights in
|
|
282
|
+
OPFS.
|
|
283
|
+
|
|
284
|
+
```ts
|
|
285
|
+
import { tokenizers } from "@jax-js/loaders";
|
|
286
|
+
|
|
287
|
+
const enc = await tokenizers.getBpe("clip");
|
|
288
|
+
const tokens = enc.encode("Hello, world!"); // => [ 49406, 3306, 267, 1002, ... ]
|
|
73
289
|
```
|
|
74
290
|
|
|
75
291
|
### Performance
|
|
@@ -77,22 +293,28 @@ script tag. This is the easiest way to get started on a blank HTML page.
|
|
|
77
293
|
We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
|
|
78
294
|
very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
|
|
79
295
|
The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
|
|
80
|
-
chip ([try it](https://
|
|
296
|
+
chip ([try it](https://jax-js.com/bench/matmul)).
|
|
81
297
|
|
|
82
298
|
For that example, it's around the same GFLOP/s as
|
|
83
299
|
[TensorFlow.js](https://github.com/tensorflow/tfjs) and
|
|
84
300
|
[ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
|
|
85
301
|
libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
|
|
86
302
|
|
|
303
|
+
### API Reference
|
|
304
|
+
|
|
305
|
+
That's all for this short tutorial. Please see the generated
|
|
306
|
+
[API reference](https://jax-js.com/docs) for detailed documentation.
|
|
307
|
+
|
|
87
308
|
## Examples
|
|
88
309
|
|
|
89
310
|
If you make something cool with jax-js, don't be a stranger! We can feature it here.
|
|
90
311
|
|
|
91
|
-
- [
|
|
92
|
-
- [
|
|
93
|
-
- [
|
|
94
|
-
- [
|
|
95
|
-
- [
|
|
312
|
+
- [Training neural networks on MNIST](https://jax-js.com/mnist)
|
|
313
|
+
- [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
|
|
314
|
+
- [In-browser REPL](https://jax-js.com/repl)
|
|
315
|
+
- [Matmul benchmark](https://jax-js.com/bench/matmul)
|
|
316
|
+
- [Conv2d benchmark](https://jax-js.com/bench/conv2d)
|
|
317
|
+
- [Mandelbrot set](https://jax-js.com/mandelbrot)
|
|
96
318
|
|
|
97
319
|
## Development
|
|
98
320
|
|
|
@@ -124,59 +346,12 @@ pnpm -C website dev
|
|
|
124
346
|
|
|
125
347
|
## Future work / help wanted
|
|
126
348
|
|
|
127
|
-
Contributions are welcomed in
|
|
349
|
+
Contributions are welcomed! Especially in:
|
|
128
350
|
|
|
129
351
|
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
130
352
|
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
131
|
-
and multithreading.
|
|
132
|
-
- Helping the JIT compiler to fuse operations in more cases
|
|
353
|
+
and multithreading. (Even single-threaded Wasm could be ~20x faster.)
|
|
354
|
+
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches and adding
|
|
355
|
+
epilogue to reductions.
|
|
133
356
|
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
134
357
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
135
|
-
- Ergonomics and API improvements.
|
|
136
|
-
|
|
137
|
-
## Next on Eric's mind
|
|
138
|
-
|
|
139
|
-
- Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
|
|
140
|
-
- Performance
|
|
141
|
-
- Improve perf of MobileCLIP neural network
|
|
142
|
-
- Add fused epilogue to JIT
|
|
143
|
-
- Fix fusion of activation functions with branches like tanh
|
|
144
|
-
- Reduce kernel overhead of constants / inline expressions
|
|
145
|
-
- How many threads to create per workgroup, depends on hardware
|
|
146
|
-
|
|
147
|
-
## Milestones
|
|
148
|
-
|
|
149
|
-
- [x] It works!
|
|
150
|
-
- [x] Demos: Browser REPL / editor
|
|
151
|
-
- [x] First custom kernel
|
|
152
|
-
- [x] Custom WebGPU backend, removing tfjs dependency
|
|
153
|
-
- [x] Low-level operations
|
|
154
|
-
- [x] Create `class Array {}` wrappers
|
|
155
|
-
- [x] Reduction operations
|
|
156
|
-
- [ ] Kernel tuning (see `tuner.ts`)
|
|
157
|
-
- [x] "Upcast" optimizations (compute a tile per thread, e.g., matmul)
|
|
158
|
-
- [x] "Unroll" optimizations (multiple loop iters per thread, e.g., matmul)
|
|
159
|
-
- [ ] "Group" optimizations (multiple threads per value, e.g., matvec)
|
|
160
|
-
- [ ] Blocks respect local dimensions
|
|
161
|
-
- [x] Other dtypes like int32 and bool
|
|
162
|
-
- [x] `jit()` support via Jaxprs and kernel fusion
|
|
163
|
-
- [x] We figure out the `dispose()` / refcount / linear types stuff
|
|
164
|
-
- [x] `dispose()` for saved "const" tracers in Jaxprs
|
|
165
|
-
- [x] Garbage collection for JIT programs
|
|
166
|
-
- [x] Debug grad-grad-jit test producing a UseAfterFreeError
|
|
167
|
-
- [ ] Demos: Navier-Stokes, neural networks, statistics
|
|
168
|
-
- [x] Features for neural networks
|
|
169
|
-
- [x] Convolution
|
|
170
|
-
- [x] Random and initializers
|
|
171
|
-
- [x] Optimizers (optax package?)
|
|
172
|
-
- [x] Wasm backend (needs malloc)
|
|
173
|
-
- [x] Better memory allocation that frees buffers
|
|
174
|
-
- [ ] SIMD support for Wasm backend
|
|
175
|
-
- [ ] Async / multithreading Wasm support
|
|
176
|
-
- [ ] Full support of weak types and committed devices
|
|
177
|
-
- [x] High-level ops have automatic type promotion
|
|
178
|
-
- [x] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
|
|
179
|
-
- [ ] Committed devices -
|
|
180
|
-
[ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
|
|
181
|
-
- [ ] Device switching with `device_put()` between webgpu/cpu/wasm
|
|
182
|
-
- [x] numpy/jax API compatibility table
|