@jax-js/jax 0.0.2 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,6 +1,6 @@
1
1
  # jax-js: JAX in pure JavaScript
2
2
 
3
- [Website](https://www.ekzhang.com/jax-js/)
3
+ [Website](https://www.ekzhang.com/jax-js/) | [API Reference](https://www.ekzhang.com/jax-js/docs/)
4
4
 
5
5
  This is a machine learning framework for the browser. It aims to bring JAX-style, high-performance
6
6
  CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.
@@ -65,15 +65,14 @@ pnpm test
65
65
 
66
66
  ## Next on Eric's mind
67
67
 
68
+ - Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
68
69
  - Fix jit-of-grad returning very incorrect result
69
- - Probably add static_argnums to jit() so that clip and some nn functions have jit added
70
70
  - Improve perf of MNIST neural network
71
- - Adding fused reductions to JIT
71
+ - Optimize conv2d further (maybe blocks -> local dims?)
72
+ - Add fused epilogue to JIT
72
73
  - Reduce kernel overhead of constants / inline expressions
73
74
  - Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
74
75
  - How many threads to create per workgroup, depends on hardware
75
- - Think about two-stage `cumsum()`
76
- - Frontend transformations need to match backend type for pureArray() and zeros() calls
77
76
 
78
77
  ## Milestones
79
78
 
@@ -97,11 +96,13 @@ pnpm test
97
96
  - [ ] Memory scheduling, buffer allocation (can be tricky)
98
97
  - [ ] Demos: Navier-Stokes, neural networks, statistics
99
98
  - [x] Features for neural networks
100
- - [ ] Convolution
99
+ - [x] Convolution
101
100
  - [x] Random and initializers
102
- - [ ] Optimizers (optax package?)
103
- - [ ] Wasm backend (needs malloc)
101
+ - [x] Optimizers (optax package?)
102
+ - [x] Wasm backend (needs malloc)
103
+ - [x] Better memory allocation that frees buffers
104
104
  - [ ] SIMD support for Wasm backend
105
+ - [ ] Async / multithreading Wasm support
105
106
  - [ ] Device switching with `.to()` between webgpu/cpu/wasm
106
107
  - [ ] numpy/jax API compatibility table
107
108
  - [ ] Import tfjs models