@jax-js/jax 0.0.4 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
30
30
  }) : target, mod));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-Ss1Mev_-.cjs');
33
+ const require_backend = require('./backend-yEU0L_ig.cjs');
34
34
 
35
35
  //#region src/tree.ts
36
36
  var tree_exports = {};
@@ -596,6 +596,21 @@ var Trace = class {
596
596
  this.main = main;
597
597
  }
598
598
  };
599
+ /**
600
+ * Broadcast shapes and promote types with casting for two avals.
601
+ *
602
+ * This implements the weak type behavior described in `promoteTypes()`, but not
603
+ * implemented in that function as `weakType` is not passed.
604
+ */
605
+ function promoteAvals(a, b) {
606
+ const shape$1 = require_backend.generalBroadcast(a.shape, b.shape);
607
+ const weakType = a.weakType && b.weakType;
608
+ let dtype;
609
+ if (a.weakType === b.weakType) dtype = require_backend.promoteTypes(a.dtype, b.dtype);
610
+ else if (a.weakType) dtype = require_backend.promoteTypes(b.dtype, require_backend.DType.Uint32);
611
+ else dtype = require_backend.promoteTypes(a.dtype, require_backend.DType.Uint32);
612
+ return new ShapedArray(shape$1, dtype, weakType);
613
+ }
599
614
  var Tracer = class Tracer {
600
615
  /** @ignore */
601
616
  _trace;
@@ -610,10 +625,19 @@ var Tracer = class Tracer {
610
625
  get size() {
611
626
  return require_backend.prod(this.shape);
612
627
  }
613
- /** The dtype of the array. */
628
+ /** The dtype of elements stored in the array. */
614
629
  get dtype() {
615
630
  return this.aval.dtype;
616
631
  }
632
+ /**
633
+ * Whether the array is weakly typed.
634
+ *
635
+ * Weakly typed arrays will cast to the dtype of the other operand. See
636
+ * `promoteTypes()` for details.
637
+ */
638
+ get weakType() {
639
+ return this.aval.weakType;
640
+ }
617
641
  /** The number of dimensions of the array. */
618
642
  get ndim() {
619
643
  return this.shape.length;
@@ -850,12 +874,13 @@ function getShape(x) {
850
874
  return x instanceof Tracer ? x.shape : [];
851
875
  }
852
876
  var ShapedArray = class ShapedArray {
853
- constructor(shape$1, dtype) {
877
+ constructor(shape$1, dtype, weakType) {
854
878
  this.shape = shape$1;
855
879
  this.dtype = dtype;
880
+ this.weakType = weakType;
856
881
  }
857
882
  static fromAval(aval) {
858
- return new ShapedArray(aval.shape, aval.dtype);
883
+ return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
859
884
  }
860
885
  get ndim() {
861
886
  return this.shape.length;
@@ -869,7 +894,7 @@ var ShapedArray = class ShapedArray {
869
894
  };
870
895
  function getAval(x) {
871
896
  if (x instanceof Tracer) return x.aval;
872
- else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32);
897
+ else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32, typeof x === "boolean" ? false : true);
873
898
  else throw new TypeError(`Unknown value: ${x}`);
874
899
  }
875
900
  function bind(prim, args, params = {}) {
@@ -1154,7 +1179,7 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1154
1179
  }
1155
1180
  function broadcastedJit(fn) {
1156
1181
  return (nargs, exps, avals, params) => {
1157
- const newShape = avals.map((aval) => aval.shape).reduce(generalBroadcast);
1182
+ const newShape = avals.map((aval) => aval.shape).reduce(require_backend.generalBroadcast);
1158
1183
  exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
1159
1184
  if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1160
1185
  }));
@@ -1191,7 +1216,7 @@ const jitRules = {
1191
1216
  const k1 = reshapeViews(keys[1], mapping);
1192
1217
  const c0 = require_backend.AluExp.u32(0);
1193
1218
  const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
1194
- const exp$2 = require_backend.AluExp.threefry2x32(c0, c1, k0, k1, mode);
1219
+ const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
1195
1220
  return new require_backend.Kernel(nargs, require_backend.prod(shape$1), exp$2);
1196
1221
  },
1197
1222
  [Primitive.Sin]: unopJit(require_backend.AluExp.sin),
@@ -1232,7 +1257,7 @@ const jitRules = {
1232
1257
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
1233
1258
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
1234
1259
  const c = k1.exp;
1235
- const cs = new ShapedArray(generalBroadcast(as.shape, bs.shape), c.dtype);
1260
+ const cs = promoteAvals(as, bs);
1236
1261
  return jitRules[Primitive.Reduce](nargs, [c], [cs], {
1237
1262
  op: require_backend.AluOp.Add,
1238
1263
  axis: [cs.ndim - 1]
@@ -1242,8 +1267,8 @@ const jitRules = {
1242
1267
  const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
1243
1268
  a = reshapeViews(a, (st) => st.compose(stX));
1244
1269
  b = reshapeViews(b, (st) => st.compose(stY));
1245
- as = new ShapedArray(stX.shape, as.dtype);
1246
- bs = new ShapedArray(stY.shape, bs.dtype);
1270
+ as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1271
+ bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1247
1272
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1248
1273
  },
1249
1274
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
@@ -1260,7 +1285,7 @@ const jitRules = {
1260
1285
  [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
1261
1286
  [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1262
1287
  const axisSet = new Set(axis);
1263
- const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
1288
+ const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
1264
1289
  const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
1265
1290
  finalShape.splice(outDim, 0, ...indexShape);
1266
1291
  const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
@@ -1296,9 +1321,10 @@ function splitGraphDataflow(backend, jaxpr) {
1296
1321
  Primitive.Conv,
1297
1322
  Primitive.PoolTranspose
1298
1323
  ];
1324
+ const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
1299
1325
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
1300
1326
  const eqn = jaxpr.eqns[i];
1301
- if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1327
+ if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
1302
1328
  for (const v of eqn.outBinders) {
1303
1329
  blackNodes.add(v);
1304
1330
  p1NextBlack.set(v, v);
@@ -1428,6 +1454,7 @@ var Array$1 = class Array$1 extends Tracer {
1428
1454
  static #nextId = 1001;
1429
1455
  id;
1430
1456
  #dtype;
1457
+ #weakType;
1431
1458
  #source;
1432
1459
  #st;
1433
1460
  #backend;
@@ -1439,21 +1466,22 @@ var Array$1 = class Array$1 extends Tracer {
1439
1466
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1440
1467
  * will be freed when the array is disposed.
1441
1468
  */
1442
- constructor(source, st, dtype, backend, { pending = null } = {}) {
1469
+ constructor(args) {
1443
1470
  super(baseArrayTrace);
1444
1471
  this.id = Array$1.#nextId++;
1445
- this.#dtype = dtype;
1446
- this.#source = source;
1447
- this.#st = st;
1448
- this.#backend = backend;
1472
+ this.#dtype = args.dtype;
1473
+ this.#weakType = args.weakType;
1474
+ this.#source = args.source;
1475
+ this.#st = args.st;
1476
+ this.#backend = args.backend;
1449
1477
  this.#rc = 1;
1450
- this.#pendingSet = new Set(pending);
1478
+ this.#pendingSet = new Set(args.pending);
1451
1479
  if (this.#pendingSet.size === 0) this.#pendingSet = null;
1452
- else if (source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1480
+ else if (this.#source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1453
1481
  }
1454
1482
  /** @ignore */
1455
1483
  get aval() {
1456
- return new ShapedArray(this.#st.shape, this.#dtype);
1484
+ return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
1457
1485
  }
1458
1486
  /** Return a simple string representation of the array's dimensions. */
1459
1487
  toString() {
@@ -1465,6 +1493,17 @@ var Array$1 = class Array$1 extends Tracer {
1465
1493
  #check() {
1466
1494
  if (this.#rc <= 0) throw new UseAfterFreeError(this);
1467
1495
  }
1496
+ /** Construct an array, copying fields from `this`. */
1497
+ #newArrayFrom(args) {
1498
+ return new Array$1({
1499
+ source: args.source ?? this.#source,
1500
+ st: args.st ?? this.#st,
1501
+ dtype: args.dtype ?? this.#dtype,
1502
+ weakType: this.#weakType,
1503
+ backend: args.backend ?? this.#backend,
1504
+ pending: args.pending ?? this.#pending ?? void 0
1505
+ });
1506
+ }
1468
1507
  get ref() {
1469
1508
  this.#check();
1470
1509
  this.#rc++;
@@ -1504,7 +1543,10 @@ var Array$1 = class Array$1 extends Tracer {
1504
1543
  const pending = this.#pending;
1505
1544
  for (const exe of pending) exe.updateRc(1);
1506
1545
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1507
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
1546
+ const ar = this.#newArrayFrom({
1547
+ st,
1548
+ pending
1549
+ });
1508
1550
  this.dispose();
1509
1551
  return ar;
1510
1552
  }
@@ -1553,7 +1595,11 @@ var Array$1 = class Array$1 extends Tracer {
1553
1595
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1554
1596
  this.dispose();
1555
1597
  for (const ar of indices) ar.dispose();
1556
- return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
1598
+ return this.#newArrayFrom({
1599
+ source: output,
1600
+ st: require_backend.ShapeTracker.fromShape(finalShape),
1601
+ pending
1602
+ });
1557
1603
  }
1558
1604
  /** Move axes to the rightmost dimension of the shape. */
1559
1605
  #moveAxesDown(axis) {
@@ -1576,11 +1622,16 @@ var Array$1 = class Array$1 extends Tracer {
1576
1622
  return this.#reshape(this.#st.permute(perm));
1577
1623
  }
1578
1624
  #unary(op, dtypeOutput) {
1625
+ const weakType = !dtypeOutput && this.#weakType;
1579
1626
  dtypeOutput ??= this.#dtype;
1580
1627
  this.#check();
1581
1628
  if (this.#source instanceof require_backend.AluExp) {
1582
1629
  const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
1583
- return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1630
+ return this.#newArrayFrom({
1631
+ source: exp$3.simplify(),
1632
+ dtype: dtypeOutput,
1633
+ weakType
1634
+ });
1584
1635
  }
1585
1636
  const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
1586
1637
  const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1590,41 +1641,65 @@ var Array$1 = class Array$1 extends Tracer {
1590
1641
  for (const exe of pending) exe.updateRc(1);
1591
1642
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1592
1643
  this.dispose();
1593
- return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
1644
+ return this.#newArrayFrom({
1645
+ source: output,
1646
+ st: require_backend.ShapeTracker.fromShape(this.shape),
1647
+ dtype: dtypeOutput,
1648
+ weakType,
1649
+ pending
1650
+ });
1594
1651
  }
1595
1652
  #binary(op, other) {
1596
- const custom = (src) => new require_backend.AluExp(op, this.#dtype, src);
1653
+ const custom = (src) => new require_backend.AluExp(op, src[0].dtype, src);
1597
1654
  return Array$1.#naryCustom(op, custom, [this, other]);
1598
1655
  }
1599
- static #naryCustom(name, custom, arrays, { dtypeOverride, dtypeOutput, reduceAxis } = {}) {
1656
+ static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1600
1657
  const n = arrays.length;
1601
1658
  const backend = arrays[0].#backend;
1602
1659
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1603
1660
  for (const ar of arrays) ar.#check();
1604
- let dtype;
1661
+ let castDtype;
1662
+ let castWeakType = true;
1605
1663
  for (let i = 0; i < n; i++) {
1606
1664
  if (dtypeOverride?.[i]) {
1607
1665
  if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1608
- } else if (!dtype) dtype = arrays[i].#dtype;
1609
- else if (arrays[i].#dtype !== dtype) throw new TypeError(`Dtype mismatch in ${name}: ${dtype} vs ${arrays[i].#dtype}`);
1666
+ } else if (castDtype === void 0) {
1667
+ castDtype = arrays[i].#dtype;
1668
+ castWeakType = arrays[i].#weakType;
1669
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1610
1670
  if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1611
1671
  }
1612
- dtypeOutput ??= dtype;
1613
- if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1672
+ const weakType = castWeakType && !strongTypeOutput;
1614
1673
  arrays = Array$1.#broadcastArrays(arrays);
1615
1674
  const newShape = [...arrays[0].shape];
1616
1675
  if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
1676
+ const sources = arrays.map((ar, i) => {
1677
+ if (!dtypeOverride?.[i]) return require_backend.AluExp.cast(castDtype, ar.#source);
1678
+ else return ar.#source;
1679
+ });
1617
1680
  if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
1618
- const exp$4 = custom(arrays.map((ar) => ar.#source));
1619
- return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1681
+ const exp$4 = custom(sources);
1682
+ return new Array$1({
1683
+ source: exp$4.simplify(),
1684
+ st: arrays[0].#st,
1685
+ dtype: exp$4.dtype,
1686
+ weakType,
1687
+ backend
1688
+ });
1620
1689
  }
1621
- const exp$3 = custom(arrays.map((ar) => {
1622
- const src$1 = ar.#source;
1690
+ const exp$3 = custom(arrays.map((ar, i) => {
1691
+ const src$1 = sources[i];
1623
1692
  if (ar.#st.contiguous) return src$1;
1624
1693
  return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
1625
1694
  }));
1626
1695
  const st = require_backend.ShapeTracker.fromShape(newShape);
1627
- return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1696
+ return new Array$1({
1697
+ source: exp$3.simplify(),
1698
+ st,
1699
+ dtype: exp$3.dtype,
1700
+ weakType,
1701
+ backend
1702
+ });
1628
1703
  }
1629
1704
  let indices;
1630
1705
  if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
@@ -1634,14 +1709,19 @@ var Array$1 = class Array$1 extends Tracer {
1634
1709
  }
1635
1710
  const inputs = [];
1636
1711
  const src = [];
1637
- for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#source, ar.#st, indices));
1638
- else {
1639
- let gid = inputs.indexOf(ar.#source);
1640
- if (gid === -1) {
1641
- gid = inputs.length;
1642
- inputs.push(ar.#source);
1712
+ for (const [i, ar] of arrays.entries()) {
1713
+ let nextSrc;
1714
+ if (ar.#source instanceof require_backend.AluExp) nextSrc = require_backend.accessorAluExp(ar.#source, ar.#st, indices);
1715
+ else {
1716
+ let gid = inputs.indexOf(ar.#source);
1717
+ if (gid === -1) {
1718
+ gid = inputs.length;
1719
+ inputs.push(ar.#source);
1720
+ }
1721
+ nextSrc = require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
1643
1722
  }
1644
- src.push(require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, indices));
1723
+ if (!dtypeOverride?.[i]) nextSrc = require_backend.AluExp.cast(castDtype, nextSrc);
1724
+ src.push(nextSrc);
1645
1725
  }
1646
1726
  const exp$2 = custom(src);
1647
1727
  let re = void 0;
@@ -1655,12 +1735,17 @@ var Array$1 = class Array$1 extends Tracer {
1655
1735
  for (const exe of pending) exe.updateRc(1);
1656
1736
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1657
1737
  for (const ar of arrays) ar.dispose();
1658
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
1738
+ return new Array$1({
1739
+ source: output,
1740
+ st: require_backend.ShapeTracker.fromShape(newShape),
1741
+ dtype: kernel.dtype,
1742
+ weakType,
1743
+ backend,
1744
+ pending
1745
+ });
1659
1746
  }
1660
1747
  /** Reduce the last dimension of the array by an operation. */
1661
1748
  #reduce(op) {
1662
- this.#check();
1663
- if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
1664
1749
  const shape$1 = this.shape;
1665
1750
  const reduction = new require_backend.Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
1666
1751
  const newShape = shape$1.slice(0, -1);
@@ -1679,7 +1764,11 @@ var Array$1 = class Array$1 extends Tracer {
1679
1764
  for (const exe of pending) exe.updateRc(1);
1680
1765
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1681
1766
  this.dispose();
1682
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
1767
+ return this.#newArrayFrom({
1768
+ source: output,
1769
+ st: require_backend.ShapeTracker.fromShape(newShape),
1770
+ pending
1771
+ });
1683
1772
  }
1684
1773
  /**
1685
1774
  * Normalizes this array into one backed by a `Slot`.
@@ -1715,15 +1804,15 @@ var Array$1 = class Array$1 extends Tracer {
1715
1804
  }
1716
1805
  #dataInline() {
1717
1806
  this.#check();
1718
- const exp$2 = this.#source;
1719
- const ar = new Array$1(exp$2, this.#st, this.dtype, require_backend.getBackend("cpu"));
1807
+ if (!(this.#source instanceof require_backend.AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
1808
+ const ar = this.#newArrayFrom({ backend: require_backend.getBackend("cpu") });
1720
1809
  this.dispose();
1721
1810
  return ar.dataSync();
1722
1811
  }
1723
1812
  static #broadcastArrays(arrays) {
1724
1813
  if (arrays.length === 0) throw new Error("Need at least one array to broadcast");
1725
1814
  if (arrays.length === 1) return arrays;
1726
- const newShape = arrays.map((a) => a.shape).reduce(generalBroadcast);
1815
+ const newShape = arrays.map((a) => a.shape).reduce(require_backend.generalBroadcast);
1727
1816
  return arrays.map((ar) => {
1728
1817
  if (require_backend.deepEqual(ar.shape, newShape)) return ar;
1729
1818
  return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
@@ -1842,14 +1931,18 @@ var Array$1 = class Array$1 extends Tracer {
1842
1931
  x.#backend.incRef(x.#source);
1843
1932
  const pending = x.#pending;
1844
1933
  for (const exe of pending) exe.updateRc(1);
1845
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
1934
+ const y = x.#newArrayFrom({
1935
+ dtype,
1936
+ weakType: false,
1937
+ pending
1938
+ });
1846
1939
  x.dispose();
1847
1940
  return [y];
1848
1941
  }
1849
1942
  },
1850
1943
  [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
1851
- const keyShape = generalBroadcast(k0.shape, k1.shape);
1852
- if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1944
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
1945
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1853
1946
  const c0 = zeros(shape$1, {
1854
1947
  dtype: require_backend.DType.Uint32,
1855
1948
  device: k0.device
@@ -1917,7 +2010,7 @@ var Array$1 = class Array$1 extends Tracer {
1917
2010
  },
1918
2011
  [Primitive.Compare]([x, y], { op }) {
1919
2012
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1920
- return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: require_backend.DType.Bool })];
2013
+ return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
1921
2014
  },
1922
2015
  [Primitive.Where]([cond, x, y]) {
1923
2016
  const custom = ([cond$1, x$1, y$1]) => require_backend.AluExp.where(cond$1, x$1, y$1);
@@ -1963,7 +2056,14 @@ var Array$1 = class Array$1 extends Tracer {
1963
2056
  pending.splice(0, 0, ...prevPending);
1964
2057
  args.forEach((x) => x.dispose());
1965
2058
  return outputs.map((source, i) => {
1966
- return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
2059
+ return new Array$1({
2060
+ source,
2061
+ st: require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
2062
+ dtype: jaxpr.outs[i].aval.dtype,
2063
+ weakType: jaxpr.outs[i].aval.weakType,
2064
+ backend,
2065
+ pending
2066
+ });
1967
2067
  });
1968
2068
  }
1969
2069
  };
@@ -1973,33 +2073,11 @@ var Array$1 = class Array$1 extends Tracer {
1973
2073
  return this.#source;
1974
2074
  }
1975
2075
  };
1976
- /** Construct an array from a single scalar constant. */
1977
- function scalar(value, { dtype, device } = {}) {
1978
- if (typeof value === "number") {
1979
- dtype ??= require_backend.DType.Float32;
1980
- if (![
1981
- require_backend.DType.Float32,
1982
- require_backend.DType.Float16,
1983
- require_backend.DType.Int32,
1984
- require_backend.DType.Uint32
1985
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1986
- } else if (typeof value === "boolean") {
1987
- dtype ??= require_backend.DType.Bool;
1988
- if (![
1989
- require_backend.DType.Float32,
1990
- require_backend.DType.Float16,
1991
- require_backend.DType.Int32,
1992
- require_backend.DType.Uint32,
1993
- require_backend.DType.Bool
1994
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1995
- } else throw new TypeError(`Invalid type for scalar ${value}`);
1996
- return new Array$1(require_backend.AluExp.const(dtype, value), require_backend.ShapeTracker.fromShape([]), dtype, require_backend.getBackend(device));
1997
- }
1998
2076
  /** Constructor for creating a new array from data. */
1999
2077
  function array(values, { shape: shape$1, dtype, device } = {}) {
2000
2078
  if (values instanceof Tracer) {
2001
2079
  if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
2002
- if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
2080
+ if (dtype && values.dtype !== dtype) values = values.astype(dtype);
2003
2081
  return values;
2004
2082
  } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
2005
2083
  dtype,
@@ -2021,6 +2099,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
2021
2099
  dtype,
2022
2100
  device
2023
2101
  });
2102
+ if (size$1 === 1) return full(shape$1, flat[0], {
2103
+ dtype,
2104
+ device
2105
+ });
2024
2106
  if (typeof flat[0] === "boolean") {
2025
2107
  dtype = dtype ?? require_backend.DType.Bool;
2026
2108
  const data = new Int32Array(flat.map((x) => x ? 1 : 0));
@@ -2029,46 +2111,51 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
2029
2111
  device
2030
2112
  });
2031
2113
  } else {
2114
+ const weakType = dtype == void 0;
2032
2115
  dtype = dtype ?? require_backend.DType.Float32;
2033
2116
  const data = require_backend.dtypedJsArray(dtype, flat);
2034
2117
  return arrayFromData(data, shape$1, {
2035
2118
  dtype,
2036
2119
  device
2037
- });
2120
+ }, weakType);
2038
2121
  }
2039
2122
  }
2040
2123
  }
2041
- function arrayFromData(data, shape$1, { dtype, device } = {}) {
2124
+ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2125
+ if (data instanceof Float32Array) {
2126
+ if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
2127
+ dtype ??= require_backend.DType.Float32;
2128
+ } else if (data instanceof Int32Array) {
2129
+ if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2130
+ dtype ??= require_backend.DType.Int32;
2131
+ } else if (data instanceof Uint32Array) {
2132
+ if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2133
+ dtype ??= require_backend.DType.Uint32;
2134
+ } else if (data instanceof Float16Array) {
2135
+ if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2136
+ dtype ??= require_backend.DType.Float16;
2137
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
2042
2138
  if (data.length < inlineArrayLimit) {
2043
2139
  let allEqual = true;
2044
2140
  for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
2045
2141
  allEqual = false;
2046
2142
  break;
2047
2143
  }
2048
- if (allEqual) return full(shape$1, data[0], {
2049
- dtype,
2050
- device
2051
- });
2144
+ if (allEqual) {
2145
+ const sa = new ShapedArray(shape$1, dtype, weakType);
2146
+ return fullInternal(sa, data[0], device);
2147
+ }
2052
2148
  }
2053
2149
  const backend = require_backend.getBackend(device);
2054
- if (ArrayBuffer.isView(data)) {
2055
- const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2056
- if (data instanceof Float32Array) {
2057
- if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
2058
- dtype ??= require_backend.DType.Float32;
2059
- } else if (data instanceof Int32Array) {
2060
- if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2061
- dtype ??= require_backend.DType.Int32;
2062
- } else if (data instanceof Uint32Array) {
2063
- if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2064
- dtype ??= require_backend.DType.Uint32;
2065
- } else if (data instanceof Float16Array) {
2066
- if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2067
- dtype ??= require_backend.DType.Float16;
2068
- } else throw new Error("Unsupported data array type: " + data.constructor.name);
2069
- const slot = backend.malloc(data.byteLength, buf);
2070
- return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
2071
- } else throw new Error("Unsupported data type: " + data.constructor.name);
2150
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2151
+ const slot = backend.malloc(data.byteLength, buf);
2152
+ return new Array$1({
2153
+ source: slot,
2154
+ st: require_backend.ShapeTracker.fromShape(shape$1),
2155
+ dtype,
2156
+ weakType,
2157
+ backend
2158
+ });
2072
2159
  }
2073
2160
  function dataToJs(dtype, data, shape$1) {
2074
2161
  if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
@@ -2084,7 +2171,7 @@ function dataToJs(dtype, data, shape$1) {
2084
2171
  /** If x is a value, lift it into an array, otherwise leave it be. */
2085
2172
  function pureArray(x) {
2086
2173
  if (x instanceof Tracer) return x;
2087
- else return scalar(x);
2174
+ else return array(x);
2088
2175
  }
2089
2176
  var EvalTrace = class extends Trace {
2090
2177
  pure = (x) => pureArray(x);
@@ -2095,20 +2182,27 @@ var EvalTrace = class extends Trace {
2095
2182
  };
2096
2183
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
2097
2184
  const implRules = Array$1._implRules();
2185
+ function fullInternal(aval, fillValue, device) {
2186
+ return new Array$1({
2187
+ source: require_backend.AluExp.const(aval.dtype, fillValue),
2188
+ st: require_backend.ShapeTracker.fromShape(aval.shape),
2189
+ dtype: aval.dtype,
2190
+ weakType: aval.weakType,
2191
+ backend: require_backend.getBackend(device)
2192
+ });
2193
+ }
2098
2194
  function zerosLike$1(val, dtype) {
2099
- const aval = getAval(val);
2100
- if (val instanceof Tracer) val.dispose();
2101
- return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2195
+ return fullLike(val, 0, dtype);
2102
2196
  }
2103
2197
  function onesLike$1(val, dtype) {
2104
- const aval = getAval(val);
2105
- if (val instanceof Tracer) val.dispose();
2106
- return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2198
+ return fullLike(val, 1, dtype);
2107
2199
  }
2108
2200
  function fullLike(val, fillValue, dtype) {
2109
2201
  const aval = getAval(val);
2110
2202
  if (val instanceof Tracer) val.dispose();
2111
- return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
2203
+ if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
2204
+ const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
2205
+ return fullInternal(sa, fillValue);
2112
2206
  }
2113
2207
  /** Return a new array of given shape and type, filled with zeros. */
2114
2208
  function zeros(shape$1, { dtype, device } = {}) {
@@ -2126,19 +2220,14 @@ function ones(shape$1, { dtype, device } = {}) {
2126
2220
  }
2127
2221
  /** Return a new array of given shape and type, filled with `fill_value`. */
2128
2222
  function full(shape$1, fillValue, { dtype, device } = {}) {
2129
- let source;
2130
- if (typeof fillValue === "number") {
2131
- dtype = dtype ?? require_backend.DType.Float32;
2132
- source = require_backend.AluExp.const(dtype, fillValue);
2133
- } else if (typeof fillValue === "bigint") {
2134
- dtype = dtype ?? require_backend.DType.Int32;
2135
- source = require_backend.AluExp.const(dtype, Number(fillValue));
2136
- } else if (typeof fillValue === "boolean") {
2223
+ let weakType = dtype == void 0;
2224
+ if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
2225
+ else if (typeof fillValue === "boolean") {
2137
2226
  dtype = dtype ?? require_backend.DType.Bool;
2138
- source = require_backend.AluExp.const(dtype, fillValue ? 1 : 0);
2227
+ weakType = false;
2139
2228
  } else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
2140
2229
  else throw new TypeError(`Invalid type for full: ${fillValue}`);
2141
- return new Array$1(source, require_backend.ShapeTracker.fromShape(shape$1), dtype ?? require_backend.DType.Float32, require_backend.getBackend(device));
2230
+ return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
2142
2231
  }
2143
2232
  /**
2144
2233
  * Create an identity matrix.
@@ -2148,6 +2237,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
2148
2237
  */
2149
2238
  function eye(numRows, numCols, { dtype, device } = {}) {
2150
2239
  numCols = numCols ?? numRows;
2240
+ const weakType = dtype == void 0;
2151
2241
  dtype = dtype ?? require_backend.DType.Float32;
2152
2242
  if (numCols < numRows) {
2153
2243
  const arr = eye(numCols, numRows, {
@@ -2161,7 +2251,13 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2161
2251
  device
2162
2252
  });
2163
2253
  const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
2164
- return new Array$1(require_backend.AluExp.cast(dtype, exp$2), require_backend.ShapeTracker.fromShape([numRows, numCols]), dtype, require_backend.getBackend(device));
2254
+ return new Array$1({
2255
+ source: require_backend.AluExp.cast(dtype, exp$2),
2256
+ st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
2257
+ dtype,
2258
+ weakType,
2259
+ backend: require_backend.getBackend(device)
2260
+ });
2165
2261
  }
2166
2262
  /** Return the identity matrix, with ones on the main diagonal. */
2167
2263
  function identity$1(n, { dtype, device } = {}) {
@@ -2198,7 +2294,13 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2198
2294
  });
2199
2295
  const exp$2 = require_backend.AluExp.add(require_backend.AluExp.const(dtype, start), require_backend.AluExp.mul(require_backend.AluExp.cast(dtype, require_backend.AluVar.idx), require_backend.AluExp.const(dtype, step)));
2200
2296
  const st = require_backend.ShapeTracker.fromShape([size$1]);
2201
- return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
2297
+ return new Array$1({
2298
+ source: exp$2,
2299
+ st,
2300
+ dtype,
2301
+ weakType: false,
2302
+ backend: require_backend.getBackend(device)
2303
+ });
2202
2304
  }
2203
2305
  /**
2204
2306
  * Return evenly spaced numbers over a specified interval.
@@ -2216,10 +2318,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2216
2318
  dtype,
2217
2319
  device
2218
2320
  });
2219
- else if (num === 1) return scalar(start, {
2321
+ else if (num === 1) return full([1], start, {
2220
2322
  dtype,
2221
2323
  device
2222
- }).reshape([1]);
2324
+ });
2223
2325
  else if (start === stop) return full([num], start, {
2224
2326
  dtype,
2225
2327
  device
@@ -2228,7 +2330,13 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2228
2330
  const denom = endpoint ? num - 1 : num;
2229
2331
  const exp$2 = require_backend.AluExp.cast(dtype, require_backend.AluExp.add(require_backend.AluExp.f32(start), require_backend.AluExp.mul(require_backend.AluExp.f32(delta / denom), require_backend.AluExp.cast(require_backend.DType.Float32, require_backend.AluVar.idx))));
2230
2332
  const st = require_backend.ShapeTracker.fromShape([num]);
2231
- return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
2333
+ return new Array$1({
2334
+ source: exp$2,
2335
+ st,
2336
+ dtype,
2337
+ weakType: false,
2338
+ backend: require_backend.getBackend(device)
2339
+ });
2232
2340
  }
2233
2341
  function aluCompare(a, b, op) {
2234
2342
  switch (op) {
@@ -2240,35 +2348,6 @@ function aluCompare(a, b, op) {
2240
2348
  case CompareOp.LessEqual: return require_backend.AluExp.add(require_backend.AluExp.cmplt(a, b), require_backend.AluExp.cmpne(a, b).not());
2241
2349
  }
2242
2350
  }
2243
- /**
2244
- * Implements a NumPy-style generalized broadcast rule on two array shapes.
2245
- *
2246
- * "When operating on two arrays, NumPy compares their shapes element-wise. It
2247
- * starts with the trailing (i.e. rightmost) dimension and works its way left.
2248
- * Two dimensions are compatible when:
2249
- * 1. they are equal, or
2250
- * 2. one of them is 1."
2251
- *
2252
- * Throws a TypeError if the broadcast is not possible.
2253
- *
2254
- * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
2255
- */
2256
- function generalBroadcast(a, b) {
2257
- const out = [];
2258
- let i = a.length - 1;
2259
- let j = b.length - 1;
2260
- for (; i >= 0 && j >= 0; i--, j--) {
2261
- const x = a[i];
2262
- const y = b[j];
2263
- if (x === y) out.push(x);
2264
- else if (x === 1) out.push(y);
2265
- else if (y === 1) out.push(x);
2266
- else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
2267
- }
2268
- for (; i >= 0; i--) out.push(a[i]);
2269
- for (; j >= 0; j--) out.push(b[j]);
2270
- return out.reverse();
2271
- }
2272
2351
 
2273
2352
  //#endregion
2274
2353
  //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
@@ -2348,13 +2427,15 @@ var Var = class Var {
2348
2427
  };
2349
2428
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
2350
2429
  var Lit = class {
2351
- dtype;
2352
2430
  value;
2353
2431
  aval;
2354
- constructor(dtype, value) {
2355
- this.dtype = dtype;
2432
+ get dtype() {
2433
+ return this.aval.dtype;
2434
+ }
2435
+ constructor(aval, value) {
2436
+ if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
2356
2437
  this.value = value;
2357
- this.aval = new ShapedArray([], dtype);
2438
+ this.aval = ShapedArray.fromAval(aval);
2358
2439
  }
2359
2440
  };
2360
2441
  function atomIsLit(atom, literal) {
@@ -2478,14 +2559,19 @@ var Jaxpr = class Jaxpr {
2478
2559
  const c = eqn.outBinders[0];
2479
2560
  if (atomIsLit(a, 0)) context.set(c, b);
2480
2561
  else if (atomIsLit(b, 0)) context.set(c, a);
2481
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.dtype === require_backend.DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2562
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === require_backend.DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2563
+ else newEqns.push(eqn);
2564
+ } else if (eqn.primitive === Primitive.Neg) {
2565
+ const [a] = inputs;
2566
+ const c = eqn.outBinders[0];
2567
+ if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
2482
2568
  else newEqns.push(eqn);
2483
2569
  } else if (eqn.primitive === Primitive.Mul) {
2484
2570
  const [a, b] = inputs;
2485
2571
  const c = eqn.outBinders[0];
2486
2572
  if (atomIsLit(a, 1)) context.set(c, b);
2487
2573
  else if (atomIsLit(b, 1)) context.set(c, a);
2488
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.value * b.value));
2574
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
2489
2575
  else newEqns.push(eqn);
2490
2576
  } else if (eqn.primitive === Primitive.Idiv) {
2491
2577
  const [a, b] = inputs;
@@ -2583,7 +2669,7 @@ function evalJaxpr(jaxpr, args) {
2583
2669
  if (x instanceof Var) {
2584
2670
  remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
2585
2671
  return env.get(x);
2586
- } else return scalar(x.value, { dtype: x.dtype });
2672
+ } else return array(x.value, { dtype: x.dtype });
2587
2673
  };
2588
2674
  const write = (v, val) => {
2589
2675
  if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
@@ -2642,7 +2728,7 @@ var JaxprTrace = class extends Trace {
2642
2728
  let tracer = this.builder.constTracers.get(val);
2643
2729
  if (tracer === void 0) {
2644
2730
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
2645
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : scalar(val));
2731
+ this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
2646
2732
  }
2647
2733
  return tracer;
2648
2734
  }
@@ -2711,7 +2797,7 @@ function _inlineLiterals(jaxpr, consts) {
2711
2797
  const newConsts = [];
2712
2798
  for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
2713
2799
  const ar = consts[i];
2714
- literals.set(jaxpr.inBinders[i], new Lit(ar.dtype, ar.dataSync()[0]));
2800
+ literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
2715
2801
  } else {
2716
2802
  constBinders.push(jaxpr.inBinders[i]);
2717
2803
  newConsts.push(consts[i]);
@@ -2724,13 +2810,12 @@ function _inlineLiterals(jaxpr, consts) {
2724
2810
  }
2725
2811
  function binopAbstractEval([x, y]) {
2726
2812
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
2727
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2728
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
2813
+ return [promoteAvals(x, y)];
2729
2814
  }
2730
2815
  function compareAbstractEval([x, y]) {
2731
2816
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
2732
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2733
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), require_backend.DType.Bool)];
2817
+ const aval = promoteAvals(x, y);
2818
+ return [new ShapedArray(aval.shape, require_backend.DType.Bool, false)];
2734
2819
  }
2735
2820
  function vectorizedUnopAbstractEval([x]) {
2736
2821
  return [ShapedArray.fromAval(x)];
@@ -2743,18 +2828,18 @@ const abstractEvalRules = {
2743
2828
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
2744
2829
  [Primitive.StopGradient]: vectorizedUnopAbstractEval,
2745
2830
  [Primitive.Cast]([x], { dtype }) {
2746
- return [new ShapedArray(x.shape, dtype)];
2831
+ return [new ShapedArray(x.shape, dtype, false)];
2747
2832
  },
2748
2833
  [Primitive.Bitcast]([x], { dtype }) {
2749
2834
  if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
2750
2835
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
2751
- return [new ShapedArray(x.shape, dtype)];
2836
+ return [new ShapedArray(x.shape, dtype, false)];
2752
2837
  },
2753
2838
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
2754
2839
  if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
2755
- const keyShape = generalBroadcast(k0.shape, k1.shape);
2756
- if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2757
- return [new ShapedArray(shape$1, require_backend.DType.Uint32)];
2840
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2841
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2842
+ return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
2758
2843
  },
2759
2844
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2760
2845
  [Primitive.Cos]: vectorizedUnopAbstractEval,
@@ -2768,55 +2853,54 @@ const abstractEvalRules = {
2768
2853
  [Primitive.Reduce]([x], { axis }) {
2769
2854
  const axisSet = new Set(axis);
2770
2855
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2771
- return [new ShapedArray(newShape, x.dtype)];
2856
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2772
2857
  },
2773
2858
  [Primitive.Pool]([x], { window, strides }) {
2774
2859
  const shape$1 = checkPoolShape(x.shape, window, strides);
2775
- return [new ShapedArray(shape$1, x.dtype)];
2860
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2776
2861
  },
2777
2862
  [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2778
2863
  const shape$1 = checkPoolShape(inShape, window, strides);
2779
2864
  if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2780
- return [new ShapedArray(inShape, x.dtype)];
2865
+ return [new ShapedArray(inShape, x.dtype, x.weakType)];
2781
2866
  },
2782
2867
  [Primitive.Dot]([x, y]) {
2783
- if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2784
2868
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
2785
- const shape$1 = generalBroadcast(x.shape, y.shape);
2869
+ const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
2786
2870
  shape$1.splice(-1, 1);
2787
- return [new ShapedArray(shape$1, x.dtype)];
2871
+ return [new ShapedArray(shape$1, dtype, weakType)];
2788
2872
  },
2789
2873
  [Primitive.Conv]([lhs, rhs], params) {
2790
- if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2874
+ const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
2791
2875
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2792
- return [new ShapedArray(shape$1, lhs.dtype)];
2876
+ return [new ShapedArray(shape$1, dtype, weakType)];
2793
2877
  },
2794
2878
  [Primitive.Compare]: compareAbstractEval,
2795
2879
  [Primitive.Where]([cond, x, y]) {
2796
2880
  if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
2797
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2798
- const shape$1 = generalBroadcast(cond.shape, generalBroadcast(x.shape, y.shape));
2799
- return [new ShapedArray(shape$1, x.dtype)];
2881
+ const xy = promoteAvals(x, y);
2882
+ const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
2883
+ return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
2800
2884
  },
2801
2885
  [Primitive.Transpose]([x], { perm }) {
2802
- return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
2886
+ return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
2803
2887
  },
2804
2888
  [Primitive.Broadcast]([x], { shape: shape$1 }) {
2805
- return [new ShapedArray(shape$1, x.dtype)];
2889
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2806
2890
  },
2807
2891
  [Primitive.Reshape]([x], { shape: shape$1 }) {
2808
- return [new ShapedArray(shape$1, x.dtype)];
2892
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2809
2893
  },
2810
2894
  [Primitive.Flip]([x], _) {
2811
- return [new ShapedArray(x.shape, x.dtype)];
2895
+ return [ShapedArray.fromAval(x)];
2812
2896
  },
2813
2897
  [Primitive.Shrink]([x], { slice }) {
2814
2898
  const newShape = slice.map((s) => s[1] - s[0]);
2815
- return [new ShapedArray(newShape, x.dtype)];
2899
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2816
2900
  },
2817
2901
  [Primitive.Pad]([x], { width }) {
2818
2902
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
2819
- return [new ShapedArray(newShape, x.dtype)];
2903
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2820
2904
  },
2821
2905
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2822
2906
  for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
@@ -2826,10 +2910,10 @@ const abstractEvalRules = {
2826
2910
  if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
2827
2911
  const axisSet = new Set(axis);
2828
2912
  if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
2829
- const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
2913
+ const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
2830
2914
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2831
2915
  newShape.splice(outDim, 0, ...gatherShape);
2832
- return [new ShapedArray(newShape, x.dtype)];
2916
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2833
2917
  },
2834
2918
  [Primitive.JitCall](args, { jaxpr }) {
2835
2919
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
@@ -2896,6 +2980,7 @@ function jit$1(f, opts) {
2896
2980
  const cacheKey = JSON.stringify(jaxprArgs);
2897
2981
  const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2898
2982
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
2983
+ name: f.name || "closure",
2899
2984
  jaxpr,
2900
2985
  numConsts: consts.length
2901
2986
  });
@@ -3058,13 +3143,14 @@ const jvpRules = {
3058
3143
  const indicesRef = indices.map((t) => t.ref);
3059
3144
  return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3060
3145
  },
3061
- [Primitive.JitCall](primals, tangents, { jaxpr }) {
3146
+ [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3062
3147
  const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3063
3148
  const outs = bind(Primitive.JitCall, [
3064
3149
  ...newConsts.map((c) => c.ref),
3065
3150
  ...primals,
3066
3151
  ...tangents
3067
3152
  ], {
3153
+ name: `${name}_jvp`,
3068
3154
  jaxpr: newJaxpr,
3069
3155
  numConsts: newConsts.length
3070
3156
  });
@@ -3119,7 +3205,7 @@ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3119
3205
  function mappedAval(batchDim, aval) {
3120
3206
  const shape$1 = [...aval.shape];
3121
3207
  shape$1.splice(batchDim, 1);
3122
- return new ShapedArray(shape$1, aval.dtype);
3208
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3123
3209
  }
3124
3210
  /** Move one axis to a different index. */
3125
3211
  function moveaxis$1(x, src, dst) {
@@ -3263,9 +3349,10 @@ const vmapRules = {
3263
3349
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3264
3350
  return [[pad$1(x, newWidth)], [xBdim]];
3265
3351
  },
3266
- [Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
3352
+ [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3267
3353
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3268
3354
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3355
+ name: `${name}_vmap`,
3269
3356
  jaxpr: newJaxpr,
3270
3357
  numConsts: newConsts.length
3271
3358
  });
@@ -3281,7 +3368,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3281
3368
  if (dims[i] === null) return v.aval;
3282
3369
  const shape$1 = [...v.aval.shape];
3283
3370
  shape$1.splice(dims[i], 0, axisSize);
3284
- return new ShapedArray(shape$1, v.aval.dtype);
3371
+ return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3285
3372
  });
3286
3373
  const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3287
3374
  const result = {
@@ -3494,8 +3581,8 @@ var PartialEvalTrace = class extends Trace {
3494
3581
  processPrimitive(primitive, tracers, params) {
3495
3582
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3496
3583
  if (primitive === Primitive.JitCall) {
3497
- const { jaxpr, numConsts } = params;
3498
- return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
3584
+ const { name, jaxpr, numConsts } = params;
3585
+ return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3499
3586
  }
3500
3587
  const tracersIn = tracers.map((t) => this.instantiateConst(t));
3501
3588
  const avalsIn = tracersIn.map((t) => t.pval.aval);
@@ -3521,12 +3608,13 @@ var PartialEvalTrace = class extends Trace {
3521
3608
  *
3522
3609
  * Used when encountering a JitCall rule during the trace.
3523
3610
  */
3524
- #partialEvalJaxpr(jaxpr, numConsts, tracers) {
3611
+ #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3525
3612
  jaxpr = jaxpr.flatten();
3526
3613
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3527
3614
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3528
3615
  const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
3529
3616
  const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
3617
+ name: `${name}_peval`,
3530
3618
  jaxpr: jaxpr1,
3531
3619
  numConsts: 0
3532
3620
  });
@@ -3538,6 +3626,7 @@ var PartialEvalTrace = class extends Trace {
3538
3626
  prim: Primitive.JitCall,
3539
3627
  tracersIn: resTracers.concat(unknownTracers),
3540
3628
  params: {
3629
+ name: `${name}_resid`,
3541
3630
  jaxpr: jaxpr2,
3542
3631
  numConsts: 0
3543
3632
  },
@@ -3680,7 +3769,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
3680
3769
  }
3681
3770
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
3682
3771
  const eqn = jaxpr.eqns[i];
3683
- const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? scalar(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3772
+ const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3684
3773
  const cotangentsOut = eqn.outBinders.map(readCotangent);
3685
3774
  const rule = transposeRules[eqn.primitive];
3686
3775
  if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
@@ -3765,7 +3854,7 @@ const transposeRules = {
3765
3854
  },
3766
3855
  [Primitive.Dot]([ct], [x, y]) {
3767
3856
  if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
3768
- const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
3857
+ const axisSize = require_backend.generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
3769
3858
  ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
3770
3859
  return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
3771
3860
  },
@@ -3860,7 +3949,7 @@ const transposeRules = {
3860
3949
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
3861
3950
  throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
3862
3951
  },
3863
- [Primitive.JitCall](cts, args, { jaxpr }) {
3952
+ [Primitive.JitCall](cts, args, { name, jaxpr }) {
3864
3953
  const undefPrimals = args.map((x) => x instanceof UndefPrimal);
3865
3954
  const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
3866
3955
  const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
@@ -3869,6 +3958,7 @@ const transposeRules = {
3869
3958
  ...residuals,
3870
3959
  ...cts
3871
3960
  ], {
3961
+ name: `${name}_t`,
3872
3962
  jaxpr: newJaxpr,
3873
3963
  numConsts: newConsts.length
3874
3964
  });
@@ -3943,7 +4033,7 @@ function valueAndGrad$1(f) {
3943
4033
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3944
4034
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3945
4035
  if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3946
- const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
4036
+ const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
3947
4037
  for (const r of rest) dispose(r);
3948
4038
  fVjp.dispose();
3949
4039
  return [y, ct];
@@ -4313,7 +4403,7 @@ function argmin(a, axis, opts) {
4313
4403
  } else axis = require_backend.checkAxis(axis, a.ndim);
4314
4404
  const shape$1 = a.shape;
4315
4405
  const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
4316
- const length = scalar(shape$1[axis], {
4406
+ const length = array(shape$1[axis], {
4317
4407
  dtype: int32,
4318
4408
  device: a.device
4319
4409
  });
@@ -4337,7 +4427,7 @@ function argmax(a, axis, opts) {
4337
4427
  } else axis = require_backend.checkAxis(axis, a.ndim);
4338
4428
  const shape$1 = a.shape;
4339
4429
  const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
4340
- const length = scalar(shape$1[axis], {
4430
+ const length = array(shape$1[axis], {
4341
4431
  dtype: int32,
4342
4432
  device: a.device
4343
4433
  });
@@ -4521,7 +4611,7 @@ function broadcastTo(a, shape$1) {
4521
4611
  /** Broadcast input shapes to a common output shape. */
4522
4612
  function broadcastShapes(...shapes) {
4523
4613
  if (shapes.length === 0) return [];
4524
- return shapes.reduce(generalBroadcast);
4614
+ return shapes.reduce(require_backend.generalBroadcast);
4525
4615
  }
4526
4616
  /** Broadcast arrays to a common shape. */
4527
4617
  function broadcastArrays(...arrays) {
@@ -4753,7 +4843,7 @@ function acos(x) {
4753
4843
  * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4754
4844
  * improvements.
4755
4845
  */
4756
- const hypot = jit$1((x1, x2) => {
4846
+ const hypot = jit$1(function hypot$1(x1, x2) {
4757
4847
  return sqrt(square(x1).add(square(x2)));
4758
4848
  });
4759
4849
  /**
@@ -4769,7 +4859,7 @@ const hypot = jit$1((x1, x2) => {
4769
4859
  *
4770
4860
  * The output is ill-defined when both x and y are zero.
4771
4861
  */
4772
- const atan2 = jit$1((y, x) => {
4862
+ const atan2 = jit$1(function atan2$1(y, x) {
4773
4863
  const r = sqrt(square(x.ref).add(square(y.ref)));
4774
4864
  const xNeg = less(x.ref, 0);
4775
4865
  const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
@@ -4837,13 +4927,13 @@ const degrees = rad2deg;
4837
4927
  * @function
4838
4928
  * Computes first array raised to power of second array, element-wise.
4839
4929
  */
4840
- const power = jit$1((x1, x2) => {
4930
+ const power = jit$1(function power$1(x1, x2) {
4841
4931
  return exp(log(x1).mul(x2));
4842
4932
  });
4843
4933
  /** @function Alias of `jax.numpy.power()`. */
4844
4934
  const pow = power;
4845
4935
  /** @function Calculate the element-wise cube root of the input array. */
4846
- const cbrt = jit$1((x) => {
4936
+ const cbrt = jit$1(function cbrt$1(x) {
4847
4937
  const sgn = where(less(x.ref, 0), -1, 1);
4848
4938
  return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4849
4939
  });
@@ -4853,7 +4943,7 @@ const cbrt = jit$1((x) => {
4853
4943
  *
4854
4944
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
4855
4945
  */
4856
- const sinh = jit$1((x) => {
4946
+ const sinh = jit$1(function sinh$1(x) {
4857
4947
  const ex = exp(x);
4858
4948
  const emx = reciprocal(ex.ref);
4859
4949
  return ex.sub(emx).mul(.5);
@@ -4864,7 +4954,7 @@ const sinh = jit$1((x) => {
4864
4954
  *
4865
4955
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
4866
4956
  */
4867
- const cosh = jit$1((x) => {
4957
+ const cosh = jit$1(function cosh$1(x) {
4868
4958
  const ex = exp(x);
4869
4959
  const emx = reciprocal(ex.ref);
4870
4960
  return ex.add(emx).mul(.5);
@@ -4875,7 +4965,7 @@ const cosh = jit$1((x) => {
4875
4965
  *
4876
4966
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4877
4967
  */
4878
- const tanh = jit$1((x) => {
4968
+ const tanh = jit$1(function tanh$1(x) {
4879
4969
  const negsgn = where(less(x.ref, 0), 1, -1);
4880
4970
  const en2x = exp(x.mul(negsgn.ref).mul(2));
4881
4971
  return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
@@ -4886,7 +4976,7 @@ const tanh = jit$1((x) => {
4886
4976
  *
4887
4977
  * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4888
4978
  */
4889
- const arcsinh = jit$1((x) => {
4979
+ const arcsinh = jit$1(function arcsinh$1(x) {
4890
4980
  return log(x.ref.add(sqrt(square(x).add(1))));
4891
4981
  });
4892
4982
  /**
@@ -4895,7 +4985,7 @@ const arcsinh = jit$1((x) => {
4895
4985
  *
4896
4986
  * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4897
4987
  */
4898
- const arccosh = jit$1((x) => {
4988
+ const arccosh = jit$1(function arccosh$1(x) {
4899
4989
  return log(x.ref.add(sqrt(square(x).sub(1))));
4900
4990
  });
4901
4991
  /**
@@ -4904,7 +4994,7 @@ const arccosh = jit$1((x) => {
4904
4994
  *
4905
4995
  * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4906
4996
  */
4907
- const arctanh = jit$1((x) => {
4997
+ const arctanh = jit$1(function arctanh$1(x) {
4908
4998
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4909
4999
  });
4910
5000
  /** @function Alias of `jax.numpy.arcsinh()`. */
@@ -5020,7 +5110,9 @@ function softSign(x) {
5020
5110
  *
5021
5111
  * Reference: https://en.wikipedia.org/wiki/Swish_function
5022
5112
  */
5023
- const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
5113
+ const silu = jit$1(function silu$1(x) {
5114
+ return x.ref.mul(sigmoid(x));
5115
+ });
5024
5116
  /**
5025
5117
  * @function
5026
5118
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
@@ -5081,7 +5173,7 @@ function celu(x, alpha = 1) {
5081
5173
  *
5082
5174
  * This will be improved in the future.
5083
5175
  */
5084
- const gelu = jit$1((x) => {
5176
+ const gelu = jit$1(function gelu$1(x) {
5085
5177
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5086
5178
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5087
5179
  });
@@ -5252,8 +5344,11 @@ function bits(key$1, shape$1 = []) {
5252
5344
  const keyShape = validateKeyShape(key$1);
5253
5345
  return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
5254
5346
  }
5255
- /** Sample uniform random values in [minval, maxval) with given shape. */
5256
- function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5347
+ /**
5348
+ * @function
5349
+ * Sample uniform random values in [minval, maxval) with given shape.
5350
+ */
5351
+ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5257
5352
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
5258
5353
  const mantissa = bits(key$1, shape$1).div(array(512, {
5259
5354
  dtype: require_backend.DType.Uint32,
@@ -5266,7 +5361,7 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5266
5361
  const rand = bitcast(float12, require_backend.DType.Float32).sub(1);
5267
5362
  if (minval === 0 && maxval === 1) return rand;
5268
5363
  else return rand.mul(maxval - minval).add(minval);
5269
- }
5364
+ }, { staticArgnums: [1, 2] });
5270
5365
  /**
5271
5366
  * Sample Bernoulli random variables with given mean (0,1 categorical).
5272
5367
  *
@@ -5277,26 +5372,30 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
5277
5372
  p = fudgeArray(p);
5278
5373
  return uniform(key$1, shape$1).less(p);
5279
5374
  }
5280
- /** Sample exponential random values according to `p(x) = exp(-x)`. */
5281
- function exponential(key$1, shape$1 = []) {
5375
+ /**
5376
+ * @function
5377
+ * Sample exponential random values according to `p(x) = exp(-x)`.
5378
+ */
5379
+ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
5282
5380
  const u = uniform(key$1, shape$1);
5283
5381
  return negative(log1p(negative(u)));
5284
- }
5382
+ }, { staticArgnums: [1] });
5285
5383
  /**
5384
+ * @function
5286
5385
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5287
5386
  *
5288
5387
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5289
5388
  * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5290
5389
  * bitwise identical to JAX.
5291
5390
  */
5292
- function normal(key$1, shape$1 = []) {
5391
+ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5293
5392
  const [k1, k2] = split(key$1, 2);
5294
5393
  const u1 = uniform(k1, shape$1);
5295
5394
  const u2 = uniform(k2, shape$1);
5296
5395
  const radius = sqrt(log1p(negative(u1)).mul(-2));
5297
5396
  const theta = u2.mul(2 * Math.PI);
5298
5397
  return radius.mul(cos(theta));
5299
- }
5398
+ }, { staticArgnums: [1] });
5300
5399
 
5301
5400
  //#endregion
5302
5401
  //#region src/polyfills.ts