@jax-js/jax 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -69,6 +69,9 @@ function zipn(...arrays) {
69
69
  const minLength = Math.min(...arrays.map((x) => x.length));
70
70
  return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
71
71
  }
72
+ function sorted(arr) {
73
+ return [...arr].sort((a, b) => a - b);
74
+ }
72
75
  function rep(length, value) {
73
76
  if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
74
77
  return new Array(length).fill(value);
@@ -145,7 +148,7 @@ function normalizeAxis(axis, ndim) {
145
148
  if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
146
149
  seen.add(ca);
147
150
  }
148
- return [...seen].sort();
151
+ return sorted(seen);
149
152
  }
150
153
  }
151
154
  function range(start, stop, step = 1) {
@@ -1328,7 +1331,7 @@ var Reduction = class {
1328
1331
  /** Evaluate this operation on CPU. */
1329
1332
  evaluate(...values) {
1330
1333
  if (this.dtype === DType.Bool) {
1331
- if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, true);
1334
+ if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, false);
1332
1335
  else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
1333
1336
  } else if (this.dtype === DType.Int32) {
1334
1337
  if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
@@ -1438,6 +1441,184 @@ function erfc(x) {
1438
1441
  else return 2 - _erfapprox$1(-x);
1439
1442
  }
1440
1443
 
1444
+ //#endregion
1445
+ //#region src/routine.ts
1446
+ /**
1447
+ * Advanced operations that don't fit into the `AluExp` compiler representation.
1448
+ *
1449
+ * Some routines like iterative matrix algorithms, FFTs, or sorting may not be
1450
+ * easy to express efficiently as a `Kernel` object. These also tend to be
1451
+ * somewhat expensive, so the benefit of kernel fusion and inlining is less
1452
+ * relevant.
1453
+ *
1454
+ * For these operations, we dispatch them as a custom operation on the backend,
1455
+ * which each backend implements in a specific way. These are listed in the
1456
+ * `Routines` enum below.
1457
+ *
1458
+ * Routines cannot be fused into other kernels and always operate on contiguous
1459
+ * arrays (default `ShapeTracker`).
1460
+ */
1461
+ var Routine = class {
1462
+ constructor(name, type, params) {
1463
+ this.name = name;
1464
+ this.type = type;
1465
+ this.params = params;
1466
+ }
1467
+ };
1468
+ /** One of the valid `Routine` that can be dispatched to backend. */
1469
+ let Routines = /* @__PURE__ */ function(Routines$1) {
1470
+ /** Stable sorting algorithm along the last axis. */
1471
+ Routines$1["Sort"] = "Sort";
1472
+ /** Returns `int32` indices of the stably sorted array. */
1473
+ Routines$1["Argsort"] = "Argsort";
1474
+ /**
1475
+ * Solve a triangular system of equations.
1476
+ *
1477
+ * The first batch of inputs `A` should be of shape `[..., N, N]` and upper
1478
+ * triangular, while the second batch `B` should be of shape `[..., M, N]`.
1479
+ *
1480
+ * Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
1481
+ * triangular matrix. This is equivalent to `X = B @ A^-T`.
1482
+ */
1483
+ Routines$1["TriangularSolve"] = "TriangularSolve";
1484
+ /**
1485
+ * Cholesky decomposition of 2D positive semi-definite matrices.
1486
+ *
1487
+ * The input batch should be of shape `[..., N, N]`, and the output batch is
1488
+ * of the same shape, containing the lower-triangular matrix `L` such that
1489
+ * `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
1490
+ */
1491
+ Routines$1["Cholesky"] = "Cholesky";
1492
+ /**
1493
+ * LU decomposition of 2D rectangular matrices.
1494
+ *
1495
+ * The input is a batch of shape `[..., M, N]`, and the output is a tuple of
1496
+ * three arrays: `LU, Pivots, Permutation`.
1497
+ *
1498
+ * - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
1499
+ * triangular matrices. (lower triangular = implicit unit diagonal)
1500
+ * - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
1501
+ * - `Permutation` is of shape `[..., M]`, containing the permutation vector
1502
+ * such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
1503
+ */
1504
+ Routines$1["LU"] = "LU";
1505
+ return Routines$1;
1506
+ }({});
1507
+ function runCpuRoutine(routine, inputs, outputs) {
1508
+ const { name, type } = routine;
1509
+ const inputAr = inputs.map((buf, i) => dtypedArray(type.inputDtypes[i], buf));
1510
+ const outputAr = outputs.map((buf, i) => dtypedArray(type.outputDtypes[i], buf));
1511
+ switch (name) {
1512
+ case Routines.Sort: return runSort(type, inputAr, outputAr);
1513
+ case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
1514
+ case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
1515
+ case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
1516
+ case Routines.LU: return runLU(type, inputAr, outputAr);
1517
+ default:
1518
+ }
1519
+ }
1520
+ function runSort(type, [x], [y]) {
1521
+ const xs = type.inputShapes[0];
1522
+ if (xs.length === 0) throw new Error("sort: cannot sort a scalar");
1523
+ const n = xs[xs.length - 1];
1524
+ y.set(x);
1525
+ for (let i = 0; i < y.length; i += n) y.subarray(i, i + n).sort();
1526
+ }
1527
+ function runArgsort(type, [x], [y, yi]) {
1528
+ const xs = type.inputShapes[0];
1529
+ if (xs.length === 0) throw new Error("argsort: cannot sort a scalar");
1530
+ const n = xs[xs.length - 1];
1531
+ for (let offset = 0; offset < y.length; offset += n) {
1532
+ const ar = x.subarray(offset, offset + n);
1533
+ const out = y.subarray(offset, offset + n);
1534
+ const outi = yi.subarray(offset, offset + n);
1535
+ for (let i = 0; i < n; i++) outi[i] = i;
1536
+ outi.sort((a, b) => ar[a] - ar[b]);
1537
+ for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1538
+ }
1539
+ }
1540
+ function runTriangularSolve(type, [a, b], [x], { unitDiagonal }) {
1541
+ const as = type.inputShapes[0];
1542
+ const bs = type.inputShapes[1];
1543
+ if (as.length < 2) throw new Error(`triangular_solve: a must be at least 2D, got ${as}`);
1544
+ if (bs.length < 2) throw new Error(`triangular_solve: b must be at least 2D, got ${bs}`);
1545
+ const n = as[as.length - 2];
1546
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
1547
+ const batch = bs[bs.length - 2];
1548
+ for (let counter = 0; counter < a.length / (n * n); counter++) {
1549
+ const a1 = a.subarray(counter * n * n, (counter + 1) * n * n);
1550
+ for (let t = 0; t < batch; t++) {
1551
+ const b1 = b.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
1552
+ const x1 = x.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
1553
+ for (let i = n - 1; i >= 0; i--) {
1554
+ let sum = b1[i];
1555
+ for (let j = i + 1; j < n; j++) sum -= a1[i * n + j] * x1[j];
1556
+ x1[i] = unitDiagonal ? sum : sum / a1[i * n + i];
1557
+ }
1558
+ }
1559
+ }
1560
+ }
1561
+ function runCholesky(type, [x], [y]) {
1562
+ const xs = type.inputShapes[0];
1563
+ if (xs.length < 2) throw new Error("cholesky: input must be at least 2D");
1564
+ const n = xs[xs.length - 2];
1565
+ const m = xs[xs.length - 1];
1566
+ if (n !== m) throw new Error(`cholesky: input must be square, got [${n}, ${m}]`);
1567
+ for (let offset = 0; offset < y.length; offset += n * n) {
1568
+ const ar = x.subarray(offset, offset + n * n);
1569
+ const out = y.subarray(offset, offset + n * n);
1570
+ for (let i = 0; i < n; i++) for (let j = 0; j <= i; j++) {
1571
+ let sum = ar[i * n + j];
1572
+ for (let k = 0; k < j; k++) sum -= out[i * n + k] * out[j * n + k];
1573
+ out[i * n + j] = i === j ? Math.sqrt(sum) : sum / out[j * n + j];
1574
+ }
1575
+ }
1576
+ }
1577
+ function runLU(type, [a], [lu, pivots, perm]) {
1578
+ const shape = type.inputShapes[0];
1579
+ if (shape.length < 2) throw new Error("lu: input must be at least 2D");
1580
+ const m = shape[shape.length - 2];
1581
+ const n = shape[shape.length - 1];
1582
+ const r = Math.min(m, n);
1583
+ for (let offset = 0; offset < a.length; offset += m * n) {
1584
+ const ar = a.subarray(offset, offset + m * n);
1585
+ const out = lu.subarray(offset, offset + m * n);
1586
+ const batchIdx = offset / (m * n);
1587
+ const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
1588
+ const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
1589
+ out.set(ar);
1590
+ for (let i = 0; i < m; i++) p[i] = i;
1591
+ for (let j = 0; j < r; j++) {
1592
+ let maxVal = Math.abs(out[j * n + j]);
1593
+ let maxRow = j;
1594
+ for (let i = j + 1; i < m; i++) {
1595
+ const val = Math.abs(out[i * n + j]);
1596
+ if (val > maxVal) {
1597
+ maxVal = val;
1598
+ maxRow = i;
1599
+ }
1600
+ }
1601
+ piv[j] = maxRow;
1602
+ if (maxRow !== j) {
1603
+ for (let col = 0; col < n; col++) {
1604
+ const tmp = out[j * n + col];
1605
+ out[j * n + col] = out[maxRow * n + col];
1606
+ out[maxRow * n + col] = tmp;
1607
+ }
1608
+ const tmpP = p[j];
1609
+ p[j] = p[maxRow];
1610
+ p[maxRow] = tmpP;
1611
+ }
1612
+ const diag = out[j * n + j];
1613
+ if (diag !== 0) for (let i = j + 1; i < m; i++) {
1614
+ const factor = out[i * n + j] / diag;
1615
+ out[i * n + j] = factor;
1616
+ for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
1617
+ }
1618
+ }
1619
+ }
1620
+ }
1621
+
1441
1622
  //#endregion
1442
1623
  //#region src/shape.ts
1443
1624
  const jstr = JSON.stringify;
@@ -1908,7 +2089,7 @@ var ShapeTracker = class ShapeTracker {
1908
2089
  let st = this;
1909
2090
  if (axis.length > 0) {
1910
2091
  const unsqueezed = [...st.shape];
1911
- for (const i of axis.toSorted()) unsqueezed.splice(i, 0, 1);
2092
+ for (const i of sorted(axis)) unsqueezed.splice(i, 0, 1);
1912
2093
  st = st.reshape(unsqueezed);
1913
2094
  }
1914
2095
  return st.expand(newShape);
@@ -2133,7 +2314,7 @@ function tuneWebgpu(kernel) {
2133
2314
  break;
2134
2315
  }
2135
2316
  }
2136
- for (const ax of Array.from(upcastedAxis).sort()) {
2317
+ for (const ax of sorted(upcastedAxis)) {
2137
2318
  const s = dim.st.shape[ax];
2138
2319
  for (const amount of [8, 4]) if (s % amount === 0) {
2139
2320
  dim.applyLocal(ax, amount);
@@ -2251,13 +2432,21 @@ var CpuBackend = class {
2251
2432
  if (count === void 0) count = buffer.byteLength - start;
2252
2433
  return buffer.slice(start, start + count);
2253
2434
  }
2254
- async prepare(kernel) {
2255
- return this.prepareSync(kernel);
2435
+ async prepareKernel(kernel) {
2436
+ return this.prepareKernelSync(kernel);
2256
2437
  }
2257
- prepareSync(kernel) {
2438
+ prepareKernelSync(kernel) {
2258
2439
  return new Executable(kernel, void 0);
2259
2440
  }
2260
- dispatch({ kernel }, inputs, outputs) {
2441
+ async prepareRoutine(routine) {
2442
+ return this.prepareRoutineSync(routine);
2443
+ }
2444
+ prepareRoutineSync(routine) {
2445
+ return new Executable(routine, void 0);
2446
+ }
2447
+ dispatch(exe, inputs, outputs) {
2448
+ if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
2449
+ const kernel = exe.source;
2261
2450
  const { exp, epilogue } = tuneNullopt(kernel);
2262
2451
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
2263
2452
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
@@ -2315,8 +2504,10 @@ var WasmAllocator = class {
2315
2504
  const sizeClass = this.#findSizeClass(size);
2316
2505
  const freeList = this.#freeLists.get(sizeClass);
2317
2506
  let ptr;
2318
- if (freeList && freeList.length > 0) ptr = freeList.pop();
2319
- else ptr = this.#bumpAlloc(sizeClass);
2507
+ if (freeList && freeList.length > 0) {
2508
+ ptr = freeList.pop();
2509
+ new Uint8Array(this.#memory.buffer, ptr, sizeClass).fill(0);
2510
+ } else ptr = this.#bumpAlloc(sizeClass);
2320
2511
  this.#allocatedBuffers.set(ptr, sizeClass);
2321
2512
  return ptr;
2322
2513
  }
@@ -3394,7 +3585,7 @@ var I32 = class {
3394
3585
  shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
3395
3586
  rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
3396
3587
  rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
3397
- eqz = BINARY_OP("eqz", 69, "i32", "i32", "i32");
3588
+ eqz = UNARY_OP("eqz", 69, "i32", "i32");
3398
3589
  eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
3399
3590
  ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3400
3591
  trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
@@ -3682,10 +3873,10 @@ var WasmBackend = class {
3682
3873
  if (count === void 0) count = buffer.byteLength - start;
3683
3874
  return buffer.slice(start, start + count);
3684
3875
  }
3685
- async prepare(kernel) {
3686
- return this.prepareSync(kernel);
3876
+ async prepareKernel(kernel) {
3877
+ return this.prepareKernelSync(kernel);
3687
3878
  }
3688
- prepareSync(kernel) {
3879
+ prepareKernelSync(kernel) {
3689
3880
  const kernelHash = FpHash.hash(kernel);
3690
3881
  const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => {
3691
3882
  const bytes = codegenWasm(kernel);
@@ -3693,7 +3884,14 @@ var WasmBackend = class {
3693
3884
  });
3694
3885
  return new Executable(kernel, { module: module$1 });
3695
3886
  }
3887
+ async prepareRoutine(routine) {
3888
+ return this.prepareRoutineSync(routine);
3889
+ }
3890
+ prepareRoutineSync(routine) {
3891
+ return new Executable(routine, void 0);
3892
+ }
3696
3893
  dispatch(exe, inputs, outputs) {
3894
+ if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
3697
3895
  const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
3698
3896
  const func = instance.exports.kernel;
3699
3897
  const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
@@ -3852,7 +4050,7 @@ function translateExp(cg, funcs, exp, ctx) {
3852
4050
  else throw new UnsupportedOpError(op, dtype, "wasm");
3853
4051
  else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3854
4052
  else dtyF(cg, op, dtype).max();
3855
- else if (dtype === DType.Int32 || dtype === DType.Uint32) {
4053
+ else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
3856
4054
  const a = cg.local.declare(cg.i32);
3857
4055
  const b = cg.local.declare(cg.i32);
3858
4056
  cg.local.set(b);
@@ -4002,7 +4200,8 @@ function dtyF(cg, op, dtype) {
4002
4200
  const devices = [
4003
4201
  "cpu",
4004
4202
  "wasm",
4005
- "webgpu"
4203
+ "webgpu",
4204
+ "webgl"
4006
4205
  ];
4007
4206
  const initializedBackends = /* @__PURE__ */ new Map();
4008
4207
  initializedBackends.set("cpu", new CpuBackend());
@@ -4041,7 +4240,7 @@ async function createBackend(device) {
4041
4240
  if (!navigator.gpu) return null;
4042
4241
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4043
4242
  if (!adapter) return null;
4044
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVns4DbI.cjs"));
4243
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Db2JrNBr.cjs"));
4045
4244
  const importantLimits = [
4046
4245
  "maxBufferSize",
4047
4246
  "maxComputeInvocationsPerWorkgroup",
@@ -4065,6 +4264,22 @@ async function createBackend(device) {
4065
4264
  console.error("Unexpected error requesting WebGPU device:", error);
4066
4265
  return null;
4067
4266
  }
4267
+ } else if (device === "webgl") {
4268
+ if (typeof WebGL2RenderingContext === "undefined") return null;
4269
+ const canvas = new OffscreenCanvas(0, 0);
4270
+ const gl = canvas.getContext("webgl2", {
4271
+ alpha: false,
4272
+ antialias: false,
4273
+ premultipliedAlpha: false,
4274
+ preserveDrawingBuffer: false,
4275
+ depth: false,
4276
+ stencil: false,
4277
+ failIfMajorPerformanceCaveat: true
4278
+ });
4279
+ if (!gl) return null;
4280
+ if (!gl.getExtension("EXT_color_buffer_float")) return null;
4281
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-ClIYb8jP.cjs"));
4282
+ return new WebGLBackend(gl);
4068
4283
  } else throw new Error(`Backend not found: ${device}`);
4069
4284
  }
4070
4285
  /** Retrieve a backend that has been initialized. */
@@ -4075,8 +4290,8 @@ function getBackend(device) {
4075
4290
  return backend;
4076
4291
  }
4077
4292
  var Executable = class {
4078
- constructor(kernel, data) {
4079
- this.kernel = kernel;
4293
+ constructor(source, data) {
4294
+ this.source = source;
4080
4295
  this.data = data;
4081
4296
  }
4082
4297
  };
@@ -4092,6 +4307,11 @@ var UnsupportedOpError = class extends Error {
4092
4307
  super(msg);
4093
4308
  }
4094
4309
  };
4310
+ var UnsupportedRoutineError = class extends Error {
4311
+ constructor(name, device) {
4312
+ super(`routine '${name}' is not supported in ${device} backend`);
4313
+ }
4314
+ };
4095
4315
 
4096
4316
  //#endregion
4097
4317
  Object.defineProperty(exports, 'AluExp', {
@@ -4160,6 +4380,18 @@ Object.defineProperty(exports, 'Reduction', {
4160
4380
  return Reduction;
4161
4381
  }
4162
4382
  });
4383
+ Object.defineProperty(exports, 'Routine', {
4384
+ enumerable: true,
4385
+ get: function () {
4386
+ return Routine;
4387
+ }
4388
+ });
4389
+ Object.defineProperty(exports, 'Routines', {
4390
+ enumerable: true,
4391
+ get: function () {
4392
+ return Routines;
4393
+ }
4394
+ });
4163
4395
  Object.defineProperty(exports, 'ShapeTracker', {
4164
4396
  enumerable: true,
4165
4397
  get: function () {
@@ -4178,6 +4410,12 @@ Object.defineProperty(exports, 'UnsupportedOpError', {
4178
4410
  return UnsupportedOpError;
4179
4411
  }
4180
4412
  });
4413
+ Object.defineProperty(exports, 'UnsupportedRoutineError', {
4414
+ enumerable: true,
4415
+ get: function () {
4416
+ return UnsupportedRoutineError;
4417
+ }
4418
+ });
4181
4419
  Object.defineProperty(exports, 'accessorAluExp', {
4182
4420
  enumerable: true,
4183
4421
  get: function () {
@@ -4358,6 +4596,12 @@ Object.defineProperty(exports, 'toposort', {
4358
4596
  return toposort;
4359
4597
  }
4360
4598
  });
4599
+ Object.defineProperty(exports, 'tuneNullopt', {
4600
+ enumerable: true,
4601
+ get: function () {
4602
+ return tuneNullopt;
4603
+ }
4604
+ });
4361
4605
  Object.defineProperty(exports, 'tuneWebgpu', {
4362
4606
  enumerable: true,
4363
4607
  get: function () {
@@ -4387,5 +4631,4 @@ Object.defineProperty(exports, 'zipn', {
4387
4631
  get: function () {
4388
4632
  return zipn;
4389
4633
  }
4390
- });
4391
- //# sourceMappingURL=backend-CmaidnkQ.cjs.map
4634
+ });