@jax-js/jax 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-tngXtWe4.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DaqL-MNz.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -356,6 +356,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
356
356
  Primitive$1["PoolTranspose"] = "pool_transpose";
357
357
  Primitive$1["Compare"] = "compare";
358
358
  Primitive$1["Where"] = "where";
359
+ Primitive$1["Concatenate"] = "concatenate";
360
+ Primitive$1["Split"] = "split";
359
361
  Primitive$1["RandomBits"] = "random_bits";
360
362
  Primitive$1["Gather"] = "gather";
361
363
  Primitive$1["Transpose"] = "transpose";
@@ -368,6 +370,7 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
368
370
  Primitive$1["Argsort"] = "argsort";
369
371
  Primitive$1["TriangularSolve"] = "triangular_solve";
370
372
  Primitive$1["Cholesky"] = "cholesky";
373
+ Primitive$1["LU"] = "lu";
371
374
  Primitive$1["Jit"] = "jit";
372
375
  return Primitive$1;
373
376
  }({});
@@ -499,7 +502,25 @@ function where$1(cond, x, y) {
499
502
  y
500
503
  ]);
501
504
  }
505
+ function concatenate$1(xs, axis) {
506
+ if (xs.length === 0) throw new Error("concatenate requires at least one input");
507
+ const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
508
+ axis = checkAxis(axis, avals[0].ndim);
509
+ for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
510
+ return bind1(Primitive.Concatenate, xs, { axis });
511
+ }
512
+ function split$2(x, axis, sizes) {
513
+ axis = checkAxis(axis, ndim$1(x));
514
+ if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
515
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
516
+ if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
517
+ return bind(Primitive.Split, [x], {
518
+ axis,
519
+ sizes
520
+ });
521
+ }
502
522
  function randomBits(k0, k1, shape$1, mode = "xor") {
523
+ if (!deepEqual(k0.shape, k1.shape) || k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
503
524
  return bind1(Primitive.RandomBits, [k0, k1], {
504
525
  shape: shape$1,
505
526
  mode
@@ -566,6 +587,11 @@ function pad$1(x, width) {
566
587
  return bind1(Primitive.Pad, [x], { width });
567
588
  }
568
589
  function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
590
+ const as = getShape(a);
591
+ const bs = getShape(b);
592
+ if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
593
+ const n = as[as.length - 2];
594
+ if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
569
595
  if (lower) {
570
596
  a = flip$1(a, [-2, -1]);
571
597
  b = flip$1(b, [-1]);
@@ -575,8 +601,15 @@ function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
575
601
  return x;
576
602
  }
577
603
  function cholesky$2(x) {
604
+ const aval = ShapedArray.fromAval(getAval(x));
605
+ if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
578
606
  return bind1(Primitive.Cholesky, [x]);
579
607
  }
608
+ function lu$1(x) {
609
+ const aval = ShapedArray.fromAval(getAval(x));
610
+ if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
611
+ return bind(Primitive.LU, [x]);
612
+ }
580
613
  function sort$1(x) {
581
614
  const nd = ndim$1(x);
582
615
  if (nd === 0) throw new Error("sort: requires at least 1D input");
@@ -685,6 +718,9 @@ var Tracer = class Tracer {
685
718
  mul(other) {
686
719
  return mul(this, other);
687
720
  }
721
+ mod(other) {
722
+ return mod(this, other);
723
+ }
688
724
  greater(other) {
689
725
  return greater$1(this, other);
690
726
  }
@@ -797,8 +833,14 @@ var Tracer = class Tracer {
797
833
  */
798
834
  *[Symbol.iterator]() {
799
835
  if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
800
- for (let i = 0; i < this.shape[0]; i++) yield this.ref.slice(i);
801
- this.dispose();
836
+ let residual = this;
837
+ const subarrayShape = this.shape.slice(1);
838
+ for (let i = 0; i < this.shape[0]; i++) {
839
+ const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
840
+ yield lr[0].reshape(subarrayShape);
841
+ residual = lr[1];
842
+ }
843
+ residual.dispose();
802
844
  }
803
845
  /**
804
846
  * Return a sorted copy of an array in ascending order.
@@ -948,6 +990,9 @@ var ShapedArray = class ShapedArray {
948
990
  get size() {
949
991
  return prod(this.shape);
950
992
  }
993
+ scalar() {
994
+ return new ShapedArray([], this.dtype, this.weakType);
995
+ }
951
996
  toString() {
952
997
  return `${this.dtype}[${this.shape.join(",")}]`;
953
998
  }
@@ -1553,7 +1598,7 @@ const abstractEvalRules = {
1553
1598
  return [new ShapedArray(shape$1, dtype, weakType)];
1554
1599
  },
1555
1600
  [Primitive.Conv]([lhs, rhs], params) {
1556
- const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
1601
+ const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
1557
1602
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
1558
1603
  return [new ShapedArray(shape$1, dtype, weakType)];
1559
1604
  },
@@ -1564,10 +1609,25 @@ const abstractEvalRules = {
1564
1609
  const shape$1 = generalBroadcast(cond.shape, xy.shape);
1565
1610
  return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
1566
1611
  },
1612
+ [Primitive.Concatenate](xs, { axis }) {
1613
+ if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
1614
+ for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
1615
+ const shape$1 = xs[0].shape.slice();
1616
+ shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
1617
+ const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
1618
+ return [new ShapedArray(shape$1, dtype, weakType)];
1619
+ },
1620
+ [Primitive.Split]([x], { axis, sizes }) {
1621
+ const totalSize = sizes.reduce((a, b) => a + b, 0);
1622
+ if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
1623
+ return sizes.map((size$1) => {
1624
+ return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
1625
+ });
1626
+ },
1567
1627
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
1568
1628
  if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
1569
- const keyShape = generalBroadcast(k0.shape, k1.shape);
1570
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1629
+ if (!deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
1630
+ if (!deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
1571
1631
  return [new ShapedArray(shape$1, DType.Uint32, false)];
1572
1632
  },
1573
1633
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
@@ -1624,6 +1684,16 @@ const abstractEvalRules = {
1624
1684
  if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
1625
1685
  return [ShapedArray.fromAval(a)];
1626
1686
  },
1687
+ [Primitive.LU]([a]) {
1688
+ if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
1689
+ const batch = a.shape.slice(0, -2);
1690
+ const [m, n] = a.shape.slice(-2);
1691
+ return [
1692
+ ShapedArray.fromAval(a),
1693
+ new ShapedArray([...batch, Math.min(m, n)], DType.Int32, false),
1694
+ new ShapedArray([...batch, m], DType.Int32, false)
1695
+ ];
1696
+ },
1627
1697
  [Primitive.Jit](args, { jaxpr }) {
1628
1698
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
1629
1699
  if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
@@ -1705,7 +1775,8 @@ const routinePrimitives = new Map([
1705
1775
  [Primitive.Sort, Routines.Sort],
1706
1776
  [Primitive.Argsort, Routines.Argsort],
1707
1777
  [Primitive.TriangularSolve, Routines.TriangularSolve],
1708
- [Primitive.Cholesky, Routines.Cholesky]
1778
+ [Primitive.Cholesky, Routines.Cholesky],
1779
+ [Primitive.LU, Routines.LU]
1709
1780
  ]);
1710
1781
  /** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
1711
1782
  var JitProgram = class {
@@ -1876,10 +1947,10 @@ function jitCompile(backend, jaxpr) {
1876
1947
  inputs.push(jv.arg);
1877
1948
  } else if (input instanceof Lit) inputs.push(builder.pushLit(input));
1878
1949
  const outputs = [];
1879
- for (const outVar$1 of eqn.outBinders) {
1880
- const outId = builder.pushBuffer(outVar$1.aval.size * byteWidth(outVar$1.aval.dtype));
1950
+ for (const outVar of eqn.outBinders) {
1951
+ const outId = builder.pushBuffer(outVar.aval.size * byteWidth(outVar.aval.dtype));
1881
1952
  outputs.push(outId);
1882
- ctx.set(outVar$1, {
1953
+ ctx.set(outVar, {
1883
1954
  type: "imm",
1884
1955
  arg: outId
1885
1956
  });
@@ -1930,35 +2001,37 @@ function jitCompile(backend, jaxpr) {
1930
2001
  let reduction;
1931
2002
  if (inputReduction) {
1932
2003
  const jv = inputReduction;
1933
- const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
1934
- exp$2 = jv.exp.reindexGids(addArgs(jv.args));
2004
+ const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
2005
+ exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
1935
2006
  reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
1936
2007
  } else {
1937
2008
  const ruleOutput = rule(inputExps, inputAvals, eqn.params);
1938
2009
  exp$2 = ruleOutput.exp;
1939
2010
  reduction = ruleOutput.reduction;
1940
2011
  }
1941
- const outVar = eqn.outBinders[0];
1942
- if (blackNodes.has(outVar)) {
1943
- const nargs$1 = inputArgs.length;
1944
- const size$1 = outVar.aval.size;
1945
- const kernel = new Kernel(nargs$1, size$1, exp$2, reduction);
1946
- const outId = builder.pushKernel(kernel, inputArgs);
1947
- ctx.set(outVar, {
1948
- type: "imm",
1949
- arg: outId
2012
+ for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
2013
+ const outVar = eqn.outBinders[i$1];
2014
+ if (blackNodes.has(outVar)) {
2015
+ const nargs$1 = inputArgs.length;
2016
+ const size$1 = outVar.aval.size;
2017
+ const kernel = new Kernel(nargs$1, size$1, exp$2[i$1], reduction);
2018
+ const outId = builder.pushKernel(kernel, inputArgs);
2019
+ ctx.set(outVar, {
2020
+ type: "imm",
2021
+ arg: outId
2022
+ });
2023
+ } else if (reduction) ctx.set(outVar, {
2024
+ type: "red",
2025
+ exp: exp$2[i$1],
2026
+ reduction,
2027
+ args: inputArgs
1950
2028
  });
1951
- } else if (reduction) ctx.set(outVar, {
1952
- type: "red",
1953
- exp: exp$2,
1954
- reduction,
1955
- args: inputArgs
1956
- });
1957
- else ctx.set(outVar, {
1958
- type: "exp",
1959
- exp: exp$2,
1960
- args: inputArgs
1961
- });
2029
+ else ctx.set(outVar, {
2030
+ type: "exp",
2031
+ exp: exp$2[i$1],
2032
+ args: inputArgs
2033
+ });
2034
+ }
1962
2035
  }
1963
2036
  const outputIds = [];
1964
2037
  for (const out of jaxpr.outs) if (out instanceof Var) {
@@ -1999,17 +2072,17 @@ function broadcastedJit(fn, opts) {
1999
2072
  if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
2000
2073
  return exp$2;
2001
2074
  });
2002
- return { exp: fn(exps, params) };
2075
+ return { exp: [fn(exps, params)] };
2003
2076
  };
2004
2077
  }
2005
2078
  function unopJit(fn) {
2006
2079
  return ([a], [_as], params) => {
2007
- return { exp: fn(a, params) };
2080
+ return { exp: [fn(a, params)] };
2008
2081
  };
2009
2082
  }
2010
2083
  function reshapeJit(fn) {
2011
2084
  return ([a], [_as], params) => {
2012
- return { exp: reshapeViews(a, (st) => fn(st, params)) };
2085
+ return { exp: [reshapeViews(a, (st) => fn(st, params))] };
2013
2086
  };
2014
2087
  }
2015
2088
  function routineNoJit() {
@@ -2055,7 +2128,7 @@ const jitRules = {
2055
2128
  a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
2056
2129
  const reduction = new Reduction(a.dtype, op, reductionSize);
2057
2130
  return {
2058
- exp: a,
2131
+ exp: [a],
2059
2132
  reduction
2060
2133
  };
2061
2134
  },
@@ -2066,13 +2139,13 @@ const jitRules = {
2066
2139
  a = reshapeViews(a, (st) => st.compose(stX), true);
2067
2140
  const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
2068
2141
  return {
2069
- exp: a,
2142
+ exp: [a],
2070
2143
  reduction
2071
2144
  };
2072
2145
  },
2073
2146
  [Primitive.Dot]([a, b], [as, bs]) {
2074
2147
  const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
2075
- const c = k1.exp;
2148
+ const [c] = k1.exp;
2076
2149
  const cs = promoteAvals(as, bs);
2077
2150
  return jitRules[Primitive.Reduce]([c], [cs], {
2078
2151
  op: AluOp.Add,
@@ -2089,16 +2162,41 @@ const jitRules = {
2089
2162
  },
2090
2163
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
2091
2164
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
2165
+ [Primitive.Concatenate](exps, avals, { axis }) {
2166
+ const ndim$2 = avals[0].ndim;
2167
+ const sizes = avals.map((x) => x.shape[axis]);
2168
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
2169
+ const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
2170
+ let cum = 0;
2171
+ const src = [];
2172
+ for (let i = 0; i < exps.length; i++) {
2173
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
2174
+ src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
2175
+ cum += sizes[i];
2176
+ }
2177
+ return { exp: [src.reduce(AluExp.add)] };
2178
+ },
2179
+ [Primitive.Split]([a], [as], { axis, sizes }) {
2180
+ const exp$2 = [];
2181
+ let start = 0;
2182
+ for (const size$1 of sizes) {
2183
+ const slice = range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
2184
+ exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
2185
+ start += size$1;
2186
+ }
2187
+ return { exp: exp$2 };
2188
+ },
2092
2189
  [Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
2190
+ const keyShape = keyShapes[0].shape;
2093
2191
  const mapping = (st) => {
2094
- if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
2192
+ if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(st.shape.length, shape$1.length));
2095
2193
  };
2096
2194
  const k0 = reshapeViews(keys[0], mapping);
2097
2195
  const k1 = reshapeViews(keys[1], mapping);
2098
2196
  const c0 = AluExp.u32(0);
2099
- const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
2197
+ const c1 = AluExp.mod(AluExp.cast(DType.Uint32, AluVar.gidx), AluExp.u32(Math.max(prod(shape$1.slice(keyShape.length)), 1)));
2100
2198
  const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
2101
- return { exp: exp$2 };
2199
+ return { exp: [exp$2] };
2102
2200
  },
2103
2201
  [Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
2104
2202
  const axisSet = new Set(axis);
@@ -2113,7 +2211,7 @@ const jitRules = {
2113
2211
  for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
2114
2212
  const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
2115
2213
  if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
2116
- return { exp: x.substitute({ gidx: index }) };
2214
+ return { exp: [x.substitute({ gidx: index })] };
2117
2215
  },
2118
2216
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
2119
2217
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
@@ -2129,6 +2227,7 @@ const jitRules = {
2129
2227
  [Primitive.Argsort]: routineNoJit(),
2130
2228
  [Primitive.TriangularSolve]: routineNoJit(),
2131
2229
  [Primitive.Cholesky]: routineNoJit(),
2230
+ [Primitive.LU]: routineNoJit(),
2132
2231
  [Primitive.Jit]() {
2133
2232
  throw new Error("internal: Jit should have been flattened before JIT compilation");
2134
2233
  }
@@ -2407,6 +2506,10 @@ var Array$1 = class Array$1 extends Tracer {
2407
2506
  this.#rc++;
2408
2507
  return this;
2409
2508
  }
2509
+ /** Get the current reference count (for debugging memory management). */
2510
+ get refCount() {
2511
+ return this.#rc;
2512
+ }
2410
2513
  dispose() {
2411
2514
  this.#check();
2412
2515
  if (--this.#rc === 0) {
@@ -2564,7 +2667,7 @@ var Array$1 = class Array$1 extends Tracer {
2564
2667
  } else if (castDtype === void 0) {
2565
2668
  castDtype = arrays[i].#dtype;
2566
2669
  castWeakType = arrays[i].#weakType;
2567
- } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
2670
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
2568
2671
  const weakType = castWeakType && !strongTypeOutput;
2569
2672
  const { backend, committed } = Array$1.#computeBackend(name, arrays);
2570
2673
  arrays = arrays.map((ar) => ar._putSync(backend));
@@ -2957,17 +3060,44 @@ var Array$1 = class Array$1 extends Tracer {
2957
3060
  y
2958
3061
  ], { dtypeOverride: [DType.Bool] })];
2959
3062
  },
3063
+ [Primitive.Concatenate](xs, { axis }) {
3064
+ const ndim$2 = xs[0].ndim;
3065
+ const sizes = xs.map((x) => x.shape[axis]);
3066
+ const finalSize = sizes.reduce((a, b) => a + b, 0);
3067
+ const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
3068
+ let cum = 0;
3069
+ const xsPadded = [];
3070
+ for (let i = 0; i < xs.length; i++) {
3071
+ const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
3072
+ xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
3073
+ cum += sizes[i];
3074
+ }
3075
+ const custom = (exps) => exps.reduce(AluExp.add);
3076
+ return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
3077
+ },
3078
+ [Primitive.Split]([x], { axis, sizes }) {
3079
+ const outputs = [];
3080
+ for (let i = 0, start = 0; i < sizes.length; i++) {
3081
+ const slice = range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
3082
+ outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
3083
+ start += sizes[i];
3084
+ }
3085
+ x.dispose();
3086
+ return outputs;
3087
+ },
2960
3088
  [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
2961
- const keyShape = generalBroadcast(k0.shape, k1.shape);
2962
- if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2963
- const c0 = zeros(shape$1, {
3089
+ const keyShape = k0.shape;
3090
+ const genShape = shape$1.slice(keyShape.length);
3091
+ const c0 = zeros(genShape, {
2964
3092
  dtype: DType.Uint32,
2965
3093
  device: k0.device
2966
3094
  });
2967
- const c1 = arange(0, prod(shape$1), 1, {
3095
+ const c1 = arange(0, prod(genShape), 1, {
2968
3096
  dtype: DType.Uint32,
2969
3097
  device: k0.device
2970
- }).reshape(shape$1);
3098
+ }).reshape(genShape);
3099
+ k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
3100
+ k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
2971
3101
  const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
2972
3102
  return [Array$1.#naryCustom("random_bits", custom, [
2973
3103
  k0,
@@ -3001,40 +3131,63 @@ var Array$1 = class Array$1 extends Tracer {
3001
3131
  },
3002
3132
  [Primitive.Sort]([x]) {
3003
3133
  const routine = new Routine(Routines.Sort, {
3004
- inputShapes: [x.aval.shape],
3005
- inputDtypes: [x.aval.dtype],
3006
- outputShapes: [x.aval.shape],
3007
- outputDtypes: [x.aval.dtype]
3134
+ inputShapes: [x.shape],
3135
+ inputDtypes: [x.dtype],
3136
+ outputShapes: [x.shape],
3137
+ outputDtypes: [x.dtype]
3008
3138
  });
3009
3139
  return Array$1.#routine(routine, [x], [x.#weakType]);
3010
3140
  },
3011
3141
  [Primitive.Argsort]([x]) {
3012
3142
  const routine = new Routine(Routines.Argsort, {
3013
- inputShapes: [x.aval.shape],
3014
- inputDtypes: [x.aval.dtype],
3015
- outputShapes: [x.aval.shape, x.aval.shape],
3016
- outputDtypes: [x.aval.dtype, DType.Int32]
3143
+ inputShapes: [x.shape],
3144
+ inputDtypes: [x.dtype],
3145
+ outputShapes: [x.shape, x.shape],
3146
+ outputDtypes: [x.dtype, DType.Int32]
3017
3147
  });
3018
3148
  return Array$1.#routine(routine, [x], [x.#weakType, false]);
3019
3149
  },
3020
3150
  [Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
3021
3151
  const routine = new Routine(Routines.TriangularSolve, {
3022
- inputShapes: [a.aval.shape, b.aval.shape],
3023
- inputDtypes: [a.aval.dtype, b.aval.dtype],
3024
- outputShapes: [b.aval.shape],
3025
- outputDtypes: [b.aval.dtype]
3152
+ inputShapes: [a.shape, b.shape],
3153
+ inputDtypes: [a.dtype, b.dtype],
3154
+ outputShapes: [b.shape],
3155
+ outputDtypes: [b.dtype]
3026
3156
  }, { unitDiagonal });
3027
3157
  return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
3028
3158
  },
3029
3159
  [Primitive.Cholesky]([a]) {
3030
3160
  const routine = new Routine(Routines.Cholesky, {
3031
- inputShapes: [a.aval.shape],
3032
- inputDtypes: [a.aval.dtype],
3033
- outputShapes: [a.aval.shape],
3034
- outputDtypes: [a.aval.dtype]
3161
+ inputShapes: [a.shape],
3162
+ inputDtypes: [a.dtype],
3163
+ outputShapes: [a.shape],
3164
+ outputDtypes: [a.dtype]
3035
3165
  });
3036
3166
  return Array$1.#routine(routine, [a], [a.#weakType]);
3037
3167
  },
3168
+ [Primitive.LU]([a]) {
3169
+ const batch = a.shape.slice(0, -2);
3170
+ const [m, n] = a.shape.slice(-2);
3171
+ const routine = new Routine(Routines.LU, {
3172
+ inputShapes: [a.shape],
3173
+ inputDtypes: [a.dtype],
3174
+ outputShapes: [
3175
+ a.shape,
3176
+ [...batch, Math.min(m, n)],
3177
+ [...batch, m]
3178
+ ],
3179
+ outputDtypes: [
3180
+ a.dtype,
3181
+ DType.Int32,
3182
+ DType.Int32
3183
+ ]
3184
+ });
3185
+ return Array$1.#routine(routine, [a], [
3186
+ a.#weakType,
3187
+ false,
3188
+ false
3189
+ ]);
3190
+ },
3038
3191
  [Primitive.Jit](args, { jaxpr }) {
3039
3192
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
3040
3193
  const { backend, committed } = Array$1.#computeBackend("jit", args);
@@ -3140,7 +3293,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3140
3293
  device
3141
3294
  });
3142
3295
  } else {
3143
- const weakType = dtype == void 0;
3296
+ const weakType = dtype == void 0 && shape$1.length === 0;
3144
3297
  dtype = dtype ?? DType.Float32;
3145
3298
  const data = dtypedJsArray(dtype, flat);
3146
3299
  return arrayFromData(data, shape$1, {
@@ -3254,7 +3407,7 @@ function ones(shape$1, { dtype, device } = {}) {
3254
3407
  }
3255
3408
  /** Return a new array of given shape and type, filled with `fill_value`. */
3256
3409
  function full(shape$1, fillValue, { dtype, device } = {}) {
3257
- let weakType = dtype == void 0;
3410
+ let weakType = dtype == void 0 && shape$1.length === 0;
3258
3411
  if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
3259
3412
  else if (typeof fillValue === "boolean") {
3260
3413
  dtype = dtype ?? DType.Bool;
@@ -3412,6 +3565,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
3412
3565
  committed: device != void 0
3413
3566
  });
3414
3567
  }
3568
+ /**
3569
+ * Return numbers spaced evenly on a log scale.
3570
+ *
3571
+ * In linear space, the sequence starts at `base ** start` and ends at
3572
+ * `base ** stop` (see `endpoint` below).
3573
+ *
3574
+ * @param start - `base ** start` is the starting value of the sequence.
3575
+ * @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
3576
+ * @param num - Number of samples to generate. Default is 50.
3577
+ * @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
3578
+ * @param base - The base of the log space. Default is 10.
3579
+ * @returns Array of evenly spaced values on a log scale.
3580
+ */
3581
+ function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
3582
+ const y = linspace(start, stop, num, endpoint, {
3583
+ dtype,
3584
+ device
3585
+ });
3586
+ const logBase = Math.log(base);
3587
+ return exp$1(mul(y, logBase));
3588
+ }
3415
3589
  function aluCompare(a, b, op) {
3416
3590
  switch (op) {
3417
3591
  case CompareOp.Less: return AluExp.cmplt(a, b);
@@ -3488,6 +3662,7 @@ var BatchTrace = class extends Trace {
3488
3662
  return valOuts$1.map((x) => new BatchTracer(this, x, null));
3489
3663
  }
3490
3664
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3665
+ if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
3491
3666
  return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3492
3667
  }
3493
3668
  get axisSize() {
@@ -3499,13 +3674,13 @@ var BatchTrace = class extends Trace {
3499
3674
  *
3500
3675
  * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3501
3676
  */
3502
- function broadcastBatcher(op) {
3503
- return (axisSize, args, dims) => {
3677
+ function broadcastBatcher(prim) {
3678
+ return (axisSize, args, dims, params) => {
3504
3679
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3505
3680
  const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3506
3681
  const firstIdx = dims.findIndex((d) => d !== null);
3507
3682
  const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3508
- if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3683
+ if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
3509
3684
  args = args.map((x, i) => {
3510
3685
  if (dims[i] === null) return x;
3511
3686
  x = moveBatchAxis(axisSize, dims[i], 0, x);
@@ -3516,37 +3691,45 @@ function broadcastBatcher(op) {
3516
3691
  ]);
3517
3692
  return x;
3518
3693
  });
3519
- return [[op(...args)], [0]];
3694
+ return [[bind1(prim, args, params)], [0]];
3520
3695
  };
3521
3696
  }
3522
- function unopBatcher(op) {
3697
+ function unopBatcher(prim) {
3523
3698
  return (axisSize, [x], [xBdim], params) => {
3524
- return [[op(x, params)], [xBdim]];
3699
+ return [[bind1(prim, [x], params)], [xBdim]];
3700
+ };
3701
+ }
3702
+ function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
3703
+ return (axisSize, [x], [xBdim], params) => {
3704
+ assertNonNull(xBdim);
3705
+ if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), rep(numOutputs, xBdim)];
3706
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3707
+ return [bind(prim, [x], params), rep(numOutputs, 0)];
3525
3708
  };
3526
3709
  }
3527
3710
  const vmapRules = {
3528
- [Primitive.Add]: broadcastBatcher(add$1),
3529
- [Primitive.Mul]: broadcastBatcher(mul),
3530
- [Primitive.Idiv]: broadcastBatcher(idiv),
3531
- [Primitive.Mod]: broadcastBatcher(mod),
3532
- [Primitive.Min]: broadcastBatcher(min$1),
3533
- [Primitive.Max]: broadcastBatcher(max$1),
3534
- [Primitive.Neg]: unopBatcher(neg),
3535
- [Primitive.Reciprocal]: unopBatcher(reciprocal$1),
3536
- [Primitive.Floor]: unopBatcher(floor$1),
3537
- [Primitive.Ceil]: unopBatcher(ceil$1),
3538
- [Primitive.StopGradient]: unopBatcher(stopGradient),
3539
- [Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
3540
- [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3541
- [Primitive.Sin]: unopBatcher(sin$1),
3542
- [Primitive.Cos]: unopBatcher(cos$1),
3543
- [Primitive.Asin]: unopBatcher(asin$1),
3544
- [Primitive.Atan]: unopBatcher(atan$1),
3545
- [Primitive.Exp]: unopBatcher(exp$1),
3546
- [Primitive.Log]: unopBatcher(log$1),
3547
- [Primitive.Erf]: unopBatcher(erf$1),
3548
- [Primitive.Erfc]: unopBatcher(erfc$1),
3549
- [Primitive.Sqrt]: unopBatcher(sqrt$1),
3711
+ [Primitive.Add]: broadcastBatcher(Primitive.Add),
3712
+ [Primitive.Mul]: broadcastBatcher(Primitive.Mul),
3713
+ [Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
3714
+ [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3715
+ [Primitive.Min]: broadcastBatcher(Primitive.Min),
3716
+ [Primitive.Max]: broadcastBatcher(Primitive.Max),
3717
+ [Primitive.Neg]: unopBatcher(Primitive.Neg),
3718
+ [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3719
+ [Primitive.Floor]: unopBatcher(Primitive.Floor),
3720
+ [Primitive.Ceil]: unopBatcher(Primitive.Ceil),
3721
+ [Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
3722
+ [Primitive.Cast]: unopBatcher(Primitive.Cast),
3723
+ [Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
3724
+ [Primitive.Sin]: unopBatcher(Primitive.Sin),
3725
+ [Primitive.Cos]: unopBatcher(Primitive.Cos),
3726
+ [Primitive.Asin]: unopBatcher(Primitive.Asin),
3727
+ [Primitive.Atan]: unopBatcher(Primitive.Atan),
3728
+ [Primitive.Exp]: unopBatcher(Primitive.Exp),
3729
+ [Primitive.Log]: unopBatcher(Primitive.Log),
3730
+ [Primitive.Erf]: unopBatcher(Primitive.Erf),
3731
+ [Primitive.Erfc]: unopBatcher(Primitive.Erfc),
3732
+ [Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
3550
3733
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3551
3734
  assertNonNull(xBdim);
3552
3735
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
@@ -3568,10 +3751,25 @@ const vmapRules = {
3568
3751
  });
3569
3752
  return [[z], [0]];
3570
3753
  },
3571
- [Primitive.Compare](axisSize, args, dims, { op }) {
3572
- return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3754
+ [Primitive.Compare]: broadcastBatcher(Primitive.Compare),
3755
+ [Primitive.Where]: broadcastBatcher(Primitive.Where),
3756
+ [Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
3757
+ const minBdim = Math.min(...xBdims.filter((d) => d !== null));
3758
+ xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
3759
+ const newAxis = axis + (minBdim <= axis ? 1 : 0);
3760
+ return [[concatenate$1(xs, newAxis)], [minBdim]];
3761
+ },
3762
+ [Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
3763
+ assertNonNull(xBdim);
3764
+ const newAxis = axis + (xBdim <= axis ? 1 : 0);
3765
+ const outs = split$2(x, newAxis, sizes);
3766
+ return [outs, rep(outs.length, xBdim)];
3767
+ },
3768
+ [Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
3769
+ k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
3770
+ k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
3771
+ return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
3573
3772
  },
3574
- [Primitive.Where]: broadcastBatcher(where$1),
3575
3773
  [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3576
3774
  if (indicesBdim.every((d) => d === null)) {
3577
3775
  assertNonNull(xBdim);
@@ -3633,18 +3831,8 @@ const vmapRules = {
3633
3831
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3634
3832
  return [[pad$1(x, newWidth)], [xBdim]];
3635
3833
  },
3636
- [Primitive.Sort](axisSize, [x], [xBdim]) {
3637
- assertNonNull(xBdim);
3638
- if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
3639
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3640
- return [[sort$1(x)], [0]];
3641
- },
3642
- [Primitive.Argsort](axisSize, [x], [xBdim]) {
3643
- assertNonNull(xBdim);
3644
- if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
3645
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3646
- return [argsort$1(x), [0, 0]];
3647
- },
3834
+ [Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
3835
+ [Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
3648
3836
  [Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
3649
3837
  if (aBdim === null) {
3650
3838
  b = moveBatchAxis(axisSize, bBdim, -3, b);
@@ -3668,12 +3856,8 @@ const vmapRules = {
3668
3856
  const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
3669
3857
  return [[x], [0]];
3670
3858
  },
3671
- [Primitive.Cholesky](axisSize, [x], [xBdim]) {
3672
- assertNonNull(xBdim);
3673
- if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
3674
- x = moveBatchAxis(axisSize, xBdim, 0, x);
3675
- return [[cholesky$2(x)], [0]];
3676
- },
3859
+ [Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
3860
+ [Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
3677
3861
  [Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
3678
3862
  const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
3679
3863
  const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
@@ -3823,6 +4007,16 @@ function batchMatmulT(a, b) {
3823
4007
  function mT(a) {
3824
4008
  return moveaxis(a, -2, -1);
3825
4009
  }
4010
+ function sliceAxis(a, axis, p) {
4011
+ const slices = Array(a.shape.length).fill([]);
4012
+ slices[checkAxis(axis, a.ndim)] = p;
4013
+ return a.slice(...slices);
4014
+ }
4015
+ function padAxis(a, axis, p) {
4016
+ const pads = Array(a.shape.length).fill([0, 0]);
4017
+ pads[checkAxis(axis, a.ndim)] = p;
4018
+ return pad$1(a, pads);
4019
+ }
3826
4020
  const jvpRules = {
3827
4021
  [Primitive.Add]: linearTangentsJvp(Primitive.Add),
3828
4022
  [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
@@ -3921,6 +4115,8 @@ const jvpRules = {
3921
4115
  dcond.dispose();
3922
4116
  return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
3923
4117
  },
4118
+ [Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
4119
+ [Primitive.Split]: linearTangentsJvp(Primitive.Split),
3924
4120
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
3925
4121
  [Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
3926
4122
  const indicesRef = indices.map((t) => t.ref);
@@ -3955,6 +4151,38 @@ const jvpRules = {
3955
4151
  const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
3956
4152
  return [[L], [dL]];
3957
4153
  },
4154
+ [Primitive.LU]([a], [da]) {
4155
+ const [luMatrix, pivots, permutation] = lu$1(a);
4156
+ const [m, n] = a.shape.slice(-2);
4157
+ const k = Math.min(m, n);
4158
+ const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
4159
+ const lLower = tril(luSliceL, -1);
4160
+ const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
4161
+ const L = lPadded.add(eye(m));
4162
+ const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
4163
+ const uUpper = triu(luSliceU);
4164
+ const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4165
+ const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4166
+ const U = uPadded.add(uEye);
4167
+ const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4168
+ const pda = batchMatmulT(P, mT(da));
4169
+ const la = mT(triangularSolve$1(L.ref, mT(pda), {
4170
+ lower: true,
4171
+ unitDiagonal: true
4172
+ }));
4173
+ const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
4174
+ const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
4175
+ const uDot = batchMatmulT(triu(lau), mT(U));
4176
+ return [[
4177
+ luMatrix,
4178
+ pivots,
4179
+ permutation
4180
+ ], [
4181
+ lDot.add(uDot),
4182
+ zerosLike$1(pivots.ref),
4183
+ zerosLike$1(permutation.ref)
4184
+ ]];
4185
+ },
3958
4186
  [Primitive.Jit](primals, tangents, { name, jaxpr }) {
3959
4187
  const newJaxpr = jvpJaxpr(jaxpr);
3960
4188
  const outs = bind(Primitive.Jit, [
@@ -4492,6 +4720,15 @@ const transposeRules = {
4492
4720
  cond.dispose();
4493
4721
  return cts;
4494
4722
  },
4723
+ [Primitive.Concatenate]([ct], inputs, { axis }) {
4724
+ if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
4725
+ const sizes = inputs.map((x) => x.aval.shape[axis]);
4726
+ return split$2(ct, axis, sizes);
4727
+ },
4728
+ [Primitive.Split](cts, [x], { axis }) {
4729
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
4730
+ return [concatenate$1(cts, axis)];
4731
+ },
4495
4732
  [Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
4496
4733
  if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
4497
4734
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
@@ -4767,8 +5004,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
4767
5004
  const idx = lhsIndex[j];
4768
5005
  const dim = shape$1[j];
4769
5006
  const existing = sizeMap.get(idx);
4770
- if (existing === void 0) sizeMap.set(idx, dim);
4771
- else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
5007
+ if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
5008
+ else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
4772
5009
  }
4773
5010
  }
4774
5011
  for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
@@ -4924,27 +5161,53 @@ function ifft(a, axis = -1) {
4924
5161
  //#region src/library/numpy-linalg.ts
4925
5162
  var numpy_linalg_exports = {};
4926
5163
  __export(numpy_linalg_exports, {
4927
- cholesky: () => cholesky$1,
5164
+ cholesky: () => cholesky,
5165
+ det: () => det,
4928
5166
  diagonal: () => diagonal,
5167
+ inv: () => inv,
4929
5168
  lstsq: () => lstsq,
4930
5169
  matmul: () => matmul,
5170
+ matrixPower: () => matrixPower,
4931
5171
  matrixTranspose: () => matrixTranspose,
4932
5172
  outer: () => outer,
5173
+ slogdet: () => slogdet,
5174
+ solve: () => solve,
4933
5175
  tensordot: () => tensordot,
4934
5176
  trace: () => trace,
4935
5177
  vecdot: () => vecdot
4936
5178
  });
5179
+ function checkSquare(name, a) {
5180
+ if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
5181
+ return a.shape[a.ndim - 1];
5182
+ }
4937
5183
  /**
4938
5184
  * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
4939
5185
  *
4940
5186
  * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
4941
5187
  * the input matrix, which is on by default.
4942
5188
  */
4943
- function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
5189
+ function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
4944
5190
  a = fudgeArray(a);
4945
- if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
5191
+ checkSquare("cholesky", a);
4946
5192
  if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
4947
- return cholesky(a, { upper });
5193
+ return cholesky$1(a, { upper });
5194
+ }
5195
+ /** Compute the determinant of a square matrix (batched). */
5196
+ function det(a) {
5197
+ a = fudgeArray(a);
5198
+ const n = checkSquare("det", a);
5199
+ const [lu$2, pivots, permutation] = lu(a);
5200
+ permutation.dispose();
5201
+ const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5202
+ const sign$1 = parity.mul(-2).add(1);
5203
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5204
+ return prod$1(diag$1, -1).mul(sign$1);
5205
+ }
5206
+ /** Compute the inverse of a square matrix (batched). */
5207
+ function inv(a) {
5208
+ a = fudgeArray(a);
5209
+ const n = checkSquare("inv", a);
5210
+ return solve(a, eye(n));
4948
5211
  }
4949
5212
  /**
4950
5213
  * Return the least-squares solution to a linear equation.
@@ -4968,7 +5231,7 @@ function lstsq(a, b) {
4968
5231
  const at = matrixTranspose(a.ref);
4969
5232
  if (m <= n) {
4970
5233
  const aat = matmul(a, at.ref);
4971
- const l = cholesky$1(aat, { symmetrizeInput: false });
5234
+ const l = cholesky(aat, { symmetrizeInput: false });
4972
5235
  const lb = triangularSolve(l.ref, b, {
4973
5236
  leftSide: true,
4974
5237
  lower: true
@@ -4980,7 +5243,7 @@ function lstsq(a, b) {
4980
5243
  return matmul(at, llb.ref);
4981
5244
  } else {
4982
5245
  const ata = matmul(at.ref, a);
4983
- const l = cholesky$1(ata, { symmetrizeInput: false });
5246
+ const l = cholesky(ata, { symmetrizeInput: false });
4984
5247
  const atb = matmul(at, b);
4985
5248
  const lb = triangularSolve(l.ref, atb, {
4986
5249
  leftSide: true,
@@ -4993,6 +5256,169 @@ function lstsq(a, b) {
4993
5256
  return llb;
4994
5257
  }
4995
5258
  }
5259
+ /** Raise a square matrix to an integer power, via repeated squarings. */
5260
+ function matrixPower(a, n) {
5261
+ if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
5262
+ a = fudgeArray(a);
5263
+ const m = checkSquare("matrixPower", a);
5264
+ if (n === 0) {
5265
+ a.dispose();
5266
+ return broadcastTo(eye(m), a.shape);
5267
+ }
5268
+ if (n < 0) {
5269
+ a = inv(a);
5270
+ n = -n;
5271
+ }
5272
+ let result = null;
5273
+ let a2k = a;
5274
+ for (let k = 0; n; k++) {
5275
+ if (k > 0) a2k = matmul(a2k.ref, a2k);
5276
+ if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
5277
+ n = Math.floor(n / 2);
5278
+ }
5279
+ a2k.dispose();
5280
+ return result;
5281
+ }
5282
+ /** Return sign and natural logarithm of the determinant of `a`. */
5283
+ function slogdet(a) {
5284
+ a = fudgeArray(a);
5285
+ const n = checkSquare("slogdet", a);
5286
+ const [lu$2, pivots, permutation] = lu(a);
5287
+ permutation.dispose();
5288
+ let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5289
+ const diag$1 = lu$2.diagonal(0, -1, -2);
5290
+ parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
5291
+ const logabsdet = log(absolute(diag$1)).sum(-1);
5292
+ const sign$1 = parity.mul(-2).add(1);
5293
+ return [sign$1, logabsdet];
5294
+ }
5295
+ /**
5296
+ * Solve a linear system of equations.
5297
+ *
5298
+ * This solves a (batched) linear system of equations `a @ x = b` for `x` given
5299
+ * `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
5300
+ *
5301
+ * @param a - Coefficient matrix of shape `(..., N, N)`.
5302
+ * @param b - Values of shape `(N,)` or `(..., N, M)`.
5303
+ * @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
5304
+ */
5305
+ function solve(a, b) {
5306
+ a = fudgeArray(a);
5307
+ b = fudgeArray(b);
5308
+ const n = checkSquare("solve", a);
5309
+ if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
5310
+ const bIs1d = b.ndim === 1;
5311
+ if (bIs1d) b = b.reshape([...b.shape, 1]);
5312
+ if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
5313
+ const m = b.shape[b.ndim - 1];
5314
+ const batchDims = generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
5315
+ a = broadcastTo(a, [
5316
+ ...batchDims,
5317
+ n,
5318
+ n
5319
+ ]);
5320
+ b = broadcastTo(b, [
5321
+ ...batchDims,
5322
+ n,
5323
+ m
5324
+ ]);
5325
+ const [lu$2, pivots, permutation] = lu(a);
5326
+ pivots.dispose();
5327
+ const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5328
+ const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5329
+ leftSide: true,
5330
+ lower: true,
5331
+ unitDiagonal: true
5332
+ });
5333
+ let x = triangularSolve(lu$2, LPb.ref, {
5334
+ leftSide: true,
5335
+ lower: false
5336
+ });
5337
+ if (bIs1d) x = squeeze(x, -1);
5338
+ return x;
5339
+ }
5340
+
5341
+ //#endregion
5342
+ //#region src/library/numpy/dtype-info.ts
5343
+ /** Machine limits for floating-point types. */
5344
+ function finfo(dtype) {
5345
+ if (!isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
5346
+ switch (dtype) {
5347
+ case DType.Float16: return Object.freeze({
5348
+ bits: 16,
5349
+ dtype: DType.Float16,
5350
+ eps: 2 ** -10,
5351
+ epsneg: 2 ** -11,
5352
+ machep: -10,
5353
+ max: 65504,
5354
+ maxexp: 16,
5355
+ min: -65504,
5356
+ minexp: -14,
5357
+ negep: -24,
5358
+ nexp: 5,
5359
+ nmant: 10,
5360
+ precision: 3,
5361
+ resolution: .001,
5362
+ smallestNormal: 2 ** -14,
5363
+ smallestSubnormal: 2 ** -24
5364
+ });
5365
+ case DType.Float32: return Object.freeze({
5366
+ bits: 32,
5367
+ dtype: DType.Float32,
5368
+ eps: 2 ** -23,
5369
+ epsneg: 2 ** -24,
5370
+ machep: -23,
5371
+ max: 34028234663852886e22,
5372
+ maxexp: 128,
5373
+ min: -34028234663852886e22,
5374
+ minexp: -126,
5375
+ negep: -24,
5376
+ nexp: 8,
5377
+ nmant: 23,
5378
+ precision: 6,
5379
+ resolution: 1e-6,
5380
+ smallestNormal: 2 ** -126,
5381
+ smallestSubnormal: 2 ** -149
5382
+ });
5383
+ case DType.Float64: return Object.freeze({
5384
+ bits: 64,
5385
+ dtype: DType.Float64,
5386
+ eps: 2 ** -52,
5387
+ epsneg: 2 ** -53,
5388
+ machep: -52,
5389
+ max: Number.MAX_VALUE,
5390
+ maxexp: 1024,
5391
+ min: -Number.MAX_VALUE,
5392
+ minexp: -1022,
5393
+ negep: -53,
5394
+ nexp: 11,
5395
+ nmant: 52,
5396
+ precision: 15,
5397
+ resolution: 1e-15,
5398
+ smallestNormal: 2 ** -1022,
5399
+ smallestSubnormal: 2 ** -1074
5400
+ });
5401
+ default: throw new Error(`finfo: unsupported dtype ${dtype}`);
5402
+ }
5403
+ }
5404
+ /** Machine limits for integer types. */
5405
+ function iinfo(dtype) {
5406
+ switch (dtype) {
5407
+ case DType.Int32: return Object.freeze({
5408
+ bits: 32,
5409
+ dtype: DType.Int32,
5410
+ max: 2147483647,
5411
+ min: -2147483648
5412
+ });
5413
+ case DType.Uint32: return Object.freeze({
5414
+ bits: 32,
5415
+ dtype: DType.Uint32,
5416
+ max: 4294967295,
5417
+ min: 0
5418
+ });
5419
+ default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
5420
+ }
5421
+ }
4996
5422
 
4997
5423
  //#endregion
4998
5424
  //#region src/library/numpy.ts
@@ -5048,6 +5474,7 @@ __export(numpy_exports, {
5048
5474
  diag: () => diag,
5049
5475
  diagonal: () => diagonal,
5050
5476
  divide: () => trueDivide,
5477
+ divmod: () => divmod,
5051
5478
  dot: () => dot$1,
5052
5479
  dstack: () => dstack,
5053
5480
  e: () => e,
@@ -5060,6 +5487,7 @@ __export(numpy_exports, {
5060
5487
  expm1: () => expm1,
5061
5488
  eye: () => eye,
5062
5489
  fft: () => numpy_fft_exports,
5490
+ finfo: () => finfo,
5063
5491
  flip: () => flip,
5064
5492
  fliplr: () => fliplr,
5065
5493
  flipud: () => flipud,
@@ -5067,6 +5495,7 @@ __export(numpy_exports, {
5067
5495
  float32: () => float32,
5068
5496
  float64: () => float64,
5069
5497
  floor: () => floor,
5498
+ floorDivide: () => floorDivide,
5070
5499
  fmod: () => fmod,
5071
5500
  frexp: () => frexp,
5072
5501
  full: () => full,
@@ -5079,6 +5508,7 @@ __export(numpy_exports, {
5079
5508
  hstack: () => hstack,
5080
5509
  hypot: () => hypot,
5081
5510
  identity: () => identity$1,
5511
+ iinfo: () => iinfo,
5082
5512
  inf: () => inf,
5083
5513
  inner: () => inner,
5084
5514
  int32: () => int32,
@@ -5096,6 +5526,7 @@ __export(numpy_exports, {
5096
5526
  log10: () => log10,
5097
5527
  log1p: () => log1p,
5098
5528
  log2: () => log2,
5529
+ logspace: () => logspace,
5099
5530
  matmul: () => matmul,
5100
5531
  matrixTranspose: () => matrixTranspose,
5101
5532
  max: () => max,
@@ -5132,9 +5563,11 @@ __export(numpy_exports, {
5132
5563
  shape: () => shape,
5133
5564
  sign: () => sign,
5134
5565
  sin: () => sin,
5566
+ sinc: () => sinc,
5135
5567
  sinh: () => sinh,
5136
5568
  size: () => size,
5137
5569
  sort: () => sort,
5570
+ split: () => split$1,
5138
5571
  sqrt: () => sqrt,
5139
5572
  square: () => square,
5140
5573
  squeeze: () => squeeze,
@@ -5142,6 +5575,7 @@ __export(numpy_exports, {
5142
5575
  std: () => std,
5143
5576
  subtract: () => subtract,
5144
5577
  sum: () => sum,
5578
+ take: () => take,
5145
5579
  tan: () => tan,
5146
5580
  tanh: () => tanh,
5147
5581
  tensordot: () => tensordot,
@@ -5400,6 +5834,45 @@ function flip(x, axis = null) {
5400
5834
  return flip$1(x, axis);
5401
5835
  }
5402
5836
  /**
5837
+ * Split an array into multiple sub-arrays along an axis.
5838
+ *
5839
+ * @param a - The input array to split.
5840
+ * @param indicesOrSections - If an integer, it indicates the number of equal
5841
+ * sections to create along the specified axis. If a list of integers, it
5842
+ * specifies the indices at which to split the array.
5843
+ * @param axis - The axis along which to split the array. Default is 0.
5844
+ */
5845
+ function split$1(a, indicesOrSections, axis = 0) {
5846
+ a = fudgeArray(a);
5847
+ axis = checkAxis(axis, a.ndim);
5848
+ const size$1 = a.shape[axis];
5849
+ let sizes;
5850
+ if (typeof indicesOrSections === "number") {
5851
+ if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
5852
+ const partSize = size$1 / indicesOrSections;
5853
+ sizes = rep(indicesOrSections, partSize);
5854
+ } else {
5855
+ const indices = indicesOrSections;
5856
+ sizes = [indices[0]];
5857
+ for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
5858
+ sizes.push(size$1 - indices[indices.length - 1]);
5859
+ }
5860
+ const results = [];
5861
+ for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
5862
+ results.push(a);
5863
+ break;
5864
+ } else if (i + 8 >= sizes.length) {
5865
+ results.push(...split$2(a, axis, sizes.slice(i)));
5866
+ break;
5867
+ } else {
5868
+ const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
5869
+ const outs = split$2(a, axis, groupSizes);
5870
+ results.push(...outs.slice(0, -1));
5871
+ a = outs[outs.length - 1];
5872
+ }
5873
+ return results;
5874
+ }
5875
+ /**
5403
5876
  * Join a sequence of arrays along an existing axis.
5404
5877
  *
5405
5878
  * The arrays must have the same shape, except in the dimension corresponding to
@@ -5411,13 +5884,11 @@ function concatenate(xs, axis = 0) {
5411
5884
  if (xs.length === 0) throw new Error("Need at least one array to concatenate");
5412
5885
  const shapes = xs.map(shape);
5413
5886
  axis = checkAxis(axis, shapes[0].length);
5414
- for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays with shapes ${JSON.stringify(shapes)} along axis ${axis}`);
5415
- const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
5887
+ for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
5416
5888
  let result = xs[0];
5417
- for (let i = 1; i < xs.length; i++) {
5418
- const len1 = result.shape[axis];
5419
- const len2 = shapes[i][axis];
5420
- result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
5889
+ for (let i = 1; i < xs.length; i += 7) {
5890
+ const group = xs.slice(i, i + 7);
5891
+ result = concatenate$1([result, ...group], axis);
5421
5892
  }
5422
5893
  return result;
5423
5894
  }
@@ -5669,6 +6140,20 @@ function sort(a, axis = -1) {
5669
6140
  function argsort(a, axis = -1) {
5670
6141
  return fudgeArray(a).argsort(axis);
5671
6142
  }
6143
+ /**
6144
+ * Take elements from an array along an axis.
6145
+ *
6146
+ * This is equivalent to advanced indexing with integer indices over that
6147
+ * numbered axis. By default, the flattened array is used.
6148
+ */
6149
+ function take(a, indices, axis = null) {
6150
+ if (axis === null) {
6151
+ a = ravel(a);
6152
+ axis = 0;
6153
+ }
6154
+ axis = checkAxis(axis, ndim(a));
6155
+ return gather(a, [indices], [axis], axis);
6156
+ }
5672
6157
  /** Return if two arrays are element-wise equal within a tolerance. */
5673
6158
  function allclose(actual, expected, options) {
5674
6159
  const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
@@ -5988,6 +6473,20 @@ function tan(x) {
5988
6473
  x = fudgeArray(x);
5989
6474
  return sin(x.ref).div(cos(x));
5990
6475
  }
6476
+ /**
6477
+ * @function
6478
+ * Return the normalized sinc function.
6479
+ *
6480
+ * The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
6481
+ * This is the normalized sinc function commonly used in signal processing.
6482
+ *
6483
+ * **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
6484
+ * requires a custom JVP rule to handle properly (see JAX implementation).
6485
+ */
6486
+ const sinc = jit$1(function sinc$1(x) {
6487
+ const pix = x.ref.mul(Math.PI);
6488
+ return where(equal(x, 0), 1, sin(pix.ref).div(pix));
6489
+ });
5991
6490
  /** Element-wise inverse cosine function (inverse of cos). */
5992
6491
  function acos(x) {
5993
6492
  return subtract(pi / 2, asin(x));
@@ -6040,6 +6539,25 @@ function trueDivide(x, y) {
6040
6539
  return x.div(y);
6041
6540
  }
6042
6541
  /**
6542
+ * Return the largest integer smaller or equal to the division of the inputs.
6543
+ *
6544
+ * The result is always rounded towards negative infinity.
6545
+ *
6546
+ * For floating-point inputs, this is equivalent to `floor(x / y)`.
6547
+ * For integer inputs, we use `(x - remainder(x, y)) / y` to handle
6548
+ * negative values correctly (note: may overflow near int32 boundaries).
6549
+ *
6550
+ * @param x - Dividend array.
6551
+ * @param y - Divisor array.
6552
+ * @returns Element-wise floor division of x by y.
6553
+ */
6554
+ function floorDivide(x, y) {
6555
+ x = fudgeArray(x);
6556
+ y = fudgeArray(y);
6557
+ if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
6558
+ return subtract(x, remainder(x.ref, y.ref)).div(y);
6559
+ }
6560
+ /**
6043
6561
  * @function
6044
6562
  * Calculate element-wise floating-point modulo operation.
6045
6563
  */
@@ -6053,6 +6571,20 @@ const fmod = jit$1(function fmod$1(x, y) {
6053
6571
  const remainder = jit$1(function remainder$1(x, y) {
6054
6572
  return mod(mod(x, y.ref).add(y.ref), y);
6055
6573
  });
6574
+ /**
6575
+ * Return element-wise quotient and remainder simultaneously.
6576
+ *
6577
+ * Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
6578
+ *
6579
+ * @param x - Dividend array.
6580
+ * @param y - Divisor array.
6581
+ * @returns Tuple of [quotient, remainder].
6582
+ */
6583
+ function divmod(x, y) {
6584
+ const xArr = fudgeArray(x);
6585
+ const yArr = fudgeArray(y);
6586
+ return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
6587
+ }
6056
6588
  /** Round input to the nearest integer towards zero. */
6057
6589
  function trunc(x) {
6058
6590
  return idiv(x, 1);
@@ -6216,14 +6748,15 @@ function std(x, axis = null, opts) {
6216
6748
  return sqrt(var_(x, axis, opts));
6217
6749
  }
6218
6750
  /** Estimate the sample covariance of a set of variables. */
6219
- function cov(x, y) {
6751
+ function cov(x, y = null, { rowvar = true } = {}) {
6220
6752
  x = fudgeArray(x);
6221
6753
  if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
6222
- if (y !== void 0) {
6754
+ if (y !== null) {
6223
6755
  y = fudgeArray(y);
6224
6756
  if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
6225
6757
  x = vstack([x, y]);
6226
6758
  }
6759
+ if (!rowvar) x = x.transpose();
6227
6760
  const [_M, N] = x.shape;
6228
6761
  x = x.ref.sub(x.mean(1, { keepdims: true }));
6229
6762
  return dot$1(x.ref, x.transpose()).div(N - 1);
@@ -6268,7 +6801,8 @@ const isfinite = jit$1(function isfinite$1(x) {
6268
6801
  //#region src/library/lax-linalg.ts
6269
6802
  var lax_linalg_exports = {};
6270
6803
  __export(lax_linalg_exports, {
6271
- cholesky: () => cholesky,
6804
+ cholesky: () => cholesky$1,
6805
+ lu: () => lu,
6272
6806
  triangularSolve: () => triangularSolve
6273
6807
  });
6274
6808
  /**
@@ -6297,11 +6831,39 @@ __export(lax_linalg_exports, {
6297
6831
  * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
6298
6832
  * ```
6299
6833
  */
6300
- function cholesky(a, { upper = false } = {}) {
6834
+ function cholesky$1(a, { upper = false } = {}) {
6301
6835
  const L = cholesky$2(a);
6302
6836
  return upper ? moveaxis$1(L, -2, -1) : L;
6303
6837
  }
6304
6838
  /**
6839
+ * LU decomposition with partial pivoting.
6840
+ *
6841
+ * Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
6842
+ * permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
6843
+ * and `U` is upper-triangular.
6844
+ *
6845
+ * @param x - A batch of matrices with shape `[..., m, n]`.
6846
+ *
6847
+ * @returns A tuple `(lu, pivots, permutation)` where:
6848
+ * - `lu`: combined lower and upper triangular matrices.
6849
+ * - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
6850
+ * - `permutation`: the permutation generated by pivots with shape `[..., m]`.
6851
+ *
6852
+ * @example
6853
+ * ```ts
6854
+ * import { lax, numpy as np } from "@jax-js/jax";
6855
+ *
6856
+ * const A = np.array([[4., 3.], [6., 3.]]);
6857
+ * const [lu, pivots, permutation] = lax.linalg.lu(A);
6858
+ * // lu ≈ [[6., 3.], [0.6666667, 1.0]]
6859
+ * // pivots = [1, 1]
6860
+ * // permutation = [1, 0]
6861
+ * ```
6862
+ */
6863
+ function lu(x) {
6864
+ return lu$1(x);
6865
+ }
6866
+ /**
6305
6867
  * Solve a triangular linear system.
6306
6868
  *
6307
6869
  * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
@@ -6844,33 +7406,41 @@ __export(random_exports, {
6844
7406
  gumbel: () => gumbel,
6845
7407
  key: () => key,
6846
7408
  laplace: () => laplace,
7409
+ multivariateNormal: () => multivariateNormal,
6847
7410
  normal: () => normal,
6848
7411
  split: () => split,
6849
7412
  uniform: () => uniform
6850
7413
  });
6851
- function validateKeyShape(key$1) {
7414
+ function validateKeyShape(key$1, scalar = false) {
6852
7415
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
6853
7416
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
7417
+ if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
6854
7418
  return key$1.shape.slice(0, -1);
6855
7419
  }
7420
+ function getK01(key$1) {
7421
+ const keyShape = validateKeyShape(key$1, true);
7422
+ let [k0, k1] = split$2(key$1, -1, [1, 1]);
7423
+ k0 = k0.reshape(keyShape);
7424
+ k1 = k1.reshape(keyShape);
7425
+ return [k0, k1];
7426
+ }
6856
7427
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
6857
7428
  function key(seed) {
6858
- seed = seed >>> 0;
6859
- return array([0, seed], { dtype: DType.Uint32 });
7429
+ seed = array(seed, { dtype: DType.Uint32 });
7430
+ if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
7431
+ return stack([0, seed]);
6860
7432
  }
6861
7433
  /** Splits a PRNG key into `num` new keys by adding a leading axis. */
6862
7434
  function split(key$1, num = 2) {
6863
7435
  const shape$1 = typeof num === "number" ? [num] : num;
6864
7436
  for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
6865
- const keyShape = validateKeyShape(key$1);
6866
- const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
6867
- const k1 = key$1.slice(...keyShape.map(() => null), 1);
7437
+ const [k0, k1] = getK01(key$1);
6868
7438
  return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
6869
7439
  }
6870
7440
  /** Sample uniform bits in the form of unsigned integers. */
6871
7441
  function bits(key$1, shape$1 = []) {
6872
- const keyShape = validateKeyShape(key$1);
6873
- return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
7442
+ const [k0, k1] = getK01(key$1);
7443
+ return randomBits(k0, k1, shape$1);
6874
7444
  }
6875
7445
  /**
6876
7446
  * @function
@@ -6944,6 +7514,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
6944
7514
  }, { staticArgnums: [1] });
6945
7515
  /**
6946
7516
  * @function
7517
+ * Sample multivariate normal random values with given mean and covariance.
7518
+ *
7519
+ * The values are returned with the given shape, along with the final dimension
7520
+ * used to represent the n-dimensional multivariate normal factors.
7521
+ *
7522
+ * This uses Cholesky decomposition on the covariance matrix.
7523
+ *
7524
+ * - `key` - PRNG key
7525
+ * - `mean` - Mean vector of shape `[..., n]`
7526
+ * - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
7527
+ * - `shape` - Result batch shape, must be broadcastable with
7528
+ * `mean.shape[:-1]` and `cov.shape[:-2]`
7529
+ * @returns Random samples of shape `[...shape, n]`
7530
+ */
7531
+ const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
7532
+ mean$1 = fudgeArray(mean$1);
7533
+ cov$1 = fudgeArray(cov$1);
7534
+ const n = mean$1.shape[mean$1.ndim - 1];
7535
+ if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
7536
+ const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
7537
+ const L = cholesky(cov$1);
7538
+ const z = normal(key$1, outputShape);
7539
+ return einsum("...ij,...j->...i", L, z).add(mean$1);
7540
+ }, { staticArgnums: [3] });
7541
+ /**
7542
+ * @function
6947
7543
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
6948
7544
  *
6949
7545
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and