@jax-js/jax 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -3,7 +3,8 @@
3
3
  <p align="center"><strong>
4
4
  <a href="https://jax-js.com">Website</a> |
5
5
  <a href="https://jax-js.com/docs/">API Reference</a> |
6
- <a href="./FEATURES.md">Compatibility Table</a>
6
+ <a href="./FEATURES.md">Compatibility Table</a> |
7
+ <a href="https://discord.gg/BW6YsCd4Tf">Discord</a>
7
8
  </strong></p>
8
9
 
9
10
  **jax-js** is a machine learning framework for the browser. It aims to bring
@@ -323,7 +324,7 @@ pnpm -C website dev
323
324
 
324
325
  ## Future work / help wanted
325
326
 
326
- Contributions are welcomed! Especially in:
327
+ Contributions are welcomed! Some fruitful areas to look into:
327
328
 
328
329
  - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
329
330
  - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
@@ -334,3 +335,5 @@ Contributions are welcomed! Especially in:
334
335
  - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
335
336
  - Adding WebGL runtime for older browsers that don't support WebGPU.
336
337
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
338
+
339
+ You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
@@ -69,6 +69,9 @@ function zipn(...arrays) {
69
69
  const minLength = Math.min(...arrays.map((x) => x.length));
70
70
  return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
71
71
  }
72
+ function sorted(arr) {
73
+ return [...arr].sort((a, b) => a - b);
74
+ }
72
75
  function rep(length, value) {
73
76
  if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
74
77
  return new Array(length).fill(value);
@@ -145,7 +148,7 @@ function normalizeAxis(axis, ndim) {
145
148
  if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
146
149
  seen.add(ca);
147
150
  }
148
- return [...seen].sort();
151
+ return sorted(seen);
149
152
  }
150
153
  }
