@jax-js/jax 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
30
30
  }) : target, mod));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-D2C4MJRP.cjs');
33
+ const require_backend = require('./backend-yEU0L_ig.cjs');
34
34
 
35
35
  //#region src/tree.ts
36
36
  var tree_exports = {};
@@ -354,6 +354,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
354
354
  Primitive$1["RandomBits"] = "random_bits";
355
355
  Primitive$1["Sin"] = "sin";
356
356
  Primitive$1["Cos"] = "cos";
357
+ Primitive$1["Asin"] = "asin";
358
+ Primitive$1["Atan"] = "atan";
357
359
  Primitive$1["Exp"] = "exp";
358
360
  Primitive$1["Log"] = "log";
359
361
  Primitive$1["Sqrt"] = "sqrt";
@@ -421,6 +423,12 @@ function sin$1(x) {
421
423
  function cos$1(x) {
422
424
  return bind1(Primitive.Cos, [x]);
423
425
  }
426
+ function asin$1(x) {
427
+ return bind1(Primitive.Asin, [x]);
428
+ }
429
+ function atan$1(x) {
430
+ return bind1(Primitive.Atan, [x]);
431
+ }
424
432
  function exp$1(x) {
425
433
  return bind1(Primitive.Exp, [x]);
426
434
  }
@@ -436,18 +444,16 @@ function min$1(x, y) {
436
444
  function max$1(x, y) {
437
445
  return bind1(Primitive.Max, [x, y]);
438
446
  }
439
- function reduce(x, op, axis, opts) {
447
+ function reduce(x, op, axis = null, opts) {
440
448
  if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
441
- if (axis === void 0) if (x instanceof Tracer) axis = require_backend.range(x.shape.length);
442
- else axis = [];
443
- else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, ndim$1(x))];
444
- else axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
449
+ axis = require_backend.normalizeAxis(axis, ndim$1(x));
445
450
  const originalShape = getShape(x);
446
- const result = bind1(Primitive.Reduce, [x], {
451
+ let result = bind1(Primitive.Reduce, [x], {
447
452
  op,
448
453
  axis
449
454
  });
450
- return opts?.keepDims ? broadcast(result, originalShape, axis) : result;
455
+ if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
456
+ return result;
451
457
  }
452
458
  function dot$1(x, y) {
453
459
  return bind1(Primitive.Dot, [x, y]);
@@ -493,10 +499,11 @@ function where$1(cond, x, y) {
493
499
  }
494
500
  function transpose$1(x, perm) {
495
501
  perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
502
+ if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
496
503
  return bind1(Primitive.Transpose, [x], { perm });
497
504
  }
498
505
  function broadcast(x, shape$1, axis) {
499
- axis = axis.map((a) => require_backend.checkAxis(a, shape$1.length));
506
+ axis = require_backend.normalizeAxis(axis, shape$1.length);
500
507
  return bind1(Primitive.Broadcast, [x], {
501
508
  shape: shape$1,
502
509
  axis
@@ -515,7 +522,7 @@ function reshape$1(x, shape$1) {
515
522
  return bind1(Primitive.Reshape, [x], { shape: shape$1 });
516
523
  }
517
524
  function flip$1(x, axis) {
518
- axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
525
+ axis = require_backend.normalizeAxis(axis, ndim$1(x));
519
526
  return bind1(Primitive.Flip, [x], { axis });
520
527
  }
521
528
  function shrink(x, slice) {
@@ -589,21 +596,49 @@ var Trace = class {
589
596
  this.main = main;
590
597
  }
591
598
  };
599
+ /**
600
+ * Broadcast shapes and promote types with casting for two avals.
601
+ *
602
+ * This implements the weak type behavior described in `promoteTypes()`, but not
603
+ * implemented in that function as `weakType` is not passed.
604
+ */
605
+ function promoteAvals(a, b) {
606
+ const shape$1 = require_backend.generalBroadcast(a.shape, b.shape);
607
+ const weakType = a.weakType && b.weakType;
608
+ let dtype;
609
+ if (a.weakType === b.weakType) dtype = require_backend.promoteTypes(a.dtype, b.dtype);
610
+ else if (a.weakType) dtype = require_backend.promoteTypes(b.dtype, require_backend.DType.Uint32);
611
+ else dtype = require_backend.promoteTypes(a.dtype, require_backend.DType.Uint32);
612
+ return new ShapedArray(shape$1, dtype, weakType);
613
+ }
592
614
  var Tracer = class Tracer {
593
615
  /** @ignore */
594
616
  _trace;
595
617
  constructor(trace) {
596
618
  this._trace = trace;
597
619
  }
620
+ /** The shape of the array. */
598
621
  get shape() {
599
622
  return this.aval.shape;
600
623
  }
624
+ /** The total number of elements in the array. */
601
625
  get size() {
602
626
  return require_backend.prod(this.shape);
603
627
  }
628
+ /** The dtype of elements stored in the array. */
604
629
  get dtype() {
605
630
  return this.aval.dtype;
606
631
  }
632
+ /**
633
+ * Whether the array is weakly typed.
634
+ *
635
+ * Weakly typed arrays will cast to the dtype of the other operand. See
636
+ * `promoteTypes()` for details.
637
+ */
638
+ get weakType() {
639
+ return this.aval.weakType;
640
+ }
641
+ /** The number of dimensions of the array. */
607
642
  get ndim() {
608
643
  return this.shape.length;
609
644
  }
@@ -639,22 +674,20 @@ var Tracer = class Tracer {
639
674
  return lessEqual$1(this, other);
640
675
  }
641
676
  /** Sum of the elements of the array over a given axis, or axes. */
642
- sum(axis, opts) {
677
+ sum(axis = null, opts) {
643
678
  return reduce(this, require_backend.AluOp.Add, axis, opts);
644
679
  }
645
680
  /** Product of the array elements over a given axis. */
646
- prod(axis, opts) {
681
+ prod(axis = null, opts) {
647
682
  return reduce(this, require_backend.AluOp.Mul, axis, opts);
648
683
  }
649
684
  /** Compute the average of the array elements along the specified axis. */
650
- mean(axis, opts) {
651
- if (axis === void 0) axis = require_backend.range(this.ndim);
652
- else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, this.ndim)];
653
- else axis = axis.map((a) => require_backend.checkAxis(a, this.ndim));
654
- let result = reduce(this, require_backend.AluOp.Add, axis);
655
- result = result.mul(result.size / this.size);
656
- if (opts?.keepDims) result = broadcast(result, this.shape, axis);
657
- return result;
685
+ mean(axis = null, opts) {
686
+ axis = require_backend.normalizeAxis(axis, this.ndim);
687
+ const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
688
+ if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
689
+ const result = reduce(this, require_backend.AluOp.Add, axis, opts);
690
+ return result.mul(1 / n);
658
691
  }
659
692
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
660
693
  transpose(perm) {
@@ -841,12 +874,13 @@ function getShape(x) {
841
874
  return x instanceof Tracer ? x.shape : [];
842
875
  }
843
876
  var ShapedArray = class ShapedArray {
844
- constructor(shape$1, dtype) {
877
+ constructor(shape$1, dtype, weakType) {
845
878
  this.shape = shape$1;
846
879
  this.dtype = dtype;
880
+ this.weakType = weakType;
847
881
  }
848
882
  static fromAval(aval) {
849
- return new ShapedArray(aval.shape, aval.dtype);
883
+ return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
850
884
  }
851
885
  get ndim() {
852
886
  return this.shape.length;
@@ -860,7 +894,7 @@ var ShapedArray = class ShapedArray {
860
894
  };
861
895
  function getAval(x) {
862
896
  if (x instanceof Tracer) return x.aval;
863
- else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32);
897
+ else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32, typeof x === "boolean" ? false : true);
864
898
  else throw new TypeError(`Unknown value: ${x}`);
865
899
  }
866
900
  function bind(prim, args, params = {}) {
@@ -1145,7 +1179,7 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1145
1179
  }
1146
1180
  function broadcastedJit(fn) {
1147
1181
  return (nargs, exps, avals, params) => {
1148
- const newShape = avals.map((aval) => aval.shape).reduce(generalBroadcast);
1182
+ const newShape = avals.map((aval) => aval.shape).reduce(require_backend.generalBroadcast);
1149
1183
  exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
1150
1184
  if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1151
1185
  }));
@@ -1182,11 +1216,13 @@ const jitRules = {
1182
1216
  const k1 = reshapeViews(keys[1], mapping);
1183
1217
  const c0 = require_backend.AluExp.u32(0);
1184
1218
  const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
1185
- const exp$2 = require_backend.AluExp.threefry2x32(c0, c1, k0, k1, mode);
1219
+ const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
1186
1220
  return new require_backend.Kernel(nargs, require_backend.prod(shape$1), exp$2);
1187
1221
  },
1188
1222
  [Primitive.Sin]: unopJit(require_backend.AluExp.sin),
1189
1223
  [Primitive.Cos]: unopJit(require_backend.AluExp.cos),
1224
+ [Primitive.Asin]: unopJit(require_backend.AluExp.asin),
1225
+ [Primitive.Atan]: unopJit(require_backend.AluExp.atan),
1190
1226
  [Primitive.Exp]: unopJit(require_backend.AluExp.exp),
1191
1227
  [Primitive.Log]: unopJit(require_backend.AluExp.log),
1192
1228
  [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
@@ -1221,7 +1257,7 @@ const jitRules = {
1221
1257
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
1222
1258
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
1223
1259
  const c = k1.exp;
1224
- const cs = new ShapedArray(generalBroadcast(as.shape, bs.shape), c.dtype);
1260
+ const cs = promoteAvals(as, bs);
1225
1261
  return jitRules[Primitive.Reduce](nargs, [c], [cs], {
1226
1262
  op: require_backend.AluOp.Add,
1227
1263
  axis: [cs.ndim - 1]
@@ -1231,8 +1267,8 @@ const jitRules = {
1231
1267
  const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
1232
1268
  a = reshapeViews(a, (st) => st.compose(stX));
1233
1269
  b = reshapeViews(b, (st) => st.compose(stY));
1234
- as = new ShapedArray(stX.shape, as.dtype);
1235
- bs = new ShapedArray(stY.shape, bs.dtype);
1270
+ as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1271
+ bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1236
1272
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1237
1273
  },
1238
1274
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
@@ -1249,7 +1285,7 @@ const jitRules = {
1249
1285
  [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
1250
1286
  [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1251
1287
  const axisSet = new Set(axis);
1252
- const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
1288
+ const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
1253
1289
  const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
1254
1290
  finalShape.splice(outDim, 0, ...indexShape);
1255
1291
  const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
@@ -1285,9 +1321,10 @@ function splitGraphDataflow(backend, jaxpr) {
1285
1321
  Primitive.Conv,
1286
1322
  Primitive.PoolTranspose
1287
1323
  ];
1324
+ const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
1288
1325
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
1289
1326
  const eqn = jaxpr.eqns[i];
1290
- if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1327
+ if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
1291
1328
  for (const v of eqn.outBinders) {
1292
1329
  blackNodes.add(v);
1293
1330
  p1NextBlack.set(v, v);
@@ -1417,6 +1454,7 @@ var Array$1 = class Array$1 extends Tracer {
1417
1454
  static #nextId = 1001;
1418
1455
  id;
1419
1456
  #dtype;
1457
+ #weakType;
1420
1458
  #source;
1421
1459
  #st;
1422
1460
  #backend;
@@ -1428,19 +1466,22 @@ var Array$1 = class Array$1 extends Tracer {
1428
1466
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1429
1467
  * will be freed when the array is disposed.
1430
1468
  */
1431
- constructor(source, st, dtype, backend, pending = null) {
1469
+ constructor(args) {
1432
1470
  super(baseArrayTrace);
1433
1471
  this.id = Array$1.#nextId++;
1434
- this.#dtype = dtype;
1435
- this.#source = source;
1436
- this.#st = st;
1437
- this.#backend = backend;
1472
+ this.#dtype = args.dtype;
1473
+ this.#weakType = args.weakType;
1474
+ this.#source = args.source;
1475
+ this.#st = args.st;
1476
+ this.#backend = args.backend;
1438
1477
  this.#rc = 1;
1439
- this.#pendingSet = new Set(pending);
1478
+ this.#pendingSet = new Set(args.pending);
1479
+ if (this.#pendingSet.size === 0) this.#pendingSet = null;
1480
+ else if (this.#source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1440
1481
  }
1441
1482
  /** @ignore */
1442
1483
  get aval() {
1443
- return new ShapedArray(this.#st.shape, this.#dtype);
1484
+ return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
1444
1485
  }
1445
1486
  /** Return a simple string representation of the array's dimensions. */
1446
1487
  toString() {
@@ -1452,6 +1493,17 @@ var Array$1 = class Array$1 extends Tracer {
1452
1493
  #check() {
1453
1494
  if (this.#rc <= 0) throw new UseAfterFreeError(this);
1454
1495
  }
1496
+ /** Construct an array, copying fields from `this`. */
1497
+ #newArrayFrom(args) {
1498
+ return new Array$1({
1499
+ source: args.source ?? this.#source,
1500
+ st: args.st ?? this.#st,
1501
+ dtype: args.dtype ?? this.#dtype,
1502
+ weakType: this.#weakType,
1503
+ backend: args.backend ?? this.#backend,
1504
+ pending: args.pending ?? this.#pending ?? void 0
1505
+ });
1506
+ }
1455
1507
  get ref() {
1456
1508
  this.#check();
1457
1509
  this.#rc++;
@@ -1491,7 +1543,10 @@ var Array$1 = class Array$1 extends Tracer {
1491
1543
  const pending = this.#pending;
1492
1544
  for (const exe of pending) exe.updateRc(1);
1493
1545
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1494
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
1546
+ const ar = this.#newArrayFrom({
1547
+ st,
1548
+ pending
1549
+ });
1495
1550
  this.dispose();
1496
1551
  return ar;
1497
1552
  }
@@ -1540,7 +1595,11 @@ var Array$1 = class Array$1 extends Tracer {
1540
1595
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1541
1596
  this.dispose();
1542
1597
  for (const ar of indices) ar.dispose();
1543
- return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
1598
+ return this.#newArrayFrom({
1599
+ source: output,
1600
+ st: require_backend.ShapeTracker.fromShape(finalShape),
1601
+ pending
1602
+ });
1544
1603
  }
1545
1604
  /** Move axes to the rightmost dimension of the shape. */
1546
1605
  #moveAxesDown(axis) {
@@ -1563,11 +1622,16 @@ var Array$1 = class Array$1 extends Tracer {
1563
1622
  return this.#reshape(this.#st.permute(perm));
1564
1623
  }
1565
1624
  #unary(op, dtypeOutput) {
1625
+ const weakType = !dtypeOutput && this.#weakType;
1566
1626
  dtypeOutput ??= this.#dtype;
1567
1627
  this.#check();
1568
1628
  if (this.#source instanceof require_backend.AluExp) {
1569
1629
  const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
1570
- return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1630
+ return this.#newArrayFrom({
1631
+ source: exp$3.simplify(),
1632
+ dtype: dtypeOutput,
1633
+ weakType
1634
+ });
1571
1635
  }
1572
1636
  const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
1573
1637
  const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1577,41 +1641,65 @@ var Array$1 = class Array$1 extends Tracer {
1577
1641
  for (const exe of pending) exe.updateRc(1);
1578
1642
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1579
1643
  this.dispose();
1580
- return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
1644
+ return this.#newArrayFrom({
1645
+ source: output,
1646
+ st: require_backend.ShapeTracker.fromShape(this.shape),
1647
+ dtype: dtypeOutput,
1648
+ weakType,
1649
+ pending
1650
+ });
1581
1651
  }
1582
1652
  #binary(op, other) {
1583
- const custom = (src) => new require_backend.AluExp(op, this.#dtype, src);
1653
+ const custom = (src) => new require_backend.AluExp(op, src[0].dtype, src);
1584
1654
  return Array$1.#naryCustom(op, custom, [this, other]);
1585
1655
  }
1586
- static #naryCustom(name, custom, arrays, { dtypeOverride, dtypeOutput, reduceAxis } = {}) {
1656
+ static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1587
1657
  const n = arrays.length;
1588
1658
  const backend = arrays[0].#backend;
1589
1659
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1590
1660
  for (const ar of arrays) ar.#check();
1591
- let dtype;
1661
+ let castDtype;
1662
+ let castWeakType = true;
1592
1663
  for (let i = 0; i < n; i++) {
1593
1664
  if (dtypeOverride?.[i]) {
1594
1665
  if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1595
- } else if (!dtype) dtype = arrays[i].#dtype;
1596
- else if (arrays[i].#dtype !== dtype) throw new TypeError(`Dtype mismatch in ${name}: ${dtype} vs ${arrays[i].#dtype}`);
1666
+ } else if (castDtype === void 0) {
1667
+ castDtype = arrays[i].#dtype;
1668
+ castWeakType = arrays[i].#weakType;
1669
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1597
1670
  if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1598
1671
  }
1599
- dtypeOutput ??= dtype;
1600
- if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1672
+ const weakType = castWeakType && !strongTypeOutput;
1601
1673
  arrays = Array$1.#broadcastArrays(arrays);
1602
1674
  const newShape = [...arrays[0].shape];
1603
1675
  if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
1676
+ const sources = arrays.map((ar, i) => {
1677
+ if (!dtypeOverride?.[i]) return require_backend.AluExp.cast(castDtype, ar.#source);
1678
+ else return ar.#source;
1679
+ });
1604
1680
  if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
1605
- const exp$4 = custom(arrays.map((ar) => ar.#source));
1606
- return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1681
+ const exp$4 = custom(sources);
1682
+ return new Array$1({
1683
+ source: exp$4.simplify(),
1684
+ st: arrays[0].#st,
1685
+ dtype: exp$4.dtype,
1686
+ weakType,
1687
+ backend
1688
+ });
1607
1689
  }
1608
- const exp$3 = custom(arrays.map((ar) => {
1609
- const src$1 = ar.#source;
1690
+ const exp$3 = custom(arrays.map((ar, i) => {
1691
+ const src$1 = sources[i];
1610
1692
  if (ar.#st.contiguous) return src$1;
1611
1693
  return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
1612
1694
  }));
1613
1695
  const st = require_backend.ShapeTracker.fromShape(newShape);
1614
- return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1696
+ return new Array$1({
1697
+ source: exp$3.simplify(),
1698
+ st,
1699
+ dtype: exp$3.dtype,
1700
+ weakType,
1701
+ backend
1702
+ });
1615
1703
  }
1616
1704
  let indices;
1617
1705
  if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
@@ -1621,14 +1709,19 @@ var Array$1 = class Array$1 extends Tracer {
1621
1709
  }
1622
1710
  const inputs = [];
1623
1711
  const src = [];
1624
- for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#source, ar.#st, indices));
1625
- else {
1626
- let gid = inputs.indexOf(ar.#source);
1627
- if (gid === -1) {
1628
- gid = inputs.length;
1629
- inputs.push(ar.#source);
1712
+ for (const [i, ar] of arrays.entries()) {
1713
+ let nextSrc;
1714
+ if (ar.#source instanceof require_backend.AluExp) nextSrc = require_backend.accessorAluExp(ar.#source, ar.#st, indices);
1715
+ else {
1716
+ let gid = inputs.indexOf(ar.#source);
1717
+ if (gid === -1) {
1718
+ gid = inputs.length;
1719
+ inputs.push(ar.#source);
1720
+ }
1721
+ nextSrc = require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
1630
1722
  }
1631
- src.push(require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, indices));
1723
+ if (!dtypeOverride?.[i]) nextSrc = require_backend.AluExp.cast(castDtype, nextSrc);
1724
+ src.push(nextSrc);
1632
1725
  }
1633
1726
  const exp$2 = custom(src);
1634
1727
  let re = void 0;
@@ -1642,12 +1735,17 @@ var Array$1 = class Array$1 extends Tracer {
1642
1735
  for (const exe of pending) exe.updateRc(1);
1643
1736
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1644
1737
  for (const ar of arrays) ar.dispose();
1645
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
1738
+ return new Array$1({
1739
+ source: output,
1740
+ st: require_backend.ShapeTracker.fromShape(newShape),
1741
+ dtype: kernel.dtype,
1742
+ weakType,
1743
+ backend,
1744
+ pending
1745
+ });
1646
1746
  }
1647
1747
  /** Reduce the last dimension of the array by an operation. */
1648
1748
  #reduce(op) {
1649
- this.#check();
1650
- if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
1651
1749
  const shape$1 = this.shape;
1652
1750
  const reduction = new require_backend.Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
1653
1751
  const newShape = shape$1.slice(0, -1);
@@ -1666,7 +1764,11 @@ var Array$1 = class Array$1 extends Tracer {
1666
1764
  for (const exe of pending) exe.updateRc(1);
1667
1765
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1668
1766
  this.dispose();
1669
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
1767
+ return this.#newArrayFrom({
1768
+ source: output,
1769
+ st: require_backend.ShapeTracker.fromShape(newShape),
1770
+ pending
1771
+ });
1670
1772
  }
1671
1773
  /**
1672
1774
  * Normalizes this array into one backed by a `Slot`.
@@ -1702,15 +1804,15 @@ var Array$1 = class Array$1 extends Tracer {
1702
1804
  }
1703
1805
  #dataInline() {
1704
1806
  this.#check();
1705
- const exp$2 = this.#source;
1706
- const ar = new Array$1(exp$2, this.#st, this.dtype, require_backend.getBackend("cpu"));
1807
+ if (!(this.#source instanceof require_backend.AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
1808
+ const ar = this.#newArrayFrom({ backend: require_backend.getBackend("cpu") });
1707
1809
  this.dispose();
1708
1810
  return ar.dataSync();
1709
1811
  }
1710
1812
  static #broadcastArrays(arrays) {
1711
1813
  if (arrays.length === 0) throw new Error("Need at least one array to broadcast");
1712
1814
  if (arrays.length === 1) return arrays;
1713
- const newShape = arrays.map((a) => a.shape).reduce(generalBroadcast);
1815
+ const newShape = arrays.map((a) => a.shape).reduce(require_backend.generalBroadcast);
1714
1816
  return arrays.map((ar) => {
1715
1817
  if (require_backend.deepEqual(ar.shape, newShape)) return ar;
1716
1818
  return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
@@ -1739,8 +1841,11 @@ var Array$1 = class Array$1 extends Tracer {
1739
1841
  *
1740
1842
  * If you are mapping from `data()` or `dataSync()`, it will also trigger
1741
1843
  * dispatch of operations as well.
1844
+ *
1845
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1846
+ * asynchronously for multiple arrays.
1742
1847
  */
1743
- async wait() {
1848
+ async blockUntilReady() {
1744
1849
  this.#check();
1745
1850
  if (this.#source instanceof require_backend.AluExp) return this;
1746
1851
  const pending = this.#pending;
@@ -1806,7 +1911,7 @@ var Array$1 = class Array$1 extends Tracer {
1806
1911
  return [x.#binary(require_backend.AluOp.Idiv, y)];
1807
1912
  },
1808
1913
  [Primitive.Neg]([x]) {
1809
- return [zerosLike(x.ref).#binary(require_backend.AluOp.Sub, x)];
1914
+ return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
1810
1915
  },
1811
1916
  [Primitive.Reciprocal]([x]) {
1812
1917
  return [x.#unary(require_backend.AluOp.Reciprocal)];
@@ -1826,14 +1931,18 @@ var Array$1 = class Array$1 extends Tracer {
1826
1931
  x.#backend.incRef(x.#source);
1827
1932
  const pending = x.#pending;
1828
1933
  for (const exe of pending) exe.updateRc(1);
1829
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
1934
+ const y = x.#newArrayFrom({
1935
+ dtype,
1936
+ weakType: false,
1937
+ pending
1938
+ });
1830
1939
  x.dispose();
1831
1940
  return [y];
1832
1941
  }
1833
1942
  },
1834
1943
  [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
1835
- const keyShape = generalBroadcast(k0.shape, k1.shape);
1836
- if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1944
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
1945
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1837
1946
  const c0 = zeros(shape$1, {
1838
1947
  dtype: require_backend.DType.Uint32,
1839
1948
  device: k0.device
@@ -1856,6 +1965,12 @@ var Array$1 = class Array$1 extends Tracer {
1856
1965
  [Primitive.Cos]([x]) {
1857
1966
  return [x.#unary(require_backend.AluOp.Cos)];
1858
1967
  },
1968
+ [Primitive.Asin]([x]) {
1969
+ return [x.#unary(require_backend.AluOp.Asin)];
1970
+ },
1971
+ [Primitive.Atan]([x]) {
1972
+ return [x.#unary(require_backend.AluOp.Atan)];
1973
+ },
1859
1974
  [Primitive.Exp]([x]) {
1860
1975
  return [x.#unary(require_backend.AluOp.Exp)];
1861
1976
  },
@@ -1895,7 +2010,7 @@ var Array$1 = class Array$1 extends Tracer {
1895
2010
  },
1896
2011
  [Primitive.Compare]([x, y], { op }) {
1897
2012
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1898
- return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: require_backend.DType.Bool })];
2013
+ return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
1899
2014
  },
1900
2015
  [Primitive.Where]([cond, x, y]) {
1901
2016
  const custom = ([cond$1, x$1, y$1]) => require_backend.AluExp.where(cond$1, x$1, y$1);
@@ -1941,7 +2056,14 @@ var Array$1 = class Array$1 extends Tracer {
1941
2056
  pending.splice(0, 0, ...prevPending);
1942
2057
  args.forEach((x) => x.dispose());
1943
2058
  return outputs.map((source, i) => {
1944
- return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
2059
+ return new Array$1({
2060
+ source,
2061
+ st: require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
2062
+ dtype: jaxpr.outs[i].aval.dtype,
2063
+ weakType: jaxpr.outs[i].aval.weakType,
2064
+ backend,
2065
+ pending
2066
+ });
1945
2067
  });
1946
2068
  }
1947
2069
  };
@@ -1951,33 +2073,11 @@ var Array$1 = class Array$1 extends Tracer {
1951
2073
  return this.#source;
1952
2074
  }
1953
2075
  };
1954
- /** Construct an array from a single scalar constant. */
1955
- function scalar(value, { dtype, device } = {}) {
1956
- if (typeof value === "number") {
1957
- dtype ??= require_backend.DType.Float32;
1958
- if (![
1959
- require_backend.DType.Float32,
1960
- require_backend.DType.Float16,
1961
- require_backend.DType.Int32,
1962
- require_backend.DType.Uint32
1963
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1964
- } else if (typeof value === "boolean") {
1965
- dtype ??= require_backend.DType.Bool;
1966
- if (![
1967
- require_backend.DType.Float32,
1968
- require_backend.DType.Float16,
1969
- require_backend.DType.Int32,
1970
- require_backend.DType.Uint32,
1971
- require_backend.DType.Bool
1972
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1973
- } else throw new TypeError(`Invalid type for scalar ${value}`);
1974
- return new Array$1(require_backend.AluExp.const(dtype, value), require_backend.ShapeTracker.fromShape([]), dtype, require_backend.getBackend(device));
1975
- }
1976
2076
  /** Constructor for creating a new array from data. */
1977
2077
  function array(values, { shape: shape$1, dtype, device } = {}) {
1978
2078
  if (values instanceof Tracer) {
1979
2079
  if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
1980
- if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
2080
+ if (dtype && values.dtype !== dtype) values = values.astype(dtype);
1981
2081
  return values;
1982
2082
  } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
1983
2083
  dtype,
@@ -1999,6 +2099,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1999
2099
  dtype,
2000
2100
  device
2001
2101
  });
2102
+ if (size$1 === 1) return full(shape$1, flat[0], {
2103
+ dtype,
2104
+ device
2105
+ });
2002
2106
  if (typeof flat[0] === "boolean") {
2003
2107
  dtype = dtype ?? require_backend.DType.Bool;
2004
2108
  const data = new Int32Array(flat.map((x) => x ? 1 : 0));
@@ -2007,46 +2111,51 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
2007
2111
  device
2008
2112
  });
2009
2113
  } else {
2114
+ const weakType = dtype == void 0;
2010
2115
  dtype = dtype ?? require_backend.DType.Float32;
2011
2116
  const data = require_backend.dtypedJsArray(dtype, flat);
2012
2117
  return arrayFromData(data, shape$1, {
2013
2118
  dtype,
2014
2119
  device
2015
- });
2120
+ }, weakType);
2016
2121
  }
2017
2122
  }
2018
2123
  }
2019
- function arrayFromData(data, shape$1, { dtype, device } = {}) {
2124
+ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2125
+ if (data instanceof Float32Array) {
2126
+ if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
2127
+ dtype ??= require_backend.DType.Float32;
2128
+ } else if (data instanceof Int32Array) {
2129
+ if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2130
+ dtype ??= require_backend.DType.Int32;
2131
+ } else if (data instanceof Uint32Array) {
2132
+ if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2133
+ dtype ??= require_backend.DType.Uint32;
2134
+ } else if (data instanceof Float16Array) {
2135
+ if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2136
+ dtype ??= require_backend.DType.Float16;
2137
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
2020
2138
  if (data.length < inlineArrayLimit) {
2021
2139
  let allEqual = true;
2022
2140
  for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
2023
2141
  allEqual = false;
2024
2142
  break;
2025
2143
  }
2026
- if (allEqual) return full(shape$1, data[0], {
2027
- dtype,
2028
- device
2029
- });
2144
+ if (allEqual) {
2145
+ const sa = new ShapedArray(shape$1, dtype, weakType);
2146
+ return fullInternal(sa, data[0], device);
2147
+ }
2030
2148
  }
2031
2149
  const backend = require_backend.getBackend(device);
2032
- if (ArrayBuffer.isView(data)) {
2033
- const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2034
- if (data instanceof Float32Array) {
2035
- if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
2036
- dtype ??= require_backend.DType.Float32;
2037
- } else if (data instanceof Int32Array) {
2038
- if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2039
- dtype ??= require_backend.DType.Int32;
2040
- } else if (data instanceof Uint32Array) {
2041
- if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2042
- dtype ??= require_backend.DType.Uint32;
2043
- } else if (data instanceof Float16Array) {
2044
- if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2045
- dtype ??= require_backend.DType.Float16;
2046
- } else throw new Error("Unsupported data array type: " + data.constructor.name);
2047
- const slot = backend.malloc(data.byteLength, buf);
2048
- return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
2049
- } else throw new Error("Unsupported data type: " + data.constructor.name);
2150
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2151
+ const slot = backend.malloc(data.byteLength, buf);
2152
+ return new Array$1({
2153
+ source: slot,
2154
+ st: require_backend.ShapeTracker.fromShape(shape$1),
2155
+ dtype,
2156
+ weakType,
2157
+ backend
2158
+ });
2050
2159
  }
2051
2160
  function dataToJs(dtype, data, shape$1) {
2052
2161
  if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
@@ -2062,7 +2171,7 @@ function dataToJs(dtype, data, shape$1) {
2062
2171
  /** If x is a value, lift it into an array, otherwise leave it be. */
2063
2172
  function pureArray(x) {
2064
2173
  if (x instanceof Tracer) return x;
2065
- else return scalar(x);
2174
+ else return array(x);
2066
2175
  }
2067
2176
  var EvalTrace = class extends Trace {
2068
2177
  pure = (x) => pureArray(x);
@@ -2073,20 +2182,27 @@ var EvalTrace = class extends Trace {
2073
2182
  };
2074
2183
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
2075
2184
  const implRules = Array$1._implRules();
2076
- function zerosLike(val, dtype) {
2077
- const aval = getAval(val);
2078
- if (val instanceof Tracer) val.dispose();
2079
- return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2185
+ function fullInternal(aval, fillValue, device) {
2186
+ return new Array$1({
2187
+ source: require_backend.AluExp.const(aval.dtype, fillValue),
2188
+ st: require_backend.ShapeTracker.fromShape(aval.shape),
2189
+ dtype: aval.dtype,
2190
+ weakType: aval.weakType,
2191
+ backend: require_backend.getBackend(device)
2192
+ });
2080
2193
  }
2081
- function onesLike(val, dtype) {
2082
- const aval = getAval(val);
2083
- if (val instanceof Tracer) val.dispose();
2084
- return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2194
+ function zerosLike$1(val, dtype) {
2195
+ return fullLike(val, 0, dtype);
2196
+ }
2197
+ function onesLike$1(val, dtype) {
2198
+ return fullLike(val, 1, dtype);
2085
2199
  }
2086
2200
  function fullLike(val, fillValue, dtype) {
2087
2201
  const aval = getAval(val);
2088
2202
  if (val instanceof Tracer) val.dispose();
2089
- return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
2203
+ if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
2204
+ const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
2205
+ return fullInternal(sa, fillValue);
2090
2206
  }
2091
2207
  /** Return a new array of given shape and type, filled with zeros. */
2092
2208
  function zeros(shape$1, { dtype, device } = {}) {
@@ -2104,19 +2220,14 @@ function ones(shape$1, { dtype, device } = {}) {
2104
2220
  }
2105
2221
  /** Return a new array of given shape and type, filled with `fill_value`. */
2106
2222
  function full(shape$1, fillValue, { dtype, device } = {}) {
2107
- let source;
2108
- if (typeof fillValue === "number") {
2109
- dtype = dtype ?? require_backend.DType.Float32;
2110
- source = require_backend.AluExp.const(dtype, fillValue);
2111
- } else if (typeof fillValue === "bigint") {
2112
- dtype = dtype ?? require_backend.DType.Int32;
2113
- source = require_backend.AluExp.const(dtype, Number(fillValue));
2114
- } else if (typeof fillValue === "boolean") {
2223
+ let weakType = dtype == void 0;
2224
+ if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
2225
+ else if (typeof fillValue === "boolean") {
2115
2226
  dtype = dtype ?? require_backend.DType.Bool;
2116
- source = require_backend.AluExp.const(dtype, fillValue ? 1 : 0);
2227
+ weakType = false;
2117
2228
  } else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
2118
2229
  else throw new TypeError(`Invalid type for full: ${fillValue}`);
2119
- return new Array$1(source, require_backend.ShapeTracker.fromShape(shape$1), dtype ?? require_backend.DType.Float32, require_backend.getBackend(device));
2230
+ return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
2120
2231
  }
2121
2232
  /**
2122
2233
  * Create an identity matrix.
@@ -2126,6 +2237,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
2126
2237
  */
2127
2238
  function eye(numRows, numCols, { dtype, device } = {}) {
2128
2239
  numCols = numCols ?? numRows;
2240
+ const weakType = dtype == void 0;
2129
2241
  dtype = dtype ?? require_backend.DType.Float32;
2130
2242
  if (numCols < numRows) {
2131
2243
  const arr = eye(numCols, numRows, {
@@ -2139,9 +2251,15 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2139
2251
  device
2140
2252
  });
2141
2253
  const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
2142
- return new Array$1(require_backend.AluExp.cast(dtype, exp$2), require_backend.ShapeTracker.fromShape([numRows, numCols]), dtype, require_backend.getBackend(device));
2254
+ return new Array$1({
2255
+ source: require_backend.AluExp.cast(dtype, exp$2),
2256
+ st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
2257
+ dtype,
2258
+ weakType,
2259
+ backend: require_backend.getBackend(device)
2260
+ });
2143
2261
  }
2144
- /** Return the identity array, with ones on the main diagonal. */
2262
+ /** Return the identity matrix, with ones on the main diagonal. */
2145
2263
  function identity$1(n, { dtype, device } = {}) {
2146
2264
  return eye(n, n, {
2147
2265
  dtype,
@@ -2176,7 +2294,13 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2176
2294
  });
2177
2295
  const exp$2 = require_backend.AluExp.add(require_backend.AluExp.const(dtype, start), require_backend.AluExp.mul(require_backend.AluExp.cast(dtype, require_backend.AluVar.idx), require_backend.AluExp.const(dtype, step)));
2178
2296
  const st = require_backend.ShapeTracker.fromShape([size$1]);
2179
- return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
2297
+ return new Array$1({
2298
+ source: exp$2,
2299
+ st,
2300
+ dtype,
2301
+ weakType: false,
2302
+ backend: require_backend.getBackend(device)
2303
+ });
2180
2304
  }
2181
2305
  /**
2182
2306
  * Return evenly spaced numbers over a specified interval.
@@ -2194,10 +2318,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2194
2318
  dtype,
2195
2319
  device
2196
2320
  });
2197
- else if (num === 1) return scalar(start, {
2321
+ else if (num === 1) return full([1], start, {
2198
2322
  dtype,
2199
2323
  device
2200
- }).reshape([1]);
2324
+ });
2201
2325
  else if (start === stop) return full([num], start, {
2202
2326
  dtype,
2203
2327
  device
@@ -2206,7 +2330,13 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2206
2330
  const denom = endpoint ? num - 1 : num;
2207
2331
  const exp$2 = require_backend.AluExp.cast(dtype, require_backend.AluExp.add(require_backend.AluExp.f32(start), require_backend.AluExp.mul(require_backend.AluExp.f32(delta / denom), require_backend.AluExp.cast(require_backend.DType.Float32, require_backend.AluVar.idx))));
2208
2332
  const st = require_backend.ShapeTracker.fromShape([num]);
2209
- return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
2333
+ return new Array$1({
2334
+ source: exp$2,
2335
+ st,
2336
+ dtype,
2337
+ weakType: false,
2338
+ backend: require_backend.getBackend(device)
2339
+ });
2210
2340
  }
2211
2341
  function aluCompare(a, b, op) {
2212
2342
  switch (op) {
@@ -2218,35 +2348,6 @@ function aluCompare(a, b, op) {
2218
2348
  case CompareOp.LessEqual: return require_backend.AluExp.add(require_backend.AluExp.cmplt(a, b), require_backend.AluExp.cmpne(a, b).not());
2219
2349
  }
2220
2350
  }
2221
- /**
2222
- * Implements a NumPy-style generalized broadcast rule on two array shapes.
2223
- *
2224
- * "When operating on two arrays, NumPy compares their shapes element-wise. It
2225
- * starts with the trailing (i.e. rightmost) dimension and works its way left.
2226
- * Two dimensions are compatible when:
2227
- * 1. they are equal, or
2228
- * 2. one of them is 1."
2229
- *
2230
- * Throws a TypeError if the broadcast is not possible.
2231
- *
2232
- * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
2233
- */
2234
- function generalBroadcast(a, b) {
2235
- const out = [];
2236
- let i = a.length - 1;
2237
- let j = b.length - 1;
2238
- for (; i >= 0 && j >= 0; i--, j--) {
2239
- const x = a[i];
2240
- const y = b[j];
2241
- if (x === y) out.push(x);
2242
- else if (x === 1) out.push(y);
2243
- else if (y === 1) out.push(x);
2244
- else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
2245
- }
2246
- for (; i >= 0; i--) out.push(a[i]);
2247
- for (; j >= 0; j--) out.push(b[j]);
2248
- return out.reverse();
2249
- }
2250
2351
 
2251
2352
  //#endregion
2252
2353
  //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
@@ -2326,13 +2427,15 @@ var Var = class Var {
2326
2427
  };
2327
2428
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
2328
2429
  var Lit = class {
2329
- dtype;
2330
2430
  value;
2331
2431
  aval;
2332
- constructor(dtype, value) {
2333
- this.dtype = dtype;
2432
+ get dtype() {
2433
+ return this.aval.dtype;
2434
+ }
2435
+ constructor(aval, value) {
2436
+ if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
2334
2437
  this.value = value;
2335
- this.aval = new ShapedArray([], dtype);
2438
+ this.aval = ShapedArray.fromAval(aval);
2336
2439
  }
2337
2440
  };
2338
2441
  function atomIsLit(atom, literal) {
@@ -2421,16 +2524,19 @@ var Jaxpr = class Jaxpr {
2421
2524
  varIds.set(v, require_backend.FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
2422
2525
  return id;
2423
2526
  };
2424
- hasher.update(this.inBinders.length, ...this.inBinders.map(vi));
2425
- hasher.update(this.eqns.length, ...this.eqns.flatMap((eqn) => [
2426
- eqn.primitive,
2427
- eqn.inputs.length,
2428
- ...eqn.inputs.map((x) => x instanceof Var ? vi(x) : x.value),
2429
- JSON.stringify(eqn.params),
2430
- eqn.outBinders.length,
2431
- ...eqn.outBinders.map(vi)
2432
- ]));
2433
- hasher.update(this.outs.length, ...this.outs.map((x) => x instanceof Var ? vi(x) : x.value));
2527
+ hasher.update(this.inBinders.length);
2528
+ for (const x of this.inBinders) hasher.update(vi(x));
2529
+ hasher.update(this.eqns.length);
2530
+ for (const eqn of this.eqns) {
2531
+ hasher.update(eqn.primitive);
2532
+ hasher.update(eqn.inputs.length);
2533
+ for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
2534
+ hasher.update(JSON.stringify(eqn.params));
2535
+ hasher.update(eqn.outBinders.length);
2536
+ for (const x of eqn.outBinders) hasher.update(vi(x));
2537
+ }
2538
+ hasher.update(this.outs.length);
2539
+ for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
2434
2540
  return this.#hash = hasher.value;
2435
2541
  }
2436
2542
  hash(state) {
@@ -2453,21 +2559,26 @@ var Jaxpr = class Jaxpr {
2453
2559
  const c = eqn.outBinders[0];
2454
2560
  if (atomIsLit(a, 0)) context.set(c, b);
2455
2561
  else if (atomIsLit(b, 0)) context.set(c, a);
2456
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.dtype === require_backend.DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2562
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === require_backend.DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2563
+ else newEqns.push(eqn);
2564
+ } else if (eqn.primitive === Primitive.Neg) {
2565
+ const [a] = inputs;
2566
+ const c = eqn.outBinders[0];
2567
+ if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
2457
2568
  else newEqns.push(eqn);
2458
2569
  } else if (eqn.primitive === Primitive.Mul) {
2459
2570
  const [a, b] = inputs;
2460
2571
  const c = eqn.outBinders[0];
2461
2572
  if (atomIsLit(a, 1)) context.set(c, b);
2462
2573
  else if (atomIsLit(b, 1)) context.set(c, a);
2463
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.value * b.value));
2574
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
2464
2575
  else newEqns.push(eqn);
2465
2576
  } else if (eqn.primitive === Primitive.Idiv) {
2466
2577
  const [a, b] = inputs;
2467
2578
  const c = eqn.outBinders[0];
2468
2579
  if (atomIsLit(b, 1)) context.set(c, a);
2469
2580
  else newEqns.push(eqn);
2470
- } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2581
+ } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2471
2582
  else newEqns.push(eqn);
2472
2583
  }
2473
2584
  const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
@@ -2558,7 +2669,7 @@ function evalJaxpr(jaxpr, args) {
2558
2669
  if (x instanceof Var) {
2559
2670
  remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
2560
2671
  return env.get(x);
2561
- } else return scalar(x.value, { dtype: x.dtype });
2672
+ } else return array(x.value, { dtype: x.dtype });
2562
2673
  };
2563
2674
  const write = (v, val) => {
2564
2675
  if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
@@ -2617,7 +2728,7 @@ var JaxprTrace = class extends Trace {
2617
2728
  let tracer = this.builder.constTracers.get(val);
2618
2729
  if (tracer === void 0) {
2619
2730
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
2620
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : scalar(val));
2731
+ this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
2621
2732
  }
2622
2733
  return tracer;
2623
2734
  }
@@ -2686,7 +2797,7 @@ function _inlineLiterals(jaxpr, consts) {
2686
2797
  const newConsts = [];
2687
2798
  for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
2688
2799
  const ar = consts[i];
2689
- literals.set(jaxpr.inBinders[i], new Lit(ar.dtype, ar.dataSync()[0]));
2800
+ literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
2690
2801
  } else {
2691
2802
  constBinders.push(jaxpr.inBinders[i]);
2692
2803
  newConsts.push(consts[i]);
@@ -2699,13 +2810,12 @@ function _inlineLiterals(jaxpr, consts) {
2699
2810
  }
2700
2811
  function binopAbstractEval([x, y]) {
2701
2812
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
2702
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2703
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
2813
+ return [promoteAvals(x, y)];
2704
2814
  }
2705
2815
  function compareAbstractEval([x, y]) {
2706
2816
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
2707
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2708
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), require_backend.DType.Bool)];
2817
+ const aval = promoteAvals(x, y);
2818
+ return [new ShapedArray(aval.shape, require_backend.DType.Bool, false)];
2709
2819
  }
2710
2820
  function vectorizedUnopAbstractEval([x]) {
2711
2821
  return [ShapedArray.fromAval(x)];
@@ -2718,21 +2828,23 @@ const abstractEvalRules = {
2718
2828
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
2719
2829
  [Primitive.StopGradient]: vectorizedUnopAbstractEval,
2720
2830
  [Primitive.Cast]([x], { dtype }) {
2721
- return [new ShapedArray(x.shape, dtype)];
2831
+ return [new ShapedArray(x.shape, dtype, false)];
2722
2832
  },
2723
2833
  [Primitive.Bitcast]([x], { dtype }) {
2724
2834
  if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
2725
2835
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
2726
- return [new ShapedArray(x.shape, dtype)];
2836
+ return [new ShapedArray(x.shape, dtype, false)];
2727
2837
  },
2728
2838
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
2729
2839
  if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
2730
- const keyShape = generalBroadcast(k0.shape, k1.shape);
2731
- if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2732
- return [new ShapedArray(shape$1, require_backend.DType.Uint32)];
2840
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2841
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2842
+ return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
2733
2843
  },
2734
2844
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2735
2845
  [Primitive.Cos]: vectorizedUnopAbstractEval,
2846
+ [Primitive.Asin]: vectorizedUnopAbstractEval,
2847
+ [Primitive.Atan]: vectorizedUnopAbstractEval,
2736
2848
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2737
2849
  [Primitive.Log]: vectorizedUnopAbstractEval,
2738
2850
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
@@ -2741,55 +2853,54 @@ const abstractEvalRules = {
2741
2853
  [Primitive.Reduce]([x], { axis }) {
2742
2854
  const axisSet = new Set(axis);
2743
2855
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2744
- return [new ShapedArray(newShape, x.dtype)];
2856
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2745
2857
  },
2746
2858
  [Primitive.Pool]([x], { window, strides }) {
2747
2859
  const shape$1 = checkPoolShape(x.shape, window, strides);
2748
- return [new ShapedArray(shape$1, x.dtype)];
2860
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2749
2861
  },
2750
2862
  [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2751
2863
  const shape$1 = checkPoolShape(inShape, window, strides);
2752
2864
  if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2753
- return [new ShapedArray(inShape, x.dtype)];
2865
+ return [new ShapedArray(inShape, x.dtype, x.weakType)];
2754
2866
  },
2755
2867
  [Primitive.Dot]([x, y]) {
2756
- if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2757
2868
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
2758
- const shape$1 = generalBroadcast(x.shape, y.shape);
2869
+ const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
2759
2870
  shape$1.splice(-1, 1);
2760
- return [new ShapedArray(shape$1, x.dtype)];
2871
+ return [new ShapedArray(shape$1, dtype, weakType)];
2761
2872
  },
2762
2873
  [Primitive.Conv]([lhs, rhs], params) {
2763
- if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2874
+ const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
2764
2875
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2765
- return [new ShapedArray(shape$1, lhs.dtype)];
2876
+ return [new ShapedArray(shape$1, dtype, weakType)];
2766
2877
  },
2767
2878
  [Primitive.Compare]: compareAbstractEval,
2768
2879
  [Primitive.Where]([cond, x, y]) {
2769
2880
  if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
2770
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2771
- const shape$1 = generalBroadcast(cond.shape, generalBroadcast(x.shape, y.shape));
2772
- return [new ShapedArray(shape$1, x.dtype)];
2881
+ const xy = promoteAvals(x, y);
2882
+ const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
2883
+ return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
2773
2884
  },
2774
2885
  [Primitive.Transpose]([x], { perm }) {
2775
- return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
2886
+ return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
2776
2887
  },
2777
2888
  [Primitive.Broadcast]([x], { shape: shape$1 }) {
2778
- return [new ShapedArray(shape$1, x.dtype)];
2889
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2779
2890
  },
2780
2891
  [Primitive.Reshape]([x], { shape: shape$1 }) {
2781
- return [new ShapedArray(shape$1, x.dtype)];
2892
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2782
2893
  },
2783
2894
  [Primitive.Flip]([x], _) {
2784
- return [new ShapedArray(x.shape, x.dtype)];
2895
+ return [ShapedArray.fromAval(x)];
2785
2896
  },
2786
2897
  [Primitive.Shrink]([x], { slice }) {
2787
2898
  const newShape = slice.map((s) => s[1] - s[0]);
2788
- return [new ShapedArray(newShape, x.dtype)];
2899
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2789
2900
  },
2790
2901
  [Primitive.Pad]([x], { width }) {
2791
2902
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
2792
- return [new ShapedArray(newShape, x.dtype)];
2903
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2793
2904
  },
2794
2905
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2795
2906
  for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
@@ -2799,10 +2910,10 @@ const abstractEvalRules = {
2799
2910
  if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
2800
2911
  const axisSet = new Set(axis);
2801
2912
  if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
2802
- const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
2913
+ const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
2803
2914
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2804
2915
  newShape.splice(outDim, 0, ...gatherShape);
2805
- return [new ShapedArray(newShape, x.dtype)];
2916
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2806
2917
  },
2807
2918
  [Primitive.JitCall](args, { jaxpr }) {
2808
2919
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
@@ -2860,7 +2971,7 @@ function makeJaxpr$1(f, opts) {
2860
2971
  function jit$1(f, opts) {
2861
2972
  const cache = /* @__PURE__ */ new Map();
2862
2973
  const staticArgnums = new Set(opts?.staticArgnums ?? []);
2863
- return ((...args) => {
2974
+ const result = ((...args) => {
2864
2975
  const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
2865
2976
  const [argsFlat, inTree] = flatten(dynamicArgs);
2866
2977
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
@@ -2869,11 +2980,16 @@ function jit$1(f, opts) {
2869
2980
  const cacheKey = JSON.stringify(jaxprArgs);
2870
2981
  const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2871
2982
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
2983
+ name: f.name || "closure",
2872
2984
  jaxpr,
2873
2985
  numConsts: consts.length
2874
2986
  });
2875
2987
  return unflatten(outTree, outs);
2876
2988
  });
2989
+ result.dispose = () => {
2990
+ for (const { consts } of cache.values()) for (const c of consts) c.dispose();
2991
+ };
2992
+ return result;
2877
2993
  }
2878
2994
 
2879
2995
  //#endregion
@@ -2905,7 +3021,7 @@ var JVPTrace = class extends Trace {
2905
3021
  return this.lift(pureArray(val));
2906
3022
  }
2907
3023
  lift(val) {
2908
- return new JVPTracer(this, val, zerosLike(val.ref));
3024
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
2909
3025
  }
2910
3026
  processPrimitive(primitive, tracers, params) {
2911
3027
  const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
@@ -2936,7 +3052,7 @@ function zeroTangentsJvp(primitive) {
2936
3052
  return (primals, tangents, params) => {
2937
3053
  for (const t of tangents) t.dispose();
2938
3054
  const ys = bind(primitive, primals, params);
2939
- return [ys, ys.map((y) => zerosLike(y.ref))];
3055
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
2940
3056
  };
2941
3057
  }
2942
3058
  const jvpRules = {
@@ -2954,13 +3070,13 @@ const jvpRules = {
2954
3070
  if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
2955
3071
  else {
2956
3072
  dx.dispose();
2957
- return [[cast(x.ref, dtype)], [zerosLike(x)]];
3073
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
2958
3074
  }
2959
3075
  },
2960
3076
  [Primitive.Bitcast]([x], [dx], { dtype }) {
2961
3077
  if (x.dtype === dtype) return [[x], [dx]];
2962
3078
  dx.dispose();
2963
- return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
3079
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
2964
3080
  },
2965
3081
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
2966
3082
  [Primitive.Sin]([x], [dx]) {
@@ -2969,6 +3085,14 @@ const jvpRules = {
2969
3085
  [Primitive.Cos]([x], [dx]) {
2970
3086
  return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
2971
3087
  },
3088
+ [Primitive.Asin]([x], [dx]) {
3089
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3090
+ return [[asin$1(x)], [denom.mul(dx)]];
3091
+ },
3092
+ [Primitive.Atan]([x], [dx]) {
3093
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3094
+ return [[atan$1(x)], [dx.div(denom)]];
3095
+ },
2972
3096
  [Primitive.Exp]([x], [dx]) {
2973
3097
  const z = exp$1(x);
2974
3098
  return [[z.ref], [z.mul(dx)]];
@@ -3019,13 +3143,14 @@ const jvpRules = {
3019
3143
  const indicesRef = indices.map((t) => t.ref);
3020
3144
  return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3021
3145
  },
3022
- [Primitive.JitCall](primals, tangents, { jaxpr }) {
3146
+ [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3023
3147
  const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3024
3148
  const outs = bind(Primitive.JitCall, [
3025
3149
  ...newConsts.map((c) => c.ref),
3026
3150
  ...primals,
3027
3151
  ...tangents
3028
3152
  ], {
3153
+ name: `${name}_jvp`,
3029
3154
  jaxpr: newJaxpr,
3030
3155
  numConsts: newConsts.length
3031
3156
  });
@@ -3080,12 +3205,15 @@ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3080
3205
  function mappedAval(batchDim, aval) {
3081
3206
  const shape$1 = [...aval.shape];
3082
3207
  shape$1.splice(batchDim, 1);
3083
- return new ShapedArray(shape$1, aval.dtype);
3208
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3084
3209
  }
3085
3210
  /** Move one axis to a different index. */
3086
3211
  function moveaxis$1(x, src, dst) {
3087
3212
  const t = pureArray(x);
3088
- const perm = require_backend.range(t.shape.length);
3213
+ src = require_backend.checkAxis(src, t.ndim);
3214
+ dst = require_backend.checkAxis(dst, t.ndim);
3215
+ if (src === dst) return t;
3216
+ const perm = require_backend.range(t.ndim);
3089
3217
  perm.splice(src, 1);
3090
3218
  perm.splice(dst, 0, src);
3091
3219
  return transpose$1(t, perm);
@@ -3178,6 +3306,8 @@ const vmapRules = {
3178
3306
  [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
3179
3307
  [Primitive.Sin]: unopBatcher(sin$1),
3180
3308
  [Primitive.Cos]: unopBatcher(cos$1),
3309
+ [Primitive.Asin]: unopBatcher(asin$1),
3310
+ [Primitive.Atan]: unopBatcher(atan$1),
3181
3311
  [Primitive.Exp]: unopBatcher(exp$1),
3182
3312
  [Primitive.Log]: unopBatcher(log$1),
3183
3313
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
@@ -3219,9 +3349,10 @@ const vmapRules = {
3219
3349
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3220
3350
  return [[pad$1(x, newWidth)], [xBdim]];
3221
3351
  },
3222
- [Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
3352
+ [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3223
3353
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3224
3354
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3355
+ name: `${name}_vmap`,
3225
3356
  jaxpr: newJaxpr,
3226
3357
  numConsts: newConsts.length
3227
3358
  });
@@ -3237,7 +3368,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3237
3368
  if (dims[i] === null) return v.aval;
3238
3369
  const shape$1 = [...v.aval.shape];
3239
3370
  shape$1.splice(dims[i], 0, axisSize);
3240
- return new ShapedArray(shape$1, v.aval.dtype);
3371
+ return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3241
3372
  });
3242
3373
  const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3243
3374
  const result = {
@@ -3363,20 +3494,28 @@ function linearizeFlatUtil(f, primalsIn) {
3363
3494
  function linearizeFlat(f, primalsIn) {
3364
3495
  const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
3365
3496
  const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3366
- return [primalsOut, fLin];
3497
+ const dispose$1 = () => {
3498
+ for (const c of consts) c.dispose();
3499
+ };
3500
+ return [
3501
+ primalsOut,
3502
+ fLin,
3503
+ dispose$1
3504
+ ];
3367
3505
  }
3368
3506
  function linearize$1(f, ...primalsIn) {
3369
3507
  const [primalsInFlat, inTree] = flatten(primalsIn);
3370
3508
  const [fFlat, outTree] = flattenFun(f, inTree);
3371
- const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3509
+ const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3372
3510
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
3373
3511
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3374
- const fLin = (...tangentsIn) => {
3512
+ const fLin = ((...tangentsIn) => {
3375
3513
  const [tangentsInFlat, inTree2] = flatten(tangentsIn);
3376
3514
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
3377
3515
  const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
3378
3516
  return unflatten(outTree.value, tangentsOutFlat);
3379
- };
3517
+ });
3518
+ fLin.dispose = dispose$1;
3380
3519
  return [primalsOut, fLin];
3381
3520
  }
3382
3521
  var PartialEvalTracer = class extends Tracer {
@@ -3442,8 +3581,8 @@ var PartialEvalTrace = class extends Trace {
3442
3581
  processPrimitive(primitive, tracers, params) {
3443
3582
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3444
3583
  if (primitive === Primitive.JitCall) {
3445
- const { jaxpr, numConsts } = params;
3446
- return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
3584
+ const { name, jaxpr, numConsts } = params;
3585
+ return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3447
3586
  }
3448
3587
  const tracersIn = tracers.map((t) => this.instantiateConst(t));
3449
3588
  const avalsIn = tracersIn.map((t) => t.pval.aval);
@@ -3469,12 +3608,13 @@ var PartialEvalTrace = class extends Trace {
3469
3608
  *
3470
3609
  * Used when encountering a JitCall rule during the trace.
3471
3610
  */
3472
- #partialEvalJaxpr(jaxpr, numConsts, tracers) {
3611
+ #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3473
3612
  jaxpr = jaxpr.flatten();
3474
3613
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3475
3614
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3476
3615
  const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
3477
3616
  const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
3617
+ name: `${name}_peval`,
3478
3618
  jaxpr: jaxpr1,
3479
3619
  numConsts: 0
3480
3620
  });
@@ -3486,13 +3626,17 @@ var PartialEvalTrace = class extends Trace {
3486
3626
  prim: Primitive.JitCall,
3487
3627
  tracersIn: resTracers.concat(unknownTracers),
3488
3628
  params: {
3629
+ name: `${name}_resid`,
3489
3630
  jaxpr: jaxpr2,
3490
3631
  numConsts: 0
3491
3632
  },
3492
3633
  avalsOut: jaxpr2.outs.map((x) => x.aval),
3493
3634
  tracerRefsOut: []
3494
3635
  };
3495
- const outs2 = jaxpr2.outs.map((x) => new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe));
3636
+ const outs2 = jaxpr2.outs.map((x, i$1) => {
3637
+ if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
3638
+ return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
3639
+ });
3496
3640
  recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
3497
3641
  let i = 0;
3498
3642
  let j = 0;
@@ -3576,13 +3720,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3576
3720
  const [consts, constvars] = require_backend.unzip2(constToVar.entries());
3577
3721
  const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
3578
3722
  const outVars = tracersOut.map((t) => tracerToVar.get(t));
3579
- const jaxpr = new Jaxpr(inBinders, eqns, outVars);
3723
+ let jaxpr = new Jaxpr(inBinders, eqns, outVars);
3580
3724
  typecheckJaxpr(jaxpr);
3581
3725
  for (const t of consts) t.ref;
3582
3726
  for (const t of tracersIn) t.dispose();
3583
3727
  for (const t of tracersOut) t.dispose();
3728
+ jaxpr = jaxpr.simplify();
3729
+ if (require_backend.DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
3584
3730
  return {
3585
- jaxpr: jaxpr.simplify(),
3731
+ jaxpr,
3586
3732
  consts
3587
3733
  };
3588
3734
  }
@@ -3623,7 +3769,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
3623
3769
  }
3624
3770
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
3625
3771
  const eqn = jaxpr.eqns[i];
3626
- const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? scalar(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3772
+ const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3627
3773
  const cotangentsOut = eqn.outBinders.map(readCotangent);
3628
3774
  const rule = transposeRules[eqn.primitive];
3629
3775
  if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
@@ -3708,7 +3854,7 @@ const transposeRules = {
3708
3854
  },
3709
3855
  [Primitive.Dot]([ct], [x, y]) {
3710
3856
  if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
3711
- const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
3857
+ const axisSize = require_backend.generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
3712
3858
  ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
3713
3859
  return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
3714
3860
  },
@@ -3803,7 +3949,7 @@ const transposeRules = {
3803
3949
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
3804
3950
  throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
3805
3951
  },
3806
- [Primitive.JitCall](cts, args, { jaxpr }) {
3952
+ [Primitive.JitCall](cts, args, { name, jaxpr }) {
3807
3953
  const undefPrimals = args.map((x) => x instanceof UndefPrimal);
3808
3954
  const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
3809
3955
  const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
@@ -3812,6 +3958,7 @@ const transposeRules = {
3812
3958
  ...residuals,
3813
3959
  ...cts
3814
3960
  ], {
3961
+ name: `${name}_t`,
3815
3962
  jaxpr: newJaxpr,
3816
3963
  numConsts: newConsts.length
3817
3964
  });
@@ -3848,20 +3995,28 @@ function vjpFlat(f, primalsIn) {
3848
3995
  const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
3849
3996
  return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
3850
3997
  };
3851
- return [primalsOut, fVjp];
3998
+ const dispose$1 = () => {
3999
+ for (const c of consts) c.dispose();
4000
+ };
4001
+ return [
4002
+ primalsOut,
4003
+ fVjp,
4004
+ dispose$1
4005
+ ];
3852
4006
  }
3853
4007
  function vjp$1(f, ...primalsIn) {
3854
4008
  const [primalsInFlat, inTree] = flatten(primalsIn);
3855
4009
  const [fFlat, outTree] = flattenFun(f, inTree);
3856
- const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
4010
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3857
4011
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
3858
4012
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3859
- const fVjp = (cotangentsOut) => {
4013
+ const fVjp = ((cotangentsOut) => {
3860
4014
  const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
3861
4015
  if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
3862
4016
  const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
3863
4017
  return unflatten(inTree, cotangentsInFlat);
3864
- };
4018
+ });
4019
+ fVjp.dispose = dispose$1;
3865
4020
  return [primalsOut, fVjp];
3866
4021
  }
3867
4022
  function grad$1(f) {
@@ -3878,8 +4033,9 @@ function valueAndGrad$1(f) {
3878
4033
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3879
4034
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3880
4035
  if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3881
- const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
3882
- for (const r of rest) r.dispose();
4036
+ const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
4037
+ for (const r of rest) dispose(r);
4038
+ fVjp.dispose();
3883
4039
  return [y, ct];
3884
4040
  };
3885
4041
  }
@@ -3887,7 +4043,13 @@ function jacrev$1(f) {
3887
4043
  return function jacobianReverse(x) {
3888
4044
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
3889
4045
  const [size$1] = x.shape;
3890
- const pullback = (ct) => vjp$1(f, x)[1](ct)[0];
4046
+ const pullback = (ct) => {
4047
+ const [y, fVjp] = vjp$1(f, x);
4048
+ y.dispose();
4049
+ const [ret] = fVjp(ct);
4050
+ fVjp.dispose();
4051
+ return ret;
4052
+ };
3891
4053
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
3892
4054
  };
3893
4055
  }
@@ -3967,19 +4129,38 @@ __export(numpy_exports, {
3967
4129
  DType: () => require_backend.DType,
3968
4130
  abs: () => abs,
3969
4131
  absolute: () => absolute,
4132
+ acos: () => acos,
4133
+ acosh: () => acosh,
3970
4134
  add: () => add,
3971
4135
  allclose: () => allclose,
3972
4136
  arange: () => arange,
4137
+ arccos: () => arccos,
4138
+ arccosh: () => arccosh,
4139
+ arcsinh: () => arcsinh,
4140
+ arctan: () => arctan,
4141
+ arctan2: () => arctan2,
4142
+ arctanh: () => arctanh,
3973
4143
  argmax: () => argmax,
3974
4144
  argmin: () => argmin,
3975
4145
  array: () => array,
4146
+ asin: () => asin,
4147
+ asinh: () => asinh,
3976
4148
  astype: () => astype,
4149
+ atan: () => atan,
4150
+ atan2: () => atan2,
4151
+ atanh: () => atanh,
3977
4152
  bool: () => bool,
4153
+ broadcastArrays: () => broadcastArrays,
4154
+ broadcastShapes: () => broadcastShapes,
4155
+ broadcastTo: () => broadcastTo,
4156
+ cbrt: () => cbrt,
3978
4157
  clip: () => clip,
3979
4158
  columnStack: () => columnStack,
3980
4159
  concatenate: () => concatenate,
3981
4160
  cos: () => cos,
3982
4161
  cosh: () => cosh,
4162
+ deg2rad: () => deg2rad,
4163
+ degrees: () => degrees,
3983
4164
  diag: () => diag,
3984
4165
  diagonal: () => diagonal,
3985
4166
  divide: () => divide,
@@ -3990,6 +4171,7 @@ __export(numpy_exports, {
3990
4171
  eulerGamma: () => eulerGamma,
3991
4172
  exp: () => exp,
3992
4173
  exp2: () => exp2,
4174
+ expm1: () => expm1,
3993
4175
  eye: () => eye,
3994
4176
  flip: () => flip,
3995
4177
  fliplr: () => fliplr,
@@ -4001,14 +4183,17 @@ __export(numpy_exports, {
4001
4183
  greater: () => greater,
4002
4184
  greaterEqual: () => greaterEqual,
4003
4185
  hstack: () => hstack,
4186
+ hypot: () => hypot,
4004
4187
  identity: () => identity$1,
4005
4188
  inf: () => inf,
4189
+ inner: () => inner,
4006
4190
  int32: () => int32,
4007
4191
  less: () => less,
4008
4192
  lessEqual: () => lessEqual,
4009
4193
  linspace: () => linspace,
4010
4194
  log: () => log,
4011
4195
  log10: () => log10,
4196
+ log1p: () => log1p,
4012
4197
  log2: () => log2,
4013
4198
  matmul: () => matmul,
4014
4199
  max: () => max,
@@ -4024,35 +4209,49 @@ __export(numpy_exports, {
4024
4209
  negative: () => negative,
4025
4210
  notEqual: () => notEqual,
4026
4211
  ones: () => ones,
4027
- onesLike: () => onesLike$1,
4212
+ onesLike: () => onesLike,
4213
+ outer: () => outer,
4028
4214
  pad: () => pad,
4029
4215
  permuteDims: () => permuteDims,
4030
4216
  pi: () => pi,
4217
+ pow: () => pow,
4218
+ power: () => power,
4031
4219
  prod: () => prod$1,
4220
+ promoteTypes: () => require_backend.promoteTypes,
4221
+ rad2deg: () => rad2deg,
4222
+ radians: () => radians,
4032
4223
  ravel: () => ravel,
4033
4224
  reciprocal: () => reciprocal,
4225
+ repeat: () => repeat,
4034
4226
  reshape: () => reshape,
4035
- scalar: () => scalar,
4036
4227
  shape: () => shape,
4228
+ sign: () => sign,
4037
4229
  sin: () => sin,
4038
4230
  sinh: () => sinh,
4039
4231
  size: () => size,
4040
4232
  sqrt: () => sqrt,
4041
4233
  square: () => square,
4042
4234
  stack: () => stack,
4235
+ std: () => std,
4236
+ subtract: () => subtract,
4043
4237
  sum: () => sum,
4044
4238
  tan: () => tan,
4045
4239
  tanh: () => tanh,
4240
+ tile: () => tile,
4046
4241
  transpose: () => transpose,
4242
+ tri: () => tri,
4243
+ tril: () => tril,
4244
+ triu: () => triu,
4047
4245
  trueDivide: () => trueDivide,
4048
4246
  trunc: () => trunc,
4049
4247
  uint32: () => uint32,
4248
+ var_: () => var_,
4050
4249
  vdot: () => vdot,
4051
4250
  vecdot: () => vecdot,
4052
4251
  vstack: () => vstack,
4053
4252
  where: () => where,
4054
4253
  zeros: () => zeros,
4055
- zerosLike: () => zerosLike$1
4254
+ zerosLike: () => zerosLike
4056
4255
  });
4057
4256
  const float32 = require_backend.DType.Float32;
4058
4257
  const int32 = require_backend.DType.Int32;
@@ -4069,54 +4268,66 @@ const inf = Number.POSITIVE_INFINITY;
4069
4268
  const nan = NaN;
4070
4269
  /** This is Pi, `π = 3.14159265358979...` */
4071
4270
  const pi = Math.PI;
4072
- /** Element-wise addition, with broadcasting. */
4271
+ /** @function Element-wise addition, with broadcasting. */
4073
4272
  const add = add$1;
4074
- /** Element-wise multiplication, with broadcasting. */
4273
+ /** @function Element-wise multiplication, with broadcasting. */
4075
4274
  const multiply = mul;
4076
- /** Numerical negative of every element of an array. */
4275
+ /** @function Numerical negative of every element of an array. */
4077
4276
  const negative = neg;
4078
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
4277
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
4079
4278
  const reciprocal = reciprocal$1;
4080
- /** Element-wise sine function (takes radians). */
4279
+ /** @function Element-wise sine function (takes radians). */
4081
4280
  const sin = sin$1;
4082
- /** Element-wise cosine function (takes radians). */
4281
+ /** @function Element-wise cosine function (takes radians). */
4083
4282
  const cos = cos$1;
4084
- /** Calculate the exponential of all elements in the input array. */
4283
+ /** @function Element-wise inverse sine function (inverse of sin). */
4284
+ const asin = asin$1;
4285
+ /** @function Element-wise inverse tangent function (inverse of tan). */
4286
+ const atan = atan$1;
4287
+ /** @function Calculate the exponential of all elements in the input array. */
4085
4288
  const exp = exp$1;
4086
- /** Calculate the natural logarithm of all elements in the input array. */
4289
+ /** @function Calculate the natural logarithm of all elements in the input array. */
4087
4290
  const log = log$1;
4088
- /** Calculate the square root of all elements in the input array. */
4291
+ /** @function Calculate the square root of all elements in the input array. */
4089
4292
  const sqrt = sqrt$1;
4090
- /** Return element-wise minimum of the input arrays. */
4293
+ /** @function Return element-wise minimum of the input arrays. */
4091
4294
  const minimum = min$1;
4092
- /** Return element-wise maximum of the input arrays. */
4295
+ /** @function Return element-wise maximum of the input arrays. */
4093
4296
  const maximum = max$1;
4094
- /** Compare two arrays element-wise. */
4297
+ /** @function Compare two arrays element-wise. */
4095
4298
  const greater = greater$1;
4096
- /** Compare two arrays element-wise. */
4299
+ /** @function Compare two arrays element-wise. */
4097
4300
  const less = less$1;
4098
- /** Compare two arrays element-wise. */
4301
+ /** @function Compare two arrays element-wise. */
4099
4302
  const equal = equal$1;
4100
- /** Compare two arrays element-wise. */
4303
+ /** @function Compare two arrays element-wise. */
4101
4304
  const notEqual = notEqual$1;
4102
- /** Compare two arrays element-wise. */
4305
+ /** @function Compare two arrays element-wise. */
4103
4306
  const greaterEqual = greaterEqual$1;
4104
- /** Compare two arrays element-wise. */
4307
+ /** @function Compare two arrays element-wise. */
4105
4308
  const lessEqual = lessEqual$1;
4106
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4309
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4107
4310
  const where = where$1;
4108
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
4311
+ /**
4312
+ * @function
4313
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
4314
+ */
4109
4315
  const transpose = transpose$1;
4110
4316
  /**
4317
+ * @function
4111
4318
  * Give a new shape to an array without changing its data.
4112
4319
  *
4113
4320
  * One shape dimension can be -1. In this case, the value is inferred from the
4114
4321
  * length of the array and remaining dimensions.
4115
4322
  */
4116
4323
  const reshape = reshape$1;
4117
- /** Move axes of an array to new positions. Other axes retain original order. */
4324
+ /**
4325
+ * @function
4326
+ * Move axes of an array to new positions. Other axes retain original order.
4327
+ */
4118
4328
  const moveaxis = moveaxis$1;
4119
4329
  /**
4330
+ * @function
4120
4331
  * Add padding (zeros) to an array.
4121
4332
  *
4122
4333
  * The `width` argument is either an integer or pair of integers, in which case
@@ -4124,15 +4335,27 @@ const moveaxis = moveaxis$1;
4124
4335
  * pair specifies the padding for its corresponding axis.
4125
4336
  */
4126
4337
  const pad = pad$1;
4127
- /** Return the number of dimensions of an array. Does not consume array reference. */
4338
+ /**
4339
+ * @function
4340
+ * Return the number of dimensions of an array. Does not consume array reference.
4341
+ */
4128
4342
  const ndim = ndim$1;
4129
- /** Return the shape of an array. Does not consume array reference. */
4343
+ /** @function Return the shape of an array. Does not consume array reference. */
4130
4344
  const shape = getShape;
4131
- /** Return an array of zeros with the same shape and type as a given array. */
4132
- const zerosLike$1 = zerosLike;
4133
- /** Return an array of ones with the same shape and type as a given array. */
4134
- const onesLike$1 = onesLike;
4135
- /** Return a full array with the same shape and type as a given array. */
4345
+ /**
4346
+ * @function
4347
+ * Return an array of zeros with the same shape and type as a given array.
4348
+ */
4349
+ const zerosLike = zerosLike$1;
4350
+ /**
4351
+ * @function
4352
+ * Return an array of ones with the same shape and type as a given array.
4353
+ */
4354
+ const onesLike = onesLike$1;
4355
+ /**
4356
+ * @function
4357
+ * Return a full array with the same shape and type as a given array.
4358
+ */
4136
4359
  const fullLike$1 = fullLike;
4137
4360
  /**
4138
4361
  * Return the number of elements in an array, optionally along an axis.
@@ -4147,23 +4370,23 @@ function astype(a, dtype) {
4147
4370
  return fudgeArray(a).astype(dtype);
4148
4371
  }
4149
4372
  /** Sum of the elements of the array over a given axis, or axes. */
4150
- function sum(a, axis, opts) {
4373
+ function sum(a, axis = null, opts) {
4151
4374
  return reduce(a, require_backend.AluOp.Add, axis, opts);
4152
4375
  }
4153
4376
  /** Product of the array elements over a given axis. */
4154
- function prod$1(a, axis, opts) {
4377
+ function prod$1(a, axis = null, opts) {
4155
4378
  return reduce(a, require_backend.AluOp.Mul, axis, opts);
4156
4379
  }
4157
4380
  /** Return the minimum of array elements along a given axis. */
4158
- function min(a, axis, opts) {
4381
+ function min(a, axis = null, opts) {
4159
4382
  return reduce(a, require_backend.AluOp.Min, axis, opts);
4160
4383
  }
4161
4384
  /** Return the maximum of array elements along a given axis. */
4162
- function max(a, axis, opts) {
4385
+ function max(a, axis = null, opts) {
4163
4386
  return reduce(a, require_backend.AluOp.Max, axis, opts);
4164
4387
  }
4165
4388
  /** Compute the average of the array elements along the specified axis. */
4166
- function mean(a, axis, opts) {
4389
+ function mean(a, axis = null, opts) {
4167
4390
  return fudgeArray(a).mean(axis, opts);
4168
4391
  }
4169
4392
  /**
@@ -4179,8 +4402,8 @@ function argmin(a, axis, opts) {
4179
4402
  axis = 0;
4180
4403
  } else axis = require_backend.checkAxis(axis, a.ndim);
4181
4404
  const shape$1 = a.shape;
4182
- const isMax = equal(a, min(a.ref, axis, { keepDims: true }));
4183
- const length = scalar(shape$1[axis], {
4405
+ const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
4406
+ const length = array(shape$1[axis], {
4184
4407
  dtype: int32,
4185
4408
  device: a.device
4186
4409
  });
@@ -4203,8 +4426,8 @@ function argmax(a, axis, opts) {
4203
4426
  axis = 0;
4204
4427
  } else axis = require_backend.checkAxis(axis, a.ndim);
4205
4428
  const shape$1 = a.shape;
4206
- const isMax = equal(a, max(a.ref, axis, { keepDims: true }));
4207
- const length = scalar(shape$1[axis], {
4429
+ const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
4430
+ const length = array(shape$1[axis], {
4208
4431
  dtype: int32,
4209
4432
  device: a.device
4210
4433
  });
@@ -4215,17 +4438,9 @@ function argmax(a, axis, opts) {
4215
4438
  return length.sub(max(idx, axis, opts));
4216
4439
  }
4217
4440
  /** Reverse the elements in an array along the given axes. */
4218
- function flip(x, axis) {
4441
+ function flip(x, axis = null) {
4219
4442
  const nd = ndim(x);
4220
- if (axis === void 0) axis = require_backend.range(nd);
4221
- else if (typeof axis === "number") axis = [axis];
4222
- const seen = /* @__PURE__ */ new Set();
4223
- for (let i = 0; i < axis.length; i++) {
4224
- if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
4225
- if (axis[i] < 0) axis[i] += nd;
4226
- if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
4227
- seen.add(axis[i]);
4228
- }
4443
+ axis = require_backend.normalizeAxis(axis, nd);
4229
4444
  return flip$1(x, axis);
4230
4445
  }
4231
4446
  /**
@@ -4331,12 +4546,80 @@ function flipud(x) {
4331
4546
  function fliplr(x) {
4332
4547
  return flip(x, 1);
4333
4548
  }
4549
+ /** @function Alternative name for `numpy.transpose()`. */
4334
4550
  const permuteDims = transpose;
4335
4551
  /** Return a 1-D flattened array containing the elements of the input. */
4336
4552
  function ravel(a) {
4337
4553
  return fudgeArray(a).ravel();
4338
4554
  }
4339
4555
  /**
4556
+ * Repeat each element of an array after themselves.
4557
+ *
4558
+ * If no axis is provided, use the flattened input array, and return a flat
4559
+ * output array.
4560
+ */
4561
+ function repeat(a, repeats, axis) {
4562
+ if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
4563
+ a = fudgeArray(a);
4564
+ if (axis === void 0) {
4565
+ a = ravel(a);
4566
+ axis = 0;
4567
+ }
4568
+ axis = require_backend.checkAxis(axis, a.ndim);
4569
+ if (repeats === 1) return a;
4570
+ const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
4571
+ const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
4572
+ return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
4573
+ }
4574
+ /**
4575
+ * Construct an array by repeating A the number of times given by reps.
4576
+ *
4577
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
4578
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
4579
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
4580
+ */
4581
+ function tile(a, reps) {
4582
+ a = fudgeArray(a);
4583
+ if (typeof reps === "number") reps = [reps];
4584
+ if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
4585
+ const ndiff = reps.length - a.ndim;
4586
+ if (ndiff > 0) a = a.reshape([...require_backend.rep(ndiff, 1), ...a.shape]);
4587
+ if (ndiff < 0) reps = [...require_backend.rep(-ndiff, 1), ...reps];
4588
+ const broadcastedShape = [];
4589
+ const broadcastAxes = [];
4590
+ for (let i = 0; i < a.ndim; i++) {
4591
+ if (reps[i] > 1) {
4592
+ broadcastedShape.push(reps[i]);
4593
+ broadcastAxes.push(broadcastedShape.length - 1);
4594
+ }
4595
+ broadcastedShape.push(a.shape[i]);
4596
+ }
4597
+ const finalShape = a.shape.map((d, i) => reps[i] * d);
4598
+ return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
4599
+ }
4600
+ /**
4601
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
4602
+ *
4603
+ * In other words, this lets you append axes to the left, and/or expand
4604
+ * dimensions where the shape is 1.
4605
+ */
4606
+ function broadcastTo(a, shape$1) {
4607
+ const nd = ndim(a);
4608
+ if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
4609
+ return broadcast(a, shape$1, require_backend.range(shape$1.length - nd));
4610
+ }
4611
+ /** Broadcast input shapes to a common output shape. */
4612
+ function broadcastShapes(...shapes) {
4613
+ if (shapes.length === 0) return [];
4614
+ return shapes.reduce(require_backend.generalBroadcast);
4615
+ }
4616
+ /** Broadcast arrays to a common shape. */
4617
+ function broadcastArrays(...arrays) {
4618
+ const shapes = arrays.map((a) => shape(a));
4619
+ const outShape = broadcastShapes(...shapes);
4620
+ return arrays.map((a) => broadcastTo(a, outShape));
4621
+ }
4622
+ /**
4340
4623
  * Return specified diagonals.
4341
4624
  *
4342
4625
  * If a is 2D, return the diagonal of the array with the given offset. If a is
@@ -4360,7 +4643,7 @@ function diag(v, k = 0) {
4360
4643
  if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
4361
4644
  if (a.ndim === 1) {
4362
4645
  const n = a.shape[0];
4363
- const ret = where(eye(n).equal(1), a.ref, zerosLike$1(a));
4646
+ const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
4364
4647
  if (k > 0) return pad(ret, [[0, k], [k, 0]]);
4365
4648
  else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
4366
4649
  else return ret;
@@ -4404,8 +4687,36 @@ function dot(x, y) {
4404
4687
  ]);
4405
4688
  return dot$1(x, y);
4406
4689
  }
4407
- /** Vector dot product of two arrays. */
4408
- function vecdot(x, y) {
4690
+ /**
4691
+ * Compute the inner product of two arrays.
4692
+ *
4693
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
4694
+ * contraction on the last axis.
4695
+ *
4696
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
4697
+ */
4698
+ function inner(x, y) {
4699
+ x = reshape(x, shape(x).toSpliced(-1, 0, ...require_backend.rep(ndim(y) - 1, 1)));
4700
+ return dot$1(x, y);
4701
+ }
4702
+ /**
4703
+ * Compute the outer product of two arrays.
4704
+ *
4705
+ * If the input arrays are not 1D, they will be flattened. Returned array will
4706
+ * be of shape `[x.size, y.size]`.
4707
+ */
4708
+ function outer(x, y) {
4709
+ x = ravel(x);
4710
+ y = ravel(y);
4711
+ return multiply(x.reshape([x.shape[0], 1]), y);
4712
+ }
4713
+ /** Vector dot product of two arrays along a given axis. */
4714
+ function vecdot(x, y, { axis } = {}) {
4715
+ const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
4716
+ const yaxis = require_backend.checkAxis(axis ?? -1, ndim(y));
4717
+ if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
4718
+ x = moveaxis(x, xaxis, -1);
4719
+ y = moveaxis(y, yaxis, -1);
4409
4720
  return dot$1(x, y);
4410
4721
  }
4411
4722
  /**
@@ -4414,7 +4725,7 @@ function vecdot(x, y) {
4414
4725
  * Like vecdot() but flattens the arguments first into vectors.
4415
4726
  */
4416
4727
  function vdot(x, y) {
4417
- return vecdot(ravel(x), ravel(y));
4728
+ return dot$1(ravel(x), ravel(y));
4418
4729
  }
4419
4730
  /**
4420
4731
  * Return a tuple of coordinate matrices from coordinate vectors.
@@ -4443,6 +4754,43 @@ function meshgrid(xs, { indexing } = {}) {
4443
4754
  return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
4444
4755
  }
4445
4756
  /**
4757
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
4758
+ *
4759
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
4760
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
4761
+ * `k>0` is above it.
4762
+ */
4763
+ function tri(n, m, k = 0, { dtype, device } = {}) {
4764
+ m ??= n;
4765
+ dtype ??= require_backend.DType.Float32;
4766
+ if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
4767
+ if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
4768
+ if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
4769
+ const rows = arange(k, n + k, 1, {
4770
+ dtype: require_backend.DType.Int32,
4771
+ device
4772
+ });
4773
+ const cols = arange(0, m, 1, {
4774
+ dtype: require_backend.DType.Int32,
4775
+ device
4776
+ });
4777
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
4778
+ }
4779
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
4780
+ function tril(a, k = 0) {
4781
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4782
+ a = fudgeArray(a);
4783
+ const [n, m] = a.shape.slice(-2);
4784
+ return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
4785
+ }
4786
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
4787
+ function triu(a, k = 0) {
4788
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4789
+ a = fudgeArray(a);
4790
+ const [n, m] = a.shape.slice(-2);
4791
+ return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
4792
+ }
4793
+ /**
4446
4794
  * Clip (limit) the values in an array.
4447
4795
  *
4448
4796
  * Given an interval, values outside the interval are clipped to the interval
@@ -4466,18 +4814,70 @@ function absolute(x) {
4466
4814
  x = fudgeArray(x);
4467
4815
  return where(less(x.ref, 0), x.ref.mul(-1), x);
4468
4816
  }
4469
- /** Alias of `jax.numpy.absolute()`. */
4817
+ /** @function Alias of `jax.numpy.absolute()`. */
4470
4818
  const abs = absolute;
4819
+ /** Return an element-wise indication of sign of the input. */
4820
+ function sign(x) {
4821
+ x = fudgeArray(x);
4822
+ return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4823
+ }
4471
4824
  /** Calculate element-wise square of the input array. */
4472
4825
  function square(x) {
4473
4826
  x = fudgeArray(x);
4474
4827
  return x.ref.mul(x);
4475
4828
  }
4476
- /** Compute a trigonometric tangent of each element of input. */
4829
+ /** Element-wise tangent function (takes radians). */
4477
4830
  function tan(x) {
4478
4831
  x = fudgeArray(x);
4479
4832
  return sin(x.ref).div(cos(x));
4480
4833
  }
4834
+ /** Element-wise inverse cosine function (inverse of cos). */
4835
+ function acos(x) {
4836
+ return subtract(pi / 2, asin(x));
4837
+ }
4838
+ /**
4839
+ * @function
4840
+ * Return element-wise hypotenuse for the given legs of a right triangle.
4841
+ *
4842
+ * In the original NumPy/JAX implementation, this function is more numerically
4843
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4844
+ * improvements.
4845
+ */
4846
+ const hypot = jit$1(function hypot$1(x1, x2) {
4847
+ return sqrt(square(x1).add(square(x2)));
4848
+ });
4849
+ /**
4850
+ * @function
4851
+ * Element-wise arc tangent of y/x with correct quadrant.
4852
+ *
4853
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
4854
+ * The result is in the range [-π, π].
4855
+ *
4856
+ * Uses numerically stable formulas:
4857
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
4858
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
4859
+ *
4860
+ * The output is ill-defined when both x and y are zero.
4861
+ */
4862
+ const atan2 = jit$1(function atan2$1(y, x) {
4863
+ const r = sqrt(square(x.ref).add(square(y.ref)));
4864
+ const xNeg = less(x.ref, 0);
4865
+ const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
4866
+ const denom = where(xNeg, y, r.add(x));
4867
+ return atan(numer.div(denom)).mul(2);
4868
+ });
4869
+ /** @function Alias of `jax.numpy.acos()`. */
4870
+ const arccos = acos;
4871
+ /** @function Alias of `jax.numpy.atan()`. */
4872
+ const arctan = atan;
4873
+ /** @function Alias of `jax.numpy.atan2()`. */
4874
+ const arctan2 = atan2;
4875
+ /** Element-wise subtraction, with broadcasting. */
4876
+ function subtract(x, y) {
4877
+ x = fudgeArray(x);
4878
+ y = fudgeArray(y);
4879
+ return x.sub(y);
4880
+ }
4481
4881
  /** Calculates the floating-point division of x by y element-wise. */
4482
4882
  function trueDivide(x, y) {
4483
4883
  x = fudgeArray(x);
@@ -4485,7 +4885,7 @@ function trueDivide(x, y) {
4485
4885
  if (!require_backend.isFloatDtype(x.dtype) || !require_backend.isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
4486
4886
  return x.div(y);
4487
4887
  }
4488
- /** Alias of `jax.numpy.trueDivide()`. */
4888
+ /** @function Alias of `jax.numpy.trueDivide()`. */
4489
4889
  const divide = trueDivide;
4490
4890
  /** Round input to the nearest integer towards zero. */
4491
4891
  function trunc(x) {
@@ -4503,36 +4903,134 @@ function log2(x) {
4503
4903
  function log10(x) {
4504
4904
  return log(x).mul(Math.LOG10E);
4505
4905
  }
4906
+ /** Calculate `exp(x) - 1` element-wise. */
4907
+ function expm1(x) {
4908
+ return exp(x).sub(1);
4909
+ }
4910
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
4911
+ function log1p(x) {
4912
+ return log(add(1, x));
4913
+ }
4914
+ /** Convert angles from degrees to radians. */
4915
+ function deg2rad(x) {
4916
+ return multiply(x, pi / 180);
4917
+ }
4918
+ /** @function Alias of `jax.numpy.deg2rad()`. */
4919
+ const radians = deg2rad;
4920
+ /** Convert angles from radians to degrees. */
4921
+ function rad2deg(x) {
4922
+ return multiply(x, 180 / pi);
4923
+ }
4924
+ /** @function Alias of `jax.numpy.rad2deg()`. */
4925
+ const degrees = rad2deg;
4926
+ /**
4927
+ * @function
4928
+ * Computes first array raised to power of second array, element-wise.
4929
+ */
4930
+ const power = jit$1(function power$1(x1, x2) {
4931
+ return exp(log(x1).mul(x2));
4932
+ });
4933
+ /** @function Alias of `jax.numpy.power()`. */
4934
+ const pow = power;
4935
+ /** @function Calculate the element-wise cube root of the input array. */
4936
+ const cbrt = jit$1(function cbrt$1(x) {
4937
+ const sgn = where(less(x.ref, 0), -1, 1);
4938
+ return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4939
+ });
4506
4940
  /**
4941
+ * @function
4507
4942
  * Calculate element-wise hyperbolic sine of input.
4508
4943
  *
4509
4944
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
4510
4945
  */
4511
- function sinh(x) {
4946
+ const sinh = jit$1(function sinh$1(x) {
4512
4947
  const ex = exp(x);
4513
4948
  const emx = reciprocal(ex.ref);
4514
4949
  return ex.sub(emx).mul(.5);
4515
- }
4950
+ });
4516
4951
  /**
4952
+ * @function
4517
4953
  * Calculate element-wise hyperbolic cosine of input.
4518
4954
  *
4519
4955
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
4520
4956
  */
4521
- function cosh(x) {
4957
+ const cosh = jit$1(function cosh$1(x) {
4522
4958
  const ex = exp(x);
4523
4959
  const emx = reciprocal(ex.ref);
4524
4960
  return ex.add(emx).mul(.5);
4525
- }
4961
+ });
4526
4962
  /**
4963
+ * @function
4527
4964
  * Calculate element-wise hyperbolic tangent of input.
4528
4965
  *
4529
4966
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4530
4967
  */
4531
- function tanh(x) {
4532
- x = fudgeArray(x);
4968
+ const tanh = jit$1(function tanh$1(x) {
4533
4969
  const negsgn = where(less(x.ref, 0), 1, -1);
4534
4970
  const en2x = exp(x.mul(negsgn.ref).mul(2));
4535
4971
  return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
4972
+ });
4973
+ /**
4974
+ * @function
4975
+ * Calculate element-wise inverse hyperbolic sine of input.
4976
+ *
4977
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4978
+ */
4979
+ const arcsinh = jit$1(function arcsinh$1(x) {
4980
+ return log(x.ref.add(sqrt(square(x).add(1))));
4981
+ });
4982
+ /**
4983
+ * @function
4984
+ * Calculate element-wise inverse hyperbolic cosine of input.
4985
+ *
4986
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4987
+ */
4988
+ const arccosh = jit$1(function arccosh$1(x) {
4989
+ return log(x.ref.add(sqrt(square(x).sub(1))));
4990
+ });
4991
+ /**
4992
+ * @function
4993
+ * Calculate element-wise inverse hyperbolic tangent of input.
4994
+ *
4995
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4996
+ */
4997
+ const arctanh = jit$1(function arctanh$1(x) {
4998
+ return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4999
+ });
5000
+ /** @function Alias of `jax.numpy.arcsinh()`. */
5001
+ const asinh = arcsinh;
5002
+ /** @function Alias of `jax.numpy.arccosh()`. */
5003
+ const acosh = arccosh;
5004
+ /** @function Alias of `jax.numpy.arctanh()`. */
5005
+ const atanh = arctanh;
5006
+ /**
5007
+ * Compute the variance of an array.
5008
+ *
5009
+ * The variance is computed for the flattened array by default, otherwise over
5010
+ * the specified axis.
5011
+ *
5012
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
5013
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
5014
+ */
5015
+ function var_(x, axis = null, opts) {
5016
+ x = fudgeArray(x);
5017
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5018
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5019
+ if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
5020
+ const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
5021
+ return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
5022
+ }
5023
+ /**
5024
+ * Compute the standard deviation of an array.
5025
+ *
5026
+ * The standard deviation is computed for the flattened array by default,
5027
+ * otherwise over the specified axis.
5028
+ *
5029
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
5030
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
5031
+ */
5032
+ function std(x, axis = null, opts) {
5033
+ return sqrt(var_(x, axis, opts));
4536
5034
  }
4537
5035
 
4538
5036
  //#endregion
@@ -4547,6 +5045,7 @@ __export(nn_exports, {
4547
5045
  leakyRelu: () => leakyRelu,
4548
5046
  logSigmoid: () => logSigmoid,
4549
5047
  logSoftmax: () => logSoftmax,
5048
+ logmeanexp: () => logmeanexp,
4550
5049
  logsumexp: () => logsumexp,
4551
5050
  mish: () => mish,
4552
5051
  oneHot: () => oneHot,
@@ -4557,6 +5056,8 @@ __export(nn_exports, {
4557
5056
  softSign: () => softSign,
4558
5057
  softmax: () => softmax,
4559
5058
  softplus: () => softplus,
5059
+ squareplus: () => squareplus,
5060
+ standardize: () => standardize,
4560
5061
  swish: () => swish
4561
5062
  });
4562
5063
  /**
@@ -4600,6 +5101,7 @@ function softSign(x) {
4600
5101
  return x.ref.div(absolute(x).add(1));
4601
5102
  }
4602
5103
  /**
5104
+ * @function
4603
5105
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4604
5106
  * Swish, computed element-wise:
4605
5107
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4608,8 +5110,11 @@ function softSign(x) {
4608
5110
  *
4609
5111
  * Reference: https://en.wikipedia.org/wiki/Swish_function
4610
5112
  */
4611
- const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
5113
+ const silu = jit$1(function silu$1(x) {
5114
+ return x.ref.mul(sigmoid(x));
5115
+ });
4612
5116
  /**
5117
+ * @function
4613
5118
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4614
5119
  * Swish, computed element-wise:
4615
5120
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4626,7 +5131,10 @@ const swish = silu;
4626
5131
  function logSigmoid(x) {
4627
5132
  return negative(softplus(negative(x)));
4628
5133
  }
4629
- /** Identity activation function. Returns the argument unmodified. */
5134
+ /**
5135
+ * @function
5136
+ * Identity activation function. Returns the argument unmodified.
5137
+ */
4630
5138
  const identity = fudgeArray;
4631
5139
  /** Leaky rectified linear (ReLU) activation function */
4632
5140
  function leakyRelu(x, negativeSlope = .01) {
@@ -4654,6 +5162,7 @@ function celu(x, alpha = 1) {
4654
5162
  return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
4655
5163
  }
4656
5164
  /**
5165
+ * @function
4657
5166
  * Gaussion error linear unit (GELU) activation function.
4658
5167
  *
4659
5168
  * This is computed element-wise. Currently jax-js does not support the erf() or
@@ -4664,7 +5173,7 @@ function celu(x, alpha = 1) {
4664
5173
  *
4665
5174
  * This will be improved in the future.
4666
5175
  */
4667
- const gelu = jit$1((x) => {
5176
+ const gelu = jit$1(function gelu$1(x) {
4668
5177
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
4669
5178
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
4670
5179
  });
@@ -4685,6 +5194,16 @@ function glu(x, axis = -1) {
4685
5194
  return a.mul(sigmoid(b));
4686
5195
  }
4687
5196
  /**
5197
+ * Squareplus activation function.
5198
+ *
5199
+ * Computes the element-wise function:
5200
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
5201
+ */
5202
+ function squareplus(x, b = 4) {
5203
+ x = fudgeArray(x);
5204
+ return x.ref.add(sqrt(square(x).add(b))).mul(.5);
5205
+ }
5206
+ /**
4688
5207
  * Mish activation function.
4689
5208
  *
4690
5209
  * Computes the element-wise function:
@@ -4702,17 +5221,13 @@ function mish(x) {
4702
5221
  *
4703
5222
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
4704
5223
  */
4705
- function softmax(x, axis) {
5224
+ function softmax(x, axis = -1) {
4706
5225
  x = fudgeArray(x);
4707
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4708
- else if (typeof axis === "number") axis = [axis];
4709
- if (axis.length === 0) {
4710
- x.dispose();
4711
- return ones(x.shape);
4712
- }
4713
- const xMax = max(x.ref, axis, { keepDims: true });
5226
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5227
+ if (axis.length === 0) return onesLike(x);
5228
+ const xMax = max(x.ref, axis, { keepdims: true });
4714
5229
  const unnormalized = exp(x.sub(stopGradient(xMax)));
4715
- return unnormalized.ref.div(unnormalized.sum(axis, { keepDims: true }));
5230
+ return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
4716
5231
  }
4717
5232
  /**
4718
5233
  * Log-Softmax function.
@@ -4722,17 +5237,13 @@ function softmax(x, axis) {
4722
5237
  *
4723
5238
  * If `axis` is not specified, it defaults to the last axis.
4724
5239
  */
4725
- function logSoftmax(x, axis) {
5240
+ function logSoftmax(x, axis = -1) {
4726
5241
  x = fudgeArray(x);
4727
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4728
- else if (typeof axis === "number") axis = [axis];
4729
- if (axis.length === 0) {
4730
- x.dispose();
4731
- return zeros(x.shape);
4732
- }
4733
- const xMax = max(x.ref, axis, { keepDims: true });
5242
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5243
+ if (axis.length === 0) return zerosLike(x);
5244
+ const xMax = max(x.ref, axis, { keepdims: true });
4734
5245
  const shifted = x.sub(stopGradient(xMax));
4735
- const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepDims: true }));
5246
+ const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
4736
5247
  return shifted.sub(shiftedLogsumexp);
4737
5248
  }
4738
5249
  /**
@@ -4743,16 +5254,39 @@ function logSoftmax(x, axis) {
4743
5254
  *
4744
5255
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
4745
5256
  */
4746
- function logsumexp(x, axis) {
5257
+ function logsumexp(x, axis = null) {
4747
5258
  x = fudgeArray(x);
4748
- if (axis === void 0) axis = require_backend.range(x.ndim);
4749
- else if (typeof axis === "number") axis = [axis];
5259
+ axis = require_backend.normalizeAxis(axis, x.ndim);
4750
5260
  if (axis.length === 0) return x;
4751
5261
  const xMax = stopGradient(max(x.ref, axis));
4752
5262
  const xMaxDims = broadcast(xMax.ref, x.shape, axis);
4753
5263
  const shifted = x.sub(xMaxDims);
4754
5264
  return xMax.add(log(exp(shifted).sum(axis)));
4755
5265
  }
5266
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
5267
+ function logmeanexp(x, axis = null) {
5268
+ x = fudgeArray(x);
5269
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5270
+ if (axis.length === 0) return x;
5271
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5272
+ return logsumexp(x, axis).sub(Math.log(n));
5273
+ }
5274
+ /**
5275
+ * Standardizes input to zero mean and unit variance.
5276
+ *
5277
+ * By default, this is computed over the last axis. You can pass in a different
5278
+ * axis, or `null` to standardize over all elements.
5279
+ *
5280
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
5281
+ */
5282
+ function standardize(x, axis = -1, opts = {}) {
5283
+ x = fudgeArray(x);
5284
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5285
+ if (axis.length === 0) return x;
5286
+ const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
5287
+ const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
5288
+ return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
5289
+ }
4756
5290
  /**
4757
5291
  * One-hot encodes the given indices.
4758
5292
  *
@@ -4770,7 +5304,7 @@ function logsumexp(x, axis) {
4770
5304
  * ```
4771
5305
  */
4772
5306
  function oneHot(x, numClasses) {
4773
- if (x.dtype !== require_backend.DType.Int32) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
5307
+ if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4774
5308
  return eye(numClasses, void 0, { device: x.device }).slice(x);
4775
5309
  }
4776
5310
 
@@ -4778,8 +5312,11 @@ function oneHot(x, numClasses) {
4778
5312
  //#region src/random.ts
4779
5313
  var random_exports = {};
4780
5314
  __export(random_exports, {
5315
+ bernoulli: () => bernoulli,
4781
5316
  bits: () => bits,
5317
+ exponential: () => exponential,
4782
5318
  key: () => key,
5319
+ normal: () => normal,
4783
5320
  split: () => split,
4784
5321
  uniform: () => uniform
4785
5322
  });
@@ -4807,21 +5344,58 @@ function bits(key$1, shape$1 = []) {
4807
5344
  const keyShape = validateKeyShape(key$1);
4808
5345
  return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
4809
5346
  }
4810
- /** Sample uniform random values in [minval, maxval) with given shape. */
4811
- function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5347
+ /**
5348
+ * @function
5349
+ * Sample uniform random values in [minval, maxval) with given shape.
5350
+ */
5351
+ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4812
5352
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
4813
- const mantissa = bits(key$1, shape$1).div(scalar(512, {
5353
+ const mantissa = bits(key$1, shape$1).div(array(512, {
4814
5354
  dtype: require_backend.DType.Uint32,
4815
5355
  device: key$1.device
4816
5356
  }));
4817
- const float12 = mantissa.add(scalar(1065353216, {
5357
+ const float12 = mantissa.add(array(1065353216, {
4818
5358
  dtype: require_backend.DType.Uint32,
4819
5359
  device: key$1.device
4820
5360
  }));
4821
5361
  const rand = bitcast(float12, require_backend.DType.Float32).sub(1);
4822
5362
  if (minval === 0 && maxval === 1) return rand;
4823
5363
  else return rand.mul(maxval - minval).add(minval);
5364
+ }, { staticArgnums: [1, 2] });
5365
+ /**
5366
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
5367
+ *
5368
+ * Returns a random Boolean array with the specified shape. `p` can be an array
5369
+ * and must be broadcastable to `shape`.
5370
+ */
5371
+ function bernoulli(key$1, p = .5, shape$1 = []) {
5372
+ p = fudgeArray(p);
5373
+ return uniform(key$1, shape$1).less(p);
4824
5374
  }
5375
+ /**
5376
+ * @function
5377
+ * Sample exponential random values according to `p(x) = exp(-x)`.
5378
+ */
5379
+ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
5380
+ const u = uniform(key$1, shape$1);
5381
+ return negative(log1p(negative(u)));
5382
+ }, { staticArgnums: [1] });
5383
+ /**
5384
+ * @function
5385
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5386
+ *
5387
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5388
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5389
+ * bitwise identical to JAX.
5390
+ */
5391
+ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5392
+ const [k1, k2] = split(key$1, 2);
5393
+ const u1 = uniform(k1, shape$1);
5394
+ const u2 = uniform(k2, shape$1);
5395
+ const radius = sqrt(log1p(negative(u1)).mul(-2));
5396
+ const theta = u2.mul(2 * Math.PI);
5397
+ return radius.mul(cos(theta));
5398
+ }, { staticArgnums: [1] });
4825
5399
 
4826
5400
  //#endregion
4827
5401
  //#region src/polyfills.ts
@@ -4831,20 +5405,36 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
4831
5405
 
4832
5406
  //#endregion
4833
5407
  //#region src/index.ts
4834
- /** Compute the forward-mode Jacobian-vector product for a function. */
5408
+ /**
5409
+ * @function
5410
+ * Compute the forward-mode Jacobian-vector product for a function.
5411
+ */
4835
5412
  const jvp = jvp$1;
4836
- /** Vectorize an operation on a batched axis for one or more inputs. */
5413
+ /**
5414
+ * @function
5415
+ * Vectorize an operation on a batched axis for one or more inputs.
5416
+ */
4837
5417
  const vmap = vmap$1;
4838
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
5418
+ /**
5419
+ * @function
5420
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
5421
+ */
4839
5422
  const jacfwd = jacfwd$1;
4840
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
5423
+ /**
5424
+ * @function
5425
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
5426
+ */
4841
5427
  const makeJaxpr = makeJaxpr$1;
4842
5428
  /**
5429
+ * @function
4843
5430
  * Mark a function for automatic JIT compilation, with operator fusion.
4844
5431
  *
4845
5432
  * The function will be compiled the first time it is called with a set of
4846
5433
  * argument shapes.
4847
5434
  *
5435
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
5436
+ * calls to free memory associated with array constants.
5437
+ *
4848
5438
  * **Options:**
4849
5439
  * - `staticArgnums`: An array of argument indices to treat as static
4850
5440
  * (compile-time constant). These arguments must be hashable, won't be traced,
@@ -4854,26 +5444,59 @@ const makeJaxpr = makeJaxpr$1;
4854
5444
  */
4855
5445
  const jit = jit$1;
4856
5446
  /**
5447
+ * @function
4857
5448
  * Produce a local linear approximation to a function at a point using jvp() and
4858
5449
  * partial evaluation.
4859
5450
  */
4860
5451
  const linearize = linearize$1;
4861
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
5452
+ /**
5453
+ * @function
5454
+ * Calculate the reverse-mode vector-Jacobian product for a function.
5455
+ */
4862
5456
  const vjp = vjp$1;
4863
5457
  /**
5458
+ * @function
4864
5459
  * Compute the gradient of a scalar-valued function `f` with respect to its
4865
5460
  * first argument.
4866
5461
  */
4867
5462
  const grad = grad$1;
4868
- /** Create a function that evaluates both `f` and the gradient of `f`. */
5463
+ /**
5464
+ * @function
5465
+ * Create a function that evaluates both `f` and the gradient of `f`.
5466
+ */
4869
5467
  const valueAndGrad = valueAndGrad$1;
4870
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
5468
+ /**
5469
+ * @function
5470
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
5471
+ */
4871
5472
  const jacrev = jacrev$1;
4872
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
5473
+ /**
5474
+ * @function
5475
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
5476
+ */
4873
5477
  const jacobian = jacrev;
5478
+ /**
5479
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
5480
+ *
5481
+ * This can be used to wait for the results of an intermediate computation to
5482
+ * finish. It's recommended to call this regularly in an iterative computation
5483
+ * to avoid queueing up too many pending operations.
5484
+ *
5485
+ * Does not consume reference to the arrays.
5486
+ */
5487
+ async function blockUntilReady(x) {
5488
+ const promises = [];
5489
+ for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
5490
+ await Promise.all(promises);
5491
+ return x;
5492
+ }
4874
5493
 
4875
5494
  //#endregion
5495
+ exports.Array = Array$1;
4876
5496
  exports.DType = require_backend.DType;
5497
+ exports.Jaxpr = Jaxpr;
5498
+ exports.blockUntilReady = blockUntilReady;
5499
+ exports.defaultDevice = require_backend.defaultDevice;
4877
5500
  exports.devices = require_backend.devices;
4878
5501
  exports.grad = grad;
4879
5502
  exports.init = require_backend.init;
@@ -4908,7 +5531,7 @@ Object.defineProperty(exports, 'random', {
4908
5531
  return random_exports;
4909
5532
  }
4910
5533
  });
4911
- exports.setDevice = require_backend.setDevice;
5534
+ exports.setDebug = require_backend.setDebug;
4912
5535
  Object.defineProperty(exports, 'tree', {
4913
5536
  enumerable: true,
4914
5537
  get: function () {