@jax-js/jax 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BqDtPGaR.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-EBRGmEYw.js";
3
3
 
4
4
  //#region src/tree.ts
5
5
  var tree_exports = {};
@@ -323,6 +323,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
323
323
  Primitive$1["RandomBits"] = "random_bits";
324
324
  Primitive$1["Sin"] = "sin";
325
325
  Primitive$1["Cos"] = "cos";
326
+ Primitive$1["Asin"] = "asin";
327
+ Primitive$1["Atan"] = "atan";
326
328
  Primitive$1["Exp"] = "exp";
327
329
  Primitive$1["Log"] = "log";
328
330
  Primitive$1["Sqrt"] = "sqrt";
@@ -390,6 +392,12 @@ function sin$1(x) {
390
392
  function cos$1(x) {
391
393
  return bind1(Primitive.Cos, [x]);
392
394
  }
395
+ function asin$1(x) {
396
+ return bind1(Primitive.Asin, [x]);
397
+ }
398
+ function atan$1(x) {
399
+ return bind1(Primitive.Atan, [x]);
400
+ }
393
401
  function exp$1(x) {
394
402
  return bind1(Primitive.Exp, [x]);
395
403
  }
@@ -405,18 +413,16 @@ function min$1(x, y) {
405
413
  function max$1(x, y) {
406
414
  return bind1(Primitive.Max, [x, y]);
407
415
  }
408
- function reduce(x, op, axis, opts) {
416
+ function reduce(x, op, axis = null, opts) {
409
417
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
410
- if (axis === void 0) if (x instanceof Tracer) axis = range(x.shape.length);
411
- else axis = [];
412
- else if (typeof axis === "number") axis = [checkAxis(axis, ndim$1(x))];
413
- else axis = axis.map((a) => checkAxis(a, ndim$1(x)));
418
+ axis = normalizeAxis(axis, ndim$1(x));
414
419
  const originalShape = getShape(x);
415
- const result = bind1(Primitive.Reduce, [x], {
420
+ let result = bind1(Primitive.Reduce, [x], {
416
421
  op,
417
422
  axis
418
423
  });
419
- return opts?.keepDims ? broadcast(result, originalShape, axis) : result;
424
+ if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
425
+ return result;
420
426
  }
421
427
  function dot$1(x, y) {
422
428
  return bind1(Primitive.Dot, [x, y]);
@@ -462,10 +468,11 @@ function where$1(cond, x, y) {
462
468
  }
463
469
  function transpose$1(x, perm) {
464
470
  perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
471
+ if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
465
472
  return bind1(Primitive.Transpose, [x], { perm });
466
473
  }
467
474
  function broadcast(x, shape$1, axis) {
468
- axis = axis.map((a) => checkAxis(a, shape$1.length));
475
+ axis = normalizeAxis(axis, shape$1.length);
469
476
  return bind1(Primitive.Broadcast, [x], {
470
477
  shape: shape$1,
471
478
  axis
@@ -484,7 +491,7 @@ function reshape$1(x, shape$1) {
484
491
  return bind1(Primitive.Reshape, [x], { shape: shape$1 });
485
492
  }
486
493
  function flip$1(x, axis) {
487
- axis = axis.map((a) => checkAxis(a, ndim$1(x)));
494
+ axis = normalizeAxis(axis, ndim$1(x));
488
495
  return bind1(Primitive.Flip, [x], { axis });
489
496
  }
490
497
  function shrink(x, slice) {
@@ -564,15 +571,19 @@ var Tracer = class Tracer {
564
571
  constructor(trace) {
565
572
  this._trace = trace;
566
573
  }
574
+ /** The shape of the array. */
567
575
  get shape() {
568
576
  return this.aval.shape;
569
577
  }
578
+ /** The total number of elements in the array. */
570
579
  get size() {
571
580
  return prod(this.shape);
572
581
  }
582
+ /** The dtype of the array. */
573
583
  get dtype() {
574
584
  return this.aval.dtype;
575
585
  }
586
+ /** The number of dimensions of the array. */
576
587
  get ndim() {
577
588
  return this.shape.length;
578
589
  }
@@ -608,22 +619,20 @@ var Tracer = class Tracer {
608
619
  return lessEqual$1(this, other);
609
620
  }
610
621
  /** Sum of the elements of the array over a given axis, or axes. */
611
- sum(axis, opts) {
622
+ sum(axis = null, opts) {
612
623
  return reduce(this, AluOp.Add, axis, opts);
613
624
  }
614
625
  /** Product of the array elements over a given axis. */
615
- prod(axis, opts) {
626
+ prod(axis = null, opts) {
616
627
  return reduce(this, AluOp.Mul, axis, opts);
617
628
  }
618
629
  /** Compute the average of the array elements along the specified axis. */
619
- mean(axis, opts) {
620
- if (axis === void 0) axis = range(this.ndim);
621
- else if (typeof axis === "number") axis = [checkAxis(axis, this.ndim)];
622
- else axis = axis.map((a) => checkAxis(a, this.ndim));
623
- let result = reduce(this, AluOp.Add, axis);
624
- result = result.mul(result.size / this.size);
625
- if (opts?.keepDims) result = broadcast(result, this.shape, axis);
626
- return result;
630
+ mean(axis = null, opts) {
631
+ axis = normalizeAxis(axis, this.ndim);
632
+ const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
633
+ if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
634
+ const result = reduce(this, AluOp.Add, axis, opts);
635
+ return result.mul(1 / n);
627
636
  }
628
637
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
629
638
  transpose(perm) {
@@ -1156,6 +1165,8 @@ const jitRules = {
1156
1165
  },
1157
1166
  [Primitive.Sin]: unopJit(AluExp.sin),
1158
1167
  [Primitive.Cos]: unopJit(AluExp.cos),
1168
+ [Primitive.Asin]: unopJit(AluExp.asin),
1169
+ [Primitive.Atan]: unopJit(AluExp.atan),
1159
1170
  [Primitive.Exp]: unopJit(AluExp.exp),
1160
1171
  [Primitive.Log]: unopJit(AluExp.log),
1161
1172
  [Primitive.Sqrt]: unopJit(AluExp.sqrt),
@@ -1397,7 +1408,7 @@ var Array$1 = class Array$1 extends Tracer {
1397
1408
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1398
1409
  * will be freed when the array is disposed.
1399
1410
  */
1400
- constructor(source, st, dtype, backend, pending = null) {
1411
+ constructor(source, st, dtype, backend, { pending = null } = {}) {
1401
1412
  super(baseArrayTrace);
1402
1413
  this.id = Array$1.#nextId++;
1403
1414
  this.#dtype = dtype;
@@ -1406,6 +1417,8 @@ var Array$1 = class Array$1 extends Tracer {
1406
1417
  this.#backend = backend;
1407
1418
  this.#rc = 1;
1408
1419
  this.#pendingSet = new Set(pending);
1420
+ if (this.#pendingSet.size === 0) this.#pendingSet = null;
1421
+ else if (source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1409
1422
  }
1410
1423
  /** @ignore */
1411
1424
  get aval() {
@@ -1460,7 +1473,7 @@ var Array$1 = class Array$1 extends Tracer {
1460
1473
  const pending = this.#pending;
1461
1474
  for (const exe of pending) exe.updateRc(1);
1462
1475
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1463
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
1476
+ const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
1464
1477
  this.dispose();
1465
1478
  return ar;
1466
1479
  }
@@ -1509,7 +1522,7 @@ var Array$1 = class Array$1 extends Tracer {
1509
1522
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1510
1523
  this.dispose();
1511
1524
  for (const ar of indices) ar.dispose();
1512
- return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
1525
+ return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
1513
1526
  }
1514
1527
  /** Move axes to the rightmost dimension of the shape. */
1515
1528
  #moveAxesDown(axis) {
@@ -1546,7 +1559,7 @@ var Array$1 = class Array$1 extends Tracer {
1546
1559
  for (const exe of pending) exe.updateRc(1);
1547
1560
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1548
1561
  this.dispose();
1549
- return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
1562
+ return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
1550
1563
  }
1551
1564
  #binary(op, other) {
1552
1565
  const custom = (src) => new AluExp(op, this.#dtype, src);
@@ -1611,7 +1624,7 @@ var Array$1 = class Array$1 extends Tracer {
1611
1624
  for (const exe of pending) exe.updateRc(1);
1612
1625
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1613
1626
  for (const ar of arrays) ar.dispose();
1614
- return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
1627
+ return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
1615
1628
  }
1616
1629
  /** Reduce the last dimension of the array by an operation. */
1617
1630
  #reduce(op) {
@@ -1635,7 +1648,7 @@ var Array$1 = class Array$1 extends Tracer {
1635
1648
  for (const exe of pending) exe.updateRc(1);
1636
1649
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1637
1650
  this.dispose();
1638
- return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
1651
+ return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
1639
1652
  }
1640
1653
  /**
1641
1654
  * Normalizes this array into one backed by a `Slot`.
@@ -1708,8 +1721,11 @@ var Array$1 = class Array$1 extends Tracer {
1708
1721
  *
1709
1722
  * If you are mapping from `data()` or `dataSync()`, it will also trigger
1710
1723
  * dispatch of operations as well.
1724
+ *
1725
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1726
+ * asynchronously for multiple arrays.
1711
1727
  */
1712
- async wait() {
1728
+ async blockUntilReady() {
1713
1729
  this.#check();
1714
1730
  if (this.#source instanceof AluExp) return this;
1715
1731
  const pending = this.#pending;
@@ -1775,7 +1791,7 @@ var Array$1 = class Array$1 extends Tracer {
1775
1791
  return [x.#binary(AluOp.Idiv, y)];
1776
1792
  },
1777
1793
  [Primitive.Neg]([x]) {
1778
- return [zerosLike(x.ref).#binary(AluOp.Sub, x)];
1794
+ return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
1779
1795
  },
1780
1796
  [Primitive.Reciprocal]([x]) {
1781
1797
  return [x.#unary(AluOp.Reciprocal)];
@@ -1795,7 +1811,7 @@ var Array$1 = class Array$1 extends Tracer {
1795
1811
  x.#backend.incRef(x.#source);
1796
1812
  const pending = x.#pending;
1797
1813
  for (const exe of pending) exe.updateRc(1);
1798
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
1814
+ const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
1799
1815
  x.dispose();
1800
1816
  return [y];
1801
1817
  }
@@ -1825,6 +1841,12 @@ var Array$1 = class Array$1 extends Tracer {
1825
1841
  [Primitive.Cos]([x]) {
1826
1842
  return [x.#unary(AluOp.Cos)];
1827
1843
  },
1844
+ [Primitive.Asin]([x]) {
1845
+ return [x.#unary(AluOp.Asin)];
1846
+ },
1847
+ [Primitive.Atan]([x]) {
1848
+ return [x.#unary(AluOp.Atan)];
1849
+ },
1828
1850
  [Primitive.Exp]([x]) {
1829
1851
  return [x.#unary(AluOp.Exp)];
1830
1852
  },
@@ -1910,7 +1932,7 @@ var Array$1 = class Array$1 extends Tracer {
1910
1932
  pending.splice(0, 0, ...prevPending);
1911
1933
  args.forEach((x) => x.dispose());
1912
1934
  return outputs.map((source, i) => {
1913
- return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
1935
+ return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
1914
1936
  });
1915
1937
  }
1916
1938
  };
@@ -2042,12 +2064,12 @@ var EvalTrace = class extends Trace {
2042
2064
  };
2043
2065
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
2044
2066
  const implRules = Array$1._implRules();
2045
- function zerosLike(val, dtype) {
2067
+ function zerosLike$1(val, dtype) {
2046
2068
  const aval = getAval(val);
2047
2069
  if (val instanceof Tracer) val.dispose();
2048
2070
  return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2049
2071
  }
2050
- function onesLike(val, dtype) {
2072
+ function onesLike$1(val, dtype) {
2051
2073
  const aval = getAval(val);
2052
2074
  if (val instanceof Tracer) val.dispose();
2053
2075
  return ones(aval.shape, { dtype: dtype ?? aval.dtype });
@@ -2110,7 +2132,7 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2110
2132
  const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
2111
2133
  return new Array$1(AluExp.cast(dtype, exp$2), ShapeTracker.fromShape([numRows, numCols]), dtype, getBackend(device));
2112
2134
  }
2113
- /** Return the identity array, with ones on the main diagonal. */
2135
+ /** Return the identity matrix, with ones on the main diagonal. */
2114
2136
  function identity$1(n, { dtype, device } = {}) {
2115
2137
  return eye(n, n, {
2116
2138
  dtype,
@@ -2386,16 +2408,19 @@ var Jaxpr = class Jaxpr {
2386
2408
  varIds.set(v, FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
2387
2409
  return id;
2388
2410
  };
2389
- hasher.update(this.inBinders.length, ...this.inBinders.map(vi));
2390
- hasher.update(this.eqns.length, ...this.eqns.flatMap((eqn) => [
2391
- eqn.primitive,
2392
- eqn.inputs.length,
2393
- ...eqn.inputs.map((x) => x instanceof Var ? vi(x) : x.value),
2394
- JSON.stringify(eqn.params),
2395
- eqn.outBinders.length,
2396
- ...eqn.outBinders.map(vi)
2397
- ]));
2398
- hasher.update(this.outs.length, ...this.outs.map((x) => x instanceof Var ? vi(x) : x.value));
2411
+ hasher.update(this.inBinders.length);
2412
+ for (const x of this.inBinders) hasher.update(vi(x));
2413
+ hasher.update(this.eqns.length);
2414
+ for (const eqn of this.eqns) {
2415
+ hasher.update(eqn.primitive);
2416
+ hasher.update(eqn.inputs.length);
2417
+ for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
2418
+ hasher.update(JSON.stringify(eqn.params));
2419
+ hasher.update(eqn.outBinders.length);
2420
+ for (const x of eqn.outBinders) hasher.update(vi(x));
2421
+ }
2422
+ hasher.update(this.outs.length);
2423
+ for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
2399
2424
  return this.#hash = hasher.value;
2400
2425
  }
2401
2426
  hash(state) {
@@ -2432,7 +2457,7 @@ var Jaxpr = class Jaxpr {
2432
2457
  const c = eqn.outBinders[0];
2433
2458
  if (atomIsLit(b, 1)) context.set(c, a);
2434
2459
  else newEqns.push(eqn);
2435
- } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2460
+ } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2436
2461
  else newEqns.push(eqn);
2437
2462
  }
2438
2463
  const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
@@ -2698,6 +2723,8 @@ const abstractEvalRules = {
2698
2723
  },
2699
2724
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2700
2725
  [Primitive.Cos]: vectorizedUnopAbstractEval,
2726
+ [Primitive.Asin]: vectorizedUnopAbstractEval,
2727
+ [Primitive.Atan]: vectorizedUnopAbstractEval,
2701
2728
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2702
2729
  [Primitive.Log]: vectorizedUnopAbstractEval,
2703
2730
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
@@ -2825,7 +2852,7 @@ function makeJaxpr$1(f, opts) {
2825
2852
  function jit$1(f, opts) {
2826
2853
  const cache = /* @__PURE__ */ new Map();
2827
2854
  const staticArgnums = new Set(opts?.staticArgnums ?? []);
2828
- return ((...args) => {
2855
+ const result = ((...args) => {
2829
2856
  const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
2830
2857
  const [argsFlat, inTree] = flatten(dynamicArgs);
2831
2858
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
@@ -2839,6 +2866,10 @@ function jit$1(f, opts) {
2839
2866
  });
2840
2867
  return unflatten(outTree, outs);
2841
2868
  });
2869
+ result.dispose = () => {
2870
+ for (const { consts } of cache.values()) for (const c of consts) c.dispose();
2871
+ };
2872
+ return result;
2842
2873
  }
2843
2874
 
2844
2875
  //#endregion
@@ -2869,7 +2900,7 @@ var JVPTrace = class extends Trace {
2869
2900
  return this.lift(pureArray(val));
2870
2901
  }
2871
2902
  lift(val) {
2872
- return new JVPTracer(this, val, zerosLike(val.ref));
2903
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
2873
2904
  }
2874
2905
  processPrimitive(primitive, tracers, params) {
2875
2906
  const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
@@ -2900,7 +2931,7 @@ function zeroTangentsJvp(primitive) {
2900
2931
  return (primals, tangents, params) => {
2901
2932
  for (const t of tangents) t.dispose();
2902
2933
  const ys = bind(primitive, primals, params);
2903
- return [ys, ys.map((y) => zerosLike(y.ref))];
2934
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
2904
2935
  };
2905
2936
  }
2906
2937
  const jvpRules = {
@@ -2918,13 +2949,13 @@ const jvpRules = {
2918
2949
  if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
2919
2950
  else {
2920
2951
  dx.dispose();
2921
- return [[cast(x.ref, dtype)], [zerosLike(x)]];
2952
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
2922
2953
  }
2923
2954
  },
2924
2955
  [Primitive.Bitcast]([x], [dx], { dtype }) {
2925
2956
  if (x.dtype === dtype) return [[x], [dx]];
2926
2957
  dx.dispose();
2927
- return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
2958
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
2928
2959
  },
2929
2960
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
2930
2961
  [Primitive.Sin]([x], [dx]) {
@@ -2933,6 +2964,14 @@ const jvpRules = {
2933
2964
  [Primitive.Cos]([x], [dx]) {
2934
2965
  return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
2935
2966
  },
2967
+ [Primitive.Asin]([x], [dx]) {
2968
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
2969
+ return [[asin$1(x)], [denom.mul(dx)]];
2970
+ },
2971
+ [Primitive.Atan]([x], [dx]) {
2972
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
2973
+ return [[atan$1(x)], [dx.div(denom)]];
2974
+ },
2936
2975
  [Primitive.Exp]([x], [dx]) {
2937
2976
  const z = exp$1(x);
2938
2977
  return [[z.ref], [z.mul(dx)]];
@@ -3048,7 +3087,10 @@ function mappedAval(batchDim, aval) {
3048
3087
  /** Move one axis to a different index. */
3049
3088
  function moveaxis$1(x, src, dst) {
3050
3089
  const t = pureArray(x);
3051
- const perm = range(t.shape.length);
3090
+ src = checkAxis(src, t.ndim);
3091
+ dst = checkAxis(dst, t.ndim);
3092
+ if (src === dst) return t;
3093
+ const perm = range(t.ndim);
3052
3094
  perm.splice(src, 1);
3053
3095
  perm.splice(dst, 0, src);
3054
3096
  return transpose$1(t, perm);
@@ -3141,6 +3183,8 @@ const vmapRules = {
3141
3183
  [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3142
3184
  [Primitive.Sin]: unopBatcher(sin$1),
3143
3185
  [Primitive.Cos]: unopBatcher(cos$1),
3186
+ [Primitive.Asin]: unopBatcher(asin$1),
3187
+ [Primitive.Atan]: unopBatcher(atan$1),
3144
3188
  [Primitive.Exp]: unopBatcher(exp$1),
3145
3189
  [Primitive.Log]: unopBatcher(log$1),
3146
3190
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
@@ -3326,20 +3370,28 @@ function linearizeFlatUtil(f, primalsIn) {
3326
3370
  function linearizeFlat(f, primalsIn) {
3327
3371
  const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3328
3372
  const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3329
- return [primalsOut, fLin];
3373
+ const dispose$1 = () => {
3374
+ for (const c of consts) c.dispose();
3375
+ };
3376
+ return [
3377
+ primalsOut,
3378
+ fLin,
3379
+ dispose$1
3380
+ ];
3330
3381
  }
3331
3382
  function linearize$1(f, ...primalsIn) {
3332
3383
  const [primalsInFlat, inTree] = flatten(primalsIn);
3333
3384
  const [fFlat, outTree] = flattenFun(f, inTree);
3334
- const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3385
+ const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3335
3386
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
3336
3387
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3337
- const fLin = (...tangentsIn) => {
3388
+ const fLin = ((...tangentsIn) => {
3338
3389
  const [tangentsInFlat, inTree2] = flatten(tangentsIn);
3339
3390
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
3340
3391
  const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
3341
3392
  return unflatten(outTree.value, tangentsOutFlat);
3342
- };
3393
+ });
3394
+ fLin.dispose = dispose$1;
3343
3395
  return [primalsOut, fLin];
3344
3396
  }
3345
3397
  var PartialEvalTracer = class extends Tracer {
@@ -3455,7 +3507,10 @@ var PartialEvalTrace = class extends Trace {
3455
3507
  avalsOut: jaxpr2.outs.map((x) => x.aval),
3456
3508
  tracerRefsOut: []
3457
3509
  };
3458
- const outs2 = jaxpr2.outs.map((x) => new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe));
3510
+ const outs2 = jaxpr2.outs.map((x, i$1) => {
3511
+ if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
3512
+ return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
3513
+ });
3459
3514
  recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
3460
3515
  let i = 0;
3461
3516
  let j = 0;
@@ -3539,13 +3594,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3539
3594
  const [consts, constvars] = unzip2(constToVar.entries());
3540
3595
  const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
3541
3596
  const outVars = tracersOut.map((t) => tracerToVar.get(t));
3542
- const jaxpr = new Jaxpr(inBinders, eqns, outVars);
3597
+ let jaxpr = new Jaxpr(inBinders, eqns, outVars);
3543
3598
  typecheckJaxpr(jaxpr);
3544
3599
  for (const t of consts) t.ref;
3545
3600
  for (const t of tracersIn) t.dispose();
3546
3601
  for (const t of tracersOut) t.dispose();
3602
+ jaxpr = jaxpr.simplify();
3603
+ if (DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
3547
3604
  return {
3548
- jaxpr: jaxpr.simplify(),
3605
+ jaxpr,
3549
3606
  consts
3550
3607
  };
3551
3608
  }
@@ -3811,20 +3868,28 @@ function vjpFlat(f, primalsIn) {
3811
3868
  const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
3812
3869
  return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
3813
3870
  };
3814
- return [primalsOut, fVjp];
3871
+ const dispose$1 = () => {
3872
+ for (const c of consts) c.dispose();
3873
+ };
3874
+ return [
3875
+ primalsOut,
3876
+ fVjp,
3877
+ dispose$1
3878
+ ];
3815
3879
  }
3816
3880
  function vjp$1(f, ...primalsIn) {
3817
3881
  const [primalsInFlat, inTree] = flatten(primalsIn);
3818
3882
  const [fFlat, outTree] = flattenFun(f, inTree);
3819
- const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3883
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3820
3884
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
3821
3885
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3822
- const fVjp = (cotangentsOut) => {
3886
+ const fVjp = ((cotangentsOut) => {
3823
3887
  const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
3824
3888
  if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
3825
3889
  const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
3826
3890
  return unflatten(inTree, cotangentsInFlat);
3827
- };
3891
+ });
3892
+ fVjp.dispose = dispose$1;
3828
3893
  return [primalsOut, fVjp];
3829
3894
  }
3830
3895
  function grad$1(f) {
@@ -3842,7 +3907,8 @@ function valueAndGrad$1(f) {
3842
3907
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3843
3908
  if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3844
3909
  const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
3845
- for (const r of rest) r.dispose();
3910
+ for (const r of rest) dispose(r);
3911
+ fVjp.dispose();
3846
3912
  return [y, ct];
3847
3913
  };
3848
3914
  }
@@ -3850,7 +3916,13 @@ function jacrev$1(f) {
3850
3916
  return function jacobianReverse(x) {
3851
3917
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
3852
3918
  const [size$1] = x.shape;
3853
- const pullback = (ct) => vjp$1(f, x)[1](ct)[0];
3919
+ const pullback = (ct) => {
3920
+ const [y, fVjp] = vjp$1(f, x);
3921
+ y.dispose();
3922
+ const [ret] = fVjp(ct);
3923
+ fVjp.dispose();
3924
+ return ret;
3925
+ };
3854
3926
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
3855
3927
  };
3856
3928
  }
@@ -3930,19 +4002,38 @@ __export(numpy_exports, {
3930
4002
  DType: () => DType,
3931
4003
  abs: () => abs,
3932
4004
  absolute: () => absolute,
4005
+ acos: () => acos,
4006
+ acosh: () => acosh,
3933
4007
  add: () => add,
3934
4008
  allclose: () => allclose,
3935
4009
  arange: () => arange,
4010
+ arccos: () => arccos,
4011
+ arccosh: () => arccosh,
4012
+ arcsinh: () => arcsinh,
4013
+ arctan: () => arctan,
4014
+ arctan2: () => arctan2,
4015
+ arctanh: () => arctanh,
3936
4016
  argmax: () => argmax,
3937
4017
  argmin: () => argmin,
3938
4018
  array: () => array,
4019
+ asin: () => asin,
4020
+ asinh: () => asinh,
3939
4021
  astype: () => astype,
4022
+ atan: () => atan,
4023
+ atan2: () => atan2,
4024
+ atanh: () => atanh,
3940
4025
  bool: () => bool,
4026
+ broadcastArrays: () => broadcastArrays,
4027
+ broadcastShapes: () => broadcastShapes,
4028
+ broadcastTo: () => broadcastTo,
4029
+ cbrt: () => cbrt,
3941
4030
  clip: () => clip,
3942
4031
  columnStack: () => columnStack,
3943
4032
  concatenate: () => concatenate,
3944
4033
  cos: () => cos,
3945
4034
  cosh: () => cosh,
4035
+ deg2rad: () => deg2rad,
4036
+ degrees: () => degrees,
3946
4037
  diag: () => diag,
3947
4038
  diagonal: () => diagonal,
3948
4039
  divide: () => divide,
@@ -3953,6 +4044,7 @@ __export(numpy_exports, {
3953
4044
  eulerGamma: () => eulerGamma,
3954
4045
  exp: () => exp,
3955
4046
  exp2: () => exp2,
4047
+ expm1: () => expm1,
3956
4048
  eye: () => eye,
3957
4049
  flip: () => flip,
3958
4050
  fliplr: () => fliplr,
@@ -3964,14 +4056,17 @@ __export(numpy_exports, {
3964
4056
  greater: () => greater,
3965
4057
  greaterEqual: () => greaterEqual,
3966
4058
  hstack: () => hstack,
4059
+ hypot: () => hypot,
3967
4060
  identity: () => identity$1,
3968
4061
  inf: () => inf,
4062
+ inner: () => inner,
3969
4063
  int32: () => int32,
3970
4064
  less: () => less,
3971
4065
  lessEqual: () => lessEqual,
3972
4066
  linspace: () => linspace,
3973
4067
  log: () => log,
3974
4068
  log10: () => log10,
4069
+ log1p: () => log1p,
3975
4070
  log2: () => log2,
3976
4071
  matmul: () => matmul,
3977
4072
  max: () => max,
@@ -3987,35 +4082,49 @@ __export(numpy_exports, {
3987
4082
  negative: () => negative,
3988
4083
  notEqual: () => notEqual,
3989
4084
  ones: () => ones,
3990
- onesLike: () => onesLike$1,
4085
+ onesLike: () => onesLike,
4086
+ outer: () => outer,
3991
4087
  pad: () => pad,
3992
4088
  permuteDims: () => permuteDims,
3993
4089
  pi: () => pi,
4090
+ pow: () => pow,
4091
+ power: () => power,
3994
4092
  prod: () => prod$1,
4093
+ promoteTypes: () => promoteTypes,
4094
+ rad2deg: () => rad2deg,
4095
+ radians: () => radians,
3995
4096
  ravel: () => ravel,
3996
4097
  reciprocal: () => reciprocal,
4098
+ repeat: () => repeat,
3997
4099
  reshape: () => reshape,
3998
- scalar: () => scalar,
3999
4100
  shape: () => shape,
4101
+ sign: () => sign,
4000
4102
  sin: () => sin,
4001
4103
  sinh: () => sinh,
4002
4104
  size: () => size,
4003
4105
  sqrt: () => sqrt,
4004
4106
  square: () => square,
4005
4107
  stack: () => stack,
4108
+ std: () => std,
4109
+ subtract: () => subtract,
4006
4110
  sum: () => sum,
4007
4111
  tan: () => tan,
4008
4112
  tanh: () => tanh,
4113
+ tile: () => tile,
4009
4114
  transpose: () => transpose,
4115
+ tri: () => tri,
4116
+ tril: () => tril,
4117
+ triu: () => triu,
4010
4118
  trueDivide: () => trueDivide,
4011
4119
  trunc: () => trunc,
4012
4120
  uint32: () => uint32,
4121
+ var_: () => var_,
4013
4122
  vdot: () => vdot,
4014
4123
  vecdot: () => vecdot,
4015
4124
  vstack: () => vstack,
4016
4125
  where: () => where,
4017
4126
  zeros: () => zeros,
4018
- zerosLike: () => zerosLike$1
4127
+ zerosLike: () => zerosLike
4019
4128
  });
4020
4129
  const float32 = DType.Float32;
4021
4130
  const int32 = DType.Int32;
@@ -4032,54 +4141,66 @@ const inf = Number.POSITIVE_INFINITY;
4032
4141
  const nan = NaN;
4033
4142
  /** This is Pi, `π = 3.14159265358979...` */
4034
4143
  const pi = Math.PI;
4035
- /** Element-wise addition, with broadcasting. */
4144
+ /** @function Element-wise addition, with broadcasting. */
4036
4145
  const add = add$1;
4037
- /** Element-wise multiplication, with broadcasting. */
4146
+ /** @function Element-wise multiplication, with broadcasting. */
4038
4147
  const multiply = mul;
4039
- /** Numerical negative of every element of an array. */
4148
+ /** @function Numerical negative of every element of an array. */
4040
4149
  const negative = neg;
4041
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
4150
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
4042
4151
  const reciprocal = reciprocal$1;
4043
- /** Element-wise sine function (takes radians). */
4152
+ /** @function Element-wise sine function (takes radians). */
4044
4153
  const sin = sin$1;
4045
- /** Element-wise cosine function (takes radians). */
4154
+ /** @function Element-wise cosine function (takes radians). */
4046
4155
  const cos = cos$1;
4047
- /** Calculate the exponential of all elements in the input array. */
4156
+ /** @function Element-wise inverse sine function (inverse of sin). */
4157
+ const asin = asin$1;
4158
+ /** @function Element-wise inverse tangent function (inverse of tan). */
4159
+ const atan = atan$1;
4160
+ /** @function Calculate the exponential of all elements in the input array. */
4048
4161
  const exp = exp$1;
4049
- /** Calculate the natural logarithm of all elements in the input array. */
4162
+ /** @function Calculate the natural logarithm of all elements in the input array. */
4050
4163
  const log = log$1;
4051
- /** Calculate the square root of all elements in the input array. */
4164
+ /** @function Calculate the square root of all elements in the input array. */
4052
4165
  const sqrt = sqrt$1;
4053
- /** Return element-wise minimum of the input arrays. */
4166
+ /** @function Return element-wise minimum of the input arrays. */
4054
4167
  const minimum = min$1;
4055
- /** Return element-wise maximum of the input arrays. */
4168
+ /** @function Return element-wise maximum of the input arrays. */
4056
4169
  const maximum = max$1;
4057
- /** Compare two arrays element-wise. */
4170
+ /** @function Compare two arrays element-wise. */
4058
4171
  const greater = greater$1;
4059
- /** Compare two arrays element-wise. */
4172
+ /** @function Compare two arrays element-wise. */
4060
4173
  const less = less$1;
4061
- /** Compare two arrays element-wise. */
4174
+ /** @function Compare two arrays element-wise. */
4062
4175
  const equal = equal$1;
4063
- /** Compare two arrays element-wise. */
4176
+ /** @function Compare two arrays element-wise. */
4064
4177
  const notEqual = notEqual$1;
4065
- /** Compare two arrays element-wise. */
4178
+ /** @function Compare two arrays element-wise. */
4066
4179
  const greaterEqual = greaterEqual$1;
4067
- /** Compare two arrays element-wise. */
4180
+ /** @function Compare two arrays element-wise. */
4068
4181
  const lessEqual = lessEqual$1;
4069
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4182
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4070
4183
  const where = where$1;
4071
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
4184
+ /**
4185
+ * @function
4186
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
4187
+ */
4072
4188
  const transpose = transpose$1;
4073
4189
  /**
4190
+ * @function
4074
4191
  * Give a new shape to an array without changing its data.
4075
4192
  *
4076
4193
  * One shape dimension can be -1. In this case, the value is inferred from the
4077
4194
  * length of the array and remaining dimensions.
4078
4195
  */
4079
4196
  const reshape = reshape$1;
4080
- /** Move axes of an array to new positions. Other axes retain original order. */
4197
+ /**
4198
+ * @function
4199
+ * Move axes of an array to new positions. Other axes retain original order.
4200
+ */
4081
4201
  const moveaxis = moveaxis$1;
4082
4202
  /**
4203
+ * @function
4083
4204
  * Add padding (zeros) to an array.
4084
4205
  *
4085
4206
  * The `width` argument is either an integer or pair of integers, in which case
@@ -4087,15 +4208,27 @@ const moveaxis = moveaxis$1;
4087
4208
  * pair specifies the padding for its corresponding axis.
4088
4209
  */
4089
4210
  const pad = pad$1;
4090
- /** Return the number of dimensions of an array. Does not consume array reference. */
4211
+ /**
4212
+ * @function
4213
+ * Return the number of dimensions of an array. Does not consume array reference.
4214
+ */
4091
4215
  const ndim = ndim$1;
4092
- /** Return the shape of an array. Does not consume array reference. */
4216
+ /** @function Return the shape of an array. Does not consume array reference. */
4093
4217
  const shape = getShape;
4094
- /** Return an array of zeros with the same shape and type as a given array. */
4095
- const zerosLike$1 = zerosLike;
4096
- /** Return an array of ones with the same shape and type as a given array. */
4097
- const onesLike$1 = onesLike;
4098
- /** Return a full array with the same shape and type as a given array. */
4218
+ /**
4219
+ * @function
4220
+ * Return an array of zeros with the same shape and type as a given array.
4221
+ */
4222
+ const zerosLike = zerosLike$1;
4223
+ /**
4224
+ * @function
4225
+ * Return an array of ones with the same shape and type as a given array.
4226
+ */
4227
+ const onesLike = onesLike$1;
4228
+ /**
4229
+ * @function
4230
+ * Return a full array with the same shape and type as a given array.
4231
+ */
4099
4232
  const fullLike$1 = fullLike;
4100
4233
  /**
4101
4234
  * Return the number of elements in an array, optionally along an axis.
@@ -4110,23 +4243,23 @@ function astype(a, dtype) {
4110
4243
  return fudgeArray(a).astype(dtype);
4111
4244
  }
4112
4245
  /** Sum of the elements of the array over a given axis, or axes. */
4113
- function sum(a, axis, opts) {
4246
+ function sum(a, axis = null, opts) {
4114
4247
  return reduce(a, AluOp.Add, axis, opts);
4115
4248
  }
4116
4249
  /** Product of the array elements over a given axis. */
4117
- function prod$1(a, axis, opts) {
4250
+ function prod$1(a, axis = null, opts) {
4118
4251
  return reduce(a, AluOp.Mul, axis, opts);
4119
4252
  }
4120
4253
  /** Return the minimum of array elements along a given axis. */
4121
- function min(a, axis, opts) {
4254
+ function min(a, axis = null, opts) {
4122
4255
  return reduce(a, AluOp.Min, axis, opts);
4123
4256
  }
4124
4257
  /** Return the maximum of array elements along a given axis. */
4125
- function max(a, axis, opts) {
4258
+ function max(a, axis = null, opts) {
4126
4259
  return reduce(a, AluOp.Max, axis, opts);
4127
4260
  }
4128
4261
  /** Compute the average of the array elements along the specified axis. */
4129
- function mean(a, axis, opts) {
4262
+ function mean(a, axis = null, opts) {
4130
4263
  return fudgeArray(a).mean(axis, opts);
4131
4264
  }
4132
4265
  /**
@@ -4142,7 +4275,7 @@ function argmin(a, axis, opts) {
4142
4275
  axis = 0;
4143
4276
  } else axis = checkAxis(axis, a.ndim);
4144
4277
  const shape$1 = a.shape;
4145
- const isMax = equal(a, min(a.ref, axis, { keepDims: true }));
4278
+ const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
4146
4279
  const length = scalar(shape$1[axis], {
4147
4280
  dtype: int32,
4148
4281
  device: a.device
@@ -4166,7 +4299,7 @@ function argmax(a, axis, opts) {
4166
4299
  axis = 0;
4167
4300
  } else axis = checkAxis(axis, a.ndim);
4168
4301
  const shape$1 = a.shape;
4169
- const isMax = equal(a, max(a.ref, axis, { keepDims: true }));
4302
+ const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
4170
4303
  const length = scalar(shape$1[axis], {
4171
4304
  dtype: int32,
4172
4305
  device: a.device
@@ -4178,17 +4311,9 @@ function argmax(a, axis, opts) {
4178
4311
  return length.sub(max(idx, axis, opts));
4179
4312
  }
4180
4313
  /** Reverse the elements in an array along the given axes. */
4181
- function flip(x, axis) {
4314
+ function flip(x, axis = null) {
4182
4315
  const nd = ndim(x);
4183
- if (axis === void 0) axis = range(nd);
4184
- else if (typeof axis === "number") axis = [axis];
4185
- const seen = /* @__PURE__ */ new Set();
4186
- for (let i = 0; i < axis.length; i++) {
4187
- if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
4188
- if (axis[i] < 0) axis[i] += nd;
4189
- if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
4190
- seen.add(axis[i]);
4191
- }
4316
+ axis = normalizeAxis(axis, nd);
4192
4317
  return flip$1(x, axis);
4193
4318
  }
4194
4319
  /**
@@ -4294,12 +4419,80 @@ function flipud(x) {
4294
4419
  function fliplr(x) {
4295
4420
  return flip(x, 1);
4296
4421
  }
4422
+ /** @function Alternative name for `numpy.transpose()`. */
4297
4423
  const permuteDims = transpose;
4298
4424
  /** Return a 1-D flattened array containing the elements of the input. */
4299
4425
  function ravel(a) {
4300
4426
  return fudgeArray(a).ravel();
4301
4427
  }
4302
4428
  /**
4429
+ * Repeat each element of an array after themselves.
4430
+ *
4431
+ * If no axis is provided, use the flattened input array, and return a flat
4432
+ * output array.
4433
+ */
4434
+ function repeat(a, repeats, axis) {
4435
+ if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
4436
+ a = fudgeArray(a);
4437
+ if (axis === void 0) {
4438
+ a = ravel(a);
4439
+ axis = 0;
4440
+ }
4441
+ axis = checkAxis(axis, a.ndim);
4442
+ if (repeats === 1) return a;
4443
+ const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
4444
+ const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
4445
+ return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
4446
+ }
4447
+ /**
4448
+ * Construct an array by repeating A the number of times given by reps.
4449
+ *
4450
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
4451
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
4452
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
4453
+ */
4454
+ function tile(a, reps) {
4455
+ a = fudgeArray(a);
4456
+ if (typeof reps === "number") reps = [reps];
4457
+ if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
4458
+ const ndiff = reps.length - a.ndim;
4459
+ if (ndiff > 0) a = a.reshape([...rep(ndiff, 1), ...a.shape]);
4460
+ if (ndiff < 0) reps = [...rep(-ndiff, 1), ...reps];
4461
+ const broadcastedShape = [];
4462
+ const broadcastAxes = [];
4463
+ for (let i = 0; i < a.ndim; i++) {
4464
+ if (reps[i] > 1) {
4465
+ broadcastedShape.push(reps[i]);
4466
+ broadcastAxes.push(broadcastedShape.length - 1);
4467
+ }
4468
+ broadcastedShape.push(a.shape[i]);
4469
+ }
4470
+ const finalShape = a.shape.map((d, i) => reps[i] * d);
4471
+ return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
4472
+ }
4473
+ /**
4474
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
4475
+ *
4476
+ * In other words, this lets you append axes to the left, and/or expand
4477
+ * dimensions where the shape is 1.
4478
+ */
4479
+ function broadcastTo(a, shape$1) {
4480
+ const nd = ndim(a);
4481
+ if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
4482
+ return broadcast(a, shape$1, range(shape$1.length - nd));
4483
+ }
4484
+ /** Broadcast input shapes to a common output shape. */
4485
+ function broadcastShapes(...shapes) {
4486
+ if (shapes.length === 0) return [];
4487
+ return shapes.reduce(generalBroadcast);
4488
+ }
4489
+ /** Broadcast arrays to a common shape. */
4490
+ function broadcastArrays(...arrays) {
4491
+ const shapes = arrays.map((a) => shape(a));
4492
+ const outShape = broadcastShapes(...shapes);
4493
+ return arrays.map((a) => broadcastTo(a, outShape));
4494
+ }
4495
+ /**
4303
4496
  * Return specified diagonals.
4304
4497
  *
4305
4498
  * If a is 2D, return the diagonal of the array with the given offset. If a is
@@ -4323,7 +4516,7 @@ function diag(v, k = 0) {
4323
4516
  if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
4324
4517
  if (a.ndim === 1) {
4325
4518
  const n = a.shape[0];
4326
- const ret = where(eye(n).equal(1), a.ref, zerosLike$1(a));
4519
+ const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
4327
4520
  if (k > 0) return pad(ret, [[0, k], [k, 0]]);
4328
4521
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
4329
4522
  else return ret;
@@ -4367,8 +4560,36 @@ function dot(x, y) {
4367
4560
  ]);
4368
4561
  return dot$1(x, y);
4369
4562
  }
4370
- /** Vector dot product of two arrays. */
4371
- function vecdot(x, y) {
4563
+ /**
4564
+ * Compute the inner product of two arrays.
4565
+ *
4566
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
4567
+ * contraction on the last axis.
4568
+ *
4569
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
4570
+ */
4571
+ function inner(x, y) {
4572
+ x = reshape(x, shape(x).toSpliced(-1, 0, ...rep(ndim(y) - 1, 1)));
4573
+ return dot$1(x, y);
4574
+ }
4575
+ /**
4576
+ * Compute the outer product of two arrays.
4577
+ *
4578
+ * If the input arrays are not 1D, they will be flattened. Returned array will
4579
+ * be of shape `[x.size, y.size]`.
4580
+ */
4581
+ function outer(x, y) {
4582
+ x = ravel(x);
4583
+ y = ravel(y);
4584
+ return multiply(x.reshape([x.shape[0], 1]), y);
4585
+ }
4586
+ /** Vector dot product of two arrays along a given axis. */
4587
+ function vecdot(x, y, { axis } = {}) {
4588
+ const xaxis = checkAxis(axis ?? -1, ndim(x));
4589
+ const yaxis = checkAxis(axis ?? -1, ndim(y));
4590
+ if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
4591
+ x = moveaxis(x, xaxis, -1);
4592
+ y = moveaxis(y, yaxis, -1);
4372
4593
  return dot$1(x, y);
4373
4594
  }
4374
4595
  /**
@@ -4377,7 +4598,7 @@ function vecdot(x, y) {
4377
4598
  * Like vecdot() but flattens the arguments first into vectors.
4378
4599
  */
4379
4600
  function vdot(x, y) {
4380
- return vecdot(ravel(x), ravel(y));
4601
+ return dot$1(ravel(x), ravel(y));
4381
4602
  }
4382
4603
  /**
4383
4604
  * Return a tuple of coordinate matrices from coordinate vectors.
@@ -4406,6 +4627,43 @@ function meshgrid(xs, { indexing } = {}) {
4406
4627
  return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
4407
4628
  }
4408
4629
  /**
4630
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
4631
+ *
4632
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
4633
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
4634
+ * `k>0` is above it.
4635
+ */
4636
+ function tri(n, m, k = 0, { dtype, device } = {}) {
4637
+ m ??= n;
4638
+ dtype ??= DType.Float32;
4639
+ if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
4640
+ if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
4641
+ if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
4642
+ const rows = arange(k, n + k, 1, {
4643
+ dtype: DType.Int32,
4644
+ device
4645
+ });
4646
+ const cols = arange(0, m, 1, {
4647
+ dtype: DType.Int32,
4648
+ device
4649
+ });
4650
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
4651
+ }
4652
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
4653
+ function tril(a, k = 0) {
4654
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4655
+ a = fudgeArray(a);
4656
+ const [n, m] = a.shape.slice(-2);
4657
+ return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
4658
+ }
4659
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
4660
+ function triu(a, k = 0) {
4661
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4662
+ a = fudgeArray(a);
4663
+ const [n, m] = a.shape.slice(-2);
4664
+ return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
4665
+ }
4666
+ /**
4409
4667
  * Clip (limit) the values in an array.
4410
4668
  *
4411
4669
  * Given an interval, values outside the interval are clipped to the interval
@@ -4429,18 +4687,70 @@ function absolute(x) {
4429
4687
  x = fudgeArray(x);
4430
4688
  return where(less(x.ref, 0), x.ref.mul(-1), x);
4431
4689
  }
4432
- /** Alias of `jax.numpy.absolute()`. */
4690
+ /** @function Alias of `jax.numpy.absolute()`. */
4433
4691
  const abs = absolute;
4692
+ /** Return an element-wise indication of sign of the input. */
4693
+ function sign(x) {
4694
+ x = fudgeArray(x);
4695
+ return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4696
+ }
4434
4697
  /** Calculate element-wise square of the input array. */
4435
4698
  function square(x) {
4436
4699
  x = fudgeArray(x);
4437
4700
  return x.ref.mul(x);
4438
4701
  }
4439
- /** Compute a trigonometric tangent of each element of input. */
4702
+ /** Element-wise tangent function (takes radians). */
4440
4703
  function tan(x) {
4441
4704
  x = fudgeArray(x);
4442
4705
  return sin(x.ref).div(cos(x));
4443
4706
  }
4707
+ /** Element-wise inverse cosine function (inverse of cos). */
4708
+ function acos(x) {
4709
+ return subtract(pi / 2, asin(x));
4710
+ }
4711
+ /**
4712
+ * @function
4713
+ * Return element-wise hypotenuse for the given legs of a right triangle.
4714
+ *
4715
+ * In the original NumPy/JAX implementation, this function is more numerically
4716
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4717
+ * improvements.
4718
+ */
4719
+ const hypot = jit$1((x1, x2) => {
4720
+ return sqrt(square(x1).add(square(x2)));
4721
+ });
4722
+ /**
4723
+ * @function
4724
+ * Element-wise arc tangent of y/x with correct quadrant.
4725
+ *
4726
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
4727
+ * The result is in the range [-π, π].
4728
+ *
4729
+ * Uses numerically stable formulas:
4730
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
4731
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
4732
+ *
4733
+ * The output is ill-defined when both x and y are zero.
4734
+ */
4735
+ const atan2 = jit$1((y, x) => {
4736
+ const r = sqrt(square(x.ref).add(square(y.ref)));
4737
+ const xNeg = less(x.ref, 0);
4738
+ const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
4739
+ const denom = where(xNeg, y, r.add(x));
4740
+ return atan(numer.div(denom)).mul(2);
4741
+ });
4742
+ /** @function Alias of `jax.numpy.acos()`. */
4743
+ const arccos = acos;
4744
+ /** @function Alias of `jax.numpy.atan()`. */
4745
+ const arctan = atan;
4746
+ /** @function Alias of `jax.numpy.atan2()`. */
4747
+ const arctan2 = atan2;
4748
+ /** Element-wise subtraction, with broadcasting. */
4749
+ function subtract(x, y) {
4750
+ x = fudgeArray(x);
4751
+ y = fudgeArray(y);
4752
+ return x.sub(y);
4753
+ }
4444
4754
  /** Calculates the floating-point division of x by y element-wise. */
4445
4755
  function trueDivide(x, y) {
4446
4756
  x = fudgeArray(x);
@@ -4448,7 +4758,7 @@ function trueDivide(x, y) {
4448
4758
  if (!isFloatDtype(x.dtype) || !isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
4449
4759
  return x.div(y);
4450
4760
  }
4451
- /** Alias of `jax.numpy.trueDivide()`. */
4761
+ /** @function Alias of `jax.numpy.trueDivide()`. */
4452
4762
  const divide = trueDivide;
4453
4763
  /** Round input to the nearest integer towards zero. */
4454
4764
  function trunc(x) {
@@ -4466,36 +4776,134 @@ function log2(x) {
4466
4776
  function log10(x) {
4467
4777
  return log(x).mul(Math.LOG10E);
4468
4778
  }
4779
+ /** Calculate `exp(x) - 1` element-wise. */
4780
+ function expm1(x) {
4781
+ return exp(x).sub(1);
4782
+ }
4783
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
4784
+ function log1p(x) {
4785
+ return log(add(1, x));
4786
+ }
4787
+ /** Convert angles from degrees to radians. */
4788
+ function deg2rad(x) {
4789
+ return multiply(x, pi / 180);
4790
+ }
4791
+ /** @function Alias of `jax.numpy.deg2rad()`. */
4792
+ const radians = deg2rad;
4793
+ /** Convert angles from radians to degrees. */
4794
+ function rad2deg(x) {
4795
+ return multiply(x, 180 / pi);
4796
+ }
4797
+ /** @function Alias of `jax.numpy.rad2deg()`. */
4798
+ const degrees = rad2deg;
4469
4799
  /**
4800
+ * @function
4801
+ * Computes first array raised to power of second array, element-wise.
4802
+ */
4803
+ const power = jit$1((x1, x2) => {
4804
+ return exp(log(x1).mul(x2));
4805
+ });
4806
+ /** @function Alias of `jax.numpy.power()`. */
4807
+ const pow = power;
4808
+ /** @function Calculate the element-wise cube root of the input array. */
4809
+ const cbrt = jit$1((x) => {
4810
+ const sgn = where(less(x.ref, 0), -1, 1);
4811
+ return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4812
+ });
4813
+ /**
4814
+ * @function
4470
4815
  * Calculate element-wise hyperbolic sine of input.
4471
4816
  *
4472
4817
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
4473
4818
  */
4474
- function sinh(x) {
4819
+ const sinh = jit$1((x) => {
4475
4820
  const ex = exp(x);
4476
4821
  const emx = reciprocal(ex.ref);
4477
4822
  return ex.sub(emx).mul(.5);
4478
- }
4823
+ });
4479
4824
  /**
4825
+ * @function
4480
4826
  * Calculate element-wise hyperbolic cosine of input.
4481
4827
  *
4482
4828
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
4483
4829
  */
4484
- function cosh(x) {
4830
+ const cosh = jit$1((x) => {
4485
4831
  const ex = exp(x);
4486
4832
  const emx = reciprocal(ex.ref);
4487
4833
  return ex.add(emx).mul(.5);
4488
- }
4834
+ });
4489
4835
  /**
4836
+ * @function
4490
4837
  * Calculate element-wise hyperbolic tangent of input.
4491
4838
  *
4492
4839
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4493
4840
  */
4494
- function tanh(x) {
4495
- x = fudgeArray(x);
4841
+ const tanh = jit$1((x) => {
4496
4842
  const negsgn = where(less(x.ref, 0), 1, -1);
4497
4843
  const en2x = exp(x.mul(negsgn.ref).mul(2));
4498
4844
  return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
4845
+ });
4846
+ /**
4847
+ * @function
4848
+ * Calculate element-wise inverse hyperbolic sine of input.
4849
+ *
4850
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4851
+ */
4852
+ const arcsinh = jit$1((x) => {
4853
+ return log(x.ref.add(sqrt(square(x).add(1))));
4854
+ });
4855
+ /**
4856
+ * @function
4857
+ * Calculate element-wise inverse hyperbolic cosine of input.
4858
+ *
4859
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4860
+ */
4861
+ const arccosh = jit$1((x) => {
4862
+ return log(x.ref.add(sqrt(square(x).sub(1))));
4863
+ });
4864
+ /**
4865
+ * @function
4866
+ * Calculate element-wise inverse hyperbolic tangent of input.
4867
+ *
4868
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4869
+ */
4870
+ const arctanh = jit$1((x) => {
4871
+ return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4872
+ });
4873
+ /** @function Alias of `jax.numpy.arcsinh()`. */
4874
+ const asinh = arcsinh;
4875
+ /** @function Alias of `jax.numpy.arccosh()`. */
4876
+ const acosh = arccosh;
4877
+ /** @function Alias of `jax.numpy.arctanh()`. */
4878
+ const atanh = arctanh;
4879
+ /**
4880
+ * Compute the variance of an array.
4881
+ *
4882
+ * The variance is computed for the flattened array by default, otherwise over
4883
+ * the specified axis.
4884
+ *
4885
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4886
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4887
+ */
4888
+ function var_(x, axis = null, opts) {
4889
+ x = fudgeArray(x);
4890
+ axis = normalizeAxis(axis, x.ndim);
4891
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
4892
+ if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
4893
+ const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
4894
+ return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
4895
+ }
4896
+ /**
4897
+ * Compute the standard deviation of an array.
4898
+ *
4899
+ * The standard deviation is computed for the flattened array by default,
4900
+ * otherwise over the specified axis.
4901
+ *
4902
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4903
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4904
+ */
4905
+ function std(x, axis = null, opts) {
4906
+ return sqrt(var_(x, axis, opts));
4499
4907
  }
4500
4908
 
4501
4909
  //#endregion
@@ -4510,6 +4918,7 @@ __export(nn_exports, {
4510
4918
  leakyRelu: () => leakyRelu,
4511
4919
  logSigmoid: () => logSigmoid,
4512
4920
  logSoftmax: () => logSoftmax,
4921
+ logmeanexp: () => logmeanexp,
4513
4922
  logsumexp: () => logsumexp,
4514
4923
  mish: () => mish,
4515
4924
  oneHot: () => oneHot,
@@ -4520,6 +4929,8 @@ __export(nn_exports, {
4520
4929
  softSign: () => softSign,
4521
4930
  softmax: () => softmax,
4522
4931
  softplus: () => softplus,
4932
+ squareplus: () => squareplus,
4933
+ standardize: () => standardize,
4523
4934
  swish: () => swish
4524
4935
  });
4525
4936
  /**
@@ -4563,6 +4974,7 @@ function softSign(x) {
4563
4974
  return x.ref.div(absolute(x).add(1));
4564
4975
  }
4565
4976
  /**
4977
+ * @function
4566
4978
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4567
4979
  * Swish, computed element-wise:
4568
4980
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4573,6 +4985,7 @@ function softSign(x) {
4573
4985
  */
4574
4986
  const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
4575
4987
  /**
4988
+ * @function
4576
4989
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4577
4990
  * Swish, computed element-wise:
4578
4991
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4589,7 +5002,10 @@ const swish = silu;
4589
5002
  function logSigmoid(x) {
4590
5003
  return negative(softplus(negative(x)));
4591
5004
  }
4592
- /** Identity activation function. Returns the argument unmodified. */
5005
+ /**
5006
+ * @function
5007
+ * Identity activation function. Returns the argument unmodified.
5008
+ */
4593
5009
  const identity = fudgeArray;
4594
5010
  /** Leaky rectified linear (ReLU) activation function */
4595
5011
  function leakyRelu(x, negativeSlope = .01) {
@@ -4617,6 +5033,7 @@ function celu(x, alpha = 1) {
4617
5033
  return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
4618
5034
  }
4619
5035
  /**
5036
+ * @function
4620
5037
  * Gaussion error linear unit (GELU) activation function.
4621
5038
  *
4622
5039
  * This is computed element-wise. Currently jax-js does not support the erf() or
@@ -4648,6 +5065,16 @@ function glu(x, axis = -1) {
4648
5065
  return a.mul(sigmoid(b));
4649
5066
  }
4650
5067
  /**
5068
+ * Squareplus activation function.
5069
+ *
5070
+ * Computes the element-wise function:
5071
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
5072
+ */
5073
+ function squareplus(x, b = 4) {
5074
+ x = fudgeArray(x);
5075
+ return x.ref.add(sqrt(square(x).add(b))).mul(.5);
5076
+ }
5077
+ /**
4651
5078
  * Mish activation function.
4652
5079
  *
4653
5080
  * Computes the element-wise function:
@@ -4665,17 +5092,13 @@ function mish(x) {
4665
5092
  *
4666
5093
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
4667
5094
  */
4668
- function softmax(x, axis) {
5095
+ function softmax(x, axis = -1) {
4669
5096
  x = fudgeArray(x);
4670
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4671
- else if (typeof axis === "number") axis = [axis];
4672
- if (axis.length === 0) {
4673
- x.dispose();
4674
- return ones(x.shape);
4675
- }
4676
- const xMax = max(x.ref, axis, { keepDims: true });
5097
+ axis = normalizeAxis(axis, x.ndim);
5098
+ if (axis.length === 0) return onesLike(x);
5099
+ const xMax = max(x.ref, axis, { keepdims: true });
4677
5100
  const unnormalized = exp(x.sub(stopGradient(xMax)));
4678
- return unnormalized.ref.div(unnormalized.sum(axis, { keepDims: true }));
5101
+ return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
4679
5102
  }
4680
5103
  /**
4681
5104
  * Log-Softmax function.
@@ -4685,17 +5108,13 @@ function softmax(x, axis) {
4685
5108
  *
4686
5109
  * If `axis` is not specified, it defaults to the last axis.
4687
5110
  */
4688
- function logSoftmax(x, axis) {
5111
+ function logSoftmax(x, axis = -1) {
4689
5112
  x = fudgeArray(x);
4690
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4691
- else if (typeof axis === "number") axis = [axis];
4692
- if (axis.length === 0) {
4693
- x.dispose();
4694
- return zeros(x.shape);
4695
- }
4696
- const xMax = max(x.ref, axis, { keepDims: true });
5113
+ axis = normalizeAxis(axis, x.ndim);
5114
+ if (axis.length === 0) return zerosLike(x);
5115
+ const xMax = max(x.ref, axis, { keepdims: true });
4697
5116
  const shifted = x.sub(stopGradient(xMax));
4698
- const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepDims: true }));
5117
+ const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
4699
5118
  return shifted.sub(shiftedLogsumexp);
4700
5119
  }
4701
5120
  /**
@@ -4706,16 +5125,39 @@ function logSoftmax(x, axis) {
4706
5125
  *
4707
5126
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
4708
5127
  */
4709
- function logsumexp(x, axis) {
5128
+ function logsumexp(x, axis = null) {
4710
5129
  x = fudgeArray(x);
4711
- if (axis === void 0) axis = range(x.ndim);
4712
- else if (typeof axis === "number") axis = [axis];
5130
+ axis = normalizeAxis(axis, x.ndim);
4713
5131
  if (axis.length === 0) return x;
4714
5132
  const xMax = stopGradient(max(x.ref, axis));
4715
5133
  const xMaxDims = broadcast(xMax.ref, x.shape, axis);
4716
5134
  const shifted = x.sub(xMaxDims);
4717
5135
  return xMax.add(log(exp(shifted).sum(axis)));
4718
5136
  }
5137
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
5138
+ function logmeanexp(x, axis = null) {
5139
+ x = fudgeArray(x);
5140
+ axis = normalizeAxis(axis, x.ndim);
5141
+ if (axis.length === 0) return x;
5142
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5143
+ return logsumexp(x, axis).sub(Math.log(n));
5144
+ }
5145
+ /**
5146
+ * Standardizes input to zero mean and unit variance.
5147
+ *
5148
+ * By default, this is computed over the last axis. You can pass in a different
5149
+ * axis, or `null` to standardize over all elements.
5150
+ *
5151
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
5152
+ */
5153
+ function standardize(x, axis = -1, opts = {}) {
5154
+ x = fudgeArray(x);
5155
+ axis = normalizeAxis(axis, x.ndim);
5156
+ if (axis.length === 0) return x;
5157
+ const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
5158
+ const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
5159
+ return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
5160
+ }
4719
5161
  /**
4720
5162
  * One-hot encodes the given indices.
4721
5163
  *
@@ -4733,7 +5175,7 @@ function logsumexp(x, axis) {
4733
5175
  * ```
4734
5176
  */
4735
5177
  function oneHot(x, numClasses) {
4736
- if (x.dtype !== DType.Int32) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
5178
+ if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4737
5179
  return eye(numClasses, void 0, { device: x.device }).slice(x);
4738
5180
  }
4739
5181
 
@@ -4741,8 +5183,11 @@ function oneHot(x, numClasses) {
4741
5183
  //#region src/random.ts
4742
5184
  var random_exports = {};
4743
5185
  __export(random_exports, {
5186
+ bernoulli: () => bernoulli,
4744
5187
  bits: () => bits,
5188
+ exponential: () => exponential,
4745
5189
  key: () => key,
5190
+ normal: () => normal,
4746
5191
  split: () => split,
4747
5192
  uniform: () => uniform
4748
5193
  });
@@ -4773,11 +5218,11 @@ function bits(key$1, shape$1 = []) {
4773
5218
  /** Sample uniform random values in [minval, maxval) with given shape. */
4774
5219
  function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4775
5220
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
4776
- const mantissa = bits(key$1, shape$1).div(scalar(512, {
5221
+ const mantissa = bits(key$1, shape$1).div(array(512, {
4777
5222
  dtype: DType.Uint32,
4778
5223
  device: key$1.device
4779
5224
  }));
4780
- const float12 = mantissa.add(scalar(1065353216, {
5225
+ const float12 = mantissa.add(array(1065353216, {
4781
5226
  dtype: DType.Uint32,
4782
5227
  device: key$1.device
4783
5228
  }));
@@ -4785,6 +5230,36 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4785
5230
  if (minval === 0 && maxval === 1) return rand;
4786
5231
  else return rand.mul(maxval - minval).add(minval);
4787
5232
  }
5233
+ /**
5234
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
5235
+ *
5236
+ * Returns a random Boolean array with the specified shape. `p` can be an array
5237
+ * and must be broadcastable to `shape`.
5238
+ */
5239
+ function bernoulli(key$1, p = .5, shape$1 = []) {
5240
+ p = fudgeArray(p);
5241
+ return uniform(key$1, shape$1).less(p);
5242
+ }
5243
+ /** Sample exponential random values according to `p(x) = exp(-x)`. */
5244
+ function exponential(key$1, shape$1 = []) {
5245
+ const u = uniform(key$1, shape$1);
5246
+ return negative(log1p(negative(u)));
5247
+ }
5248
+ /**
5249
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5250
+ *
5251
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5252
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5253
+ * bitwise identical to JAX.
5254
+ */
5255
+ function normal(key$1, shape$1 = []) {
5256
+ const [k1, k2] = split(key$1, 2);
5257
+ const u1 = uniform(k1, shape$1);
5258
+ const u2 = uniform(k2, shape$1);
5259
+ const radius = sqrt(log1p(negative(u1)).mul(-2));
5260
+ const theta = u2.mul(2 * Math.PI);
5261
+ return radius.mul(cos(theta));
5262
+ }
4788
5263
 
4789
5264
  //#endregion
4790
5265
  //#region src/polyfills.ts
@@ -4794,20 +5269,36 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
4794
5269
 
4795
5270
  //#endregion
4796
5271
  //#region src/index.ts
4797
- /** Compute the forward-mode Jacobian-vector product for a function. */
5272
+ /**
5273
+ * @function
5274
+ * Compute the forward-mode Jacobian-vector product for a function.
5275
+ */
4798
5276
  const jvp = jvp$1;
4799
- /** Vectorize an operation on a batched axis for one or more inputs. */
5277
+ /**
5278
+ * @function
5279
+ * Vectorize an operation on a batched axis for one or more inputs.
5280
+ */
4800
5281
  const vmap = vmap$1;
4801
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
5282
+ /**
5283
+ * @function
5284
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
5285
+ */
4802
5286
  const jacfwd = jacfwd$1;
4803
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
5287
+ /**
5288
+ * @function
5289
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
5290
+ */
4804
5291
  const makeJaxpr = makeJaxpr$1;
4805
5292
  /**
5293
+ * @function
4806
5294
  * Mark a function for automatic JIT compilation, with operator fusion.
4807
5295
  *
4808
5296
  * The function will be compiled the first time it is called with a set of
4809
5297
  * argument shapes.
4810
5298
  *
5299
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
5300
+ * calls to free memory associated with array constants.
5301
+ *
4811
5302
  * **Options:**
4812
5303
  * - `staticArgnums`: An array of argument indices to treat as static
4813
5304
  * (compile-time constant). These arguments must be hashable, won't be traced,
@@ -4817,23 +5308,52 @@ const makeJaxpr = makeJaxpr$1;
4817
5308
  */
4818
5309
  const jit = jit$1;
4819
5310
  /**
5311
+ * @function
4820
5312
  * Produce a local linear approximation to a function at a point using jvp() and
4821
5313
  * partial evaluation.
4822
5314
  */
4823
5315
  const linearize = linearize$1;
4824
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
5316
+ /**
5317
+ * @function
5318
+ * Calculate the reverse-mode vector-Jacobian product for a function.
5319
+ */
4825
5320
  const vjp = vjp$1;
4826
5321
  /**
5322
+ * @function
4827
5323
  * Compute the gradient of a scalar-valued function `f` with respect to its
4828
5324
  * first argument.
4829
5325
  */
4830
5326
  const grad = grad$1;
4831
- /** Create a function that evaluates both `f` and the gradient of `f`. */
5327
+ /**
5328
+ * @function
5329
+ * Create a function that evaluates both `f` and the gradient of `f`.
5330
+ */
4832
5331
  const valueAndGrad = valueAndGrad$1;
4833
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
5332
+ /**
5333
+ * @function
5334
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
5335
+ */
4834
5336
  const jacrev = jacrev$1;
4835
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
5337
+ /**
5338
+ * @function
5339
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
5340
+ */
4836
5341
  const jacobian = jacrev;
5342
+ /**
5343
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
5344
+ *
5345
+ * This can be used to wait for the results of an intermediate computation to
5346
+ * finish. It's recommended to call this regularly in an iterative computation
5347
+ * to avoid queueing up too many pending operations.
5348
+ *
5349
+ * Does not consume reference to the arrays.
5350
+ */
5351
+ async function blockUntilReady(x) {
5352
+ const promises = [];
5353
+ for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
5354
+ await Promise.all(promises);
5355
+ return x;
5356
+ }
4837
5357
 
4838
5358
  //#endregion
4839
- export { DType, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDevice, tree_exports as tree, valueAndGrad, vjp, vmap };
5359
+ export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };