@jax-js/jax 0.0.2 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +57 -25
- package/dist/backend-EBRGmEYw.js +3816 -0
- package/dist/{backend-BK21PBVP.cjs → backend-Ss1Mev_-.cjs} +2075 -107
- package/dist/index.cjs +1393 -250
- package/dist/index.d.cts +651 -102
- package/dist/index.d.ts +651 -102
- package/dist/index.js +1377 -245
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-BVdMaO9T.cjs} +62 -35
- package/dist/{webgpu-JVpVad6g.js → webgpu-ow0Pn_6q.js} +62 -35
- package/package.json +21 -9
- package/dist/backend-1eVbAoaV.js +0 -1890
package/README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# jax-js: JAX in pure JavaScript
|
|
2
2
|
|
|
3
|
-
[Website](https://www.ekzhang.com/jax-js/)
|
|
3
|
+
[Website](https://www.ekzhang.com/jax-js/) | [API Reference](https://www.ekzhang.com/jax-js/docs/)
|
|
4
4
|
|
|
5
5
|
This is a machine learning framework for the browser. It aims to bring JAX-style, high-performance
|
|
6
6
|
CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.
|
|
@@ -37,43 +37,58 @@ const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
|
|
|
37
37
|
const xgrad = grad(norm)(x); // [2, 4, 6]
|
|
38
38
|
```
|
|
39
39
|
|
|
40
|
-
The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu),
|
|
41
|
-
|
|
40
|
+
The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu), you can
|
|
41
|
+
switch to GPU for better performance.
|
|
42
42
|
|
|
43
43
|
```js
|
|
44
|
-
import { numpy as np
|
|
44
|
+
import { defaultDevice, numpy as np } from "@jax-js/jax";
|
|
45
45
|
|
|
46
46
|
// Change the default backend to GPU.
|
|
47
|
-
|
|
47
|
+
defaultDevice("webgpu");
|
|
48
48
|
|
|
49
49
|
const x = np.ones([4096, 4096]);
|
|
50
50
|
const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
|
|
51
51
|
```
|
|
52
52
|
|
|
53
|
+
Most common JAX APIs are supported. See the [compatibility table](./FEATURES.md) for a full
|
|
54
|
+
breakdown of what features are available.
|
|
55
|
+
|
|
53
56
|
## Development
|
|
54
57
|
|
|
55
|
-
|
|
58
|
+
This repository is managed by [`pnpm`](https://pnpm.io/). You can compile and build all packages in
|
|
59
|
+
watch mode with:
|
|
56
60
|
|
|
57
61
|
```bash
|
|
58
62
|
pnpm install
|
|
59
63
|
pnpm run build:watch
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
|
|
60
67
|
|
|
61
|
-
|
|
68
|
+
```bash
|
|
62
69
|
pnpm exec playwright install
|
|
63
70
|
pnpm test
|
|
64
71
|
```
|
|
65
72
|
|
|
73
|
+
_We are currently on an older version of Playwright that supports using WebGPU in headless mode;
|
|
74
|
+
newer versions seem to skip the WebGPU tests._
|
|
75
|
+
|
|
76
|
+
To start a Vite dev server running the website, demos and REPL:
|
|
77
|
+
|
|
78
|
+
```bash
|
|
79
|
+
pnpm -C website dev
|
|
80
|
+
```
|
|
81
|
+
|
|
66
82
|
## Next on Eric's mind
|
|
67
83
|
|
|
68
|
-
-
|
|
69
|
-
-
|
|
70
|
-
- Improve perf of MNIST neural network
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
-
|
|
74
|
-
-
|
|
75
|
-
-
|
|
76
|
-
- Frontend transformations need to match backend type for pureArray() and zeros() calls
|
|
84
|
+
- Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
|
|
85
|
+
- Performance
|
|
86
|
+
- Improve perf of MNIST neural network
|
|
87
|
+
- Optimize conv2d further (maybe blocks -> local dims?)
|
|
88
|
+
- Add fused epilogue to JIT
|
|
89
|
+
- Reduce kernel overhead of constants / inline expressions
|
|
90
|
+
- Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
|
|
91
|
+
- How many threads to create per workgroup, depends on hardware
|
|
77
92
|
|
|
78
93
|
## Milestones
|
|
79
94
|
|
|
@@ -92,16 +107,33 @@ pnpm test
|
|
|
92
107
|
- [x] Other dtypes like int32 and bool
|
|
93
108
|
- [x] `jit()` support via Jaxprs and kernel fusion
|
|
94
109
|
- [x] We figure out the `dispose()` / refcount / linear types stuff
|
|
95
|
-
- [
|
|
96
|
-
- [
|
|
97
|
-
- [
|
|
110
|
+
- [x] `dispose()` for saved "const" tracers in Jaxprs
|
|
111
|
+
- [x] Garbage collection for JIT programs
|
|
112
|
+
- [x] Debug grad-grad-jit test producing a UseAfterFreeError
|
|
98
113
|
- [ ] Demos: Navier-Stokes, neural networks, statistics
|
|
99
114
|
- [x] Features for neural networks
|
|
100
|
-
- [
|
|
115
|
+
- [x] Convolution
|
|
101
116
|
- [x] Random and initializers
|
|
102
|
-
- [
|
|
103
|
-
- [
|
|
117
|
+
- [x] Optimizers (optax package?)
|
|
118
|
+
- [x] Wasm backend (needs malloc)
|
|
119
|
+
- [x] Better memory allocation that frees buffers
|
|
104
120
|
- [ ] SIMD support for Wasm backend
|
|
105
|
-
- [ ]
|
|
106
|
-
- [ ]
|
|
107
|
-
- [ ]
|
|
121
|
+
- [ ] Async / multithreading Wasm support
|
|
122
|
+
- [ ] Full support of weak types and committed devices
|
|
123
|
+
- [ ] High-level ops have automatic type promotion
|
|
124
|
+
- [ ] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
|
|
125
|
+
- [ ] Committed devices -
|
|
126
|
+
[ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
|
|
127
|
+
- [ ] Device switching with `device_put()` between webgpu/cpu/wasm
|
|
128
|
+
- [x] numpy/jax API compatibility table
|
|
129
|
+
|
|
130
|
+
## Future work / help wanted
|
|
131
|
+
|
|
132
|
+
Contributions are welcomed in the following areas:
|
|
133
|
+
|
|
134
|
+
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
135
|
+
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, using SIMD and
|
|
136
|
+
multithreading.
|
|
137
|
+
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
138
|
+
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
139
|
+
- Ergonomics and API improvements.
|