@jax-js/jax 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -257,36 +257,12 @@ await devicePut(ar, "webgpu"); // Now device="webgpu"
257
257
  There are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a
258
258
  self-contained way in other projects.
259
259
 
260
- **`@jax-js/optax`** provides implementations of optimizers like Adam and SGD.
261
-
262
- ```ts
263
- import { adam } from "@jax-js/optax";
264
-
265
- let params = np.array([1.0, 2.0, 3.0]);
266
-
267
- const solver = adam(1e-3);
268
- let optState = solver.init(params.ref);
269
- let updates: np.Array;
270
-
271
- const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();
272
-
273
- for (let i = 0; i < 100; i++) {
274
- const paramsGrad = grad(f)(params.ref);
275
- [updates, optState] = solver.update(paramsGrad, optState);
276
- params = applyUpdates(params, updates);
277
- }
278
- ```
279
-
280
- **`@jax-js/loaders`** can load tensors from various formats like Safetensors, includes a fast and
281
- compliant implementation of BPE, and caches HTTP requests for large assets like model weights in
282
- OPFS.
283
-
284
- ```ts
285
- import { tokenizers } from "@jax-js/loaders";
286
-
287
- const enc = await tokenizers.getBpe("clip");
288
- const tokens = enc.encode("Hello, world!"); // => [ 49406, 3306, 267, 1002, ... ]
289
- ```
260
+ - [**`@jax-js/loaders`**](packages/loaders) can load tensors from various formats like Safetensors,
261
+ includes a fast and compliant implementation of BPE, and caches HTTP requests for large assets
262
+ like model weights in OPFS.
263
+ - [**`@jax-js/onnx`**](packages/onnx) is a model loader from the [ONNX](https://onnx.ai/) format
264
+ into native jax-js functions.
265
+ - [**`@jax-js/optax`**](packages/optax) provides implementations of optimizers like Adam and SGD.
290
266
 
291
267
  ### Performance
292
268
 
@@ -311,6 +287,7 @@ If you make something cool with jax-js, don't be a stranger! We can feature it h
311
287
 
312
288
  - [Training neural networks on MNIST](https://jax-js.com/mnist)
313
289
  - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
290
+ - [Object detection with DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
314
291
  - [In-browser REPL](https://jax-js.com/repl)
315
292
  - [Matmul benchmark](https://jax-js.com/bench/matmul)
316
293
  - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
@@ -351,7 +328,9 @@ Contributions are welcomed! Especially in:
351
328
  - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
352
329
  - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
353
330
  and multithreading. (Even single-threaded Wasm could be ~20x faster.)
354
- - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches and adding
355
- epilogue to reductions.
331
+ - Adding support for `jax.profiling`, in particular the start and end trace functions. We should be
332
+ able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
333
+ to help with model performance debugging.
334
+ - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
356
335
  - Adding WebGL runtime for older browsers that don't support WebGPU.
357
336
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
@@ -557,16 +557,16 @@ var AluExp = class AluExp {
557
557
  });
558
558
  }
559
559
  /** Reindex gid values in this expression as needed. */
560
- reindexGids(gidMap) {
560
+ reindexGids(newGids) {
561
561
  return this.rewrite((exp) => {
562
562
  if (exp.op === AluOp.GlobalIndex) {
563
563
  const [gid, len] = exp.arg;
564
- const newGid = gidMap.get(gid);
565
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
564
+ const newGid = newGids[gid];
565
+ if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
566
566
  } else if (exp.op === AluOp.GlobalView) {
567
567
  const gid = exp.arg[0];
568
- const newGid = gidMap.get(gid);
569
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
568
+ const newGid = newGids[gid];
569
+ if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
570
570
  }
571
571
  });
572
572
  }
@@ -780,7 +780,7 @@ var AluExp = class AluExp {
780
780
  if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
781
781
  if (op === AluOp.Mul && x === 1) return src[1 - i];
782
782
  if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
783
- if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
783
+ if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
784
784
  if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
785
785
  }
786
786
  if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
@@ -2066,7 +2066,8 @@ function tuneNullopt(kernel) {
2066
2066
  if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2067
2067
  return {
2068
2068
  exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
2069
- outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
2069
+ epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
2070
+ outputIdxExp: vars.gidx,
2070
2071
  threadCount: kernel.size,
2071
2072
  size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
2072
2073
  };
@@ -2099,7 +2100,11 @@ function tuneWebgpu(kernel) {
2099
2100
  while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
2100
2101
  const choices = [];
2101
2102
  const composedSts = sts.map((st) => st.compose(dim.st));
2102
- for (let axis = 0; axis < dim.groups; axis++) for (const amount of [3, 4]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2103
+ for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
2104
+ 3,
2105
+ 4,
2106
+ 5
2107
+ ]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2103
2108
  let nonzeroStrides = 0;
2104
2109
  let totalStrides = 0;
2105
2110
  for (const st of composedSts) {
@@ -2175,7 +2180,15 @@ function tuneWebgpu(kernel) {
2175
2180
  });
2176
2181
  const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
2177
2182
  const outputUpcast = dim.outputSt.shape.slice(dim.groups);
2178
- const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
2183
+ const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
2184
+ const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
2185
+ const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
2186
+ if (exp$1.op === AluOp.GlobalView) {
2187
+ const gid = exp$1.arg[0];
2188
+ const st = exp$1.arg[1];
2189
+ return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
2190
+ }
2191
+ });
2179
2192
  if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
2180
2193
  const size = {
2181
2194
  groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
@@ -2185,6 +2198,7 @@ function tuneWebgpu(kernel) {
2185
2198
  };
2186
2199
  return {
2187
2200
  exp: newExp.simplify(),
2201
+ epilogue: newEpilogue.simplify(),
2188
2202
  outputIdxExp: outputIdxExp.simplify(),
2189
2203
  threadCount: kernel.size / size.upcast * size.groups,
2190
2204
  size
@@ -2243,10 +2257,10 @@ var CpuBackend = class {
2243
2257
  return new Executable(kernel, void 0);
2244
2258
  }
2245
2259
  dispatch({ kernel }, inputs, outputs) {
2246
- const { exp } = tuneNullopt(kernel);
2260
+ const { exp, epilogue } = tuneNullopt(kernel);
2247
2261
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
2248
2262
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
2249
- const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2263
+ const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2250
2264
  const inputArrays = inputBuffers.map((buf, i) => {
2251
2265
  const dtype = usedArgs.get(i);
2252
2266
  if (!dtype) return null;
@@ -2268,7 +2282,10 @@ var CpuBackend = class {
2268
2282
  }, globals);
2269
2283
  acc = kernel.reduction.evaluate(acc, item);
2270
2284
  }
2271
- outputArray[i] = kernel.reduction.epilogue.evaluate({ acc });
2285
+ outputArray[i] = epilogue.evaluate({
2286
+ acc,
2287
+ gidx: i
2288
+ }, globals);
2272
2289
  }
2273
2290
  }
2274
2291
  #getBuffer(slot) {
@@ -2431,7 +2448,7 @@ function wasm_log(cg) {
2431
2448
  const t2 = cg.local.declare(cg.f32);
2432
2449
  cg.local.get(0);
2433
2450
  cg.f32.const(0);
2434
- cg.f32.le();
2451
+ cg.f32.lt();
2435
2452
  cg.if(cg.void);
2436
2453
  cg.f32.const(NaN);
2437
2454
  cg.return();
@@ -2446,6 +2463,20 @@ function wasm_log(cg) {
2446
2463
  cg.i32.const(127);
2447
2464
  cg.i32.sub();
2448
2465
  cg.local.set(e);
2466
+ cg.local.get(e);
2467
+ cg.i32.const(-127);
2468
+ cg.i32.eq();
2469
+ cg.if(cg.void);
2470
+ cg.f32.const(-Infinity);
2471
+ cg.return();
2472
+ cg.end();
2473
+ cg.local.get(e);
2474
+ cg.i32.const(128);
2475
+ cg.i32.eq();
2476
+ cg.if(cg.void);
2477
+ cg.local.get(0);
2478
+ cg.return();
2479
+ cg.end();
2449
2480
  cg.local.get(bits);
2450
2481
  cg.i32.const(8388607);
2451
2482
  cg.i32.and();
@@ -2511,7 +2542,7 @@ function _sincos(cg) {
2511
2542
  cg.f32.mul();
2512
2543
  cg.f32.nearest();
2513
2544
  cg.local.tee(qf);
2514
- cg.i32.trunc_f32_s();
2545
+ cg.i32.trunc_sat_f32_s();
2515
2546
  cg.local.set(q);
2516
2547
  cg.local.get(y);
2517
2548
  cg.local.get(qf);
@@ -3598,6 +3629,7 @@ var F32x4 = class extends V128 {
3598
3629
 
3599
3630
  //#endregion
3600
3631
  //#region src/backend/wasm.ts
3632
+ const moduleCache = /* @__PURE__ */ new Map();
3601
3633
  /** Backend that compiles into WebAssembly bytecode for immediate execution. */
3602
3634
  var WasmBackend = class {
3603
3635
  type = "wasm";
@@ -3653,8 +3685,11 @@ var WasmBackend = class {
3653
3685
  return this.prepareSync(kernel);
3654
3686
  }
3655
3687
  prepareSync(kernel) {
3656
- const bytes = codegenWasm(kernel);
3657
- const module = new WebAssembly.Module(bytes);
3688
+ const kernelHash = FpHash.hash(kernel);
3689
+ const module = runWithCache(moduleCache, kernelHash.toString(), () => {
3690
+ const bytes = codegenWasm(kernel);
3691
+ return new WebAssembly.Module(bytes);
3692
+ });
3658
3693
  return new Executable(kernel, { module });
3659
3694
  }
3660
3695
  dispatch(exe, inputs, outputs) {
@@ -3675,7 +3710,7 @@ function codegenWasm(kernel) {
3675
3710
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3676
3711
  const cg = new CodeGenerator();
3677
3712
  cg.memory.import("env", "memory");
3678
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3713
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
3679
3714
  const funcs = {};
3680
3715
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3681
3716
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
@@ -3753,7 +3788,10 @@ function codegenWasm(kernel) {
3753
3788
  cg.br(1);
3754
3789
  cg.end();
3755
3790
  cg.end();
3756
- translateExp(cg, funcs, kernel.reduction.epilogue, { acc });
3791
+ translateExp(cg, funcs, tune.epilogue, {
3792
+ acc,
3793
+ gidx
3794
+ });
3757
3795
  } else translateExp(cg, funcs, tune.exp, { gidx });
3758
3796
  dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
3759
3797
  cg.local.get(gidx);
@@ -4002,7 +4040,7 @@ async function createBackend(device) {
4002
4040
  if (!navigator.gpu) return null;
4003
4041
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4004
4042
  if (!adapter) return null;
4005
- const { WebGPUBackend } = await import("./webgpu-BGuG58KZ.js");
4043
+ const { WebGPUBackend } = await import("./webgpu-C9iAP5h5.js");
4006
4044
  const importantLimits = [
4007
4045
  "maxBufferSize",
4008
4046
  "maxComputeInvocationsPerWorkgroup",
@@ -4056,4 +4094,4 @@ var UnsupportedOpError = class extends Error {
4056
4094
 
4057
4095
  //#endregion
4058
4096
  export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
4059
- //# sourceMappingURL=backend-BqymqzuU.js.map
4097
+ //# sourceMappingURL=backend-BY8wlLEl.js.map
@@ -558,16 +558,16 @@ var AluExp = class AluExp {
558
558
  });
559
559
  }
560
560
  /** Reindex gid values in this expression as needed. */
561
- reindexGids(gidMap) {
561
+ reindexGids(newGids) {
562
562
  return this.rewrite((exp) => {
563
563
  if (exp.op === AluOp.GlobalIndex) {
564
564
  const [gid, len] = exp.arg;
565
- const newGid = gidMap.get(gid);
566
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
565
+ const newGid = newGids[gid];
566
+ if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
567
567
  } else if (exp.op === AluOp.GlobalView) {
568
568
  const gid = exp.arg[0];
569
- const newGid = gidMap.get(gid);
570
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
569
+ const newGid = newGids[gid];
570
+ if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
571
571
  }
572
572
  });
573
573
  }
@@ -781,7 +781,7 @@ var AluExp = class AluExp {
781
781
  if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
782
782
  if (op === AluOp.Mul && x === 1) return src[1 - i];
783
783
  if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
784
- if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
784
+ if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
785
785
  if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
786
786
  }
787
787
  if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
@@ -2067,7 +2067,8 @@ function tuneNullopt(kernel) {
2067
2067
  if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2068
2068
  return {
2069
2069
  exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
2070
- outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
2070
+ epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
2071
+ outputIdxExp: vars.gidx,
2071
2072
  threadCount: kernel.size,
2072
2073
  size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
2073
2074
  };
@@ -2100,7 +2101,11 @@ function tuneWebgpu(kernel) {
2100
2101
  while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
2101
2102
  const choices = [];
2102
2103
  const composedSts = sts.map((st) => st.compose(dim.st));
2103
- for (let axis = 0; axis < dim.groups; axis++) for (const amount of [3, 4]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2104
+ for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
2105
+ 3,
2106
+ 4,
2107
+ 5
2108
+ ]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2104
2109
  let nonzeroStrides = 0;
2105
2110
  let totalStrides = 0;
2106
2111
  for (const st of composedSts) {
@@ -2176,7 +2181,15 @@ function tuneWebgpu(kernel) {
2176
2181
  });
2177
2182
  const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
2178
2183
  const outputUpcast = dim.outputSt.shape.slice(dim.groups);
2179
- const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
2184
+ const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
2185
+ const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
2186
+ const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
2187
+ if (exp$1.op === AluOp.GlobalView) {
2188
+ const gid = exp$1.arg[0];
2189
+ const st = exp$1.arg[1];
2190
+ return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
2191
+ }
2192
+ });
2180
2193
  if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
2181
2194
  const size = {
2182
2195
  groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
@@ -2186,6 +2199,7 @@ function tuneWebgpu(kernel) {
2186
2199
  };
2187
2200
  return {
2188
2201
  exp: newExp.simplify(),
2202
+ epilogue: newEpilogue.simplify(),
2189
2203
  outputIdxExp: outputIdxExp.simplify(),
2190
2204
  threadCount: kernel.size / size.upcast * size.groups,
2191
2205
  size
@@ -2244,10 +2258,10 @@ var CpuBackend = class {
2244
2258
  return new Executable(kernel, void 0);
2245
2259
  }
2246
2260
  dispatch({ kernel }, inputs, outputs) {
2247
- const { exp } = tuneNullopt(kernel);
2261
+ const { exp, epilogue } = tuneNullopt(kernel);
2248
2262
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
2249
2263
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
2250
- const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2264
+ const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2251
2265
  const inputArrays = inputBuffers.map((buf, i) => {
2252
2266
  const dtype = usedArgs.get(i);
2253
2267
  if (!dtype) return null;
@@ -2269,7 +2283,10 @@ var CpuBackend = class {
2269
2283
  }, globals);
2270
2284
  acc = kernel.reduction.evaluate(acc, item);
2271
2285
  }
2272
- outputArray[i] = kernel.reduction.epilogue.evaluate({ acc });
2286
+ outputArray[i] = epilogue.evaluate({
2287
+ acc,
2288
+ gidx: i
2289
+ }, globals);
2273
2290
  }
2274
2291
  }
2275
2292
  #getBuffer(slot) {
@@ -2432,7 +2449,7 @@ function wasm_log(cg) {
2432
2449
  const t2 = cg.local.declare(cg.f32);
2433
2450
  cg.local.get(0);
2434
2451
  cg.f32.const(0);
2435
- cg.f32.le();
2452
+ cg.f32.lt();
2436
2453
  cg.if(cg.void);
2437
2454
  cg.f32.const(NaN);
2438
2455
  cg.return();
@@ -2447,6 +2464,20 @@ function wasm_log(cg) {
2447
2464
  cg.i32.const(127);
2448
2465
  cg.i32.sub();
2449
2466
  cg.local.set(e);
2467
+ cg.local.get(e);
2468
+ cg.i32.const(-127);
2469
+ cg.i32.eq();
2470
+ cg.if(cg.void);
2471
+ cg.f32.const(-Infinity);
2472
+ cg.return();
2473
+ cg.end();
2474
+ cg.local.get(e);
2475
+ cg.i32.const(128);
2476
+ cg.i32.eq();
2477
+ cg.if(cg.void);
2478
+ cg.local.get(0);
2479
+ cg.return();
2480
+ cg.end();
2450
2481
  cg.local.get(bits);
2451
2482
  cg.i32.const(8388607);
2452
2483
  cg.i32.and();
@@ -2512,7 +2543,7 @@ function _sincos(cg) {
2512
2543
  cg.f32.mul();
2513
2544
  cg.f32.nearest();
2514
2545
  cg.local.tee(qf);
2515
- cg.i32.trunc_f32_s();
2546
+ cg.i32.trunc_sat_f32_s();
2516
2547
  cg.local.set(q);
2517
2548
  cg.local.get(y);
2518
2549
  cg.local.get(qf);
@@ -3599,6 +3630,7 @@ var F32x4 = class extends V128 {
3599
3630
 
3600
3631
  //#endregion
3601
3632
  //#region src/backend/wasm.ts
3633
+ const moduleCache = /* @__PURE__ */ new Map();
3602
3634
  /** Backend that compiles into WebAssembly bytecode for immediate execution. */
3603
3635
  var WasmBackend = class {
3604
3636
  type = "wasm";
@@ -3654,8 +3686,11 @@ var WasmBackend = class {
3654
3686
  return this.prepareSync(kernel);
3655
3687
  }
3656
3688
  prepareSync(kernel) {
3657
- const bytes = codegenWasm(kernel);
3658
- const module$1 = new WebAssembly.Module(bytes);
3689
+ const kernelHash = FpHash.hash(kernel);
3690
+ const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => {
3691
+ const bytes = codegenWasm(kernel);
3692
+ return new WebAssembly.Module(bytes);
3693
+ });
3659
3694
  return new Executable(kernel, { module: module$1 });
3660
3695
  }
3661
3696
  dispatch(exe, inputs, outputs) {
@@ -3676,7 +3711,7 @@ function codegenWasm(kernel) {
3676
3711
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3677
3712
  const cg = new CodeGenerator();
3678
3713
  cg.memory.import("env", "memory");
3679
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3714
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
3680
3715
  const funcs = {};
3681
3716
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3682
3717
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
@@ -3754,7 +3789,10 @@ function codegenWasm(kernel) {
3754
3789
  cg.br(1);
3755
3790
  cg.end();
3756
3791
  cg.end();
3757
- translateExp(cg, funcs, kernel.reduction.epilogue, { acc });
3792
+ translateExp(cg, funcs, tune.epilogue, {
3793
+ acc,
3794
+ gidx
3795
+ });
3758
3796
  } else translateExp(cg, funcs, tune.exp, { gidx });
3759
3797
  dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
3760
3798
  cg.local.get(gidx);
@@ -4003,7 +4041,7 @@ async function createBackend(device) {
4003
4041
  if (!navigator.gpu) return null;
4004
4042
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4005
4043
  if (!adapter) return null;
4006
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CcGP160M.cjs"));
4044
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVns4DbI.cjs"));
4007
4045
  const importantLimits = [
4008
4046
  "maxBufferSize",
4009
4047
  "maxComputeInvocationsPerWorkgroup",
@@ -4350,4 +4388,4 @@ Object.defineProperty(exports, 'zipn', {
4350
4388
  return zipn;
4351
4389
  }
4352
4390
  });
4353
- //# sourceMappingURL=backend-DeVfWEFS.cjs.map
4391
+ //# sourceMappingURL=backend-CmaidnkQ.cjs.map