@jax-js/jax 0.0.4 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-EBRGmEYw.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-CdcTZEOF.js";
3
3
 
4
4
  //#region src/tree.ts
5
5
  var tree_exports = {};
@@ -565,6 +565,21 @@ var Trace = class {
565
565
  this.main = main;
566
566
  }
567
567
  };
568
+ /**
569
+ * Broadcast shapes and promote types with casting for two avals.
570
+ *
571
+ * This implements the weak type behavior described in `promoteTypes()`, but not
572
+ * implemented in that function as `weakType` is not passed.
573
+ */
574
+ function promoteAvals(a, b) {
575
+ const shape$1 = generalBroadcast(a.shape, b.shape);
576
+ const weakType = a.weakType && b.weakType;
577
+ let dtype;
578
+ if (a.weakType === b.weakType) dtype = promoteTypes(a.dtype, b.dtype);
579
+ else if (a.weakType) dtype = promoteTypes(b.dtype, DType.Uint32);
580
+ else dtype = promoteTypes(a.dtype, DType.Uint32);
581
+ return new ShapedArray(shape$1, dtype, weakType);
582
+ }
568
583
  var Tracer = class Tracer {
569
584
  /** @ignore */
570
585
  _trace;
@@ -579,10 +594,19 @@ var Tracer = class Tracer {
579
594
  get size() {
580
595
  return prod(this.shape);
581
596
  }
582
- /** The dtype of the array. */
597
+ /** The dtype of elements stored in the array. */
583
598
  get dtype() {
584
599
  return this.aval.dtype;
585
600
  }
601
+ /**
602
+ * Whether the array is weakly typed.
603
+ *
604
+ * Weakly typed arrays will cast to the dtype of the other operand. See
605
+ * `promoteTypes()` for details.
606
+ */
607
+ get weakType() {
608
+ return this.aval.weakType;
609
+ }
586
610
  /** The number of dimensions of the array. */
587
611
  get ndim() {
588
612
  return this.shape.length;
@@ -819,12 +843,13 @@ function getShape(x) {
819
843
  return x instanceof Tracer ? x.shape : [];
820
844
  }
821
845
  var ShapedArray = class ShapedArray {
822
- constructor(shape$1, dtype) {
846
+ constructor(shape$1, dtype, weakType) {
823
847
  this.shape = shape$1;
824
848
  this.dtype = dtype;
849
+ this.weakType = weakType;
825
850
  }
826
851
  static fromAval(aval) {
827
- return new ShapedArray(aval.shape, aval.dtype);
852
+ return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
828
853
  }
829
854
  get ndim() {
830
855
  return this.shape.length;
@@ -838,7 +863,7 @@ var ShapedArray = class ShapedArray {
838
863
  };
839
864
  function getAval(x) {
840
865
  if (x instanceof Tracer) return x.aval;
841
- else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32);
866
+ else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32, typeof x === "boolean" ? false : true);
842
867
  else throw new TypeError(`Unknown value: ${x}`);
843
868
  }
844
869
  function bind(prim, args, params = {}) {
@@ -1160,7 +1185,7 @@ const jitRules = {
1160
1185
  const k1 = reshapeViews(keys[1], mapping);
1161
1186
  const c0 = AluExp.u32(0);
1162
1187
  const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
1163
- const exp$2 = AluExp.threefry2x32(c0, c1, k0, k1, mode);
1188
+ const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
1164
1189
  return new Kernel(nargs, prod(shape$1), exp$2);
1165
1190
  },
1166
1191
  [Primitive.Sin]: unopJit(AluExp.sin),
@@ -1201,7 +1226,7 @@ const jitRules = {
1201
1226
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
1202
1227
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
1203
1228
  const c = k1.exp;
1204
- const cs = new ShapedArray(generalBroadcast(as.shape, bs.shape), c.dtype);
1229
+ const cs = promoteAvals(as, bs);
1205
1230
  return jitRules[Primitive.Reduce](nargs, [c], [cs], {
1206
1231
  op: AluOp.Add,
1207
1232
  axis: [cs.ndim - 1]
@@ -1211,8 +1236,8 @@ const jitRules = {
1211
1236
  const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
1212
1237
  a = reshapeViews(a, (st) => st.compose(stX));
1213
1238
  b = reshapeViews(b, (st) => st.compose(stY));
1214
- as = new ShapedArray(stX.shape, as.dtype);
1215
- bs = new ShapedArray(stY.shape, bs.dtype);
1239
+ as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1240
+ bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1216
1241
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1217
1242
  },
1218
1243
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
@@ -1265,9 +1290,10 @@ function splitGraphDataflow(backend, jaxpr) {
1265
1290
  Primitive.Conv,
1266
1291
  Primitive.PoolTranspose
1267
1292
  ];
1293
+ const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
1268
1294
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
1269
1295
  const eqn = jaxpr.eqns[i];
1270
- if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1296
+ if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
1271
1297
  for (const v of eqn.outBinders) {
1272
1298
  blackNodes.add(v);
1273
1299
  p1NextBlack.set(v, v);
@@ -1397,6 +1423,7 @@ var Array$1 = class Array$1 extends Tracer {
1397
1423
  static #nextId = 1001;
1398
1424
  id;
1399
1425
  #dtype;
1426
+ #weakType;
1400
1427
  #source;
1401
1428
  #st;
1402
1429
  #backend;
@@ -1408,21 +1435,22 @@ var Array$1 = class Array$1 extends Tracer {
1408
1435
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1409
1436
  * will be freed when the array is disposed.
1410
1437
  */
1411
- constructor(source, st, dtype, backend, { pending = null } = {}) {
1438
+ constructor(args) {
1412
1439
  super(baseArrayTrace);
1413
1440
  this.id = Array$1.#nextId++;
1414
- this.#dtype = dtype;
1415
- this.#source = source;
1416
- this.#st = st;
1417
- this.#backend = backend;
1441
+ this.#dtype = args.dtype;
1442
+ this.#weakType = args.weakType;
1443
+ this.#source = args.source;
1444
+ this.#st = args.st;
1445
+ this.#backend = args.backend;
1418
1446
  this.#rc = 1;
1419
- this.#pendingSet = new Set(pending);
1447
+ this.#pendingSet = new Set(args.pending);
1420
1448
  if (this.#pendingSet.size === 0) this.#pendingSet = null;
1421
- else if (source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1449
+ else if (this.#source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1422
1450
  }
1423
1451
  /** @ignore */
1424
1452
  get aval() {
1425
- return new ShapedArray(this.#st.shape, this.#dtype);
1453
+ return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
1426
1454
  }
1427
1455
  /** Return a simple string representation of the array's dimensions. */
1428
1456
  toString() {
@@ -1434,6 +1462,17 @@ var Array$1 = class Array$1 extends Tracer {
1434
1462
  #check() {
1435
1463
  if (this.#rc <= 0) throw new UseAfterFreeError(this);
1436
1464
  }
1465
+ /** Construct an array, copying fields from `this`. */
1466
+ #newArrayFrom(args) {
1467
+ return new Array$1({
1468
+ source: args.source ?? this.#source,
1469
+ st: args.st ?? this.#st,
1470
+ dtype: args.dtype ?? this.#dtype,
1471
+ weakType: this.#weakType,
1472
+ backend: args.backend ?? this.#backend,
1473
+ pending: args.pending ?? this.#pending ?? void 0
1474
+ });
1475
+ }
1437
1476
  get ref() {
1438
1477
  this.#check();
1439
1478
  this.#rc++;
@@ -1473,7 +1512,10 @@ var Array$1 = class Array$1 extends Tracer {
1473
1512
  const pending = this.#pending;
1474
1513
  for (const exe of pending) exe.updateRc(1);
1475
1514
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1476
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
1515
+ const ar = this.#newArrayFrom({
1516
+ st,
1517
+ pending
1518
+ });
1477
1519
  this.dispose();
1478
1520
  return ar;
1479
1521
  }
@@ -1522,7 +1564,11 @@ var Array$1 = class Array$1 extends Tracer {
1522
1564
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1523
1565
  this.dispose();
1524
1566
  for (const ar of indices) ar.dispose();
1525
- return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
1567
+ return this.#newArrayFrom({
1568
+ source: output,
1569
+ st: ShapeTracker.fromShape(finalShape),
1570
+ pending
1571
+ });
1526
1572
  }
1527
1573
  /** Move axes to the rightmost dimension of the shape. */
1528
1574
  #moveAxesDown(axis) {
@@ -1545,11 +1591,16 @@ var Array$1 = class Array$1 extends Tracer {
1545
1591
  return this.#reshape(this.#st.permute(perm));
1546
1592
  }
1547
1593
  #unary(op, dtypeOutput) {
1594
+ const weakType = !dtypeOutput && this.#weakType;
1548
1595
  dtypeOutput ??= this.#dtype;
1549
1596
  this.#check();
1550
1597
  if (this.#source instanceof AluExp) {
1551
1598
  const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
1552
- return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1599
+ return this.#newArrayFrom({
1600
+ source: exp$3.simplify(),
1601
+ dtype: dtypeOutput,
1602
+ weakType
1603
+ });
1553
1604
  }
1554
1605
  const indices = unravelAlu(this.#st.shape, AluVar.gidx);
1555
1606
  const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1559,41 +1610,65 @@ var Array$1 = class Array$1 extends Tracer {
1559
1610
  for (const exe of pending) exe.updateRc(1);
1560
1611
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1561
1612
  this.dispose();
1562
- return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
1613
+ return this.#newArrayFrom({
1614
+ source: output,
1615
+ st: ShapeTracker.fromShape(this.shape),
1616
+ dtype: dtypeOutput,
1617
+ weakType,
1618
+ pending
1619
+ });
1563
1620
  }
1564
1621
  #binary(op, other) {
1565
- const custom = (src) => new AluExp(op, this.#dtype, src);
1622
+ const custom = (src) => new AluExp(op, src[0].dtype, src);
1566
1623
  return Array$1.#naryCustom(op, custom, [this, other]);
1567
1624
  }
1568
- static #naryCustom(name, custom, arrays, { dtypeOverride, dtypeOutput, reduceAxis } = {}) {
1625
+ static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1569
1626
  const n = arrays.length;
1570
1627
  const backend = arrays[0].#backend;
1571
1628
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1572
1629
  for (const ar of arrays) ar.#check();
1573
- let dtype;
1630
+ let castDtype;
1631
+ let castWeakType = true;
1574
1632
  for (let i = 0; i < n; i++) {
1575
1633
  if (dtypeOverride?.[i]) {
1576
1634
  if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1577
- } else if (!dtype) dtype = arrays[i].#dtype;
1578
- else if (arrays[i].#dtype !== dtype) throw new TypeError(`Dtype mismatch in ${name}: ${dtype} vs ${arrays[i].#dtype}`);
1635
+ } else if (castDtype === void 0) {
1636
+ castDtype = arrays[i].#dtype;
1637
+ castWeakType = arrays[i].#weakType;
1638
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1579
1639
  if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1580
1640
  }
1581
- dtypeOutput ??= dtype;
1582
- if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1641
+ const weakType = castWeakType && !strongTypeOutput;
1583
1642
  arrays = Array$1.#broadcastArrays(arrays);
1584
1643
  const newShape = [...arrays[0].shape];
1585
1644
  if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
1645
+ const sources = arrays.map((ar, i) => {
1646
+ if (!dtypeOverride?.[i]) return AluExp.cast(castDtype, ar.#source);
1647
+ else return ar.#source;
1648
+ });
1586
1649
  if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
1587
- const exp$4 = custom(arrays.map((ar) => ar.#source));
1588
- return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1650
+ const exp$4 = custom(sources);
1651
+ return new Array$1({
1652
+ source: exp$4.simplify(),
1653
+ st: arrays[0].#st,
1654
+ dtype: exp$4.dtype,
1655
+ weakType,
1656
+ backend
1657
+ });
1589
1658
  }
1590
- const exp$3 = custom(arrays.map((ar) => {
1591
- const src$1 = ar.#source;
1659
+ const exp$3 = custom(arrays.map((ar, i) => {
1660
+ const src$1 = sources[i];
1592
1661
  if (ar.#st.contiguous) return src$1;
1593
1662
  return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
1594
1663
  }));
1595
1664
  const st = ShapeTracker.fromShape(newShape);
1596
- return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1665
+ return new Array$1({
1666
+ source: exp$3.simplify(),
1667
+ st,
1668
+ dtype: exp$3.dtype,
1669
+ weakType,
1670
+ backend
1671
+ });
1597
1672
  }
1598
1673
  let indices;
1599
1674
  if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
@@ -1603,14 +1678,19 @@ var Array$1 = class Array$1 extends Tracer {
1603
1678
  }
1604
1679
  const inputs = [];
1605
1680
  const src = [];
1606
- for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#source, ar.#st, indices));
1607
- else {
1608
- let gid = inputs.indexOf(ar.#source);
1609
- if (gid === -1) {
1610
- gid = inputs.length;
1611
- inputs.push(ar.#source);
1681
+ for (const [i, ar] of arrays.entries()) {
1682
+ let nextSrc;
1683
+ if (ar.#source instanceof AluExp) nextSrc = accessorAluExp(ar.#source, ar.#st, indices);
1684
+ else {
1685
+ let gid = inputs.indexOf(ar.#source);
1686
+ if (gid === -1) {
1687
+ gid = inputs.length;
1688
+ inputs.push(ar.#source);
1689
+ }
1690
+ nextSrc = AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
1612
1691
  }
1613
- src.push(AluExp.globalView(ar.#dtype, gid, ar.#st, indices));
1692
+ if (!dtypeOverride?.[i]) nextSrc = AluExp.cast(castDtype, nextSrc);
1693
+ src.push(nextSrc);
1614
1694
  }
1615
1695
  const exp$2 = custom(src);
1616
1696
  let re = void 0;
@@ -1624,12 +1704,17 @@ var Array$1 = class Array$1 extends Tracer {
1624
1704
  for (const exe of pending) exe.updateRc(1);
1625
1705
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1626
1706
  for (const ar of arrays) ar.dispose();
1627
- return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
1707
+ return new Array$1({
1708
+ source: output,
1709
+ st: ShapeTracker.fromShape(newShape),
1710
+ dtype: kernel.dtype,
1711
+ weakType,
1712
+ backend,
1713
+ pending
1714
+ });
1628
1715
  }
1629
1716
  /** Reduce the last dimension of the array by an operation. */
1630
1717
  #reduce(op) {
1631
- this.#check();
1632
- if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
1633
1718
  const shape$1 = this.shape;
1634
1719
  const reduction = new Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
1635
1720
  const newShape = shape$1.slice(0, -1);
@@ -1648,7 +1733,11 @@ var Array$1 = class Array$1 extends Tracer {
1648
1733
  for (const exe of pending) exe.updateRc(1);
1649
1734
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1650
1735
  this.dispose();
1651
- return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
1736
+ return this.#newArrayFrom({
1737
+ source: output,
1738
+ st: ShapeTracker.fromShape(newShape),
1739
+ pending
1740
+ });
1652
1741
  }
1653
1742
  /**
1654
1743
  * Normalizes this array into one backed by a `Slot`.
@@ -1684,8 +1773,8 @@ var Array$1 = class Array$1 extends Tracer {
1684
1773
  }
1685
1774
  #dataInline() {
1686
1775
  this.#check();
1687
- const exp$2 = this.#source;
1688
- const ar = new Array$1(exp$2, this.#st, this.dtype, getBackend("cpu"));
1776
+ if (!(this.#source instanceof AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
1777
+ const ar = this.#newArrayFrom({ backend: getBackend("cpu") });
1689
1778
  this.dispose();
1690
1779
  return ar.dataSync();
1691
1780
  }
@@ -1811,7 +1900,11 @@ var Array$1 = class Array$1 extends Tracer {
1811
1900
  x.#backend.incRef(x.#source);
1812
1901
  const pending = x.#pending;
1813
1902
  for (const exe of pending) exe.updateRc(1);
1814
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
1903
+ const y = x.#newArrayFrom({
1904
+ dtype,
1905
+ weakType: false,
1906
+ pending
1907
+ });
1815
1908
  x.dispose();
1816
1909
  return [y];
1817
1910
  }
@@ -1886,7 +1979,7 @@ var Array$1 = class Array$1 extends Tracer {
1886
1979
  },
1887
1980
  [Primitive.Compare]([x, y], { op }) {
1888
1981
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1889
- return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: DType.Bool })];
1982
+ return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
1890
1983
  },
1891
1984
  [Primitive.Where]([cond, x, y]) {
1892
1985
  const custom = ([cond$1, x$1, y$1]) => AluExp.where(cond$1, x$1, y$1);
@@ -1932,7 +2025,14 @@ var Array$1 = class Array$1 extends Tracer {
1932
2025
  pending.splice(0, 0, ...prevPending);
1933
2026
  args.forEach((x) => x.dispose());
1934
2027
  return outputs.map((source, i) => {
1935
- return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
2028
+ return new Array$1({
2029
+ source,
2030
+ st: ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
2031
+ dtype: jaxpr.outs[i].aval.dtype,
2032
+ weakType: jaxpr.outs[i].aval.weakType,
2033
+ backend,
2034
+ pending
2035
+ });
1936
2036
  });
1937
2037
  }
1938
2038
  };
@@ -1942,33 +2042,11 @@ var Array$1 = class Array$1 extends Tracer {
1942
2042
  return this.#source;
1943
2043
  }
1944
2044
  };
1945
- /** Construct an array from a single scalar constant. */
1946
- function scalar(value, { dtype, device } = {}) {
1947
- if (typeof value === "number") {
1948
- dtype ??= DType.Float32;
1949
- if (![
1950
- DType.Float32,
1951
- DType.Float16,
1952
- DType.Int32,
1953
- DType.Uint32
1954
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1955
- } else if (typeof value === "boolean") {
1956
- dtype ??= DType.Bool;
1957
- if (![
1958
- DType.Float32,
1959
- DType.Float16,
1960
- DType.Int32,
1961
- DType.Uint32,
1962
- DType.Bool
1963
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1964
- } else throw new TypeError(`Invalid type for scalar ${value}`);
1965
- return new Array$1(AluExp.const(dtype, value), ShapeTracker.fromShape([]), dtype, getBackend(device));
1966
- }
1967
2045
  /** Constructor for creating a new array from data. */
1968
2046
  function array(values, { shape: shape$1, dtype, device } = {}) {
1969
2047
  if (values instanceof Tracer) {
1970
2048
  if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
1971
- if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
2049
+ if (dtype && values.dtype !== dtype) values = values.astype(dtype);
1972
2050
  return values;
1973
2051
  } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
1974
2052
  dtype,
@@ -1990,6 +2068,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1990
2068
  dtype,
1991
2069
  device
1992
2070
  });
2071
+ if (size$1 === 1) return full(shape$1, flat[0], {
2072
+ dtype,
2073
+ device
2074
+ });
1993
2075
  if (typeof flat[0] === "boolean") {
1994
2076
  dtype = dtype ?? DType.Bool;
1995
2077
  const data = new Int32Array(flat.map((x) => x ? 1 : 0));
@@ -1998,46 +2080,51 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1998
2080
  device
1999
2081
  });
2000
2082
  } else {
2083
+ const weakType = dtype == void 0;
2001
2084
  dtype = dtype ?? DType.Float32;
2002
2085
  const data = dtypedJsArray(dtype, flat);
2003
2086
  return arrayFromData(data, shape$1, {
2004
2087
  dtype,
2005
2088
  device
2006
- });
2089
+ }, weakType);
2007
2090
  }
2008
2091
  }
2009
2092
  }
2010
- function arrayFromData(data, shape$1, { dtype, device } = {}) {
2093
+ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2094
+ if (data instanceof Float32Array) {
2095
+ if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
2096
+ dtype ??= DType.Float32;
2097
+ } else if (data instanceof Int32Array) {
2098
+ if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2099
+ dtype ??= DType.Int32;
2100
+ } else if (data instanceof Uint32Array) {
2101
+ if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2102
+ dtype ??= DType.Uint32;
2103
+ } else if (data instanceof Float16Array) {
2104
+ if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
2105
+ dtype ??= DType.Float16;
2106
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
2011
2107
  if (data.length < inlineArrayLimit) {
2012
2108
  let allEqual = true;
2013
2109
  for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
2014
2110
  allEqual = false;
2015
2111
  break;
2016
2112
  }
2017
- if (allEqual) return full(shape$1, data[0], {
2018
- dtype,
2019
- device
2020
- });
2113
+ if (allEqual) {
2114
+ const sa = new ShapedArray(shape$1, dtype, weakType);
2115
+ return fullInternal(sa, data[0], device);
2116
+ }
2021
2117
  }
2022
2118
  const backend = getBackend(device);
2023
- if (ArrayBuffer.isView(data)) {
2024
- const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2025
- if (data instanceof Float32Array) {
2026
- if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
2027
- dtype ??= DType.Float32;
2028
- } else if (data instanceof Int32Array) {
2029
- if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2030
- dtype ??= DType.Int32;
2031
- } else if (data instanceof Uint32Array) {
2032
- if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2033
- dtype ??= DType.Uint32;
2034
- } else if (data instanceof Float16Array) {
2035
- if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
2036
- dtype ??= DType.Float16;
2037
- } else throw new Error("Unsupported data array type: " + data.constructor.name);
2038
- const slot = backend.malloc(data.byteLength, buf);
2039
- return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
2040
- } else throw new Error("Unsupported data type: " + data.constructor.name);
2119
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2120
+ const slot = backend.malloc(data.byteLength, buf);
2121
+ return new Array$1({
2122
+ source: slot,
2123
+ st: ShapeTracker.fromShape(shape$1),
2124
+ dtype,
2125
+ weakType,
2126
+ backend
2127
+ });
2041
2128
  }
2042
2129
  function dataToJs(dtype, data, shape$1) {
2043
2130
  if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
@@ -2053,7 +2140,7 @@ function dataToJs(dtype, data, shape$1) {
2053
2140
  /** If x is a value, lift it into an array, otherwise leave it be. */
2054
2141
  function pureArray(x) {
2055
2142
  if (x instanceof Tracer) return x;
2056
- else return scalar(x);
2143
+ else return array(x);
2057
2144
  }
2058
2145
  var EvalTrace = class extends Trace {
2059
2146
  pure = (x) => pureArray(x);
@@ -2064,20 +2151,27 @@ var EvalTrace = class extends Trace {
2064
2151
  };
2065
2152
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
2066
2153
  const implRules = Array$1._implRules();
2154
+ function fullInternal(aval, fillValue, device) {
2155
+ return new Array$1({
2156
+ source: AluExp.const(aval.dtype, fillValue),
2157
+ st: ShapeTracker.fromShape(aval.shape),
2158
+ dtype: aval.dtype,
2159
+ weakType: aval.weakType,
2160
+ backend: getBackend(device)
2161
+ });
2162
+ }
2067
2163
  function zerosLike$1(val, dtype) {
2068
- const aval = getAval(val);
2069
- if (val instanceof Tracer) val.dispose();
2070
- return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2164
+ return fullLike(val, 0, dtype);
2071
2165
  }
2072
2166
  function onesLike$1(val, dtype) {
2073
- const aval = getAval(val);
2074
- if (val instanceof Tracer) val.dispose();
2075
- return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2167
+ return fullLike(val, 1, dtype);
2076
2168
  }
2077
2169
  function fullLike(val, fillValue, dtype) {
2078
2170
  const aval = getAval(val);
2079
2171
  if (val instanceof Tracer) val.dispose();
2080
- return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
2172
+ if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
2173
+ const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
2174
+ return fullInternal(sa, fillValue);
2081
2175
  }
2082
2176
  /** Return a new array of given shape and type, filled with zeros. */
2083
2177
  function zeros(shape$1, { dtype, device } = {}) {
@@ -2095,19 +2189,14 @@ function ones(shape$1, { dtype, device } = {}) {
2095
2189
  }
2096
2190
  /** Return a new array of given shape and type, filled with `fill_value`. */
2097
2191
  function full(shape$1, fillValue, { dtype, device } = {}) {
2098
- let source;
2099
- if (typeof fillValue === "number") {
2100
- dtype = dtype ?? DType.Float32;
2101
- source = AluExp.const(dtype, fillValue);
2102
- } else if (typeof fillValue === "bigint") {
2103
- dtype = dtype ?? DType.Int32;
2104
- source = AluExp.const(dtype, Number(fillValue));
2105
- } else if (typeof fillValue === "boolean") {
2192
+ let weakType = dtype == void 0;
2193
+ if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
2194
+ else if (typeof fillValue === "boolean") {
2106
2195
  dtype = dtype ?? DType.Bool;
2107
- source = AluExp.const(dtype, fillValue ? 1 : 0);
2196
+ weakType = false;
2108
2197
  } else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
2109
2198
  else throw new TypeError(`Invalid type for full: ${fillValue}`);
2110
- return new Array$1(source, ShapeTracker.fromShape(shape$1), dtype ?? DType.Float32, getBackend(device));
2199
+ return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
2111
2200
  }
2112
2201
  /**
2113
2202
  * Create an identity matrix.
@@ -2117,6 +2206,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
2117
2206
  */
2118
2207
  function eye(numRows, numCols, { dtype, device } = {}) {
2119
2208
  numCols = numCols ?? numRows;
2209
+ const weakType = dtype == void 0;
2120
2210
  dtype = dtype ?? DType.Float32;
2121
2211
  if (numCols < numRows) {
2122
2212
  const arr = eye(numCols, numRows, {
@@ -2130,7 +2220,13 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2130
2220
  device
2131
2221
  });
2132
2222
  const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
2133
- return new Array$1(AluExp.cast(dtype, exp$2), ShapeTracker.fromShape([numRows, numCols]), dtype, getBackend(device));
2223
+ return new Array$1({
2224
+ source: AluExp.cast(dtype, exp$2),
2225
+ st: ShapeTracker.fromShape([numRows, numCols]),
2226
+ dtype,
2227
+ weakType,
2228
+ backend: getBackend(device)
2229
+ });
2134
2230
  }
2135
2231
  /** Return the identity matrix, with ones on the main diagonal. */
2136
2232
  function identity$1(n, { dtype, device } = {}) {
@@ -2167,7 +2263,13 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2167
2263
  });
2168
2264
  const exp$2 = AluExp.add(AluExp.const(dtype, start), AluExp.mul(AluExp.cast(dtype, AluVar.idx), AluExp.const(dtype, step)));
2169
2265
  const st = ShapeTracker.fromShape([size$1]);
2170
- return new Array$1(exp$2, st, dtype, getBackend(device));
2266
+ return new Array$1({
2267
+ source: exp$2,
2268
+ st,
2269
+ dtype,
2270
+ weakType: false,
2271
+ backend: getBackend(device)
2272
+ });
2171
2273
  }
2172
2274
  /**
2173
2275
  * Return evenly spaced numbers over a specified interval.
@@ -2185,10 +2287,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2185
2287
  dtype,
2186
2288
  device
2187
2289
  });
2188
- else if (num === 1) return scalar(start, {
2290
+ else if (num === 1) return full([1], start, {
2189
2291
  dtype,
2190
2292
  device
2191
- }).reshape([1]);
2293
+ });
2192
2294
  else if (start === stop) return full([num], start, {
2193
2295
  dtype,
2194
2296
  device
@@ -2197,7 +2299,13 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2197
2299
  const denom = endpoint ? num - 1 : num;
2198
2300
  const exp$2 = AluExp.cast(dtype, AluExp.add(AluExp.f32(start), AluExp.mul(AluExp.f32(delta / denom), AluExp.cast(DType.Float32, AluVar.idx))));
2199
2301
  const st = ShapeTracker.fromShape([num]);
2200
- return new Array$1(exp$2, st, dtype, getBackend(device));
2302
+ return new Array$1({
2303
+ source: exp$2,
2304
+ st,
2305
+ dtype,
2306
+ weakType: false,
2307
+ backend: getBackend(device)
2308
+ });
2201
2309
  }
2202
2310
  function aluCompare(a, b, op) {
2203
2311
  switch (op) {
@@ -2209,35 +2317,6 @@ function aluCompare(a, b, op) {
2209
2317
  case CompareOp.LessEqual: return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
2210
2318
  }
2211
2319
  }
2212
- /**
2213
- * Implements a NumPy-style generalized broadcast rule on two array shapes.
2214
- *
2215
- * "When operating on two arrays, NumPy compares their shapes element-wise. It
2216
- * starts with the trailing (i.e. rightmost) dimension and works its way left.
2217
- * Two dimensions are compatible when:
2218
- * 1. they are equal, or
2219
- * 2. one of them is 1."
2220
- *
2221
- * Throws a TypeError if the broadcast is not possible.
2222
- *
2223
- * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
2224
- */
2225
- function generalBroadcast(a, b) {
2226
- const out = [];
2227
- let i = a.length - 1;
2228
- let j = b.length - 1;
2229
- for (; i >= 0 && j >= 0; i--, j--) {
2230
- const x = a[i];
2231
- const y = b[j];
2232
- if (x === y) out.push(x);
2233
- else if (x === 1) out.push(y);
2234
- else if (y === 1) out.push(x);
2235
- else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
2236
- }
2237
- for (; i >= 0; i--) out.push(a[i]);
2238
- for (; j >= 0; j--) out.push(b[j]);
2239
- return out.reverse();
2240
- }
2241
2320
 
2242
2321
  //#endregion
2243
2322
  //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
@@ -2313,13 +2392,15 @@ var Var = class Var {
2313
2392
  };
2314
2393
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
2315
2394
  var Lit = class {
2316
- dtype;
2317
2395
  value;
2318
2396
  aval;
2319
- constructor(dtype, value) {
2320
- this.dtype = dtype;
2397
+ get dtype() {
2398
+ return this.aval.dtype;
2399
+ }
2400
+ constructor(aval, value) {
2401
+ if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
2321
2402
  this.value = value;
2322
- this.aval = new ShapedArray([], dtype);
2403
+ this.aval = ShapedArray.fromAval(aval);
2323
2404
  }
2324
2405
  };
2325
2406
  function atomIsLit(atom, literal) {
@@ -2443,14 +2524,19 @@ var Jaxpr = class Jaxpr {
2443
2524
  const c = eqn.outBinders[0];
2444
2525
  if (atomIsLit(a, 0)) context.set(c, b);
2445
2526
  else if (atomIsLit(b, 0)) context.set(c, a);
2446
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.dtype === DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2527
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2528
+ else newEqns.push(eqn);
2529
+ } else if (eqn.primitive === Primitive.Neg) {
2530
+ const [a] = inputs;
2531
+ const c = eqn.outBinders[0];
2532
+ if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
2447
2533
  else newEqns.push(eqn);
2448
2534
  } else if (eqn.primitive === Primitive.Mul) {
2449
2535
  const [a, b] = inputs;
2450
2536
  const c = eqn.outBinders[0];
2451
2537
  if (atomIsLit(a, 1)) context.set(c, b);
2452
2538
  else if (atomIsLit(b, 1)) context.set(c, a);
2453
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.value * b.value));
2539
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
2454
2540
  else newEqns.push(eqn);
2455
2541
  } else if (eqn.primitive === Primitive.Idiv) {
2456
2542
  const [a, b] = inputs;
@@ -2548,7 +2634,7 @@ function evalJaxpr(jaxpr, args) {
2548
2634
  if (x instanceof Var) {
2549
2635
  remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
2550
2636
  return env.get(x);
2551
- } else return scalar(x.value, { dtype: x.dtype });
2637
+ } else return array(x.value, { dtype: x.dtype });
2552
2638
  };
2553
2639
  const write = (v, val) => {
2554
2640
  if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
@@ -2607,7 +2693,7 @@ var JaxprTrace = class extends Trace {
2607
2693
  let tracer = this.builder.constTracers.get(val);
2608
2694
  if (tracer === void 0) {
2609
2695
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
2610
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : scalar(val));
2696
+ this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
2611
2697
  }
2612
2698
  return tracer;
2613
2699
  }
@@ -2676,7 +2762,7 @@ function _inlineLiterals(jaxpr, consts) {
2676
2762
  const newConsts = [];
2677
2763
  for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
2678
2764
  const ar = consts[i];
2679
- literals.set(jaxpr.inBinders[i], new Lit(ar.dtype, ar.dataSync()[0]));
2765
+ literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
2680
2766
  } else {
2681
2767
  constBinders.push(jaxpr.inBinders[i]);
2682
2768
  newConsts.push(consts[i]);
@@ -2689,13 +2775,12 @@ function _inlineLiterals(jaxpr, consts) {
2689
2775
  }
2690
2776
  function binopAbstractEval([x, y]) {
2691
2777
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
2692
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2693
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
2778
+ return [promoteAvals(x, y)];
2694
2779
  }
2695
2780
  function compareAbstractEval([x, y]) {
2696
2781
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
2697
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2698
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), DType.Bool)];
2782
+ const aval = promoteAvals(x, y);
2783
+ return [new ShapedArray(aval.shape, DType.Bool, false)];
2699
2784
  }
2700
2785
  function vectorizedUnopAbstractEval([x]) {
2701
2786
  return [ShapedArray.fromAval(x)];
@@ -2708,18 +2793,18 @@ const abstractEvalRules = {
2708
2793
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
2709
2794
  [Primitive.StopGradient]: vectorizedUnopAbstractEval,
2710
2795
  [Primitive.Cast]([x], { dtype }) {
2711
- return [new ShapedArray(x.shape, dtype)];
2796
+ return [new ShapedArray(x.shape, dtype, false)];
2712
2797
  },
2713
2798
  [Primitive.Bitcast]([x], { dtype }) {
2714
2799
  if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
2715
2800
  if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
2716
- return [new ShapedArray(x.shape, dtype)];
2801
+ return [new ShapedArray(x.shape, dtype, false)];
2717
2802
  },
2718
2803
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
2719
2804
  if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
2720
2805
  const keyShape = generalBroadcast(k0.shape, k1.shape);
2721
2806
  if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2722
- return [new ShapedArray(shape$1, DType.Uint32)];
2807
+ return [new ShapedArray(shape$1, DType.Uint32, false)];
2723
2808
  },
2724
2809
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2725
2810
  [Primitive.Cos]: vectorizedUnopAbstractEval,
@@ -2733,55 +2818,54 @@ const abstractEvalRules = {
2733
2818
  [Primitive.Reduce]([x], { axis }) {
2734
2819
  const axisSet = new Set(axis);
2735
2820
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2736
- return [new ShapedArray(newShape, x.dtype)];
2821
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2737
2822
  },
2738
2823
  [Primitive.Pool]([x], { window, strides }) {
2739
2824
  const shape$1 = checkPoolShape(x.shape, window, strides);
2740
- return [new ShapedArray(shape$1, x.dtype)];
2825
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2741
2826
  },
2742
2827
  [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2743
2828
  const shape$1 = checkPoolShape(inShape, window, strides);
2744
2829
  if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2745
- return [new ShapedArray(inShape, x.dtype)];
2830
+ return [new ShapedArray(inShape, x.dtype, x.weakType)];
2746
2831
  },
2747
2832
  [Primitive.Dot]([x, y]) {
2748
- if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2749
2833
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
2750
- const shape$1 = generalBroadcast(x.shape, y.shape);
2834
+ const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
2751
2835
  shape$1.splice(-1, 1);
2752
- return [new ShapedArray(shape$1, x.dtype)];
2836
+ return [new ShapedArray(shape$1, dtype, weakType)];
2753
2837
  },
2754
2838
  [Primitive.Conv]([lhs, rhs], params) {
2755
- if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2839
+ const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
2756
2840
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2757
- return [new ShapedArray(shape$1, lhs.dtype)];
2841
+ return [new ShapedArray(shape$1, dtype, weakType)];
2758
2842
  },
2759
2843
  [Primitive.Compare]: compareAbstractEval,
2760
2844
  [Primitive.Where]([cond, x, y]) {
2761
2845
  if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
2762
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2763
- const shape$1 = generalBroadcast(cond.shape, generalBroadcast(x.shape, y.shape));
2764
- return [new ShapedArray(shape$1, x.dtype)];
2846
+ const xy = promoteAvals(x, y);
2847
+ const shape$1 = generalBroadcast(cond.shape, xy.shape);
2848
+ return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
2765
2849
  },
2766
2850
  [Primitive.Transpose]([x], { perm }) {
2767
- return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
2851
+ return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
2768
2852
  },
2769
2853
  [Primitive.Broadcast]([x], { shape: shape$1 }) {
2770
- return [new ShapedArray(shape$1, x.dtype)];
2854
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2771
2855
  },
2772
2856
  [Primitive.Reshape]([x], { shape: shape$1 }) {
2773
- return [new ShapedArray(shape$1, x.dtype)];
2857
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2774
2858
  },
2775
2859
  [Primitive.Flip]([x], _) {
2776
- return [new ShapedArray(x.shape, x.dtype)];
2860
+ return [ShapedArray.fromAval(x)];
2777
2861
  },
2778
2862
  [Primitive.Shrink]([x], { slice }) {
2779
2863
  const newShape = slice.map((s) => s[1] - s[0]);
2780
- return [new ShapedArray(newShape, x.dtype)];
2864
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2781
2865
  },
2782
2866
  [Primitive.Pad]([x], { width }) {
2783
2867
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
2784
- return [new ShapedArray(newShape, x.dtype)];
2868
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2785
2869
  },
2786
2870
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2787
2871
  for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
@@ -2794,7 +2878,7 @@ const abstractEvalRules = {
2794
2878
  const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
2795
2879
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2796
2880
  newShape.splice(outDim, 0, ...gatherShape);
2797
- return [new ShapedArray(newShape, x.dtype)];
2881
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2798
2882
  },
2799
2883
  [Primitive.JitCall](args, { jaxpr }) {
2800
2884
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
@@ -2861,6 +2945,7 @@ function jit$1(f, opts) {
2861
2945
  const cacheKey = JSON.stringify(jaxprArgs);
2862
2946
  const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2863
2947
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
2948
+ name: f.name || "closure",
2864
2949
  jaxpr,
2865
2950
  numConsts: consts.length
2866
2951
  });
@@ -3022,13 +3107,14 @@ const jvpRules = {
3022
3107
  const indicesRef = indices.map((t) => t.ref);
3023
3108
  return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3024
3109
  },
3025
- [Primitive.JitCall](primals, tangents, { jaxpr }) {
3110
+ [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3026
3111
  const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3027
3112
  const outs = bind(Primitive.JitCall, [
3028
3113
  ...newConsts.map((c) => c.ref),
3029
3114
  ...primals,
3030
3115
  ...tangents
3031
3116
  ], {
3117
+ name: `${name}_jvp`,
3032
3118
  jaxpr: newJaxpr,
3033
3119
  numConsts: newConsts.length
3034
3120
  });
@@ -3082,7 +3168,7 @@ function jvp$1(f, primals, tangents) {
3082
3168
  function mappedAval(batchDim, aval) {
3083
3169
  const shape$1 = [...aval.shape];
3084
3170
  shape$1.splice(batchDim, 1);
3085
- return new ShapedArray(shape$1, aval.dtype);
3171
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3086
3172
  }
3087
3173
  /** Move one axis to a different index. */
3088
3174
  function moveaxis$1(x, src, dst) {
@@ -3226,9 +3312,10 @@ const vmapRules = {
3226
3312
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3227
3313
  return [[pad$1(x, newWidth)], [xBdim]];
3228
3314
  },
3229
- [Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
3315
+ [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3230
3316
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3231
3317
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3318
+ name: `${name}_vmap`,
3232
3319
  jaxpr: newJaxpr,
3233
3320
  numConsts: newConsts.length
3234
3321
  });
@@ -3244,7 +3331,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3244
3331
  if (dims[i] === null) return v.aval;
3245
3332
  const shape$1 = [...v.aval.shape];
3246
3333
  shape$1.splice(dims[i], 0, axisSize);
3247
- return new ShapedArray(shape$1, v.aval.dtype);
3334
+ return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3248
3335
  });
3249
3336
  const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3250
3337
  const result = {
@@ -3457,8 +3544,8 @@ var PartialEvalTrace = class extends Trace {
3457
3544
  processPrimitive(primitive, tracers, params) {
3458
3545
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3459
3546
  if (primitive === Primitive.JitCall) {
3460
- const { jaxpr, numConsts } = params;
3461
- return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
3547
+ const { name, jaxpr, numConsts } = params;
3548
+ return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3462
3549
  }
3463
3550
  const tracersIn = tracers.map((t) => this.instantiateConst(t));
3464
3551
  const avalsIn = tracersIn.map((t) => t.pval.aval);
@@ -3484,12 +3571,13 @@ var PartialEvalTrace = class extends Trace {
3484
3571
  *
3485
3572
  * Used when encountering a JitCall rule during the trace.
3486
3573
  */
3487
- #partialEvalJaxpr(jaxpr, numConsts, tracers) {
3574
+ #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3488
3575
  jaxpr = jaxpr.flatten();
3489
3576
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3490
3577
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3491
3578
  const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
3492
3579
  const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
3580
+ name: `${name}_peval`,
3493
3581
  jaxpr: jaxpr1,
3494
3582
  numConsts: 0
3495
3583
  });
@@ -3501,6 +3589,7 @@ var PartialEvalTrace = class extends Trace {
3501
3589
  prim: Primitive.JitCall,
3502
3590
  tracersIn: resTracers.concat(unknownTracers),
3503
3591
  params: {
3592
+ name: `${name}_resid`,
3504
3593
  jaxpr: jaxpr2,
3505
3594
  numConsts: 0
3506
3595
  },
@@ -3643,7 +3732,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
3643
3732
  }
3644
3733
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
3645
3734
  const eqn = jaxpr.eqns[i];
3646
- const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? scalar(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3735
+ const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3647
3736
  const cotangentsOut = eqn.outBinders.map(readCotangent);
3648
3737
  const rule = transposeRules[eqn.primitive];
3649
3738
  if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
@@ -3823,7 +3912,7 @@ const transposeRules = {
3823
3912
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
3824
3913
  throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
3825
3914
  },
3826
- [Primitive.JitCall](cts, args, { jaxpr }) {
3915
+ [Primitive.JitCall](cts, args, { name, jaxpr }) {
3827
3916
  const undefPrimals = args.map((x) => x instanceof UndefPrimal);
3828
3917
  const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
3829
3918
  const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
@@ -3832,6 +3921,7 @@ const transposeRules = {
3832
3921
  ...residuals,
3833
3922
  ...cts
3834
3923
  ], {
3924
+ name: `${name}_t`,
3835
3925
  jaxpr: newJaxpr,
3836
3926
  numConsts: newConsts.length
3837
3927
  });
@@ -3906,7 +3996,7 @@ function valueAndGrad$1(f) {
3906
3996
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3907
3997
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3908
3998
  if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3909
- const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
3999
+ const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
3910
4000
  for (const r of rest) dispose(r);
3911
4001
  fVjp.dispose();
3912
4002
  return [y, ct];
@@ -4276,7 +4366,7 @@ function argmin(a, axis, opts) {
4276
4366
  } else axis = checkAxis(axis, a.ndim);
4277
4367
  const shape$1 = a.shape;
4278
4368
  const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
4279
- const length = scalar(shape$1[axis], {
4369
+ const length = array(shape$1[axis], {
4280
4370
  dtype: int32,
4281
4371
  device: a.device
4282
4372
  });
@@ -4300,7 +4390,7 @@ function argmax(a, axis, opts) {
4300
4390
  } else axis = checkAxis(axis, a.ndim);
4301
4391
  const shape$1 = a.shape;
4302
4392
  const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
4303
- const length = scalar(shape$1[axis], {
4393
+ const length = array(shape$1[axis], {
4304
4394
  dtype: int32,
4305
4395
  device: a.device
4306
4396
  });
@@ -4716,7 +4806,7 @@ function acos(x) {
4716
4806
  * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4717
4807
  * improvements.
4718
4808
  */
4719
- const hypot = jit$1((x1, x2) => {
4809
+ const hypot = jit$1(function hypot$1(x1, x2) {
4720
4810
  return sqrt(square(x1).add(square(x2)));
4721
4811
  });
4722
4812
  /**
@@ -4732,7 +4822,7 @@ const hypot = jit$1((x1, x2) => {
4732
4822
  *
4733
4823
  * The output is ill-defined when both x and y are zero.
4734
4824
  */
4735
- const atan2 = jit$1((y, x) => {
4825
+ const atan2 = jit$1(function atan2$1(y, x) {
4736
4826
  const r = sqrt(square(x.ref).add(square(y.ref)));
4737
4827
  const xNeg = less(x.ref, 0);
4738
4828
  const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
@@ -4800,13 +4890,13 @@ const degrees = rad2deg;
4800
4890
  * @function
4801
4891
  * Computes first array raised to power of second array, element-wise.
4802
4892
  */
4803
- const power = jit$1((x1, x2) => {
4893
+ const power = jit$1(function power$1(x1, x2) {
4804
4894
  return exp(log(x1).mul(x2));
4805
4895
  });
4806
4896
  /** @function Alias of `jax.numpy.power()`. */
4807
4897
  const pow = power;
4808
4898
  /** @function Calculate the element-wise cube root of the input array. */
4809
- const cbrt = jit$1((x) => {
4899
+ const cbrt = jit$1(function cbrt$1(x) {
4810
4900
  const sgn = where(less(x.ref, 0), -1, 1);
4811
4901
  return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4812
4902
  });
@@ -4816,7 +4906,7 @@ const cbrt = jit$1((x) => {
4816
4906
  *
4817
4907
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
4818
4908
  */
4819
- const sinh = jit$1((x) => {
4909
+ const sinh = jit$1(function sinh$1(x) {
4820
4910
  const ex = exp(x);
4821
4911
  const emx = reciprocal(ex.ref);
4822
4912
  return ex.sub(emx).mul(.5);
@@ -4827,7 +4917,7 @@ const sinh = jit$1((x) => {
4827
4917
  *
4828
4918
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
4829
4919
  */
4830
- const cosh = jit$1((x) => {
4920
+ const cosh = jit$1(function cosh$1(x) {
4831
4921
  const ex = exp(x);
4832
4922
  const emx = reciprocal(ex.ref);
4833
4923
  return ex.add(emx).mul(.5);
@@ -4838,7 +4928,7 @@ const cosh = jit$1((x) => {
4838
4928
  *
4839
4929
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4840
4930
  */
4841
- const tanh = jit$1((x) => {
4931
+ const tanh = jit$1(function tanh$1(x) {
4842
4932
  const negsgn = where(less(x.ref, 0), 1, -1);
4843
4933
  const en2x = exp(x.mul(negsgn.ref).mul(2));
4844
4934
  return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
@@ -4849,7 +4939,7 @@ const tanh = jit$1((x) => {
4849
4939
  *
4850
4940
  * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4851
4941
  */
4852
- const arcsinh = jit$1((x) => {
4942
+ const arcsinh = jit$1(function arcsinh$1(x) {
4853
4943
  return log(x.ref.add(sqrt(square(x).add(1))));
4854
4944
  });
4855
4945
  /**
@@ -4858,7 +4948,7 @@ const arcsinh = jit$1((x) => {
4858
4948
  *
4859
4949
  * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4860
4950
  */
4861
- const arccosh = jit$1((x) => {
4951
+ const arccosh = jit$1(function arccosh$1(x) {
4862
4952
  return log(x.ref.add(sqrt(square(x).sub(1))));
4863
4953
  });
4864
4954
  /**
@@ -4867,7 +4957,7 @@ const arccosh = jit$1((x) => {
4867
4957
  *
4868
4958
  * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4869
4959
  */
4870
- const arctanh = jit$1((x) => {
4960
+ const arctanh = jit$1(function arctanh$1(x) {
4871
4961
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4872
4962
  });
4873
4963
  /** @function Alias of `jax.numpy.arcsinh()`. */
@@ -4983,7 +5073,9 @@ function softSign(x) {
4983
5073
  *
4984
5074
  * Reference: https://en.wikipedia.org/wiki/Swish_function
4985
5075
  */
4986
- const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
5076
+ const silu = jit$1(function silu$1(x) {
5077
+ return x.ref.mul(sigmoid(x));
5078
+ });
4987
5079
  /**
4988
5080
  * @function
4989
5081
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
@@ -5044,7 +5136,7 @@ function celu(x, alpha = 1) {
5044
5136
  *
5045
5137
  * This will be improved in the future.
5046
5138
  */
5047
- const gelu = jit$1((x) => {
5139
+ const gelu = jit$1(function gelu$1(x) {
5048
5140
  const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5049
5141
  return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5050
5142
  });
@@ -5215,8 +5307,11 @@ function bits(key$1, shape$1 = []) {
5215
5307
  const keyShape = validateKeyShape(key$1);
5216
5308
  return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
5217
5309
  }
5218
- /** Sample uniform random values in [minval, maxval) with given shape. */
5219
- function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5310
+ /**
5311
+ * @function
5312
+ * Sample uniform random values in [minval, maxval) with given shape.
5313
+ */
5314
+ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5220
5315
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
5221
5316
  const mantissa = bits(key$1, shape$1).div(array(512, {
5222
5317
  dtype: DType.Uint32,
@@ -5229,7 +5324,7 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5229
5324
  const rand = bitcast(float12, DType.Float32).sub(1);
5230
5325
  if (minval === 0 && maxval === 1) return rand;
5231
5326
  else return rand.mul(maxval - minval).add(minval);
5232
- }
5327
+ }, { staticArgnums: [1, 2] });
5233
5328
  /**
5234
5329
  * Sample Bernoulli random variables with given mean (0,1 categorical).
5235
5330
  *
@@ -5240,26 +5335,30 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
5240
5335
  p = fudgeArray(p);
5241
5336
  return uniform(key$1, shape$1).less(p);
5242
5337
  }
5243
- /** Sample exponential random values according to `p(x) = exp(-x)`. */
5244
- function exponential(key$1, shape$1 = []) {
5338
+ /**
5339
+ * @function
5340
+ * Sample exponential random values according to `p(x) = exp(-x)`.
5341
+ */
5342
+ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
5245
5343
  const u = uniform(key$1, shape$1);
5246
5344
  return negative(log1p(negative(u)));
5247
- }
5345
+ }, { staticArgnums: [1] });
5248
5346
  /**
5347
+ * @function
5249
5348
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5250
5349
  *
5251
5350
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5252
5351
  * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5253
5352
  * bitwise identical to JAX.
5254
5353
  */
5255
- function normal(key$1, shape$1 = []) {
5354
+ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5256
5355
  const [k1, k2] = split(key$1, 2);
5257
5356
  const u1 = uniform(k1, shape$1);
5258
5357
  const u2 = uniform(k2, shape$1);
5259
5358
  const radius = sqrt(log1p(negative(u1)).mul(-2));
5260
5359
  const theta = u2.mul(2 * Math.PI);
5261
5360
  return radius.mul(cos(theta));
5262
- }
5361
+ }, { staticArgnums: [1] });
5263
5362
 
5264
5363
  //#endregion
5265
5364
  //#region src/polyfills.ts