@jax-js/jax 0.1.2 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -3,7 +3,8 @@
3
3
  <p align="center"><strong>
4
4
  <a href="https://jax-js.com">Website</a> |
5
5
  <a href="https://jax-js.com/docs/">API Reference</a> |
6
- <a href="./FEATURES.md">Compatibility Table</a>
6
+ <a href="./FEATURES.md">Compatibility Table</a> |
7
+ <a href="https://discord.gg/BW6YsCd4Tf">Discord</a>
7
8
  </strong></p>
8
9
 
9
10
  **jax-js** is a machine learning framework for the browser. It aims to bring
@@ -257,36 +258,12 @@ await devicePut(ar, "webgpu"); // Now device="webgpu"
257
258
  There are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a
258
259
  self-contained way in other projects.
259
260
 
260
- **`@jax-js/optax`** provides implementations of optimizers like Adam and SGD.
261
-
262
- ```ts
263
- import { adam } from "@jax-js/optax";
264
-
265
- let params = np.array([1.0, 2.0, 3.0]);
266
-
267
- const solver = adam(1e-3);
268
- let optState = solver.init(params.ref);
269
- let updates: np.Array;
270
-
271
- const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();
272
-
273
- for (let i = 0; i < 100; i++) {
274
- const paramsGrad = grad(f)(params.ref);
275
- [updates, optState] = solver.update(paramsGrad, optState);
276
- params = applyUpdates(params, updates);
277
- }
278
- ```
279
-
280
- **`@jax-js/loaders`** can load tensors from various formats like Safetensors, includes a fast and
281
- compliant implementation of BPE, and caches HTTP requests for large assets like model weights in
282
- OPFS.
283
-
284
- ```ts
285
- import { tokenizers } from "@jax-js/loaders";
286
-
287
- const enc = await tokenizers.getBpe("clip");
288
- const tokens = enc.encode("Hello, world!"); // => [ 49406, 3306, 267, 1002, ... ]
289
- ```
261
+ - [**`@jax-js/loaders`**](packages/loaders) can load tensors from various formats like Safetensors,
262
+ includes a fast and compliant implementation of BPE, and caches HTTP requests for large assets
263
+ like model weights in OPFS.
264
+ - [**`@jax-js/onnx`**](packages/onnx) is a model loader from the [ONNX](https://onnx.ai/) format
265
+ into native jax-js functions.
266
+ - [**`@jax-js/optax`**](packages/optax) provides implementations of optimizers like Adam and SGD.
290
267
 
291
268
  ### Performance
292
269
 
@@ -311,6 +288,7 @@ If you make something cool with jax-js, don't be a stranger! We can feature it h
311
288
 
312
289
  - [Training neural networks on MNIST](https://jax-js.com/mnist)
313
290
  - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
291
+ - [Object detection with DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
314
292
  - [In-browser REPL](https://jax-js.com/repl)
315
293
  - [Matmul benchmark](https://jax-js.com/bench/matmul)
316
294
  - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
@@ -346,12 +324,16 @@ pnpm -C website dev
346
324
 
347
325
  ## Future work / help wanted
348
326
 
349
- Contributions are welcomed! Especially in:
327
+ Contributions are welcomed! Some fruitful areas to look into:
350
328
 
351
329
  - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
352
330
  - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
353
331
  and multithreading. (Even single-threaded Wasm could be ~20x faster.)
354
- - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches and adding
355
- epilogue to reductions.
332
+ - Adding support for `jax.profiling`, in particular the start and end trace functions. We should be
333
+ able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
334
+ to help with model performance debugging.
335
+ - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
356
336
  - Adding WebGL runtime for older browsers that don't support WebGPU.
357
337
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
338
+
339
+ You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
@@ -69,6 +69,9 @@ function zipn(...arrays) {
69
69
  const minLength = Math.min(...arrays.map((x) => x.length));
70
70
  return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
71
71
  }
72
+ function sorted(arr) {
73
+ return [...arr].sort((a, b) => a - b);
74
+ }
72
75
  function rep(length, value) {
73
76
  if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
74
77
  return new Array(length).fill(value);
@@ -145,7 +148,7 @@ function normalizeAxis(axis, ndim) {
145
148
  if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
146
149
  seen.add(ca);
147
150
  }
148
- return [...seen].sort();
151
+ return sorted(seen);
149
152
  }
150
153
  }
151
154
  function range(start, stop, step = 1) {
@@ -558,16 +561,16 @@ var AluExp = class AluExp {
558
561
  });
559
562
  }
560
563
  /** Reindex gid values in this expression as needed. */
561
- reindexGids(gidMap) {
564
+ reindexGids(newGids) {
562
565
  return this.rewrite((exp) => {
563
566
  if (exp.op === AluOp.GlobalIndex) {
564
567
  const [gid, len] = exp.arg;
565
- const newGid = gidMap.get(gid);
566
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
568
+ const newGid = newGids[gid];
569
+ if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
567
570
  } else if (exp.op === AluOp.GlobalView) {
568
571
  const gid = exp.arg[0];
569
- const newGid = gidMap.get(gid);
570
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
572
+ const newGid = newGids[gid];
573
+ if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
571
574
  }
572
575
  });
573
576
  }
@@ -781,7 +784,7 @@ var AluExp = class AluExp {
781
784
  if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
782
785
  if (op === AluOp.Mul && x === 1) return src[1 - i];
783
786
  if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
784
- if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
787
+ if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
785
788
  if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
786
789
  }
787
790
  if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
@@ -1328,7 +1331,7 @@ var Reduction = class {
1328
1331
  /** Evaluate this operation on CPU. */
1329
1332
  evaluate(...values) {
1330
1333
  if (this.dtype === DType.Bool) {
1331
- if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, true);
1334
+ if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, false);
1332
1335
  else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
1333
1336
  } else if (this.dtype === DType.Int32) {
1334
1337
  if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
@@ -1438,6 +1441,112 @@ function erfc(x) {
1438
1441
  else return 2 - _erfapprox$1(-x);
1439
1442
  }
1440
1443
 
1444
+ //#endregion
1445
+ //#region src/routine.ts
1446
+ /**
1447
+ * Advanced operations that don't fit into the `AluExp` compiler representation.
1448
+ *
1449
+ * Some routines like iterative matrix algorithms, FFTs, or sorting may not be
1450
+ * easy to express efficiently as a `Kernel` object. These also tend to be
1451
+ * somewhat expensive, so the benefit of kernel fusion and inlining is less
1452
+ * relevant.
1453
+ *
1454
+ * For these operations, we dispatch them as a custom operation on the backend,
1455
+ * which each backend implements in a specific way. These are listed in the
1456
+ * `Routines` enum below.
1457
+ *
1458
+ * Routines cannot be fused into other kernels and always operate on contiguous
1459
+ * arrays (default `ShapeTracker`).
1460
+ */
1461
+ var Routine = class {
1462
+ constructor(name, type, params) {
1463
+ this.name = name;
1464
+ this.type = type;
1465
+ this.params = params;
1466
+ }
1467
+ };
1468
+ /** One of the valid `Routine` that can be dispatched to backend. */
1469
+ let Routines = /* @__PURE__ */ function(Routines$1) {
1470
+ /** Stable sorting algorithm along the last axis. */
1471
+ Routines$1["Sort"] = "Sort";
1472
+ /** Returns `int32` indices of the stably sorted array. */
1473
+ Routines$1["Argsort"] = "Argsort";
1474
+ /** Solve a triangular system of questions. */
1475
+ Routines$1["TriangularSolve"] = "TriangularSolve";
1476
+ /** Cholesky decomposition of 2D positive semi-definite matrices. */
1477
+ Routines$1["Cholesky"] = "Cholesky";
1478
+ return Routines$1;
1479
+ }({});
1480
+ function runCpuRoutine(routine, inputs, outputs) {
1481
+ const { name, type } = routine;
1482
+ const inputAr = inputs.map((buf, i) => dtypedArray(type.inputDtypes[i], buf));
1483
+ const outputAr = outputs.map((buf, i) => dtypedArray(type.outputDtypes[i], buf));
1484
+ switch (name) {
1485
+ case Routines.Sort: return runSort(type, inputAr, outputAr);
1486
+ case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
1487
+ case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
1488
+ case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
1489
+ default:
1490
+ }
1491
+ }
1492
+ function runSort(type, [x], [y]) {
1493
+ const xs = type.inputShapes[0];
1494
+ if (xs.length === 0) throw new Error("sort: cannot sort a scalar");
1495
+ const n = xs[xs.length - 1];
1496
+ y.set(x);
1497
+ for (let i = 0; i < y.length; i += n) y.subarray(i, i + n).sort();
1498
+ }
1499
+ function runArgsort(type, [x], [y, yi]) {
1500
+ const xs = type.inputShapes[0];
1501
+ if (xs.length === 0) throw new Error("argsort: cannot sort a scalar");
1502
+ const n = xs[xs.length - 1];
1503
+ for (let offset = 0; offset < y.length; offset += n) {
1504
+ const ar = x.subarray(offset, offset + n);
1505
+ const out = y.subarray(offset, offset + n);
1506
+ const outi = yi.subarray(offset, offset + n);
1507
+ for (let i = 0; i < n; i++) outi[i] = i;
1508
+ outi.sort((a, b) => ar[a] - ar[b]);
1509
+ for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
1510
+ }
1511
+ }
1512
+ function runTriangularSolve(type, [a, b], [x], { unitDiagonal }) {
1513
+ const as = type.inputShapes[0];
1514
+ const bs = type.inputShapes[1];
1515
+ if (as.length < 2) throw new Error(`triangular_solve: a must be at least 2D, got ${as}`);
1516
+ if (bs.length < 2) throw new Error(`triangular_solve: b must be at least 2D, got ${bs}`);
1517
+ const n = as[as.length - 2];
1518
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
1519
+ const batch = bs[bs.length - 2];
1520
+ for (let counter = 0; counter < a.length / (n * n); counter++) {
1521
+ const a1 = a.subarray(counter * n * n, (counter + 1) * n * n);
1522
+ for (let t = 0; t < batch; t++) {
1523
+ const b1 = b.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
1524
+ const x1 = x.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
1525
+ for (let i = n - 1; i >= 0; i--) {
1526
+ let sum = b1[i];
1527
+ for (let j = i + 1; j < n; j++) sum -= a1[i * n + j] * x1[j];
1528
+ x1[i] = unitDiagonal ? sum : sum / a1[i * n + i];
1529
+ }
1530
+ }
1531
+ }
1532
+ }
1533
+ function runCholesky(type, [x], [y]) {
1534
+ const xs = type.inputShapes[0];
1535
+ if (xs.length < 2) throw new Error("cholesky: input must be at least 2D");
1536
+ const n = xs[xs.length - 2];
1537
+ const m = xs[xs.length - 1];
1538
+ if (n !== m) throw new Error(`cholesky: input must be square, got [${n}, ${m}]`);
1539
+ for (let offset = 0; offset < y.length; offset += n * n) {
1540
+ const ar = x.subarray(offset, offset + n * n);
1541
+ const out = y.subarray(offset, offset + n * n);
1542
+ for (let i = 0; i < n; i++) for (let j = 0; j <= i; j++) {
1543
+ let sum = ar[i * n + j];
1544
+ for (let k = 0; k < j; k++) sum -= out[i * n + k] * out[j * n + k];
1545
+ out[i * n + j] = i === j ? Math.sqrt(sum) : sum / out[j * n + j];
1546
+ }
1547
+ }
1548
+ }
1549
+
1441
1550
  //#endregion
1442
1551
  //#region src/shape.ts
1443
1552
  const jstr = JSON.stringify;
@@ -1908,7 +2017,7 @@ var ShapeTracker = class ShapeTracker {
1908
2017
  let st = this;
1909
2018
  if (axis.length > 0) {
1910
2019
  const unsqueezed = [...st.shape];
1911
- for (const i of axis.toSorted()) unsqueezed.splice(i, 0, 1);
2020
+ for (const i of sorted(axis)) unsqueezed.splice(i, 0, 1);
1912
2021
  st = st.reshape(unsqueezed);
1913
2022
  }
1914
2023
  return st.expand(newShape);
@@ -2067,7 +2176,8 @@ function tuneNullopt(kernel) {
2067
2176
  if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2068
2177
  return {
2069
2178
  exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
2070
- outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
2179
+ epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
2180
+ outputIdxExp: vars.gidx,
2071
2181
  threadCount: kernel.size,
2072
2182
  size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
2073
2183
  };
@@ -2100,7 +2210,11 @@ function tuneWebgpu(kernel) {
2100
2210
  while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
2101
2211
  const choices = [];
2102
2212
  const composedSts = sts.map((st) => st.compose(dim.st));
2103
- for (let axis = 0; axis < dim.groups; axis++) for (const amount of [3, 4]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2213
+ for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
2214
+ 3,
2215
+ 4,
2216
+ 5
2217
+ ]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2104
2218
  let nonzeroStrides = 0;
2105
2219
  let totalStrides = 0;
2106
2220
  for (const st of composedSts) {
@@ -2128,7 +2242,7 @@ function tuneWebgpu(kernel) {
2128
2242
  break;
2129
2243
  }
2130
2244
  }
2131
- for (const ax of Array.from(upcastedAxis).sort()) {
2245
+ for (const ax of sorted(upcastedAxis)) {
2132
2246
  const s = dim.st.shape[ax];
2133
2247
  for (const amount of [8, 4]) if (s % amount === 0) {
2134
2248
  dim.applyLocal(ax, amount);
@@ -2176,7 +2290,15 @@ function tuneWebgpu(kernel) {
2176
2290
  });
2177
2291
  const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
2178
2292
  const outputUpcast = dim.outputSt.shape.slice(dim.groups);
2179
- const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
2293
+ const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
2294
+ const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
2295
+ const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
2296
+ if (exp$1.op === AluOp.GlobalView) {
2297
+ const gid = exp$1.arg[0];
2298
+ const st = exp$1.arg[1];
2299
+ return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
2300
+ }
2301
+ });
2180
2302
  if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
2181
2303
  const size = {
2182
2304
  groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
@@ -2186,6 +2308,7 @@ function tuneWebgpu(kernel) {
2186
2308
  };
2187
2309
  return {
2188
2310
  exp: newExp.simplify(),
2311
+ epilogue: newEpilogue.simplify(),
2189
2312
  outputIdxExp: outputIdxExp.simplify(),
2190
2313
  threadCount: kernel.size / size.upcast * size.groups,
2191
2314
  size
@@ -2237,17 +2360,25 @@ var CpuBackend = class {
2237
2360
  if (count === void 0) count = buffer.byteLength - start;
2238
2361
  return buffer.slice(start, start + count);
2239
2362
  }
2240
- async prepare(kernel) {
2241
- return this.prepareSync(kernel);
2363
+ async prepareKernel(kernel) {
2364
+ return this.prepareKernelSync(kernel);
2242
2365
  }
2243
- prepareSync(kernel) {
2366
+ prepareKernelSync(kernel) {
2244
2367
  return new Executable(kernel, void 0);
2245
2368
  }
2246
- dispatch({ kernel }, inputs, outputs) {
2247
- const { exp } = tuneNullopt(kernel);
2369
+ async prepareRoutine(routine) {
2370
+ return this.prepareRoutineSync(routine);
2371
+ }
2372
+ prepareRoutineSync(routine) {
2373
+ return new Executable(routine, void 0);
2374
+ }
2375
+ dispatch(exe, inputs, outputs) {
2376
+ if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
2377
+ const kernel = exe.source;
2378
+ const { exp, epilogue } = tuneNullopt(kernel);
2248
2379
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
2249
2380
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
2250
- const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2381
+ const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2251
2382
  const inputArrays = inputBuffers.map((buf, i) => {
2252
2383
  const dtype = usedArgs.get(i);
2253
2384
  if (!dtype) return null;
@@ -2269,7 +2400,10 @@ var CpuBackend = class {
2269
2400
  }, globals);
2270
2401
  acc = kernel.reduction.evaluate(acc, item);
2271
2402
  }
2272
- outputArray[i] = kernel.reduction.epilogue.evaluate({ acc });
2403
+ outputArray[i] = epilogue.evaluate({
2404
+ acc,
2405
+ gidx: i
2406
+ }, globals);
2273
2407
  }
2274
2408
  }
2275
2409
  #getBuffer(slot) {
@@ -2298,8 +2432,10 @@ var WasmAllocator = class {
2298
2432
  const sizeClass = this.#findSizeClass(size);
2299
2433
  const freeList = this.#freeLists.get(sizeClass);
2300
2434
  let ptr;
2301
- if (freeList && freeList.length > 0) ptr = freeList.pop();
2302
- else ptr = this.#bumpAlloc(sizeClass);
2435
+ if (freeList && freeList.length > 0) {
2436
+ ptr = freeList.pop();
2437
+ new Uint8Array(this.#memory.buffer, ptr, sizeClass).fill(0);
2438
+ } else ptr = this.#bumpAlloc(sizeClass);
2303
2439
  this.#allocatedBuffers.set(ptr, sizeClass);
2304
2440
  return ptr;
2305
2441
  }
@@ -2432,7 +2568,7 @@ function wasm_log(cg) {
2432
2568
  const t2 = cg.local.declare(cg.f32);
2433
2569
  cg.local.get(0);
2434
2570
  cg.f32.const(0);
2435
- cg.f32.le();
2571
+ cg.f32.lt();
2436
2572
  cg.if(cg.void);
2437
2573
  cg.f32.const(NaN);
2438
2574
  cg.return();
@@ -2447,6 +2583,20 @@ function wasm_log(cg) {
2447
2583
  cg.i32.const(127);
2448
2584
  cg.i32.sub();
2449
2585
  cg.local.set(e);
2586
+ cg.local.get(e);
2587
+ cg.i32.const(-127);
2588
+ cg.i32.eq();
2589
+ cg.if(cg.void);
2590
+ cg.f32.const(-Infinity);
2591
+ cg.return();
2592
+ cg.end();
2593
+ cg.local.get(e);
2594
+ cg.i32.const(128);
2595
+ cg.i32.eq();
2596
+ cg.if(cg.void);
2597
+ cg.local.get(0);
2598
+ cg.return();
2599
+ cg.end();
2450
2600
  cg.local.get(bits);
2451
2601
  cg.i32.const(8388607);
2452
2602
  cg.i32.and();
@@ -2512,7 +2662,7 @@ function _sincos(cg) {
2512
2662
  cg.f32.mul();
2513
2663
  cg.f32.nearest();
2514
2664
  cg.local.tee(qf);
2515
- cg.i32.trunc_f32_s();
2665
+ cg.i32.trunc_sat_f32_s();
2516
2666
  cg.local.set(q);
2517
2667
  cg.local.get(y);
2518
2668
  cg.local.get(qf);
@@ -3599,6 +3749,7 @@ var F32x4 = class extends V128 {
3599
3749
 
3600
3750
  //#endregion
3601
3751
  //#region src/backend/wasm.ts
3752
+ const moduleCache = /* @__PURE__ */ new Map();
3602
3753
  /** Backend that compiles into WebAssembly bytecode for immediate execution. */
3603
3754
  var WasmBackend = class {
3604
3755
  type = "wasm";
@@ -3650,15 +3801,25 @@ var WasmBackend = class {
3650
3801
  if (count === void 0) count = buffer.byteLength - start;
3651
3802
  return buffer.slice(start, start + count);
3652
3803
  }
3653
- async prepare(kernel) {
3654
- return this.prepareSync(kernel);
3804
+ async prepareKernel(kernel) {
3805
+ return this.prepareKernelSync(kernel);
3655
3806
  }
3656
- prepareSync(kernel) {
3657
- const bytes = codegenWasm(kernel);
3658
- const module$1 = new WebAssembly.Module(bytes);
3807
+ prepareKernelSync(kernel) {
3808
+ const kernelHash = FpHash.hash(kernel);
3809
+ const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => {
3810
+ const bytes = codegenWasm(kernel);
3811
+ return new WebAssembly.Module(bytes);
3812
+ });
3659
3813
  return new Executable(kernel, { module: module$1 });
3660
3814
  }
3815
+ async prepareRoutine(routine) {
3816
+ return this.prepareRoutineSync(routine);
3817
+ }
3818
+ prepareRoutineSync(routine) {
3819
+ return new Executable(routine, void 0);
3820
+ }
3661
3821
  dispatch(exe, inputs, outputs) {
3822
+ if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
3662
3823
  const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
3663
3824
  const func = instance.exports.kernel;
3664
3825
  const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
@@ -3676,7 +3837,7 @@ function codegenWasm(kernel) {
3676
3837
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3677
3838
  const cg = new CodeGenerator();
3678
3839
  cg.memory.import("env", "memory");
3679
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3840
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
3680
3841
  const funcs = {};
3681
3842
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3682
3843
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
@@ -3754,7 +3915,10 @@ function codegenWasm(kernel) {
3754
3915
  cg.br(1);
3755
3916
  cg.end();
3756
3917
  cg.end();
3757
- translateExp(cg, funcs, kernel.reduction.epilogue, { acc });
3918
+ translateExp(cg, funcs, tune.epilogue, {
3919
+ acc,
3920
+ gidx
3921
+ });
3758
3922
  } else translateExp(cg, funcs, tune.exp, { gidx });
3759
3923
  dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
3760
3924
  cg.local.get(gidx);
@@ -4003,7 +4167,7 @@ async function createBackend(device) {
4003
4167
  if (!navigator.gpu) return null;
4004
4168
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4005
4169
  if (!adapter) return null;
4006
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CcGP160M.cjs"));
4170
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Oj3Kd-kd.cjs"));
4007
4171
  const importantLimits = [
4008
4172
  "maxBufferSize",
4009
4173
  "maxComputeInvocationsPerWorkgroup",
@@ -4037,8 +4201,8 @@ function getBackend(device) {
4037
4201
  return backend;
4038
4202
  }
4039
4203
  var Executable = class {
4040
- constructor(kernel, data) {
4041
- this.kernel = kernel;
4204
+ constructor(source, data) {
4205
+ this.source = source;
4042
4206
  this.data = data;
4043
4207
  }
4044
4208
  };
@@ -4054,6 +4218,11 @@ var UnsupportedOpError = class extends Error {
4054
4218
  super(msg);
4055
4219
  }
4056
4220
  };
4221
+ var UnsupportedRoutineError = class extends Error {
4222
+ constructor(name, device) {
4223
+ super(`routine '${name}' is not supported in ${device} backend`);
4224
+ }
4225
+ };
4057
4226
 
4058
4227
  //#endregion
4059
4228
  Object.defineProperty(exports, 'AluExp', {
@@ -4122,6 +4291,18 @@ Object.defineProperty(exports, 'Reduction', {
4122
4291
  return Reduction;
4123
4292
  }
4124
4293
  });
4294
+ Object.defineProperty(exports, 'Routine', {
4295
+ enumerable: true,
4296
+ get: function () {
4297
+ return Routine;
4298
+ }
4299
+ });
4300
+ Object.defineProperty(exports, 'Routines', {
4301
+ enumerable: true,
4302
+ get: function () {
4303
+ return Routines;
4304
+ }
4305
+ });
4125
4306
  Object.defineProperty(exports, 'ShapeTracker', {
4126
4307
  enumerable: true,
4127
4308
  get: function () {
@@ -4140,6 +4321,12 @@ Object.defineProperty(exports, 'UnsupportedOpError', {
4140
4321
  return UnsupportedOpError;
4141
4322
  }
4142
4323
  });
4324
+ Object.defineProperty(exports, 'UnsupportedRoutineError', {
4325
+ enumerable: true,
4326
+ get: function () {
4327
+ return UnsupportedRoutineError;
4328
+ }
4329
+ });
4143
4330
  Object.defineProperty(exports, 'accessorAluExp', {
4144
4331
  enumerable: true,
4145
4332
  get: function () {
@@ -4349,5 +4536,4 @@ Object.defineProperty(exports, 'zipn', {
4349
4536
  get: function () {
4350
4537
  return zipn;
4351
4538
  }
4352
- });
4353
- //# sourceMappingURL=backend-DeVfWEFS.cjs.map
4539
+ });