151
154
  function range(start, stop, step = 1) {
@@ -1328,7 +1331,7 @@ var Reduction = class {
1328
1331
  /** Evaluate this operation on CPU. */
1329
1332
  evaluate(...values) {
1330
1333
  if (this.dtype === DType.Bool) {
1331
- if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, true);
1334
+ if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, false);
1332
1335
  else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
1333
1336
  } else if (this.dtype === DType.Int32) {
1334
1337
  if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
@@ -1438,6 +1441,112 @@ function erfc(x) {
1438
1441
  else return 2 - _erfapprox$1(-x);
1439
1442
  }
1440
1443
 
1444
+ //#endregion
1445
+ //#region src/routine.ts
1446
+ /**
1447
+ * Advanced operations that don't fit into the `AluExp` compiler representation.
1448
+ *
1449
+ * Some routines like iterative matrix algorithms, FFTs, or sorting may not be
1450
+ * easy to express efficiently as a `Kernel` object. These also tend to be
1451
+ * somewhat expensive, so the benefit of kernel fusion and inlining is less
1452
+ * relevant.
1453
+ *
1454
+ * For these operations, we dispatch them as a custom operation on the backend,
1455
+ * which each backend implements in a specific way. These are listed in the
1456
+ * `Routines` enum below.
1457
+ *
1458
+ * Routines cannot be fused into other kernels and always operate on contiguous
1459
+ * arrays (default `ShapeTracker`).
1460
+ */
1461
+ var Routine = class {
1462
+ constructor(name, type, params) {
1463
+ this.name = name;
1464
+ this.type = type;
1465
+ this.params = params;
1466
+ }
1467
+ };
1468
+ /** One of the valid `Routine` that can be dispatched to backend. */
1469
+ let Routines = /* @__PURE__ */ function(Routines$1) {
1470
+ /** Stable sorting algorithm along the last axis. */
1471
+ Routines$1["Sort"] = "Sort";
1472
+ /** Returns `int32` indices of the stably sorted array. */
1473
+ Routines$1["Argsort"] = "Argsort";
1474
+ /** Solve a triangular system of questions. */
1475
+ Routines$1["TriangularSolve"] = "TriangularSolve";
1476
+ /** Cholesky decomposition of 2D positive semi-definite matrices. */
1477
+ Routines$1["Cholesky"] = "Cholesky";
1478
+ return Routines$1;
1479
+ }({});
1480
+ function runCpuRoutine(routine, inputs, outputs) {
1481
+ const { name, type } = routine;
1482
+ const inputAr = inputs.map((buf, i) => dtypedArray(type.inputDtypes[i], buf));
1483
+ const outputAr = outputs.map((buf, i) => dtypedArray(type.outputDtypes[i], buf));
1484
+ switch (name) {
1485
+ case Routines.Sort: return runSort(type, inputAr, outputAr);
1486
+ case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
1487
+ case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
1488
+ case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
1489
+ default:
1490
+ }
1491
+ }
1492
+ function runSort(type, [x], [y]) {
1493
+ const xs = type.inputShapes[0];
1494
+ if (xs.length === 0) throw new Error("sort: cannot sort a scalar");
1495
+ const n = xs[xs.length - 1];
1496
+ y.set(x);
1497
+ for (let i = 0; i < y.length; i += n) y.subarray(i, i + n).sort();
1498
+ }
1499
+ function runArgsort(type, [x], [y, yi]) {
1500
+ const xs = type.inputShapes[0];
1501
+ if (xs.length === 0) throw new Error("argsort: cannot sort a scalar");
1502
+ const n = xs[xs.length - 1];
1503
+ for (let offset = 0; offset < y.length; offset += n) {
1504
+ const ar = x.subarray(offset, offset + n);
1505
+ const out = y.subarray(offset, offset + n);
1506
+ const outi = yi.subarray(offset, offset + n);
1507
+ for (let i = 0; i < n; i++) outi[i] = i;
1508
+ outi.sort((a, b) => ar[a] - ar[b]);
1509
+ for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1510
+ }
1511
+ }
1512
+ function runTriangularSolve(type, [a, b], [x], { unitDiagonal }) {
1513
+ const as = type.inputShapes[0];
1514
+ const bs = type.inputShapes[1];
1515
+ if (as.length < 2) throw new Error(`triangular_solve: a must be at least 2D, got ${as}`);
1516
+ if (bs.length < 2) throw new Error(`triangular_solve: b must be at least 2D, got ${bs}`);
1517
+ const n = as[as.length - 2];
1518
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
1519
+ const batch = bs[bs.length - 2];
1520
+ for (let counter = 0; counter < a.length / (n * n); counter++) {
1521
+ const a1 = a.subarray(counter * n * n, (counter + 1) * n * n);
1522
+ for (let t = 0; t < batch; t++) {
1523
+ const b1 = b.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
1524
+ const x1 = x.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
1525
+ for (let i = n - 1; i >= 0; i--) {
1526
+ let sum = b1[i];
1527
+ for (let j = i + 1; j < n; j++) sum -= a1[i * n + j] * x1[j];
1528
+ x1[i] = unitDiagonal ? sum : sum / a1[i * n + i];
1529
+ }
1530
+ }
1531
+ }
1532
+ }
1533
+ function runCholesky(type, [x], [y]) {
1534
+ const xs = type.inputShapes[0];
1535
+ if (xs.length < 2) throw new Error("cholesky: input must be at least 2D");
1536
+ const n = xs[xs.length - 2];
1537
+ const m = xs[xs.length - 1];
1538
+ if (n !== m) throw new Error(`cholesky: input must be square, got [${n}, ${m}]`);
1539
+ for (let offset = 0; offset < y.length; offset += n * n) {
1540
+ const ar = x.subarray(offset, offset + n * n);
1541
+ const out = y.subarray(offset, offset + n * n);
1542
+ for (let i = 0; i < n; i++) for (let j = 0; j <= i; j++) {
1543
+ let sum = ar[i * n + j];
1544
+ for (let k = 0; k < j; k++) sum -= out[i * n + k] * out[j * n + k];
1545
+ out[i * n + j] = i === j ? Math.sqrt(sum) : sum / out[j * n + j];
1546
+ }
1547
+ }
1548
+ }
1549
+
1441
1550
  //#endregion
1442
1551
  //#region src/shape.ts
1443
1552
  const jstr = JSON.stringify;
@@ -1908,7 +2017,7 @@ var ShapeTracker = class ShapeTracker {
1908
2017
  let st = this;
1909
2018
  if (axis.length > 0) {
1910
2019
  const unsqueezed = [...st.shape];
1911
- for (const i of axis.toSorted()) unsqueezed.splice(i, 0, 1);
2020
+ for (const i of sorted(axis)) unsqueezed.splice(i, 0, 1);
1912
2021
  st = st.reshape(unsqueezed);
1913
2022
  }
1914
2023
  return st.expand(newShape);
@@ -2133,7 +2242,7 @@ function tuneWebgpu(kernel) {
2133
2242
  break;
2134
2243
  }
2135
2244
  }
2136
- for (const ax of Array.from(upcastedAxis).sort()) {
2245
+ for (const ax of sorted(upcastedAxis)) {
2137
2246
  const s = dim.st.shape[ax];
2138
2247
  for (const amount of [8, 4]) if (s % amount === 0) {
2139
2248
  dim.applyLocal(ax, amount);
@@ -2251,13 +2360,21 @@ var CpuBackend = class {
2251
2360
  if (count === void 0) count = buffer.byteLength - start;
2252
2361
  return buffer.slice(start, start + count);
2253
2362
  }
2254
- async prepare(kernel) {
2255
- return this.prepareSync(kernel);
2363
+ async prepareKernel(kernel) {
2364
+ return this.prepareKernelSync(kernel);
2256
2365
  }
2257
- prepareSync(kernel) {
2366
+ prepareKernelSync(kernel) {
2258
2367
  return new Executable(kernel, void 0);
2259
2368
  }
2260
- dispatch({ kernel }, inputs, outputs) {
2369
+ async prepareRoutine(routine) {
2370
+ return this.prepareRoutineSync(routine);
2371
+ }
2372
+ prepareRoutineSync(routine) {
2373
+ return new Executable(routine, void 0);
2374
+ }
2375
+ dispatch(exe, inputs, outputs) {
2376
+ if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
2377
+ const kernel = exe.source;
2261
2378
  const { exp, epilogue } = tuneNullopt(kernel);
2262
2379
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
2263
2380
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
@@ -2315,8 +2432,10 @@ var WasmAllocator = class {
2315
2432
  const sizeClass = this.#findSizeClass(size);
2316
2433
  const freeList = this.#freeLists.get(sizeClass);
2317
2434
  let ptr;
2318
- if (freeList && freeList.length > 0) ptr = freeList.pop();
2319
- else ptr = this.#bumpAlloc(sizeClass);
2435
+ if (freeList && freeList.length > 0) {
2436
+ ptr = freeList.pop();
2437
+ new Uint8Array(this.#memory.buffer, ptr, sizeClass).fill(0);
2438
+ } else ptr = this.#bumpAlloc(sizeClass);
2320
2439
  this.#allocatedBuffers.set(ptr, sizeClass);
2321
2440
  return ptr;
2322
2441
  }
@@ -3682,10 +3801,10 @@ var WasmBackend = class {
3682
3801
  if (count === void 0) count = buffer.byteLength - start;
3683
3802
  return buffer.slice(start, start + count);
3684
3803
  }
3685
- async prepare(kernel) {
3686
- return this.prepareSync(kernel);
3804
+ async prepareKernel(kernel) {
3805
+ return this.prepareKernelSync(kernel);
3687
3806
  }
3688
- prepareSync(kernel) {
3807
+ prepareKernelSync(kernel) {
3689
3808
  const kernelHash = FpHash.hash(kernel);
3690
3809
  const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => {
3691
3810
  const bytes = codegenWasm(kernel);
@@ -3693,7 +3812,14 @@ var WasmBackend = class {
3693
3812
  });
3694
3813
  return new Executable(kernel, { module: module$1 });
3695
3814
  }
3815
+ async prepareRoutine(routine) {
3816
+ return this.prepareRoutineSync(routine);
3817
+ }
3818
+ prepareRoutineSync(routine) {
3819
+ return new Executable(routine, void 0);
3820
+ }
3696
3821
  dispatch(exe, inputs, outputs) {
3822
+ if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
3697
3823
  const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
3698
3824
  const func = instance.exports.kernel;
3699
3825
  const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
@@ -4041,7 +4167,7 @@ async function createBackend(device) {
4041
4167
  if (!navigator.gpu) return null;
4042
4168
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4043
4169
  if (!adapter) return null;
4044
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVns4DbI.cjs"));
4170
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Oj3Kd-kd.cjs"));
4045
4171
  const importantLimits = [
4046
4172
  "maxBufferSize",
4047
4173
  "maxComputeInvocationsPerWorkgroup",
@@ -4075,8 +4201,8 @@ function getBackend(device) {
4075
4201
  return backend;
4076
4202
  }
4077
4203
  var Executable = class {
4078
- constructor(kernel, data) {
4079
- this.kernel = kernel;
4204
+ constructor(source, data) {
4205
+ this.source = source;
4080
4206
  this.data = data;
4081
4207
  }
4082
4208
  };
@@ -4092,6 +4218,11 @@ var UnsupportedOpError = class extends Error {
4092
4218
  super(msg);
4093
4219
  }
4094
4220
  };
4221
+ var UnsupportedRoutineError = class extends Error {
4222
+ constructor(name, device) {
4223
+ super(`routine '${name}' is not supported in ${device} backend`);
4224
+ }
4225
+ };
4095
4226
 
4096
4227
  //#endregion
4097
4228
  Object.defineProperty(exports, 'AluExp', {
@@ -4160,6 +4291,18 @@ Object.defineProperty(exports, 'Reduction', {
4160
4291
  return Reduction;
4161
4292
  }
4162
4293
  });
4294
+ Object.defineProperty(exports, 'Routine', {
4295
+ enumerable: true,
4296
+ get: function () {
4297
+ return Routine;
4298
+ }
4299
+ });
4300
+ Object.defineProperty(exports, 'Routines', {
4301
+ enumerable: true,
4302
+ get: function () {
4303
+ return Routines;
4304
+ }
4305
+ });
4163
4306
  Object.defineProperty(exports, 'ShapeTracker', {
4164
4307
  enumerable: true,
4165
4308
  get: function () {
@@ -4178,6 +4321,12 @@ Object.defineProperty(exports, 'UnsupportedOpError', {
4178
4321
  return UnsupportedOpError;
4179
4322
  }
4180
4323
  });
4324
+ Object.defineProperty(exports, 'UnsupportedRoutineError', {
4325
+ enumerable: true,
4326
+ get: function () {
4327
+ return UnsupportedRoutineError;
4328
+ }
4329
+ });
4181
4330
  Object.defineProperty(exports, 'accessorAluExp', {
4182
4331
  enumerable: true,
4183
4332
  get: function () {
@@ -4387,5 +4536,4 @@ Object.defineProperty(exports, 'zipn', {
4387
4536
  get: function () {
4388
4537
  return zipn;
4389
4538
  }
4390
- });
4391
- //# sourceMappingURL=backend-CmaidnkQ.cjs.map
4539
+ });
@@ -68,6 +68,9 @@ function zipn(...arrays) {
68
68
  const minLength = Math.min(...arrays.map((x) => x.length));
69
69
  return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
70
70
  }
71
+ function sorted(arr) {
72
+ return [...arr].sort((a, b) => a - b);
73
+ }
71
74
  function rep(length, value) {
72
75
  if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
73
76
  return new Array(length).fill(value);
@@ -144,7 +147,7 @@ function normalizeAxis(axis, ndim) {
144
147
  if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
145
148
  seen.add(ca);
146
149
  }
147
- return [...seen].sort();
150
+ return sorted(seen);
148
151
  }
149
152
  }
150
153
  function range(start, stop, step = 1) {
@@ -1327,7 +1330,7 @@ var Reduction = class {
1327
1330
  /** Evaluate this operation on CPU. */
1328
1331
  evaluate(...values) {
1329
1332
  if (this.dtype === DType.Bool) {
1330
- if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, true);
1333
+ if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, false);
1331
1334
  else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
1332
1335
  } else if (this.dtype === DType.Int32) {
1333
1336
  if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
@@ -1437,6 +1440,112 @@ function erfc(x) {
1437
1440
  else return 2 - _erfapprox$1(-x);
1438
1441
  }
1439
1442
 
1443
+ //#endregion
1444
+ //#region src/routine.ts
1445
+ /**
1446
+ * Advanced operations that don't fit into the `AluExp` compiler representation.
1447
+ *
1448
+ * Some routines like iterative matrix algorithms, FFTs, or sorting may not be
1449
+ * easy to express efficiently as a `Kernel` object. These also tend to be
1450
+ * somewhat expensive, so the benefit of kernel fusion and inlining is less
1451
+ * relevant.
1452
+ *
1453
+ * For these operations, we dispatch them as a custom operation on the backend,
1454
+ * which each backend implements in a specific way. These are listed in the
1455
+ * `Routines` enum below.
1456
+ *
1457
+ * Routines cannot be fused into other kernels and always operate on contiguous
1458
+ * arrays (default `ShapeTracker`).
1459
+ */
1460
+ var Routine = class {
1461
+ constructor(name, type, params) {
1462
+ this.name = name;
1463
+ this.type = type;
1464
+ this.params = params;
1465
+ }
1466
+ };
1467
+ /** One of the valid `Routine` that can be dispatched to backend. */
1468
+ let Routines = /* @__PURE__ */ function(Routines$1) {
1469
+ /** Stable sorting algorithm along the last axis. */
1470
+ Routines$1["Sort"] = "Sort";
1471
+ /** Returns `int32` indices of the stably sorted array. */
1472
+ Routines$1["Argsort"] = "Argsort";
1473
+ /** Solve a triangular system of questions. */
1474
+ Routines$1["TriangularSolve"] = "TriangularSolve";
1475
+ /** Cholesky decomposition of 2D positive semi-definite matrices. */
1476
+ Routines$1["Cholesky"] = "Cholesky";
1477
+ return Routines$1;
1478
+ }({});
1479
+ function runCpuRoutine(routine, inputs, outputs) {
1480
+ const { name, type } = routine;
1481
+ const inputAr = inputs.map((buf, i) => dtypedArray(type.inputDtypes[i], buf));
1482
+ const outputAr = outputs.map((buf, i) => dtypedArray(type.outputDtypes[i], buf));
1483
+ switch (name) {
1484
+ case Routines.Sort: return runSort(type, inputAr, outputAr);
1485
+ case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
1486
+ case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
1487
+ case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
1488
+ default:
1489
+ }
1490
+ }
1491
+ function runSort(type, [x], [y]) {
1492
+ const xs = type.inputShapes[0];
1493
+ if (xs.length === 0) throw new Error("sort: cannot sort a scalar");
1494
+ const n = xs[xs.length - 1];
1495
+ y.set(x);
1496
+ for (let i = 0; i < y.length; i += n) y.subarray(i, i + n).sort();
1497
+ }
1498
+ function runArgsort(type, [x], [y, yi]) {
1499
+ const xs = type.inputShapes[0];
1500
+ if (xs.length === 0) throw new Error("argsort: cannot sort a scalar");
1501
+ const n = xs[xs.length - 1];
1502
+ for (let offset = 0; offset < y.length; offset += n) {
1503
+ const ar = x.subarray(offset, offset + n);
1504
+ const out = y.subarray(offset, offset + n);
1505
+ const outi = yi.subarray(offset, offset + n);
1506
+ for (let i = 0; i < n; i++) outi[i] = i;
1507
+ outi.sort((a, b) => ar[a] - ar[b]);
1508
+ for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1509
+ }
1510
+ }
1511
+ function runTriangularSolve(type, [a, b], [x], { unitDiagonal }) {
1512
+ const as = type.inputShapes[0];
1513
+ const bs = type.inputShapes[1];
1514
+ if (as.length < 2) throw new Error(`triangular_solve: a must be at least 2D, got ${as}`);
1515
+ if (bs.length < 2) throw new Error(`triangular_solve: b must be at least 2D, got ${bs}`);
1516
+ const n = as[as.length - 2];
1517
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
1518
+ const batch = bs[bs.length - 2];
1519
+ for (let counter = 0; counter < a.length / (n * n); counter++) {
1520
+ const a1 = a.subarray(counter * n * n, (counter + 1) * n * n);
1521
+ for (let t = 0; t < batch; t++) {
1522
+ const b1 = b.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
1523
+ const x1 = x.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
1524
+ for (let i = n - 1; i >= 0; i--) {
1525
+ let sum = b1[i];
1526
+ for (let j = i + 1; j < n; j++) sum -= a1[i * n + j] * x1[j];
1527
+ x1[i] = unitDiagonal ? sum : sum / a1[i * n + i];
1528
+ }
1529
+ }
1530
+ }
1531
+ }
1532
+ function runCholesky(type, [x], [y]) {
1533
+ const xs = type.inputShapes[0];
1534
+ if (xs.length < 2) throw new Error("cholesky: input must be at least 2D");
1535
+ const n = xs[xs.length - 2];
1536
+ const m = xs[xs.length - 1];
1537
+ if (n !== m) throw new Error(`cholesky: input must be square, got [${n}, ${m}]`);
1538
+ for (let offset = 0; offset < y.length; offset += n * n) {
1539
+ const ar = x.subarray(offset, offset + n * n);
1540
+ const out = y.subarray(offset, offset + n * n);
1541
+ for (let i = 0; i < n; i++) for (let j = 0; j <= i; j++) {
1542
+ let sum = ar[i * n + j];
1543
+ for (let k = 0; k < j; k++) sum -= out[i * n + k] * out[j * n + k];
1544
+ out[i * n + j] = i === j ? Math.sqrt(sum) : sum / out[j * n + j];
1545
+ }
1546
+ }
1547
+ }
1548
+
1440
1549
  //#endregion
1441
1550
  //#region src/shape.ts
1442
1551
  const jstr = JSON.stringify;
@@ -1907,7 +2016,7 @@ var ShapeTracker = class ShapeTracker {
1907
2016
  let st = this;
1908
2017
  if (axis.length > 0) {
1909
2018
  const unsqueezed = [...st.shape];
1910
- for (const i of axis.toSorted()) unsqueezed.splice(i, 0, 1);
2019
+ for (const i of sorted(axis)) unsqueezed.splice(i, 0, 1);
1911
2020
  st = st.reshape(unsqueezed);
1912
2021
  }
1913
2022
  return st.expand(newShape);
@@ -2132,7 +2241,7 @@ function tuneWebgpu(kernel) {
2132
2241
  break;
2133
2242
  }
2134
2243
  }
2135
- for (const ax of Array.from(upcastedAxis).sort()) {
2244
+ for (const ax of sorted(upcastedAxis)) {
2136
2245
  const s = dim.st.shape[ax];
2137
2246
  for (const amount of [8, 4]) if (s % amount === 0) {
2138
2247
  dim.applyLocal(ax, amount);
@@ -2250,13 +2359,21 @@ var CpuBackend = class {
2250
2359
  if (count === void 0) count = buffer.byteLength - start;
2251
2360
  return buffer.slice(start, start + count);
2252
2361
  }
2253
- async prepare(kernel) {
2254
- return this.prepareSync(kernel);
2362
+ async prepareKernel(kernel) {
2363
+ return this.prepareKernelSync(kernel);
2255
2364
  }
2256
- prepareSync(kernel) {
2365
+ prepareKernelSync(kernel) {
2257
2366
  return new Executable(kernel, void 0);
2258
2367
  }
2259
- dispatch({ kernel }, inputs, outputs) {
2368
+ async prepareRoutine(routine) {
2369
+ return this.prepareRoutineSync(routine);
2370
+ }
2371
+ prepareRoutineSync(routine) {
2372
+ return new Executable(routine, void 0);
2373
+ }
2374
+ dispatch(exe, inputs, outputs) {
2375
+ if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
2376
+ const kernel = exe.source;
2260
2377
  const { exp, epilogue } = tuneNullopt(kernel);
2261
2378
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
2262
2379
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
@@ -2314,8 +2431,10 @@ var WasmAllocator = class {
2314
2431
  const sizeClass = this.#findSizeClass(size);
2315
2432
  const freeList = this.#freeLists.get(sizeClass);
2316
2433
  let ptr;
2317
- if (freeList && freeList.length > 0) ptr = freeList.pop();
2318
- else ptr = this.#bumpAlloc(sizeClass);
2434
+ if (freeList && freeList.length > 0) {
2435
+ ptr = freeList.pop();
2436
+ new Uint8Array(this.#memory.buffer, ptr, sizeClass).fill(0);
2437
+ } else ptr = this.#bumpAlloc(sizeClass);
2319
2438
  this.#allocatedBuffers.set(ptr, sizeClass);
2320
2439
  return ptr;
2321
2440
  }
@@ -3681,10 +3800,10 @@ var WasmBackend = class {
3681
3800
  if (count === void 0) count = buffer.byteLength - start;
3682
3801
  return buffer.slice(start, start + count);
3683
3802
  }
3684
- async prepare(kernel) {
3685
- return this.prepareSync(kernel);
3803
+ async prepareKernel(kernel) {
3804
+ return this.prepareKernelSync(kernel);
3686
3805
  }
3687
- prepareSync(kernel) {
3806
+ prepareKernelSync(kernel) {
3688
3807
  const kernelHash = FpHash.hash(kernel);
3689
3808
  const module = runWithCache(moduleCache, kernelHash.toString(), () => {
3690
3809
  const bytes = codegenWasm(kernel);
@@ -3692,7 +3811,14 @@ var WasmBackend = class {
3692
3811
  });
3693
3812
  return new Executable(kernel, { module });
3694
3813
  }
3814
+ async prepareRoutine(routine) {
3815
+ return this.prepareRoutineSync(routine);
3816
+ }
3817
+ prepareRoutineSync(routine) {
3818
+ return new Executable(routine, void 0);
3819
+ }
3695
3820
  dispatch(exe, inputs, outputs) {
3821
+ if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
3696
3822
  const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
3697
3823
  const func = instance.exports.kernel;
3698
3824
  const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
@@ -4040,7 +4166,7 @@ async function createBackend(device) {
4040
4166
  if (!navigator.gpu) return null;
4041
4167
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4042
4168
  if (!adapter) return null;
4043
- const { WebGPUBackend } = await import("./webgpu-C9iAP5h5.js");
4169
+ const { WebGPUBackend } = await import("./webgpu-ChVgx3b6.js");
4044
4170
  const importantLimits = [
4045
4171
  "maxBufferSize",
4046
4172
  "maxComputeInvocationsPerWorkgroup",
@@ -4074,8 +4200,8 @@ function getBackend(device) {
4074
4200
  return backend;
4075
4201
  }
4076
4202
  var Executable = class {
4077
- constructor(kernel, data) {
4078
- this.kernel = kernel;
4203
+ constructor(source, data) {
4204
+ this.source = source;
4079
4205
  this.data = data;
4080
4206
  }
4081
4207
  };
@@ -4091,7 +4217,11 @@ var UnsupportedOpError = class extends Error {
4091
4217
  super(msg);
4092
4218
  }
4093
4219
  };
4220
+ var UnsupportedRoutineError = class extends Error {
4221
+ constructor(name, device) {
4222
+ super(`routine '${name}' is not supported in ${device} backend`);
4223
+ }
4224
+ };
4094
4225
 
4095
4226
  //#endregion
4096
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
4097
- //# sourceMappingURL=backend-BY8wlLEl.js.map
4227
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };