@jax-js/jax 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BqDtPGaR.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-CdcTZEOF.js";
3
3
 
4
4
  //#region src/tree.ts
5
5
  var tree_exports = {};
@@ -323,6 +323,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
323
323
  Primitive$1["RandomBits"] = "random_bits";
324
324
  Primitive$1["Sin"] = "sin";
325
325
  Primitive$1["Cos"] = "cos";
326
+ Primitive$1["Asin"] = "asin";
327
+ Primitive$1["Atan"] = "atan";
326
328
  Primitive$1["Exp"] = "exp";
327
329
  Primitive$1["Log"] = "log";
328
330
  Primitive$1["Sqrt"] = "sqrt";
@@ -390,6 +392,12 @@ function sin$1(x) {
390
392
  function cos$1(x) {
391
393
  return bind1(Primitive.Cos, [x]);
392
394
  }
395
+ function asin$1(x) {
396
+ return bind1(Primitive.Asin, [x]);
397
+ }
398
+ function atan$1(x) {
399
+ return bind1(Primitive.Atan, [x]);
400
+ }
393
401
  function exp$1(x) {
394
402
  return bind1(Primitive.Exp, [x]);
395
403
  }
@@ -405,18 +413,16 @@ function min$1(x, y) {
405
413
  function max$1(x, y) {
406
414
  return bind1(Primitive.Max, [x, y]);
407
415
  }
408
- function reduce(x, op, axis, opts) {
416
+ function reduce(x, op, axis = null, opts) {
409
417
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
410
- if (axis === void 0) if (x instanceof Tracer) axis = range(x.shape.length);
411
- else axis = [];
412
- else if (typeof axis === "number") axis = [checkAxis(axis, ndim$1(x))];
413
- else axis = axis.map((a) => checkAxis(a, ndim$1(x)));
418
+ axis = normalizeAxis(axis, ndim$1(x));
414
419
  const originalShape = getShape(x);
415
- const result = bind1(Primitive.Reduce, [x], {
420
+ let result = bind1(Primitive.Reduce, [x], {
416
421
  op,
417
422
  axis
418
423
  });
419
- return opts?.keepDims ? broadcast(result, originalShape, axis) : result;
424
+ if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
425
+ return result;
420
426
  }
421
427
  function dot$1(x, y) {
422
428
  return bind1(Primitive.Dot, [x, y]);
@@ -462,10 +468,11 @@ function where$1(cond, x, y) {
462
468
  }
463
469
  function transpose$1(x, perm) {
464
470
  perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
471
+ if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
465
472
  return bind1(Primitive.Transpose, [x], { perm });
466
473
  }
467
474
  function broadcast(x, shape$1, axis) {
468
- axis = axis.map((a) => checkAxis(a, shape$1.length));
475
+ axis = normalizeAxis(axis, shape$1.length);
469
476
  return bind1(Primitive.Broadcast, [x], {
470
477
  shape: shape$1,
471
478
  axis
@@ -484,7 +491,7 @@ function reshape$1(x, shape$1) {
484
491
  return bind1(Primitive.Reshape, [x], { shape: shape$1 });
485
492
  }
486
493
  function flip$1(x, axis) {
487
- axis = axis.map((a) => checkAxis(a, ndim$1(x)));
494
+ axis = normalizeAxis(axis, ndim$1(x));
488
495
  return bind1(Primitive.Flip, [x], { axis });
489
496
  }
490
497
  function shrink(x, slice) {
@@ -558,21 +565,49 @@ var Trace = class {
558
565
  this.main = main;
559
566
  }
560
567
  };
568
+ /**
569
+ * Broadcast shapes and promote types with casting for two avals.
570
+ *
571
+ * This implements the weak type behavior described in `promoteTypes()`, but not
572
+ * implemented in that function as `weakType` is not passed.
573
+ */
574
+ function promoteAvals(a, b) {
575
+ const shape$1 = generalBroadcast(a.shape, b.shape);
576
+ const weakType = a.weakType && b.weakType;
577
+ let dtype;
578
+ if (a.weakType === b.weakType) dtype = promoteTypes(a.dtype, b.dtype);
579
+ else if (a.weakType) dtype = promoteTypes(b.dtype, DType.Uint32);
580
+ else dtype = promoteTypes(a.dtype, DType.Uint32);
581
+ return new ShapedArray(shape$1, dtype, weakType);
582
+ }
561
583
  var Tracer = class Tracer {
562
584
  /** @ignore */
563
585
  _trace;
564
586
  constructor(trace) {
565
587
  this._trace = trace;
566
588
  }
589
+ /** The shape of the array. */
567
590
  get shape() {
568
591
  return this.aval.shape;
569
592
  }
593
+ /** The total number of elements in the array. */
570
594
  get size() {
571
595
  return prod(this.shape);
572
596
  }
597
+ /** The dtype of elements stored in the array. */
573
598
  get dtype() {
574
599
  return this.aval.dtype;
575
600
  }
601
+ /**
602
+ * Whether the array is weakly typed.
603
+ *
604
+ * Weakly typed arrays will cast to the dtype of the other operand. See
605
+ * `promoteTypes()` for details.
606
+ */
607
+ get weakType() {
608
+ return this.aval.weakType;
609
+ }
610
+ /** The number of dimensions of the array. */
576
611
  get ndim() {
577
612
  return this.shape.length;
578
613
  }
@@ -608,22 +643,20 @@ var Tracer = class Tracer {
608
643
  return lessEqual$1(this, other);
609
644
  }
610
645
  /** Sum of the elements of the array over a given axis, or axes. */
611
- sum(axis, opts) {
646
+ sum(axis = null, opts) {
612
647
  return reduce(this, AluOp.Add, axis, opts);
613
648
  }
614
649
  /** Product of the array elements over a given axis. */
615
- prod(axis, opts) {
650
+ prod(axis = null, opts) {
616
651
  return reduce(this, AluOp.Mul, axis, opts);
617
652
  }
618
653
  /** Compute the average of the array elements along the specified axis. */
619
- mean(axis, opts) {
620
- if (axis === void 0) axis = range(this.ndim);
621
- else if (typeof axis === "number") axis = [checkAxis(axis, this.ndim)];
622
- else axis = axis.map((a) => checkAxis(a, this.ndim));
623
- let result = reduce(this, AluOp.Add, axis);
624
- result = result.mul(result.size / this.size);
625
- if (opts?.keepDims) result = broadcast(result, this.shape, axis);
626
- return result;
654
+ mean(axis = null, opts) {
655
+ axis = normalizeAxis(axis, this.ndim);
656
+ const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
657
+ if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
658
+ const result = reduce(this, AluOp.Add, axis, opts);
659
+ return result.mul(1 / n);
627
660
  }
628
661
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
629
662
  transpose(perm) {
@@ -810,12 +843,13 @@ function getShape(x) {
810
843
  return x instanceof Tracer ? x.shape : [];
811
844
  }
812
845
  var ShapedArray = class ShapedArray {
813
- constructor(shape$1, dtype) {
846
+ constructor(shape$1, dtype, weakType) {
814
847
  this.shape = shape$1;
815
848
  this.dtype = dtype;
849
+ this.weakType = weakType;
816
850
  }
817
851
  static fromAval(aval) {
818
- return new ShapedArray(aval.shape, aval.dtype);
852
+ return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
819
853
  }
820
854
  get ndim() {
821
855
  return this.shape.length;
@@ -829,7 +863,7 @@ var ShapedArray = class ShapedArray {
829
863
  };
830
864
  function getAval(x) {
831
865
  if (x instanceof Tracer) return x.aval;
832
- else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32);
866
+ else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32, typeof x === "boolean" ? false : true);
833
867
  else throw new TypeError(`Unknown value: ${x}`);
834
868
  }
835
869
  function bind(prim, args, params = {}) {
@@ -1151,11 +1185,13 @@ const jitRules = {
1151
1185
  const k1 = reshapeViews(keys[1], mapping);
1152
1186
  const c0 = AluExp.u32(0);
1153
1187
  const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
1154
- const exp$2 = AluExp.threefry2x32(c0, c1, k0, k1, mode);
1188
+ const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
1155
1189
  return new Kernel(nargs, prod(shape$1), exp$2);
1156
1190
  },
1157
1191
  [Primitive.Sin]: unopJit(AluExp.sin),
1158
1192
  [Primitive.Cos]: unopJit(AluExp.cos),
1193
+ [Primitive.Asin]: unopJit(AluExp.asin),
1194
+ [Primitive.Atan]: unopJit(AluExp.atan),
1159
1195
  [Primitive.Exp]: unopJit(AluExp.exp),
1160
1196
  [Primitive.Log]: unopJit(AluExp.log),
1161
1197
  [Primitive.Sqrt]: unopJit(AluExp.sqrt),
@@ -1190,7 +1226,7 @@ const jitRules = {
1190
1226
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
1191
1227
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
1192
1228
  const c = k1.exp;
1193
- const cs = new ShapedArray(generalBroadcast(as.shape, bs.shape), c.dtype);
1229
+ const cs = promoteAvals(as, bs);
1194
1230
  return jitRules[Primitive.Reduce](nargs, [c], [cs], {
1195
1231
  op: AluOp.Add,
1196
1232
  axis: [cs.ndim - 1]
@@ -1200,8 +1236,8 @@ const jitRules = {
1200
1236
  const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
1201
1237
  a = reshapeViews(a, (st) => st.compose(stX));
1202
1238
  b = reshapeViews(b, (st) => st.compose(stY));
1203
- as = new ShapedArray(stX.shape, as.dtype);
1204
- bs = new ShapedArray(stY.shape, bs.dtype);
1239
+ as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1240
+ bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1205
1241
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1206
1242
  },
1207
1243
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
@@ -1254,9 +1290,10 @@ function splitGraphDataflow(backend, jaxpr) {
1254
1290
  Primitive.Conv,
1255
1291
  Primitive.PoolTranspose
1256
1292
  ];
1293
+ const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
1257
1294
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
1258
1295
  const eqn = jaxpr.eqns[i];
1259
- if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1296
+ if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
1260
1297
  for (const v of eqn.outBinders) {
1261
1298
  blackNodes.add(v);
1262
1299
  p1NextBlack.set(v, v);
@@ -1386,6 +1423,7 @@ var Array$1 = class Array$1 extends Tracer {
1386
1423
  static #nextId = 1001;
1387
1424
  id;
1388
1425
  #dtype;
1426
+ #weakType;
1389
1427
  #source;
1390
1428
  #st;
1391
1429
  #backend;
@@ -1397,19 +1435,22 @@ var Array$1 = class Array$1 extends Tracer {
1397
1435
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1398
1436
  * will be freed when the array is disposed.
1399
1437
  */
1400
- constructor(source, st, dtype, backend, pending = null) {
1438
+ constructor(args) {
1401
1439
  super(baseArrayTrace);
1402
1440
  this.id = Array$1.#nextId++;
1403
- this.#dtype = dtype;
1404
- this.#source = source;
1405
- this.#st = st;
1406
- this.#backend = backend;
1441
+ this.#dtype = args.dtype;
1442
+ this.#weakType = args.weakType;
1443
+ this.#source = args.source;
1444
+ this.#st = args.st;
1445
+ this.#backend = args.backend;
1407
1446
  this.#rc = 1;
1408
- this.#pendingSet = new Set(pending);
1447
+ this.#pendingSet = new Set(args.pending);
1448
+ if (this.#pendingSet.size === 0) this.#pendingSet = null;
1449
+ else if (this.#source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1409
1450
  }
1410
1451
  /** @ignore */
1411
1452
  get aval() {
1412
- return new ShapedArray(this.#st.shape, this.#dtype);
1453
+ return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
1413
1454
  }
1414
1455
  /** Return a simple string representation of the array's dimensions. */
1415
1456
  toString() {
@@ -1421,6 +1462,17 @@ var Array$1 = class Array$1 extends Tracer {
1421
1462
  #check() {
1422
1463
  if (this.#rc <= 0) throw new UseAfterFreeError(this);
1423
1464
  }
1465
+ /** Construct an array, copying fields from `this`. */
1466
+ #newArrayFrom(args) {
1467
+ return new Array$1({
1468
+ source: args.source ?? this.#source,
1469
+ st: args.st ?? this.#st,
1470
+ dtype: args.dtype ?? this.#dtype,
1471
+ weakType: this.#weakType,
1472
+ backend: args.backend ?? this.#backend,
1473
+ pending: args.pending ?? this.#pending ?? void 0
1474
+ });
1475
+ }
1424
1476
  get ref() {
1425
1477
  this.#check();
1426
1478
  this.#rc++;
@@ -1460,7 +1512,10 @@ var Array$1 = class Array$1 extends Tracer {
1460
1512
  const pending = this.#pending;
1461
1513
  for (const exe of pending) exe.updateRc(1);
1462
1514
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1463
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
1515
+ const ar = this.#newArrayFrom({
1516
+ st,
1517
+ pending
1518
+ });
1464
1519
  this.dispose();
1465
1520
  return ar;
1466
1521
  }
@@ -1509,7 +1564,11 @@ var Array$1 = class Array$1 extends Tracer {
1509
1564
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1510
1565
  this.dispose();
1511
1566
  for (const ar of indices) ar.dispose();
1512
- return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
1567
+ return this.#newArrayFrom({
1568
+ source: output,
1569
+ st: ShapeTracker.fromShape(finalShape),
1570
+ pending
1571
+ });
1513
1572
  }
1514
1573
  /** Move axes to the rightmost dimension of the shape. */
1515
1574
  #moveAxesDown(axis) {
@@ -1532,11 +1591,16 @@ var Array$1 = class Array$1 extends Tracer {
1532
1591
  return this.#reshape(this.#st.permute(perm));
1533
1592
  }
1534
1593
  #unary(op, dtypeOutput) {
1594
+ const weakType = !dtypeOutput && this.#weakType;
1535
1595
  dtypeOutput ??= this.#dtype;
1536
1596
  this.#check();
1537
1597
  if (this.#source instanceof AluExp) {
1538
1598
  const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
1539
- return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1599
+ return this.#newArrayFrom({
1600
+ source: exp$3.simplify(),
1601
+ dtype: dtypeOutput,
1602
+ weakType
1603
+ });
1540
1604
  }
1541
1605
  const indices = unravelAlu(this.#st.shape, AluVar.gidx);
1542
1606
  const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1546,41 +1610,65 @@ var Array$1 = class Array$1 extends Tracer {
1546
1610
  for (const exe of pending) exe.updateRc(1);
1547
1611
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1548
1612
  this.dispose();
1549
- return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
1613
+ return this.#newArrayFrom({
1614
+ source: output,
1615
+ st: ShapeTracker.fromShape(this.shape),
1616
+ dtype: dtypeOutput,
1617
+ weakType,
1618
+ pending
1619
+ });
1550
1620
  }
1551
1621
  #binary(op, other) {
1552
- const custom = (src) => new AluExp(op, this.#dtype, src);
1622
+ const custom = (src) => new AluExp(op, src[0].dtype, src);
1553
1623
  return Array$1.#naryCustom(op, custom, [this, other]);
1554
1624
  }
1555
- static #naryCustom(name, custom, arrays, { dtypeOverride, dtypeOutput, reduceAxis } = {}) {
1625
+ static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1556
1626
  const n = arrays.length;
1557
1627
  const backend = arrays[0].#backend;
1558
1628
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1559
1629
  for (const ar of arrays) ar.#check();
1560
- let dtype;
1630
+ let castDtype;
1631
+ let castWeakType = true;
1561
1632
  for (let i = 0; i < n; i++) {
1562
1633
  if (dtypeOverride?.[i]) {
1563
1634
  if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1564
- } else if (!dtype) dtype = arrays[i].#dtype;
1565
- else if (arrays[i].#dtype !== dtype) throw new TypeError(`Dtype mismatch in ${name}: ${dtype} vs ${arrays[i].#dtype}`);
1635
+ } else if (castDtype === void 0) {
1636
+ castDtype = arrays[i].#dtype;
1637
+ castWeakType = arrays[i].#weakType;
1638
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1566
1639
  if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1567
1640
  }
1568
- dtypeOutput ??= dtype;
1569
- if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1641
+ const weakType = castWeakType && !strongTypeOutput;
1570
1642
  arrays = Array$1.#broadcastArrays(arrays);
1571
1643
  const newShape = [...arrays[0].shape];
1572
1644
  if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
1645
+ const sources = arrays.map((ar, i) => {
1646
+ if (!dtypeOverride?.[i]) return AluExp.cast(castDtype, ar.#source);
1647
+ else return ar.#source;
1648
+ });
1573
1649
  if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
1574
- const exp$4 = custom(arrays.map((ar) => ar.#source));
1575
- return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1650
+ const exp$4 = custom(sources);
1651
+ return new Array$1({
1652
+ source: exp$4.simplify(),
1653
+ st: arrays[0].#st,
1654
+ dtype: exp$4.dtype,
1655
+ weakType,
1656
+ backend
1657
+ });
1576
1658
  }
1577
- const exp$3 = custom(arrays.map((ar) => {
1578
- const src$1 = ar.#source;
1659
+ const exp$3 = custom(arrays.map((ar, i) => {
1660
+ const src$1 = sources[i];
1579
1661
  if (ar.#st.contiguous) return src$1;
1580
1662
  return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
1581
1663
  }));
1582
1664
  const st = ShapeTracker.fromShape(newShape);
1583
- return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1665
+ return new Array$1({
1666
+ source: exp$3.simplify(),
1667
+ st,
1668
+ dtype: exp$3.dtype,
1669
+ weakType,
1670
+ backend
1671
+ });
1584
1672
  }
1585
1673
  let indices;
1586
1674
  if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
@@ -1590,14 +1678,19 @@ var Array$1 = class Array$1 extends Tracer {
1590
1678
  }
1591
1679
  const inputs = [];
1592
1680
  const src = [];
1593
- for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#source, ar.#st, indices));
1594
- else {
1595
- let gid = inputs.indexOf(ar.#source);
1596
- if (gid === -1) {
1597
- gid = inputs.length;
1598
- inputs.push(ar.#source);
1681
+ for (const [i, ar] of arrays.entries()) {
1682
+ let nextSrc;
1683
+ if (ar.#source instanceof AluExp) nextSrc = accessorAluExp(ar.#source, ar.#st, indices);
1684
+ else {
1685
+ let gid = inputs.indexOf(ar.#source);
1686
+ if (gid === -1) {
1687
+ gid = inputs.length;
1688
+ inputs.push(ar.#source);
1689
+ }
1690
+ nextSrc = AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
1599
1691
  }
1600
- src.push(AluExp.globalView(ar.#dtype, gid, ar.#st, indices));
1692
+ if (!dtypeOverride?.[i]) nextSrc = AluExp.cast(castDtype, nextSrc);
1693
+ src.push(nextSrc);
1601
1694
  }
1602
1695
  const exp$2 = custom(src);
1603
1696
  let re = void 0;
@@ -1611,12 +1704,17 @@ var Array$1 = class Array$1 extends Tracer {
1611
1704
  for (const exe of pending) exe.updateRc(1);
1612
1705
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1613
1706
  for (const ar of arrays) ar.dispose();
1614
- return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
1707
+ return new Array$1({
1708
+ source: output,
1709
+ st: ShapeTracker.fromShape(newShape),
1710
+ dtype: kernel.dtype,
1711
+ weakType,
1712
+ backend,
1713
+ pending
1714
+ });
1615
1715
  }
1616
1716
  /** Reduce the last dimension of the array by an operation. */
1617
1717
  #reduce(op) {
1618
- this.#check();
1619
- if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
1620
1718
  const shape$1 = this.shape;
1621
1719
  const reduction = new Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
1622
1720
  const newShape = shape$1.slice(0, -1);
@@ -1635,7 +1733,11 @@ var Array$1 = class Array$1 extends Tracer {
1635
1733
  for (const exe of pending) exe.updateRc(1);
1636
1734
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1637
1735
  this.dispose();
1638
- return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
1736
+ return this.#newArrayFrom({
1737
+ source: output,
1738
+ st: ShapeTracker.fromShape(newShape),
1739
+ pending
1740
+ });
1639
1741
  }
1640
1742
  /**
1641
1743
  * Normalizes this array into one backed by a `Slot`.
@@ -1671,8 +1773,8 @@ var Array$1 = class Array$1 extends Tracer {
1671
1773
  }
1672
1774
  #dataInline() {
1673
1775
  this.#check();
1674
- const exp$2 = this.#source;
1675
- const ar = new Array$1(exp$2, this.#st, this.dtype, getBackend("cpu"));
1776
+ if (!(this.#source instanceof AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
1777
+ const ar = this.#newArrayFrom({ backend: getBackend("cpu") });
1676
1778
  this.dispose();
1677
1779
  return ar.dataSync();
1678
1780
  }
@@ -1708,8 +1810,11 @@ var Array$1 = class Array$1 extends Tracer {
1708
1810
  *
1709
1811
  * If you are mapping from `data()` or `dataSync()`, it will also trigger
1710
1812
  * dispatch of operations as well.
1813
+ *
1814
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1815
+ * asynchronously for multiple arrays.
1711
1816
  */
1712
- async wait() {
1817
+ async blockUntilReady() {
1713
1818
  this.#check();
1714
1819
  if (this.#source instanceof AluExp) return this;
1715
1820
  const pending = this.#pending;
@@ -1775,7 +1880,7 @@ var Array$1 = class Array$1 extends Tracer {
1775
1880
  return [x.#binary(AluOp.Idiv, y)];
1776
1881
  },
1777
1882
  [Primitive.Neg]([x]) {
1778
- return [zerosLike(x.ref).#binary(AluOp.Sub, x)];
1883
+ return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
1779
1884
  },
1780
1885
  [Primitive.Reciprocal]([x]) {
1781
1886
  return [x.#unary(AluOp.Reciprocal)];
@@ -1795,7 +1900,11 @@ var Array$1 = class Array$1 extends Tracer {
1795
1900
  x.#backend.incRef(x.#source);
1796
1901
  const pending = x.#pending;
1797
1902
  for (const exe of pending) exe.updateRc(1);
1798
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
1903
+ const y = x.#newArrayFrom({
1904
+ dtype,
1905
+ weakType: false,
1906
+ pending
1907
+ });
1799
1908
  x.dispose();
1800
1909
  return [y];
1801
1910
  }
@@ -1825,6 +1934,12 @@ var Array$1 = class Array$1 extends Tracer {
1825
1934
  [Primitive.Cos]([x]) {
1826
1935
  return [x.#unary(AluOp.Cos)];
1827
1936
  },
1937
+ [Primitive.Asin]([x]) {
1938
+ return [x.#unary(AluOp.Asin)];
1939
+ },
1940
+ [Primitive.Atan]([x]) {
1941
+ return [x.#unary(AluOp.Atan)];
1942
+ },
1828
1943
  [Primitive.Exp]([x]) {
1829
1944
  return [x.#unary(AluOp.Exp)];
1830
1945
  },
@@ -1864,7 +1979,7 @@ var Array$1 = class Array$1 extends Tracer {
1864
1979
  },
1865
1980
  [Primitive.Compare]([x, y], { op }) {
1866
1981
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1867
- return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: DType.Bool })];
1982
+ return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
1868
1983
  },
1869
1984
  [Primitive.Where]([cond, x, y]) {
1870
1985
  const custom = ([cond$1, x$1, y$1]) => AluExp.where(cond$1, x$1, y$1);
@@ -1910,7 +2025,14 @@ var Array$1 = class Array$1 extends Tracer {
1910
2025
  pending.splice(0, 0, ...prevPending);
1911
2026
  args.forEach((x) => x.dispose());
1912
2027
  return outputs.map((source, i) => {
1913
- return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
2028
+ return new Array$1({
2029
+ source,
2030
+ st: ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
2031
+ dtype: jaxpr.outs[i].aval.dtype,
2032
+ weakType: jaxpr.outs[i].aval.weakType,
2033
+ backend,
2034
+ pending
2035
+ });
1914
2036
  });
1915
2037
  }
1916
2038
  };
@@ -1920,33 +2042,11 @@ var Array$1 = class Array$1 extends Tracer {
1920
2042
  return this.#source;
1921
2043
  }
1922
2044
  };
1923
- /** Construct an array from a single scalar constant. */
1924
- function scalar(value, { dtype, device } = {}) {
1925
- if (typeof value === "number") {
1926
- dtype ??= DType.Float32;
1927
- if (![
1928
- DType.Float32,
1929
- DType.Float16,
1930
- DType.Int32,
1931
- DType.Uint32
1932
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1933
- } else if (typeof value === "boolean") {
1934
- dtype ??= DType.Bool;
1935
- if (![
1936
- DType.Float32,
1937
- DType.Float16,
1938
- DType.Int32,
1939
- DType.Uint32,
1940
- DType.Bool
1941
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1942
- } else throw new TypeError(`Invalid type for scalar ${value}`);
1943
- return new Array$1(AluExp.const(dtype, value), ShapeTracker.fromShape([]), dtype, getBackend(device));
1944
- }
1945
2045
  /** Constructor for creating a new array from data. */
1946
2046
  function array(values, { shape: shape$1, dtype, device } = {}) {
1947
2047
  if (values instanceof Tracer) {
1948
2048
  if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
1949
- if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
2049
+ if (dtype && values.dtype !== dtype) values = values.astype(dtype);
1950
2050
  return values;
1951
2051
  } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
1952
2052
  dtype,
@@ -1968,6 +2068,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1968
2068
  dtype,
1969
2069
  device
1970
2070
  });
2071
+ if (size$1 === 1) return full(shape$1, flat[0], {
2072
+ dtype,
2073
+ device
2074
+ });
1971
2075
  if (typeof flat[0] === "boolean") {
1972
2076
  dtype = dtype ?? DType.Bool;
1973
2077
  const data = new Int32Array(flat.map((x) => x ? 1 : 0));
@@ -1976,46 +2080,51 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1976
2080
  device
1977
2081
  });
1978
2082
  } else {
2083
+ const weakType = dtype == void 0;
1979
2084
  dtype = dtype ?? DType.Float32;
1980
2085
  const data = dtypedJsArray(dtype, flat);
1981
2086
  return arrayFromData(data, shape$1, {
1982
2087
  dtype,
1983
2088
  device
1984
- });
2089
+ }, weakType);
1985
2090
  }
1986
2091
  }
1987
2092
  }
1988
- function arrayFromData(data, shape$1, { dtype, device } = {}) {
2093
+ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2094
+ if (data instanceof Float32Array) {
2095
+ if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
2096
+ dtype ??= DType.Float32;
2097
+ } else if (data instanceof Int32Array) {
2098
+ if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2099
+ dtype ??= DType.Int32;
2100
+ } else if (data instanceof Uint32Array) {
2101
+ if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2102
+ dtype ??= DType.Uint32;
2103
+ } else if (data instanceof Float16Array) {
2104
+ if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
2105
+ dtype ??= DType.Float16;
2106
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
1989
2107
  if (data.length < inlineArrayLimit) {
1990
2108
  let allEqual = true;
1991
2109
  for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
1992
2110
  allEqual = false;
1993
2111
  break;
1994
2112
  }
1995
- if (allEqual) return full(shape$1, data[0], {
1996
- dtype,
1997
- device
1998
- });
2113
+ if (allEqual) {
2114
+ const sa = new ShapedArray(shape$1, dtype, weakType);
2115
+ return fullInternal(sa, data[0], device);
2116
+ }
1999
2117
  }
2000
2118
  const backend = getBackend(device);
2001
- if (ArrayBuffer.isView(data)) {
2002
- const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2003
- if (data instanceof Float32Array) {
2004
- if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
2005
- dtype ??= DType.Float32;
2006
- } else if (data instanceof Int32Array) {
2007
- if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2008
- dtype ??= DType.Int32;
2009
- } else if (data instanceof Uint32Array) {
2010
- if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2011
- dtype ??= DType.Uint32;
2012
- } else if (data instanceof Float16Array) {
2013
- if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
2014
- dtype ??= DType.Float16;
2015
- } else throw new Error("Unsupported data array type: " + data.constructor.name);
2016
- const slot = backend.malloc(data.byteLength, buf);
2017
- return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
2018
- } else throw new Error("Unsupported data type: " + data.constructor.name);
2119
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2120
+ const slot = backend.malloc(data.byteLength, buf);
2121
+ return new Array$1({
2122
+ source: slot,
2123
+ st: ShapeTracker.fromShape(shape$1),
2124
+ dtype,
2125
+ weakType,
2126
+ backend
2127
+ });
2019
2128
  }
2020
2129
  function dataToJs(dtype, data, shape$1) {
2021
2130
  if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
@@ -2031,7 +2140,7 @@ function dataToJs(dtype, data, shape$1) {
2031
2140
  /** If x is a value, lift it into an array, otherwise leave it be. */
2032
2141
  function pureArray(x) {
2033
2142
  if (x instanceof Tracer) return x;
2034
- else return scalar(x);
2143
+ else return array(x);
2035
2144
  }
2036
2145
  var EvalTrace = class extends Trace {
2037
2146
  pure = (x) => pureArray(x);
@@ -2042,20 +2151,27 @@ var EvalTrace = class extends Trace {
2042
2151
  };
2043
2152
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
2044
2153
  const implRules = Array$1._implRules();
2045
- function zerosLike(val, dtype) {
2046
- const aval = getAval(val);
2047
- if (val instanceof Tracer) val.dispose();
2048
- return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2154
+ function fullInternal(aval, fillValue, device) {
2155
+ return new Array$1({
2156
+ source: AluExp.const(aval.dtype, fillValue),
2157
+ st: ShapeTracker.fromShape(aval.shape),
2158
+ dtype: aval.dtype,
2159
+ weakType: aval.weakType,
2160
+ backend: getBackend(device)
2161
+ });
2049
2162
  }
2050
- function onesLike(val, dtype) {
2051
- const aval = getAval(val);
2052
- if (val instanceof Tracer) val.dispose();
2053
- return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2163
+ function zerosLike$1(val, dtype) {
2164
+ return fullLike(val, 0, dtype);
2165
+ }
2166
+ function onesLike$1(val, dtype) {
2167
+ return fullLike(val, 1, dtype);
2054
2168
  }
2055
2169
  function fullLike(val, fillValue, dtype) {
2056
2170
  const aval = getAval(val);
2057
2171
  if (val instanceof Tracer) val.dispose();
2058
- return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
2172
+ if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
2173
+ const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
2174
+ return fullInternal(sa, fillValue);
2059
2175
  }
2060
2176
  /** Return a new array of given shape and type, filled with zeros. */
2061
2177
  function zeros(shape$1, { dtype, device } = {}) {
@@ -2073,19 +2189,14 @@ function ones(shape$1, { dtype, device } = {}) {
2073
2189
  }
2074
2190
  /** Return a new array of given shape and type, filled with `fill_value`. */
2075
2191
  function full(shape$1, fillValue, { dtype, device } = {}) {
2076
- let source;
2077
- if (typeof fillValue === "number") {
2078
- dtype = dtype ?? DType.Float32;
2079
- source = AluExp.const(dtype, fillValue);
2080
- } else if (typeof fillValue === "bigint") {
2081
- dtype = dtype ?? DType.Int32;
2082
- source = AluExp.const(dtype, Number(fillValue));
2083
- } else if (typeof fillValue === "boolean") {
2192
+ let weakType = dtype == void 0;
2193
+ if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
2194
+ else if (typeof fillValue === "boolean") {
2084
2195
  dtype = dtype ?? DType.Bool;
2085
- source = AluExp.const(dtype, fillValue ? 1 : 0);
2196
+ weakType = false;
2086
2197
  } else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
2087
2198
  else throw new TypeError(`Invalid type for full: ${fillValue}`);
2088
- return new Array$1(source, ShapeTracker.fromShape(shape$1), dtype ?? DType.Float32, getBackend(device));
2199
+ return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
2089
2200
  }
2090
2201
  /**
2091
2202
  * Create an identity matrix.
@@ -2095,6 +2206,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
2095
2206
  */
2096
2207
  function eye(numRows, numCols, { dtype, device } = {}) {
2097
2208
  numCols = numCols ?? numRows;
2209
+ const weakType = dtype == void 0;
2098
2210
  dtype = dtype ?? DType.Float32;
2099
2211
  if (numCols < numRows) {
2100
2212
  const arr = eye(numCols, numRows, {
@@ -2108,9 +2220,15 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2108
2220
  device
2109
2221
  });
2110
2222
  const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
2111
- return new Array$1(AluExp.cast(dtype, exp$2), ShapeTracker.fromShape([numRows, numCols]), dtype, getBackend(device));
2223
+ return new Array$1({
2224
+ source: AluExp.cast(dtype, exp$2),
2225
+ st: ShapeTracker.fromShape([numRows, numCols]),
2226
+ dtype,
2227
+ weakType,
2228
+ backend: getBackend(device)
2229
+ });
2112
2230
  }
2113
- /** Return the identity array, with ones on the main diagonal. */
2231
+ /** Return the identity matrix, with ones on the main diagonal. */
2114
2232
  function identity$1(n, { dtype, device } = {}) {
2115
2233
  return eye(n, n, {
2116
2234
  dtype,
@@ -2145,7 +2263,13 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2145
2263
  });
2146
2264
  const exp$2 = AluExp.add(AluExp.const(dtype, start), AluExp.mul(AluExp.cast(dtype, AluVar.idx), AluExp.const(dtype, step)));
2147
2265
  const st = ShapeTracker.fromShape([size$1]);
2148
- return new Array$1(exp$2, st, dtype, getBackend(device));
2266
+ return new Array$1({
2267
+ source: exp$2,
2268
+ st,
2269
+ dtype,
2270
+ weakType: false,
2271
+ backend: getBackend(device)
2272
+ });
2149
2273
  }
2150
2274
  /**
2151
2275
  * Return evenly spaced numbers over a specified interval.
@@ -2163,10 +2287,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2163
2287
  dtype,
2164
2288
  device
2165
2289
  });
2166
- else if (num === 1) return scalar(start, {
2290
+ else if (num === 1) return full([1], start, {
2167
2291
  dtype,
2168
2292
  device
2169
- }).reshape([1]);
2293
+ });
2170
2294
  else if (start === stop) return full([num], start, {
2171
2295
  dtype,
2172
2296
  device
@@ -2175,7 +2299,13 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2175
2299
  const denom = endpoint ? num - 1 : num;
2176
2300
  const exp$2 = AluExp.cast(dtype, AluExp.add(AluExp.f32(start), AluExp.mul(AluExp.f32(delta / denom), AluExp.cast(DType.Float32, AluVar.idx))));
2177
2301
  const st = ShapeTracker.fromShape([num]);
2178
- return new Array$1(exp$2, st, dtype, getBackend(device));
2302
+ return new Array$1({
2303
+ source: exp$2,
2304
+ st,
2305
+ dtype,
2306
+ weakType: false,
2307
+ backend: getBackend(device)
2308
+ });
2179
2309
  }
2180
2310
  function aluCompare(a, b, op) {
2181
2311
  switch (op) {
@@ -2187,35 +2317,6 @@ function aluCompare(a, b, op) {
2187
2317
  case CompareOp.LessEqual: return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
2188
2318
  }
2189
2319
  }
2190
- /**
2191
- * Implements a NumPy-style generalized broadcast rule on two array shapes.
2192
- *
2193
- * "When operating on two arrays, NumPy compares their shapes element-wise. It
2194
- * starts with the trailing (i.e. rightmost) dimension and works its way left.
2195
- * Two dimensions are compatible when:
2196
- * 1. they are equal, or
2197
- * 2. one of them is 1."
2198
- *
2199
- * Throws a TypeError if the broadcast is not possible.
2200
- *
2201
- * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
2202
- */
2203
- function generalBroadcast(a, b) {
2204
- const out = [];
2205
- let i = a.length - 1;
2206
- let j = b.length - 1;
2207
- for (; i >= 0 && j >= 0; i--, j--) {
2208
- const x = a[i];
2209
- const y = b[j];
2210
- if (x === y) out.push(x);
2211
- else if (x === 1) out.push(y);
2212
- else if (y === 1) out.push(x);
2213
- else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
2214
- }
2215
- for (; i >= 0; i--) out.push(a[i]);
2216
- for (; j >= 0; j--) out.push(b[j]);
2217
- return out.reverse();
2218
- }
2219
2320
 
2220
2321
  //#endregion
2221
2322
  //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
@@ -2291,13 +2392,15 @@ var Var = class Var {
2291
2392
  };
2292
2393
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
2293
2394
  var Lit = class {
2294
- dtype;
2295
2395
  value;
2296
2396
  aval;
2297
- constructor(dtype, value) {
2298
- this.dtype = dtype;
2397
+ get dtype() {
2398
+ return this.aval.dtype;
2399
+ }
2400
+ constructor(aval, value) {
2401
+ if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
2299
2402
  this.value = value;
2300
- this.aval = new ShapedArray([], dtype);
2403
+ this.aval = ShapedArray.fromAval(aval);
2301
2404
  }
2302
2405
  };
2303
2406
  function atomIsLit(atom, literal) {
@@ -2386,16 +2489,19 @@ var Jaxpr = class Jaxpr {
2386
2489
  varIds.set(v, FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
2387
2490
  return id;
2388
2491
  };
2389
- hasher.update(this.inBinders.length, ...this.inBinders.map(vi));
2390
- hasher.update(this.eqns.length, ...this.eqns.flatMap((eqn) => [
2391
- eqn.primitive,
2392
- eqn.inputs.length,
2393
- ...eqn.inputs.map((x) => x instanceof Var ? vi(x) : x.value),
2394
- JSON.stringify(eqn.params),
2395
- eqn.outBinders.length,
2396
- ...eqn.outBinders.map(vi)
2397
- ]));
2398
- hasher.update(this.outs.length, ...this.outs.map((x) => x instanceof Var ? vi(x) : x.value));
2492
+ hasher.update(this.inBinders.length);
2493
+ for (const x of this.inBinders) hasher.update(vi(x));
2494
+ hasher.update(this.eqns.length);
2495
+ for (const eqn of this.eqns) {
2496
+ hasher.update(eqn.primitive);
2497
+ hasher.update(eqn.inputs.length);
2498
+ for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
2499
+ hasher.update(JSON.stringify(eqn.params));
2500
+ hasher.update(eqn.outBinders.length);
2501
+ for (const x of eqn.outBinders) hasher.update(vi(x));
2502
+ }
2503
+ hasher.update(this.outs.length);
2504
+ for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
2399
2505
  return this.#hash = hasher.value;
2400
2506
  }
2401
2507
  hash(state) {
@@ -2418,21 +2524,26 @@ var Jaxpr = class Jaxpr {
2418
2524
  const c = eqn.outBinders[0];
2419
2525
  if (atomIsLit(a, 0)) context.set(c, b);
2420
2526
  else if (atomIsLit(b, 0)) context.set(c, a);
2421
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.dtype === DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2527
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2528
+ else newEqns.push(eqn);
2529
+ } else if (eqn.primitive === Primitive.Neg) {
2530
+ const [a] = inputs;
2531
+ const c = eqn.outBinders[0];
2532
+ if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
2422
2533
  else newEqns.push(eqn);
2423
2534
  } else if (eqn.primitive === Primitive.Mul) {
2424
2535
  const [a, b] = inputs;
2425
2536
  const c = eqn.outBinders[0];
2426
2537
  if (atomIsLit(a, 1)) context.set(c, b);
2427
2538
  else if (atomIsLit(b, 1)) context.set(c, a);
2428
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.value * b.value));
2539
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
2429
2540
  else newEqns.push(eqn);
2430
2541
  } else if (eqn.primitive === Primitive.Idiv) {
2431
2542
  const [a, b] = inputs;
2432
2543
  const c = eqn.outBinders[0];
2433
2544
  if (atomIsLit(b, 1)) context.set(c, a);
2434
2545
  else newEqns.push(eqn);
2435
- } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2546
+ } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2436
2547
  else newEqns.push(eqn);
2437
2548
  }
2438
2549
  const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
@@ -2523,7 +2634,7 @@ function evalJaxpr(jaxpr, args) {
2523
2634
  if (x instanceof Var) {
2524
2635
  remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
2525
2636
  return env.get(x);
2526
- } else return scalar(x.value, { dtype: x.dtype });
2637
+ } else return array(x.value, { dtype: x.dtype });
2527
2638
  };
2528
2639
  const write = (v, val) => {
2529
2640
  if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
@@ -2582,7 +2693,7 @@ var JaxprTrace = class extends Trace {
2582
2693
  let tracer = this.builder.constTracers.get(val);
2583
2694
  if (tracer === void 0) {
2584
2695
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
2585
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : scalar(val));
2696
+ this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
2586
2697
  }
2587
2698
  return tracer;
2588
2699
  }
@@ -2651,7 +2762,7 @@ function _inlineLiterals(jaxpr, consts) {
2651
2762
  const newConsts = [];
2652
2763
  for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
2653
2764
  const ar = consts[i];
2654
- literals.set(jaxpr.inBinders[i], new Lit(ar.dtype, ar.dataSync()[0]));
2765
+ literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
2655
2766
  } else {
2656
2767
  constBinders.push(jaxpr.inBinders[i]);
2657
2768
  newConsts.push(consts[i]);
@@ -2664,13 +2775,12 @@ function _inlineLiterals(jaxpr, consts) {
2664
2775
  }
2665
2776
  function binopAbstractEval([x, y]) {
2666
2777
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
2667
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2668
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
2778
+ return [promoteAvals(x, y)];
2669
2779
  }
2670
2780
  function compareAbstractEval([x, y]) {
2671
2781
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
2672
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2673
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), DType.Bool)];
2782
+ const aval = promoteAvals(x, y);
2783
+ return [new ShapedArray(aval.shape, DType.Bool, false)];
2674
2784
  }
2675
2785
  function vectorizedUnopAbstractEval([x]) {
2676
2786
  return [ShapedArray.fromAval(x)];
@@ -2683,21 +2793,23 @@ const abstractEvalRules = {
2683
2793
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
2684
2794
  [Primitive.StopGradient]: vectorizedUnopAbstractEval,
2685
2795
  [Primitive.Cast]([x], { dtype }) {
2686
- return [new ShapedArray(x.shape, dtype)];
2796
+ return [new ShapedArray(x.shape, dtype, false)];
2687
2797
  },
2688
2798
  [Primitive.Bitcast]([x], { dtype }) {
2689
2799
  if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
2690
2800
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
2691
- return [new ShapedArray(x.shape, dtype)];
2801
+ return [new ShapedArray(x.shape, dtype, false)];
2692
2802
  },
2693
2803
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
2694
2804
  if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
2695
2805
  const keyShape = generalBroadcast(k0.shape, k1.shape);
2696
2806
  if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2697
- return [new ShapedArray(shape$1, DType.Uint32)];
2807
+ return [new ShapedArray(shape$1, DType.Uint32, false)];
2698
2808
  },
2699
2809
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2700
2810
  [Primitive.Cos]: vectorizedUnopAbstractEval,
2811
+ [Primitive.Asin]: vectorizedUnopAbstractEval,
2812
+ [Primitive.Atan]: vectorizedUnopAbstractEval,
2701
2813
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2702
2814
  [Primitive.Log]: vectorizedUnopAbstractEval,
2703
2815
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
@@ -2706,55 +2818,54 @@ const abstractEvalRules = {
2706
2818
  [Primitive.Reduce]([x], { axis }) {
2707
2819
  const axisSet = new Set(axis);
2708
2820
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2709
- return [new ShapedArray(newShape, x.dtype)];
2821
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2710
2822
  },
2711
2823
  [Primitive.Pool]([x], { window, strides }) {
2712
2824
  const shape$1 = checkPoolShape(x.shape, window, strides);
2713
- return [new ShapedArray(shape$1, x.dtype)];
2825
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2714
2826
  },
2715
2827
  [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2716
2828
  const shape$1 = checkPoolShape(inShape, window, strides);
2717
2829
  if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2718
- return [new ShapedArray(inShape, x.dtype)];
2830
+ return [new ShapedArray(inShape, x.dtype, x.weakType)];
2719
2831
  },
2720
2832
  [Primitive.Dot]([x, y]) {
2721
- if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2722
2833
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
2723
- const shape$1 = generalBroadcast(x.shape, y.shape);
2834
+ const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
2724
2835
  shape$1.splice(-1, 1);
2725
- return [new ShapedArray(shape$1, x.dtype)];
2836
+ return [new ShapedArray(shape$1, dtype, weakType)];
2726
2837
  },
2727
2838
  [Primitive.Conv]([lhs, rhs], params) {
2728
- if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2839
+ const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
2729
2840
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2730
- return [new ShapedArray(shape$1, lhs.dtype)];
2841
+ return [new ShapedArray(shape$1, dtype, weakType)];
2731
2842
  },
2732
2843
  [Primitive.Compare]: compareAbstractEval,
2733
2844
  [Primitive.Where]([cond, x, y]) {
2734
2845
  if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
2735
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2736
- const shape$1 = generalBroadcast(cond.shape, generalBroadcast(x.shape, y.shape));
2737
- return [new ShapedArray(shape$1, x.dtype)];
2846
+ const xy = promoteAvals(x, y);
2847
+ const shape$1 = generalBroadcast(cond.shape, xy.shape);
2848
+ return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
2738
2849
  },
2739
2850
  [Primitive.Transpose]([x], { perm }) {
2740
- return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
2851
+ return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
2741
2852
  },
2742
2853
  [Primitive.Broadcast]([x], { shape: shape$1 }) {
2743
- return [new ShapedArray(shape$1, x.dtype)];
2854
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2744
2855
  },
2745
2856
  [Primitive.Reshape]([x], { shape: shape$1 }) {
2746
- return [new ShapedArray(shape$1, x.dtype)];
2857
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2747
2858
  },
2748
2859
  [Primitive.Flip]([x], _) {
2749
- return [new ShapedArray(x.shape, x.dtype)];
2860
+ return [ShapedArray.fromAval(x)];
2750
2861
  },
2751
2862
  [Primitive.Shrink]([x], { slice }) {
2752
2863
  const newShape = slice.map((s) => s[1] - s[0]);
2753
- return [new ShapedArray(newShape, x.dtype)];
2864
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2754
2865
  },
2755
2866
  [Primitive.Pad]([x], { width }) {
2756
2867
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
2757
- return [new ShapedArray(newShape, x.dtype)];
2868
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2758
2869
  },
2759
2870
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2760
2871
  for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
@@ -2767,7 +2878,7 @@ const abstractEvalRules = {
2767
2878
  const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
2768
2879
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2769
2880
  newShape.splice(outDim, 0, ...gatherShape);
2770
- return [new ShapedArray(newShape, x.dtype)];
2881
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2771
2882
  },
2772
2883
  [Primitive.JitCall](args, { jaxpr }) {
2773
2884
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
@@ -2825,7 +2936,7 @@ function makeJaxpr$1(f, opts) {
2825
2936
  function jit$1(f, opts) {
2826
2937
  const cache = /* @__PURE__ */ new Map();
2827
2938
  const staticArgnums = new Set(opts?.staticArgnums ?? []);
2828
- return ((...args) => {
2939
+ const result = ((...args) => {
2829
2940
  const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
2830
2941
  const [argsFlat, inTree] = flatten(dynamicArgs);
2831
2942
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
@@ -2834,11 +2945,16 @@ function jit$1(f, opts) {
2834
2945
  const cacheKey = JSON.stringify(jaxprArgs);
2835
2946
  const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2836
2947
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
2948
+ name: f.name || "closure",
2837
2949
  jaxpr,
2838
2950
  numConsts: consts.length
2839
2951
  });
2840
2952
  return unflatten(outTree, outs);
2841
2953
  });
2954
+ result.dispose = () => {
2955
+ for (const { consts } of cache.values()) for (const c of consts) c.dispose();
2956
+ };
2957
+ return result;
2842
2958
  }
2843
2959
 
2844
2960
  //#endregion
@@ -2869,7 +2985,7 @@ var JVPTrace = class extends Trace {
2869
2985
  return this.lift(pureArray(val));
2870
2986
  }
2871
2987
  lift(val) {
2872
- return new JVPTracer(this, val, zerosLike(val.ref));
2988
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
2873
2989
  }
2874
2990
  processPrimitive(primitive, tracers, params) {
2875
2991
  const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
@@ -2900,7 +3016,7 @@ function zeroTangentsJvp(primitive) {
2900
3016
  return (primals, tangents, params) => {
2901
3017
  for (const t of tangents) t.dispose();
2902
3018
  const ys = bind(primitive, primals, params);
2903
- return [ys, ys.map((y) => zerosLike(y.ref))];
3019
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
2904
3020
  };
2905
3021
  }
2906
3022
  const jvpRules = {
@@ -2918,13 +3034,13 @@ const jvpRules = {
2918
3034
  if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
2919
3035
  else {
2920
3036
  dx.dispose();
2921
- return [[cast(x.ref, dtype)], [zerosLike(x)]];
3037
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
2922
3038
  }
2923
3039
  },
2924
3040
  [Primitive.Bitcast]([x], [dx], { dtype }) {
2925
3041
  if (x.dtype === dtype) return [[x], [dx]];
2926
3042
  dx.dispose();
2927
- return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
3043
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
2928
3044
  },
2929
3045
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
2930
3046
  [Primitive.Sin]([x], [dx]) {
@@ -2933,6 +3049,14 @@ const jvpRules = {
2933
3049
  [Primitive.Cos]([x], [dx]) {
2934
3050
  return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
2935
3051
  },
3052
+ [Primitive.Asin]([x], [dx]) {
3053
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3054
+ return [[asin$1(x)], [denom.mul(dx)]];
3055
+ },
3056
+ [Primitive.Atan]([x], [dx]) {
3057
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3058
+ return [[atan$1(x)], [dx.div(denom)]];
3059
+ },
2936
3060
  [Primitive.Exp]([x], [dx]) {
2937
3061
  const z = exp$1(x);
2938
3062
  return [[z.ref], [z.mul(dx)]];
@@ -2983,13 +3107,14 @@ const jvpRules = {
2983
3107
  const indicesRef = indices.map((t) => t.ref);
2984
3108
  return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
2985
3109
  },
2986
- [Primitive.JitCall](primals, tangents, { jaxpr }) {
3110
+ [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
2987
3111
  const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
2988
3112
  const outs = bind(Primitive.JitCall, [
2989
3113
  ...newConsts.map((c) => c.ref),
2990
3114
  ...primals,
2991
3115
  ...tangents
2992
3116
  ], {
3117
+ name: `${name}_jvp`,
2993
3118
  jaxpr: newJaxpr,
2994
3119
  numConsts: newConsts.length
2995
3120
  });
@@ -3043,12 +3168,15 @@ function jvp$1(f, primals, tangents) {
3043
3168
  function mappedAval(batchDim, aval) {
3044
3169
  const shape$1 = [...aval.shape];
3045
3170
  shape$1.splice(batchDim, 1);
3046
- return new ShapedArray(shape$1, aval.dtype);
3171
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3047
3172
  }
3048
3173
  /** Move one axis to a different index. */
3049
3174
  function moveaxis$1(x, src, dst) {
3050
3175
  const t = pureArray(x);
3051
- const perm = range(t.shape.length);
3176
+ src = checkAxis(src, t.ndim);
3177
+ dst = checkAxis(dst, t.ndim);
3178
+ if (src === dst) return t;
3179
+ const perm = range(t.ndim);
3052
3180
  perm.splice(src, 1);
3053
3181
  perm.splice(dst, 0, src);
3054
3182
  return transpose$1(t, perm);
@@ -3141,6 +3269,8 @@ const vmapRules = {
3141
3269
  [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3142
3270
  [Primitive.Sin]: unopBatcher(sin$1),
3143
3271
  [Primitive.Cos]: unopBatcher(cos$1),
3272
+ [Primitive.Asin]: unopBatcher(asin$1),
3273
+ [Primitive.Atan]: unopBatcher(atan$1),
3144
3274
  [Primitive.Exp]: unopBatcher(exp$1),
3145
3275
  [Primitive.Log]: unopBatcher(log$1),
3146
3276
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
@@ -3182,9 +3312,10 @@ const vmapRules = {
3182
3312
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3183
3313
  return [[pad$1(x, newWidth)], [xBdim]];
3184
3314
  },
3185
- [Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
3315
+ [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3186
3316
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3187
3317
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3318
+ name: `${name}_vmap`,
3188
3319
  jaxpr: newJaxpr,
3189
3320
  numConsts: newConsts.length
3190
3321
  });
@@ -3200,7 +3331,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3200
3331
  if (dims[i] === null) return v.aval;
3201
3332
  const shape$1 = [...v.aval.shape];
3202
3333
  shape$1.splice(dims[i], 0, axisSize);
3203
- return new ShapedArray(shape$1, v.aval.dtype);
3334
+ return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3204
3335
  });
3205
3336
  const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3206
3337
  const result = {
@@ -3326,20 +3457,28 @@ function linearizeFlatUtil(f, primalsIn) {
3326
3457
  function linearizeFlat(f, primalsIn) {
3327
3458
  const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3328
3459
  const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3329
- return [primalsOut, fLin];
3460
+ const dispose$1 = () => {
3461
+ for (const c of consts) c.dispose();
3462
+ };
3463
+ return [
3464
+ primalsOut,
3465
+ fLin,
3466
+ dispose$1
3467
+ ];
3330
3468
  }
3331
3469
  function linearize$1(f, ...primalsIn) {
3332
3470
  const [primalsInFlat, inTree] = flatten(primalsIn);
3333
3471
  const [fFlat, outTree] = flattenFun(f, inTree);
3334
- const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3472
+ const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3335
3473
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
3336
3474
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3337
- const fLin = (...tangentsIn) => {
3475
+ const fLin = ((...tangentsIn) => {
3338
3476
  const [tangentsInFlat, inTree2] = flatten(tangentsIn);
3339
3477
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
3340
3478
  const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
3341
3479
  return unflatten(outTree.value, tangentsOutFlat);
3342
- };
3480
+ });
3481
+ fLin.dispose = dispose$1;
3343
3482
  return [primalsOut, fLin];
3344
3483
  }
3345
3484
  var PartialEvalTracer = class extends Tracer {
@@ -3405,8 +3544,8 @@ var PartialEvalTrace = class extends Trace {
3405
3544
  processPrimitive(primitive, tracers, params) {
3406
3545
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3407
3546
  if (primitive === Primitive.JitCall) {
3408
- const { jaxpr, numConsts } = params;
3409
- return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
3547
+ const { name, jaxpr, numConsts } = params;
3548
+ return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3410
3549
  }
3411
3550
  const tracersIn = tracers.map((t) => this.instantiateConst(t));
3412
3551
  const avalsIn = tracersIn.map((t) => t.pval.aval);
@@ -3432,12 +3571,13 @@ var PartialEvalTrace = class extends Trace {
3432
3571
  *
3433
3572
  * Used when encountering a JitCall rule during the trace.
3434
3573
  */
3435
- #partialEvalJaxpr(jaxpr, numConsts, tracers) {
3574
+ #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3436
3575
  jaxpr = jaxpr.flatten();
3437
3576
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3438
3577
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3439
3578
  const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
3440
3579
  const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
3580
+ name: `${name}_peval`,
3441
3581
  jaxpr: jaxpr1,
3442
3582
  numConsts: 0
3443
3583
  });
@@ -3449,13 +3589,17 @@ var PartialEvalTrace = class extends Trace {
3449
3589
  prim: Primitive.JitCall,
3450
3590
  tracersIn: resTracers.concat(unknownTracers),
3451
3591
  params: {
3592
+ name: `${name}_resid`,
3452
3593
  jaxpr: jaxpr2,
3453
3594
  numConsts: 0
3454
3595
  },
3455
3596
  avalsOut: jaxpr2.outs.map((x) => x.aval),
3456
3597
  tracerRefsOut: []
3457
3598
  };
3458
- const outs2 = jaxpr2.outs.map((x) => new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe));
3599
+ const outs2 = jaxpr2.outs.map((x, i$1) => {
3600
+ if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
3601
+ return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
3602
+ });
3459
3603
  recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
3460
3604
  let i = 0;
3461
3605
  let j = 0;
@@ -3539,13 +3683,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3539
3683
  const [consts, constvars] = unzip2(constToVar.entries());
3540
3684
  const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
3541
3685
  const outVars = tracersOut.map((t) => tracerToVar.get(t));
3542
- const jaxpr = new Jaxpr(inBinders, eqns, outVars);
3686
+ let jaxpr = new Jaxpr(inBinders, eqns, outVars);
3543
3687
  typecheckJaxpr(jaxpr);
3544
3688
  for (const t of consts) t.ref;
3545
3689
  for (const t of tracersIn) t.dispose();
3546
3690
  for (const t of tracersOut) t.dispose();
3691
+ jaxpr = jaxpr.simplify();
3692
+ if (DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
3547
3693
  return {
3548
- jaxpr: jaxpr.simplify(),
3694
+ jaxpr,
3549
3695
  consts
3550
3696
  };
3551
3697
  }
@@ -3586,7 +3732,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
3586
3732
  }
3587
3733
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
3588
3734
  const eqn = jaxpr.eqns[i];
3589
- const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? scalar(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3735
+ const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3590
3736
  const cotangentsOut = eqn.outBinders.map(readCotangent);
3591
3737
  const rule = transposeRules[eqn.primitive];
3592
3738
  if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
@@ -3766,7 +3912,7 @@ const transposeRules = {
3766
3912
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
3767
3913
  throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
3768
3914
  },
3769
- [Primitive.JitCall](cts, args, { jaxpr }) {
3915
+ [Primitive.JitCall](cts, args, { name, jaxpr }) {
3770
3916
  const undefPrimals = args.map((x) => x instanceof UndefPrimal);
3771
3917
  const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
3772
3918
  const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
@@ -3775,6 +3921,7 @@ const transposeRules = {
3775
3921
  ...residuals,
3776
3922
  ...cts
3777
3923
  ], {
3924
+ name: `${name}_t`,
3778
3925
  jaxpr: newJaxpr,
3779
3926
  numConsts: newConsts.length
3780
3927
  });
@@ -3811,20 +3958,28 @@ function vjpFlat(f, primalsIn) {
3811
3958
  const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
3812
3959
  return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
3813
3960
  };
3814
- return [primalsOut, fVjp];
3961
+ const dispose$1 = () => {
3962
+ for (const c of consts) c.dispose();
3963
+ };
3964
+ return [
3965
+ primalsOut,
3966
+ fVjp,
3967
+ dispose$1
3968
+ ];
3815
3969
  }
3816
3970
  function vjp$1(f, ...primalsIn) {
3817
3971
  const [primalsInFlat, inTree] = flatten(primalsIn);
3818
3972
  const [fFlat, outTree] = flattenFun(f, inTree);
3819
- const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3973
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3820
3974
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
3821
3975
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3822
- const fVjp = (cotangentsOut) => {
3976
+ const fVjp = ((cotangentsOut) => {
3823
3977
  const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
3824
3978
  if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
3825
3979
  const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
3826
3980
  return unflatten(inTree, cotangentsInFlat);
3827
- };
3981
+ });
3982
+ fVjp.dispose = dispose$1;
3828
3983
  return [primalsOut, fVjp];
3829
3984
  }
3830
3985
  function grad$1(f) {
@@ -3841,8 +3996,9 @@ function valueAndGrad$1(f) {
3841
3996
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3842
3997
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3843
3998
  if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3844
- const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
3845
- for (const r of rest) r.dispose();
3999
+ const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
4000
+ for (const r of rest) dispose(r);
4001
+ fVjp.dispose();
3846
4002
  return [y, ct];
3847
4003
  };
3848
4004
  }
@@ -3850,7 +4006,13 @@ function jacrev$1(f) {
3850
4006
  return function jacobianReverse(x) {
3851
4007
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
3852
4008
  const [size$1] = x.shape;
3853
- const pullback = (ct) => vjp$1(f, x)[1](ct)[0];
4009
+ const pullback = (ct) => {
4010
+ const [y, fVjp] = vjp$1(f, x);
4011
+ y.dispose();
4012
+ const [ret] = fVjp(ct);
4013
+ fVjp.dispose();
4014
+ return ret;
4015
+ };
3854
4016
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
3855
4017
  };
3856
4018
  }
@@ -3930,19 +4092,38 @@ __export(numpy_exports, {
3930
4092
  DType: () => DType,
3931
4093
  abs: () => abs,
3932
4094
  absolute: () => absolute,
4095
+ acos: () => acos,
4096
+ acosh: () => acosh,
3933
4097
  add: () => add,
3934
4098
  allclose: () => allclose,
3935
4099
  arange: () => arange,
4100
+ arccos: () => arccos,
4101
+ arccosh: () => arccosh,
4102
+ arcsinh: () => arcsinh,
4103
+ arctan: () => arctan,
4104
+ arctan2: () => arctan2,
4105
+ arctanh: () => arctanh,
3936
4106
  argmax: () => argmax,
3937
4107
  argmin: () => argmin,
3938
4108
  array: () => array,
4109
+ asin: () => asin,
4110
+ asinh: () => asinh,
3939
4111
  astype: () => astype,
4112
+ atan: () => atan,
4113
+ atan2: () => atan2,
4114
+ atanh: () => atanh,
3940
4115
  bool: () => bool,
4116
+ broadcastArrays: () => broadcastArrays,
4117
+ broadcastShapes: () => broadcastShapes,
4118
+ broadcastTo: () => broadcastTo,
4119
+ cbrt: () => cbrt,
3941
4120
  clip: () => clip,
3942
4121
  columnStack: () => columnStack,
3943
4122
  concatenate: () => concatenate,
3944
4123
  cos: () => cos,
3945
4124
  cosh: () => cosh,
4125
+ deg2rad: () => deg2rad,
4126
+ degrees: () => degrees,
3946
4127
  diag: () => diag,
3947
4128
  diagonal: () => diagonal,
3948
4129
  divide: () => divide,
@@ -3953,6 +4134,7 @@ __export(numpy_exports, {
3953
4134
  eulerGamma: () => eulerGamma,
3954
4135
  exp: () => exp,
3955
4136
  exp2: () => exp2,
4137
+ expm1: () => expm1,
3956
4138
  eye: () => eye,
3957
4139
  flip: () => flip,
3958
4140
  fliplr: () => fliplr,
@@ -3964,14 +4146,17 @@ __export(numpy_exports, {
3964
4146
  greater: () => greater,
3965
4147
  greaterEqual: () => greaterEqual,
3966
4148
  hstack: () => hstack,
4149
+ hypot: () => hypot,
3967
4150
  identity: () => identity$1,
3968
4151
  inf: () => inf,
4152
+ inner: () => inner,
3969
4153
  int32: () => int32,
3970
4154
  less: () => less,
3971
4155
  lessEqual: () => lessEqual,
3972
4156
  linspace: () => linspace,
3973
4157
  log: () => log,
3974
4158
  log10: () => log10,
4159
+ log1p: () => log1p,
3975
4160
  log2: () => log2,
3976
4161
  matmul: () => matmul,
3977
4162
  max: () => max,
@@ -3987,35 +4172,49 @@ __export(numpy_exports, {
3987
4172
  negative: () => negative,
3988
4173
  notEqual: () => notEqual,
3989
4174
  ones: () => ones,
3990
- onesLike: () => onesLike$1,
4175
+ onesLike: () => onesLike,
4176
+ outer: () => outer,
3991
4177
  pad: () => pad,
3992
4178
  permuteDims: () => permuteDims,
3993
4179
  pi: () => pi,
4180
+ pow: () => pow,
4181
+ power: () => power,
3994
4182
  prod: () => prod$1,
4183
+ promoteTypes: () => promoteTypes,
4184
+ rad2deg: () => rad2deg,
4185
+ radians: () => radians,
3995
4186
  ravel: () => ravel,
3996
4187
  reciprocal: () => reciprocal,
4188
+ repeat: () => repeat,
3997
4189
  reshape: () => reshape,
3998
- scalar: () => scalar,
3999
4190
  shape: () => shape,
4191
+ sign: () => sign,
4000
4192
  sin: () => sin,
4001
4193
  sinh: () => sinh,
4002
4194
  size: () => size,
4003
4195
  sqrt: () => sqrt,
4004
4196
  square: () => square,
4005
4197
  stack: () => stack,
4198
+ std: () => std,
4199
+ subtract: () => subtract,
4006
4200
  sum: () => sum,
4007
4201
  tan: () => tan,
4008
4202
  tanh: () => tanh,
4203
+ tile: () => tile,
4009
4204
  transpose: () => transpose,
4205
+ tri: () => tri,
4206
+ tril: () => tril,
4207
+ triu: () => triu,
4010
4208
  trueDivide: () => trueDivide,
4011
4209
  trunc: () => trunc,
4012
4210
  uint32: () => uint32,
4211
+ var_: () => var_,
4013
4212
  vdot: () => vdot,
4014
4213
  vecdot: () => vecdot,
4015
4214
  vstack: () => vstack,
4016
4215
  where: () => where,
4017
4216
  zeros: () => zeros,
4018
- zerosLike: () => zerosLike$1
4217
+ zerosLike: () => zerosLike
4019
4218
  });
4020
4219
  const float32 = DType.Float32;
4021
4220
  const int32 = DType.Int32;
@@ -4032,54 +4231,66 @@ const inf = Number.POSITIVE_INFINITY;
4032
4231
  const nan = NaN;
4033
4232
  /** This is Pi, `π = 3.14159265358979...` */
4034
4233
  const pi = Math.PI;
4035
- /** Element-wise addition, with broadcasting. */
4234
+ /** @function Element-wise addition, with broadcasting. */
4036
4235
  const add = add$1;
4037
- /** Element-wise multiplication, with broadcasting. */
4236
+ /** @function Element-wise multiplication, with broadcasting. */
4038
4237
  const multiply = mul;
4039
- /** Numerical negative of every element of an array. */
4238
+ /** @function Numerical negative of every element of an array. */
4040
4239
  const negative = neg;
4041
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
4240
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
4042
4241
  const reciprocal = reciprocal$1;
4043
- /** Element-wise sine function (takes radians). */
4242
+ /** @function Element-wise sine function (takes radians). */
4044
4243
  const sin = sin$1;
4045
- /** Element-wise cosine function (takes radians). */
4244
+ /** @function Element-wise cosine function (takes radians). */
4046
4245
  const cos = cos$1;
4047
- /** Calculate the exponential of all elements in the input array. */
4246
+ /** @function Element-wise inverse sine function (inverse of sin). */
4247
+ const asin = asin$1;
4248
+ /** @function Element-wise inverse tangent function (inverse of tan). */
4249
+ const atan = atan$1;
4250
+ /** @function Calculate the exponential of all elements in the input array. */
4048
4251
  const exp = exp$1;
4049
- /** Calculate the natural logarithm of all elements in the input array. */
4252
+ /** @function Calculate the natural logarithm of all elements in the input array. */
4050
4253
  const log = log$1;
4051
- /** Calculate the square root of all elements in the input array. */
4254
+ /** @function Calculate the square root of all elements in the input array. */
4052
4255
  const sqrt = sqrt$1;
4053
- /** Return element-wise minimum of the input arrays. */
4256
+ /** @function Return element-wise minimum of the input arrays. */
4054
4257
  const minimum = min$1;
4055
- /** Return element-wise maximum of the input arrays. */
4258
+ /** @function Return element-wise maximum of the input arrays. */
4056
4259
  const maximum = max$1;
4057
- /** Compare two arrays element-wise. */
4260
+ /** @function Compare two arrays element-wise. */
4058
4261
  const greater = greater$1;
4059
- /** Compare two arrays element-wise. */
4262
+ /** @function Compare two arrays element-wise. */
4060
4263
  const less = less$1;
4061
- /** Compare two arrays element-wise. */
4264
+ /** @function Compare two arrays element-wise. */
4062
4265
  const equal = equal$1;
4063
- /** Compare two arrays element-wise. */
4266
+ /** @function Compare two arrays element-wise. */
4064
4267
  const notEqual = notEqual$1;
4065
- /** Compare two arrays element-wise. */
4268
+ /** @function Compare two arrays element-wise. */
4066
4269
  const greaterEqual = greaterEqual$1;
4067
- /** Compare two arrays element-wise. */
4270
+ /** @function Compare two arrays element-wise. */
4068
4271
  const lessEqual = lessEqual$1;
4069
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4272
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4070
4273
  const where = where$1;
4071
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
4274
+ /**
4275
+ * @function
4276
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
4277
+ */
4072
4278
  const transpose = transpose$1;
4073
4279
  /**
4280
+ * @function
4074
4281
  * Give a new shape to an array without changing its data.
4075
4282
  *
4076
4283
  * One shape dimension can be -1. In this case, the value is inferred from the
4077
4284
  * length of the array and remaining dimensions.
4078
4285
  */
4079
4286
  const reshape = reshape$1;
4080
- /** Move axes of an array to new positions. Other axes retain original order. */
4287
+ /**
4288
+ * @function
4289
+ * Move axes of an array to new positions. Other axes retain original order.
4290
+ */
4081
4291
  const moveaxis = moveaxis$1;
4082
4292
  /**
4293
+ * @function
4083
4294
  * Add padding (zeros) to an array.
4084
4295
  *
4085
4296
  * The `width` argument is either an integer or pair of integers, in which case
@@ -4087,15 +4298,27 @@ const moveaxis = moveaxis$1;
4087
4298
  * pair specifies the padding for its corresponding axis.
4088
4299
  */
4089
4300
  const pad = pad$1;
4090
- /** Return the number of dimensions of an array. Does not consume array reference. */
4301
+ /**
4302
+ * @function
4303
+ * Return the number of dimensions of an array. Does not consume array reference.
4304
+ */
4091
4305
  const ndim = ndim$1;
4092
- /** Return the shape of an array. Does not consume array reference. */
4306
+ /** @function Return the shape of an array. Does not consume array reference. */
4093
4307
  const shape = getShape;
4094
- /** Return an array of zeros with the same shape and type as a given array. */
4095
- const zerosLike$1 = zerosLike;
4096
- /** Return an array of ones with the same shape and type as a given array. */
4097
- const onesLike$1 = onesLike;
4098
- /** Return a full array with the same shape and type as a given array. */
4308
+ /**
4309
+ * @function
4310
+ * Return an array of zeros with the same shape and type as a given array.
4311
+ */
4312
+ const zerosLike = zerosLike$1;
4313
+ /**
4314
+ * @function
4315
+ * Return an array of ones with the same shape and type as a given array.
4316
+ */
4317
+ const onesLike = onesLike$1;
4318
+ /**
4319
+ * @function
4320
+ * Return a full array with the same shape and type as a given array.
4321
+ */
4099
4322
  const fullLike$1 = fullLike;
4100
4323
  /**
4101
4324
  * Return the number of elements in an array, optionally along an axis.
@@ -4110,23 +4333,23 @@ function astype(a, dtype) {
4110
4333
  return fudgeArray(a).astype(dtype);
4111
4334
  }
4112
4335
  /** Sum of the elements of the array over a given axis, or axes. */
4113
- function sum(a, axis, opts) {
4336
+ function sum(a, axis = null, opts) {
4114
4337
  return reduce(a, AluOp.Add, axis, opts);
4115
4338
  }
4116
4339
  /** Product of the array elements over a given axis. */
4117
- function prod$1(a, axis, opts) {
4340
+ function prod$1(a, axis = null, opts) {
4118
4341
  return reduce(a, AluOp.Mul, axis, opts);
4119
4342
  }
4120
4343
  /** Return the minimum of array elements along a given axis. */
4121
- function min(a, axis, opts) {
4344
+ function min(a, axis = null, opts) {
4122
4345
  return reduce(a, AluOp.Min, axis, opts);
4123
4346
  }
4124
4347
  /** Return the maximum of array elements along a given axis. */
4125
- function max(a, axis, opts) {
4348
+ function max(a, axis = null, opts) {
4126
4349
  return reduce(a, AluOp.Max, axis, opts);
4127
4350
  }
4128
4351
  /** Compute the average of the array elements along the specified axis. */
4129
- function mean(a, axis, opts) {
4352
+ function mean(a, axis = null, opts) {
4130
4353
  return fudgeArray(a).mean(axis, opts);
4131
4354
  }
4132
4355
  /**
@@ -4142,8 +4365,8 @@ function argmin(a, axis, opts) {
4142
4365
  axis = 0;
4143
4366
  } else axis = checkAxis(axis, a.ndim);
4144
4367
  const shape$1 = a.shape;
4145
- const isMax = equal(a, min(a.ref, axis, { keepDims: true }));
4146
- const length = scalar(shape$1[axis], {
4368
+ const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
4369
+ const length = array(shape$1[axis], {
4147
4370
  dtype: int32,
4148
4371
  device: a.device
4149
4372
  });
@@ -4166,8 +4389,8 @@ function argmax(a, axis, opts) {
4166
4389
  axis = 0;
4167
4390
  } else axis = checkAxis(axis, a.ndim);
4168
4391
  const shape$1 = a.shape;
4169
- const isMax = equal(a, max(a.ref, axis, { keepDims: true }));
4170
- const length = scalar(shape$1[axis], {
4392
+ const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
4393
+ const length = array(shape$1[axis], {
4171
4394
  dtype: int32,
4172
4395
  device: a.device
4173
4396
  });
@@ -4178,17 +4401,9 @@ function argmax(a, axis, opts) {
4178
4401
  return length.sub(max(idx, axis, opts));
4179
4402
  }
4180
4403
  /** Reverse the elements in an array along the given axes. */
4181
- function flip(x, axis) {
4404
+ function flip(x, axis = null) {
4182
4405
  const nd = ndim(x);
4183
- if (axis === void 0) axis = range(nd);
4184
- else if (typeof axis === "number") axis = [axis];
4185
- const seen = /* @__PURE__ */ new Set();
4186
- for (let i = 0; i < axis.length; i++) {
4187
- if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
4188
- if (axis[i] < 0) axis[i] += nd;
4189
- if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
4190
- seen.add(axis[i]);
4191
- }
4406
+ axis = normalizeAxis(axis, nd);
4192
4407
  return flip$1(x, axis);
4193
4408
  }
4194
4409
  /**
@@ -4294,12 +4509,80 @@ function flipud(x) {
4294
4509
  function fliplr(x) {
4295
4510
  return flip(x, 1);
4296
4511
  }
4512
+ /** @function Alternative name for `numpy.transpose()`. */
4297
4513
  const permuteDims = transpose;
4298
4514
  /** Return a 1-D flattened array containing the elements of the input. */
4299
4515
  function ravel(a) {
4300
4516
  return fudgeArray(a).ravel();
4301
4517
  }
4302
4518
  /**
4519
+ * Repeat each element of an array after themselves.
4520
+ *
4521
+ * If no axis is provided, use the flattened input array, and return a flat
4522
+ * output array.
4523
+ */
4524
+ function repeat(a, repeats, axis) {
4525
+ if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
4526
+ a = fudgeArray(a);
4527
+ if (axis === void 0) {
4528
+ a = ravel(a);
4529
+ axis = 0;
4530
+ }
4531
+ axis = checkAxis(axis, a.ndim);
4532
+ if (repeats === 1) return a;
4533
+ const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
4534
+ const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
4535
+ return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
4536
+ }
4537
+ /**
4538
+ * Construct an array by repeating A the number of times given by reps.
4539
+ *
4540
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
4541
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
4542
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
4543
+ */
4544
+ function tile(a, reps) {
4545
+ a = fudgeArray(a);
4546
+ if (typeof reps === "number") reps = [reps];
4547
+ if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
4548
+ const ndiff = reps.length - a.ndim;
4549
+ if (ndiff > 0) a = a.reshape([...rep(ndiff, 1), ...a.shape]);
4550
+ if (ndiff < 0) reps = [...rep(-ndiff, 1), ...reps];
4551
+ const broadcastedShape = [];
4552
+ const broadcastAxes = [];
4553
+ for (let i = 0; i < a.ndim; i++) {
4554
+ if (reps[i] > 1) {
4555
+ broadcastedShape.push(reps[i]);
4556
+ broadcastAxes.push(broadcastedShape.length - 1);
4557
+ }
4558
+ broadcastedShape.push(a.shape[i]);
4559
+ }
4560
+ const finalShape = a.shape.map((d, i) => reps[i] * d);
4561
+ return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
4562
+ }
4563
+ /**
4564
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
4565
+ *
4566
+ * In other words, this lets you append axes to the left, and/or expand
4567
+ * dimensions where the shape is 1.
4568
+ */
4569
+ function broadcastTo(a, shape$1) {
4570
+ const nd = ndim(a);
4571
+ if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
4572
+ return broadcast(a, shape$1, range(shape$1.length - nd));
4573
+ }
4574
+ /** Broadcast input shapes to a common output shape. */
4575
+ function broadcastShapes(...shapes) {
4576
+ if (shapes.length === 0) return [];
4577
+ return shapes.reduce(generalBroadcast);
4578
+ }
4579
+ /** Broadcast arrays to a common shape. */
4580
+ function broadcastArrays(...arrays) {
4581
+ const shapes = arrays.map((a) => shape(a));
4582
+ const outShape = broadcastShapes(...shapes);
4583
+ return arrays.map((a) => broadcastTo(a, outShape));
4584
+ }
4585
+ /**
4303
4586
  * Return specified diagonals.
4304
4587
  *
4305
4588
  * If a is 2D, return the diagonal of the array with the given offset. If a is
@@ -4323,7 +4606,7 @@ function diag(v, k = 0) {
4323
4606
  if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
4324
4607
  if (a.ndim === 1) {
4325
4608
  const n = a.shape[0];
4326
- const ret = where(eye(n).equal(1), a.ref, zerosLike$1(a));
4609
+ const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
4327
4610
  if (k > 0) return pad(ret, [[0, k], [k, 0]]);
4328
4611
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
4329
4612
  else return ret;
@@ -4367,8 +4650,36 @@ function dot(x, y) {
4367
4650
  ]);
4368
4651
  return dot$1(x, y);
4369
4652
  }
4370
- /** Vector dot product of two arrays. */
4371
- function vecdot(x, y) {
4653
+ /**
4654
+ * Compute the inner product of two arrays.
4655
+ *
4656
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
4657
+ * contraction on the last axis.
4658
+ *
4659
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
4660
+ */
4661
+ function inner(x, y) {
4662
+ x = reshape(x, shape(x).toSpliced(-1, 0, ...rep(ndim(y) - 1, 1)));
4663
+ return dot$1(x, y);
4664
+ }
4665
+ /**
4666
+ * Compute the outer product of two arrays.
4667
+ *
4668
+ * If the input arrays are not 1D, they will be flattened. Returned array will
4669
+ * be of shape `[x.size, y.size]`.
4670
+ */
4671
+ function outer(x, y) {
4672
+ x = ravel(x);
4673
+ y = ravel(y);
4674
+ return multiply(x.reshape([x.shape[0], 1]), y);
4675
+ }
4676
+ /** Vector dot product of two arrays along a given axis. */
4677
+ function vecdot(x, y, { axis } = {}) {
4678
+ const xaxis = checkAxis(axis ?? -1, ndim(x));
4679
+ const yaxis = checkAxis(axis ?? -1, ndim(y));
4680
+ if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
4681
+ x = moveaxis(x, xaxis, -1);
4682
+ y = moveaxis(y, yaxis, -1);
4372
4683
  return dot$1(x, y);
4373
4684
  }
4374
4685
  /**
@@ -4377,7 +4688,7 @@ function vecdot(x, y) {
4377
4688
  * Like vecdot() but flattens the arguments first into vectors.
4378
4689
  */
4379
4690
  function vdot(x, y) {
4380
- return vecdot(ravel(x), ravel(y));
4691
+ return dot$1(ravel(x), ravel(y));
4381
4692
  }
4382
4693
  /**
4383
4694
  * Return a tuple of coordinate matrices from coordinate vectors.
@@ -4406,6 +4717,43 @@ function meshgrid(xs, { indexing } = {}) {
4406
4717
  return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
4407
4718
  }
4408
4719
  /**
4720
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
4721
+ *
4722
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
4723
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
4724
+ * `k>0` is above it.
4725
+ */
4726
+ function tri(n, m, k = 0, { dtype, device } = {}) {
4727
+ m ??= n;
4728
+ dtype ??= DType.Float32;
4729
+ if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
4730
+ if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
4731
+ if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
4732
+ const rows = arange(k, n + k, 1, {
4733
+ dtype: DType.Int32,
4734
+ device
4735
+ });
4736
+ const cols = arange(0, m, 1, {
4737
+ dtype: DType.Int32,
4738
+ device
4739
+ });
4740
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
4741
+ }
4742
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
4743
+ function tril(a, k = 0) {
4744
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4745
+ a = fudgeArray(a);
4746
+ const [n, m] = a.shape.slice(-2);
4747
+ return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
4748
+ }
4749
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
4750
+ function triu(a, k = 0) {
4751
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4752
+ a = fudgeArray(a);
4753
+ const [n, m] = a.shape.slice(-2);
4754
+ return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
4755
+ }
4756
+ /**
4409
4757
  * Clip (limit) the values in an array.
4410
4758
  *
4411
4759
  * Given an interval, values outside the interval are clipped to the interval
@@ -4429,18 +4777,70 @@ function absolute(x) {
4429
4777
  x = fudgeArray(x);
4430
4778
  return where(less(x.ref, 0), x.ref.mul(-1), x);
4431
4779
  }
4432
- /** Alias of `jax.numpy.absolute()`. */
4780
+ /** @function Alias of `jax.numpy.absolute()`. */
4433
4781
  const abs = absolute;
4782
+ /** Return an element-wise indication of sign of the input. */
4783
+ function sign(x) {
4784
+ x = fudgeArray(x);
4785
+ return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4786
+ }
4434
4787
  /** Calculate element-wise square of the input array. */
4435
4788
  function square(x) {
4436
4789
  x = fudgeArray(x);
4437
4790
  return x.ref.mul(x);
4438
4791
  }
4439
- /** Compute a trigonometric tangent of each element of input. */
4792
+ /** Element-wise tangent function (takes radians). */
4440
4793
  function tan(x) {
4441
4794
  x = fudgeArray(x);
4442
4795
  return sin(x.ref).div(cos(x));
4443
4796
  }
4797
+ /** Element-wise inverse cosine function (inverse of cos). */
4798
+ function acos(x) {
4799
+ return subtract(pi / 2, asin(x));
4800
+ }
4801
+ /**
4802
+ * @function
4803
+ * Return element-wise hypotenuse for the given legs of a right triangle.
4804
+ *
4805
+ * In the original NumPy/JAX implementation, this function is more numerically
4806
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4807
+ * improvements.
4808
+ */
4809
+ const hypot = jit$1(function hypot$1(x1, x2) {
4810
+ return sqrt(square(x1).add(square(x2)));
4811
+ });
4812
+ /**
4813
+ * @function
4814
+ * Element-wise arc tangent of y/x with correct quadrant.
4815
+ *
4816
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
4817
+ * The result is in the range [-π, π].
4818
+ *
4819
+ * Uses numerically stable formulas:
4820
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
4821
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
4822
+ *
4823
+ * The output is ill-defined when both x and y are zero.
4824
+ */
4825
+ const atan2 = jit$1(function atan2$1(y, x) {
4826
+ const r = sqrt(square(x.ref).add(square(y.ref)));
4827
+ const xNeg = less(x.ref, 0);
4828
+ const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
4829
+ const denom = where(xNeg, y, r.add(x));
4830
+ return atan(numer.div(denom)).mul(2);
4831
+ });
4832
+ /** @function Alias of `jax.numpy.acos()`. */
4833
+ const arccos = acos;
4834
+ /** @function Alias of `jax.numpy.atan()`. */
4835
+ const arctan = atan;
4836
+ /** @function Alias of `jax.numpy.atan2()`. */
4837
+ const arctan2 = atan2;
4838
+ /** Element-wise subtraction, with broadcasting. */
4839
+ function subtract(x, y) {
4840
+ x = fudgeArray(x);
4841
+ y = fudgeArray(y);
4842
+ return x.sub(y);
4843
+ }
4444
4844
  /** Calculates the floating-point division of x by y element-wise. */
4445
4845
  function trueDivide(x, y) {
4446
4846
  x = fudgeArray(x);
@@ -4448,7 +4848,7 @@ function trueDivide(x, y) {
4448
4848
  if (!isFloatDtype(x.dtype) || !isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
4449
4849
  return x.div(y);
4450
4850
  }
4451
- /** Alias of `jax.numpy.trueDivide()`. */
4851
+ /** @function Alias of `jax.numpy.trueDivide()`. */
4452
4852
  const divide = trueDivide;
4453
4853
  /** Round input to the nearest integer towards zero. */
4454
4854
  function trunc(x) {
@@ -4466,36 +4866,134 @@ function log2(x) {
4466
4866
  function log10(x) {
4467
4867
  return log(x).mul(Math.LOG10E);
4468
4868
  }
4869
+ /** Calculate `exp(x) - 1` element-wise. */
4870
+ function expm1(x) {
4871
+ return exp(x).sub(1);
4872
+ }
4873
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
4874
+ function log1p(x) {
4875
+ return log(add(1, x));
4876
+ }
4877
+ /** Convert angles from degrees to radians. */
4878
+ function deg2rad(x) {
4879
+ return multiply(x, pi / 180);
4880
+ }
4881
+ /** @function Alias of `jax.numpy.deg2rad()`. */
4882
+ const radians = deg2rad;
4883
+ /** Convert angles from radians to degrees. */
4884
+ function rad2deg(x) {
4885
+ return multiply(x, 180 / pi);
4886
+ }
4887
+ /** @function Alias of `jax.numpy.rad2deg()`. */
4888
+ const degrees = rad2deg;
4889
+ /**
4890
+ * @function
4891
+ * Computes first array raised to power of second array, element-wise.
4892
+ */
4893
+ const power = jit$1(function power$1(x1, x2) {
4894
+ return exp(log(x1).mul(x2));
4895
+ });
4896
+ /** @function Alias of `jax.numpy.power()`. */
4897
+ const pow = power;
4898
+ /** @function Calculate the element-wise cube root of the input array. */
4899
+ const cbrt = jit$1(function cbrt$1(x) {
4900
+ const sgn = where(less(x.ref, 0), -1, 1);
4901
+ return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4902
+ });
4469
4903
  /**
4904
+ * @function
4470
4905
  * Calculate element-wise hyperbolic sine of input.
4471
4906
  *
4472
4907
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
4473
4908
  */
4474
- function sinh(x) {
4909
+ const sinh = jit$1(function sinh$1(x) {
4475
4910
  const ex = exp(x);
4476
4911
  const emx = reciprocal(ex.ref);
4477
4912
  return ex.sub(emx).mul(.5);
4478
- }
4913
+ });
4479
4914
  /**
4915
+ * @function
4480
4916
  * Calculate element-wise hyperbolic cosine of input.
4481
4917
  *
4482
4918
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
4483
4919
  */
4484
- function cosh(x) {
4920
+ const cosh = jit$1(function cosh$1(x) {
4485
4921
  const ex = exp(x);
4486
4922
  const emx = reciprocal(ex.ref);
4487
4923
  return ex.add(emx).mul(.5);
4488
- }
4924
+ });
4489
4925
  /**
4926
+ * @function
4490
4927
  * Calculate element-wise hyperbolic tangent of input.
4491
4928
  *
4492
4929
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4493
4930
  */
4494
- function tanh(x) {
4495
- x = fudgeArray(x);
4931
+ const tanh = jit$1(function tanh$1(x) {
4496
4932
  const negsgn = where(less(x.ref, 0), 1, -1);
4497
4933
  const en2x = exp(x.mul(negsgn.ref).mul(2));
4498
4934
  return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
4935
+ });
4936
+ /**
4937
+ * @function
4938
+ * Calculate element-wise inverse hyperbolic sine of input.
4939
+ *
4940
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4941
+ */
4942
+ const arcsinh = jit$1(function arcsinh$1(x) {
4943
+ return log(x.ref.add(sqrt(square(x).add(1))));
4944
+ });
4945
+ /**
4946
+ * @function
4947
+ * Calculate element-wise inverse hyperbolic cosine of input.
4948
+ *
4949
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4950
+ */
4951
+ const arccosh = jit$1(function arccosh$1(x) {
4952
+ return log(x.ref.add(sqrt(square(x).sub(1))));
4953
+ });
4954
+ /**
4955
+ * @function
4956
+ * Calculate element-wise inverse hyperbolic tangent of input.
4957
+ *
4958
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4959
+ */
4960
+ const arctanh = jit$1(function arctanh$1(x) {
4961
+ return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4962
+ });
4963
+ /** @function Alias of `jax.numpy.arcsinh()`. */
4964
+ const asinh = arcsinh;
4965
+ /** @function Alias of `jax.numpy.arccosh()`. */
4966
+ const acosh = arccosh;
4967
+ /** @function Alias of `jax.numpy.arctanh()`. */
4968
+ const atanh = arctanh;
4969
+ /**
4970
+ * Compute the variance of an array.
4971
+ *
4972
+ * The variance is computed for the flattened array by default, otherwise over
4973
+ * the specified axis.
4974
+ *
4975
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4976
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4977
+ */
4978
+ function var_(x, axis = null, opts) {
4979
+ x = fudgeArray(x);
4980
+ axis = normalizeAxis(axis, x.ndim);
4981
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
4982
+ if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
4983
+ const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
4984
+ return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
4985
+ }
4986
+ /**
4987
+ * Compute the standard deviation of an array.
4988
+ *
4989
+ * The standard deviation is computed for the flattened array by default,
4990
+ * otherwise over the specified axis.
4991
+ *
4992
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4993
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4994
+ */
4995
+ function std(x, axis = null, opts) {
4996
+ return sqrt(var_(x, axis, opts));
4499
4997
  }
4500
4998
 
4501
4999
  //#endregion
@@ -4510,6 +5008,7 @@ __export(nn_exports, {
4510
5008
  leakyRelu: () => leakyRelu,
4511
5009
  logSigmoid: () => logSigmoid,
4512
5010
  logSoftmax: () => logSoftmax,
5011
+ logmeanexp: () => logmeanexp,
4513
5012
  logsumexp: () => logsumexp,
4514
5013
  mish: () => mish,
4515
5014
  oneHot: () => oneHot,
@@ -4520,6 +5019,8 @@ __export(nn_exports, {
4520
5019
  softSign: () => softSign,
4521
5020
  softmax: () => softmax,
4522
5021
  softplus: () => softplus,
5022
+ squareplus: () => squareplus,
5023
+ standardize: () => standardize,
4523
5024
  swish: () => swish
4524
5025
  });
4525
5026
  /**
@@ -4563,6 +5064,7 @@ function softSign(x) {
4563
5064
  return x.ref.div(absolute(x).add(1));
4564
5065
  }
4565
5066
  /**
5067
+ * @function
4566
5068
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4567
5069
  * Swish, computed element-wise:
4568
5070
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4571,8 +5073,11 @@ function softSign(x) {
4571
5073
  *
4572
5074
  * Reference: https://en.wikipedia.org/wiki/Swish_function
4573
5075
  */
4574
- const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
5076
+ const silu = jit$1(function silu$1(x) {
5077
+ return x.ref.mul(sigmoid(x));
5078
+ });
4575
5079
  /**
5080
+ * @function
4576
5081
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4577
5082
  * Swish, computed element-wise:
4578
5083
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4589,7 +5094,10 @@ const swish = silu;
4589
5094
  function logSigmoid(x) {
4590
5095
  return negative(softplus(negative(x)));
4591
5096
  }
4592
- /** Identity activation function. Returns the argument unmodified. */
5097
+ /**
5098
+ * @function
5099
+ * Identity activation function. Returns the argument unmodified.
5100
+ */
4593
5101
  const identity = fudgeArray;
4594
5102
  /** Leaky rectified linear (ReLU) activation function */
4595
5103
  function leakyRelu(x, negativeSlope = .01) {
@@ -4617,6 +5125,7 @@ function celu(x, alpha = 1) {
4617
5125
  return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
4618
5126
  }
4619
5127
  /**
5128
+ * @function
4620
5129
  * Gaussion error linear unit (GELU) activation function.
4621
5130
  *
4622
5131
  * This is computed element-wise. Currently jax-js does not support the erf() or
@@ -4627,7 +5136,7 @@ function celu(x, alpha = 1) {
4627
5136
  *
4628
5137
  * This will be improved in the future.
4629
5138
  */
4630
- const gelu = jit$1((x) => {
5139
+ const gelu = jit$1(function gelu$1(x) {
4631
5140
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
4632
5141
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
4633
5142
  });
@@ -4648,6 +5157,16 @@ function glu(x, axis = -1) {
4648
5157
  return a.mul(sigmoid(b));
4649
5158
  }
4650
5159
  /**
5160
+ * Squareplus activation function.
5161
+ *
5162
+ * Computes the element-wise function:
5163
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
5164
+ */
5165
+ function squareplus(x, b = 4) {
5166
+ x = fudgeArray(x);
5167
+ return x.ref.add(sqrt(square(x).add(b))).mul(.5);
5168
+ }
5169
+ /**
4651
5170
  * Mish activation function.
4652
5171
  *
4653
5172
  * Computes the element-wise function:
@@ -4665,17 +5184,13 @@ function mish(x) {
4665
5184
  *
4666
5185
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
4667
5186
  */
4668
- function softmax(x, axis) {
5187
+ function softmax(x, axis = -1) {
4669
5188
  x = fudgeArray(x);
4670
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4671
- else if (typeof axis === "number") axis = [axis];
4672
- if (axis.length === 0) {
4673
- x.dispose();
4674
- return ones(x.shape);
4675
- }
4676
- const xMax = max(x.ref, axis, { keepDims: true });
5189
+ axis = normalizeAxis(axis, x.ndim);
5190
+ if (axis.length === 0) return onesLike(x);
5191
+ const xMax = max(x.ref, axis, { keepdims: true });
4677
5192
  const unnormalized = exp(x.sub(stopGradient(xMax)));
4678
- return unnormalized.ref.div(unnormalized.sum(axis, { keepDims: true }));
5193
+ return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
4679
5194
  }
4680
5195
  /**
4681
5196
  * Log-Softmax function.
@@ -4685,17 +5200,13 @@ function softmax(x, axis) {
4685
5200
  *
4686
5201
  * If `axis` is not specified, it defaults to the last axis.
4687
5202
  */
4688
- function logSoftmax(x, axis) {
5203
+ function logSoftmax(x, axis = -1) {
4689
5204
  x = fudgeArray(x);
4690
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4691
- else if (typeof axis === "number") axis = [axis];
4692
- if (axis.length === 0) {
4693
- x.dispose();
4694
- return zeros(x.shape);
4695
- }
4696
- const xMax = max(x.ref, axis, { keepDims: true });
5205
+ axis = normalizeAxis(axis, x.ndim);
5206
+ if (axis.length === 0) return zerosLike(x);
5207
+ const xMax = max(x.ref, axis, { keepdims: true });
4697
5208
  const shifted = x.sub(stopGradient(xMax));
4698
- const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepDims: true }));
5209
+ const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
4699
5210
  return shifted.sub(shiftedLogsumexp);
4700
5211
  }
4701
5212
  /**
@@ -4706,16 +5217,39 @@ function logSoftmax(x, axis) {
4706
5217
  *
4707
5218
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
4708
5219
  */
4709
- function logsumexp(x, axis) {
5220
+ function logsumexp(x, axis = null) {
4710
5221
  x = fudgeArray(x);
4711
- if (axis === void 0) axis = range(x.ndim);
4712
- else if (typeof axis === "number") axis = [axis];
5222
+ axis = normalizeAxis(axis, x.ndim);
4713
5223
  if (axis.length === 0) return x;
4714
5224
  const xMax = stopGradient(max(x.ref, axis));
4715
5225
  const xMaxDims = broadcast(xMax.ref, x.shape, axis);
4716
5226
  const shifted = x.sub(xMaxDims);
4717
5227
  return xMax.add(log(exp(shifted).sum(axis)));
4718
5228
  }
5229
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
5230
+ function logmeanexp(x, axis = null) {
5231
+ x = fudgeArray(x);
5232
+ axis = normalizeAxis(axis, x.ndim);
5233
+ if (axis.length === 0) return x;
5234
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5235
+ return logsumexp(x, axis).sub(Math.log(n));
5236
+ }
5237
+ /**
5238
+ * Standardizes input to zero mean and unit variance.
5239
+ *
5240
+ * By default, this is computed over the last axis. You can pass in a different
5241
+ * axis, or `null` to standardize over all elements.
5242
+ *
5243
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
5244
+ */
5245
+ function standardize(x, axis = -1, opts = {}) {
5246
+ x = fudgeArray(x);
5247
+ axis = normalizeAxis(axis, x.ndim);
5248
+ if (axis.length === 0) return x;
5249
+ const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
5250
+ const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
5251
+ return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
5252
+ }
4719
5253
  /**
4720
5254
  * One-hot encodes the given indices.
4721
5255
  *
@@ -4733,7 +5267,7 @@ function logsumexp(x, axis) {
4733
5267
  * ```
4734
5268
  */
4735
5269
  function oneHot(x, numClasses) {
4736
- if (x.dtype !== DType.Int32) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
5270
+ if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4737
5271
  return eye(numClasses, void 0, { device: x.device }).slice(x);
4738
5272
  }
4739
5273
 
@@ -4741,8 +5275,11 @@ function oneHot(x, numClasses) {
4741
5275
  //#region src/random.ts
4742
5276
  var random_exports = {};
4743
5277
  __export(random_exports, {
5278
+ bernoulli: () => bernoulli,
4744
5279
  bits: () => bits,
5280
+ exponential: () => exponential,
4745
5281
  key: () => key,
5282
+ normal: () => normal,
4746
5283
  split: () => split,
4747
5284
  uniform: () => uniform
4748
5285
  });
@@ -4770,21 +5307,58 @@ function bits(key$1, shape$1 = []) {
4770
5307
  const keyShape = validateKeyShape(key$1);
4771
5308
  return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
4772
5309
  }
4773
- /** Sample uniform random values in [minval, maxval) with given shape. */
4774
- function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5310
+ /**
5311
+ * @function
5312
+ * Sample uniform random values in [minval, maxval) with given shape.
5313
+ */
5314
+ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4775
5315
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
4776
- const mantissa = bits(key$1, shape$1).div(scalar(512, {
5316
+ const mantissa = bits(key$1, shape$1).div(array(512, {
4777
5317
  dtype: DType.Uint32,
4778
5318
  device: key$1.device
4779
5319
  }));
4780
- const float12 = mantissa.add(scalar(1065353216, {
5320
+ const float12 = mantissa.add(array(1065353216, {
4781
5321
  dtype: DType.Uint32,
4782
5322
  device: key$1.device
4783
5323
  }));
4784
5324
  const rand = bitcast(float12, DType.Float32).sub(1);
4785
5325
  if (minval === 0 && maxval === 1) return rand;
4786
5326
  else return rand.mul(maxval - minval).add(minval);
5327
+ }, { staticArgnums: [1, 2] });
5328
+ /**
5329
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
5330
+ *
5331
+ * Returns a random Boolean array with the specified shape. `p` can be an array
5332
+ * and must be broadcastable to `shape`.
5333
+ */
5334
+ function bernoulli(key$1, p = .5, shape$1 = []) {
5335
+ p = fudgeArray(p);
5336
+ return uniform(key$1, shape$1).less(p);
4787
5337
  }
5338
+ /**
5339
+ * @function
5340
+ * Sample exponential random values according to `p(x) = exp(-x)`.
5341
+ */
5342
+ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
5343
+ const u = uniform(key$1, shape$1);
5344
+ return negative(log1p(negative(u)));
5345
+ }, { staticArgnums: [1] });
5346
+ /**
5347
+ * @function
5348
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5349
+ *
5350
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5351
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5352
+ * bitwise identical to JAX.
5353
+ */
5354
+ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5355
+ const [k1, k2] = split(key$1, 2);
5356
+ const u1 = uniform(k1, shape$1);
5357
+ const u2 = uniform(k2, shape$1);
5358
+ const radius = sqrt(log1p(negative(u1)).mul(-2));
5359
+ const theta = u2.mul(2 * Math.PI);
5360
+ return radius.mul(cos(theta));
5361
+ }, { staticArgnums: [1] });
4788
5362
 
4789
5363
  //#endregion
4790
5364
  //#region src/polyfills.ts
@@ -4794,20 +5368,36 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
4794
5368
 
4795
5369
  //#endregion
4796
5370
  //#region src/index.ts
4797
- /** Compute the forward-mode Jacobian-vector product for a function. */
5371
+ /**
5372
+ * @function
5373
+ * Compute the forward-mode Jacobian-vector product for a function.
5374
+ */
4798
5375
  const jvp = jvp$1;
4799
- /** Vectorize an operation on a batched axis for one or more inputs. */
5376
+ /**
5377
+ * @function
5378
+ * Vectorize an operation on a batched axis for one or more inputs.
5379
+ */
4800
5380
  const vmap = vmap$1;
4801
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
5381
+ /**
5382
+ * @function
5383
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
5384
+ */
4802
5385
  const jacfwd = jacfwd$1;
4803
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
5386
+ /**
5387
+ * @function
5388
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
5389
+ */
4804
5390
  const makeJaxpr = makeJaxpr$1;
4805
5391
  /**
5392
+ * @function
4806
5393
  * Mark a function for automatic JIT compilation, with operator fusion.
4807
5394
  *
4808
5395
  * The function will be compiled the first time it is called with a set of
4809
5396
  * argument shapes.
4810
5397
  *
5398
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
5399
+ * calls to free memory associated with array constants.
5400
+ *
4811
5401
  * **Options:**
4812
5402
  * - `staticArgnums`: An array of argument indices to treat as static
4813
5403
  * (compile-time constant). These arguments must be hashable, won't be traced,
@@ -4817,23 +5407,52 @@ const makeJaxpr = makeJaxpr$1;
4817
5407
  */
4818
5408
  const jit = jit$1;
4819
5409
  /**
5410
+ * @function
4820
5411
  * Produce a local linear approximation to a function at a point using jvp() and
4821
5412
  * partial evaluation.
4822
5413
  */
4823
5414
  const linearize = linearize$1;
4824
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
5415
+ /**
5416
+ * @function
5417
+ * Calculate the reverse-mode vector-Jacobian product for a function.
5418
+ */
4825
5419
  const vjp = vjp$1;
4826
5420
  /**
5421
+ * @function
4827
5422
  * Compute the gradient of a scalar-valued function `f` with respect to its
4828
5423
  * first argument.
4829
5424
  */
4830
5425
  const grad = grad$1;
4831
- /** Create a function that evaluates both `f` and the gradient of `f`. */
5426
+ /**
5427
+ * @function
5428
+ * Create a function that evaluates both `f` and the gradient of `f`.
5429
+ */
4832
5430
  const valueAndGrad = valueAndGrad$1;
4833
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
5431
+ /**
5432
+ * @function
5433
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
5434
+ */
4834
5435
  const jacrev = jacrev$1;
4835
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
5436
+ /**
5437
+ * @function
5438
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
5439
+ */
4836
5440
  const jacobian = jacrev;
5441
+ /**
5442
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
5443
+ *
5444
+ * This can be used to wait for the results of an intermediate computation to
5445
+ * finish. It's recommended to call this regularly in an iterative computation
5446
+ * to avoid queueing up too many pending operations.
5447
+ *
5448
+ * Does not consume reference to the arrays.
5449
+ */
5450
+ async function blockUntilReady(x) {
5451
+ const promises = [];
5452
+ for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
5453
+ await Promise.all(promises);
5454
+ return x;
5455
+ }
4837
5456
 
4838
5457
  //#endregion
4839
- export { DType, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDevice, tree_exports as tree, valueAndGrad, vjp, vmap };
5458
+ export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };