@jax-js/jax 0.1.2 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -8,9 +8,9 @@ var __hasOwnProp = Object.prototype.hasOwnProperty;
8
8
  var __commonJS = (cb, mod$1) => function() {
9
9
  return mod$1 || (0, cb[__getOwnPropNames(cb)[0]])((mod$1 = { exports: {} }).exports, mod$1), mod$1.exports;
10
10
  };
11
- var __export = (target, all) => {
12
- for (var name in all) __defProp(target, name, {
13
- get: all[name],
11
+ var __export = (target, all$1) => {
12
+ for (var name in all$1) __defProp(target, name, {
13
+ get: all$1[name],
14
14
  enumerable: true
15
15
  });
16
16
  };
@@ -30,30 +30,38 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DeVfWEFS.cjs');
33
+ const require_backend = require('./backend-Bu9GY6sK.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
37
37
  * Check that the shapes and parameters passed to convolution are valid.
38
+ * Expected shapes of the lhs and rhs of the convolution are:
39
+ *
40
+ * - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
41
+ * - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
38
42
  *
39
43
  * If the check succeeds, returns the output shape.
40
44
  */
41
- function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
45
+ function checkConvShape(lhsShape, rhsShape, { vmapDims, strides, padding, lhsDilation, rhsDilation }) {
42
46
  if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
43
- const n = lhsShape.length - 2;
47
+ const n = lhsShape.length - 2 - vmapDims;
44
48
  if (n < 0) throw new Error("conv() requires at least 2D inputs");
45
49
  if (strides.length !== n) throw new Error("conv() strides != spatial dims");
46
50
  if (padding.length !== n) throw new Error("conv() padding != spatial dims");
47
51
  if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
48
52
  if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
49
- if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
50
- const outShape = [lhsShape[0], rhsShape[0]];
53
+ if (lhsShape[vmapDims + 1] !== rhsShape[vmapDims + 1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
54
+ const outShape = [
55
+ ...require_backend.generalBroadcast(lhsShape.slice(0, vmapDims), rhsShape.slice(0, vmapDims)),
56
+ lhsShape[vmapDims],
57
+ rhsShape[vmapDims]
58
+ ];
51
59
  for (let i = 0; i < n; i++) {
52
60
  if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
53
61
  if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
54
62
  if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
55
63
  if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
56
- const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
64
+ const [x, k] = [lhsShape[i + vmapDims + 2], rhsShape[i + vmapDims + 2]];
57
65
  if (k <= 0) throw new Error("conv() kernel size must be positive");
58
66
  const [pl, pr] = padding[i];
59
67
  if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
@@ -178,27 +186,13 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
178
186
  function applyDilation(st, dilation) {
179
187
  if (dilation.every((s) => s === 1)) return st;
180
188
  const s_ = dilation;
181
- const [a, b, ...k_] = st.shape;
182
- st = st.reshape([
183
- a,
184
- b,
185
- ...k_.flatMap((k) => [k, 1])
186
- ]);
187
- st = st.pad([
188
- [0, 0],
189
- [0, 0],
190
- ...s_.flatMap((s) => [[0, 0], [0, s - 1]])
191
- ]);
192
- st = st.reshape([
193
- a,
194
- b,
195
- ...k_.map((k, i) => k * s_[i])
196
- ]);
197
- st = st.shrink([
198
- [0, a],
199
- [0, b],
200
- ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
201
- ]);
189
+ const n = s_.length;
190
+ const prefix = st.shape.slice(0, -n);
191
+ const k_ = st.shape.slice(-n);
192
+ st = st.reshape([...prefix, ...k_.flatMap((k) => [k, 1])]);
193
+ st = st.pad([...prefix.map(() => [0, 0]), ...s_.flatMap((s) => [[0, 0], [0, s - 1]])]);
194
+ st = st.reshape([...prefix, ...k_.map((k, i) => k * s_[i])]);
195
+ st = st.shrink([...prefix.map((p) => [0, p]), ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])]);
202
196
  return st;
203
197
  }
204
198
  /**
@@ -208,25 +202,26 @@ function applyDilation(st, dilation) {
208
202
  * beforehand using `checkConvShape()`.
209
203
  */
210
204
  function prepareConv(stX, stY, params) {
211
- const n = stX.shape.length - 2;
205
+ const v = params.vmapDims;
206
+ const n = stX.shape.length - 2 - v;
207
+ const vmapShape = stX.shape.slice(0, v);
212
208
  stX = applyDilation(stX, params.lhsDilation);
213
- const ks = stY.shape.slice(2);
214
- stX = stX.padOrShrink([
215
- [0, 0],
216
- [0, 0],
217
- ...params.padding
218
- ]);
209
+ const ks = stY.shape.slice(v + 2);
210
+ stX = stX.padOrShrink([...require_backend.rep(v + 2, [0, 0]), ...params.padding]);
219
211
  stX = pool(stX, ks, params.strides, params.rhsDilation);
220
- stX = stX.moveaxis(1, n + 1).reshape([
221
- stX.shape[0],
212
+ stX = stX.moveaxis(v + 1, v + n + 1).reshape([
213
+ ...vmapShape,
214
+ stX.shape[v],
222
215
  1,
223
- ...stX.shape.slice(2, n + 2),
224
- stX.shape[1] * require_backend.prod(ks)
216
+ ...stX.shape.slice(v + 2, v + n + 2),
217
+ stX.shape[v + 1] * require_backend.prod(ks)
225
218
  ]);
226
219
  stY = stY.reshape([
227
- stY.shape[0],
220
+ ...vmapShape,
221
+ 1,
222
+ stY.shape[v],
228
223
  ...require_backend.rep(n, 1),
229
- stY.shape[1] * require_backend.prod(ks)
224
+ stY.shape[v + 1] * require_backend.prod(ks)
230
225
  ]);
231
226
  return [stX, stY];
232
227
  }
@@ -367,6 +362,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
367
362
  Primitive$1["Mul"] = "mul";
368
363
  Primitive$1["Idiv"] = "idiv";
369
364
  Primitive$1["Mod"] = "mod";
365
+ Primitive$1["Min"] = "min";
366
+ Primitive$1["Max"] = "max";
370
367
  Primitive$1["Neg"] = "neg";
371
368
  Primitive$1["Reciprocal"] = "reciprocal";
372
369
  Primitive$1["Floor"] = "floor";
@@ -374,7 +371,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
374
371
  Primitive$1["StopGradient"] = "stop_gradient";
375
372
  Primitive$1["Cast"] = "cast";
376
373
  Primitive$1["Bitcast"] = "bitcast";
377
- Primitive$1["RandomBits"] = "random_bits";
378
374
  Primitive$1["Sin"] = "sin";
379
375
  Primitive$1["Cos"] = "cos";
380
376
  Primitive$1["Asin"] = "asin";
@@ -384,8 +380,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
384
380
  Primitive$1["Erf"] = "erf";
385
381
  Primitive$1["Erfc"] = "erfc";
386
382
  Primitive$1["Sqrt"] = "sqrt";
387
- Primitive$1["Min"] = "min";
388
- Primitive$1["Max"] = "max";
389
383
  Primitive$1["Reduce"] = "reduce";
390
384
  Primitive$1["Dot"] = "dot";
391
385
  Primitive$1["Conv"] = "conv";
@@ -393,14 +387,19 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
393
387
  Primitive$1["PoolTranspose"] = "pool_transpose";
394
388
  Primitive$1["Compare"] = "compare";
395
389
  Primitive$1["Where"] = "where";
390
+ Primitive$1["RandomBits"] = "random_bits";
391
+ Primitive$1["Gather"] = "gather";
396
392
  Primitive$1["Transpose"] = "transpose";
397
393
  Primitive$1["Broadcast"] = "broadcast";
398
394
  Primitive$1["Reshape"] = "reshape";
399
395
  Primitive$1["Flip"] = "flip";
400
396
  Primitive$1["Shrink"] = "shrink";
401
397
  Primitive$1["Pad"] = "pad";
402
- Primitive$1["Gather"] = "gather";
403
- Primitive$1["JitCall"] = "jit_call";
398
+ Primitive$1["Sort"] = "sort";
399
+ Primitive$1["Argsort"] = "argsort";
400
+ Primitive$1["TriangularSolve"] = "triangular_solve";
401
+ Primitive$1["Cholesky"] = "cholesky";
402
+ Primitive$1["Jit"] = "jit";
404
403
  return Primitive$1;
405
404
  }({});
406
405
  let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
@@ -422,6 +421,12 @@ function idiv(x, y) {
422
421
  function mod(x, y) {
423
422
  return bind1(Primitive.Mod, [x, y]);
424
423
  }
424
+ function min$1(x, y) {
425
+ return bind1(Primitive.Min, [x, y]);
426
+ }
427
+ function max$1(x, y) {
428
+ return bind1(Primitive.Max, [x, y]);
429
+ }
425
430
  function neg(x) {
426
431
  return bind1(Primitive.Neg, [x]);
427
432
  }
@@ -443,12 +448,6 @@ function cast(x, dtype) {
443
448
  function bitcast(x, dtype) {
444
449
  return bind1(Primitive.Bitcast, [x], { dtype });
445
450
  }
446
- function randomBits(k0, k1, shape$1, mode = "xor") {
447
- return bind1(Primitive.RandomBits, [k0, k1], {
448
- shape: shape$1,
449
- mode
450
- });
451
- }
452
451
  function sin$1(x) {
453
452
  return bind1(Primitive.Sin, [x]);
454
453
  }
@@ -476,12 +475,6 @@ function erfc$1(x) {
476
475
  function sqrt$1(x) {
477
476
  return bind1(Primitive.Sqrt, [x]);
478
477
  }
479
- function min$1(x, y) {
480
- return bind1(Primitive.Min, [x, y]);
481
- }
482
- function max$1(x, y) {
483
- return bind1(Primitive.Max, [x, y]);
484
- }
485
478
  function reduce(x, op, axis = null, opts) {
486
479
  if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
487
480
  axis = require_backend.normalizeAxis(axis, ndim$1(x));
@@ -498,9 +491,11 @@ function dot$2(x, y) {
498
491
  }
499
492
  function conv$1(x, y, params = {}) {
500
493
  if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
501
- const n = x.ndim - 2;
494
+ const vmapDims = params.vmapDims ?? 0;
495
+ const n = x.ndim - 2 - vmapDims;
502
496
  if (n < 0) throw new Error("conv() requires at least 2D inputs");
503
497
  return bind1(Primitive.Conv, [x, y], {
498
+ vmapDims,
504
499
  strides: params.strides ?? require_backend.rep(n, 1),
505
500
  padding: params.padding ?? require_backend.rep(n, [0, 0]),
506
501
  lhsDilation: params.lhsDilation ?? require_backend.rep(n, 1),
@@ -535,6 +530,23 @@ function where$1(cond, x, y) {
535
530
  y
536
531
  ]);
537
532
  }
533
+ function randomBits(k0, k1, shape$1, mode = "xor") {
534
+ return bind1(Primitive.RandomBits, [k0, k1], {
535
+ shape: shape$1,
536
+ mode
537
+ });
538
+ }
539
+ function gather(x, indices, axis, outDim) {
540
+ if (indices.length === 0) throw new Error("gather() requires at least one index");
541
+ if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
542
+ axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
543
+ if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
544
+ outDim = require_backend.checkAxis(outDim, ndim$1(x) - axis.length + 1);
545
+ return bind1(Primitive.Gather, [x, ...indices], {
546
+ axis,
547
+ outDim
548
+ });
549
+ }
538
550
  function transpose$1(x, perm) {
539
551
  perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
540
552
  if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
@@ -584,16 +596,27 @@ function pad$1(x, width) {
584
596
  } else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
585
597
  return bind1(Primitive.Pad, [x], { width });
586
598
  }
587
- function gather(x, indices, axis, outDim) {
588
- if (indices.length === 0) throw new Error("gather() requires at least one index");
589
- if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
590
- axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
591
- if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
592
- outDim = require_backend.checkAxis(outDim, ndim$1(x) - axis.length + 1);
593
- return bind1(Primitive.Gather, [x, ...indices], {
594
- axis,
595
- outDim
596
- });
599
+ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
600
+ if (lower) {
601
+ a = flip$1(a, [-2, -1]);
602
+ b = flip$1(b, [-1]);
603
+ }
604
+ let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
605
+ if (lower) x = flip$1(x, [-1]);
606
+ return x;
607
+ }
608
+ function cholesky$2(x) {
609
+ return bind1(Primitive.Cholesky, [x]);
610
+ }
611
+ function sort$1(x) {
612
+ const nd = ndim$1(x);
613
+ if (nd === 0) throw new Error("sort: requires at least 1D input");
614
+ return bind1(Primitive.Sort, [x]);
615
+ }
616
+ function argsort$1(x) {
617
+ const nd = ndim$1(x);
618
+ if (nd === 0) throw new Error("argsort: requires at least 1D input");
619
+ return bind(Primitive.Argsort, [x]);
597
620
  }
598
621
  function bind1(prim, args, params = {}) {
599
622
  const [results] = bind(prim, args, params);
@@ -724,8 +747,10 @@ var Tracer = class Tracer {
724
747
  axis = require_backend.normalizeAxis(axis, this.ndim);
725
748
  const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
726
749
  if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
727
- const result = reduce(this, require_backend.AluOp.Add, axis, opts);
728
- return result.mul(1 / n);
750
+ const originalDtype = this.dtype;
751
+ const castDtype = require_backend.promoteTypes(originalDtype, require_backend.DType.Float32);
752
+ const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
753
+ return result.mul(1 / n).astype(originalDtype);
729
754
  }
730
755
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
731
756
  transpose(perm) {
@@ -754,7 +779,7 @@ var Tracer = class Tracer {
754
779
  if (require_backend.isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
755
780
  return idiv(this, other);
756
781
  }
757
- /** Return specified diagonals. See `numpy.diagonal` for full docs. */
782
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
758
783
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
759
784
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
760
785
  if (offset < 0) return this.diagonal(-offset, axis2, axis1);
@@ -807,6 +832,34 @@ var Tracer = class Tracer {
807
832
  this.dispose();
808
833
  }
809
834
  /**
835
+ * Return a sorted copy of an array in ascending order.
836
+ *
837
+ * See `jax.numpy.sort` for full docs.
838
+ */
839
+ sort(axis = -1) {
840
+ axis = require_backend.checkAxis(axis, this.ndim);
841
+ if (this.shape[axis] <= 1) return this;
842
+ if (axis === this.ndim - 1) return sort$1(this);
843
+ const perm = require_backend.range(this.ndim);
844
+ perm.splice(axis, 1);
845
+ perm.push(axis);
846
+ return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
847
+ }
848
+ /**
849
+ * Return the indices that would sort an array. This may not be a stable
850
+ * sorting algorithm; it need not preserve order of indices in ties.
851
+ *
852
+ * See `jax.numpy.argsort` for full docs.
853
+ */
854
+ argsort(axis = -1) {
855
+ axis = require_backend.checkAxis(axis, this.ndim);
856
+ if (axis === this.ndim - 1) return argsort$1(this)[1];
857
+ const perm = require_backend.range(this.ndim);
858
+ perm.splice(axis, 1);
859
+ perm.push(axis);
860
+ return argsort$1(this.transpose(perm))[1].transpose(require_backend.invertPermutation(perm));
861
+ }
862
+ /**
810
863
  * Slice an array along one or more axes.
811
864
  *
812
865
  * This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
@@ -923,6 +976,9 @@ var ShapedArray = class ShapedArray {
923
976
  get ndim() {
924
977
  return this.shape.length;
925
978
  }
979
+ get size() {
980
+ return require_backend.prod(this.shape);
981
+ }
926
982
  toString() {
927
983
  return `${this.dtype}[${this.shape.join(",")}]`;
928
984
  }
@@ -1205,7 +1261,7 @@ var Jaxpr = class Jaxpr {
1205
1261
  } else if (eqn.primitive === Primitive.Idiv) {
1206
1262
  const [a, b] = inputs;
1207
1263
  const c = eqn.outBinders[0];
1208
- if (atomIsLit(b, 1)) context.set(c, a);
1264
+ if (atomIsLit(b, 1) && !require_backend.isFloatDtype(a.aval.dtype)) context.set(c, a);
1209
1265
  else newEqns.push(eqn);
1210
1266
  } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
1211
1267
  else newEqns.push(eqn);
@@ -1222,13 +1278,13 @@ var Jaxpr = class Jaxpr {
1222
1278
  }
1223
1279
  return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
1224
1280
  }
1225
- /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1281
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1226
1282
  flatten() {
1227
- if (!this.eqns.some((eqn) => eqn.primitive === Primitive.JitCall)) return this;
1283
+ if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
1228
1284
  const newEqns = [];
1229
1285
  const varMap = /* @__PURE__ */ new Map();
1230
1286
  const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
1231
- for (const eqn of this.eqns) if (eqn.primitive === Primitive.JitCall) {
1287
+ for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
1232
1288
  const jaxpr = eqn.params.jaxpr.flatten();
1233
1289
  const translation = /* @__PURE__ */ new Map();
1234
1290
  const translationF = (x) => x instanceof Var ? translation.get(x) : x;
@@ -1329,19 +1385,48 @@ function evalJaxpr(jaxpr, args) {
1329
1385
  function jaxprAsFun(jaxpr) {
1330
1386
  return (...args) => evalJaxpr(jaxpr, args);
1331
1387
  }
1388
+ /** Jaxpr with a collection of associated, traced constants. */
1389
+ var ClosedJaxpr = class ClosedJaxpr {
1390
+ constructor(jaxpr, consts) {
1391
+ this.jaxpr = jaxpr;
1392
+ this.consts = consts;
1393
+ }
1394
+ /** String representation of this Jaxpr. */
1395
+ toString() {
1396
+ return this.jaxpr.toString();
1397
+ }
1398
+ /** Apply a function to the underlying Jaxpr. */
1399
+ mapJaxpr(f) {
1400
+ return new ClosedJaxpr(f(this.jaxpr), this.consts);
1401
+ }
1402
+ /** Dispose of the constants in this Jaxpr. */
1403
+ dispose() {
1404
+ for (const c of this.consts) c.dispose();
1405
+ }
1406
+ };
1332
1407
  /** Tracer that records its operations to dynamically construct a Jaxpr. */
1333
1408
  var JaxprTracer = class extends Tracer {
1409
+ #rc;
1334
1410
  constructor(trace$1, aval) {
1335
1411
  super(trace$1);
1336
1412
  this.aval = aval;
1413
+ this.#rc = 1;
1337
1414
  }
1338
1415
  toString() {
1339
1416
  return `JaxprTracer(${this.aval.toString()})`;
1340
1417
  }
1341
1418
  get ref() {
1419
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1420
+ this.#rc++;
1342
1421
  return this;
1343
1422
  }
1344
- dispose() {}
1423
+ dispose() {
1424
+ if (this.#rc <= 0) throw new UseAfterFreeError(this);
1425
+ this.#rc--;
1426
+ }
1427
+ trackLiftedConstant() {
1428
+ this.#rc++;
1429
+ }
1345
1430
  };
1346
1431
  /** Analogous to the 'DynamicJaxprTrace' class in JAX. */
1347
1432
  var JaxprTrace = class extends Trace {
@@ -1354,17 +1439,24 @@ var JaxprTrace = class extends Trace {
1354
1439
  }
1355
1440
  /** Register a constant / literal in this Jaxpr. */
1356
1441
  getOrMakeConstTracer(val) {
1442
+ if (!(val instanceof Tracer)) val = pureArray(val);
1357
1443
  let tracer = this.builder.constTracers.get(val);
1358
1444
  if (tracer === void 0) {
1359
1445
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
1360
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
1446
+ this.builder.addConst(tracer, val);
1447
+ } else {
1448
+ val.dispose();
1449
+ tracer.trackLiftedConstant();
1361
1450
  }
1362
1451
  return tracer;
1363
1452
  }
1364
1453
  pure = this.getOrMakeConstTracer;
1365
1454
  lift = this.getOrMakeConstTracer;
1366
1455
  processPrimitive(primitive, tracers, params) {
1367
- const avalsIn = tracers.map((t) => t.aval);
1456
+ const avalsIn = tracers.map((t) => {
1457
+ t.dispose();
1458
+ return t.aval;
1459
+ });
1368
1460
  const avalsOut = abstractEvalRules[primitive](avalsIn, params);
1369
1461
  const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
1370
1462
  this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
@@ -1407,20 +1499,17 @@ var JaxprBuilder = class {
1407
1499
  return v;
1408
1500
  }
1409
1501
  build(inTracers, outTracers) {
1410
- let [constVars, consts] = require_backend.unzip2(this.constVals.entries());
1502
+ const [constVars, consts] = require_backend.unzip2(this.constVals.entries());
1411
1503
  const t2v = this.getVar.bind(this);
1412
1504
  const inBinders = [...constVars, ...inTracers.map(t2v)];
1413
1505
  const outVars = outTracers.map(t2v);
1414
- let jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1506
+ const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
1415
1507
  typecheckJaxpr(jaxpr);
1416
- [jaxpr, consts] = _inlineLiterals(jaxpr, consts);
1417
- return {
1418
- jaxpr,
1419
- consts
1420
- };
1508
+ const cjaxpr = new ClosedJaxpr(jaxpr, consts);
1509
+ return _inlineLiterals(cjaxpr);
1421
1510
  }
1422
1511
  };
1423
- function _inlineLiterals(jaxpr, consts) {
1512
+ function _inlineLiterals({ jaxpr, consts }) {
1424
1513
  const literals = /* @__PURE__ */ new Map();
1425
1514
  const constBinders = [];
1426
1515
  const newConsts = [];
@@ -1435,7 +1524,7 @@ function _inlineLiterals(jaxpr, consts) {
1435
1524
  const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
1436
1525
  const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
1437
1526
  typecheckJaxpr(newJaxpr);
1438
- return [newJaxpr, newConsts];
1527
+ return new ClosedJaxpr(newJaxpr, newConsts);
1439
1528
  }
1440
1529
  function binopAbstractEval([x, y]) {
1441
1530
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
@@ -1454,6 +1543,8 @@ const abstractEvalRules = {
1454
1543
  [Primitive.Mul]: binopAbstractEval,
1455
1544
  [Primitive.Idiv]: binopAbstractEval,
1456
1545
  [Primitive.Mod]: binopAbstractEval,
1546
+ [Primitive.Min]: binopAbstractEval,
1547
+ [Primitive.Max]: binopAbstractEval,
1457
1548
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1458
1549
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1459
1550
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -1467,12 +1558,6 @@ const abstractEvalRules = {
1467
1558
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
1468
1559
  return [new ShapedArray(x.shape, dtype, false)];
1469
1560
  },
1470
- [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1471
- if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1472
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
1473
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1474
- return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
1475
- },
1476
1561
  [Primitive.Sin]: vectorizedUnopAbstractEval,
1477
1562
  [Primitive.Cos]: vectorizedUnopAbstractEval,
1478
1563
  [Primitive.Asin]: vectorizedUnopAbstractEval,
@@ -1482,8 +1567,6 @@ const abstractEvalRules = {
1482
1567
  [Primitive.Erf]: vectorizedUnopAbstractEval,
1483
1568
  [Primitive.Erfc]: vectorizedUnopAbstractEval,
1484
1569
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
1485
- [Primitive.Min]: binopAbstractEval,
1486
- [Primitive.Max]: binopAbstractEval,
1487
1570
  [Primitive.Reduce]([x], { axis }) {
1488
1571
  const axisSet = new Set(axis);
1489
1572
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
@@ -1516,6 +1599,25 @@ const abstractEvalRules = {
1516
1599
  const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
1517
1600
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1518
1601
  },
1602
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1603
+ if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1604
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
1605
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1606
+ return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
1607
+ },
1608
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1609
+ for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1610
+ if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1611
+ if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1612
+ if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1613
+ if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1614
+ const axisSet = new Set(axis);
1615
+ if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1616
+ const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
1617
+ const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1618
+ newShape.splice(outDim, 0, ...gatherShape);
1619
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
1620
+ },
1519
1621
  [Primitive.Transpose]([x], { perm }) {
1520
1622
  return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
1521
1623
  },
@@ -1536,23 +1638,31 @@ const abstractEvalRules = {
1536
1638
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
1537
1639
  return [new ShapedArray(newShape, x.dtype, x.weakType)];
1538
1640
  },
1539
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
1540
- for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
1541
- if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
1542
- if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
1543
- if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
1544
- if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
1545
- const axisSet = new Set(axis);
1546
- if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
1547
- const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
1548
- const newShape = x.shape.filter((_, i) => !axisSet.has(i));
1549
- newShape.splice(outDim, 0, ...gatherShape);
1550
- return [new ShapedArray(newShape, x.dtype, x.weakType)];
1641
+ [Primitive.Sort]([x]) {
1642
+ if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
1643
+ return [ShapedArray.fromAval(x)];
1644
+ },
1645
+ [Primitive.Argsort]([x]) {
1646
+ if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
1647
+ return [ShapedArray.fromAval(x), new ShapedArray(x.shape, require_backend.DType.Int32, false)];
1648
+ },
1649
+ [Primitive.TriangularSolve]([a, b]) {
1650
+ if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
1651
+ if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
1652
+ const [m, n] = a.shape.slice(-2);
1653
+ const [_batch, q] = b.shape.slice(-2);
1654
+ if (!require_backend.deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
1655
+ return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
1656
+ },
1657
+ [Primitive.Cholesky]([a]) {
1658
+ if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
1659
+ if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1660
+ return [ShapedArray.fromAval(a)];
1551
1661
  },
1552
- [Primitive.JitCall](args, { jaxpr }) {
1662
+ [Primitive.Jit](args, { jaxpr }) {
1553
1663
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1554
- if (args.length !== inTypes.length) throw new TypeError(`jit_call expected ${inTypes.length} arguments, got ${args.length}`);
1555
- for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit_call argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1664
+ if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
1665
+ for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
1556
1666
  return outTypes;
1557
1667
  }
1558
1668
  };
@@ -1588,11 +1698,10 @@ function makeJaxpr$1(f, opts) {
1588
1698
  const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
1589
1699
  const outs = fFlat(...tracersIn);
1590
1700
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
1591
- const { jaxpr, consts } = builder.build(tracersIn, tracersOut);
1701
+ const jaxpr = builder.build(tracersIn, tracersOut);
1592
1702
  if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
1593
1703
  return {
1594
- jaxpr: jaxpr.simplify(),
1595
- consts,
1704
+ jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
1596
1705
  treedef: outTree.value
1597
1706
  };
1598
1707
  } catch (_) {
@@ -1611,22 +1720,28 @@ function jit$1(f, opts) {
1611
1720
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
1612
1721
  const avalsIn = unflatten(inTree, avalsInFlat);
1613
1722
  const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
1614
- const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1615
- const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
1723
+ const { jaxpr, treedef: outTree } = require_backend.runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
1724
+ const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
1616
1725
  name: f.name || "closure",
1617
- jaxpr,
1618
- numConsts: consts.length
1726
+ jaxpr: jaxpr.jaxpr,
1727
+ numConsts: jaxpr.consts.length
1619
1728
  });
1620
1729
  return unflatten(outTree, outs);
1621
1730
  });
1622
1731
  result.dispose = () => {
1623
- for (const { consts } of cache.values()) for (const c of consts) c.dispose();
1732
+ for (const { jaxpr } of cache.values()) jaxpr.dispose();
1624
1733
  };
1625
1734
  return result;
1626
1735
  }
1627
1736
 
1628
1737
  //#endregion
1629
1738
  //#region src/frontend/jit.ts
1739
+ const routinePrimitives = new Map([
1740
+ [Primitive.Sort, require_backend.Routines.Sort],
1741
+ [Primitive.Argsort, require_backend.Routines.Argsort],
1742
+ [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
1743
+ [Primitive.Cholesky, require_backend.Routines.Cholesky]
1744
+ ]);
1630
1745
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1631
1746
  var JitProgram = class {
1632
1747
  constructor(backend, steps, inputs, outputs) {
@@ -1641,9 +1756,14 @@ var JitProgram = class {
1641
1756
  case "execute": {
1642
1757
  const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
1643
1758
  const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
1644
- return require_backend.PPrint.pp(`execute (${inputsNice}) -> ${outputsNice}, kernel`).concat(step.kernel.pprint().indent(2));
1759
+ const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
1760
+ if (step.source instanceof require_backend.Kernel) return require_backend.PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
1761
+ else if (step.source instanceof require_backend.Routine) return require_backend.PPrint.pp(`${executeText}, routine ${step.source.name}`);
1762
+ else {
1763
+ step.source;
1764
+ return require_backend.PPrint.pp(executeText);
1765
+ }
1645
1766
  }
1646
- case "const": return require_backend.PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
1647
1767
  case "malloc": return require_backend.PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
1648
1768
  case "incref": return require_backend.PPrint.pp(`incref ${step.input}`);
1649
1769
  case "free": return require_backend.PPrint.pp(`free ${step.input}`);
@@ -1666,12 +1786,9 @@ var JitProgram = class {
1666
1786
  const inputs$1 = step.inputs.map((id) => scope.get(id));
1667
1787
  const outputs = step.outputs.map((id) => scope.get(id));
1668
1788
  if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
1669
- pending.push(new PendingExecute(this.backend, step.kernel, inputs$1, outputs));
1789
+ pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
1670
1790
  break;
1671
1791
  }
1672
- case "const":
1673
- scope.set(step.output, step.slot);
1674
- break;
1675
1792
  case "malloc": {
1676
1793
  const slot = this.backend.malloc(step.size);
1677
1794
  scope.set(step.output, slot);
@@ -1705,34 +1822,37 @@ var JitProgramBuilder = class {
1705
1822
  this.#nextId = nargs;
1706
1823
  this.steps = [];
1707
1824
  }
1708
- pushConst(slot) {
1709
- const id = this.#nextId++;
1710
- this.steps.push({
1711
- type: "const",
1712
- slot,
1713
- output: id
1714
- });
1715
- return id;
1716
- }
1717
1825
  pushLit(lit) {
1718
- const kernel = new require_backend.Kernel(0, require_backend.prod(lit.aval.shape), require_backend.AluExp.const(lit.dtype, lit.value));
1826
+ const kernel = new require_backend.Kernel(0, lit.aval.size, require_backend.AluExp.const(lit.dtype, lit.value));
1719
1827
  return this.pushKernel(kernel, []);
1720
1828
  }
1721
- pushKernel(kernel, inputs) {
1829
+ pushBuffer(size$1) {
1722
1830
  const id = this.#nextId++;
1723
1831
  this.steps.push({
1724
1832
  type: "malloc",
1725
- size: kernel.bytes,
1833
+ size: size$1,
1726
1834
  output: id
1727
1835
  });
1836
+ return id;
1837
+ }
1838
+ pushKernel(kernel, inputs) {
1839
+ const id = this.pushBuffer(kernel.bytes);
1728
1840
  this.steps.push({
1729
1841
  type: "execute",
1730
- kernel,
1842
+ source: kernel,
1731
1843
  inputs,
1732
1844
  outputs: [id]
1733
1845
  });
1734
1846
  return id;
1735
1847
  }
1848
+ pushRoutine(routine, inputs, outputs) {
1849
+ this.steps.push({
1850
+ type: "execute",
1851
+ source: routine,
1852
+ inputs,
1853
+ outputs
1854
+ });
1855
+ }
1736
1856
  pushIncref(id) {
1737
1857
  this.steps.push({
1738
1858
  type: "incref",
@@ -1758,28 +1878,18 @@ var JitProgramBuilder = class {
1758
1878
  }
1759
1879
  };
1760
1880
  const jitCompileCache = /* @__PURE__ */ new Map();
1761
- function jitCompile(backend, jaxpr, consts) {
1762
- if (jaxpr.inBinders.length < consts.length) throw new TypeError(`Jaxpr has ${jaxpr.inBinders.length} inputs, but ${consts.length} consts were provided`);
1763
- for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
1764
- const cacheKey = backend.type + require_backend.FpHash.hash(jaxpr, ...consts.map((c) => c.id));
1881
+ function jitCompile(backend, jaxpr) {
1882
+ const cacheKey = backend.type + "," + require_backend.FpHash.hash(jaxpr);
1765
1883
  const cached = jitCompileCache.get(cacheKey);
1766
1884
  if (cached) return cached;
1767
1885
  if (require_backend.DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
1768
1886
  jaxpr = jaxpr.flatten().simplify();
1769
- const nargs = jaxpr.inBinders.length - consts.length;
1887
+ const nargs = jaxpr.inBinders.length;
1770
1888
  const builder = new JitProgramBuilder(backend, nargs);
1771
1889
  const blackNodes = splitGraphDataflow(backend, jaxpr);
1772
1890
  const ctx = /* @__PURE__ */ new Map();
1773
- for (let i = 0; i < consts.length; i++) {
1774
- const v = jaxpr.inBinders[i];
1775
- const slot = consts[i]._realizeSource();
1776
- ctx.set(v, {
1777
- type: "imm",
1778
- arg: builder.pushConst(slot)
1779
- });
1780
- }
1781
1891
  for (let i = 0; i < nargs; i++) {
1782
- const v = jaxpr.inBinders[consts.length + i];
1892
+ const v = jaxpr.inBinders[i];
1783
1893
  ctx.set(v, {
1784
1894
  type: "imm",
1785
1895
  arg: i
@@ -1787,51 +1897,101 @@ function jitCompile(backend, jaxpr, consts) {
1787
1897
  }
1788
1898
  for (let i = 0; i < jaxpr.eqns.length; i++) {
1789
1899
  const eqn = jaxpr.eqns[i];
1900
+ if (routinePrimitives.has(eqn.primitive)) {
1901
+ const routine = new require_backend.Routine(routinePrimitives.get(eqn.primitive), {
1902
+ inputShapes: eqn.inputs.map((x) => x.aval.shape),
1903
+ inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
1904
+ outputShapes: eqn.outBinders.map((x) => x.aval.shape),
1905
+ outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
1906
+ }, eqn.params);
1907
+ const inputs = [];
1908
+ for (const input of eqn.inputs) if (input instanceof Var) {
1909
+ const jv = ctx.get(input);
1910
+ if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
1911
+ inputs.push(jv.arg);
1912
+ } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1913
+ const outputs = [];
1914
+ for (const outVar$1 of eqn.outBinders) {
1915
+ const outId = builder.pushBuffer(outVar$1.aval.size * require_backend.byteWidth(outVar$1.aval.dtype));
1916
+ outputs.push(outId);
1917
+ ctx.set(outVar$1, {
1918
+ type: "imm",
1919
+ arg: outId
1920
+ });
1921
+ }
1922
+ builder.pushRoutine(routine, inputs, outputs);
1923
+ continue;
1924
+ }
1790
1925
  const inputExps = [];
1791
1926
  const inputAvals = [];
1792
1927
  const inputArgs = [];
1793
- for (const input of eqn.inputs) if (input instanceof Var) {
1794
- const jitValue = ctx.get(input);
1795
- if (jitValue.type === "exp") {
1796
- const gidMap = /* @__PURE__ */ new Map();
1797
- for (const [gid, jitId] of jitValue.args.entries()) {
1798
- let newGid = inputArgs.indexOf(jitId);
1799
- if (newGid === -1) {
1800
- newGid = inputArgs.length;
1801
- inputArgs.push(jitId);
1802
- }
1803
- gidMap.set(gid, newGid);
1804
- }
1805
- inputExps.push(jitValue.exp.reindexGids(gidMap));
1806
- } else if (jitValue.type === "imm") {
1807
- let gid = inputArgs.indexOf(jitValue.arg);
1808
- if (gid === -1) {
1809
- gid = inputArgs.length;
1810
- inputArgs.push(jitValue.arg);
1928
+ let inputReduction = null;
1929
+ const addArgs = (args) => {
1930
+ const newGids = [];
1931
+ for (const jitId of args) {
1932
+ let newGid = inputArgs.indexOf(jitId);
1933
+ if (newGid === -1) {
1934
+ newGid = inputArgs.length;
1935
+ inputArgs.push(jitId);
1811
1936
  }
1937
+ newGids.push(newGid);
1938
+ }
1939
+ return newGids;
1940
+ };
1941
+ for (const input of eqn.inputs) if (input instanceof Var) {
1942
+ const jv = ctx.get(input);
1943
+ if (jv.type === "exp") {
1944
+ const newGids = addArgs(jv.args);
1945
+ inputExps.push(jv.exp.reindexGids(newGids));
1946
+ } else if (jv.type === "imm") {
1947
+ const [gid] = addArgs([jv.arg]);
1812
1948
  const st = require_backend.ShapeTracker.fromShape(input.aval.shape);
1813
1949
  const indices = require_backend.unravelAlu(st.shape, require_backend.AluVar.gidx);
1814
1950
  inputExps.push(require_backend.AluExp.globalView(input.aval.dtype, gid, st, indices));
1951
+ } else if (jv.type === "red") {
1952
+ if (inputReduction) throw new Error("jit: unexpected, multiple red inputs");
1953
+ const newGids = addArgs(jv.args);
1954
+ inputExps.push(jv.reduction.epilogue.reindexGids(newGids));
1955
+ inputReduction = jv;
1815
1956
  }
1816
1957
  inputAvals.push(input.aval);
1817
1958
  } else if (input instanceof Lit) {
1818
1959
  inputExps.push(require_backend.AluExp.const(input.dtype, input.value));
1819
1960
  inputAvals.push(input.aval);
1820
1961
  } else throw new TypeError(`Unexpected input in Jaxpr: ${input}`);
1821
- const nargs$1 = inputArgs.length;
1822
1962
  const rule = jitRules[eqn.primitive];
1823
1963
  if (!rule) throw new TypeError(`JIT not implemented for primitive ${eqn.primitive}`);
1824
- const kernel = rule(nargs$1, inputExps, inputAvals, eqn.params);
1964
+ let exp$2;
1965
+ let reduction;
1966
+ if (inputReduction) {
1967
+ const jv = inputReduction;
1968
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1969
+ exp$2 = jv.exp.reindexGids(addArgs(jv.args));
1970
+ reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1971
+ } else {
1972
+ const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1973
+ exp$2 = ruleOutput.exp;
1974
+ reduction = ruleOutput.reduction;
1975
+ }
1825
1976
  const outVar = eqn.outBinders[0];
1826
- if (kernel.reduction || blackNodes.has(outVar)) {
1977
+ if (blackNodes.has(outVar)) {
1978
+ const nargs$1 = inputArgs.length;
1979
+ const size$1 = outVar.aval.size;
1980
+ const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2, reduction);
1827
1981
  const outId = builder.pushKernel(kernel, inputArgs);
1828
1982
  ctx.set(outVar, {
1829
1983
  type: "imm",
1830
1984
  arg: outId
1831
1985
  });
1832
- } else ctx.set(outVar, {
1986
+ } else if (reduction) ctx.set(outVar, {
1987
+ type: "red",
1988
+ exp: exp$2,
1989
+ reduction,
1990
+ args: inputArgs
1991
+ });
1992
+ else ctx.set(outVar, {
1833
1993
  type: "exp",
1834
- exp: kernel.exp,
1994
+ exp: exp$2,
1835
1995
  args: inputArgs
1836
1996
  });
1837
1997
  }
@@ -1841,7 +2001,7 @@ function jitCompile(backend, jaxpr, consts) {
1841
2001
  if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
1842
2002
  outputIds.push(jitValue.arg);
1843
2003
  } else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
1844
- const outputNeedsRef = new Set([...require_backend.range(nargs), ...builder.steps.filter((s) => s.type === "const").map((s) => s.output)]);
2004
+ const outputNeedsRef = new Set(require_backend.range(nargs));
1845
2005
  for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
1846
2006
  else outputNeedsRef.add(outputId);
1847
2007
  builder.insertFreeSteps(outputIds);
@@ -1863,31 +2023,33 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1863
2023
  });
1864
2024
  }
1865
2025
  function broadcastedJit(fn, opts) {
1866
- return (nargs, exps, avals, params) => {
2026
+ return (exps, avals, params) => {
1867
2027
  let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1868
2028
  const skipCastIdx = opts?.skipCastIdx ?? [];
1869
2029
  if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1870
- exps = exps.map((exp$3, i) => {
1871
- exp$3 = reshapeViews(exp$3, (st) => {
2030
+ exps = exps.map((exp$2, i) => {
2031
+ exp$2 = reshapeViews(exp$2, (st) => {
1872
2032
  if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1873
2033
  });
1874
- if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = require_backend.AluExp.cast(newDtype, exp$3);
1875
- return exp$3;
2034
+ if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
2035
+ return exp$2;
1876
2036
  });
1877
- const exp$2 = fn(exps, params);
1878
- return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
2037
+ return { exp: fn(exps, params) };
1879
2038
  };
1880
2039
  }
1881
2040
  function unopJit(fn) {
1882
- return (nargs, [a], [as], params) => {
1883
- return new require_backend.Kernel(nargs, require_backend.prod(as.shape), fn(a, params));
2041
+ return ([a], [_as], params) => {
2042
+ return { exp: fn(a, params) };
1884
2043
  };
1885
2044
  }
1886
2045
  function reshapeJit(fn) {
1887
- return (nargs, [a], [as], params) => {
1888
- a = reshapeViews(a, (st) => fn(st, params));
1889
- const newShape = fn(require_backend.ShapeTracker.fromShape(as.shape), params).shape;
1890
- return new require_backend.Kernel(nargs, require_backend.prod(newShape), a);
2046
+ return ([a], [_as], params) => {
2047
+ return { exp: reshapeViews(a, (st) => fn(st, params)) };
2048
+ };
2049
+ }
2050
+ function routineNoJit() {
2051
+ return () => {
2052
+ throw new Error("jit: rule is not implemented for routines");
1891
2053
  };
1892
2054
  }
1893
2055
  const jitRules = {
@@ -1895,6 +2057,8 @@ const jitRules = {
1895
2057
  [Primitive.Mul]: broadcastedJit(([a, b]) => require_backend.AluExp.mul(a, b)),
1896
2058
  [Primitive.Idiv]: broadcastedJit(([a, b]) => require_backend.AluExp.idiv(a, b)),
1897
2059
  [Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
2060
+ [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
2061
+ [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
1898
2062
  [Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
1899
2063
  [Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
1900
2064
  [Primitive.Floor]: unopJit(require_backend.AluExp.floor),
@@ -1902,17 +2066,6 @@ const jitRules = {
1902
2066
  [Primitive.StopGradient]: unopJit((a) => a),
1903
2067
  [Primitive.Cast]: unopJit((a, { dtype }) => require_backend.AluExp.cast(dtype, a)),
1904
2068
  [Primitive.Bitcast]: unopJit((a, { dtype }) => require_backend.AluExp.bitcast(dtype, a)),
1905
- [Primitive.RandomBits]: (nargs, keys, keyShapes, { shape: shape$1, mode }) => {
1906
- const mapping = (st) => {
1907
- if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
1908
- };
1909
- const k0 = reshapeViews(keys[0], mapping);
1910
- const k1 = reshapeViews(keys[1], mapping);
1911
- const c0 = require_backend.AluExp.u32(0);
1912
- const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
1913
- const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
1914
- return new require_backend.Kernel(nargs, require_backend.prod(shape$1), exp$2);
1915
- },
1916
2069
  [Primitive.Sin]: unopJit(require_backend.AluExp.sin),
1917
2070
  [Primitive.Cos]: unopJit(require_backend.AluExp.cos),
1918
2071
  [Primitive.Asin]: unopJit(require_backend.AluExp.asin),
@@ -1922,9 +2075,7 @@ const jitRules = {
1922
2075
  [Primitive.Erf]: unopJit(require_backend.AluExp.erf),
1923
2076
  [Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
1924
2077
  [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
1925
- [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
1926
- [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
1927
- [Primitive.Reduce](nargs, [a], [as], { op, axis }) {
2078
+ [Primitive.Reduce]([a], [as], { op, axis }) {
1928
2079
  const keptAxes = [];
1929
2080
  const shiftedAxes = [];
1930
2081
  const newShape = [];
@@ -1933,53 +2084,58 @@ const jitRules = {
1933
2084
  keptAxes.push(i);
1934
2085
  newShape.push(as.shape[i]);
1935
2086
  }
1936
- const size$1 = require_backend.prod(newShape);
1937
2087
  const reductionSize = require_backend.prod(shiftedAxes.map((ax) => as.shape[ax]));
1938
2088
  newShape.push(reductionSize);
1939
2089
  const perm = keptAxes.concat(shiftedAxes);
1940
2090
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
1941
2091
  const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
1942
- return new require_backend.Kernel(nargs, size$1, a, reduction);
2092
+ return {
2093
+ exp: a,
2094
+ reduction
2095
+ };
1943
2096
  },
1944
2097
  [Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
1945
- [Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
2098
+ [Primitive.PoolTranspose]([a], [as], { inShape, window, strides }) {
1946
2099
  let stX = poolTranspose(require_backend.ShapeTracker.fromShape(as.shape), inShape, window, strides);
1947
- const size$1 = require_backend.prod(inShape);
1948
2100
  stX = stX.reshape([...inShape, require_backend.prod(stX.shape.slice(inShape.length))]);
1949
2101
  a = reshapeViews(a, (st) => st.compose(stX), true);
1950
2102
  const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
1951
- return new require_backend.Kernel(nargs, size$1, a, reduction);
2103
+ return {
2104
+ exp: a,
2105
+ reduction
2106
+ };
1952
2107
  },
1953
- [Primitive.Dot](nargs, [a, b], [as, bs]) {
1954
- const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
2108
+ [Primitive.Dot]([a, b], [as, bs]) {
2109
+ const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
1955
2110
  const c = k1.exp;
1956
2111
  const cs = promoteAvals(as, bs);
1957
- return jitRules[Primitive.Reduce](nargs, [c], [cs], {
2112
+ return jitRules[Primitive.Reduce]([c], [cs], {
1958
2113
  op: require_backend.AluOp.Add,
1959
2114
  axis: [cs.ndim - 1]
1960
2115
  });
1961
2116
  },
1962
- [Primitive.Conv](nargs, [a, b], [as, bs], params) {
2117
+ [Primitive.Conv]([a, b], [as, bs], params) {
1963
2118
  const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
1964
2119
  a = reshapeViews(a, (st) => st.compose(stX));
1965
2120
  b = reshapeViews(b, (st) => st.compose(stY));
1966
2121
  as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1967
2122
  bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1968
- return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
2123
+ return jitRules[Primitive.Dot]([a, b], [as, bs], {});
1969
2124
  },
1970
2125
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1971
2126
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1972
- [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1973
- [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1974
- [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
1975
- [Primitive.Flip]: reshapeJit((st, { axis }) => {
1976
- const arg = require_backend.rep(st.shape.length, false);
1977
- for (const ax of axis) arg[ax] = true;
1978
- return st.flip(arg);
1979
- }),
1980
- [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
1981
- [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
1982
- [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
2127
+ [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2128
+ const mapping = (st) => {
2129
+ if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
2130
+ };
2131
+ const k0 = reshapeViews(keys[0], mapping);
2132
+ const k1 = reshapeViews(keys[1], mapping);
2133
+ const c0 = require_backend.AluExp.u32(0);
2134
+ const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
2135
+ const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
2136
+ return { exp: exp$2 };
2137
+ },
2138
+ [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1983
2139
  const axisSet = new Set(axis);
1984
2140
  const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
1985
2141
  const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
@@ -1992,24 +2148,38 @@ const jitRules = {
1992
2148
  for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
1993
2149
  const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
1994
2150
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1995
- return new require_backend.Kernel(nargs, require_backend.prod(finalShape), x.substitute({ gidx: index }));
2151
+ return { exp: x.substitute({ gidx: index }) };
1996
2152
  },
1997
- [Primitive.JitCall]() {
1998
- throw new Error("internal: JitCall should have been flattened before JIT compilation");
2153
+ [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2154
+ [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
2155
+ [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
2156
+ [Primitive.Flip]: reshapeJit((st, { axis }) => {
2157
+ const arg = require_backend.rep(st.shape.length, false);
2158
+ for (const ax of axis) arg[ax] = true;
2159
+ return st.flip(arg);
2160
+ }),
2161
+ [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
2162
+ [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
2163
+ [Primitive.Sort]: routineNoJit(),
2164
+ [Primitive.Argsort]: routineNoJit(),
2165
+ [Primitive.TriangularSolve]: routineNoJit(),
2166
+ [Primitive.Cholesky]: routineNoJit(),
2167
+ [Primitive.Jit]() {
2168
+ throw new Error("internal: Jit should have been flattened before JIT compilation");
1999
2169
  }
2000
2170
  };
2001
2171
  /** Determines how to split the Jaxpr into kernels via dataflow analysis. */
2002
2172
  function splitGraphDataflow(backend, jaxpr) {
2003
- const varToEqn = /* @__PURE__ */ new Map();
2173
+ const varToDefn = /* @__PURE__ */ new Map();
2174
+ const varToUsages = /* @__PURE__ */ new Map();
2004
2175
  for (let i = 0; i < jaxpr.eqns.length; i++) {
2005
2176
  const eqn = jaxpr.eqns[i];
2006
- for (const v of eqn.outBinders) if (v instanceof Var) varToEqn.set(v, i);
2007
- }
2008
- const blackNodes = /* @__PURE__ */ new Set();
2009
- const p1NextBlack = /* @__PURE__ */ new Map();
2010
- for (const v of jaxpr.outs) if (v instanceof Var) {
2011
- blackNodes.add(v);
2012
- p1NextBlack.set(v, v);
2177
+ for (const v of eqn.outBinders) if (v instanceof Var) varToDefn.set(v, i);
2178
+ for (const input of eqn.inputs) if (input instanceof Var) {
2179
+ const usages = varToUsages.get(input);
2180
+ if (usages) usages.push(i);
2181
+ else varToUsages.set(input, [i]);
2182
+ }
2013
2183
  }
2014
2184
  const reducePrimitives = [
2015
2185
  Primitive.Reduce,
@@ -2017,28 +2187,94 @@ function splitGraphDataflow(backend, jaxpr) {
2017
2187
  Primitive.Conv,
2018
2188
  Primitive.PoolTranspose
2019
2189
  ];
2020
- const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
2021
- for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2190
+ const reductionEpilogueEqns = /* @__PURE__ */ new Set();
2191
+ const reductionEndpointEqns = /* @__PURE__ */ new Set();
2192
+ for (let i = 0; i < jaxpr.eqns.length; i++) {
2022
2193
  const eqn = jaxpr.eqns[i];
2023
- if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2024
- for (const v of eqn.outBinders) {
2025
- blackNodes.add(v);
2026
- p1NextBlack.set(v, v);
2194
+ if (reducePrimitives.includes(eqn.primitive)) {
2195
+ let head = i;
2196
+ while (true) {
2197
+ reductionEpilogueEqns.add(head);
2198
+ const outVar = jaxpr.eqns[head].outBinders[0];
2199
+ const usages = varToUsages.get(outVar) ?? [];
2200
+ if (jaxpr.outs.includes(outVar) || usages.length !== 1) break;
2201
+ if (reductionEpilogueEqns.has(usages[0])) break;
2202
+ const nextEqn = jaxpr.eqns[usages[0]];
2203
+ switch (nextEqn.primitive) {
2204
+ case Primitive.Neg:
2205
+ case Primitive.Reciprocal:
2206
+ case Primitive.Floor:
2207
+ case Primitive.Ceil:
2208
+ case Primitive.StopGradient:
2209
+ case Primitive.Cast:
2210
+ case Primitive.Bitcast:
2211
+ case Primitive.Sin:
2212
+ case Primitive.Cos:
2213
+ case Primitive.Asin:
2214
+ case Primitive.Atan:
2215
+ case Primitive.Exp:
2216
+ case Primitive.Log:
2217
+ case Primitive.Erf:
2218
+ case Primitive.Erfc:
2219
+ case Primitive.Sqrt:
2220
+ head = usages[0];
2221
+ continue;
2222
+ case Primitive.Add:
2223
+ case Primitive.Mul:
2224
+ case Primitive.Idiv:
2225
+ case Primitive.Mod:
2226
+ case Primitive.Min:
2227
+ case Primitive.Max: {
2228
+ const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2229
+ if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2230
+ head = usages[0];
2231
+ continue;
2232
+ }
2233
+ break;
2234
+ }
2235
+ }
2236
+ break;
2027
2237
  }
2028
- continue;
2238
+ reductionEndpointEqns.add(head);
2029
2239
  }
2030
- const reach = /* @__PURE__ */ new Set();
2031
- for (let j = i + 1; j < jaxpr.eqns.length; j++) for (const v of jaxpr.eqns[j].inputs) if (v instanceof Var && eqn.outBinders.includes(v)) for (const o of jaxpr.eqns[j].outBinders) {
2032
- const u = p1NextBlack.get(o);
2033
- if (u) reach.add(u);
2034
- }
2035
- if (reach.size === 1) {
2036
- const b = reach.values().next().value;
2037
- for (const v of eqn.outBinders) p1NextBlack.set(v, b);
2038
- } else if (reach.size > 1) for (const v of eqn.outBinders) {
2240
+ }
2241
+ const blackNodes = /* @__PURE__ */ new Set();
2242
+ const p1NextBlack = /* @__PURE__ */ new Map();
2243
+ for (const v of jaxpr.outs) if (v instanceof Var) {
2244
+ blackNodes.add(v);
2245
+ p1NextBlack.set(v, v);
2246
+ }
2247
+ const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2248
+ const needsCleanShapePrimitives = [Primitive.Pad];
2249
+ for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2250
+ const eqn = jaxpr.eqns[i];
2251
+ if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
2252
+ for (const v of eqn.outBinders) {
2253
+ blackNodes.add(v);
2254
+ p1NextBlack.set(v, v);
2255
+ }
2256
+ continue;
2257
+ }
2258
+ const reach = /* @__PURE__ */ new Set();
2259
+ let needsCleanOutput = false;
2260
+ outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
2261
+ if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
2262
+ needsCleanOutput = true;
2263
+ break outer;
2264
+ }
2265
+ for (const o of jaxpr.eqns[j].outBinders) {
2266
+ const u = p1NextBlack.get(o);
2267
+ if (u) reach.add(u);
2268
+ }
2269
+ }
2270
+ if (reach.size > 1 || needsCleanOutput) for (const v of eqn.outBinders) {
2039
2271
  blackNodes.add(v);
2040
2272
  p1NextBlack.set(v, v);
2041
2273
  }
2274
+ else if (reach.size === 1) {
2275
+ const b = reach.values().next().value;
2276
+ for (const v of eqn.outBinders) p1NextBlack.set(v, b);
2277
+ }
2042
2278
  }
2043
2279
  const p2Deps = /* @__PURE__ */ new Map();
2044
2280
  for (const v of jaxpr.inBinders) p2Deps.set(v, new Set([v]));
@@ -2046,7 +2282,6 @@ function splitGraphDataflow(backend, jaxpr) {
2046
2282
  while (p2idx < jaxpr.eqns.length) {
2047
2283
  const eqn = jaxpr.eqns[p2idx++];
2048
2284
  const deps = [];
2049
- if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
2050
2285
  for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
2051
2286
  else deps.push(p2Deps.get(input));
2052
2287
  else deps.push(/* @__PURE__ */ new Set());
@@ -2057,7 +2292,7 @@ function splitGraphDataflow(backend, jaxpr) {
2057
2292
  let assocInput = -1;
2058
2293
  for (let i = 0; i < eqn.inputs.length; i++) {
2059
2294
  const input = eqn.inputs[i];
2060
- if (input instanceof Var && varToEqn.has(input)) {
2295
+ if (input instanceof Var && varToDefn.has(input)) {
2061
2296
  let uniqueDeps = 0;
2062
2297
  for (const dep of deps[i]) if (depCounter.get(dep) === 1) uniqueDeps++;
2063
2298
  if (uniqueDeps > maxUniqueDeps) {
@@ -2068,8 +2303,8 @@ function splitGraphDataflow(backend, jaxpr) {
2068
2303
  }
2069
2304
  if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
2070
2305
  const assocVar = eqn.inputs[assocInput];
2071
- p2idx = varToEqn.get(assocVar);
2072
- for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
2306
+ p2idx = varToDefn.get(assocVar);
2307
+ for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
2073
2308
  } else {
2074
2309
  const s = new Set(depCounter.keys());
2075
2310
  for (const out of eqn.outBinders) p2Deps.set(out, s);
@@ -2095,9 +2330,9 @@ var PendingExecute = class {
2095
2330
  submitted = false;
2096
2331
  #promise = null;
2097
2332
  #rc = 1;
2098
- constructor(backend, kernel, inputs, outputs) {
2333
+ constructor(backend, source, inputs, outputs) {
2099
2334
  this.backend = backend;
2100
- this.kernel = kernel;
2335
+ this.source = source;
2101
2336
  this.inputs = inputs;
2102
2337
  this.outputs = outputs;
2103
2338
  for (const slot of inputs) this.backend.incRef(slot);
@@ -2118,13 +2353,15 @@ var PendingExecute = class {
2118
2353
  return;
2119
2354
  }
2120
2355
  this.#promise = (async () => {
2121
- this.prepared = await this.backend.prepare(this.kernel);
2356
+ if (this.source instanceof require_backend.Kernel) this.prepared = await this.backend.prepareKernel(this.source);
2357
+ else this.prepared = await this.backend.prepareRoutine(this.source);
2122
2358
  })();
2123
2359
  await this.#promise;
2124
2360
  }
2125
2361
  prepareSync() {
2126
2362
  if (this.prepared) return;
2127
- this.prepared = this.backend.prepareSync(this.kernel);
2363
+ if (this.source instanceof require_backend.Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
2364
+ else this.prepared = this.backend.prepareRoutineSync(this.source);
2128
2365
  }
2129
2366
  submit() {
2130
2367
  if (this.submitted) return;
@@ -2147,8 +2384,6 @@ var PendingExecute = class {
2147
2384
  * "Array" type by name.
2148
2385
  */
2149
2386
  var Array$1 = class Array$1 extends Tracer {
2150
- static #nextId = 1001;
2151
- id;
2152
2387
  #dtype;
2153
2388
  #weakType;
2154
2389
  #source;
@@ -2165,7 +2400,6 @@ var Array$1 = class Array$1 extends Tracer {
2165
2400
  */
2166
2401
  constructor(args) {
2167
2402
  super(baseArrayTrace);
2168
- this.id = Array$1.#nextId++;
2169
2403
  this.#dtype = args.dtype;
2170
2404
  this.#weakType = args.weakType;
2171
2405
  this.#source = args.source;
@@ -2474,6 +2708,27 @@ var Array$1 = class Array$1 extends Tracer {
2474
2708
  pending
2475
2709
  });
2476
2710
  }
2711
+ /** Apply an operation with custom lowering to this array. */
2712
+ static #routine(routine, arrays, outputWeakType) {
2713
+ const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2714
+ for (const ar of arrays) ar.#realize();
2715
+ const inputs = arrays.map((ar) => ar.#source);
2716
+ const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(require_backend.byteWidth(dtype) * require_backend.prod(routine.type.outputShapes[i])));
2717
+ const pending = arrays.flatMap((ar) => ar.#pending);
2718
+ for (const exe of pending) exe.updateRc(+outputs.length);
2719
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2720
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2721
+ arrays.forEach((ar) => ar.dispose());
2722
+ return outputs.map((output, i) => new Array$1({
2723
+ source: output,
2724
+ st: require_backend.ShapeTracker.fromShape(routine.type.outputShapes[i]),
2725
+ dtype: routine.type.outputDtypes[i],
2726
+ weakType: outputWeakType[i],
2727
+ backend,
2728
+ committed,
2729
+ pending
2730
+ }));
2731
+ }
2477
2732
  /**
2478
2733
  * Normalizes this array into one backed by a `Slot`.
2479
2734
  *
@@ -2634,6 +2889,12 @@ var Array$1 = class Array$1 extends Tracer {
2634
2889
  [Primitive.Mod]([x, y]) {
2635
2890
  return [x.#binary(require_backend.AluOp.Mod, y)];
2636
2891
  },
2892
+ [Primitive.Min]([x, y]) {
2893
+ return [x.#binary(require_backend.AluOp.Min, y)];
2894
+ },
2895
+ [Primitive.Max]([x, y]) {
2896
+ return [x.#binary(require_backend.AluOp.Max, y)];
2897
+ },
2637
2898
  [Primitive.Neg]([x]) {
2638
2899
  return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
2639
2900
  },
@@ -2670,25 +2931,6 @@ var Array$1 = class Array$1 extends Tracer {
2670
2931
  return [y];
2671
2932
  }
2672
2933
  },
2673
- [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2674
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2675
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2676
- const c0 = zeros(shape$1, {
2677
- dtype: require_backend.DType.Uint32,
2678
- device: k0.device
2679
- });
2680
- const c1 = arange(0, require_backend.prod(shape$1), 1, {
2681
- dtype: require_backend.DType.Uint32,
2682
- device: k0.device
2683
- }).reshape(shape$1);
2684
- const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2685
- return [Array$1.#naryCustom("random_bits", custom, [
2686
- k0,
2687
- k1,
2688
- c0,
2689
- c1
2690
- ])];
2691
- },
2692
2934
  [Primitive.Sin]([x]) {
2693
2935
  return [x.#unary(require_backend.AluOp.Sin)];
2694
2936
  },
@@ -2716,12 +2958,6 @@ var Array$1 = class Array$1 extends Tracer {
2716
2958
  [Primitive.Sqrt]([x]) {
2717
2959
  return [x.#unary(require_backend.AluOp.Sqrt)];
2718
2960
  },
2719
- [Primitive.Min]([x, y]) {
2720
- return [x.#binary(require_backend.AluOp.Min, y)];
2721
- },
2722
- [Primitive.Max]([x, y]) {
2723
- return [x.#binary(require_backend.AluOp.Max, y)];
2724
- },
2725
2961
  [Primitive.Reduce]([x], { op, axis }) {
2726
2962
  if (axis.length === 0) return [x];
2727
2963
  return [x.#moveAxesDown(axis).#reduce(op)];
@@ -2756,6 +2992,28 @@ var Array$1 = class Array$1 extends Tracer {
2756
2992
  y
2757
2993
  ], { dtypeOverride: [require_backend.DType.Bool] })];
2758
2994
  },
2995
+ [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2996
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2997
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2998
+ const c0 = zeros(shape$1, {
2999
+ dtype: require_backend.DType.Uint32,
3000
+ device: k0.device
3001
+ });
3002
+ const c1 = arange(0, require_backend.prod(shape$1), 1, {
3003
+ dtype: require_backend.DType.Uint32,
3004
+ device: k0.device
3005
+ }).reshape(shape$1);
3006
+ const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
3007
+ return [Array$1.#naryCustom("random_bits", custom, [
3008
+ k0,
3009
+ k1,
3010
+ c0,
3011
+ c1
3012
+ ])];
3013
+ },
3014
+ [Primitive.Gather]([x, ...indices], { axis, outDim }) {
3015
+ return [x.#gather(indices, axis, outDim)];
3016
+ },
2759
3017
  [Primitive.Transpose]([x], { perm }) {
2760
3018
  return [x.#transpose(perm)];
2761
3019
  },
@@ -2776,17 +3034,48 @@ var Array$1 = class Array$1 extends Tracer {
2776
3034
  [Primitive.Pad]([x], { width }) {
2777
3035
  return [x.#reshape(x.#st.pad(width))];
2778
3036
  },
2779
- [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2780
- return [x.#gather(indices, axis, outDim)];
3037
+ [Primitive.Sort]([x]) {
3038
+ const routine = new require_backend.Routine(require_backend.Routines.Sort, {
3039
+ inputShapes: [x.aval.shape],
3040
+ inputDtypes: [x.aval.dtype],
3041
+ outputShapes: [x.aval.shape],
3042
+ outputDtypes: [x.aval.dtype]
3043
+ });
3044
+ return Array$1.#routine(routine, [x], [x.#weakType]);
3045
+ },
3046
+ [Primitive.Argsort]([x]) {
3047
+ const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
3048
+ inputShapes: [x.aval.shape],
3049
+ inputDtypes: [x.aval.dtype],
3050
+ outputShapes: [x.aval.shape, x.aval.shape],
3051
+ outputDtypes: [x.aval.dtype, require_backend.DType.Int32]
3052
+ });
3053
+ return Array$1.#routine(routine, [x], [x.#weakType, false]);
3054
+ },
3055
+ [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3056
+ const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
3057
+ inputShapes: [a.aval.shape, b.aval.shape],
3058
+ inputDtypes: [a.aval.dtype, b.aval.dtype],
3059
+ outputShapes: [b.aval.shape],
3060
+ outputDtypes: [b.aval.dtype]
3061
+ }, { unitDiagonal });
3062
+ return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
2781
3063
  },
2782
- [Primitive.JitCall](args, { jaxpr, numConsts }) {
2783
- if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
2784
- const { backend, committed } = Array$1.#computeBackend("jit_call", args);
3064
+ [Primitive.Cholesky]([a]) {
3065
+ const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
3066
+ inputShapes: [a.aval.shape],
3067
+ inputDtypes: [a.aval.dtype],
3068
+ outputShapes: [a.aval.shape],
3069
+ outputDtypes: [a.aval.dtype]
3070
+ });
3071
+ return Array$1.#routine(routine, [a], [a.#weakType]);
3072
+ },
3073
+ [Primitive.Jit](args, { jaxpr }) {
3074
+ if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3075
+ const { backend, committed } = Array$1.#computeBackend("jit", args);
2785
3076
  args = args.map((ar) => ar._putSync(backend));
2786
- const consts = args.slice(0, numConsts);
2787
- const tracers = args.slice(numConsts);
2788
- const jp = jitCompile(backend, jaxpr, consts);
2789
- const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
3077
+ const jp = jitCompile(backend, jaxpr);
3078
+ const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
2790
3079
  for (const exe of pending) exe.updateRc(+outputs.length - 1);
2791
3080
  const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
2792
3081
  for (const exe of prevPending) exe.updateRc(+outputs.length);
@@ -3085,6 +3374,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
3085
3374
  });
3086
3375
  }
3087
3376
  /**
3377
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
3378
+ *
3379
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
3380
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
3381
+ * `k>0` is above it.
3382
+ */
3383
+ function tri(n, m, k = 0, { dtype, device } = {}) {
3384
+ m ??= n;
3385
+ dtype ??= require_backend.DType.Float32;
3386
+ if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
3387
+ if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
3388
+ if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
3389
+ const rows = arange(k, n + k, 1, {
3390
+ dtype: require_backend.DType.Int32,
3391
+ device
3392
+ });
3393
+ const cols = arange(0, m, 1, {
3394
+ dtype: require_backend.DType.Int32,
3395
+ device
3396
+ });
3397
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
3398
+ }
3399
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
3400
+ function tril(a, k = 0) {
3401
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3402
+ a = fudgeArray(a);
3403
+ const [n, m] = a.shape.slice(-2);
3404
+ return where$1(tri(n, m, k, { dtype: require_backend.DType.Bool }), a.ref, zerosLike$1(a));
3405
+ }
3406
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
3407
+ function triu(a, k = 0) {
3408
+ if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
3409
+ a = fudgeArray(a);
3410
+ const [n, m] = a.shape.slice(-2);
3411
+ return where$1(tri(n, m, k - 1, { dtype: require_backend.DType.Bool }), zerosLike$1(a.ref), a);
3412
+ }
3413
+ /**
3088
3414
  * Return evenly spaced numbers over a specified interval.
3089
3415
  *
3090
3416
  * Returns _num_ evenly spaced samples, calculated over the interval
@@ -3131,335 +3457,107 @@ function aluCompare(a, b, op) {
3131
3457
  }
3132
3458
 
3133
3459
  //#endregion
3134
- //#region src/frontend/jvp.ts
3460
+ //#region src/frontend/vmap.ts
3135
3461
  var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3136
- var JVPTracer = class extends Tracer {
3137
- constructor(trace$1, primal, tangent) {
3462
+ function mappedAval(batchDim, aval) {
3463
+ const shape$1 = [...aval.shape];
3464
+ shape$1.splice(batchDim, 1);
3465
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3466
+ }
3467
+ /** Move one axis to a different index. */
3468
+ function moveaxis(x, src, dst) {
3469
+ const t = pureArray(x);
3470
+ src = require_backend.checkAxis(src, t.ndim);
3471
+ dst = require_backend.checkAxis(dst, t.ndim);
3472
+ if (src === dst) return t;
3473
+ const perm = require_backend.range(t.ndim);
3474
+ perm.splice(src, 1);
3475
+ perm.splice(dst, 0, src);
3476
+ return transpose$1(t, perm);
3477
+ }
3478
+ function moveBatchAxis(axisSize, src, dst, x) {
3479
+ if (src === null) {
3480
+ const targetShape = [...x.shape];
3481
+ targetShape.splice(dst, 0, axisSize);
3482
+ return broadcast(x, targetShape, [dst]);
3483
+ } else if (src === dst) return x;
3484
+ else return moveaxis(x, src, dst);
3485
+ }
3486
+ var BatchTracer = class extends Tracer {
3487
+ constructor(trace$1, val, batchDim) {
3138
3488
  super(trace$1);
3139
- this.primal = primal;
3140
- this.tangent = tangent;
3489
+ this.val = val;
3490
+ this.batchDim = batchDim;
3141
3491
  }
3142
3492
  get aval() {
3143
- return this.primal.aval;
3493
+ if (this.batchDim === null) return this.val.aval;
3494
+ else return mappedAval(this.batchDim, this.val.aval);
3144
3495
  }
3145
3496
  toString() {
3146
- return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3497
+ return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3147
3498
  }
3148
3499
  get ref() {
3149
- this.primal.ref, this.tangent.ref;
3500
+ this.val.ref;
3150
3501
  return this;
3151
3502
  }
3152
3503
  dispose() {
3153
- this.primal.dispose();
3154
- this.tangent.dispose();
3504
+ this.val.dispose();
3505
+ }
3506
+ fullLower() {
3507
+ if (this.batchDim === null) return this.val.fullLower();
3508
+ else return this;
3155
3509
  }
3156
3510
  };
3157
- var JVPTrace = class extends Trace {
3511
+ var BatchTrace = class extends Trace {
3158
3512
  pure(val) {
3159
3513
  return this.lift(pureArray(val));
3160
3514
  }
3161
3515
  lift(val) {
3162
- return new JVPTracer(this, val, zerosLike$1(val.ref));
3516
+ return new BatchTracer(this, val, null);
3163
3517
  }
3164
3518
  processPrimitive(primitive, tracers, params) {
3165
- const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
3166
- const jvpRule = jvpRules[primitive];
3167
- if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3168
- const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3169
- return require_backend.zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3519
+ const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
3520
+ const vmapRule = vmapRules[primitive];
3521
+ if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3522
+ if (bdimsIn.every((d) => d === null)) {
3523
+ const valOuts$1 = bind(primitive, valsIn, params);
3524
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3525
+ }
3526
+ const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3527
+ return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3528
+ }
3529
+ get axisSize() {
3530
+ return this.main.globalData;
3170
3531
  }
3171
3532
  };
3172
- /** Rule that applies the same operation to primals and tangents. */
3173
- function linearTangentsJvp(primitive) {
3174
- return (primals, tangents, params) => {
3175
- const ys = bind(primitive, primals, params);
3176
- const dys = bind(primitive, tangents, params);
3177
- return [ys, dys];
3178
- };
3179
- }
3180
- /** Rule for product of gradients in bilinear operations. */
3181
- function bilinearTangentsJvp(primitive) {
3182
- return ([x, y], [dx, dy], params) => {
3183
- const primal = bind1(primitive, [x.ref, y.ref], params);
3184
- const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3185
- return [[primal], [tangent]];
3533
+ /**
3534
+ * Process a primitive with built-in broadcasting.
3535
+ *
3536
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3537
+ */
3538
+ function broadcastBatcher(op) {
3539
+ return (axisSize, args, dims) => {
3540
+ if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3541
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3542
+ const firstIdx = dims.findIndex((d) => d !== null);
3543
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3544
+ if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3545
+ args = args.map((x, i) => {
3546
+ if (dims[i] === null) return x;
3547
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3548
+ if (x.ndim < nd) x = x.reshape([
3549
+ x.shape[0],
3550
+ ...require_backend.rep(nd - x.ndim, 1),
3551
+ ...x.shape.slice(1)
3552
+ ]);
3553
+ return x;
3554
+ });
3555
+ return [[op(...args)], [0]];
3186
3556
  };
3187
3557
  }
3188
- /** Rule that zeros out any tangents. */
3189
- function zeroTangentsJvp(primitive) {
3190
- return (primals, tangents, params) => {
3191
- for (const t of tangents) t.dispose();
3192
- const ys = bind(primitive, primals, params);
3193
- return [ys, ys.map((y) => zerosLike$1(y.ref))];
3194
- };
3195
- }
3196
- const jvpRules = {
3197
- [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3198
- [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3199
- [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3200
- [Primitive.Mod]([x, y], [dx, dy]) {
3201
- if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
3202
- dx.dispose();
3203
- dy.dispose();
3204
- return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3205
- }
3206
- const q = idiv(x.ref, y.ref);
3207
- return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3208
- },
3209
- [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3210
- [Primitive.Reciprocal]([x], [dx]) {
3211
- const xRecip = reciprocal$1(x.ref);
3212
- return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3213
- },
3214
- [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3215
- [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3216
- [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3217
- [Primitive.Cast]([x], [dx], { dtype }) {
3218
- if (x.dtype === dtype) return [[x], [dx]];
3219
- if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3220
- else {
3221
- dx.dispose();
3222
- return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3223
- }
3224
- },
3225
- [Primitive.Bitcast]([x], [dx], { dtype }) {
3226
- if (x.dtype === dtype) return [[x], [dx]];
3227
- dx.dispose();
3228
- return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3229
- },
3230
- [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3231
- [Primitive.Sin]([x], [dx]) {
3232
- return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3233
- },
3234
- [Primitive.Cos]([x], [dx]) {
3235
- return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3236
- },
3237
- [Primitive.Asin]([x], [dx]) {
3238
- const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3239
- return [[asin$1(x)], [denom.mul(dx)]];
3240
- },
3241
- [Primitive.Atan]([x], [dx]) {
3242
- const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3243
- return [[atan$1(x)], [dx.div(denom)]];
3244
- },
3245
- [Primitive.Exp]([x], [dx]) {
3246
- const z = exp$1(x);
3247
- return [[z.ref], [z.mul(dx)]];
3248
- },
3249
- [Primitive.Log]([x], [dx]) {
3250
- return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3251
- },
3252
- [Primitive.Erf]([x], [dx]) {
3253
- const coeff = 2 / Math.sqrt(Math.PI);
3254
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3255
- return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3256
- },
3257
- [Primitive.Erfc]([x], [dx]) {
3258
- const coeff = -2 / Math.sqrt(Math.PI);
3259
- const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3260
- return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3261
- },
3262
- [Primitive.Sqrt]([x], [dx]) {
3263
- const z = sqrt$1(x);
3264
- return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3265
- },
3266
- [Primitive.Min]([x, y], [dx, dy]) {
3267
- return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3268
- },
3269
- [Primitive.Max]([x, y], [dx, dy]) {
3270
- return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3271
- },
3272
- [Primitive.Reduce]([x], [dx], { op, axis }) {
3273
- if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3274
- else if (op === require_backend.AluOp.Mul) {
3275
- const primal = reduce(x.ref, op, axis);
3276
- const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3277
- return [[primal], [tangent]];
3278
- } else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
3279
- const primal = reduce(x.ref, op, axis);
3280
- const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3281
- const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3282
- const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3283
- return [[primal], [tangent]];
3284
- } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3285
- },
3286
- [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3287
- [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3288
- [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3289
- [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3290
- [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3291
- [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3292
- dcond.dispose();
3293
- return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3294
- },
3295
- [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3296
- [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3297
- [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3298
- [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3299
- [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3300
- [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3301
- [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3302
- const indicesRef = indices.map((t) => t.ref);
3303
- return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3304
- },
3305
- [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3306
- const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3307
- const outs = bind(Primitive.JitCall, [
3308
- ...newConsts.map((c) => c.ref),
3309
- ...primals,
3310
- ...tangents
3311
- ], {
3312
- name: `${name}_jvp`,
3313
- jaxpr: newJaxpr,
3314
- numConsts: newConsts.length
3315
- });
3316
- const n = outs.length / 2;
3317
- if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
3318
- const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
3319
- return [primalsOut, tangentsOut];
3320
- }
3321
- };
3322
- const jvpJaxprCache = /* @__PURE__ */ new Map();
3323
- function jvpJaxpr(jaxpr) {
3324
- if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
3325
- const inAvals = jaxpr.inBinders.map((v) => v.aval);
3326
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
3327
- const result = {
3328
- newJaxpr,
3329
- newConsts
3330
- };
3331
- jvpJaxprCache.set(jaxpr, result);
3332
- return result;
3333
- }
3334
- function jvpFlat(f, primals, tangents) {
3335
- try {
3336
- var _usingCtx$1 = (0, import_usingCtx$1.default)();
3337
- const main = _usingCtx$1.u(newMain(JVPTrace));
3338
- const trace$1 = new JVPTrace(main);
3339
- const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
3340
- const outs = f(...tracersIn);
3341
- const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3342
- return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
3343
- } catch (_) {
3344
- _usingCtx$1.e = _;
3345
- } finally {
3346
- _usingCtx$1.d();
3347
- }
3348
- }
3349
- function jvp$1(f, primals, tangents) {
3350
- const [primalsFlat, inTree] = flatten(primals);
3351
- const [tangentsFlat, inTree2] = flatten(tangents);
3352
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
3353
- const [flatFun, outTree] = flattenFun(f, inTree);
3354
- const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
3355
- if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
3356
- const primalsOut = unflatten(outTree.value, primalsOutFlat);
3357
- const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
3358
- return [primalsOut, tangentsOut];
3359
- }
3360
-
3361
- //#endregion
3362
- //#region src/frontend/vmap.ts
3363
- var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3364
- function mappedAval(batchDim, aval) {
3365
- const shape$1 = [...aval.shape];
3366
- shape$1.splice(batchDim, 1);
3367
- return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3368
- }
3369
- /** Move one axis to a different index. */
3370
- function moveaxis(x, src, dst) {
3371
- const t = pureArray(x);
3372
- src = require_backend.checkAxis(src, t.ndim);
3373
- dst = require_backend.checkAxis(dst, t.ndim);
3374
- if (src === dst) return t;
3375
- const perm = require_backend.range(t.ndim);
3376
- perm.splice(src, 1);
3377
- perm.splice(dst, 0, src);
3378
- return transpose$1(t, perm);
3379
- }
3380
- function moveBatchAxis(axisSize, src, dst, x) {
3381
- if (src === null) {
3382
- const targetShape = [...x.shape];
3383
- targetShape.splice(dst, 0, axisSize);
3384
- return broadcast(x, targetShape, [dst]);
3385
- } else if (src === dst) return x;
3386
- else return moveaxis(x, src, dst);
3387
- }
3388
- var BatchTracer = class extends Tracer {
3389
- constructor(trace$1, val, batchDim) {
3390
- super(trace$1);
3391
- this.val = val;
3392
- this.batchDim = batchDim;
3393
- }
3394
- get aval() {
3395
- if (this.batchDim === null) return this.val.aval;
3396
- else return mappedAval(this.batchDim, this.val.aval);
3397
- }
3398
- toString() {
3399
- return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
3400
- }
3401
- get ref() {
3402
- this.val.ref;
3403
- return this;
3404
- }
3405
- dispose() {
3406
- this.val.dispose();
3407
- }
3408
- fullLower() {
3409
- if (this.batchDim === null) return this.val.fullLower();
3410
- else return this;
3411
- }
3412
- };
3413
- var BatchTrace = class extends Trace {
3414
- pure(val) {
3415
- return this.lift(pureArray(val));
3416
- }
3417
- lift(val) {
3418
- return new BatchTracer(this, val, null);
3419
- }
3420
- processPrimitive(primitive, tracers, params) {
3421
- const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
3422
- const vmapRule = vmapRules[primitive];
3423
- if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3424
- if (bdimsIn.every((d) => d === null)) {
3425
- const valOuts$1 = bind(primitive, valsIn, params);
3426
- return valOuts$1.map((x) => new BatchTracer(this, x, null));
3427
- }
3428
- const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3429
- return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3430
- }
3431
- get axisSize() {
3432
- return this.main.globalData;
3433
- }
3434
- };
3435
- /**
3436
- * Process a primitive with built-in broadcasting.
3437
- *
3438
- * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3439
- */
3440
- function broadcastBatcher(op) {
3441
- return (axisSize, args, dims) => {
3442
- if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3443
- const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3444
- const firstIdx = dims.findIndex((d) => d !== null);
3445
- const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3446
- if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3447
- args = args.map((x, i) => {
3448
- if (dims[i] === null) return x;
3449
- x = moveBatchAxis(axisSize, dims[i], 0, x);
3450
- if (x.ndim < nd) x = x.reshape([
3451
- x.shape[0],
3452
- ...require_backend.rep(nd - x.ndim, 1),
3453
- ...x.shape.slice(1)
3454
- ]);
3455
- return x;
3456
- });
3457
- return [[op(...args)], [0]];
3458
- };
3459
- }
3460
- function unopBatcher(op) {
3461
- return (axisSize, [x], [xBdim], params) => {
3462
- return [[op(x, params)], [xBdim]];
3558
+ function unopBatcher(op) {
3559
+ return (axisSize, [x], [xBdim], params) => {
3560
+ return [[op(x, params)], [xBdim]];
3463
3561
  };
3464
3562
  }
3465
3563
  const vmapRules = {
@@ -3467,6 +3565,8 @@ const vmapRules = {
3467
3565
  [Primitive.Mul]: broadcastBatcher(mul),
3468
3566
  [Primitive.Idiv]: broadcastBatcher(idiv),
3469
3567
  [Primitive.Mod]: broadcastBatcher(mod),
3568
+ [Primitive.Min]: broadcastBatcher(min$1),
3569
+ [Primitive.Max]: broadcastBatcher(max$1),
3470
3570
  [Primitive.Neg]: unopBatcher(neg),
3471
3571
  [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3472
3572
  [Primitive.Floor]: unopBatcher(floor$1),
@@ -3483,8 +3583,6 @@ const vmapRules = {
3483
3583
  [Primitive.Erf]: unopBatcher(erf$1),
3484
3584
  [Primitive.Erfc]: unopBatcher(erfc$1),
3485
3585
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
3486
- [Primitive.Min]: broadcastBatcher(min$1),
3487
- [Primitive.Max]: broadcastBatcher(max$1),
3488
3586
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3489
3587
  require_backend.assertNonNull(xBdim);
3490
3588
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
@@ -3497,10 +3595,49 @@ const vmapRules = {
3497
3595
  const z = dot$2(x, y);
3498
3596
  return [[z], [z.ndim - 1]];
3499
3597
  },
3598
+ [Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
3599
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3600
+ y = moveBatchAxis(axisSize, yBdim, 0, y);
3601
+ const z = conv$1(x, y, {
3602
+ ...params,
3603
+ vmapDims: params.vmapDims + 1
3604
+ });
3605
+ return [[z], [0]];
3606
+ },
3500
3607
  [Primitive.Compare](axisSize, args, dims, { op }) {
3501
3608
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3502
3609
  },
3503
- [Primitive.Where]: broadcastBatcher(where$1),
3610
+ [Primitive.Where]: broadcastBatcher(where$1),
3611
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3612
+ if (indicesBdim.every((d) => d === null)) {
3613
+ require_backend.assertNonNull(xBdim);
3614
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3615
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3616
+ let newOutDim = outDim;
3617
+ if (newOutDim < newBdim) newBdim += axis.length;
3618
+ else newOutDim += 1;
3619
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3620
+ }
3621
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3622
+ indices = indices.map((m, i) => {
3623
+ if (indicesBdim[i] === null) return m;
3624
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3625
+ if (m.ndim < nd) m = m.reshape([
3626
+ m.shape[0],
3627
+ ...require_backend.rep(nd - m.ndim, 1),
3628
+ ...m.shape.slice(1)
3629
+ ]);
3630
+ return m;
3631
+ });
3632
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3633
+ else {
3634
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3635
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3636
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
3637
+ indices.splice(0, 0, extraBatchIndex);
3638
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3639
+ }
3640
+ },
3504
3641
  [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3505
3642
  require_backend.assertNonNull(xBdim);
3506
3643
  const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
@@ -3532,42 +3669,53 @@ const vmapRules = {
3532
3669
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3533
3670
  return [[pad$1(x, newWidth)], [xBdim]];
3534
3671
  },
3535
- [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3536
- if (indicesBdim.every((d) => d === null)) {
3537
- require_backend.assertNonNull(xBdim);
3538
- const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3539
- let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3540
- let newOutDim = outDim;
3541
- if (newOutDim < newBdim) newBdim += axis.length;
3542
- else newOutDim += 1;
3543
- return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3544
- }
3545
- const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3546
- indices = indices.map((m, i) => {
3547
- if (indicesBdim[i] === null) return m;
3548
- m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3549
- if (m.ndim < nd) m = m.reshape([
3550
- m.shape[0],
3551
- ...require_backend.rep(nd - m.ndim, 1),
3552
- ...m.shape.slice(1)
3672
+ [Primitive.Sort](axisSize, [x], [xBdim]) {
3673
+ require_backend.assertNonNull(xBdim);
3674
+ if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
3675
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3676
+ return [[sort$1(x)], [0]];
3677
+ },
3678
+ [Primitive.Argsort](axisSize, [x], [xBdim]) {
3679
+ require_backend.assertNonNull(xBdim);
3680
+ if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
3681
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3682
+ return [argsort$1(x), [0, 0]];
3683
+ },
3684
+ [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3685
+ if (aBdim === null) {
3686
+ b = moveBatchAxis(axisSize, bBdim, -3, b);
3687
+ const [s, m, n] = b.shape.slice(-3);
3688
+ b = b.reshape([
3689
+ ...b.shape.slice(0, -3),
3690
+ s * m,
3691
+ n
3553
3692
  ]);
3554
- return m;
3555
- });
3556
- if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3557
- else {
3558
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3559
- const newAxis = [0, ...axis.map((ax) => ax + 1)];
3560
- const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
3561
- indices.splice(0, 0, extraBatchIndex);
3562
- return [[gather(x, indices, newAxis, outDim)], [outDim]];
3693
+ let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3694
+ x$1 = x$1.reshape([
3695
+ ...b.shape.slice(0, -2),
3696
+ s,
3697
+ m,
3698
+ n
3699
+ ]);
3700
+ return [[x$1], [x$1.ndim - 3]];
3563
3701
  }
3702
+ a = moveBatchAxis(axisSize, aBdim, 0, a);
3703
+ b = moveBatchAxis(axisSize, bBdim, 0, b);
3704
+ const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3705
+ return [[x], [0]];
3564
3706
  },
3565
- [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3566
- const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3567
- const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3707
+ [Primitive.Cholesky](axisSize, [x], [xBdim]) {
3708
+ require_backend.assertNonNull(xBdim);
3709
+ if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
3710
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3711
+ return [[cholesky$2(x)], [0]];
3712
+ },
3713
+ [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3714
+ const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3715
+ const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
3568
3716
  name: `${name}_vmap`,
3569
- jaxpr: newJaxpr,
3570
- numConsts: newConsts.length
3717
+ jaxpr: newJaxpr.jaxpr,
3718
+ numConsts: newJaxpr.consts.length
3571
3719
  });
3572
3720
  return [outs, require_backend.rep(outs.length, 0)];
3573
3721
  }
@@ -3583,14 +3731,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3583
3731
  shape$1.splice(dims[i], 0, axisSize);
3584
3732
  return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3585
3733
  });
3586
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3587
- const result = {
3588
- newJaxpr,
3589
- newConsts
3590
- };
3734
+ const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3591
3735
  if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
3592
- vmapJaxprCache.get(jaxpr).set(cacheKey, result);
3593
- return result;
3736
+ vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
3737
+ return newJaxpr;
3594
3738
  }
3595
3739
  function vmapFlat(f, inAxes, args) {
3596
3740
  let axisSize = void 0;
@@ -3604,7 +3748,7 @@ function vmapFlat(f, inAxes, args) {
3604
3748
  if (axisSize === void 0) throw new TypeError("vmap requires at least one mapped axis");
3605
3749
  let valsOut, bdimsOut;
3606
3750
  try {
3607
- var _usingCtx$1 = (0, import_usingCtx.default)();
3751
+ var _usingCtx$1 = (0, import_usingCtx$1.default)();
3608
3752
  const main = _usingCtx$1.u(newMain(BatchTrace, axisSize));
3609
3753
  const trace$1 = new BatchTrace(main);
3610
3754
  const tracersIn = args.map((x, i) => inAxes[i] === null ? pureArray(x) : new BatchTracer(trace$1, pureArray(x), inAxes[i]));
@@ -3645,6 +3789,261 @@ function jacfwd$1(f) {
3645
3789
  };
3646
3790
  }
3647
3791
 
3792
+ //#endregion
3793
+ //#region src/frontend/jvp.ts
3794
+ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3795
+ var JVPTracer = class extends Tracer {
3796
+ constructor(trace$1, primal, tangent) {
3797
+ super(trace$1);
3798
+ this.primal = primal;
3799
+ this.tangent = tangent;
3800
+ }
3801
+ get aval() {
3802
+ return this.primal.aval;
3803
+ }
3804
+ toString() {
3805
+ return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
3806
+ }
3807
+ get ref() {
3808
+ this.primal.ref, this.tangent.ref;
3809
+ return this;
3810
+ }
3811
+ dispose() {
3812
+ this.primal.dispose();
3813
+ this.tangent.dispose();
3814
+ }
3815
+ };
3816
+ var JVPTrace = class extends Trace {
3817
+ pure(val) {
3818
+ return this.lift(pureArray(val));
3819
+ }
3820
+ lift(val) {
3821
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
3822
+ }
3823
+ processPrimitive(primitive, tracers, params) {
3824
+ const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
3825
+ const jvpRule = jvpRules[primitive];
3826
+ if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
3827
+ const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
3828
+ return require_backend.zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
3829
+ }
3830
+ };
3831
+ /** Rule that applies the same operation to primals and tangents. */
3832
+ function linearTangentsJvp(primitive) {
3833
+ return (primals, tangents, params) => {
3834
+ const ys = bind(primitive, primals, params);
3835
+ const dys = bind(primitive, tangents, params);
3836
+ return [ys, dys];
3837
+ };
3838
+ }
3839
+ /** Rule for product of gradients in bilinear operations. */
3840
+ function bilinearTangentsJvp(primitive) {
3841
+ return ([x, y], [dx, dy], params) => {
3842
+ const primal = bind1(primitive, [x.ref, y.ref], params);
3843
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
3844
+ return [[primal], [tangent]];
3845
+ };
3846
+ }
3847
+ /** Rule that zeros out any tangents. */
3848
+ function zeroTangentsJvp(primitive) {
3849
+ return (primals, tangents, params) => {
3850
+ for (const t of tangents) t.dispose();
3851
+ const ys = bind(primitive, primals, params);
3852
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
3853
+ };
3854
+ }
3855
+ /** Compute `a @ b.T`, batched to last two axes. */
3856
+ function batchMatmulT(a, b) {
3857
+ return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
3858
+ }
3859
+ /** Batch matrix transpose. */
3860
+ function mT(a) {
3861
+ return moveaxis(a, -2, -1);
3862
+ }
3863
+ const jvpRules = {
3864
+ [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3865
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
3866
+ [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
3867
+ [Primitive.Mod]([x, y], [dx, dy]) {
3868
+ if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
3869
+ dx.dispose();
3870
+ dy.dispose();
3871
+ return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
3872
+ }
3873
+ const q = idiv(x.ref, y.ref);
3874
+ return [[mod(x, y)], [dx.sub(dy.mul(q))]];
3875
+ },
3876
+ [Primitive.Min]([x, y], [dx, dy]) {
3877
+ return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
3878
+ },
3879
+ [Primitive.Max]([x, y], [dx, dy]) {
3880
+ return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
3881
+ },
3882
+ [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
3883
+ [Primitive.Reciprocal]([x], [dx]) {
3884
+ const xRecip = reciprocal$1(x.ref);
3885
+ return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
3886
+ },
3887
+ [Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
3888
+ [Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
3889
+ [Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
3890
+ [Primitive.Cast]([x], [dx], { dtype }) {
3891
+ if (x.dtype === dtype) return [[x], [dx]];
3892
+ if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
3893
+ else {
3894
+ dx.dispose();
3895
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
3896
+ }
3897
+ },
3898
+ [Primitive.Bitcast]([x], [dx], { dtype }) {
3899
+ if (x.dtype === dtype) return [[x], [dx]];
3900
+ dx.dispose();
3901
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
3902
+ },
3903
+ [Primitive.Sin]([x], [dx]) {
3904
+ return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
3905
+ },
3906
+ [Primitive.Cos]([x], [dx]) {
3907
+ return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
3908
+ },
3909
+ [Primitive.Asin]([x], [dx]) {
3910
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3911
+ return [[asin$1(x)], [denom.mul(dx)]];
3912
+ },
3913
+ [Primitive.Atan]([x], [dx]) {
3914
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3915
+ return [[atan$1(x)], [dx.div(denom)]];
3916
+ },
3917
+ [Primitive.Exp]([x], [dx]) {
3918
+ const z = exp$1(x);
3919
+ return [[z.ref], [z.mul(dx)]];
3920
+ },
3921
+ [Primitive.Log]([x], [dx]) {
3922
+ return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3923
+ },
3924
+ [Primitive.Erf]([x], [dx]) {
3925
+ const coeff = 2 / Math.sqrt(Math.PI);
3926
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3927
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3928
+ },
3929
+ [Primitive.Erfc]([x], [dx]) {
3930
+ const coeff = -2 / Math.sqrt(Math.PI);
3931
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3932
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3933
+ },
3934
+ [Primitive.Sqrt]([x], [dx]) {
3935
+ const z = sqrt$1(x);
3936
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3937
+ },
3938
+ [Primitive.Reduce]([x], [dx], { op, axis }) {
3939
+ if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
3940
+ else if (op === require_backend.AluOp.Mul) {
3941
+ const primal = reduce(x.ref, op, axis);
3942
+ const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
3943
+ return [[primal], [tangent]];
3944
+ } else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
3945
+ const primal = reduce(x.ref, op, axis);
3946
+ const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
3947
+ const minCount = where$1(notMin.ref, 0, 1).sum(axis);
3948
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
3949
+ return [[primal], [tangent]];
3950
+ } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
3951
+ },
3952
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3953
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3954
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3955
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
3956
+ [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
3957
+ [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
3958
+ dcond.dispose();
3959
+ return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3960
+ },
3961
+ [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3962
+ [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3963
+ const indicesRef = indices.map((t) => t.ref);
3964
+ return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3965
+ },
3966
+ [Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
3967
+ [Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
3968
+ [Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
3969
+ [Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
3970
+ [Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
3971
+ [Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
3972
+ [Primitive.Sort]([x], [dx]) {
3973
+ const [y, idx] = argsort$1(x);
3974
+ return [[y], [gather(dx, [idx], [-1], -1)]];
3975
+ },
3976
+ [Primitive.Argsort]([x], [dx]) {
3977
+ const [y, idx] = argsort$1(x);
3978
+ return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
3979
+ },
3980
+ [Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
3981
+ const x = triangularSolve$1(a.ref, b, { unitDiagonal });
3982
+ const dax = batchMatmulT(da, x.ref);
3983
+ const rhsT = db.sub(mT(dax));
3984
+ const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
3985
+ return [[x], [dx]];
3986
+ },
3987
+ [Primitive.Cholesky]([a], [da]) {
3988
+ const L = cholesky$2(a.ref);
3989
+ da = da.ref.add(mT(da)).mul(.5);
3990
+ const W = triangularSolve$1(L.ref, da, { lower: true });
3991
+ const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
3992
+ const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
3993
+ return [[L], [dL]];
3994
+ },
3995
+ [Primitive.Jit](primals, tangents, { name, jaxpr }) {
3996
+ const newJaxpr = jvpJaxpr(jaxpr);
3997
+ const outs = bind(Primitive.Jit, [
3998
+ ...newJaxpr.consts.map((c) => c.ref),
3999
+ ...primals,
4000
+ ...tangents
4001
+ ], {
4002
+ name: `${name}_jvp`,
4003
+ jaxpr: newJaxpr.jaxpr,
4004
+ numConsts: newJaxpr.consts.length
4005
+ });
4006
+ const n = outs.length / 2;
4007
+ if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
4008
+ const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
4009
+ return [primalsOut, tangentsOut];
4010
+ }
4011
+ };
4012
+ const jvpJaxprCache = /* @__PURE__ */ new Map();
4013
+ function jvpJaxpr(jaxpr) {
4014
+ if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
4015
+ const inAvals = jaxpr.inBinders.map((v) => v.aval);
4016
+ const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
4017
+ jvpJaxprCache.set(jaxpr, newJaxpr);
4018
+ return newJaxpr;
4019
+ }
4020
+ function jvpFlat(f, primals, tangents) {
4021
+ try {
4022
+ var _usingCtx$1 = (0, import_usingCtx.default)();
4023
+ const main = _usingCtx$1.u(newMain(JVPTrace));
4024
+ const trace$1 = new JVPTrace(main);
4025
+ const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
4026
+ const outs = f(...tracersIn);
4027
+ const tracersOut = outs.map((out) => fullRaise(trace$1, out));
4028
+ return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
4029
+ } catch (_) {
4030
+ _usingCtx$1.e = _;
4031
+ } finally {
4032
+ _usingCtx$1.d();
4033
+ }
4034
+ }
4035
+ function jvp$1(f, primals, tangents) {
4036
+ const [primalsFlat, inTree] = flatten(primals);
4037
+ const [tangentsFlat, inTree2] = flatten(tangents);
4038
+ if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4039
+ const [flatFun, outTree] = flattenFun(f, inTree);
4040
+ const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4041
+ if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4042
+ const primalsOut = unflatten(outTree.value, primalsOutFlat);
4043
+ const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4044
+ return [primalsOut, tangentsOut];
4045
+ }
4046
+
3648
4047
  //#endregion
3649
4048
  //#region src/frontend/linearize.ts
3650
4049
  /** Array value that can either be known or unknown. */
@@ -3675,11 +4074,10 @@ function partialEvalFlat(f, pvalsIn) {
3675
4074
  const tracersOut = outs.map((out) => fullRaise(trace$1, out));
3676
4075
  const pvalsOut = tracersOut.map((t) => t.pval);
3677
4076
  const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
3678
- const { jaxpr, consts } = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
4077
+ const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
3679
4078
  return {
3680
4079
  jaxpr,
3681
- pvalsOut,
3682
- consts
4080
+ pvalsOut
3683
4081
  };
3684
4082
  }
3685
4083
  /**
@@ -3696,22 +4094,19 @@ function linearizeFlatUtil(f, primalsIn) {
3696
4094
  const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
3697
4095
  return [...primalsOut$1, ...tangentsOut];
3698
4096
  };
3699
- const { jaxpr, pvalsOut, consts } = partialEvalFlat(fJvp, pvalsIn);
4097
+ const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
3700
4098
  const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
3701
4099
  if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
3702
4100
  const primalsOut = primalPvals.map((pval) => pval.val);
3703
4101
  return {
3704
4102
  primalsOut,
3705
- jaxpr,
3706
- consts
4103
+ jaxpr
3707
4104
  };
3708
4105
  }
3709
4106
  function linearizeFlat(f, primalsIn) {
3710
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3711
- const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3712
- const dispose$1 = () => {
3713
- for (const c of consts) c.dispose();
3714
- };
4107
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4108
+ const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
4109
+ const dispose$1 = () => jaxpr.dispose();
3715
4110
  return [
3716
4111
  primalsOut,
3717
4112
  fLin,
@@ -3795,7 +4190,7 @@ var PartialEvalTrace = class extends Trace {
3795
4190
  }
3796
4191
  processPrimitive(primitive, tracers, params) {
3797
4192
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3798
- if (primitive === Primitive.JitCall) {
4193
+ if (primitive === Primitive.Jit) {
3799
4194
  const { name, jaxpr, numConsts } = params;
3800
4195
  return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3801
4196
  }
@@ -3821,14 +4216,14 @@ var PartialEvalTrace = class extends Trace {
3821
4216
  * Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
3822
4217
  * values as possible (with JIT) and forwarding the unknown ones.
3823
4218
  *
3824
- * Used when encountering a JitCall rule during the trace.
4219
+ * Used when encountering a Jit rule during the trace.
3825
4220
  */
3826
4221
  #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3827
4222
  jaxpr = jaxpr.flatten();
3828
4223
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3829
4224
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3830
4225
  const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
3831
- const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
4226
+ const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
3832
4227
  name: `${name}_peval`,
3833
4228
  jaxpr: jaxpr1,
3834
4229
  numConsts: 0
@@ -3838,7 +4233,7 @@ var PartialEvalTrace = class extends Trace {
3838
4233
  const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
3839
4234
  const recipe = {
3840
4235
  type: "JaxprEqn",
3841
- prim: Primitive.JitCall,
4236
+ prim: Primitive.Jit,
3842
4237
  tracersIn: resTracers.concat(unknownTracers),
3843
4238
  params: {
3844
4239
  name: `${name}_resid`,
@@ -3867,7 +4262,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
3867
4262
  const eqns1 = [];
3868
4263
  const eqns2 = [];
3869
4264
  for (const eqn of jaxpr.eqns) {
3870
- if (eqn.primitive === Primitive.JitCall) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
4265
+ if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
3871
4266
  const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
3872
4267
  if (hasUnknowns) {
3873
4268
  for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
@@ -3941,11 +4336,8 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3941
4336
  for (const t of tracersIn) t.dispose();
3942
4337
  for (const t of tracersOut) t.dispose();
3943
4338
  jaxpr = jaxpr.simplify();
3944
- if (require_backend.DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
3945
- return {
3946
- jaxpr,
3947
- consts
3948
- };
4339
+ if (require_backend.DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
4340
+ return new ClosedJaxpr(jaxpr, consts);
3949
4341
  }
3950
4342
  /** Marker type for pullback, used by transpose rules. */
3951
4343
  var UndefPrimal = class {
@@ -4075,22 +4467,25 @@ const transposeRules = {
4075
4467
  },
4076
4468
  [Primitive.Conv]([ct], [lhs, rhs], params) {
4077
4469
  if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
4470
+ const v = params.vmapDims;
4078
4471
  const rev01 = [
4079
- 1,
4080
- 0,
4081
- ...require_backend.range(2, ct.ndim)
4472
+ ...require_backend.range(v),
4473
+ v + 1,
4474
+ v,
4475
+ ...require_backend.range(v + 2, ct.ndim)
4082
4476
  ];
4083
4477
  if (lhs instanceof UndefPrimal) {
4084
4478
  let kernel = rhs;
4085
4479
  kernel = transpose$1(kernel, rev01);
4086
- kernel = flip$1(kernel, require_backend.range(2, kernel.ndim));
4480
+ kernel = flip$1(kernel, require_backend.range(v + 2, kernel.ndim));
4087
4481
  const result = conv$1(ct, kernel, {
4482
+ vmapDims: v,
4088
4483
  strides: params.lhsDilation,
4089
4484
  padding: params.padding.map(([pl, _pr], i) => {
4090
- const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
4091
- const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
4485
+ const dilatedKernel = (kernel.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
4486
+ const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
4092
4487
  const padBefore = dilatedKernel - 1 - pl;
4093
- const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
4488
+ const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
4094
4489
  const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
4095
4490
  return [padBefore, padAfter];
4096
4491
  }),
@@ -4102,11 +4497,12 @@ const transposeRules = {
4102
4497
  const newLhs = transpose$1(lhs, rev01);
4103
4498
  const newRhs = transpose$1(ct, rev01);
4104
4499
  let result = conv$1(newLhs, newRhs, {
4500
+ vmapDims: v,
4105
4501
  strides: params.rhsDilation,
4106
4502
  padding: params.padding.map(([pl, _pr], i) => {
4107
- const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
4108
- const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
4109
- const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
4503
+ const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
4504
+ const dilatedKernel = (rhs.aval.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
4505
+ const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
4110
4506
  const padFromLhs = dilatedCt - dilatedLhs;
4111
4507
  const padFromRhs = dilatedKernel - pl - 1;
4112
4508
  return [pl, padFromLhs + padFromRhs];
@@ -4133,6 +4529,11 @@ const transposeRules = {
4133
4529
  cond.dispose();
4134
4530
  return cts;
4135
4531
  },
4532
+ [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4533
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4534
+ if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4535
+ throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4536
+ },
4136
4537
  [Primitive.Transpose]([ct], [x], { perm }) {
4137
4538
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
4138
4539
  return [transpose$1(ct, require_backend.invertPermutation(perm))];
@@ -4159,23 +4560,26 @@ const transposeRules = {
4159
4560
  const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
4160
4561
  return [shrink(ct, slice)];
4161
4562
  },
4162
- [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4163
- if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4164
- if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4165
- throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
4563
+ [Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
4564
+ if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
4565
+ const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
4566
+ lower: true,
4567
+ unitDiagonal
4568
+ });
4569
+ return [null, ctB];
4166
4570
  },
4167
- [Primitive.JitCall](cts, args, { name, jaxpr }) {
4571
+ [Primitive.Jit](cts, args, { name, jaxpr }) {
4168
4572
  const undefPrimals = args.map((x) => x instanceof UndefPrimal);
4169
- const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
4573
+ const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
4170
4574
  const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
4171
- const outs = bind(Primitive.JitCall, [
4172
- ...newConsts.map((c) => c.ref),
4575
+ const outs = bind(Primitive.Jit, [
4576
+ ...newJaxpr.consts.map((c) => c.ref),
4173
4577
  ...residuals,
4174
4578
  ...cts
4175
4579
  ], {
4176
4580
  name: `${name}_t`,
4177
- jaxpr: newJaxpr,
4178
- numConsts: newConsts.length
4581
+ jaxpr: newJaxpr.jaxpr,
4582
+ numConsts: newJaxpr.consts.length
4179
4583
  });
4180
4584
  let i = 0;
4181
4585
  return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
@@ -4188,31 +4592,25 @@ function transposeJaxpr(jaxpr, undefPrimals) {
4188
4592
  if (prevResult) return prevResult;
4189
4593
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
4190
4594
  const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
4191
- const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((forwardIn, cotangents) => {
4595
+ const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
4192
4596
  const args = [];
4193
4597
  let forwardInIdx = 0;
4194
4598
  for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
4195
4599
  else args.push(forwardIn[forwardInIdx++]);
4196
4600
  return evalJaxprTransposed(jaxpr, args, cotangents);
4197
4601
  })(forwardInTypes, outTypes);
4198
- typecheckJaxpr(newJaxpr);
4199
- const result = {
4200
- newJaxpr,
4201
- newConsts
4202
- };
4602
+ typecheckJaxpr(newJaxpr.jaxpr);
4203
4603
  if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
4204
- transposeJaxprCache.get(jaxpr).set(cacheKey, result);
4205
- return result;
4604
+ transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
4605
+ return newJaxpr;
4206
4606
  }
4207
4607
  function vjpFlat(f, primalsIn) {
4208
- const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
4608
+ const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
4209
4609
  const fVjp = (...cotangents) => {
4210
- const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4211
- return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
4212
- };
4213
- const dispose$1 = () => {
4214
- for (const c of consts) c.dispose();
4610
+ const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
4611
+ return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
4215
4612
  };
4613
+ const dispose$1 = () => jaxpr.dispose();
4216
4614
  return [
4217
4615
  primalsOut,
4218
4616
  fVjp,
@@ -4269,150 +4667,6 @@ function jacrev$1(f) {
4269
4667
  };
4270
4668
  }
4271
4669
 
4272
- //#endregion
4273
- //#region src/library/lax.ts
4274
- var lax_exports = {};
4275
- __export(lax_exports, {
4276
- conv: () => conv,
4277
- convGeneralDilated: () => convGeneralDilated,
4278
- convWithGeneralPadding: () => convWithGeneralPadding,
4279
- dot: () => dot$1,
4280
- erf: () => erf,
4281
- erfc: () => erfc,
4282
- reduceWindow: () => reduceWindow,
4283
- stopGradient: () => stopGradient$1
4284
- });
4285
- /**
4286
- * General dot product/contraction operator.
4287
- *
4288
- * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
4289
- * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
4290
- */
4291
- function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
4292
- if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
4293
- else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
4294
- lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
4295
- rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
4296
- lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
4297
- rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
4298
- if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
4299
- else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
4300
- const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
4301
- const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
4302
- const lhs2 = lhs.transpose([
4303
- ...lb,
4304
- ...lf,
4305
- ...lc
4306
- ]);
4307
- const rhs2 = rhs.transpose([
4308
- ...rb,
4309
- ...rf,
4310
- ...rc
4311
- ]);
4312
- if (lc.length === 0) return mul(lhs2.reshape([
4313
- ...lb.map((a) => lhs.shape[a]),
4314
- ...lf.map((a) => lhs.shape[a]),
4315
- ...require_backend.rep(rf.length, 1)
4316
- ]), rhs2.reshape([
4317
- ...rb.map((a) => rhs.shape[a]),
4318
- ...require_backend.rep(lf.length, 1),
4319
- ...rf.map((a) => rhs.shape[a])
4320
- ]));
4321
- const dotShapeX = lc.map((a) => lhs.shape[a]);
4322
- const dotShapeY = rc.map((a) => rhs.shape[a]);
4323
- if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
4324
- return dot$2(lhs2.reshape([
4325
- ...lb.map((a) => lhs.shape[a]),
4326
- ...lf.map((a) => lhs.shape[a]),
4327
- ...require_backend.rep(rf.length, 1),
4328
- require_backend.prod(dotShapeX)
4329
- ]), rhs2.reshape([
4330
- ...rb.map((a) => rhs.shape[a]),
4331
- ...require_backend.rep(lf.length, 1),
4332
- ...rf.map((a) => rhs.shape[a]),
4333
- require_backend.prod(dotShapeY)
4334
- ]));
4335
- }
4336
- function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
4337
- const padType = padding.toUpperCase();
4338
- switch (padType) {
4339
- case "VALID": return require_backend.rep(inShape.length, [0, 0]);
4340
- case "SAME":
4341
- case "SAME_LOWER": {
4342
- const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
4343
- const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
4344
- if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
4345
- else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
4346
- }
4347
- default: throw new Error(`Unknown padding type: ${padType}`);
4348
- }
4349
- }
4350
- /**
4351
- * General n-dimensional convolution operator, with optional dilation.
4352
- *
4353
- * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
4354
- * function in JAX, which wraps XLA's general convolution operator.
4355
- *
4356
- * Grouped convolutions are not supported right now.
4357
- */
4358
- function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
4359
- if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
4360
- if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
4361
- if (typeof padding === "string") {
4362
- if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
4363
- padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
4364
- }
4365
- return conv$1(lhs, rhs, {
4366
- strides: windowStrides,
4367
- padding,
4368
- lhsDilation,
4369
- rhsDilation
4370
- });
4371
- }
4372
- /** Convenience wrapper around `convGeneralDilated`. */
4373
- function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
4374
- return convGeneralDilated(lhs, rhs, windowStrides, padding, {
4375
- lhsDilation,
4376
- rhsDilation
4377
- });
4378
- }
4379
- /** Convenience wrapper around `convGeneralDilated`. */
4380
- function conv(lhs, rhs, windowStrides, padding) {
4381
- return convGeneralDilated(lhs, rhs, windowStrides, padding);
4382
- }
4383
- /** Reduce a computation over padded windows. */
4384
- function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4385
- if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
4386
- if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
4387
- for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
4388
- return computation(bind1(Primitive.Pool, [operand], {
4389
- window: windowDimensions,
4390
- strides: windowStrides
4391
- }));
4392
- }
4393
- /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4394
- function erf(x) {
4395
- return erf$1(x);
4396
- }
4397
- /**
4398
- * The complementary error function: `erfc(x) = 1 - erf(x)`.
4399
- *
4400
- * This function is more accurate than `1 - erf(x)` for large values of `x`,
4401
- * where `erf(x)` is very close to 1.
4402
- */
4403
- function erfc(x) {
4404
- return erfc$1(x);
4405
- }
4406
- /**
4407
- * Stops gradient computation.
4408
- *
4409
- * Behaves as the identity function but prevents the flow of gradients during
4410
- * forward or reverse-mode automatic differentiation.
4411
- */
4412
- function stopGradient$1(x) {
4413
- return stopGradient(x);
4414
- }
4415
-
4416
4670
  //#endregion
4417
4671
  //#region src/library/numpy/einsum.ts
4418
4672
  const bprod = (...xs) => xs.reduce((acc, x) => acc * BigInt(x), 1n);
@@ -4608,34 +4862,207 @@ function* allPaths(tensors, next) {
4608
4862
  }
4609
4863
  }
4610
4864
 
4865
+ //#endregion
4866
+ //#region src/library/numpy-fft.ts
4867
+ var numpy_fft_exports = {};
4868
+ __export(numpy_fft_exports, {
4869
+ fft: () => fft,
4870
+ ifft: () => ifft
4871
+ });
4872
+ function checkPairInput(name, a) {
4873
+ const fullName = `jax.numpy.fft.${name}`;
4874
+ if (!require_backend.deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
4875
+ if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
4876
+ if (!require_backend.isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
4877
+ }
4878
+ function checkPowerOfTwo(name, n) {
4879
+ if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
4880
+ }
4881
+ const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
4882
+ const half = 2 ** i;
4883
+ real = real.reshape([-1, 2 * half]);
4884
+ imag = imag.reshape([-1, 2 * half]);
4885
+ const k = arange(0, half, 1, { dtype: real.dtype });
4886
+ const theta = k.mul(-Math.PI / half);
4887
+ const wr = cos(theta.ref);
4888
+ const wi = sin(theta);
4889
+ const ur = real.ref.slice([], [0, half]);
4890
+ const ui = imag.ref.slice([], [0, half]);
4891
+ const vr = real.slice([], [half, 2 * half]);
4892
+ const vi = imag.slice([], [half, 2 * half]);
4893
+ const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
4894
+ const ti = vr.mul(wi).add(vi.mul(wr));
4895
+ return {
4896
+ real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
4897
+ imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
4898
+ };
4899
+ }, { staticArgnums: [0] });
4900
+ /**
4901
+ * Compute a one-dimensional discrete Fourier transform.
4902
+ *
4903
+ * Currently, the size of the axis must be a power of two.
4904
+ */
4905
+ function fft(a, axis = -1) {
4906
+ checkPairInput("fft", a);
4907
+ let { real, imag } = a;
4908
+ axis = require_backend.checkAxis(axis, real.ndim);
4909
+ const n = real.shape[axis];
4910
+ checkPowerOfTwo("fft", n);
4911
+ const logN = Math.log2(n);
4912
+ let perm = null;
4913
+ if (axis !== real.ndim - 1) {
4914
+ perm = require_backend.range(real.ndim);
4915
+ perm.splice(axis, 1);
4916
+ perm.push(axis);
4917
+ real = real.transpose(perm);
4918
+ imag = imag.transpose(perm);
4919
+ }
4920
+ const originalShape = real.shape;
4921
+ real = real.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
4922
+ imag = imag.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
4923
+ for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
4924
+ real,
4925
+ imag
4926
+ }));
4927
+ real = real.reshape(originalShape);
4928
+ imag = imag.reshape(originalShape);
4929
+ if (perm !== null) {
4930
+ real = real.transpose(require_backend.invertPermutation(perm));
4931
+ imag = imag.transpose(require_backend.invertPermutation(perm));
4932
+ }
4933
+ return {
4934
+ real,
4935
+ imag
4936
+ };
4937
+ }
4938
+ /**
4939
+ * Compute a one-dimensional inverse discrete Fourier transform.
4940
+ *
4941
+ * Currently, the size of the axis must be a power of two.
4942
+ */
4943
+ function ifft(a, axis = -1) {
4944
+ checkPairInput("ifft", a);
4945
+ let { real, imag } = a;
4946
+ axis = require_backend.checkAxis(axis, real.ndim);
4947
+ const n = real.shape[axis];
4948
+ checkPowerOfTwo("ifft", n);
4949
+ imag = imag.mul(-1);
4950
+ const result = fft({
4951
+ real,
4952
+ imag
4953
+ }, axis);
4954
+ return {
4955
+ real: result.real.div(n),
4956
+ imag: result.imag.mul(-1).div(n)
4957
+ };
4958
+ }
4959
+
4960
+ //#endregion
4961
+ //#region src/library/numpy-linalg.ts
4962
+ var numpy_linalg_exports = {};
4963
+ __export(numpy_linalg_exports, {
4964
+ cholesky: () => cholesky$1,
4965
+ diagonal: () => diagonal,
4966
+ lstsq: () => lstsq,
4967
+ matmul: () => matmul,
4968
+ matrixTranspose: () => matrixTranspose,
4969
+ outer: () => outer,
4970
+ tensordot: () => tensordot,
4971
+ trace: () => trace,
4972
+ vecdot: () => vecdot
4973
+ });
4974
+ /**
4975
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
4976
+ *
4977
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
4978
+ * the input matrix, which is on by default.
4979
+ */
4980
+ function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
4981
+ a = fudgeArray(a);
4982
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
4983
+ if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
4984
+ return cholesky(a, { upper });
4985
+ }
4986
+ /**
4987
+ * Return the least-squares solution to a linear equation.
4988
+ *
4989
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
4990
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
4991
+ *
4992
+ * This currently uses Cholesky decomposition to solve the normal equations,
4993
+ * under the hood. The method is not as robust as QR or SVD.
4994
+ *
4995
+ * @param a coefficient matrix of shape `(M, N)`
4996
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
4997
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
4998
+ */
4999
+ function lstsq(a, b) {
5000
+ a = fudgeArray(a);
5001
+ b = fudgeArray(b);
5002
+ if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
5003
+ const [m, n] = a.shape;
5004
+ if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
5005
+ const at = matrixTranspose(a.ref);
5006
+ if (m <= n) {
5007
+ const aat = matmul(a, at.ref);
5008
+ const l = cholesky$1(aat, { symmetrizeInput: false });
5009
+ const lb = triangularSolve(l.ref, b, {
5010
+ leftSide: true,
5011
+ lower: true
5012
+ });
5013
+ const llb = triangularSolve(l, lb, {
5014
+ leftSide: true,
5015
+ transposeA: true
5016
+ });
5017
+ return matmul(at, llb.ref);
5018
+ } else {
5019
+ const ata = matmul(at.ref, a);
5020
+ const l = cholesky$1(ata, { symmetrizeInput: false });
5021
+ const atb = matmul(at, b);
5022
+ const lb = triangularSolve(l.ref, atb, {
5023
+ leftSide: true,
5024
+ lower: true
5025
+ });
5026
+ const llb = triangularSolve(l, lb, {
5027
+ leftSide: true,
5028
+ transposeA: true
5029
+ });
5030
+ return llb;
5031
+ }
5032
+ }
5033
+
4611
5034
  //#endregion
4612
5035
  //#region src/library/numpy.ts
4613
5036
  var numpy_exports = {};
4614
5037
  __export(numpy_exports, {
4615
5038
  Array: () => Array$1,
4616
5039
  DType: () => require_backend.DType,
4617
- abs: () => abs,
5040
+ abs: () => absolute,
4618
5041
  absolute: () => absolute,
4619
5042
  acos: () => acos,
4620
- acosh: () => acosh,
5043
+ acosh: () => arccosh,
4621
5044
  add: () => add,
5045
+ all: () => all,
4622
5046
  allclose: () => allclose,
5047
+ any: () => any,
4623
5048
  arange: () => arange,
4624
- arccos: () => arccos,
5049
+ arccos: () => acos,
4625
5050
  arccosh: () => arccosh,
5051
+ arcsin: () => asin,
4626
5052
  arcsinh: () => arcsinh,
4627
- arctan: () => arctan,
4628
- arctan2: () => arctan2,
5053
+ arctan: () => atan,
5054
+ arctan2: () => atan2,
4629
5055
  arctanh: () => arctanh,
4630
5056
  argmax: () => argmax,
4631
5057
  argmin: () => argmin,
5058
+ argsort: () => argsort,
4632
5059
  array: () => array,
4633
5060
  asin: () => asin,
4634
- asinh: () => asinh,
5061
+ asinh: () => arcsinh,
4635
5062
  astype: () => astype,
4636
5063
  atan: () => atan,
4637
5064
  atan2: () => atan2,
4638
- atanh: () => atanh,
5065
+ atanh: () => arctanh,
4639
5066
  bool: () => bool,
4640
5067
  broadcastArrays: () => broadcastArrays,
4641
5068
  broadcastShapes: () => broadcastShapes,
@@ -4645,14 +5072,20 @@ __export(numpy_exports, {
4645
5072
  clip: () => clip,
4646
5073
  columnStack: () => columnStack,
4647
5074
  concatenate: () => concatenate,
5075
+ convolve: () => convolve,
5076
+ corrcoef: () => corrcoef,
5077
+ correlate: () => correlate,
4648
5078
  cos: () => cos,
4649
5079
  cosh: () => cosh,
5080
+ cov: () => cov,
5081
+ cumsum: () => cumsum,
5082
+ cumulativeSum: () => cumsum,
4650
5083
  deg2rad: () => deg2rad,
4651
5084
  degrees: () => degrees,
4652
5085
  diag: () => diag,
4653
5086
  diagonal: () => diagonal,
4654
- divide: () => divide,
4655
- dot: () => dot,
5087
+ divide: () => trueDivide,
5088
+ dot: () => dot$1,
4656
5089
  dstack: () => dstack,
4657
5090
  e: () => e,
4658
5091
  einsum: () => einsum,
@@ -4660,8 +5093,10 @@ __export(numpy_exports, {
4660
5093
  eulerGamma: () => eulerGamma,
4661
5094
  exp: () => exp,
4662
5095
  exp2: () => exp2,
5096
+ expandDims: () => expandDims,
4663
5097
  expm1: () => expm1,
4664
5098
  eye: () => eye,
5099
+ fft: () => numpy_fft_exports,
4665
5100
  flip: () => flip,
4666
5101
  fliplr: () => fliplr,
4667
5102
  flipud: () => flipud,
@@ -4692,12 +5127,14 @@ __export(numpy_exports, {
4692
5127
  ldexp: () => ldexp,
4693
5128
  less: () => less,
4694
5129
  lessEqual: () => lessEqual,
5130
+ linalg: () => numpy_linalg_exports,
4695
5131
  linspace: () => linspace,
4696
5132
  log: () => log,
4697
5133
  log10: () => log10,
4698
5134
  log1p: () => log1p,
4699
5135
  log2: () => log2,
4700
5136
  matmul: () => matmul,
5137
+ matrixTranspose: () => matrixTranspose,
4701
5138
  max: () => max,
4702
5139
  maximum: () => maximum,
4703
5140
  mean: () => mean,
@@ -4714,10 +5151,10 @@ __export(numpy_exports, {
4714
5151
  onesLike: () => onesLike,
4715
5152
  outer: () => outer,
4716
5153
  pad: () => pad,
4717
- permuteDims: () => permuteDims,
5154
+ permuteDims: () => transpose,
4718
5155
  pi: () => pi,
4719
5156
  positive: () => positive,
4720
- pow: () => pow,
5157
+ pow: () => power,
4721
5158
  power: () => power,
4722
5159
  prod: () => prod$1,
4723
5160
  promoteTypes: () => require_backend.promoteTypes,
@@ -4734,6 +5171,7 @@ __export(numpy_exports, {
4734
5171
  sin: () => sin,
4735
5172
  sinh: () => sinh,
4736
5173
  size: () => size,
5174
+ sort: () => sort,
4737
5175
  sqrt: () => sqrt,
4738
5176
  square: () => square,
4739
5177
  squeeze: () => squeeze,
@@ -4898,6 +5336,26 @@ function min(a, axis = null, opts) {
4898
5336
  function max(a, axis = null, opts) {
4899
5337
  return reduce(a, require_backend.AluOp.Max, axis, opts);
4900
5338
  }
5339
+ /**
5340
+ * Test whether all array elements along a given axis evaluate to True.
5341
+ *
5342
+ * Returns a boolean array with the same shape as `a` with the specified axis
5343
+ * removed. If axis is None, returns a scalar.
5344
+ */
5345
+ function all(a, axis = null, opts) {
5346
+ a = fudgeArray(a).astype(require_backend.DType.Bool);
5347
+ return min(a, axis, opts);
5348
+ }
5349
+ /**
5350
+ * Test whether any array element along a given axis evaluates to True.
5351
+ *
5352
+ * Returns a boolean array with the same shape as `a` with the specified axis
5353
+ * removed. If axis is None, returns a scalar.
5354
+ */
5355
+ function any(a, axis = null, opts) {
5356
+ a = fudgeArray(a).astype(require_backend.DType.Bool);
5357
+ return max(a, axis, opts);
5358
+ }
4901
5359
  /** Return the peak-to-peak range along a given axis (`max - min`). */
4902
5360
  function ptp(a, axis = null, opts) {
4903
5361
  a = fudgeArray(a);
@@ -4955,6 +5413,23 @@ function argmax(a, axis, opts) {
4955
5413
  }).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
4956
5414
  return length.sub(max(idx, axis, opts));
4957
5415
  }
5416
+ /**
5417
+ * Cumulative sum of elements along an axis.
5418
+ *
5419
+ * Currently this function is `O(n^2)`, we'll improve this later on with a
5420
+ * two-phase parallel reduction algorithm.
5421
+ */
5422
+ function cumsum(a, axis) {
5423
+ a = fudgeArray(a);
5424
+ if (axis === void 0) {
5425
+ a = a.ravel();
5426
+ axis = 0;
5427
+ } else axis = require_backend.checkAxis(axis, a.ndim);
5428
+ const n = a.shape[axis];
5429
+ a = moveaxis$1(a, axis, -1);
5430
+ a = broadcast(a, a.shape.concat(n), [-2]);
5431
+ return moveaxis$1(tril(a).sum(-1), -1, axis);
5432
+ }
4958
5433
  /** Reverse the elements in an array along the given axes. */
4959
5434
  function flip(x, axis = null) {
4960
5435
  const nd = ndim(x);
@@ -5064,8 +5539,11 @@ function flipud(x) {
5064
5539
  function fliplr(x) {
5065
5540
  return flip(x, 1);
5066
5541
  }
5067
- /** @function Alternative name for `numpy.transpose()`. */
5068
- const permuteDims = transpose;
5542
+ /** Transpose the last two dimensions of an array. */
5543
+ function matrixTranspose(a) {
5544
+ if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
5545
+ return moveaxis$1(a, -1, -2);
5546
+ }
5069
5547
  /** Return a 1-D flattened array containing the elements of the input. */
5070
5548
  function ravel(a) {
5071
5549
  return fudgeArray(a).ravel();
@@ -5081,6 +5559,32 @@ function squeeze(a, axis = null) {
5081
5559
  return reshape(a, newShape);
5082
5560
  }
5083
5561
  /**
5562
+ * Expand the shape of an array by inserting new axes of length 1.
5563
+ *
5564
+ * @param a - Input array.
5565
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
5566
+ * is placed. Can be a single integer or an array of integers.
5567
+ * @returns Array with the number of dimensions increased.
5568
+ *
5569
+ * @example
5570
+ * ```ts
5571
+ * const x = np.array([1, 2]);
5572
+ * np.expandDims(x, 0); // Shape [1, 2]
5573
+ * np.expandDims(x, 1); // Shape [2, 1]
5574
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
5575
+ * ```
5576
+ */
5577
+ function expandDims(a, axis) {
5578
+ const as = shape(a);
5579
+ axis = typeof axis === "number" ? [axis] : axis;
5580
+ axis = require_backend.normalizeAxis(axis, as.length + axis.length);
5581
+ const newShape = [];
5582
+ let srcIdx = 0;
5583
+ for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
5584
+ else newShape.push(as[srcIdx++]);
5585
+ return reshape(a, newShape);
5586
+ }
5587
+ /**
5084
5588
  * Repeat each element of an array after themselves.
5085
5589
  *
5086
5590
  * If no axis is provided, use the flattened input array, and return a flat
@@ -5168,7 +5672,7 @@ function diagonal(a, offset, axis1, axis2) {
5168
5672
  */
5169
5673
  function diag(v, k = 0) {
5170
5674
  const a = fudgeArray(v);
5171
- if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
5675
+ if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
5172
5676
  if (a.ndim === 1) {
5173
5677
  const n = a.shape[0];
5174
5678
  const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
@@ -5176,12 +5680,32 @@ function diag(v, k = 0) {
5176
5680
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
5177
5681
  else return ret;
5178
5682
  } else if (a.ndim === 2) return diagonal(a, k);
5179
- else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
5683
+ else throw new Error("numpy.diag only supports 1D and 2D arrays");
5180
5684
  }
5181
5685
  /** Calculate the sum of the diagonal of an array along the given axes. */
5182
5686
  function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
5183
5687
  return diagonal(a, offset, axis1, axis2).sum(-1);
5184
5688
  }
5689
+ /**
5690
+ * Return a sorted copy of an array.
5691
+ *
5692
+ * The array is sorted along a specified axis (the last by default). This may be
5693
+ * an unstable sort, and it dispatches to device-specific implementation.
5694
+ */
5695
+ function sort(a, axis = -1) {
5696
+ return fudgeArray(a).sort(axis);
5697
+ }
5698
+ /**
5699
+ * Return indices that would sort an array. This may be an unstable sorting
5700
+ * algorithm; it need not preserve order of indices in ties.
5701
+ *
5702
+ * Returns an array of `int32` indices.
5703
+ *
5704
+ * The array is sorted along a specified axis (the last by default).
5705
+ */
5706
+ function argsort(a, axis = -1) {
5707
+ return fudgeArray(a).argsort(axis);
5708
+ }
5185
5709
  /** Return if two arrays are element-wise equal within a tolerance. */
5186
5710
  function allclose(actual, expected, options) {
5187
5711
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -5190,16 +5714,19 @@ function allclose(actual, expected, options) {
5190
5714
  if (!require_backend.deepEqual(x.shape, y.shape)) return false;
5191
5715
  const xData = x.dataSync();
5192
5716
  const yData = y.dataSync();
5193
- for (let i = 0; i < xData.length; i++) if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
5717
+ for (let i = 0; i < xData.length; i++) {
5718
+ if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
5719
+ if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
5720
+ }
5194
5721
  return true;
5195
5722
  }
5196
5723
  /** Matrix product of two arrays. */
5197
5724
  function matmul(x, y) {
5198
- if (ndim(x) === 0 || ndim(y) === 0) throw new TypeError("matmul: x and y must be at least 1D");
5725
+ if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
5199
5726
  x = x, y = y;
5200
5727
  if (y.ndim === 1) return dot$2(x, y);
5201
5728
  const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
5202
- return dot$1(x, y, {
5729
+ return dot(x, y, {
5203
5730
  lhsContractingDims: [-1],
5204
5731
  rhsContractingDims: [-2],
5205
5732
  lhsBatchDims: require_backend.range(-2 - numBatchDims, -2),
@@ -5207,11 +5734,11 @@ function matmul(x, y) {
5207
5734
  });
5208
5735
  }
5209
5736
  /** Dot product of two arrays. */
5210
- function dot(x, y) {
5737
+ function dot$1(x, y) {
5211
5738
  if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
5212
5739
  x = x, y = y;
5213
5740
  if (y.ndim === 1) return dot$2(x, y);
5214
- return dot$1(x, y, {
5741
+ return dot(x, y, {
5215
5742
  lhsContractingDims: [-1],
5216
5743
  rhsContractingDims: [-2]
5217
5744
  });
@@ -5227,7 +5754,7 @@ function tensordot(x, y, axes = 2) {
5227
5754
  x = fudgeArray(x);
5228
5755
  y = fudgeArray(y);
5229
5756
  if (typeof axes === "number") axes = [require_backend.range(-axes, 0), require_backend.range(axes)];
5230
- return dot$1(x, y, {
5757
+ return dot(x, y, {
5231
5758
  lhsContractingDims: axes[0],
5232
5759
  rhsContractingDims: axes[1]
5233
5760
  });
@@ -5320,7 +5847,7 @@ function einsum(...args) {
5320
5847
  const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
5321
5848
  indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
5322
5849
  const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
5323
- const result = dot$1(a, b, {
5850
+ const result = dot(a, b, {
5324
5851
  lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
5325
5852
  rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
5326
5853
  lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
@@ -5348,7 +5875,7 @@ function einsum(...args) {
5348
5875
  * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
5349
5876
  */
5350
5877
  function inner(x, y) {
5351
- return dot$1(fudgeArray(x), fudgeArray(y), {
5878
+ return dot(fudgeArray(x), fudgeArray(y), {
5352
5879
  lhsContractingDims: [-1],
5353
5880
  rhsContractingDims: [-1]
5354
5881
  });
@@ -5381,6 +5908,30 @@ function vecdot(x, y, { axis } = {}) {
5381
5908
  function vdot(x, y) {
5382
5909
  return dot$2(ravel(x), ravel(y));
5383
5910
  }
5911
+ function _convImpl(name, x, y, mode) {
5912
+ if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
5913
+ let flipOutput = false;
5914
+ if (x.shape[0] < y.shape[0]) {
5915
+ [x, y] = [y, x];
5916
+ if (name === "correlate") flipOutput = true;
5917
+ }
5918
+ if (name === "convolve") y = flip(y);
5919
+ let padding;
5920
+ if (mode === "valid") padding = "VALID";
5921
+ else if (mode === "same") padding = "SAME_LOWER";
5922
+ else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
5923
+ else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
5924
+ const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
5925
+ return flipOutput ? flip(z) : z;
5926
+ }
5927
+ /** Convolution of two one-dimensional arrays. */
5928
+ function convolve(x, y, mode = "full") {
5929
+ return _convImpl("convolve", x, y, mode);
5930
+ }
5931
+ /** Correlation of two one dimensional arrays. */
5932
+ function correlate(x, y, mode = "valid") {
5933
+ return _convImpl("correlate", x, y, mode);
5934
+ }
5384
5935
  /**
5385
5936
  * Return a tuple of coordinate matrices from coordinate vectors.
5386
5937
  *
@@ -5389,7 +5940,7 @@ function vdot(x, y) {
5389
5940
  */
5390
5941
  function meshgrid(xs, { indexing } = {}) {
5391
5942
  indexing ??= "xy";
5392
- for (const x of xs) if (x.ndim !== 1) throw new TypeError(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5943
+ for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
5393
5944
  if (xs.length <= 1) return xs;
5394
5945
  if (indexing === "xy") {
5395
5946
  const [a, b, ...rest] = xs;
@@ -5408,43 +5959,6 @@ function meshgrid(xs, { indexing } = {}) {
5408
5959
  return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
5409
5960
  }
5410
5961
  /**
5411
- * Return an array with ones on and below the diagonal and zeros elsewhere.
5412
- *
5413
- * If `k` is provided, it specifies the sub-diagonal on and below which the
5414
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
5415
- * `k>0` is above it.
5416
- */
5417
- function tri(n, m, k = 0, { dtype, device } = {}) {
5418
- m ??= n;
5419
- dtype ??= require_backend.DType.Float32;
5420
- if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
5421
- if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
5422
- if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
5423
- const rows = arange(k, n + k, 1, {
5424
- dtype: require_backend.DType.Int32,
5425
- device
5426
- });
5427
- const cols = arange(0, m, 1, {
5428
- dtype: require_backend.DType.Int32,
5429
- device
5430
- });
5431
- return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
5432
- }
5433
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
5434
- function tril(a, k = 0) {
5435
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5436
- a = fudgeArray(a);
5437
- const [n, m] = a.shape.slice(-2);
5438
- return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
5439
- }
5440
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
5441
- function triu(a, k = 0) {
5442
- if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
5443
- a = fudgeArray(a);
5444
- const [n, m] = a.shape.slice(-2);
5445
- return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
5446
- }
5447
- /**
5448
5962
  * Clip (limit) the values in an array.
5449
5963
  *
5450
5964
  * Given an interval, values outside the interval are clipped to the interval
@@ -5468,8 +5982,6 @@ function absolute(x) {
5468
5982
  x = fudgeArray(x);
5469
5983
  return where(less(x.ref, 0), x.ref.mul(-1), x);
5470
5984
  }
5471
- /** @function Alias of `jax.numpy.absolute()`. */
5472
- const abs = absolute;
5473
5985
  /** Return an element-wise indication of sign of the input. */
5474
5986
  function sign(x) {
5475
5987
  x = fudgeArray(x);
@@ -5548,12 +6060,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
5548
6060
  const denom = where(xNeg, y, r.add(x));
5549
6061
  return atan(numer.div(denom)).mul(2);
5550
6062
  });
5551
- /** @function Alias of `jax.numpy.acos()`. */
5552
- const arccos = acos;
5553
- /** @function Alias of `jax.numpy.atan()`. */
5554
- const arctan = atan;
5555
- /** @function Alias of `jax.numpy.atan2()`. */
5556
- const arctan2 = atan2;
5557
6063
  /** Element-wise subtraction, with broadcasting. */
5558
6064
  function subtract(x, y) {
5559
6065
  x = fudgeArray(x);
@@ -5584,8 +6090,6 @@ const fmod = jit$1(function fmod$1(x, y) {
5584
6090
  const remainder = jit$1(function remainder$1(x, y) {
5585
6091
  return mod(mod(x, y.ref).add(y.ref), y);
5586
6092
  });
5587
- /** @function Alias of `jax.numpy.trueDivide()`. */
5588
- const divide = trueDivide;
5589
6093
  /** Round input to the nearest integer towards zero. */
5590
6094
  function trunc(x) {
5591
6095
  return idiv(x, 1);
@@ -5607,9 +6111,9 @@ function ldexp(x1, x2) {
5607
6111
  */
5608
6112
  function frexp(x) {
5609
6113
  x = fudgeArray(x);
5610
- const absx = abs(x.ref);
6114
+ const absx = absolute(x.ref);
5611
6115
  const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(require_backend.DType.Int32));
5612
- const mantissa = divide(x, exp2(exponent.ref.astype(x.dtype)));
6116
+ const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
5613
6117
  return [mantissa, exponent];
5614
6118
  }
5615
6119
  /** Calculate `2**p` for all p in the input array. */
@@ -5649,10 +6153,11 @@ const degrees = rad2deg;
5649
6153
  * Computes first array raised to power of second array, element-wise.
5650
6154
  */
5651
6155
  const power = jit$1(function power$1(x1, x2) {
5652
- return exp(log(x1).mul(x2));
6156
+ const x2i = trunc(x2.ref);
6157
+ const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
6158
+ const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
6159
+ return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
5653
6160
  });
5654
- /** @function Alias of `jax.numpy.power()`. */
5655
- const pow = power;
5656
6161
  /** @function Calculate the element-wise cube root of the input array. */
5657
6162
  const cbrt = jit$1(function cbrt$1(x) {
5658
6163
  const sgn = where(less(x.ref, 0), -1, 1);
@@ -5718,12 +6223,6 @@ const arccosh = jit$1(function arccosh$1(x) {
5718
6223
  const arctanh = jit$1(function arctanh$1(x) {
5719
6224
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
5720
6225
  });
5721
- /** @function Alias of `jax.numpy.arcsinh()`. */
5722
- const asinh = arcsinh;
5723
- /** @function Alias of `jax.numpy.arccosh()`. */
5724
- const acosh = arccosh;
5725
- /** @function Alias of `jax.numpy.arctanh()`. */
5726
- const atanh = arctanh;
5727
6226
  /**
5728
6227
  * Compute the variance of an array.
5729
6228
  *
@@ -5753,6 +6252,26 @@ function var_(x, axis = null, opts) {
5753
6252
  function std(x, axis = null, opts) {
5754
6253
  return sqrt(var_(x, axis, opts));
5755
6254
  }
6255
+ /** Estimate the sample covariance of a set of variables. */
6256
+ function cov(x, y) {
6257
+ x = fudgeArray(x);
6258
+ if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6259
+ if (y !== void 0) {
6260
+ y = fudgeArray(y);
6261
+ if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6262
+ x = vstack([x, y]);
6263
+ }
6264
+ const [_M, N] = x.shape;
6265
+ x = x.ref.sub(x.mean(1, { keepdims: true }));
6266
+ return dot$1(x.ref, x.transpose()).div(N - 1);
6267
+ }
6268
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
6269
+ function corrcoef(x, y) {
6270
+ const c = cov(x, y);
6271
+ const variances = diag(c.ref);
6272
+ const norm = sqrt(outer(variances.ref, variances));
6273
+ return c.div(norm);
6274
+ }
5756
6275
  /** Test element-wise for positive or negative infinity, return bool array. */
5757
6276
  function isinf(x) {
5758
6277
  x = fudgeArray(x);
@@ -5782,6 +6301,253 @@ const isfinite = jit$1(function isfinite$1(x) {
5782
6301
  return isnan(x.ref).add(isinf(x)).notEqual(true);
5783
6302
  });
5784
6303
 
6304
+ //#endregion
6305
+ //#region src/library/lax-linalg.ts
6306
+ var lax_linalg_exports = {};
6307
+ __export(lax_linalg_exports, {
6308
+ cholesky: () => cholesky,
6309
+ triangularSolve: () => triangularSolve
6310
+ });
6311
+ /**
6312
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
6313
+ *
6314
+ * The Cholesky decomposition of a matrix `A` is:
6315
+ *
6316
+ * - A = L @ L^T (for upper=false, default)
6317
+ * - A = U^T @ U (for upper=true)
6318
+ *
6319
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
6320
+ * The input matrix must be symmetric and positive-definite.
6321
+ *
6322
+ * @example
6323
+ * ```ts
6324
+ * import { lax, numpy as np } from "@jax-js/jax";
6325
+ *
6326
+ * const x = np.array([[2., 1.], [1., 2.]]);
6327
+ *
6328
+ * // Lower Cholesky factorization (default):
6329
+ * const L = lax.linalg.cholesky(x);
6330
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
6331
+ *
6332
+ * // Upper Cholesky factorization:
6333
+ * const U = lax.linalg.cholesky(x, { upper: true });
6334
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6335
+ * ```
6336
+ */
6337
+ function cholesky(a, { upper = false } = {}) {
6338
+ const L = cholesky$2(a);
6339
+ return upper ? moveaxis$1(L, -2, -1) : L;
6340
+ }
6341
+ /**
6342
+ * Solve a triangular linear system.
6343
+ *
6344
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
6345
+ * where `a` is a triangular matrix.
6346
+ *
6347
+ * @example
6348
+ * ```ts
6349
+ * import { lax, numpy as np } from "@jax-js/jax";
6350
+ *
6351
+ * const L = np.array([[2., 0.], [1., 3.]]);
6352
+ * const b = np.array([4., 7.]).reshape([2, 1]);
6353
+ *
6354
+ * // Solve L @ x = b
6355
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
6356
+ * // x = [[2.], [5./3.]]
6357
+ * ```
6358
+ */
6359
+ function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
6360
+ a = fudgeArray(a);
6361
+ b = fudgeArray(b);
6362
+ if (!leftSide) transposeA = !transposeA;
6363
+ else b = moveaxis$1(b, -2, -1);
6364
+ if (transposeA) a = moveaxis$1(a, -2, -1);
6365
+ let x = triangularSolve$1(a, b, {
6366
+ lower,
6367
+ unitDiagonal
6368
+ });
6369
+ if (leftSide) x = moveaxis$1(x, -2, -1);
6370
+ return x;
6371
+ }
6372
+
6373
+ //#endregion
6374
+ //#region src/library/lax.ts
6375
+ var lax_exports = {};
6376
+ __export(lax_exports, {
6377
+ conv: () => conv,
6378
+ convGeneralDilated: () => convGeneralDilated,
6379
+ convWithGeneralPadding: () => convWithGeneralPadding,
6380
+ dot: () => dot,
6381
+ erf: () => erf,
6382
+ erfc: () => erfc,
6383
+ linalg: () => lax_linalg_exports,
6384
+ reduceWindow: () => reduceWindow,
6385
+ stopGradient: () => stopGradient$1
6386
+ });
6387
+ /**
6388
+ * General dot product/contraction operator.
6389
+ *
6390
+ * Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
6391
+ * `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
6392
+ */
6393
+ function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
6394
+ if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
6395
+ else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
6396
+ lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
6397
+ rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
6398
+ lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
6399
+ rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
6400
+ if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
6401
+ else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
6402
+ const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
6403
+ const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
6404
+ const lhs2 = lhs.transpose([
6405
+ ...lb,
6406
+ ...lf,
6407
+ ...lc
6408
+ ]);
6409
+ const rhs2 = rhs.transpose([
6410
+ ...rb,
6411
+ ...rf,
6412
+ ...rc
6413
+ ]);
6414
+ if (lc.length === 0) return mul(lhs2.reshape([
6415
+ ...lb.map((a) => lhs.shape[a]),
6416
+ ...lf.map((a) => lhs.shape[a]),
6417
+ ...require_backend.rep(rf.length, 1)
6418
+ ]), rhs2.reshape([
6419
+ ...rb.map((a) => rhs.shape[a]),
6420
+ ...require_backend.rep(lf.length, 1),
6421
+ ...rf.map((a) => rhs.shape[a])
6422
+ ]));
6423
+ const dotShapeX = lc.map((a) => lhs.shape[a]);
6424
+ const dotShapeY = rc.map((a) => rhs.shape[a]);
6425
+ if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
6426
+ return dot$2(lhs2.reshape([
6427
+ ...lb.map((a) => lhs.shape[a]),
6428
+ ...lf.map((a) => lhs.shape[a]),
6429
+ ...require_backend.rep(rf.length, 1),
6430
+ require_backend.prod(dotShapeX)
6431
+ ]), rhs2.reshape([
6432
+ ...rb.map((a) => rhs.shape[a]),
6433
+ ...require_backend.rep(lf.length, 1),
6434
+ ...rf.map((a) => rhs.shape[a]),
6435
+ require_backend.prod(dotShapeY)
6436
+ ]));
6437
+ }
6438
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
6439
+ const padType = padding.toUpperCase();
6440
+ switch (padType) {
6441
+ case "VALID": return require_backend.rep(inShape.length, [0, 0]);
6442
+ case "SAME":
6443
+ case "SAME_LOWER": {
6444
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
6445
+ const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
6446
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
6447
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
6448
+ }
6449
+ default: throw new Error(`Unknown padding type: ${padType}`);
6450
+ }
6451
+ }
6452
+ /**
6453
+ * General n-dimensional convolution operator, with optional dilation.
6454
+ *
6455
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
6456
+ * function in JAX, which wraps XLA's general convolution operator.
6457
+ *
6458
+ * Grouped convolutions are not supported right now.
6459
+ */
6460
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
6461
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
6462
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
6463
+ if (typeof padding === "string") {
6464
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
6465
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
6466
+ }
6467
+ if (featureGroupCount !== 1) {
6468
+ const G = featureGroupCount;
6469
+ const [N, C_in, ...xs] = lhs.shape;
6470
+ const [C_out, C_in_per_group, ...ks] = rhs.shape;
6471
+ if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
6472
+ if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
6473
+ if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
6474
+ const lhsGrouped = moveaxis(lhs.reshape([
6475
+ N,
6476
+ G,
6477
+ C_in / G,
6478
+ ...xs
6479
+ ]), 1, 0);
6480
+ const rhsGrouped = rhs.reshape([
6481
+ G,
6482
+ C_out / G,
6483
+ C_in_per_group,
6484
+ ...ks
6485
+ ]);
6486
+ const result = conv$1(lhsGrouped, rhsGrouped, {
6487
+ vmapDims: 1,
6488
+ strides: windowStrides,
6489
+ padding,
6490
+ lhsDilation,
6491
+ rhsDilation
6492
+ });
6493
+ const ys = result.shape.slice(3);
6494
+ return moveaxis(result, 0, 1).reshape([
6495
+ N,
6496
+ C_out,
6497
+ ...ys
6498
+ ]);
6499
+ }
6500
+ return conv$1(lhs, rhs, {
6501
+ strides: windowStrides,
6502
+ padding,
6503
+ lhsDilation,
6504
+ rhsDilation
6505
+ });
6506
+ }
6507
+ /** Convenience wrapper around `convGeneralDilated`. */
6508
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
6509
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
6510
+ lhsDilation,
6511
+ rhsDilation
6512
+ });
6513
+ }
6514
+ /** Convenience wrapper around `convGeneralDilated`. */
6515
+ function conv(lhs, rhs, windowStrides, padding) {
6516
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
6517
+ }
6518
+ /** Reduce a computation over padded windows. */
6519
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
6520
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
6521
+ if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
6522
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
6523
+ return computation(bind1(Primitive.Pool, [operand], {
6524
+ window: windowDimensions,
6525
+ strides: windowStrides
6526
+ }));
6527
+ }
6528
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
6529
+ function erf(x) {
6530
+ return erf$1(x);
6531
+ }
6532
+ /**
6533
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
6534
+ *
6535
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
6536
+ * where `erf(x)` is very close to 1.
6537
+ */
6538
+ function erfc(x) {
6539
+ return erfc$1(x);
6540
+ }
6541
+ /**
6542
+ * Stops gradient computation.
6543
+ *
6544
+ * Behaves as the identity function but prevents the flow of gradients during
6545
+ * forward or reverse-mode automatic differentiation.
6546
+ */
6547
+ function stopGradient$1(x) {
6548
+ return stopGradient(x);
6549
+ }
6550
+
5785
6551
  //#endregion
5786
6552
  //#region src/library/nn.ts
5787
6553
  var nn_exports = {};
@@ -5790,6 +6556,10 @@ __export(nn_exports, {
5790
6556
  elu: () => elu,
5791
6557
  gelu: () => gelu,
5792
6558
  glu: () => glu,
6559
+ hardSigmoid: () => hardSigmoid,
6560
+ hardSilu: () => hardSilu,
6561
+ hardSwish: () => hardSilu,
6562
+ hardTanh: () => hardTanh,
5793
6563
  identity: () => identity,
5794
6564
  leakyRelu: () => leakyRelu,
5795
6565
  logSigmoid: () => logSigmoid,
@@ -5800,14 +6570,17 @@ __export(nn_exports, {
5800
6570
  oneHot: () => oneHot,
5801
6571
  relu: () => relu,
5802
6572
  relu6: () => relu6,
6573
+ selu: () => selu,
5803
6574
  sigmoid: () => sigmoid,
5804
6575
  silu: () => silu,
5805
6576
  softSign: () => softSign,
5806
6577
  softmax: () => softmax,
5807
6578
  softplus: () => softplus,
6579
+ sparsePlus: () => sparsePlus,
6580
+ sparseSigmoid: () => sparseSigmoid,
5808
6581
  squareplus: () => squareplus,
5809
6582
  standardize: () => standardize,
5810
- swish: () => swish
6583
+ swish: () => silu
5811
6584
  });
5812
6585
  /**
5813
6586
  * Rectified Linear Unit (ReLU) activation function:
@@ -5842,6 +6615,28 @@ function softplus(x) {
5842
6615
  return log(exp(x).add(1));
5843
6616
  }
5844
6617
  /**
6618
+ * @function
6619
+ * Sparse plus function:
6620
+ *
6621
+ * - When `x <= -1`: `0`
6622
+ * - When `-1 < x < 1`: `(x+1)**2 / 4`
6623
+ * - When `x >= 1`: `x`
6624
+ */
6625
+ const sparsePlus = jit$1((x) => {
6626
+ return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
6627
+ });
6628
+ /**
6629
+ * @function
6630
+ * Sparse sigmoid activation function.
6631
+ *
6632
+ * - When `x <= -1`: `0`
6633
+ * - When `-1 < x < 1`: `(x + 1) / 2`
6634
+ * - When `x >= 1`: `1`
6635
+ */
6636
+ const sparseSigmoid = jit$1((x) => {
6637
+ return clip(x.add(1).mul(.5), 0, 1);
6638
+ });
6639
+ /**
5845
6640
  * Soft-sign activation function, computed element-wise:
5846
6641
  * `softsign(x) = x / (|x| + 1)`.
5847
6642
  */
@@ -5863,17 +6658,6 @@ const silu = jit$1(function silu$1(x) {
5863
6658
  return x.ref.mul(sigmoid(x));
5864
6659
  });
5865
6660
  /**
5866
- * @function
5867
- * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
5868
- * Swish, computed element-wise:
5869
- * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
5870
- *
5871
- * `swish()` and `silu()` are both aliases for the same function.
5872
- *
5873
- * Reference: https://en.wikipedia.org/wiki/Swish_function
5874
- */
5875
- const swish = silu;
5876
- /**
5877
6661
  * Log-sigmoid activation function, computed element-wise:
5878
6662
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
5879
6663
  */
@@ -5890,6 +6674,19 @@ function leakyRelu(x, negativeSlope = .01) {
5890
6674
  x = fudgeArray(x);
5891
6675
  return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
5892
6676
  }
6677
+ /** Hard sigmoid activation function: `relu6(x+3)/6`. */
6678
+ function hardSigmoid(x) {
6679
+ return relu6(add(x, 3)).mul(1 / 6);
6680
+ }
6681
+ /** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
6682
+ function hardSilu(x) {
6683
+ x = fudgeArray(x);
6684
+ return x.ref.mul(hardSigmoid(x));
6685
+ }
6686
+ /** Hard tanh activation function: `clip(x, -1, 1)`. */
6687
+ function hardTanh(x) {
6688
+ return clip(x, -1, 1);
6689
+ }
5893
6690
  /**
5894
6691
  * Exponential linear unit activation function.
5895
6692
  *
@@ -5912,6 +6709,20 @@ function celu(x, alpha = 1) {
5912
6709
  }
5913
6710
  /**
5914
6711
  * @function
6712
+ * Scaled exponential linear unit activation.
6713
+ *
6714
+ * Computes the element-wise function:
6715
+ * `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
6716
+ *
6717
+ * Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
6718
+ */
6719
+ const selu = jit$1(function selu$1(x) {
6720
+ const alpha = 1.6732632423543772;
6721
+ const lambda = 1.0507009873554805;
6722
+ return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
6723
+ });
6724
+ /**
6725
+ * @function
5915
6726
  * Gaussion error linear unit (GELU) activation function.
5916
6727
  *
5917
6728
  * This is computed element-wise. There are two variants depending on whether
@@ -6005,22 +6816,22 @@ function logSoftmax(x, axis = -1) {
6005
6816
  *
6006
6817
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
6007
6818
  */
6008
- function logsumexp(x, axis = null) {
6819
+ function logsumexp(x, axis = null, opts) {
6009
6820
  x = fudgeArray(x);
6010
6821
  axis = require_backend.normalizeAxis(axis, x.ndim);
6011
6822
  if (axis.length === 0) return x;
6012
- const xMax = stopGradient(max(x.ref, axis));
6013
- const xMaxDims = broadcast(xMax.ref, x.shape, axis);
6014
- const shifted = x.sub(xMaxDims);
6015
- return xMax.add(log(exp(shifted).sum(axis)));
6823
+ const xMax = stopGradient(max(x.ref, axis, { keepdims: true }));
6824
+ const shifted = x.sub(xMax.ref);
6825
+ const result = xMax.add(log(exp(shifted).sum(axis, { keepdims: true })));
6826
+ return opts?.keepdims ? result : squeeze(result, axis);
6016
6827
  }
6017
6828
  /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
6018
- function logmeanexp(x, axis = null) {
6829
+ function logmeanexp(x, axis = null, opts) {
6019
6830
  x = fudgeArray(x);
6020
6831
  axis = require_backend.normalizeAxis(axis, x.ndim);
6021
6832
  if (axis.length === 0) return x;
6022
6833
  const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
6023
- return logsumexp(x, axis).sub(Math.log(n));
6834
+ return logsumexp(x, axis, opts).sub(Math.log(n));
6024
6835
  }
6025
6836
  /**
6026
6837
  * Standardizes input to zero mean and unit variance.
@@ -6065,8 +6876,11 @@ var random_exports = {};
6065
6876
  __export(random_exports, {
6066
6877
  bernoulli: () => bernoulli,
6067
6878
  bits: () => bits,
6879
+ cauchy: () => cauchy,
6068
6880
  exponential: () => exponential,
6881
+ gumbel: () => gumbel,
6069
6882
  key: () => key,
6883
+ laplace: () => laplace,
6070
6884
  normal: () => normal,
6071
6885
  split: () => split,
6072
6886
  uniform: () => uniform
@@ -6125,6 +6939,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
6125
6939
  }
6126
6940
  /**
6127
6941
  * @function
6942
+ * Sample from a Cauchy distribution with location 0 and scale 1.
6943
+ *
6944
+ * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
6945
+ */
6946
+ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
6947
+ const u = uniform(key$1, shape$1);
6948
+ return tan(u.sub(.5).mul(Math.PI));
6949
+ }, { staticArgnums: [1] });
6950
+ /**
6951
+ * @function
6128
6952
  * Sample exponential random values according to `p(x) = exp(-x)`.
6129
6953
  */
6130
6954
  const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
@@ -6133,6 +6957,30 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
6133
6957
  }, { staticArgnums: [1] });
6134
6958
  /**
6135
6959
  * @function
6960
+ * Sample from a Gumbel distribution with location 0 and scale 1.
6961
+ *
6962
+ * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
6963
+ */
6964
+ const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
6965
+ const u = uniform(key$1, shape$1);
6966
+ return negative(log(negative(log1p(negative(u)))));
6967
+ }, { staticArgnums: [1] });
6968
+ /**
6969
+ * @function
6970
+ * Sample from a Laplace distribution with location 0 and scale 1.
6971
+ *
6972
+ * Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
6973
+ * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
6974
+ */
6975
+ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
6976
+ const u = uniform(key$1, shape$1);
6977
+ const centered = u.sub(.5);
6978
+ const s = sign(centered.ref);
6979
+ const absVal = absolute(centered);
6980
+ return s.mul(log1p(absVal.mul(-2)).mul(-1));
6981
+ }, { staticArgnums: [1] });
6982
+ /**
6983
+ * @function
6136
6984
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6137
6985
  *
6138
6986
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
@@ -6241,11 +7089,6 @@ const valueAndGrad = valueAndGrad$1;
6241
7089
  */
6242
7090
  const jacrev = jacrev$1;
6243
7091
  /**
6244
- * @function
6245
- * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
6246
- */
6247
- const jacobian = jacrev;
6248
- /**
6249
7092
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
6250
7093
  *
6251
7094
  * This can be used to wait for the results of an intermediate computation to
@@ -6281,6 +7124,7 @@ async function devicePut(x, device) {
6281
7124
 
6282
7125
  //#endregion
6283
7126
  exports.Array = Array$1;
7127
+ exports.ClosedJaxpr = ClosedJaxpr;
6284
7128
  exports.DType = require_backend.DType;
6285
7129
  exports.Jaxpr = Jaxpr;
6286
7130
  exports.blockUntilReady = blockUntilReady;
@@ -6290,7 +7134,7 @@ exports.devices = require_backend.devices;
6290
7134
  exports.grad = grad;
6291
7135
  exports.init = require_backend.init;
6292
7136
  exports.jacfwd = jacfwd;
6293
- exports.jacobian = jacobian;
7137
+ exports.jacobian = jacrev;
6294
7138
  exports.jacrev = jacrev;
6295
7139
  exports.jit = jit;
6296
7140
  exports.jvp = jvp;
@@ -6335,5 +7179,4 @@ Object.defineProperty(exports, 'tree', {
6335
7179
  });
6336
7180
  exports.valueAndGrad = valueAndGrad;
6337
7181
  exports.vjp = vjp;
6338
- exports.vmap = vmap;
6339
- //# sourceMappingURL=index.cjs.map
7182
+ exports.vmap = vmap;