@jax-js/jax 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-Bu9GY6sK.cjs');
33
+ const require_backend = require('./backend-D7s-Retx.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -240,7 +240,7 @@ __export(tree_exports, {
240
240
  structure: () => structure,
241
241
  unflatten: () => unflatten
242
242
  });
243
- const JsArray$1 = globalThis.Array;
243
+ const JsArray$2 = globalThis.Array;
244
244
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
245
245
  NodeType$1["Array"] = "Array";
246
246
  NodeType$1["Object"] = "Object";
@@ -288,7 +288,7 @@ function flatten(tree) {
288
288
  return [leaves$1, treedef];
289
289
  }
290
290
  function _flatten(tree, leaves$1) {
291
- if (JsArray$1.isArray(tree)) {
291
+ if (JsArray$2.isArray(tree)) {
292
292
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
293
293
  return new JsTreeDef(NodeType.Array, null, childTrees);
294
294
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -387,6 +387,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
387
387
  Primitive$1["PoolTranspose"] = "pool_transpose";
388
388
  Primitive$1["Compare"] = "compare";
389
389
  Primitive$1["Where"] = "where";
390
+ Primitive$1["Concatenate"] = "concatenate";
391
+ Primitive$1["Split"] = "split";
390
392
  Primitive$1["RandomBits"] = "random_bits";
391
393
  Primitive$1["Gather"] = "gather";
392
394
  Primitive$1["Transpose"] = "transpose";
@@ -399,6 +401,7 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
399
401
  Primitive$1["Argsort"] = "argsort";
400
402
  Primitive$1["TriangularSolve"] = "triangular_solve";
401
403
  Primitive$1["Cholesky"] = "cholesky";
404
+ Primitive$1["LU"] = "lu";
402
405
  Primitive$1["Jit"] = "jit";
403
406
  return Primitive$1;
404
407
  }({});
@@ -409,6 +412,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
409
412
  CompareOp$1["LessEqual"] = "less_equal";
410
413
  return CompareOp$1;
411
414
  }({});
415
+ const routinePrimitives = new Map([
416
+ [Primitive.Sort, require_backend.Routines.Sort],
417
+ [Primitive.Argsort, require_backend.Routines.Argsort],
418
+ [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
419
+ [Primitive.Cholesky, require_backend.Routines.Cholesky],
420
+ [Primitive.LU, require_backend.Routines.LU]
421
+ ]);
412
422
  function add$1(x, y) {
413
423
  return bind1(Primitive.Add, [x, y]);
414
424
  }
@@ -530,7 +540,25 @@ function where$1(cond, x, y) {
530
540
  y
531
541
  ]);
532
542
  }
543
+ function concatenate$1(xs, axis) {
544
+ if (xs.length === 0) throw new Error("concatenate requires at least one input");
545
+ const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
546
+ axis = require_backend.checkAxis(axis, avals[0].ndim);
547
+ for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
548
+ return bind1(Primitive.Concatenate, xs, { axis });
549
+ }
550
+ function split$2(x, axis, sizes) {
551
+ axis = require_backend.checkAxis(axis, ndim$1(x));
552
+ if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
553
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
554
+ if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
555
+ return bind(Primitive.Split, [x], {
556
+ axis,
557
+ sizes
558
+ });
559
+ }
533
560
  function randomBits(k0, k1, shape$1, mode = "xor") {
561
+ if (!require_backend.deepEqual(k0.shape, k1.shape) || k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
534
562
  return bind1(Primitive.RandomBits, [k0, k1], {
535
563
  shape: shape$1,
536
564
  mode
@@ -597,6 +625,11 @@ function pad$1(x, width) {
597
625
  return bind1(Primitive.Pad, [x], { width });
598
626
  }
599
627
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
628
+ const as = getShape(a);
629
+ const bs = getShape(b);
630
+ if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
631
+ const n = as[as.length - 2];
632
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
600
633
  if (lower) {
601
634
  a = flip$1(a, [-2, -1]);
602
635
  b = flip$1(b, [-1]);
@@ -606,8 +639,15 @@ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
606
639
  return x;
607
640
  }
608
641
  function cholesky$2(x) {
642
+ const aval = ShapedArray.fromAval(getAval(x));
643
+ if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
609
644
  return bind1(Primitive.Cholesky, [x]);
610
645
  }
646
+ function lu$1(x) {
647
+ const aval = ShapedArray.fromAval(getAval(x));
648
+ if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
649
+ return bind(Primitive.LU, [x]);
650
+ }
611
651
  function sort$1(x) {
612
652
  const nd = ndim$1(x);
613
653
  if (nd === 0) throw new Error("sort: requires at least 1D input");
@@ -652,6 +692,9 @@ function newDynamic(main) {
652
692
  dynamicTrace = prevDynamicTrace;
653
693
  } };
654
694
  }
695
+ function currentTraceLevel() {
696
+ return traceStack[traceStack.length - 1].level;
697
+ }
655
698
  var Trace = class {
656
699
  constructor(main) {
657
700
  this.main = main;
@@ -716,6 +759,9 @@ var Tracer = class Tracer {
716
759
  mul(other) {
717
760
  return mul(this, other);
718
761
  }
762
+ mod(other) {
763
+ return mod(this, other);
764
+ }
719
765
  greater(other) {
720
766
  return greater$1(this, other);
721
767
  }
@@ -828,8 +874,14 @@ var Tracer = class Tracer {
828
874
  */
829
875
  *[Symbol.iterator]() {
830
876
  if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
831
- for (let i = 0; i < this.shape[0]; i++) yield this.ref.slice(i);
832
- this.dispose();
877
+ let residual = this;
878
+ const subarrayShape = this.shape.slice(1);
879
+ for (let i = 0; i < this.shape[0]; i++) {
880
+ const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
881
+ yield lr[0].reshape(subarrayShape);
882
+ residual = lr[1];
883
+ }
884
+ residual.dispose();
833
885
  }
834
886
  /**
835
887
  * Return a sorted copy of an array in ascending order.
@@ -979,6 +1031,9 @@ var ShapedArray = class ShapedArray {
979
1031
  get size() {
980
1032
  return require_backend.prod(this.shape);
981
1033
  }
1034
+ scalar() {
1035
+ return new ShapedArray([], this.dtype, this.weakType);
1036
+ }
982
1037
  toString() {
983
1038
  return `${this.dtype}[${this.shape.join(",")}]`;
984
1039
  }
@@ -1017,6 +1072,7 @@ var TreeMismatchError = class extends TypeError {
1017
1072
  super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
1018
1073
  }
1019
1074
  };
1075
+ /** Flatten a function of `JsTree` input/output for use in tracing. */
1020
1076
  function flattenFun(f, inTree) {
1021
1077
  const store = { value: void 0 };
1022
1078
  const flatFun = (...argsFlat) => {
@@ -1028,6 +1084,26 @@ function flattenFun(f, inTree) {
1028
1084
  };
1029
1085
  return [flatFun, store];
1030
1086
  }
1087
+ /** Like flattenFun, but expects f to return [main, aux] tuple. */
1088
+ function flattenFunWithAux(f, inTree) {
1089
+ const store = { value: void 0 };
1090
+ const auxStore = { value: void 0 };
1091
+ const flatFun = (...argsFlat) => {
1092
+ const pytreeArgs = unflatten(inTree, argsFlat);
1093
+ const result = f(...pytreeArgs);
1094
+ if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
1095
+ const [out, aux] = result;
1096
+ const [outFlat, outTree] = flatten(out);
1097
+ store.value = outTree;
1098
+ auxStore.value = aux;
1099
+ return outFlat;
1100
+ };
1101
+ return [
1102
+ flatFun,
1103
+ store,
1104
+ auxStore
1105
+ ];
1106
+ }
1031
1107
  var UseAfterFreeError = class extends ReferenceError {
1032
1108
  constructor(tracer) {
1033
1109
  super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
@@ -1588,7 +1664,7 @@ const abstractEvalRules = {
1588
1664
  return [new ShapedArray(shape$1, dtype, weakType)];
1589
1665
  },
1590
1666
  [Primitive.Conv]([lhs, rhs], params) {
1591
- const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
1667
+ const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
1592
1668
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
1593
1669
  return [new ShapedArray(shape$1, dtype, weakType)];
1594
1670
  },
@@ -1599,10 +1675,25 @@ const abstractEvalRules = {
1599
1675
  const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
1600
1676
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1601
1677
  },
1678
+ [Primitive.Concatenate](xs, { axis }) {
1679
+ if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
1680
+ for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
1681
+ const shape$1 = xs[0].shape.slice();
1682
+ shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
1683
+ const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
1684
+ return [new ShapedArray(shape$1, dtype, weakType)];
1685
+ },
1686
+ [Primitive.Split]([x], { axis, sizes }) {
1687
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
1688
+ if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
1689
+ return sizes.map((size$1) => {
1690
+ return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
1691
+ });
1692
+ },
1602
1693
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1603
1694
  if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1604
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
1605
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1695
+ if (!require_backend.deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
1696
+ if (!require_backend.deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
1606
1697
  return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
1607
1698
  },
1608
1699
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
@@ -1659,6 +1750,16 @@ const abstractEvalRules = {
1659
1750
  if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1660
1751
  return [ShapedArray.fromAval(a)];
1661
1752
  },
1753
+ [Primitive.LU]([a]) {
1754
+ if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
1755
+ const batch = a.shape.slice(0, -2);
1756
+ const [m, n] = a.shape.slice(-2);
1757
+ return [
1758
+ ShapedArray.fromAval(a),
1759
+ new ShapedArray([...batch, Math.min(m, n)], require_backend.DType.Int32, false),
1760
+ new ShapedArray([...batch, m], require_backend.DType.Int32, false)
1761
+ ];
1762
+ },
1662
1763
  [Primitive.Jit](args, { jaxpr }) {
1663
1764
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1664
1765
  if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
@@ -1736,12 +1837,6 @@ function jit$1(f, opts) {
1736
1837
 
1737
1838
  //#endregion
1738
1839
  //#region src/frontend/jit.ts
1739
- const routinePrimitives = new Map([
1740
- [Primitive.Sort, require_backend.Routines.Sort],
1741
- [Primitive.Argsort, require_backend.Routines.Argsort],
1742
- [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
1743
- [Primitive.Cholesky, require_backend.Routines.Cholesky]
1744
- ]);
1745
1840
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1746
1841
  var JitProgram = class {
1747
1842
  constructor(backend, steps, inputs, outputs) {
@@ -1911,10 +2006,10 @@ function jitCompile(backend, jaxpr) {
1911
2006
  inputs.push(jv.arg);
1912
2007
  } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1913
2008
  const outputs = [];
1914
- for (const outVar$1 of eqn.outBinders) {
1915
- const outId = builder.pushBuffer(outVar$1.aval.size * require_backend.byteWidth(outVar$1.aval.dtype));
2009
+ for (const outVar of eqn.outBinders) {
2010
+ const outId = builder.pushBuffer(outVar.aval.size * require_backend.byteWidth(outVar.aval.dtype));
1916
2011
  outputs.push(outId);
1917
- ctx.set(outVar$1, {
2012
+ ctx.set(outVar, {
1918
2013
  type: "imm",
1919
2014
  arg: outId
1920
2015
  });
@@ -1965,35 +2060,37 @@ function jitCompile(backend, jaxpr) {
1965
2060
  let reduction;
1966
2061
  if (inputReduction) {
1967
2062
  const jv = inputReduction;
1968
- const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1969
- exp$2 = jv.exp.reindexGids(addArgs(jv.args));
2063
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
2064
+ exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
1970
2065
  reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1971
2066
  } else {
1972
2067
  const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1973
2068
  exp$2 = ruleOutput.exp;
1974
2069
  reduction = ruleOutput.reduction;
1975
2070
  }
1976
- const outVar = eqn.outBinders[0];
1977
- if (blackNodes.has(outVar)) {
1978
- const nargs$1 = inputArgs.length;
1979
- const size$1 = outVar.aval.size;
1980
- const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2, reduction);
1981
- const outId = builder.pushKernel(kernel, inputArgs);
1982
- ctx.set(outVar, {
1983
- type: "imm",
1984
- arg: outId
2071
+ for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
2072
+ const outVar = eqn.outBinders[i$1];
2073
+ if (blackNodes.has(outVar)) {
2074
+ const nargs$1 = inputArgs.length;
2075
+ const size$1 = outVar.aval.size;
2076
+ const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2[i$1], reduction);
2077
+ const outId = builder.pushKernel(kernel, inputArgs);
2078
+ ctx.set(outVar, {
2079
+ type: "imm",
2080
+ arg: outId
2081
+ });
2082
+ } else if (reduction) ctx.set(outVar, {
2083
+ type: "red",
2084
+ exp: exp$2[i$1],
2085
+ reduction,
2086
+ args: inputArgs
1985
2087
  });
1986
- } else if (reduction) ctx.set(outVar, {
1987
- type: "red",
1988
- exp: exp$2,
1989
- reduction,
1990
- args: inputArgs
1991
- });
1992
- else ctx.set(outVar, {
1993
- type: "exp",
1994
- exp: exp$2,
1995
- args: inputArgs
1996
- });
2088
+ else ctx.set(outVar, {
2089
+ type: "exp",
2090
+ exp: exp$2[i$1],
2091
+ args: inputArgs
2092
+ });
2093
+ }
1997
2094
  }
1998
2095
  const outputIds = [];
1999
2096
  for (const out of jaxpr.outs) if (out instanceof Var) {
@@ -2034,17 +2131,17 @@ function broadcastedJit(fn, opts) {
2034
2131
  if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
2035
2132
  return exp$2;
2036
2133
  });
2037
- return { exp: fn(exps, params) };
2134
+ return { exp: [fn(exps, params)] };
2038
2135
  };
2039
2136
  }
2040
2137
  function unopJit(fn) {
2041
2138
  return ([a], [_as], params) => {
2042
- return { exp: fn(a, params) };
2139
+ return { exp: [fn(a, params)] };
2043
2140
  };
2044
2141
  }
2045
2142
  function reshapeJit(fn) {
2046
2143
  return ([a], [_as], params) => {
2047
- return { exp: reshapeViews(a, (st) => fn(st, params)) };
2144
+ return { exp: [reshapeViews(a, (st) => fn(st, params))] };
2048
2145
  };
2049
2146
  }
2050
2147
  function routineNoJit() {
@@ -2090,7 +2187,7 @@ const jitRules = {
2090
2187
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
2091
2188
  const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
2092
2189
  return {
2093
- exp: a,
2190
+ exp: [a],
2094
2191
  reduction
2095
2192
  };
2096
2193
  },
@@ -2101,13 +2198,13 @@ const jitRules = {
2101
2198
  a = reshapeViews(a, (st) => st.compose(stX), true);
2102
2199
  const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
2103
2200
  return {
2104
- exp: a,
2201
+ exp: [a],
2105
2202
  reduction
2106
2203
  };
2107
2204
  },
2108
2205
  [Primitive.Dot]([a, b], [as, bs]) {
2109
2206
  const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
2110
- const c = k1.exp;
2207
+ const [c] = k1.exp;
2111
2208
  const cs = promoteAvals(as, bs);
2112
2209
  return jitRules[Primitive.Reduce]([c], [cs], {
2113
2210
  op: require_backend.AluOp.Add,
@@ -2124,16 +2221,42 @@ const jitRules = {
2124
2221
  },
2125
2222
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
2126
2223
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
2224
+ [Primitive.Concatenate](exps, avals, { axis }) {
2225
+ const ndim$2 = avals[0].ndim;
2226
+ const sizes = avals.map((x) => x.shape[axis]);
2227
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
2228
+ const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
2229
+ const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2230
+ let cum = 0;
2231
+ const src = [];
2232
+ for (let i = 0; i < exps.length; i++) {
2233
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2234
+ src.push(reshapeViews(require_backend.AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
2235
+ cum += sizes[i];
2236
+ }
2237
+ return { exp: [src.reduce(require_backend.AluExp.add)] };
2238
+ },
2239
+ [Primitive.Split]([a], [as], { axis, sizes }) {
2240
+ const exp$2 = [];
2241
+ let start = 0;
2242
+ for (const size$1 of sizes) {
2243
+ const slice = require_backend.range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
2244
+ exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
2245
+ start += size$1;
2246
+ }
2247
+ return { exp: exp$2 };
2248
+ },
2127
2249
  [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2250
+ const keyShape = keyShapes[0].shape;
2128
2251
  const mapping = (st) => {
2129
- if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
2252
+ if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(st.shape.length, shape$1.length));
2130
2253
  };
2131
2254
  const k0 = reshapeViews(keys[0], mapping);
2132
2255
  const k1 = reshapeViews(keys[1], mapping);
2133
2256
  const c0 = require_backend.AluExp.u32(0);
2134
- const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
2257
+ const c1 = require_backend.AluExp.mod(require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx), require_backend.AluExp.u32(Math.max(require_backend.prod(shape$1.slice(keyShape.length)), 1)));
2135
2258
  const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
2136
- return { exp: exp$2 };
2259
+ return { exp: [exp$2] };
2137
2260
  },
2138
2261
  [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
2139
2262
  const axisSet = new Set(axis);
@@ -2148,7 +2271,7 @@ const jitRules = {
2148
2271
  for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
2149
2272
  const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
2150
2273
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
2151
- return { exp: x.substitute({ gidx: index }) };
2274
+ return { exp: [x.substitute({ gidx: index })] };
2152
2275
  },
2153
2276
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2154
2277
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
@@ -2164,6 +2287,7 @@ const jitRules = {
2164
2287
  [Primitive.Argsort]: routineNoJit(),
2165
2288
  [Primitive.TriangularSolve]: routineNoJit(),
2166
2289
  [Primitive.Cholesky]: routineNoJit(),
2290
+ [Primitive.LU]: routineNoJit(),
2167
2291
  [Primitive.Jit]() {
2168
2292
  throw new Error("internal: Jit should have been flattened before JIT compilation");
2169
2293
  }
@@ -2245,7 +2369,7 @@ function splitGraphDataflow(backend, jaxpr) {
2245
2369
  p1NextBlack.set(v, v);
2246
2370
  }
2247
2371
  const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2248
- const needsCleanShapePrimitives = [Primitive.Pad];
2372
+ const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
2249
2373
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2250
2374
  const eqn = jaxpr.eqns[i];
2251
2375
  if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
@@ -2315,7 +2439,7 @@ function splitGraphDataflow(backend, jaxpr) {
2315
2439
 
2316
2440
  //#endregion
2317
2441
  //#region src/frontend/array.ts
2318
- const JsArray = globalThis.Array;
2442
+ const JsArray$1 = globalThis.Array;
2319
2443
  const inlineArrayLimit = 128;
2320
2444
  /** Version of pureArray with fudged types. */
2321
2445
  const fudgeArray = pureArray;
@@ -2442,6 +2566,10 @@ var Array$1 = class Array$1 extends Tracer {
2442
2566
  this.#rc++;
2443
2567
  return this;
2444
2568
  }
2569
+ /** Get the current reference count (for debugging memory management). */
2570
+ get refCount() {
2571
+ return this.#rc;
2572
+ }
2445
2573
  dispose() {
2446
2574
  this.#check();
2447
2575
  if (--this.#rc === 0) {
@@ -2599,7 +2727,7 @@ var Array$1 = class Array$1 extends Tracer {
2599
2727
  } else if (castDtype === void 0) {
2600
2728
  castDtype = arrays[i].#dtype;
2601
2729
  castWeakType = arrays[i].#weakType;
2602
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
2730
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
2603
2731
  const weakType = castWeakType && !strongTypeOutput;
2604
2732
  const { backend, committed } = Array$1.#computeBackend(name, arrays);
2605
2733
  arrays = arrays.map((ar) => ar._putSync(backend));
@@ -2709,25 +2837,35 @@ var Array$1 = class Array$1 extends Tracer {
2709
2837
  });
2710
2838
  }
2711
2839
  /** Apply an operation with custom lowering to this array. */
2712
- static #routine(routine, arrays, outputWeakType) {
2713
- const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2714
- for (const ar of arrays) ar.#realize();
2715
- const inputs = arrays.map((ar) => ar.#source);
2716
- const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(require_backend.byteWidth(dtype) * require_backend.prod(routine.type.outputShapes[i])));
2717
- const pending = arrays.flatMap((ar) => ar.#pending);
2718
- for (const exe of pending) exe.updateRc(+outputs.length);
2719
- pending.push(new PendingExecute(backend, routine, inputs, outputs));
2720
- pending[pending.length - 1].updateRc(+outputs.length - 1);
2721
- arrays.forEach((ar) => ar.dispose());
2722
- return outputs.map((output, i) => new Array$1({
2723
- source: output,
2724
- st: require_backend.ShapeTracker.fromShape(routine.type.outputShapes[i]),
2725
- dtype: routine.type.outputDtypes[i],
2726
- weakType: outputWeakType[i],
2727
- backend,
2728
- committed,
2729
- pending
2730
- }));
2840
+ static #routine(prim) {
2841
+ return (arrays, params) => {
2842
+ const { backend, committed } = Array$1.#computeBackend(prim, arrays);
2843
+ for (const ar of arrays) ar.#realize();
2844
+ const avals = arrays.map((ar) => ar.aval);
2845
+ const avalsOut = abstractEvalRules[prim](avals, params);
2846
+ const routine = new require_backend.Routine(routinePrimitives.get(prim), {
2847
+ inputShapes: avals.map((a) => a.shape),
2848
+ inputDtypes: avals.map((a) => a.dtype),
2849
+ outputShapes: avalsOut.map((a) => a.shape),
2850
+ outputDtypes: avalsOut.map((a) => a.dtype)
2851
+ }, params);
2852
+ const inputs = arrays.map((ar) => ar.#source);
2853
+ const outputs = avalsOut.map((x) => backend.malloc(require_backend.byteWidth(x.dtype) * x.size));
2854
+ const pending = arrays.flatMap((ar) => ar.#pending);
2855
+ for (const exe of pending) exe.updateRc(+outputs.length);
2856
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2857
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2858
+ arrays.forEach((ar) => ar.dispose());
2859
+ return outputs.map((output, i) => new Array$1({
2860
+ source: output,
2861
+ st: require_backend.ShapeTracker.fromShape(avalsOut[i].shape),
2862
+ dtype: avalsOut[i].dtype,
2863
+ weakType: avalsOut[i].weakType,
2864
+ backend,
2865
+ committed,
2866
+ pending
2867
+ }));
2868
+ };
2731
2869
  }
2732
2870
  /**
2733
2871
  * Normalizes this array into one backed by a `Slot`.
@@ -2992,17 +3130,44 @@ var Array$1 = class Array$1 extends Tracer {
2992
3130
  y
2993
3131
  ], { dtypeOverride: [require_backend.DType.Bool] })];
2994
3132
  },
3133
+ [Primitive.Concatenate](xs, { axis }) {
3134
+ const ndim$2 = xs[0].ndim;
3135
+ const sizes = xs.map((x) => x.shape[axis]);
3136
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
3137
+ const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
3138
+ let cum = 0;
3139
+ const xsPadded = [];
3140
+ for (let i = 0; i < xs.length; i++) {
3141
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
3142
+ xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
3143
+ cum += sizes[i];
3144
+ }
3145
+ const custom = (exps) => exps.reduce(require_backend.AluExp.add);
3146
+ return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
3147
+ },
3148
+ [Primitive.Split]([x], { axis, sizes }) {
3149
+ const outputs = [];
3150
+ for (let i = 0, start = 0; i < sizes.length; i++) {
3151
+ const slice = require_backend.range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
3152
+ outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
3153
+ start += sizes[i];
3154
+ }
3155
+ x.dispose();
3156
+ return outputs;
3157
+ },
2995
3158
  [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2996
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2997
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2998
- const c0 = zeros(shape$1, {
3159
+ const keyShape = k0.shape;
3160
+ const genShape = shape$1.slice(keyShape.length);
3161
+ const c0 = zeros(genShape, {
2999
3162
  dtype: require_backend.DType.Uint32,
3000
3163
  device: k0.device
3001
3164
  });
3002
- const c1 = arange(0, require_backend.prod(shape$1), 1, {
3165
+ const c1 = arange(0, require_backend.prod(genShape), 1, {
3003
3166
  dtype: require_backend.DType.Uint32,
3004
3167
  device: k0.device
3005
- }).reshape(shape$1);
3168
+ }).reshape(genShape);
3169
+ k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
3170
+ k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
3006
3171
  const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
3007
3172
  return [Array$1.#naryCustom("random_bits", custom, [
3008
3173
  k0,
@@ -3034,42 +3199,11 @@ var Array$1 = class Array$1 extends Tracer {
3034
3199
  [Primitive.Pad]([x], { width }) {
3035
3200
  return [x.#reshape(x.#st.pad(width))];
3036
3201
  },
3037
- [Primitive.Sort]([x]) {
3038
- const routine = new require_backend.Routine(require_backend.Routines.Sort, {
3039
- inputShapes: [x.aval.shape],
3040
- inputDtypes: [x.aval.dtype],
3041
- outputShapes: [x.aval.shape],
3042
- outputDtypes: [x.aval.dtype]
3043
- });
3044
- return Array$1.#routine(routine, [x], [x.#weakType]);
3045
- },
3046
- [Primitive.Argsort]([x]) {
3047
- const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
3048
- inputShapes: [x.aval.shape],
3049
- inputDtypes: [x.aval.dtype],
3050
- outputShapes: [x.aval.shape, x.aval.shape],
3051
- outputDtypes: [x.aval.dtype, require_backend.DType.Int32]
3052
- });
3053
- return Array$1.#routine(routine, [x], [x.#weakType, false]);
3054
- },
3055
- [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3056
- const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
3057
- inputShapes: [a.aval.shape, b.aval.shape],
3058
- inputDtypes: [a.aval.dtype, b.aval.dtype],
3059
- outputShapes: [b.aval.shape],
3060
- outputDtypes: [b.aval.dtype]
3061
- }, { unitDiagonal });
3062
- return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3063
- },
3064
- [Primitive.Cholesky]([a]) {
3065
- const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
3066
- inputShapes: [a.aval.shape],
3067
- inputDtypes: [a.aval.dtype],
3068
- outputShapes: [a.aval.shape],
3069
- outputDtypes: [a.aval.dtype]
3070
- });
3071
- return Array$1.#routine(routine, [a], [a.#weakType]);
3072
- },
3202
+ [Primitive.Sort]: Array$1.#routine(Primitive.Sort),
3203
+ [Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
3204
+ [Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
3205
+ [Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
3206
+ [Primitive.LU]: Array$1.#routine(Primitive.LU),
3073
3207
  [Primitive.Jit](args, { jaxpr }) {
3074
3208
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3075
3209
  const { backend, committed } = Array$1.#computeBackend("jit", args);
@@ -3151,7 +3285,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3151
3285
  if (!shape$1) {
3152
3286
  shape$1 = [];
3153
3287
  let cur = values;
3154
- while (JsArray.isArray(cur)) {
3288
+ while (JsArray$1.isArray(cur)) {
3155
3289
  shape$1.push(cur.length);
3156
3290
  cur = cur[0];
3157
3291
  }
@@ -3175,7 +3309,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3175
3309
  device
3176
3310
  });
3177
3311
  } else {
3178
- const weakType = dtype == void 0;
3312
+ const weakType = dtype == void 0 && shape$1.length === 0;
3179
3313
  dtype = dtype ?? require_backend.DType.Float32;
3180
3314
  const data = require_backend.dtypedJsArray(dtype, flat);
3181
3315
  return arrayFromData(data, shape$1, {
@@ -3289,7 +3423,7 @@ function ones(shape$1, { dtype, device } = {}) {
3289
3423
  }
3290
3424
  /** Return a new array of given shape and type, filled with `fill_value`. */
3291
3425
  function full(shape$1, fillValue, { dtype, device } = {}) {
3292
- let weakType = dtype == void 0;
3426
+ let weakType = dtype == void 0 && shape$1.length === 0;
3293
3427
  if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
3294
3428
  else if (typeof fillValue === "boolean") {
3295
3429
  dtype = dtype ?? require_backend.DType.Bool;
@@ -3447,6 +3581,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
3447
3581
  committed: device != void 0
3448
3582
  });
3449
3583
  }
3584
+ /**
3585
+ * Return numbers spaced evenly on a log scale.
3586
+ *
3587
+ * In linear space, the sequence starts at `base ** start` and ends at
3588
+ * `base ** stop` (see `endpoint` below).
3589
+ *
3590
+ * @param start - `base ** start` is the starting value of the sequence.
3591
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
3592
+ * @param num - Number of samples to generate. Default is 50.
3593
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
3594
+ * @param base - The base of the log space. Default is 10.
3595
+ * @returns Array of evenly spaced values on a log scale.
3596
+ */
3597
+ function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
3598
+ const y = linspace(start, stop, num, endpoint, {
3599
+ dtype,
3600
+ device
3601
+ });
3602
+ const logBase = Math.log(base);
3603
+ return exp$1(mul(y, logBase));
3604
+ }
3450
3605
  function aluCompare(a, b, op) {
3451
3606
  switch (op) {
3452
3607
  case CompareOp.Less: return require_backend.AluExp.cmplt(a, b);
@@ -3524,6 +3679,7 @@ var BatchTrace = class extends Trace {
3524
3679
  return valOuts$1.map((x) => new BatchTracer(this, x, null));
3525
3680
  }
3526
3681
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3682
+ if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
3527
3683
  return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3528
3684
  }
3529
3685
  get axisSize() {
@@ -3535,13 +3691,13 @@ var BatchTrace = class extends Trace {
3535
3691
  *
3536
3692
  * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3537
3693
  */
3538
- function broadcastBatcher(op) {
3539
- return (axisSize, args, dims) => {
3694
+ function broadcastBatcher(prim) {
3695
+ return (axisSize, args, dims, params) => {
3540
3696
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3541
3697
  const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3542
3698
  const firstIdx = dims.findIndex((d) => d !== null);
3543
3699
  const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3544
- if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3700
+ if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
3545
3701
  args = args.map((x, i) => {
3546
3702
  if (dims[i] === null) return x;
3547
3703
  x = moveBatchAxis(axisSize, dims[i], 0, x);
@@ -3552,37 +3708,45 @@ function broadcastBatcher(op) {
3552
3708
  ]);
3553
3709
  return x;
3554
3710
  });
3555
- return [[op(...args)], [0]];
3711
+ return [[bind1(prim, args, params)], [0]];
3556
3712
  };
3557
3713
  }
3558
- function unopBatcher(op) {
3714
+ function unopBatcher(prim) {
3559
3715
  return (axisSize, [x], [xBdim], params) => {
3560
- return [[op(x, params)], [xBdim]];
3716
+ return [[bind1(prim, [x], params)], [xBdim]];
3717
+ };
3718
+ }
3719
+ function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
3720
+ return (axisSize, [x], [xBdim], params) => {
3721
+ require_backend.assertNonNull(xBdim);
3722
+ if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), require_backend.rep(numOutputs, xBdim)];
3723
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3724
+ return [bind(prim, [x], params), require_backend.rep(numOutputs, 0)];
3561
3725
  };
3562
3726
  }
3563
3727
  const vmapRules = {
3564
- [Primitive.Add]: broadcastBatcher(add$1),
3565
- [Primitive.Mul]: broadcastBatcher(mul),
3566
- [Primitive.Idiv]: broadcastBatcher(idiv),
3567
- [Primitive.Mod]: broadcastBatcher(mod),
3568
- [Primitive.Min]: broadcastBatcher(min$1),
3569
- [Primitive.Max]: broadcastBatcher(max$1),
3570
- [Primitive.Neg]: unopBatcher(neg),
3571
- [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3572
- [Primitive.Floor]: unopBatcher(floor$1),
3573
- [Primitive.Ceil]: unopBatcher(ceil$1),
3574
- [Primitive.StopGradient]: unopBatcher(stopGradient),
3575
- [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3576
- [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3577
- [Primitive.Sin]: unopBatcher(sin$1),
3578
- [Primitive.Cos]: unopBatcher(cos$1),
3579
- [Primitive.Asin]: unopBatcher(asin$1),
3580
- [Primitive.Atan]: unopBatcher(atan$1),
3581
- [Primitive.Exp]: unopBatcher(exp$1),
3582
- [Primitive.Log]: unopBatcher(log$1),
3583
- [Primitive.Erf]: unopBatcher(erf$1),
3584
- [Primitive.Erfc]: unopBatcher(erfc$1),
3585
- [Primitive.Sqrt]: unopBatcher(sqrt$1),
3728
+ [Primitive.Add]: broadcastBatcher(Primitive.Add),
3729
+ [Primitive.Mul]: broadcastBatcher(Primitive.Mul),
3730
+ [Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
3731
+ [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3732
+ [Primitive.Min]: broadcastBatcher(Primitive.Min),
3733
+ [Primitive.Max]: broadcastBatcher(Primitive.Max),
3734
+ [Primitive.Neg]: unopBatcher(Primitive.Neg),
3735
+ [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3736
+ [Primitive.Floor]: unopBatcher(Primitive.Floor),
3737
+ [Primitive.Ceil]: unopBatcher(Primitive.Ceil),
3738
+ [Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
3739
+ [Primitive.Cast]: unopBatcher(Primitive.Cast),
3740
+ [Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
3741
+ [Primitive.Sin]: unopBatcher(Primitive.Sin),
3742
+ [Primitive.Cos]: unopBatcher(Primitive.Cos),
3743
+ [Primitive.Asin]: unopBatcher(Primitive.Asin),
3744
+ [Primitive.Atan]: unopBatcher(Primitive.Atan),
3745
+ [Primitive.Exp]: unopBatcher(Primitive.Exp),
3746
+ [Primitive.Log]: unopBatcher(Primitive.Log),
3747
+ [Primitive.Erf]: unopBatcher(Primitive.Erf),
3748
+ [Primitive.Erfc]: unopBatcher(Primitive.Erfc),
3749
+ [Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
3586
3750
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3587
3751
  require_backend.assertNonNull(xBdim);
3588
3752
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
@@ -3604,10 +3768,25 @@ const vmapRules = {
3604
3768
  });
3605
3769
  return [[z], [0]];
3606
3770
  },
3607
- [Primitive.Compare](axisSize, args, dims, { op }) {
3608
- return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3771
+ [Primitive.Compare]: broadcastBatcher(Primitive.Compare),
3772
+ [Primitive.Where]: broadcastBatcher(Primitive.Where),
3773
+ [Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
3774
+ const minBdim = Math.min(...xBdims.filter((d) => d !== null));
3775
+ xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
3776
+ const newAxis = axis + (minBdim <= axis ? 1 : 0);
3777
+ return [[concatenate$1(xs, newAxis)], [minBdim]];
3778
+ },
3779
+ [Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
3780
+ require_backend.assertNonNull(xBdim);
3781
+ const newAxis = axis + (xBdim <= axis ? 1 : 0);
3782
+ const outs = split$2(x, newAxis, sizes);
3783
+ return [outs, require_backend.rep(outs.length, xBdim)];
3784
+ },
3785
+ [Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
3786
+ k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
3787
+ k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
3788
+ return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
3609
3789
  },
3610
- [Primitive.Where]: broadcastBatcher(where$1),
3611
3790
  [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3612
3791
  if (indicesBdim.every((d) => d === null)) {
3613
3792
  require_backend.assertNonNull(xBdim);
@@ -3669,18 +3848,8 @@ const vmapRules = {
3669
3848
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3670
3849
  return [[pad$1(x, newWidth)], [xBdim]];
3671
3850
  },
3672
- [Primitive.Sort](axisSize, [x], [xBdim]) {
3673
- require_backend.assertNonNull(xBdim);
3674
- if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
3675
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3676
- return [[sort$1(x)], [0]];
3677
- },
3678
- [Primitive.Argsort](axisSize, [x], [xBdim]) {
3679
- require_backend.assertNonNull(xBdim);
3680
- if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
3681
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3682
- return [argsort$1(x), [0, 0]];
3683
- },
3851
+ [Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
3852
+ [Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
3684
3853
  [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3685
3854
  if (aBdim === null) {
3686
3855
  b = moveBatchAxis(axisSize, bBdim, -3, b);
@@ -3704,12 +3873,8 @@ const vmapRules = {
3704
3873
  const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3705
3874
  return [[x], [0]];
3706
3875
  },
3707
- [Primitive.Cholesky](axisSize, [x], [xBdim]) {
3708
- require_backend.assertNonNull(xBdim);
3709
- if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
3710
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3711
- return [[cholesky$2(x)], [0]];
3712
- },
3876
+ [Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
3877
+ [Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
3713
3878
  [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3714
3879
  const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3715
3880
  const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
@@ -3860,6 +4025,16 @@ function batchMatmulT(a, b) {
3860
4025
  function mT(a) {
3861
4026
  return moveaxis(a, -2, -1);
3862
4027
  }
4028
+ function sliceAxis(a, axis, p) {
4029
+ const slices = Array(a.shape.length).fill([]);
4030
+ slices[require_backend.checkAxis(axis, a.ndim)] = p;
4031
+ return a.slice(...slices);
4032
+ }
4033
+ function padAxis(a, axis, p) {
4034
+ const pads = Array(a.shape.length).fill([0, 0]);
4035
+ pads[require_backend.checkAxis(axis, a.ndim)] = p;
4036
+ return pad$1(a, pads);
4037
+ }
3863
4038
  const jvpRules = {
3864
4039
  [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3865
4040
  [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
@@ -3958,6 +4133,8 @@ const jvpRules = {
3958
4133
  dcond.dispose();
3959
4134
  return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3960
4135
  },
4136
+ [Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
4137
+ [Primitive.Split]: linearTangentsJvp(Primitive.Split),
3961
4138
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3962
4139
  [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3963
4140
  const indicesRef = indices.map((t) => t.ref);
@@ -3992,6 +4169,38 @@ const jvpRules = {
3992
4169
  const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
3993
4170
  return [[L], [dL]];
3994
4171
  },
4172
+ [Primitive.LU]([a], [da]) {
4173
+ const [luMatrix, pivots, permutation] = lu$1(a);
4174
+ const [m, n] = a.shape.slice(-2);
4175
+ const k = Math.min(m, n);
4176
+ const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
4177
+ const lLower = tril(luSliceL, -1);
4178
+ const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
4179
+ const L = lPadded.add(eye(m));
4180
+ const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
4181
+ const uUpper = triu(luSliceU);
4182
+ const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4183
+ const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4184
+ const U = uPadded.add(uEye);
4185
+ const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4186
+ const pda = batchMatmulT(P, mT(da));
4187
+ const la = mT(triangularSolve$1(L.ref, mT(pda), {
4188
+ lower: true,
4189
+ unitDiagonal: true
4190
+ }));
4191
+ const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
4192
+ const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
4193
+ const uDot = batchMatmulT(triu(lau), mT(U));
4194
+ return [[
4195
+ luMatrix,
4196
+ pivots,
4197
+ permutation
4198
+ ], [
4199
+ lDot.add(uDot),
4200
+ zerosLike$1(pivots.ref),
4201
+ zerosLike$1(permutation.ref)
4202
+ ]];
4203
+ },
3995
4204
  [Primitive.Jit](primals, tangents, { name, jaxpr }) {
3996
4205
  const newJaxpr = jvpJaxpr(jaxpr);
3997
4206
  const outs = bind(Primitive.Jit, [
@@ -4032,17 +4241,39 @@ function jvpFlat(f, primals, tangents) {
4032
4241
  _usingCtx$1.d();
4033
4242
  }
4034
4243
  }
4035
- function jvp$1(f, primals, tangents) {
4244
+ function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
4036
4245
  const [primalsFlat, inTree] = flatten(primals);
4037
4246
  const [tangentsFlat, inTree2] = flatten(tangents);
4038
4247
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4039
- const [flatFun, outTree] = flattenFun(f, inTree);
4248
+ let flatFun, outTree, aux;
4249
+ if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
4250
+ else [flatFun, outTree] = flattenFun(f, inTree);
4040
4251
  const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4041
4252
  if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4042
4253
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
4043
4254
  const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4255
+ if (hasAux) return [
4256
+ primalsOut,
4257
+ tangentsOut,
4258
+ lowerAux(aux.value)
4259
+ ];
4044
4260
  return [primalsOut, tangentsOut];
4045
4261
  }
4262
+ /** Lowering for auxiliary data returned in `hasAux: true` methods. */
4263
+ function lowerAux(aux) {
4264
+ const level = currentTraceLevel();
4265
+ return map((x) => {
4266
+ if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
4267
+ x.tangent.dispose();
4268
+ x = x.primal;
4269
+ } else {
4270
+ const y = x.fullLower();
4271
+ if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
4272
+ x = y;
4273
+ }
4274
+ return x;
4275
+ }, aux);
4276
+ }
4046
4277
 
4047
4278
  //#endregion
4048
4279
  //#region src/frontend/linearize.ts
@@ -4113,9 +4344,11 @@ function linearizeFlat(f, primalsIn) {
4113
4344
  dispose$1
4114
4345
  ];
4115
4346
  }
4116
- function linearize$1(f, ...primalsIn) {
4347
+ function linearize$1(f, primalsIn, { hasAux = false } = {}) {
4117
4348
  const [primalsInFlat, inTree] = flatten(primalsIn);
4118
- const [fFlat, outTree] = flattenFun(f, inTree);
4349
+ let fFlat, outTree, aux;
4350
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4351
+ else [fFlat, outTree] = flattenFun(f, inTree);
4119
4352
  const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
4120
4353
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
4121
4354
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4126,6 +4359,11 @@ function linearize$1(f, ...primalsIn) {
4126
4359
  return unflatten(outTree.value, tangentsOutFlat);
4127
4360
  });
4128
4361
  fLin.dispose = dispose$1;
4362
+ if (hasAux) return [
4363
+ primalsOut,
4364
+ fLin,
4365
+ lowerAux(aux.value)
4366
+ ];
4129
4367
  return [primalsOut, fLin];
4130
4368
  }
4131
4369
  var PartialEvalTracer = class extends Tracer {
@@ -4529,6 +4767,15 @@ const transposeRules = {
4529
4767
  cond.dispose();
4530
4768
  return cts;
4531
4769
  },
4770
+ [Primitive.Concatenate]([ct], inputs, { axis }) {
4771
+ if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
4772
+ const sizes = inputs.map((x) => x.aval.shape[axis]);
4773
+ return split$2(ct, axis, sizes);
4774
+ },
4775
+ [Primitive.Split](cts, [x], { axis }) {
4776
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
4777
+ return [concatenate$1(cts, axis)];
4778
+ },
4532
4779
  [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4533
4780
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4534
4781
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
@@ -4617,9 +4864,11 @@ function vjpFlat(f, primalsIn) {
4617
4864
  dispose$1
4618
4865
  ];
4619
4866
  }
4620
- function vjp$1(f, ...primalsIn) {
4867
+ function vjp$1(f, primalsIn, { hasAux = false } = {}) {
4621
4868
  const [primalsInFlat, inTree] = flatten(primalsIn);
4622
- const [fFlat, outTree] = flattenFun(f, inTree);
4869
+ let fFlat, outTree, aux;
4870
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4871
+ else [fFlat, outTree] = flattenFun(f, inTree);
4623
4872
  const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4624
4873
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4625
4874
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4630,26 +4879,43 @@ function vjp$1(f, ...primalsIn) {
4630
4879
  return unflatten(inTree, cotangentsInFlat);
4631
4880
  });
4632
4881
  fVjp.dispose = dispose$1;
4882
+ if (hasAux) return [
4883
+ primalsOut,
4884
+ fVjp,
4885
+ lowerAux(aux.value)
4886
+ ];
4633
4887
  return [primalsOut, fVjp];
4634
4888
  }
4635
- function grad$1(f) {
4636
- const valueAndGradFn = valueAndGrad$1(f);
4889
+ function grad$1(f, opts) {
4890
+ const valueAndGradFn = valueAndGrad$1(f, opts);
4637
4891
  return (...x) => {
4638
- const [y, dx] = valueAndGradFn(...x);
4639
- y.dispose();
4640
- return dx;
4892
+ if (opts?.hasAux) {
4893
+ const [[y, aux], dx] = valueAndGradFn(...x);
4894
+ y.dispose();
4895
+ return [dx, aux];
4896
+ } else {
4897
+ const [y, dx] = valueAndGradFn(...x);
4898
+ y.dispose();
4899
+ return dx;
4900
+ }
4641
4901
  };
4642
4902
  }
4643
- function valueAndGrad$1(f) {
4903
+ function valueAndGrad$1(f, opts) {
4904
+ const argnums = opts?.argnums ?? 0;
4905
+ const hasAux = opts?.hasAux ?? false;
4906
+ require_backend.checkInts(argnums);
4907
+ const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
4644
4908
  return (...x) => {
4645
4909
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4646
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4910
+ for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
4911
+ const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
4647
4912
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4648
4913
  if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4649
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4650
- for (const r of rest) dispose(r);
4914
+ const cts = fVjp(onesLike$1(y.ref));
4651
4915
  fVjp.dispose();
4652
- return [y, ct];
4916
+ for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
4917
+ const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
4918
+ return hasAux ? [[y, aux], grads] : [y, grads];
4653
4919
  };
4654
4920
  }
4655
4921
  function jacrev$1(f) {
@@ -4657,7 +4923,7 @@ function jacrev$1(f) {
4657
4923
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4658
4924
  const [size$1] = x.shape;
4659
4925
  const pullback = (ct) => {
4660
- const [y, fVjp] = vjp$1(f, x);
4926
+ const [y, fVjp] = vjp$1(f, [x]);
4661
4927
  y.dispose();
4662
4928
  const [ret] = fVjp(ct);
4663
4929
  fVjp.dispose();
@@ -4666,6 +4932,9 @@ function jacrev$1(f) {
4666
4932
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4667
4933
  };
4668
4934
  }
4935
+ function hessian$1(f) {
4936
+ return jacfwd$1(grad$1(f));
4937
+ }
4669
4938
 
4670
4939
  //#endregion
4671
4940
  //#region src/library/numpy/einsum.ts
@@ -4804,8 +5073,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
4804
5073
  const idx = lhsIndex[j];
4805
5074
  const dim = shape$1[j];
4806
5075
  const existing = sizeMap.get(idx);
4807
- if (existing === void 0) sizeMap.set(idx, dim);
4808
- else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
5076
+ if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
5077
+ else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
4809
5078
  }
4810
5079
  }
4811
5080
  for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
@@ -4961,27 +5230,53 @@ function ifft(a, axis = -1) {
4961
5230
  //#region src/library/numpy-linalg.ts
4962
5231
  var numpy_linalg_exports = {};
4963
5232
  __export(numpy_linalg_exports, {
4964
- cholesky: () => cholesky$1,
5233
+ cholesky: () => cholesky,
5234
+ det: () => det,
4965
5235
  diagonal: () => diagonal,
5236
+ inv: () => inv,
4966
5237
  lstsq: () => lstsq,
4967
5238
  matmul: () => matmul,
5239
+ matrixPower: () => matrixPower,
4968
5240
  matrixTranspose: () => matrixTranspose,
4969
5241
  outer: () => outer,
5242
+ slogdet: () => slogdet,
5243
+ solve: () => solve,
4970
5244
  tensordot: () => tensordot,
4971
5245
  trace: () => trace,
4972
5246
  vecdot: () => vecdot
4973
5247
  });
5248
+ function checkSquare(name, a) {
5249
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
5250
+ return a.shape[a.ndim - 1];
5251
+ }
4974
5252
  /**
4975
5253
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
4976
5254
  *
4977
5255
  * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
4978
5256
  * the input matrix, which is on by default.
4979
5257
  */
4980
- function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
5258
+ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
4981
5259
  a = fudgeArray(a);
4982
- if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
5260
+ checkSquare("cholesky", a);
4983
5261
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
4984
- return cholesky(a, { upper });
5262
+ return cholesky$1(a, { upper });
5263
+ }
5264
+ /** Compute the determinant of a square matrix (batched). */
5265
+ function det(a) {
5266
+ a = fudgeArray(a);
5267
+ const n = checkSquare("det", a);
5268
+ const [lu$2, pivots, permutation] = lu(a);
5269
+ permutation.dispose();
5270
+ const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5271
+ const sign$1 = parity.mul(-2).add(1);
5272
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5273
+ return prod$1(diag$1, -1).mul(sign$1);
5274
+ }
5275
+ /** Compute the inverse of a square matrix (batched). */
5276
+ function inv(a) {
5277
+ a = fudgeArray(a);
5278
+ const n = checkSquare("inv", a);
5279
+ return solve(a, eye(n));
4985
5280
  }
4986
5281
  /**
4987
5282
  * Return the least-squares solution to a linear equation.
@@ -5005,7 +5300,7 @@ function lstsq(a, b) {
5005
5300
  const at = matrixTranspose(a.ref);
5006
5301
  if (m <= n) {
5007
5302
  const aat = matmul(a, at.ref);
5008
- const l = cholesky$1(aat, { symmetrizeInput: false });
5303
+ const l = cholesky(aat, { symmetrizeInput: false });
5009
5304
  const lb = triangularSolve(l.ref, b, {
5010
5305
  leftSide: true,
5011
5306
  lower: true
@@ -5017,7 +5312,7 @@ function lstsq(a, b) {
5017
5312
  return matmul(at, llb.ref);
5018
5313
  } else {
5019
5314
  const ata = matmul(at.ref, a);
5020
- const l = cholesky$1(ata, { symmetrizeInput: false });
5315
+ const l = cholesky(ata, { symmetrizeInput: false });
5021
5316
  const atb = matmul(at, b);
5022
5317
  const lb = triangularSolve(l.ref, atb, {
5023
5318
  leftSide: true,
@@ -5030,6 +5325,169 @@ function lstsq(a, b) {
5030
5325
  return llb;
5031
5326
  }
5032
5327
  }
5328
+ /** Raise a square matrix to an integer power, via repeated squarings. */
5329
+ function matrixPower(a, n) {
5330
+ if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
5331
+ a = fudgeArray(a);
5332
+ const m = checkSquare("matrixPower", a);
5333
+ if (n === 0) {
5334
+ a.dispose();
5335
+ return broadcastTo(eye(m), a.shape);
5336
+ }
5337
+ if (n < 0) {
5338
+ a = inv(a);
5339
+ n = -n;
5340
+ }
5341
+ let result = null;
5342
+ let a2k = a;
5343
+ for (let k = 0; n; k++) {
5344
+ if (k > 0) a2k = matmul(a2k.ref, a2k);
5345
+ if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
5346
+ n = Math.floor(n / 2);
5347
+ }
5348
+ a2k.dispose();
5349
+ return result;
5350
+ }
5351
+ /** Return sign and natural logarithm of the determinant of `a`. */
5352
+ function slogdet(a) {
5353
+ a = fudgeArray(a);
5354
+ const n = checkSquare("slogdet", a);
5355
+ const [lu$2, pivots, permutation] = lu(a);
5356
+ permutation.dispose();
5357
+ let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5358
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5359
+ parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
5360
+ const logabsdet = log(absolute(diag$1)).sum(-1);
5361
+ const sign$1 = parity.mul(-2).add(1);
5362
+ return [sign$1, logabsdet];
5363
+ }
5364
+ /**
5365
+ * Solve a linear system of equations.
5366
+ *
5367
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
5368
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
5369
+ *
5370
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
5371
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
5372
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
5373
+ */
5374
+ function solve(a, b) {
5375
+ a = fudgeArray(a);
5376
+ b = fudgeArray(b);
5377
+ const n = checkSquare("solve", a);
5378
+ if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
5379
+ const bIs1d = b.ndim === 1;
5380
+ if (bIs1d) b = b.reshape([...b.shape, 1]);
5381
+ if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
5382
+ const m = b.shape[b.ndim - 1];
5383
+ const batchDims = require_backend.generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
5384
+ a = broadcastTo(a, [
5385
+ ...batchDims,
5386
+ n,
5387
+ n
5388
+ ]);
5389
+ b = broadcastTo(b, [
5390
+ ...batchDims,
5391
+ n,
5392
+ m
5393
+ ]);
5394
+ const [lu$2, pivots, permutation] = lu(a);
5395
+ pivots.dispose();
5396
+ const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5397
+ const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5398
+ leftSide: true,
5399
+ lower: true,
5400
+ unitDiagonal: true
5401
+ });
5402
+ let x = triangularSolve(lu$2, LPb.ref, {
5403
+ leftSide: true,
5404
+ lower: false
5405
+ });
5406
+ if (bIs1d) x = squeeze(x, -1);
5407
+ return x;
5408
+ }
5409
+
5410
+ //#endregion
5411
+ //#region src/library/numpy/dtype-info.ts
5412
+ /** Machine limits for floating-point types. */
5413
+ function finfo(dtype) {
5414
+ if (!require_backend.isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
5415
+ switch (dtype) {
5416
+ case require_backend.DType.Float16: return Object.freeze({
5417
+ bits: 16,
5418
+ dtype: require_backend.DType.Float16,
5419
+ eps: 2 ** -10,
5420
+ epsneg: 2 ** -11,
5421
+ machep: -10,
5422
+ max: 65504,
5423
+ maxexp: 16,
5424
+ min: -65504,
5425
+ minexp: -14,
5426
+ negep: -24,
5427
+ nexp: 5,
5428
+ nmant: 10,
5429
+ precision: 3,
5430
+ resolution: .001,
5431
+ smallestNormal: 2 ** -14,
5432
+ smallestSubnormal: 2 ** -24
5433
+ });
5434
+ case require_backend.DType.Float32: return Object.freeze({
5435
+ bits: 32,
5436
+ dtype: require_backend.DType.Float32,
5437
+ eps: 2 ** -23,
5438
+ epsneg: 2 ** -24,
5439
+ machep: -23,
5440
+ max: 34028234663852886e22,
5441
+ maxexp: 128,
5442
+ min: -34028234663852886e22,
5443
+ minexp: -126,
5444
+ negep: -24,
5445
+ nexp: 8,
5446
+ nmant: 23,
5447
+ precision: 6,
5448
+ resolution: 1e-6,
5449
+ smallestNormal: 2 ** -126,
5450
+ smallestSubnormal: 2 ** -149
5451
+ });
5452
+ case require_backend.DType.Float64: return Object.freeze({
5453
+ bits: 64,
5454
+ dtype: require_backend.DType.Float64,
5455
+ eps: 2 ** -52,
5456
+ epsneg: 2 ** -53,
5457
+ machep: -52,
5458
+ max: Number.MAX_VALUE,
5459
+ maxexp: 1024,
5460
+ min: -Number.MAX_VALUE,
5461
+ minexp: -1022,
5462
+ negep: -53,
5463
+ nexp: 11,
5464
+ nmant: 52,
5465
+ precision: 15,
5466
+ resolution: 1e-15,
5467
+ smallestNormal: 2 ** -1022,
5468
+ smallestSubnormal: 2 ** -1074
5469
+ });
5470
+ default: throw new Error(`finfo: unsupported dtype ${dtype}`);
5471
+ }
5472
+ }
5473
+ /** Machine limits for integer types. */
5474
+ function iinfo(dtype) {
5475
+ switch (dtype) {
5476
+ case require_backend.DType.Int32: return Object.freeze({
5477
+ bits: 32,
5478
+ dtype: require_backend.DType.Int32,
5479
+ max: 2147483647,
5480
+ min: -2147483648
5481
+ });
5482
+ case require_backend.DType.Uint32: return Object.freeze({
5483
+ bits: 32,
5484
+ dtype: require_backend.DType.Uint32,
5485
+ max: 4294967295,
5486
+ min: 0
5487
+ });
5488
+ default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
5489
+ }
5490
+ }
5033
5491
 
5034
5492
  //#endregion
5035
5493
  //#region src/library/numpy.ts
@@ -5085,6 +5543,7 @@ __export(numpy_exports, {
5085
5543
  diag: () => diag,
5086
5544
  diagonal: () => diagonal,
5087
5545
  divide: () => trueDivide,
5546
+ divmod: () => divmod,
5088
5547
  dot: () => dot$1,
5089
5548
  dstack: () => dstack,
5090
5549
  e: () => e,
@@ -5097,6 +5556,7 @@ __export(numpy_exports, {
5097
5556
  expm1: () => expm1,
5098
5557
  eye: () => eye,
5099
5558
  fft: () => numpy_fft_exports,
5559
+ finfo: () => finfo,
5100
5560
  flip: () => flip,
5101
5561
  fliplr: () => fliplr,
5102
5562
  flipud: () => flipud,
@@ -5104,6 +5564,7 @@ __export(numpy_exports, {
5104
5564
  float32: () => float32,
5105
5565
  float64: () => float64,
5106
5566
  floor: () => floor,
5567
+ floorDivide: () => floorDivide,
5107
5568
  fmod: () => fmod,
5108
5569
  frexp: () => frexp,
5109
5570
  full: () => full,
@@ -5116,6 +5577,7 @@ __export(numpy_exports, {
5116
5577
  hstack: () => hstack,
5117
5578
  hypot: () => hypot,
5118
5579
  identity: () => identity$1,
5580
+ iinfo: () => iinfo,
5119
5581
  inf: () => inf,
5120
5582
  inner: () => inner,
5121
5583
  int32: () => int32,
@@ -5133,6 +5595,7 @@ __export(numpy_exports, {
5133
5595
  log10: () => log10,
5134
5596
  log1p: () => log1p,
5135
5597
  log2: () => log2,
5598
+ logspace: () => logspace,
5136
5599
  matmul: () => matmul,
5137
5600
  matrixTranspose: () => matrixTranspose,
5138
5601
  max: () => max,
@@ -5169,9 +5632,11 @@ __export(numpy_exports, {
5169
5632
  shape: () => shape,
5170
5633
  sign: () => sign,
5171
5634
  sin: () => sin,
5635
+ sinc: () => sinc,
5172
5636
  sinh: () => sinh,
5173
5637
  size: () => size,
5174
5638
  sort: () => sort,
5639
+ split: () => split$1,
5175
5640
  sqrt: () => sqrt,
5176
5641
  square: () => square,
5177
5642
  squeeze: () => squeeze,
@@ -5179,6 +5644,8 @@ __export(numpy_exports, {
5179
5644
  std: () => std,
5180
5645
  subtract: () => subtract,
5181
5646
  sum: () => sum,
5647
+ swapaxes: () => swapaxes,
5648
+ take: () => take,
5182
5649
  tan: () => tan,
5183
5650
  tanh: () => tanh,
5184
5651
  tensordot: () => tensordot,
@@ -5437,6 +5904,45 @@ function flip(x, axis = null) {
5437
5904
  return flip$1(x, axis);
5438
5905
  }
5439
5906
  /**
5907
+ * Split an array into multiple sub-arrays along an axis.
5908
+ *
5909
+ * @param a - The input array to split.
5910
+ * @param indicesOrSections - If an integer, it indicates the number of equal
5911
+ * sections to create along the specified axis. If a list of integers, it
5912
+ * specifies the indices at which to split the array.
5913
+ * @param axis - The axis along which to split the array. Default is 0.
5914
+ */
5915
+ function split$1(a, indicesOrSections, axis = 0) {
5916
+ a = fudgeArray(a);
5917
+ axis = require_backend.checkAxis(axis, a.ndim);
5918
+ const size$1 = a.shape[axis];
5919
+ let sizes;
5920
+ if (typeof indicesOrSections === "number") {
5921
+ if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
5922
+ const partSize = size$1 / indicesOrSections;
5923
+ sizes = require_backend.rep(indicesOrSections, partSize);
5924
+ } else {
5925
+ const indices = indicesOrSections;
5926
+ sizes = [indices[0]];
5927
+ for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5928
+ sizes.push(size$1 - indices[indices.length - 1]);
5929
+ }
5930
+ const results = [];
5931
+ for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
5932
+ results.push(a);
5933
+ break;
5934
+ } else if (i + 8 >= sizes.length) {
5935
+ results.push(...split$2(a, axis, sizes.slice(i)));
5936
+ break;
5937
+ } else {
5938
+ const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
5939
+ const outs = split$2(a, axis, groupSizes);
5940
+ results.push(...outs.slice(0, -1));
5941
+ a = outs[outs.length - 1];
5942
+ }
5943
+ return results;
5944
+ }
5945
+ /**
5440
5946
  * Join a sequence of arrays along an existing axis.
5441
5947
  *
5442
5948
  * The arrays must have the same shape, except in the dimension corresponding to
@@ -5448,13 +5954,11 @@ function concatenate(xs, axis = 0) {
5448
5954
  if (xs.length === 0) throw new Error("Need at least one array to concatenate");
5449
5955
  const shapes = xs.map(shape);
5450
5956
  axis = require_backend.checkAxis(axis, shapes[0].length);
5451
- for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays with shapes ${JSON.stringify(shapes)} along axis ${axis}`);
5452
- const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
5957
+ for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
5453
5958
  let result = xs[0];
5454
- for (let i = 1; i < xs.length; i++) {
5455
- const len1 = result.shape[axis];
5456
- const len2 = shapes[i][axis];
5457
- result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
5959
+ for (let i = 1; i < xs.length; i += 7) {
5960
+ const group = xs.slice(i, i + 7);
5961
+ result = concatenate$1([result, ...group], axis);
5458
5962
  }
5459
5963
  return result;
5460
5964
  }
@@ -5539,6 +6043,17 @@ function flipud(x) {
5539
6043
  function fliplr(x) {
5540
6044
  return flip(x, 1);
5541
6045
  }
6046
+ /** Interchange two axes of an array. */
6047
+ function swapaxes(a, axis1, axis2) {
6048
+ a = fudgeArray(a);
6049
+ axis1 = require_backend.checkAxis(axis1, a.ndim);
6050
+ axis2 = require_backend.checkAxis(axis2, a.ndim);
6051
+ if (axis1 === axis2) return a;
6052
+ const perm = require_backend.range(a.ndim);
6053
+ perm[axis1] = axis2;
6054
+ perm[axis2] = axis1;
6055
+ return transpose(a, perm);
6056
+ }
5542
6057
  /** Transpose the last two dimensions of an array. */
5543
6058
  function matrixTranspose(a) {
5544
6059
  if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
@@ -5706,6 +6221,20 @@ function sort(a, axis = -1) {
5706
6221
  function argsort(a, axis = -1) {
5707
6222
  return fudgeArray(a).argsort(axis);
5708
6223
  }
6224
+ /**
6225
+ * Take elements from an array along an axis.
6226
+ *
6227
+ * This is equivalent to advanced indexing with integer indices over that
6228
+ * numbered axis. By default, the flattened array is used.
6229
+ */
6230
+ function take(a, indices, axis = null) {
6231
+ if (axis === null) {
6232
+ a = ravel(a);
6233
+ axis = 0;
6234
+ }
6235
+ axis = require_backend.checkAxis(axis, ndim(a));
6236
+ return gather(a, [indices], [axis], axis);
6237
+ }
5709
6238
  /** Return if two arrays are element-wise equal within a tolerance. */
5710
6239
  function allclose(actual, expected, options) {
5711
6240
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -6025,6 +6554,20 @@ function tan(x) {
6025
6554
  x = fudgeArray(x);
6026
6555
  return sin(x.ref).div(cos(x));
6027
6556
  }
6557
+ /**
6558
+ * @function
6559
+ * Return the normalized sinc function.
6560
+ *
6561
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
6562
+ * This is the normalized sinc function commonly used in signal processing.
6563
+ *
6564
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
6565
+ * requires a custom JVP rule to handle properly (see JAX implementation).
6566
+ */
6567
+ const sinc = jit$1(function sinc$1(x) {
6568
+ const pix = x.ref.mul(Math.PI);
6569
+ return where(equal(x, 0), 1, sin(pix.ref).div(pix));
6570
+ });
6028
6571
  /** Element-wise inverse cosine function (inverse of cos). */
6029
6572
  function acos(x) {
6030
6573
  return subtract(pi / 2, asin(x));
@@ -6077,6 +6620,25 @@ function trueDivide(x, y) {
6077
6620
  return x.div(y);
6078
6621
  }
6079
6622
  /**
6623
+ * Return the largest integer smaller or equal to the division of the inputs.
6624
+ *
6625
+ * The result is always rounded towards negative infinity.
6626
+ *
6627
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
6628
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
6629
+ * negative values correctly (note: may overflow near int32 boundaries).
6630
+ *
6631
+ * @param x - Dividend array.
6632
+ * @param y - Divisor array.
6633
+ * @returns Element-wise floor division of x by y.
6634
+ */
6635
+ function floorDivide(x, y) {
6636
+ x = fudgeArray(x);
6637
+ y = fudgeArray(y);
6638
+ if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
6639
+ return subtract(x, remainder(x.ref, y.ref)).div(y);
6640
+ }
6641
+ /**
6080
6642
  * @function
6081
6643
  * Calculate element-wise floating-point modulo operation.
6082
6644
  */
@@ -6090,6 +6652,20 @@ const fmod = jit$1(function fmod$1(x, y) {
6090
6652
  const remainder = jit$1(function remainder$1(x, y) {
6091
6653
  return mod(mod(x, y.ref).add(y.ref), y);
6092
6654
  });
6655
+ /**
6656
+ * Return element-wise quotient and remainder simultaneously.
6657
+ *
6658
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
6659
+ *
6660
+ * @param x - Dividend array.
6661
+ * @param y - Divisor array.
6662
+ * @returns Tuple of [quotient, remainder].
6663
+ */
6664
+ function divmod(x, y) {
6665
+ const xArr = fudgeArray(x);
6666
+ const yArr = fudgeArray(y);
6667
+ return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
6668
+ }
6093
6669
  /** Round input to the nearest integer towards zero. */
6094
6670
  function trunc(x) {
6095
6671
  return idiv(x, 1);
@@ -6253,14 +6829,15 @@ function std(x, axis = null, opts) {
6253
6829
  return sqrt(var_(x, axis, opts));
6254
6830
  }
6255
6831
  /** Estimate the sample covariance of a set of variables. */
6256
- function cov(x, y) {
6832
+ function cov(x, y = null, { rowvar = true } = {}) {
6257
6833
  x = fudgeArray(x);
6258
6834
  if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6259
- if (y !== void 0) {
6835
+ if (y !== null) {
6260
6836
  y = fudgeArray(y);
6261
6837
  if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6262
6838
  x = vstack([x, y]);
6263
6839
  }
6840
+ if (!rowvar) x = x.transpose();
6264
6841
  const [_M, N] = x.shape;
6265
6842
  x = x.ref.sub(x.mean(1, { keepdims: true }));
6266
6843
  return dot$1(x.ref, x.transpose()).div(N - 1);
@@ -6305,7 +6882,8 @@ const isfinite = jit$1(function isfinite$1(x) {
6305
6882
  //#region src/library/lax-linalg.ts
6306
6883
  var lax_linalg_exports = {};
6307
6884
  __export(lax_linalg_exports, {
6308
- cholesky: () => cholesky,
6885
+ cholesky: () => cholesky$1,
6886
+ lu: () => lu,
6309
6887
  triangularSolve: () => triangularSolve
6310
6888
  });
6311
6889
  /**
@@ -6334,11 +6912,39 @@ __export(lax_linalg_exports, {
6334
6912
  * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6335
6913
  * ```
6336
6914
  */
6337
- function cholesky(a, { upper = false } = {}) {
6915
+ function cholesky$1(a, { upper = false } = {}) {
6338
6916
  const L = cholesky$2(a);
6339
6917
  return upper ? moveaxis$1(L, -2, -1) : L;
6340
6918
  }
6341
6919
  /**
6920
+ * LU decomposition with partial pivoting.
6921
+ *
6922
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
6923
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
6924
+ * and `U` is upper-triangular.
6925
+ *
6926
+ * @param x - A batch of matrices with shape `[..., m, n]`.
6927
+ *
6928
+ * @returns A tuple `(lu, pivots, permutation)` where:
6929
+ * - `lu`: combined lower and upper triangular matrices.
6930
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
6931
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
6932
+ *
6933
+ * @example
6934
+ * ```ts
6935
+ * import { lax, numpy as np } from "@jax-js/jax";
6936
+ *
6937
+ * const A = np.array([[4., 3.], [6., 3.]]);
6938
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
6939
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
6940
+ * // pivots = [1, 1]
6941
+ * // permutation = [1, 0]
6942
+ * ```
6943
+ */
6944
+ function lu(x) {
6945
+ return lu$1(x);
6946
+ }
6947
+ /**
6342
6948
  * Solve a triangular linear system.
6343
6949
  *
6344
6950
  * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
@@ -6376,6 +6982,7 @@ var lax_exports = {};
6376
6982
  __export(lax_exports, {
6377
6983
  conv: () => conv,
6378
6984
  convGeneralDilated: () => convGeneralDilated,
6985
+ convTranspose: () => convTranspose,
6379
6986
  convWithGeneralPadding: () => convWithGeneralPadding,
6380
6987
  dot: () => dot,
6381
6988
  erf: () => erf,
@@ -6384,6 +6991,7 @@ __export(lax_exports, {
6384
6991
  reduceWindow: () => reduceWindow,
6385
6992
  stopGradient: () => stopGradient$1
6386
6993
  });
6994
+ const JsArray = globalThis.Array;
6387
6995
  /**
6388
6996
  * General dot product/contraction operator.
6389
6997
  *
@@ -6455,7 +7063,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
6455
7063
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
6456
7064
  * function in JAX, which wraps XLA's general convolution operator.
6457
7065
  *
6458
- * Grouped convolutions are not supported right now.
7066
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7067
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
7068
+ * @param windowStrides - Strides for each spatial dimension
7069
+ * @param padding - Padding for each spatial dimension, or a string
7070
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
6459
7071
  */
6460
7072
  function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
6461
7073
  if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
@@ -6515,6 +7127,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
6515
7127
  function conv(lhs, rhs, windowStrides, padding) {
6516
7128
  return convGeneralDilated(lhs, rhs, windowStrides, padding);
6517
7129
  }
7130
+ /**
7131
+ * Convenience wrapper for calculating the N-d convolution "transpose".
7132
+ *
7133
+ * This function directly calculates a fractionally strided conv rather than
7134
+ * indirectly calculating the gradient (transpose) of a forward convolution.
7135
+ * It is equivalent to the JAX version, except:
7136
+ *
7137
+ * - The `use_consistent_padding` option is not available. We only have the
7138
+ * consistent padding case (JAX version >0.8.4).
7139
+ * - The order of dimensions matches `lax.conv_general_dilated`.
7140
+ *
7141
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
7142
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
7143
+ * `transposeKernel` to true.
7144
+ *
7145
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7146
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
7147
+ * @param strides - Sequence of n integers, sets fractional stride
7148
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
7149
+ * each side of the input, so it acts like gradient of `conv()`
7150
+ * @param rhsDilation - Atrous dilation for the kernel
7151
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
7152
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
7153
+ */
7154
+ function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
7155
+ const kernelShape = rhs.shape.slice(2);
7156
+ rhsDilation = rhsDilation ?? require_backend.rep(kernelShape.length, 1);
7157
+ const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
7158
+ const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
7159
+ if (transposeKernel) {
7160
+ rhs = flip$1(rhs, require_backend.range(2, rhs.ndim));
7161
+ rhs = moveaxis(rhs, 0, 1);
7162
+ }
7163
+ return convGeneralDilated(lhs, rhs, require_backend.rep(lhs.ndim - 2, 1), pads, {
7164
+ lhsDilation: strides,
7165
+ rhsDilation
7166
+ });
7167
+ }
7168
+ function convTransposePadding(k, s, padding) {
7169
+ let padLen;
7170
+ let pad1;
7171
+ if (padding === "SAME") {
7172
+ padLen = k + s - 2;
7173
+ pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
7174
+ } else if (padding === "VALID") {
7175
+ padLen = k + s - 2 + Math.max(k - s, 0);
7176
+ pad1 = k - 1;
7177
+ } else if (JsArray.isArray(padding)) {
7178
+ const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7179
+ pad1 = pads[0];
7180
+ padLen = pads[0] + pads[1];
7181
+ } else throw new Error(`convTranspose: Invalid padding type ${padding}`);
7182
+ return [pad1, padLen - pad1];
7183
+ }
6518
7184
  /** Reduce a computation over padded windows. */
6519
7185
  function reduceWindow(operand, computation, windowDimensions, windowStrides) {
6520
7186
  if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
@@ -6553,6 +7219,7 @@ function stopGradient$1(x) {
6553
7219
  var nn_exports = {};
6554
7220
  __export(nn_exports, {
6555
7221
  celu: () => celu,
7222
+ dotProductAttention: () => dotProductAttention,
6556
7223
  elu: () => elu,
6557
7224
  gelu: () => gelu,
6558
7225
  glu: () => glu,
@@ -6869,6 +7536,95 @@ function oneHot(x, numClasses) {
6869
7536
  if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
6870
7537
  return eye(numClasses, void 0, { device: x.device }).slice(x);
6871
7538
  }
7539
+ /**
7540
+ * Scaled dot product attention (SDPA).
7541
+ *
7542
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
7543
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
7544
+ * and query vector.
7545
+ *
7546
+ * Multi-query attention is applied when input `key` and `value` tensors have
7547
+ * fewer heads than `query`.
7548
+ *
7549
+ * We use the following uppercase letters to denote array shapes:
7550
+ * - `B` = batch size
7551
+ * - `S` = length of key/value sequences (source)
7552
+ * - `L` = length of query sequences
7553
+ * - `N` = number of attention heads
7554
+ * - `H` = dimensionality of each attention head
7555
+ * - `K` = number of key/value heads (for grouped-query attention)
7556
+ *
7557
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
7558
+ * case it must be omitted from all inputs.
7559
+ *
7560
+ * @param query - Query array; shape `[B, L, N, H]`
7561
+ * @param key - Key array; shape `[B, S, K, H]`
7562
+ * @param value - Value array; same shape as `key`
7563
+ * @param opts.bias - Optional bias to add to the attention logits; shape
7564
+ * `[B, N, L, S]` or broadcastable to it.
7565
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
7566
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
7567
+ * the element should take part in attention.
7568
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
7569
+ * @param opts.isCausal - If true, applies a casual mask.
7570
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
7571
+ * shape `(B,)`. Taken from the beginning of the tensor.
7572
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
7573
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
7574
+ * @param opts.localWindowSize - If specified, applies a local attention window
7575
+ * of the given size. Can be a single number or a tuple `[left, right]`.
7576
+ *
7577
+ * @returns The result of the attention operation; shape is the same as query
7578
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7579
+ */
7580
+ function dotProductAttention(query, key$1, value, opts = {}) {
7581
+ if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
7582
+ if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
7583
+ query = fudgeArray(query);
7584
+ key$1 = fudgeArray(key$1);
7585
+ value = fudgeArray(value);
7586
+ if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
7587
+ if (!require_backend.deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
7588
+ const isRank3 = query.ndim === 3;
7589
+ if (isRank3) {
7590
+ query = expandDims(query, 0);
7591
+ key$1 = expandDims(key$1, 0);
7592
+ value = expandDims(value, 0);
7593
+ }
7594
+ const [B, L, N, H] = query.shape;
7595
+ if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
7596
+ const S = key$1.shape[1];
7597
+ const K = key$1.shape[2];
7598
+ if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
7599
+ const G = N / K;
7600
+ key$1 = tile(key$1, [
7601
+ 1,
7602
+ 1,
7603
+ G,
7604
+ 1
7605
+ ]);
7606
+ value = tile(value, [
7607
+ 1,
7608
+ 1,
7609
+ G,
7610
+ 1
7611
+ ]);
7612
+ const scale = opts.scale ?? 1 / Math.sqrt(H);
7613
+ let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
7614
+ if (opts.bias !== void 0) scores = scores.add(opts.bias);
7615
+ if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
7616
+ if (opts.isCausal) {
7617
+ const causalMask = tri(L, S, 0, { dtype: require_backend.DType.Bool });
7618
+ scores = where(causalMask, scores, -Infinity);
7619
+ }
7620
+ const attn = softmax(scores, -1);
7621
+ const out = einsum("BNLS,BSNH->BLNH", attn, value);
7622
+ return isRank3 ? out.reshape([
7623
+ L,
7624
+ N,
7625
+ H
7626
+ ]) : out;
7627
+ }
6872
7628
 
6873
7629
  //#endregion
6874
7630
  //#region src/library/random.ts
@@ -6881,33 +7637,41 @@ __export(random_exports, {
6881
7637
  gumbel: () => gumbel,
6882
7638
  key: () => key,
6883
7639
  laplace: () => laplace,
7640
+ multivariateNormal: () => multivariateNormal,
6884
7641
  normal: () => normal,
6885
7642
  split: () => split,
6886
7643
  uniform: () => uniform
6887
7644
  });
6888
- function validateKeyShape(key$1) {
7645
+ function validateKeyShape(key$1, scalar = false) {
6889
7646
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
6890
7647
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
7648
+ if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
6891
7649
  return key$1.shape.slice(0, -1);
6892
7650
  }
7651
+ function getK01(key$1) {
7652
+ const keyShape = validateKeyShape(key$1, true);
7653
+ let [k0, k1] = split$2(key$1, -1, [1, 1]);
7654
+ k0 = k0.reshape(keyShape);
7655
+ k1 = k1.reshape(keyShape);
7656
+ return [k0, k1];
7657
+ }
6893
7658
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
6894
7659
  function key(seed) {
6895
- seed = seed >>> 0;
6896
- return array([0, seed], { dtype: require_backend.DType.Uint32 });
7660
+ seed = array(seed, { dtype: require_backend.DType.Uint32 });
7661
+ if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7662
+ return stack([0, seed]);
6897
7663
  }
6898
7664
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
6899
7665
  function split(key$1, num = 2) {
6900
7666
  const shape$1 = typeof num === "number" ? [num] : num;
6901
7667
  for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
6902
- const keyShape = validateKeyShape(key$1);
6903
- const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
6904
- const k1 = key$1.slice(...keyShape.map(() => null), 1);
7668
+ const [k0, k1] = getK01(key$1);
6905
7669
  return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
6906
7670
  }
6907
7671
  /** Sample uniform bits in the form of unsigned integers. */
6908
7672
  function bits(key$1, shape$1 = []) {
6909
- const keyShape = validateKeyShape(key$1);
6910
- return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
7673
+ const [k0, k1] = getK01(key$1);
7674
+ return randomBits(k0, k1, shape$1);
6911
7675
  }
6912
7676
  /**
6913
7677
  * @function
@@ -6981,6 +7745,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
6981
7745
  }, { staticArgnums: [1] });
6982
7746
  /**
6983
7747
  * @function
7748
+ * Sample multivariate normal random values with given mean and covariance.
7749
+ *
7750
+ * The values are returned with the given shape, along with the final dimension
7751
+ * used to represent the n-dimensional multivariate normal factors.
7752
+ *
7753
+ * This uses Cholesky decomposition on the covariance matrix.
7754
+ *
7755
+ * - `key` - PRNG key
7756
+ * - `mean` - Mean vector of shape `[..., n]`
7757
+ * - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
7758
+ * - `shape` - Result batch shape, must be broadcastable with
7759
+ * `mean.shape[:-1]` and `cov.shape[:-2]`
7760
+ * @returns Random samples of shape `[...shape, n]`
7761
+ */
7762
+ const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
7763
+ mean$1 = fudgeArray(mean$1);
7764
+ cov$1 = fudgeArray(cov$1);
7765
+ const n = mean$1.shape[mean$1.ndim - 1];
7766
+ if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
7767
+ const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
7768
+ const L = cholesky(cov$1);
7769
+ const z = normal(key$1, outputShape);
7770
+ return einsum("...ij,...j->...i", L, z).add(mean$1);
7771
+ }, { staticArgnums: [3] });
7772
+ /**
7773
+ * @function
6984
7774
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6985
7775
  *
6986
7776
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
@@ -7070,17 +7860,62 @@ const linearize = linearize$1;
7070
7860
  /**
7071
7861
  * @function
7072
7862
  * Calculate the reverse-mode vector-Jacobian product for a function.
7863
+ *
7864
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
7865
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
7866
+ * output and returns the cotangents for each input.
7867
+ *
7868
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7869
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
7870
+ *
7871
+ * @example
7872
+ * ```ts
7873
+ * const [y, vjpFn] = vjp(f, [x]);
7874
+ *
7875
+ * // With hasAux
7876
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
7877
+ * ```
7073
7878
  */
7074
7879
  const vjp = vjp$1;
7075
7880
  /**
7076
7881
  * @function
7077
7882
  * Compute the gradient of a scalar-valued function `f` with respect to its
7078
7883
  * first argument.
7884
+ *
7885
+ * Pass in different `argnums` to differentiate with respect to other
7886
+ * arguments. If a tuple is provided, the return value will be a tuple of
7887
+ * gradients corresponding to each argument index.
7888
+ *
7889
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
7890
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
7891
+ *
7892
+ * @example
7893
+ * ```ts
7894
+ * const gradient = grad(f)(x);
7895
+ *
7896
+ * // With `argnums`
7897
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
7898
+ *
7899
+ * // With `hasAux`
7900
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
7901
+ * ```
7079
7902
  */
7080
7903
  const grad = grad$1;
7081
7904
  /**
7082
7905
  * @function
7083
7906
  * Create a function that evaluates both `f` and the gradient of `f`.
7907
+ *
7908
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7909
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
7910
+ *
7911
+ * @example
7912
+ * ```ts
7913
+ * // Without hasAux
7914
+ * const [value, gradient] = valueAndGrad(f)(x);
7915
+ *
7916
+ * // With hasAux
7917
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
7918
+ * ```
7084
7919
  */
7085
7920
  const valueAndGrad = valueAndGrad$1;
7086
7921
  /**
@@ -7089,6 +7924,21 @@ const valueAndGrad = valueAndGrad$1;
7089
7924
  */
7090
7925
  const jacrev = jacrev$1;
7091
7926
  /**
7927
+ * @function
7928
+ * Compute the Hessian matrix of a scalar-valued function.
7929
+ *
7930
+ * The Hessian is the matrix of second-order partial derivatives of a function.
7931
+ * This is implemented as `jacfwd(grad(f))`.
7932
+ *
7933
+ * @example
7934
+ * ```ts
7935
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
7936
+ * const H = hessian(f)(np.array([1, 2, 3]));
7937
+ * // H[i,j] = d^2f / dx_i dx_j
7938
+ * ```
7939
+ */
7940
+ const hessian = hessian$1;
7941
+ /**
7092
7942
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
7093
7943
  *
7094
7944
  * This can be used to wait for the results of an intermediate computation to
@@ -7132,6 +7982,7 @@ exports.defaultDevice = require_backend.defaultDevice;
7132
7982
  exports.devicePut = devicePut;
7133
7983
  exports.devices = require_backend.devices;
7134
7984
  exports.grad = grad;
7985
+ exports.hessian = hessian;
7135
7986
  exports.init = require_backend.init;
7136
7987
  exports.jacfwd = jacfwd;
7137
7988
  exports.jacobian = jacrev;