@jax-js/jax 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,6 +1,6 @@
1
1
  # jax-js: JAX in pure JavaScript
2
2
 
3
- [Website](https://www.ekzhang.com/jax-js/)
3
+ [Website](https://www.ekzhang.com/jax-js/) | [API Reference](https://www.ekzhang.com/jax-js/docs/)
4
4
 
5
5
  This is a machine learning framework for the browser. It aims to bring JAX-style, high-performance
6
6
  CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.
@@ -37,43 +37,58 @@ const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
37
37
  const xgrad = grad(norm)(x); // [2, 4, 6]
38
38
  ```
39
39
 
40
- The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu),
41
- you can switch to GPU for maximum performance.
40
+ The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu), you can
41
+ switch to GPU for better performance.
42
42
 
43
43
  ```js
44
- import { numpy as np, setDevice } from "@jax-js/jax";
44
+ import { defaultDevice, numpy as np } from "@jax-js/jax";
45
45
 
46
46
  // Change the default backend to GPU.
47
- setDevice("webgpu");
47
+ defaultDevice("webgpu");
48
48
 
49
49
  const x = np.ones([4096, 4096]);
50
50
  const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
51
51
  ```
52
52
 
53
+ Most common JAX APIs are supported. See the [compatibility table](./FEATURES.md) for a full
54
+ breakdown of what features are available.
55
+
53
56
  ## Development
54
57
 
55
- Under construction.
58
+ This repository is managed by [`pnpm`](https://pnpm.io/). You can compile and build all packages in
59
+ watch mode with:
56
60
 
57
61
  ```bash
58
62
  pnpm install
59
63
  pnpm run build:watch
64
+ ```
65
+
66
+ Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
60
67
 
61
- # Run tests
68
+ ```bash
62
69
  pnpm exec playwright install
63
70
  pnpm test
64
71
  ```
65
72
 
73
+ _We are currently on an older version of Playwright that supports using WebGPU in headless mode;
74
+ newer versions seem to skip the WebGPU tests._
75
+
76
+ To start a Vite dev server running the website, demos and REPL:
77
+
78
+ ```bash
79
+ pnpm -C website dev
80
+ ```
81
+
66
82
  ## Next on Eric's mind
67
83
 
68
- - Fix jit-of-grad returning very incorrect result
69
- - Probably add static_argnums to jit() so that clip and some nn functions have jit added
70
- - Improve perf of MNIST neural network
71
- - Adding fused reductions to JIT
72
- - Reduce kernel overhead of constants / inline expressions
73
- - Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
74
- - How many threads to create per workgroup, depends on hardware
75
- - Think about two-stage `cumsum()`
76
- - Frontend transformations need to match backend type for pureArray() and zeros() calls
84
+ - Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
85
+ - Performance
86
+ - Improve perf of MNIST neural network
87
+ - Optimize conv2d further (maybe blocks -> local dims?)
88
+ - Add fused epilogue to JIT
89
+ - Reduce kernel overhead of constants / inline expressions
90
+ - Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
91
+ - How many threads to create per workgroup, depends on hardware
77
92
 
78
93
  ## Milestones
79
94
 
@@ -92,16 +107,33 @@ pnpm test
92
107
  - [x] Other dtypes like int32 and bool
93
108
  - [x] `jit()` support via Jaxprs and kernel fusion
94
109
  - [x] We figure out the `dispose()` / refcount / linear types stuff
95
- - [ ] `dispose()` for saved "const" tracers in Jaxprs
96
- - [ ] Garbage collection for JIT programs
97
- - [ ] Memory scheduling, buffer allocation (can be tricky)
110
+ - [x] `dispose()` for saved "const" tracers in Jaxprs
111
+ - [x] Garbage collection for JIT programs
112
+ - [x] Debug grad-grad-jit test producing a UseAfterFreeError
98
113
  - [ ] Demos: Navier-Stokes, neural networks, statistics
99
114
  - [x] Features for neural networks
100
- - [ ] Convolution
115
+ - [x] Convolution
101
116
  - [x] Random and initializers
102
- - [ ] Optimizers (optax package?)
103
- - [ ] Wasm backend (needs malloc)
117
+ - [x] Optimizers (optax package?)
118
+ - [x] Wasm backend (needs malloc)
119
+ - [x] Better memory allocation that frees buffers
104
120
  - [ ] SIMD support for Wasm backend
105
- - [ ] Device switching with `.to()` between webgpu/cpu/wasm
106
- - [ ] numpy/jax API compatibility table
107
- - [ ] Import tfjs models
121
+ - [ ] Async / multithreading Wasm support
122
+ - [ ] Full support of weak types and committed devices
123
+ - [ ] High-level ops have automatic type promotion
124
+ - [ ] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
125
+ - [ ] Committed devices -
126
+ [ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
127
+ - [ ] Device switching with `device_put()` between webgpu/cpu/wasm
128
+ - [x] numpy/jax API compatibility table
129
+
130
+ ## Future work / help wanted
131
+
132
+ Contributions are welcomed in the following areas:
133
+
134
+ - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
135
+ - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, using SIMD and
136
+ multithreading.
137
+ - Adding WebGL runtime for older browsers that don't support WebGPU.
138
+ - Making a fast transformer inference engine, comparing against onnxruntime-web.
139
+ - Ergonomics and API improvements.