@jax-js/jax 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -58,13 +58,13 @@ import { numpy as np } from "@jax-js/jax";
58
58
  const ar = np.array([1, 2, 3]);
59
59
  ```
60
60
 
61
- By default, this is a float32 array, but you can also specify a dtype explicitly:
61
+ By default, this is a float32 array, but you can specify a different dtype:
62
62
 
63
63
  ```ts
64
- const ar = np.array([1, 2, 3], { dtype: np.float32 });
64
+ const ar = np.array([1, 2, 3], { dtype: np.int32 });
65
65
  ```
66
66
 
67
- For more efficient construction, you can create an array from a JS `TypedArray` buffer:
67
+ For more efficient construction, create an array from a JS `TypedArray` buffer:
68
68
 
69
69
  ```ts
70
70
  const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
@@ -223,14 +223,18 @@ Note that you need to use `type` alias syntax rather than `interface` to define
223
223
  Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory
224
224
  and determines how to execute compiled operations on them.
225
225
 
226
- There are currently 3 devices in jax-js:
226
+ There are currently 4 devices in jax-js:
227
227
 
228
- - `cpu`: Slow, mostly for debugging purposes.
228
+ - `cpu`: Slow, interpreted JS, only meant for debugging.
229
229
  - `wasm`: [WebAssembly](https://webassembly.org/), currently single-threaded and blocking.
230
230
  - `webgpu`: [WebGPU](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API), available on
231
231
  [supported browsers](https://caniuse.com/webgpu) (Chrome, Firefox, Safari, iOS).
232
+ - `webgl`: [WebGL2](https://developer.mozilla.org/en-US/docs/Web/API/WebGL2RenderingContext), via
233
+ fragment shaders. This is an older graphics API that runs on almost all browsers, but it is much
234
+ slower than WebGPU. It's offered on a best-effort basis and not as well-supported.
232
235
 
233
- The default device is `wasm`, but you can change this at startup time:
236
+ **We recommend `webgpu` for best performance, especially when running neural networks.** The default
237
+ device is `wasm`, but you can change this at startup time:
234
238
 
235
239
  ```ts
236
240
  import { defaultDevice, init } from "@jax-js/jax";
@@ -333,7 +337,6 @@ Contributions are welcomed! Some fruitful areas to look into:
333
337
  able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
334
338
  to help with model performance debugging.
335
339
  - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
336
- - Adding WebGL runtime for older browsers that don't support WebGPU.
337
340
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
338
341
 
339
342
  You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
@@ -1470,10 +1470,37 @@ let Routines = /* @__PURE__ */ function(Routines$1) {
1470
1470
  Routines$1["Sort"] = "Sort";
1471
1471
  /** Returns `int32` indices of the stably sorted array. */
1472
1472
  Routines$1["Argsort"] = "Argsort";
1473
- /** Solve a triangular system of questions. */
1473
+ /**
1474
+ * Solve a triangular system of equations.
1475
+ *
1476
+ * The first batch of inputs `A` should be of shape `[..., N, N]` and upper
1477
+ * triangular, while the second batch `B` should be of shape `[..., M, N]`.
1478
+ *
1479
+ * Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
1480
+ * triangular matrix. This is equivalent to `X = B @ A^-T`.
1481
+ */
1474
1482
  Routines$1["TriangularSolve"] = "TriangularSolve";
1475
- /** Cholesky decomposition of 2D positive semi-definite matrices. */
1483
+ /**
1484
+ * Cholesky decomposition of 2D positive semi-definite matrices.
1485
+ *
1486
+ * The input batch should be of shape `[..., N, N]`, and the output batch is
1487
+ * of the same shape, containing the lower-triangular matrix `L` such that
1488
+ * `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
1489
+ */
1476
1490
  Routines$1["Cholesky"] = "Cholesky";
1491
+ /**
1492
+ * LU decomposition of 2D rectangular matrices.
1493
+ *
1494
+ * The input is a batch of shape `[..., M, N]`, and the output is a tuple of
1495
+ * three arrays: `LU, Pivots, Permutation`.
1496
+ *
1497
+ * - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
1498
+ * triangular matrices. (lower triangular = implicit unit diagonal)
1499
+ * - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
1500
+ * - `Permutation` is of shape `[..., M]`, containing the permutation vector
1501
+ * such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
1502
+ */
1503
+ Routines$1["LU"] = "LU";
1477
1504
  return Routines$1;
1478
1505
  }({});
1479
1506
  function runCpuRoutine(routine, inputs, outputs) {
@@ -1485,6 +1512,7 @@ function runCpuRoutine(routine, inputs, outputs) {
1485
1512
  case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
1486
1513
  case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
1487
1514
  case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
1515
+ case Routines.LU: return runLU(type, inputAr, outputAr);
1488
1516
  default:
1489
1517
  }
1490
1518
  }
@@ -1545,6 +1573,50 @@ function runCholesky(type, [x], [y]) {
1545
1573
  }
1546
1574
  }
1547
1575
  }
1576
+ function runLU(type, [a], [lu, pivots, perm]) {
1577
+ const shape = type.inputShapes[0];
1578
+ if (shape.length < 2) throw new Error("lu: input must be at least 2D");
1579
+ const m = shape[shape.length - 2];
1580
+ const n = shape[shape.length - 1];
1581
+ const r = Math.min(m, n);
1582
+ for (let offset = 0; offset < a.length; offset += m * n) {
1583
+ const ar = a.subarray(offset, offset + m * n);
1584
+ const out = lu.subarray(offset, offset + m * n);
1585
+ const batchIdx = offset / (m * n);
1586
+ const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
1587
+ const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
1588
+ out.set(ar);
1589
+ for (let i = 0; i < m; i++) p[i] = i;
1590
+ for (let j = 0; j < r; j++) {
1591
+ let maxVal = Math.abs(out[j * n + j]);
1592
+ let maxRow = j;
1593
+ for (let i = j + 1; i < m; i++) {
1594
+ const val = Math.abs(out[i * n + j]);
1595
+ if (val > maxVal) {
1596
+ maxVal = val;
1597
+ maxRow = i;
1598
+ }
1599
+ }
1600
+ piv[j] = maxRow;
1601
+ if (maxRow !== j) {
1602
+ for (let col = 0; col < n; col++) {
1603
+ const tmp = out[j * n + col];
1604
+ out[j * n + col] = out[maxRow * n + col];
1605
+ out[maxRow * n + col] = tmp;
1606
+ }
1607
+ const tmpP = p[j];
1608
+ p[j] = p[maxRow];
1609
+ p[maxRow] = tmpP;
1610
+ }
1611
+ const diag = out[j * n + j];
1612
+ if (diag !== 0) for (let i = j + 1; i < m; i++) {
1613
+ const factor = out[i * n + j] / diag;
1614
+ out[i * n + j] = factor;
1615
+ for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
1616
+ }
1617
+ }
1618
+ }
1619
+ }
1548
1620
 
1549
1621
  //#endregion
1550
1622
  //#region src/shape.ts
