@jax-js/jax 0.1.1 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -257,36 +257,12 @@ await devicePut(ar, "webgpu"); // Now device="webgpu"
257
257
  There are other libraries in the `@jax-js` namespace that can work with jax-js, or be used in a
258
258
  self-contained way in other projects.
259
259
 
260
- **`@jax-js/optax`** provides implementations of optimizers like Adam and SGD.
261
-
262
- ```ts
263
- import { adam } from "@jax-js/optax";
264
-
265
- let params = np.array([1.0, 2.0, 3.0]);
266
-
267
- const solver = adam(1e-3);
268
- let optState = solver.init(params.ref);
269
- let updates: np.Array;
270
-
271
- const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();
272
-
273
- for (let i = 0; i < 100; i++) {
274
- const paramsGrad = grad(f)(params.ref);
275
- [updates, optState] = solver.update(paramsGrad, optState);
276
- params = applyUpdates(params, updates);
277
- }
278
- ```
279
-
280
- **`@jax-js/loaders`** can load tensors from various formats like Safetensors, includes a fast and
281
- compliant implementation of BPE, and caches HTTP requests for large assets like model weights in
282
- OPFS.
283
-
284
- ```ts
285
- import { tokenizers } from "@jax-js/loaders";
286
-
287
- const enc = await tokenizers.getBpe("clip");
288
- const tokens = enc.encode("Hello, world!"); // => [ 49406, 3306, 267, 1002, ... ]
289
- ```
260
+ - [**`@jax-js/loaders`**](packages/loaders) can load tensors from various formats like Safetensors,
261
+ includes a fast and compliant implementation of BPE, and caches HTTP requests for large assets
262
+ like model weights in OPFS.
263
+ - [**`@jax-js/onnx`**](packages/onnx) is a model loader from the [ONNX](https://onnx.ai/) format
264
+ into native jax-js functions.
265
+ - [**`@jax-js/optax`**](packages/optax) provides implementations of optimizers like Adam and SGD.
290
266
 
291
267
  ### Performance
292
268
 
@@ -311,6 +287,7 @@ If you make something cool with jax-js, don't be a stranger! We can feature it h
311
287
 
312
288
  - [Training neural networks on MNIST](https://jax-js.com/mnist)
313
289
  - [CLIP embeddings for books in-browser](https://jax-js.com/mobileclip)
290
+ - [Object detection with DETR ResNet-50 (ONNX)](https://jax-js.com/detr-resnet-50)
314
291
  - [In-browser REPL](https://jax-js.com/repl)
315
292
  - [Matmul benchmark](https://jax-js.com/bench/matmul)
316
293
  - [Conv2d benchmark](https://jax-js.com/bench/conv2d)
@@ -351,7 +328,9 @@ Contributions are welcomed! Especially in:
351
328
  - Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
352
329
  - Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
353
330
  and multithreading. (Even single-threaded Wasm could be ~20x faster.)
354
- - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches and adding
355
- epilogue to reductions.
331
+ - Adding support for `jax.profiling`, in particular the start and end trace functions. We should be
332
+ able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
333
+ to help with model performance debugging.
334
+ - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
356
335
  - Adding WebGL runtime for older browsers that don't support WebGPU.
357
336
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
@@ -289,10 +289,11 @@ var FpHash = class FpHash {
289
289
  };
290
290
  /** Run a function while caching it inline inside a `Map`. */
291
291
  function runWithCache(cache, key, thunk) {
292
- if (cache.has(key)) return cache.get(key);
292
+ const keyStr = JSON.stringify(key);
293
+ if (cache.has(keyStr)) return cache.get(keyStr);
293
294
  else {
294
295
  const value = thunk();
295
- cache.set(key, value);
296
+ cache.set(keyStr, value);
296
297
  return value;
297
298
  }
298
299
  }
@@ -449,6 +450,14 @@ var AluExp = class AluExp {
449
450
  static sqrt(a) {
450
451
  return new AluExp(AluOp.Sqrt, a.dtype, [a]);
451
452
  }
453
+ static floor(a) {
454
+ if (!isFloatDtype(a.dtype)) return a;
455
+ return new AluExp(AluOp.Floor, a.dtype, [a]);
456
+ }
457
+ static ceil(a) {
458
+ if (!isFloatDtype(a.dtype)) return a;
459
+ return new AluExp(AluOp.Ceil, a.dtype, [a]);
460
+ }
452
461
  static reciprocal(a) {
453
462
  return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
454
463
  }
@@ -548,16 +557,16 @@ var AluExp = class AluExp {
548
557
  });
549
558
  }
550
559
  /** Reindex gid values in this expression as needed. */
551
- reindexGids(gidMap) {
560
+ reindexGids(newGids) {
552
561
  return this.rewrite((exp) => {
553
562
  if (exp.op === AluOp.GlobalIndex) {
554
563
  const [gid, len] = exp.arg;
555
- const newGid = gidMap.get(gid);
556
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
564
+ const newGid = newGids[gid];
565
+ if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
557
566
  } else if (exp.op === AluOp.GlobalView) {
558
567
  const gid = exp.arg[0];
559
- const newGid = gidMap.get(gid);
560
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
568
+ const newGid = newGids[gid];
569
+ if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
561
570
  }
562
571
  });
563
572
  }
@@ -629,6 +638,12 @@ var AluExp = class AluExp {
629
638
  case AluOp.Sqrt:
630
639
  ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
631
640
  break;
641
+ case AluOp.Floor:
642
+ ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
643
+ break;
644
+ case AluOp.Ceil:
645
+ ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
646
+ break;
632
647
  case AluOp.Reciprocal:
633
648
  if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
634
649
  ret = [1 / src[0].max, 1 / src[0].min];
@@ -765,7 +780,7 @@ var AluExp = class AluExp {
765
780
  if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
766
781
  if (op === AluOp.Mul && x === 1) return src[1 - i];
767
782
  if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
768
- if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
783
+ if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
769
784
  if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
770
785
  }
771
786
  if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
@@ -861,7 +876,7 @@ var AluExp = class AluExp {
861
876
  else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
862
877
  if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
863
878
  }
864
- if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
879
+ if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
865
880
  const [x, y] = src;
866
881
  {
867
882
  const factors = [];
@@ -951,6 +966,8 @@ var AluExp = class AluExp {
951
966
  case AluOp.Erf: return erf(x);
952
967
  case AluOp.Erfc: return erfc(x);
953
968
  case AluOp.Sqrt: return Math.sqrt(x);
969
+ case AluOp.Floor: return Math.floor(x);
970
+ case AluOp.Ceil: return Math.ceil(x);
954
971
  case AluOp.Reciprocal: return 1 / x;
955
972
  case AluOp.Cast: {
956
973
  const wasFloat = isFloatDtype(this.src[0].dtype);
@@ -1140,6 +1157,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1140
1157
  AluOp$1["Erf"] = "Erf";
1141
1158
  AluOp$1["Erfc"] = "Erfc";
1142
1159
  AluOp$1["Sqrt"] = "Sqrt";
1160
+ AluOp$1["Floor"] = "Floor";
1161
+ AluOp$1["Ceil"] = "Ceil";
1143
1162
  AluOp$1["Reciprocal"] = "Reciprocal";
1144
1163
  AluOp$1["Cast"] = "Cast";
1145
1164
  AluOp$1["Bitcast"] = "Bitcast";
@@ -1174,6 +1193,8 @@ const AluGroup = {
1174
1193
  AluOp.Erf,
1175
1194
  AluOp.Erfc,
1176
1195
  AluOp.Sqrt,
1196
+ AluOp.Floor,
1197
+ AluOp.Ceil,
1177
1198
  AluOp.Reciprocal,
1178
1199
  AluOp.Cast,
1179
1200
  AluOp.Bitcast
@@ -1201,7 +1222,9 @@ const AluGroup = {
1201
1222
  AluOp.Erf,
1202
1223
  AluOp.Erfc,
1203
1224
  AluOp.Sqrt,
1204
- AluOp.Reciprocal
1225
+ AluOp.Reciprocal,
1226
+ AluOp.Floor,
1227
+ AluOp.Ceil
1205
1228
  ])
1206
1229
  };
1207
1230
  /** Common variables that can be substituted in expressions. */
@@ -2043,7 +2066,8 @@ function tuneNullopt(kernel) {
2043
2066
  if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2044
2067
  return {
2045
2068
  exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
2046
- outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
2069
+ epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
2070
+ outputIdxExp: vars.gidx,
2047
2071
  threadCount: kernel.size,
2048
2072
  size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
2049
2073
  };
@@ -2076,7 +2100,11 @@ function tuneWebgpu(kernel) {
2076
2100
  while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
2077
2101
  const choices = [];
2078
2102
  const composedSts = sts.map((st) => st.compose(dim.st));
2079
- for (let axis = 0; axis < dim.groups; axis++) for (const amount of [3, 4]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2103
+ for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
2104
+ 3,
2105
+ 4,
2106
+ 5
2107
+ ]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2080
2108
  let nonzeroStrides = 0;
2081
2109
  let totalStrides = 0;
2082
2110
  for (const st of composedSts) {
@@ -2152,7 +2180,15 @@ function tuneWebgpu(kernel) {
2152
2180
  });
2153
2181
  const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
2154
2182
  const outputUpcast = dim.outputSt.shape.slice(dim.groups);
2155
- const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
2183
+ const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
2184
+ const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
2185
+ const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
2186
+ if (exp$1.op === AluOp.GlobalView) {
2187
+ const gid = exp$1.arg[0];
2188
+ const st = exp$1.arg[1];
2189
+ return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
2190
+ }
2191
+ });
2156
2192
  if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
2157
2193
  const size = {
2158
2194
  groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
@@ -2162,6 +2198,7 @@ function tuneWebgpu(kernel) {
2162
2198
  };
2163
2199
  return {
2164
2200
  exp: newExp.simplify(),
2201
+ epilogue: newEpilogue.simplify(),
2165
2202
  outputIdxExp: outputIdxExp.simplify(),
2166
2203
  threadCount: kernel.size / size.upcast * size.groups,
2167
2204
  size
@@ -2220,10 +2257,10 @@ var CpuBackend = class {
2220
2257
  return new Executable(kernel, void 0);
2221
2258
  }
2222
2259
  dispatch({ kernel }, inputs, outputs) {
2223
- const { exp } = tuneNullopt(kernel);
2260
+ const { exp, epilogue } = tuneNullopt(kernel);
2224
2261
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
2225
2262
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
2226
- const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2263
+ const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2227
2264
  const inputArrays = inputBuffers.map((buf, i) => {
2228
2265
  const dtype = usedArgs.get(i);
2229
2266
  if (!dtype) return null;
@@ -2245,7 +2282,10 @@ var CpuBackend = class {
2245
2282
  }, globals);
2246
2283
  acc = kernel.reduction.evaluate(acc, item);
2247
2284
  }
2248
- outputArray[i] = kernel.reduction.epilogue.evaluate({ acc });
2285
+ outputArray[i] = epilogue.evaluate({
2286
+ acc,
2287
+ gidx: i
2288
+ }, globals);
2249
2289
  }
2250
2290
  }
2251
2291
  #getBuffer(slot) {
@@ -2408,7 +2448,7 @@ function wasm_log(cg) {
2408
2448
  const t2 = cg.local.declare(cg.f32);
2409
2449
  cg.local.get(0);
2410
2450
  cg.f32.const(0);
2411
- cg.f32.le();
2451
+ cg.f32.lt();
2412
2452
  cg.if(cg.void);
2413
2453
  cg.f32.const(NaN);
2414
2454
  cg.return();
@@ -2423,6 +2463,20 @@ function wasm_log(cg) {
2423
2463
  cg.i32.const(127);
2424
2464
  cg.i32.sub();
2425
2465
  cg.local.set(e);
2466
+ cg.local.get(e);
2467
+ cg.i32.const(-127);
2468
+ cg.i32.eq();
2469
+ cg.if(cg.void);
2470
+ cg.f32.const(-Infinity);
2471
+ cg.return();
2472
+ cg.end();
2473
+ cg.local.get(e);
2474
+ cg.i32.const(128);
2475
+ cg.i32.eq();
2476
+ cg.if(cg.void);
2477
+ cg.local.get(0);
2478
+ cg.return();
2479
+ cg.end();
2426
2480
  cg.local.get(bits);
2427
2481
  cg.i32.const(8388607);
2428
2482
  cg.i32.and();
@@ -2488,7 +2542,7 @@ function _sincos(cg) {
2488
2542
  cg.f32.mul();
2489
2543
  cg.f32.nearest();
2490
2544
  cg.local.tee(qf);
2491
- cg.i32.trunc_f32_s();
2545
+ cg.i32.trunc_sat_f32_s();
2492
2546
  cg.local.set(q);
2493
2547
  cg.local.get(y);
2494
2548
  cg.local.get(qf);
@@ -3575,6 +3629,7 @@ var F32x4 = class extends V128 {
3575
3629
 
3576
3630
  //#endregion
3577
3631
  //#region src/backend/wasm.ts
3632
+ const moduleCache = /* @__PURE__ */ new Map();
3578
3633
  /** Backend that compiles into WebAssembly bytecode for immediate execution. */
3579
3634
  var WasmBackend = class {
3580
3635
  type = "wasm";
@@ -3630,8 +3685,11 @@ var WasmBackend = class {
3630
3685
  return this.prepareSync(kernel);
3631
3686
  }
3632
3687
  prepareSync(kernel) {
3633
- const bytes = codegenWasm(kernel);
3634
- const module = new WebAssembly.Module(bytes);
3688
+ const kernelHash = FpHash.hash(kernel);
3689
+ const module = runWithCache(moduleCache, kernelHash.toString(), () => {
3690
+ const bytes = codegenWasm(kernel);
3691
+ return new WebAssembly.Module(bytes);
3692
+ });
3635
3693
  return new Executable(kernel, { module });
3636
3694
  }
3637
3695
  dispatch(exe, inputs, outputs) {
@@ -3652,7 +3710,7 @@ function codegenWasm(kernel) {
3652
3710
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3653
3711
  const cg = new CodeGenerator();
3654
3712
  cg.memory.import("env", "memory");
3655
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3713
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
3656
3714
  const funcs = {};
3657
3715
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3658
3716
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
@@ -3730,7 +3788,10 @@ function codegenWasm(kernel) {
3730
3788
  cg.br(1);
3731
3789
  cg.end();
3732
3790
  cg.end();
3733
- translateExp(cg, funcs, kernel.reduction.epilogue, { acc });
3791
+ translateExp(cg, funcs, tune.epilogue, {
3792
+ acc,
3793
+ gidx
3794
+ });
3734
3795
  } else translateExp(cg, funcs, tune.exp, { gidx });
3735
3796
  dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
3736
3797
  cg.local.get(gidx);
@@ -3831,7 +3892,9 @@ function translateExp(cg, funcs, exp, ctx) {
3831
3892
  else if (op === AluOp.Reciprocal) {
3832
3893
  const dt = dtyF(cg, op, dtype);
3833
3894
  dt.const(1), gen(src[0]), dt.div();
3834
- } else if (op === AluOp.Cast) {
3895
+ } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
3896
+ else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
3897
+ else if (op === AluOp.Cast) {
3835
3898
  gen(src[0]);
3836
3899
  const dtype0 = src[0].dtype;
3837
3900
  const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
@@ -3977,7 +4040,7 @@ async function createBackend(device) {
3977
4040
  if (!navigator.gpu) return null;
3978
4041
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3979
4042
  if (!adapter) return null;
3980
- const { WebGPUBackend } = await import("./webgpu-B3UVme6n.js");
4043
+ const { WebGPUBackend } = await import("./webgpu-C9iAP5h5.js");
3981
4044
  const importantLimits = [
3982
4045
  "maxBufferSize",
3983
4046
  "maxComputeInvocationsPerWorkgroup",
@@ -4031,4 +4094,4 @@ var UnsupportedOpError = class extends Error {
4031
4094
 
4032
4095
  //#endregion
4033
4096
  export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
4034
- //# sourceMappingURL=backend-CoVtc9dx.js.map
4097
+ //# sourceMappingURL=backend-BY8wlLEl.js.map
@@ -290,10 +290,11 @@ var FpHash = class FpHash {
290
290
  };
291
291
  /** Run a function while caching it inline inside a `Map`. */
292
292
  function runWithCache(cache, key, thunk) {
293
- if (cache.has(key)) return cache.get(key);
293
+ const keyStr = JSON.stringify(key);
294
+ if (cache.has(keyStr)) return cache.get(keyStr);
294
295
  else {
295
296
  const value = thunk();
296
- cache.set(key, value);
297
+ cache.set(keyStr, value);
297
298
  return value;
298
299
  }
299
300
  }
@@ -450,6 +451,14 @@ var AluExp = class AluExp {
450
451
  static sqrt(a) {
451
452
  return new AluExp(AluOp.Sqrt, a.dtype, [a]);
452
453
  }
454
+ static floor(a) {
455
+ if (!isFloatDtype(a.dtype)) return a;
456
+ return new AluExp(AluOp.Floor, a.dtype, [a]);
457
+ }
458
+ static ceil(a) {
459
+ if (!isFloatDtype(a.dtype)) return a;
460
+ return new AluExp(AluOp.Ceil, a.dtype, [a]);
461
+ }
453
462
  static reciprocal(a) {
454
463
  return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
455
464
  }
@@ -549,16 +558,16 @@ var AluExp = class AluExp {
549
558
  });
550
559
  }
551
560
  /** Reindex gid values in this expression as needed. */
552
- reindexGids(gidMap) {
561
+ reindexGids(newGids) {
553
562
  return this.rewrite((exp) => {
554
563
  if (exp.op === AluOp.GlobalIndex) {
555
564
  const [gid, len] = exp.arg;
556
- const newGid = gidMap.get(gid);
557
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
565
+ const newGid = newGids[gid];
566
+ if (newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, len, exp.src[0]);
558
567
  } else if (exp.op === AluOp.GlobalView) {
559
568
  const gid = exp.arg[0];
560
- const newGid = gidMap.get(gid);
561
- if (newGid !== void 0 && newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
569
+ const newGid = newGids[gid];
570
+ if (newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
562
571
  }
563
572
  });
564
573
  }
@@ -630,6 +639,12 @@ var AluExp = class AluExp {
630
639
  case AluOp.Sqrt:
631
640
  ret = [Math.sqrt(src[0].min), Math.sqrt(src[0].max)];
632
641
  break;
642
+ case AluOp.Floor:
643
+ ret = [Math.floor(src[0].min), Math.floor(src[0].max)];
644
+ break;
645
+ case AluOp.Ceil:
646
+ ret = [Math.ceil(src[0].min), Math.ceil(src[0].max)];
647
+ break;
633
648
  case AluOp.Reciprocal:
634
649
  if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
635
650
  ret = [1 / src[0].max, 1 / src[0].min];
@@ -766,7 +781,7 @@ var AluExp = class AluExp {
766
781
  if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
767
782
  if (op === AluOp.Mul && x === 1) return src[1 - i];
768
783
  if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
769
- if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
784
+ if (op === AluOp.Idiv && i === 1 && x === 1 && !isFloatDtype(this.dtype)) return src[1 - i];
770
785
  if (op === AluOp.Cmpne && src[i].dtype === DType.Bool && x === 0) return src[1 - i];
771
786
  }
772
787
  if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
@@ -862,7 +877,7 @@ var AluExp = class AluExp {
862
877
  else return p(p(src[0].src[0], src[1]), src[0].src[1]).simplify(cache);
863
878
  if (src[1].op === op && src[1].src[1].op === AluOp.Const) return p(p(src[0], src[1].src[0]), src[1].src[1]).simplify(cache);
864
879
  }
865
- if (op === AluOp.Mod || op === AluOp.Idiv && src[1].#isConstInt()) {
880
+ if ((op === AluOp.Mod || op === AluOp.Idiv) && src[1].#isConstInt()) {
866
881
  const [x, y] = src;
867
882
  {
868
883
  const factors = [];
@@ -952,6 +967,8 @@ var AluExp = class AluExp {
952
967
  case AluOp.Erf: return erf(x);
953
968
  case AluOp.Erfc: return erfc(x);
954
969
  case AluOp.Sqrt: return Math.sqrt(x);
970
+ case AluOp.Floor: return Math.floor(x);
971
+ case AluOp.Ceil: return Math.ceil(x);
955
972
  case AluOp.Reciprocal: return 1 / x;
956
973
  case AluOp.Cast: {
957
974
  const wasFloat = isFloatDtype(this.src[0].dtype);
@@ -1141,6 +1158,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
1141
1158
  AluOp$1["Erf"] = "Erf";
1142
1159
  AluOp$1["Erfc"] = "Erfc";
1143
1160
  AluOp$1["Sqrt"] = "Sqrt";
1161
+ AluOp$1["Floor"] = "Floor";
1162
+ AluOp$1["Ceil"] = "Ceil";
1144
1163
  AluOp$1["Reciprocal"] = "Reciprocal";
1145
1164
  AluOp$1["Cast"] = "Cast";
1146
1165
  AluOp$1["Bitcast"] = "Bitcast";
@@ -1175,6 +1194,8 @@ const AluGroup = {
1175
1194
  AluOp.Erf,
1176
1195
  AluOp.Erfc,
1177
1196
  AluOp.Sqrt,
1197
+ AluOp.Floor,
1198
+ AluOp.Ceil,
1178
1199
  AluOp.Reciprocal,
1179
1200
  AluOp.Cast,
1180
1201
  AluOp.Bitcast
@@ -1202,7 +1223,9 @@ const AluGroup = {
1202
1223
  AluOp.Erf,
1203
1224
  AluOp.Erfc,
1204
1225
  AluOp.Sqrt,
1205
- AluOp.Reciprocal
1226
+ AluOp.Reciprocal,
1227
+ AluOp.Floor,
1228
+ AluOp.Ceil
1206
1229
  ])
1207
1230
  };
1208
1231
  /** Common variables that can be substituted in expressions. */
@@ -2044,7 +2067,8 @@ function tuneNullopt(kernel) {
2044
2067
  if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
2045
2068
  return {
2046
2069
  exp: kernel.exp.substitute(vars).rewriteGlobalViews().simplify(),
2047
- outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
2070
+ epilogue: kernel.reduction?.epilogue.substitute({ gidx: vars.gidx }).rewriteGlobalViews().simplify(),
2071
+ outputIdxExp: vars.gidx,
2048
2072
  threadCount: kernel.size,
2049
2073
  size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
2050
2074
  };
@@ -2077,7 +2101,11 @@ function tuneWebgpu(kernel) {
2077
2101
  while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
2078
2102
  const choices = [];
2079
2103
  const composedSts = sts.map((st) => st.compose(dim.st));
2080
- for (let axis = 0; axis < dim.groups; axis++) for (const amount of [3, 4]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2104
+ for (let axis = 0; axis < dim.groups; axis++) for (const amount of [
2105
+ 3,
2106
+ 4,
2107
+ 5
2108
+ ]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
2081
2109
  let nonzeroStrides = 0;
2082
2110
  let totalStrides = 0;
2083
2111
  for (const st of composedSts) {
@@ -2153,7 +2181,15 @@ function tuneWebgpu(kernel) {
2153
2181
  });
2154
2182
  const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
2155
2183
  const outputUpcast = dim.outputSt.shape.slice(dim.groups);
2156
- const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
2184
+ const outputIndices = [...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)];
2185
+ const [outputIdxExp, _] = dim.outputSt.toAluExp(outputIndices);
2186
+ const newEpilogue = reduction.epilogue.rewrite((exp$1) => {
2187
+ if (exp$1.op === AluOp.GlobalView) {
2188
+ const gid = exp$1.arg[0];
2189
+ const st = exp$1.arg[1];
2190
+ return accessorGlobal(exp$1.dtype, gid, st.compose(dim.outputSt), outputIndices);
2191
+ }
2192
+ });
2157
2193
  if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
2158
2194
  const size = {
2159
2195
  groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
@@ -2163,6 +2199,7 @@ function tuneWebgpu(kernel) {
2163
2199
  };
2164
2200
  return {
2165
2201
  exp: newExp.simplify(),
2202
+ epilogue: newEpilogue.simplify(),
2166
2203
  outputIdxExp: outputIdxExp.simplify(),
2167
2204
  threadCount: kernel.size / size.upcast * size.groups,
2168
2205
  size
@@ -2221,10 +2258,10 @@ var CpuBackend = class {
2221
2258
  return new Executable(kernel, void 0);
2222
2259
  }
2223
2260
  dispatch({ kernel }, inputs, outputs) {
2224
- const { exp } = tuneNullopt(kernel);
2261
+ const { exp, epilogue } = tuneNullopt(kernel);
2225
2262
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
2226
2263
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
2227
- const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2264
+ const usedArgs = new Map([...exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex), ...epilogue ? epilogue.collect((exp$1) => exp$1.op === AluOp.GlobalIndex) : []].map((exp$1) => [exp$1.arg[0], exp$1.dtype]));
2228
2265
  const inputArrays = inputBuffers.map((buf, i) => {
2229
2266
  const dtype = usedArgs.get(i);
2230
2267
  if (!dtype) return null;
@@ -2246,7 +2283,10 @@ var CpuBackend = class {
2246
2283
  }, globals);
2247
2284
  acc = kernel.reduction.evaluate(acc, item);
2248
2285
  }
2249
- outputArray[i] = kernel.reduction.epilogue.evaluate({ acc });
2286
+ outputArray[i] = epilogue.evaluate({
2287
+ acc,
2288
+ gidx: i
2289
+ }, globals);
2250
2290
  }
2251
2291
  }
2252
2292
  #getBuffer(slot) {
@@ -2409,7 +2449,7 @@ function wasm_log(cg) {
2409
2449
  const t2 = cg.local.declare(cg.f32);
2410
2450
  cg.local.get(0);
2411
2451
  cg.f32.const(0);
2412
- cg.f32.le();
2452
+ cg.f32.lt();
2413
2453
  cg.if(cg.void);
2414
2454
  cg.f32.const(NaN);
2415
2455
  cg.return();
@@ -2424,6 +2464,20 @@ function wasm_log(cg) {
2424
2464
  cg.i32.const(127);
2425
2465
  cg.i32.sub();
2426
2466
  cg.local.set(e);
2467
+ cg.local.get(e);
2468
+ cg.i32.const(-127);
2469
+ cg.i32.eq();
2470
+ cg.if(cg.void);
2471
+ cg.f32.const(-Infinity);
2472
+ cg.return();
2473
+ cg.end();
2474
+ cg.local.get(e);
2475
+ cg.i32.const(128);
2476
+ cg.i32.eq();
2477
+ cg.if(cg.void);
2478
+ cg.local.get(0);
2479
+ cg.return();
2480
+ cg.end();
2427
2481
  cg.local.get(bits);
2428
2482
  cg.i32.const(8388607);
2429
2483
  cg.i32.and();
@@ -2489,7 +2543,7 @@ function _sincos(cg) {
2489
2543
  cg.f32.mul();
2490
2544
  cg.f32.nearest();
2491
2545
  cg.local.tee(qf);
2492
- cg.i32.trunc_f32_s();
2546
+ cg.i32.trunc_sat_f32_s();
2493
2547
  cg.local.set(q);
2494
2548
  cg.local.get(y);
2495
2549
  cg.local.get(qf);
@@ -3576,6 +3630,7 @@ var F32x4 = class extends V128 {
3576
3630
 
3577
3631
  //#endregion
3578
3632
  //#region src/backend/wasm.ts
3633
+ const moduleCache = /* @__PURE__ */ new Map();
3579
3634
  /** Backend that compiles into WebAssembly bytecode for immediate execution. */
3580
3635
  var WasmBackend = class {
3581
3636
  type = "wasm";
@@ -3631,8 +3686,11 @@ var WasmBackend = class {
3631
3686
  return this.prepareSync(kernel);
3632
3687
  }
3633
3688
  prepareSync(kernel) {
3634
- const bytes = codegenWasm(kernel);
3635
- const module$1 = new WebAssembly.Module(bytes);
3689
+ const kernelHash = FpHash.hash(kernel);
3690
+ const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => {
3691
+ const bytes = codegenWasm(kernel);
3692
+ return new WebAssembly.Module(bytes);
3693
+ });
3636
3694
  return new Executable(kernel, { module: module$1 });
3637
3695
  }
3638
3696
  dispatch(exe, inputs, outputs) {
@@ -3653,7 +3711,7 @@ function codegenWasm(kernel) {
3653
3711
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
3654
3712
  const cg = new CodeGenerator();
3655
3713
  cg.memory.import("env", "memory");
3656
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
3714
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
3657
3715
  const funcs = {};
3658
3716
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3659
3717
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
@@ -3731,7 +3789,10 @@ function codegenWasm(kernel) {
3731
3789
  cg.br(1);
3732
3790
  cg.end();
3733
3791
  cg.end();
3734
- translateExp(cg, funcs, kernel.reduction.epilogue, { acc });
3792
+ translateExp(cg, funcs, tune.epilogue, {
3793
+ acc,
3794
+ gidx
3795
+ });
3735
3796
  } else translateExp(cg, funcs, tune.exp, { gidx });
3736
3797
  dty(cg, null, kernel.dtype).store(Math.log2(byteWidth(kernel.dtype)));
3737
3798
  cg.local.get(gidx);
@@ -3832,7 +3893,9 @@ function translateExp(cg, funcs, exp, ctx) {
3832
3893
  else if (op === AluOp.Reciprocal) {
3833
3894
  const dt = dtyF(cg, op, dtype);
3834
3895
  dt.const(1), gen(src[0]), dt.div();
3835
- } else if (op === AluOp.Cast) {
3896
+ } else if (op === AluOp.Floor) gen(src[0]), dtyF(cg, op, dtype).floor();
3897
+ else if (op === AluOp.Ceil) gen(src[0]), dtyF(cg, op, dtype).ceil();
3898
+ else if (op === AluOp.Cast) {
3836
3899
  gen(src[0]);
3837
3900
  const dtype0 = src[0].dtype;
3838
3901
  const i32repr = dtype0 === DType.Int32 || dtype0 === DType.Uint32 || dtype0 === DType.Bool;
@@ -3978,7 +4041,7 @@ async function createBackend(device) {
3978
4041
  if (!navigator.gpu) return null;
3979
4042
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3980
4043
  if (!adapter) return null;
3981
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-DGYNVHma.cjs"));
4044
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVns4DbI.cjs"));
3982
4045
  const importantLimits = [
3983
4046
  "maxBufferSize",
3984
4047
  "maxComputeInvocationsPerWorkgroup",
@@ -4325,4 +4388,4 @@ Object.defineProperty(exports, 'zipn', {
4325
4388
  return zipn;
4326
4389
  }
4327
4390
  });
4328
- //# sourceMappingURL=backend-BbrKEB18.cjs.map
4391
+ //# sourceMappingURL=backend-CmaidnkQ.cjs.map