@jax-js/jax 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
30
30
  }) : target, mod));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-D2C4MJRP.cjs');
33
+ const require_backend = require('./backend-Ss1Mev_-.cjs');
34
34
 
35
35
  //#region src/tree.ts
36
36
  var tree_exports = {};
@@ -354,6 +354,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
354
354
  Primitive$1["RandomBits"] = "random_bits";
355
355
  Primitive$1["Sin"] = "sin";
356
356
  Primitive$1["Cos"] = "cos";
357
+ Primitive$1["Asin"] = "asin";
358
+ Primitive$1["Atan"] = "atan";
357
359
  Primitive$1["Exp"] = "exp";
358
360
  Primitive$1["Log"] = "log";
359
361
  Primitive$1["Sqrt"] = "sqrt";
@@ -421,6 +423,12 @@ function sin$1(x) {
421
423
  function cos$1(x) {
422
424
  return bind1(Primitive.Cos, [x]);
423
425
  }
426
+ function asin$1(x) {
427
+ return bind1(Primitive.Asin, [x]);
428
+ }
429
+ function atan$1(x) {
430
+ return bind1(Primitive.Atan, [x]);
431
+ }
424
432
  function exp$1(x) {
425
433
  return bind1(Primitive.Exp, [x]);
426
434
  }
@@ -436,18 +444,16 @@ function min$1(x, y) {
436
444
  function max$1(x, y) {
437
445
  return bind1(Primitive.Max, [x, y]);
438
446
  }
439
- function reduce(x, op, axis, opts) {
447
+ function reduce(x, op, axis = null, opts) {
440
448
  if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
441
- if (axis === void 0) if (x instanceof Tracer) axis = require_backend.range(x.shape.length);
442
- else axis = [];
443
- else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, ndim$1(x))];
444
- else axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
449
+ axis = require_backend.normalizeAxis(axis, ndim$1(x));
445
450
  const originalShape = getShape(x);
446
- const result = bind1(Primitive.Reduce, [x], {
451
+ let result = bind1(Primitive.Reduce, [x], {
447
452
  op,
448
453
  axis
449
454
  });
450
- return opts?.keepDims ? broadcast(result, originalShape, axis) : result;
455
+ if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
456
+ return result;
451
457
  }
452
458
  function dot$1(x, y) {
453
459
  return bind1(Primitive.Dot, [x, y]);
@@ -493,10 +499,11 @@ function where$1(cond, x, y) {
493
499
  }
494
500
  function transpose$1(x, perm) {
495
501
  perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
502
+ if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
496
503
  return bind1(Primitive.Transpose, [x], { perm });
497
504
  }
498
505
  function broadcast(x, shape$1, axis) {
499
- axis = axis.map((a) => require_backend.checkAxis(a, shape$1.length));
506
+ axis = require_backend.normalizeAxis(axis, shape$1.length);
500
507
  return bind1(Primitive.Broadcast, [x], {
501
508
  shape: shape$1,
502
509
  axis
@@ -515,7 +522,7 @@ function reshape$1(x, shape$1) {
515
522
  return bind1(Primitive.Reshape, [x], { shape: shape$1 });
516
523
  }
517
524
  function flip$1(x, axis) {
518
- axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
525
+ axis = require_backend.normalizeAxis(axis, ndim$1(x));
519
526
  return bind1(Primitive.Flip, [x], { axis });
520
527
  }
521
528
  function shrink(x, slice) {
@@ -595,15 +602,19 @@ var Tracer = class Tracer {
595
602
  constructor(trace) {
596
603
  this._trace = trace;
597
604
  }
605
+ /** The shape of the array. */
598
606
  get shape() {
599
607
  return this.aval.shape;
600
608
  }
609
+ /** The total number of elements in the array. */
601
610
  get size() {
602
611
  return require_backend.prod(this.shape);
603
612
  }
613
+ /** The dtype of the array. */
604
614
  get dtype() {
605
615
  return this.aval.dtype;
606
616
  }
617
+ /** The number of dimensions of the array. */
607
618
  get ndim() {
608
619
  return this.shape.length;
609
620
  }
@@ -639,22 +650,20 @@ var Tracer = class Tracer {
639
650
  return lessEqual$1(this, other);
640
651
  }
641
652
  /** Sum of the elements of the array over a given axis, or axes. */
642
- sum(axis, opts) {
653
+ sum(axis = null, opts) {
643
654
  return reduce(this, require_backend.AluOp.Add, axis, opts);
644
655
  }
645
656
  /** Product of the array elements over a given axis. */
646
- prod(axis, opts) {
657
+ prod(axis = null, opts) {
647
658
  return reduce(this, require_backend.AluOp.Mul, axis, opts);
648
659
  }
649
660
  /** Compute the average of the array elements along the specified axis. */
650
- mean(axis, opts) {
651
- if (axis === void 0) axis = require_backend.range(this.ndim);
652
- else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, this.ndim)];
653
- else axis = axis.map((a) => require_backend.checkAxis(a, this.ndim));
654
- let result = reduce(this, require_backend.AluOp.Add, axis);
655
- result = result.mul(result.size / this.size);
656
- if (opts?.keepDims) result = broadcast(result, this.shape, axis);
657
- return result;
661
+ mean(axis = null, opts) {
662
+ axis = require_backend.normalizeAxis(axis, this.ndim);
663
+ const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
664
+ if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
665
+ const result = reduce(this, require_backend.AluOp.Add, axis, opts);
666
+ return result.mul(1 / n);
658
667
  }
659
668
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
660
669
  transpose(perm) {
@@ -1187,6 +1196,8 @@ const jitRules = {
1187
1196
  },
1188
1197
  [Primitive.Sin]: unopJit(require_backend.AluExp.sin),
1189
1198
  [Primitive.Cos]: unopJit(require_backend.AluExp.cos),
1199
+ [Primitive.Asin]: unopJit(require_backend.AluExp.asin),
1200
+ [Primitive.Atan]: unopJit(require_backend.AluExp.atan),
1190
1201
  [Primitive.Exp]: unopJit(require_backend.AluExp.exp),
1191
1202
  [Primitive.Log]: unopJit(require_backend.AluExp.log),
1192
1203
  [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
@@ -1428,7 +1439,7 @@ var Array$1 = class Array$1 extends Tracer {
1428
1439
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1429
1440
  * will be freed when the array is disposed.
1430
1441
  */
1431
- constructor(source, st, dtype, backend, pending = null) {
1442
+ constructor(source, st, dtype, backend, { pending = null } = {}) {
1432
1443
  super(baseArrayTrace);
1433
1444
  this.id = Array$1.#nextId++;
1434
1445
  this.#dtype = dtype;
@@ -1437,6 +1448,8 @@ var Array$1 = class Array$1 extends Tracer {
1437
1448
  this.#backend = backend;
1438
1449
  this.#rc = 1;
1439
1450
  this.#pendingSet = new Set(pending);
1451
+ if (this.#pendingSet.size === 0) this.#pendingSet = null;
1452
+ else if (source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1440
1453
  }
1441
1454
  /** @ignore */
1442
1455
  get aval() {
@@ -1491,7 +1504,7 @@ var Array$1 = class Array$1 extends Tracer {
1491
1504
  const pending = this.#pending;
1492
1505
  for (const exe of pending) exe.updateRc(1);
1493
1506
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1494
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
1507
+ const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
1495
1508
  this.dispose();
1496
1509
  return ar;
1497
1510
  }
@@ -1540,7 +1553,7 @@ var Array$1 = class Array$1 extends Tracer {
1540
1553
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1541
1554
  this.dispose();
1542
1555
  for (const ar of indices) ar.dispose();
1543
- return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
1556
+ return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
1544
1557
  }
1545
1558
  /** Move axes to the rightmost dimension of the shape. */
1546
1559
  #moveAxesDown(axis) {
@@ -1577,7 +1590,7 @@ var Array$1 = class Array$1 extends Tracer {
1577
1590
  for (const exe of pending) exe.updateRc(1);
1578
1591
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1579
1592
  this.dispose();
1580
- return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
1593
+ return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
1581
1594
  }
1582
1595
  #binary(op, other) {
1583
1596
  const custom = (src) => new require_backend.AluExp(op, this.#dtype, src);
@@ -1642,7 +1655,7 @@ var Array$1 = class Array$1 extends Tracer {
1642
1655
  for (const exe of pending) exe.updateRc(1);
1643
1656
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1644
1657
  for (const ar of arrays) ar.dispose();
1645
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
1658
+ return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
1646
1659
  }
1647
1660
  /** Reduce the last dimension of the array by an operation. */
1648
1661
  #reduce(op) {
@@ -1666,7 +1679,7 @@ var Array$1 = class Array$1 extends Tracer {
1666
1679
  for (const exe of pending) exe.updateRc(1);
1667
1680
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1668
1681
  this.dispose();
1669
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
1682
+ return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
1670
1683
  }
1671
1684
  /**
1672
1685
  * Normalizes this array into one backed by a `Slot`.
@@ -1739,8 +1752,11 @@ var Array$1 = class Array$1 extends Tracer {
1739
1752
  *
1740
1753
  * If you are mapping from `data()` or `dataSync()`, it will also trigger
1741
1754
  * dispatch of operations as well.
1755
+ *
1756
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1757
+ * asynchronously for multiple arrays.
1742
1758
  */
1743
- async wait() {
1759
+ async blockUntilReady() {
1744
1760
  this.#check();
1745
1761
  if (this.#source instanceof require_backend.AluExp) return this;
1746
1762
  const pending = this.#pending;
@@ -1806,7 +1822,7 @@ var Array$1 = class Array$1 extends Tracer {
1806
1822
  return [x.#binary(require_backend.AluOp.Idiv, y)];
1807
1823
  },
1808
1824
  [Primitive.Neg]([x]) {
1809
- return [zerosLike(x.ref).#binary(require_backend.AluOp.Sub, x)];
1825
+ return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
1810
1826
  },
1811
1827
  [Primitive.Reciprocal]([x]) {
1812
1828
  return [x.#unary(require_backend.AluOp.Reciprocal)];
@@ -1826,7 +1842,7 @@ var Array$1 = class Array$1 extends Tracer {
1826
1842
  x.#backend.incRef(x.#source);
1827
1843
  const pending = x.#pending;
1828
1844
  for (const exe of pending) exe.updateRc(1);
1829
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
1845
+ const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
1830
1846
  x.dispose();
1831
1847
  return [y];
1832
1848
  }
@@ -1856,6 +1872,12 @@ var Array$1 = class Array$1 extends Tracer {
1856
1872
  [Primitive.Cos]([x]) {
1857
1873
  return [x.#unary(require_backend.AluOp.Cos)];
1858
1874
  },
1875
+ [Primitive.Asin]([x]) {
1876
+ return [x.#unary(require_backend.AluOp.Asin)];
1877
+ },
1878
+ [Primitive.Atan]([x]) {
1879
+ return [x.#unary(require_backend.AluOp.Atan)];
1880
+ },
1859
1881
  [Primitive.Exp]([x]) {
1860
1882
  return [x.#unary(require_backend.AluOp.Exp)];
1861
1883
  },
@@ -1941,7 +1963,7 @@ var Array$1 = class Array$1 extends Tracer {
1941
1963
  pending.splice(0, 0, ...prevPending);
1942
1964
  args.forEach((x) => x.dispose());
1943
1965
  return outputs.map((source, i) => {
1944
- return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
1966
+ return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
1945
1967
  });
1946
1968
  }
1947
1969
  };
@@ -2073,12 +2095,12 @@ var EvalTrace = class extends Trace {
2073
2095
  };
2074
2096
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
2075
2097
  const implRules = Array$1._implRules();
2076
- function zerosLike(val, dtype) {
2098
+ function zerosLike$1(val, dtype) {
2077
2099
  const aval = getAval(val);
2078
2100
  if (val instanceof Tracer) val.dispose();
2079
2101
  return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2080
2102
  }
2081
- function onesLike(val, dtype) {
2103
+ function onesLike$1(val, dtype) {
2082
2104
  const aval = getAval(val);
2083
2105
  if (val instanceof Tracer) val.dispose();
2084
2106
  return ones(aval.shape, { dtype: dtype ?? aval.dtype });
@@ -2141,7 +2163,7 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2141
2163
  const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
2142
2164
  return new Array$1(require_backend.AluExp.cast(dtype, exp$2), require_backend.ShapeTracker.fromShape([numRows, numCols]), dtype, require_backend.getBackend(device));
2143
2165
  }
2144
- /** Return the identity array, with ones on the main diagonal. */
2166
+ /** Return the identity matrix, with ones on the main diagonal. */
2145
2167
  function identity$1(n, { dtype, device } = {}) {
2146
2168
  return eye(n, n, {
2147
2169
  dtype,
@@ -2421,16 +2443,19 @@ var Jaxpr = class Jaxpr {
2421
2443
  varIds.set(v, require_backend.FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
2422
2444
  return id;
2423
2445
  };
2424
- hasher.update(this.inBinders.length, ...this.inBinders.map(vi));
2425
- hasher.update(this.eqns.length, ...this.eqns.flatMap((eqn) => [
2426
- eqn.primitive,
2427
- eqn.inputs.length,
2428
- ...eqn.inputs.map((x) => x instanceof Var ? vi(x) : x.value),
2429
- JSON.stringify(eqn.params),
2430
- eqn.outBinders.length,
2431
- ...eqn.outBinders.map(vi)
2432
- ]));
2433
- hasher.update(this.outs.length, ...this.outs.map((x) => x instanceof Var ? vi(x) : x.value));
2446
+ hasher.update(this.inBinders.length);
2447
+ for (const x of this.inBinders) hasher.update(vi(x));
2448
+ hasher.update(this.eqns.length);
2449
+ for (const eqn of this.eqns) {
2450
+ hasher.update(eqn.primitive);
2451
+ hasher.update(eqn.inputs.length);
2452
+ for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
2453
+ hasher.update(JSON.stringify(eqn.params));
2454
+ hasher.update(eqn.outBinders.length);
2455
+ for (const x of eqn.outBinders) hasher.update(vi(x));
2456
+ }
2457
+ hasher.update(this.outs.length);
2458
+ for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
2434
2459
  return this.#hash = hasher.value;
2435
2460
  }
2436
2461
  hash(state) {
@@ -2467,7 +2492,7 @@ var Jaxpr = class Jaxpr {
2467
2492
  const c = eqn.outBinders[0];
2468
2493
  if (atomIsLit(b, 1)) context.set(c, a);
2469
2494
  else newEqns.push(eqn);
2470
- } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2495
+ } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2471
2496
  else newEqns.push(eqn);
2472
2497
  }
2473
2498
  const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
@@ -2733,6 +2758,8 @@ const abstractEvalRules = {
2733
2758
  },
2734
2759
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2735
2760
  [Primitive.Cos]: vectorizedUnopAbstractEval,
2761
+ [Primitive.Asin]: vectorizedUnopAbstractEval,
2762
+ [Primitive.Atan]: vectorizedUnopAbstractEval,
2736
2763
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2737
2764
  [Primitive.Log]: vectorizedUnopAbstractEval,
2738
2765
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
@@ -2860,7 +2887,7 @@ function makeJaxpr$1(f, opts) {
2860
2887
  function jit$1(f, opts) {
2861
2888
  const cache = /* @__PURE__ */ new Map();
2862
2889
  const staticArgnums = new Set(opts?.staticArgnums ?? []);
2863
- return ((...args) => {
2890
+ const result = ((...args) => {
2864
2891
  const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
2865
2892
  const [argsFlat, inTree] = flatten(dynamicArgs);
2866
2893
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
@@ -2874,6 +2901,10 @@ function jit$1(f, opts) {
2874
2901
  });
2875
2902
  return unflatten(outTree, outs);
2876
2903
  });
2904
+ result.dispose = () => {
2905
+ for (const { consts } of cache.values()) for (const c of consts) c.dispose();
2906
+ };
2907
+ return result;
2877
2908
  }
2878
2909
 
2879
2910
  //#endregion
@@ -2905,7 +2936,7 @@ var JVPTrace = class extends Trace {
2905
2936
  return this.lift(pureArray(val));
2906
2937
  }
2907
2938
  lift(val) {
2908
- return new JVPTracer(this, val, zerosLike(val.ref));
2939
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
2909
2940
  }
2910
2941
  processPrimitive(primitive, tracers, params) {
2911
2942
  const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
@@ -2936,7 +2967,7 @@ function zeroTangentsJvp(primitive) {
2936
2967
  return (primals, tangents, params) => {
2937
2968
  for (const t of tangents) t.dispose();
2938
2969
  const ys = bind(primitive, primals, params);
2939
- return [ys, ys.map((y) => zerosLike(y.ref))];
2970
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
2940
2971
  };
2941
2972
  }
2942
2973
  const jvpRules = {
@@ -2954,13 +2985,13 @@ const jvpRules = {
2954
2985
  if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
2955
2986
  else {
2956
2987
  dx.dispose();
2957
- return [[cast(x.ref, dtype)], [zerosLike(x)]];
2988
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
2958
2989
  }
2959
2990
  },
2960
2991
  [Primitive.Bitcast]([x], [dx], { dtype }) {
2961
2992
  if (x.dtype === dtype) return [[x], [dx]];
2962
2993
  dx.dispose();
2963
- return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
2994
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
2964
2995
  },
2965
2996
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
2966
2997
  [Primitive.Sin]([x], [dx]) {
@@ -2969,6 +3000,14 @@ const jvpRules = {
2969
3000
  [Primitive.Cos]([x], [dx]) {
2970
3001
  return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
2971
3002
  },
3003
+ [Primitive.Asin]([x], [dx]) {
3004
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3005
+ return [[asin$1(x)], [denom.mul(dx)]];
3006
+ },
3007
+ [Primitive.Atan]([x], [dx]) {
3008
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3009
+ return [[atan$1(x)], [dx.div(denom)]];
3010
+ },
2972
3011
  [Primitive.Exp]([x], [dx]) {
2973
3012
  const z = exp$1(x);
2974
3013
  return [[z.ref], [z.mul(dx)]];
@@ -3085,7 +3124,10 @@ function mappedAval(batchDim, aval) {
3085
3124
  /** Move one axis to a different index. */
3086
3125
  function moveaxis$1(x, src, dst) {
3087
3126
  const t = pureArray(x);
3088
- const perm = require_backend.range(t.shape.length);
3127
+ src = require_backend.checkAxis(src, t.ndim);
3128
+ dst = require_backend.checkAxis(dst, t.ndim);
3129
+ if (src === dst) return t;
3130
+ const perm = require_backend.range(t.ndim);
3089
3131
  perm.splice(src, 1);
3090
3132
  perm.splice(dst, 0, src);
3091
3133
  return transpose$1(t, perm);
@@ -3178,6 +3220,8 @@ const vmapRules = {
3178
3220
  [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3179
3221
  [Primitive.Sin]: unopBatcher(sin$1),
3180
3222
  [Primitive.Cos]: unopBatcher(cos$1),
3223
+ [Primitive.Asin]: unopBatcher(asin$1),
3224
+ [Primitive.Atan]: unopBatcher(atan$1),
3181
3225
  [Primitive.Exp]: unopBatcher(exp$1),
3182
3226
  [Primitive.Log]: unopBatcher(log$1),
3183
3227
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
@@ -3363,20 +3407,28 @@ function linearizeFlatUtil(f, primalsIn) {
3363
3407
  function linearizeFlat(f, primalsIn) {
3364
3408
  const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3365
3409
  const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3366
- return [primalsOut, fLin];
3410
+ const dispose$1 = () => {
3411
+ for (const c of consts) c.dispose();
3412
+ };
3413
+ return [
3414
+ primalsOut,
3415
+ fLin,
3416
+ dispose$1
3417
+ ];
3367
3418
  }
3368
3419
  function linearize$1(f, ...primalsIn) {
3369
3420
  const [primalsInFlat, inTree] = flatten(primalsIn);
3370
3421
  const [fFlat, outTree] = flattenFun(f, inTree);
3371
- const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3422
+ const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3372
3423
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
3373
3424
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3374
- const fLin = (...tangentsIn) => {
3425
+ const fLin = ((...tangentsIn) => {
3375
3426
  const [tangentsInFlat, inTree2] = flatten(tangentsIn);
3376
3427
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
3377
3428
  const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
3378
3429
  return unflatten(outTree.value, tangentsOutFlat);
3379
- };
3430
+ });
3431
+ fLin.dispose = dispose$1;
3380
3432
  return [primalsOut, fLin];
3381
3433
  }
3382
3434
  var PartialEvalTracer = class extends Tracer {
@@ -3492,7 +3544,10 @@ var PartialEvalTrace = class extends Trace {
3492
3544
  avalsOut: jaxpr2.outs.map((x) => x.aval),
3493
3545
  tracerRefsOut: []
3494
3546
  };
3495
- const outs2 = jaxpr2.outs.map((x) => new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe));
3547
+ const outs2 = jaxpr2.outs.map((x, i$1) => {
3548
+ if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
3549
+ return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
3550
+ });
3496
3551
  recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
3497
3552
  let i = 0;
3498
3553
  let j = 0;
@@ -3576,13 +3631,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3576
3631
  const [consts, constvars] = require_backend.unzip2(constToVar.entries());
3577
3632
  const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
3578
3633
  const outVars = tracersOut.map((t) => tracerToVar.get(t));
3579
- const jaxpr = new Jaxpr(inBinders, eqns, outVars);
3634
+ let jaxpr = new Jaxpr(inBinders, eqns, outVars);
3580
3635
  typecheckJaxpr(jaxpr);
3581
3636
  for (const t of consts) t.ref;
3582
3637
  for (const t of tracersIn) t.dispose();
3583
3638
  for (const t of tracersOut) t.dispose();
3639
+ jaxpr = jaxpr.simplify();
3640
+ if (require_backend.DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
3584
3641
  return {
3585
- jaxpr: jaxpr.simplify(),
3642
+ jaxpr,
3586
3643
  consts
3587
3644
  };
3588
3645
  }
@@ -3848,20 +3905,28 @@ function vjpFlat(f, primalsIn) {
3848
3905
  const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
3849
3906
  return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
3850
3907
  };
3851
- return [primalsOut, fVjp];
3908
+ const dispose$1 = () => {
3909
+ for (const c of consts) c.dispose();
3910
+ };
3911
+ return [
3912
+ primalsOut,
3913
+ fVjp,
3914
+ dispose$1
3915
+ ];
3852
3916
  }
3853
3917
  function vjp$1(f, ...primalsIn) {
3854
3918
  const [primalsInFlat, inTree] = flatten(primalsIn);
3855
3919
  const [fFlat, outTree] = flattenFun(f, inTree);
3856
- const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3920
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3857
3921
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
3858
3922
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3859
- const fVjp = (cotangentsOut) => {
3923
+ const fVjp = ((cotangentsOut) => {
3860
3924
  const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
3861
3925
  if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
3862
3926
  const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
3863
3927
  return unflatten(inTree, cotangentsInFlat);
3864
- };
3928
+ });
3929
+ fVjp.dispose = dispose$1;
3865
3930
  return [primalsOut, fVjp];
3866
3931
  }
3867
3932
  function grad$1(f) {
@@ -3879,7 +3944,8 @@ function valueAndGrad$1(f) {
3879
3944
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3880
3945
  if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3881
3946
  const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
3882
- for (const r of rest) r.dispose();
3947
+ for (const r of rest) dispose(r);
3948
+ fVjp.dispose();
3883
3949
  return [y, ct];
3884
3950
  };
3885
3951
  }
@@ -3887,7 +3953,13 @@ function jacrev$1(f) {
3887
3953
  return function jacobianReverse(x) {
3888
3954
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
3889
3955
  const [size$1] = x.shape;
3890
- const pullback = (ct) => vjp$1(f, x)[1](ct)[0];
3956
+ const pullback = (ct) => {
3957
+ const [y, fVjp] = vjp$1(f, x);
3958
+ y.dispose();
3959
+ const [ret] = fVjp(ct);
3960
+ fVjp.dispose();
3961
+ return ret;
3962
+ };
3891
3963
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
3892
3964
  };
3893
3965
  }
@@ -3967,19 +4039,38 @@ __export(numpy_exports, {
3967
4039
  DType: () => require_backend.DType,
3968
4040
  abs: () => abs,
3969
4041
  absolute: () => absolute,
4042
+ acos: () => acos,
4043
+ acosh: () => acosh,
3970
4044
  add: () => add,
3971
4045
  allclose: () => allclose,
3972
4046
  arange: () => arange,
4047
+ arccos: () => arccos,
4048
+ arccosh: () => arccosh,
4049
+ arcsinh: () => arcsinh,
4050
+ arctan: () => arctan,
4051
+ arctan2: () => arctan2,
4052
+ arctanh: () => arctanh,
3973
4053
  argmax: () => argmax,
3974
4054
  argmin: () => argmin,
3975
4055
  array: () => array,
4056
+ asin: () => asin,
4057
+ asinh: () => asinh,
3976
4058
  astype: () => astype,
4059
+ atan: () => atan,
4060
+ atan2: () => atan2,
4061
+ atanh: () => atanh,
3977
4062
  bool: () => bool,
4063
+ broadcastArrays: () => broadcastArrays,
4064
+ broadcastShapes: () => broadcastShapes,
4065
+ broadcastTo: () => broadcastTo,
4066
+ cbrt: () => cbrt,
3978
4067
  clip: () => clip,
3979
4068
  columnStack: () => columnStack,
3980
4069
  concatenate: () => concatenate,
3981
4070
  cos: () => cos,
3982
4071
  cosh: () => cosh,
4072
+ deg2rad: () => deg2rad,
4073
+ degrees: () => degrees,
3983
4074
  diag: () => diag,
3984
4075
  diagonal: () => diagonal,
3985
4076
  divide: () => divide,
@@ -3990,6 +4081,7 @@ __export(numpy_exports, {
3990
4081
  eulerGamma: () => eulerGamma,
3991
4082
  exp: () => exp,
3992
4083
  exp2: () => exp2,
4084
+ expm1: () => expm1,
3993
4085
  eye: () => eye,
3994
4086
  flip: () => flip,
3995
4087
  fliplr: () => fliplr,
@@ -4001,14 +4093,17 @@ __export(numpy_exports, {
4001
4093
  greater: () => greater,
4002
4094
  greaterEqual: () => greaterEqual,
4003
4095
  hstack: () => hstack,
4096
+ hypot: () => hypot,
4004
4097
  identity: () => identity$1,
4005
4098
  inf: () => inf,
4099
+ inner: () => inner,
4006
4100
  int32: () => int32,
4007
4101
  less: () => less,
4008
4102
  lessEqual: () => lessEqual,
4009
4103
  linspace: () => linspace,
4010
4104
  log: () => log,
4011
4105
  log10: () => log10,
4106
+ log1p: () => log1p,
4012
4107
  log2: () => log2,
4013
4108
  matmul: () => matmul,
4014
4109
  max: () => max,
@@ -4024,35 +4119,49 @@ __export(numpy_exports, {
4024
4119
  negative: () => negative,
4025
4120
  notEqual: () => notEqual,
4026
4121
  ones: () => ones,
4027
- onesLike: () => onesLike$1,
4122
+ onesLike: () => onesLike,
4123
+ outer: () => outer,
4028
4124
  pad: () => pad,
4029
4125
  permuteDims: () => permuteDims,
4030
4126
  pi: () => pi,
4127
+ pow: () => pow,
4128
+ power: () => power,
4031
4129
  prod: () => prod$1,
4130
+ promoteTypes: () => require_backend.promoteTypes,
4131
+ rad2deg: () => rad2deg,
4132
+ radians: () => radians,
4032
4133
  ravel: () => ravel,
4033
4134
  reciprocal: () => reciprocal,
4135
+ repeat: () => repeat,
4034
4136
  reshape: () => reshape,
4035
- scalar: () => scalar,
4036
4137
  shape: () => shape,
4138
+ sign: () => sign,
4037
4139
  sin: () => sin,
4038
4140
  sinh: () => sinh,
4039
4141
  size: () => size,
4040
4142
  sqrt: () => sqrt,
4041
4143
  square: () => square,
4042
4144
  stack: () => stack,
4145
+ std: () => std,
4146
+ subtract: () => subtract,
4043
4147
  sum: () => sum,
4044
4148
  tan: () => tan,
4045
4149
  tanh: () => tanh,
4150
+ tile: () => tile,
4046
4151
  transpose: () => transpose,
4152
+ tri: () => tri,
4153
+ tril: () => tril,
4154
+ triu: () => triu,
4047
4155
  trueDivide: () => trueDivide,
4048
4156
  trunc: () => trunc,
4049
4157
  uint32: () => uint32,
4158
+ var_: () => var_,
4050
4159
  vdot: () => vdot,
4051
4160
  vecdot: () => vecdot,
4052
4161
  vstack: () => vstack,
4053
4162
  where: () => where,
4054
4163
  zeros: () => zeros,
4055
- zerosLike: () => zerosLike$1
4164
+ zerosLike: () => zerosLike
4056
4165
  });
4057
4166
  const float32 = require_backend.DType.Float32;
4058
4167
  const int32 = require_backend.DType.Int32;
@@ -4069,54 +4178,66 @@ const inf = Number.POSITIVE_INFINITY;
4069
4178
  const nan = NaN;
4070
4179
  /** This is Pi, `π = 3.14159265358979...` */
4071
4180
  const pi = Math.PI;
4072
- /** Element-wise addition, with broadcasting. */
4181
+ /** @function Element-wise addition, with broadcasting. */
4073
4182
  const add = add$1;
4074
- /** Element-wise multiplication, with broadcasting. */
4183
+ /** @function Element-wise multiplication, with broadcasting. */
4075
4184
  const multiply = mul;
4076
- /** Numerical negative of every element of an array. */
4185
+ /** @function Numerical negative of every element of an array. */
4077
4186
  const negative = neg;
4078
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
4187
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
4079
4188
  const reciprocal = reciprocal$1;
4080
- /** Element-wise sine function (takes radians). */
4189
+ /** @function Element-wise sine function (takes radians). */
4081
4190
  const sin = sin$1;
4082
- /** Element-wise cosine function (takes radians). */
4191
+ /** @function Element-wise cosine function (takes radians). */
4083
4192
  const cos = cos$1;
4084
- /** Calculate the exponential of all elements in the input array. */
4193
+ /** @function Element-wise inverse sine function (inverse of sin). */
4194
+ const asin = asin$1;
4195
+ /** @function Element-wise inverse tangent function (inverse of tan). */
4196
+ const atan = atan$1;
4197
+ /** @function Calculate the exponential of all elements in the input array. */
4085
4198
  const exp = exp$1;
4086
- /** Calculate the natural logarithm of all elements in the input array. */
4199
+ /** @function Calculate the natural logarithm of all elements in the input array. */
4087
4200
  const log = log$1;
4088
- /** Calculate the square root of all elements in the input array. */
4201
+ /** @function Calculate the square root of all elements in the input array. */
4089
4202
  const sqrt = sqrt$1;
4090
- /** Return element-wise minimum of the input arrays. */
4203
+ /** @function Return element-wise minimum of the input arrays. */
4091
4204
  const minimum = min$1;
4092
- /** Return element-wise maximum of the input arrays. */
4205
+ /** @function Return element-wise maximum of the input arrays. */
4093
4206
  const maximum = max$1;
4094
- /** Compare two arrays element-wise. */
4207
+ /** @function Compare two arrays element-wise. */
4095
4208
  const greater = greater$1;
4096
- /** Compare two arrays element-wise. */
4209
+ /** @function Compare two arrays element-wise. */
4097
4210
  const less = less$1;
4098
- /** Compare two arrays element-wise. */
4211
+ /** @function Compare two arrays element-wise. */
4099
4212
  const equal = equal$1;
4100
- /** Compare two arrays element-wise. */
4213
+ /** @function Compare two arrays element-wise. */
4101
4214
  const notEqual = notEqual$1;
4102
- /** Compare two arrays element-wise. */
4215
+ /** @function Compare two arrays element-wise. */
4103
4216
  const greaterEqual = greaterEqual$1;
4104
- /** Compare two arrays element-wise. */
4217
+ /** @function Compare two arrays element-wise. */
4105
4218
  const lessEqual = lessEqual$1;
4106
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4219
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4107
4220
  const where = where$1;
4108
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
4221
+ /**
4222
+ * @function
4223
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
4224
+ */
4109
4225
  const transpose = transpose$1;
4110
4226
  /**
4227
+ * @function
4111
4228
  * Give a new shape to an array without changing its data.
4112
4229
  *
4113
4230
  * One shape dimension can be -1. In this case, the value is inferred from the
4114
4231
  * length of the array and remaining dimensions.
4115
4232
  */
4116
4233
  const reshape = reshape$1;
4117
- /** Move axes of an array to new positions. Other axes retain original order. */
4234
+ /**
4235
+ * @function
4236
+ * Move axes of an array to new positions. Other axes retain original order.
4237
+ */
4118
4238
  const moveaxis = moveaxis$1;
4119
4239
  /**
4240
+ * @function
4120
4241
  * Add padding (zeros) to an array.
4121
4242
  *
4122
4243
  * The `width` argument is either an integer or pair of integers, in which case
@@ -4124,15 +4245,27 @@ const moveaxis = moveaxis$1;
4124
4245
  * pair specifies the padding for its corresponding axis.
4125
4246
  */
4126
4247
  const pad = pad$1;
4127
- /** Return the number of dimensions of an array. Does not consume array reference. */
4248
+ /**
4249
+ * @function
4250
+ * Return the number of dimensions of an array. Does not consume array reference.
4251
+ */
4128
4252
  const ndim = ndim$1;
4129
- /** Return the shape of an array. Does not consume array reference. */
4253
+ /** @function Return the shape of an array. Does not consume array reference. */
4130
4254
  const shape = getShape;
4131
- /** Return an array of zeros with the same shape and type as a given array. */
4132
- const zerosLike$1 = zerosLike;
4133
- /** Return an array of ones with the same shape and type as a given array. */
4134
- const onesLike$1 = onesLike;
4135
- /** Return a full array with the same shape and type as a given array. */
4255
+ /**
4256
+ * @function
4257
+ * Return an array of zeros with the same shape and type as a given array.
4258
+ */
4259
+ const zerosLike = zerosLike$1;
4260
+ /**
4261
+ * @function
4262
+ * Return an array of ones with the same shape and type as a given array.
4263
+ */
4264
+ const onesLike = onesLike$1;
4265
+ /**
4266
+ * @function
4267
+ * Return a full array with the same shape and type as a given array.
4268
+ */
4136
4269
  const fullLike$1 = fullLike;
4137
4270
  /**
4138
4271
  * Return the number of elements in an array, optionally along an axis.
@@ -4147,23 +4280,23 @@ function astype(a, dtype) {
4147
4280
  return fudgeArray(a).astype(dtype);
4148
4281
  }
4149
4282
  /** Sum of the elements of the array over a given axis, or axes. */
4150
- function sum(a, axis, opts) {
4283
+ function sum(a, axis = null, opts) {
4151
4284
  return reduce(a, require_backend.AluOp.Add, axis, opts);
4152
4285
  }
4153
4286
  /** Product of the array elements over a given axis. */
4154
- function prod$1(a, axis, opts) {
4287
+ function prod$1(a, axis = null, opts) {
4155
4288
  return reduce(a, require_backend.AluOp.Mul, axis, opts);
4156
4289
  }
4157
4290
  /** Return the minimum of array elements along a given axis. */
4158
- function min(a, axis, opts) {
4291
+ function min(a, axis = null, opts) {
4159
4292
  return reduce(a, require_backend.AluOp.Min, axis, opts);
4160
4293
  }
4161
4294
  /** Return the maximum of array elements along a given axis. */
4162
- function max(a, axis, opts) {
4295
+ function max(a, axis = null, opts) {
4163
4296
  return reduce(a, require_backend.AluOp.Max, axis, opts);
4164
4297
  }
4165
4298
  /** Compute the average of the array elements along the specified axis. */
4166
- function mean(a, axis, opts) {
4299
+ function mean(a, axis = null, opts) {
4167
4300
  return fudgeArray(a).mean(axis, opts);
4168
4301
  }
4169
4302
  /**
@@ -4179,7 +4312,7 @@ function argmin(a, axis, opts) {
4179
4312
  axis = 0;
4180
4313
  } else axis = require_backend.checkAxis(axis, a.ndim);
4181
4314
  const shape$1 = a.shape;
4182
- const isMax = equal(a, min(a.ref, axis, { keepDims: true }));
4315
+ const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
4183
4316
  const length = scalar(shape$1[axis], {
4184
4317
  dtype: int32,
4185
4318
  device: a.device
@@ -4203,7 +4336,7 @@ function argmax(a, axis, opts) {
4203
4336
  axis = 0;
4204
4337
  } else axis = require_backend.checkAxis(axis, a.ndim);
4205
4338
  const shape$1 = a.shape;
4206
- const isMax = equal(a, max(a.ref, axis, { keepDims: true }));
4339
+ const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
4207
4340
  const length = scalar(shape$1[axis], {
4208
4341
  dtype: int32,
4209
4342
  device: a.device
@@ -4215,17 +4348,9 @@ function argmax(a, axis, opts) {
4215
4348
  return length.sub(max(idx, axis, opts));
4216
4349
  }
4217
4350
  /** Reverse the elements in an array along the given axes. */
4218
- function flip(x, axis) {
4351
+ function flip(x, axis = null) {
4219
4352
  const nd = ndim(x);
4220
- if (axis === void 0) axis = require_backend.range(nd);
4221
- else if (typeof axis === "number") axis = [axis];
4222
- const seen = /* @__PURE__ */ new Set();
4223
- for (let i = 0; i < axis.length; i++) {
4224
- if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
4225
- if (axis[i] < 0) axis[i] += nd;
4226
- if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
4227
- seen.add(axis[i]);
4228
- }
4353
+ axis = require_backend.normalizeAxis(axis, nd);
4229
4354
  return flip$1(x, axis);
4230
4355
  }
4231
4356
  /**
@@ -4331,12 +4456,80 @@ function flipud(x) {
4331
4456
  function fliplr(x) {
4332
4457
  return flip(x, 1);
4333
4458
  }
4459
+ /** @function Alternative name for `numpy.transpose()`. */
4334
4460
  const permuteDims = transpose;
4335
4461
  /** Return a 1-D flattened array containing the elements of the input. */
4336
4462
  function ravel(a) {
4337
4463
  return fudgeArray(a).ravel();
4338
4464
  }
4339
4465
  /**
4466
+ * Repeat each element of an array after themselves.
4467
+ *
4468
+ * If no axis is provided, use the flattened input array, and return a flat
4469
+ * output array.
4470
+ */
4471
+ function repeat(a, repeats, axis) {
4472
+ if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
4473
+ a = fudgeArray(a);
4474
+ if (axis === void 0) {
4475
+ a = ravel(a);
4476
+ axis = 0;
4477
+ }
4478
+ axis = require_backend.checkAxis(axis, a.ndim);
4479
+ if (repeats === 1) return a;
4480
+ const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
4481
+ const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
4482
+ return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
4483
+ }
4484
+ /**
4485
+ * Construct an array by repeating A the number of times given by reps.
4486
+ *
4487
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
4488
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
4489
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
4490
+ */
4491
+ function tile(a, reps) {
4492
+ a = fudgeArray(a);
4493
+ if (typeof reps === "number") reps = [reps];
4494
+ if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
4495
+ const ndiff = reps.length - a.ndim;
4496
+ if (ndiff > 0) a = a.reshape([...require_backend.rep(ndiff, 1), ...a.shape]);
4497
+ if (ndiff < 0) reps = [...require_backend.rep(-ndiff, 1), ...reps];
4498
+ const broadcastedShape = [];
4499
+ const broadcastAxes = [];
4500
+ for (let i = 0; i < a.ndim; i++) {
4501
+ if (reps[i] > 1) {
4502
+ broadcastedShape.push(reps[i]);
4503
+ broadcastAxes.push(broadcastedShape.length - 1);
4504
+ }
4505
+ broadcastedShape.push(a.shape[i]);
4506
+ }
4507
+ const finalShape = a.shape.map((d, i) => reps[i] * d);
4508
+ return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
4509
+ }
4510
+ /**
4511
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
4512
+ *
4513
+ * In other words, this lets you append axes to the left, and/or expand
4514
+ * dimensions where the shape is 1.
4515
+ */
4516
+ function broadcastTo(a, shape$1) {
4517
+ const nd = ndim(a);
4518
+ if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
4519
+ return broadcast(a, shape$1, require_backend.range(shape$1.length - nd));
4520
+ }
4521
+ /** Broadcast input shapes to a common output shape. */
4522
+ function broadcastShapes(...shapes) {
4523
+ if (shapes.length === 0) return [];
4524
+ return shapes.reduce(generalBroadcast);
4525
+ }
4526
+ /** Broadcast arrays to a common shape. */
4527
+ function broadcastArrays(...arrays) {
4528
+ const shapes = arrays.map((a) => shape(a));
4529
+ const outShape = broadcastShapes(...shapes);
4530
+ return arrays.map((a) => broadcastTo(a, outShape));
4531
+ }
4532
+ /**
4340
4533
  * Return specified diagonals.
4341
4534
  *
4342
4535
  * If a is 2D, return the diagonal of the array with the given offset. If a is
@@ -4360,7 +4553,7 @@ function diag(v, k = 0) {
4360
4553
  if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
4361
4554
  if (a.ndim === 1) {
4362
4555
  const n = a.shape[0];
4363
- const ret = where(eye(n).equal(1), a.ref, zerosLike$1(a));
4556
+ const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
4364
4557
  if (k > 0) return pad(ret, [[0, k], [k, 0]]);
4365
4558
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
4366
4559
  else return ret;
@@ -4404,8 +4597,36 @@ function dot(x, y) {
4404
4597
  ]);
4405
4598
  return dot$1(x, y);
4406
4599
  }
4407
- /** Vector dot product of two arrays. */
4408
- function vecdot(x, y) {
4600
+ /**
4601
+ * Compute the inner product of two arrays.
4602
+ *
4603
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
4604
+ * contraction on the last axis.
4605
+ *
4606
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
4607
+ */
4608
+ function inner(x, y) {
4609
+ x = reshape(x, shape(x).toSpliced(-1, 0, ...require_backend.rep(ndim(y) - 1, 1)));
4610
+ return dot$1(x, y);
4611
+ }
4612
+ /**
4613
+ * Compute the outer product of two arrays.
4614
+ *
4615
+ * If the input arrays are not 1D, they will be flattened. Returned array will
4616
+ * be of shape `[x.size, y.size]`.
4617
+ */
4618
+ function outer(x, y) {
4619
+ x = ravel(x);
4620
+ y = ravel(y);
4621
+ return multiply(x.reshape([x.shape[0], 1]), y);
4622
+ }
4623
+ /** Vector dot product of two arrays along a given axis. */
4624
+ function vecdot(x, y, { axis } = {}) {
4625
+ const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
4626
+ const yaxis = require_backend.checkAxis(axis ?? -1, ndim(y));
4627
+ if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
4628
+ x = moveaxis(x, xaxis, -1);
4629
+ y = moveaxis(y, yaxis, -1);
4409
4630
  return dot$1(x, y);
4410
4631
  }
4411
4632
  /**
@@ -4414,7 +4635,7 @@ function vecdot(x, y) {
4414
4635
  * Like vecdot() but flattens the arguments first into vectors.
4415
4636
  */
4416
4637
  function vdot(x, y) {
4417
- return vecdot(ravel(x), ravel(y));
4638
+ return dot$1(ravel(x), ravel(y));
4418
4639
  }
4419
4640
  /**
4420
4641
  * Return a tuple of coordinate matrices from coordinate vectors.
@@ -4443,6 +4664,43 @@ function meshgrid(xs, { indexing } = {}) {
4443
4664
  return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
4444
4665
  }
4445
4666
  /**
4667
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
4668
+ *
4669
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
4670
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
4671
+ * `k>0` is above it.
4672
+ */
4673
+ function tri(n, m, k = 0, { dtype, device } = {}) {
4674
+ m ??= n;
4675
+ dtype ??= require_backend.DType.Float32;
4676
+ if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
4677
+ if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
4678
+ if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
4679
+ const rows = arange(k, n + k, 1, {
4680
+ dtype: require_backend.DType.Int32,
4681
+ device
4682
+ });
4683
+ const cols = arange(0, m, 1, {
4684
+ dtype: require_backend.DType.Int32,
4685
+ device
4686
+ });
4687
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
4688
+ }
4689
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
4690
+ function tril(a, k = 0) {
4691
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4692
+ a = fudgeArray(a);
4693
+ const [n, m] = a.shape.slice(-2);
4694
+ return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
4695
+ }
4696
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
4697
+ function triu(a, k = 0) {
4698
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4699
+ a = fudgeArray(a);
4700
+ const [n, m] = a.shape.slice(-2);
4701
+ return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
4702
+ }
4703
+ /**
4446
4704
  * Clip (limit) the values in an array.
4447
4705
  *
4448
4706
  * Given an interval, values outside the interval are clipped to the interval
@@ -4466,18 +4724,70 @@ function absolute(x) {
4466
4724
  x = fudgeArray(x);
4467
4725
  return where(less(x.ref, 0), x.ref.mul(-1), x);
4468
4726
  }
4469
- /** Alias of `jax.numpy.absolute()`. */
4727
+ /** @function Alias of `jax.numpy.absolute()`. */
4470
4728
  const abs = absolute;
4729
+ /** Return an element-wise indication of sign of the input. */
4730
+ function sign(x) {
4731
+ x = fudgeArray(x);
4732
+ return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4733
+ }
4471
4734
  /** Calculate element-wise square of the input array. */
4472
4735
  function square(x) {
4473
4736
  x = fudgeArray(x);
4474
4737
  return x.ref.mul(x);
4475
4738
  }
4476
- /** Compute a trigonometric tangent of each element of input. */
4739
+ /** Element-wise tangent function (takes radians). */
4477
4740
  function tan(x) {
4478
4741
  x = fudgeArray(x);
4479
4742
  return sin(x.ref).div(cos(x));
4480
4743
  }
4744
+ /** Element-wise inverse cosine function (inverse of cos). */
4745
+ function acos(x) {
4746
+ return subtract(pi / 2, asin(x));
4747
+ }
4748
+ /**
4749
+ * @function
4750
+ * Return element-wise hypotenuse for the given legs of a right triangle.
4751
+ *
4752
+ * In the original NumPy/JAX implementation, this function is more numerically
4753
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4754
+ * improvements.
4755
+ */
4756
+ const hypot = jit$1((x1, x2) => {
4757
+ return sqrt(square(x1).add(square(x2)));
4758
+ });
4759
+ /**
4760
+ * @function
4761
+ * Element-wise arc tangent of y/x with correct quadrant.
4762
+ *
4763
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
4764
+ * The result is in the range [-π, π].
4765
+ *
4766
+ * Uses numerically stable formulas:
4767
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
4768
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
4769
+ *
4770
+ * The output is ill-defined when both x and y are zero.
4771
+ */
4772
+ const atan2 = jit$1((y, x) => {
4773
+ const r = sqrt(square(x.ref).add(square(y.ref)));
4774
+ const xNeg = less(x.ref, 0);
4775
+ const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
4776
+ const denom = where(xNeg, y, r.add(x));
4777
+ return atan(numer.div(denom)).mul(2);
4778
+ });
4779
+ /** @function Alias of `jax.numpy.acos()`. */
4780
+ const arccos = acos;
4781
+ /** @function Alias of `jax.numpy.atan()`. */
4782
+ const arctan = atan;
4783
+ /** @function Alias of `jax.numpy.atan2()`. */
4784
+ const arctan2 = atan2;
4785
+ /** Element-wise subtraction, with broadcasting. */
4786
+ function subtract(x, y) {
4787
+ x = fudgeArray(x);
4788
+ y = fudgeArray(y);
4789
+ return x.sub(y);
4790
+ }
4481
4791
  /** Calculates the floating-point division of x by y element-wise. */
4482
4792
  function trueDivide(x, y) {
4483
4793
  x = fudgeArray(x);
@@ -4485,7 +4795,7 @@ function trueDivide(x, y) {
4485
4795
  if (!require_backend.isFloatDtype(x.dtype) || !require_backend.isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
4486
4796
  return x.div(y);
4487
4797
  }
4488
- /** Alias of `jax.numpy.trueDivide()`. */
4798
+ /** @function Alias of `jax.numpy.trueDivide()`. */
4489
4799
  const divide = trueDivide;
4490
4800
  /** Round input to the nearest integer towards zero. */
4491
4801
  function trunc(x) {
@@ -4503,36 +4813,134 @@ function log2(x) {
4503
4813
  function log10(x) {
4504
4814
  return log(x).mul(Math.LOG10E);
4505
4815
  }
4816
+ /** Calculate `exp(x) - 1` element-wise. */
4817
+ function expm1(x) {
4818
+ return exp(x).sub(1);
4819
+ }
4820
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
4821
+ function log1p(x) {
4822
+ return log(add(1, x));
4823
+ }
4824
+ /** Convert angles from degrees to radians. */
4825
+ function deg2rad(x) {
4826
+ return multiply(x, pi / 180);
4827
+ }
4828
+ /** @function Alias of `jax.numpy.deg2rad()`. */
4829
+ const radians = deg2rad;
4830
+ /** Convert angles from radians to degrees. */
4831
+ function rad2deg(x) {
4832
+ return multiply(x, 180 / pi);
4833
+ }
4834
+ /** @function Alias of `jax.numpy.rad2deg()`. */
4835
+ const degrees = rad2deg;
4506
4836
  /**
4837
+ * @function
4838
+ * Computes first array raised to power of second array, element-wise.
4839
+ */
4840
+ const power = jit$1((x1, x2) => {
4841
+ return exp(log(x1).mul(x2));
4842
+ });
4843
+ /** @function Alias of `jax.numpy.power()`. */
4844
+ const pow = power;
4845
+ /** @function Calculate the element-wise cube root of the input array. */
4846
+ const cbrt = jit$1((x) => {
4847
+ const sgn = where(less(x.ref, 0), -1, 1);
4848
+ return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4849
+ });
4850
+ /**
4851
+ * @function
4507
4852
  * Calculate element-wise hyperbolic sine of input.
4508
4853
  *
4509
4854
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
4510
4855
  */
4511
- function sinh(x) {
4856
+ const sinh = jit$1((x) => {
4512
4857
  const ex = exp(x);
4513
4858
  const emx = reciprocal(ex.ref);
4514
4859
  return ex.sub(emx).mul(.5);
4515
- }
4860
+ });
4516
4861
  /**
4862
+ * @function
4517
4863
  * Calculate element-wise hyperbolic cosine of input.
4518
4864
  *
4519
4865
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
4520
4866
  */
4521
- function cosh(x) {
4867
+ const cosh = jit$1((x) => {
4522
4868
  const ex = exp(x);
4523
4869
  const emx = reciprocal(ex.ref);
4524
4870
  return ex.add(emx).mul(.5);
4525
- }
4871
+ });
4526
4872
  /**
4873
+ * @function
4527
4874
  * Calculate element-wise hyperbolic tangent of input.
4528
4875
  *
4529
4876
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4530
4877
  */
4531
- function tanh(x) {
4532
- x = fudgeArray(x);
4878
+ const tanh = jit$1((x) => {
4533
4879
  const negsgn = where(less(x.ref, 0), 1, -1);
4534
4880
  const en2x = exp(x.mul(negsgn.ref).mul(2));
4535
4881
  return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
4882
+ });
4883
+ /**
4884
+ * @function
4885
+ * Calculate element-wise inverse hyperbolic sine of input.
4886
+ *
4887
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4888
+ */
4889
+ const arcsinh = jit$1((x) => {
4890
+ return log(x.ref.add(sqrt(square(x).add(1))));
4891
+ });
4892
+ /**
4893
+ * @function
4894
+ * Calculate element-wise inverse hyperbolic cosine of input.
4895
+ *
4896
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4897
+ */
4898
+ const arccosh = jit$1((x) => {
4899
+ return log(x.ref.add(sqrt(square(x).sub(1))));
4900
+ });
4901
+ /**
4902
+ * @function
4903
+ * Calculate element-wise inverse hyperbolic tangent of input.
4904
+ *
4905
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4906
+ */
4907
+ const arctanh = jit$1((x) => {
4908
+ return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4909
+ });
4910
+ /** @function Alias of `jax.numpy.arcsinh()`. */
4911
+ const asinh = arcsinh;
4912
+ /** @function Alias of `jax.numpy.arccosh()`. */
4913
+ const acosh = arccosh;
4914
+ /** @function Alias of `jax.numpy.arctanh()`. */
4915
+ const atanh = arctanh;
4916
+ /**
4917
+ * Compute the variance of an array.
4918
+ *
4919
+ * The variance is computed for the flattened array by default, otherwise over
4920
+ * the specified axis.
4921
+ *
4922
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4923
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4924
+ */
4925
+ function var_(x, axis = null, opts) {
4926
+ x = fudgeArray(x);
4927
+ axis = require_backend.normalizeAxis(axis, x.ndim);
4928
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
4929
+ if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
4930
+ const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
4931
+ return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
4932
+ }
4933
+ /**
4934
+ * Compute the standard deviation of an array.
4935
+ *
4936
+ * The standard deviation is computed for the flattened array by default,
4937
+ * otherwise over the specified axis.
4938
+ *
4939
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4940
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4941
+ */
4942
+ function std(x, axis = null, opts) {
4943
+ return sqrt(var_(x, axis, opts));
4536
4944
  }
4537
4945
 
4538
4946
  //#endregion
@@ -4547,6 +4955,7 @@ __export(nn_exports, {
4547
4955
  leakyRelu: () => leakyRelu,
4548
4956
  logSigmoid: () => logSigmoid,
4549
4957
  logSoftmax: () => logSoftmax,
4958
+ logmeanexp: () => logmeanexp,
4550
4959
  logsumexp: () => logsumexp,
4551
4960
  mish: () => mish,
4552
4961
  oneHot: () => oneHot,
@@ -4557,6 +4966,8 @@ __export(nn_exports, {
4557
4966
  softSign: () => softSign,
4558
4967
  softmax: () => softmax,
4559
4968
  softplus: () => softplus,
4969
+ squareplus: () => squareplus,
4970
+ standardize: () => standardize,
4560
4971
  swish: () => swish
4561
4972
  });
4562
4973
  /**
@@ -4600,6 +5011,7 @@ function softSign(x) {
4600
5011
  return x.ref.div(absolute(x).add(1));
4601
5012
  }
4602
5013
  /**
5014
+ * @function
4603
5015
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4604
5016
  * Swish, computed element-wise:
4605
5017
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4610,6 +5022,7 @@ function softSign(x) {
4610
5022
  */
4611
5023
  const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
4612
5024
  /**
5025
+ * @function
4613
5026
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4614
5027
  * Swish, computed element-wise:
4615
5028
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4626,7 +5039,10 @@ const swish = silu;
4626
5039
  function logSigmoid(x) {
4627
5040
  return negative(softplus(negative(x)));
4628
5041
  }
4629
- /** Identity activation function. Returns the argument unmodified. */
5042
+ /**
5043
+ * @function
5044
+ * Identity activation function. Returns the argument unmodified.
5045
+ */
4630
5046
  const identity = fudgeArray;
4631
5047
  /** Leaky rectified linear (ReLU) activation function */
4632
5048
  function leakyRelu(x, negativeSlope = .01) {
@@ -4654,6 +5070,7 @@ function celu(x, alpha = 1) {
4654
5070
  return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
4655
5071
  }
4656
5072
  /**
5073
+ * @function
4657
5074
  * Gaussion error linear unit (GELU) activation function.
4658
5075
  *
4659
5076
  * This is computed element-wise. Currently jax-js does not support the erf() or
@@ -4685,6 +5102,16 @@ function glu(x, axis = -1) {
4685
5102
  return a.mul(sigmoid(b));
4686
5103
  }
4687
5104
  /**
5105
+ * Squareplus activation function.
5106
+ *
5107
+ * Computes the element-wise function:
5108
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
5109
+ */
5110
+ function squareplus(x, b = 4) {
5111
+ x = fudgeArray(x);
5112
+ return x.ref.add(sqrt(square(x).add(b))).mul(.5);
5113
+ }
5114
+ /**
4688
5115
  * Mish activation function.
4689
5116
  *
4690
5117
  * Computes the element-wise function:
@@ -4702,17 +5129,13 @@ function mish(x) {
4702
5129
  *
4703
5130
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
4704
5131
  */
4705
- function softmax(x, axis) {
5132
+ function softmax(x, axis = -1) {
4706
5133
  x = fudgeArray(x);
4707
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4708
- else if (typeof axis === "number") axis = [axis];
4709
- if (axis.length === 0) {
4710
- x.dispose();
4711
- return ones(x.shape);
4712
- }
4713
- const xMax = max(x.ref, axis, { keepDims: true });
5134
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5135
+ if (axis.length === 0) return onesLike(x);
5136
+ const xMax = max(x.ref, axis, { keepdims: true });
4714
5137
  const unnormalized = exp(x.sub(stopGradient(xMax)));
4715
- return unnormalized.ref.div(unnormalized.sum(axis, { keepDims: true }));
5138
+ return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
4716
5139
  }
4717
5140
  /**
4718
5141
  * Log-Softmax function.
@@ -4722,17 +5145,13 @@ function softmax(x, axis) {
4722
5145
  *
4723
5146
  * If `axis` is not specified, it defaults to the last axis.
4724
5147
  */
4725
- function logSoftmax(x, axis) {
5148
+ function logSoftmax(x, axis = -1) {
4726
5149
  x = fudgeArray(x);
4727
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4728
- else if (typeof axis === "number") axis = [axis];
4729
- if (axis.length === 0) {
4730
- x.dispose();
4731
- return zeros(x.shape);
4732
- }
4733
- const xMax = max(x.ref, axis, { keepDims: true });
5150
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5151
+ if (axis.length === 0) return zerosLike(x);
5152
+ const xMax = max(x.ref, axis, { keepdims: true });
4734
5153
  const shifted = x.sub(stopGradient(xMax));
4735
- const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepDims: true }));
5154
+ const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
4736
5155
  return shifted.sub(shiftedLogsumexp);
4737
5156
  }
4738
5157
  /**
@@ -4743,16 +5162,39 @@ function logSoftmax(x, axis) {
4743
5162
  *
4744
5163
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
4745
5164
  */
4746
- function logsumexp(x, axis) {
5165
+ function logsumexp(x, axis = null) {
4747
5166
  x = fudgeArray(x);
4748
- if (axis === void 0) axis = require_backend.range(x.ndim);
4749
- else if (typeof axis === "number") axis = [axis];
5167
+ axis = require_backend.normalizeAxis(axis, x.ndim);
4750
5168
  if (axis.length === 0) return x;
4751
5169
  const xMax = stopGradient(max(x.ref, axis));
4752
5170
  const xMaxDims = broadcast(xMax.ref, x.shape, axis);
4753
5171
  const shifted = x.sub(xMaxDims);
4754
5172
  return xMax.add(log(exp(shifted).sum(axis)));
4755
5173
  }
5174
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
5175
+ function logmeanexp(x, axis = null) {
5176
+ x = fudgeArray(x);
5177
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5178
+ if (axis.length === 0) return x;
5179
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5180
+ return logsumexp(x, axis).sub(Math.log(n));
5181
+ }
5182
+ /**
5183
+ * Standardizes input to zero mean and unit variance.
5184
+ *
5185
+ * By default, this is computed over the last axis. You can pass in a different
5186
+ * axis, or `null` to standardize over all elements.
5187
+ *
5188
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
5189
+ */
5190
+ function standardize(x, axis = -1, opts = {}) {
5191
+ x = fudgeArray(x);
5192
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5193
+ if (axis.length === 0) return x;
5194
+ const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
5195
+ const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
5196
+ return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
5197
+ }
4756
5198
  /**
4757
5199
  * One-hot encodes the given indices.
4758
5200
  *
@@ -4770,7 +5212,7 @@ function logsumexp(x, axis) {
4770
5212
  * ```
4771
5213
  */
4772
5214
  function oneHot(x, numClasses) {
4773
- if (x.dtype !== require_backend.DType.Int32) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
5215
+ if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4774
5216
  return eye(numClasses, void 0, { device: x.device }).slice(x);
4775
5217
  }
4776
5218
 
@@ -4778,8 +5220,11 @@ function oneHot(x, numClasses) {
4778
5220
  //#region src/random.ts
4779
5221
  var random_exports = {};
4780
5222
  __export(random_exports, {
5223
+ bernoulli: () => bernoulli,
4781
5224
  bits: () => bits,
5225
+ exponential: () => exponential,
4782
5226
  key: () => key,
5227
+ normal: () => normal,
4783
5228
  split: () => split,
4784
5229
  uniform: () => uniform
4785
5230
  });
@@ -4810,11 +5255,11 @@ function bits(key$1, shape$1 = []) {
4810
5255
  /** Sample uniform random values in [minval, maxval) with given shape. */
4811
5256
  function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4812
5257
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
4813
- const mantissa = bits(key$1, shape$1).div(scalar(512, {
5258
+ const mantissa = bits(key$1, shape$1).div(array(512, {
4814
5259
  dtype: require_backend.DType.Uint32,
4815
5260
  device: key$1.device
4816
5261
  }));
4817
- const float12 = mantissa.add(scalar(1065353216, {
5262
+ const float12 = mantissa.add(array(1065353216, {
4818
5263
  dtype: require_backend.DType.Uint32,
4819
5264
  device: key$1.device
4820
5265
  }));
@@ -4822,6 +5267,36 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4822
5267
  if (minval === 0 && maxval === 1) return rand;
4823
5268
  else return rand.mul(maxval - minval).add(minval);
4824
5269
  }
5270
+ /**
5271
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
5272
+ *
5273
+ * Returns a random Boolean array with the specified shape. `p` can be an array
5274
+ * and must be broadcastable to `shape`.
5275
+ */
5276
+ function bernoulli(key$1, p = .5, shape$1 = []) {
5277
+ p = fudgeArray(p);
5278
+ return uniform(key$1, shape$1).less(p);
5279
+ }
5280
+ /** Sample exponential random values according to `p(x) = exp(-x)`. */
5281
+ function exponential(key$1, shape$1 = []) {
5282
+ const u = uniform(key$1, shape$1);
5283
+ return negative(log1p(negative(u)));
5284
+ }
5285
+ /**
5286
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5287
+ *
5288
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5289
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5290
+ * bitwise identical to JAX.
5291
+ */
5292
+ function normal(key$1, shape$1 = []) {
5293
+ const [k1, k2] = split(key$1, 2);
5294
+ const u1 = uniform(k1, shape$1);
5295
+ const u2 = uniform(k2, shape$1);
5296
+ const radius = sqrt(log1p(negative(u1)).mul(-2));
5297
+ const theta = u2.mul(2 * Math.PI);
5298
+ return radius.mul(cos(theta));
5299
+ }
4825
5300
 
4826
5301
  //#endregion
4827
5302
  //#region src/polyfills.ts
@@ -4831,20 +5306,36 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
4831
5306
 
4832
5307
  //#endregion
4833
5308
  //#region src/index.ts
4834
- /** Compute the forward-mode Jacobian-vector product for a function. */
5309
+ /**
5310
+ * @function
5311
+ * Compute the forward-mode Jacobian-vector product for a function.
5312
+ */
4835
5313
  const jvp = jvp$1;
4836
- /** Vectorize an operation on a batched axis for one or more inputs. */
5314
+ /**
5315
+ * @function
5316
+ * Vectorize an operation on a batched axis for one or more inputs.
5317
+ */
4837
5318
  const vmap = vmap$1;
4838
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
5319
+ /**
5320
+ * @function
5321
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
5322
+ */
4839
5323
  const jacfwd = jacfwd$1;
4840
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
5324
+ /**
5325
+ * @function
5326
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
5327
+ */
4841
5328
  const makeJaxpr = makeJaxpr$1;
4842
5329
  /**
5330
+ * @function
4843
5331
  * Mark a function for automatic JIT compilation, with operator fusion.
4844
5332
  *
4845
5333
  * The function will be compiled the first time it is called with a set of
4846
5334
  * argument shapes.
4847
5335
  *
5336
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
5337
+ * calls to free memory associated with array constants.
5338
+ *
4848
5339
  * **Options:**
4849
5340
  * - `staticArgnums`: An array of argument indices to treat as static
4850
5341
  * (compile-time constant). These arguments must be hashable, won't be traced,
@@ -4854,26 +5345,59 @@ const makeJaxpr = makeJaxpr$1;
4854
5345
  */
4855
5346
  const jit = jit$1;
4856
5347
  /**
5348
+ * @function
4857
5349
  * Produce a local linear approximation to a function at a point using jvp() and
4858
5350
  * partial evaluation.
4859
5351
  */
4860
5352
  const linearize = linearize$1;
4861
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
5353
+ /**
5354
+ * @function
5355
+ * Calculate the reverse-mode vector-Jacobian product for a function.
5356
+ */
4862
5357
  const vjp = vjp$1;
4863
5358
  /**
5359
+ * @function
4864
5360
  * Compute the gradient of a scalar-valued function `f` with respect to its
4865
5361
  * first argument.
4866
5362
  */
4867
5363
  const grad = grad$1;
4868
- /** Create a function that evaluates both `f` and the gradient of `f`. */
5364
+ /**
5365
+ * @function
5366
+ * Create a function that evaluates both `f` and the gradient of `f`.
5367
+ */
4869
5368
  const valueAndGrad = valueAndGrad$1;
4870
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
5369
+ /**
5370
+ * @function
5371
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
5372
+ */
4871
5373
  const jacrev = jacrev$1;
4872
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
5374
+ /**
5375
+ * @function
5376
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
5377
+ */
4873
5378
  const jacobian = jacrev;
5379
+ /**
5380
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
5381
+ *
5382
+ * This can be used to wait for the results of an intermediate computation to
5383
+ * finish. It's recommended to call this regularly in an iterative computation
5384
+ * to avoid queueing up too many pending operations.
5385
+ *
5386
+ * Does not consume reference to the arrays.
5387
+ */
5388
+ async function blockUntilReady(x) {
5389
+ const promises = [];
5390
+ for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
5391
+ await Promise.all(promises);
5392
+ return x;
5393
+ }
4874
5394
 
4875
5395
  //#endregion
5396
+ exports.Array = Array$1;
4876
5397
  exports.DType = require_backend.DType;
5398
+ exports.Jaxpr = Jaxpr;
5399
+ exports.blockUntilReady = blockUntilReady;
5400
+ exports.defaultDevice = require_backend.defaultDevice;
4877
5401
  exports.devices = require_backend.devices;
4878
5402
  exports.grad = grad;
4879
5403
  exports.init = require_backend.init;
@@ -4908,7 +5432,7 @@ Object.defineProperty(exports, 'random', {
4908
5432
  return random_exports;
4909
5433
  }
4910
5434
  });
4911
- exports.setDevice = require_backend.setDevice;
5435
+ exports.setDebug = require_backend.setDebug;
4912
5436
  Object.defineProperty(exports, 'tree', {
4913
5437
  enumerable: true,
4914
5438
  get: function () {