@jax-js/jax 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -37,42 +37,58 @@ const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
37
37
  const xgrad = grad(norm)(x); // [2, 4, 6]
38
38
  ```
39
39
 
40
- The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu),
41
- you can switch to GPU for maximum performance.
40
+ The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu), you can
41
+ switch to GPU for better performance.
42
42
 
43
43
  ```js
44
- import { numpy as np, setDevice } from "@jax-js/jax";
44
+ import { defaultDevice, numpy as np } from "@jax-js/jax";
45
45
 
46
46
  // Change the default backend to GPU.
47
- setDevice("webgpu");
47
+ defaultDevice("webgpu");
48
48
 
49
49
  const x = np.ones([4096, 4096]);
50
50
  const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
51
51
  ```
52
52
 
53
+ Most common JAX APIs are supported. See the [compatibility table](./FEATURES.md) for a full
54
+ breakdown of what features are available.
55
+
53
56
  ## Development
54
57
 
55
- Under construction.
58
+ This repository is managed by [`pnpm`](https://pnpm.io/). You can compile and build all packages in
59
+ watch mode with:
56
60
 
57
61
  ```bash
58
62
  pnpm install
59
63
  pnpm run build:watch
64
+ ```
65
+
66
+ Then you can run tests in a headless browser using [Vitest](https://vitest.dev/).
60
67
 
61
- # Run tests
68
+ ```bash
62
69
  pnpm exec playwright install
63
70
  pnpm test
64
71
  ```
65
72
 
73
+ _We are currently on an older version of Playwright that supports using WebGPU in headless mode;
74
+ newer versions seem to skip the WebGPU tests._
75
+
76
+ To start a Vite dev server running the website, demos and REPL:
77
+
78
+ ```bash
79
+ pnpm -C website dev
80
+ ```
81
+
66
82
  ## Next on Eric's mind
67
83
 
68
84
  - Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
69
- - Fix jit-of-grad returning very incorrect result
70
- - Improve perf of MNIST neural network
71
- - Optimize conv2d further (maybe blocks -> local dims?)
72
- - Add fused epilogue to JIT
73
- - Reduce kernel overhead of constants / inline expressions
74
- - Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
75
- - How many threads to create per workgroup, depends on hardware
85
+ - Performance
86
+ - Improve perf of MNIST neural network
87
+ - Optimize conv2d further (maybe blocks -> local dims?)
88
+ - Add fused epilogue to JIT
89
+ - Reduce kernel overhead of constants / inline expressions
90
+ - Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
91
+ - How many threads to create per workgroup, depends on hardware
76
92
 
77
93
  ## Milestones
78
94
 
@@ -91,9 +107,9 @@ pnpm test
91
107
  - [x] Other dtypes like int32 and bool
92
108
  - [x] `jit()` support via Jaxprs and kernel fusion
93
109
  - [x] We figure out the `dispose()` / refcount / linear types stuff
94
- - [ ] `dispose()` for saved "const" tracers in Jaxprs
95
- - [ ] Garbage collection for JIT programs
96
- - [ ] Memory scheduling, buffer allocation (can be tricky)
110
+ - [x] `dispose()` for saved "const" tracers in Jaxprs
111
+ - [x] Garbage collection for JIT programs
112
+ - [x] Debug grad-grad-jit test producing a UseAfterFreeError
97
113
  - [ ] Demos: Navier-Stokes, neural networks, statistics
98
114
  - [x] Features for neural networks
99
115
  - [x] Convolution
@@ -103,6 +119,21 @@ pnpm test
103
119
  - [x] Better memory allocation that frees buffers
104
120
  - [ ] SIMD support for Wasm backend
105
121
  - [ ] Async / multithreading Wasm support
106
- - [ ] Device switching with `.to()` between webgpu/cpu/wasm
107
- - [ ] numpy/jax API compatibility table
108
- - [ ] Import tfjs models
122
+ - [ ] Full support of weak types and committed devices
123
+ - [ ] High-level ops have automatic type promotion
124
+ - [ ] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
125
+ - [ ] Committed devices -
126
+ [ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
127
+ - [ ] Device switching with `device_put()` between webgpu/cpu/wasm
128
+ - [x] numpy/jax API compatibility table
129
+
130
+ ## Future work / help wanted
131
+
132
+ Contributions are welcomed in the following areas:
133
+
134
+ - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
135
+ - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, using SIMD and
136
+ multithreading.
137
+ - Adding WebGL runtime for older browsers that don't support WebGPU.
138
+ - Making a fast transformer inference engine, comparing against onnxruntime-web.
139
+ - Ergonomics and API improvements.