@jax-js/jax 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-tngXtWe4.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Dx6Ob2D1.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -209,7 +209,7 @@ __export(tree_exports, {
209
209
  structure: () => structure,
210
210
  unflatten: () => unflatten
211
211
  });
212
- const JsArray$1 = globalThis.Array;
212
+ const JsArray$2 = globalThis.Array;
213
213
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
214
214
  NodeType$1["Array"] = "Array";
215
215
  NodeType$1["Object"] = "Object";
@@ -257,7 +257,7 @@ function flatten(tree) {
257
257
  return [leaves$1, treedef];
258
258
  }
259
259
  function _flatten(tree, leaves$1) {
260
- if (JsArray$1.isArray(tree)) {
260
+ if (JsArray$2.isArray(tree)) {
261
261
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
262
262
  return new JsTreeDef(NodeType.Array, null, childTrees);
263
263
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -356,6 +356,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
356
356
  Primitive$1["PoolTranspose"] = "pool_transpose";
357
357
  Primitive$1["Compare"] = "compare";
358
358
  Primitive$1["Where"] = "where";
359
+ Primitive$1["Concatenate"] = "concatenate";
360
+ Primitive$1["Split"] = "split";
359
361
  Primitive$1["RandomBits"] = "random_bits";
360
362
  Primitive$1["Gather"] = "gather";
361
363
  Primitive$1["Transpose"] = "transpose";
@@ -368,6 +370,7 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
368
370
  Primitive$1["Argsort"] = "argsort";
369
371
  Primitive$1["TriangularSolve"] = "triangular_solve";
370
372
  Primitive$1["Cholesky"] = "cholesky";
373
+ Primitive$1["LU"] = "lu";
371
374
  Primitive$1["Jit"] = "jit";
372
375
  return Primitive$1;
373
376
  }({});
@@ -378,6 +381,13 @@ let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
378
381
  CompareOp$1["LessEqual"] = "less_equal";
379
382
  return CompareOp$1;
380
383
  }({});
384
+ const routinePrimitives = new Map([
385
+ [Primitive.Sort, Routines.Sort],
386
+ [Primitive.Argsort, Routines.Argsort],
387
+ [Primitive.TriangularSolve, Routines.TriangularSolve],
388
+ [Primitive.Cholesky, Routines.Cholesky],
389
+ [Primitive.LU, Routines.LU]
390
+ ]);
381
391
  function add$1(x, y) {
382
392
  return bind1(Primitive.Add, [x, y]);
383
393
  }
@@ -499,7 +509,25 @@ function where$1(cond, x, y) {
499
509
  y
500
510
  ]);
501
511
  }
