@jax-js/jax 0.0.4 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -2,8 +2,9 @@
2
2
 
3
3
  [Website](https://www.ekzhang.com/jax-js/) | [API Reference](https://www.ekzhang.com/jax-js/docs/)
4
4
 
5
- This is a machine learning framework for the browser. It aims to bring JAX-style, high-performance
6
- CPU and GPU kernels to JavaScript, so you can run numerical applications on the web.
5
+ **jax-js** is a machine learning framework for the browser. It aims to bring JAX-style,
6
+ high-performance CPU and GPU kernels to JavaScript, so you can run numerical applications on the
7
+ web.
7
8
 
8
9
  ```bash
9
10
  npm i @jax-js/jax
@@ -12,6 +13,10 @@ npm i @jax-js/jax
12
13
  Under the hood, it translates array operations into a compiler representation, then synthesizes
13
14
  kernels in WebAssembly and WebGPU.
14
15
 
16
+ The library is written from scratch, with zero external dependencies. It maintains close API
17
+ compatibility with NumPy/JAX. Since everything runs client-side, jax-js is likely the most portable
18
+ GPU ML framework, since it runs anywhere a browser can run.
19
+
15
20
  ## Quickstart
16
21
 
17
22
  You can use `jax-js` as an array API, just like NumPy.
@@ -24,7 +29,7 @@ const x = np.array([1, 2, 3]);
24
29
  const y = x.mul(4); // [4, 8, 12]
25
30
  ```
26
31
 
27
- It also lets you take derivatives like in JAX.
32
+ It also lets you take derivatives with `grad` like in JAX (as well as `vmap`, `jit`).
28
33
 
29
34
  ```js
30
35
  import { grad, numpy as np } from "@jax-js/jax";
@@ -37,11 +42,14 @@ const xnorm = norm(x.ref); // 1^2 + 2^2 + 3^2 = 14
37
42
  const xgrad = grad(norm)(x); // [2, 4, 6]
38
43
  ```
39
44
 
40
- The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu), you can
41
- switch to GPU for better performance.
45
+ The default backend runs on CPU, but on [supported browsers](https://caniuse.com/webgpu) including
46
+ Chrome and iOS Safari, you can switch to GPU for better performance.
42
47
 
43
48
  ```js
44
- import { defaultDevice, numpy as np } from "@jax-js/jax";
49
+ import { defaultDevice, init, numpy as np } from "@jax-js/jax";
50
+
51
+ // Initialize the GPU backend.
52
+ await init("webgpu");
45
53
 
46
54
  // Change the default backend to GPU.
47
55
  defaultDevice("webgpu");
@@ -53,8 +61,43 @@ const y = np.dot(x.ref, x); // JIT-compiled into a matrix multiplication kernel
53
61
  Most common JAX APIs are supported. See the [compatibility table](./FEATURES.md) for a full
54
62
  breakdown of what features are available.
55
63
 
64
+ ### Web usage (CDN)
65
+
66
+ If you want to use `jax-js` in vanilla JavaScript (without a bundler), just import from a module
67
+ script tag. This is the easiest way to get started on a blank HTML page.
68
+
69
+ ```html
70
+ <script type="module">
71
+ import { numpy as np } from "https://esm.sh/@jax-js/jax";
72
+ </script>
73
+ ```
74
+
75
+ ### Performance
76
+
77
+ We haven't spent a ton of time optimizing yet, but performance is generally pretty good. `jit` is
78
+ very helpful for fusing operations together, and it's a feature only available on the web in jax-js.
79
+ The default kernel-tuning heuristics get about 3000 GFLOP/s for matrix multiplication on an M4 Pro
80
+ chip ([try it](https://www.ekzhang.com/jax-js/bench/matmul)).
81
+
82
+ For that example, it's around the same GFLOP/s as
83
+ [TensorFlow.js](https://github.com/tensorflow/tfjs) and
84
+ [ONNX Runtime Web](https://www.npmjs.com/package/onnxruntime-web), which both use handwritten
85
+ libraries of custom kernels (versus jax-js, which generates kernels with an ML compiler).
86
+
87
+ ## Examples
88
+
89
+ If you make something cool with jax-js, don't be a stranger! We can feature it here.
90
+
91
+ - [In-browser REPL](https://www.ekzhang.com/jax-js/repl)
92
+ - [Interactive MNIST training](https://www.ekzhang.com/jax-js/mnist)
93
+ - [Matmul benchmark](https://www.ekzhang.com/jax-js/bench/matmul)
94
+ - [Conv2d benchmark](https://www.ekzhang.com/jax-js/bench/conv2d)
95
+ - [Mandelbrot set](https://www.ekzhang.com/jax-js/mandelbrot)
96
+
56
97
  ## Development
57
98
 
99
+ _The following technical details are for contributing to jax-js and modifying its internals._
100
+
58
101
  This repository is managed by [`pnpm`](https://pnpm.io/). You can compile and build all packages in
59
102
  watch mode with:
60
103
 
@@ -70,8 +113,8 @@ pnpm exec playwright install
70
113
  pnpm test
71
114
  ```
72
115
 
73
- _We are currently on an older version of Playwright that supports using WebGPU in headless mode;
74
- newer versions seem to skip the WebGPU tests._
116
+ We are currently on an older version of Playwright that supports using WebGPU in headless mode;
117
+ newer versions skip the WebGPU tests.
75
118
 
76
119
  To start a Vite dev server running the website, demos and REPL:
77
120
 
@@ -79,15 +122,26 @@ To start a Vite dev server running the website, demos and REPL:
79
122
  pnpm -C website dev
80
123
  ```
81
124
 
125
+ ## Future work / help wanted
126
+
127
+ Contributions are welcomed in the following areas:
128
+
129
+ - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
130
+ - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
131
+ and multithreading.
132
+ - Helping the JIT compiler to fuse operations in more cases.
133
+ - Adding WebGL runtime for older browsers that don't support WebGPU.
134
+ - Making a fast transformer inference engine, comparing against onnxruntime-web.
135
+ - Ergonomics and API improvements.
136
+
82
137
  ## Next on Eric's mind
83
138
 
84
139
  - Finish CLIP inference demo and associated features (depthwise convolution, vmap of gather, etc.)
85
140
  - Performance
86
- - Improve perf of MNIST neural network
87
- - Optimize conv2d further (maybe blocks -> local dims?)
141
+ - Improve perf of MobileCLIP neural network
88
142
  - Add fused epilogue to JIT
143
+ - Fix fusion of activation functions with branches like tanh
89
144
  - Reduce kernel overhead of constants / inline expressions
90
- - Investigate why jax-js Matmul is 2x slower on Safari TP than unroll kernel
91
145
  - How many threads to create per workgroup, depends on hardware
92
146
 
93
147
  ## Milestones
@@ -120,20 +174,9 @@ pnpm -C website dev
120
174
  - [ ] SIMD support for Wasm backend
121
175
  - [ ] Async / multithreading Wasm support
122
176
  - [ ] Full support of weak types and committed devices
123
- - [ ] High-level ops have automatic type promotion
124
- - [ ] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
177
+ - [x] High-level ops have automatic type promotion
178
+ - [x] Weak types - [ref](https://docs.jax.dev/en/latest/type_promotion.html#weak-types)
125
179
  - [ ] Committed devices -
126
180
  [ref](https://docs.jax.dev/en/latest/sharded-computation.html#sharded-data-placement)
127
181
  - [ ] Device switching with `device_put()` between webgpu/cpu/wasm
128
182
  - [x] numpy/jax API compatibility table
129
-
130
- ## Future work / help wanted
131
-
132
- Contributions are welcomed in the following areas:
133
-
134
- - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
135
- - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, using SIMD and
136
- multithreading.
137
- - Adding WebGL runtime for older browsers that don't support WebGPU.
138
- - Making a fast transformer inference engine, comparing against onnxruntime-web.
139
- - Ergonomics and API improvements.
@@ -206,6 +206,35 @@ function findPow2(hint, max) {
206
206
  while (ret < hint && 2 * ret <= max) ret *= 2;
207
207
  return ret;
208
208
  }
209
+ /**
210
+ * Implements a NumPy-style generalized broadcast rule on two array shapes.
211
+ *
212
+ * "When operating on two arrays, NumPy compares their shapes element-wise. It
213
+ * starts with the trailing (i.e. rightmost) dimension and works its way left.
214
+ * Two dimensions are compatible when:
215
+ * 1. they are equal, or
216
+ * 2. one of them is 1."
217
+ *
218
+ * Throws a TypeError if the broadcast is not possible.
219
+ *
220
+ * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
221
+ */
222
+ function generalBroadcast(a, b) {
223
+ const out = [];
224
+ let i = a.length - 1;
225
+ let j = b.length - 1;
226
+ for (; i >= 0 && j >= 0; i--, j--) {
227
+ const x = a[i];
228
+ const y = b[j];
229
+ if (x === y) out.push(x);
230
+ else if (x === 1) out.push(y);
231
+ else if (y === 1) out.push(x);
232
+ else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
233
+ }
234
+ for (; i >= 0; i--) out.push(a[i]);
235
+ for (; j >= 0; j--) out.push(b[j]);
236
+ return out.reverse();
237
+ }
209
238
  function recursiveFlatten(ar) {
210
239
  if (!Array.isArray(ar)) return [ar];
211
240
  return ar.flat(Infinity);
@@ -294,12 +323,12 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
294
323
  * **Type lattice:**
295
324
  * ```text
296
325
  * bool -> uint32 -> int32 -> float16 -> float32
297
- * weak f* --^
326
+ * weakType --^
298
327
  * ```
299
328
  *
300
- * The asterisk f* is a weak type used for JS number constants. When creating
301
- * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
302
- * any array they are first combined with.
329
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
330
+ * which default to float32 but "weak" so they cast to the dtype of any array
331
+ * they are first combined with, except `bool`.
303
332
  *
304
333
  * **Examples:**
305
334
  * - `promoteTypes(bool, int32) → int32`
@@ -3760,7 +3789,7 @@ async function createBackend(device) {
3760
3789
  if (!navigator.gpu) return null;
3761
3790
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3762
3791
  if (!adapter) return null;
3763
- const { WebGPUBackend } = await import("./webgpu-ow0Pn_6q.js");
3792
+ const { WebGPUBackend } = await import("./webgpu-CM-xNYzW.js");
3764
3793
  const importantLimits = [
3765
3794
  "maxBufferSize",
3766
3795
  "maxComputeInvocationsPerWorkgroup",
@@ -3813,4 +3842,4 @@ var UnsupportedOpError = class extends Error {
3813
3842
  };
3814
3843
 
3815
3844
  //#endregion
3816
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
3845
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
@@ -207,6 +207,35 @@ function findPow2(hint, max) {
207
207
  while (ret < hint && 2 * ret <= max) ret *= 2;
208
208
  return ret;
209
209
  }
210
+ /**
211
+ * Implements a NumPy-style generalized broadcast rule on two array shapes.
212
+ *
213
+ * "When operating on two arrays, NumPy compares their shapes element-wise. It
214
+ * starts with the trailing (i.e. rightmost) dimension and works its way left.
215
+ * Two dimensions are compatible when:
216
+ * 1. they are equal, or
217
+ * 2. one of them is 1."
218
+ *
219
+ * Throws a TypeError if the broadcast is not possible.
220
+ *
221
+ * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
222
+ */
223
+ function generalBroadcast(a, b) {
224
+ const out = [];
225
+ let i = a.length - 1;
226
+ let j = b.length - 1;
227
+ for (; i >= 0 && j >= 0; i--, j--) {
228
+ const x = a[i];
229
+ const y = b[j];
230
+ if (x === y) out.push(x);
231
+ else if (x === 1) out.push(y);
232
+ else if (y === 1) out.push(x);
233
+ else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
234
+ }
235
+ for (; i >= 0; i--) out.push(a[i]);
236
+ for (; j >= 0; j--) out.push(b[j]);
237
+ return out.reverse();
238
+ }
210
239
  function recursiveFlatten(ar) {
211
240
  if (!Array.isArray(ar)) return [ar];
212
241
  return ar.flat(Infinity);
@@ -295,12 +324,12 @@ const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float
295
324
  * **Type lattice:**
296
325
  * ```text
297
326
  * bool -> uint32 -> int32 -> float16 -> float32
298
- * weak f* --^
327
+ * weakType --^
299
328
  * ```
300
329
  *
301
- * The asterisk f* is a weak type used for JS number constants. When creating
302
- * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
303
- * any array they are first combined with.
330
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
331
+ * which default to float32 but "weak" so they cast to the dtype of any array
332
+ * they are first combined with, except `bool`.
304
333
  *
305
334
  * **Examples:**
306
335
  * - `promoteTypes(bool, int32) → int32`
@@ -3761,7 +3790,7 @@ async function createBackend(device) {
3761
3790
  if (!navigator.gpu) return null;
3762
3791
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3763
3792
  if (!adapter) return null;
3764
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVdMaO9T.cjs"));
3793
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CNOpiO5T.cjs"));
3765
3794
  const importantLimits = [
3766
3795
  "maxBufferSize",
3767
3796
  "maxComputeInvocationsPerWorkgroup",
@@ -3958,6 +3987,12 @@ Object.defineProperty(exports, 'findPow2', {
3958
3987
  return findPow2;
3959
3988
  }
3960
3989
  });
3990
+ Object.defineProperty(exports, 'generalBroadcast', {
3991
+ enumerable: true,
3992
+ get: function () {
3993
+ return generalBroadcast;
3994
+ }
3995
+ });
3961
3996
  Object.defineProperty(exports, 'getBackend', {
3962
3997
  enumerable: true,
3963
3998
  get: function () {