@jax-js/jax 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-Bu9GY6sK.cjs');
33
+ const require_backend = require('./backend-DziQSaoQ.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -387,6 +387,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
387
387
  Primitive$1["PoolTranspose"] = "pool_transpose";
388
388
  Primitive$1["Compare"] = "compare";
389
389
  Primitive$1["Where"] = "where";
390
+ Primitive$1["Concatenate"] = "concatenate";
391
+ Primitive$1["Split"] = "split";
390
392
  Primitive$1["RandomBits"] = "random_bits";
391
393
  Primitive$1["Gather"] = "gather";
392
394
  Primitive$1["Transpose"] = "transpose";
@@ -399,6 +401,7 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
399
401
  Primitive$1["Argsort"] = "argsort";
400
402
  Primitive$1["TriangularSolve"] = "triangular_solve";
401
403
  Primitive$1["Cholesky"] = "cholesky";
404
+ Primitive$1["LU"] = "lu";
402
405
  Primitive$1["Jit"] = "jit";
403
406
  return Primitive$1;
404
407
  }({});
@@ -530,7 +533,25 @@ function where$1(cond, x, y) {
530
533
  y
531
534
  ]);
532
535
  }
536
+ function concatenate$1(xs, axis) {
537
+ if (xs.length === 0) throw new Error("concatenate requires at least one input");
538
+ const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
539
+ axis = require_backend.checkAxis(axis, avals[0].ndim);
540
+ for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
541
+ return bind1(Primitive.Concatenate, xs, { axis });
542
+ }
543
+ function split$2(x, axis, sizes) {
544
+ axis = require_backend.checkAxis(axis, ndim$1(x));
545
+ if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
546
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
547
+ if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
548
+ return bind(Primitive.Split, [x], {
549
+ axis,
550
+ sizes
551
+ });
552
+ }
533
553
  function randomBits(k0, k1, shape$1, mode = "xor") {
554
+ if (!require_backend.deepEqual(k0.shape, k1.shape) || k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
534
555
  return bind1(Primitive.RandomBits, [k0, k1], {
535
556
  shape: shape$1,
536
557
  mode
@@ -597,6 +618,11 @@ function pad$1(x, width) {
597
618
  return bind1(Primitive.Pad, [x], { width });
598
619
  }
599
620
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
621
+ const as = getShape(a);
622
+ const bs = getShape(b);
623
+ if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
624
+ const n = as[as.length - 2];
625
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
600
626
  if (lower) {
601
627
  a = flip$1(a, [-2, -1]);
602
628
  b = flip$1(b, [-1]);
@@ -606,8 +632,15 @@ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
606
632
  return x;
607
633
  }
608
634
  function cholesky$2(x) {
635
+ const aval = ShapedArray.fromAval(getAval(x));
636
+ if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
609
637
  return bind1(Primitive.Cholesky, [x]);
610
638
  }
639
+ function lu$1(x) {
640
+ const aval = ShapedArray.fromAval(getAval(x));
641
+ if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
642
+ return bind(Primitive.LU, [x]);
643
+ }
611
644
  function sort$1(x) {
612
645
  const nd = ndim$1(x);
613
646
  if (nd === 0) throw new Error("sort: requires at least 1D input");
@@ -716,6 +749,9 @@ var Tracer = class Tracer {
716
749
  mul(other) {
717
750
  return mul(this, other);
718
751
  }
752
+ mod(other) {
753
+ return mod(this, other);
754
+ }
719
755
  greater(other) {
720
756
  return greater$1(this, other);
721
757
  }
@@ -828,8 +864,14 @@ var Tracer = class Tracer {
828
864
  */
829
865
  *[Symbol.iterator]() {
830
866
  if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
831
- for (let i = 0; i < this.shape[0]; i++) yield this.ref.slice(i);
832
- this.dispose();
867
+ let residual = this;
868
+ const subarrayShape = this.shape.slice(1);
869
+ for (let i = 0; i < this.shape[0]; i++) {
870
+ const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
871
+ yield lr[0].reshape(subarrayShape);
872
+ residual = lr[1];
873
+ }
874
+ residual.dispose();
833
875
  }
834
876
  /**
835
877
  * Return a sorted copy of an array in ascending order.
@@ -979,6 +1021,9 @@ var ShapedArray = class ShapedArray {
979
1021
  get size() {
980
1022
  return require_backend.prod(this.shape);
981
1023
  }
1024
+ scalar() {
1025
+ return new ShapedArray([], this.dtype, this.weakType);
1026
+ }
982
1027
  toString() {
983
1028
  return `${this.dtype}[${this.shape.join(",")}]`;
984
1029
  }
@@ -1588,7 +1633,7 @@ const abstractEvalRules = {
1588
1633
  return [new ShapedArray(shape$1, dtype, weakType)];
1589
1634
  },
1590
1635
  [Primitive.Conv]([lhs, rhs], params) {
1591
- const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
1636
+ const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
1592
1637
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
1593
1638
  return [new ShapedArray(shape$1, dtype, weakType)];
1594
1639
  },
@@ -1599,10 +1644,25 @@ const abstractEvalRules = {
1599
1644
  const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
1600
1645
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1601
1646
  },
1647
+ [Primitive.Concatenate](xs, { axis }) {
1648
+ if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
1649
+ for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
1650
+ const shape$1 = xs[0].shape.slice();
1651
+ shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
1652
+ const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
1653
+ return [new ShapedArray(shape$1, dtype, weakType)];
1654
+ },
1655
+ [Primitive.Split]([x], { axis, sizes }) {
1656
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
1657
+ if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
1658
+ return sizes.map((size$1) => {
1659
+ return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
1660
+ });
1661
+ },
1602
1662
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1603
1663
  if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1604
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
1605
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1664
+ if (!require_backend.deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
1665
+ if (!require_backend.deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
1606
1666
  return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
1607
1667
  },
1608
1668
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
@@ -1659,6 +1719,16 @@ const abstractEvalRules = {
1659
1719
  if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1660
1720
  return [ShapedArray.fromAval(a)];
1661
1721
  },
1722
+ [Primitive.LU]([a]) {
1723
+ if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
1724
+ const batch = a.shape.slice(0, -2);
1725
+ const [m, n] = a.shape.slice(-2);
1726
+ return [
1727
+ ShapedArray.fromAval(a),
1728
+ new ShapedArray([...batch, Math.min(m, n)], require_backend.DType.Int32, false),
1729
+ new ShapedArray([...batch, m], require_backend.DType.Int32, false)
1730
+ ];
1731
+ },
1662
1732
  [Primitive.Jit](args, { jaxpr }) {
1663
1733
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1664
1734
  if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
@@ -1740,7 +1810,8 @@ const routinePrimitives = new Map([
1740
1810
  [Primitive.Sort, require_backend.Routines.Sort],
1741
1811
  [Primitive.Argsort, require_backend.Routines.Argsort],
1742
1812
  [Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
1743
- [Primitive.Cholesky, require_backend.Routines.Cholesky]
1813
+ [Primitive.Cholesky, require_backend.Routines.Cholesky],
1814
+ [Primitive.LU, require_backend.Routines.LU]
1744
1815
  ]);
1745
1816
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1746
1817
  var JitProgram = class {
@@ -1911,10 +1982,10 @@ function jitCompile(backend, jaxpr) {
1911
1982
  inputs.push(jv.arg);
1912
1983
  } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1913
1984
  const outputs = [];
1914
- for (const outVar$1 of eqn.outBinders) {
1915
- const outId = builder.pushBuffer(outVar$1.aval.size * require_backend.byteWidth(outVar$1.aval.dtype));
1985
+ for (const outVar of eqn.outBinders) {
1986
+ const outId = builder.pushBuffer(outVar.aval.size * require_backend.byteWidth(outVar.aval.dtype));
1916
1987
  outputs.push(outId);
1917
- ctx.set(outVar$1, {
1988
+ ctx.set(outVar, {
1918
1989
  type: "imm",
1919
1990
  arg: outId
1920
1991
  });
@@ -1965,35 +2036,37 @@ function jitCompile(backend, jaxpr) {
1965
2036
  let reduction;
1966
2037
  if (inputReduction) {
1967
2038
  const jv = inputReduction;
1968
- const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1969
- exp$2 = jv.exp.reindexGids(addArgs(jv.args));
2039
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
2040
+ exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
1970
2041
  reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1971
2042
  } else {
1972
2043
  const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1973
2044
  exp$2 = ruleOutput.exp;
1974
2045
  reduction = ruleOutput.reduction;
1975
2046
  }
1976
- const outVar = eqn.outBinders[0];
1977
- if (blackNodes.has(outVar)) {
1978
- const nargs$1 = inputArgs.length;
1979
- const size$1 = outVar.aval.size;
1980
- const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2, reduction);
1981
- const outId = builder.pushKernel(kernel, inputArgs);
1982
- ctx.set(outVar, {
1983
- type: "imm",
1984
- arg: outId
2047
+ for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
2048
+ const outVar = eqn.outBinders[i$1];
2049
+ if (blackNodes.has(outVar)) {
2050
+ const nargs$1 = inputArgs.length;
2051
+ const size$1 = outVar.aval.size;
2052
+ const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2[i$1], reduction);
2053
+ const outId = builder.pushKernel(kernel, inputArgs);
2054
+ ctx.set(outVar, {
2055
+ type: "imm",
2056
+ arg: outId
2057
+ });
2058
+ } else if (reduction) ctx.set(outVar, {
2059
+ type: "red",
2060
+ exp: exp$2[i$1],
2061
+ reduction,
2062
+ args: inputArgs
1985
2063
  });
1986
- } else if (reduction) ctx.set(outVar, {
1987
- type: "red",
1988
- exp: exp$2,
1989
- reduction,
1990
- args: inputArgs
1991
- });
1992
- else ctx.set(outVar, {
1993
- type: "exp",
1994
- exp: exp$2,
1995
- args: inputArgs
1996
- });
2064
+ else ctx.set(outVar, {
2065
+ type: "exp",
2066
+ exp: exp$2[i$1],
2067
+ args: inputArgs
2068
+ });
2069
+ }
1997
2070
  }
1998
2071
  const outputIds = [];
1999
2072
  for (const out of jaxpr.outs) if (out instanceof Var) {
@@ -2034,17 +2107,17 @@ function broadcastedJit(fn, opts) {
2034
2107
  if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
2035
2108
  return exp$2;
2036
2109
  });
2037
- return { exp: fn(exps, params) };
2110
+ return { exp: [fn(exps, params)] };
2038
2111
  };
2039
2112
  }
2040
2113
  function unopJit(fn) {
2041
2114
  return ([a], [_as], params) => {
2042
- return { exp: fn(a, params) };
2115
+ return { exp: [fn(a, params)] };
2043
2116
  };
2044
2117
  }
2045
2118
  function reshapeJit(fn) {
2046
2119
  return ([a], [_as], params) => {
2047
- return { exp: reshapeViews(a, (st) => fn(st, params)) };
2120
+ return { exp: [reshapeViews(a, (st) => fn(st, params))] };
2048
2121
  };
2049
2122
  }
2050
2123
  function routineNoJit() {
@@ -2090,7 +2163,7 @@ const jitRules = {
2090
2163
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
2091
2164
  const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
2092
2165
  return {
2093
- exp: a,
2166
+ exp: [a],
2094
2167
  reduction
2095
2168
  };
2096
2169
  },
@@ -2101,13 +2174,13 @@ const jitRules = {
2101
2174
  a = reshapeViews(a, (st) => st.compose(stX), true);
2102
2175
  const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
2103
2176
  return {
2104
- exp: a,
2177
+ exp: [a],
2105
2178
  reduction
2106
2179
  };
2107
2180
  },
2108
2181
  [Primitive.Dot]([a, b], [as, bs]) {
2109
2182
  const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
2110
- const c = k1.exp;
2183
+ const [c] = k1.exp;
2111
2184
  const cs = promoteAvals(as, bs);
2112
2185
  return jitRules[Primitive.Reduce]([c], [cs], {
2113
2186
  op: require_backend.AluOp.Add,
@@ -2124,16 +2197,41 @@ const jitRules = {
2124
2197
  },
2125
2198
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
2126
2199
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
2200
+ [Primitive.Concatenate](exps, avals, { axis }) {
2201
+ const ndim$2 = avals[0].ndim;
2202
+ const sizes = avals.map((x) => x.shape[axis]);
2203
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
2204
+ const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2205
+ let cum = 0;
2206
+ const src = [];
2207
+ for (let i = 0; i < exps.length; i++) {
2208
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2209
+ src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
2210
+ cum += sizes[i];
2211
+ }
2212
+ return { exp: [src.reduce(require_backend.AluExp.add)] };
2213
+ },
2214
+ [Primitive.Split]([a], [as], { axis, sizes }) {
2215
+ const exp$2 = [];
2216
+ let start = 0;
2217
+ for (const size$1 of sizes) {
2218
+ const slice = require_backend.range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
2219
+ exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
2220
+ start += size$1;
2221
+ }
2222
+ return { exp: exp$2 };
2223
+ },
2127
2224
  [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2225
+ const keyShape = keyShapes[0].shape;
2128
2226
  const mapping = (st) => {
2129
- if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
2227
+ if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(st.shape.length, shape$1.length));
2130
2228
  };
2131
2229
  const k0 = reshapeViews(keys[0], mapping);
2132
2230
  const k1 = reshapeViews(keys[1], mapping);
2133
2231
  const c0 = require_backend.AluExp.u32(0);
2134
- const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
2232
+ const c1 = require_backend.AluExp.mod(require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx), require_backend.AluExp.u32(Math.max(require_backend.prod(shape$1.slice(keyShape.length)), 1)));
2135
2233
  const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
2136
- return { exp: exp$2 };
2234
+ return { exp: [exp$2] };
2137
2235
  },
2138
2236
  [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
2139
2237
  const axisSet = new Set(axis);
@@ -2148,7 +2246,7 @@ const jitRules = {
2148
2246
  for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
2149
2247
  const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
2150
2248
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
2151
- return { exp: x.substitute({ gidx: index }) };
2249
+ return { exp: [x.substitute({ gidx: index })] };
2152
2250
  },
2153
2251
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2154
2252
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
@@ -2164,6 +2262,7 @@ const jitRules = {
2164
2262
  [Primitive.Argsort]: routineNoJit(),
2165
2263
  [Primitive.TriangularSolve]: routineNoJit(),
2166
2264
  [Primitive.Cholesky]: routineNoJit(),
2265
+ [Primitive.LU]: routineNoJit(),
2167
2266
  [Primitive.Jit]() {
2168
2267
  throw new Error("internal: Jit should have been flattened before JIT compilation");
2169
2268
  }
@@ -2442,6 +2541,10 @@ var Array$1 = class Array$1 extends Tracer {
2442
2541
  this.#rc++;
2443
2542
  return this;
2444
2543
  }
2544
+ /** Get the current reference count (for debugging memory management). */
2545
+ get refCount() {
2546
+ return this.#rc;
2547
+ }
2445
2548
  dispose() {
2446
2549
  this.#check();
2447
2550
  if (--this.#rc === 0) {
@@ -2599,7 +2702,7 @@ var Array$1 = class Array$1 extends Tracer {
2599
2702
  } else if (castDtype === void 0) {
2600
2703
  castDtype = arrays[i].#dtype;
2601
2704
  castWeakType = arrays[i].#weakType;
2602
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
2705
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
2603
2706
  const weakType = castWeakType && !strongTypeOutput;
2604
2707
  const { backend, committed } = Array$1.#computeBackend(name, arrays);
2605
2708
  arrays = arrays.map((ar) => ar._putSync(backend));
@@ -2992,17 +3095,44 @@ var Array$1 = class Array$1 extends Tracer {
2992
3095
  y
2993
3096
  ], { dtypeOverride: [require_backend.DType.Bool] })];
2994
3097
  },
3098
+ [Primitive.Concatenate](xs, { axis }) {
3099
+ const ndim$2 = xs[0].ndim;
3100
+ const sizes = xs.map((x) => x.shape[axis]);
3101
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
3102
+ const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
3103
+ let cum = 0;
3104
+ const xsPadded = [];
3105
+ for (let i = 0; i < xs.length; i++) {
3106
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
3107
+ xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
3108
+ cum += sizes[i];
3109
+ }
3110
+ const custom = (exps) => exps.reduce(require_backend.AluExp.add);
3111
+ return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
3112
+ },
3113
+ [Primitive.Split]([x], { axis, sizes }) {
3114
+ const outputs = [];
3115
+ for (let i = 0, start = 0; i < sizes.length; i++) {
3116
+ const slice = require_backend.range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
3117
+ outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
3118
+ start += sizes[i];
3119
+ }
3120
+ x.dispose();
3121
+ return outputs;
3122
+ },
2995
3123
  [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2996
- const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2997
- if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2998
- const c0 = zeros(shape$1, {
3124
+ const keyShape = k0.shape;
3125
+ const genShape = shape$1.slice(keyShape.length);
3126
+ const c0 = zeros(genShape, {
2999
3127
  dtype: require_backend.DType.Uint32,
3000
3128
  device: k0.device
3001
3129
  });
3002
- const c1 = arange(0, require_backend.prod(shape$1), 1, {
3130
+ const c1 = arange(0, require_backend.prod(genShape), 1, {
3003
3131
  dtype: require_backend.DType.Uint32,
3004
3132
  device: k0.device
3005
- }).reshape(shape$1);
3133
+ }).reshape(genShape);
3134
+ k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
3135
+ k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
3006
3136
  const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
3007
3137
  return [Array$1.#naryCustom("random_bits", custom, [
3008
3138
  k0,
@@ -3036,40 +3166,63 @@ var Array$1 = class Array$1 extends Tracer {
3036
3166
  },
3037
3167
  [Primitive.Sort]([x]) {
3038
3168
  const routine = new require_backend.Routine(require_backend.Routines.Sort, {
3039
- inputShapes: [x.aval.shape],
3040
- inputDtypes: [x.aval.dtype],
3041
- outputShapes: [x.aval.shape],
3042
- outputDtypes: [x.aval.dtype]
3169
+ inputShapes: [x.shape],
3170
+ inputDtypes: [x.dtype],
3171
+ outputShapes: [x.shape],
3172
+ outputDtypes: [x.dtype]
3043
3173
  });
3044
3174
  return Array$1.#routine(routine, [x], [x.#weakType]);
3045
3175
  },
3046
3176
  [Primitive.Argsort]([x]) {
3047
3177
  const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
3048
- inputShapes: [x.aval.shape],
3049
- inputDtypes: [x.aval.dtype],
3050
- outputShapes: [x.aval.shape, x.aval.shape],
3051
- outputDtypes: [x.aval.dtype, require_backend.DType.Int32]
3178
+ inputShapes: [x.shape],
3179
+ inputDtypes: [x.dtype],
3180
+ outputShapes: [x.shape, x.shape],
3181
+ outputDtypes: [x.dtype, require_backend.DType.Int32]
3052
3182
  });
3053
3183
  return Array$1.#routine(routine, [x], [x.#weakType, false]);
3054
3184
  },
3055
3185
  [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3056
3186
  const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
3057
- inputShapes: [a.aval.shape, b.aval.shape],
3058
- inputDtypes: [a.aval.dtype, b.aval.dtype],
3059
- outputShapes: [b.aval.shape],
3060
- outputDtypes: [b.aval.dtype]
3187
+ inputShapes: [a.shape, b.shape],
3188
+ inputDtypes: [a.dtype, b.dtype],
3189
+ outputShapes: [b.shape],
3190
+ outputDtypes: [b.dtype]
3061
3191
  }, { unitDiagonal });
3062
3192
  return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3063
3193
  },
3064
3194
  [Primitive.Cholesky]([a]) {
3065
3195
  const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
3066
- inputShapes: [a.aval.shape],
3067
- inputDtypes: [a.aval.dtype],
3068
- outputShapes: [a.aval.shape],
3069
- outputDtypes: [a.aval.dtype]
3196
+ inputShapes: [a.shape],
3197
+ inputDtypes: [a.dtype],
3198
+ outputShapes: [a.shape],
3199
+ outputDtypes: [a.dtype]
3070
3200
  });
3071
3201
  return Array$1.#routine(routine, [a], [a.#weakType]);
3072
3202
  },
3203
+ [Primitive.LU]([a]) {
3204
+ const batch = a.shape.slice(0, -2);
3205
+ const [m, n] = a.shape.slice(-2);
3206
+ const routine = new require_backend.Routine(require_backend.Routines.LU, {
3207
+ inputShapes: [a.shape],
3208
+ inputDtypes: [a.dtype],
3209
+ outputShapes: [
3210
+ a.shape,
3211
+ [...batch, Math.min(m, n)],
3212
+ [...batch, m]
3213
+ ],
3214
+ outputDtypes: [
3215
+ a.dtype,
3216
+ require_backend.DType.Int32,
3217
+ require_backend.DType.Int32
3218
+ ]
3219
+ });
3220
+ return Array$1.#routine(routine, [a], [
3221
+ a.#weakType,
3222
+ false,
3223
+ false
3224
+ ]);
3225
+ },
3073
3226
  [Primitive.Jit](args, { jaxpr }) {
3074
3227
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3075
3228
  const { backend, committed } = Array$1.#computeBackend("jit", args);
@@ -3175,7 +3328,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3175
3328
  device
3176
3329
  });
3177
3330
  } else {
3178
- const weakType = dtype == void 0;
3331
+ const weakType = dtype == void 0 && shape$1.length === 0;
3179
3332
  dtype = dtype ?? require_backend.DType.Float32;
3180
3333
  const data = require_backend.dtypedJsArray(dtype, flat);
3181
3334
  return arrayFromData(data, shape$1, {
@@ -3289,7 +3442,7 @@ function ones(shape$1, { dtype, device } = {}) {
3289
3442
  }
3290
3443
  /** Return a new array of given shape and type, filled with `fill_value`. */
3291
3444
  function full(shape$1, fillValue, { dtype, device } = {}) {
3292
- let weakType = dtype == void 0;
3445
+ let weakType = dtype == void 0 && shape$1.length === 0;
3293
3446
  if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
3294
3447
  else if (typeof fillValue === "boolean") {
3295
3448
  dtype = dtype ?? require_backend.DType.Bool;
@@ -3447,6 +3600,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
3447
3600
  committed: device != void 0
3448
3601
  });
3449
3602
  }
3603
+ /**
3604
+ * Return numbers spaced evenly on a log scale.
3605
+ *
3606
+ * In linear space, the sequence starts at `base ** start` and ends at
3607
+ * `base ** stop` (see `endpoint` below).
3608
+ *
3609
+ * @param start - `base ** start` is the starting value of the sequence.
3610
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
3611
+ * @param num - Number of samples to generate. Default is 50.
3612
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
3613
+ * @param base - The base of the log space. Default is 10.
3614
+ * @returns Array of evenly spaced values on a log scale.
3615
+ */
3616
+ function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
3617
+ const y = linspace(start, stop, num, endpoint, {
3618
+ dtype,
3619
+ device
3620
+ });
3621
+ const logBase = Math.log(base);
3622
+ return exp$1(mul(y, logBase));
3623
+ }
3450
3624
  function aluCompare(a, b, op) {
3451
3625
  switch (op) {
3452
3626
  case CompareOp.Less: return require_backend.AluExp.cmplt(a, b);
@@ -3524,6 +3698,7 @@ var BatchTrace = class extends Trace {
3524
3698
  return valOuts$1.map((x) => new BatchTracer(this, x, null));
3525
3699
  }
3526
3700
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3701
+ if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
3527
3702
  return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3528
3703
  }
3529
3704
  get axisSize() {
@@ -3535,13 +3710,13 @@ var BatchTrace = class extends Trace {
3535
3710
  *
3536
3711
  * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3537
3712
  */
3538
- function broadcastBatcher(op) {
3539
- return (axisSize, args, dims) => {
3713
+ function broadcastBatcher(prim) {
3714
+ return (axisSize, args, dims, params) => {
3540
3715
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3541
3716
  const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3542
3717
  const firstIdx = dims.findIndex((d) => d !== null);
3543
3718
  const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3544
- if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3719
+ if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
3545
3720
  args = args.map((x, i) => {
3546
3721
  if (dims[i] === null) return x;
3547
3722
  x = moveBatchAxis(axisSize, dims[i], 0, x);
@@ -3552,37 +3727,45 @@ function broadcastBatcher(op) {
3552
3727
  ]);
3553
3728
  return x;
3554
3729
  });
3555
- return [[op(...args)], [0]];
3730
+ return [[bind1(prim, args, params)], [0]];
3556
3731
  };
3557
3732
  }
3558
- function unopBatcher(op) {
3733
+ function unopBatcher(prim) {
3559
3734
  return (axisSize, [x], [xBdim], params) => {
3560
- return [[op(x, params)], [xBdim]];
3735
+ return [[bind1(prim, [x], params)], [xBdim]];
3736
+ };
3737
+ }
3738
+ function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
3739
+ return (axisSize, [x], [xBdim], params) => {
3740
+ require_backend.assertNonNull(xBdim);
3741
+ if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), require_backend.rep(numOutputs, xBdim)];
3742
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3743
+ return [bind(prim, [x], params), require_backend.rep(numOutputs, 0)];
3561
3744
  };
3562
3745
  }
3563
3746
  const vmapRules = {
3564
- [Primitive.Add]: broadcastBatcher(add$1),
3565
- [Primitive.Mul]: broadcastBatcher(mul),
3566
- [Primitive.Idiv]: broadcastBatcher(idiv),
3567
- [Primitive.Mod]: broadcastBatcher(mod),
3568
- [Primitive.Min]: broadcastBatcher(min$1),
3569
- [Primitive.Max]: broadcastBatcher(max$1),
3570
- [Primitive.Neg]: unopBatcher(neg),
3571
- [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3572
- [Primitive.Floor]: unopBatcher(floor$1),
3573
- [Primitive.Ceil]: unopBatcher(ceil$1),
3574
- [Primitive.StopGradient]: unopBatcher(stopGradient),
3575
- [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3576
- [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3577
- [Primitive.Sin]: unopBatcher(sin$1),
3578
- [Primitive.Cos]: unopBatcher(cos$1),
3579
- [Primitive.Asin]: unopBatcher(asin$1),
3580
- [Primitive.Atan]: unopBatcher(atan$1),
3581
- [Primitive.Exp]: unopBatcher(exp$1),
3582
- [Primitive.Log]: unopBatcher(log$1),
3583
- [Primitive.Erf]: unopBatcher(erf$1),
3584
- [Primitive.Erfc]: unopBatcher(erfc$1),
3585
- [Primitive.Sqrt]: unopBatcher(sqrt$1),
3747
+ [Primitive.Add]: broadcastBatcher(Primitive.Add),
3748
+ [Primitive.Mul]: broadcastBatcher(Primitive.Mul),
3749
+ [Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
3750
+ [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3751
+ [Primitive.Min]: broadcastBatcher(Primitive.Min),
3752
+ [Primitive.Max]: broadcastBatcher(Primitive.Max),
3753
+ [Primitive.Neg]: unopBatcher(Primitive.Neg),
3754
+ [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3755
+ [Primitive.Floor]: unopBatcher(Primitive.Floor),
3756
+ [Primitive.Ceil]: unopBatcher(Primitive.Ceil),
3757
+ [Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
3758
+ [Primitive.Cast]: unopBatcher(Primitive.Cast),
3759
+ [Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
3760
+ [Primitive.Sin]: unopBatcher(Primitive.Sin),
3761
+ [Primitive.Cos]: unopBatcher(Primitive.Cos),
3762
+ [Primitive.Asin]: unopBatcher(Primitive.Asin),
3763
+ [Primitive.Atan]: unopBatcher(Primitive.Atan),
3764
+ [Primitive.Exp]: unopBatcher(Primitive.Exp),
3765
+ [Primitive.Log]: unopBatcher(Primitive.Log),
3766
+ [Primitive.Erf]: unopBatcher(Primitive.Erf),
3767
+ [Primitive.Erfc]: unopBatcher(Primitive.Erfc),
3768
+ [Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
3586
3769
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3587
3770
  require_backend.assertNonNull(xBdim);
3588
3771
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
@@ -3604,10 +3787,25 @@ const vmapRules = {
3604
3787
  });
3605
3788
  return [[z], [0]];
3606
3789
  },
3607
- [Primitive.Compare](axisSize, args, dims, { op }) {
3608
- return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3790
+ [Primitive.Compare]: broadcastBatcher(Primitive.Compare),
3791
+ [Primitive.Where]: broadcastBatcher(Primitive.Where),
3792
+ [Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
3793
+ const minBdim = Math.min(...xBdims.filter((d) => d !== null));
3794
+ xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
3795
+ const newAxis = axis + (minBdim <= axis ? 1 : 0);
3796
+ return [[concatenate$1(xs, newAxis)], [minBdim]];
3797
+ },
3798
+ [Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
3799
+ require_backend.assertNonNull(xBdim);
3800
+ const newAxis = axis + (xBdim <= axis ? 1 : 0);
3801
+ const outs = split$2(x, newAxis, sizes);
3802
+ return [outs, require_backend.rep(outs.length, xBdim)];
3803
+ },
3804
+ [Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
3805
+ k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
3806
+ k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
3807
+ return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
3609
3808
  },
3610
- [Primitive.Where]: broadcastBatcher(where$1),
3611
3809
  [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3612
3810
  if (indicesBdim.every((d) => d === null)) {
3613
3811
  require_backend.assertNonNull(xBdim);
@@ -3669,18 +3867,8 @@ const vmapRules = {
3669
3867
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3670
3868
  return [[pad$1(x, newWidth)], [xBdim]];
3671
3869
  },
3672
- [Primitive.Sort](axisSize, [x], [xBdim]) {
3673
- require_backend.assertNonNull(xBdim);
3674
- if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
3675
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3676
- return [[sort$1(x)], [0]];
3677
- },
3678
- [Primitive.Argsort](axisSize, [x], [xBdim]) {
3679
- require_backend.assertNonNull(xBdim);
3680
- if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
3681
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3682
- return [argsort$1(x), [0, 0]];
3683
- },
3870
+ [Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
3871
+ [Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
3684
3872
  [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3685
3873
  if (aBdim === null) {
3686
3874
  b = moveBatchAxis(axisSize, bBdim, -3, b);
@@ -3704,12 +3892,8 @@ const vmapRules = {
3704
3892
  const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3705
3893
  return [[x], [0]];
3706
3894
  },
3707
- [Primitive.Cholesky](axisSize, [x], [xBdim]) {
3708
- require_backend.assertNonNull(xBdim);
3709
- if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
3710
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3711
- return [[cholesky$2(x)], [0]];
3712
- },
3895
+ [Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
3896
+ [Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
3713
3897
  [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3714
3898
  const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3715
3899
  const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
@@ -3860,6 +4044,16 @@ function batchMatmulT(a, b) {
3860
4044
  function mT(a) {
3861
4045
  return moveaxis(a, -2, -1);
3862
4046
  }
4047
+ function sliceAxis(a, axis, p) {
4048
+ const slices = Array(a.shape.length).fill([]);
4049
+ slices[require_backend.checkAxis(axis, a.ndim)] = p;
4050
+ return a.slice(...slices);
4051
+ }
4052
+ function padAxis(a, axis, p) {
4053
+ const pads = Array(a.shape.length).fill([0, 0]);
4054
+ pads[require_backend.checkAxis(axis, a.ndim)] = p;
4055
+ return pad$1(a, pads);
4056
+ }
3863
4057
  const jvpRules = {
3864
4058
  [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3865
4059
  [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
@@ -3958,6 +4152,8 @@ const jvpRules = {
3958
4152
  dcond.dispose();
3959
4153
  return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3960
4154
  },
4155
+ [Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
4156
+ [Primitive.Split]: linearTangentsJvp(Primitive.Split),
3961
4157
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3962
4158
  [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3963
4159
  const indicesRef = indices.map((t) => t.ref);
@@ -3992,6 +4188,38 @@ const jvpRules = {
3992
4188
  const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
3993
4189
  return [[L], [dL]];
3994
4190
  },
4191
+ [Primitive.LU]([a], [da]) {
4192
+ const [luMatrix, pivots, permutation] = lu$1(a);
4193
+ const [m, n] = a.shape.slice(-2);
4194
+ const k = Math.min(m, n);
4195
+ const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
4196
+ const lLower = tril(luSliceL, -1);
4197
+ const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
4198
+ const L = lPadded.add(eye(m));
4199
+ const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
4200
+ const uUpper = triu(luSliceU);
4201
+ const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4202
+ const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4203
+ const U = uPadded.add(uEye);
4204
+ const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4205
+ const pda = batchMatmulT(P, mT(da));
4206
+ const la = mT(triangularSolve$1(L.ref, mT(pda), {
4207
+ lower: true,
4208
+ unitDiagonal: true
4209
+ }));
4210
+ const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
4211
+ const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
4212
+ const uDot = batchMatmulT(triu(lau), mT(U));
4213
+ return [[
4214
+ luMatrix,
4215
+ pivots,
4216
+ permutation
4217
+ ], [
4218
+ lDot.add(uDot),
4219
+ zerosLike$1(pivots.ref),
4220
+ zerosLike$1(permutation.ref)
4221
+ ]];
4222
+ },
3995
4223
  [Primitive.Jit](primals, tangents, { name, jaxpr }) {
3996
4224
  const newJaxpr = jvpJaxpr(jaxpr);
3997
4225
  const outs = bind(Primitive.Jit, [
@@ -4529,6 +4757,15 @@ const transposeRules = {
4529
4757
  cond.dispose();
4530
4758
  return cts;
4531
4759
  },
4760
+ [Primitive.Concatenate]([ct], inputs, { axis }) {
4761
+ if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
4762
+ const sizes = inputs.map((x) => x.aval.shape[axis]);
4763
+ return split$2(ct, axis, sizes);
4764
+ },
4765
+ [Primitive.Split](cts, [x], { axis }) {
4766
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
4767
+ return [concatenate$1(cts, axis)];
4768
+ },
4532
4769
  [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4533
4770
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4534
4771
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
@@ -4804,8 +5041,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
4804
5041
  const idx = lhsIndex[j];
4805
5042
  const dim = shape$1[j];
4806
5043
  const existing = sizeMap.get(idx);
4807
- if (existing === void 0) sizeMap.set(idx, dim);
4808
- else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
5044
+ if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
5045
+ else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
4809
5046
  }
4810
5047
  }
4811
5048
  for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
@@ -4961,27 +5198,53 @@ function ifft(a, axis = -1) {
4961
5198
  //#region src/library/numpy-linalg.ts
4962
5199
  var numpy_linalg_exports = {};
4963
5200
  __export(numpy_linalg_exports, {
4964
- cholesky: () => cholesky$1,
5201
+ cholesky: () => cholesky,
5202
+ det: () => det,
4965
5203
  diagonal: () => diagonal,
5204
+ inv: () => inv,
4966
5205
  lstsq: () => lstsq,
4967
5206
  matmul: () => matmul,
5207
+ matrixPower: () => matrixPower,
4968
5208
  matrixTranspose: () => matrixTranspose,
4969
5209
  outer: () => outer,
5210
+ slogdet: () => slogdet,
5211
+ solve: () => solve,
4970
5212
  tensordot: () => tensordot,
4971
5213
  trace: () => trace,
4972
5214
  vecdot: () => vecdot
4973
5215
  });
5216
+ function checkSquare(name, a) {
5217
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
5218
+ return a.shape[a.ndim - 1];
5219
+ }
4974
5220
  /**
4975
5221
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
4976
5222
  *
4977
5223
  * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
4978
5224
  * the input matrix, which is on by default.
4979
5225
  */
4980
- function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
5226
+ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
4981
5227
  a = fudgeArray(a);
4982
- if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
5228
+ checkSquare("cholesky", a);
4983
5229
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
4984
- return cholesky(a, { upper });
5230
+ return cholesky$1(a, { upper });
5231
+ }
5232
+ /** Compute the determinant of a square matrix (batched). */
5233
+ function det(a) {
5234
+ a = fudgeArray(a);
5235
+ const n = checkSquare("det", a);
5236
+ const [lu$2, pivots, permutation] = lu(a);
5237
+ permutation.dispose();
5238
+ const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5239
+ const sign$1 = parity.mul(-2).add(1);
5240
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5241
+ return prod$1(diag$1, -1).mul(sign$1);
5242
+ }
5243
+ /** Compute the inverse of a square matrix (batched). */
5244
+ function inv(a) {
5245
+ a = fudgeArray(a);
5246
+ const n = checkSquare("inv", a);
5247
+ return solve(a, eye(n));
4985
5248
  }
4986
5249
  /**
4987
5250
  * Return the least-squares solution to a linear equation.
@@ -5005,7 +5268,7 @@ function lstsq(a, b) {
5005
5268
  const at = matrixTranspose(a.ref);
5006
5269
  if (m <= n) {
5007
5270
  const aat = matmul(a, at.ref);
5008
- const l = cholesky$1(aat, { symmetrizeInput: false });
5271
+ const l = cholesky(aat, { symmetrizeInput: false });
5009
5272
  const lb = triangularSolve(l.ref, b, {
5010
5273
  leftSide: true,
5011
5274
  lower: true
@@ -5017,7 +5280,7 @@ function lstsq(a, b) {
5017
5280
  return matmul(at, llb.ref);
5018
5281
  } else {
5019
5282
  const ata = matmul(at.ref, a);
5020
- const l = cholesky$1(ata, { symmetrizeInput: false });
5283
+ const l = cholesky(ata, { symmetrizeInput: false });
5021
5284
  const atb = matmul(at, b);
5022
5285
  const lb = triangularSolve(l.ref, atb, {
5023
5286
  leftSide: true,
@@ -5030,6 +5293,169 @@ function lstsq(a, b) {
5030
5293
  return llb;
5031
5294
  }
5032
5295
  }
5296
+ /** Raise a square matrix to an integer power, via repeated squarings. */
5297
+ function matrixPower(a, n) {
5298
+ if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
5299
+ a = fudgeArray(a);
5300
+ const m = checkSquare("matrixPower", a);
5301
+ if (n === 0) {
5302
+ a.dispose();
5303
+ return broadcastTo(eye(m), a.shape);
5304
+ }
5305
+ if (n < 0) {
5306
+ a = inv(a);
5307
+ n = -n;
5308
+ }
5309
+ let result = null;
5310
+ let a2k = a;
5311
+ for (let k = 0; n; k++) {
5312
+ if (k > 0) a2k = matmul(a2k.ref, a2k);
5313
+ if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
5314
+ n = Math.floor(n / 2);
5315
+ }
5316
+ a2k.dispose();
5317
+ return result;
5318
+ }
5319
+ /** Return sign and natural logarithm of the determinant of `a`. */
5320
+ function slogdet(a) {
5321
+ a = fudgeArray(a);
5322
+ const n = checkSquare("slogdet", a);
5323
+ const [lu$2, pivots, permutation] = lu(a);
5324
+ permutation.dispose();
5325
+ let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5326
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5327
+ parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
5328
+ const logabsdet = log(absolute(diag$1)).sum(-1);
5329
+ const sign$1 = parity.mul(-2).add(1);
5330
+ return [sign$1, logabsdet];
5331
+ }
5332
+ /**
5333
+ * Solve a linear system of equations.
5334
+ *
5335
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
5336
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
5337
+ *
5338
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
5339
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
5340
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
5341
+ */
5342
+ function solve(a, b) {
5343
+ a = fudgeArray(a);
5344
+ b = fudgeArray(b);
5345
+ const n = checkSquare("solve", a);
5346
+ if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
5347
+ const bIs1d = b.ndim === 1;
5348
+ if (bIs1d) b = b.reshape([...b.shape, 1]);
5349
+ if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
5350
+ const m = b.shape[b.ndim - 1];
5351
+ const batchDims = require_backend.generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
5352
+ a = broadcastTo(a, [
5353
+ ...batchDims,
5354
+ n,
5355
+ n
5356
+ ]);
5357
+ b = broadcastTo(b, [
5358
+ ...batchDims,
5359
+ n,
5360
+ m
5361
+ ]);
5362
+ const [lu$2, pivots, permutation] = lu(a);
5363
+ pivots.dispose();
5364
+ const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5365
+ const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5366
+ leftSide: true,
5367
+ lower: true,
5368
+ unitDiagonal: true
5369
+ });
5370
+ let x = triangularSolve(lu$2, LPb.ref, {
5371
+ leftSide: true,
5372
+ lower: false
5373
+ });
5374
+ if (bIs1d) x = squeeze(x, -1);
5375
+ return x;
5376
+ }
5377
+
5378
+ //#endregion
5379
+ //#region src/library/numpy/dtype-info.ts
5380
+ /** Machine limits for floating-point types. */
5381
+ function finfo(dtype) {
5382
+ if (!require_backend.isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
5383
+ switch (dtype) {
5384
+ case require_backend.DType.Float16: return Object.freeze({
5385
+ bits: 16,
5386
+ dtype: require_backend.DType.Float16,
5387
+ eps: 2 ** -10,
5388
+ epsneg: 2 ** -11,
5389
+ machep: -10,
5390
+ max: 65504,
5391
+ maxexp: 16,
5392
+ min: -65504,
5393
+ minexp: -14,
5394
+ negep: -24,
5395
+ nexp: 5,
5396
+ nmant: 10,
5397
+ precision: 3,
5398
+ resolution: .001,
5399
+ smallestNormal: 2 ** -14,
5400
+ smallestSubnormal: 2 ** -24
5401
+ });
5402
+ case require_backend.DType.Float32: return Object.freeze({
5403
+ bits: 32,
5404
+ dtype: require_backend.DType.Float32,
5405
+ eps: 2 ** -23,
5406
+ epsneg: 2 ** -24,
5407
+ machep: -23,
5408
+ max: 34028234663852886e22,
5409
+ maxexp: 128,
5410
+ min: -34028234663852886e22,
5411
+ minexp: -126,
5412
+ negep: -24,
5413
+ nexp: 8,
5414
+ nmant: 23,
5415
+ precision: 6,
5416
+ resolution: 1e-6,
5417
+ smallestNormal: 2 ** -126,
5418
+ smallestSubnormal: 2 ** -149
5419
+ });
5420
+ case require_backend.DType.Float64: return Object.freeze({
5421
+ bits: 64,
5422
+ dtype: require_backend.DType.Float64,
5423
+ eps: 2 ** -52,
5424
+ epsneg: 2 ** -53,
5425
+ machep: -52,
5426
+ max: Number.MAX_VALUE,
5427
+ maxexp: 1024,
5428
+ min: -Number.MAX_VALUE,
5429
+ minexp: -1022,
5430
+ negep: -53,
5431
+ nexp: 11,
5432
+ nmant: 52,
5433
+ precision: 15,
5434
+ resolution: 1e-15,
5435
+ smallestNormal: 2 ** -1022,
5436
+ smallestSubnormal: 2 ** -1074
5437
+ });
5438
+ default: throw new Error(`finfo: unsupported dtype ${dtype}`);
5439
+ }
5440
+ }
5441
+ /** Machine limits for integer types. */
5442
+ function iinfo(dtype) {
5443
+ switch (dtype) {
5444
+ case require_backend.DType.Int32: return Object.freeze({
5445
+ bits: 32,
5446
+ dtype: require_backend.DType.Int32,
5447
+ max: 2147483647,
5448
+ min: -2147483648
5449
+ });
5450
+ case require_backend.DType.Uint32: return Object.freeze({
5451
+ bits: 32,
5452
+ dtype: require_backend.DType.Uint32,
5453
+ max: 4294967295,
5454
+ min: 0
5455
+ });
5456
+ default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
5457
+ }
5458
+ }
5033
5459
 
5034
5460
  //#endregion
5035
5461
  //#region src/library/numpy.ts
@@ -5085,6 +5511,7 @@ __export(numpy_exports, {
5085
5511
  diag: () => diag,
5086
5512
  diagonal: () => diagonal,
5087
5513
  divide: () => trueDivide,
5514
+ divmod: () => divmod,
5088
5515
  dot: () => dot$1,
5089
5516
  dstack: () => dstack,
5090
5517
  e: () => e,
@@ -5097,6 +5524,7 @@ __export(numpy_exports, {
5097
5524
  expm1: () => expm1,
5098
5525
  eye: () => eye,
5099
5526
  fft: () => numpy_fft_exports,
5527
+ finfo: () => finfo,
5100
5528
  flip: () => flip,
5101
5529
  fliplr: () => fliplr,
5102
5530
  flipud: () => flipud,
@@ -5104,6 +5532,7 @@ __export(numpy_exports, {
5104
5532
  float32: () => float32,
5105
5533
  float64: () => float64,
5106
5534
  floor: () => floor,
5535
+ floorDivide: () => floorDivide,
5107
5536
  fmod: () => fmod,
5108
5537
  frexp: () => frexp,
5109
5538
  full: () => full,
@@ -5116,6 +5545,7 @@ __export(numpy_exports, {
5116
5545
  hstack: () => hstack,
5117
5546
  hypot: () => hypot,
5118
5547
  identity: () => identity$1,
5548
+ iinfo: () => iinfo,
5119
5549
  inf: () => inf,
5120
5550
  inner: () => inner,
5121
5551
  int32: () => int32,
@@ -5133,6 +5563,7 @@ __export(numpy_exports, {
5133
5563
  log10: () => log10,
5134
5564
  log1p: () => log1p,
5135
5565
  log2: () => log2,
5566
+ logspace: () => logspace,
5136
5567
  matmul: () => matmul,
5137
5568
  matrixTranspose: () => matrixTranspose,
5138
5569
  max: () => max,
@@ -5169,9 +5600,11 @@ __export(numpy_exports, {
5169
5600
  shape: () => shape,
5170
5601
  sign: () => sign,
5171
5602
  sin: () => sin,
5603
+ sinc: () => sinc,
5172
5604
  sinh: () => sinh,
5173
5605
  size: () => size,
5174
5606
  sort: () => sort,
5607
+ split: () => split$1,
5175
5608
  sqrt: () => sqrt,
5176
5609
  square: () => square,
5177
5610
  squeeze: () => squeeze,
@@ -5179,6 +5612,7 @@ __export(numpy_exports, {
5179
5612
  std: () => std,
5180
5613
  subtract: () => subtract,
5181
5614
  sum: () => sum,
5615
+ take: () => take,
5182
5616
  tan: () => tan,
5183
5617
  tanh: () => tanh,
5184
5618
  tensordot: () => tensordot,
@@ -5437,6 +5871,45 @@ function flip(x, axis = null) {
5437
5871
  return flip$1(x, axis);
5438
5872
  }
5439
5873
  /**
5874
+ * Split an array into multiple sub-arrays along an axis.
5875
+ *
5876
+ * @param a - The input array to split.
5877
+ * @param indicesOrSections - If an integer, it indicates the number of equal
5878
+ * sections to create along the specified axis. If a list of integers, it
5879
+ * specifies the indices at which to split the array.
5880
+ * @param axis - The axis along which to split the array. Default is 0.
5881
+ */
5882
+ function split$1(a, indicesOrSections, axis = 0) {
5883
+ a = fudgeArray(a);
5884
+ axis = require_backend.checkAxis(axis, a.ndim);
5885
+ const size$1 = a.shape[axis];
5886
+ let sizes;
5887
+ if (typeof indicesOrSections === "number") {
5888
+ if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
5889
+ const partSize = size$1 / indicesOrSections;
5890
+ sizes = require_backend.rep(indicesOrSections, partSize);
5891
+ } else {
5892
+ const indices = indicesOrSections;
5893
+ sizes = [indices[0]];
5894
+ for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5895
+ sizes.push(size$1 - indices[indices.length - 1]);
5896
+ }
5897
+ const results = [];
5898
+ for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
5899
+ results.push(a);
5900
+ break;
5901
+ } else if (i + 8 >= sizes.length) {
5902
+ results.push(...split$2(a, axis, sizes.slice(i)));
5903
+ break;
5904
+ } else {
5905
+ const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
5906
+ const outs = split$2(a, axis, groupSizes);
5907
+ results.push(...outs.slice(0, -1));
5908
+ a = outs[outs.length - 1];
5909
+ }
5910
+ return results;
5911
+ }
5912
+ /**
5440
5913
  * Join a sequence of arrays along an existing axis.
5441
5914
  *
5442
5915
  * The arrays must have the same shape, except in the dimension corresponding to
@@ -5448,13 +5921,11 @@ function concatenate(xs, axis = 0) {
5448
5921
  if (xs.length === 0) throw new Error("Need at least one array to concatenate");
5449
5922
  const shapes = xs.map(shape);
5450
5923
  axis = require_backend.checkAxis(axis, shapes[0].length);
5451
- for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays with shapes ${JSON.stringify(shapes)} along axis ${axis}`);
5452
- const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
5924
+ for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
5453
5925
  let result = xs[0];
5454
- for (let i = 1; i < xs.length; i++) {
5455
- const len1 = result.shape[axis];
5456
- const len2 = shapes[i][axis];
5457
- result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
5926
+ for (let i = 1; i < xs.length; i += 7) {
5927
+ const group = xs.slice(i, i + 7);
5928
+ result = concatenate$1([result, ...group], axis);
5458
5929
  }
5459
5930
  return result;
5460
5931
  }
@@ -5706,6 +6177,20 @@ function sort(a, axis = -1) {
5706
6177
  function argsort(a, axis = -1) {
5707
6178
  return fudgeArray(a).argsort(axis);
5708
6179
  }
6180
+ /**
6181
+ * Take elements from an array along an axis.
6182
+ *
6183
+ * This is equivalent to advanced indexing with integer indices over that
6184
+ * numbered axis. By default, the flattened array is used.
6185
+ */
6186
+ function take(a, indices, axis = null) {
6187
+ if (axis === null) {
6188
+ a = ravel(a);
6189
+ axis = 0;
6190
+ }
6191
+ axis = require_backend.checkAxis(axis, ndim(a));
6192
+ return gather(a, [indices], [axis], axis);
6193
+ }
5709
6194
  /** Return if two arrays are element-wise equal within a tolerance. */
5710
6195
  function allclose(actual, expected, options) {
5711
6196
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -6025,6 +6510,20 @@ function tan(x) {
6025
6510
  x = fudgeArray(x);
6026
6511
  return sin(x.ref).div(cos(x));
6027
6512
  }
6513
+ /**
6514
+ * @function
6515
+ * Return the normalized sinc function.
6516
+ *
6517
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
6518
+ * This is the normalized sinc function commonly used in signal processing.
6519
+ *
6520
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
6521
+ * requires a custom JVP rule to handle properly (see JAX implementation).
6522
+ */
6523
+ const sinc = jit$1(function sinc$1(x) {
6524
+ const pix = x.ref.mul(Math.PI);
6525
+ return where(equal(x, 0), 1, sin(pix.ref).div(pix));
6526
+ });
6028
6527
  /** Element-wise inverse cosine function (inverse of cos). */
6029
6528
  function acos(x) {
6030
6529
  return subtract(pi / 2, asin(x));
@@ -6077,6 +6576,25 @@ function trueDivide(x, y) {
6077
6576
  return x.div(y);
6078
6577
  }
6079
6578
  /**
6579
+ * Return the largest integer smaller or equal to the division of the inputs.
6580
+ *
6581
+ * The result is always rounded towards negative infinity.
6582
+ *
6583
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
6584
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
6585
+ * negative values correctly (note: may overflow near int32 boundaries).
6586
+ *
6587
+ * @param x - Dividend array.
6588
+ * @param y - Divisor array.
6589
+ * @returns Element-wise floor division of x by y.
6590
+ */
6591
+ function floorDivide(x, y) {
6592
+ x = fudgeArray(x);
6593
+ y = fudgeArray(y);
6594
+ if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
6595
+ return subtract(x, remainder(x.ref, y.ref)).div(y);
6596
+ }
6597
+ /**
6080
6598
  * @function
6081
6599
  * Calculate element-wise floating-point modulo operation.
6082
6600
  */
@@ -6090,6 +6608,20 @@ const fmod = jit$1(function fmod$1(x, y) {
6090
6608
  const remainder = jit$1(function remainder$1(x, y) {
6091
6609
  return mod(mod(x, y.ref).add(y.ref), y);
6092
6610
  });
6611
+ /**
6612
+ * Return element-wise quotient and remainder simultaneously.
6613
+ *
6614
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
6615
+ *
6616
+ * @param x - Dividend array.
6617
+ * @param y - Divisor array.
6618
+ * @returns Tuple of [quotient, remainder].
6619
+ */
6620
+ function divmod(x, y) {
6621
+ const xArr = fudgeArray(x);
6622
+ const yArr = fudgeArray(y);
6623
+ return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
6624
+ }
6093
6625
  /** Round input to the nearest integer towards zero. */
6094
6626
  function trunc(x) {
6095
6627
  return idiv(x, 1);
@@ -6253,14 +6785,15 @@ function std(x, axis = null, opts) {
6253
6785
  return sqrt(var_(x, axis, opts));
6254
6786
  }
6255
6787
  /** Estimate the sample covariance of a set of variables. */
6256
- function cov(x, y) {
6788
+ function cov(x, y = null, { rowvar = true } = {}) {
6257
6789
  x = fudgeArray(x);
6258
6790
  if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6259
- if (y !== void 0) {
6791
+ if (y !== null) {
6260
6792
  y = fudgeArray(y);
6261
6793
  if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6262
6794
  x = vstack([x, y]);
6263
6795
  }
6796
+ if (!rowvar) x = x.transpose();
6264
6797
  const [_M, N] = x.shape;
6265
6798
  x = x.ref.sub(x.mean(1, { keepdims: true }));
6266
6799
  return dot$1(x.ref, x.transpose()).div(N - 1);
@@ -6305,7 +6838,8 @@ const isfinite = jit$1(function isfinite$1(x) {
6305
6838
  //#region src/library/lax-linalg.ts
6306
6839
  var lax_linalg_exports = {};
6307
6840
  __export(lax_linalg_exports, {
6308
- cholesky: () => cholesky,
6841
+ cholesky: () => cholesky$1,
6842
+ lu: () => lu,
6309
6843
  triangularSolve: () => triangularSolve
6310
6844
  });
6311
6845
  /**
@@ -6334,11 +6868,39 @@ __export(lax_linalg_exports, {
6334
6868
  * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6335
6869
  * ```
6336
6870
  */
6337
- function cholesky(a, { upper = false } = {}) {
6871
+ function cholesky$1(a, { upper = false } = {}) {
6338
6872
  const L = cholesky$2(a);
6339
6873
  return upper ? moveaxis$1(L, -2, -1) : L;
6340
6874
  }
6341
6875
  /**
6876
+ * LU decomposition with partial pivoting.
6877
+ *
6878
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
6879
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
6880
+ * and `U` is upper-triangular.
6881
+ *
6882
+ * @param x - A batch of matrices with shape `[..., m, n]`.
6883
+ *
6884
+ * @returns A tuple `(lu, pivots, permutation)` where:
6885
+ * - `lu`: combined lower and upper triangular matrices.
6886
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
6887
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
6888
+ *
6889
+ * @example
6890
+ * ```ts
6891
+ * import { lax, numpy as np } from "@jax-js/jax";
6892
+ *
6893
+ * const A = np.array([[4., 3.], [6., 3.]]);
6894
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
6895
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
6896
+ * // pivots = [1, 1]
6897
+ * // permutation = [1, 0]
6898
+ * ```
6899
+ */
6900
+ function lu(x) {
6901
+ return lu$1(x);
6902
+ }
6903
+ /**
6342
6904
  * Solve a triangular linear system.
6343
6905
  *
6344
6906
  * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
@@ -6881,33 +7443,41 @@ __export(random_exports, {
6881
7443
  gumbel: () => gumbel,
6882
7444
  key: () => key,
6883
7445
  laplace: () => laplace,
7446
+ multivariateNormal: () => multivariateNormal,
6884
7447
  normal: () => normal,
6885
7448
  split: () => split,
6886
7449
  uniform: () => uniform
6887
7450
  });
6888
- function validateKeyShape(key$1) {
7451
+ function validateKeyShape(key$1, scalar = false) {
6889
7452
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
6890
7453
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
7454
+ if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
6891
7455
  return key$1.shape.slice(0, -1);
6892
7456
  }
7457
+ function getK01(key$1) {
7458
+ const keyShape = validateKeyShape(key$1, true);
7459
+ let [k0, k1] = split$2(key$1, -1, [1, 1]);
7460
+ k0 = k0.reshape(keyShape);
7461
+ k1 = k1.reshape(keyShape);
7462
+ return [k0, k1];
7463
+ }
6893
7464
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
6894
7465
  function key(seed) {
6895
- seed = seed >>> 0;
6896
- return array([0, seed], { dtype: require_backend.DType.Uint32 });
7466
+ seed = array(seed, { dtype: require_backend.DType.Uint32 });
7467
+ if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7468
+ return stack([0, seed]);
6897
7469
  }
6898
7470
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
6899
7471
  function split(key$1, num = 2) {
6900
7472
  const shape$1 = typeof num === "number" ? [num] : num;
6901
7473
  for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
6902
- const keyShape = validateKeyShape(key$1);
6903
- const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
6904
- const k1 = key$1.slice(...keyShape.map(() => null), 1);
7474
+ const [k0, k1] = getK01(key$1);
6905
7475
  return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
6906
7476
  }
6907
7477
  /** Sample uniform bits in the form of unsigned integers. */
6908
7478
  function bits(key$1, shape$1 = []) {
6909
- const keyShape = validateKeyShape(key$1);
6910
- return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
7479
+ const [k0, k1] = getK01(key$1);
7480
+ return randomBits(k0, k1, shape$1);
6911
7481
  }
6912
7482
  /**
6913
7483
  * @function
@@ -6981,6 +7551,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
6981
7551
  }, { staticArgnums: [1] });
6982
7552
  /**
6983
7553
  * @function
7554
+ * Sample multivariate normal random values with given mean and covariance.
7555
+ *
7556
+ * The values are returned with the given shape, along with the final dimension
7557
+ * used to represent the n-dimensional multivariate normal factors.
7558
+ *
7559
+ * This uses Cholesky decomposition on the covariance matrix.
7560
+ *
7561
+ * - `key` - PRNG key
7562
+ * - `mean` - Mean vector of shape `[..., n]`
7563
+ * - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
7564
+ * - `shape` - Result batch shape, must be broadcastable with
7565
+ * `mean.shape[:-1]` and `cov.shape[:-2]`
7566
+ * @returns Random samples of shape `[...shape, n]`
7567
+ */
7568
+ const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
7569
+ mean$1 = fudgeArray(mean$1);
7570
+ cov$1 = fudgeArray(cov$1);
7571
+ const n = mean$1.shape[mean$1.ndim - 1];
7572
+ if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
7573
+ const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
7574
+ const L = cholesky(cov$1);
7575
+ const z = normal(key$1, outputShape);
7576
+ return einsum("...ij,...j->...i", L, z).add(mean$1);
7577
+ }, { staticArgnums: [3] });
7578
+ /**
7579
+ * @function
6984
7580
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6985
7581
  *
6986
7582
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and