512
+ function concatenate$1(xs, axis) {
513
+ if (xs.length === 0) throw new Error("concatenate requires at least one input");
514
+ const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
515
+ axis = checkAxis(axis, avals[0].ndim);
516
+ for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
517
+ return bind1(Primitive.Concatenate, xs, { axis });
518
+ }
519
+ function split$2(x, axis, sizes) {
520
+ axis = checkAxis(axis, ndim$1(x));
521
+ if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
522
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
523
+ if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
524
+ return bind(Primitive.Split, [x], {
525
+ axis,
526
+ sizes
527
+ });
528
+ }
502
529
  function randomBits(k0, k1, shape$1, mode = "xor") {
530
+ if (!deepEqual(k0.shape, k1.shape) || k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
503
531
  return bind1(Primitive.RandomBits, [k0, k1], {
504
532
  shape: shape$1,
505
533
  mode
@@ -566,6 +594,11 @@ function pad$1(x, width) {
566
594
  return bind1(Primitive.Pad, [x], { width });
567
595
  }
568
596
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
597
+ const as = getShape(a);
598
+ const bs = getShape(b);
599
+ if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
600
+ const n = as[as.length - 2];
601
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
569
602
  if (lower) {
570
603
  a = flip$1(a, [-2, -1]);
571
604
  b = flip$1(b, [-1]);
@@ -575,8 +608,15 @@ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
575
608
  return x;
576
609
  }
577
610
  function cholesky$2(x) {
611
+ const aval = ShapedArray.fromAval(getAval(x));
612
+ if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
578
613
  return bind1(Primitive.Cholesky, [x]);
579
614
  }
615
+ function lu$1(x) {
616
+ const aval = ShapedArray.fromAval(getAval(x));
617
+ if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
618
+ return bind(Primitive.LU, [x]);
619
+ }
580
620
  function sort$1(x) {
581
621
  const nd = ndim$1(x);
582
622
  if (nd === 0) throw new Error("sort: requires at least 1D input");
@@ -621,6 +661,9 @@ function newDynamic(main) {
621
661
  dynamicTrace = prevDynamicTrace;
622
662
  } };
623
663
  }
664
+ function currentTraceLevel() {
665
+ return traceStack[traceStack.length - 1].level;
666
+ }
624
667
  var Trace = class {
625
668
  constructor(main) {
626
669
  this.main = main;
@@ -685,6 +728,9 @@ var Tracer = class Tracer {
685
728
  mul(other) {
686
729
  return mul(this, other);
687
730
  }
731
+ mod(other) {
732
+ return mod(this, other);
733
+ }
688
734
  greater(other) {
689
735
  return greater$1(this, other);
690
736
  }
@@ -797,8 +843,14 @@ var Tracer = class Tracer {
797
843
  */
798
844
  *[Symbol.iterator]() {
799
845
  if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
800
- for (let i = 0; i < this.shape[0]; i++) yield this.ref.slice(i);
801
- this.dispose();
846
+ let residual = this;
847
+ const subarrayShape = this.shape.slice(1);
848
+ for (let i = 0; i < this.shape[0]; i++) {
849
+ const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
850
+ yield lr[0].reshape(subarrayShape);
851
+ residual = lr[1];
852
+ }
853
+ residual.dispose();
802
854
  }
803
855
  /**
804
856
  * Return a sorted copy of an array in ascending order.
@@ -948,6 +1000,9 @@ var ShapedArray = class ShapedArray {
948
1000
  get size() {
949
1001
  return prod(this.shape);
950
1002
  }
1003
+ scalar() {
1004
+ return new ShapedArray([], this.dtype, this.weakType);
1005
+ }
951
1006
  toString() {
952
1007
  return `${this.dtype}[${this.shape.join(",")}]`;
953
1008
  }
@@ -986,6 +1041,7 @@ var TreeMismatchError = class extends TypeError {
986
1041
  super(`Mismatched tree structures in ${where$2}: ${left} != ${right}`);
987
1042
  }
988
1043
  };
1044
+ /** Flatten a function of `JsTree` input/output for use in tracing. */
989
1045
  function flattenFun(f, inTree) {
990
1046
  const store = { value: void 0 };
991
1047
  const flatFun = (...argsFlat) => {
@@ -997,6 +1053,26 @@ function flattenFun(f, inTree) {
997
1053
  };
998
1054
  return [flatFun, store];
999
1055
  }
1056
+ /** Like flattenFun, but expects f to return [main, aux] tuple. */
1057
+ function flattenFunWithAux(f, inTree) {
1058
+ const store = { value: void 0 };
1059
+ const auxStore = { value: void 0 };
1060
+ const flatFun = (...argsFlat) => {
1061
+ const pytreeArgs = unflatten(inTree, argsFlat);
1062
+ const result = f(...pytreeArgs);
1063
+ if (!Array.isArray(result) || result.length !== 2) throw new Error("Function with `hasAux: true` must return [output, aux] tuple");
1064
+ const [out, aux] = result;
1065
+ const [outFlat, outTree] = flatten(out);
1066
+ store.value = outTree;
1067
+ auxStore.value = aux;
1068
+ return outFlat;
1069
+ };
1070
+ return [
1071
+ flatFun,
1072
+ store,
1073
+ auxStore
1074
+ ];
1075
+ }
1000
1076
  var UseAfterFreeError = class extends ReferenceError {
1001
1077
  constructor(tracer) {
1002
1078
  super(`Referenced tracer ${tracer.toString()} freed, please use .ref move semantics`);
@@ -1553,7 +1629,7 @@ const abstractEvalRules = {
1553
1629
  return [new ShapedArray(shape$1, dtype, weakType)];
1554
1630
  },
1555
1631
  [Primitive.Conv]([lhs, rhs], params) {
1556
- const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
1632
+ const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
1557
1633
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
1558
1634
  return [new ShapedArray(shape$1, dtype, weakType)];
1559
1635
  },
@@ -1564,10 +1640,25 @@ const abstractEvalRules = {
1564
1640
  const shape$1 = generalBroadcast(cond.shape, xy.shape);
1565
1641
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1566
1642
  },
1643
+ [Primitive.Concatenate](xs, { axis }) {
1644
+ if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
1645
+ for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
1646
+ const shape$1 = xs[0].shape.slice();
1647
+ shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
1648
+ const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
1649
+ return [new ShapedArray(shape$1, dtype, weakType)];
1650
+ },
1651
+ [Primitive.Split]([x], { axis, sizes }) {
1652
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
1653
+ if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
1654
+ return sizes.map((size$1) => {
1655
+ return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
1656
+ });
1657
+ },
1567
1658
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1568
1659
  if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1569
- const keyShape = generalBroadcast(k0.shape, k1.shape);
1570
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1660
+ if (!deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
1661
+ if (!deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
1571
1662
  return [new ShapedArray(shape$1, DType.Uint32, false)];
1572
1663
  },
1573
1664
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
@@ -1624,6 +1715,16 @@ const abstractEvalRules = {
1624
1715
  if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1625
1716
  return [ShapedArray.fromAval(a)];
1626
1717
  },
1718
+ [Primitive.LU]([a]) {
1719
+ if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
1720
+ const batch = a.shape.slice(0, -2);
1721
+ const [m, n] = a.shape.slice(-2);
1722
+ return [
1723
+ ShapedArray.fromAval(a),
1724
+ new ShapedArray([...batch, Math.min(m, n)], DType.Int32, false),
1725
+ new ShapedArray([...batch, m], DType.Int32, false)
1726
+ ];
1727
+ },
1627
1728
  [Primitive.Jit](args, { jaxpr }) {
1628
1729
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1629
1730
  if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
@@ -1701,12 +1802,6 @@ function jit$1(f, opts) {
1701
1802
 
1702
1803
  //#endregion
1703
1804
  //#region src/frontend/jit.ts
1704
- const routinePrimitives = new Map([
1705
- [Primitive.Sort, Routines.Sort],
1706
- [Primitive.Argsort, Routines.Argsort],
1707
- [Primitive.TriangularSolve, Routines.TriangularSolve],
1708
- [Primitive.Cholesky, Routines.Cholesky]
1709
- ]);
1710
1805
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1711
1806
  var JitProgram = class {
1712
1807
  constructor(backend, steps, inputs, outputs) {
@@ -1876,10 +1971,10 @@ function jitCompile(backend, jaxpr) {
1876
1971
  inputs.push(jv.arg);
1877
1972
  } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1878
1973
  const outputs = [];
1879
- for (const outVar$1 of eqn.outBinders) {
1880
- const outId = builder.pushBuffer(outVar$1.aval.size * byteWidth(outVar$1.aval.dtype));
1974
+ for (const outVar of eqn.outBinders) {
1975
+ const outId = builder.pushBuffer(outVar.aval.size * byteWidth(outVar.aval.dtype));
1881
1976
  outputs.push(outId);
1882
- ctx.set(outVar$1, {
1977
+ ctx.set(outVar, {
1883
1978
  type: "imm",
1884
1979
  arg: outId
1885
1980
  });
@@ -1930,35 +2025,37 @@ function jitCompile(backend, jaxpr) {
1930
2025
  let reduction;
1931
2026
  if (inputReduction) {
1932
2027
  const jv = inputReduction;
1933
- const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1934
- exp$2 = jv.exp.reindexGids(addArgs(jv.args));
2028
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
2029
+ exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
1935
2030
  reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1936
2031
  } else {
1937
2032
  const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1938
2033
  exp$2 = ruleOutput.exp;
1939
2034
  reduction = ruleOutput.reduction;
1940
2035
  }
1941
- const outVar = eqn.outBinders[0];
1942
- if (blackNodes.has(outVar)) {
1943
- const nargs$1 = inputArgs.length;
1944
- const size$1 = outVar.aval.size;
1945
- const kernel = new Kernel(nargs$1, size$1, exp$2, reduction);
1946
- const outId = builder.pushKernel(kernel, inputArgs);
1947
- ctx.set(outVar, {
1948
- type: "imm",
1949
- arg: outId
2036
+ for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
2037
+ const outVar = eqn.outBinders[i$1];
2038
+ if (blackNodes.has(outVar)) {
2039
+ const nargs$1 = inputArgs.length;
2040
+ const size$1 = outVar.aval.size;
2041
+ const kernel = new Kernel(nargs$1, size$1, exp$2[i$1], reduction);
2042
+ const outId = builder.pushKernel(kernel, inputArgs);
2043
+ ctx.set(outVar, {
2044
+ type: "imm",
2045
+ arg: outId
2046
+ });
2047
+ } else if (reduction) ctx.set(outVar, {
2048
+ type: "red",
2049
+ exp: exp$2[i$1],
2050
+ reduction,
2051
+ args: inputArgs
1950
2052
  });
1951
- } else if (reduction) ctx.set(outVar, {
1952
- type: "red",
1953
- exp: exp$2,
1954
- reduction,
1955
- args: inputArgs
1956
- });
1957
- else ctx.set(outVar, {
1958
- type: "exp",
1959
- exp: exp$2,
1960
- args: inputArgs
1961
- });
2053
+ else ctx.set(outVar, {
2054
+ type: "exp",
2055
+ exp: exp$2[i$1],
2056
+ args: inputArgs
2057
+ });
2058
+ }
1962
2059
  }
1963
2060
  const outputIds = [];
1964
2061
  for (const out of jaxpr.outs) if (out instanceof Var) {
@@ -1999,17 +2096,17 @@ function broadcastedJit(fn, opts) {
1999
2096
  if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
2000
2097
  return exp$2;
2001
2098
  });
2002
- return { exp: fn(exps, params) };
2099
+ return { exp: [fn(exps, params)] };
2003
2100
  };
2004
2101
  }
2005
2102
  function unopJit(fn) {
2006
2103
  return ([a], [_as], params) => {
2007
- return { exp: fn(a, params) };
2104
+ return { exp: [fn(a, params)] };
2008
2105
  };
2009
2106
  }
2010
2107
  function reshapeJit(fn) {
2011
2108
  return ([a], [_as], params) => {
2012
- return { exp: reshapeViews(a, (st) => fn(st, params)) };
2109
+ return { exp: [reshapeViews(a, (st) => fn(st, params))] };
2013
2110
  };
2014
2111
  }
2015
2112
  function routineNoJit() {
@@ -2055,7 +2152,7 @@ const jitRules = {
2055
2152
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
2056
2153
  const reduction = new Reduction(a.dtype, op, reductionSize);
2057
2154
  return {
2058
- exp: a,
2155
+ exp: [a],
2059
2156
  reduction
2060
2157
  };
2061
2158
  },
@@ -2066,13 +2163,13 @@ const jitRules = {
2066
2163
  a = reshapeViews(a, (st) => st.compose(stX), true);
2067
2164
  const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
2068
2165
  return {
2069
- exp: a,
2166
+ exp: [a],
2070
2167
  reduction
2071
2168
  };
2072
2169
  },
2073
2170
  [Primitive.Dot]([a, b], [as, bs]) {
2074
2171
  const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
2075
- const c = k1.exp;
2172
+ const [c] = k1.exp;
2076
2173
  const cs = promoteAvals(as, bs);
2077
2174
  return jitRules[Primitive.Reduce]([c], [cs], {
2078
2175
  op: AluOp.Add,
@@ -2089,16 +2186,42 @@ const jitRules = {
2089
2186
  },
2090
2187
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
2091
2188
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
2189
+ [Primitive.Concatenate](exps, avals, { axis }) {
2190
+ const ndim$2 = avals[0].ndim;
2191
+ const sizes = avals.map((x) => x.shape[axis]);
2192
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
2193
+ const { dtype: dtypeOut } = avals.map((x) => x.scalar()).reduce(promoteAvals);
2194
+ const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2195
+ let cum = 0;
2196
+ const src = [];
2197
+ for (let i = 0; i < exps.length; i++) {
2198
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2199
+ src.push(reshapeViews(AluExp.cast(dtypeOut, exps[i]), (st) => st.pad(padding)));
2200
+ cum += sizes[i];
2201
+ }
2202
+ return { exp: [src.reduce(AluExp.add)] };
2203
+ },
2204
+ [Primitive.Split]([a], [as], { axis, sizes }) {
2205
+ const exp$2 = [];
2206
+ let start = 0;
2207
+ for (const size$1 of sizes) {
2208
+ const slice = range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
2209
+ exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
2210
+ start += size$1;
2211
+ }
2212
+ return { exp: exp$2 };
2213
+ },
2092
2214
  [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2215
+ const keyShape = keyShapes[0].shape;
2093
2216
  const mapping = (st) => {
2094
- if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
2217
+ if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(st.shape.length, shape$1.length));
2095
2218
  };
2096
2219
  const k0 = reshapeViews(keys[0], mapping);
2097
2220
  const k1 = reshapeViews(keys[1], mapping);
2098
2221
  const c0 = AluExp.u32(0);
2099
- const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
2222
+ const c1 = AluExp.mod(AluExp.cast(DType.Uint32, AluVar.gidx), AluExp.u32(Math.max(prod(shape$1.slice(keyShape.length)), 1)));
2100
2223
  const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
2101
- return { exp: exp$2 };
2224
+ return { exp: [exp$2] };
2102
2225
  },
2103
2226
  [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
2104
2227
  const axisSet = new Set(axis);
@@ -2113,7 +2236,7 @@ const jitRules = {
2113
2236
  for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
2114
2237
  const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
2115
2238
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
2116
- return { exp: x.substitute({ gidx: index }) };
2239
+ return { exp: [x.substitute({ gidx: index })] };
2117
2240
  },
2118
2241
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2119
2242
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
@@ -2129,6 +2252,7 @@ const jitRules = {
2129
2252
  [Primitive.Argsort]: routineNoJit(),
2130
2253
  [Primitive.TriangularSolve]: routineNoJit(),
2131
2254
  [Primitive.Cholesky]: routineNoJit(),
2255
+ [Primitive.LU]: routineNoJit(),
2132
2256
  [Primitive.Jit]() {
2133
2257
  throw new Error("internal: Jit should have been flattened before JIT compilation");
2134
2258
  }
@@ -2210,7 +2334,7 @@ function splitGraphDataflow(backend, jaxpr) {
2210
2334
  p1NextBlack.set(v, v);
2211
2335
  }
2212
2336
  const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
2213
- const needsCleanShapePrimitives = [Primitive.Pad];
2337
+ const needsCleanShapePrimitives = [Primitive.Concatenate, Primitive.Pad];
2214
2338
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
2215
2339
  const eqn = jaxpr.eqns[i];
2216
2340
  if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
@@ -2280,7 +2404,7 @@ function splitGraphDataflow(backend, jaxpr) {
2280
2404
 
2281
2405
  //#endregion
2282
2406
  //#region src/frontend/array.ts
2283
- const JsArray = globalThis.Array;
2407
+ const JsArray$1 = globalThis.Array;
2284
2408
  const inlineArrayLimit = 128;
2285
2409
  /** Version of pureArray with fudged types. */
2286
2410
  const fudgeArray = pureArray;
@@ -2407,6 +2531,10 @@ var Array$1 = class Array$1 extends Tracer {
2407
2531
  this.#rc++;
2408
2532
  return this;
2409
2533
  }
2534
+ /** Get the current reference count (for debugging memory management). */
2535
+ get refCount() {
2536
+ return this.#rc;
2537
+ }
2410
2538
  dispose() {
2411
2539
  this.#check();
2412
2540
  if (--this.#rc === 0) {
@@ -2564,7 +2692,7 @@ var Array$1 = class Array$1 extends Tracer {
2564
2692
  } else if (castDtype === void 0) {
2565
2693
  castDtype = arrays[i].#dtype;
2566
2694
  castWeakType = arrays[i].#weakType;
2567
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
2695
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
2568
2696
  const weakType = castWeakType && !strongTypeOutput;
2569
2697
  const { backend, committed } = Array$1.#computeBackend(name, arrays);
2570
2698
  arrays = arrays.map((ar) => ar._putSync(backend));
@@ -2674,25 +2802,35 @@ var Array$1 = class Array$1 extends Tracer {
2674
2802
  });
2675
2803
  }
2676
2804
  /** Apply an operation with custom lowering to this array. */
2677
- static #routine(routine, arrays, outputWeakType) {
2678
- const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
2679
- for (const ar of arrays) ar.#realize();
2680
- const inputs = arrays.map((ar) => ar.#source);
2681
- const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(byteWidth(dtype) * prod(routine.type.outputShapes[i])));
2682
- const pending = arrays.flatMap((ar) => ar.#pending);
2683
- for (const exe of pending) exe.updateRc(+outputs.length);
2684
- pending.push(new PendingExecute(backend, routine, inputs, outputs));
2685
- pending[pending.length - 1].updateRc(+outputs.length - 1);
2686
- arrays.forEach((ar) => ar.dispose());
2687
- return outputs.map((output, i) => new Array$1({
2688
- source: output,
2689
- st: ShapeTracker.fromShape(routine.type.outputShapes[i]),
2690
- dtype: routine.type.outputDtypes[i],
2691
- weakType: outputWeakType[i],
2692
- backend,
2693
- committed,
2694
- pending
2695
- }));
2805
+ static #routine(prim) {
2806
+ return (arrays, params) => {
2807
+ const { backend, committed } = Array$1.#computeBackend(prim, arrays);
2808
+ for (const ar of arrays) ar.#realize();
2809
+ const avals = arrays.map((ar) => ar.aval);
2810
+ const avalsOut = abstractEvalRules[prim](avals, params);
2811
+ const routine = new Routine(routinePrimitives.get(prim), {
2812
+ inputShapes: avals.map((a) => a.shape),
2813
+ inputDtypes: avals.map((a) => a.dtype),
2814
+ outputShapes: avalsOut.map((a) => a.shape),
2815
+ outputDtypes: avalsOut.map((a) => a.dtype)
2816
+ }, params);
2817
+ const inputs = arrays.map((ar) => ar.#source);
2818
+ const outputs = avalsOut.map((x) => backend.malloc(byteWidth(x.dtype) * x.size));
2819
+ const pending = arrays.flatMap((ar) => ar.#pending);
2820
+ for (const exe of pending) exe.updateRc(+outputs.length);
2821
+ pending.push(new PendingExecute(backend, routine, inputs, outputs));
2822
+ pending[pending.length - 1].updateRc(+outputs.length - 1);
2823
+ arrays.forEach((ar) => ar.dispose());
2824
+ return outputs.map((output, i) => new Array$1({
2825
+ source: output,
2826
+ st: ShapeTracker.fromShape(avalsOut[i].shape),
2827
+ dtype: avalsOut[i].dtype,
2828
+ weakType: avalsOut[i].weakType,
2829
+ backend,
2830
+ committed,
2831
+ pending
2832
+ }));
2833
+ };
2696
2834
  }
2697
2835
  /**
2698
2836
  * Normalizes this array into one backed by a `Slot`.
@@ -2957,17 +3095,44 @@ var Array$1 = class Array$1 extends Tracer {
2957
3095
  y
2958
3096
  ], { dtypeOverride: [DType.Bool] })];
2959
3097
  },
3098
+ [Primitive.Concatenate](xs, { axis }) {
3099
+ const ndim$2 = xs[0].ndim;
3100
+ const sizes = xs.map((x) => x.shape[axis]);
3101
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
3102
+ const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
3103
+ let cum = 0;
3104
+ const xsPadded = [];
3105
+ for (let i = 0; i < xs.length; i++) {
3106
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
3107
+ xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
3108
+ cum += sizes[i];
3109
+ }
3110
+ const custom = (exps) => exps.reduce(AluExp.add);
3111
+ return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
3112
+ },
3113
+ [Primitive.Split]([x], { axis, sizes }) {
3114
+ const outputs = [];
3115
+ for (let i = 0, start = 0; i < sizes.length; i++) {
3116
+ const slice = range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
3117
+ outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
3118
+ start += sizes[i];
3119
+ }
3120
+ x.dispose();
3121
+ return outputs;
3122
+ },
2960
3123
  [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2961
- const keyShape = generalBroadcast(k0.shape, k1.shape);
2962
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2963
- const c0 = zeros(shape$1, {
3124
+ const keyShape = k0.shape;
3125
+ const genShape = shape$1.slice(keyShape.length);
3126
+ const c0 = zeros(genShape, {
2964
3127
  dtype: DType.Uint32,
2965
3128
  device: k0.device
2966
3129
  });
2967
- const c1 = arange(0, prod(shape$1), 1, {
3130
+ const c1 = arange(0, prod(genShape), 1, {
2968
3131
  dtype: DType.Uint32,
2969
3132
  device: k0.device
2970
- }).reshape(shape$1);
3133
+ }).reshape(genShape);
3134
+ k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
3135
+ k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
2971
3136
  const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2972
3137
  return [Array$1.#naryCustom("random_bits", custom, [
2973
3138
  k0,
@@ -2999,42 +3164,11 @@ var Array$1 = class Array$1 extends Tracer {
2999
3164
  [Primitive.Pad]([x], { width }) {
3000
3165
  return [x.#reshape(x.#st.pad(width))];
3001
3166
  },
3002
- [Primitive.Sort]([x]) {
3003
- const routine = new Routine(Routines.Sort, {
3004
- inputShapes: [x.aval.shape],
3005
- inputDtypes: [x.aval.dtype],
3006
- outputShapes: [x.aval.shape],
3007
- outputDtypes: [x.aval.dtype]
3008
- });
3009
- return Array$1.#routine(routine, [x], [x.#weakType]);
3010
- },
3011
- [Primitive.Argsort]([x]) {
3012
- const routine = new Routine(Routines.Argsort, {
3013
- inputShapes: [x.aval.shape],
3014
- inputDtypes: [x.aval.dtype],
3015
- outputShapes: [x.aval.shape, x.aval.shape],
3016
- outputDtypes: [x.aval.dtype, DType.Int32]
3017
- });
3018
- return Array$1.#routine(routine, [x], [x.#weakType, false]);
3019
- },
3020
- [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3021
- const routine = new Routine(Routines.TriangularSolve, {
3022
- inputShapes: [a.aval.shape, b.aval.shape],
3023
- inputDtypes: [a.aval.dtype, b.aval.dtype],
3024
- outputShapes: [b.aval.shape],
3025
- outputDtypes: [b.aval.dtype]
3026
- }, { unitDiagonal });
3027
- return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3028
- },
3029
- [Primitive.Cholesky]([a]) {
3030
- const routine = new Routine(Routines.Cholesky, {
3031
- inputShapes: [a.aval.shape],
3032
- inputDtypes: [a.aval.dtype],
3033
- outputShapes: [a.aval.shape],
3034
- outputDtypes: [a.aval.dtype]
3035
- });
3036
- return Array$1.#routine(routine, [a], [a.#weakType]);
3037
- },
3167
+ [Primitive.Sort]: Array$1.#routine(Primitive.Sort),
3168
+ [Primitive.Argsort]: Array$1.#routine(Primitive.Argsort),
3169
+ [Primitive.TriangularSolve]: Array$1.#routine(Primitive.TriangularSolve),
3170
+ [Primitive.Cholesky]: Array$1.#routine(Primitive.Cholesky),
3171
+ [Primitive.LU]: Array$1.#routine(Primitive.LU),
3038
3172
  [Primitive.Jit](args, { jaxpr }) {
3039
3173
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3040
3174
  const { backend, committed } = Array$1.#computeBackend("jit", args);
@@ -3116,7 +3250,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3116
3250
  if (!shape$1) {
3117
3251
  shape$1 = [];
3118
3252
  let cur = values;
3119
- while (JsArray.isArray(cur)) {
3253
+ while (JsArray$1.isArray(cur)) {
3120
3254
  shape$1.push(cur.length);
3121
3255
  cur = cur[0];
3122
3256
  }
@@ -3140,7 +3274,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3140
3274
  device
3141
3275
  });
3142
3276
  } else {
3143
- const weakType = dtype == void 0;
3277
+ const weakType = dtype == void 0 && shape$1.length === 0;
3144
3278
  dtype = dtype ?? DType.Float32;
3145
3279
  const data = dtypedJsArray(dtype, flat);
3146
3280
  return arrayFromData(data, shape$1, {
@@ -3254,7 +3388,7 @@ function ones(shape$1, { dtype, device } = {}) {
3254
3388
  }
3255
3389
  /** Return a new array of given shape and type, filled with `fill_value`. */
3256
3390
  function full(shape$1, fillValue, { dtype, device } = {}) {
3257
- let weakType = dtype == void 0;
3391
+ let weakType = dtype == void 0 && shape$1.length === 0;
3258
3392
  if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
3259
3393
  else if (typeof fillValue === "boolean") {
3260
3394
  dtype = dtype ?? DType.Bool;
@@ -3412,6 +3546,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
3412
3546
  committed: device != void 0
3413
3547
  });
3414
3548
  }
3549
+ /**
3550
+ * Return numbers spaced evenly on a log scale.
3551
+ *
3552
+ * In linear space, the sequence starts at `base ** start` and ends at
3553
+ * `base ** stop` (see `endpoint` below).
3554
+ *
3555
+ * @param start - `base ** start` is the starting value of the sequence.
3556
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
3557
+ * @param num - Number of samples to generate. Default is 50.
3558
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
3559
+ * @param base - The base of the log space. Default is 10.
3560
+ * @returns Array of evenly spaced values on a log scale.
3561
+ */
3562
+ function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
3563
+ const y = linspace(start, stop, num, endpoint, {
3564
+ dtype,
3565
+ device
3566
+ });
3567
+ const logBase = Math.log(base);
3568
+ return exp$1(mul(y, logBase));
3569
+ }
3415
3570
  function aluCompare(a, b, op) {
3416
3571
  switch (op) {
3417
3572
  case CompareOp.Less: return AluExp.cmplt(a, b);
@@ -3488,6 +3643,7 @@ var BatchTrace = class extends Trace {
3488
3643
  return valOuts$1.map((x) => new BatchTracer(this, x, null));
3489
3644
  }
3490
3645
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3646
+ if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
3491
3647
  return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3492
3648
  }
3493
3649
  get axisSize() {
@@ -3499,13 +3655,13 @@ var BatchTrace = class extends Trace {
3499
3655
  *
3500
3656
  * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3501
3657
  */
3502
- function broadcastBatcher(op) {
3503
- return (axisSize, args, dims) => {
3658
+ function broadcastBatcher(prim) {
3659
+ return (axisSize, args, dims, params) => {
3504
3660
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3505
3661
  const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3506
3662
  const firstIdx = dims.findIndex((d) => d !== null);
3507
3663
  const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3508
- if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3664
+ if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
3509
3665
  args = args.map((x, i) => {
3510
3666
  if (dims[i] === null) return x;
3511
3667
  x = moveBatchAxis(axisSize, dims[i], 0, x);
@@ -3516,37 +3672,45 @@ function broadcastBatcher(op) {
3516
3672
  ]);
3517
3673
  return x;
3518
3674
  });
3519
- return [[op(...args)], [0]];
3675
+ return [[bind1(prim, args, params)], [0]];
3520
3676
  };
3521
3677
  }
3522
- function unopBatcher(op) {
3678
+ function unopBatcher(prim) {
3523
3679
  return (axisSize, [x], [xBdim], params) => {
3524
- return [[op(x, params)], [xBdim]];
3680
+ return [[bind1(prim, [x], params)], [xBdim]];
3681
+ };
3682
+ }
3683
+ function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
3684
+ return (axisSize, [x], [xBdim], params) => {
3685
+ assertNonNull(xBdim);
3686
+ if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), rep(numOutputs, xBdim)];
3687
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3688
+ return [bind(prim, [x], params), rep(numOutputs, 0)];
3525
3689
  };
3526
3690
  }
3527
3691
  const vmapRules = {
3528
- [Primitive.Add]: broadcastBatcher(add$1),
3529
- [Primitive.Mul]: broadcastBatcher(mul),
3530
- [Primitive.Idiv]: broadcastBatcher(idiv),
3531
- [Primitive.Mod]: broadcastBatcher(mod),
3532
- [Primitive.Min]: broadcastBatcher(min$1),
3533
- [Primitive.Max]: broadcastBatcher(max$1),
3534
- [Primitive.Neg]: unopBatcher(neg),
3535
- [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3536
- [Primitive.Floor]: unopBatcher(floor$1),
3537
- [Primitive.Ceil]: unopBatcher(ceil$1),
3538
- [Primitive.StopGradient]: unopBatcher(stopGradient),
3539
- [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3540
- [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3541
- [Primitive.Sin]: unopBatcher(sin$1),
3542
- [Primitive.Cos]: unopBatcher(cos$1),
3543
- [Primitive.Asin]: unopBatcher(asin$1),
3544
- [Primitive.Atan]: unopBatcher(atan$1),
3545
- [Primitive.Exp]: unopBatcher(exp$1),
3546
- [Primitive.Log]: unopBatcher(log$1),
3547
- [Primitive.Erf]: unopBatcher(erf$1),
3548
- [Primitive.Erfc]: unopBatcher(erfc$1),
3549
- [Primitive.Sqrt]: unopBatcher(sqrt$1),
3692
+ [Primitive.Add]: broadcastBatcher(Primitive.Add),
3693
+ [Primitive.Mul]: broadcastBatcher(Primitive.Mul),
3694
+ [Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
3695
+ [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3696
+ [Primitive.Min]: broadcastBatcher(Primitive.Min),
3697
+ [Primitive.Max]: broadcastBatcher(Primitive.Max),
3698
+ [Primitive.Neg]: unopBatcher(Primitive.Neg),
3699
+ [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3700
+ [Primitive.Floor]: unopBatcher(Primitive.Floor),
3701
+ [Primitive.Ceil]: unopBatcher(Primitive.Ceil),
3702
+ [Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
3703
+ [Primitive.Cast]: unopBatcher(Primitive.Cast),
3704
+ [Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
3705
+ [Primitive.Sin]: unopBatcher(Primitive.Sin),
3706
+ [Primitive.Cos]: unopBatcher(Primitive.Cos),
3707
+ [Primitive.Asin]: unopBatcher(Primitive.Asin),
3708
+ [Primitive.Atan]: unopBatcher(Primitive.Atan),
3709
+ [Primitive.Exp]: unopBatcher(Primitive.Exp),
3710
+ [Primitive.Log]: unopBatcher(Primitive.Log),
3711
+ [Primitive.Erf]: unopBatcher(Primitive.Erf),
3712
+ [Primitive.Erfc]: unopBatcher(Primitive.Erfc),
3713
+ [Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
3550
3714
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3551
3715
  assertNonNull(xBdim);
3552
3716
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
@@ -3568,10 +3732,25 @@ const vmapRules = {
3568
3732
  });
3569
3733
  return [[z], [0]];
3570
3734
  },
3571
- [Primitive.Compare](axisSize, args, dims, { op }) {
3572
- return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3735
+ [Primitive.Compare]: broadcastBatcher(Primitive.Compare),
3736
+ [Primitive.Where]: broadcastBatcher(Primitive.Where),
3737
+ [Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
3738
+ const minBdim = Math.min(...xBdims.filter((d) => d !== null));
3739
+ xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
3740
+ const newAxis = axis + (minBdim <= axis ? 1 : 0);
3741
+ return [[concatenate$1(xs, newAxis)], [minBdim]];
3742
+ },
3743
+ [Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
3744
+ assertNonNull(xBdim);
3745
+ const newAxis = axis + (xBdim <= axis ? 1 : 0);
3746
+ const outs = split$2(x, newAxis, sizes);
3747
+ return [outs, rep(outs.length, xBdim)];
3748
+ },
3749
+ [Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
3750
+ k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
3751
+ k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
3752
+ return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
3573
3753
  },
3574
- [Primitive.Where]: broadcastBatcher(where$1),
3575
3754
  [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3576
3755
  if (indicesBdim.every((d) => d === null)) {
3577
3756
  assertNonNull(xBdim);
@@ -3633,18 +3812,8 @@ const vmapRules = {
3633
3812
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3634
3813
  return [[pad$1(x, newWidth)], [xBdim]];
3635
3814
  },
3636
- [Primitive.Sort](axisSize, [x], [xBdim]) {
3637
- assertNonNull(xBdim);
3638
- if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
3639
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3640
- return [[sort$1(x)], [0]];
3641
- },
3642
- [Primitive.Argsort](axisSize, [x], [xBdim]) {
3643
- assertNonNull(xBdim);
3644
- if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
3645
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3646
- return [argsort$1(x), [0, 0]];
3647
- },
3815
+ [Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
3816
+ [Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
3648
3817
  [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3649
3818
  if (aBdim === null) {
3650
3819
  b = moveBatchAxis(axisSize, bBdim, -3, b);
@@ -3668,12 +3837,8 @@ const vmapRules = {
3668
3837
  const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3669
3838
  return [[x], [0]];
3670
3839
  },
3671
- [Primitive.Cholesky](axisSize, [x], [xBdim]) {
3672
- assertNonNull(xBdim);
3673
- if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
3674
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3675
- return [[cholesky$2(x)], [0]];
3676
- },
3840
+ [Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
3841
+ [Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
3677
3842
  [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3678
3843
  const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3679
3844
  const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
@@ -3823,6 +3988,16 @@ function batchMatmulT(a, b) {
3823
3988
  function mT(a) {
3824
3989
  return moveaxis(a, -2, -1);
3825
3990
  }
3991
+ function sliceAxis(a, axis, p) {
3992
+ const slices = Array(a.shape.length).fill([]);
3993
+ slices[checkAxis(axis, a.ndim)] = p;
3994
+ return a.slice(...slices);
3995
+ }
3996
+ function padAxis(a, axis, p) {
3997
+ const pads = Array(a.shape.length).fill([0, 0]);
3998
+ pads[checkAxis(axis, a.ndim)] = p;
3999
+ return pad$1(a, pads);
4000
+ }
3826
4001
  const jvpRules = {
3827
4002
  [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3828
4003
  [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
@@ -3921,6 +4096,8 @@ const jvpRules = {
3921
4096
  dcond.dispose();
3922
4097
  return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3923
4098
  },
4099
+ [Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
4100
+ [Primitive.Split]: linearTangentsJvp(Primitive.Split),
3924
4101
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3925
4102
  [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3926
4103
  const indicesRef = indices.map((t) => t.ref);
@@ -3955,6 +4132,38 @@ const jvpRules = {
3955
4132
  const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
3956
4133
  return [[L], [dL]];
3957
4134
  },
4135
+ [Primitive.LU]([a], [da]) {
4136
+ const [luMatrix, pivots, permutation] = lu$1(a);
4137
+ const [m, n] = a.shape.slice(-2);
4138
+ const k = Math.min(m, n);
4139
+ const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
4140
+ const lLower = tril(luSliceL, -1);
4141
+ const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
4142
+ const L = lPadded.add(eye(m));
4143
+ const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
4144
+ const uUpper = triu(luSliceU);
4145
+ const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4146
+ const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4147
+ const U = uPadded.add(uEye);
4148
+ const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4149
+ const pda = batchMatmulT(P, mT(da));
4150
+ const la = mT(triangularSolve$1(L.ref, mT(pda), {
4151
+ lower: true,
4152
+ unitDiagonal: true
4153
+ }));
4154
+ const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
4155
+ const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
4156
+ const uDot = batchMatmulT(triu(lau), mT(U));
4157
+ return [[
4158
+ luMatrix,
4159
+ pivots,
4160
+ permutation
4161
+ ], [
4162
+ lDot.add(uDot),
4163
+ zerosLike$1(pivots.ref),
4164
+ zerosLike$1(permutation.ref)
4165
+ ]];
4166
+ },
3958
4167
  [Primitive.Jit](primals, tangents, { name, jaxpr }) {
3959
4168
  const newJaxpr = jvpJaxpr(jaxpr);
3960
4169
  const outs = bind(Primitive.Jit, [
@@ -3995,17 +4204,39 @@ function jvpFlat(f, primals, tangents) {
3995
4204
  _usingCtx$1.d();
3996
4205
  }
3997
4206
  }
3998
- function jvp$1(f, primals, tangents) {
4207
+ function jvp$1(f, primals, tangents, { hasAux = false } = {}) {
3999
4208
  const [primalsFlat, inTree] = flatten(primals);
4000
4209
  const [tangentsFlat, inTree2] = flatten(tangents);
4001
4210
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
4002
- const [flatFun, outTree] = flattenFun(f, inTree);
4211
+ let flatFun, outTree, aux;
4212
+ if (hasAux) [flatFun, outTree, aux] = flattenFunWithAux(f, inTree);
4213
+ else [flatFun, outTree] = flattenFun(f, inTree);
4003
4214
  const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
4004
4215
  if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
4005
4216
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
4006
4217
  const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
4218
+ if (hasAux) return [
4219
+ primalsOut,
4220
+ tangentsOut,
4221
+ lowerAux(aux.value)
4222
+ ];
4007
4223
  return [primalsOut, tangentsOut];
4008
4224
  }
4225
+ /** Lowering for auxiliary data returned in `hasAux: true` methods. */
4226
+ function lowerAux(aux) {
4227
+ const level = currentTraceLevel();
4228
+ return map((x) => {
4229
+ if (x instanceof Tracer) while (x._trace.main.level > level) if (x instanceof JVPTracer) {
4230
+ x.tangent.dispose();
4231
+ x = x.primal;
4232
+ } else {
4233
+ const y = x.fullLower();
4234
+ if (y._trace.main.level >= x._trace.main.level) throw new Error("internal: lowerAux did not reduce trace level");
4235
+ x = y;
4236
+ }
4237
+ return x;
4238
+ }, aux);
4239
+ }
4009
4240
 
4010
4241
  //#endregion
4011
4242
  //#region src/frontend/linearize.ts
@@ -4076,9 +4307,11 @@ function linearizeFlat(f, primalsIn) {
4076
4307
  dispose$1
4077
4308
  ];
4078
4309
  }
4079
- function linearize$1(f, ...primalsIn) {
4310
+ function linearize$1(f, primalsIn, { hasAux = false } = {}) {
4080
4311
  const [primalsInFlat, inTree] = flatten(primalsIn);
4081
- const [fFlat, outTree] = flattenFun(f, inTree);
4312
+ let fFlat, outTree, aux;
4313
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4314
+ else [fFlat, outTree] = flattenFun(f, inTree);
4082
4315
  const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
4083
4316
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
4084
4317
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4089,6 +4322,11 @@ function linearize$1(f, ...primalsIn) {
4089
4322
  return unflatten(outTree.value, tangentsOutFlat);
4090
4323
  });
4091
4324
  fLin.dispose = dispose$1;
4325
+ if (hasAux) return [
4326
+ primalsOut,
4327
+ fLin,
4328
+ lowerAux(aux.value)
4329
+ ];
4092
4330
  return [primalsOut, fLin];
4093
4331
  }
4094
4332
  var PartialEvalTracer = class extends Tracer {
@@ -4492,6 +4730,15 @@ const transposeRules = {
4492
4730
  cond.dispose();
4493
4731
  return cts;
4494
4732
  },
4733
+ [Primitive.Concatenate]([ct], inputs, { axis }) {
4734
+ if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
4735
+ const sizes = inputs.map((x) => x.aval.shape[axis]);
4736
+ return split$2(ct, axis, sizes);
4737
+ },
4738
+ [Primitive.Split](cts, [x], { axis }) {
4739
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
4740
+ return [concatenate$1(cts, axis)];
4741
+ },
4495
4742
  [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4496
4743
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4497
4744
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
@@ -4580,9 +4827,11 @@ function vjpFlat(f, primalsIn) {
4580
4827
  dispose$1
4581
4828
  ];
4582
4829
  }
4583
- function vjp$1(f, ...primalsIn) {
4830
+ function vjp$1(f, primalsIn, { hasAux = false } = {}) {
4584
4831
  const [primalsInFlat, inTree] = flatten(primalsIn);
4585
- const [fFlat, outTree] = flattenFun(f, inTree);
4832
+ let fFlat, outTree, aux;
4833
+ if (hasAux) [fFlat, outTree, aux] = flattenFunWithAux(f, inTree);
4834
+ else [fFlat, outTree] = flattenFun(f, inTree);
4586
4835
  const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4587
4836
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
4588
4837
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
@@ -4593,26 +4842,43 @@ function vjp$1(f, ...primalsIn) {
4593
4842
  return unflatten(inTree, cotangentsInFlat);
4594
4843
  });
4595
4844
  fVjp.dispose = dispose$1;
4845
+ if (hasAux) return [
4846
+ primalsOut,
4847
+ fVjp,
4848
+ lowerAux(aux.value)
4849
+ ];
4596
4850
  return [primalsOut, fVjp];
4597
4851
  }
4598
- function grad$1(f) {
4599
- const valueAndGradFn = valueAndGrad$1(f);
4852
+ function grad$1(f, opts) {
4853
+ const valueAndGradFn = valueAndGrad$1(f, opts);
4600
4854
  return (...x) => {
4601
- const [y, dx] = valueAndGradFn(...x);
4602
- y.dispose();
4603
- return dx;
4855
+ if (opts?.hasAux) {
4856
+ const [[y, aux], dx] = valueAndGradFn(...x);
4857
+ y.dispose();
4858
+ return [dx, aux];
4859
+ } else {
4860
+ const [y, dx] = valueAndGradFn(...x);
4861
+ y.dispose();
4862
+ return dx;
4863
+ }
4604
4864
  };
4605
4865
  }
4606
- function valueAndGrad$1(f) {
4866
+ function valueAndGrad$1(f, opts) {
4867
+ const argnums = opts?.argnums ?? 0;
4868
+ const hasAux = opts?.hasAux ?? false;
4869
+ checkInts(argnums);
4870
+ const argnumsSet = new Set(typeof argnums === "number" ? [argnums] : argnums);
4607
4871
  return (...x) => {
4608
4872
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
4609
- const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
4873
+ for (let i = 0; i < x.length; i++) if (!argnumsSet.has(i)) x[i] = map(stopGradient, x[i]);
4874
+ const [y, fVjp, aux] = vjp$1(f, x, { hasAux });
4610
4875
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
4611
4876
  if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
4612
- const [ct, ...rest] = fVjp(onesLike$1(y.ref));
4613
- for (const r of rest) dispose(r);
4877
+ const cts = fVjp(onesLike$1(y.ref));
4614
4878
  fVjp.dispose();
4615
- return [y, ct];
4879
+ for (let i = 0; i < cts.length; i++) if (!argnumsSet.has(i)) dispose(cts[i]);
4880
+ const grads = typeof argnums === "number" ? cts[argnums] : argnums.map((i) => cts[i]);
4881
+ return hasAux ? [[y, aux], grads] : [y, grads];
4616
4882
  };
4617
4883
  }
4618
4884
  function jacrev$1(f) {
@@ -4620,7 +4886,7 @@ function jacrev$1(f) {
4620
4886
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
4621
4887
  const [size$1] = x.shape;
4622
4888
  const pullback = (ct) => {
4623
- const [y, fVjp] = vjp$1(f, x);
4889
+ const [y, fVjp] = vjp$1(f, [x]);
4624
4890
  y.dispose();
4625
4891
  const [ret] = fVjp(ct);
4626
4892
  fVjp.dispose();
@@ -4629,6 +4895,9 @@ function jacrev$1(f) {
4629
4895
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
4630
4896
  };
4631
4897
  }
4898
+ function hessian$1(f) {
4899
+ return jacfwd$1(grad$1(f));
4900
+ }
4632
4901
 
4633
4902
  //#endregion
4634
4903
  //#region src/library/numpy/einsum.ts
@@ -4767,8 +5036,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
4767
5036
  const idx = lhsIndex[j];
4768
5037
  const dim = shape$1[j];
4769
5038
  const existing = sizeMap.get(idx);
4770
- if (existing === void 0) sizeMap.set(idx, dim);
4771
- else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
5039
+ if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
5040
+ else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
4772
5041
  }
4773
5042
  }
4774
5043
  for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
@@ -4924,27 +5193,53 @@ function ifft(a, axis = -1) {
4924
5193
  //#region src/library/numpy-linalg.ts
4925
5194
  var numpy_linalg_exports = {};
4926
5195
  __export(numpy_linalg_exports, {
4927
- cholesky: () => cholesky$1,
5196
+ cholesky: () => cholesky,
5197
+ det: () => det,
4928
5198
  diagonal: () => diagonal,
5199
+ inv: () => inv,
4929
5200
  lstsq: () => lstsq,
4930
5201
  matmul: () => matmul,
5202
+ matrixPower: () => matrixPower,
4931
5203
  matrixTranspose: () => matrixTranspose,
4932
5204
  outer: () => outer,
5205
+ slogdet: () => slogdet,
5206
+ solve: () => solve,
4933
5207
  tensordot: () => tensordot,
4934
5208
  trace: () => trace,
4935
5209
  vecdot: () => vecdot
4936
5210
  });
5211
+ function checkSquare(name, a) {
5212
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
5213
+ return a.shape[a.ndim - 1];
5214
+ }
4937
5215
  /**
4938
5216
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
4939
5217
  *
4940
5218
  * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
4941
5219
  * the input matrix, which is on by default.
4942
5220
  */
4943
- function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
5221
+ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
4944
5222
  a = fudgeArray(a);
4945
- if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
5223
+ checkSquare("cholesky", a);
4946
5224
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
4947
- return cholesky(a, { upper });
5225
+ return cholesky$1(a, { upper });
5226
+ }
5227
+ /** Compute the determinant of a square matrix (batched). */
5228
+ function det(a) {
5229
+ a = fudgeArray(a);
5230
+ const n = checkSquare("det", a);
5231
+ const [lu$2, pivots, permutation] = lu(a);
5232
+ permutation.dispose();
5233
+ const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5234
+ const sign$1 = parity.mul(-2).add(1);
5235
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5236
+ return prod$1(diag$1, -1).mul(sign$1);
5237
+ }
5238
+ /** Compute the inverse of a square matrix (batched). */
5239
+ function inv(a) {
5240
+ a = fudgeArray(a);
5241
+ const n = checkSquare("inv", a);
5242
+ return solve(a, eye(n));
4948
5243
  }
4949
5244
  /**
4950
5245
  * Return the least-squares solution to a linear equation.
@@ -4968,7 +5263,7 @@ function lstsq(a, b) {
4968
5263
  const at = matrixTranspose(a.ref);
4969
5264
  if (m <= n) {
4970
5265
  const aat = matmul(a, at.ref);
4971
- const l = cholesky$1(aat, { symmetrizeInput: false });
5266
+ const l = cholesky(aat, { symmetrizeInput: false });
4972
5267
  const lb = triangularSolve(l.ref, b, {
4973
5268
  leftSide: true,
4974
5269
  lower: true
@@ -4980,7 +5275,7 @@ function lstsq(a, b) {
4980
5275
  return matmul(at, llb.ref);
4981
5276
  } else {
4982
5277
  const ata = matmul(at.ref, a);
4983
- const l = cholesky$1(ata, { symmetrizeInput: false });
5278
+ const l = cholesky(ata, { symmetrizeInput: false });
4984
5279
  const atb = matmul(at, b);
4985
5280
  const lb = triangularSolve(l.ref, atb, {
4986
5281
  leftSide: true,
@@ -4993,6 +5288,169 @@ function lstsq(a, b) {
4993
5288
  return llb;
4994
5289
  }
4995
5290
  }
5291
+ /** Raise a square matrix to an integer power, via repeated squarings. */
5292
+ function matrixPower(a, n) {
5293
+ if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
5294
+ a = fudgeArray(a);
5295
+ const m = checkSquare("matrixPower", a);
5296
+ if (n === 0) {
5297
+ a.dispose();
5298
+ return broadcastTo(eye(m), a.shape);
5299
+ }
5300
+ if (n < 0) {
5301
+ a = inv(a);
5302
+ n = -n;
5303
+ }
5304
+ let result = null;
5305
+ let a2k = a;
5306
+ for (let k = 0; n; k++) {
5307
+ if (k > 0) a2k = matmul(a2k.ref, a2k);
5308
+ if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
5309
+ n = Math.floor(n / 2);
5310
+ }
5311
+ a2k.dispose();
5312
+ return result;
5313
+ }
5314
+ /** Return sign and natural logarithm of the determinant of `a`. */
5315
+ function slogdet(a) {
5316
+ a = fudgeArray(a);
5317
+ const n = checkSquare("slogdet", a);
5318
+ const [lu$2, pivots, permutation] = lu(a);
5319
+ permutation.dispose();
5320
+ let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5321
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5322
+ parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
5323
+ const logabsdet = log(absolute(diag$1)).sum(-1);
5324
+ const sign$1 = parity.mul(-2).add(1);
5325
+ return [sign$1, logabsdet];
5326
+ }
5327
+ /**
5328
+ * Solve a linear system of equations.
5329
+ *
5330
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
5331
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
5332
+ *
5333
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
5334
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
5335
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
5336
+ */
5337
+ function solve(a, b) {
5338
+ a = fudgeArray(a);
5339
+ b = fudgeArray(b);
5340
+ const n = checkSquare("solve", a);
5341
+ if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
5342
+ const bIs1d = b.ndim === 1;
5343
+ if (bIs1d) b = b.reshape([...b.shape, 1]);
5344
+ if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
5345
+ const m = b.shape[b.ndim - 1];
5346
+ const batchDims = generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
5347
+ a = broadcastTo(a, [
5348
+ ...batchDims,
5349
+ n,
5350
+ n
5351
+ ]);
5352
+ b = broadcastTo(b, [
5353
+ ...batchDims,
5354
+ n,
5355
+ m
5356
+ ]);
5357
+ const [lu$2, pivots, permutation] = lu(a);
5358
+ pivots.dispose();
5359
+ const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5360
+ const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5361
+ leftSide: true,
5362
+ lower: true,
5363
+ unitDiagonal: true
5364
+ });
5365
+ let x = triangularSolve(lu$2, LPb.ref, {
5366
+ leftSide: true,
5367
+ lower: false
5368
+ });
5369
+ if (bIs1d) x = squeeze(x, -1);
5370
+ return x;
5371
+ }
5372
+
5373
+ //#endregion
5374
+ //#region src/library/numpy/dtype-info.ts
5375
+ /** Machine limits for floating-point types. */
5376
+ function finfo(dtype) {
5377
+ if (!isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
5378
+ switch (dtype) {
5379
+ case DType.Float16: return Object.freeze({
5380
+ bits: 16,
5381
+ dtype: DType.Float16,
5382
+ eps: 2 ** -10,
5383
+ epsneg: 2 ** -11,
5384
+ machep: -10,
5385
+ max: 65504,
5386
+ maxexp: 16,
5387
+ min: -65504,
5388
+ minexp: -14,
5389
+ negep: -24,
5390
+ nexp: 5,
5391
+ nmant: 10,
5392
+ precision: 3,
5393
+ resolution: .001,
5394
+ smallestNormal: 2 ** -14,
5395
+ smallestSubnormal: 2 ** -24
5396
+ });
5397
+ case DType.Float32: return Object.freeze({
5398
+ bits: 32,
5399
+ dtype: DType.Float32,
5400
+ eps: 2 ** -23,
5401
+ epsneg: 2 ** -24,
5402
+ machep: -23,
5403
+ max: 34028234663852886e22,
5404
+ maxexp: 128,
5405
+ min: -34028234663852886e22,
5406
+ minexp: -126,
5407
+ negep: -24,
5408
+ nexp: 8,
5409
+ nmant: 23,
5410
+ precision: 6,
5411
+ resolution: 1e-6,
5412
+ smallestNormal: 2 ** -126,
5413
+ smallestSubnormal: 2 ** -149
5414
+ });
5415
+ case DType.Float64: return Object.freeze({
5416
+ bits: 64,
5417
+ dtype: DType.Float64,
5418
+ eps: 2 ** -52,
5419
+ epsneg: 2 ** -53,
5420
+ machep: -52,
5421
+ max: Number.MAX_VALUE,
5422
+ maxexp: 1024,
5423
+ min: -Number.MAX_VALUE,
5424
+ minexp: -1022,
5425
+ negep: -53,
5426
+ nexp: 11,
5427
+ nmant: 52,
5428
+ precision: 15,
5429
+ resolution: 1e-15,
5430
+ smallestNormal: 2 ** -1022,
5431
+ smallestSubnormal: 2 ** -1074
5432
+ });
5433
+ default: throw new Error(`finfo: unsupported dtype ${dtype}`);
5434
+ }
5435
+ }
5436
+ /** Machine limits for integer types. */
5437
+ function iinfo(dtype) {
5438
+ switch (dtype) {
5439
+ case DType.Int32: return Object.freeze({
5440
+ bits: 32,
5441
+ dtype: DType.Int32,
5442
+ max: 2147483647,
5443
+ min: -2147483648
5444
+ });
5445
+ case DType.Uint32: return Object.freeze({
5446
+ bits: 32,
5447
+ dtype: DType.Uint32,
5448
+ max: 4294967295,
5449
+ min: 0
5450
+ });
5451
+ default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
5452
+ }
5453
+ }
4996
5454
 
4997
5455
  //#endregion
4998
5456
  //#region src/library/numpy.ts
@@ -5048,6 +5506,7 @@ __export(numpy_exports, {
5048
5506
  diag: () => diag,
5049
5507
  diagonal: () => diagonal,
5050
5508
  divide: () => trueDivide,
5509
+ divmod: () => divmod,
5051
5510
  dot: () => dot$1,
5052
5511
  dstack: () => dstack,
5053
5512
  e: () => e,
@@ -5060,6 +5519,7 @@ __export(numpy_exports, {
5060
5519
  expm1: () => expm1,
5061
5520
  eye: () => eye,
5062
5521
  fft: () => numpy_fft_exports,
5522
+ finfo: () => finfo,
5063
5523
  flip: () => flip,
5064
5524
  fliplr: () => fliplr,
5065
5525
  flipud: () => flipud,
@@ -5067,6 +5527,7 @@ __export(numpy_exports, {
5067
5527
  float32: () => float32,
5068
5528
  float64: () => float64,
5069
5529
  floor: () => floor,
5530
+ floorDivide: () => floorDivide,
5070
5531
  fmod: () => fmod,
5071
5532
  frexp: () => frexp,
5072
5533
  full: () => full,
@@ -5079,6 +5540,7 @@ __export(numpy_exports, {
5079
5540
  hstack: () => hstack,
5080
5541
  hypot: () => hypot,
5081
5542
  identity: () => identity$1,
5543
+ iinfo: () => iinfo,
5082
5544
  inf: () => inf,
5083
5545
  inner: () => inner,
5084
5546
  int32: () => int32,
@@ -5096,6 +5558,7 @@ __export(numpy_exports, {
5096
5558
  log10: () => log10,
5097
5559
  log1p: () => log1p,
5098
5560
  log2: () => log2,
5561
+ logspace: () => logspace,
5099
5562
  matmul: () => matmul,
5100
5563
  matrixTranspose: () => matrixTranspose,
5101
5564
  max: () => max,
@@ -5132,9 +5595,11 @@ __export(numpy_exports, {
5132
5595
  shape: () => shape,
5133
5596
  sign: () => sign,
5134
5597
  sin: () => sin,
5598
+ sinc: () => sinc,
5135
5599
  sinh: () => sinh,
5136
5600
  size: () => size,
5137
5601
  sort: () => sort,
5602
+ split: () => split$1,
5138
5603
  sqrt: () => sqrt,
5139
5604
  square: () => square,
5140
5605
  squeeze: () => squeeze,
@@ -5142,6 +5607,8 @@ __export(numpy_exports, {
5142
5607
  std: () => std,
5143
5608
  subtract: () => subtract,
5144
5609
  sum: () => sum,
5610
+ swapaxes: () => swapaxes,
5611
+ take: () => take,
5145
5612
  tan: () => tan,
5146
5613
  tanh: () => tanh,
5147
5614
  tensordot: () => tensordot,
@@ -5400,6 +5867,45 @@ function flip(x, axis = null) {
5400
5867
  return flip$1(x, axis);
5401
5868
  }
5402
5869
  /**
5870
+ * Split an array into multiple sub-arrays along an axis.
5871
+ *
5872
+ * @param a - The input array to split.
5873
+ * @param indicesOrSections - If an integer, it indicates the number of equal
5874
+ * sections to create along the specified axis. If a list of integers, it
5875
+ * specifies the indices at which to split the array.
5876
+ * @param axis - The axis along which to split the array. Default is 0.
5877
+ */
5878
+ function split$1(a, indicesOrSections, axis = 0) {
5879
+ a = fudgeArray(a);
5880
+ axis = checkAxis(axis, a.ndim);
5881
+ const size$1 = a.shape[axis];
5882
+ let sizes;
5883
+ if (typeof indicesOrSections === "number") {
5884
+ if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
5885
+ const partSize = size$1 / indicesOrSections;
5886
+ sizes = rep(indicesOrSections, partSize);
5887
+ } else {
5888
+ const indices = indicesOrSections;
5889
+ sizes = [indices[0]];
5890
+ for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5891
+ sizes.push(size$1 - indices[indices.length - 1]);
5892
+ }
5893
+ const results = [];
5894
+ for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
5895
+ results.push(a);
5896
+ break;
5897
+ } else if (i + 8 >= sizes.length) {
5898
+ results.push(...split$2(a, axis, sizes.slice(i)));
5899
+ break;
5900
+ } else {
5901
+ const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
5902
+ const outs = split$2(a, axis, groupSizes);
5903
+ results.push(...outs.slice(0, -1));
5904
+ a = outs[outs.length - 1];
5905
+ }
5906
+ return results;
5907
+ }
5908
+ /**
5403
5909
  * Join a sequence of arrays along an existing axis.
5404
5910
  *
5405
5911
  * The arrays must have the same shape, except in the dimension corresponding to
@@ -5411,13 +5917,11 @@ function concatenate(xs, axis = 0) {
5411
5917
  if (xs.length === 0) throw new Error("Need at least one array to concatenate");
5412
5918
  const shapes = xs.map(shape);
5413
5919
  axis = checkAxis(axis, shapes[0].length);
5414
- for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays with shapes ${JSON.stringify(shapes)} along axis ${axis}`);
5415
- const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
5920
+ for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
5416
5921
  let result = xs[0];
5417
- for (let i = 1; i < xs.length; i++) {
5418
- const len1 = result.shape[axis];
5419
- const len2 = shapes[i][axis];
5420
- result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
5922
+ for (let i = 1; i < xs.length; i += 7) {
5923
+ const group = xs.slice(i, i + 7);
5924
+ result = concatenate$1([result, ...group], axis);
5421
5925
  }
5422
5926
  return result;
5423
5927
  }
@@ -5502,6 +6006,17 @@ function flipud(x) {
5502
6006
  function fliplr(x) {
5503
6007
  return flip(x, 1);
5504
6008
  }
6009
+ /** Interchange two axes of an array. */
6010
+ function swapaxes(a, axis1, axis2) {
6011
+ a = fudgeArray(a);
6012
+ axis1 = checkAxis(axis1, a.ndim);
6013
+ axis2 = checkAxis(axis2, a.ndim);
6014
+ if (axis1 === axis2) return a;
6015
+ const perm = range(a.ndim);
6016
+ perm[axis1] = axis2;
6017
+ perm[axis2] = axis1;
6018
+ return transpose(a, perm);
6019
+ }
5505
6020
  /** Transpose the last two dimensions of an array. */
5506
6021
  function matrixTranspose(a) {
5507
6022
  if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
@@ -5669,6 +6184,20 @@ function sort(a, axis = -1) {
5669
6184
  function argsort(a, axis = -1) {
5670
6185
  return fudgeArray(a).argsort(axis);
5671
6186
  }
6187
+ /**
6188
+ * Take elements from an array along an axis.
6189
+ *
6190
+ * This is equivalent to advanced indexing with integer indices over that
6191
+ * numbered axis. By default, the flattened array is used.
6192
+ */
6193
+ function take(a, indices, axis = null) {
6194
+ if (axis === null) {
6195
+ a = ravel(a);
6196
+ axis = 0;
6197
+ }
6198
+ axis = checkAxis(axis, ndim(a));
6199
+ return gather(a, [indices], [axis], axis);
6200
+ }
5672
6201
  /** Return if two arrays are element-wise equal within a tolerance. */
5673
6202
  function allclose(actual, expected, options) {
5674
6203
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -5988,6 +6517,20 @@ function tan(x) {
5988
6517
  x = fudgeArray(x);
5989
6518
  return sin(x.ref).div(cos(x));
5990
6519
  }
6520
+ /**
6521
+ * @function
6522
+ * Return the normalized sinc function.
6523
+ *
6524
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
6525
+ * This is the normalized sinc function commonly used in signal processing.
6526
+ *
6527
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
6528
+ * requires a custom JVP rule to handle properly (see JAX implementation).
6529
+ */
6530
+ const sinc = jit$1(function sinc$1(x) {
6531
+ const pix = x.ref.mul(Math.PI);
6532
+ return where(equal(x, 0), 1, sin(pix.ref).div(pix));
6533
+ });
5991
6534
  /** Element-wise inverse cosine function (inverse of cos). */
5992
6535
  function acos(x) {
5993
6536
  return subtract(pi / 2, asin(x));
@@ -6040,6 +6583,25 @@ function trueDivide(x, y) {
6040
6583
  return x.div(y);
6041
6584
  }
6042
6585
  /**
6586
+ * Return the largest integer smaller or equal to the division of the inputs.
6587
+ *
6588
+ * The result is always rounded towards negative infinity.
6589
+ *
6590
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
6591
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
6592
+ * negative values correctly (note: may overflow near int32 boundaries).
6593
+ *
6594
+ * @param x - Dividend array.
6595
+ * @param y - Divisor array.
6596
+ * @returns Element-wise floor division of x by y.
6597
+ */
6598
+ function floorDivide(x, y) {
6599
+ x = fudgeArray(x);
6600
+ y = fudgeArray(y);
6601
+ if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
6602
+ return subtract(x, remainder(x.ref, y.ref)).div(y);
6603
+ }
6604
+ /**
6043
6605
  * @function
6044
6606
  * Calculate element-wise floating-point modulo operation.
6045
6607
  */
@@ -6053,6 +6615,20 @@ const fmod = jit$1(function fmod$1(x, y) {
6053
6615
  const remainder = jit$1(function remainder$1(x, y) {
6054
6616
  return mod(mod(x, y.ref).add(y.ref), y);
6055
6617
  });
6618
+ /**
6619
+ * Return element-wise quotient and remainder simultaneously.
6620
+ *
6621
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
6622
+ *
6623
+ * @param x - Dividend array.
6624
+ * @param y - Divisor array.
6625
+ * @returns Tuple of [quotient, remainder].
6626
+ */
6627
+ function divmod(x, y) {
6628
+ const xArr = fudgeArray(x);
6629
+ const yArr = fudgeArray(y);
6630
+ return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
6631
+ }
6056
6632
  /** Round input to the nearest integer towards zero. */
6057
6633
  function trunc(x) {
6058
6634
  return idiv(x, 1);
@@ -6216,14 +6792,15 @@ function std(x, axis = null, opts) {
6216
6792
  return sqrt(var_(x, axis, opts));
6217
6793
  }
6218
6794
  /** Estimate the sample covariance of a set of variables. */
6219
- function cov(x, y) {
6795
+ function cov(x, y = null, { rowvar = true } = {}) {
6220
6796
  x = fudgeArray(x);
6221
6797
  if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6222
- if (y !== void 0) {
6798
+ if (y !== null) {
6223
6799
  y = fudgeArray(y);
6224
6800
  if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6225
6801
  x = vstack([x, y]);
6226
6802
  }
6803
+ if (!rowvar) x = x.transpose();
6227
6804
  const [_M, N] = x.shape;
6228
6805
  x = x.ref.sub(x.mean(1, { keepdims: true }));
6229
6806
  return dot$1(x.ref, x.transpose()).div(N - 1);
@@ -6268,7 +6845,8 @@ const isfinite = jit$1(function isfinite$1(x) {
6268
6845
  //#region src/library/lax-linalg.ts
6269
6846
  var lax_linalg_exports = {};
6270
6847
  __export(lax_linalg_exports, {
6271
- cholesky: () => cholesky,
6848
+ cholesky: () => cholesky$1,
6849
+ lu: () => lu,
6272
6850
  triangularSolve: () => triangularSolve
6273
6851
  });
6274
6852
  /**
@@ -6297,11 +6875,39 @@ __export(lax_linalg_exports, {
6297
6875
  * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6298
6876
  * ```
6299
6877
  */
6300
- function cholesky(a, { upper = false } = {}) {
6878
+ function cholesky$1(a, { upper = false } = {}) {
6301
6879
  const L = cholesky$2(a);
6302
6880
  return upper ? moveaxis$1(L, -2, -1) : L;
6303
6881
  }
6304
6882
  /**
6883
+ * LU decomposition with partial pivoting.
6884
+ *
6885
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
6886
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
6887
+ * and `U` is upper-triangular.
6888
+ *
6889
+ * @param x - A batch of matrices with shape `[..., m, n]`.
6890
+ *
6891
+ * @returns A tuple `(lu, pivots, permutation)` where:
6892
+ * - `lu`: combined lower and upper triangular matrices.
6893
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
6894
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
6895
+ *
6896
+ * @example
6897
+ * ```ts
6898
+ * import { lax, numpy as np } from "@jax-js/jax";
6899
+ *
6900
+ * const A = np.array([[4., 3.], [6., 3.]]);
6901
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
6902
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
6903
+ * // pivots = [1, 1]
6904
+ * // permutation = [1, 0]
6905
+ * ```
6906
+ */
6907
+ function lu(x) {
6908
+ return lu$1(x);
6909
+ }
6910
+ /**
6305
6911
  * Solve a triangular linear system.
6306
6912
  *
6307
6913
  * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
@@ -6339,6 +6945,7 @@ var lax_exports = {};
6339
6945
  __export(lax_exports, {
6340
6946
  conv: () => conv,
6341
6947
  convGeneralDilated: () => convGeneralDilated,
6948
+ convTranspose: () => convTranspose,
6342
6949
  convWithGeneralPadding: () => convWithGeneralPadding,
6343
6950
  dot: () => dot,
6344
6951
  erf: () => erf,
@@ -6347,6 +6954,7 @@ __export(lax_exports, {
6347
6954
  reduceWindow: () => reduceWindow,
6348
6955
  stopGradient: () => stopGradient$1
6349
6956
  });
6957
+ const JsArray = globalThis.Array;
6350
6958
  /**
6351
6959
  * General dot product/contraction operator.
6352
6960
  *
@@ -6418,7 +7026,11 @@ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
6418
7026
  * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
6419
7027
  * function in JAX, which wraps XLA's general convolution operator.
6420
7028
  *
6421
- * Grouped convolutions are not supported right now.
7029
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7030
+ * @param rhs - Convolution kernel; shape `[C_out, C_in / G, ...ks]`
7031
+ * @param windowStrides - Strides for each spatial dimension
7032
+ * @param padding - Padding for each spatial dimension, or a string
7033
+ * (`"VALID"`, `"SAME"`, or `"SAME_LOWER"`)
6422
7034
  */
6423
7035
  function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
6424
7036
  if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
@@ -6478,6 +7090,60 @@ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, r
6478
7090
  function conv(lhs, rhs, windowStrides, padding) {
6479
7091
  return convGeneralDilated(lhs, rhs, windowStrides, padding);
6480
7092
  }
7093
+ /**
7094
+ * Convenience wrapper for calculating the N-d convolution "transpose".
7095
+ *
7096
+ * This function directly calculates a fractionally strided conv rather than
7097
+ * indirectly calculating the gradient (transpose) of a forward convolution.
7098
+ * It is equivalent to the JAX version, except:
7099
+ *
7100
+ * - The `use_consistent_padding` option is not available. We only have the
7101
+ * consistent padding case (JAX version >0.8.4).
7102
+ * - The order of dimensions matches `lax.conv_general_dilated`.
7103
+ *
7104
+ * Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial
7105
+ * dimensions or the `(C_out, C_in)` axis order. To get this behavior, set
7106
+ * `transposeKernel` to true.
7107
+ *
7108
+ * @param lhs - Input tensor; shape `[N, C_in, ...xs]`
7109
+ * @param rhs - Convolution kernel; shape `[C_out, C_in, ...ks]`
7110
+ * @param strides - Sequence of n integers, sets fractional stride
7111
+ * @param padding - Apply padding of `dilation * (kernel_size - 1) - padding` to
7112
+ * each side of the input, so it acts like gradient of `conv()`
7113
+ * @param rhsDilation - Atrous dilation for the kernel
7114
+ * @param transposeKernel - Flip spatial axes and swap the input/output channels
7115
+ * of the kernel; its shape should be `[C_in, C_out, ...ks]`
7116
+ */
7117
+ function convTranspose(lhs, rhs, strides, padding, { rhsDilation, transposeKernel = false } = {}) {
7118
+ const kernelShape = rhs.shape.slice(2);
7119
+ rhsDilation = rhsDilation ?? rep(kernelShape.length, 1);
7120
+ const effectiveKernel = kernelShape.map((k, i) => Math.max(0, (k - 1) * rhsDilation[i] + 1));
7121
+ const pads = effectiveKernel.map((k, i) => convTransposePadding(k, strides[i], typeof padding === "string" ? padding : padding[i]));
7122
+ if (transposeKernel) {
7123
+ rhs = flip$1(rhs, range(2, rhs.ndim));
7124
+ rhs = moveaxis(rhs, 0, 1);
7125
+ }
7126
+ return convGeneralDilated(lhs, rhs, rep(lhs.ndim - 2, 1), pads, {
7127
+ lhsDilation: strides,
7128
+ rhsDilation
7129
+ });
7130
+ }
7131
+ function convTransposePadding(k, s, padding) {
7132
+ let padLen;
7133
+ let pad1;
7134
+ if (padding === "SAME") {
7135
+ padLen = k + s - 2;
7136
+ pad1 = s > k - 1 ? k - 1 : Math.ceil(padLen / 2);
7137
+ } else if (padding === "VALID") {
7138
+ padLen = k + s - 2 + Math.max(k - s, 0);
7139
+ pad1 = k - 1;
7140
+ } else if (JsArray.isArray(padding)) {
7141
+ const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7142
+ pad1 = pads[0];
7143
+ padLen = pads[0] + pads[1];
7144
+ } else throw new Error(`convTranspose: Invalid padding type ${padding}`);
7145
+ return [pad1, padLen - pad1];
7146
+ }
6481
7147
  /** Reduce a computation over padded windows. */
6482
7148
  function reduceWindow(operand, computation, windowDimensions, windowStrides) {
6483
7149
  if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
@@ -6516,6 +7182,7 @@ function stopGradient$1(x) {
6516
7182
  var nn_exports = {};
6517
7183
  __export(nn_exports, {
6518
7184
  celu: () => celu,
7185
+ dotProductAttention: () => dotProductAttention,
6519
7186
  elu: () => elu,
6520
7187
  gelu: () => gelu,
6521
7188
  glu: () => glu,
@@ -6832,6 +7499,95 @@ function oneHot(x, numClasses) {
6832
7499
  if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
6833
7500
  return eye(numClasses, void 0, { device: x.device }).slice(x);
6834
7501
  }
7502
+ /**
7503
+ * Scaled dot product attention (SDPA).
7504
+ *
7505
+ * Computes `softmax((Q @ K^T) / sqrt(d) + bias) @ V`, where `Q` is the query,
7506
+ * `K` is the key, `V` is the value, and `d` is the dimensionality of each key
7507
+ * and query vector.
7508
+ *
7509
+ * Multi-query attention is applied when input `key` and `value` tensors have
7510
+ * fewer heads than `query`.
7511
+ *
7512
+ * We use the following uppercase letters to denote array shapes:
7513
+ * - `B` = batch size
7514
+ * - `S` = length of key/value sequences (source)
7515
+ * - `L` = length of query sequences
7516
+ * - `N` = number of attention heads
7517
+ * - `H` = dimensionality of each attention head
7518
+ * - `K` = number of key/value heads (for grouped-query attention)
7519
+ *
7520
+ * The batch size `B` may be omitted, which is equivalent to `B = 1`. In this
7521
+ * case it must be omitted from all inputs.
7522
+ *
7523
+ * @param query - Query array; shape `[B, L, N, H]`
7524
+ * @param key - Key array; shape `[B, S, K, H]`
7525
+ * @param value - Value array; same shape as `key`
7526
+ * @param opts.bias - Optional bias to add to the attention logits; shape
7527
+ * `[B, N, L, S]` or broadcastable to it.
7528
+ * @param opts.mask - Optional mask to apply to the attention logits; should be
7529
+ * a boolean array broadcastable to `[B, N, L, S]`, where `true` indicates
7530
+ * the element should take part in attention.
7531
+ * @param opts.scale - Scaling factor override, default is `1 / sqrt(H)`.
7532
+ * @param opts.isCausal - If true, applies a casual mask.
7533
+ * @param opts.querySeqLengths - Optional sequence lengths for the queries;
7534
+ * shape `(B,)`. Taken from the beginning of the tensor.
7535
+ * @param opts.keyValueSeqLengths - Optional sequence lengths for the keys and
7536
+ * values; shape `(B,)`. Taken from the beginning of the tensor.
7537
+ * @param opts.localWindowSize - If specified, applies a local attention window
7538
+ * of the given size. Can be a single number or a tuple `[left, right]`.
7539
+ *
7540
+ * @returns The result of the attention operation; shape is the same as query
7541
+ * `[B, L, N, H]`, or `[L, N, H]` if `B` is omitted.
7542
+ */
7543
+ function dotProductAttention(query, key$1, value, opts = {}) {
7544
+ if (opts.querySeqLengths !== void 0 || opts.keyValueSeqLengths !== void 0) throw new Error("Sequence length masking is not yet implemented");
7545
+ if (opts.localWindowSize !== void 0) throw new Error("Local attention is not yet implemented");
7546
+ query = fudgeArray(query);
7547
+ key$1 = fudgeArray(key$1);
7548
+ value = fudgeArray(value);
7549
+ if (query.ndim !== 3 && query.ndim !== 4 || query.ndim !== key$1.ndim || query.ndim !== value.ndim) throw new Error(`dotProductAttention: expected all tensors to have rank 3 or 4, got Q=${query.aval}, K=${key$1.aval}, V=${value.aval}`);
7550
+ if (!deepEqual(key$1.shape, value.shape)) throw new Error(`dotProductAttention: key and value shapes must match, got K=${key$1.shape}, V=${value.shape}`);
7551
+ const isRank3 = query.ndim === 3;
7552
+ if (isRank3) {
7553
+ query = expandDims(query, 0);
7554
+ key$1 = expandDims(key$1, 0);
7555
+ value = expandDims(value, 0);
7556
+ }
7557
+ const [B, L, N, H] = query.shape;
7558
+ if (key$1.shape[0] !== B || key$1.shape[3] !== H) throw new Error(`dotProductAttention: query and key shapes mismatch, got Q=${query.aval}, K=${key$1.aval}`);
7559
+ const S = key$1.shape[1];
7560
+ const K = key$1.shape[2];
7561
+ if (N < K || N != K && N % K !== 0) throw new Error(`dotProductAttention: number of query heads N=${N} must be divisible by number of key/value heads K=${K} for GQA`);
7562
+ const G = N / K;
7563
+ key$1 = tile(key$1, [
7564
+ 1,
7565
+ 1,
7566
+ G,
7567
+ 1
7568
+ ]);
7569
+ value = tile(value, [
7570
+ 1,
7571
+ 1,
7572
+ G,
7573
+ 1
7574
+ ]);
7575
+ const scale = opts.scale ?? 1 / Math.sqrt(H);
7576
+ let scores = einsum("BLNH,BSNH->BNLS", query, key$1).mul(scale);
7577
+ if (opts.bias !== void 0) scores = scores.add(opts.bias);
7578
+ if (opts.mask !== void 0) scores = where(opts.mask, scores, -Infinity);
7579
+ if (opts.isCausal) {
7580
+ const causalMask = tri(L, S, 0, { dtype: DType.Bool });
7581
+ scores = where(causalMask, scores, -Infinity);
7582
+ }
7583
+ const attn = softmax(scores, -1);
7584
+ const out = einsum("BNLS,BSNH->BLNH", attn, value);
7585
+ return isRank3 ? out.reshape([
7586
+ L,
7587
+ N,
7588
+ H
7589
+ ]) : out;
7590
+ }
6835
7591
 
6836
7592
  //#endregion
6837
7593
  //#region src/library/random.ts
@@ -6844,33 +7600,41 @@ __export(random_exports, {
6844
7600
  gumbel: () => gumbel,
6845
7601
  key: () => key,
6846
7602
  laplace: () => laplace,
7603
+ multivariateNormal: () => multivariateNormal,
6847
7604
  normal: () => normal,
6848
7605
  split: () => split,
6849
7606
  uniform: () => uniform
6850
7607
  });
6851
- function validateKeyShape(key$1) {
7608
+ function validateKeyShape(key$1, scalar = false) {
6852
7609
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
6853
7610
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
7611
+ if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
6854
7612
  return key$1.shape.slice(0, -1);
6855
7613
  }
7614
+ function getK01(key$1) {
7615
+ const keyShape = validateKeyShape(key$1, true);
7616
+ let [k0, k1] = split$2(key$1, -1, [1, 1]);
7617
+ k0 = k0.reshape(keyShape);
7618
+ k1 = k1.reshape(keyShape);
7619
+ return [k0, k1];
7620
+ }
6856
7621
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
6857
7622
  function key(seed) {
6858
- seed = seed >>> 0;
6859
- return array([0, seed], { dtype: DType.Uint32 });
7623
+ seed = array(seed, { dtype: DType.Uint32 });
7624
+ if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7625
+ return stack([0, seed]);
6860
7626
  }
6861
7627
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
6862
7628
  function split(key$1, num = 2) {
6863
7629
  const shape$1 = typeof num === "number" ? [num] : num;
6864
7630
  for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
6865
- const keyShape = validateKeyShape(key$1);
6866
- const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
6867
- const k1 = key$1.slice(...keyShape.map(() => null), 1);
7631
+ const [k0, k1] = getK01(key$1);
6868
7632
  return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
6869
7633
  }
6870
7634
  /** Sample uniform bits in the form of unsigned integers. */
6871
7635
  function bits(key$1, shape$1 = []) {
6872
- const keyShape = validateKeyShape(key$1);
6873
- return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
7636
+ const [k0, k1] = getK01(key$1);
7637
+ return randomBits(k0, k1, shape$1);
6874
7638
  }
6875
7639
  /**
6876
7640
  * @function
@@ -6944,6 +7708,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
6944
7708
  }, { staticArgnums: [1] });
6945
7709
  /**
6946
7710
  * @function
7711
+ * Sample multivariate normal random values with given mean and covariance.
7712
+ *
7713
+ * The values are returned with the given shape, along with the final dimension
7714
+ * used to represent the n-dimensional multivariate normal factors.
7715
+ *
7716
+ * This uses Cholesky decomposition on the covariance matrix.
7717
+ *
7718
+ * - `key` - PRNG key
7719
+ * - `mean` - Mean vector of shape `[..., n]`
7720
+ * - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
7721
+ * - `shape` - Result batch shape, must be broadcastable with
7722
+ * `mean.shape[:-1]` and `cov.shape[:-2]`
7723
+ * @returns Random samples of shape `[...shape, n]`
7724
+ */
7725
+ const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
7726
+ mean$1 = fudgeArray(mean$1);
7727
+ cov$1 = fudgeArray(cov$1);
7728
+ const n = mean$1.shape[mean$1.ndim - 1];
7729
+ if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
7730
+ const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
7731
+ const L = cholesky(cov$1);
7732
+ const z = normal(key$1, outputShape);
7733
+ return einsum("...ij,...j->...i", L, z).add(mean$1);
7734
+ }, { staticArgnums: [3] });
7735
+ /**
7736
+ * @function
6947
7737
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6948
7738
  *
6949
7739
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
@@ -7033,17 +7823,62 @@ const linearize = linearize$1;
7033
7823
  /**
7034
7824
  * @function
7035
7825
  * Calculate the reverse-mode vector-Jacobian product for a function.
7826
+ *
7827
+ * The return value is a tuple of `[out, vjpFn]`, where `out` is the output of
7828
+ * `f(primals)`, and `vjpFn` is a function that takes in cotangents for each
7829
+ * output and returns the cotangents for each input.
7830
+ *
7831
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7832
+ * `[out, aux]` tuple, and `vjp` returns `[out, vjpFn, aux]`.
7833
+ *
7834
+ * @example
7835
+ * ```ts
7836
+ * const [y, vjpFn] = vjp(f, [x]);
7837
+ *
7838
+ * // With hasAux
7839
+ * const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });
7840
+ * ```
7036
7841
  */
7037
7842
  const vjp = vjp$1;
7038
7843
  /**
7039
7844
  * @function
7040
7845
  * Compute the gradient of a scalar-valued function `f` with respect to its
7041
7846
  * first argument.
7847
+ *
7848
+ * Pass in different `argnums` to differentiate with respect to other
7849
+ * arguments. If a tuple is provided, the return value will be a tuple of
7850
+ * gradients corresponding to each argument index.
7851
+ *
7852
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return a
7853
+ * `[out, aux]` tuple, and the return value will be `[gradient, aux]`.
7854
+ *
7855
+ * @example
7856
+ * ```ts
7857
+ * const gradient = grad(f)(x);
7858
+ *
7859
+ * // With `argnums`
7860
+ * const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);
7861
+ *
7862
+ * // With `hasAux`
7863
+ * const [gradient, aux] = grad(f, { hasAux: true })(x);
7864
+ * ```
7042
7865
  */
7043
7866
  const grad = grad$1;
7044
7867
  /**
7045
7868
  * @function
7046
7869
  * Create a function that evaluates both `f` and the gradient of `f`.
7870
+ *
7871
+ * When `{ hasAux: true }` is passed, the function `f` is expected to return an
7872
+ * `[out, aux]` tuple, and the return value will be `[[out, aux], gradient]`.
7873
+ *
7874
+ * @example
7875
+ * ```ts
7876
+ * // Without hasAux
7877
+ * const [value, gradient] = valueAndGrad(f)(x);
7878
+ *
7879
+ * // With hasAux
7880
+ * const [[value, aux], gradient] = valueAndGrad(f, { hasAux: true })(x);
7881
+ * ```
7047
7882
  */
7048
7883
  const valueAndGrad = valueAndGrad$1;
7049
7884
  /**
@@ -7052,6 +7887,21 @@ const valueAndGrad = valueAndGrad$1;
7052
7887
  */
7053
7888
  const jacrev = jacrev$1;
7054
7889
  /**
7890
+ * @function
7891
+ * Compute the Hessian matrix of a scalar-valued function.
7892
+ *
7893
+ * The Hessian is the matrix of second-order partial derivatives of a function.
7894
+ * This is implemented as `jacfwd(grad(f))`.
7895
+ *
7896
+ * @example
7897
+ * ```ts
7898
+ * const f = (x: np.Array) => np.sum(x.ref.mul(x.ref).mul(x)); // x^3
7899
+ * const H = hessian(f)(np.array([1, 2, 3]));
7900
+ * // H[i,j] = d^2f / dx_i dx_j
7901
+ * ```
7902
+ */
7903
+ const hessian = hessian$1;
7904
+ /**
7055
7905
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
7056
7906
  *
7057
7907
  * This can be used to wait for the results of an intermediate computation to
@@ -7086,4 +7936,4 @@ async function devicePut(x, device) {
7086
7936
  }
7087
7937
 
7088
7938
  //#endregion
7089
- export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
7939
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };