@jax-js/jax 0.0.4 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
30
30
  }) : target, mod));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-Ss1Mev_-.cjs');
33
+ const require_backend = require('./backend-FtkbO6pI.cjs');
34
34
 
35
35
  //#region src/tree.ts
36
36
  var tree_exports = {};
@@ -60,6 +60,10 @@ var JsTreeDef = class JsTreeDef {
60
60
  this.nodeMetadata = nodeMetadata;
61
61
  this.childTreedefs = childTreedefs;
62
62
  }
63
+ /** Get the total number of leaves in the tree. */
64
+ get size() {
65
+ return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
66
+ }
63
67
  /** Returns a string representation of this tree definition. */
64
68
  toString(root = true) {
65
69
  if (root) return "JsTreeDef(" + this.toString(false) + ")";
@@ -215,6 +219,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
215
219
  const s_ = strides;
216
220
  const d_ = dilation;
217
221
  const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
222
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
223
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
224
+ st = st.reshape([...noop, ...require_backend.zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...require_backend.zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
225
+ st = st.permute([
226
+ ...require_backend.range(noop.length),
227
+ ...ks.map((_, j) => noop.length + 2 * j),
228
+ ...ks.map((_, j) => noop.length + 2 * j + 1)
229
+ ]);
230
+ return st;
231
+ }
218
232
  const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
219
233
  const kidf = require_backend.zipn(ks, i_, d_, f_);
220
234
  st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
@@ -249,6 +263,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
249
263
  const s_ = strides;
250
264
  const d_ = dilation;
251
265
  const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
266
+ if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
267
+ st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
268
+ st = st.pad([...noop.map(() => [0, 0]), ...require_backend.zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...require_backend.zip(o_, s_).map(([o, s]) => o * s)]);
269
+ st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
270
+ return st.reshape(st.shape.concat(require_backend.rep(ks.length, 1)));
271
+ }
252
272
  if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
253
273
  const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
254
274
  const kidf = require_backend.zipn(ks, i_, d_, f_);
@@ -358,6 +378,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
358
378
  Primitive$1["Atan"] = "atan";
359
379
  Primitive$1["Exp"] = "exp";
360
380
  Primitive$1["Log"] = "log";
381
+ Primitive$1["Erf"] = "erf";
382
+ Primitive$1["Erfc"] = "erfc";
361
383
  Primitive$1["Sqrt"] = "sqrt";
362
384
  Primitive$1["Min"] = "min";
363
385
  Primitive$1["Max"] = "max";
@@ -435,6 +457,12 @@ function exp$1(x) {
435
457
  function log$1(x) {
436
458
  return bind1(Primitive.Log, [x]);
437
459
  }
460
+ function erf$1(x) {
461
+ return bind1(Primitive.Erf, [x]);
462
+ }
463
+ function erfc$1(x) {
464
+ return bind1(Primitive.Erfc, [x]);
465
+ }
438
466
  function sqrt$1(x) {
439
467
  return bind1(Primitive.Sqrt, [x]);
440
468
  }
@@ -596,6 +624,21 @@ var Trace = class {
596
624
  this.main = main;
597
625
  }
598
626
  };
627
+ /**
628
+ * Broadcast shapes and promote types with casting for two avals.
629
+ *
630
+ * This implements the weak type behavior described in `promoteTypes()`, but not
631
+ * implemented in that function as `weakType` is not passed.
632
+ */
633
+ function promoteAvals(a, b) {
634
+ const shape$1 = require_backend.generalBroadcast(a.shape, b.shape);
635
+ const weakType = a.weakType && b.weakType;
636
+ let dtype;
637
+ if (a.weakType === b.weakType) dtype = require_backend.promoteTypes(a.dtype, b.dtype);
638
+ else if (a.weakType) dtype = require_backend.promoteTypes(b.dtype, require_backend.DType.Uint32);
639
+ else dtype = require_backend.promoteTypes(a.dtype, require_backend.DType.Uint32);
640
+ return new ShapedArray(shape$1, dtype, weakType);
641
+ }
599
642
  var Tracer = class Tracer {
600
643
  /** @ignore */
601
644
  _trace;
@@ -610,10 +653,19 @@ var Tracer = class Tracer {
610
653
  get size() {
611
654
  return require_backend.prod(this.shape);
612
655
  }
613
- /** The dtype of the array. */
656
+ /** The dtype of elements stored in the array. */
614
657
  get dtype() {
615
658
  return this.aval.dtype;
616
659
  }
660
+ /**
661
+ * Whether the array is weakly typed.
662
+ *
663
+ * Weakly typed arrays will cast to the dtype of the other operand. See
664
+ * `promoteTypes()` for details.
665
+ */
666
+ get weakType() {
667
+ return this.aval.weakType;
668
+ }
617
669
  /** The number of dimensions of the array. */
618
670
  get ndim() {
619
671
  return this.shape.length;
@@ -850,12 +902,13 @@ function getShape(x) {
850
902
  return x instanceof Tracer ? x.shape : [];
851
903
  }
852
904
  var ShapedArray = class ShapedArray {
853
- constructor(shape$1, dtype) {
905
+ constructor(shape$1, dtype, weakType) {
854
906
  this.shape = shape$1;
855
907
  this.dtype = dtype;
908
+ this.weakType = weakType;
856
909
  }
857
910
  static fromAval(aval) {
858
- return new ShapedArray(aval.shape, aval.dtype);
911
+ return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
859
912
  }
860
913
  get ndim() {
861
914
  return this.shape.length;
@@ -869,7 +922,7 @@ var ShapedArray = class ShapedArray {
869
922
  };
870
923
  function getAval(x) {
871
924
  if (x instanceof Tracer) return x.aval;
872
- else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32);
925
+ else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32, typeof x === "boolean" ? false : true);
873
926
  else throw new TypeError(`Unknown value: ${x}`);
874
927
  }
875
928
  function bind(prim, args, params = {}) {
@@ -1152,12 +1205,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
1152
1205
  } else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
1153
1206
  });
1154
1207
  }
1155
- function broadcastedJit(fn) {
1208
+ function broadcastedJit(fn, opts) {
1156
1209
  return (nargs, exps, avals, params) => {
1157
- const newShape = avals.map((aval) => aval.shape).reduce(generalBroadcast);
1158
- exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
1159
- if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1160
- }));
1210
+ let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
1211
+ const skipCastIdx = opts?.skipCastIdx ?? [];
1212
+ if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
1213
+ exps = exps.map((exp$3, i) => {
1214
+ exp$3 = reshapeViews(exp$3, (st) => {
1215
+ if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
1216
+ });
1217
+ if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = require_backend.AluExp.cast(newDtype, exp$3);
1218
+ return exp$3;
1219
+ });
1161
1220
  const exp$2 = fn(exps, params);
1162
1221
  return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
1163
1222
  };
@@ -1191,7 +1250,7 @@ const jitRules = {
1191
1250
  const k1 = reshapeViews(keys[1], mapping);
1192
1251
  const c0 = require_backend.AluExp.u32(0);
1193
1252
  const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
1194
- const exp$2 = require_backend.AluExp.threefry2x32(c0, c1, k0, k1, mode);
1253
+ const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
1195
1254
  return new require_backend.Kernel(nargs, require_backend.prod(shape$1), exp$2);
1196
1255
  },
1197
1256
  [Primitive.Sin]: unopJit(require_backend.AluExp.sin),
@@ -1200,6 +1259,8 @@ const jitRules = {
1200
1259
  [Primitive.Atan]: unopJit(require_backend.AluExp.atan),
1201
1260
  [Primitive.Exp]: unopJit(require_backend.AluExp.exp),
1202
1261
  [Primitive.Log]: unopJit(require_backend.AluExp.log),
1262
+ [Primitive.Erf]: unopJit(require_backend.AluExp.erf),
1263
+ [Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
1203
1264
  [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
1204
1265
  [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
1205
1266
  [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
@@ -1232,7 +1293,7 @@ const jitRules = {
1232
1293
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
1233
1294
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
1234
1295
  const c = k1.exp;
1235
- const cs = new ShapedArray(generalBroadcast(as.shape, bs.shape), c.dtype);
1296
+ const cs = promoteAvals(as, bs);
1236
1297
  return jitRules[Primitive.Reduce](nargs, [c], [cs], {
1237
1298
  op: require_backend.AluOp.Add,
1238
1299
  axis: [cs.ndim - 1]
@@ -1242,12 +1303,12 @@ const jitRules = {
1242
1303
  const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
1243
1304
  a = reshapeViews(a, (st) => st.compose(stX));
1244
1305
  b = reshapeViews(b, (st) => st.compose(stY));
1245
- as = new ShapedArray(stX.shape, as.dtype);
1246
- bs = new ShapedArray(stY.shape, bs.dtype);
1306
+ as = new ShapedArray(stX.shape, as.dtype, as.weakType);
1307
+ bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
1247
1308
  return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1248
1309
  },
1249
1310
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
1250
- [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
1311
+ [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
1251
1312
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
1252
1313
  [Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
1253
1314
  [Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
@@ -1260,7 +1321,7 @@ const jitRules = {
1260
1321
  [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
1261
1322
  [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1262
1323
  const axisSet = new Set(axis);
1263
- const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
1324
+ const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
1264
1325
  const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
1265
1326
  finalShape.splice(outDim, 0, ...indexShape);
1266
1327
  const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
@@ -1296,9 +1357,10 @@ function splitGraphDataflow(backend, jaxpr) {
1296
1357
  Primitive.Conv,
1297
1358
  Primitive.PoolTranspose
1298
1359
  ];
1360
+ const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
1299
1361
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
1300
1362
  const eqn = jaxpr.eqns[i];
1301
- if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1363
+ if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
1302
1364
  for (const v of eqn.outBinders) {
1303
1365
  blackNodes.add(v);
1304
1366
  p1NextBlack.set(v, v);
@@ -1417,7 +1479,7 @@ var PendingExecute = class {
1417
1479
  /**
1418
1480
  * A multidimensional numeric array with data stored on CPU or GPU.
1419
1481
  *
1420
- * This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
1482
+ * This is the library's core data type. Equivalent to `jax.Array` from JAX, or
1421
1483
  * `torch.Tensor`.
1422
1484
  *
1423
1485
  * Not to be confused with the JavaScript "Array" constructor. Avoid importing
@@ -1428,9 +1490,11 @@ var Array$1 = class Array$1 extends Tracer {
1428
1490
  static #nextId = 1001;
1429
1491
  id;
1430
1492
  #dtype;
1493
+ #weakType;
1431
1494
  #source;
1432
1495
  #st;
1433
1496
  #backend;
1497
+ #committed;
1434
1498
  #rc;
1435
1499
  #pendingSet;
1436
1500
  /**
@@ -1439,21 +1503,23 @@ var Array$1 = class Array$1 extends Tracer {
1439
1503
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1440
1504
  * will be freed when the array is disposed.
1441
1505
  */
1442
- constructor(source, st, dtype, backend, { pending = null } = {}) {
1506
+ constructor(args) {
1443
1507
  super(baseArrayTrace);
1444
1508
  this.id = Array$1.#nextId++;
1445
- this.#dtype = dtype;
1446
- this.#source = source;
1447
- this.#st = st;
1448
- this.#backend = backend;
1509
+ this.#dtype = args.dtype;
1510
+ this.#weakType = args.weakType;
1511
+ this.#source = args.source;
1512
+ this.#st = args.st;
1513
+ this.#backend = args.backend;
1514
+ this.#committed = args.committed;
1449
1515
  this.#rc = 1;
1450
- this.#pendingSet = new Set(pending);
1516
+ this.#pendingSet = new Set(args.pending);
1451
1517
  if (this.#pendingSet.size === 0) this.#pendingSet = null;
1452
- else if (source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1518
+ else if (this.#source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1453
1519
  }
1454
1520
  /** @ignore */
1455
1521
  get aval() {
1456
- return new ShapedArray(this.#st.shape, this.#dtype);
1522
+ return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
1457
1523
  }
1458
1524
  /** Return a simple string representation of the array's dimensions. */
1459
1525
  toString() {
@@ -1465,6 +1531,18 @@ var Array$1 = class Array$1 extends Tracer {
1465
1531
  #check() {
1466
1532
  if (this.#rc <= 0) throw new UseAfterFreeError(this);
1467
1533
  }
1534
+ /** Construct an array, copying fields from `this`. */
1535
+ #newArrayFrom(args) {
1536
+ return new Array$1({
1537
+ source: args.source ?? this.#source,
1538
+ st: args.st ?? this.#st,
1539
+ dtype: args.dtype ?? this.#dtype,
1540
+ weakType: this.#weakType,
1541
+ backend: args.backend ?? this.#backend,
1542
+ committed: args.committed ?? this.#committed,
1543
+ pending: args.pending ?? this.#pending ?? void 0
1544
+ });
1545
+ }
1468
1546
  get ref() {
1469
1547
  this.#check();
1470
1548
  this.#rc++;
@@ -1504,7 +1582,10 @@ var Array$1 = class Array$1 extends Tracer {
1504
1582
  const pending = this.#pending;
1505
1583
  for (const exe of pending) exe.updateRc(1);
1506
1584
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1507
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
1585
+ const ar = this.#newArrayFrom({
1586
+ st,
1587
+ pending
1588
+ });
1508
1589
  this.dispose();
1509
1590
  return ar;
1510
1591
  }
@@ -1514,9 +1595,10 @@ var Array$1 = class Array$1 extends Tracer {
1514
1595
  */
1515
1596
  #gather(indices, axis, outDim) {
1516
1597
  this.#check();
1517
- if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1518
1598
  const axisSet = new Set(axis);
1519
1599
  if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
1600
+ if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
1601
+ indices = indices.map((ar) => ar._putSync(this.#backend));
1520
1602
  indices = Array$1.#broadcastArrays(indices);
1521
1603
  const indexShape = indices[0].shape;
1522
1604
  const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
@@ -1553,7 +1635,11 @@ var Array$1 = class Array$1 extends Tracer {
1553
1635
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1554
1636
  this.dispose();
1555
1637
  for (const ar of indices) ar.dispose();
1556
- return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
1638
+ return this.#newArrayFrom({
1639
+ source: output,
1640
+ st: require_backend.ShapeTracker.fromShape(finalShape),
1641
+ pending
1642
+ });
1557
1643
  }
1558
1644
  /** Move axes to the rightmost dimension of the shape. */
1559
1645
  #moveAxesDown(axis) {
@@ -1576,11 +1662,17 @@ var Array$1 = class Array$1 extends Tracer {
1576
1662
  return this.#reshape(this.#st.permute(perm));
1577
1663
  }
1578
1664
  #unary(op, dtypeOutput) {
1665
+ const weakType = !dtypeOutput && this.#weakType;
1579
1666
  dtypeOutput ??= this.#dtype;
1580
1667
  this.#check();
1581
1668
  if (this.#source instanceof require_backend.AluExp) {
1582
1669
  const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
1583
- return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1670
+ this.dispose();
1671
+ return this.#newArrayFrom({
1672
+ source: exp$3.simplify(),
1673
+ dtype: dtypeOutput,
1674
+ weakType
1675
+ });
1584
1676
  }
1585
1677
  const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
1586
1678
  const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1590,41 +1682,67 @@ var Array$1 = class Array$1 extends Tracer {
1590
1682
  for (const exe of pending) exe.updateRc(1);
1591
1683
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1592
1684
  this.dispose();
1593
- return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
1685
+ return this.#newArrayFrom({
1686
+ source: output,
1687
+ st: require_backend.ShapeTracker.fromShape(this.shape),
1688
+ dtype: dtypeOutput,
1689
+ weakType,
1690
+ pending
1691
+ });
1594
1692
  }
1595
1693
  #binary(op, other) {
1596
- const custom = (src) => new require_backend.AluExp(op, this.#dtype, src);
1694
+ const custom = (src) => new require_backend.AluExp(op, src[0].dtype, src);
1597
1695
  return Array$1.#naryCustom(op, custom, [this, other]);
1598
1696
  }
1599
- static #naryCustom(name, custom, arrays, { dtypeOverride, dtypeOutput, reduceAxis } = {}) {
1697
+ static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
1600
1698
  const n = arrays.length;
1601
- const backend = arrays[0].#backend;
1602
1699
  if (n === 0) throw new TypeError(`No inputs for ${name}`);
1603
1700
  for (const ar of arrays) ar.#check();
1604
- let dtype;
1605
- for (let i = 0; i < n; i++) {
1606
- if (dtypeOverride?.[i]) {
1607
- if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1608
- } else if (!dtype) dtype = arrays[i].#dtype;
1609
- else if (arrays[i].#dtype !== dtype) throw new TypeError(`Dtype mismatch in ${name}: ${dtype} vs ${arrays[i].#dtype}`);
1610
- if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
1611
- }
1612
- dtypeOutput ??= dtype;
1613
- if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1701
+ let castDtype;
1702
+ let castWeakType = true;
1703
+ for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
1704
+ if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
1705
+ } else if (castDtype === void 0) {
1706
+ castDtype = arrays[i].#dtype;
1707
+ castWeakType = arrays[i].#weakType;
1708
+ } else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
1709
+ const weakType = castWeakType && !strongTypeOutput;
1710
+ const { backend, committed } = Array$1.#computeBackend(name, arrays);
1711
+ arrays = arrays.map((ar) => ar._putSync(backend));
1614
1712
  arrays = Array$1.#broadcastArrays(arrays);
1615
1713
  const newShape = [...arrays[0].shape];
1616
1714
  if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
1715
+ const sources = arrays.map((ar, i) => {
1716
+ if (!dtypeOverride?.[i]) return require_backend.AluExp.cast(castDtype, ar.#source);
1717
+ else return ar.#source;
1718
+ });
1617
1719
  if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
1618
- const exp$4 = custom(arrays.map((ar) => ar.#source));
1619
- return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1720
+ const exp$4 = custom(sources);
1721
+ arrays.forEach((ar) => ar.dispose());
1722
+ return new Array$1({
1723
+ source: exp$4.simplify(),
1724
+ st: arrays[0].#st,
1725
+ dtype: exp$4.dtype,
1726
+ weakType,
1727
+ backend,
1728
+ committed
1729
+ });
1620
1730
  }
1621
- const exp$3 = custom(arrays.map((ar) => {
1622
- const src$1 = ar.#source;
1731
+ const exp$3 = custom(arrays.map((ar, i) => {
1732
+ const src$1 = sources[i];
1623
1733
  if (ar.#st.contiguous) return src$1;
1624
1734
  return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
1625
1735
  }));
1626
1736
  const st = require_backend.ShapeTracker.fromShape(newShape);
1627
- return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1737
+ arrays.forEach((ar) => ar.dispose());
1738
+ return new Array$1({
1739
+ source: exp$3.simplify(),
1740
+ st,
1741
+ dtype: exp$3.dtype,
1742
+ weakType,
1743
+ backend,
1744
+ committed
1745
+ });
1628
1746
  }
1629
1747
  let indices;
1630
1748
  if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
@@ -1634,14 +1752,19 @@ var Array$1 = class Array$1 extends Tracer {
1634
1752
  }
1635
1753
  const inputs = [];
1636
1754
  const src = [];
1637
- for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#source, ar.#st, indices));
1638
- else {
1639
- let gid = inputs.indexOf(ar.#source);
1640
- if (gid === -1) {
1641
- gid = inputs.length;
1642
- inputs.push(ar.#source);
1755
+ for (const [i, ar] of arrays.entries()) {
1756
+ let nextSrc;
1757
+ if (ar.#source instanceof require_backend.AluExp) nextSrc = require_backend.accessorAluExp(ar.#source, ar.#st, indices);
1758
+ else {
1759
+ let gid = inputs.indexOf(ar.#source);
1760
+ if (gid === -1) {
1761
+ gid = inputs.length;
1762
+ inputs.push(ar.#source);
1763
+ }
1764
+ nextSrc = require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
1643
1765
  }
1644
- src.push(require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, indices));
1766
+ if (!dtypeOverride?.[i]) nextSrc = require_backend.AluExp.cast(castDtype, nextSrc);
1767
+ src.push(nextSrc);
1645
1768
  }
1646
1769
  const exp$2 = custom(src);
1647
1770
  let re = void 0;
@@ -1654,13 +1777,19 @@ var Array$1 = class Array$1 extends Tracer {
1654
1777
  const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
1655
1778
  for (const exe of pending) exe.updateRc(1);
1656
1779
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1657
- for (const ar of arrays) ar.dispose();
1658
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
1780
+ arrays.forEach((ar) => ar.dispose());
1781
+ return new Array$1({
1782
+ source: output,
1783
+ st: require_backend.ShapeTracker.fromShape(newShape),
1784
+ dtype: kernel.dtype,
1785
+ weakType,
1786
+ backend,
1787
+ committed,
1788
+ pending
1789
+ });
1659
1790
  }
1660
1791
  /** Reduce the last dimension of the array by an operation. */
1661
1792
  #reduce(op) {
1662
- this.#check();
1663
- if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
1664
1793
  const shape$1 = this.shape;
1665
1794
  const reduction = new require_backend.Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
1666
1795
  const newShape = shape$1.slice(0, -1);
@@ -1679,7 +1808,11 @@ var Array$1 = class Array$1 extends Tracer {
1679
1808
  for (const exe of pending) exe.updateRc(1);
1680
1809
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1681
1810
  this.dispose();
1682
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
1811
+ return this.#newArrayFrom({
1812
+ source: output,
1813
+ st: require_backend.ShapeTracker.fromShape(newShape),
1814
+ pending
1815
+ });
1683
1816
  }
1684
1817
  /**
1685
1818
  * Normalizes this array into one backed by a `Slot`.
@@ -1715,20 +1848,37 @@ var Array$1 = class Array$1 extends Tracer {
1715
1848
  }
1716
1849
  #dataInline() {
1717
1850
  this.#check();
1718
- const exp$2 = this.#source;
1719
- const ar = new Array$1(exp$2, this.#st, this.dtype, require_backend.getBackend("cpu"));
1851
+ if (!(this.#source instanceof require_backend.AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
1852
+ const ar = this.#newArrayFrom({ backend: require_backend.getBackend("cpu") });
1720
1853
  this.dispose();
1721
1854
  return ar.dataSync();
1722
1855
  }
1723
1856
  static #broadcastArrays(arrays) {
1724
1857
  if (arrays.length === 0) throw new Error("Need at least one array to broadcast");
1725
1858
  if (arrays.length === 1) return arrays;
1726
- const newShape = arrays.map((a) => a.shape).reduce(generalBroadcast);
1859
+ const newShape = arrays.map((a) => a.shape).reduce(require_backend.generalBroadcast);
1727
1860
  return arrays.map((ar) => {
1728
1861
  if (require_backend.deepEqual(ar.shape, newShape)) return ar;
1729
1862
  return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
1730
1863
  });
1731
1864
  }
1865
+ static #computeBackend(name, arrays) {
1866
+ const committed = arrays.filter((ar) => ar.#committed);
1867
+ if (committed.length > 0) {
1868
+ const backend = committed[0].#backend;
1869
+ for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
1870
+ return {
1871
+ backend,
1872
+ committed: true
1873
+ };
1874
+ } else {
1875
+ const backend = arrays.length > 0 ? arrays[0].#backend : require_backend.getBackend();
1876
+ return {
1877
+ backend,
1878
+ committed: false
1879
+ };
1880
+ }
1881
+ }
1732
1882
  /** Realize the array and return it as data. */
1733
1883
  async data() {
1734
1884
  if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
@@ -1842,14 +1992,18 @@ var Array$1 = class Array$1 extends Tracer {
1842
1992
  x.#backend.incRef(x.#source);
1843
1993
  const pending = x.#pending;
1844
1994
  for (const exe of pending) exe.updateRc(1);
1845
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
1995
+ const y = x.#newArrayFrom({
1996
+ dtype,
1997
+ weakType: false,
1998
+ pending
1999
+ });
1846
2000
  x.dispose();
1847
2001
  return [y];
1848
2002
  }
1849
2003
  },
1850
2004
  [Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
1851
- const keyShape = generalBroadcast(k0.shape, k1.shape);
1852
- if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2005
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2006
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
1853
2007
  const c0 = zeros(shape$1, {
1854
2008
  dtype: require_backend.DType.Uint32,
1855
2009
  device: k0.device
@@ -1884,6 +2038,12 @@ var Array$1 = class Array$1 extends Tracer {
1884
2038
  [Primitive.Log]([x]) {
1885
2039
  return [x.#unary(require_backend.AluOp.Log)];
1886
2040
  },
2041
+ [Primitive.Erf]([x]) {
2042
+ return [x.#unary(require_backend.AluOp.Erf)];
2043
+ },
2044
+ [Primitive.Erfc]([x]) {
2045
+ return [x.#unary(require_backend.AluOp.Erfc)];
2046
+ },
1887
2047
  [Primitive.Sqrt]([x]) {
1888
2048
  return [x.#unary(require_backend.AluOp.Sqrt)];
1889
2049
  },
@@ -1917,7 +2077,7 @@ var Array$1 = class Array$1 extends Tracer {
1917
2077
  },
1918
2078
  [Primitive.Compare]([x, y], { op }) {
1919
2079
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1920
- return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: require_backend.DType.Bool })];
2080
+ return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
1921
2081
  },
1922
2082
  [Primitive.Where]([cond, x, y]) {
1923
2083
  const custom = ([cond$1, x$1, y$1]) => require_backend.AluExp.where(cond$1, x$1, y$1);
@@ -1952,7 +2112,8 @@ var Array$1 = class Array$1 extends Tracer {
1952
2112
  },
1953
2113
  [Primitive.JitCall](args, { jaxpr, numConsts }) {
1954
2114
  if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
1955
- const backend = require_backend.getBackend();
2115
+ const { backend, committed } = Array$1.#computeBackend("jit_call", args);
2116
+ args = args.map((ar) => ar._putSync(backend));
1956
2117
  const consts = args.slice(0, numConsts);
1957
2118
  const tracers = args.slice(numConsts);
1958
2119
  const jp = jitCompile(backend, jaxpr, consts);
@@ -1963,43 +2124,66 @@ var Array$1 = class Array$1 extends Tracer {
1963
2124
  pending.splice(0, 0, ...prevPending);
1964
2125
  args.forEach((x) => x.dispose());
1965
2126
  return outputs.map((source, i) => {
1966
- return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
2127
+ return new Array$1({
2128
+ source,
2129
+ st: require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
2130
+ dtype: jaxpr.outs[i].aval.dtype,
2131
+ weakType: jaxpr.outs[i].aval.weakType,
2132
+ backend,
2133
+ committed,
2134
+ pending
2135
+ });
1967
2136
  });
1968
2137
  }
1969
2138
  };
1970
2139
  }
2140
+ /** @private */
1971
2141
  _realizeSource() {
1972
2142
  this.#realize();
1973
2143
  return this.#source;
1974
2144
  }
2145
+ /** @private Put this array on a new backend, asynchronously. */
2146
+ async _put(backend) {
2147
+ if (this.#backend === backend) return this;
2148
+ if (this.#source instanceof require_backend.AluExp) {
2149
+ const ar = this.#newArrayFrom({
2150
+ backend,
2151
+ committed: true
2152
+ });
2153
+ this.dispose();
2154
+ return ar;
2155
+ } else {
2156
+ const data = await this.data();
2157
+ return arrayFromData(data, this.shape, {
2158
+ dtype: this.#dtype,
2159
+ device: backend.type
2160
+ }, this.#weakType);
2161
+ }
2162
+ }
2163
+ /** @private Put this array on a new backend, synchronously. */
2164
+ _putSync(backend) {
2165
+ if (this.#backend === backend) return this;
2166
+ if (this.#source instanceof require_backend.AluExp) {
2167
+ const ar = this.#newArrayFrom({
2168
+ backend,
2169
+ committed: true
2170
+ });
2171
+ this.dispose();
2172
+ return ar;
2173
+ } else {
2174
+ const data = this.dataSync();
2175
+ return arrayFromData(data, this.shape, {
2176
+ dtype: this.#dtype,
2177
+ device: backend.type
2178
+ }, this.#weakType);
2179
+ }
2180
+ }
1975
2181
  };
1976
- /** Construct an array from a single scalar constant. */
1977
- function scalar(value, { dtype, device } = {}) {
1978
- if (typeof value === "number") {
1979
- dtype ??= require_backend.DType.Float32;
1980
- if (![
1981
- require_backend.DType.Float32,
1982
- require_backend.DType.Float16,
1983
- require_backend.DType.Int32,
1984
- require_backend.DType.Uint32
1985
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1986
- } else if (typeof value === "boolean") {
1987
- dtype ??= require_backend.DType.Bool;
1988
- if (![
1989
- require_backend.DType.Float32,
1990
- require_backend.DType.Float16,
1991
- require_backend.DType.Int32,
1992
- require_backend.DType.Uint32,
1993
- require_backend.DType.Bool
1994
- ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
1995
- } else throw new TypeError(`Invalid type for scalar ${value}`);
1996
- return new Array$1(require_backend.AluExp.const(dtype, value), require_backend.ShapeTracker.fromShape([]), dtype, require_backend.getBackend(device));
1997
- }
1998
2182
  /** Constructor for creating a new array from data. */
1999
2183
  function array(values, { shape: shape$1, dtype, device } = {}) {
2000
2184
  if (values instanceof Tracer) {
2001
2185
  if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
2002
- if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
2186
+ if (dtype && values.dtype !== dtype) values = values.astype(dtype);
2003
2187
  return values;
2004
2188
  } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
2005
2189
  dtype,
@@ -2021,6 +2205,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
2021
2205
  dtype,
2022
2206
  device
2023
2207
  });
2208
+ if (size$1 === 1) return full(shape$1, flat[0], {
2209
+ dtype,
2210
+ device
2211
+ });
2024
2212
  if (typeof flat[0] === "boolean") {
2025
2213
  dtype = dtype ?? require_backend.DType.Bool;
2026
2214
  const data = new Int32Array(flat.map((x) => x ? 1 : 0));
@@ -2029,46 +2217,52 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
2029
2217
  device
2030
2218
  });
2031
2219
  } else {
2220
+ const weakType = dtype == void 0;
2032
2221
  dtype = dtype ?? require_backend.DType.Float32;
2033
2222
  const data = require_backend.dtypedJsArray(dtype, flat);
2034
2223
  return arrayFromData(data, shape$1, {
2035
2224
  dtype,
2036
2225
  device
2037
- });
2226
+ }, weakType);
2038
2227
  }
2039
2228
  }
2040
2229
  }
2041
- function arrayFromData(data, shape$1, { dtype, device } = {}) {
2230
+ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
2231
+ if (data instanceof Float32Array) {
2232
+ if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
2233
+ dtype ??= require_backend.DType.Float32;
2234
+ } else if (data instanceof Int32Array) {
2235
+ if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2236
+ dtype ??= require_backend.DType.Int32;
2237
+ } else if (data instanceof Uint32Array) {
2238
+ if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2239
+ dtype ??= require_backend.DType.Uint32;
2240
+ } else if (data instanceof Float16Array) {
2241
+ if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2242
+ dtype ??= require_backend.DType.Float16;
2243
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
2042
2244
  if (data.length < inlineArrayLimit) {
2043
2245
  let allEqual = true;
2044
2246
  for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
2045
2247
  allEqual = false;
2046
2248
  break;
2047
2249
  }
2048
- if (allEqual) return full(shape$1, data[0], {
2049
- dtype,
2050
- device
2051
- });
2250
+ if (allEqual) {
2251
+ const sa = new ShapedArray(shape$1, dtype, weakType);
2252
+ return fullInternal(sa, data[0], device);
2253
+ }
2052
2254
  }
2053
2255
  const backend = require_backend.getBackend(device);
2054
- if (ArrayBuffer.isView(data)) {
2055
- const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2056
- if (data instanceof Float32Array) {
2057
- if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
2058
- dtype ??= require_backend.DType.Float32;
2059
- } else if (data instanceof Int32Array) {
2060
- if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2061
- dtype ??= require_backend.DType.Int32;
2062
- } else if (data instanceof Uint32Array) {
2063
- if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2064
- dtype ??= require_backend.DType.Uint32;
2065
- } else if (data instanceof Float16Array) {
2066
- if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2067
- dtype ??= require_backend.DType.Float16;
2068
- } else throw new Error("Unsupported data array type: " + data.constructor.name);
2069
- const slot = backend.malloc(data.byteLength, buf);
2070
- return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
2071
- } else throw new Error("Unsupported data type: " + data.constructor.name);
2256
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2257
+ const slot = backend.malloc(data.byteLength, buf);
2258
+ return new Array$1({
2259
+ source: slot,
2260
+ st: require_backend.ShapeTracker.fromShape(shape$1),
2261
+ dtype,
2262
+ weakType,
2263
+ backend,
2264
+ committed: device != void 0
2265
+ });
2072
2266
  }
2073
2267
  function dataToJs(dtype, data, shape$1) {
2074
2268
  if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
@@ -2084,7 +2278,7 @@ function dataToJs(dtype, data, shape$1) {
2084
2278
  /** If x is a value, lift it into an array, otherwise leave it be. */
2085
2279
  function pureArray(x) {
2086
2280
  if (x instanceof Tracer) return x;
2087
- else return scalar(x);
2281
+ else return array(x);
2088
2282
  }
2089
2283
  var EvalTrace = class extends Trace {
2090
2284
  pure = (x) => pureArray(x);
@@ -2095,20 +2289,28 @@ var EvalTrace = class extends Trace {
2095
2289
  };
2096
2290
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
2097
2291
  const implRules = Array$1._implRules();
2292
+ function fullInternal(aval, fillValue, device) {
2293
+ return new Array$1({
2294
+ source: require_backend.AluExp.const(aval.dtype, fillValue),
2295
+ st: require_backend.ShapeTracker.fromShape(aval.shape),
2296
+ dtype: aval.dtype,
2297
+ weakType: aval.weakType,
2298
+ backend: require_backend.getBackend(device),
2299
+ committed: device != void 0
2300
+ });
2301
+ }
2098
2302
  function zerosLike$1(val, dtype) {
2099
- const aval = getAval(val);
2100
- if (val instanceof Tracer) val.dispose();
2101
- return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2303
+ return fullLike(val, 0, dtype);
2102
2304
  }
2103
2305
  function onesLike$1(val, dtype) {
2104
- const aval = getAval(val);
2105
- if (val instanceof Tracer) val.dispose();
2106
- return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2306
+ return fullLike(val, 1, dtype);
2107
2307
  }
2108
2308
  function fullLike(val, fillValue, dtype) {
2109
2309
  const aval = getAval(val);
2110
2310
  if (val instanceof Tracer) val.dispose();
2111
- return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
2311
+ if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
2312
+ const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
2313
+ return fullInternal(sa, fillValue);
2112
2314
  }
2113
2315
  /** Return a new array of given shape and type, filled with zeros. */
2114
2316
  function zeros(shape$1, { dtype, device } = {}) {
@@ -2126,19 +2328,14 @@ function ones(shape$1, { dtype, device } = {}) {
2126
2328
  }
2127
2329
  /** Return a new array of given shape and type, filled with `fill_value`. */
2128
2330
  function full(shape$1, fillValue, { dtype, device } = {}) {
2129
- let source;
2130
- if (typeof fillValue === "number") {
2131
- dtype = dtype ?? require_backend.DType.Float32;
2132
- source = require_backend.AluExp.const(dtype, fillValue);
2133
- } else if (typeof fillValue === "bigint") {
2134
- dtype = dtype ?? require_backend.DType.Int32;
2135
- source = require_backend.AluExp.const(dtype, Number(fillValue));
2136
- } else if (typeof fillValue === "boolean") {
2331
+ let weakType = dtype == void 0;
2332
+ if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
2333
+ else if (typeof fillValue === "boolean") {
2137
2334
  dtype = dtype ?? require_backend.DType.Bool;
2138
- source = require_backend.AluExp.const(dtype, fillValue ? 1 : 0);
2335
+ weakType = false;
2139
2336
  } else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
2140
2337
  else throw new TypeError(`Invalid type for full: ${fillValue}`);
2141
- return new Array$1(source, require_backend.ShapeTracker.fromShape(shape$1), dtype ?? require_backend.DType.Float32, require_backend.getBackend(device));
2338
+ return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
2142
2339
  }
2143
2340
  /**
2144
2341
  * Create an identity matrix.
@@ -2148,6 +2345,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
2148
2345
  */
2149
2346
  function eye(numRows, numCols, { dtype, device } = {}) {
2150
2347
  numCols = numCols ?? numRows;
2348
+ const weakType = dtype == void 0;
2151
2349
  dtype = dtype ?? require_backend.DType.Float32;
2152
2350
  if (numCols < numRows) {
2153
2351
  const arr = eye(numCols, numRows, {
@@ -2161,7 +2359,14 @@ function eye(numRows, numCols, { dtype, device } = {}) {
2161
2359
  device
2162
2360
  });
2163
2361
  const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
2164
- return new Array$1(require_backend.AluExp.cast(dtype, exp$2), require_backend.ShapeTracker.fromShape([numRows, numCols]), dtype, require_backend.getBackend(device));
2362
+ return new Array$1({
2363
+ source: require_backend.AluExp.cast(dtype, exp$2),
2364
+ st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
2365
+ dtype,
2366
+ weakType,
2367
+ backend: require_backend.getBackend(device),
2368
+ committed: device != void 0
2369
+ });
2165
2370
  }
2166
2371
  /** Return the identity matrix, with ones on the main diagonal. */
2167
2372
  function identity$1(n, { dtype, device } = {}) {
@@ -2198,7 +2403,14 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
2198
2403
  });
2199
2404
  const exp$2 = require_backend.AluExp.add(require_backend.AluExp.const(dtype, start), require_backend.AluExp.mul(require_backend.AluExp.cast(dtype, require_backend.AluVar.idx), require_backend.AluExp.const(dtype, step)));
2200
2405
  const st = require_backend.ShapeTracker.fromShape([size$1]);
2201
- return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
2406
+ return new Array$1({
2407
+ source: exp$2,
2408
+ st,
2409
+ dtype,
2410
+ weakType: false,
2411
+ backend: require_backend.getBackend(device),
2412
+ committed: device != void 0
2413
+ });
2202
2414
  }
2203
2415
  /**
2204
2416
  * Return evenly spaced numbers over a specified interval.
@@ -2216,10 +2428,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2216
2428
  dtype,
2217
2429
  device
2218
2430
  });
2219
- else if (num === 1) return scalar(start, {
2431
+ else if (num === 1) return full([1], start, {
2220
2432
  dtype,
2221
2433
  device
2222
- }).reshape([1]);
2434
+ });
2223
2435
  else if (start === stop) return full([num], start, {
2224
2436
  dtype,
2225
2437
  device
@@ -2228,7 +2440,14 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
2228
2440
  const denom = endpoint ? num - 1 : num;
2229
2441
  const exp$2 = require_backend.AluExp.cast(dtype, require_backend.AluExp.add(require_backend.AluExp.f32(start), require_backend.AluExp.mul(require_backend.AluExp.f32(delta / denom), require_backend.AluExp.cast(require_backend.DType.Float32, require_backend.AluVar.idx))));
2230
2442
  const st = require_backend.ShapeTracker.fromShape([num]);
2231
- return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
2443
+ return new Array$1({
2444
+ source: exp$2,
2445
+ st,
2446
+ dtype,
2447
+ weakType: false,
2448
+ backend: require_backend.getBackend(device),
2449
+ committed: device != void 0
2450
+ });
2232
2451
  }
2233
2452
  function aluCompare(a, b, op) {
2234
2453
  switch (op) {
@@ -2240,35 +2459,6 @@ function aluCompare(a, b, op) {
2240
2459
  case CompareOp.LessEqual: return require_backend.AluExp.add(require_backend.AluExp.cmplt(a, b), require_backend.AluExp.cmpne(a, b).not());
2241
2460
  }
2242
2461
  }
2243
- /**
2244
- * Implements a NumPy-style generalized broadcast rule on two array shapes.
2245
- *
2246
- * "When operating on two arrays, NumPy compares their shapes element-wise. It
2247
- * starts with the trailing (i.e. rightmost) dimension and works its way left.
2248
- * Two dimensions are compatible when:
2249
- * 1. they are equal, or
2250
- * 2. one of them is 1."
2251
- *
2252
- * Throws a TypeError if the broadcast is not possible.
2253
- *
2254
- * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
2255
- */
2256
- function generalBroadcast(a, b) {
2257
- const out = [];
2258
- let i = a.length - 1;
2259
- let j = b.length - 1;
2260
- for (; i >= 0 && j >= 0; i--, j--) {
2261
- const x = a[i];
2262
- const y = b[j];
2263
- if (x === y) out.push(x);
2264
- else if (x === 1) out.push(y);
2265
- else if (y === 1) out.push(x);
2266
- else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
2267
- }
2268
- for (; i >= 0; i--) out.push(a[i]);
2269
- for (; j >= 0; j--) out.push(b[j]);
2270
- return out.reverse();
2271
- }
2272
2462
 
2273
2463
  //#endregion
2274
2464
  //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
@@ -2348,13 +2538,15 @@ var Var = class Var {
2348
2538
  };
2349
2539
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
2350
2540
  var Lit = class {
2351
- dtype;
2352
2541
  value;
2353
2542
  aval;
2354
- constructor(dtype, value) {
2355
- this.dtype = dtype;
2543
+ get dtype() {
2544
+ return this.aval.dtype;
2545
+ }
2546
+ constructor(aval, value) {
2547
+ if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
2356
2548
  this.value = value;
2357
- this.aval = new ShapedArray([], dtype);
2549
+ this.aval = ShapedArray.fromAval(aval);
2358
2550
  }
2359
2551
  };
2360
2552
  function atomIsLit(atom, literal) {
@@ -2478,14 +2670,19 @@ var Jaxpr = class Jaxpr {
2478
2670
  const c = eqn.outBinders[0];
2479
2671
  if (atomIsLit(a, 0)) context.set(c, b);
2480
2672
  else if (atomIsLit(b, 0)) context.set(c, a);
2481
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.dtype === require_backend.DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2673
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === require_backend.DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
2674
+ else newEqns.push(eqn);
2675
+ } else if (eqn.primitive === Primitive.Neg) {
2676
+ const [a] = inputs;
2677
+ const c = eqn.outBinders[0];
2678
+ if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
2482
2679
  else newEqns.push(eqn);
2483
2680
  } else if (eqn.primitive === Primitive.Mul) {
2484
2681
  const [a, b] = inputs;
2485
2682
  const c = eqn.outBinders[0];
2486
2683
  if (atomIsLit(a, 1)) context.set(c, b);
2487
2684
  else if (atomIsLit(b, 1)) context.set(c, a);
2488
- else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.dtype, a.value * b.value));
2685
+ else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
2489
2686
  else newEqns.push(eqn);
2490
2687
  } else if (eqn.primitive === Primitive.Idiv) {
2491
2688
  const [a, b] = inputs;
@@ -2583,7 +2780,7 @@ function evalJaxpr(jaxpr, args) {
2583
2780
  if (x instanceof Var) {
2584
2781
  remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
2585
2782
  return env.get(x);
2586
- } else return scalar(x.value, { dtype: x.dtype });
2783
+ } else return array(x.value, { dtype: x.dtype });
2587
2784
  };
2588
2785
  const write = (v, val) => {
2589
2786
  if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
@@ -2642,7 +2839,7 @@ var JaxprTrace = class extends Trace {
2642
2839
  let tracer = this.builder.constTracers.get(val);
2643
2840
  if (tracer === void 0) {
2644
2841
  tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
2645
- this.builder.addConst(tracer, val instanceof Tracer ? val.ref : scalar(val));
2842
+ this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
2646
2843
  }
2647
2844
  return tracer;
2648
2845
  }
@@ -2711,7 +2908,7 @@ function _inlineLiterals(jaxpr, consts) {
2711
2908
  const newConsts = [];
2712
2909
  for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
2713
2910
  const ar = consts[i];
2714
- literals.set(jaxpr.inBinders[i], new Lit(ar.dtype, ar.dataSync()[0]));
2911
+ literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
2715
2912
  } else {
2716
2913
  constBinders.push(jaxpr.inBinders[i]);
2717
2914
  newConsts.push(consts[i]);
@@ -2724,13 +2921,12 @@ function _inlineLiterals(jaxpr, consts) {
2724
2921
  }
2725
2922
  function binopAbstractEval([x, y]) {
2726
2923
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
2727
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2728
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
2924
+ return [promoteAvals(x, y)];
2729
2925
  }
2730
2926
  function compareAbstractEval([x, y]) {
2731
2927
  if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
2732
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2733
- return [new ShapedArray(generalBroadcast(x.shape, y.shape), require_backend.DType.Bool)];
2928
+ const aval = promoteAvals(x, y);
2929
+ return [new ShapedArray(aval.shape, require_backend.DType.Bool, false)];
2734
2930
  }
2735
2931
  function vectorizedUnopAbstractEval([x]) {
2736
2932
  return [ShapedArray.fromAval(x)];
@@ -2743,18 +2939,18 @@ const abstractEvalRules = {
2743
2939
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
2744
2940
  [Primitive.StopGradient]: vectorizedUnopAbstractEval,
2745
2941
  [Primitive.Cast]([x], { dtype }) {
2746
- return [new ShapedArray(x.shape, dtype)];
2942
+ return [new ShapedArray(x.shape, dtype, false)];
2747
2943
  },
2748
2944
  [Primitive.Bitcast]([x], { dtype }) {
2749
2945
  if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
2750
2946
  if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
2751
- return [new ShapedArray(x.shape, dtype)];
2947
+ return [new ShapedArray(x.shape, dtype, false)];
2752
2948
  },
2753
2949
  [Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
2754
2950
  if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
2755
- const keyShape = generalBroadcast(k0.shape, k1.shape);
2756
- if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2757
- return [new ShapedArray(shape$1, require_backend.DType.Uint32)];
2951
+ const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
2952
+ if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
2953
+ return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
2758
2954
  },
2759
2955
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2760
2956
  [Primitive.Cos]: vectorizedUnopAbstractEval,
@@ -2762,61 +2958,62 @@ const abstractEvalRules = {
2762
2958
  [Primitive.Atan]: vectorizedUnopAbstractEval,
2763
2959
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2764
2960
  [Primitive.Log]: vectorizedUnopAbstractEval,
2961
+ [Primitive.Erf]: vectorizedUnopAbstractEval,
2962
+ [Primitive.Erfc]: vectorizedUnopAbstractEval,
2765
2963
  [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2766
2964
  [Primitive.Min]: binopAbstractEval,
2767
2965
  [Primitive.Max]: binopAbstractEval,
2768
2966
  [Primitive.Reduce]([x], { axis }) {
2769
2967
  const axisSet = new Set(axis);
2770
2968
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2771
- return [new ShapedArray(newShape, x.dtype)];
2969
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2772
2970
  },
2773
2971
  [Primitive.Pool]([x], { window, strides }) {
2774
2972
  const shape$1 = checkPoolShape(x.shape, window, strides);
2775
- return [new ShapedArray(shape$1, x.dtype)];
2973
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2776
2974
  },
2777
2975
  [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2778
2976
  const shape$1 = checkPoolShape(inShape, window, strides);
2779
2977
  if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2780
- return [new ShapedArray(inShape, x.dtype)];
2978
+ return [new ShapedArray(inShape, x.dtype, x.weakType)];
2781
2979
  },
2782
2980
  [Primitive.Dot]([x, y]) {
2783
- if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2784
2981
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
2785
- const shape$1 = generalBroadcast(x.shape, y.shape);
2982
+ const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
2786
2983
  shape$1.splice(-1, 1);
2787
- return [new ShapedArray(shape$1, x.dtype)];
2984
+ return [new ShapedArray(shape$1, dtype, weakType)];
2788
2985
  },
2789
2986
  [Primitive.Conv]([lhs, rhs], params) {
2790
- if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2987
+ const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
2791
2988
  const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2792
- return [new ShapedArray(shape$1, lhs.dtype)];
2989
+ return [new ShapedArray(shape$1, dtype, weakType)];
2793
2990
  },
2794
2991
  [Primitive.Compare]: compareAbstractEval,
2795
2992
  [Primitive.Where]([cond, x, y]) {
2796
2993
  if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
2797
- if (x.dtype !== y.dtype) throw new TypeError(`Mismatched dtypes: ${x.dtype} vs ${y.dtype}`);
2798
- const shape$1 = generalBroadcast(cond.shape, generalBroadcast(x.shape, y.shape));
2799
- return [new ShapedArray(shape$1, x.dtype)];
2994
+ const xy = promoteAvals(x, y);
2995
+ const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
2996
+ return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
2800
2997
  },
2801
2998
  [Primitive.Transpose]([x], { perm }) {
2802
- return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
2999
+ return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
2803
3000
  },
2804
3001
  [Primitive.Broadcast]([x], { shape: shape$1 }) {
2805
- return [new ShapedArray(shape$1, x.dtype)];
3002
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2806
3003
  },
2807
3004
  [Primitive.Reshape]([x], { shape: shape$1 }) {
2808
- return [new ShapedArray(shape$1, x.dtype)];
3005
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
2809
3006
  },
2810
3007
  [Primitive.Flip]([x], _) {
2811
- return [new ShapedArray(x.shape, x.dtype)];
3008
+ return [ShapedArray.fromAval(x)];
2812
3009
  },
2813
3010
  [Primitive.Shrink]([x], { slice }) {
2814
3011
  const newShape = slice.map((s) => s[1] - s[0]);
2815
- return [new ShapedArray(newShape, x.dtype)];
3012
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2816
3013
  },
2817
3014
  [Primitive.Pad]([x], { width }) {
2818
3015
  const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
2819
- return [new ShapedArray(newShape, x.dtype)];
3016
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2820
3017
  },
2821
3018
  [Primitive.Gather]([x, ...indices], { axis, outDim }) {
2822
3019
  for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
@@ -2826,10 +3023,10 @@ const abstractEvalRules = {
2826
3023
  if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
2827
3024
  const axisSet = new Set(axis);
2828
3025
  if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
2829
- const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
3026
+ const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
2830
3027
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2831
3028
  newShape.splice(outDim, 0, ...gatherShape);
2832
- return [new ShapedArray(newShape, x.dtype)];
3029
+ return [new ShapedArray(newShape, x.dtype, x.weakType)];
2833
3030
  },
2834
3031
  [Primitive.JitCall](args, { jaxpr }) {
2835
3032
  const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
@@ -2896,6 +3093,7 @@ function jit$1(f, opts) {
2896
3093
  const cacheKey = JSON.stringify(jaxprArgs);
2897
3094
  const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2898
3095
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
3096
+ name: f.name || "closure",
2899
3097
  jaxpr,
2900
3098
  numConsts: consts.length
2901
3099
  });
@@ -3015,6 +3213,16 @@ const jvpRules = {
3015
3213
  [Primitive.Log]([x], [dx]) {
3016
3214
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
3017
3215
  },
3216
+ [Primitive.Erf]([x], [dx]) {
3217
+ const coeff = 2 / Math.sqrt(Math.PI);
3218
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3219
+ return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
3220
+ },
3221
+ [Primitive.Erfc]([x], [dx]) {
3222
+ const coeff = -2 / Math.sqrt(Math.PI);
3223
+ const expTerm = exp$1(neg(x.ref.mul(x.ref)));
3224
+ return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
3225
+ },
3018
3226
  [Primitive.Sqrt]([x], [dx]) {
3019
3227
  const z = sqrt$1(x);
3020
3228
  return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
@@ -3058,13 +3266,14 @@ const jvpRules = {
3058
3266
  const indicesRef = indices.map((t) => t.ref);
3059
3267
  return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
3060
3268
  },
3061
- [Primitive.JitCall](primals, tangents, { jaxpr }) {
3269
+ [Primitive.JitCall](primals, tangents, { name, jaxpr }) {
3062
3270
  const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
3063
3271
  const outs = bind(Primitive.JitCall, [
3064
3272
  ...newConsts.map((c) => c.ref),
3065
3273
  ...primals,
3066
3274
  ...tangents
3067
3275
  ], {
3276
+ name: `${name}_jvp`,
3068
3277
  jaxpr: newJaxpr,
3069
3278
  numConsts: newConsts.length
3070
3279
  });
@@ -3119,7 +3328,7 @@ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
3119
3328
  function mappedAval(batchDim, aval) {
3120
3329
  const shape$1 = [...aval.shape];
3121
3330
  shape$1.splice(batchDim, 1);
3122
- return new ShapedArray(shape$1, aval.dtype);
3331
+ return new ShapedArray(shape$1, aval.dtype, aval.weakType);
3123
3332
  }
3124
3333
  /** Move one axis to a different index. */
3125
3334
  function moveaxis$1(x, src, dst) {
@@ -3176,6 +3385,10 @@ var BatchTrace = class extends Trace {
3176
3385
  const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
3177
3386
  const vmapRule = vmapRules[primitive];
3178
3387
  if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
3388
+ if (bdimsIn.every((d) => d === null)) {
3389
+ const valOuts$1 = bind(primitive, valsIn, params);
3390
+ return valOuts$1.map((x) => new BatchTracer(this, x, null));
3391
+ }
3179
3392
  const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
3180
3393
  return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
3181
3394
  }
@@ -3183,24 +3396,28 @@ var BatchTrace = class extends Trace {
3183
3396
  return this.main.globalData;
3184
3397
  }
3185
3398
  };
3186
- function handleScalarBroadcasting(nd, x, d) {
3187
- if (d === null || nd === ndim$1(x)) return x;
3188
- else {
3189
- const axis = require_backend.range(ndim$1(x), nd);
3190
- const shape$1 = [...x.shape, ...axis.map(() => 1)];
3191
- return broadcast(x, shape$1, axis);
3192
- }
3193
- }
3194
- /** Process a primitive with built-in broadcasting. */
3399
+ /**
3400
+ * Process a primitive with built-in broadcasting.
3401
+ *
3402
+ * Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
3403
+ */
3195
3404
  function broadcastBatcher(op) {
3196
3405
  return (axisSize, args, dims) => {
3197
3406
  if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
3198
- const idx = dims.findIndex((d) => d !== null);
3199
- if (idx === -1) return [[op(...args)], [null]];
3200
- if (require_backend.zip(args, dims).every(([x, d]) => ndim$1(x) === 0 || require_backend.deepEqual(x.shape, args[idx].shape) && d === dims[idx])) return [[op(...args)], [dims[idx]]];
3201
- args = args.map((x, i) => ndim$1(x) > 0 ? moveBatchAxis(axisSize, dims[i], 0, x) : x);
3202
- const nd = Math.max(...args.map(ndim$1));
3203
- args = args.map((x, i) => handleScalarBroadcasting(nd, x, dims[i]));
3407
+ const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
3408
+ const firstIdx = dims.findIndex((d) => d !== null);
3409
+ const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
3410
+ if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
3411
+ args = args.map((x, i) => {
3412
+ if (dims[i] === null) return x;
3413
+ x = moveBatchAxis(axisSize, dims[i], 0, x);
3414
+ if (x.ndim < nd) x = x.reshape([
3415
+ x.shape[0],
3416
+ ...require_backend.rep(nd - x.ndim, 1),
3417
+ ...x.shape.slice(1)
3418
+ ]);
3419
+ return x;
3420
+ });
3204
3421
  return [[op(...args)], [0]];
3205
3422
  };
3206
3423
  }
@@ -3224,17 +3441,18 @@ const vmapRules = {
3224
3441
  [Primitive.Atan]: unopBatcher(atan$1),
3225
3442
  [Primitive.Exp]: unopBatcher(exp$1),
3226
3443
  [Primitive.Log]: unopBatcher(log$1),
3444
+ [Primitive.Erf]: unopBatcher(erf$1),
3445
+ [Primitive.Erfc]: unopBatcher(erfc$1),
3227
3446
  [Primitive.Sqrt]: unopBatcher(sqrt$1),
3228
3447
  [Primitive.Min]: broadcastBatcher(min$1),
3229
3448
  [Primitive.Max]: broadcastBatcher(max$1),
3230
3449
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
3231
- if (xBdim === null) return [[reduce(x, op, axis)], [null]];
3450
+ require_backend.assertNonNull(xBdim);
3232
3451
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3233
3452
  const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3234
3453
  return [[reduce(x, op, newAxis)], [outBdim]];
3235
3454
  },
3236
3455
  [Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
3237
- if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
3238
3456
  x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
3239
3457
  y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
3240
3458
  const z = dot$1(x, y);
@@ -3243,29 +3461,72 @@ const vmapRules = {
3243
3461
  [Primitive.Compare](axisSize, args, dims, { op }) {
3244
3462
  return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
3245
3463
  },
3464
+ [Primitive.Where]: broadcastBatcher(where$1),
3465
+ [Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
3466
+ require_backend.assertNonNull(xBdim);
3467
+ const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
3468
+ newPerm.splice(xBdim, 0, xBdim);
3469
+ return [[transpose$1(x, newPerm)], [xBdim]];
3470
+ },
3471
+ [Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
3472
+ require_backend.assertNonNull(xBdim);
3473
+ const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
3474
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3475
+ return [[broadcast(x, newShape, newAxis)], [xBdim]];
3476
+ },
3246
3477
  [Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
3247
- if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
3248
3478
  x = moveBatchAxis(axisSize, xBdim, 0, x);
3249
3479
  return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
3250
3480
  },
3251
3481
  [Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
3252
- if (xBdim === null) return [[flip$1(x, axis)], [null]];
3482
+ require_backend.assertNonNull(xBdim);
3253
3483
  const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3254
3484
  return [[flip$1(x, newAxis)], [xBdim]];
3255
3485
  },
3256
3486
  [Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
3257
- if (xBdim === null) return [[shrink(x, slice)], [null]];
3487
+ require_backend.assertNonNull(xBdim);
3258
3488
  const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
3259
3489
  return [[shrink(x, newSlice)], [xBdim]];
3260
3490
  },
3261
3491
  [Primitive.Pad](axisSize, [x], [xBdim], { width }) {
3262
- if (xBdim === null) return [[pad$1(x, width)], [null]];
3492
+ require_backend.assertNonNull(xBdim);
3263
3493
  const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
3264
3494
  return [[pad$1(x, newWidth)], [xBdim]];
3265
3495
  },
3266
- [Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
3496
+ [Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
3497
+ if (indicesBdim.every((d) => d === null)) {
3498
+ require_backend.assertNonNull(xBdim);
3499
+ const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
3500
+ let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
3501
+ let newOutDim = outDim;
3502
+ if (newOutDim < newBdim) newBdim += axis.length;
3503
+ else newOutDim += 1;
3504
+ return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
3505
+ }
3506
+ const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
3507
+ indices = indices.map((m, i) => {
3508
+ if (indicesBdim[i] === null) return m;
3509
+ m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
3510
+ if (m.ndim < nd) m = m.reshape([
3511
+ m.shape[0],
3512
+ ...require_backend.rep(nd - m.ndim, 1),
3513
+ ...m.shape.slice(1)
3514
+ ]);
3515
+ return m;
3516
+ });
3517
+ if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
3518
+ else {
3519
+ x = moveBatchAxis(axisSize, xBdim, 0, x);
3520
+ const newAxis = [0, ...axis.map((ax) => ax + 1)];
3521
+ const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
3522
+ indices.splice(0, 0, extraBatchIndex);
3523
+ return [[gather(x, indices, newAxis, outDim)], [outDim]];
3524
+ }
3525
+ },
3526
+ [Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
3267
3527
  const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
3268
3528
  const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
3529
+ name: `${name}_vmap`,
3269
3530
  jaxpr: newJaxpr,
3270
3531
  numConsts: newConsts.length
3271
3532
  });
@@ -3281,7 +3542,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
3281
3542
  if (dims[i] === null) return v.aval;
3282
3543
  const shape$1 = [...v.aval.shape];
3283
3544
  shape$1.splice(dims[i], 0, axisSize);
3284
- return new ShapedArray(shape$1, v.aval.dtype);
3545
+ return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
3285
3546
  });
3286
3547
  const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
3287
3548
  const result = {
@@ -3321,12 +3582,14 @@ function vmapFlat(f, inAxes, args) {
3321
3582
  function vmap$1(f, inAxes = 0) {
3322
3583
  return (...args) => {
3323
3584
  const [argsFlat, inTree] = flatten(args);
3324
- let inAxesFlat;
3585
+ let inAxesFlat = [];
3325
3586
  if (typeof inAxes === "number") inAxesFlat = require_backend.rep(argsFlat.length, inAxes);
3587
+ else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, null));
3588
+ else if (typeof inAxes[i] === "number") inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, inAxes[i]));
3326
3589
  else {
3327
- let inTree2;
3328
- [inAxesFlat, inTree2] = flatten(inAxes);
3329
- if (!inTree.equals(inTree2)) throw new TreeMismatchError("vmap", inTree, inTree2);
3590
+ const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
3591
+ if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
3592
+ inAxesFlat.push(...axesFlat);
3330
3593
  }
3331
3594
  const [fFlat, outTree] = flattenFun(f, inTree);
3332
3595
  const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
@@ -3494,8 +3757,8 @@ var PartialEvalTrace = class extends Trace {
3494
3757
  processPrimitive(primitive, tracers, params) {
3495
3758
  if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
3496
3759
  if (primitive === Primitive.JitCall) {
3497
- const { jaxpr, numConsts } = params;
3498
- return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
3760
+ const { name, jaxpr, numConsts } = params;
3761
+ return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
3499
3762
  }
3500
3763
  const tracersIn = tracers.map((t) => this.instantiateConst(t));
3501
3764
  const avalsIn = tracersIn.map((t) => t.pval.aval);
@@ -3521,12 +3784,13 @@ var PartialEvalTrace = class extends Trace {
3521
3784
  *
3522
3785
  * Used when encountering a JitCall rule during the trace.
3523
3786
  */
3524
- #partialEvalJaxpr(jaxpr, numConsts, tracers) {
3787
+ #partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
3525
3788
  jaxpr = jaxpr.flatten();
3526
3789
  const inUnknowns = tracers.map((t) => !t.pval.isKnown);
3527
3790
  const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
3528
3791
  const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
3529
3792
  const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
3793
+ name: `${name}_peval`,
3530
3794
  jaxpr: jaxpr1,
3531
3795
  numConsts: 0
3532
3796
  });
@@ -3538,6 +3802,7 @@ var PartialEvalTrace = class extends Trace {
3538
3802
  prim: Primitive.JitCall,
3539
3803
  tracersIn: resTracers.concat(unknownTracers),
3540
3804
  params: {
3805
+ name: `${name}_resid`,
3541
3806
  jaxpr: jaxpr2,
3542
3807
  numConsts: 0
3543
3808
  },
@@ -3680,7 +3945,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
3680
3945
  }
3681
3946
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
3682
3947
  const eqn = jaxpr.eqns[i];
3683
- const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? scalar(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3948
+ const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
3684
3949
  const cotangentsOut = eqn.outBinders.map(readCotangent);
3685
3950
  const rule = transposeRules[eqn.primitive];
3686
3951
  if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
@@ -3765,7 +4030,7 @@ const transposeRules = {
3765
4030
  },
3766
4031
  [Primitive.Dot]([ct], [x, y]) {
3767
4032
  if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
3768
- const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
4033
+ const axisSize = require_backend.generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
3769
4034
  ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
3770
4035
  return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
3771
4036
  },
@@ -3860,7 +4125,7 @@ const transposeRules = {
3860
4125
  if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
3861
4126
  throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
3862
4127
  },
3863
- [Primitive.JitCall](cts, args, { jaxpr }) {
4128
+ [Primitive.JitCall](cts, args, { name, jaxpr }) {
3864
4129
  const undefPrimals = args.map((x) => x instanceof UndefPrimal);
3865
4130
  const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
3866
4131
  const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
@@ -3869,6 +4134,7 @@ const transposeRules = {
3869
4134
  ...residuals,
3870
4135
  ...cts
3871
4136
  ], {
4137
+ name: `${name}_t`,
3872
4138
  jaxpr: newJaxpr,
3873
4139
  numConsts: newConsts.length
3874
4140
  });
@@ -3943,7 +4209,7 @@ function valueAndGrad$1(f) {
3943
4209
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3944
4210
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3945
4211
  if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3946
- const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
4212
+ const [ct, ...rest] = fVjp(onesLike$1(y.ref));
3947
4213
  for (const r of rest) dispose(r);
3948
4214
  fVjp.dispose();
3949
4215
  return [y, ct];
@@ -3971,7 +4237,10 @@ __export(lax_exports, {
3971
4237
  conv: () => conv$1,
3972
4238
  convGeneralDilated: () => convGeneralDilated,
3973
4239
  convWithGeneralPadding: () => convWithGeneralPadding,
3974
- reduceWindow: () => reduceWindow
4240
+ erf: () => erf,
4241
+ erfc: () => erfc,
4242
+ reduceWindow: () => reduceWindow,
4243
+ stopGradient: () => stopGradient$1
3975
4244
  });
3976
4245
  function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
3977
4246
  const padType = padding.toUpperCase();
@@ -4030,6 +4299,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4030
4299
  strides: windowStrides
4031
4300
  }));
4032
4301
  }
4302
+ /** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
4303
+ function erf(x) {
4304
+ return erf$1(x);
4305
+ }
4306
+ /**
4307
+ * The complementary error function: `erfc(x) = 1 - erf(x)`.
4308
+ *
4309
+ * This function is more accurate than `1 - erf(x)` for large values of `x`,
4310
+ * where `erf(x)` is very close to 1.
4311
+ */
4312
+ function erfc(x) {
4313
+ return erfc$1(x);
4314
+ }
4315
+ /**
4316
+ * Stops gradient computation.
4317
+ *
4318
+ * Behaves as the identity function but prevents the flow of gradients during
4319
+ * forward or reverse-mode automatic differentiation.
4320
+ */
4321
+ function stopGradient$1(x) {
4322
+ return stopGradient(x);
4323
+ }
4033
4324
 
4034
4325
  //#endregion
4035
4326
  //#region src/numpy.ts
@@ -4092,6 +4383,9 @@ __export(numpy_exports, {
4092
4383
  fullLike: () => fullLike$1,
4093
4384
  greater: () => greater,
4094
4385
  greaterEqual: () => greaterEqual,
4386
+ hamming: () => hamming,
4387
+ hann: () => hann,
4388
+ heaviside: () => heaviside,
4095
4389
  hstack: () => hstack,
4096
4390
  hypot: () => hypot,
4097
4391
  identity: () => identity$1,
@@ -4313,7 +4607,7 @@ function argmin(a, axis, opts) {
4313
4607
  } else axis = require_backend.checkAxis(axis, a.ndim);
4314
4608
  const shape$1 = a.shape;
4315
4609
  const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
4316
- const length = scalar(shape$1[axis], {
4610
+ const length = array(shape$1[axis], {
4317
4611
  dtype: int32,
4318
4612
  device: a.device
4319
4613
  });
@@ -4337,7 +4631,7 @@ function argmax(a, axis, opts) {
4337
4631
  } else axis = require_backend.checkAxis(axis, a.ndim);
4338
4632
  const shape$1 = a.shape;
4339
4633
  const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
4340
- const length = scalar(shape$1[axis], {
4634
+ const length = array(shape$1[axis], {
4341
4635
  dtype: int32,
4342
4636
  device: a.device
4343
4637
  });
@@ -4521,7 +4815,7 @@ function broadcastTo(a, shape$1) {
4521
4815
  /** Broadcast input shapes to a common output shape. */
4522
4816
  function broadcastShapes(...shapes) {
4523
4817
  if (shapes.length === 0) return [];
4524
- return shapes.reduce(generalBroadcast);
4818
+ return shapes.reduce(require_backend.generalBroadcast);
4525
4819
  }
4526
4820
  /** Broadcast arrays to a common shape. */
4527
4821
  function broadcastArrays(...arrays) {
@@ -4731,6 +5025,32 @@ function sign(x) {
4731
5025
  x = fudgeArray(x);
4732
5026
  return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4733
5027
  }
5028
+ /**
5029
+ * Return the Hamming window of size M, a taper with a weighted cosine bell.
5030
+ *
5031
+ * `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5032
+ */
5033
+ function hamming(M) {
5034
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
5035
+ }
5036
+ /**
5037
+ * Return the Hann window of size M, a taper with a weighted cosine bell.
5038
+ *
5039
+ * `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
5040
+ */
5041
+ function hann(M) {
5042
+ return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
5043
+ }
5044
+ /**
5045
+ * @function
5046
+ * Compute the Heaviside step function. It is defined piecewise:
5047
+ * - `heaviside(x1, x2) = 0` for `x1 < 0`,
5048
+ * - `heaviside(x1, x2) = x2` for `x1 == 0`,
5049
+ * - `heaviside(x1, x2) = 1` for `x1 > 0`.
5050
+ */
5051
+ const heaviside = jit$1(function heaviside$1(x1, x2) {
5052
+ return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
5053
+ });
4734
5054
  /** Calculate element-wise square of the input array. */
4735
5055
  function square(x) {
4736
5056
  x = fudgeArray(x);
@@ -4750,10 +5070,10 @@ function acos(x) {
4750
5070
  * Return element-wise hypotenuse for the given legs of a right triangle.
4751
5071
  *
4752
5072
  * In the original NumPy/JAX implementation, this function is more numerically
4753
- * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4754
- * improvements.
5073
+ * stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
5074
+ * stability improvements.
4755
5075
  */
4756
- const hypot = jit$1((x1, x2) => {
5076
+ const hypot = jit$1(function hypot$1(x1, x2) {
4757
5077
  return sqrt(square(x1).add(square(x2)));
4758
5078
  });
4759
5079
  /**
@@ -4769,7 +5089,7 @@ const hypot = jit$1((x1, x2) => {
4769
5089
  *
4770
5090
  * The output is ill-defined when both x and y are zero.
4771
5091
  */
4772
- const atan2 = jit$1((y, x) => {
5092
+ const atan2 = jit$1(function atan2$1(y, x) {
4773
5093
  const r = sqrt(square(x.ref).add(square(y.ref)));
4774
5094
  const xNeg = less(x.ref, 0);
4775
5095
  const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
@@ -4837,13 +5157,13 @@ const degrees = rad2deg;
4837
5157
  * @function
4838
5158
  * Computes first array raised to power of second array, element-wise.
4839
5159
  */
4840
- const power = jit$1((x1, x2) => {
5160
+ const power = jit$1(function power$1(x1, x2) {
4841
5161
  return exp(log(x1).mul(x2));
4842
5162
  });
4843
5163
  /** @function Alias of `jax.numpy.power()`. */
4844
5164
  const pow = power;
4845
5165
  /** @function Calculate the element-wise cube root of the input array. */
4846
- const cbrt = jit$1((x) => {
5166
+ const cbrt = jit$1(function cbrt$1(x) {
4847
5167
  const sgn = where(less(x.ref, 0), -1, 1);
4848
5168
  return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4849
5169
  });
@@ -4853,7 +5173,7 @@ const cbrt = jit$1((x) => {
4853
5173
  *
4854
5174
  * `sinh(x) = (exp(x) - exp(-x)) / 2`
4855
5175
  */
4856
- const sinh = jit$1((x) => {
5176
+ const sinh = jit$1(function sinh$1(x) {
4857
5177
  const ex = exp(x);
4858
5178
  const emx = reciprocal(ex.ref);
4859
5179
  return ex.sub(emx).mul(.5);
@@ -4864,7 +5184,7 @@ const sinh = jit$1((x) => {
4864
5184
  *
4865
5185
  * `cosh(x) = (exp(x) + exp(-x)) / 2`
4866
5186
  */
4867
- const cosh = jit$1((x) => {
5187
+ const cosh = jit$1(function cosh$1(x) {
4868
5188
  const ex = exp(x);
4869
5189
  const emx = reciprocal(ex.ref);
4870
5190
  return ex.add(emx).mul(.5);
@@ -4875,7 +5195,7 @@ const cosh = jit$1((x) => {
4875
5195
  *
4876
5196
  * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4877
5197
  */
4878
- const tanh = jit$1((x) => {
5198
+ const tanh = jit$1(function tanh$1(x) {
4879
5199
  const negsgn = where(less(x.ref, 0), 1, -1);
4880
5200
  const en2x = exp(x.mul(negsgn.ref).mul(2));
4881
5201
  return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
@@ -4886,7 +5206,7 @@ const tanh = jit$1((x) => {
4886
5206
  *
4887
5207
  * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4888
5208
  */
4889
- const arcsinh = jit$1((x) => {
5209
+ const arcsinh = jit$1(function arcsinh$1(x) {
4890
5210
  return log(x.ref.add(sqrt(square(x).add(1))));
4891
5211
  });
4892
5212
  /**
@@ -4895,7 +5215,7 @@ const arcsinh = jit$1((x) => {
4895
5215
  *
4896
5216
  * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4897
5217
  */
4898
- const arccosh = jit$1((x) => {
5218
+ const arccosh = jit$1(function arccosh$1(x) {
4899
5219
  return log(x.ref.add(sqrt(square(x).sub(1))));
4900
5220
  });
4901
5221
  /**
@@ -4904,7 +5224,7 @@ const arccosh = jit$1((x) => {
4904
5224
  *
4905
5225
  * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4906
5226
  */
4907
- const arctanh = jit$1((x) => {
5227
+ const arctanh = jit$1(function arctanh$1(x) {
4908
5228
  return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4909
5229
  });
4910
5230
  /** @function Alias of `jax.numpy.arcsinh()`. */
@@ -5020,7 +5340,9 @@ function softSign(x) {
5020
5340
  *
5021
5341
  * Reference: https://en.wikipedia.org/wiki/Swish_function
5022
5342
  */
5023
- const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
5343
+ const silu = jit$1(function silu$1(x) {
5344
+ return x.ref.mul(sigmoid(x));
5345
+ });
5024
5346
  /**
5025
5347
  * @function
5026
5348
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
@@ -5073,18 +5395,20 @@ function celu(x, alpha = 1) {
5073
5395
  * @function
5074
5396
  * Gaussion error linear unit (GELU) activation function.
5075
5397
  *
5076
- * This is computed element-wise. Currently jax-js does not support the erf() or
5077
- * gelu() functions exactly as primitives, so an approximation is used:
5078
- * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
5398
+ * This is computed element-wise. There are two variants depending on whether
5399
+ * `approximate` is set (default true):
5079
5400
  *
5080
- * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5401
+ * - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
5402
+ * - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
5081
5403
  *
5082
- * This will be improved in the future.
5404
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5083
5405
  */
5084
- const gelu = jit$1((x) => {
5085
- const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5086
- return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5087
- });
5406
+ const gelu = jit$1(function gelu$1(x, opts) {
5407
+ if (opts?.approximate ?? true) {
5408
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5409
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5410
+ } else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
5411
+ }, { staticArgnums: [1] });
5088
5412
  /**
5089
5413
  * Gated linear unit (GLU) activation function.
5090
5414
  *
@@ -5252,8 +5576,11 @@ function bits(key$1, shape$1 = []) {
5252
5576
  const keyShape = validateKeyShape(key$1);
5253
5577
  return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
5254
5578
  }
5255
- /** Sample uniform random values in [minval, maxval) with given shape. */
5256
- function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5579
+ /**
5580
+ * @function
5581
+ * Sample uniform random values in [minval, maxval) with given shape.
5582
+ */
5583
+ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5257
5584
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
5258
5585
  const mantissa = bits(key$1, shape$1).div(array(512, {
5259
5586
  dtype: require_backend.DType.Uint32,
@@ -5266,7 +5593,7 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
5266
5593
  const rand = bitcast(float12, require_backend.DType.Float32).sub(1);
5267
5594
  if (minval === 0 && maxval === 1) return rand;
5268
5595
  else return rand.mul(maxval - minval).add(minval);
5269
- }
5596
+ }, { staticArgnums: [1, 2] });
5270
5597
  /**
5271
5598
  * Sample Bernoulli random variables with given mean (0,1 categorical).
5272
5599
  *
@@ -5277,26 +5604,49 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
5277
5604
  p = fudgeArray(p);
5278
5605
  return uniform(key$1, shape$1).less(p);
5279
5606
  }
5280
- /** Sample exponential random values according to `p(x) = exp(-x)`. */
5281
- function exponential(key$1, shape$1 = []) {
5607
+ /**
5608
+ * @function
5609
+ * Sample exponential random values according to `p(x) = exp(-x)`.
5610
+ */
5611
+ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
5282
5612
  const u = uniform(key$1, shape$1);
5283
5613
  return negative(log1p(negative(u)));
5284
- }
5614
+ }, { staticArgnums: [1] });
5285
5615
  /**
5616
+ * @function
5286
5617
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5287
5618
  *
5288
5619
  * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5289
5620
  * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5290
5621
  * bitwise identical to JAX.
5291
5622
  */
5292
- function normal(key$1, shape$1 = []) {
5623
+ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
5293
5624
  const [k1, k2] = split(key$1, 2);
5294
5625
  const u1 = uniform(k1, shape$1);
5295
5626
  const u2 = uniform(k2, shape$1);
5296
5627
  const radius = sqrt(log1p(negative(u1)).mul(-2));
5297
5628
  const theta = u2.mul(2 * Math.PI);
5298
5629
  return radius.mul(cos(theta));
5299
- }
5630
+ }, { staticArgnums: [1] });
5631
+
5632
+ //#endregion
5633
+ //#region src/scipy-special.ts
5634
+ var scipy_special_exports = {};
5635
+ __export(scipy_special_exports, {
5636
+ erf: () => erf,
5637
+ erfc: () => erfc,
5638
+ logSoftmax: () => logSoftmax,
5639
+ logit: () => logit,
5640
+ logsumexp: () => logsumexp,
5641
+ softmax: () => softmax
5642
+ });
5643
+ /**
5644
+ * @function
5645
+ * The logit function, `logit(p) = log(p / (1-p))`.
5646
+ */
5647
+ const logit = jit$1(function logit$1(x) {
5648
+ return log(x.ref.div(subtract(1, x)));
5649
+ });
5300
5650
 
5301
5651
  //#endregion
5302
5652
  //#region src/polyfills.ts
@@ -5391,6 +5741,24 @@ async function blockUntilReady(x) {
5391
5741
  await Promise.all(promises);
5392
5742
  return x;
5393
5743
  }
5744
+ /**
5745
+ * Transfer `x` to `device`.
5746
+ *
5747
+ * `x` may be a nested container of arrays or scalars. The resulting structure
5748
+ * is committed to the device.
5749
+ *
5750
+ * If `device` is not specified, this function behaves as identity if the input
5751
+ * is already an `Array`, otherwise it places the scalar uncommitted on the
5752
+ * default device.
5753
+ */
5754
+ async function devicePut(x, device) {
5755
+ const [xflat, structure$1] = flatten(x);
5756
+ const yflat = await Promise.all(xflat.map((leaf) => {
5757
+ if (leaf instanceof Array$1) return device ? leaf._put(require_backend.getBackend(device)) : Promise.resolve(leaf);
5758
+ else return Promise.resolve(array(leaf, { device }));
5759
+ }));
5760
+ return unflatten(structure$1, yflat);
5761
+ }
5394
5762
 
5395
5763
  //#endregion
5396
5764
  exports.Array = Array$1;
@@ -5398,6 +5766,7 @@ exports.DType = require_backend.DType;
5398
5766
  exports.Jaxpr = Jaxpr;
5399
5767
  exports.blockUntilReady = blockUntilReady;
5400
5768
  exports.defaultDevice = require_backend.defaultDevice;
5769
+ exports.devicePut = devicePut;
5401
5770
  exports.devices = require_backend.devices;
5402
5771
  exports.grad = grad;
5403
5772
  exports.init = require_backend.init;
@@ -5432,6 +5801,12 @@ Object.defineProperty(exports, 'random', {
5432
5801
  return random_exports;
5433
5802
  }
5434
5803
  });
5804
+ Object.defineProperty(exports, 'scipySpecial', {
5805
+ enumerable: true,
5806
+ get: function () {
5807
+ return scipy_special_exports;
5808
+ }
5809
+ });
5435
5810
  exports.setDebug = require_backend.setDebug;
5436
5811
  Object.defineProperty(exports, 'tree', {
5437
5812
  enumerable: true,
@@ -5441,4 +5816,5 @@ Object.defineProperty(exports, 'tree', {
5441
5816
  });
5442
5817
  exports.valueAndGrad = valueAndGrad;
5443
5818
  exports.vjp = vjp;
5444
- exports.vmap = vmap;
5819
+ exports.vmap = vmap;
5820
+ //# sourceMappingURL=index.cjs.map