@@ -3512,7 +3584,7 @@ var I32 = class {
3512
3584
  shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
3513
3585
  rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
3514
3586
  rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
3515
- eqz = BINARY_OP("eqz", 69, "i32", "i32", "i32");
3587
+ eqz = UNARY_OP("eqz", 69, "i32", "i32");
3516
3588
  eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
3517
3589
  ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3518
3590
  trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
@@ -3977,7 +4049,7 @@ function translateExp(cg, funcs, exp, ctx) {
3977
4049
  else throw new UnsupportedOpError(op, dtype, "wasm");
3978
4050
  else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3979
4051
  else dtyF(cg, op, dtype).max();
3980
- else if (dtype === DType.Int32 || dtype === DType.Uint32) {
4052
+ else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
3981
4053
  const a = cg.local.declare(cg.i32);
3982
4054
  const b = cg.local.declare(cg.i32);
3983
4055
  cg.local.set(b);
@@ -4127,7 +4199,8 @@ function dtyF(cg, op, dtype) {
4127
4199
  const devices = [
4128
4200
  "cpu",
4129
4201
  "wasm",
4130
- "webgpu"
4202
+ "webgpu",
4203
+ "webgl"
4131
4204
  ];
4132
4205
  const initializedBackends = /* @__PURE__ */ new Map();
4133
4206
  initializedBackends.set("cpu", new CpuBackend());
@@ -4166,7 +4239,7 @@ async function createBackend(device) {
4166
4239
  if (!navigator.gpu) return null;
4167
4240
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4168
4241
  if (!adapter) return null;
4169
- const { WebGPUBackend } = await import("./webgpu-ChVgx3b6.js");
4242
+ const { WebGPUBackend } = await import("./webgpu-Dh7k9io0.js");
4170
4243
  const importantLimits = [
4171
4244
  "maxBufferSize",
4172
4245
  "maxComputeInvocationsPerWorkgroup",
@@ -4190,6 +4263,22 @@ async function createBackend(device) {
4190
4263
  console.error("Unexpected error requesting WebGPU device:", error);
4191
4264
  return null;
4192
4265
  }
4266
+ } else if (device === "webgl") {
4267
+ if (typeof WebGL2RenderingContext === "undefined") return null;
4268
+ const canvas = new OffscreenCanvas(0, 0);
4269
+ const gl = canvas.getContext("webgl2", {
4270
+ alpha: false,
4271
+ antialias: false,
4272
+ premultipliedAlpha: false,
4273
+ preserveDrawingBuffer: false,
4274
+ depth: false,
4275
+ stencil: false,
4276
+ failIfMajorPerformanceCaveat: true
4277
+ });
4278
+ if (!gl) return null;
4279
+ if (!gl.getExtension("EXT_color_buffer_float")) return null;
4280
+ const { WebGLBackend } = await import("./webgl-RSuZKvgc.js");
4281
+ return new WebGLBackend(gl);
4193
4282
  } else throw new Error(`Backend not found: ${device}`);
4194
4283
  }
4195
4284
  /** Retrieve a backend that has been initialized. */
@@ -4224,4 +4313,4 @@ var UnsupportedRoutineError = class extends Error {
4224
4313
  };
4225
4314
 
4226
4315
  //#endregion
4227
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
4316
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
@@ -1471,10 +1471,37 @@ let Routines = /* @__PURE__ */ function(Routines$1) {
1471
1471
  Routines$1["Sort"] = "Sort";
1472
1472
  /** Returns `int32` indices of the stably sorted array. */
1473
1473
  Routines$1["Argsort"] = "Argsort";
1474
- /** Solve a triangular system of questions. */
1474
+ /**
1475
+ * Solve a triangular system of equations.
1476
+ *
1477
+ * The first batch of inputs `A` should be of shape `[..., N, N]` and upper
1478
+ * triangular, while the second batch `B` should be of shape `[..., M, N]`.
1479
+ *
1480
+ * Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
1481
+ * triangular matrix. This is equivalent to `X = B @ A^-T`.
1482
+ */
1475
1483
  Routines$1["TriangularSolve"] = "TriangularSolve";
1476
- /** Cholesky decomposition of 2D positive semi-definite matrices. */
1484
+ /**
1485
+ * Cholesky decomposition of 2D positive semi-definite matrices.
1486
+ *
1487
+ * The input batch should be of shape `[..., N, N]`, and the output batch is
1488
+ * of the same shape, containing the lower-triangular matrix `L` such that
1489
+ * `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
1490
+ */
1477
1491
  Routines$1["Cholesky"] = "Cholesky";
1492
+ /**
1493
+ * LU decomposition of 2D rectangular matrices.
1494
+ *
1495
+ * The input is a batch of shape `[..., M, N]`, and the output is a tuple of
1496
+ * three arrays: `LU, Pivots, Permutation`.
1497
+ *
1498
+ * - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
1499
+ * triangular matrices. (lower triangular = implicit unit diagonal)
1500
+ * - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
1501
+ * - `Permutation` is of shape `[..., M]`, containing the permutation vector
1502
+ * such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
1503
+ */
1504
+ Routines$1["LU"] = "LU";
1478
1505
  return Routines$1;
1479
1506
  }({});
1480
1507
  function runCpuRoutine(routine, inputs, outputs) {
@@ -1486,6 +1513,7 @@ function runCpuRoutine(routine, inputs, outputs) {
1486
1513
  case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
1487
1514
  case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
1488
1515
  case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
1516
+ case Routines.LU: return runLU(type, inputAr, outputAr);
1489
1517
  default:
1490
1518
  }
1491
1519
  }
@@ -1546,6 +1574,50 @@ function runCholesky(type, [x], [y]) {
1546
1574
  }
1547
1575
  }
1548
1576
  }
1577
+ function runLU(type, [a], [lu, pivots, perm]) {
1578
+ const shape = type.inputShapes[0];
1579
+ if (shape.length < 2) throw new Error("lu: input must be at least 2D");
1580
+ const m = shape[shape.length - 2];
1581
+ const n = shape[shape.length - 1];
1582
+ const r = Math.min(m, n);
1583
+ for (let offset = 0; offset < a.length; offset += m * n) {
1584
+ const ar = a.subarray(offset, offset + m * n);
1585
+ const out = lu.subarray(offset, offset + m * n);
1586
+ const batchIdx = offset / (m * n);
1587
+ const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
1588
+ const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
1589
+ out.set(ar);
1590
+ for (let i = 0; i < m; i++) p[i] = i;
1591
+ for (let j = 0; j < r; j++) {
1592
+ let maxVal = Math.abs(out[j * n + j]);
1593
+ let maxRow = j;
1594
+ for (let i = j + 1; i < m; i++) {
1595
+ const val = Math.abs(out[i * n + j]);
1596
+ if (val > maxVal) {
1597
+ maxVal = val;
1598
+ maxRow = i;
1599
+ }
1600
+ }
1601
+ piv[j] = maxRow;
1602
+ if (maxRow !== j) {
1603
+ for (let col = 0; col < n; col++) {
1604
+ const tmp = out[j * n + col];
1605
+ out[j * n + col] = out[maxRow * n + col];
1606
+ out[maxRow * n + col] = tmp;
1607
+ }
1608
+ const tmpP = p[j];
1609
+ p[j] = p[maxRow];
1610
+ p[maxRow] = tmpP;
1611
+ }
1612
+ const diag = out[j * n + j];
1613
+ if (diag !== 0) for (let i = j + 1; i < m; i++) {
1614
+ const factor = out[i * n + j] / diag;
1615
+ out[i * n + j] = factor;
1616
+ for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
1617
+ }
1618
+ }
1619
+ }
1620
+ }
1549
1621
 
1550
1622
  //#endregion
1551
1623
  //#region src/shape.ts
@@ -3513,7 +3585,7 @@ var I32 = class {
3513
3585
  shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
3514
3586
  rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
3515
3587
  rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
3516
- eqz = BINARY_OP("eqz", 69, "i32", "i32", "i32");
3588
+ eqz = UNARY_OP("eqz", 69, "i32", "i32");
3517
3589
  eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
3518
3590
  ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3519
3591
  trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
@@ -3978,7 +4050,7 @@ function translateExp(cg, funcs, exp, ctx) {
3978
4050
  else throw new UnsupportedOpError(op, dtype, "wasm");
3979
4051
  else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3980
4052
  else dtyF(cg, op, dtype).max();
3981
- else if (dtype === DType.Int32 || dtype === DType.Uint32) {
4053
+ else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
3982
4054
  const a = cg.local.declare(cg.i32);
3983
4055
  const b = cg.local.declare(cg.i32);
3984
4056
  cg.local.set(b);
@@ -4128,7 +4200,8 @@ function dtyF(cg, op, dtype) {
4128
4200
  const devices = [
4129
4201
  "cpu",
4130
4202
  "wasm",
4131
- "webgpu"
4203
+ "webgpu",
4204
+ "webgl"
4132
4205
  ];
4133
4206
  const initializedBackends = /* @__PURE__ */ new Map();
4134
4207
  initializedBackends.set("cpu", new CpuBackend());
@@ -4167,7 +4240,7 @@ async function createBackend(device) {
4167
4240
  if (!navigator.gpu) return null;
4168
4241
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4169
4242
  if (!adapter) return null;
4170
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Oj3Kd-kd.cjs"));
4243
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Db2JrNBr.cjs"));
4171
4244
  const importantLimits = [
4172
4245
  "maxBufferSize",
4173
4246
  "maxComputeInvocationsPerWorkgroup",
@@ -4191,6 +4264,22 @@ async function createBackend(device) {
4191
4264
  console.error("Unexpected error requesting WebGPU device:", error);
4192
4265
  return null;
4193
4266
  }
4267
+ } else if (device === "webgl") {
4268
+ if (typeof WebGL2RenderingContext === "undefined") return null;
4269
+ const canvas = new OffscreenCanvas(0, 0);
4270
+ const gl = canvas.getContext("webgl2", {
4271
+ alpha: false,
4272
+ antialias: false,
4273
+ premultipliedAlpha: false,
4274
+ preserveDrawingBuffer: false,
4275
+ depth: false,
4276
+ stencil: false,
4277
+ failIfMajorPerformanceCaveat: true
4278
+ });
4279
+ if (!gl) return null;
4280
+ if (!gl.getExtension("EXT_color_buffer_float")) return null;
4281
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-ClIYb8jP.cjs"));
4282
+ return new WebGLBackend(gl);
4194
4283
  } else throw new Error(`Backend not found: ${device}`);
4195
4284
  }
4196
4285
  /** Retrieve a backend that has been initialized. */
@@ -4507,6 +4596,12 @@ Object.defineProperty(exports, 'toposort', {
4507
4596
  return toposort;
4508
4597
  }
4509
4598
  });
4599
+ Object.defineProperty(exports, 'tuneNullopt', {
4600
+ enumerable: true,
4601
+ get: function () {
4602
+ return tuneNullopt;
4603
+ }
4604
+ });
4510
4605
  Object.defineProperty(exports, 'tuneWebgpu', {
4511
4606
  enumerable: true,
4512
4607
  get: function () {